09. Triton Fused LoRA | Triton 进阶算子:Multi-LoRA 融合推理与 Batch 内指针路由 (Punica 思想)
难度: Hard | 标签: Triton, LoRA, Punica, Serving | 目标人群: 核心 Infra 与算子开发
🚀 云端运行环境
本章节的实战代码可以点击以下链接在免费 GPU 算力平台上直接运行:
在 02_PyTorch_Algorithms/09_LoRA_Tutorial.ipynb 中,我们在 PyTorch 层面实现了 LoRA (linear(),会严重降低 GPU 的吞吐量(无法利用 Batch 计算)。 本节我们将手撕 Multi-LoRA (如 S-LoRA / Punica 论文核心思路) 的底层 Triton 融合算子:通过传入 lora_indices,让每一个 Token 在 SRAM 内自动去内存池中拉取它专属的 LoRA 权重,完成批量计算!
相关阅读: 本节使用 Triton 实现了高阶的 Token 级动态路由与融合推理。 如果你对该算子的基础 PyTorch 层面的权重分解与训练不熟悉,建议先复习 PyTorch 篇:
../02_PyTorch_Algorithms/10_LoRA_Tutorial.ipynb
Step 1: Multi-LoRA 内存池与 Batch 路由
LoRA 的内存池 (Weight Pool): 我们不再为每一个单独的 LoRA 权重分配离散的张量。而是将机器上加载的所有 LoRA 矩阵 A (形状
r \t\times in_features) 拼接成一个巨大的连续张量lora_a_pool,形状为(num_loras, r, in_features)。
Token 级别的细粒度路由 (Token-wise Routing): 假设输入特征
的形状为 (batch_size, in_features)。 我们额外传入一个长度为batch_size的一维整型数组lora_indices。它记录了:当前 Batch 中,第个 Token 到底需要使用 lora_a_pool里的第几个 LoRA 模型。
SRAM 内部的极致并行:
- Triton
pid处理矩阵的某一行(某个 Token)。 - 我们直接在内核里根据
lora_idx = tl.load(lora_indices + pid),算出指向内存池中特定 LoRA A 和 B 的偏移量!- 一次性读取
和专属的 ,在极速的 SRAM 中完成 ,最后写回 。 这样,原本必须串行计算的不同模型请求,被完美地合并在了一个 Triton Kernel 调用中(Batch Inference)!
Step 2: 内存池与 Batch 指针路由
在推理服务器中,往往存在一个大底座模型对应几百个微调的 LoRA 权重。为了避免切换开销,我们将所有的 LoRA 权重放进统一的巨大的显存池中。每个发来的 Token 请求都会带有一个 lora_id。内核利用这个 lora_id 充当偏移指针,直接在同一次前向计算中去抓取不同的权重完成点积。
Step 3: 指针路由代码框架
传入包含所有权重的张量 lora_pool 和整数数组 lora_indices。在内核中,先读取 lora_idx = tl.load(lora_indices_ptr + pid_batch),将该索引乘上权重的 stride,动态确定当前线程块该加载哪一份 LoRA A 和 B,随后做标准的低秩乘加运算。
Step 4: 动手实战
要求:请补全下方 fused_multi_lora_kernel。我们需要根据传入的 lora_indices,正确地在三维张量 lora_a_pool 和 lora_b_pool 中计算偏移量,并执行点积。为了简化,我们假设秩 r 很小(如 8 或 16),且输入只进行列并行分块(行完全塞入 SRAM)。
import torch
import triton
import triton.language as tl@triton.jit
def fused_multi_lora_kernel(
x_ptr, out_ptr,
lora_a_pool_ptr, lora_b_pool_ptr,
lora_indices_ptr,
M, IN_DIM, OUT_DIM, R: tl.constexpr,
stride_x_m, stride_x_in,
stride_out_m, stride_out_dim,
stride_a_pool, stride_a_r, stride_a_in,
stride_b_pool, stride_b_out, stride_b_r,
BLOCK_IN: tl.constexpr, BLOCK_OUT: tl.constexpr
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_n = pid_n * BLOCK_OUT + tl.arange(0, BLOCK_OUT)
# ==========================================
# TODO 1: 读取当前 Token 的 LoRA 索引
# ==========================================
# lora_idx = ???
# ==========================================
# TODO 2: 计算内存池中该 LoRA 的基地址偏移
# ==========================================
# a_pool_base_ptr = ???
# b_pool_base_ptr = ???
# ==========================================
# TODO 3: 计算 x @ A,得到中间激活 h_r
# ==========================================
# acc = tl.zeros((BLOCK_OUT,), dtype=tl.float32)
# h_r = tl.zeros((R,), dtype=tl.float32)
# num_k_blocks = tl.cdiv(IN_DIM, BLOCK_IN)
# for k in range(num_k_blocks):
# ...
# ==========================================
# TODO 4: 计算 h_r @ B,得到最终输出
# ==========================================
# offs_r = tl.arange(0, R)
# b_ptrs = ???
# b_val = ???
# acc += ???
# ==========================================
# TODO 5: 将结果写回显存
# ==========================================
# out_ptrs = ???
# tl.store(...)
raise NotImplementedError("请完成 TODO 1-5")
def triton_multi_lora_forward(x: torch.Tensor, lora_a_pool: torch.Tensor, lora_b_pool: torch.Tensor, lora_indices: torch.Tensor):
M, IN_DIM = x.shape
num_loras, R, _ = lora_a_pool.shape
_, OUT_DIM, _ = lora_b_pool.shape
out = torch.empty((M, OUT_DIM), device=x.device, dtype=x.dtype)
BLOCK_IN = 64
BLOCK_OUT = 64
grid = (M, triton.cdiv(OUT_DIM, BLOCK_OUT))
fused_multi_lora_kernel[grid](
x, out,
lora_a_pool, lora_b_pool,
lora_indices,
M, IN_DIM, OUT_DIM, R,
x.stride(0), x.stride(1),
out.stride(0), out.stride(1),
lora_a_pool.stride(0), lora_a_pool.stride(1), lora_a_pool.stride(2),
lora_b_pool.stride(0), lora_b_pool.stride(1), lora_b_pool.stride(2),
BLOCK_IN=BLOCK_IN, BLOCK_OUT=BLOCK_OUT
)
return out# 测试并验证 Multi-LoRA 路由的正确性
def test_multi_lora():
if not torch.cuda.is_available():
print("⏭️ 忽略测试:无 GPU。")
return
try:
torch.manual_seed(42)
batch_size = 4
in_dim = 128
out_dim = 256
num_loras = 3 # 内存池中有 3 个不同的 LoRA
r = 16
x = torch.randn(batch_size, in_dim, device='cuda')
# 构造内存池
# A_pool: (3, 16, 128)
# B_pool: (3, 256, 16)
lora_a_pool = torch.randn(num_loras, r, in_dim, device='cuda')
lora_b_pool = torch.randn(num_loras, out_dim, r, device='cuda')
# 构造复杂的请求路由 (Token 0用LoRA_2, Token 1用LoRA_0...)
lora_indices = torch.tensor([2, 0, 1, 2], device='cuda', dtype=torch.int32)
# 1. 纯 PyTorch 参考计算 (必须使用极其缓慢的 for 循环拼接)
out_ref = torch.zeros(batch_size, out_dim, device='cuda')
for i in range(batch_size):
idx = lora_indices[i].item()
# 提取专属权重
A = lora_a_pool[idx] # (r, in_dim)
B = lora_b_pool[idx] # (out_dim, r)
# x[i]: (1, in_dim)
# x_i @ A.T @ B.T
h_r = x[i].unsqueeze(0) @ A.T # (1, r)
y_i = h_r @ B.T # (1, out_dim)
out_ref[i] = y_i.squeeze(0)
# 2. Triton 单算子极致并行计算
out_tri = triton_multi_lora_forward(x, lora_a_pool, lora_b_pool, lora_indices)
# 3. 验证结果
diff = torch.max(torch.abs(out_ref - out_tri))
print(f"最大误差: {diff.item():.6e}")
assert diff < 1e-3, "Triton Multi-LoRA 路由或计算结果不正确!"
print("✅ Multi-LoRA 路由融合推理验证通过。")
print("该算子实现了 Token 级别的动态路由,支持 Batch 内多模型并发推理。")
except NotImplementedError:
print("请先完成 TODO 代码!")
except Exception as e:
print(f"❌ 测试失败: {e}")
test_multi_lora()🛑 STOP HERE 🛑
请先尝试自己完成代码并跑通测试。
如果你正在 Colab 中运行,并且遇到困难没有思路,可以向下滚动查看参考答案。
参考代码与解析
代码
import torch
import triton
import triton.language as tl
@triton.jit
def fused_multi_lora_kernel(
x_ptr, out_ptr,
lora_a_pool_ptr, lora_b_pool_ptr,
lora_indices_ptr,
M, IN_DIM, OUT_DIM, R: tl.constexpr,
stride_x_m, stride_x_in,
stride_out_m, stride_out_dim,
stride_a_pool, stride_a_r, stride_a_in,
stride_b_pool, stride_b_out, stride_b_r,
BLOCK_IN: tl.constexpr, BLOCK_OUT: tl.constexpr
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_n = pid_n * BLOCK_OUT + tl.arange(0, BLOCK_OUT)
# TODO 1: 读取当前 Token 的 LoRA 索引
lora_idx = tl.load(lora_indices_ptr + pid_m)
# TODO 2: 计算内存池中该 LoRA 的基地址偏移
a_pool_base_ptr = lora_a_pool_ptr + lora_idx * stride_a_pool
b_pool_base_ptr = lora_b_pool_ptr + lora_idx * stride_b_pool
acc = tl.zeros((BLOCK_OUT,), dtype=tl.float32)
h_r = tl.zeros((R,), dtype=tl.float32)
# TODO 3: 计算 x @ A,得到中间激活 h_r
num_k_blocks = tl.cdiv(IN_DIM, BLOCK_IN)
for k in range(num_k_blocks):
offs_in = k * BLOCK_IN + tl.arange(0, BLOCK_IN)
x_ptrs = x_ptr + pid_m * stride_x_m + offs_in * stride_x_in
x_val = tl.load(x_ptrs, mask=offs_in < IN_DIM, other=0.0)
offs_r = tl.arange(0, R)
a_ptrs = a_pool_base_ptr + offs_r[:, None] * stride_a_r + offs_in[None, :] * stride_a_in
a_val = tl.load(a_ptrs, mask=offs_in[None, :] < IN_DIM, other=0.0)
h_r += tl.sum(x_val[None, :] * a_val, axis=1)
# TODO 4: 计算 h_r @ B,得到最终输出
offs_r = tl.arange(0, R)
b_ptrs = b_pool_base_ptr + offs_n[:, None] * stride_b_out + offs_r[None, :] * stride_b_r
b_val = tl.load(b_ptrs, mask=offs_n[:, None] < OUT_DIM, other=0.0)
acc += tl.sum(h_r[None, :] * b_val, axis=1)
# TODO 5: 将结果写回显存
out_ptrs = out_ptr + pid_m * stride_out_m + offs_n * stride_out_dim
tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=offs_n < OUT_DIM)
def triton_multi_lora_forward(x: torch.Tensor, lora_a_pool: torch.Tensor, lora_b_pool: torch.Tensor, lora_indices: torch.Tensor):
M, IN_DIM = x.shape
num_loras, R, _ = lora_a_pool.shape
_, OUT_DIM, _ = lora_b_pool.shape
out = torch.empty((M, OUT_DIM), device=x.device, dtype=x.dtype)
BLOCK_IN = 64
BLOCK_OUT = 64
grid = (M, triton.cdiv(OUT_DIM, BLOCK_OUT))
fused_multi_lora_kernel[grid](
x, out,
lora_a_pool, lora_b_pool,
lora_indices,
M, IN_DIM, OUT_DIM, R,
x.stride(0), x.stride(1),
out.stride(0), out.stride(1),
lora_a_pool.stride(0), lora_a_pool.stride(1), lora_a_pool.stride(2),
lora_b_pool.stride(0), lora_b_pool.stride(1), lora_b_pool.stride(2),
BLOCK_IN=BLOCK_IN, BLOCK_OUT=BLOCK_OUT
)
return out解析
1. TODO 1: 读取当前 Token 的 LoRA 索引
- 实现方式:
lora_idx = tl.load(lora_indices_ptr + pid_m) - 关键点:每个 Token(由
pid_m标识)都有自己专属的 LoRA 索引 - 技术细节:
lora_indices是一个长度为batch_size的整型数组,记录了每个 Token 应该使用内存池中的哪个 LoRA 模型
2. TODO 2: 计算内存池中该 LoRA 的基地址偏移
- 实现方式:python
a_pool_base_ptr = lora_a_pool_ptr + lora_idx * stride_a_pool b_pool_base_ptr = lora_b_pool_ptr + lora_idx * stride_b_pool - 关键点:这是 Multi-LoRA 路由的核心逻辑,通过指针偏移实现动态权重选择
- 技术细节:
stride_a_pool和stride_b_pool是内存池第一维的步长,乘以lora_idx后可以直接跳转到对应 LoRA 的起始位置
3. TODO 3: 计算 x @ A,得到中间激活 h_r
- 实现方式:循环遍历
IN_DIM,分块加载x和A,累加点积结果 - 关键点:
- 使用
tl.cdiv(IN_DIM, BLOCK_IN)计算需要的块数 - 对每个块,加载
x的一段和A的对应列 - 使用
tl.sum(x_val[None, :] * a_val, axis=1)计算点积并累加到h_r
- 使用
- 技术细节:
A的形状是(R, IN_DIM),我们用offs_r[:, None]和offs_in[None, :]构造二维索引mask=offs_in[None, :] < IN_DIM确保不会越界访问
4. TODO 4: 计算 h_r @ B,得到最终输出
- 实现方式:python
b_ptrs = b_pool_base_ptr + offs_n[:, None] * stride_b_out + offs_r[None, :] * stride_b_r b_val = tl.load(b_ptrs, mask=offs_n[:, None] < OUT_DIM, other=0.0) acc += tl.sum(h_r[None, :] * b_val, axis=1) - 关键点:
B的形状是(OUT_DIM, R),我们提取对应的输出维度分块 - 技术细节:使用
tl.sum对R维度求和,得到长度为BLOCK_OUT的输出向量
5. TODO 5: 将结果写回显存
- 实现方式:python
out_ptrs = out_ptr + pid_m * stride_out_m + offs_n * stride_out_dim tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=offs_n < OUT_DIM) - 关键点:使用
mask保护边界,避免越界写入 - 技术细节:
acc.to(out_ptr.dtype.element_ty)确保数据类型与输出张量一致
工程优化要点
- 内存池设计:将所有 LoRA 权重放入统一的连续内存池,避免频繁的内存分配和释放
- Token 级路由:每个 Token 独立选择 LoRA 权重,实现细粒度的动态路由
- 指针偏移优化:使用
stride计算偏移量,避免复杂的索引计算 - 分块计算:对
IN_DIM和OUT_DIM进行分块,充分利用 SRAM 的高速缓存 - 低秩分解:利用 LoRA 的低秩特性(
R << IN_DIM, OUT_DIM),先计算x @ A再计算h_r @ B,减少计算量 - Batch 并行:不同 Token 的计算完全独立,可以在 GPU 上高度并行
- 内存访问模式:使用
mask保护边界访问,确保内存安全 - 工业应用:该算子是 S-LoRA、Punica 等多租户推理框架的核心组件,可以在单次 kernel 调用中处理多个用户的不同 LoRA 请求
