Attention 算法

自注意力(Self-Attention)

Q=XWQ,K=XWK,V=XWVAttention(Q,K,V)=softmax(QKTdk)VX.shape:[B,T,D],W.shape:[D,D],dk=D \begin{aligned} & Q = X \cdot W_Q, \quad K = X \cdot W_K, \quad V = X \cdot W_V \\ & \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T} { \sqrt{d_k} }\right)V \\ & X.\text{shape}: [B, T, D], \quad W.\text{shape}: [D, D], \quad d_k = D \end{aligned}

理论解释为什么 Attention 有效

假设我们有 5 个单词 {cat, milk, it, sweet, hungry}, 每个单词用一个向量表示:
sweet=(...,0,4,...)milk=(...,1,3,...)it=(...,2,2,...)cat=(...,2,2,...)hungry=(...,4,0,...) \begin{aligned} \text{sweet} &= (..., 0, 4, ...)\\ \text{milk} &= (..., 1, 3, ...)\\ \text{it} &= (..., 2, 2, ...)\\ \text{cat} &= (..., 2, 2, ...)\\ \text{hungry} &= (..., 4, 0, ...) \end{aligned}

假设词向量的维度 dim=4,第一列数字(第 1 个头)代表这个单词关于状态的属性,值越高代表该单词与hungry越相关; 第二列数字(第 2 个头)代表这个单词关于味道的属性,值越高代表该单词与sweet越相关。

例子 1

现在让我们考虑 Attention 算法中,计算状态这部分的头。假设我们正在处理一个句子 “The cat drank the milk because it was sweet.”,其中包含了 cat、milk、it、sweet 这四个单词(暂且忽略其余不相关单词,单词按词序组成矩阵)。此时在 Self-Attention 算法中的 Q、K、V 矩阵为:
Q=K=V=[...22......13......22......04...]catmilkitsweet Q=K=V= \begin{bmatrix} ... & 2 & 2 & ...\\ ... & 1 & 3 & ...\\ ... & 2 & 2 & ...\\ ... & 0 & 4 & ... \end{bmatrix} \begin{matrix} \text{cat} \\ \text{milk} \\ \text{it} \\ \text{sweet} \end{matrix}

现在我们计算 Attention 分数(为了方便理解,... 部分用 0 代替):

QKT=catmilkitsweetcat8888milk810812it8888sweet812816Softmax(QKTd)V=[...0.6251.375......0.0891.911......0.6251.375......0.0101.990...]catmilkitsweet Q \cdot K^T= \begin{array}{cccccc} & \text{cat} & \text{milk} & \text{it} & \text{sweet} \\ \text{cat} & 8 & 8 & 8 & 8 \\ \text{milk} & 8 & 10 & 8 & 12 \\ \text{it} & 8 & 8 & 8 & 8 \\ \text{sweet} & 8 & 12 & 8 & 16 \\ \end{array} \\ Softmax(\frac{Q \cdot K^T}{\sqrt{d}}) \cdot V = \begin{bmatrix} ... & 0.625 & 1.375 & ...\\ ... & 0.089 & 1.911 & ...\\ ... & 0.625 & 1.375 & ...\\ ... & 0.010 & 1.990 & ... \end{bmatrix} \begin{matrix} \text{cat} \\ \text{milk} \\ \text{it} \\ \text{sweet} \end{matrix}

之后,我们得到这 4 个单词的新 embedding 为:
sweet=(...,0,4,...)(...,0.010,1.990,...)milk=(...,1,3,...)(...,0.089,1.911,...)it=(...,2,2,...)(...,0.625,1.375,...)cat=(...,2,2,...)(...,0.625,1.375,...)hungry=(...,4,0,...) \begin{aligned} \text{sweet} &= (..., 0, 4, ...) \rightarrow (..., 0.010, 1.990, ...)\\ \text{milk} &= (..., 1, 3, ...) \rightarrow (..., 0.089, 1.911, ...)\\ \text{it} &= (..., 2, 2, ...) \rightarrow (..., 0.625, 1.375, ...)\\ \text{cat} &= (..., 2, 2, ...) \rightarrow (..., 0.625, 1.375, ...)\\ \text{hungry} &= (..., 4, 0, ...) \end{aligned}

