自注意力(Self-Attention) Q = X ⋅ W Q , K = X ⋅ W K , V = X ⋅ W V Attention ( Q , K , V ) = softmax ( Q K T d k ) V X . shape : [ B , T , D ] , W . shape : [ D , D ] , d k = D 
\begin{aligned}
& Q = X \cdot W_Q, \quad K = X \cdot W_K, \quad V = X \cdot W_V \\
& \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T} { \sqrt{d_k} }\right)V \\
& X.\text{shape}: [B, T, D], \quad W.\text{shape}: [D, D], \quad d_k = D
\end{aligned}
  Q = X ⋅ W Q  , K = X ⋅ W K  , V = X ⋅ W V  Attention ( Q , K , V ) = softmax ( d k   Q K T  ) V X . shape : [ B , T , D ] , W . shape : [ D , D ] , d k  = D  理论解释为什么 Attention 有效 假设我们有 5 个单词 {cat, milk, it, sweet, hungry}, 每个单词用一个向量表示:sweet = ( . . . , 0 , 4 , . . . ) milk = ( . . . , 1 , 3 , . . . ) it = ( . . . , 2 , 2 , . . . ) cat = ( . . . , 2 , 2 , . . . ) hungry = ( . . . , 4 , 0 , . . . ) 
\begin{aligned}
\text{sweet} &= (..., 0, 4, ...)\\
\text{milk} &= (..., 1, 3, ...)\\
\text{it} &= (..., 2, 2, ...)\\
\text{cat} &= (..., 2, 2, ...)\\
\text{hungry} &= (..., 4, 0, ...)
\end{aligned}
 sweet milk it cat hungry  = ( ... , 0 , 4 , ... ) = ( ... , 1 , 3 , ... ) = ( ... , 2 , 2 , ... ) = ( ... , 2 , 2 , ... ) = ( ... , 4 , 0 , ... )  
假设词向量的维度 dim=4,第一列数字(第 1 个头)代表这个单词关于状态的属性,值越高代表该单词与hungry越相关; 第二列数字(第 2 个头)代表这个单词关于味道的属性,值越高代表该单词与sweet越相关。
例子 1 现在让我们考虑 Attention 算法中,计算状态这部分的头。假设我们正在处理一个句子 “The cat drank the milk because it was sweet.” ,其中包含了 cat、milk、it、sweet 这四个单词(暂且忽略其余不相关单词,单词按词序组成矩阵)。此时在 Self-Attention 算法中的 Q、K、V 矩阵为:Q = K = V = [ . . . 2 2 . . . . . . 1 3 . . . . . . 2 2 . . . . . . 0 4 . . . ] cat milk it sweet 
Q=K=V=
\begin{bmatrix}
  ... & 2 & 2 & ...\\ 
  ... & 1 & 3 & ...\\ 
  ... & 2 & 2 & ...\\ 
  ... & 0 & 4 & ...
\end{bmatrix}
\begin{matrix}
  \text{cat} \\ 
  \text{milk} \\ 
  \text{it} \\ 
  \text{sweet}
\end{matrix}
 Q = K = V = ⎣ ⎡  ... ... ... ...  2 1 2 0  2 3 2 4  ... ... ... ...  ⎦ ⎤  cat milk it sweet  
现在我们计算 Attention 分数(为了方便理解,... 部分用 0 代替):
Q ⋅ K T = cat milk it sweet cat 8 8 8 8 milk 8 10 8 12 it 8 8 8 8 sweet 8 12 8 16 S o f t m a x ( Q ⋅ K T d ) ⋅ V = [ . . . 0.625 1.375 . . . . . . 0.089 1.911 . . . . . . 0.625 1.375 . . . . . . 0.010 1.990 . . . ] cat milk it sweet 
Q \cdot K^T=
\begin{array}{cccccc}
  & \text{cat} & \text{milk} & \text{it} & \text{sweet} \\
