By Apoorv Khandelwal and Peter Curtin
Automatically distribute PyTorch functions onto multiple machines or GPUs
pip install torchrunx
Requires: Linux (with shared filesystem & SSH access if using multiple machines)
Here's a simple example where we "train" a model on two nodes (with 2 GPUs each).
Training code
import os
import torch
def train():
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
model = torch.nn.Linear(10, 10).to(local_rank)
ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
optimizer = torch.optim.AdamW(ddp_model.parameters())
optimizer.zero_grad()
outputs = ddp_model(torch.randn(5, 10))
labels = torch.randn(5, 10).to(local_rank)
torch.nn.functional.mse_loss(outputs, labels).backward()
optimizer.step()
if rank == 0:
return model
You could also use transformers.Trainer
(or similar) to automatically handle all the multi-GPU / DDP code above.
import torchrunx as trx
if __name__ == "__main__":
result = trx.launch(
func=train,
hostnames=["localhost", "other_node"],
workers_per_host=2 # number of GPUs
)
trained_model = result.rank(0)
torch.save(trained_model.state_dict(), "model.pth")
Whether you have 1 GPU, 8 GPUs, or 8 machines:
Features
- Our
launch()
utility is super Pythonic- Return objects from your workers
- Run
python script.py
instead oftorchrun script.py
- Launch multi-node functions, even from Python Notebooks
- Fine-grained control over logging, environment variables, exception handling, etc.
- Automatic integration with SLURM
Robustness
- If you want to run a complex, modular workflow in one script
- don't parallelize your entire script: just the functions you want!
- no worries about memory leaks or OS failures
Convenience
- If you don't want to:
- set up
dist.init_process_group
yourself - manually SSH into every machine and
torchrun --master-ip --master-port ...
, babysit failed processes, etc.
- set up