PyTorch deep learning project made easy.
- Python 3.x
- PyTorch
- Clear folder structure which is suitable for many deep learning projects.
.json
config file support for more convenient parameter tuning.- Checkpoint saving and resuming.
- Abstract base classes for faster development:
BaseTrainer
handles checkpoint saving/resuming, training process logging, and more.BaseDataLoader
handles batch generation, data shuffling, and validation data splitting.BaseModel
provides basic model summary.
pytorch-template/
│
├── train.py - example main
├── config.json - example config file
│
├── base/ - abstract base classes
│ ├── base_data_loader.py - abstract base class for data loaders
│ ├── base_model.py - abstract base class for models
│ └── base_trainer.py - abstract base class for trainers
│
├── data_loader/ - anything about data loading goes here
│ └── data_loaders.py
│
├── datasets/ - default datasets folder
│
├── logger/ - for training process logging
│ └── logger.py
│
├── model/ - models, losses, and metrics
│ ├── modules/ - submodules of your model
│ ├── loss.py
│ ├── metric.py
│ └── model.py
│
├── saved/ - default checkpoints folder
│
├── trainer/ - trainers
│ └── trainer.py
│
└── utils/
├── util.py
└── ...
The code in this repo is an MNIST example of the template.
Config files are in .json
format:
{
"name": "Mnist_LeNet", // training session name
"cuda": true, // use cuda
"data_loader": {
"data_dir": "datasets/", // dataset path
"batch_size": 32, // batch size
"shuffle": true // shuffle data each time calling __iter__()
},
"validation": {
"validation_split": 0.1, // validation data ratio
"shuffle": true // shuffle training data before splitting
},
"optimizer_type": "Adam",
"optimizer": {
"lr": 0.001, // (optional) learning rate
"weight_decay": 0 // (optional) weight decay
},
"loss": "my_loss", // loss
"metrics": [ // metrics
"my_metric",
"my_metric2"
],
"trainer": {
"epochs": 1000, // number of training epochs
"save_dir": "saved/", // checkpoints are saved in save_dir/name
"save_freq": 1, // save checkpoints every save_freq epochs
"verbosity": 2, // 0: quiet, 1: per epoch, 2: full
"monitor": "val_loss", // monitor value for best model
"monitor_mode": "min" // "min" if monitor value the lower the better, otherwise "max"
},
"arch": "MnistModel", // model architecture
"model": {} // model configs
}
Add addional configurations if you need.
Modify the configurations in .json
config files, then run:
python train.py --config config.json
You can resume from a previously saved checkpoint by:
python train.py --resume path/to/checkpoint
- Writing your own data loader
-
Inherit
BaseDataLoader
BaseDataLoader
is similar totorch.utils.data.DataLoader
, you can use either of them.BaseDataLoader
handles:- Generating next batch
- Data shuffling
- Generating validation data loader by calling
BaseDataLoader.split_validation()
-
Implementing abstract methods
You need to implement these abstract methods:
_pack_data()
: pack data members into a list of tuples_unpack_data()
: unpack packed data_update_data()
: updata data members_n_samples()
: total number of samples
-
DataLoader Usage
BaseDataLoader
is an iterator, to iterate through batches:for batch_idx, (x_batch, y_batch) in data_loader: pass
-
Example
Please refer to
data_loader/data_loaders.py
for an MNIST data loading example.
- Writing your own trainer
-
Inherit
BaseTrainer
BaseTrainer
handles:- Training process logging
- Checkpoint saving
- Checkpoint resuming
- Reconfigurable monitored value for saving current best
- Controlled by the configs
monitor
andmonitor_mode
, ifmonitor_mode == 'min'
then the trainer will save a checkpointmodel_best.pth.tar
whenmonitor
is a current minimum
- Controlled by the configs
-
Implementing abstract methods
You need to implement
_train_epoch()
for your training process, if you need validation then you can implement_valid_epoch()
as intrainer/trainer.py
-
Example
Please refer to
trainer/trainer.py
for MNIST training.
- Writing your own model
-
Inherit
BaseModel
BaseModel
handles:- Inherited from
torch.nn.Module
summary()
: Model summary
- Inherited from
-
Implementing abstract methods
Implement the foward pass method
forward()
-
Example
Please refer to
model/model.py
for a LeNet example.
If you need to change the loss function or metrics, first import
those function in train.py
, then modify "loss"
and "metrics"
in .json
config files
You can add multiple metrics in your config files:
"metrics": ["my_metric", "my_metric2"],
If you have additional information to be logged, in _train_epoch()
of your trainer class, merge them with log
as shown below before returning:
additional_log = {"gradient_norm": g, "sensitivity": s}
log = {**log, **additional_log}
return log
To split validation data from a data loader, call BaseDataLoader.split_validation()
, it will return a validation data loader, with the number of samples according to the specified ratio in your config file.
Note: the split_validation()
method will modify the original data loader
Note: split_validation()
will return None
if "validation_split"
is set to 0
You can specify the name of the training session in config files:
"name": "MNIST_LeNet",
The checkpoints will be saved in save_dir/name
.
The config file is saved in the same folder.
Note: checkpoints contain:
{
'arch': arch,
'epoch': epoch,
'logger': self.train_logger,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'monitor_best': self.monitor_best,
'config': self.config
}
Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8
Code should pass the Flake8 check before committing.
- Iteration-based training (instead of epoch-based)
- Deprecate
BaseDataLoader
, usetorch.utils.data
instesad - Multi-GPU support
- Multiple optimizers and lr scheduler
- Update the example to PyTorch 0.4 (or 1.0)
-
TensorboardX
orvisdom
logger support - Configurable logging layout, checkpoint naming
- Load settings from
config
files
This project is licensed under the MIT License. See LICENSE for more details
This project is inspired by the project Tensorflow-Project-Template by Mahmoud Gemy