\text{cat} & 8 & 8 & 8 & 8 \\ 
\text{milk} & 8 & 10 & 8 & 12 \\ 
\text{it} & 8 & 8 & 8 & 8 \\ 
\text{sweet} & 8 & 12 & 8 & 16 \\
\end{array}
\\
Softmax(\frac{Q \cdot K^T}{\sqrt{d}}) \cdot V = 
\begin{bmatrix}
  ... & 0.625 & 1.375 & ...\\ 
  ... & 0.089 & 1.911 & ...\\ 
  ... & 0.625 & 1.375 & ...\\ 
  ... & 0.010 & 1.990 & ...
\end{bmatrix}
\begin{matrix}
  \text{cat} \\ 
  \text{milk} \\ 
  \text{it} \\ 
  \text{sweet}
\end{matrix}
 Q ⋅ K T = cat milk it sweet  cat 8 8 8 8  milk 8 10 8 12  it 8 8 8 8  sweet 8 12 8 16  S o f t ma x ( d  Q ⋅ K T  ) ⋅ V = ⎣ ⎡  ... ... ... ...  0.625 0.089 0.625 0.010  1.375 1.911 1.375 1.990  ... ... ... ...  ⎦ ⎤  cat milk it sweet  之后,我们得到这 4 个单词的新 embedding 为:sweet = ( . . . , 0 , 4 , . . . ) → ( . . . , 0.010 , 1.990 , . . . ) milk = ( . . . , 1 , 3 , . . . ) → ( . . . , 0.089 , 1.911 , . . . ) it = ( . . . , 2 , 2 , . . . ) → ( . . . , 0.625 , 1.375 , . . . ) cat = ( . . . , 2 , 2 , . . . ) → ( . . . , 0.625 , 1.375 , . . . ) hungry = ( . . . , 4 , 0 , . . . ) 
\begin{aligned}
\text{sweet} &= (..., 0, 4, ...) \rightarrow (..., 0.010, 1.990, ...)\\
\text{milk} &= (..., 1, 3, ...) \rightarrow (..., 0.089, 1.911, ...)\\
\text{it} &= (..., 2, 2, ...) \rightarrow (..., 0.625, 1.375, ...)\\
\text{cat} &= (..., 2, 2, ...) \rightarrow (..., 0.625, 1.375, ...)\\
\text{hungry} &= (..., 4, 0, ...) 
\end{aligned}
 sweet milk it cat hungry  = ( ... , 0 , 4 , ... ) → ( ... , 0.010 , 1.990 , ... ) = ( ... , 1 , 3 , ... ) → ( ... , 0.089 , 1.911 , ... ) = ( ... , 2 , 2 , ... ) → ( ... , 0.625 , 1.375 , ... ) = ( ... , 2 , 2 , ... ) → ( ... , 0.625 , 1.375 , ... ) = ( ... , 4 , 0 , ... )  
通常情况下, 在英文中 it 既可以指代 cat 又可以指代 milk,因此 it 和 cat 的相似度与 it 和 milk 的相似度相同,即:sim(it, milk) = 1 × 2 + 3 × 2 = 8 sim(it, cat) = 2 × 2 + 2 × 2 = 8 
\text{sim(it, milk)} = 1 \times 2 + 3 \times 2 = 8 \\
\text{sim(it, cat)}  = 2 \times 2 + 2 \times 2 = 8
 sim(it, milk) = 1 × 2 + 3 × 2 = 8 sim(it, cat) = 2 × 2 + 2 × 2 = 8 “The cat drank the milk because it was sweet.” ,这个句子中 it 指代 milk,通过 Attention 算法后得到了新的词 embedding,这时 it 在词向量表达上更加靠近 milk:sim(it, milk) = 0.0625 × 0.089 + 1.375 × 1.911 = 2.68 sim(it, cat) = 0.0625 × 0.0625 + 1.375 × 1.375 = 2.28 
\text{sim(it, milk)} = 0.0625 \times 0.089 + 1.375 \times 1.911 = 2.68 \\
\text{sim(it, cat)}  = 0.0625 \times 0.0625 + 1.375 \times 1.375 = 2.28
 sim(it, milk) = 0.0625 × 0.089 + 1.375 × 1.911 = 2.68 sim(it, cat) = 0.0625 × 0.0625 + 1.375 × 1.375 = 2.28 
