2.2 Linear Attention# 和诸多实用的序列建模方案一样,研究人员也十分关注线性模型的构建,以此提升长序列的处理速度。线性模型有多种定义形式,序列模型中常用的通用形式如下:
z i = f ( a ⋅ z i − 1 + b ⋅ s i ) (4) \mathbf{z}_i = f(a \cdot \mathbf{z}_{i-1} + b \cdot \mathbf{s}_i) \tag{4} z i = f ( a ⋅ z i − 1 + b ⋅ s i ) ( 4 ) 其中 s i \mathbf{s}_i s i 表示模型在第 i i i 步的中间状态,z i \mathbf{z}_i z i 表示截至第 i i i 步的历史状态汇总。显然这是一种循环模型:第 i i i 步的输出仅依赖当前步输入与上一步输出。和主流神经网络设计一致,线性计算后会再接变换函数 f ( ⋅ ) f(\cdot) f ( ⋅ ) ,它可以是激活函数或前馈网络。需要注意:仅当 f ( ⋅ ) f(\cdot) f ( ⋅ ) 为线性函数时,上式才是标准线性模型。引入 f ( ⋅ ) f(\cdot) f ( ⋅ ) 能提升建模灵活性,但若其为非线性形式,该模型便不再属于线性模型范畴。
只需对当前输入与上一时刻状态做线性变换,即可将 式 (4) 拓展为标准循环神经网络模型:z i = f ( z i − 1 W z + s i W s ) \mathbf{z}_i = f(\mathbf{z}_{i-1}\mathbf{W}_z + \mathbf{s}_i\mathbf{W}_s) z i = f ( z i − 1 W z + s i W s ) 。由此可很方便地将循环神经网络及其变体融入 Transformer,构建混合模型。
实际上,研究人员更倾向于构建线性注意力模型,在保证全局序列建模能力的同时,提升整体运行效率。实现该目标的一大难点在于,标准自注意力本身并非线性形式。先回顾自注意力的表达式:
A t t s e l f = A ⋅ V = ψ ( Q ⋅ K ⊤ ) ⋅ V \begin{align*}
\mathrm{Att}_{\mathrm{self}} &= \mathbf{A} \cdot \mathbf{V} \\
&= \psi(\mathbf{Q} \cdot \mathbf{K}^\top) \cdot \mathbf{V} \tag{5}
\end{align*} Att self = A ⋅ V = ψ ( Q ⋅ K ⊤ ) ⋅ V ( 5 ) 其中 ψ ( ⋅ ) \psi(\cdot) ψ ( ⋅ ) 是一个由缩放、指数运算、掩码和归一化组成的复合函数(即 ψ ( a ) = Normalize ( Mask ( exp ( a d ) ) ) \psi(a) = \text{Normalize}(\text{Mask}(\exp(\frac{a}{\sqrt{d}}))) ψ ( a ) = Normalize ( Mask ( exp ( d a ))) )。由于 ψ ( ⋅ ) \psi(\cdot) ψ ( ⋅ ) 是复杂的非线性函数,无法直接简化计算,通常需要分别执行两次矩阵乘法(一次在 ψ ( ⋅ ) \psi(\cdot) ψ ( ⋅ ) 内部,一次在外部)。这就要求显式存储所有键值对,并在处理每个查询时逐一访问,因此自注意力模型的计算复杂度会随序列长度 n n n 呈二次增长。
尽管自注意力中键与值是成对存在的,但二者在计算中是分步使用的。一种巧妙的解法是将查询与键值交互解耦,使上下文信息的编码独立于具体的查询。其核心思路是:通过对查询和键应用特征映射函数 ϕ ( ⋅ ) \phi(\cdot) ϕ ( ⋅ ) ,绕过非线性的 Softmax 运算。我们将 Q \mathbf{Q} Q 和 K \mathbf{K} K 变换为:Q ′ = ϕ ( Q ) ∈ R n × d ′ , K ′ = ϕ ( K ) ∈ R n × d ′ \mathbf{Q}' = \phi(\mathbf{Q}) \in \mathbb{R}^{n \times d'}, \quad \mathbf{K}' = \phi(\mathbf{K}) \in \mathbb{R}^{n \times d'} Q ′ = ϕ ( Q ) ∈ R n × d ′ , K ′ = ϕ ( K ) ∈ R n × d ′ ,此时,注意力机制可重新表示为:
A t t s e l f ≡ ψ ′ ( Q ′ ⋅ K ′ ⊤ ) ⋅ V = D − 1 ( Q ′ ⋅ K ′ ⊤ ) ⋅ V = D − 1 ( Q ′ ⋅ ( K ′ ⊤ ⋅ V ) ) \begin{align*}
\mathrm{Att}_{\mathrm{self}} &\equiv \psi'(\mathbf{Q}' \cdot \mathbf{K}'^\top) \cdot \mathbf{V} \\ &= \mathbf{D}^{-1} (\mathbf{Q}' \cdot \mathbf{K}'^\top) \cdot \mathbf{V} \\
&= \mathbf{D}^{-1} \left( \mathbf{Q}' \cdot (\mathbf{K}'^\top \cdot \mathbf{V}) \right) \tag{6}
\end{align*} Att self ≡ ψ ′ ( Q ′ ⋅ K ′⊤ ) ⋅ V = D − 1 ( Q ′ ⋅ K ′⊤ ) ⋅ V = D − 1 ( Q ′ ⋅ ( K ′⊤ ⋅ V ) ) ( 6 ) 其中 D \mathbf{D} D 是对角归一化矩阵,且 ψ ′ ( Z ) = D − 1 Z \psi'(\mathbf{Z}) = \mathbf{D}^{-1}\mathbf{Z} ψ ′ ( Z ) = D − 1 Z 。在该变换空间中,查询-键点积无需再经过指数形式的 Softmax 归一化,仅需乘以 D − 1 \mathbf{D}^{-1} D − 1 完成归一化即可。由于此时核心运算为纯矩阵乘法,我们可以利用矩阵乘法的结合律改变计算顺序。
由此可以得到一种高效的自回归形式:首先通过 K ′ ⊤ ⋅ V \mathbf{K}'^\top \cdot \mathbf{V} K ′⊤ ⋅ V 对键和值进行聚合,再让查询对这个全局表示进行注意力计算。由于 K ′ ⊤ ⋅ V = ∑ j = 1 n k ′ j ⊤ ⋅ v j \mathbf{K}'^\top \cdot \mathbf{V} = \sum_{j=1}^{n} \mathbf{k}'{_j}^\top \cdot \mathbf{v}_j K ′⊤ ⋅ V = ∑ j = 1 n k ′ j ⊤ ⋅ v j ,我们可以将 K ′ ⊤ ⋅ V \mathbf{K}'^\top \cdot \mathbf{V} K ′⊤ ⋅ V 写成式 (4) 的形式,如下所示:
μ j = μ j − 1 + k ′ j ⊤ ⋅ v j (7) \mu_j = \mu_{j-1} + \mathbf{k}'{_j}^\top \cdot \mathbf{v}_j \tag{7} μ j = μ j − 1 + k ′ j ⊤ ⋅ v j ( 7 ) 其中 μ j ∈ R d ′ × d \mu_j \in \mathbb{R}^{d' \times d} μ j ∈ R d ′ × d 是一个变量,它每次累加一项 k ′ j ⊤ ⋅ v j \mathbf{k}'{_j}^\top \cdot \mathbf{v}_j k ′ j ⊤ ⋅ v j 。同理,我们可以定义另一个变量 ν j ∈ R d ′ \nu_j \in \mathbb{R}^{d'} ν j ∈ R d ′
ν j = ν j − 1 + k ′ j ⊤ (8) \nu_j = \nu_{j-1} + \mathbf{k}'{_j}^\top \tag{8} ν j = ν j − 1 + k ′ j ⊤ ( 8 ) 据此,第 j j j 个查询对应的自注意力输出可表示为:
A t t s e l f , j = q j ′ ⋅ μ n q j ′ ⋅ ν n (9) \mathrm{Att}_{\mathrm{self},j} = \frac{\mathbf{q}'_j \cdot \mu_n}{\mathbf{q}'_j \cdot \nu_n} \tag{9} Att self , j = q j ′ ⋅ ν n q j ′ ⋅ μ n ( 9 ) 显然,该模型实现了线性时间复杂度,因为 μ n \mu_n μ n 和 ν n \nu_n ν n 随序列长度 n n n 呈线性增长。在该模型的简单实现中,只需维护 μ j \mu_j μ j 和 ν j \nu_j ν j 两个状态变量即可。每当遇到新的查询时,我们更新 μ j \mu_j μ j 和 ν j \nu_j ν j ,然后计算:A t t s e l f , j = q j ′ ⋅ μ j q j ′ ⋅ ν j \mathrm{Att}_{\mathrm{self},j} = \frac{\mathbf{q}'_j \cdot \mu_j}{\mathbf{q}'_j \cdot \nu_j} Att self , j = q j ′ ⋅ ν j q j ′ ⋅ μ j 。
对线性注意力模型的一个直接扩展,是让式 (7), (8) 能够以不同权重组合各项。例如,我们可以将 μ j \mu_j μ j 和 ν j \nu_j ν j 重新定义为:
μ j = a ⋅ μ j − 1 + ( 1 − a ) ⋅ k j ′ ⊤ ⋅ v j ν j = a ⋅ ν j − 1 + ( 1 − a ) ⋅ k j ′ ⊤ \begin{align*}
\mu_j &= a \cdot \mu_{j-1} + (1 - a) \cdot \mathbf{k}'^\top_j \cdot \mathbf{v}_j \tag{10} \\
\nu_j &= a \cdot \nu_{j-1} + (1 - a) \cdot \mathbf{k}'^\top_j \tag{11}
\end{align*} μ j ν j = a ⋅ μ j − 1 + ( 1 − a ) ⋅ k j ′⊤ ⋅ v j = a ⋅ ν j − 1 + ( 1 − a ) ⋅ k j ′⊤ ( 10 ) ( 11 ) 并照常对参数 a a a 进行训练。此外,我们也可以将 a a a 视作门控变量,使用额外的神经网络来计算其取值。
我们已经了解了为注意力机制设计线性模型的通用思路。这类模型的核心设计选择是移除基于 Softmax 的归一化,从而得到基于模型各类中间状态的线性表示形式。这一思路催生了多种近年提出的自注意力替代方案。尽管这些系统的架构各不相同,但其底层模型均具有式 (4) 所描述的类似形式。需要注意的是,通过使用循环模型的通用形式,我们不再局限于标准的 QKV 注意力机制;相反,我们可以为查询、键和值赋予新的含义与形式。
此处的讨论也与之前介绍的记忆模型相关。从记忆的角度来看,键和值可以被视为上下文的编码。因此,在上述线性注意力模型中,我们构建了一个记忆系统,其中仅用两个简单变量 μ j \mu_j μ j 和 ν j \nu_j ν j 就可以表示位置 j j j 之前的所有上下文信息。这形成了一种固定长度的记忆,在实际应用中非常有用。
2.3 State-Space Models# 在控制系统中,状态空间模型 (SSMs ) 是一类系统的表示形式,其输入与输出通过若干状态变量(简称状态)关联,系统的动态特性则由这些状态的一阶微分方程描述。作为一个简单示例,我们考虑一个以状态空间形式表示的连续时间线性时不变系统。本节统一采用行向量约定:所有变量均表示为行向量,线性变换作用于向量右侧:
d z ( t ) d t = z ( t ) ⋅ A + s ( t ) ⋅ B o ( t ) = z ( t ) ⋅ C + s ( t ) ⋅ D \begin{align*} \frac{d\mathbf{z}(t)}{dt}
&= \mathbf{z}(t) \cdot \mathbf{A} + \mathbf{s}(t) \cdot \mathbf{B} \tag{12} \\
\mathbf{o}(t) &= \mathbf{z}(t) \cdot \mathbf{C} + \mathbf{s}(t) \cdot \mathbf{D} \tag{13}
\end{align*} d t d z ( t ) o ( t ) = z ( t ) ⋅ A + s ( t ) ⋅ B = z ( t ) ⋅ C + s ( t ) ⋅ D ( 12 ) ( 13 ) 其中 s ( t ) \mathbf{s}(t) s ( t ) 、o ( t ) \mathbf{o}(t) o ( t ) 和 z ( t ) \mathbf{z}(t) z ( t ) 分别表示时刻 t t t 的输入变量、输出变量和状态变量的值。在一般情况下,这些变量可以具有不同的维度。为简化起见,我们假设 s ( t ) , o ( t ) ∈ R d \mathbf{s}(t), \mathbf{o}(t) \in \mathbb{R}^d s ( t ) , o ( t ) ∈ R d ,且 z ( t ) ∈ R d z \mathbf{z}(t) \in \mathbb{R}^{d_z} z ( t ) ∈ R d z 。式 (12) 被称为状态方程,其中 A ∈ R d z × d z \mathbf{A} \in \mathbb{R}^{d_z \times d_z} A ∈ R d z × d z 是状态矩阵,B ∈ R d × d z \mathbf{B} \in \mathbb{R}^{d \times d_z} B ∈ R d × d z 是输入矩阵。式 (13) 被称为输出方程,其中 C ∈ R d z × d \mathbf{C} \in \mathbb{R}^{d_z \times d} C ∈ R d z × d 是输出矩阵,D ∈ R d × d \mathbf{D} \in \mathbb{R}^{d \times d} D ∈ R d × d 是前馈矩阵。
这些方程描述了从输入 s ( t ) \mathbf{s}(t) s ( t ) 到输出 o ( t ) \mathbf{o}(t) o ( t ) 随时间变化的连续映射关系,因此常被用于处理连续时间序列数据。为了将该模型应用于序列建模,我们需要对上述方程进行修改,以得到离散形式的状态空间表示。假设 { s 0 , s 1 , … , s n } \{\mathbf{s}_0, \mathbf{s}_1, \dots, \mathbf{s}_n\} { s 0 , s 1 , … , s n } 是从 s ( t ) \mathbf{s}(t) s ( t ) 中以时间步长 Δ t \Delta t Δ t 采样得到的输入数据点序列,类似地,定义 { z 0 , z 1 , … , z n } \{\mathbf{z}_0, \mathbf{z}_1, \dots, \mathbf{z}_n\} { z 0 , z 1 , … , z n } 和 { o 0 , o 1 , … , o n } \{\mathbf{o}_0, \mathbf{o}_1, \dots, \mathbf{o}_n\} { o 0 , o 1 , … , o n } 分别为状态向量和输出向量的序列。基于这些符号,我们可以写出离散化的 SSM 的方程:
z t = z t − 1 ⋅ A ‾ + s t ⋅ B ‾ o t = z t ⋅ C ‾ + s t ⋅ D ‾ \begin{align*}
\mathbf{z}_t &= \mathbf{z}_{t-1} \cdot \overline{\mathbf{A}} + \mathbf{s}_t \cdot \overline{\mathbf{B}} \tag{14} \\
\mathbf{o}_t &= \mathbf{z}_t \cdot \overline{\mathbf{C}} + \mathbf{s}_t \cdot \overline{\mathbf{D}} \tag{15}
\end{align*} z t o t = z t − 1 ⋅ A + s t ⋅ B = z t ⋅ C + s t ⋅ D ( 14 ) ( 15 ) 这种形式的 SSM 定义了一个带有残差连接的 RNN。式 (14) 描述了一个循环单元,它读取第 t t t 步的输入和第 t − 1 t-1 t − 1 步的状态,且不使用任何激活函数。式 (15) 描述了一个输出层,它将状态 z t \mathbf{z}_t z t 的线性变换与残差映射 s t \mathbf{s}_t s t 相加。
参数 A ‾ \overline{\mathbf{A}} A 、B ‾ \overline{\mathbf{B}} B 、C ‾ \overline{\mathbf{C}} C 和 D ‾ \overline{\mathbf{D}} D 可以通过多种不同方式从 A \mathbf{A} A 、B \mathbf{B} B 、C \mathbf{C} C 和 D \mathbf{D} D 推导得出,具体取决于式 (12) 如何由式 (14) 近似表示。一种时间离散化方法是双线性变换 ,另一种方法是采用零阶保持 (ZOH ) 离散化。
式 (14) 的循环形式使得在一系列离散时间步上计算状态和输出变得十分简便。我们可以将 z t \mathbf{z}_t z t 和 o t \mathbf{o}_t o t 以前馈方式展开:
很容易得到:
z t = ∑ i = 0 t s i ⋅ B ‾ ⋅ A ‾ t − i o t = ∑ i = 0 t s i ⋅ B ‾ ⋅ A ‾ t − i ⋅ C ‾ + s t ⋅ D ‾ \begin{align*}
\mathbf{z}_t &= \sum_{i=0}^{t} \mathbf{s}_i \cdot \overline{\mathbf{B}} \cdot \overline{\mathbf{A}}^{t-i} \tag{16} \\
\mathbf{o}_t &= \sum_{i=0}^{t} \mathbf{s}_i \cdot \overline{\mathbf{B}} \cdot \overline{\mathbf{A}}^{t-i} \cdot \overline{\mathbf{C}} + \mathbf{s}_t \cdot \overline{\mathbf{D}} \tag{17}
\end{align*} z t o t = i = 0 ∑ t s i ⋅ B ⋅ A t − i = i = 0 ∑ t s i ⋅ B ⋅ A t − i ⋅ C + s t ⋅ D ( 16 ) ( 17 ) 显然,式 (17) 的右侧可以被解释为一个卷积项与一个线性项的和。鉴于
∑ i = 0 t s i ⋅ B ‾ ⋅ A ‾ t − i ⋅ C ‾ = [ s 0 s 1 … s t ] ⋅ [ B ‾ ⋅ A ‾ t ⋅ C ‾ B ‾ ⋅ A ‾ t − 1 ⋅ C ‾ … B ‾ ⋅ C ‾ ] ⊤ \begin{align*}
\sum_{i=0}^{t} \mathbf{s}_i \cdot \overline{\mathbf{B}} \cdot \overline{\mathbf{A}}^{t-i} \cdot \overline{\mathbf{C}}
={}&
\begin{bmatrix} \mathbf{s}_0 & \mathbf{s}_1 & \dots & \mathbf{s}_t \end{bmatrix} \cdot \\
&
\begin{bmatrix} \overline{\mathbf{B}} \cdot \overline{\mathbf{A}}^{t} \cdot \overline{\mathbf{C}} & \overline{\mathbf{B}} \cdot \overline{\mathbf{A}}^{t-1} \cdot \overline{\mathbf{C}} & \dots & \overline{\mathbf{B}} \cdot \overline{\mathbf{C}} \end{bmatrix}^\top \tag{18}
\end{align*} i = 0 ∑ t s i ⋅ B ⋅ A t − i ⋅ C = [ s 0 s 1 … s t ] ⋅ [ B ⋅ A t ⋅ C B ⋅ A t − 1 ⋅ C … B ⋅ C ] ⊤ ( 18 ) 我们定义一个参数化的卷积滤波器:
W ssm = [ B ‾ ⋅ A ‾ n max ⋅ C ‾ B ‾ ⋅ A ‾ n max − 1 ⋅ C ‾ … B ‾ ⋅ C ‾ ] (19) \mathbf{W}_{\text{ssm}} = \begin{bmatrix} \overline{\mathbf{B}} \cdot \overline{\mathbf{A}}^{n_{\max}} \cdot \overline{\mathbf{C}} & \overline{\mathbf{B}} \cdot \overline{\mathbf{A}}^{n_{\max}-1} \cdot \overline{\mathbf{C}} & \ldots & \overline{\mathbf{B}} \cdot \overline{\mathbf{C}} \end{bmatrix} \tag{19} W ssm = [ B ⋅ A n m a x ⋅ C B ⋅ A n m a x − 1 ⋅ C … B ⋅ C ] ( 19 ) 其中 n max n_{\max} n m a x 为序列的最大长度。此时,对于序列 S = [ s 0 ⋮ s n ] \mathbf{S} = \begin{bmatrix} \mathbf{s}_0 \\ \vdots \\ \mathbf{s}_n \end{bmatrix} S = s 0 ⋮ s n ,状态空间模型的输出可表示为:
O = Conv ( S , W ssm ) + Linear ( S , D ‾ ) (20) \mathbf{O} = \text{Conv}(\mathbf{S}, \mathbf{W}_{\text{ssm}}) + \text{Linear}(\mathbf{S}, \overline{\mathbf{D}}) \tag{20} O = Conv ( S , W ssm ) + Linear ( S , D ) ( 20 ) 其中 Conv ( ⋅ ) \text{Conv}(\cdot) Conv ( ⋅ ) 表示卷积操作,Linear ( ⋅ ) \text{Linear}(\cdot) Linear ( ⋅ ) 表示线性变换操作。对状态空间模型的这种处理方式,使得系统能够通过快速并行卷积算法高效实现。
另一个计算瓶颈在于求解 A ‾ n \overline{\mathbf{A}}^n A n 时需要反复进行矩阵乘法。当 n n n 较大时,该运算不仅计算开销高,还存在数值不稳定的问题。现代状态空间模型普遍采用对角化 方案:通过对状态空间做基变换,使矩阵 A \mathbf{A} A (或变换后的 A ‾ \overline{\mathbf{A}} A )化为对角矩阵。对于参数为 ( A , B , C , D ) (\mathbf{A},\mathbf{B},\mathbf{C},\mathbf{D}) ( A , B , C , D ) 的状态空间模型,引入可逆变换矩阵 U \mathbf{U} U ,可得到等价模型 ( U A U − 1 , B U − 1 , U C , D ) (\mathbf{U}\mathbf{A}\mathbf{U}^{-1},\mathbf{B}\mathbf{U}^{-1},\mathbf{U}\mathbf{C},\mathbf{D}) ( UA U − 1 , B U − 1 , UC , D ) 。可以证明,两种形式仅为基变换下的不同表达,数学上完全等价。
若对矩阵执行上述变换,且假设 A \mathbf{A} A 可对角化为标准形式 P − 1 Λ P \mathbf{P}^{-1}\Lambda\mathbf{P} P − 1 Λ P (其中 Λ \boldsymbol{\Lambda} Λ 为对角矩阵),便可将状态矩阵约束为对角矩阵,进而得到对角状态空间模型。以卷积形式的滤波器为例,将 A = P − 1 Λ P \mathbf{A} = \mathbf{P}^{-1}\Lambda\mathbf{P} A = P − 1 Λ P 代入后,可把 B ⋅ A ‾ t ⋅ C \mathbf{B} \cdot \overline{\mathbf{A}}^t \cdot \mathbf{C} B ⋅ A t ⋅ C 改写为:
B ‾ ⋅ A ‾ t ⋅ C ‾ = B ‾ ⋅ ( P − 1 Λ P ) t ⋅ C ‾ = B ‾ ⋅ ( P − 1 Λ P ) ⋅ ( P − 1 Λ P ) ⋯ ( P − 1 Λ P ) ⋅ C ‾ = ( B ‾ ⋅ P − 1 ) ⋅ Λ t ⋅ ( P ⋅ C ‾ ) \begin{align*}
\overline{\mathbf{B}} \cdot \overline{\mathbf{A}}^t \cdot \overline{\mathbf{C}} &= \overline{\mathbf{B}} \cdot (\mathbf{P}^{-1}\Lambda\mathbf{P})^t \cdot \overline{\mathbf{C}} \\
&= \overline{\mathbf{B}} \cdot (\mathbf{P}^{-1}\Lambda\mathbf{P}) \cdot (\mathbf{P}^{-1}\Lambda\mathbf{P}) \cdots (\mathbf{P}^{-1}\Lambda\mathbf{P}) \cdot \overline{\mathbf{C}} \\
&= (\overline{\mathbf{B}} \cdot \mathbf{P}^{-1}) \cdot \Lambda^t \cdot (\mathbf{P} \cdot \overline{\mathbf{C}}) \tag{21}
\end{align*} B ⋅ A t ⋅ C = B ⋅ ( P − 1 Λ P ) t ⋅ C = B ⋅ ( P − 1 Λ P ) ⋅ ( P − 1 Λ P ) ⋯ ( P − 1 Λ P ) ⋅ C = ( B ⋅ P − 1 ) ⋅ Λ t ⋅ ( P ⋅ C ) ( 21 ) 由于 Λ \Lambda Λ 是对角矩阵,我们只需将其所有元素取 t t t 次幂,即可高效计算出 Λ t \Lambda^t Λ t 。这样一来,我们就得到了计算成本更低的模型,在该模型中:
A ‾ ′ = Λ B ‾ ′ = B ‾ ⋅ P − 1 C ‾ ′ = P ⋅ C ‾ D ‾ ′ = D ‾ \begin{align*}
\overline{\mathbf{A}}' &= \Lambda \tag{22} \\
\overline{\mathbf{B}}' &= \overline{\mathbf{B}} \cdot \mathbf{P}^{-1} \tag{23} \\
\overline{\mathbf{C}}' &= \mathbf{P} \cdot \overline{\mathbf{C}} \tag{24} \\
\overline{\mathbf{D}}' &= \overline{\mathbf{D}} \tag{25}
\end{align*} A ′ B ′ C ′ D ′ = Λ = B ⋅ P − 1 = P ⋅ C = D ( 22 ) ( 23 ) ( 24 ) ( 25 ) 将 SSM 应用于 Transformer 的方式十分简便:只需按照式 (14), (15) ,用状态空间模型子层替换原有自注意力子层即可。前文已经说明,状态空间模型与卷积神经网络、循环神经网络均存在紧密关联。在序列建模任务中,既可以像循环神经网络那样逐序列串行处理,也能仿照卷积神经网络实现并行运算。由此催生了一种全新范式:训练阶段采用卷积网络的并行模式,借助高效的并行算法加速训练;预测阶段则转化为串行更新形式,依托类循环网络结构实现高效推理。
SSM 的形式体系为序列建模搭建了统一框架。借助并行与串行两种互补视角,我们可针对不同硬件条件与延迟要求进行优化。近期多款模型均受这种双重特性启发,成功融合了 Transformer 的并行计算能力与循环神经网络的高效推理优势。