通常情况下, 在英文中 it 既可以指代 cat 又可以指代 milk,因此 it 和 cat 的相似度与 it 和 milk 的相似度相同,即:
sim(it, milk)=1×2+3×2=8sim(it, cat)=2×2+2×2=8 \text{sim(it, milk)} = 1 \times 2 + 3 \times 2 = 8 \\ \text{sim(it, cat)} = 2 \times 2 + 2 \times 2 = 8
之后,模型学习了句子 “The cat drank the milk because it was sweet.”,这个句子中 it 指代 milk,通过 Attention 算法后得到了新的词 embedding,这时 it 在词向量表达上更加靠近 milk:
sim(it, milk)=0.0625×0.089+1.375×1.911=2.68sim(it, cat)=0.0625×0.0625+1.375×1.375=2.28 \text{sim(it, milk)} = 0.0625 \times 0.089 + 1.375 \times 1.911 = 2.68 \\ \text{sim(it, cat)} = 0.0625 \times 0.0625 + 1.375 \times 1.375 = 2.28

例子 2

再来看另一种情况,对于另一个句子 “The cat drank the milk because it was hungry.”,这个句子中 it 指代 cat,我们同样运用 Attention 算法,得到新的词 embedding:
Q=K=V=[...22......13......22......40...]catmilkithungry Q=K=V= \begin{bmatrix} ... & 2 & 2 & ...\\ ... & 1 & 3 & ...\\ ... & 2 & 2 & ...\\ ... & 4 & 0 & ... \end{bmatrix} \begin{matrix} \text{cat} \\ \text{milk} \\ \text{it} \\ \text{hungry} \end{matrix}
计算 Attention 分数:
QKT=catmilkithungrycat8888milk81084it8888hungry84816Softmax(QKTd)V=[...1.1250.875......0.6091.391......1.1250.875......1.9990.001...]catmilkithungry Q \cdot K^T= \begin{array}{cccccc} & \text{cat} & \text{milk} & \text{it} & \text{hungry} \\ \text{cat} & 8 & 8 & 8 & 8 \\ \text{milk} & 8 & 10 & 8 & 4 \\ \text{it} & 8 & 8 & 8 & 8 \\ \text{hungry} & 8 & 4 & 8 & 16 \\ \end{array} \\ Softmax(\frac{Q \cdot K^T}{\sqrt{d}}) \cdot V = \begin{bmatrix} ... & 1.125 & 0.875 & ...\\ ... & 0.609 & 1.391 & ...\\ ... & 1.125 & 0.875 & ...\\ ... & 1.999 & 0.001 & ... \end{bmatrix} \begin{matrix} \text{cat} \\ \text{milk} \\ \text{it} \\ \text{hungry} \end{matrix}

之后我们得到 4 个单词的新 embedding:
hungry=(...,4,0,...)(...,1.999,0.001,...)milk=(...,1,3,...)(...,0.609,1.391,...)it=(...,2,2,...)(...,1.125,0.875,...)cat=(...,2,2,...)(...,1.125,0.875,...)sweet=(...,0,4,...) \begin{aligned} \text{hungry} &= (..., 4, 0, ...) \rightarrow (..., 1.999, 0.001, ...)\\ \text{milk} &= (..., 1, 3, ...) \rightarrow (..., 0.609, 1.391, ...)\\ \text{it} &= (..., 2, 2, ...) \rightarrow (..., 1.125, 0.875, ...)\\ \text{cat} &= (..., 2, 2, ...) \rightarrow (..., 1.125, 0.875, ...)\\ \text{sweet} &= (..., 0, 4, ...) \\ \end{aligned}
此时 it 的词向量更加接近 cat:
sim(it, milk)=1.125×0.609+0.875×1.391=1.90sim(it, cat)=1.125×1.125+0.875×0.875=2.03 \text{sim(it, milk)} = 1.125 \times 0.609 + 0.875 \times 1.391 = 1.90 \\ \text{sim(it, cat)} = 1.125 \times 1.125 + 0.875 \times 0.875 = 2.03

为什么要 scaling?

scaling 目的是解决数值稳定性问题,从而提高训练的效率和性能。当 Q,K 的维度 dkd_k 很大时, qi,kiq_i,k_i 的点积值可能变得很大。点积值越大,输入到 softmax 函数中的数值范围越广,可能会导致以下问题:

  • softmax 的梯度变得极小: softmax 函数对大数值非常敏感,极大值会导致其他位置的权重几乎为 0,从而产生数值不稳定性。
  • 模型训练变得困难: 梯度消失问题会使得模型难以学习。

为什么是 dkd_k 而不是其他数?

