Skip to content

04. Attention MHA GQA | 注意力机制与键值缓存 (MHA / GQA / MQA)

难度: Medium | 标签: 基础架构, PyTorch, 推理优化 | 目标人群: 模型微调与工程部署

🚀 云端运行环境

本章节的实战代码可以点击以下链接在免费 GPU 算力平台上直接运行:

Open In ColabOpen In Studio (国内推荐:魔搭社区免费实例)

欢迎来到 LLM-LeetCode!本节我们将深入解析大语言模型的核心组件:注意力机制,并实现支持 KV Cache 和 GQA (Grouped-Query Attention) 的代码。

Step 1: 核心思想与痛点

在大语言模型中,注意力机制 (Attention) 决定了模型如何“回顾”并提取历史上下文的信息。随着模型层数加深和序列变长,Attention 模块在推理阶段面临极大的性能挑战。

什么是 KV Cache?为什么它是性能瓶颈? 在自回归生成中,每次生成第 N 个 Token 时,我们需要计算它与前面 N1 个 Token 的相关性。为了避免重复计算前 N1 个 Token 的特征,我们将其投影后的 Key 和 Value 张量缓存(Cache)在显存中,当前步直接拼接读取。

然而,读取巨量的 KV Cache 会面临严重的显存容量瓶颈内存带宽瓶颈 (Memory-bound),导致推理极慢。

从 MHA 到 GQA:大模型架构的进化

  • MHA (Multi-Head Attention): 标准的多头注意力。每个 Query 头都有自己专属的 Key 和 Value 头。其巨大的 KV Cache 让推理寸步难行。
  • MQA (Multi-Query Attention): 所有的 Query 头共享同一个 Key 和 Value 头。极大地减少了 KV Cache 的占用,但由于表达能力锐减,模型效果往往打折。
  • GQA (Grouped-Query Attention): LLaMA-2/3 采用的折中方案。将 Query 头分组,每组共享一个 Key 和 Value 头。这在模型效果和显存占用之间取得了良好的工程平衡

Step 2: 核心公式与张量维度

注意力计算公式:

Attention(Q,K,V)=Softmax(QKTdk)V

张量维度追踪 (Shape Tracking) - 算法工程师的灵魂: 假设 Batch=B, Seq_len=S, Num_Heads=H, Head_Dim=D

  1. 线性投影后:Q 形状为 [B, S, H * D]
  2. 切分多头后:转置为 [B, H, S, D]
  3. 注意力分数计算:Q @ K^T -> [B, H, S, D] @ [B, H, D, S] -> [B, H, S, S]
  4. 乘以 Value:Scores @ V -> [B, H, S, S] @ [B, H, S, D] -> [B, H, S, D]
  5. 最后合并多头:转置回 [B, S, H, D]view[B, S, H * D]

Step 3: 工业界源码映射

在真实的工业界代码中,这段逻辑在哪里?

  • HuggingFace LLaMA: transformers/models/llama/modeling_llama.py 中的 LlamaAttention 类。
  • vLLM (推理框架): 核心关注它的 PagedAttention 实现,用来解决这里 KV Cache 的显存碎片化问题。

Step 4: 动手实战

要求:请补全下方 GroupedQueryAttentionforward 函数中的 TODO 部分,实现:

  1. 张量的多头切分与 Reshape
  2. KV Cache 的拼接逻辑
  3. 注意力分数的计算
