PyTorch 中的 Transformer 源码 [LLM]
1979 words
10 minutes
PyTorch 中的 Transformer 源码 [LLM]
Image API Error
上节我们自己实现了一个 Transformer,这节我们通过源码来看一下 PyTorch 官方是如何实现 Transformer 的。内容按以下顺序进行:从最简单的注意力组件开始向上直到顶层的 Transformer 类。
1. 多头自注意力的实现#
MHA 在 pytorch/torch/nn/modules/activation.py 中
1class MultiheadAttention(Module):2
3 # batch_first 在模块生命周期内不会改变,可以作为一个常量处理4 __constants__ = ["batch_first"]5 # 类型注解:bias_k,v 可以是 Tensor 或 None6 bias_k: torch.Tensor | None7 bias_v: torch.Tensor | None8
9 def __init__(10 self,11 embed_dim,12 num_heads,13 dropout=0.0,14 bias=True,15 add_bias_kv=False,16 add_zero_attn=False,17 kdim=None,18 vdim=None,19 batch_first=False,20 device=None,21 dtype=None,22 ) -> None:23 # 创建一个字典参数,后续通过 **factory_kwargs 直接解包24 factory_kwargs = {"device": device, "dtype": dtype}25 super().__init__()26 self.embed_dim = embed_dim27 self.kdim = kdim if kdim is not None else embed_dim28 self.vdim = vdim if vdim is not None else embed_dim29 # 检查 qkv 的维度是否一样30 self._qkv_same_embed_dim = (self.kdim == embed_dim and self.vdim == embed_dim)31
32 self.num_heads = num_heads33 self.dropout = dropout34 self.batch_first = batch_first35 # 划分子空间36 self.head_dim = embed_dim // num_heads37
38 # qkvo 初始化39 if not self._qkv_same_embed_dim:40 else:41 self.in_proj_weight = Parameter(42 torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)43 )44 self.register_parameter("q_proj_weight", None)45 self.register_parameter("k_proj_weight", None)46 self.register_parameter("v_proj_weight", None)47
48 # qkv 在投影时是否需要 bias (Q = W_q @ X + b_q...)49 if bias:50 self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))51 else:52 self.register_parameter("in_proj_bias", None)53
54 # 与普通的 Linear 层相同,只是名称带有非量化的显式标记55 self.out_proj = NonDynamicallyQuantizableLinear(56 embed_dim, embed_dim, bias=bias, **factory_kwargs57 )58
59 # 是否在 K 和 V 序列的开头添加一个额外的可学习向量60 if add_bias_kv:61 self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))62 self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))63 else:64 self.bias_k = self.bias_v = None65
66 self.add_zero_attn = add_zero_attn67
68 self._reset_parameters()2. LN 的实现#
2.1 LayerNorm 类#
LN 在 pytorch/torch/nn/modules/normalization.py 中:
1_shape_t = int | list[int] | Size2
3class LayerNorm(Module):4
5 def __init__(6 self,7 # 定义要归一化的维度的大小8 normalized_shape: _shape_t,9 # 公式中的 ε10 eps: float = 1e-5,11 # 使用 γ 和 β12 elementwise_affine: bool = True,13 bias: bool = True,14 device=None,15 dtype=None,16 ) -> None:17 factory_kwargs = {"device": device, "dtype": dtype}18 super().__init__()19 # 给成员赋值20 self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]21 self.eps = eps22 self.elementwise_affine = elementwise_affine23 if self.elementwise_affine:24 self.weight = Parameter(25 torch.empty(self.normalized_shape, **factory_kwargs)26 )27 if bias:28 self.bias = Parameter(29 torch.empty(self.normalized_shape, **factory_kwargs)30 )31 else:32 self.register_parameter("bias", None)33 else:34 self.register_parameter("weight", None)35 self.register_parameter("bias", None)36
37 self.reset_parameters()38
39 def reset_parameters(self) -> None:40 if self.elementwise_affine:41 # γ 初始值是 142 init.ones_(self.weight)43 if self.bias is not None:44 # β 初始值是 045 init.zeros_(self.bias)46
47 def forward(self, input: Tensor) -> Tensor:48 # 将参数传给 functional.layer_norm()49 return F.layer_norm(50 input, self.normalized_shape, self.weight, self.bias, self.eps51 )2.2 functional.py#
pytorch/torch/nn/functional.py 将底层的高效 C++ 实现包装成优雅的 Python API,供 nn.Module 和开发者调用
1def layer_norm(2 input: Tensor,3 normalized_shape: list[int],4 weight: Tensor | None = None,5 bias: Tensor | None = None,6 eps: float = 1e-5,7) -> Tensor:8 r"""Apply Layer Normalization for last certain number of dimensions.9
10 See :class:`~torch.nn.LayerNorm` for details.11 """12 # 如果自定义了张量 (variadic: 可变参数),将调用交给自定义类的 __torch_function__ 方法13 if has_torch_function_variadic(input, weight, bias):14 return handle_torch_function(15 layer_norm,16 (input, weight, bias),17 input,18 normalized_shape,19 weight=weight,20 bias=bias,21 eps=eps,22 )23 # 否则调用 C++ 实现的高性能版本24 return torch.layer_norm(25 input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled26 )2.3 c++ 实现#
LN 计算底层在 pytorch/aten/src/ATen/native/cuda/layer_norm_kernel.cu 中实现,aten 文件夹是 PyTorch 的核心底层张量计算库,存放所有基础张量操作的 C++/CUDA 实现
1std::tuple<Tensor, Tensor, Tensor> layer_norm_cuda(2 const Tensor& input,3 IntArrayRef normalized_shape,4 const std::optional<Tensor>& weight_opt /* optional */,5 const std::optional<Tensor>& bias_opt /* optional */,6 double eps) {7
8 c10::MaybeOwned<Tensor> weight_maybe_owned =9 at::borrow_from_optional_tensor(weight_opt);10 // 解引用获取 Tensor 引用11 const Tensor& weight = *weight_maybe_owned;12
13 c10::MaybeOwned<Tensor> bias_maybe_owned =14 at::borrow_from_optional_tensor(bias_opt);15 const Tensor& bias = *bias_maybe_owned;16
17 // auto 让编译器自动推断变量的类型18 auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias);19 // M_N 的 第一个和第二个元素20 auto M = M_N.first;21 auto N = M_N.second;22 auto X = input.expect_contiguous();23 auto gamma = weight.expect_contiguous();24 auto beta = bias.expect_contiguous();25
26 // 输出张量 Y27 Tensor Y = at::native::empty_like(...);28 // 均值张量和 std 张量29 Tensor mean = at::empty(...);30 Tensor rstd = at::empty(...);31 // M > 0 时调用 kernel32 if (M > 0) {33 LayerNormKernelImpl(*X, *gamma, *beta, M, N, eps, &Y, &mean, &rstd);34 }35 ...36 // 计算得到的均值和 std 通过 view 改变维度与输入对齐进行广播37 mean = mean.view(stat_shape);38 rstd = rstd.view(stat_shape);39 // 返回 mean, std 用于 bp40 return std::make_tuple(std::move(Y), std::move(mean), std::move(rstd));41}42
43void LayerNormKernelImpl(...) {44 // 根据输入张量的数据类型,自动生成并调用针对该类型的模板代码45 AT_DISPATCH_FLOATING_TYPES_AND2(46 at::ScalarType::Half,47 at::ScalarType::BFloat16,48 X.scalar_type(),49 "LayerNormKernelImpl",50 [&]() {51 using acc_t = acc_type<scalar_t, true>;52 // 执行计算53 LayerNormKernelImplInternal<scalar_t, acc_t>(54 X, gamma, beta, M, N, static_cast<acc_t>(eps), Y, mean, rstd);55 });56}57
58// 最后在 LayerNormKernelImplInternal 中调用 RowwiseMomentsCUDAKernel 和 LayerNormForwardCUDAKernel 计算3. encoder 的实现#
encoder,decoder 都在 pytorch/torch/nn/modules/transformer.py 中
3.1 TransformerEncoderLayer 类#
1class TransformerEncoderLayer(Module):2
3 def __init__(4 self,5 d_model: int,6 nhead: int,7 # FFN 的维度8 dim_feedforward: int = 2048,9 dropout: float = 0.1,10 # 激活函数11 activation: str | Callable[[Tensor], Tensor] = F.relu,12 layer_norm_eps: float = 1e-5,13 batch_first: bool = False,14 # Pre-LN 还是 Post-LN15 norm_first: bool = False,16 bias: bool = True,17 device=None,18 dtype=None,19 ) -> None:20 factory_kwargs = {"device": device, "dtype": dtype}21 super().__init__()22 # 自注意力23 self.self_attn = MultiheadAttention(...)24
25 # FFN26 self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)27 self.dropout = Dropout(dropout)28 self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)29
30 # LN31 self.norm_first = norm_first32 # norm1 用于自注意力,norm2 用于 FFN33 self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)34 self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)35 self.dropout1 = Dropout(dropout)36 self.dropout2 = Dropout(dropout)37 # ... (略去激活函数处理细节)38 # 定义激活函数39 self.activation = activation40
41 def forward(42 self,43 src: Tensor,44 # 注意力掩码45 src_mask: Tensor | None = None,46 # pad 掩码47 src_key_padding_mask: Tensor | None = None,48 is_causal: bool = False,49 ) -> Tensor:50 # 将传入的 mask 格式规范化51 src_key_padding_mask = F._canonical_mask(...)52 src_mask = F._canonical_mask(...)53 # 计算逻辑54 x = src55 # Pre-LN56 if self.norm_first:57 # 注意力计算58 x = x + self._sa_block(59 self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal60 )61 x = x + self._ff_block(self.norm2(x))62 else:63 # Post-LN64 x = self.norm1(65 x66 + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal)67 )68 # FFN69 x = self.norm2(x + self._ff_block(x))70
71 return x72
73 def _sa_block(...) -> Tensor:74 # 上面定义了 self.self_attn = MultiheadAttention(...)75 x = self.self_attn(76 # 自注意力 qkv 都是 x77 x,78 x,79 x,80 attn_mask=attn_mask,81 key_padding_mask=key_padding_mask,82 need_weights=False,83 is_causal=is_causal,84 )[0]85 return self.dropout1(x)86
87 def _ff_block(self, x: Tensor) -> Tensor:88 x = self.linear2(self.dropout(self.activation(self.linear1(x))))89 return self.dropout2(x)3.2 TransformerEncoder 类#
通过 TransformerEncoder 类对 TransformerEncoderLayer 进行堆叠:
1class TransformerEncoder(Module):2 def __init__(3 self,4 encoder_layer: "TransformerEncoderLayer",5 num_layers: int,6 norm: Module | None = None,7 enable_nested_tensor: bool = True,8 mask_check: bool = True,9 ) -> None:10 super().__init__()11 # 创建 num_layers 个编码器层的深拷贝12 self.layers = _get_clones(encoder_layer, num_layers)13 self.num_layers = num_layers14 self.norm = norm15 self.mask_check = mask_check16
17 def forward(18 self,19 src: Tensor,20 mask: Tensor | None = None,21 src_key_padding_mask: Tensor | None = None,22 is_causal: bool | None = None,23 ) -> Tensor:24
25 seq_len = _get_seq_len(src, batch_first)26
27 for mod in self.layers:28 # 循环调用 model 的 forward 方法29 output = mod(30 output,31 src_mask=mask,32 is_causal=is_causal,33 src_key_padding_mask=src_key_padding_mask_for_layers,34 )35
36 return output4. decoder 的实现#
TransformerDecoder 堆叠 layer 的方式与 TransformerEncoder 类似,略去:
1class TransformerDecoderLayer(Module):2
3 def __init__(4 self,5 d_model: int,6 nhead: int,7 # FFN 的中间维度8 dim_feedforward: int = 2048,9 ...10 ) -> None:11 self.self_attn = MultiheadAttention(...)12 self.multihead_attn = MultiheadAttention(...)13
14 # 定义 FFN 和 残差15 # 跟 encoder 相比多一个交叉注意力,因此需要 norm3 和 dropout316
17 def forward(18 self,19 # decoder 的输入 编码序列20 tgt: Tensor,21 # encoder 输出22 memory: Tensor,23 # 屏蔽未来 token24 tgt_mask: Tensor | None = None,25 # 用于交叉注意力26 memory_mask: Tensor | None = None,27 # 屏蔽目标序列的 <pad>28 tgt_key_padding_mask: Tensor | None = None,29 # 屏蔽源序列 <pad>30 memory_key_padding_mask: Tensor | None = None,31 ) -> Tensor:32
33 x = tgt34 if self.norm_first:35 # mmha36 x = x + self._sa_block(37 self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal38 )39 # 交叉注意力40 x = x + self._mha_block(41 self.norm2(x),42 memory,43 memory_mask,44 memory_key_padding_mask,45 memory_is_causal,46 )47 x = x + self._ff_block(self.norm3(x))48 else:49 x = self.norm1(50 x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal)51 )52 x = self.norm2(53 x54 + self._mha_block(55 x, memory, memory_mask, memory_key_padding_mask, memory_is_causal56 )57 )58 x = self.norm3(x + self._ff_block(x))59
60 return x61
62 def _mha_block(63 self,64 x: Tensor,65 mem: Tensor,66 attn_mask: Tensor | None,67 key_padding_mask: Tensor | None,68 is_causal: bool = False,69 ) -> Tensor:70 x = self.multihead_attn(71 x,72 # kv 是 encoder 输出73 mem,74 mem,75 attn_mask=attn_mask,76 key_padding_mask=key_padding_mask,77 is_causal=is_causal,78 need_weights=False,79 )[0]80 return self.dropout2(x)5. 组件整合#
Transformer 类是整个模型的顶层组装器,它负责将编码器和解码器组合成一个完整的端到端模型:
1class Transformer(Module):2
3 def __init__(4 self,5 # 默认值6 d_model: int = 512,7 nhead: int = 8,8 num_encoder_layers: int = 6,9 num_decoder_layers: int = 6,10 dim_feedforward: int = 2048,11 dropout: float = 0.1,12 activation: str | Callable[[Tensor], Tensor] = F.relu,13 custom_encoder: Any | None = None,14 custom_decoder: Any | None = None,15 layer_norm_eps: float = 1e-5,16 batch_first: bool = False,17 norm_first: bool = False,18 bias: bool = True,19 device=None,20 dtype=None,21 ) -> None:22
23 # 搭建网络24 if custom_encoder is not None:25 else:26 encoder_layer = TransformerEncoderLayer(27 d_model,28 nhead,29 dim_feedforward,30 dropout,31 activation,32 layer_norm_eps,33 batch_first,34 norm_first,35 bias,36 **factory_kwargs,37 )38 encoder_norm = LayerNorm(39 d_model,40 eps=layer_norm_eps,41 bias=bias,42 **factory_kwargs,43 )44 # 创建 encoder45 self.encoder = TransformerEncoder(46 encoder_layer, num_encoder_layers, encoder_norm47 )48
49 if custom_decoder is not None:50 else:51 decoder_layer = TransformerDecoderLayer(52 d_model,53 nhead,54 dim_feedforward,55 dropout,56 activation,57 layer_norm_eps,58 batch_first,59 norm_first,60 bias,61 **factory_kwargs,62 )63 decoder_norm = LayerNorm(64 d_model,65 eps=layer_norm_eps,66 bias=bias,67 **factory_kwargs,68 )69 # 创建 decoder70 self.decoder = TransformerDecoder(71 decoder_layer, num_decoder_layers, decoder_norm72 )73
74 self._reset_parameters()75
76 self.d_model = d_model77 self.nhead = nhead78
79 self.batch_first = batch_first80
81 def forward(82 self,83 src: Tensor,84 tgt: Tensor,85 src_mask: Tensor | None = None,86 tgt_mask: Tensor | None = None,87 memory_mask: Tensor | None = None,88 src_key_padding_mask: Tensor | None = None,89 tgt_key_padding_mask: Tensor | None = None,90 memory_key_padding_mask: Tensor | None = None,91 src_is_causal: bool | None = None,92 tgt_is_causal: bool | None = None,93 memory_is_causal: bool = False,94 ) -> Tensor:95
96 # encoder 输出, 调用 encoder 的 forward()97 memory = self.encoder(98 src,99 mask=src_mask,100 src_key_padding_mask=src_key_padding_mask,101 is_causal=src_is_causal,102 )103
104 # decoder 输出105 output = self.decoder(106 tgt,107 memory,108 tgt_mask=tgt_mask,109 memory_mask=memory_mask,110 tgt_key_padding_mask=tgt_key_padding_mask,111 memory_key_padding_mask=memory_key_padding_mask,112 tgt_is_causal=tgt_is_causal,113 memory_is_causal=memory_is_causal,114 )115 return outputComments
Site Statistics
144
6
9
2,255,454
0 days
0 days ago
2026年6月
Less More
日
一
二
三
四
五
六
Table of Contents