对于输入特征 XX,其元素通常服从均值为 0、方差为 1 的标准正态分布。经过点积计算 QKTQK^T 后,由于 qi,kiq_i,k_i 的点积是 dkd_k 个独立随机变量的和,所以方差会变为 dkd_k,除以 dk\sqrt{d_k} 可以让方差重新变为 1。

  • 如果直接除以 dkd_k,方差变为 1/dk1/d_k,分布过于集中,让 softmax 的值趋于均匀分布,会弱化注意力机制的效果
  • 如果除以 dk3\sqrt[3]{d_k},会让方差仍然较大,可能会导致数值不稳定和训练困难

其余 scale 处理

  1. 归一化 Attention 分数,放缩到[0,1]范围

    score=QKTQKT score=\frac{QK^T}{||QK^T||}
  2. 温度参数:dkd_k 换成常数

  3. 预归一化

    Q=QQ,K=KK,V=V Q'=\frac{Q}{||Q||}, K'=\frac{K}{||K||}, V'=V
  4. 长序列放缩,长序列导致 Attention 分数范围增大,Softmax 失效

    softmax(QKTdkT)V softmax(\frac{QK^T}{\sqrt{d_k} \cdot \sqrt{T}})V\\

代码手撕

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn
import math

class SelfAttention(nn.Module):
def __init__(self, input_dim, dim_k=512):
super().__init__()
self.norm_factor = 1 / math.sqrt(dim_k)
self.q = nn.Linear(input_dim, dim_k)
self.k = nn.Linear(input_dim, dim_k)
self.v = nn.Linear(input_dim, dim_k)

def forward(self, x, mask=None):
"""
x.shape: [B, T, D]
"""
Q, K, V = self.q(x), self.k(x), self.v(x)
# torch.bmm 输入为 3 维矩阵, 批量相乘, 速度快
# torch.matmul 输入可为多种矩阵, 更灵活
score = torch.bmm(Q, K.transpose(1, 2)) * self.norm_factor
if mask is not None:
score += mask * -1e9
return torch.bmm(torch.softmax(score, dim=-1), V)

多头自注意力(MHA)

为什么要多头?

模型只能学习到一个层面的注意力模式,不能捕捉到输入序列中复杂的多样性关系。仅通过单个头来表示查询、键和值的投影,会限制模型的表达能力。
多头注意力的优势包括:

  • 捕捉不同的上下文信息:每个注意力头可以专注于不同的上下文信息或关系。例如,一个头可以专注于捕捉远距离词语之间的关系,而另一个头可以专注于局部词语之间的关系。
  • 提高模型的表达能力:通过并行计算多个注意力分布,模型能够从多个角度理解同一输入,从而获得更丰富的语义信息
  • 提升模型的灵活性和鲁棒性:多头注意力使得模型能够在多个子空间中进行学习,从而减少单一注意力头可能带来的信息损失。

举例:“The quick brown fox jumped over the lazy dog”。我们希望 Transformer 模型能够理解以下关系:

  • 主谓关系:“fox” 和 “jumped”
  • 定语关系:“quick” 修饰 “fox”
  • 位置关系:“over” 与 “jumped”

对模型来说,不同的头用来学习单词之间的不同关系,比如 “jump” 的头 1 用来学习与”fox”的主谓关系,头 2 用来学习与”over”的位置关系。

代码手撕

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
import math

class MultiheadAttention(nn.Module):
def __init__(self, input_dim, dim_k, num_head=8):
super().__init__()
assert dim_k % num_head == 0
self.dk = dim_k // num_head
self.head = num_head
# dk 缩减后放缩因子也要改变
self.norm_factor = 1 / math.sqrt(self.dk)
self.q = nn.Linear(input_dim, dim_k)
self.k = nn.Linear(input_dim, dim_k)
self.v = nn.Linear(input_dim, dim_k)

def forward(self, x, mask=None):
"""
x.shape: [B, T, D]
"""
batch, seqlen, _ = x.shape
Q, K, V = self.q(x), self.k(x), self.v(x)
# (B, T, D)
Q, K, V = (
Q.reshape(batch, seqlen, -1, self.dk).transpose(1, 2),
K.reshape(batch, seqlen, -1, self.dk).transpose(1, 2),
V.reshape(batch, seqlen, -1, self.dk).transpose(1, 2),
)
# (B, H, T, dk)

score = torch.matmul(Q, K.transpose(-2, -1)) * self.norm_factor
if mask is not None:
score += mask * -1e9
output = torch.matmul(torch.softmax(score, dim=-1), V)
output = output.reshape(batch, seqlen, -1)
return output

多头潜在注意力(MLA)

背景