例子 2 再来看另一种情况,对于另一个句子 “The cat drank the milk because it was hungry.” ,这个句子中 it 指代 cat,我们同样运用 Attention 算法,得到新的词 embedding:Q = K = V = [ . . . 2 2 . . . . . . 1 3 . . . . . . 2 2 . . . . . . 4 0 . . . ] cat milk it hungry 
Q=K=V=
\begin{bmatrix}
  ... & 2 & 2 & ...\\ 
  ... & 1 & 3 & ...\\ 
  ... & 2 & 2 & ...\\ 
  ... & 4 & 0 & ...
\end{bmatrix}
\begin{matrix}
  \text{cat} \\ 
  \text{milk} \\ 
  \text{it} \\ 
  \text{hungry}
\end{matrix}
 Q = K = V = ⎣ ⎡  ... ... ... ...  2 1 2 4  2 3 2 0  ... ... ... ...  ⎦ ⎤  cat milk it hungry  Q ⋅ K T = cat milk it hungry cat 8 8 8 8 milk 8 10 8 4 it 8 8 8 8 hungry 8 4 8 16 S o f t m a x ( Q ⋅ K T d ) ⋅ V = [ . . . 1.125 0.875 . . . . . . 0.609 1.391 . . . . . . 1.125 0.875 . . . . . . 1.999 0.001 . . . ] cat milk it hungry 
Q \cdot K^T=
\begin{array}{cccccc}
  & \text{cat} & \text{milk} & \text{it} & \text{hungry} \\
\text{cat} & 8 & 8 & 8 & 8 \\ 
\text{milk} & 8 & 10 & 8 & 4 \\ 
\text{it} & 8 & 8 & 8 & 8 \\ 
\text{hungry} & 8 & 4 & 8 & 16 \\
\end{array}
\\
Softmax(\frac{Q \cdot K^T}{\sqrt{d}}) \cdot V = 
\begin{bmatrix}
  ... & 1.125 & 0.875 & ...\\ 
  ... & 0.609 & 1.391 & ...\\ 
  ... & 1.125 & 0.875 & ...\\ 
  ... & 1.999 & 0.001 & ...
\end{bmatrix}
\begin{matrix}
  \text{cat} \\ 
  \text{milk} \\ 
  \text{it} \\ 
  \text{hungry}
\end{matrix}
 Q ⋅ K T = cat milk it hungry  cat 8 8 8 8  milk 8 10 8 4  it 8 8 8 8  hungry 8 4 8 16  S o f t ma x ( d  Q ⋅ K T  ) ⋅ V = ⎣ ⎡  ... ... ... ...  1.125 0.609 1.125 1.999  0.875 1.391 0.875 0.001  ... ... ... ...  ⎦ ⎤  cat milk it hungry  
之后我们得到 4 个单词的新 embedding:hungry = ( . . . , 4 , 0 , . . . ) → ( . . . , 1.999 , 0.001 , . . . ) milk = ( . . . , 1 , 3 , . . . ) → ( . . . , 0.609 , 1.391 , . . . ) it = ( . . . , 2 , 2 , . . . ) → ( . . . , 1.125 , 0.875 , . . . ) cat = ( . . . , 2 , 2 , . . . ) → ( . . . , 1.125 , 0.875 , . . . ) sweet = ( . . . , 0 , 4 , . . . ) 
\begin{aligned}
\text{hungry} &= (..., 4, 0, ...) \rightarrow (..., 1.999, 0.001, ...)\\
\text{milk} &= (..., 1, 3, ...) \rightarrow (..., 0.609, 1.391, ...)\\
\text{it} &= (..., 2, 2, ...) \rightarrow (..., 1.125, 0.875, ...)\\
\text{cat} &= (..., 2, 2, ...) \rightarrow (..., 1.125, 0.875, ...)\\
\text{sweet} &= (..., 0, 4, ...) \\
\end{aligned}
 hungry milk it cat sweet  = ( ... , 4 , 0 , ... ) → ( ... , 1.999 , 0.001 , ... ) = ( ... , 1 , 3 , ... ) → ( ... , 0.609 , 1.391 , ... ) = ( ... , 2 , 2 , ... ) → ( ... , 1.125 , 0.875 , ... ) = ( ... , 2 , 2 , ... ) → ( ... , 1.125 , 0.875 , ... ) = ( ... , 0 , 4 , ... )  sim(it, milk) = 1.125 × 0.609 + 0.875 × 1.391 = 1.90 sim(it, cat) = 1.125 × 1.125 + 0.875 × 0.875 = 2.03 
