KV Cache

模型推理

在大模型推理的时候,我们最看重的是两个指标:

  1. 吞吐量: 吞吐量代表了大模型单位时间内处理 Tokens 的数量,这里的 Tokens 一般指输入和输出Tokens数量的总和。在 Infra 条件一样的情况下,吞吐量越大,大模型推理系统的资源利用效率更高,推理的成本也就是更低
  2. 时延: 时延是针对最终用户而言的。时延用户平均收到每个Token所花费的时间,业务通常认为这个数值如果小于50ms,则用户可以收获相对良好的使用体验。

Transformer 模型的显著特性在于,每次推理过程仅生成一个 token 作为输出。该 token 随后与之前生成的所有 tokens 结合,形成下一轮推理的输入。这个过程不断重复,直至生成完整的输出序列。然而,由于每轮的输入仅比上一轮多出一个 token,导致了大量的冗余计算。KV Cache技术的出现正是为了解决这一问题,它通过存储可复用的键值向量,有效避免了这些不必要的重复计算,显著提高了推理的效率。

无 KV Cache 计算流程

  1. 假设输入 token 为 [a, b, c, d],下两个预测的 token 为 e, f。已知的Q、K、V矩阵分别为:
  • Q=WQX=[qa,qb,qc,qd],shape=[N,d]Q=W_QX=[q_a, q_b, q_c, q_d], \text{shape}=[N, d]
  • K=WKX=[ka,kb,kc,kd],shape=[N,d]K=W_KX=[k_a, k_b, k_c, k_d], \text{shape}=[N, d]
  • V=WVX=[va,vb,vc,vd],shape=[N,d]V=W_VX=[v_a, v_b, v_c, v_d], \text{shape}=[N, d]
  1. 计算 Attention 分数,时间复杂度 O(n2)O(n^2)
Attn=Softmax(QKTd) Attn=Softmax(\frac{Q \cdot K^T}{\sqrt{d}})
  1. 结束这一轮后,将生成的 e 并入前面的输入,作为下一轮的输入,此时输入 token 为 [a, b, c, d, e],重新计算得到Q、K、V矩阵,时间复杂度仍为 O(n2)O(n^2)
  • Q=WQX=[qa,qb,qc,qd,qe],shape=[N+1,d]Q=W_QX=[q_a, q_b, q_c, q_d, q_e], \text{shape}=[N+1, d]
  • K=WKX=[ka,kb,kc,kd,ke],shape=[N+1,d]K=W_KX=[k_a, k_b, k_c, k_d, k_e], \text{shape}=[N+1, d]
  • V=WVX=[va,vb,vc,vd,ve],shape=[N+1,d]V=W_VX=[v_a, v_b, v_c, v_d, v_e], \text{shape}=[N+1, d]
  1. 继续计算 Attention 分数,接着预测下一个 token 为 f,时间复杂度仍为 O(n2)O(n^2)

  2. 以此类推…

带 KV Cache 计算流程

分为两个阶段:预填充阶段、解码阶段

仍然假设输入 token 为 [a, b, c, d],下两个预测的 token 为 e, f。

预填充阶段

这一步要计算初始输入 token 的 Attention 分数,时间复杂度也是 O(n2)O(n^2)

  • Q=WQX=[qa,qb,qc,qd],shape=[N,d]Q=W_QX=[q_a, q_b, q_c, q_d], \text{shape}=[N, d]
  • K=WKX=[ka,kb,kc,kd],shape=[N,d]K=W_KX=[k_a, k_b, k_c, k_d], \text{shape}=[N, d]
  • V=WVX=[va,vb,vc,vd],shape=[N,d]V=W_VX=[v_a, v_b, v_c, v_d], \text{shape}=[N, d]
Attn=Softmax(QKTd) Attn=Softmax(\frac{Q \cdot K^T}{\sqrt{d}})

这里会将得到的 K、V 矩阵做缓存,方便在解码阶段使用

解码阶段

通过上一轮 Attention 结果,可以得到了下一个预测 token e。但在下一轮 Attention 时,仅需要计算 token e 和其余已有的 token 的 Attention 分数即可

  1. 首先通过 WQWKWVW_Q、W_K、W_V 得到 qekeveq_e、k_e、v_e,并将 kevek_e、v_e 追加入缓存中的 K、V 矩阵,时间复杂度 O(n)O(n)
  2. qeq_e 与所有缓存的 K_cache(包括 k_a, k_b, k_c, k_d, k_e)计算注意力分数,时间复杂度 O(n)O(n)
  3. 通过 Softmax(Qe@KcacheTd)Softmax(\frac{Q_e @ K_{cache}^T}{\sqrt{d}}) 得到新的注意力分数,再与 V_cache 相乘得出结果,时间复杂度 O(n)O(n)

参考:一文读懂大模型KV Cache


KV Cache
https://guokent.github.io/deeplearning/kvcache/
作者
Kent
发布于
2025年3月8日
许可协议