Trainers API Reference
This section provides detailed API documentation for all trainers in Torch-RecHub.
CTRTrainer
CTRTrainer is a general trainer for single task learning, primarily used for binary classification tasks such as Click-Through Rate (CTR) prediction.
Parameters
model
(nn.Module): Any single task learning modeloptimizer_fn
(torch.optim): PyTorch optimizer function, defaults totorch.optim.Adam
optimizer_params
(dict): Optimizer parameters, defaults to{"lr": 1e-3, "weight_decay": 1e-5}
scheduler_fn
(torch.optim.lr_scheduler): PyTorch learning rate scheduler, e.g.,torch.optim.lr_scheduler.StepLR
scheduler_params
(dict): Learning rate scheduler parametersn_epoch
(int): Number of training epochsearlystop_patience
(int): Number of epochs to wait before early stopping when validation performance doesn't improve, defaults to 10device
(str): Device to use, either"cpu"
or"cuda:0"
gpus
(list): List of GPU IDs, defaults to empty. If length >=1, model will be wrapped by nn.DataParallelloss_mode
(bool): Training mode, defaults to Truemodel_path
(str): Path to save the model, defaults to"./"
Main Methods
train_one_epoch(data_loader, log_interval=10)
: Train for one epochfit(train_dataloader, val_dataloader=None)
: Train the modelevaluate(model, data_loader)
: Evaluate the modelpredict(model, data_loader)
: Make predictions
MatchTrainer
MatchTrainer is a trainer for matching/retrieval tasks, supporting multiple training modes.
Parameters
model
(nn.Module): Any matching modelmode
(int): Training mode, options:- 0: point-wise
- 1: pair-wise
- 2: list-wise
optimizer_fn
(torch.optim): Same as CTRTraineroptimizer_params
(dict): Same as CTRTrainerscheduler_fn
(torch.optim.lr_scheduler): Same as CTRTrainerscheduler_params
(dict): Same as CTRTrainern_epoch
(int): Same as CTRTrainerearlystop_patience
(int): Same as CTRTrainerdevice
(str): Same as CTRTrainergpus
(list): Same as CTRTrainermodel_path
(str): Same as CTRTrainer
Main Methods
train_one_epoch(data_loader, log_interval=10)
: Train for one epochfit(train_dataloader, val_dataloader=None)
: Train the modelevaluate(model, data_loader)
: Evaluate the modelpredict(model, data_loader)
: Make predictionsinference_embedding(model, mode, data_loader, model_path)
: Infer embeddingsmode
: Either "user" or "item"
MTLTrainer
MTLTrainer is a trainer for multi-task learning, supporting various adaptive loss weighting methods.
Parameters
model
(nn.Module): Any multi-task learning modeltask_types
(list): List of task types, supports ["classification", "regression"]optimizer_fn
(torch.optim): Same as CTRTraineroptimizer_params
(dict): Same as CTRTrainerscheduler_fn
(torch.optim.lr_scheduler): Same as CTRTrainerscheduler_params
(dict): Same as CTRTraineradaptive_params
(dict): Adaptive loss weighting method parameters, supports:{"method": "uwl"}
: Uncertainty Weighted Loss{"method": "metabalance"}
: MetaBalance method{"method": "gradnorm", "alpha": 0.16}
: GradNorm methodn_epoch
(int): Same as CTRTrainerearlystop_taskid
(int): Task ID for early stopping, defaults to 0earlystop_patience
(int): Same as CTRTrainerdevice
(str): Same as CTRTrainergpus
(list): Same as CTRTrainermodel_path
(str): Same as CTRTrainer
Main Methods
train_one_epoch(data_loader)
: Train for one epochfit(train_dataloader, val_dataloader, mode='base', seed=0)
: Train the modelevaluate(model, data_loader)
: Evaluate the modelpredict(model, data_loader)
: Make predictions
Special Features
- Support for Multiple Adaptive Loss Weighting Methods:
- UWL (Uncertainty Weighted Loss)
- MetaBalance
-
GradNorm
-
Multi-task Early Stopping:
- Early stopping based on specified task performance
-
Saves best model weights based on validation performance
-
Support for Multiple Task Type Combinations:
- Classification tasks
-
Regression tasks
-
Training Log Recording:
- Records loss for each task
- Records loss weights (when using adaptive methods)
- Records performance metrics on validation set