11. Triton Multi LoRA | Triton 多租户路由与融合推理
难度: Hard | 标签: Triton, LoRA, Punica, Serving | 目标人群: 核心 Infra 与算子开发
🚀 云端运行环境
本章节的实战代码可以点击以下链接在免费 GPU 算力平台上直接运行:
在 02_PyTorch_Algorithms/09_LoRA_Tutorial.ipynb 中,我们在 PyTorch 层面实现了 LoRA (linear(),通常会降低 GPU 的吞吐量,也不便于利用 Batch 计算。 本节我们将实现 Multi-LoRA (如 S-LoRA / Punica 论文思路) 的底层 Triton 融合算子:通过传入 lora_indices,让每一个 Token 在 SRAM 内按索引读取它对应的 LoRA 权重,完成批量计算。
本节承接 08. FlashAttention、09. PagedAttention 和 10. Quantization 的结果,继续解决多用户并发推理中的 LoRA 权重路由 问题。想象一个 SaaS 推理服务:用户 A 用“代码生成” LoRA,用户 B 用“翻译” LoRA,用户 C 用“摘要” LoRA。传统做法会在切换 LoRA 时反复加载权重,导致 GPU 利用率偏低;Multi-LoRA 则把多个请求合并到一个 Batch 中,通过 Token 级动态路由完成融合推理。如果说 08-09 解决的是 Attention 计算与 KV Cache,10 解决的是权重体积,那么 11 关注的就是 多租户服务如何更高效地复用这些优化。
前置
导语: 这一节会把“一个 batch 里多种 LoRA 路由”直接落到 Triton 的 SRAM 访问上。
相关阅读
导语: 如果想先回顾 LoRA 的 PyTorch 版本,可以先看这页,帮助理解后面的指针路由。
Step 1: Multi-LoRA 内存池与 Batch 路由
LoRA 的内存池 (Weight Pool): 我们通常不会为每一个单独的 LoRA 权重分配离散的张量,而是将机器上加载的所有 LoRA 矩阵 A (形状
r \t\times in_features) 拼接成一个连续张量lora_a_pool,形状为(num_loras, r, in_features)。
Token 级别的细粒度路由 (Token-wise Routing): 假设输入特征
的形状为 (batch_size, in_features)。 我们额外传入一个长度为batch_size的一维整型数组lora_indices。它记录了:当前 Batch 中,第个 Token 到底需要使用 lora_a_pool里的第几个 LoRA 模型。
SRAM 内部的块级并行:
- Triton
pid处理矩阵的某一行(某个 Token)。 - 我们直接在内核里根据
lora_idx = tl.load(lora_indices + pid),算出指向内存池中特定 LoRA A 和 B 的偏移量!- 一次性读取
和专属的 ,在极速的 SRAM 中完成 ,最后写回 。 这样,原本通常需要分别计算的不同模型请求,可以被合并到一个 Triton Kernel 调用中(Batch Inference)。
Step 2: 内存池与 Batch 指针路由
在推理服务器中,往往存在一个大底座模型对应几百个微调的 LoRA 权重。为了避免切换开销,我们将所有的 LoRA 权重放进统一的巨大的显存池中。每个发来的 Token 请求都会带有一个 lora_id。内核利用这个 lora_id 充当偏移指针,直接在同一次前向计算中去抓取不同的权重完成点积。
Step 3: 指针路由代码框架
传入包含所有权重的张量 lora_pool 和整数数组 lora_indices。在内核中,先读取 lora_idx = tl.load(lora_indices_ptr + pid_batch),将该索引乘上权重的 stride,动态确定当前线程块该加载哪一份 LoRA A 和 B,随后做标准的低秩乘加运算。
补充说明:性能对比与边界
这一节的核心卖点是把多个不同 LoRA 请求合并成一次 kernel 调用,所以最好补一个与串行方案的对比口径。 建议后面在验证区展示:
- 串行 LoRA 推理的时延
- Multi-LoRA 融合后的时延
- 最终加速比
边界条件上,当前实现默认 num_loras > 0 且 R > 0;如果要继续增强鲁棒性,可以在宿主侧先做参数校验,而不是把异常分支塞进 kernel 主干。
Step 4: 动手实战
要求:请补全下方 fused_multi_lora_kernel。我们需要根据传入的 lora_indices,正确地在三维张量 lora_a_pool 和 lora_b_pool 中计算偏移量,并执行点积。为了简化,这里假设秩 r 较小(如 8 或 16),且输入只进行列并行分块(行方向通常可以一次性放入 SRAM)。
try:
import triton
except ModuleNotFoundError:
try:
import google.colab # type: ignore
except Exception:
raise
import subprocess, sys
print('Installing Triton for Part 3...')
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', 'triton'])
import triton
import torch
import triton
import triton.language as tl@triton.jit
def fused_multi_lora_kernel(
x_ptr, out_ptr,
lora_a_pool_ptr, lora_b_pool_ptr,
lora_indices_ptr,
M, IN_DIM, OUT_DIM, R: tl.constexpr,
stride_x_m, stride_x_in,
stride_out_m, stride_out_dim,
stride_a_pool, stride_a_r, stride_a_in,
stride_b_pool, stride_b_out, stride_b_r,
BLOCK_IN: tl.constexpr, BLOCK_OUT: tl.constexpr
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_n = pid_n * BLOCK_OUT + tl.arange(0, BLOCK_OUT)
# ==========================================
# TODO 1: 读取当前 Token 的 LoRA 索引
# ==========================================
# lora_idx = ???
# ==========================================
# TODO 2: 计算内存池中该 LoRA 的基地址偏移
# ==========================================
# a_pool_base_ptr = ???
# b_pool_base_ptr = ???
# ==========================================
# TODO 3: 计算 x @ A,得到中间激活 h_r
# ==========================================
# acc = tl.zeros((BLOCK_OUT,), dtype=tl.float32)
# h_r = tl.zeros((R,), dtype=tl.float32)
# num_k_blocks = tl.cdiv(IN_DIM, BLOCK_IN)
# for k in range(num_k_blocks):
# ...
# ==========================================
# TODO 4: 计算 h_r @ B,得到最终输出
# ==========================================
# offs_r = tl.arange(0, R)
# b_ptrs = ???
# b_val = ???
# acc += ???
# ==========================================
# TODO 5: 将结果写回显存
# ==========================================
# out_ptrs = ???
# tl.store(...)
def _question_placeholder_11():
raise NotImplementedError("请完成 TODO 1-5")
_question_placeholder_11()
def triton_multi_lora_forward(x: torch.Tensor, lora_a_pool: torch.Tensor, lora_b_pool: torch.Tensor, lora_indices: torch.Tensor):
M, IN_DIM = x.shape
num_loras, R, _ = lora_a_pool.shape
_, OUT_DIM, _ = lora_b_pool.shape
out = torch.empty((M, OUT_DIM), device=x.device, dtype=x.dtype)
BLOCK_IN = 64
BLOCK_OUT = 64
grid = (M, triton.cdiv(OUT_DIM, BLOCK_OUT))
fused_multi_lora_kernel[grid](
x, out,
lora_a_pool, lora_b_pool,
lora_indices,
M, IN_DIM, OUT_DIM, R,
x.stride(0), x.stride(1),
out.stride(0), out.stride(1),
lora_a_pool.stride(0), lora_a_pool.stride(1), lora_a_pool.stride(2),
lora_b_pool.stride(0), lora_b_pool.stride(1), lora_b_pool.stride(2),
BLOCK_IN=BLOCK_IN, BLOCK_OUT=BLOCK_OUT
)
return out# 测试并验证 Multi-LoRA 路由的正确性
def test_multi_lora():
if not torch.cuda.is_available():
print("⏭️ 忽略测试:无 GPU。")
return
try:
torch.manual_seed(42)
batch_size = 4
in_dim = 128
out_dim = 256
num_loras = 3 # 内存池中有 3 个不同的 LoRA
r = 16
x = torch.randn(batch_size, in_dim, device='cuda')
# 构造内存池
# A_pool: (3, 16, 128)
# B_pool: (3, 256, 16)
lora_a_pool = torch.randn(num_loras, r, in_dim, device='cuda')
lora_b_pool = torch.randn(num_loras, out_dim, r, device='cuda')
# 构造复杂的请求路由 (Token 0用LoRA_2, Token 1用LoRA_0...)
lora_indices = torch.tensor([2, 0, 1, 2], device='cuda', dtype=torch.int32)
# 1. 纯 PyTorch 参考计算 (为了便于对照,使用较直观的 for 循环拼接)
out_ref = torch.zeros(batch_size, out_dim, device='cuda')
for i in range(batch_size):
idx = lora_indices[i].item()
# 提取专属权重
A = lora_a_pool[idx] # (r, in_dim)
B = lora_b_pool[idx] # (out_dim, r)
# x[i]: (1, in_dim)
# x_i @ A.T @ B.T
h_r = x[i].unsqueeze(0) @ A.T # (1, r)
y_i = h_r @ B.T # (1, out_dim)
out_ref[i] = y_i.squeeze(0)
# 2. Triton 单算子块级并行计算
out_tri = triton_multi_lora_forward(x, lora_a_pool, lora_b_pool, lora_indices)
# 3. 验证结果
diff = torch.max(torch.abs(out_ref - out_tri))
print(f"最大误差: {diff.item():.6e}")
assert diff < 1e-3, "Triton Multi-LoRA 路由或计算结果不正确!"
print("✅ Multi-LoRA 路由融合推理验证通过。")
print("该算子实现了 Token 级别的动态路由,支持 Batch 内多模型并发推理。")
except NotImplementedError:
print("请先完成 TODO 代码!")
except Exception as e:
print(f"❌ 测试失败: {e}")
test_multi_lora()🛑 STOP HERE 🛑
请先尝试自己完成代码并跑通测试。
如果你正在 Colab 中运行,并且遇到困难没有思路,可以向下滚动查看参考答案。
参考代码与解析
代码
import torch
import triton
import triton.language as tl
@triton.jit
def fused_multi_lora_kernel(
x_ptr, out_ptr,
lora_a_pool_ptr, lora_b_pool_ptr,
lora_indices_ptr,
M, IN_DIM, OUT_DIM, R: tl.constexpr,
stride_x_m, stride_x_in,
stride_out_m, stride_out_dim,
stride_a_pool, stride_a_r, stride_a_in,
stride_b_pool, stride_b_out, stride_b_r,
BLOCK_IN: tl.constexpr, BLOCK_OUT: tl.constexpr
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_n = pid_n * BLOCK_OUT + tl.arange(0, BLOCK_OUT)
# TODO 1: 读取当前 Token 的 LoRA 索引
lora_idx = tl.load(lora_indices_ptr + pid_m)
# TODO 2: 计算内存池中该 LoRA 的基地址偏移
a_pool_base_ptr = lora_a_pool_ptr + lora_idx * stride_a_pool
b_pool_base_ptr = lora_b_pool_ptr + lora_idx * stride_b_pool
acc = tl.zeros((BLOCK_OUT,), dtype=tl.float32)
h_r = tl.zeros((R,), dtype=tl.float32)
# TODO 3: 计算 x @ A,得到中间激活 h_r
num_k_blocks = tl.cdiv(IN_DIM, BLOCK_IN)
for k in range(num_k_blocks):
offs_in = k * BLOCK_IN + tl.arange(0, BLOCK_IN)
x_ptrs = x_ptr + pid_m * stride_x_m + offs_in * stride_x_in
x_val = tl.load(x_ptrs, mask=offs_in < IN_DIM, other=0.0)
offs_r = tl.arange(0, R)
a_ptrs = a_pool_base_ptr + offs_r[:, None] * stride_a_r + offs_in[None, :] * stride_a_in
a_val = tl.load(a_ptrs, mask=offs_in[None, :] < IN_DIM, other=0.0)
h_r += tl.sum(x_val[None, :] * a_val, axis=1)
# TODO 4: 计算 h_r @ B,得到最终输出
offs_r = tl.arange(0, R)
b_ptrs = b_pool_base_ptr + offs_n[:, None] * stride_b_out + offs_r[None, :] * stride_b_r
b_val = tl.load(b_ptrs, mask=offs_n[:, None] < OUT_DIM, other=0.0)
acc += tl.sum(h_r[None, :] * b_val, axis=1)
# TODO 5: 将结果写回显存
out_ptrs = out_ptr + pid_m * stride_out_m + offs_n * stride_out_dim
tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=offs_n < OUT_DIM)
def triton_multi_lora_forward(x: torch.Tensor, lora_a_pool: torch.Tensor, lora_b_pool: torch.Tensor, lora_indices: torch.Tensor):
M, IN_DIM = x.shape
num_loras, R, _ = lora_a_pool.shape
_, OUT_DIM, _ = lora_b_pool.shape
out = torch.empty((M, OUT_DIM), device=x.device, dtype=x.dtype)
BLOCK_IN = 64
BLOCK_OUT = 64
grid = (M, triton.cdiv(OUT_DIM, BLOCK_OUT))
fused_multi_lora_kernel[grid](
x, out,
lora_a_pool, lora_b_pool,
lora_indices,
M, IN_DIM, OUT_DIM, R,
x.stride(0), x.stride(1),
out.stride(0), out.stride(1),
lora_a_pool.stride(0), lora_a_pool.stride(1), lora_a_pool.stride(2),
lora_b_pool.stride(0), lora_b_pool.stride(1), lora_b_pool.stride(2),
BLOCK_IN=BLOCK_IN, BLOCK_OUT=BLOCK_OUT
)
return out解析
1. TODO 1: 读取当前 Token 的 LoRA 索引
- 实现方式:python
lora_idx = tl.load(lora_indices_ptr + pid_m) - 关键点:每个 Token 都有自己的 LoRA 选择,
pid_m负责定位当前 Token 的路由信息。 - 技术细节:
lora_indices是一个长度为batch_size的整型数组,记录了每个 Token 应该使用哪组 LoRA 权重。
2. TODO 2: 计算内存池中该 LoRA 的基地址偏移
- 实现方式:python
a_pool_base_ptr = lora_a_pool_ptr + lora_idx * stride_a_pool b_pool_base_ptr = lora_b_pool_ptr + lora_idx * stride_b_pool - 关键点:通过指针偏移直接跳转到对应 LoRA 的权重区间,而不是做额外的查找或拷贝。
- 技术细节:
stride_a_pool和stride_b_pool表示内存池第一维的步长,乘上lora_idx后就能定位到对应 LoRA 的起始位置。
3. TODO 3: 计算 x @ A,得到中间激活 h_r
- 实现方式:python
for k in range(num_k_blocks): offs_in = k * BLOCK_IN + tl.arange(0, BLOCK_IN) x_val = tl.load(x_ptrs, mask=offs_in < IN_DIM, other=0.0) a_val = tl.load(a_ptrs, mask=offs_in[None, :] < IN_DIM, other=0.0) h_r += tl.sum(x_val[None, :] * a_val, axis=1) - 关键点:
x @ A是低秩分解中的第一步,需要按输入维度分块累加,避免一次性展开全部计算。 - 技术细节:
A的形状是(R, IN_DIM),因此需要用二维索引同时定位低秩维和输入维。tl.cdiv(IN_DIM, BLOCK_IN)负责决定分块次数。mask=offs_in[None, :] < IN_DIM可以防止尾块越界访问。
4. TODO 4: 计算 h_r @ B,得到最终输出
- 实现方式:python
b_ptrs = b_pool_base_ptr + offs_n[:, None] * stride_b_out + offs_r[None, :] * stride_b_r b_val = tl.load(b_ptrs, mask=offs_n[:, None] < OUT_DIM, other=0.0) acc += tl.sum(h_r[None, :] * b_val, axis=1) - 关键点:
B的形状是(OUT_DIM, R),需要把中间激活投影回输出维度。 - 技术细节:
tl.sum(..., axis=1)在低秩维度上做归约,得到每个输出位置的累加值。- 这里的计算完全复用前一步的中间激活,不需要再回到 HBM 读取额外中间结果。
5. TODO 5: 将结果写回显存
- 实现方式:python
out_ptrs = out_ptr + pid_m * stride_out_m + offs_n * stride_out_dim tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=offs_n < OUT_DIM) - 关键点:写回输出时必须保留边界保护,避免超出
OUT_DIM的无效写入。 - 技术细节:
acc.to(out_ptr.dtype.element_ty)保证输出张量的类型和上层 PyTorch 接口一致。
工程优化要点
- 内存池设计:把所有 LoRA 权重放进统一内存池,减少频繁分配和释放的开销。
- Token 级路由:每个 Token 独立选择 LoRA 权重,支持更细粒度的多租户推理。
- 指针偏移优化:通过
stride计算偏移,避免复杂索引带来的额外开销。 - 分块计算:对
IN_DIM和OUT_DIM分块,提升 SRAM 利用率和并行度。 - 低秩分解:利用
R << IN_DIM, OUT_DIM的结构,把一次大投影拆成两次小投影。 - Batch 并行:不同 Token 的路由彼此独立,适合在 GPU 上并行执行。
- 工业应用:该算子是 S-LoRA、Punica 等多租户推理框架的核心组件。
性能对比建议
- 建议补一个串行 LoRA 参考实现,与 Multi-LoRA 融合版本对照。
- 至少展示三项:串行时延、融合时延、加速比。
- 如果要继续扩展,可再补 batch size 变化趋势和 HBM 读写分析。
推理优化主线收束
Multi-LoRA 是推理优化主线的收束点:
- RoPE → 前处理融合
- FlashAttention → Attention 核心加速
- PagedAttention → 显存碎片管理
- Quantization → 显存体积压缩
- Multi-LoRA → 多租户动态路由
至此,07-11 的两条主线已经覆盖完毕:
- 07-09:Attention 优化(让单次计算更快)
- 10-11:推理服务优化(让系统承载更多用户)
接下来进入 项目篇(12-14):把所有这些算子集成到一个完整的 Llama3 Block 中,用真实数据验证整体性能收益。
