Skip to content

11. Triton Multi LoRA | Triton 多租户路由与融合推理

难度: Hard | 标签: Triton, LoRA, Punica, Serving | 目标人群: 核心 Infra 与算子开发

🚀 云端运行环境

本章节的实战代码可以点击以下链接在免费 GPU 算力平台上直接运行:

Open In ColabOpen In Studio (国内推荐:魔搭社区免费实例)

02_PyTorch_Algorithms/09_LoRA_Tutorial.ipynb 中,我们在 PyTorch 层面实现了 LoRA (h=xW+xAB) 的逻辑。 然而,在工业级的大模型推理服务 (Serving) 中,面对并发的多个用户请求,如果每个用户的 prompt 挂载了不同的 LoRA 权重(如用户 A 请求写代码的 LoRA,用户 B 请求翻译的 LoRA),如果将它们拆分并循环执行 PyTorch 的 linear(),通常会降低 GPU 的吞吐量,也不便于利用 Batch 计算。 本节我们将实现 Multi-LoRA (如 S-LoRA / Punica 论文思路) 的底层 Triton 融合算子:通过传入 lora_indices,让每一个 Token 在 SRAM 内按索引读取它对应的 LoRA 权重,完成批量计算。

本节承接 08. FlashAttention、09. PagedAttention 和 10. Quantization 的结果,继续解决多用户并发推理中的 LoRA 权重路由 问题。想象一个 SaaS 推理服务:用户 A 用“代码生成” LoRA,用户 B 用“翻译” LoRA,用户 C 用“摘要” LoRA。传统做法会在切换 LoRA 时反复加载权重,导致 GPU 利用率偏低;Multi-LoRA 则把多个请求合并到一个 Batch 中,通过 Token 级动态路由完成融合推理。如果说 08-09 解决的是 Attention 计算与 KV Cache,10 解决的是权重体积,那么 11 关注的就是 多租户服务如何更高效地复用这些优化

前置

导语: 这一节会把“一个 batch 里多种 LoRA 路由”直接落到 Triton 的 SRAM 访问上。

相关阅读

导语: 如果想先回顾 LoRA 的 PyTorch 版本,可以先看这页,帮助理解后面的指针路由。

Step 1: Multi-LoRA 内存池与 Batch 路由

LoRA 的内存池 (Weight Pool): 我们通常不会为每一个单独的 LoRA 权重分配离散的张量,而是将机器上加载的所有 LoRA 矩阵 A (形状 r \t\times in_features) 拼接成一个连续张量 lora_a_pool,形状为 (num_loras, r, in_features)

Token 级别的细粒度路由 (Token-wise Routing): 假设输入特征 X 的形状为 (batch_size, in_features)。 我们额外传入一个长度为 batch_size 的一维整型数组 lora_indices。它记录了:当前 Batch 中,第 i 个 Token 到底需要使用 lora_a_pool 里的第几个 LoRA 模型。

SRAM 内部的块级并行:

  • Triton pid 处理矩阵 X 的某一行(某个 Token)。
  • 我们直接在内核里根据 lora_idx = tl.load(lora_indices + pid),算出指向内存池中特定 LoRA A 和 B 的偏移量!
  • 一次性读取 Xi 和专属的 Aidx,Bidx,在极速的 SRAM 中完成 Xi×A×B,最后写回 Hi。 这样,原本通常需要分别计算的不同模型请求,可以被合并到一个 Triton Kernel 调用中(Batch Inference)

Step 2: 内存池与 Batch 指针路由

在推理服务器中,往往存在一个大底座模型对应几百个微调的 LoRA 权重。为了避免切换开销,我们将所有的 LoRA 权重放进统一的巨大的显存池中。每个发来的 Token 请求都会带有一个 lora_id。内核利用这个 lora_id 充当偏移指针,直接在同一次前向计算中去抓取不同的权重完成点积。

Step 3: 指针路由代码框架

