Skip to content

向量检索封装

Torch-RecHub 提供了统一的向量检索接口,支持三种主流的近似最近邻(ANN)检索库:AnnoyFAISSMilvus。通过标准化的 Builder-Indexer 模式,用户可以方便地在不同检索后端之间切换。

快速开始

python
import torch
from torch_rechub.serving import builder_factory

# 准备嵌入向量
item_embeddings = torch.randn(1000, 64, dtype=torch.float32)  # 1000个物品,64维
user_embeddings = torch.randn(10, 64, dtype=torch.float32)    # 10个用户,64维

# 创建 Builder 并构建索引
builder = builder_factory("faiss", index_type="Flat", metric="L2")

with builder.from_embeddings(item_embeddings) as indexer:
    # 查询 Top-K
    ids, distances = indexer.query(user_embeddings, top_k=10)
    # 保存索引
    indexer.save("item.index")

# 加载已有索引
with builder.from_index_file("item.index") as indexer:
    ids, distances = indexer.query(user_embeddings, top_k=10)

核心概念

Builder-Indexer 模式

  • Builder:负责索引的构建配置,通过 builder_factory 工厂函数创建
  • Indexer:负责具体的查询和保存操作,通过 Builder 的上下文管理器获取

工厂函数

python
from torch_rechub.serving import builder_factory

builder = builder_factory(model, **builder_config)
参数类型说明
modelstr检索后端名称:"annoy""faiss""milvus"
**builder_configdict传递给具体 Builder 的配置参数

Annoy

Annoy(Approximate Nearest Neighbors Oh Yeah)是 Spotify 开源的近似最近邻搜索库,特点是内存友好、支持文件内存映射。

安装

bash
pip install annoy

⚠️ 注意:Annoy 是 C++ 库,安装时需要编译。如果你的系统没有 C++ 编译环境(如 Windows 上缺少 Visual Studio Build Tools),可能会遇到编译错误。

解决方案

  • Linux/macOS:确保安装了 gcc/g++clang
  • Windows:安装 Visual Studio Build Tools,或直接下载预编译的 wheel 文件:

参数说明

python
builder = builder_factory(
    "annoy",
    d=64,                    # 向量维度(必需)
    metric="angular",        # 距离度量
    n_trees=10,              # 树的数量
    threads=-1,              # 构建时的线程数
    searchk=-1,              # 搜索时检查的节点数
)
参数类型默认值说明
dint必需向量维度
metricstr"angular"距离度量:"angular"(余弦)、"euclidean"(欧氏)、"dot"(点积)
n_treesint10构建的树数量,越多精度越高但构建越慢
threadsint-1构建线程数,-1 表示使用所有可用核心
searchkint-1搜索时检查的节点数,-1 表示 n_trees * top_k

使用示例

python
import torch
from torch_rechub.serving import builder_factory

item_embeddings = torch.randn(1000, 64, dtype=torch.float32)
user_embeddings = torch.randn(10, 64, dtype=torch.float32)

# 使用余弦相似度
builder = builder_factory(
    "annoy",
    d=64,
    metric="angular",
    n_trees=50,
    searchk=100,
)

with builder.from_embeddings(item_embeddings) as indexer:
    ids, distances = indexer.query(user_embeddings, top_k=10)
    indexer.save("annoy.index")
    
print(f"召回物品ID: {ids}")
print(f"距离: {distances}")

FAISS

FAISS(Facebook AI Similarity Search)是 Meta 开源的高性能相似性搜索库,支持 GPU 加速和多种索引类型。

安装

bash
# CPU 版本
pip install faiss-cpu

# GPU 版本(需要 CUDA)
pip install faiss-gpu

支持的索引类型

索引类型说明适用场景
Flat暴力搜索,精确结果小规模数据(< 10万)
HNSW基于图的近似搜索中等规模,高召回率要求
IVF倒排索引,聚类后搜索大规模数据

参数说明

Flat 索引

python
builder = builder_factory(
    "faiss",
    index_type="Flat",       # 索引类型
    metric="L2",             # 距离度量
)
参数类型默认值说明
index_typestr"Flat"索引类型
metricstr"L2"距离度量:"L2"(欧氏距离)、"IP"(内积)

HNSW 索引

python
builder = builder_factory(
    "faiss",
    index_type="HNSW",
    metric="L2",
    m=32,                    # 每个节点的最大邻居数
    efSearch=50,             # 搜索时的候选节点数
)
参数类型默认值说明
mint32每个节点的最大邻居数,越大精度越高
efSearchintNone搜索时的候选节点数,越大精度越高但越慢

IVF 索引

python
builder = builder_factory(
    "faiss",
    index_type="IVF",
    metric="L2",
    nlists=100,              # 聚类中心数量
    nprobe=10,               # 搜索时访问的聚类数
)
参数类型默认值说明
nlistsint100聚类中心数量,建议为 sqrt(n)4*sqrt(n)
nprobeintNone搜索时访问的聚类数,越大精度越高但越慢

使用示例

python
import torch
from torch_rechub.serving import builder_factory

item_embeddings = torch.randn(10000, 128, dtype=torch.float32)
user_embeddings = torch.randn(100, 128, dtype=torch.float32)

# 使用 HNSW 索引
builder = builder_factory(
    "faiss",
    index_type="HNSW",
    metric="IP",  # 内积,适合归一化后的向量
    m=32,
    efSearch=64,
)

with builder.from_embeddings(item_embeddings) as indexer:
    ids, distances = indexer.query(user_embeddings, top_k=20)
    indexer.save("faiss_hnsw.index")

Milvus

Milvus 是一个云原生向量数据库,支持分布式部署和多种索引算法,适合生产环境的大规模向量检索。

安装

bash
pip install pymilvus

注意:使用 Milvus 需要先启动 Milvus 服务。可以使用 Docker 快速启动:

bash
docker run -d --name milvus-standalone -p 19530:19530 milvusdb/milvus:latest

支持的索引类型

索引类型说明适用场景
FLAT暴力搜索小规模数据,精确结果
HNSW基于图的索引中等规模,高召回率
IVF_FLAT倒排索引大规模数据

参数说明

FLAT 索引

python
builder = builder_factory(
    "milvus",
    d=64,                    # 向量维度(必需)
    index_type="FLAT",
    metric="L2",             # 距离度量
)
参数类型默认值说明
dint必需向量维度
metricstr"L2"距离度量:"L2""IP""COSINE"

HNSW 索引

python
builder = builder_factory(
    "milvus",
    d=64,
    index_type="HNSW",
    metric="COSINE",
    m=16,                    # 每个节点的最大邻居数
    ef=50,                   # 搜索时的候选节点数
)
参数类型默认值说明
mint16每个节点的最大邻居数
efintNone搜索时的候选节点数

IVF_FLAT 索引

python
builder = builder_factory(
    "milvus",
    d=64,
    index_type="IVF_FLAT",
    metric="IP",
    nlist=128,               # 聚类中心数量
    nprobe=16,               # 搜索时访问的聚类数
)
参数类型默认值说明
nlistint128聚类中心数量
nprobeintNone搜索时访问的聚类数

使用示例

python
import torch
from torch_rechub.serving import builder_factory

item_embeddings = torch.randn(10000, 64, dtype=torch.float32)
user_embeddings = torch.randn(100, 64, dtype=torch.float32)

# 使用 Milvus HNSW 索引
builder = builder_factory(
    "milvus",
    d=64,
    index_type="HNSW",
    metric="COSINE",
    m=32,
    ef=64,
)

with builder.from_embeddings(item_embeddings) as indexer:
    ids, distances = indexer.query(user_embeddings, top_k=10)
    # 注意:Milvus 索引保存在服务端,不支持本地 save

完整示例:召回模型评估

以下是一个完整的双塔模型向量化评估示例:

python
import collections
import numpy as np
import pandas as pd
import torch
from torch_rechub.serving import builder_factory
from torch_rechub.basic.metric import topk_metrics

