20. Quantization W8A16 | 模型量化基础: INT8 绝对最大值量化与反量化 (Quantization)
难度: Medium | 标签: 推理优化, 量化, PTQ | 目标人群: 模型微调与工程部署
🚀 云端运行环境
本章节的实战代码可以点击以下链接在免费 GPU 算力平台上直接运行:
大模型动辄百亿参数,显存占用极大。为了让 7B 模型能跑在消费级显卡(如 RTX 4090 / 3090)甚至手机端,模型量化 (Quantization) 是不可或缺的技术。 本节我们将实现最基础、也是面试中最常考的 INT8 对称量化 (Symmetric Quantization),理解量化缩放因子 (Scale) 的计算,以及如何在进行矩阵乘法前进行反量化 (Dequantization)。
相关阅读: 本节使用纯 PyTorch 实现了算法逻辑与数学推导。 如果你想学习工业界如何打破该算子的 Memory Bound (访存瓶颈),请前往 Triton 篇:
../03_CUDA_and_Triton_Kernels/11_Triton_Quantization_Support.ipynb
Step 1: 核心思想与概念
什么是量化? 将高精度(如 FP32/FP16,占用 4/2 个字节)的浮点数,映射到低精度(如 INT8,占用 1 个字节)的整数上。这样不仅显存占用直接缩小 2-4 倍,还能利用硬件的整数计算单元(如 INT8 Tensor Core)加速计算。
PTQ 与 QAT 的区别:
- PTQ (Post-Training Quantization,训练后量化):模型已经训练好了。我们只需要拿一小批校准数据(Calibration Data)跑一遍,统计一下激活值的分布,算出缩放因子(Scale),直接对权重转换。本节我们实现的就是 PTQ。
- QAT (Quantization-Aware Training,量化感知训练):在训练时,正向传播模拟量化的误差,反向传播用“直通估计器 (STE)”更新原始的高精度权重。成本极高,但精度损失最小。
Step 2: 代码实现框架
我们需要实现 quantize 和 dequantize 两个函数。在量化时,先计算 Scale,然后将张量除以 Scale,接着进行 torch.round,最后通过 torch.clamp 限制在 torch.int8 数据类型。反量化则是简单的乘以 Scale。
Step 3: 数学公式:绝对最大值量化
这是对称量化最常用的方法。假设我们有一个浮点张量
计算绝对最大值 (Absmax): 找到张量中绝对值最大的元素:
。 计算缩放因子 (Scale): S = \frac{127}{m}$。这个
就代表了“1个单位的 INT8 等于多少个单位的 FP16”。 量化 (Quantize): 将张量乘以缩放因子,然后四舍五入 (Round) 变成整数,并截断 (Clamp) 到 INT8 范围内,防止异常值越界:
反量化 (Dequantize): 在真正做矩阵乘法前(如果是 W8A16 这种 Weight-only 量化),需要把 INT8 恢复成 FP16 参与计算:
Step 4: 动手实战
要求:
- 补全
absmax_quantize函数,实现权重的 INT8 转换并返回scale。 - 补全
W8A16Linear的forward方法。W8A16 意味着权重 (We\r\right) 是 INT8,但激活值 (Activation/Input) 保持 FP16。计算时需要实时反量化。
import torch
import torch.nn as nn
import torch.nn.functional as Fdef absmax_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
将浮点张量 X 量化为 INT8,并返回缩放因子。
Args:
x: 浮点类型的张量
Returns:
x_quant: dtype 为 torch.int8 的量化张量
scale: float 类型的缩放因子
"""
# ==========================================
# TODO 1: 计算张量的绝对最大值 absmax
# ==========================================
# absmax = ???
absmax = torch.tensor(1.0) # 占位初始化
# 避免除以 0 的情况
# if absmax == 0:
# absmax = 1e-8
# ==========================================
# TODO 2: 计算缩放因子 scale (映射到 [-127, 127])
# ==========================================
# scale = ???
scale = torch.tensor(1.0) # 占位初始化
# ==========================================
# TODO 3: 量化过程
# 1. 乘以 scale
# ==========================================
# x_scaled = ???
# x_quant = ???
x_scaled = torch.zeros_like(x) # 占位初始化
x_quant = torch.zeros_like(x, dtype=torch.int8) # 占位初始化
return x_quant, scale
class W8A16Linear(nn.Module):
"""
Weight-only INT8 量化线性层。
在内存中,我们存储的是非常微小的 INT8 权重。
在计算时,我们将权重反量化回 FP16,与同样是 FP16 的输入进行矩阵乘法。
这种方式虽然没有加速计算,但极大地缓解了从内存读取权重的 Memory-bound (带宽高了 2 倍)。
"""
def __init__(self, in_features: int, out_features: int):
super().__init__()
# 内存中存储的是 int8 格式
self.register_buffer("weight_int8", torch.zeros((out_features, in_features), dtype=torch.int8))
self.register_buffer("scale", torch.tensor(1.0))
self.bias = nn.Parameter(torch.zeros(out_features))
def from_float(self, linear_layer: nn.Linear):
"""
从高精度的 Linear 层中吸收权重并进行 PTQ 量化
"""
w_quant, scale = absmax_quantize(linear_layer.weight.data)
self.weight_int8.copy_(w_quant)
self.scale.copy_(scale)
if linear_layer.bias is not None:
self.bias.data.copy_(linear_layer.bias.data)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# ==========================================
# TODO 4: 反量化与前向传播
# 1. 将 weight_int8 转换回与输入 x 相同的类型 (如 float32/float16)
# 2. 除以 self.scale 恢复其数值范围
# 3. 使用 F.linear 进行标准的矩阵乘法
# ==========================================
# w_fp = ???
# w_dequant = ???
w_fp = torch.zeros_like(self.weight_int8, dtype=x.dtype) # 占位初始化(错误实现,供测试框架捕获)
w_dequant = torch.zeros_like(w_fp) # 占位初始化(错误实现,供测试框架捕获)
# out = ???
out = torch.zeros(x.shape[:-1] + (self.weight_int8.shape[0],), dtype=x.dtype, device=x.device) # 占位初始化(错误实现,供测试框架捕获)
return out# 测试你的实现
def test_quantization():
try:
torch.manual_seed(42)
# 1. 测试 absmax_quantize
x_fp = torch.tensor([-0.8, 1.5, -3.0, 2.5, 0.0])
# 绝对最大值是 3.0。Scale = 127 / 3.0 = 42.333
# 2.5 * 42.333 = 105.8 -> 106
x_q, scale = absmax_quantize(x_fp)
assert x_q.dtype == torch.int8, "量化后的张量必须是 int8 类型!"
assert torch.allclose(scale, torch.tensor(127.0 / 3.0)), "Scale 计算不正确!"
assert x_q[3].item() == 106, "量化后的四舍五入数值计算不正确!"
print("✅ absmax_quantize 核心算法测试通过!")
# 2. 测试 W8A16 线性层
in_dim, out_dim = 128, 64
batch, seq = 2, 10
# 构建一个原始的 FP32 Linear 层
fp_linear = nn.Linear(in_dim, out_dim)
# 构建我们的 INT8 量化层并吸入权重
q_linear = W8A16Linear(in_dim, out_dim)
q_linear.from_float(fp_linear)
# 验证显存占用 (理论上应该小 4 倍,因为 FP32 是 4 字节,INT8 是 1 字节)
fp_bytes = fp_linear.weight.element_size() * fp_linear.weight.numel()
q_bytes = q_linear.weight_int8.element_size() * q_linear.weight_int8.numel()
assert q_bytes == fp_bytes // 4, "INT8 权重的内存占用必须是 FP32 的四分之一!"
# 验证前向传播结果的误差 (量化必然带来微小误差)
x_input = torch.randn(batch, seq, in_dim)
out_fp = fp_linear(x_input)
out_q = q_linear(x_input)
# 计算余弦相似度,因为经过量化反量化,数值不可能完全一致,只要相似度极高就算成功
cos_sim = F.cosine_similarity(out_fp.flatten(), out_q.flatten(), dim=0)
assert cos_sim > 0.99, f"反量化计算出的张量与原始张量差异过大,相似度仅为: {cos_sim.item():.4f}"
print(f"✅ W8A16Linear 测试通过!输出相似度极高 (Cosine Sim: {cos_sim.item():.4f}),且权重内存缩小 4 倍。")
except NotImplementedError:
print("请先完成 TODO 代码!")
except AttributeError:
print("代码未完成导致变量属性错误。")
raise e
except AssertionError as e:
print(f"❌ 测试失败: {e}")
raise e
except Exception as e:
print(f"❌ 发生未知异常: {e}")
raise e
test_quantization()🛑 STOP HERE 🛑
请先尝试自己完成代码并跑通测试。
如果你正在 Colab 中运行,并且遇到困难没有思路,可以向下滚动查看参考答案。
参考代码与解析
代码
def absmax_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
将浮点张量 X 量化为 INT8,并返回缩放因子。
"""
# TODO 1: 计算张量的绝对最大值 absmax
absmax = torch.max(torch.abs(x))
# 避免除以 0 的情况
if absmax == 0:
absmax = 1e-8
# TODO 2: 计算缩放因子 scale (映射到 [-127, 127])
scale = 127.0 / absmax
# TODO 3: 量化过程
x_scaled = x * scale
x_quant = torch.clamp(torch.round(x_scaled), -128, 127).to(torch.int8)
return x_quant, scale
class W8A16Linear(nn.Module):
"""
Weight-only INT8 量化线性层。
"""
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.register_buffer("weight_int8", torch.zeros((out_features, in_features), dtype=torch.int8))
self.register_buffer("scale", torch.tensor(1.0))
self.bias = nn.Parameter(torch.zeros(out_features))
def from_float(self, linear_layer: nn.Linear):
"""
从高精度的 Linear 层中吸收权重并进行 PTQ 量化
"""
w_quant, scale = absmax_quantize(linear_layer.weight.data)
self.weight_int8.copy_(w_quant)
self.scale.copy_(scale)
if linear_layer.bias is not None:
self.bias.data.copy_(linear_layer.bias.data)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# TODO 4: 反量化与前向传播
# 1. 将 weight_int8 转换回与输入 x 相同的类型
w_fp = self.weight_int8.to(x.dtype)
# 2. 除以 self.scale 恢复其数值范围
w_dequant = w_fp / self.scale
# 3. 使用 F.linear 进行标准的矩阵乘法
out = F.linear(x, w_dequant, self.bias)
return out解析
1. TODO 1: 计算绝对最大值
- 实现方式:
absmax = torch.max(torch.abs(x)) - 关键点:找到张量中绝对值最大的元素,用于确定量化范围
- 技术细节:需要处理除零情况,当 absmax 为 0 时设为 1e-8
2. TODO 2: 计算缩放因子
- 实现方式:
scale = 127.0 / absmax - 关键点:将浮点数范围映射到 INT8 的 [-127, 127] 区间
- 技术细节:使用 127 而非 128,保持对称性,避免 -128 这个特殊值
3. TODO 3: 量化过程
- 实现方式:
x_scaled = x * scale,x_quant = torch.clamp(torch.round(x_scaled), -128, 127).to(torch.int8) - 关键点:先缩放、再四舍五入、最后截断到有效范围
- 技术细节:使用
torch.clamp防止异常值越界,确保所有值在 [-128, 127] 内
4. TODO 4: 反量化与前向传播
- 实现方式:
w_fp = self.weight_int8.to(x.dtype),w_dequant = w_fp / self.scale,out = F.linear(x, w_dequant, self.bias) - 关键点:在计算前将 INT8 权重恢复为浮点数
- 技术细节:先转换数据类型,再除以 scale 恢复数值范围
工程优化要点
- 显存节省:INT8 权重占用空间是 FP32 的 1/4,FP16 的 1/2
- 带宽优化:W8A16 虽然计算仍用 FP16,但从内存读取权重的带宽需求降低 2 倍
- 精度损失:对称量化通常损失 < 1% 精度,适合大多数推理场景
- Per-tensor vs Per-channel:本实现是 per-tensor 量化,工业界常用 per-channel(每个输出通道独立 scale)以提高精度
- 校准数据:PTQ 需要少量校准数据(通常 128-512 样本)来统计激活值分布
- 工业实践:LLM.int8() 使用混合精度,对异常值(outlier)保持 FP16,其余用 INT8
