Skip to content

21. Quantization Theory and INT4 INT8 Practice | 1A-03. Quantization Theory and INT4/INT8 - 计算练习

难度: Medium | 标签: 量化, 显存计算, 推理加速 | 目标人群: 入门 / 推理优化学习者

🚀 云端运行环境

本章节的实战代码可以点击以下链接在免费 GPU 算力平台上直接运行:

Open In ColabOpen In Studio (国内推荐:魔搭社区免费实例)

本练习配套理论文档:21_Quantization_Theory_and_INT4_INT8.md

建议先阅读理论文档,再来完成下面三件事。Notebook 负责把显存、误差和速度差异真正算出来。

学习目标

  • 理解 FP16 / INT8 / INT4 的显存差异
  • 能够手动实现基础的 per-tensor 量化与反量化
  • 学会比较量化前后的误差与推理收益
  • 能够判断什么时候需要 PTQ、QAT 或更保守的方案

练习目标

  • 做一个简单的 per-tensor 量化实现,观察 FP16 / INT8 / INT4 的显存差异
  • 用校准数据跑一次 GPTQ / AWQ 风格的量化流程,看看误差如何变化
  • 对比不同量化格式在同一模型上的推理速度和效果损失

题目区

这一部分先完成练习,再看答案。

  • 先把每个函数自己写出来
  • 再运行测试确认结果
  • 如果卡住,先回看理论页的直觉部分

TODO 概览

  • TODO 1:计算 FP16 / INT8 / INT4 显存
  • TODO 2:实现 per-tensor 量化与反量化
  • TODO 3:比较不同位宽的误差与显存节省

Part 1: 显存与位宽

练习 1.1: 实现模型权重显存计算函数

实现一个函数,计算给定参数量和数据格式的模型权重显存占用。

python
def calculate_weight_memory(num_params_b, dtype):
    """
    计算模型权重的显存占用。

    Args:
        num_params_b: 参数量(单位:B,即十亿)
        dtype: 数据类型,可选 'fp16', 'bf16', 'int8', 'int4'

    Returns:
        memory_gb: 显存占用(单位:GB)
    """
    bytes_per_param = {'fp16': 2, 'bf16': 2, 'int8': 1, 'int4': 0.5}[dtype]
    return num_params_b * bytes_per_param


def test_calculate_weight_memory():
    result = calculate_weight_memory(7, 'fp16')
    assert abs(result - 14.0) < 1e-9
    result = calculate_weight_memory(7, 'int8')
    assert abs(result - 7.0) < 1e-9
    result = calculate_weight_memory(7, 'int4')
    assert abs(result - 3.5) < 1e-9
    print('✅ calculate_weight_memory tests passed')

# 运行测试
test_calculate_weight_memory()

练习 1.2: 对比不同数据格式

使用上面的函数,对比 7B 模型在不同数据格式下的权重显存占用。

python
# TODO: 计算 7B 模型在不同格式下的显存占用
num_params = 7
dtypes = ['fp16', 'bf16', 'int8', 'int4']

print('7B 模型权重显存占用对比:')
print('-' * 40)
for dtype in dtypes:
    memory = calculate_weight_memory(num_params, dtype)
    print(f'{dtype.upper():<6} {memory:>6.1f} GB')

Part 2: 基础量化实现

练习 2.1: 实现 per-tensor 对称量化与反量化

实现一个最简单的量化函数,输入浮点张量,输出量化整数和 scale。

python
import torch


def quantize_per_tensor(x, num_bits=8):
    """
    对张量做对称 per-tensor 量化。

    Returns:
        q: 量化后的整数张量
        scale: 量化比例
    """
    qmax = 2 ** (num_bits - 1) - 1
    scale = x.abs().max() / qmax if x.numel() > 0 else torch.tensor(1.0, device=x.device, dtype=x.dtype)
    scale = torch.clamp(scale, min=1e-8)
    q = torch.clamp(torch.round(x / scale), -qmax - 1, qmax).to(torch.int8)
    return q, scale


def dequantize_per_tensor(q, scale):
    return q.to(torch.float32) * scale


def test_quantize_per_tensor():
    x = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0])
    q, scale = quantize_per_tensor(x, 8)
    x_hat = dequantize_per_tensor(q, scale)
    assert q.dtype == torch.int8
    assert x_hat.shape == x.shape
    print('q:', q.tolist())
    print('scale:', float(scale))
    print('x_hat:', x_hat.tolist())
    print('✅ quantize_per_tensor tests passed')

# 运行测试
test_quantize_per_tensor()

练习 2.2: 观察量化误差

比较原始张量和反量化张量的误差。

python
# TODO: 构造一个随机张量,观察 8-bit 和 4-bit 的误差差异
torch.manual_seed(0)
x = torch.randn(1024) * 2

for bits in [8, 4]:
    q, scale = quantize_per_tensor(x, bits)
    x_hat = dequantize_per_tensor(q, scale)
    mse = torch.mean((x - x_hat) ** 2).item()
    max_err = torch.max(torch.abs(x - x_hat)).item()
    print(f'{bits}-bit -> MSE={mse:.6f}, max_err={max_err:.6f}')
python
def test_quantization_practice():
    x = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0])
    q8, s8 = quantize_per_tensor(x, 8)
    x8 = dequantize_per_tensor(q8, s8)
    q4, s4 = quantize_per_tensor(x, 4)
    x4 = dequantize_per_tensor(q4, s4)
    assert q8.dtype == torch.int8 and q4.dtype == torch.int8
    assert x8.shape == x.shape and x4.shape == x.shape
    assert torch.mean((x - x4).abs()) >= torch.mean((x - x8).abs())
    print('✅ 21 Quantization tests passed')

test_quantization_practice()

Part 3: 量化收益与风险

练习 3.1: 比较显存节省与误差

把显存节省和误差放在一起看,判断量化是否值得。

python
# TODO: 对比 7B 模型在 FP16 / INT8 / INT4 下的显存节省
base = calculate_weight_memory(7, 'fp16')
for dtype in ['int8', 'int4']:
    mem = calculate_weight_memory(7, dtype)
    saving = 1 - mem / base
    print(f'{dtype.upper():<4} memory={mem:.1f} GB, saving={saving:.0%}')

练习 3.2: 何时需要更保守的方案

思考:当误差过大、激活离群值明显或任务对精度敏感时,应该如何调整量化策略?

python
# TODO: 写出 2-3 个需要保守量化或 QAT 的场景
scenarios = [
    '小模型生成任务对精度非常敏感',
    '激活值离群值明显,INT4 误差过大',
    '需要保持与基线几乎一致的输出质量'
]
for s in scenarios:
    print('-', s)

Part 4: 自检与扩展

练习 4.1: 小结检查

如果你能回答下面三个问题,就说明这一页基本过关:

  1. FP16 / INT8 / INT4 的显存差异是多少?
  2. 为什么激活值通常比权重更难量化?
  3. 什么时候应该考虑 QAT 而不是继续硬压位宽?
python
def check_understanding():
    questions = [
        '1. FP16 / INT8 / INT4 的显存差异是多少?',
        '2. 为什么激活值通常比权重更难量化?',
        '3. 什么时候应该考虑 QAT 而不是继续硬压位宽?',
    ]
    for q in questions:
        print(q)

check_understanding()

🛑 STOP HERE 🛑

先确保你能独立完成上面的练习,再看参考答案。

参考答案与解析

下面的内容在正式讲解时再补全。当前版本先保留练习结构,方便后续统一填充答案。

Released under the MIT License.