\text{sim(it, milk)} = 1.125 \times 0.609 + 0.875 \times 1.391 = 1.90 \\
\text{sim(it, cat)}  = 1.125 \times 1.125 + 0.875 \times 0.875 = 2.03
 sim(it, milk) = 1.125 × 0.609 + 0.875 × 1.391 = 1.90 sim(it, cat) = 1.125 × 1.125 + 0.875 × 0.875 = 2.03 
为什么要 scaling? scaling 目的是解决数值稳定性问题,从而提高训练的效率和性能。当 Q,K 的维度 d k d_k d k  q i , k i q_i,k_i q i  , k i  
softmax 的梯度变得极小 : softmax 函数对大数值非常敏感,极大值会导致其他位置的权重几乎为 0,从而产生数值不稳定性。模型训练变得困难 : 梯度消失问题会使得模型难以学习。 
为什么是 d k d_k d k   对于输入特征 X X X Q K T QK^T Q K T q i , k i q_i,k_i q i  , k i  d k d_k d k  d k d_k d k  d k \sqrt{d_k} d k   
如果直接除以 d k d_k d k  1 / d k 1/d_k 1/ d k   
如果除以 d k 3 \sqrt[3]{d_k} 3 d k    
 
其余 scale 处理 
归一化 Attention 分数,放缩到[0,1]范围
 s c o r e = Q K T ∣ ∣ Q K T ∣ ∣ 
    score=\frac{QK^T}{||QK^T||}
     score = ∣∣ Q K T ∣∣ Q K T  温度参数:d k d_k d k  
预归一化
 Q ′ = Q ∣ ∣ Q ∣ ∣ , K ′ = K ∣ ∣ K ∣ ∣ , V ′ = V 
    Q'=\frac{Q}{||Q||}, K'=\frac{K}{||K||}, V'=V
     Q ′ = ∣∣ Q ∣∣ Q  , K ′ = ∣∣ K ∣∣ K  , V ′ = V 长序列放缩,长序列导致 Attention 分数范围增大,Softmax 失效
 s o f t m a x ( Q K T d k ⋅ T ) V 
    softmax(\frac{QK^T}{\sqrt{d_k} \cdot \sqrt{T}})V\\
     so f t ma x ( d k   ⋅ T  Q K T  ) V  