传入包含所有权重的张量 lora_pool 和整数数组 lora_indices。在内核中,先读取 lora_idx = tl.load(lora_indices_ptr + pid_batch),将该索引乘上权重的 stride,动态确定当前线程块该加载哪一份 LoRA A 和 B,随后做标准的低秩乘加运算。

补充说明:性能对比与边界

这一节的核心卖点是把多个不同 LoRA 请求合并成一次 kernel 调用,所以最好补一个与串行方案的对比口径。 建议后面在验证区展示:

  • 串行 LoRA 推理的时延
  • Multi-LoRA 融合后的时延
  • 最终加速比

边界条件上,当前实现默认 num_loras > 0R > 0;如果要继续增强鲁棒性,可以在宿主侧先做参数校验,而不是把异常分支塞进 kernel 主干。

Step 4: 动手实战

要求:请补全下方 fused_multi_lora_kernel。我们需要根据传入的 lora_indices,正确地在三维张量 lora_a_poollora_b_pool 中计算偏移量,并执行点积。为了简化,这里假设秩 r 较小(如 8 或 16),且输入只进行列并行分块(行方向通常可以一次性放入 SRAM)。

python
try:
    import triton
except ModuleNotFoundError:
    try:
        import google.colab  # type: ignore
    except Exception:
        raise
    import subprocess, sys
    print('Installing Triton for Part 3...')
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', 'triton'])
    import triton