def match_evaluation(
    user_embedding: torch.Tensor,
    item_embedding: torch.Tensor,
    test_user: dict,
    all_item: dict,
    user_col: str = 'user_id',
    item_col: str = 'item_id',
    raw_id_maps: str = "./raw_id_maps.npy",
    topk: int = 10,
    backend: str = "faiss",
    **backend_kwargs,
):
    """
    使用向量检索进行召回评估
    
    Args:
        user_embedding: 用户嵌入向量 (n_users, dim)
        item_embedding: 物品嵌入向量 (n_items, dim)
        test_user: 测试用户数据字典
        all_item: 全量物品数据字典
        user_col: 用户ID列名
        item_col: 物品ID列名
        raw_id_maps: ID映射文件路径
        topk: 召回数量
        backend: 检索后端 ("annoy", "faiss", "milvus")
        **backend_kwargs: 传递给 builder_factory 的额外参数
    
    Returns:
        评估指标字典
    """
    print(f"使用 {backend} 进行向量化召回评估")
    
    # 1. 创建 Builder
    dim = item_embedding.shape[1]
    
    if backend == "annoy":
        builder = builder_factory("annoy", d=dim, n_trees=10, **backend_kwargs)
    elif backend == "faiss":
        builder = builder_factory("faiss", index_type="Flat", metric="L2", **backend_kwargs)
    elif backend == "milvus":
        builder = builder_factory("milvus", d=dim, index_type="FLAT", metric="L2", **backend_kwargs)
    else:
        raise ValueError(f"不支持的后端: {backend}")
    
    # 2. 确保张量在 CPU 上
    item_embedding = item_embedding.cpu().float()
    user_embedding = user_embedding.cpu().float()
    
    # 3. 加载 ID 映射
    user_map, item_map = np.load(raw_id_maps, allow_pickle=True)
    
    # 4. 构建索引并查询
    match_res = collections.defaultdict(dict)
    
    with builder.from_embeddings(item_embedding) as indexer:
        ids, distances = indexer.query(user_embedding, topk)
        ids_np = ids.numpy()
        
        for i, user_id in enumerate(test_user[user_col]):
            items_idx = ids_np[i]
            predicted_item_ids = all_item[item_col][items_idx]
            match_res[user_map[user_id]] = np.vectorize(item_map.get)(predicted_item_ids)
    
    # 5. 构建 ground truth
    data = pd.DataFrame({user_col: test_user[user_col], item_col: test_user[item_col]})
    data[user_col] = data[user_col].map(user_map)
    data[item_col] = data[item_col].map(item_map)
    user_pos_item = data.groupby(user_col).agg(list).reset_index()
    ground_truth = dict(zip(user_pos_item[user_col], user_pos_item[item_col]))
    
    # 6. 计算指标
    out = topk_metrics(y_true=ground_truth, y_pred=match_res, topKs=[topk])
    return out


# 使用示例
# result = match_evaluation(
#     user_embedding, item_embedding, test_user, all_item,
#     topk=10, backend="faiss", index_type="HNSW", m=32
# )

性能对比与选型建议

特性AnnoyFAISSMilvus
安装难度简单中等需要服务
内存占用中等依赖服务配置
构建速度
查询速度中等
GPU 支持
分布式
适用场景小规模离线中大规模生产环境

选型建议

  • 快速原型/小数据集:使用 Annoy,安装简单,内存友好
  • 中大规模离线计算:使用 FAISS,性能优秀,支持 GPU
  • 生产环境/在线服务:使用 Milvus,支持分布式和动态更新

API 参考

BaseBuilder

python
class BaseBuilder(abc.ABC):
    def from_embeddings(self, embeddings: torch.Tensor) -> ContextManager[BaseIndexer]:
        """从嵌入向量构建索引"""
        
    def from_index_file(self, index_file: FilePath) -> ContextManager[BaseIndexer]:
        """从文件加载索引"""

BaseIndexer

python
class BaseIndexer(abc.ABC):
    def query(self, embeddings: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
        """
        查询最近邻
        
        Args:
            embeddings: 查询向量 (n, d)
            top_k: 返回的最近邻数量
            
        Returns:
            ids: 最近邻ID (n, top_k)
            distances: 距离 (n, top_k)
        """
        
    def save(self, file_path: FilePath) -> None:
        """保存索引到文件"""