Skip to content

Pytorch pipeline with torch.distributed from YAI 8th DongHa Kim

Notifications You must be signed in to change notification settings

Refeat/distributed-pipeline

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Pytorch pipeline with torch.distributed

utils.trainer.TrainLoop will run training loop - compatible with torch.distributed.run

All you need to do

  • Complete config/train.py's TrainSettings(or YourSettings) class.
    • this setting class is compatible with argparse and json.
  • Complete data/__init__.py's load_data_from_args function.
  • Complete model package.
  • Complete utils/initialization.py's create_model_from_config function.
  • Complete some method of utils/trainer.py's TrainLoop class.
    • log_loss_dict method: logging function of loss values dict.
    • compute_losses method: calculate losses from micro_batch and TrainLoop vars
    • backward_from_losses method: make single loss from losses, and run loss.backward()
    • __init__ method: add your extra values to TrainLoop vars if needed.
  • Complete run/train.py to make sense with all code signatures you modified.
  • Modify setting json file, after copying default train settings with command,
    python3 -c "from config;train import TrainSettings as T; print(T().json(indent=2))" >> train_config.json
    

How to run

after completion, you can run train script with

python3 -m run.train --distributed --config_json train_config.json

Citations

@inproceedings{gong2022diffuseq,
  author = {Gong, Shansan and Li, Mukai and Feng, Jiangtao and Wu, Zhiyong and Kong, Lingpeng},
  booktitle = {International Conference on Learning Representations, ICLR},
  title = {{DiffuSeq}: Sequence to Sequence Text Generation with Diffusion Models},
  year = 2023
}

About

Pytorch pipeline with torch.distributed from YAI 8th DongHa Kim

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.9%
  • Shell 0.1%