Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Fully Sharded Data Parallel (FSDP) Integration #147

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions examples/power_limit_optimizer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

This example will demonstrate how to integrate Zeus with `torchvision` and the ImageNet dataset.

[`train_single.py`](train_single.py) and [`train_dp.py`](train_dp.py) were adapted and simplified from [PyTorch's example training code for ImageNet](https://github.com/pytorch/examples/blob/main/imagenet/main.py).
The former script is for simple single GPU training, whereas the latter is for data parallel training with PyTorch DDP and [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html).
[`train_single.py`](train_single.py) and [`train_dp.py`](train_dp.py) were adapted and simplified from [PyTorch's example training code for ImageNet](https://github.com/pytorch/examples/blob/main/imagenet/main.py). [`train_fsdp.py`](train_fsdp.py) was adapted from [Getting Started with Fully Sharded Data Parallel(FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html).

[`train_single.py`](train_single.py) is for simple single GPU training, [`train_dp.py`](train_dp.py) is for data parallel training with PyTorch DDP, and [`train_fsdp.py`](train_fsdp.py) is for Fully Sharded Data Parallel training.

## Dependencies

Expand All @@ -23,6 +24,17 @@ You just need to download and extract the ImageNet data and mount it to the Dock
- [`ZeusMonitor`](http://ml.energy/zeus/reference/monitor/#zeus.monitor.ZeusMonitor): Measures the GPU time and energy consumption of arbitrary code blocks.
- [`GlobalPowerLimitOptimizer`](https://ml.energy/zeus/reference/optimizer/power_limit/#zeus.optimizer.power_limit.GlobalPowerLimitOptimizer): Online-profiles each power limit with `ZeusMonitor` and finds the cost-optimal power limit.

## Multi-GPU Distributed Training (Pytorch DDP and FSDP)

When using `ZeusMonitor` and/or `GlobalPowerLimitOptimizer` in a multi-GPU Distributed context, launch one instance of `ZeusMonitor` and/or `GlobalPowerLimitOptimizer` per local rank (per GPU on each node), and pass in the local rank to `ZeusMonitor` as shown below:
parthraut marked this conversation as resolved.
Show resolved Hide resolved

```python
monitor = ZeusMonitor(gpu_indices=[local_rank]) # pass in local rank to gpu_indices.
plo = GlobalPowerLimitOptimizer(monitor)
```

Ensure that only one GPU is monitored per `ZeusMonitor`. Internally, `GlobalPowerLimitOptimizer` performs an [All-Reduce](https://pytorch.org/docs/stable/distributed.html) to synchronize before making a power limit decision.
parthraut marked this conversation as resolved.
Show resolved Hide resolved

## Example command

You can specify the maximum training time slowdown factor (1.0 means no slowdown) by setting `ZEUS_MAX_SLOWDOWN`. The default is set to 1.1 in this example script, meaning the lowest power limit that keeps training time inflation within 10% will be automatically found.
Expand All @@ -34,11 +46,25 @@ python train_single.py \
[DATA_DIR] \
--gpu 0 `# Specify the GPU id to use`

# Multi-GPU Data Parallel
# Multi-GPU Distributed Data Parallel
torchrun \
--nnodes 1 \
--nproc_per_node gpu `# Number of processes per node, should be equal to the number of GPUs.` \
`# When set to 'gpu', it means use all the GPUs available.` \
train_dp.py \
[DATA_DIR]

# Multi-GPU Fully Sharded Data Parallel
torchrun \
--nnodes 1 \
--nproc_per_node=gpu `# Number of processes per node, should be equal to the number of GPUs.` \
train_fsdp.py \
--batch-size 64 `# Batch size for training.` \
--test-batch-size 1000 `# Batch size for testing.` \
--epochs 10 `# Number of epochs to train.` \
--lr 1.0 `# Learning rate.` \
--gamma 0.7 `# Learning rate step gamma.` \
--save-model `# Save the trained model.` \
parthraut marked this conversation as resolved.
Show resolved Hide resolved
[DATA_DIR]
```

2 changes: 1 addition & 1 deletion examples/power_limit_optimizer/train_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def main():
if args.gpu == 0:
callback_set: list[Callback] = [
GlobalPowerLimitOptimizer(
monitor=ZeusMonitor(gpu_indices=None), # All visible GPUs.
monitor=ZeusMonitor(gpu_indices=args.gpu), # Since there is only one GPU per process, monitor it (give it local rank).
parthraut marked this conversation as resolved.
Show resolved Hide resolved
optimum_selector=MaxSlowdownConstraint(
factor=get_env("ZEUS_MAX_SLOWDOWN", float, 1.1),
),
Expand Down
199 changes: 199 additions & 0 deletions examples/power_limit_optimizer/train_fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import os
import argparse
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from torch.optim.lr_scheduler import StepLR

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)

from zeus.monitor import ZeusMonitor
from zeus.optimizer.power_limit import GlobalPowerLimitOptimizer

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output

def train(args, model, rank, world_size, train_loader, optimizer, epoch, plo, sampler=None):
model.train()
ddp_loss = torch.zeros(2).to(rank)
if sampler:
sampler.set_epoch(epoch)
for batch_idx, (data, target) in enumerate(train_loader):
plo.on_step_begin()

data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target, reduction='sum')
loss.backward()
optimizer.step()
ddp_loss[0] += loss.item()
ddp_loss[1] += len(data)

plo.on_step_end()

dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))

def test(model, rank, world_size, test_loader):
model.eval()
ddp_loss = torch.zeros(3).to(rank)
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(rank), target.to(rank)
output = model(data)
ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
ddp_loss[2] += len(data)

dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
test_loss = ddp_loss[0] / ddp_loss[2]
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, int(ddp_loss[1]), int(ddp_loss[2]),
100. * ddp_loss[1] / ddp_loss[2]))

def fsdp_main(args):
# If the user wants to explicitly set MASTER_ADDR and MASTER_PORT:
if args.master_addr is not None:
os.environ['MASTER_ADDR'] = args.master_addr
if args.master_port is not None:
os.environ['MASTER_PORT'] = args.master_port

# The following environment variables are provided by torchrun:
# MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE
# We can now initialize the process group using these env variables.
dist.init_process_group(backend="nccl", init_method="env://")

rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = args.local_rank # Get local rank from the arguments

# Set the device using local rank
torch.cuda.set_device(local_rank)

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('./data', train=False, transform=transform)

sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)

train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
cuda_kwargs = {'num_workers': 2, 'pin_memory': True, 'shuffle': False}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)

train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

model = Net().to(local_rank)
model = FSDP(model)

optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

init_start_event = torch.cuda.Event(enable_timing=True)
init_end_event = torch.cuda.Event(enable_timing=True)
init_start_event.record()

# Init ZeusMonitor and GPLO
monitor = ZeusMonitor(gpu_indices=[local_rank])
plo = GlobalPowerLimitOptimizer(monitor, profile_steps=200)

for epoch in range(1, args.epochs + 1):
plo.on_epoch_begin()
train(args, model, local_rank, world_size, train_loader, optimizer, epoch, plo, sampler=sampler1)
plo.on_epoch_end()

test(model, local_rank, world_size, test_loader)
scheduler.step()

init_end_event.record()

if rank == 0:
print(f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec")
print(f"{model}")

if args.save_model:
dist.barrier()
states = model.state_dict()
if rank == 0:
torch.save(states, "mnist_cnn.pt")

dist.destroy_process_group()

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch MNIST FSDP with torchrun')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
parser.add_argument('--master_addr', type=str, default=None,
help='Master address for distributed training (optional, otherwise taken from env)')
parser.add_argument('--master_port', type=str, default=None,
help='Master port for distributed training (optional, otherwise taken from env)')
parser.add_argument('--local-rank', type=int, default=0,
help='Local rank for the process (required for torchrun)')

args = parser.parse_args()
torch.manual_seed(args.seed)

fsdp_main(args)
1 change: 1 addition & 0 deletions zeus/monitor/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __init__(
except ZeusCPUInitError:
self.cpus = EmptyCPUs()
except ZeusCPUNoPermissionError as err:
self.cpus = EmptyCPUs()
parthraut marked this conversation as resolved.
Show resolved Hide resolved
if cpu_indices:
raise RuntimeError(
"Root privilege is required to read RAPL metrics. See "
Expand Down
38 changes: 35 additions & 3 deletions zeus/optimizer/power_limit.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 265 is now broken, because the previous implementation assumed that len(zeus_monitor.gpu_indices) gives the current world size. Let's just switch the default optimum_selector to MaxSlowdownConstraint(factor=1.1).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could use torch.distributed.get_world_size (and something analogous for jax) by defining a generic framework function zeus.framework.get_world_size. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah, I wouldn't bother for this one. Now I think MaxSlowdownConstraint is a better default; the original one if from 2022.

Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from zeus.callback import Callback
from zeus.monitor import ZeusMonitor
from zeus.utils.framework import all_reduce, is_distributed
from zeus.utils.logging import get_logger
from zeus.utils.metric import zeus_cost
from zeus.utils.pydantic_v1 import BaseModel, PositiveInt, PositiveFloat
Expand Down Expand Up @@ -201,6 +202,24 @@ class GlobalPowerLimitOptimizer(Callback):
"""Optimizer for the power limit knob.

This optimizer uses the JIT profiling log to determine the optimal power limit.

Non-distributed training (Single GPU or Multi-GPU on a single node):
Launch one instance of `ZeusMonitor` and `GlobalPowerLimitOptimizer`, and have `ZeusMonitor` track all desired GPUs.
For example, to track all GPUs on a single node:
```python
monitor = ZeusMonitor(gpu_indices=None) # monitor all GPUs
plo = GlobalPowerLimitOptimizer(monitor)
```

Distributed training (Multi-GPU on multiple nodes):
`ZeusMonitor` and `GlobalPowerLimitOptimizer` make the assumption that each GPU is monitored by one and only one instance of `ZeusMonitor` to ensure correct functionality.
Therefore, it is recommended to launch one instance of `ZeusMonitor` and `GlobalPowerLimitOptimizer`
per device (per GPU on each node), and pass in the local rank to `ZeusMonitor` as shown below:
```python
monitor = ZeusMonitor(gpu_indices=[local_rank]) # pass in local rank to gpu_indices.
plo = GlobalPowerLimitOptimizer(monitor)
```
Internally, `GlobalPowerLimitOptimizer` performs an all-reduce over all devices to compute the optimal power limit.
parthraut marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
Expand Down Expand Up @@ -255,9 +274,19 @@ def __init__(
# Setup logging.
self.logger = get_logger(type(self).__name__)

gpus = get_gpus(ensure_homogeneous=True)

# Warn if distributed training is enabled with multiple GPUs monitored.
if is_distributed() and len(monitor.gpu_indices) > 1:
self.logger.warning(
"Distributed training is enabled with %d GPUs monitored. "
"For distributed training, it is recommended to monitor only one GPU per `ZeusMonitor` instance "
"since `GlobalPowerLimitOptimizer` performs an all-reduce operation internally over all devices.",
len(monitor.gpu_indices),
)

# Set the range of power limits to explore.
# Assert that supported power limits ranges are uniform across GPUs.
gpus = get_gpus(ensure_homogeneous=True)
pls = []
for index in monitor.gpu_indices:
pls.append(gpus.getPowerManagementLimitConstraints(index))
Expand Down Expand Up @@ -387,11 +416,14 @@ def on_step_begin(self) -> None:
"Finished profiling for power limit %d W.",
self.state.current_power_limit // 1000,
)

self.measurements.append(
PowerLimitMeasurement(
power_limit=self.state.current_power_limit // 1000,
energy=measurement.total_energy,
time=measurement.time,
energy=all_reduce(
list(measurement.gpu_energy.values()), operation="sum"
),
time=all_reduce([measurement.time], operation="max"),
)
)
# If we're done profiling all power limits, compute the optimal
Expand Down
Loading
Loading