English | 简体中文
Torch-RecHub is a flexible and extensible recommender system framework built with PyTorch. It aims to simplify research and application of recommendation algorithms by providing common model implementations, data processing tools, and evaluation metrics.
- Modular Design: Easy to add new models, datasets, and evaluation metrics.
- PyTorch-based: Leverages PyTorch's dynamic graph and GPU acceleration capabilities.
- Rich Model Library: Contains various classic and cutting-edge recommendation algorithms.
- Standardized Pipeline: Provides unified data loading, training, and evaluation workflows.
- Easy Configuration: Adjust experiment settings via config files or command-line arguments.
- Reproducibility: Designed to ensure reproducible experimental results.
- Additional Features: Negative sampling, multi-task learning, etc.
- Python 3.8+
- PyTorch 1.7+ (CUDA-enabled version recommended for GPU acceleration)
- NumPy
- Pandas
- SciPy
- Scikit-learn
- Stable Version
pip install torch-rechub
- Latest Version (Recommended)
git clone https://github.com/datawhalechina/torch-rechub.git
cd torch-rechub
python setup.py install
Install dependencies:
pip install -r requirements.txt
Here's a simple example of training a model (e.g., MF - Matrix Factorization) on the MovieLens-100k dataset:
# 1. Prepare data (if preprocessing needed)
# python examples/matching/data/ml-1m/preprocess_ml.py
# 2. Train model
python run_ml_dssm.py
# Or override config with command-line arguments:
# python run_ml_dssm.py --model_name dssm --device 'cuda:0' --learning_rate 0.001 --epoch 50 --batch_size 4096 --weight_decay 0.0001 --save_dir 'saved/dssm_ml-100k'
After training, model files will be saved in the saved/dssm_ml-100k
directory (or your configured directory).
torch-rechub/ # Root directory
├── README.md # Project documentation
├── torch_rechub/ # Core library
│ ├── basic/ # Basic components
│ ├── models/ # Recommendation model implementations
│ │ ├── matching/ # Matching models (DSSM/MIND/GRU4Rec etc.)
│ │ └── ranking/ # Ranking models (WideDeep/DeepFM/DIN etc.)
│ │ └── multi_task/ # Multi-task models (MMoE/ESMM etc.)
│ ├── trainers/ # Trainers
│ ├── utils/ # Utility functions
├── examples/ # Example scripts
│ ├── matching/ # Matching task examples
│ └── ranking/ # Ranking task examples
├── docs/ # Documentation
├── tutorials/ # Jupyter tutorials
├── setup.py # Package installation script
├── mkdocs.yml # MkDocs config file
└── requirements.txt # Project dependencies
The framework currently supports the following recommendation models:
General Recommendation:
- DSSM: Deep Structured Semantic Model
- Wide&Deep: Wide & Deep Learning for Recommender Systems
- FM: Factorization Machines
- DeepFM: Deep Factorization Machine
- ...
Sequential Recommendation:
- DIN: Deep Interest Network
- DIEN: Deep Interest Evolution Network
- BST: Behavior Sequence Transformer
- GRU4Rec: Gated Recurrent Unit for Recommendation
- SASRec: Self-Attentive Sequential Recommendation
- ...
Multi-Interest Recommendation:
- MIND: Multi-Interest Network with Dynamic Routing
- SINE: Self-Interested Network for Recommendation
- ...
Multi-Task Recommendation:
- ESMM: Entire Space Multi-Task Model
- MMoE: Multi-Task Multi-Interest Network for Recommendation
- PLE: Personalized Learning to Rank
- AITM: Adaptive Interest-Task Matching
- ...
The framework provides built-in support or preprocessing scripts for the following common datasets:
- MovieLens
- Amazon
- Criteo
- Avazu
- Census-Income
- BookCrossing
- Ali-ccp
- Yidian
- ...
The expected data format is typically an interaction file containing:
- User ID
- Item ID
- Rating (optional)
- Timestamp (optional)
For specific format requirements, please refer to the example code in the tutorials
directory.
You can easily integrate your own datasets by ensuring they conform to the framework's data format requirements or by writing custom data loaders.
All model usage examples can be found in /examples
from torch_rechub.models.ranking import DeepFM
from torch_rechub.trainers import CTRTrainer
from torch_rechub.utils.data import DataGenerator
dg = DataGenerator(x, y)
train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=256)
model = DeepFM(deep_features=deep_features, fm_features=fm_features, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})
ctr_trainer = CTRTrainer(model)
ctr_trainer.fit(train_dataloader, val_dataloader)
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
from torch_rechub.models.multi_task import SharedBottom, ESMM, MMOE, PLE, AITM
from torch_rechub.trainers import MTLTrainer
task_types = ["classification", "classification"]
model = MMOE(features, task_types, 8, expert_params={"dims": [32,16]}, tower_params_list=[{"dims": [32, 16]}, {"dims": [32, 16]}])
mtl_trainer = MTLTrainer(model)
mtl_trainer.fit(train_dataloader, val_dataloader)
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
from torch_rechub.models.matching import DSSM
from torch_rechub.trainers import MatchTrainer
from torch_rechub.utils.data import MatchDataGenerator
dg = MatchDataGenerator(x y)
train_dl, test_dl, item_dl = dg.generate_dataloader(test_user, all_item, batch_size=256)
model = DSSM(user_features, item_features, temperature=0.02,
user_params={
"dims": [256, 128, 64],
"activation": 'prelu',
},
item_params={
"dims": [256, 128, 64],
"activation": 'prelu',
})
match_trainer = MatchTrainer(model)
match_trainer.fit(train_dl)
We welcome all types of contributions! If you'd like to contribute to this project, please follow these steps:
- Fork the repository: Click the "Fork" button in the upper right corner.
- Make your changes: Implement new features or fix bugs.
- Commit changes:
git commit -m "feat: add new feature"
orfix: fix some issue"
(Following Conventional Commits is preferred). - Push to branch:
git push origin
- Create Pull Request: Go back to the original repository page, click "New pull request", compare your branch with the
main
branch of the main repository, and submit PR.
Please ensure your PR description clearly explains the changes and their purpose.
We also welcome bug reports and feature suggestions through Issues.
This project is licensed under the MIT License.
If you use this framework in your research or work, please consider citing:
@misc{torch_rechub,
title = {Torch-RecHub},
author = {Datawhale},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/datawhalechina/torch-rechub}},
note = {A PyTorch-based recommender system framework providing easy-to-use and extensible solutions}
}
- Project Lead: morningsky
- GitHub Issues
Last updated: [2025-03-31]