python
import torch
import torch.nn as nn
import math
python
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    将 KV 头复制 n_rep 次,以匹配 Query 头的数量 (GQA/MQA 需要)
    """
    batch, num_kv_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)

class GroupedQueryAttention(nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int, num_kv_heads: int = None):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
        
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.head_dim = hidden_dim // num_heads
        
        # 定义投影矩阵
        self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_dim, bias=False)

    def forward(
        self, 
        x: torch.Tensor, 
        attention_mask: torch.Tensor = None, 
        kv_cache: tuple[torch.Tensor, torch.Tensor] = None
    ):
        batch_size, seq_len, _ = x.shape
        
        # 1. 线性投影
        xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
        
        # ==========================================
        # TODO 1: Reshape xq, xk, xv 以适配多头注意力计算
        # ==========================================
        # xq = ???
        # xk = ???
        # xv = ???
        
        
        # ==========================================
        # TODO 2: 处理 KV Cache
        # ==========================================
        if kv_cache is not None:
            k_cache, v_cache = kv_cache
            # xk = ???
            # xv = ???
            
        new_kv_cache = (xk, xv)
        
        # 通过 repeat_kv 把 GQA 的 KV 头数扩充到和 Query 数量一致
        xk = repeat_kv(xk, self.num_queries_per_kv)
        xv = repeat_kv(xv, self.num_queries_per_kv)
        
        # ==========================================
        # TODO 3: 计算注意力分数 (Scaled Dot-Product)
        # ==========================================
        # scores = ???
        
        if attention_mask is not None:
            scores = scores + attention_mask
            
        # probs = ???
        # output = ???
        
        
        # ==========================================
        # TODO 4: 恢复形状并输出
        # [B, H, S, D] -> [B, S, H*D]
        # ==========================================
        # output = ???
        
        # return self.o_proj(output), new_kv_cache
        pass
python
# 运行此单元格以测试你的实现
def test_mha_mqa_gqa():
    try:
        batch_size, seq_len, hidden_dim, num_heads = 2, 16, 128, 4
        
        # 1. 测试 MHA
        print("Testing MHA (Multi-Head Attention)...")
        mha = GroupedQueryAttention(hidden_dim, num_heads, num_kv_heads=num_heads)
        x = torch.randn(batch_size, seq_len, hidden_dim)
        out, _ = mha(x)
        assert out.shape == (batch_size, seq_len, hidden_dim), "MHA 输出形状错误!"
        
        # 2. 测试 GQA
        print("Testing GQA (Grouped-Query Attention)...")
        gqa = GroupedQueryAttention(hidden_dim, num_heads, num_kv_heads=2)
        out, _ = gqa(x)
        assert out.shape == (batch_size, seq_len, hidden_dim), "GQA 输出形状错误!"
        
        # 3. 测试 KV Cache
        print("Testing KV Cache Autoregressive Decoding...")
        prefill_len = 5
        x_prefill = torch.randn(batch_size, prefill_len, hidden_dim)
        _, kv_cache = mha(x_prefill)
        
        x_decode = torch.randn(batch_size, 1, hidden_dim)
        out_decode, new_kv_cache = mha(x_decode, kv_cache=kv_cache)
        assert new_kv_cache[0].shape == (batch_size, num_heads, prefill_len + 1, hidden_dim // num_heads), "KV Cache 更新错误!"
        
        print("\n✅ All Tests Passed! Attention 算子实现通过测试。")
    except NotImplementedError:
        print("请先完成 TODO 部分的代码!")
    except Exception as e:
        print(f"\n❌ 测试失败,请检查张量维度: {e}")

test_mha_mqa_gqa()

🛑 STOP HERE 🛑









请先尝试自己完成代码并跑通测试。
如果你正在 Colab 中运行,并且遇到困难没有思路,可以向下滚动查看参考答案。










参考代码与解析

代码

python
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch, num_kv_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)

class GroupedQueryAttention(nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int, num_kv_heads: int = None):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
        
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.head_dim = hidden_dim // num_heads
        
        self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_dim, bias=False)

    def forward(
        self, 
        x: torch.Tensor, 
        attention_mask: torch.Tensor = None, 
        kv_cache: tuple[torch.Tensor, torch.Tensor] = None
    ):
        batch_size, seq_len, _ = x.shape
        
        xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
        
        # TODO 1: Reshape 为多头形式
        xq = xq.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        xk = xk.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        xv = xv.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        
        # TODO 2: 处理 KV Cache
        if kv_cache is not None:
            k_cache, v_cache = kv_cache
            xk = torch.cat([k_cache, xk], dim=2)
            xv = torch.cat([v_cache, xv], dim=2)
            
        new_kv_cache = (xk, xv)
        
        xk = repeat_kv(xk, self.num_queries_per_kv)
        xv = repeat_kv(xv, self.num_queries_per_kv)
        
        # TODO 3: 计算注意力分数
        scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            scores = scores + attention_mask
            
        probs = torch.nn.functional.softmax(scores, dim=-1)
        output = torch.matmul(probs, xv)
        
        # TODO 4: 恢复形状
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        
        return self.o_proj(output), new_kv_cache

解析

1. TODO 1 (多头切分与维度转置)

  • 切分多头: 使用 view(batch_size, seq_len, num_heads, head_dim) 将线性投影后的张量从 [B, S, H*D] 重塑为 [B, S, H, D],其中 H 是头数,D 是每个头的维度。
  • 维度转置: 通过 .transpose(1, 2) 将形状从 [B, S, H, D] 转为 [B, H, S, D],这是注意力计算的标准格式,方便后续的矩阵乘法。
  • GQA 的 KV 头数: 注意 xkxv 使用 num_kv_heads 而不是 num_heads,这是 GQA 的核心区别。例如 LLaMA-2 70B 使用 64 个 Query 头但只有 8 个 KV 头。
  • 工程细节: 为什么要 transpose?因为注意力分数计算 Q @ K^T 需要在 [S, D][D, S] 维度上进行矩阵乘法,将 heads 维度放在第二个位置可以让 batch 和 heads 维度自动广播。

2. TODO 2 (KV Cache 拼接)

  • 自回归生成场景: 在推理时,每次只生成一个新 token,但需要用到之前所有 token 的 Key 和 Value。如果每次都重新计算,时间复杂度是 O(N2)
  • Cache 机制: 将历史的 k_cachev_cache 与当前步的 xkxvseq_len 维度(dim=2)拼接,形状从 [B, H, old_len, D] 变为 [B, H, old_len+1, D]
  • 显存优化: GQA 的 KV Cache 只需存储 num_kv_heads 个头,而不是 num_heads 个。例如 LLaMA-2 70B 的 KV Cache 显存占用是 MHA 的 1/8。
  • 工程陷阱: 必须在 repeat_kv 之前进行拼接,否则会重复缓存已扩展的 KV,导致显存浪费。

3. TODO 3 (Scaled Dot-Product Attention)

  • 注意力分数计算: scores = Q @ K^T / sqrt(d_k),其中 xk.transpose(2, 3)[B, H, S, D] 转为 [B, H, D, S],与 xq[B, H, S, D] 相乘得到 [B, H, S, S] 的注意力矩阵。
  • 缩放因子: 除以 sqrt(head_dim) 是为了防止点积结果过大导致 softmax 梯度消失。这是 Transformer 原论文的核心设计。
  • Mask 机制: attention_mask 通常是一个下三角矩阵(Causal Mask),用 -inf 填充上三角部分,确保当前 token 只能看到之前的 token。
  • Softmax 归一化: 在最后一个维度(dim=-1)上进行 softmax,将注意力分数转为概率分布。
  • 加权求和: output = probs @ V 将注意力权重与 Value 相乘,得到加权后的特征表示。

4. TODO 4 (多头合并与输出投影)

  • 维度转置: .transpose(1, 2)[B, H, S, D] 转回 [B, S, H, D]
  • 内存连续性: .contiguous() 确保张量在内存中是连续存储的,这是 view 操作的前提。如果不调用 contiguous()view 可能会报错。
  • 合并多头: .view(batch_size, seq_len, -1)[B, S, H, D] 展平为 [B, S, H*D],其中 -1 自动推断为 num_heads * head_dim
  • 输出投影: 通过 o_proj 线性层将多头特征映射回 hidden_dim,这是标准 Transformer 的最后一步。

进阶思考:GQA 的延迟扩充 (Lazy Expansion)

  • 为什么不直接缓存扩充后的 KV? 如果在缓存时就用 repeat_kv 扩充,显存占用会和 MHA 一样大,失去了 GQA 的优势。
  • 正确做法: 只缓存原始的 num_kv_heads 个头,在每次前向传播时临时扩充。虽然增加了计算量,但由于注意力计算是 Memory-bound(受限于显存带宽而非计算速度),这个开销可以忽略。
  • 工业实践: vLLM、TensorRT-LLM 等推理框架都采用这种延迟扩充策略,在 70B 模型上可以节省数十 GB 的 KV Cache 显存。

Released under the MIT License.