Multi-Task Learning Tutorial
This tutorial uses the built-in Ali-CCP sample dataset to introduce the actual multi-task training flow in Torch-RecHub. All code snippets assume you are running from the repository root.
1. Data Preparation
1. Load sample data
python
import pandas as pd
df_train = pd.read_csv("examples/ranking/data/ali-ccp/ali_ccp_train_sample.csv")
df_val = pd.read_csv("examples/ranking/data/ali-ccp/ali_ccp_val_sample.csv")
df_test = pd.read_csv("examples/ranking/data/ali-ccp/ali_ccp_test_sample.csv")
# Concatenate train / val / test first so feature definitions stay consistent.
train_idx = df_train.shape[0]
val_idx = train_idx + df_val.shape[0]
data = pd.concat([df_train, df_val, df_test], axis=0)
# ctcvr_label is often used in ESMM as the third task: click * conversion
data.rename(columns={"purchase": "cvr_label", "click": "ctr_label"}, inplace=True)
data["ctcvr_label"] = data["cvr_label"] * data["ctr_label"]2. Build dense and sparse features
python
from torch_rechub.basic.features import DenseFeature, SparseFeature
# Ali-CCP is mostly sparse features, with a small number of dense columns.
dense_cols = ["D109_14", "D110_14", "D127_14", "D150_14", "D508", "D509", "D702", "D853"]
sparse_cols = [
col for col in data.columns
if col not in dense_cols and col not in ["cvr_label", "ctr_label", "ctcvr_label"]
]
# In multi-task learning, all tasks share the same bottom input feature set by default.
features = [SparseFeature(col, data[col].max() + 1, embed_dim=4) for col in sparse_cols] + [
DenseFeature(col) for col in dense_cols
]
label_cols = ["cvr_label", "ctr_label"]
used_cols = sparse_cols + dense_cols3. Build train / validation / test loaders
python
from torch_rechub.utils.data import DataGenerator
# In multi-task settings, y becomes a 2D label matrix instead of a single label vector.
x_train = {name: data[name].values[:train_idx] for name in used_cols}
y_train = data[label_cols].values[:train_idx]
x_val = {name: data[name].values[train_idx:val_idx] for name in used_cols}
y_val = data[label_cols].values[train_idx:val_idx]
x_test = {name: data[name].values[val_idx:] for name in used_cols}
y_test = data[label_cols].values[val_idx:]
dg = DataGenerator(x_train, y_train)
train_dl, val_dl, test_dl = dg.generate_dataloader(
x_val=x_val,
y_val=y_val,
x_test=x_test,
y_test=y_test,
batch_size=1024,
)2. MMOE
python
from torch_rechub.models.multi_task import MMOE
from torch_rechub.trainers import MTLTrainer
# MMOE: shared experts + task-specific gates
model = MMOE(
features=features,
task_types=["classification", "classification"],
n_expert=8,
expert_params={"dims": [16]},
tower_params_list=[{"dims": [8]}, {"dims": [8]}],
)Training pattern
python
import os
import torch
torch.manual_seed(2022)
# MTLTrainer does not create model_path automatically.
os.makedirs("./saved/mmoe", exist_ok=True)
mtl_trainer = MTLTrainer(
model,
task_types=["classification", "classification"],
optimizer_params={"lr": 1e-3, "weight_decay": 1e-4},
n_epoch=5,
earlystop_patience=5,
device="cpu", # Change to "cuda:0" for GPU.
model_path="./saved/mmoe",
)
mtl_trainer.fit(train_dl, val_dl)
# evaluate() returns a list whose order matches task_types.
auc = mtl_trainer.evaluate(mtl_trainer.model, test_dl)
print(f"Test AUC: {auc}") # [cvr_auc, ctr_auc]3. PLE
python
from torch_rechub.models.multi_task import PLE
# PLE is often more stable than MMOE when task differences are larger,
# because it separates shared and task-specific experts.
model = PLE(
features=features,
task_types=["classification", "classification"],
n_level=1,
n_expert_specific=2,
n_expert_shared=1,
expert_params={"dims": [16]},
tower_params_list=[{"dims": [8]}, {"dims": [8]}],
)Adaptive loss weighting (optional)
python
# adaptive_params turns on dynamic loss balancing; this example uses UWL.
os.makedirs("./saved/ple", exist_ok=True)
mtl_trainer = MTLTrainer(
model,
task_types=["classification", "classification"],
optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
adaptive_params={"method": "uwl"},
n_epoch=5,
earlystop_patience=5,
device="cpu",
model_path="./saved/ple",
)
mtl_trainer.fit(train_dl, val_dl)4. ESMM
ESMM differs from MMOE / PLE in two ways:
- it only uses sparse features
- its label order is usually
["cvr_label", "ctr_label", "ctcvr_label"]
python
from torch_rechub.models.multi_task import ESMM
item_cols = ["129", "205", "206", "207", "210", "216"]
user_cols = [col for col in sparse_cols if col not in item_cols]
user_features = [SparseFeature(col, data[col].max() + 1, embed_dim=16) for col in user_cols]
item_features = [SparseFeature(col, data[col].max() + 1, embed_dim=16) for col in item_cols]
label_cols = ["cvr_label", "ctr_label", "ctcvr_label"]
x_train = {name: data[name].values[:train_idx] for name in sparse_cols}
y_train = data[label_cols].values[:train_idx]python
# ESMM estimates CTR / CVR / CTCVR jointly from user and item feature towers.
model = ESMM(
user_features,
item_features,
cvr_params={"dims": [16, 8]},
ctr_params={"dims": [16, 8]},
)5. Trainer Interface
python
from torch_rechub.trainers import MTLTrainer
trainer = MTLTrainer(
model,
task_types=["classification", "classification"],
optimizer_params={"lr": 1e-3},
regularization_params={"embedding_l2": 0.0, "dense_l2": 0.0},
adaptive_params=None, # Optional: {"method": "uwl"} / {"method": "gradnorm"} / {"method": "metabalance"}
n_epoch=10,
earlystop_taskid=0,
earlystop_patience=10,
device="cpu",
model_path="./saved/mtl",
)6. Evaluation and Tuning Suggestions
1. Evaluation output
python
scores = mtl_trainer.evaluate(mtl_trainer.model, test_dl)
print(scores)evaluate() returns a list ordered by task_types, for example:
[cvr_auc, ctr_auc]- or three task scores in the ESMM case
2. What to tune first
MMOE: start withn_expertPLE: start withn_level / n_expert_specific / n_expert_shared- if task imbalance is obvious: try
adaptive_params={"method": "uwl"} - if multi-task AUC is unstable: reduce the learning rate first, then shrink expert/tower dimensions
7. FAQ
Q1: Why not use from torch_rechub.utils import DataGenerator here?
Because DataGenerator lives in torch_rechub.utils.data, not in the top-level torch_rechub.utils namespace.
Q2: Why use n_epoch instead of n_epochs?
The actual parameter name in MTLTrainer is n_epoch.
Q3: Why is there no evaluate_multi_task() helper?
The framework directly uses MTLTrainer.evaluate(model, data_loader), which returns a list of task scores.
Q4: Why call os.makedirs(...) before training?
MTLTrainer does not create model_path automatically, so the examples create the save directory explicitly.
