Skip to content

02. SwiGLU Activation | 激活函数与门控机制 (SwiGLU Activation)

难度: Easy | 标签: 模型架构, 激活函数 | 目标人群: 模型微调与工程部署

🚀 云端运行环境

本章节的实战代码可以点击以下链接在免费 GPU 算力平台上直接运行:

Open In ColabOpen In Studio (国内推荐:魔搭社区免费实例)

在组装 LLaMA-3 的那一节中,我们使用了 SwiGLU 作为 MLP 的激活函数。为什么所有主流大模型(LLaMA, Qwen, Mistral, PaLM)都在抛弃 ReLU/GELU 而转向 SwiGLU? 本节我们将深入推导 SwiGLU 的设计原理,特别是如何调整隐藏层的维度,以保证参数量与标准 Transformer 严格对齐。这是面试中非常经典的架构推导题

相关阅读: 本节使用纯 PyTorch 实现了算法逻辑与数学推导。 如果你想学习工业界如何打破该算子的 Memory Bound (访存瓶颈),请前往 Triton 篇: ../03_CUDA_and_Triton_Kernels/02_Triton_Fused_SwiGLU.ipynb

Step 1: 核心思想与痛点

什么是 GLU (Gated Linear Unit)? 传统 MLP 是 Wdown(σ(Wupx))。 门控机制 (GLU) 引入了“两条路”:一条路做激活(作为门控开关),另一条路保持线性,然后两者逐元素相乘(Hadamard Product)。 公式:GLU(x,W1,W2)=(xW1σ(xW2))Wdown。 这种机制类似于 LSTM 中的遗忘门,极大地增强了模型捕捉复杂模式的能力。

什么是 SwiGLU? 就是把 GLU 中的激活函数 σ 换成了 Swish(即 xSigmoid(βx),在 PyTorch 中 β=1 时等于 SiLU)。

Step 2: 核心数学机制:参数量对齐

典型的面试问题:

“在 GPT-2 中,隐藏层维度通常是输入维度 d 的 4 倍(即 4d)。但在使用 SwiGLU 的 LLaMA 中,为什么隐藏层维度变成了 83d 并向上取整?”

推导过程:

  1. 标准 MLP 参数量: 输入为 d,隐藏层为 h。 有两个投影矩阵(升维 dh,降维 hd)。 总参数量 = 2(d×h)。 当 h=4d 时,总参数量 = 24d2=8d2

  2. SwiGLU MLP 参数量: 输入为 d,隐藏层为 hswiglu。 因为有门控机制,升维阶段需要两个投影矩阵(WgateWup,均是 dhswiglu)。 降维阶段需要一个矩阵(Wdown,是 hswiglud)。 总参数量 = 3(d×hswiglu)

  3. 对齐参数量: 为了使得 SwiGLU 的计算开销(参数量)与原始模型完全相同: 3dhswiglu=8d2 解得:hswiglu=83d

这正是 LLaMA 源码中对中间层维度进行 int(8 * hidden_size / 3) 计算的根本原因。

Step 3: 工业级实现框架与性能陷阱 (Memory Bound)

在理解了 SwiGLU 的基本公式(down_proj(SiLU(gate_proj(x)) * up_proj(x)))和 8/3 维度由来后,如何把它写进真实的训练框架中?

性能陷阱 1:张量并行 (TP) 与内存对齐 在真实的 LLaMA 源码中,除了按 8/3 计算出隐藏层维度,还需要将其向上取整对齐到一个 multiple_of(通常是 256)的倍数。 这不仅是为了让单卡 Tensor Core(通常要求 8-byte 或 32-byte 对齐)跑得更快,更是因为大模型训练会使用张量并行 (Tensor Parallelism)。如果隐藏层维度不能被 GPU 数量整除(例如 TP=8 时,256 的倍数分给 8 张卡,每张卡至少能分到 32 维),在切分权重矩阵时就会发生严重的报错。

性能陷阱 2:访存瓶颈 (Memory Bound) 与矩阵融合 在最朴素的代码实现中,开发者会分别定义并执行 gate_proj(x)up_proj(x)。 由于这两个线性层共享完全相同的输入张量 x,分开计算会导致巨大的输入 x 被 GPU 从全局显存 (HBM) 中读取两次。

工业界解法 (Matrix Fusion): 在 vLLM、Megatron 等主流框架中,标准的做法是将 WgateWup 这两个形状为 [hidden_size, intermediate_size] 的权重矩阵,在初始化时拼接成一个巨大的融合矩阵 gate_up_proj,其形状为 [hidden_size, 2 * intermediate_size]

在前向传播时,输入 x 只需要被读取一次进行一次大规模矩阵乘法。得到的结果再通过 torch.chunk(2, dim=-1) 切分为两半,分别作为 gate 和 up 块。这极大地缓解了内存带宽瓶颈。

Step 4: 动手实战

要求:请补全下方 calculate_intermediate_sizeSwiGLU 模块的代码。