传统 Transformer 采用 MHA(Multi-Head Attention),但是 KV Cache 会成为推理瓶颈。MQA(Multi-Query Attention) 和 GQA(Grouped-Query Attention) 可以一定程度减少 KV Cache,但效果上不如MHA。DeepSeek-V2 设计了一种称为 MLA(Multi-Head Latent Attention) 的注意力机制。MLA 通过低秩key-value 联合压缩,实现了比 MHA 更好的效果并且需要的 kv cache 要小很多。

解决核心问题:使用 KV Cache 时,随着序列长度变长导致显存不足的问题。

低秩 Key-Value 联合压缩

在 MHA 的单头计算过程中,输入 XX (维度 N×dembN \times d_{emb})会先送入三个projection矩阵 WQ,WK,WVW_Q, W_K, W_V (维度为 demb×dkd_{emb} \times d_k) 进行线性变换:
Q=X×WQ(N×dk)K=X×WK(N×dk)V=X×WV(N×dk) Q = X \times W_Q \quad (N \times d_k) \\ K = X \times W_K \quad (N \times d_k) \\ V = X \times W_V \quad (N \times d_k)

变换完成之后生成的 KVK、V 就是传统意义上的 KV Cache。Multi-head 会将 h 个 head 的 QKVQ、K、V 矩阵分别拼接在一起,构成 d×dd \times d 的矩阵,其中 d=h×dkd = h \times d_k。而在 MLA 中没有保留 KV Cache,而是引入一个 C Cache,也即增加了公式:
C=X×WC(N×demb)×(demb×dc) C = X \times W_C \quad (N \times d_{emb}) \times (d_{emb} \times d_c)

WCW_C 的维度不再是 demb×dkd_{emb} \times d_k,而是变成了 demb×dcd_{emb} \times d_cdcd_c 通常远小雨 dd,但比 dkd_k 大(DeepSeek-V2 中 dk=128d_k=128dc=512d_c=512)。有了 CC 之后,为了保证 Attention 计算的正确性,我们可以从概念上再引入 K 和 V,也即得到如下两个关于 KV 的新公式: K=C×WK(N×dc)×(dc×dk)V=C×WV(N×dc)×(dc×dk) K = C \times W_K \quad (N \times d_c) \times (d_c \times d_k) \\ V = C \times W_V \quad (N \times d_c) \times (d_c \times d_k)

此时 WKW_KWVW_V 的维度就变成了 (dc×dk)(d_c \times d_k),最终输出的 K 和 V 维度依旧是 (N×dk)(N \times d_k)。但是请大家切记,MLA 最终的公式中没有出现 KV,我们现在处于将 MHA 转换到MLA的过程中,引入 KKVV 只是方便进行公式推导。MHA 进行 Attention 计算时会将 QQKTK^T 相乘,得到如下公式:
QKT=X×WQ×(C×WK)T=X×(WQ×WKT)×CT=QCT QK^T = X \times W_Q \times (C \times W_K)^T = X \times (W_Q \times W_K^T) \times C^T = Q'C^T

于是我们可以将 WQ×WKTW_Q \times W_K^T 当做新的 projection 矩阵 WQW_Q' (维度 demb×dcd_{emb} \times d_c),此时 Attention 计算就和 KK 没什么关系了。类似的,Attention计算中的 VV 也可以进行消除,从而使得整个 Attention 计算只和 QCQ'、C 相关。MLA更进一步,Attention 计算完成得到 OO 之后,还会将其与 projection 矩阵 WOW_O 进行相乘,WVW_V 也可以与 WOW_O 融合在一起得到新的 projection 矩阵 WOW_O'。最终我们可以得到 MLA Attention 部分的计算公式:
Q=X×WQC=X×WCO=Attention(Q,C)=softmax(QCTdk)CO=O×WO Q' = X \times W_Q' \\ C = X \times W_C \\ O = Attention(Q', C) = softmax(\frac{Q' C^T}{\sqrt{d_k}})C \\ O' = O \times W_O'

兼容 RoPE

然而,当下大模型一般会在Attention计算之前将 QQKK 添加 RoPE,这就导致上面的 QKTQK^T 过程不再成立。但是我们又不能丢弃 RoPE,为了弥补这个缺陷,MLA 又在 QQKK 中额外增加了 drd_r 维度来专门用来存放 RoPE 信息。

Qrope=[X×WC,X×WQR×R]=[X×WQ,QR]WQR:(d×dr)R:(dr×dr)Qrope:(N×(dc+dr)) Q_{rope} = [X \times W_C, X \times W_{QR} \times R] = [X \times W_Q, QR] \\ W_{QR}: (d \times d_r) \quad R: (d_r \times d_r) \quad Q_{rope}: (N \times (d_c + d_r))

