Chapter 2: PyTorch 算法实战 - 完整导学
🎯 本章概览
本章包含 26 道题,覆盖从基础算子到分布式训练的完整算法链路。通过本章学习,你将掌握大模型从算子实现到推理优化的核心技术栈。
学习组划分
为了方便学习,我们将题目按主题分为 5 个学习组:
| 学习组 | 题目范围 | 主题 | 难度 |
|---|---|---|---|
| 2A: 基础算子 | 00-04 | Transformer 组件 | Easy-Medium |
| 2B: 模型架构 | 05-08 | 模型组装 | Medium |
| 2C: 训练技术 | 09-11 | SFT/LoRA/调度器 | Medium |
| 2D: 对齐技术 | 12-14 | RLHF/DPO | Medium-Hard |
| 2E: 推理优化 | 15-25 | 推理加速 | Hard |
📚 推荐学习路径
路径 1:快速入门
适合: 准备面试、快速了解大模型算法
学习顺序:
- 2A: 基础算子(必学,00-04题)→ 理解 Transformer 组件
- 2C: 训练技术(09-10 题)→ 掌握 SFT 与 LoRA
- 2E: 推理优化(15-16 题)→ 学习 FlashAttention 与 Decoding
路径 2:系统学习
适合: 深入理解、全面掌握大模型底层原理
学习顺序:
- 2A: 基础算子 → 理解 Transformer 组件
- 2B: 模型架构 → 组装完整模型
- 2C: 训练技术 → 掌握微调方法
- 2D: 对齐技术 → 理解 RLHF 与 DPO
- 2E: 推理优化 → 学习核心推理加速技术
路径 3:专项突破
根据需求选择:
专注训练:
- 2A(00-04)→ 2C(09-11)→ 2E中的分布式部分(21-25)
专注推理:
- 2A(00-04)→ 2E(15-20)
专注对齐:
- 2A(00-04)→ 2C(09-10)→ 2D(12-14)
🔧 题目区占位初始化机制
什么是占位初始化?
在每道题的题目区(STOP HERE 之前),你会看到类似这样的代码:
def compute_loss(x, y):
# ==========================================
# TODO: 计算损失函数
# ==========================================
# loss = ???
loss = torch.tensor(0.0) # 占位初始化(错误实现,供测试框架捕获)
return loss为什么需要占位初始化?
占位初始化的设计目的是实现"正确失败"(Correct Failure)机制:
避免 Python 错误:如果没有占位符,代码会因为
NameError、UnboundLocalError等错误而无法运行到测试阶段。在断言阶段失败:占位符提供了正确的数据类型和形状,但数值是错误的,这样测试会在数值验证阶段失败,而不是在 Python 语法阶段崩溃。
教育价值:
- ✅ 题目区:占位符返回错误的数值 → 测试失败并显示清晰的错误信息(如"期望 0.5,实际 0.0")
- ✅ 答案区:正确实现 → 测试通过
- ❌ 无占位符:代码直接崩溃 → 学习者看不到有意义的错误提示
占位初始化的典型模式
| 场景 | 占位符示例 | 说明 |
|---|---|---|
| 标量返回 | loss = torch.tensor(0.0) | 返回形状正确但数值错误的标量 |
| 张量返回 | output = torch.zeros_like(input) | 返回形状正确但全零的张量 |
| 多返回值 | return loss, acc | 确保返回值数量正确,避免解包错误 |
| 字典/列表 | states = {0: {}, 1: {}} | 提供正确的数据结构但内容为空 |
示例对比
❌ 没有占位符(会崩溃):
def forward(x):
# TODO: 实现前向传播
# output = ???
pass # 返回 None,导致后续代码崩溃✅ 有占位符(正确失败):
def forward(x):
# TODO: 实现前向传播
# output = ???
output = torch.zeros_like(x) # 占位初始化
return output # 测试会在数值验证阶段失败,显示清晰错误如何使用占位符学习?
- 第一步:阅读题目描述和 TODO 提示
- 第二步:尝试自己实现(忽略占位符)
- 第三步:运行测试,查看错误信息
- 第四步:如果卡住超过 30 分钟,查看答案区的参考实现
- 第五步:理解后删除占位符,重新实现
重要提示: 占位符只是为了让测试框架能够运行,你的目标是替换掉占位符,实现正确的逻辑。
📗 2A: 基础算子与 Transformer 组件
🎯 学习目标
完成本组学习后,你将能够:
- ✅ 理解 Transformer 的基础构建块
- ✅ 掌握 RMSNorm、SwiGLU、RoPE、Attention 的实现
- ✅ 能手写这些算子的前向传播
- ✅ 理解这些算子的设计动机及与传统算子的对比
📚 题目列表 (00-04)
| 题号 | 题目 | 难度 | 核心知识点 |
|---|---|---|---|
| 00 | PyTorch Warmup | Easy | Tensor 操作、自动求导 |
| 01 | RMSNorm Tutorial | Easy | 归一化、广播机制 |
| 02 | SwiGLU Activation | Easy | 激活函数、门控机制 |
| 03 | RoPE Tutorial | Medium | 位置编码、旋转矩阵 |
| 04 | Attention MHA/GQA | Medium | 注意力机制、KV 共享 |
🗺️ 推荐学习顺序
顺序 1:线性学习(推荐初学者)
00 → 01 → 02 → 03 → 04- 适合:初学者、打算系统学习的同学
- 优势:循序渐进,知识连贯
顺序 2:核心优先(适合有基础者)
01 → 04 → 03 → 02 → 00- 适合:有一定深度学习基础、时间紧张的同学
- 优势:优先掌握当前 LLM 最核心的差异化算子(RMSNorm 和 Attention)
📖 详细题目指南
00: PyTorch Warmup
学习重点:
- PyTorch Tensor 的维度变换(
permute、reshape、view) - 自定义 autograd 函数的基本范式
常见错误:
- ❌ 维度顺序错误(
permute的参数填错) - ❌ 忘记保存中间结果(
ctx.save_for_backward) - ❌ 梯度形状与输入不匹配
进阶方向:
- 阅读 PyTorch Autograd 源码或探究
einops库的实现原理
01: RMSNorm Tutorial
学习重点:
- RMSNorm 的数学原理:为什么省去减均值步骤不仅算得快,而且不影响(甚至在某些场景下提升)效果?
- 掌握 PyTorch 中的
mean、sqrt以及广播机制(Broadcasting)
常见错误:
- ❌ 忘记
keepdim=True导致除法时维度广播不匹配 - ❌ 加上
的时机不对(放在平方根外围导致数值不稳定)
进阶方向:
- 对比 LayerNorm、RMSNorm、GroupNorm 的异同
- 挑战用 Triton 手写 Fused RMSNorm(参考第三章)
02: SwiGLU Activation
学习重点:
- 了解 GLU(Gated Linear Unit)系列门控机制
- 对比 SwiGLU 相比于 ReLU/GELU 的优势及计算代价(参数量增加)
常见错误:
- ❌ 维度切分时切到了错误的维度
- ❌ 在计算 FFN 时忘记维度扩展比例(如
或 )
03: RoPE Tutorial
学习重点:
- 掌握 RoPE(Rotary Position Embedding)的核心数学思路:用绝对位置的旋转矩阵来实现相对位置的内积
- 理解旋转矩阵的构建(复数形式与矩阵形式)
常见错误:
- ❌ 频率参数
指数项计算错误(base_freq 设置等) - ❌ 复数乘法和
view_as_real/view_as_complex的转换错误
进阶方向:
- 理解 RoPE 如何处理超长序列上下文扩展(如 PI、YaRN 等插值方案)
04: Attention MHA/GQA
学习重点:
- 理解 MHA(Multi-Head Attention)如何拆解特征维度
- 掌握 GQA(Grouped-Query Attention)的 KV Cache 共享机制
- 分析 MHA、MQA、GQA 在显存与质量之间的权衡
常见错误:
- ❌ 多头切分时
reshape与transpose搞反,导致数据布局错乱 - ❌ GQA 重复 KV 头(
repeat_interleave或expand)时的次数计算错误 - ❌ Attention Causal Mask 的广播(Broadcasting)错误,导致信息泄露
📝 2A 组总结与自我检测
自我检测清单
完成本组学习后,你应该能够:
- [ ] 默写/手写 RMSNorm 的前向传播公式
- [ ] 清楚说明 SwiGLU 为什么需要两个线性层投影
- [ ] 解释 RoPE 在内积时是如何保持相对距离信息的
- [ ] 画出 GQA (Grouped-Query Attention) 的内存流向图,并计算它的 KV Cache 占用
下一步
完成 2A 基础算子后,推荐学习:
- 2B: 模型架构 → 学习如何将这 4 种算子像搭积木一样组装成完整的 LLaMA Transformer Block
- Chapter 3-01~07 → 如果你对性能极致压榨感兴趣,可以去第三章学习这些算子的 Triton GPU 融合实现
💡 学习建议
做题技巧
- 先理解再动手:先阅读题目描述和数学公式
- 参考官方实现:对比 HuggingFace、vLLM 的源码实现
- 测试驱动:先跑通测试用例,再进行代码优化
- 查看答案:卡住超过 30 分钟可以看答案,理解后自己重新实现
- 忽略占位符:占位符只是为了测试框架,你的目标是实现正确的逻辑
常见问题
Q: 没有 GPU 能学吗?
- A: 00-13 题可以纯在 CPU 上运行,14-25 题涉及量化与分布式等特性强烈建议使用 GPU
Q: 遇到 Bug 怎么办?
- A: 请在仓库提 Issue,或者查看 GitHub Discussions 里的已有解答
Q: 为什么题目区的代码运行后测试失败?
- A: 这是正常的!题目区包含占位初始化,目的是让你看到清晰的错误信息。你需要替换占位符,实现正确的逻辑
Q: 可以直接删除占位符吗?
- A: 可以!占位符只是为了测试框架能够运行。你可以删除占位符,从头实现自己的逻辑
🗺️ 其他学习组导航
📗 2B: 模型架构
- 题目范围: 05-08
- 学习目标: 从算子到完整模型
- 核心内容: LLaMA3 Block、MoE Router 与 Load Balancing
📗 2C: 训练技术
- 题目范围: 09-11
- 学习目标: 掌握大模型训练的核心技术
- 核心内容: SFT、LoRA、学习率调度
📗 2D: 对齐技术
- 题目范围: 12-14
- 学习目标: 理解 RLHF 与 DPO
- 核心内容: PPO、DPO、Attention 反向传播
📗 2E: 推理优化
- 题目范围: 15-25
- 学习目标: 掌握大模型推理加速技术
- 核心内容: FlashAttention、PagedAttention、量化、分布式通信与并行计算
🎓 结语
本章是整个仓库的核心,涵盖了从基础算子到分布式训练的完整技术栈。建议按照推荐路径循序渐进,遇到困难时善用答案区的参考实现。
记住:占位初始化是你的朋友,不是敌人。它帮助你看到清晰的错误信息,理解自己的实现与正确答案之间的差距。
祝学习愉快!🚀