python
import torch
import torch.nn as nn
import torch.nn.functional as F
python
def calculate_intermediate_size(hidden_size: int, multiple_of: int = 256):
    """
    计算 LLaMA 风格的 SwiGLU 隐藏层维度
    
    规则:
    1. 取 hidden_size 的 8/3
    2. 为了硬件对齐(如 Tensor Core),通常要求是 multiple_of 的倍数。
       因此将结果除以 multiple_of,向上取整后再乘以 multiple_of。
    """
    # ==========================================
    # TODO 1: 计算理论隐藏层大小 (8/3 * hidden_size)
    # 提示: 注意使用整数除法
    # ==========================================
    # intermediate_size = ???
    
    # ==========================================
    # TODO 2: 向 multiple_of 对齐 (向上取整)
    # 提示: 思考如何利用整除的特性实现向上取整
    # ==========================================
    # aligned_size = ???
    
    # return aligned_size
    pass

class SwiGLU_MLP(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int):
        super().__init__()
        # ==========================================
        # TODO 3: 定义工业级 SwiGLU 的投影矩阵
        # ==========================================
        # self.gate_up_proj = ???
        # self.down_proj = ???
        pass

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # ==========================================
        # TODO 4: 组装工业级 SwiGLU 前向传播
        # ==========================================
        # return ???
        pass
python
# 运行此单元格以测试你的实现
def test_swiglu():
    try:
        # 1. 测试维度推导函数
        hidden_size = 4096 # LLaMA-7B 的 hidden_size
        
        # 理论值 = 4096 * 8 / 3 = 10922.66 -> 10922
        # 对齐到 256 倍数: 10922 / 256 = 42.66 -> 取 43
        # 43 * 256 = 11008
        
        aligned_size = calculate_intermediate_size(hidden_size, multiple_of=256)
        assert aligned_size == 11008, f"维度计算错误,期望 11008,实际得到 {aligned_size}"
        print(f"✅ 隐藏层维度推导正确!4096 -> {aligned_size}")
        
        # 2. 测试参数量对齐
        # 标准 MLP: 2 * (4096 * 16384) = 134,217,728
        # LLaMA SwiGLU: 3 * (4096 * 11008) = 135,266,304 (因为向上取整,略大一点点)
        mlp = SwiGLU_MLP(hidden_size, aligned_size)
        
        # 检查是否使用了融合矩阵
        assert hasattr(mlp, 'gate_up_proj'), "请实现融合的 gate_up_proj 矩阵!"
        
        total_params = sum(p.numel() for p in mlp.parameters())
        assert total_params == 135266304, f"参数量异常!{total_params}"
        print(f"✅ SwiGLU 实例参数量验证正确:{total_params} 个参数 (包含融合矩阵)")
        
        # 3. 测试前向传播连通性
        x = torch.randn(2, 10, hidden_size)
        out = mlp(x)
        assert out.shape == x.shape, "输出形状不等于输入形状!"
        print("\n✅ All Tests Passed! 你已经掌握了当前大模型最主流激活函数的底层数学逻辑与访存优化!")
        
    except NotImplementedError:
        print("请先完成 TODO 部分的代码!")
    except AssertionError as e:
        print(f"❌ 测试失败: {e}")
    except TypeError as e:
        print("代码可能未完成,导致了操作错误。")
    except Exception as e:
        print(f"❌ 发生异常: {e}")

test_swiglu()

🛑 STOP HERE 🛑









请先尝试自己完成代码并跑通测试。
如果你正在 Colab 中运行,并且遇到困难没有思路,可以向下滚动查看参考答案。










参考代码与解析

代码

python
def calculate_intermediate_size(hidden_size: int, multiple_of: int = 256):
    # TODO 1 & 2
    intermediate_size = int(hidden_size * 8 / 3)
    aligned_size = ((intermediate_size + multiple_of - 1) // multiple_of) * multiple_of
    return aligned_size

class SwiGLU_MLP(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int):
        super().__init__()
        # TODO 3
        self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # TODO 4
        gate_up = self.gate_up_proj(x)
        gate, up = torch.chunk(gate_up, 2, dim=-1)
        return self.down_proj(F.silu(gate) * up)

解析

1. TODO 1 & 2 (隐藏层维度计算)

  • 计算理论值: 推导公式 3dhswiglu=8d2 得出理论隐藏层维度应为 83d
  • 硬件对齐: 为了确保在张量并行时权重矩阵能够被所有 GPU 整除(不报错),并且满足 Tensor Core 对内存的对齐要求,必须将其向上取整对齐到 multiple_of 的倍数。

2. TODO 3 (定义矩阵)

  • 矩阵融合: 工业级实现的核心。将 gate_projup_proj 融合成一个完整的 gate_up_proj 线性层(输出维度为 2 * intermediate_size)。
  • 进阶思考: 为什么要合并矩阵?如果分开计算,巨大的输入张量 x 会被 GPU 从全局显存中读取两次。大模型计算受限于显存带宽,合并计算让输入 x 仅被读取一次,显著降低访存开销和算子发射延迟。

3. TODO 4 (前向传播)

  • 计算与切分: 得到 gate_up 结果后,使用 torch.chunk(2, dim=-1) 沿最后一个维度均分为门控块 (gate) 和激活块 (up)。
  • 公式实现: 执行 down_proj(SiLU(gate) * up) 计算,完成 SwiGLU 激活机制。

Released under the MIT License.