代码手撕 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 import  torchimport  torch.nn as  nnimport  mathclass  SelfAttention (nn.Module):def  __init__ (self, input_dim, dim_k=512  ):super ().__init__()self .norm_factor = 1  / math.sqrt(dim_k)self .q = nn.Linear(input_dim, dim_k)self .k = nn.Linear(input_dim, dim_k)self .v = nn.Linear(input_dim, dim_k)def  forward (self, x, mask=None  ):"""         x.shape: [B, T, D]         """ self .q(x), self .k(x), self .v(x)1 , 2 )) * self .norm_factorif  mask is  not  None :1e9 return  torch.bmm(torch.softmax(score, dim=-1 ), V)
多头自注意力(MHA) 为什么要多头? 模型只能学习到一个层面的注意力模式,不能捕捉到输入序列中复杂的多样性关系。仅通过单个头来表示查询、键和值的投影,会限制模型的表达能力。
捕捉不同的上下文信息 :每个注意力头可以专注于不同的上下文信息或关系。例如,一个头可以专注于捕捉远距离词语之间的关系,而另一个头可以专注于局部词语之间的关系。提高模型的表达能力 :通过并行计算多个注意力分布,模型能够从多个角度理解同一输入,从而获得更丰富的语义信息 。提升模型的灵活性和鲁棒性 :多头注意力使得模型能够在多个子空间中进行学习,从而减少单一注意力头可能带来的信息损失。 
举例:“The quick brown fox jumped over the lazy dog”。我们希望 Transformer 模型能够理解以下关系:
主谓关系 :“fox” 和 “jumped”定语关系 :“quick” 修饰 “fox”位置关系 :“over” 与 “jumped” 
对模型来说,不同的头用来学习单词之间的不同关系,比如 “jump” 的头 1 用来学习与”fox”的主谓关系,头 2 用来学习与”over”的位置关系。
代码手撕 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 import  torchimport  torch.nn as  nnimport  mathclass  MultiheadAttention (nn.Module):def  __init__ (self, input_dim, dim_k, num_head=8  ):super ().__init__()assert  dim_k % num_head == 0 self .dk = dim_k // num_headself .head = num_headself .norm_factor = 1  / math.sqrt(self .dk)self .q = nn.Linear(input_dim, dim_k)self .k = nn.Linear(input_dim, dim_k)self .v = nn.Linear(input_dim, dim_k)def  forward (self, x, mask=None  ):"""         x.shape: [B, T, D]         """ self .q(x), self .k(x), self .v(x)1 , self .dk).transpose(1 , 2 ),1 , self .dk).transpose(1 , 2 ),1 , self .dk).transpose(1 , 2 ),2 , -1 )) * self .norm_factorif  mask is  not  None :1e9 1 ), V)1 )return  output
多头潜在注意力(MLA) 背景 传统 Transformer 采用 MHA(Multi-Head Attention),但是 KV Cache 会成为推理瓶颈。MQA(Multi-Query Attention) 和 GQA(Grouped-Query Attention) 可以一定程度减少 KV Cache,但效果上不如MHA。DeepSeek-V2 设计了一种称为 MLA(Multi-Head Latent Attention) 的注意力机制。MLA 通过低秩key-value 联合压缩,实现了比 MHA 更好的效果并且需要的 kv cache 要小很多。
解决核心问题:使用 KV Cache 时,随着序列长度变长导致显存不足的问题。
 
低秩 Key-Value 联合压缩 在 MHA 的单头计算过程中,输入 X X X N × d e m b N \times d_{emb} N × d e mb  W Q , W K , W V W_Q, W_K, W_V W Q  , W K  , W V  d e m b × d k d_{emb} \times d_k d e mb  × d k  Q = X × W Q ( N × d k ) K = X × W K ( N × d k ) V = X × W V ( N × d k ) 
Q = X \times W_Q \quad (N \times d_k) \\
K = X \times W_K \quad (N \times d_k) \\
V = X \times W_V \quad (N \times d_k)
 Q = X × W Q  ( N × d k  ) K = X × W K  ( N × d k  ) V = X × W V  ( N × d k  ) 
变换完成之后生成的 K 、 V K、V K 、 V Q 、 K 、 V Q、K、V Q 、 K 、 V d × d d \times d d × d d = h × d k d = h \times d_k d = h × d k  C = X × W C ( N × d e m b ) × ( d e m b × d c ) 
C = X \times W_C \quad (N \times d_{emb}) \times (d_{emb} \times d_c)
 C = X × W C  ( N × d e mb  ) × ( d e mb  × d c  ) 
W C W_C W C   的维度不再是 
d e m b × d k d_{emb} \times d_k d e mb  × d k  ,而是变成了 
d e m b × d c d_{emb} \times d_c d e mb  × d c  ,
d c d_c d c   通常远小雨 
d d d ,但比 
d k d_k d k   大(DeepSeek-V2 中 
d k = 128 d_k=128 d k  = 128 ,
d c = 512 d_c=512 d c  = 512 )。有了 
C C C  之后,为了保证 Attention 计算的正确性,我们可以从概念上再引入 K 和 V,也即得到如下两个关于 KV 的新公式:
K = C × W K ( N × d c ) × ( d c × d k ) V = C × W V ( N × d c ) × ( d c × d k ) 
K = C \times W_K \quad (N \times d_c) \times (d_c \times d_k) \\
V = C \times W_V \quad (N \times d_c) \times (d_c \times d_k)
 K = C × W K  ( N × d c  ) × ( d c  × d k  ) V = C × W V  ( N × d c  ) × ( d c  × d k  ) 此时 W K W_K W K  W V W_V W V  ( d c × d k ) (d_c \times d_k) ( d c  × d k  ) ( N × d k ) (N \times d_k) ( N × d k  ) K K K V V V Q Q Q K T K^T K T Q K T = X × W Q × ( C × W K ) T = X × ( W Q × W K T ) × C T = Q ′ C T 
