DOI: 10.48550/arXiv.1803.02155
Date: 2018/04/12
1 问题#
非循环模型无法天然按序遍历输入元素,因此需要额外显式编码位置信息,以表征序列先后关系。
Transformer 与 RNN 和 CNN 相比,并没有显式地对相对或绝对信息进行建模,而是需要在输入中加入绝对位置的表示。
本文提出了一个替代方案,对注意力机制进行扩展,从而有效地考虑相对位置,或句子元素间的距离。
2 符号体系#
每个注意力头在包含 n 个元素的输入 x=(x1,…,xn) 上进行运算,其中 xi∈Rdx,生成大小相同的序列 z=(z1,…,zn),其中 zi∈Rdz。每个输出元素 zi 由经过线性变换后的输入元素加权求和得到:
zi=j=1∑nαij(xjWV)(1)教材中将 xjWV 直接写作 hj
每个权重系数 αij 由 softmax 函数得到:
αij=∑k=1nexpeikexpeij eij 由用于比对两个输入单元的匹配函数计算得到:
eij=dz(xiWQ)(xjWK)⊤(2)
3 改进#
3.1 关系感知自注意力#
本文提出一种自注意力拓展方案,用以建模输入单元间的成对关联;在此思路下,将输入数据建模为带标签、有向、全连接图结构。
输入元素 xi 和 xj 的边通过向量 aijV,aijK∈Rda 表示。aijK、aijV可分别直接用于式(3)(4),无需额外线性变换;边表征在所有注意力头共享参数,实验取 da=dz。
我们对式(1) 进行改写,将边信息传播到子层的输出中:
zi=j=1∑nαij(xjWV+aijV)(3) 该改进对如下任务至关重要:注意力头筛选出的边类型信息可被后续编码器/解码器复用;但实验表明,机器翻译任务中该结构并非必需。
同样对式(2) 进行改写:
eij=dzxiWQ(xjWK+aijK)⊤(4)3.2 相对位置表示#
对线性序列来说,边可用来建模输入单元间的相对位置差信息,最大相对位置在 k 处进行截断。假设超过指定距离后相对位置信息无建模价值,对最大距离进行截断可以让模型对训练集未出现过的超长序列具备泛化能力。因此考虑 2k+1 个独特的边标签:
aijKaijVclip(x,k)=wclip(j−i,k)K=wclip(j−i,k)V=max(−k,min(k,x)) 接下来我们可以学习相对位置表征 wK=(w−kK,…,wkK) 和 wV=(w−kV,…,wkV),其中 wiK,wiV∈Rda。
即 wK∈R(2k+1)×da,从 2k+1 行中挑出索引为 clip(j−i,k) 的一行
3.3 高效实现#
设序列长度为 n、注意力头数 h,通过多头共享相对位置表征,相对位置表示的存储复杂度由 O(hn2da) 降至 O(n2da);此外该表征可跨序列共享。因此自注意力整体空间复杂度从 O(bhndz) 变为 O(bhndz+n2da)。在 da=dz 条件下,空间相对增量由 bhn 决定。
序列长度为 n,(i,j) 的配对个数为 n2
单个 aijK/aijV 的维度是 da
最后乘上注意力头数得到 hn2da,共享后不用乘 h
注意力复杂度:batch×h×n×xiWQ 的维度 =bhndz
相对增量=n2da/bhndz=n/bh
Transformer 利用并行矩阵乘法,高效计算批次内所有序列、注意力头与位置的自注意力。在不使用相对位置表征时,所有 eij 可通过 bh 次 n×dz 与 dz×n 矩阵的并行相乘得到。
引入相对位置后,位置配对不同则表征不同,无法通过单次矩阵乘法一次性算出全部 eij,同时也需要规避相对位置表征的广播运算。将式(4) 拆分为两项即可解决上述两个问题:
eij=dzxiWQ(xjWK)⊤+xiWQ(aijK)⊤(5)第一项和式(2) 相同,可沿用前述常规矩阵乘法计算。对于含相对位置表征的第二项,通过张量变形,能够执行 n 组并行矩阵乘法(bh×dz与dz×n相乘);单次矩阵乘法可算出单个 token 位置、全部批次与注意力头对应的 eij 增量。再次张量变形后即可合并两项结果,该方法同样适用于式(3) 的高效求解。
Q∈Rbh×n×dz 与 A∈Rn×n×dz 做运算时会进行广播,也就是计算 xiWQ(xjWK+aijK)⊤ 时,括号外的 i 需要同时与括号内的 i,j 进行广播
将两项拆开,只需在每个 i 上与 aijK 计算即可
4 结论#
- 数据集:
WMT14 英德 (450 万句对)、英法 (3600 万句对),词表 32768。单卡单批限制 4096 词,整批源 / 目标端各约 25k token
- 训练通用设置:
Adam 优化器、4000 步学习率预热、标签平滑 0.1;解码束搜索 beam=4、长度惩罚 0.6
- 参数:编解码器各 6 层,dx=512,dz=64,h=8,FNN=1024,dropout=0.1
与采用正弦位置编码的原始 Transformer 基线对比:
表 1: 基于 WMT2014 英德、英法翻译任务,在 newstest2014 官方测试集上输出实验结果消融实验: