Skip to content

Chapter 2: PyTorch 算法实战 - 完整导学

🎯 本章概览

本章包含 26 道题,覆盖从基础算子到分布式训练的完整算法链路。通过本章学习,你将掌握大模型从算子实现到推理优化的核心技术栈。

学习组划分

为了方便学习,我们将题目按主题分为 5 个学习组:

学习组题目范围主题难度
2A: 基础算子00-04Transformer 组件Easy-Medium
2B: 模型架构05-08模型组装Medium
2C: 训练技术09-11SFT/LoRA/调度器Medium
2D: 对齐技术12-14RLHF/DPOMedium-Hard
2E: 推理优化15-25推理加速Hard

📚 推荐学习路径

路径 1:快速入门

适合: 准备面试、快速了解大模型算法

学习顺序:

  1. 2A: 基础算子(必学,00-04题)→ 理解 Transformer 组件
  2. 2C: 训练技术(09-10 题)→ 掌握 SFT 与 LoRA
  3. 2E: 推理优化(15-16 题)→ 学习 FlashAttention 与 Decoding

路径 2:系统学习

适合: 深入理解、全面掌握大模型底层原理

学习顺序:

  1. 2A: 基础算子 → 理解 Transformer 组件
  2. 2B: 模型架构 → 组装完整模型
  3. 2C: 训练技术 → 掌握微调方法
  4. 2D: 对齐技术 → 理解 RLHF 与 DPO
  5. 2E: 推理优化 → 学习核心推理加速技术

路径 3:专项突破

根据需求选择:

专注训练:

  • 2A(00-04)→ 2C(09-11)→ 2E中的分布式部分(21-25)

专注推理:

  • 2A(00-04)→ 2E(15-20)

专注对齐:

  • 2A(00-04)→ 2C(09-10)→ 2D(12-14)

🔧 题目区占位初始化机制

什么是占位初始化?

在每道题的题目区(STOP HERE 之前),你会看到类似这样的代码:

python
def compute_loss(x, y):
    # ==========================================
    # TODO: 计算损失函数
    # ==========================================
    # loss = ???
    loss = torch.tensor(0.0)  # 占位初始化(错误实现,供测试框架捕获)
    return loss

为什么需要占位初始化?

占位初始化的设计目的是实现"正确失败"(Correct Failure)机制:

  1. 避免 Python 错误:如果没有占位符,代码会因为 NameErrorUnboundLocalError 等错误而无法运行到测试阶段。

  2. 在断言阶段失败:占位符提供了正确的数据类型和形状,但数值是错误的,这样测试会在数值验证阶段失败,而不是在 Python 语法阶段崩溃。

  3. 教育价值

    • 题目区:占位符返回错误的数值 → 测试失败并显示清晰的错误信息(如"期望 0.5,实际 0.0")
    • 答案区:正确实现 → 测试通过
    • 无占位符:代码直接崩溃 → 学习者看不到有意义的错误提示

占位初始化的典型模式

场景占位符示例说明
标量返回loss = torch.tensor(0.0)返回形状正确但数值错误的标量
张量返回output = torch.zeros_like(input)返回形状正确但全零的张量
多返回值return loss, acc确保返回值数量正确,避免解包错误
字典/列表states = {0: {}, 1: {}}提供正确的数据结构但内容为空

示例对比

❌ 没有占位符(会崩溃):

python
def forward(x):
    # TODO: 实现前向传播
    # output = ???
    pass  # 返回 None,导致后续代码崩溃

✅ 有占位符(正确失败):

python
def forward(x):
    # TODO: 实现前向传播
    # output = ???
    output = torch.zeros_like(x)  # 占位初始化
    return output  # 测试会在数值验证阶段失败,显示清晰错误

如何使用占位符学习?

  1. 第一步:阅读题目描述和 TODO 提示
  2. 第二步:尝试自己实现(忽略占位符)
  3. 第三步:运行测试,查看错误信息
  4. 第四步:如果卡住超过 30 分钟,查看答案区的参考实现
  5. 第五步:理解后删除占位符,重新实现

重要提示: 占位符只是为了让测试框架能够运行,你的目标是替换掉占位符,实现正确的逻辑。


📗 2A: 基础算子与 Transformer 组件

🎯 学习目标

完成本组学习后,你将能够:

  • ✅ 理解 Transformer 的基础构建块
  • ✅ 掌握 RMSNorm、SwiGLU、RoPE、Attention 的实现
  • ✅ 能手写这些算子的前向传播
  • ✅ 理解这些算子的设计动机及与传统算子的对比

📚 题目列表 (00-04)

题号题目难度核心知识点
00PyTorch WarmupEasyTensor 操作、自动求导
01RMSNorm TutorialEasy归一化、广播机制
02SwiGLU ActivationEasy激活函数、门控机制
03RoPE TutorialMedium位置编码、旋转矩阵
04Attention MHA/GQAMedium注意力机制、KV 共享

🗺️ 推荐学习顺序

顺序 1:线性学习(推荐初学者)

00 → 01 → 02 → 03 → 04
  • 适合:初学者、打算系统学习的同学
  • 优势:循序渐进,知识连贯

顺序 2:核心优先(适合有基础者)

01 → 04 → 03 → 02 → 00
  • 适合:有一定深度学习基础、时间紧张的同学
  • 优势:优先掌握当前 LLM 最核心的差异化算子(RMSNorm 和 Attention)

📖 详细题目指南

00: PyTorch Warmup

学习重点:

  • PyTorch Tensor 的维度变换(permutereshapeview
  • 自定义 autograd 函数的基本范式

常见错误:

  • ❌ 维度顺序错误(permute 的参数填错)
  • ❌ 忘记保存中间结果(ctx.save_for_backward
  • ❌ 梯度形状与输入不匹配

进阶方向:

  • 阅读 PyTorch Autograd 源码或探究 einops 库的实现原理

01: RMSNorm Tutorial

学习重点:

  • RMSNorm 的数学原理:为什么省去减均值步骤不仅算得快,而且不影响(甚至在某些场景下提升)效果?
  • 掌握 PyTorch 中的 meansqrt 以及广播机制(Broadcasting)

常见错误:

  • ❌ 忘记 keepdim=True 导致除法时维度广播不匹配
  • ❌ 加上 ϵ 的时机不对(放在平方根外围导致数值不稳定)

进阶方向:

  • 对比 LayerNorm、RMSNorm、GroupNorm 的异同
  • 挑战用 Triton 手写 Fused RMSNorm(参考第三章)

02: SwiGLU Activation

学习重点:

  • 了解 GLU(Gated Linear Unit)系列门控机制
  • 对比 SwiGLU 相比于 ReLU/GELU 的优势及计算代价(参数量增加)

常见错误:

  • ❌ 维度切分时切到了错误的维度
  • ❌ 在计算 FFN 时忘记维度扩展比例(如 d4ddd83dd

03: RoPE Tutorial

学习重点:

  • 掌握 RoPE(Rotary Position Embedding)的核心数学思路:用绝对位置的旋转矩阵来实现相对位置的内积
  • 理解旋转矩阵的构建(复数形式与矩阵形式)

常见错误:

  • ❌ 频率参数 θ 指数项计算错误(base_freq 设置等)
  • ❌ 复数乘法和 view_as_real / view_as_complex 的转换错误

进阶方向:

  • 理解 RoPE 如何处理超长序列上下文扩展(如 PI、YaRN 等插值方案)

04: Attention MHA/GQA

学习重点:

  • 理解 MHA(Multi-Head Attention)如何拆解特征维度
  • 掌握 GQA(Grouped-Query Attention)的 KV Cache 共享机制
  • 分析 MHA、MQA、GQA 在显存与质量之间的权衡

常见错误:

  • ❌ 多头切分时 reshapetranspose 搞反,导致数据布局错乱
  • ❌ GQA 重复 KV 头(repeat_interleaveexpand)时的次数计算错误
  • ❌ Attention Causal Mask 的广播(Broadcasting)错误,导致信息泄露

📝 2A 组总结与自我检测

自我检测清单

完成本组学习后,你应该能够:

  • [ ] 默写/手写 RMSNorm 的前向传播公式
  • [ ] 清楚说明 SwiGLU 为什么需要两个线性层投影
  • [ ] 解释 RoPE 在内积时是如何保持相对距离信息的
  • [ ] 画出 GQA (Grouped-Query Attention) 的内存流向图,并计算它的 KV Cache 占用

下一步

完成 2A 基础算子后,推荐学习:

  • 2B: 模型架构 → 学习如何将这 4 种算子像搭积木一样组装成完整的 LLaMA Transformer Block
  • Chapter 3-01~07 → 如果你对性能极致压榨感兴趣,可以去第三章学习这些算子的 Triton GPU 融合实现

💡 学习建议

做题技巧

  1. 先理解再动手:先阅读题目描述和数学公式
  2. 参考官方实现:对比 HuggingFace、vLLM 的源码实现
  3. 测试驱动:先跑通测试用例,再进行代码优化
  4. 查看答案:卡住超过 30 分钟可以看答案,理解后自己重新实现
  5. 忽略占位符:占位符只是为了测试框架,你的目标是实现正确的逻辑

常见问题

Q: 没有 GPU 能学吗?

  • A: 00-13 题可以纯在 CPU 上运行,14-25 题涉及量化与分布式等特性强烈建议使用 GPU

Q: 遇到 Bug 怎么办?

  • A: 请在仓库提 Issue,或者查看 GitHub Discussions 里的已有解答

Q: 为什么题目区的代码运行后测试失败?

  • A: 这是正常的!题目区包含占位初始化,目的是让你看到清晰的错误信息。你需要替换占位符,实现正确的逻辑

Q: 可以直接删除占位符吗?

  • A: 可以!占位符只是为了测试框架能够运行。你可以删除占位符,从头实现自己的逻辑

🗺️ 其他学习组导航

📗 2B: 模型架构

  • 题目范围: 05-08
  • 学习目标: 从算子到完整模型
  • 核心内容: LLaMA3 Block、MoE Router 与 Load Balancing

📗 2C: 训练技术

  • 题目范围: 09-11
  • 学习目标: 掌握大模型训练的核心技术
  • 核心内容: SFT、LoRA、学习率调度

📗 2D: 对齐技术

  • 题目范围: 12-14
  • 学习目标: 理解 RLHF 与 DPO
  • 核心内容: PPO、DPO、Attention 反向传播

📗 2E: 推理优化

  • 题目范围: 15-25
  • 学习目标: 掌握大模型推理加速技术
  • 核心内容: FlashAttention、PagedAttention、量化、分布式通信与并行计算

🎓 结语

本章是整个仓库的核心,涵盖了从基础算子到分布式训练的完整技术栈。建议按照推荐路径循序渐进,遇到困难时善用答案区的参考实现。

记住:占位初始化是你的朋友,不是敌人。它帮助你看到清晰的错误信息,理解自己的实现与正确答案之间的差距。

祝学习愉快!🚀

Released under the MIT License.