QK^T = X \times W_Q \times (C \times W_K)^T = X \times (W_Q \times W_K^T) \times C^T = Q'C^T
 Q K T = X × W Q  × ( C × W K  ) T = X × ( W Q  × W K T  ) × C T = Q ′ C T 
于是我们可以将 W Q × W K T W_Q \times W_K^T W Q  × W K T  W Q ′ W_Q' W Q ′  d e m b × d c d_{emb} \times d_c d e mb  × d c  K K K V V V Q ′ 、 C Q'、C Q ′ 、 C O O O W O W_O W O  W V W_V W V  W O W_O W O  W O ′ W_O' W O ′  Q ′ = X × W Q ′ C = X × W C O = A t t e n t i o n ( Q ′ , C ) = s o f t m a x ( Q ′ C T d k ) C O ′ = O × W O ′ 
Q' = X \times W_Q' \\
C = X \times W_C \\
O = Attention(Q', C) = softmax(\frac{Q' C^T}{\sqrt{d_k}})C \\
O' = O \times W_O'
 Q ′ = X × W Q ′  C = X × W C  O = A tt e n t i o n ( Q ′ , C ) = so f t ma x ( d k   Q ′ C T  ) C O ′ = O × W O ′  
兼容 RoPE 然而,当下大模型一般会在Attention计算之前将 Q Q Q K K K RoPE ,这就导致上面的 Q K T QK^T Q K T Q Q Q K K K d r d_r d r  
Q r o p e = [ X × W C , X × W Q R × R ] = [ X × W Q , Q R ] W Q R : ( d × d r ) R : ( d r × d r ) Q r o p e : ( N × ( d c + d r ) ) 
Q_{rope} = [X \times W_C, X \times W_{QR} \times R] = [X \times W_Q, QR] \\
W_{QR}: (d \times d_r) \quad R: (d_r \times d_r) \quad Q_{rope}: (N \times (d_c + d_r))
 Q ro p e  = [ X × W C  , X × W QR  × R ] = [ X × W Q  , QR ] W QR  : ( d × d r  ) R : ( d r  × d r  ) Q ro p e  : ( N × ( d c  + d r  )) 其中 W Q R W_{QR} W QR  R R R Q R QR QR Q Q Q ( N × d r ) (N \times d_r) ( N × d r  ) K K K K r o p e = [ C × W K , X × W K R × R ] = [ C × W K , K R ] W K R : ( d × d r ) K r o p e : ( N × ( d c + d r ) ) 
K_{rope} = [C \times W_K, X \times W_{KR} \times R] = [C \times W_K, KR] \\
W_{KR}: (d \times d_r) \quad K_{rope}: (N \times (d_c + d_r))
 K ro p e  = [ C × W K  , X × W K R  × R ] = [ C × W K  , K R ] W K R  : ( d × d r  ) K ro p e  : ( N × ( d c  + d r  )) 
