Transformer-XL 的源码 [LLM]
2286 words
11 minutes
Transformer-XL 的源码 [LLM]
Image API Error
这节我们来学习 Transformer-XL 的源码,模型的论文内容在之前已有介绍 👉 taffybook.cn/posts/paper/transformer-xl/ ,其核心类如下:
- 自适应词嵌入 —
AdaptiveEmbedding- 位置编码 —
PositionalEmbedding- 注意力层 —
RelPartialLearnableMultiHeadAttn
├──qkv_net: Linear# W_q, W_k,E, W_v(联合投影)
├──r_net: Linear# W_k,R(位置编码投影)
├──o_net: Linear# 输出投影
└──_rel_shift()# 相对位置偏移(关键操作)decoder—RelPartialLearnableDecoderLayer- 前馈网络 —
PositionwiseFF- 记忆机制
├──_init_mems()# 初始化记忆
└──_update_mems()# 更新记忆- 输出 —
ProjectedAdaptiveLogSoftmax
1 词嵌入#
自适应嵌入可以根据词频动态分配嵌入维度来解决传统嵌入的参数效率问题
- 解决低频词维度过多,高频词维度可能不够的问题
1class AdaptiveEmbedding(nn.Module):2 def __init__(3 self,4 n_token, # 词表大小5 d_embed,6 d_proj, # 投影后维度7 cutoffs, # 词频分界点8 div_val=1, # 用于维度计算9 sample_softmax=False):10 super().__init__()11
12 # 成员初始化13 self.n_token = n_token14 self.d_embed = d_embed15 # 边界设定16 # 向 cutoffs 数组最后加上 n_token 作为最后一个边界17 self.cutoffs = cutoffs + [n_token]18 # 添加 019 self.cutoff_ends = [0] + self.cutoffs20 self.div_val = div_val21 self.d_proj = d_proj22 self.emb_scale = d_proj ** 0.523 self.emb_layers = nn.ModuleList()24 self.emb_projs = nn.ParameterList()25
26 # 所有词相同维度,共用同一个嵌入表27 if div_val == 1:28 self.emb_layers.append(29 nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)30 )31 # 投影维度 ≠ 嵌入维度时需要投影32 if d_proj != d_embed:33 self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))34 else:35 for i in range(len(self.cutoffs)):36 # 当前组的词范围37 l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]38 # 计算嵌入维度39 d_emb_i = d_embed // (div_val ** i)40
41 self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))42 self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))43
44 def forward(self, inp):45 if self.div_val == 1:46 # 调用 emb_layers[0] 的 forward(), 对输入 inp 进行嵌入47 embed = self.emb_layers[0](inp)48 # 需要投影时传入 linear 改变维度49 if self.d_proj != self.d_embed:50 embed = F.linear(embed, self.emb_projs[0])51 else:52 # 获取参数的引用53 param = next(self.parameters())54 # 展成一维55 inp_flat = inp.view(-1)56 # 初始化嵌入矩阵57 emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],58 dtype=param.dtype, device=param.device)59 # 遍历每个频率组60 for i in range(len(self.cutoffs)):61 l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]62 # 通过掩码将词与对应的词频组对应起来63 mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)64 # 为 True 的位置索引65 indices_i = mask_i.nonzero().squeeze()66
67 # 跳过空组68 if indices_i.numel() == 0:69 continue70 # 减去 l_idx 将全局索引转组内索引71 inp_i = inp_flat.index_select(0, indices_i) - l_idx72 # 分组嵌入,投影73 emb_i = self.emb_layers[i](inp_i)74 emb_i = F.linear(emb_i, self.emb_projs[i])75 # 通过 index_copy_ 方法填充 emb_flat 矩阵76 emb_flat.index_copy_(0, indices_i, emb_i)77 # * 是解包,将 flat 恢复成原始形状78 embed = emb_flat.view(*inp.size(), self.d_proj)79 # 将所有嵌入向量的值乘以一个缩放因子80 embed.mul_(self.emb_scale)81
82 return embed2 相对编码的实现#
Transformer-XL 中用到了原版的正弦位置编码,对应的类是 PositionalEmbedding,相对位置编码的实现在注意力中:
2.1 RelMultiHeadAttn 类#
RelMultiHeadAttn 是其他两个相对多头注意力 RelPartialLearnableMultiHeadAttn 和 RelLearnableMultiHeadAttn 的父类,定义类属性
1class RelMultiHeadAttn(nn.Module):2 def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,3 tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):4 super(RelMultiHeadAttn, self).__init__()5
6 self.n_head = n_head7 self.d_model = d_model8 self.d_head = d_head9 self.dropout = dropout10
11 # 同时生成 Q、K、V 三个矩阵12 self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)13
14 self.drop = nn.Dropout(dropout)15 # 注意力权重矩阵的 Dropout16 self.dropatt = nn.Dropout(dropatt)17 self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)18
19 self.layer_norm = nn.LayerNorm(d_model)20
21 # sqrt(d_k)22 self.scale = 1 / (d_head ** 0.5)23 # post-LN 或 pre-LN24 self.pre_lnorm = pre_lnorm25
26 # 将绝对位置变为相对位置27 # x 是注意力矩阵 [qlen, klen, bsz, n_head]28 def _rel_shift(self, x, zero_triu=False):29 # 创建大小为 [qlen, 1, bsz, n_head] 的全 0 矩阵30 zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),31 device=x.device, dtype=x.dtype)32 # 在 dim=1 将 zero_pad 拼接到 x 上,相当于给每一行查询的位置分数最前面补了一个 033 x_padded = torch.cat([zero_pad, x], dim=1)34
35 # 0,1 维要交换,所以将 x.size(1) + 136 x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])37
38 # 切片舍弃第一行39 # 由于 zero_pad 的存在,使得每一行的偏移量不同40 # 原始的 i-j 相等的元素按对角线排列 --> 按列排列41 x = x_padded[1:].view_as(x)42
43 if zero_triu:44 ones = torch.ones((x.size(0), x.size(1)))45 x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]46
47 return x2.2 RelPartialLearnableMultiHeadAttn 类#
1class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):2 def __init__(self, *args, **kwargs):3 super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)4 # 位置偏置5 self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)6
7 # r 是正弦位置编码,r_w_bias 是公式中的 u,r_r_bias 是公式中的 v,mems 是 上一个 seg 的记忆8 def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):9
10 qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)11 if mems is not None:12 # 拼接之前的记忆和输入 E_xj [mlen+qlen, bsz, d_model] = [klen, bsz, d_model]13 cat = torch.cat([mems, w], 0)14 # 可选的层归一化15 if self.pre_lnorm:16 w_heads = self.qkv_net(self.layer_norm(cat))17 else:18 w_heads = self.qkv_net(cat)19
20 # 实现 W_kE 和 W_kR21 r_head_k = self.r_net(r)22 # 分离 Q, K, V23 w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)24 # Q 只来自当前25 w_head_q = w_head_q[-qlen:]26 else:27 if self.pre_lnorm:28 w_heads = self.qkv_net(self.layer_norm(w))29 else:30 w_heads = self.qkv_net(w)31 r_head_k = self.r_net(r)32
33 w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)34
35 # [klen, bsz, n_head*d_head]36 klen = w_head_k.size(0)37 # _len × bsz × n_head × d_head38 w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)39 w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)40 w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)41
42 r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)43
44 #### compute attention score45 # term A+C46 # w_head_q(W_q · E_xi) + r_w_bias(u)47 rw_head_q = w_head_q + r_w_bias48 '''49 einsum(equation, *operands)50 equation: str - 描述运算的字符串51 operands: List[Tensor] - 输入张量52 输出中没有d维度,因此在d维度上做点积:53 output[i][j][b][n] = sum_d(A[i][b][n][d] * B[j][b][n][d])54 '''55 AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))56
57 # W_q·E_xi + v58 rr_head_q = w_head_q + r_r_bias59 BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k))60 BD = self._rel_shift(BD)61
62 attn_score = AC + BD63 attn_score.mul_(self.scale)3 主类#
了解了嵌入和相对编码的实现后,我们最后来看 MemTransformerLM 类,记忆机制也在该类中实现。Transformer-XL 是 decoder-only 结构,不包含 encoder
3.1 记忆机制#
1def init_mems(self):2 if self.mem_len > 0:3 mems = []4 param = next(self.parameters())5 # 为每层创建记忆,mem[0] 是词嵌入的输出,mem[i] 是第 i 层 decoder 的输出6 for i in range(self.n_layer+1):7 empty = torch.empty(0, dtype=param.dtype, device=param.device)8 mems.append(empty)9 return mems10 else:11 return None12
13 def _update_mems(self, hids, mems, qlen, mlen):14 # 未使用 mem 不更新15 if mems is None: return None16
17 # mems is not None18 assert len(hids) == len(mems),19
20 # 记忆来自之前的 seg,不需要更新梯度21 with torch.no_grad():22 new_mems = []23 # mlen:旧记忆长度, qlen: 当前段长度24 # 需要保留的记忆的索引边界25 # qlen - self.ext_len,ext_len 是留给下一个段扩展用26 end_idx = mlen + max(0, qlen - self.ext_len)27 beg_idx = max(0, end_idx - self.mem_len)28 for i in range(len(hids)):29 # 新旧记忆拼接,[mlen+qlen, bsz, d_model]30 cat = torch.cat([mems[i], hids[i]], dim=0)31 # 保留 mlen 长度的记忆32 new_mems.append(cat[beg_idx:end_idx].detach())33 return new_mems3.2 类初始化#
1class MemTransformerLM(nn.Module):2 def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,3 dropout, dropatt, tie_weight=True, d_embed=None,4 div_val=1, tie_projs=[False], pre_lnorm=False,5 tgt_len=None, ext_len=None, mem_len=None,6 cutoffs=[], adapt_inp=False,7 same_length=False, attn_type=0, clamp_len=-1,8 sample_softmax=-1):9 super().__init__()10 ...11 d_embed = d_model if d_embed is None else d_embed12 # 词嵌入13 self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val)14 self.drop = nn.Dropout(dropout)15 # 最大键长度 = 目标序列长度 + 记忆长度 + 扩展长度16 self.max_klen = tgt_len + ext_len + mem_len17 self.layers = nn.ModuleList()18
19 # 本文提出的 attn20 if attn_type == 0: # the default attention21 for i in range(n_layer):22 self.layers.append(23 RelPartialLearnableDecoderLayer(24 n_head, d_model, d_head, d_inner, dropout,25 tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,26 dropatt=dropatt, pre_lnorm=pre_lnorm)27 )28 elif attn_type == 1: # learnable embeddings29 # 使用 RelLearnableDecoderLayer30 ...31 elif attn_type in [2, 3]: # absolute embeddings32 # 使用普通的 DecoderLayer33 ...34
35 self.sample_softmax = sample_softmax36 # 使用采样 softmax,不对全部数据进行 softmax,而是只计算 target 词 + 随机采样的一部分词的分数37 if sample_softmax > 0:38 self.out_layer = nn.Linear(d_model, n_token)39 # 权重绑定,out_layer 和 word_emb 共享权重减少参数40 if tie_weight:41 self.out_layer.weight = self.word_emb.weight42 self.tie_weight = tie_weight43 # 创建采样器,按对数均匀分布(平衡高频和低频词)44 self.sampler = LogUniformSampler(n_token, sample_softmax)45
46 else:47 # 自适应 Softmax 输出层48 self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,49 cutoffs, div_val=div_val)50 # 绑定输出层权重51 if tie_weight:52 for i in range(len(self.crit.out_layers)):53 self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight54 # 绑定投影矩阵55 if tie_projs:56 for i, tie_proj in enumerate(tie_projs):57 if tie_proj and div_val == 1 and d_model != d_embed:58 self.crit.out_projs[i] = self.word_emb.emb_projs[0]59 elif tie_proj and div_val != 1:60 self.crit.out_projs[i] = self.word_emb.emb_projs[i]61 # 控制注意力掩码的形状62 self.same_length = same_length63 # 限制位置编码的最大距离64 self.clamp_len = clamp_len65 self._create_params()66
67 # 根据注意力类型创建不同参数68 def _create_params(self):69 if self.attn_type == 0: # default attention70 self.pos_emb = PositionalEmbedding(self.d_model)71 self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))72 self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))73 elif self.attn_type == 1: # learnable74 ...75 elif self.attn_type == 2: # absolute standard76 ...77 elif self.attn_type == 3: # absolute deeper SA78 ...3.3 forward()#
1def forward(self, data, target, *mems):2 if not mems: mems = self.init_mems()3
4 tgt_len = target.size(0)5 # 通过 _forward 得到隐状态和新记忆6 hidden, new_mems = self._forward(data, mems=mems)7
8 # 取最后 tgt_len 个预测9 pred_hid = hidden[-tgt_len:]10 if self.sample_softmax > 0 and self.training:11 assert self.tie_weight12 logit = sample_logits(self.word_emb,13 self.out_layer.bias, target, pred_hid, self.sampler)14 loss = -F.log_softmax(logit, -1)[:, :, 0]15 else:16 loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))17 loss = loss.view(tgt_len, -1)18
19 if new_mems is None:20 return [loss]21 else:22 return [loss] + new_mems23
24
25def _forward(self, dec_inp, mems=None):26
27 # 输入大小的获取28 qlen, bsz = dec_inp.size()29 # 词嵌入30 word_emb = self.word_emb(dec_inp)31
32 mlen = mems[0].size(0) if mems is not None else 033 # klen = 记忆长度 + 当前段长度34 klen = mlen + qlen35
36 # 标准因果掩码,通过 triu 只能看到当前位置及之前的内容37 if self.same_length:38 # qlen × klen 的全 1 矩阵39 all_ones = word_emb.new_ones(qlen, klen)40 # 当前段长度41 mask_len = klen - self.mem_len42 if mask_len > 0:43 # 需要遮蔽的左边列数44 mask_shift_len = qlen - mask_len45 else:46 mask_shift_len = qlen47 # 1+mlen 对角线向右移动,-mask_shift_len 向左移动48 # j - i >= 1 + mlen ,j - i <= -mask_shift_len 的位置不可见49 dec_attn_mask = (torch.triu(all_ones, 1+mlen) + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -150 else:51 # 标准掩码只看过去52 dec_attn_mask = torch.triu(53 word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]54
55 hids = []56 if self.attn_type == 0: # default57 # 倒序的位置序列 [klen-1, klen-2,...,0]58 pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,59 dtype=word_emb.dtype)60 # 截断,让模型专注于近期的位置关系61 if self.clamp_len > 0:62 pos_seq.clamp_(max=self.clamp_len)63
64 # 位置编码65 pos_emb = self.pos_emb(pos_seq)66 # 词嵌入 + dropout67 core_out = self.drop(word_emb)68 # 位置编码 + dropout69 pos_emb = self.drop(pos_emb)70 # 保存嵌入层输出71 hids.append(core_out)72 for i, layer in enumerate(self.layers):73 mems_i = None if mems is None else mems[i]74 core_out = layer(core_out, pos_emb, self.r_w_bias,75 self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)76 hids.append(core_out)77
78 # 处理其它类型注意力79 ...80
81 # 最后一层的 dropout82 core_out = self.drop(core_out)83 # 更新记忆84 new_mems = self._update_mems(hids, mems, mlen, qlen)85 return core_out, new_memsComments
Site Statistics
144
6
9
2,255,454
0 days
0 days ago
2026年6月
Less More
日
一
二
三
四
五
六
Table of Contents