import torch
import triton
import triton.language as tl
python
@triton.jit
def fused_multi_lora_kernel(
    x_ptr, out_ptr,
    lora_a_pool_ptr, lora_b_pool_ptr,
    lora_indices_ptr,
    M, IN_DIM, OUT_DIM, R: tl.constexpr,
    stride_x_m, stride_x_in,
    stride_out_m, stride_out_dim,
    stride_a_pool, stride_a_r, stride_a_in,
    stride_b_pool, stride_b_out, stride_b_r,
    BLOCK_IN: tl.constexpr, BLOCK_OUT: tl.constexpr
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_n = pid_n * BLOCK_OUT + tl.arange(0, BLOCK_OUT)

    # ==========================================
    # TODO 1: 读取当前 Token 的 LoRA 索引
    # ==========================================
    # lora_idx = ???

    # ==========================================
    # TODO 2: 计算内存池中该 LoRA 的基地址偏移
    # ==========================================
    # a_pool_base_ptr = ???
    # b_pool_base_ptr = ???

    # ==========================================
    # TODO 3: 计算 x @ A,得到中间激活 h_r
    # ==========================================
    # acc = tl.zeros((BLOCK_OUT,), dtype=tl.float32)
    # h_r = tl.zeros((R,), dtype=tl.float32)
    # num_k_blocks = tl.cdiv(IN_DIM, BLOCK_IN)
    # for k in range(num_k_blocks):
    #     ...

    # ==========================================
    # TODO 4: 计算 h_r @ B,得到最终输出
    # ==========================================
    # offs_r = tl.arange(0, R)
    # b_ptrs = ???
    # b_val = ???
    # acc += ???

    # ==========================================
    # TODO 5: 将结果写回显存
    # ==========================================
    # out_ptrs = ???
    # tl.store(...)
    
def _question_placeholder_11():
    raise NotImplementedError("请完成 TODO 1-5")

_question_placeholder_11()

def triton_multi_lora_forward(x: torch.Tensor, lora_a_pool: torch.Tensor, lora_b_pool: torch.Tensor, lora_indices: torch.Tensor):
    M, IN_DIM = x.shape
    num_loras, R, _ = lora_a_pool.shape
    _, OUT_DIM, _ = lora_b_pool.shape
    
    out = torch.empty((M, OUT_DIM), device=x.device, dtype=x.dtype)
    
    BLOCK_IN = 64
    BLOCK_OUT = 64
    
    grid = (M, triton.cdiv(OUT_DIM, BLOCK_OUT))
    
    fused_multi_lora_kernel[grid](
        x, out, 
        lora_a_pool, lora_b_pool, 
        lora_indices,
        M, IN_DIM, OUT_DIM, R,
        x.stride(0), x.stride(1),
        out.stride(0), out.stride(1),
        lora_a_pool.stride(0), lora_a_pool.stride(1), lora_a_pool.stride(2),
        lora_b_pool.stride(0), lora_b_pool.stride(1), lora_b_pool.stride(2),
        BLOCK_IN=BLOCK_IN, BLOCK_OUT=BLOCK_OUT
    )
    return out
python
# 测试并验证 Multi-LoRA 路由的正确性
def test_multi_lora():
    if not torch.cuda.is_available():
        print("⏭️ 忽略测试:无 GPU。")
        return
        
    try:
        torch.manual_seed(42)
        batch_size = 4
        in_dim = 128
        out_dim = 256
        num_loras = 3 # 内存池中有 3 个不同的 LoRA
        r = 16
        
        x = torch.randn(batch_size, in_dim, device='cuda')
        
        # 构造内存池
        # A_pool: (3, 16, 128)
        # B_pool: (3, 256, 16)
        lora_a_pool = torch.randn(num_loras, r, in_dim, device='cuda')
        lora_b_pool = torch.randn(num_loras, out_dim, r, device='cuda')
        
        # 构造复杂的请求路由 (Token 0用LoRA_2, Token 1用LoRA_0...)
        lora_indices = torch.tensor([2, 0, 1, 2], device='cuda', dtype=torch.int32)
        
        # 1. 纯 PyTorch 参考计算 (为了便于对照,使用较直观的 for 循环拼接)
        out_ref = torch.zeros(batch_size, out_dim, device='cuda')
        for i in range(batch_size):
            idx = lora_indices[i].item()
            # 提取专属权重
            A = lora_a_pool[idx] # (r, in_dim)
            B = lora_b_pool[idx] # (out_dim, r)
            
            # x[i]: (1, in_dim)
            # x_i @ A.T @ B.T
            h_r = x[i].unsqueeze(0) @ A.T # (1, r)
            y_i = h_r @ B.T               # (1, out_dim)
            out_ref[i] = y_i.squeeze(0)
            
        # 2. Triton 单算子块级并行计算
        out_tri = triton_multi_lora_forward(x, lora_a_pool, lora_b_pool, lora_indices)
        
        # 3. 验证结果
        diff = torch.max(torch.abs(out_ref - out_tri))
        print(f"最大误差: {diff.item():.6e}")
        assert diff < 1e-3, "Triton Multi-LoRA 路由或计算结果不正确!"
        
        print("✅ Multi-LoRA 路由融合推理验证通过。")
        print("该算子实现了 Token 级别的动态路由,支持 Batch 内多模型并发推理。")
        
    except NotImplementedError:
        print("请先完成 TODO 代码!")
    except Exception as e:
        print(f"❌ 测试失败: {e}")

test_multi_lora()

🛑 STOP HERE 🛑









请先尝试自己完成代码并跑通测试。
如果你正在 Colab 中运行,并且遇到困难没有思路,可以向下滚动查看参考答案。










参考代码与解析

代码

python
import torch
import triton
import triton.language as tl

@triton.jit
def fused_multi_lora_kernel(
    x_ptr, out_ptr, 
    lora_a_pool_ptr, lora_b_pool_ptr, 
    lora_indices_ptr,
    M, IN_DIM, OUT_DIM, R: tl.constexpr,
    stride_x_m, stride_x_in,
    stride_out_m, stride_out_dim,
    stride_a_pool, stride_a_r, stride_a_in,
    stride_b_pool, stride_b_out, stride_b_r,
    BLOCK_IN: tl.constexpr, BLOCK_OUT: tl.constexpr
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    offs_n = pid_n * BLOCK_OUT + tl.arange(0, BLOCK_OUT)
    
    # TODO 1: 读取当前 Token 的 LoRA 索引
    lora_idx = tl.load(lora_indices_ptr + pid_m)
    
    # TODO 2: 计算内存池中该 LoRA 的基地址偏移
    a_pool_base_ptr = lora_a_pool_ptr + lora_idx * stride_a_pool
    b_pool_base_ptr = lora_b_pool_ptr + lora_idx * stride_b_pool
    
    acc = tl.zeros((BLOCK_OUT,), dtype=tl.float32)
    h_r = tl.zeros((R,), dtype=tl.float32)
    
    # TODO 3: 计算 x @ A,得到中间激活 h_r
    num_k_blocks = tl.cdiv(IN_DIM, BLOCK_IN)
    for k in range(num_k_blocks):
        offs_in = k * BLOCK_IN + tl.arange(0, BLOCK_IN)
        
        x_ptrs = x_ptr + pid_m * stride_x_m + offs_in * stride_x_in
        x_val = tl.load(x_ptrs, mask=offs_in < IN_DIM, other=0.0)
        
        offs_r = tl.arange(0, R)
        a_ptrs = a_pool_base_ptr + offs_r[:, None] * stride_a_r + offs_in[None, :] * stride_a_in
        a_val = tl.load(a_ptrs, mask=offs_in[None, :] < IN_DIM, other=0.0)
        
        h_r += tl.sum(x_val[None, :] * a_val, axis=1)
    
    # TODO 4: 计算 h_r @ B,得到最终输出
    offs_r = tl.arange(0, R)
    b_ptrs = b_pool_base_ptr + offs_n[:, None] * stride_b_out + offs_r[None, :] * stride_b_r
    b_val = tl.load(b_ptrs, mask=offs_n[:, None] < OUT_DIM, other=0.0)
    
    acc += tl.sum(h_r[None, :] * b_val, axis=1)
    
    # TODO 5: 将结果写回显存
    out_ptrs = out_ptr + pid_m * stride_out_m + offs_n * stride_out_dim
    tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=offs_n < OUT_DIM)

def triton_multi_lora_forward(x: torch.Tensor, lora_a_pool: torch.Tensor, lora_b_pool: torch.Tensor, lora_indices: torch.Tensor):
    M, IN_DIM = x.shape
    num_loras, R, _ = lora_a_pool.shape
    _, OUT_DIM, _ = lora_b_pool.shape
    
    out = torch.empty((M, OUT_DIM), device=x.device, dtype=x.dtype)
    
    BLOCK_IN = 64
    BLOCK_OUT = 64
    
    grid = (M, triton.cdiv(OUT_DIM, BLOCK_OUT))
    
    fused_multi_lora_kernel[grid](
        x, out, 
        lora_a_pool, lora_b_pool, 
        lora_indices,
        M, IN_DIM, OUT_DIM, R,
        x.stride(0), x.stride(1),
        out.stride(0), out.stride(1),
        lora_a_pool.stride(0), lora_a_pool.stride(1), lora_a_pool.stride(2),
        lora_b_pool.stride(0), lora_b_pool.stride(1), lora_b_pool.stride(2),
        BLOCK_IN=BLOCK_IN, BLOCK_OUT=BLOCK_OUT
    )
    return out

解析

1. TODO 1: 读取当前 Token 的 LoRA 索引

  • 实现方式
    python
    lora_idx = tl.load(lora_indices_ptr + pid_m)
  • 关键点:每个 Token 都有自己的 LoRA 选择,pid_m 负责定位当前 Token 的路由信息。
  • 技术细节lora_indices 是一个长度为 batch_size 的整型数组,记录了每个 Token 应该使用哪组 LoRA 权重。

2. TODO 2: 计算内存池中该 LoRA 的基地址偏移

  • 实现方式
    python
    a_pool_base_ptr = lora_a_pool_ptr + lora_idx * stride_a_pool
    b_pool_base_ptr = lora_b_pool_ptr + lora_idx * stride_b_pool
  • 关键点:通过指针偏移直接跳转到对应 LoRA 的权重区间,而不是做额外的查找或拷贝。
  • 技术细节stride_a_poolstride_b_pool 表示内存池第一维的步长,乘上 lora_idx 后就能定位到对应 LoRA 的起始位置。

3. TODO 3: 计算 x @ A,得到中间激活 h_r

  • 实现方式
    python
    for k in range(num_k_blocks):
        offs_in = k * BLOCK_IN + tl.arange(0, BLOCK_IN)
        x_val = tl.load(x_ptrs, mask=offs_in < IN_DIM, other=0.0)
        a_val = tl.load(a_ptrs, mask=offs_in[None, :] < IN_DIM, other=0.0)
        h_r += tl.sum(x_val[None, :] * a_val, axis=1)
  • 关键点x @ A 是低秩分解中的第一步,需要按输入维度分块累加,避免一次性展开全部计算。
  • 技术细节
    • A 的形状是 (R, IN_DIM),因此需要用二维索引同时定位低秩维和输入维。
    • tl.cdiv(IN_DIM, BLOCK_IN) 负责决定分块次数。
    • mask=offs_in[None, :] < IN_DIM 可以防止尾块越界访问。

4. TODO 4: 计算 h_r @ B,得到最终输出

  • 实现方式
    python
    b_ptrs = b_pool_base_ptr + offs_n[:, None] * stride_b_out + offs_r[None, :] * stride_b_r
    b_val = tl.load(b_ptrs, mask=offs_n[:, None] < OUT_DIM, other=0.0)
    acc += tl.sum(h_r[None, :] * b_val, axis=1)
  • 关键点B 的形状是 (OUT_DIM, R),需要把中间激活投影回输出维度。
  • 技术细节
    • tl.sum(..., axis=1) 在低秩维度上做归约,得到每个输出位置的累加值。
    • 这里的计算完全复用前一步的中间激活,不需要再回到 HBM 读取额外中间结果。

5. TODO 5: 将结果写回显存

  • 实现方式
    python
    out_ptrs = out_ptr + pid_m * stride_out_m + offs_n * stride_out_dim
    tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=offs_n < OUT_DIM)
  • 关键点:写回输出时必须保留边界保护,避免超出 OUT_DIM 的无效写入。
  • 技术细节acc.to(out_ptr.dtype.element_ty) 保证输出张量的类型和上层 PyTorch 接口一致。

