utils.trainer.TrainLoop
will run training loop - compatible with torch.distributed.run
- Complete
config/train.py
'sTrainSettings
(orYourSettings
) class.- this setting class is compatible with argparse and json.
- Complete
data/__init__.py
'sload_data_from_args
function. - Complete
model
package. - Complete
utils/initialization.py
'screate_model_from_config
function. - Complete some method of
utils/trainer.py
'sTrainLoop
class.log_loss_dict
method: logging function of loss values dict.compute_losses
method: calculatelosses
frommicro_batch
and TrainLoop varsbackward_from_losses
method: make singleloss
fromlosses
, and runloss.backward()
__init__
method: add your extra values to TrainLoop vars if needed.
- Complete
run/train.py
to make sense with all code signatures you modified. - Modify setting json file, after copying default train settings with command,
python3 -c "from config;train import TrainSettings as T; print(T().json(indent=2))" >> train_config.json
after completion, you can run train script with
python3 -m run.train --distributed --config_json train_config.json
@inproceedings{gong2022diffuseq,
author = {Gong, Shansan and Li, Mukai and Feng, Jiangtao and Wu, Zhiyong and Kong, Lingpeng},
booktitle = {International Conference on Learning Representations, ICLR},
title = {{DiffuSeq}: Sequence to Sequence Text Generation with Diffusion Models},
year = 2023
}