Skip to content

PyTorch optimizations

Albert Zeyer edited this page Feb 29, 2024 · 9 revisions

This is just an overview and collection of references.

Potential optimizations (speed and/or memory):

  • Distributed / Multi GPU training. RETURNN config: check torch_distributed
  • Automatic mixed precision (AMP), e.g. to use float16 (fp16). RETURNN config: torch_amp = "float16"
  • PyTorch scripting and tracing (https://github.com/rwth-i6/returnn/issues/1436)
  • torch.compile
  • TorchDynamo
  • torch.optim._multi_tensor.AdamW
  • apex.optimizers.FusedAdam (might be integrated into PyTorch? https://github.com/pytorch/pytorch/issues/71274)
  • Asynchronous data loading and augmentation. RETURNN config: torch_dataloader_opts = {"num_workers": 1}, maybe use together with MultiProcDataset if more workers are needed, see here

To find potential bottlenecks in the code:

References: