04. Attention MHA GQA | 注意力机制与键值缓存 (MHA / GQA / MQA)
难度: Medium | 标签: 基础架构, PyTorch, 推理优化 | 目标人群: 模型微调与工程部署
🚀 云端运行环境
本章节的实战代码可以点击以下链接在免费 GPU 算力平台上直接运行:
欢迎来到 LLM-LeetCode!本节我们将深入解析大语言模型的核心组件:注意力机制,并实现支持 KV Cache 和 GQA (Grouped-Query Attention) 的代码。
Step 1: 核心思想与痛点
在大语言模型中,注意力机制 (Attention) 决定了模型如何“回顾”并提取历史上下文的信息。随着模型层数加深和序列变长,Attention 模块在推理阶段面临极大的性能挑战。
什么是 KV Cache?为什么它是性能瓶颈? 在自回归生成中,每次生成第
个 Token 时,我们需要计算它与前面 个 Token 的相关性。为了避免重复计算前 个 Token 的特征,我们将其投影后的 Key 和 Value 张量缓存(Cache)在显存中,当前步直接拼接读取。 然而,读取巨量的 KV Cache 会面临严重的显存容量瓶颈和内存带宽瓶颈 (Memory-bound),导致推理极慢。
从 MHA 到 GQA:大模型架构的进化
- MHA (Multi-Head Attention): 标准的多头注意力。每个 Query 头都有自己专属的 Key 和 Value 头。其巨大的 KV Cache 让推理寸步难行。
- MQA (Multi-Query Attention): 所有的 Query 头共享同一个 Key 和 Value 头。极大地减少了 KV Cache 的占用,但由于表达能力锐减,模型效果往往打折。
- GQA (Grouped-Query Attention): LLaMA-2/3 采用的折中方案。将 Query 头分组,每组共享一个 Key 和 Value 头。这在模型效果和显存占用之间取得了良好的工程平衡。
Step 2: 核心公式与张量维度
注意力计算公式:
张量维度追踪 (Shape Tracking) - 算法工程师的灵魂: 假设 Batch=B, Seq_len=S, Num_Heads=H, Head_Dim=D
- 线性投影后:
Q形状为[B, S, H * D] - 切分多头后:转置为
[B, H, S, D] - 注意力分数计算:
Q @ K^T->[B, H, S, D] @ [B, H, D, S]->[B, H, S, S] - 乘以 Value:
Scores @ V->[B, H, S, S] @ [B, H, S, D]->[B, H, S, D] - 最后合并多头:转置回
[B, S, H, D]并view成[B, S, H * D]。
Step 3: 工业界源码映射
在真实的工业界代码中,这段逻辑在哪里?
- HuggingFace LLaMA:
transformers/models/llama/modeling_llama.py中的LlamaAttention类。 - vLLM (推理框架): 核心关注它的 PagedAttention 实现,用来解决这里 KV Cache 的显存碎片化问题。
Step 4: 动手实战
要求:请补全下方 GroupedQueryAttention 的 forward 函数中的 TODO 部分,实现:
- 张量的多头切分与 Reshape
- KV Cache 的拼接逻辑
- 注意力分数的计算
import torch
import torch.nn as nn
import mathdef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
将 KV 头复制 n_rep 次,以匹配 Query 头的数量 (GQA/MQA 需要)
"""
batch, num_kv_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)
class GroupedQueryAttention(nn.Module):
def __init__(self, hidden_dim: int, num_heads: int, num_kv_heads: int = None):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.head_dim = hidden_dim // num_heads
# 定义投影矩阵
self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_dim, bias=False)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor = None,
kv_cache: tuple[torch.Tensor, torch.Tensor] = None
):
batch_size, seq_len, _ = x.shape
# 1. 线性投影
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# ==========================================
# TODO 1: Reshape xq, xk, xv 以适配多头注意力计算
# ==========================================
# xq = ???
# xk = ???
# xv = ???
# ==========================================
# TODO 2: 处理 KV Cache
# ==========================================
if kv_cache is not None:
k_cache, v_cache = kv_cache
# xk = ???
# xv = ???
new_kv_cache = (xk, xv)
# 通过 repeat_kv 把 GQA 的 KV 头数扩充到和 Query 数量一致
xk = repeat_kv(xk, self.num_queries_per_kv)
xv = repeat_kv(xv, self.num_queries_per_kv)
# ==========================================
# TODO 3: 计算注意力分数 (Scaled Dot-Product)
# ==========================================
# scores = ???
if attention_mask is not None:
scores = scores + attention_mask
# probs = ???
# output = ???
# ==========================================
# TODO 4: 恢复形状并输出
# [B, H, S, D] -> [B, S, H*D]
# ==========================================
# output = ???
# return self.o_proj(output), new_kv_cache
pass# 运行此单元格以测试你的实现
def test_mha_mqa_gqa():
try:
batch_size, seq_len, hidden_dim, num_heads = 2, 16, 128, 4
# 1. 测试 MHA
print("Testing MHA (Multi-Head Attention)...")
mha = GroupedQueryAttention(hidden_dim, num_heads, num_kv_heads=num_heads)
x = torch.randn(batch_size, seq_len, hidden_dim)
out, _ = mha(x)
assert out.shape == (batch_size, seq_len, hidden_dim), "MHA 输出形状错误!"
# 2. 测试 GQA
print("Testing GQA (Grouped-Query Attention)...")
gqa = GroupedQueryAttention(hidden_dim, num_heads, num_kv_heads=2)
out, _ = gqa(x)
assert out.shape == (batch_size, seq_len, hidden_dim), "GQA 输出形状错误!"
# 3. 测试 KV Cache
print("Testing KV Cache Autoregressive Decoding...")
prefill_len = 5
x_prefill = torch.randn(batch_size, prefill_len, hidden_dim)
_, kv_cache = mha(x_prefill)
x_decode = torch.randn(batch_size, 1, hidden_dim)
out_decode, new_kv_cache = mha(x_decode, kv_cache=kv_cache)
assert new_kv_cache[0].shape == (batch_size, num_heads, prefill_len + 1, hidden_dim // num_heads), "KV Cache 更新错误!"
print("\n✅ All Tests Passed! Attention 算子实现通过测试。")
except NotImplementedError:
print("请先完成 TODO 部分的代码!")
except Exception as e:
print(f"\n❌ 测试失败,请检查张量维度: {e}")
test_mha_mqa_gqa()🛑 STOP HERE 🛑
请先尝试自己完成代码并跑通测试。
如果你正在 Colab 中运行,并且遇到困难没有思路,可以向下滚动查看参考答案。
参考代码与解析
代码
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_kv_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)
class GroupedQueryAttention(nn.Module):
def __init__(self, hidden_dim: int, num_heads: int, num_kv_heads: int = None):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.head_dim = hidden_dim // num_heads
self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_dim, bias=False)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor = None,
kv_cache: tuple[torch.Tensor, torch.Tensor] = None
):
batch_size, seq_len, _ = x.shape
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# TODO 1: Reshape 为多头形式
xq = xq.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
xk = xk.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
xv = xv.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
# TODO 2: 处理 KV Cache
if kv_cache is not None:
k_cache, v_cache = kv_cache
xk = torch.cat([k_cache, xk], dim=2)
xv = torch.cat([v_cache, xv], dim=2)
new_kv_cache = (xk, xv)
xk = repeat_kv(xk, self.num_queries_per_kv)
xv = repeat_kv(xv, self.num_queries_per_kv)
# TODO 3: 计算注意力分数
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
scores = scores + attention_mask
probs = torch.nn.functional.softmax(scores, dim=-1)
output = torch.matmul(probs, xv)
# TODO 4: 恢复形状
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.o_proj(output), new_kv_cache解析
1. TODO 1 (多头切分与维度转置)
- 切分多头: 使用
view(batch_size, seq_len, num_heads, head_dim)将线性投影后的张量从[B, S, H*D]重塑为[B, S, H, D],其中H是头数,D是每个头的维度。 - 维度转置: 通过
.transpose(1, 2)将形状从[B, S, H, D]转为[B, H, S, D],这是注意力计算的标准格式,方便后续的矩阵乘法。 - GQA 的 KV 头数: 注意
xk和xv使用num_kv_heads而不是num_heads,这是 GQA 的核心区别。例如 LLaMA-2 70B 使用 64 个 Query 头但只有 8 个 KV 头。 - 工程细节: 为什么要 transpose?因为注意力分数计算
Q @ K^T需要在[S, D]和[D, S]维度上进行矩阵乘法,将 heads 维度放在第二个位置可以让 batch 和 heads 维度自动广播。
2. TODO 2 (KV Cache 拼接)
- 自回归生成场景: 在推理时,每次只生成一个新 token,但需要用到之前所有 token 的 Key 和 Value。如果每次都重新计算,时间复杂度是
。 - Cache 机制: 将历史的
k_cache和v_cache与当前步的xk、xv在seq_len维度(dim=2)拼接,形状从[B, H, old_len, D]变为[B, H, old_len+1, D]。 - 显存优化: GQA 的 KV Cache 只需存储
num_kv_heads个头,而不是num_heads个。例如 LLaMA-2 70B 的 KV Cache 显存占用是 MHA 的 1/8。 - 工程陷阱: 必须在
repeat_kv之前进行拼接,否则会重复缓存已扩展的 KV,导致显存浪费。
3. TODO 3 (Scaled Dot-Product Attention)
- 注意力分数计算:
scores = Q @ K^T / sqrt(d_k),其中xk.transpose(2, 3)将[B, H, S, D]转为[B, H, D, S],与xq的[B, H, S, D]相乘得到[B, H, S, S]的注意力矩阵。 - 缩放因子: 除以
sqrt(head_dim)是为了防止点积结果过大导致 softmax 梯度消失。这是 Transformer 原论文的核心设计。 - Mask 机制:
attention_mask通常是一个下三角矩阵(Causal Mask),用-inf填充上三角部分,确保当前 token 只能看到之前的 token。 - Softmax 归一化: 在最后一个维度(
dim=-1)上进行 softmax,将注意力分数转为概率分布。 - 加权求和:
output = probs @ V将注意力权重与 Value 相乘,得到加权后的特征表示。
4. TODO 4 (多头合并与输出投影)
- 维度转置:
.transpose(1, 2)将[B, H, S, D]转回[B, S, H, D]。 - 内存连续性:
.contiguous()确保张量在内存中是连续存储的,这是view操作的前提。如果不调用contiguous(),view可能会报错。 - 合并多头:
.view(batch_size, seq_len, -1)将[B, S, H, D]展平为[B, S, H*D],其中-1自动推断为num_heads * head_dim。 - 输出投影: 通过
o_proj线性层将多头特征映射回hidden_dim,这是标准 Transformer 的最后一步。
进阶思考:GQA 的延迟扩充 (Lazy Expansion)
- 为什么不直接缓存扩充后的 KV? 如果在缓存时就用
repeat_kv扩充,显存占用会和 MHA 一样大,失去了 GQA 的优势。 - 正确做法: 只缓存原始的
num_kv_heads个头,在每次前向传播时临时扩充。虽然增加了计算量,但由于注意力计算是 Memory-bound(受限于显存带宽而非计算速度),这个开销可以忽略。 - 工业实践: vLLM、TensorRT-LLM 等推理框架都采用这种延迟扩充策略,在 70B 模型上可以节省数十 GB 的 KV Cache 显存。
