14. Attention Backward Math | 注意力机制反向传播推导与自定义 Autograd (Attention Backward)
难度: Hard | 标签: 微积分, PyTorch, Autograd | 目标人群: 底层算子开发、高阶算法面试
🚀 云端运行环境
本章节的实战代码可以点击以下链接在免费 GPU 算力平台上直接运行:
在第 04 节,我们完成了多头注意力机制(MHA)的前向传播搭建。然而,“大模型训练”与“推理”核心挑战之一在于反向传播 (Backward Pass)。
为什么各大厂商都在持续优化 FlashAttention?因为在训练时,为了计算反向传播的梯度,框架必须在显存中保存尺寸为
在下一节正式进入 FlashAttention 之前,我们必须跨过这座高山:完全搞懂 Attention 反向传播的微积分推导,并抛弃 PyTorch 原生的 .backward(),利用 torch.autograd.Function 写出它的反向梯度计算代码!
这是底层架构岗位的常见考核点。
Step 1: 前向传播回顾与变量定义
为了不打断思路,我们先简洁回顾一下 04 节的单头 Attention 前向公式(省略缩放因子
- 打分矩阵:
- 概率矩阵:
- 最终输出:
张量形状说明:
(序列长度 ,特征维数 )
Step 2: 链式法则逆流而上 (微积分时间)
假设下游的损失函数已经帮我们算好了输出张量
1. 求
2. 求
3. 跨越 Softmax (核心难点) 我们需要从
(注:
4. 求
Step 3: 手撕 PyTorch Autograd Function
现在,把你刚才看到的微积分公式,转化为能够实际运行的代码。我们将继承 torch.autograd.Function。
要求:完成 backward 函数中 TODO 的数学推导代码。你可以使用 ctx.saved_tensors 来获取前向传播时保存的
import torch
import torch.nn.functional as F
import mathclass CustomAttention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v):
# 1. 缩放点积
d_k = q.size(-1)
scale = 1.0 / math.sqrt(d_k)
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
# 2. Softmax 获取概率 P
p = F.softmax(scores, dim=-1)
# 3. 乘上 V 得到输出
out = torch.matmul(p, v)
# 保存反向传播需要用到的张量
ctx.save_for_backward(q, k, v, p)
ctx.scale = scale
return out
@staticmethod
def backward(ctx, dout):
# 提取前向保存的张量
q, k, v, p = ctx.saved_tensors
scale = ctx.scale
# ==========================================
# TODO 1: 求 dV
# ==========================================
# dv = ???
dv = torch.zeros_like(v) # 占位初始化
# ==========================================
# TODO 2: 求 dP
# ==========================================
# dp = ???
dp = torch.zeros_like(p) # 占位初始化
# ==========================================
# TODO 3: 穿过 Softmax 求 dS
# ==========================================
# dp_mul_p = ???
# row_sum = ???
# ds = ???
dp_mul_p = torch.zeros_like(p) # 占位初始化
row_sum = torch.zeros(p.shape[:-1] + (1,), device=p.device, dtype=p.dtype) # 占位初始化
ds = torch.zeros_like(p) # 占位初始化
# ==========================================
# TODO 4: 求 dQ 和 dK (别忘了乘以 scale 缩放因子)
# ==========================================
# dq = ???
# dk = ???
dq = torch.zeros_like(q) # 占位初始化
dk = torch.zeros_like(k) # 占位初始化
return dq, dk, dv# 运行此单元格以测试你的实现
def test_attention_backward():
try:
torch.manual_seed(42)
B, N, d = 2, 8, 16
# 随机初始化张量,必须要求梯度
q = torch.randn(B, N, d, dtype=torch.float64, requires_grad=True)
k = torch.randn(B, N, d, dtype=torch.float64, requires_grad=True)
v = torch.randn(B, N, d, dtype=torch.float64, requires_grad=True)
print("1. 测试前向传播是否正常...")
custom_out = CustomAttention.apply(q, k, v)
# 原生 PyTorch 实现
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d)
ref_out = torch.matmul(F.softmax(scores, dim=-1), v)
assert torch.allclose(custom_out, ref_out), "前向传播结果不一致!"
print("\n2. 进行梯度数值检验 (Gradcheck)...")
test_passed = torch.autograd.gradcheck(CustomAttention.apply, (q, k, v), eps=1e-6, atol=1e-4)
if test_passed:
print("✅ All Tests Passed! Attention 反向传播实现通过测试。")
except NotImplementedError:
print("请先完成 TODO 部分的代码!")
except AssertionError as e:
print(f"❌ 测试失败: {e}")
raise e
except Exception as e:
print(f"❌ 发生异常 (很可能是梯度公式写错了): {e}")
raise e
test_attention_backward()Step 4: 工业界的现实与破局(预告)
看看你刚才写的 ctx.save_for_backward(q, k, v, p)。这行代码在反向传播被调用前,会一直把
如果现在的上下文是
思考题:如果你是底层算法工程师,怎么解决这个问题? 答案预告:不存
!我们在反向传播需要 的时候,拿 和 现场重算一次 (Recomputation)! 通过巧妙的 SRAM 分块加载机制,虽然计算量变大了,但因为避免了把庞大的 写入又读出非常缓慢的 HBM,最终不但不 OOM,速度反而变快了 3 倍!
这就是下一节业界广泛使用的 FlashAttention 所做的事。
🛑 STOP HERE 🛑
请先尝试自己完成代码并跑通测试。
如果你正在 Colab 中运行,并且遇到困难没有思路,可以向下滚动查看参考答案。
参考代码与解析
代码
class CustomAttention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v):
d_k = q.size(-1)
scale = 1.0 / math.sqrt(d_k)
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
p = F.softmax(scores, dim=-1)
out = torch.matmul(p, v)
ctx.save_for_backward(q, k, v, p)
ctx.scale = scale
return out
@staticmethod
def backward(ctx, dout):
q, k, v, p = ctx.saved_tensors
scale = ctx.scale
# 1. dV = P^T * dO
dv = torch.matmul(p.transpose(-2, -1), dout)
# 2. dP = dO * V^T
dp = torch.matmul(dout, v.transpose(-2, -1))
# 3. dS = P * (dP - row_sum(P * dP))
dp_mul_p = dp * p
row_sum = dp_mul_p.sum(dim=-1, keepdim=True)
ds = p * (dp - row_sum)
# 4. dQ 和 dK
dq = torch.matmul(ds, k) * scale
dk = torch.matmul(ds.transpose(-2, -1), q) * scale
return dq, dk, dv解析
Attention 梯度的核心在于处理 Softmax 的雅可比矩阵。对于
