自注意力(Self-Attention) Q = X ⋅ W Q , K = X ⋅ W K , V = X ⋅ W V Attention ( Q , K , V ) = softmax ( Q K T d k ) V X . shape : [ B , T , D ] , W . shape : [ D , D ] , d k = D
\begin{aligned}
& Q = X \cdot W_Q, \quad K = X \cdot W_K, \quad V = X \cdot W_V \\
& \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T} { \sqrt{d_k} }\right)V \\
& X.\text{shape}: [B, T, D], \quad W.\text{shape}: [D, D], \quad d_k = D
\end{aligned}
Q = X ⋅ W Q , K = X ⋅ W K , V = X ⋅ W V Attention ( Q , K , V ) = softmax ( d k Q K T ) V X . shape : [ B , T , D ] , W . shape : [ D , D ] , d k = D
理论解释为什么 Attention 有效 假设我们有 5 个单词 {cat, milk, it, sweet, hungry}
, 每个单词用一个向量表示:sweet = ( . . . , 0 , 4 , . . . ) milk = ( . . . , 1 , 3 , . . . ) it = ( . . . , 2 , 2 , . . . ) cat = ( . . . , 2 , 2 , . . . ) hungry = ( . . . , 4 , 0 , . . . )
\begin{aligned}
\text{sweet} &= (..., 0, 4, ...)\\
\text{milk} &= (..., 1, 3, ...)\\
\text{it} &= (..., 2, 2, ...)\\
\text{cat} &= (..., 2, 2, ...)\\
\text{hungry} &= (..., 4, 0, ...)
\end{aligned}
sweet milk it cat hungry = ( ... , 0 , 4 , ... ) = ( ... , 1 , 3 , ... ) = ( ... , 2 , 2 , ... ) = ( ... , 2 , 2 , ... ) = ( ... , 4 , 0 , ... )
假设词向量的维度 dim=4,第一列数字(第 1 个头)代表这个单词关于状态
的属性,值越高代表该单词与hungry
越相关; 第二列数字(第 2 个头)代表这个单词关于味道
的属性,值越高代表该单词与sweet
越相关。
例子 1 现在让我们考虑 Attention 算法中,计算状态
这部分的头。假设我们正在处理一个句子 “The cat drank the milk because it was sweet.” ,其中包含了 cat、milk、it、sweet 这四个单词(暂且忽略其余不相关单词,单词按词序组成矩阵)。此时在 Self-Attention 算法中的 Q、K、V 矩阵为:Q = K = V = [ . . . 2 2 . . . . . . 1 3 . . . . . . 2 2 . . . . . . 0 4 . . . ] cat milk it sweet
Q=K=V=
\begin{bmatrix}
... & 2 & 2 & ...\\
... & 1 & 3 & ...\\
... & 2 & 2 & ...\\
... & 0 & 4 & ...
\end{bmatrix}
\begin{matrix}
\text{cat} \\
\text{milk} \\
\text{it} \\
\text{sweet}
\end{matrix}
Q = K = V = ⎣ ⎡ ... ... ... ... 2 1 2 0 2 3 2 4 ... ... ... ... ⎦ ⎤ cat milk it sweet
现在我们计算 Attention 分数(为了方便理解,...
部分用 0 代替):
Q ⋅ K T = cat milk it sweet cat 8 8 8 8 milk 8 10 8 12 it 8 8 8 8 sweet 8 12 8 16 S o f t m a x ( Q ⋅ K T d ) ⋅ V = [ . . . 0.625 1.375 . . . . . . 0.089 1.911 . . . . . . 0.625 1.375 . . . . . . 0.010 1.990 . . . ] cat milk it sweet
Q \cdot K^T=
\begin{array}{cccccc}
& \text{cat} & \text{milk} & \text{it} & \text{sweet} \\
\text{cat} & 8 & 8 & 8 & 8 \\
\text{milk} & 8 & 10 & 8 & 12 \\
\text{it} & 8 & 8 & 8 & 8 \\
\text{sweet} & 8 & 12 & 8 & 16 \\
\end{array}
\\
Softmax(\frac{Q \cdot K^T}{\sqrt{d}}) \cdot V =
\begin{bmatrix}
... & 0.625 & 1.375 & ...\\
... & 0.089 & 1.911 & ...\\
... & 0.625 & 1.375 & ...\\
... & 0.010 & 1.990 & ...
\end{bmatrix}
\begin{matrix}
\text{cat} \\
\text{milk} \\
\text{it} \\
\text{sweet}
\end{matrix}
Q ⋅ K T = cat milk it sweet cat 8 8 8 8 milk 8 10 8 12 it 8 8 8 8 sweet 8 12 8 16 S o f t ma x ( d Q ⋅ K T ) ⋅ V = ⎣ ⎡ ... ... ... ... 0.625 0.089 0.625 0.010 1.375 1.911 1.375 1.990 ... ... ... ... ⎦ ⎤ cat milk it sweet
之后,我们得到这 4 个单词的新 embedding 为:sweet = ( . . . , 0 , 4 , . . . ) → ( . . . , 0.010 , 1.990 , . . . ) milk = ( . . . , 1 , 3 , . . . ) → ( . . . , 0.089 , 1.911 , . . . ) it = ( . . . , 2 , 2 , . . . ) → ( . . . , 0.625 , 1.375 , . . . ) cat = ( . . . , 2 , 2 , . . . ) → ( . . . , 0.625 , 1.375 , . . . ) hungry = ( . . . , 4 , 0 , . . . )
\begin{aligned}
\text{sweet} &= (..., 0, 4, ...) \rightarrow (..., 0.010, 1.990, ...)\\
\text{milk} &= (..., 1, 3, ...) \rightarrow (..., 0.089, 1.911, ...)\\
\text{it} &= (..., 2, 2, ...) \rightarrow (..., 0.625, 1.375, ...)\\
\text{cat} &= (..., 2, 2, ...) \rightarrow (..., 0.625, 1.375, ...)\\
\text{hungry} &= (..., 4, 0, ...)
\end{aligned}
sweet milk it cat hungry = ( ... , 0 , 4 , ... ) → ( ... , 0.010 , 1.990 , ... ) = ( ... , 1 , 3 , ... ) → ( ... , 0.089 , 1.911 , ... ) = ( ... , 2 , 2 , ... ) → ( ... , 0.625 , 1.375 , ... ) = ( ... , 2 , 2 , ... ) → ( ... , 0.625 , 1.375 , ... ) = ( ... , 4 , 0 , ... )
通常情况下, 在英文中 it 既可以指代 cat 又可以指代 milk,因此 it 和 cat 的相似度与 it 和 milk 的相似度相同,即:sim(it, milk) = 1 × 2 + 3 × 2 = 8 sim(it, cat) = 2 × 2 + 2 × 2 = 8
\text{sim(it, milk)} = 1 \times 2 + 3 \times 2 = 8 \\
\text{sim(it, cat)} = 2 \times 2 + 2 \times 2 = 8
sim(it, milk) = 1 × 2 + 3 × 2 = 8 sim(it, cat) = 2 × 2 + 2 × 2 = 8 之后,模型学习了句子 “The cat drank the milk because it was sweet.” ,这个句子中 it 指代 milk,通过 Attention 算法后得到了新的词 embedding,这时 it 在词向量表达上更加靠近 milk:sim(it, milk) = 0.0625 × 0.089 + 1.375 × 1.911 = 2.68 sim(it, cat) = 0.0625 × 0.0625 + 1.375 × 1.375 = 2.28
\text{sim(it, milk)} = 0.0625 \times 0.089 + 1.375 \times 1.911 = 2.68 \\
\text{sim(it, cat)} = 0.0625 \times 0.0625 + 1.375 \times 1.375 = 2.28
sim(it, milk) = 0.0625 × 0.089 + 1.375 × 1.911 = 2.68 sim(it, cat) = 0.0625 × 0.0625 + 1.375 × 1.375 = 2.28
例子 2 再来看另一种情况,对于另一个句子 “The cat drank the milk because it was hungry.” ,这个句子中 it 指代 cat,我们同样运用 Attention 算法,得到新的词 embedding:Q = K = V = [ . . . 2 2 . . . . . . 1 3 . . . . . . 2 2 . . . . . . 4 0 . . . ] cat milk it hungry
Q=K=V=
\begin{bmatrix}
... & 2 & 2 & ...\\
... & 1 & 3 & ...\\
... & 2 & 2 & ...\\
... & 4 & 0 & ...
\end{bmatrix}
\begin{matrix}
\text{cat} \\
\text{milk} \\
\text{it} \\
\text{hungry}
\end{matrix}
Q = K = V = ⎣ ⎡ ... ... ... ... 2 1 2 4 2 3 2 0 ... ... ... ... ⎦ ⎤ cat milk it hungry 计算 Attention 分数:Q ⋅ K T = cat milk it hungry cat 8 8 8 8 milk 8 10 8 4 it 8 8 8 8 hungry 8 4 8 16 S o f t m a x ( Q ⋅ K T d ) ⋅ V = [ . . . 1.125 0.875 . . . . . . 0.609 1.391 . . . . . . 1.125 0.875 . . . . . . 1.999 0.001 . . . ] cat milk it hungry
Q \cdot K^T=
\begin{array}{cccccc}
& \text{cat} & \text{milk} & \text{it} & \text{hungry} \\
\text{cat} & 8 & 8 & 8 & 8 \\
\text{milk} & 8 & 10 & 8 & 4 \\
\text{it} & 8 & 8 & 8 & 8 \\
\text{hungry} & 8 & 4 & 8 & 16 \\
\end{array}
\\
Softmax(\frac{Q \cdot K^T}{\sqrt{d}}) \cdot V =
\begin{bmatrix}
... & 1.125 & 0.875 & ...\\
... & 0.609 & 1.391 & ...\\
... & 1.125 & 0.875 & ...\\
... & 1.999 & 0.001 & ...
\end{bmatrix}
\begin{matrix}
\text{cat} \\
\text{milk} \\
\text{it} \\
\text{hungry}
\end{matrix}
Q ⋅ K T = cat milk it hungry cat 8 8 8 8 milk 8 10 8 4 it 8 8 8 8 hungry 8 4 8 16 S o f t ma x ( d Q ⋅ K T ) ⋅ V = ⎣ ⎡ ... ... ... ... 1.125 0.609 1.125 1.999 0.875 1.391 0.875 0.001 ... ... ... ... ⎦ ⎤ cat milk it hungry
之后我们得到 4 个单词的新 embedding:hungry = ( . . . , 4 , 0 , . . . ) → ( . . . , 1.999 , 0.001 , . . . ) milk = ( . . . , 1 , 3 , . . . ) → ( . . . , 0.609 , 1.391 , . . . ) it = ( . . . , 2 , 2 , . . . ) → ( . . . , 1.125 , 0.875 , . . . ) cat = ( . . . , 2 , 2 , . . . ) → ( . . . , 1.125 , 0.875 , . . . ) sweet = ( . . . , 0 , 4 , . . . )
\begin{aligned}
\text{hungry} &= (..., 4, 0, ...) \rightarrow (..., 1.999, 0.001, ...)\\
\text{milk} &= (..., 1, 3, ...) \rightarrow (..., 0.609, 1.391, ...)\\
\text{it} &= (..., 2, 2, ...) \rightarrow (..., 1.125, 0.875, ...)\\
\text{cat} &= (..., 2, 2, ...) \rightarrow (..., 1.125, 0.875, ...)\\
\text{sweet} &= (..., 0, 4, ...) \\
\end{aligned}
hungry milk it cat sweet = ( ... , 4 , 0 , ... ) → ( ... , 1.999 , 0.001 , ... ) = ( ... , 1 , 3 , ... ) → ( ... , 0.609 , 1.391 , ... ) = ( ... , 2 , 2 , ... ) → ( ... , 1.125 , 0.875 , ... ) = ( ... , 2 , 2 , ... ) → ( ... , 1.125 , 0.875 , ... ) = ( ... , 0 , 4 , ... ) 此时 it 的词向量更加接近 cat:sim(it, milk) = 1.125 × 0.609 + 0.875 × 1.391 = 1.90 sim(it, cat) = 1.125 × 1.125 + 0.875 × 0.875 = 2.03
\text{sim(it, milk)} = 1.125 \times 0.609 + 0.875 \times 1.391 = 1.90 \\
\text{sim(it, cat)} = 1.125 \times 1.125 + 0.875 \times 0.875 = 2.03
sim(it, milk) = 1.125 × 0.609 + 0.875 × 1.391 = 1.90 sim(it, cat) = 1.125 × 1.125 + 0.875 × 0.875 = 2.03
为什么要 scaling? scaling 目的是解决数值稳定性问题,从而提高训练的效率和性能。当 Q,K 的维度 d k d_k d k 很大时, q i , k i q_i,k_i q i , k i 的点积值可能变得很大。点积值越大,输入到 softmax 函数中的数值范围越广,可能会导致以下问题:
softmax 的梯度变得极小 : softmax 函数对大数值非常敏感,极大值会导致其他位置的权重几乎为 0,从而产生数值不稳定性。
模型训练变得困难 : 梯度消失问题会使得模型难以学习。
为什么是 d k d_k d k 而不是其他数? 对于输入特征 X X X ,其元素通常服从均值为 0、方差为 1 的标准正态分布。经过点积计算 Q K T QK^T Q K T 后,由于 q i , k i q_i,k_i q i , k i 的点积是 d k d_k d k 个独立随机变量的和,所以方差会变为 d k d_k d k ,除以 d k \sqrt{d_k} d k 可以让方差重新变为 1。
如果直接除以 d k d_k d k ,方差变为 1 / d k 1/d_k 1/ d k ,分布过于集中,让 softmax 的值趋于均匀分布,会弱化注意力机制的效果
如果除以 d k 3 \sqrt[3]{d_k} 3 d k ,会让方差仍然较大,可能会导致数值不稳定和训练困难
其余 scale 处理
归一化 Attention 分数,放缩到[0,1]范围
s c o r e = Q K T ∣ ∣ Q K T ∣ ∣
score=\frac{QK^T}{||QK^T||}
score = ∣∣ Q K T ∣∣ Q K T
温度参数:d k d_k d k 换成常数
预归一化
Q ′ = Q ∣ ∣ Q ∣ ∣ , K ′ = K ∣ ∣ K ∣ ∣ , V ′ = V
Q'=\frac{Q}{||Q||}, K'=\frac{K}{||K||}, V'=V
Q ′ = ∣∣ Q ∣∣ Q , K ′ = ∣∣ K ∣∣ K , V ′ = V
长序列放缩,长序列导致 Attention 分数范围增大,Softmax 失效
s o f t m a x ( Q K T d k ⋅ T ) V
softmax(\frac{QK^T}{\sqrt{d_k} \cdot \sqrt{T}})V\\
so f t ma x ( d k ⋅ T Q K T ) V
代码手撕 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 import torchimport torch.nn as nnimport mathclass SelfAttention (nn.Module): def __init__ (self, input_dim, dim_k=512 ): super ().__init__() self .norm_factor = 1 / math.sqrt(dim_k) self .q = nn.Linear(input_dim, dim_k) self .k = nn.Linear(input_dim, dim_k) self .v = nn.Linear(input_dim, dim_k) def forward (self, x, mask=None ): """ x.shape: [B, T, D] """ Q, K, V = self .q(x), self .k(x), self .v(x) score = torch.bmm(Q, K.transpose(1 , 2 )) * self .norm_factor if mask is not None : score += mask * -1e9 return torch.bmm(torch.softmax(score, dim=-1 ), V)
多头自注意力(MHA) 为什么要多头? 模型只能学习到一个层面的注意力模式,不能捕捉到输入序列中复杂的多样性关系。仅通过单个头来表示查询、键和值的投影,会限制模型的表达能力。 多头注意力的优势包括:
捕捉不同的上下文信息 :每个注意力头可以专注于不同的上下文信息或关系。例如,一个头可以专注于捕捉远距离词语之间的关系,而另一个头可以专注于局部词语之间的关系。
提高模型的表达能力 :通过并行计算多个注意力分布,模型能够从多个角度理解同一输入,从而获得更丰富的语义信息 。
提升模型的灵活性和鲁棒性 :多头注意力使得模型能够在多个子空间中进行学习,从而减少单一注意力头可能带来的信息损失。
举例:“The quick brown fox jumped over the lazy dog”。我们希望 Transformer 模型能够理解以下关系:
主谓关系 :“fox” 和 “jumped”
定语关系 :“quick” 修饰 “fox”
位置关系 :“over” 与 “jumped”
对模型来说,不同的头用来学习单词之间的不同关系,比如 “jump” 的头 1 用来学习与”fox”的主谓关系,头 2 用来学习与”over”的位置关系。
代码手撕 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 import torchimport torch.nn as nnimport mathclass MultiheadAttention (nn.Module): def __init__ (self, input_dim, dim_k, num_head=8 ): super ().__init__() assert dim_k % num_head == 0 self .dk = dim_k // num_head self .head = num_head self .norm_factor = 1 / math.sqrt(self .dk) self .q = nn.Linear(input_dim, dim_k) self .k = nn.Linear(input_dim, dim_k) self .v = nn.Linear(input_dim, dim_k) def forward (self, x, mask=None ): """ x.shape: [B, T, D] """ batch, seqlen, _ = x.shape Q, K, V = self .q(x), self .k(x), self .v(x) Q, K, V = ( Q.reshape(batch, seqlen, -1 , self .dk).transpose(1 , 2 ), K.reshape(batch, seqlen, -1 , self .dk).transpose(1 , 2 ), V.reshape(batch, seqlen, -1 , self .dk).transpose(1 , 2 ), ) score = torch.matmul(Q, K.transpose(-2 , -1 )) * self .norm_factor if mask is not None : score += mask * -1e9 output = torch.matmul(torch.softmax(score, dim=-1 ), V) output = output.reshape(batch, seqlen, -1 ) return output
多头潜在注意力(MLA) 背景 传统 Transformer 采用 MHA(Multi-Head Attention),但是 KV Cache 会成为推理瓶颈。MQA(Multi-Query Attention) 和 GQA(Grouped-Query Attention) 可以一定程度减少 KV Cache,但效果上不如MHA。DeepSeek-V2 设计了一种称为 MLA(Multi-Head Latent Attention) 的注意力机制。MLA 通过低秩key-value 联合压缩,实现了比 MHA 更好的效果并且需要的 kv cache 要小很多。
解决核心问题:使用 KV Cache 时,随着序列长度变长导致显存不足的问题。
低秩 Key-Value 联合压缩 在 MHA 的单头计算过程中,输入 X X X (维度 N × d e m b N \times d_{emb} N × d e mb )会先送入三个projection矩阵 W Q , W K , W V W_Q, W_K, W_V W Q , W K , W V (维度为 d e m b × d k d_{emb} \times d_k d e mb × d k ) 进行线性变换:Q = X × W Q ( N × d k ) K = X × W K ( N × d k ) V = X × W V ( N × d k )
Q = X \times W_Q \quad (N \times d_k) \\
K = X \times W_K \quad (N \times d_k) \\
V = X \times W_V \quad (N \times d_k)
Q = X × W Q ( N × d k ) K = X × W K ( N × d k ) V = X × W V ( N × d k )
变换完成之后生成的 K 、 V K、V K 、 V 就是传统意义上的 KV Cache。Multi-head 会将 h 个 head 的 Q 、 K 、 V Q、K、V Q 、 K 、 V 矩阵分别拼接在一起,构成 d × d d \times d d × d 的矩阵,其中 d = h × d k d = h \times d_k d = h × d k 。而在 MLA 中没有保留 KV Cache,而是引入一个 C Cache,也即增加了公式:C = X × W C ( N × d e m b ) × ( d e m b × d c )
C = X \times W_C \quad (N \times d_{emb}) \times (d_{emb} \times d_c)
C = X × W C ( N × d e mb ) × ( d e mb × d c )
W C W_C W C 的维度不再是
d e m b × d k d_{emb} \times d_k d e mb × d k ,而是变成了
d e m b × d c d_{emb} \times d_c d e mb × d c ,
d c d_c d c 通常远小雨
d d d ,但比
d k d_k d k 大(DeepSeek-V2 中
d k = 128 d_k=128 d k = 128 ,
d c = 512 d_c=512 d c = 512 )。有了
C C C 之后,为了保证 Attention 计算的正确性,我们可以从概念上再引入 K 和 V,也即得到如下两个关于 KV 的新公式:
K = C × W K ( N × d c ) × ( d c × d k ) V = C × W V ( N × d c ) × ( d c × d k )
K = C \times W_K \quad (N \times d_c) \times (d_c \times d_k) \\
V = C \times W_V \quad (N \times d_c) \times (d_c \times d_k)
K = C × W K ( N × d c ) × ( d c × d k ) V = C × W V ( N × d c ) × ( d c × d k )
此时 W K W_K W K 和 W V W_V W V 的维度就变成了 ( d c × d k ) (d_c \times d_k) ( d c × d k ) ,最终输出的 K 和 V 维度依旧是 ( N × d k ) (N \times d_k) ( N × d k ) 。但是请大家切记,MLA 最终的公式中没有出现 KV,我们现在处于将 MHA 转换到MLA的过程中,引入 K K K 和 V V V 只是方便进行公式推导。MHA 进行 Attention 计算时会将 Q Q Q 和 K T K^T K T 相乘,得到如下公式:Q K T = X × W Q × ( C × W K ) T = X × ( W Q × W K T ) × C T = Q ′ C T
QK^T = X \times W_Q \times (C \times W_K)^T = X \times (W_Q \times W_K^T) \times C^T = Q'C^T
Q K T = X × W Q × ( C × W K ) T = X × ( W Q × W K T ) × C T = Q ′ C T
于是我们可以将 W Q × W K T W_Q \times W_K^T W Q × W K T 当做新的 projection 矩阵 W Q ′ W_Q' W Q ′ (维度 d e m b × d c d_{emb} \times d_c d e mb × d c ),此时 Attention 计算就和 K K K 没什么关系了。类似的,Attention计算中的 V V V 也可以进行消除,从而使得整个 Attention 计算只和 Q ′ 、 C Q'、C Q ′ 、 C 相关。MLA更进一步,Attention 计算完成得到 O O O 之后,还会将其与 projection 矩阵 W O W_O W O 进行相乘,W V W_V W V 也可以与 W O W_O W O 融合在一起得到新的 projection 矩阵 W O ′ W_O' W O ′ 。最终我们可以得到 MLA Attention 部分的计算公式:Q ′ = X × W Q ′ C = X × W C O = A t t e n t i o n ( Q ′ , C ) = s o f t m a x ( Q ′ C T d k ) C O ′ = O × W O ′
Q' = X \times W_Q' \\
C = X \times W_C \\
O = Attention(Q', C) = softmax(\frac{Q' C^T}{\sqrt{d_k}})C \\
O' = O \times W_O'
Q ′ = X × W Q ′ C = X × W C O = A tt e n t i o n ( Q ′ , C ) = so f t ma x ( d k Q ′ C T ) C O ′ = O × W O ′
兼容 RoPE 然而,当下大模型一般会在Attention计算之前将 Q Q Q 和 K K K 添加 RoPE ,这就导致上面的 Q K T QK^T Q K T 过程不再成立。但是我们又不能丢弃 RoPE,为了弥补这个缺陷,MLA 又在 Q Q Q 和 K K K 中额外增加了 d r d_r d r 维度来专门用来存放 RoPE 信息。
Q r o p e = [ X × W C , X × W Q R × R ] = [ X × W Q , Q R ] W Q R : ( d × d r ) R : ( d r × d r ) Q r o p e : ( N × ( d c + d r ) )
Q_{rope} = [X \times W_C, X \times W_{QR} \times R] = [X \times W_Q, QR] \\
W_{QR}: (d \times d_r) \quad R: (d_r \times d_r) \quad Q_{rope}: (N \times (d_c + d_r))
Q ro p e = [ X × W C , X × W QR × R ] = [ X × W Q , QR ] W QR : ( d × d r ) R : ( d r × d r ) Q ro p e : ( N × ( d c + d r ))
其中 W Q R W_{QR} W QR 是新增的一个矩阵,R R R 是 RoPE 矩阵,用 Q R QR QR 来代表 Q Q Q 新增的子矩阵( N × d r ) (N \times d_r) ( N × d r ) 。类似的,K K K 也可以扩展为:K r o p e = [ C × W K , X × W K R × R ] = [ C × W K , K R ] W K R : ( d × d r ) K r o p e : ( N × ( d c + d r ) )
K_{rope} = [C \times W_K, X \times W_{KR} \times R] = [C \times W_K, KR] \\
W_{KR}: (d \times d_r) \quad K_{rope}: (N \times (d_c + d_r))
K ro p e = [ C × W K , X × W K R × R ] = [ C × W K , K R ] W K R : ( d × d r ) K ro p e : ( N × ( d c + d r ))
W K R W_{KR} W K R 也是新增的矩阵,维度也是
d × d r d \times d_r d × d r ,用
K R KR K R 来代表
K K K 新增的子矩阵
( N × d r ) (N \times d_r) ( N × d r ) ,它保存了 RoPE 信息。之后可以扩展后续公式为:
Q K T = [ X × W Q , Q R ] × [ C × W K , K R ] T = [ X × W Q , Q R ] × [ ( C × W K ) T K R T ] = [ X × W Q , Q R ] × [ W K T C T K R T ] = X × ( W C × W K T ) × C T + Q R × K R T = X × W Q ′ × C T + Q R × K R T = Q ′ C T + Q R × K R T
\begin{align*}
QK^T
&= [X \times W_Q, QR] \times [C \times W_K, KR]^T \\
&=[X \times W_Q, QR] \times
\begin{bmatrix}
(C \times W_K)^T \\
KR^T
\end{bmatrix} \\
&=[X \times W_Q, QR] \times
\begin{bmatrix}
W_K^T C^T \\
KR^T
\end{bmatrix} \\
&=X \times (W_C \times W_K^T) \times C^T + QR \times KR^T \\
&=X \times W_Q' \times C^T + QR \times KR^T \\
&=Q'C^T + QR \times KR^T
\end{align*}
Q K T = [ X × W Q , QR ] × [ C × W K , K R ] T = [ X × W Q , QR ] × [ ( C × W K ) T K R T ] = [ X × W Q , QR ] × [ W K T C T K R T ] = X × ( W C × W K T ) × C T + QR × K R T = X × W Q ′ × C T + QR × K R T = Q ′ C T + QR × K R T
于是,在 Q Q Q 和 K K K 新增 d r d_r d r 维度计算 RoPE 信息的情况下,Attention 的计算公式还可以继续复用上面提到的 C Cache 压缩技巧,只需要把第一部分提到的结果与新增的两个子矩阵相乘的结果相加即可。不过我们还需要把 KR 缓存下来,加速后半部分在 Decoding 阶段的计算速度。所以整体 MLA 需要保留 N × d c N \times d_c N × d c 的 C Cache,还需要保留 N × d r N \times d_r N × d r 的 KR Cache,但 d r d_r d r 一般不大,在论文中是 d k / 2 = 64 d_k /2 = 64 d k /2 = 64 。
计算量对比 MLA 相比原始的 MHA 简化了计算公式,压缩了缓存大小,难道天下真有免费的午餐吗?我们仔细对比一下 MHA、MQA、GQA 和 MLA 在推理过程中的模型单层参数量和计算量,如下表。假定:d k = 8 k h e a d = 64 d h = 128 g r o u p = 8 d c = 512 N = 128 k (模型 max token)
\begin{align*}
&d_k=8k \quad head=64 \quad d_h=128 \\
&group=8 \quad d_c=512 \quad N=128k\text{(模型 max token)}
\end{align*}
d k = 8 k h e a d = 64 d h = 128 g ro u p = 8 d c = 512 N = 128 k ( 模型 max token) 在统计计算量的时候只考虑了计算的大头,较小计算量的部分丢弃掉也不会产生大的影响。
Attention
MHA
MQA
GQA
MLA
参数量计算
W Q = W K = W V = d × h × d k W_Q=W_K=W_V\\ =d \times h \times d_k W Q = W K = W V = d × h × d k
W Q = W O = d × h × d k W K = W V = d × d k W_Q=W_O=d \times h \times d_k \\ W_K=W_V=d \times d_k W Q = W O = d × h × d k W K = W V = d × d k
W Q = W O = d × h × d k W K = W V = d × g × d k W_Q=W_O=d \times h \times d_k \\ W_K=W_V=d \times g \times d_k W Q = W O = d × h × d k W K = W V = d × g × d k
W Q = W O = d × h × d c W C = d × d c W_Q=W_O=d \times h \times d_c \\ W_C = d \times d_c W Q = W O = d × h × d c W C = d × d c
参数量
64M * 4 = 256 M
64 M * 2 + 2M = 130 M
64 M 2 + 8M 2 = 144 M
256 M * 2 + 4M = 516 M
缓存计算
N × h × d k × 2 N \times h \times d_k \times 2 N × h × d k × 2
N × d k × 2 N \times d_k \times 2 N × d k × 2
N × g × d k × 2 N \times g \times d_k \times 2 N × g × d k × 2
N × d c N \times d_c N × d c
缓存大小
2G
32M
256M
64M
Prefilling 计算量
256T
256T
256T
1000T
Decoding 计算量
2G
2G
2G
8G
可以看出,MLA在参数量和计算量上都比另外三种 Attention 计算方法要大,但 MLA 效果好于 MHA 的原因也是因为计算量增大,但好处就是大大降低了 Decoding 阶段的缓存大小,从而使得解码速度更快。上面结果没有考虑 RoPE 引入,但参数量和计算量增加 不影响结论。
参考:大模型推理-MLA