W K R W_{KR} W K R   也是新增的矩阵,维度也是 
d × d r d \times d_r d × d r  ,用 
K R KR K R  来代表 
K K K  新增的子矩阵
( N × d r ) (N \times d_r) ( N × d r  ) ,它保存了 RoPE 信息。之后可以扩展后续公式为:
Q K T = [ X × W Q , Q R ] × [ C × W K , K R ] T = [ X × W Q , Q R ] × [ ( C × W K ) T K R T ] = [ X × W Q , Q R ] × [ W K T C T K R T ] = X × ( W C × W K T ) × C T + Q R × K R T = X × W Q ′ × C T + Q R × K R T = Q ′ C T + Q R × K R T 
\begin{align*}
QK^T 
&= [X \times W_Q, QR] \times [C \times W_K, KR]^T \\
&=[X \times W_Q, QR] \times 
\begin{bmatrix}
(C \times W_K)^T \\
KR^T
\end{bmatrix} \\
&=[X \times W_Q, QR] \times 
\begin{bmatrix}
W_K^T C^T \\
KR^T
\end{bmatrix} \\
&=X \times (W_C \times W_K^T) \times C^T + QR \times KR^T \\
&=X \times W_Q' \times C^T + QR \times KR^T \\
&=Q'C^T + QR \times KR^T
\end{align*}
 Q K T  = [ X × W Q  , QR ] × [ C × W K  , K R ] T = [ X × W Q  , QR ] × [ ( C × W K  ) T K R T  ] = [ X × W Q  , QR ] × [ W K T  C T K R T  ] = X × ( W C  × W K T  ) × C T + QR × K R T = X × W Q ′  × C T + QR × K R T = Q ′ C T + QR × K R T  于是,在 Q Q Q K K K d r d_r d r  N × d c N \times d_c N × d c  N × d r N \times d_r N × d r  d r d_r d r  d k / 2 = 64 d_k /2 = 64 d k  /2 = 64 
计算量对比 MLA 相比原始的 MHA 简化了计算公式,压缩了缓存大小,难道天下真有免费的午餐吗?我们仔细对比一下 MHA、MQA、GQA 和 MLA 在推理过程中的模型单层参数量和计算量,如下表。假定:d k = 8 k h e a d = 64 d h = 128 g r o u p = 8 d c = 512 N = 128 k (模型 max token) 
\begin{align*}
&d_k=8k \quad head=64 \quad d_h=128 \\ 
&group=8 \quad d_c=512 \quad N=128k\text{(模型 max token)}
\end{align*}
  d k  = 8 k h e a d = 64 d h  = 128 g ro u p = 8 d c  = 512 N = 128 k ( 模型  max token)  
Attention 
MHA 
MQA 
GQA 
MLA 
 
 
参数量计算 
W Q = W K = W V = d × h × d k W_Q=W_K=W_V\\ =d \times h \times d_k W Q  = W K  = W V  = d × h × d k  W Q = W O = d × h × d k W K = W V = d × d k W_Q=W_O=d \times h \times d_k \\ W_K=W_V=d \times d_k W Q  = W O  = d × h × d k  W K  = W V  = d × d k  W Q = W O = d × h × d k W K = W V = d × g × d k W_Q=W_O=d \times h \times d_k \\ W_K=W_V=d \times g \times d_k W Q  = W O  = d × h × d k  W K  = W V  = d × g × d k  W Q = W O = d × h × d c W C = d × d c W_Q=W_O=d \times h \times d_c \\ W_C = d \times d_c W Q  = W O  = d × h × d c  W C  = d × d c   
参数量 
64M * 4 = 256 M 
64 M * 2 + 2M = 130 M 
64 M  2 + 8M   2 = 144 M 
256 M * 2 + 4M = 516 M 
 
缓存计算 
N × h × d k × 2 N \times h \times d_k \times 2 N × h × d k  × 2 N × d k × 2 N \times d_k \times 2 N × d k  × 2 N × g × d k × 2 N \times g \times d_k \times 2 N × g × d k  × 2 N × d c N \times d_c N × d c   
缓存大小 
2G 
32M 
256M 
64M 
 
Prefilling 计算量 
256T 
256T 
256T 
1000T 
 
Decoding 计算量 
2G 
2G 
2G 
8G 
 
 
 
可以看出,MLA在参数量和计算量上都比另外三种 Attention 计算方法要大,但 MLA 效果好于 MHA 的原因也是因为计算量增大,但好处就是大大降低了 Decoding 阶段的缓存大小,从而使得解码速度更快。上面结果没有考虑 RoPE 引入,但参数量和计算量增加 不影响结论。
参考:大模型推理-MLA