工程优化要点

  • 内存池设计:把所有 LoRA 权重放进统一内存池,减少频繁分配和释放的开销。
  • Token 级路由:每个 Token 独立选择 LoRA 权重,支持更细粒度的多租户推理。
  • 指针偏移优化:通过 stride 计算偏移,避免复杂索引带来的额外开销。
  • 分块计算:对 IN_DIMOUT_DIM 分块,提升 SRAM 利用率和并行度。
  • 低秩分解:利用 R << IN_DIM, OUT_DIM 的结构,把一次大投影拆成两次小投影。
  • Batch 并行:不同 Token 的路由彼此独立,适合在 GPU 上并行执行。
  • 工业应用:该算子是 S-LoRA、Punica 等多租户推理框架的核心组件。

性能对比建议

  • 建议补一个串行 LoRA 参考实现,与 Multi-LoRA 融合版本对照。
  • 至少展示三项:串行时延、融合时延、加速比。
  • 如果要继续扩展,可再补 batch size 变化趋势和 HBM 读写分析。

推理优化主线收束

Multi-LoRA 是推理优化主线的收束点

  • RoPE → 前处理融合
  • FlashAttention → Attention 核心加速
  • PagedAttention → 显存碎片管理
  • Quantization → 显存体积压缩
  • Multi-LoRA → 多租户动态路由

至此,07-11 的两条主线已经覆盖完毕:

  • 07-09:Attention 优化(让单次计算更快)
  • 10-11:推理服务优化(让系统承载更多用户)

接下来进入 项目篇(12-14):把所有这些算子集成到一个完整的 Llama3 Block 中,用真实数据验证整体性能收益。

Released under the MIT License.