Trainers API Reference
This section provides detailed API documentation for all trainers in Torch-RecHub.
CTRTrainer
CTRTrainer is a general trainer for single task learning, primarily used for binary classification tasks such as Click-Through Rate (CTR) prediction.
Parameters
model(nn.Module): Any single task learning modeloptimizer_fn(torch.optim): PyTorch optimizer function, defaults totorch.optim.Adamoptimizer_params(dict): Optimizer parameters, defaults to{"lr": 1e-3, "weight_decay": 1e-5}scheduler_fn(torch.optim.lr_scheduler): PyTorch learning rate scheduler, e.g.,torch.optim.lr_scheduler.StepLRscheduler_params(dict): Learning rate scheduler parametersn_epoch(int): Number of training epochsearlystop_patience(int): Number of epochs to wait before early stopping when validation performance doesn't improve, defaults to 10device(str): Device to use, either"cpu"or"cuda:0"gpus(list): List of GPU IDs, defaults to empty. If length >=1, model will be wrapped by nn.DataParallelloss_mode(bool): Training mode, defaults to Truemodel_path(str): Path to save the model, defaults to"./"
Main Methods
train_one_epoch(data_loader, log_interval=10): Train for one epochfit(train_dataloader, val_dataloader=None): Train the modelevaluate(model, data_loader): Evaluate the modelpredict(model, data_loader): Make predictions
MatchTrainer
MatchTrainer is a trainer for matching/retrieval tasks, supporting multiple training modes.
Parameters
model(nn.Module): Any matching modelmode(int): Training mode, options:- 0: point-wise
- 1: pair-wise
- 2: list-wise
optimizer_fn(torch.optim): Same as CTRTraineroptimizer_params(dict): Same as CTRTrainerscheduler_fn(torch.optim.lr_scheduler): Same as CTRTrainerscheduler_params(dict): Same as CTRTrainern_epoch(int): Same as CTRTrainerearlystop_patience(int): Same as CTRTrainerdevice(str): Same as CTRTrainergpus(list): Same as CTRTrainermodel_path(str): Same as CTRTrainer
Main Methods
train_one_epoch(data_loader, log_interval=10): Train for one epochfit(train_dataloader, val_dataloader=None): Train the modelevaluate(model, data_loader): Evaluate the modelpredict(model, data_loader): Make predictionsinference_embedding(model, mode, data_loader, model_path): Infer embeddingsmode: Either "user" or "item"
MTLTrainer
MTLTrainer is a trainer for multi-task learning, supporting various adaptive loss weighting methods.
Parameters
model(nn.Module): Any multi-task learning modeltask_types(list): List of task types, supports ["classification", "regression"]optimizer_fn(torch.optim): Same as CTRTraineroptimizer_params(dict): Same as CTRTrainerscheduler_fn(torch.optim.lr_scheduler): Same as CTRTrainerscheduler_params(dict): Same as CTRTraineradaptive_params(dict): Adaptive loss weighting method parameters, supports:{"method": "uwl"}: Uncertainty Weighted Loss{"method": "metabalance"}: MetaBalance method{"method": "gradnorm", "alpha": 0.16}: GradNorm methodn_epoch(int): Same as CTRTrainerearlystop_taskid(int): Task ID for early stopping, defaults to 0earlystop_patience(int): Same as CTRTrainerdevice(str): Same as CTRTrainergpus(list): Same as CTRTrainermodel_path(str): Same as CTRTrainer
Main Methods
train_one_epoch(data_loader): Train for one epochfit(train_dataloader, val_dataloader, mode='base', seed=0): Train the modelevaluate(model, data_loader): Evaluate the modelpredict(model, data_loader): Make predictions
Special Features
- Support for Multiple Adaptive Loss Weighting Methods:
- UWL (Uncertainty Weighted Loss)
- MetaBalance
-
GradNorm
-
Multi-task Early Stopping:
- Early stopping based on specified task performance
-
Saves best model weights based on validation performance
-
Support for Multiple Task Type Combinations:
- Classification tasks
-
Regression tasks
-
Training Log Recording:
- Records loss for each task
- Records loss weights (when using adaptive methods)
- Records performance metrics on validation set