Skip to content

Matching Models

Matching (retrieval) models fetch a candidate set from large catalogs. Torch-RecHub offers two-tower and sequence-based retrieval models for diverse scenarios.

Model Overview (what to pick)

ModelWhen to useHighlights
DSSMClassic text/ID matchingTwo-tower semantic matching; cosine/dot similarity
YoutubeDNNLarge-scale retrievalSequence-based user encoder + item tower
MINDMulti-interest usersCapsule routing to learn multiple user interests
GRU4Rec / SASRecSequential retrievalRNN / Transformer for recent behavior
ComirecDR / ComirecSAControllable interestsDisentangled or self-attentive multi-interest

Quick Usage (example: DSSM)

python
from torch_rechub.models.matching import DSSM
from torch_rechub.trainers import MatchTrainer
from torch_rechub.basic.features import SparseFeature, DenseFeature

user_features = [
    SparseFeature("user_id", vocab_size=10000, embed_dim=32),
    DenseFeature("age", embed_dim=1),
]
item_features = [
    SparseFeature("item_id", vocab_size=100000, embed_dim=32),
    SparseFeature("category", vocab_size=1000, embed_dim=16),
]

model = DSSM(
    user_features=user_features,
    item_features=item_features,
    temperature=0.02,
    user_params={"dims": [256, 128, 64], "dropout": 0.2, "activation": "prelu"},
    item_params={"dims": [256, 128, 64], "dropout": 0.2, "activation": "prelu"},
)

trainer = MatchTrainer(
    model=model,
    mode=0,  # 0: point-wise, 1: pair-wise (BPR), 2: list-wise
    optimizer_params={"lr": 1e-3},
    n_epoch=10,
)
trainer.fit(train_dl)

Export towers (two-tower serving)

python
trainer.export_onnx("user_tower.onnx", mode="user")
trainer.export_onnx("item_tower.onnx", mode="item")

Parameter Notes (common fields)

  • mode: training objective (point / pair / list).
  • user_features / item_features: feature schemas for user/item tower.
  • temperature: softmax temperature for contrastive losses.
  • user_params / item_params: MLP configs for each tower (dims, dropout, activation).

Tips

  • Multi-interest heavy users → MIND or Comirec family.
  • Need strong sequence signals → GRU4Rec (RNN) or SASRec (Transformer).
  • Large-scale production two-tower → YoutubeDNN / DSSM.