其中 WQRW_{QR} 是新增的一个矩阵,RR 是 RoPE 矩阵,用 QRQR 来代表 QQ 新增的子矩阵(N×dr)(N \times d_r)。类似的,KK 也可以扩展为:
Krope=[C×WK,X×WKR×R]=[C×WK,KR]WKR:(d×dr)Krope:(N×(dc+dr)) K_{rope} = [C \times W_K, X \times W_{KR} \times R] = [C \times W_K, KR] \\ W_{KR}: (d \times d_r) \quad K_{rope}: (N \times (d_c + d_r))

WKRW_{KR} 也是新增的矩阵,维度也是 d×drd \times d_r,用 KRKR 来代表 KK 新增的子矩阵(N×dr)(N \times d_r),它保存了 RoPE 信息。之后可以扩展后续公式为: QKT=[X×WQ,QR]×[C×WK,KR]T=[X×WQ,QR]×[(C×WK)TKRT]=[X×WQ,QR]×[WKTCTKRT]=X×(WC×WKT)×CT+QR×KRT=X×WQ×CT+QR×KRT=QCT+QR×KRT \begin{align*} QK^T &= [X \times W_Q, QR] \times [C \times W_K, KR]^T \\ &=[X \times W_Q, QR] \times \begin{bmatrix} (C \times W_K)^T \\ KR^T \end{bmatrix} \\ &=[X \times W_Q, QR] \times \begin{bmatrix} W_K^T C^T \\ KR^T \end{bmatrix} \\ &=X \times (W_C \times W_K^T) \times C^T + QR \times KR^T \\ &=X \times W_Q' \times C^T + QR \times KR^T \\ &=Q'C^T + QR \times KR^T \end{align*}

于是,在 QQKK 新增 drd_r 维度计算 RoPE 信息的情况下,Attention 的计算公式还可以继续复用上面提到的 C Cache 压缩技巧,只需要把第一部分提到的结果与新增的两个子矩阵相乘的结果相加即可。不过我们还需要把 KR 缓存下来,加速后半部分在 Decoding 阶段的计算速度。所以整体 MLA 需要保留 N×dcN \times d_c 的 C Cache,还需要保留 N×drN \times d_r 的 KR Cache,但 drd_r 一般不大,在论文中是 dk/2=64d_k /2 = 64

计算量对比

MLA 相比原始的 MHA 简化了计算公式,压缩了缓存大小,难道天下真有免费的午餐吗?我们仔细对比一下 MHA、MQA、GQA 和 MLA 在推理过程中的模型单层参数量和计算量,如下表。假定:
dk=8khead=64dh=128group=8dc=512N=128k(模型 max token) \begin{align*} &d_k=8k \quad head=64 \quad d_h=128 \\ &group=8 \quad d_c=512 \quad N=128k\text{(模型 max token)} \end{align*}
在统计计算量的时候只考虑了计算的大头,较小计算量的部分丢弃掉也不会产生大的影响。

Attention MHA MQA GQA MLA
参数量计算 WQ=WK=WV=d×h×dkW_Q=W_K=W_V\\ =d \times h \times d_k WQ=WO=d×h×dkWK=WV=d×dkW_Q=W_O=d \times h \times d_k \\ W_K=W_V=d \times d_k WQ=WO=d×h×dkWK=WV=d×g×dkW_Q=W_O=d \times h \times d_k \\ W_K=W_V=d \times g \times d_k WQ=WO=d×h×dcWC=d×dcW_Q=W_O=d \times h \times d_c \\ W_C = d \times d_c
参数量 64M * 4 = 256 M 64 M * 2 + 2M = 130 M 64 M 2 + 8M 2 = 144 M 256 M * 2 + 4M = 516 M
缓存计算 N×h×dk×2N \times h \times d_k \times 2 N×dk×2N \times d_k \times 2 N×g×dk×2N \times g \times d_k \times 2 N×dcN \times d_c
缓存大小 2G 32M 256M 64M
Prefilling 计算量 256T 256T 256T 1000T
Decoding 计算量 2G 2G 2G 8G

可以看出,MLA在参数量和计算量上都比另外三种 Attention 计算方法要大,但 MLA 效果好于 MHA 的原因也是因为计算量增大,但好处就是大大降低了 Decoding 阶段的缓存大小,从而使得解码速度更快。上面结果没有考虑 RoPE 引入,但参数量和计算量增加不影响结论。


参考:大模型推理-MLA


Attention 算法
https://guokent.github.io/deeplearning/attention/
作者
Kent
发布于
2024年10月3日
许可协议