diff --git a/examples/power_limit_optimizer/README.md b/examples/power_limit_optimizer/README.md index ea227ed0..ca79eaf6 100644 --- a/examples/power_limit_optimizer/README.md +++ b/examples/power_limit_optimizer/README.md @@ -2,8 +2,9 @@ This example will demonstrate how to integrate Zeus with `torchvision` and the ImageNet dataset. -[`train_single.py`](train_single.py) and [`train_dp.py`](train_dp.py) were adapted and simplified from [PyTorch's example training code for ImageNet](https://github.com/pytorch/examples/blob/main/imagenet/main.py). -The former script is for simple single GPU training, whereas the latter is for data parallel training with PyTorch DDP and [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html). +[`train_single.py`](train_single.py) and [`train_dp.py`](train_dp.py) were adapted and simplified from [PyTorch's example training code for ImageNet](https://github.com/pytorch/examples/blob/main/imagenet/main.py). [`train_fsdp.py`](train_fsdp.py) was adapted from [Getting Started with Fully Sharded Data Parallel(FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html). + +[`train_single.py`](train_single.py) is for simple single GPU training, [`train_dp.py`](train_dp.py) is for data parallel training with PyTorch DDP, and [`train_fsdp.py`](train_fsdp.py) is for Fully Sharded Data Parallel training. ## Dependencies @@ -23,6 +24,17 @@ You just need to download and extract the ImageNet data and mount it to the Dock - [`ZeusMonitor`](http://ml.energy/zeus/reference/monitor/#zeus.monitor.ZeusMonitor): Measures the GPU time and energy consumption of arbitrary code blocks. - [`GlobalPowerLimitOptimizer`](https://ml.energy/zeus/reference/optimizer/power_limit/#zeus.optimizer.power_limit.GlobalPowerLimitOptimizer): Online-profiles each power limit with `ZeusMonitor` and finds the cost-optimal power limit. +## Multi-GPU Distributed Training (Pytorch DDP and FSDP) + +When using `ZeusMonitor` and/or `GlobalPowerLimitOptimizer` in a multi-GPU Distributed context, construct one instance of `ZeusMonitor` and/or `GlobalPowerLimitOptimizer` per local rank (per GPU on each node), and pass in the local rank to `ZeusMonitor` as shown below: + +```python +monitor = ZeusMonitor(gpu_indices=[local_rank]) # pass in local rank to gpu_indices. +plo = GlobalPowerLimitOptimizer(monitor) +``` + +Ensure that only one GPU is monitored per `ZeusMonitor`. Internally, `GlobalPowerLimitOptimizer` performs an [AllReduce](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html) to aggregate time and energy measurements across all GPUs before making a power limit decision. + ## Example command You can specify the maximum training time slowdown factor (1.0 means no slowdown) by setting `ZEUS_MAX_SLOWDOWN`. The default is set to 1.1 in this example script, meaning the lowest power limit that keeps training time inflation within 10% will be automatically found. @@ -34,11 +46,19 @@ python train_single.py \ [DATA_DIR] \ --gpu 0 `# Specify the GPU id to use` -# Multi-GPU Data Parallel +# Multi-GPU Distributed Data Parallel torchrun \ --nnodes 1 \ --nproc_per_node gpu `# Number of processes per node, should be equal to the number of GPUs.` \ `# When set to 'gpu', it means use all the GPUs available.` \ train_dp.py \ [DATA_DIR] + +# Multi-GPU Fully Sharded Data Parallel +torchrun \ + --nnodes 1 \ + --nproc_per_node=gpu `# Number of processes per node, should be equal to the number of GPUs.` \ + train_fsdp.py \ + [DATA_DIR] ``` + diff --git a/examples/power_limit_optimizer/train_dp.py b/examples/power_limit_optimizer/train_dp.py index a68f0fc4..1fb0c797 100644 --- a/examples/power_limit_optimizer/train_dp.py +++ b/examples/power_limit_optimizer/train_dp.py @@ -197,21 +197,19 @@ def main(): sampler=val_sampler, ) - # The rank 0 process will monitor and optimize the power limit of all GPUs. - if args.gpu == 0: - callback_set: list[Callback] = [ - GlobalPowerLimitOptimizer( - monitor=ZeusMonitor(gpu_indices=None), # All visible GPUs. - optimum_selector=MaxSlowdownConstraint( - factor=get_env("ZEUS_MAX_SLOWDOWN", float, 1.1), - ), - warmup_steps=10, - profile_steps=40, - pl_step=25, - ) - ] - else: - callback_set = [] + # All proceses will monitor and optimize the power limit of all GPUs (one process per GPU). + callback_set: list[Callback] = [ + GlobalPowerLimitOptimizer( + monitor=ZeusMonitor(gpu_indices=args.gpu), # Since there is only one GPU per process, monitor it (give it local rank). + optimum_selector=MaxSlowdownConstraint( + factor=get_env("ZEUS_MAX_SLOWDOWN", float, 1.1), + ), + warmup_steps=10, + profile_steps=40, + pl_step=25, + ) + ] + callbacks = CallbackSet(callback_set) for epoch in range(args.epochs): diff --git a/examples/power_limit_optimizer/train_fsdp.py b/examples/power_limit_optimizer/train_fsdp.py new file mode 100644 index 00000000..93006812 --- /dev/null +++ b/examples/power_limit_optimizer/train_fsdp.py @@ -0,0 +1,199 @@ +import os +import argparse +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms + +from torch.optim.lr_scheduler import StepLR + +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + CPUOffload, + BackwardPrefetch, +) +from torch.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + enable_wrap, + wrap, +) + +from zeus.monitor import ZeusMonitor +from zeus.optimizer.power_limit import GlobalPowerLimitOptimizer + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + +def train(args, model, rank, world_size, train_loader, optimizer, epoch, plo, sampler=None): + model.train() + ddp_loss = torch.zeros(2).to(rank) + if sampler: + sampler.set_epoch(epoch) + for batch_idx, (data, target) in enumerate(train_loader): + plo.on_step_begin() + + data, target = data.to(rank), target.to(rank) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target, reduction='sum') + loss.backward() + optimizer.step() + ddp_loss[0] += loss.item() + ddp_loss[1] += len(data) + + plo.on_step_end() + + dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) + if rank == 0: + print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1])) + +def test(model, rank, world_size, test_loader): + model.eval() + ddp_loss = torch.zeros(3).to(rank) + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(rank), target.to(rank) + output = model(data) + ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item() + ddp_loss[2] += len(data) + + dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) + if rank == 0: + test_loss = ddp_loss[0] / ddp_loss[2] + print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( + test_loss, int(ddp_loss[1]), int(ddp_loss[2]), + 100. * ddp_loss[1] / ddp_loss[2])) + +def fsdp_main(args): + # If the user wants to explicitly set MASTER_ADDR and MASTER_PORT: + if args.master_addr is not None: + os.environ['MASTER_ADDR'] = args.master_addr + if args.master_port is not None: + os.environ['MASTER_PORT'] = args.master_port + + # The following environment variables are provided by torchrun: + # MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE + # We can now initialize the process group using these env variables. + dist.init_process_group(backend="nccl", init_method="env://") + + rank = dist.get_rank() + world_size = dist.get_world_size() + local_rank = args.local_rank # Get local rank from the arguments + + # Set the device using local rank + torch.cuda.set_device(local_rank) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform) + dataset2 = datasets.MNIST('./data', train=False, transform=transform) + + sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True) + sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size) + + train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1} + test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2} + cuda_kwargs = {'num_workers': 2, 'pin_memory': True, 'shuffle': False} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + + model = Net().to(local_rank) + model = FSDP(model) + + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + + init_start_event = torch.cuda.Event(enable_timing=True) + init_end_event = torch.cuda.Event(enable_timing=True) + init_start_event.record() + + # Init ZeusMonitor and GPLO + monitor = ZeusMonitor(gpu_indices=[local_rank]) + plo = GlobalPowerLimitOptimizer(monitor, profile_steps=200) + + for epoch in range(1, args.epochs + 1): + plo.on_epoch_begin() + train(args, model, local_rank, world_size, train_loader, optimizer, epoch, plo, sampler=sampler1) + plo.on_epoch_end() + + test(model, local_rank, world_size, test_loader) + scheduler.step() + + init_end_event.record() + + if rank == 0: + print(f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec") + print(f"{model}") + + if args.save_model: + dist.barrier() + states = model.state_dict() + if rank == 0: + torch.save(states, "mnist_cnn.pt") + + dist.destroy_process_group() + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='PyTorch MNIST FSDP with torchrun') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 10)') + parser.add_argument('--lr', type=float, default=1.0, metavar='LR', + help='learning rate (default: 1.0)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + parser.add_argument('--master_addr', type=str, default=None, + help='Master address for distributed training (optional, otherwise taken from env)') + parser.add_argument('--master_port', type=str, default=None, + help='Master port for distributed training (optional, otherwise taken from env)') + parser.add_argument('--local-rank', type=int, default=0, + help='Local rank for the process (required for torchrun)') + + args = parser.parse_args() + torch.manual_seed(args.seed) + + fsdp_main(args) diff --git a/zeus/optimizer/power_limit.py b/zeus/optimizer/power_limit.py index 0ac2c2d1..95b6727a 100644 --- a/zeus/optimizer/power_limit.py +++ b/zeus/optimizer/power_limit.py @@ -28,6 +28,7 @@ from zeus.callback import Callback from zeus.monitor import ZeusMonitor +from zeus.utils.framework import all_reduce, is_distributed from zeus.utils.logging import get_logger from zeus.utils.metric import zeus_cost from zeus.utils.pydantic_v1 import BaseModel, PositiveInt, PositiveFloat @@ -201,6 +202,23 @@ class GlobalPowerLimitOptimizer(Callback): """Optimizer for the power limit knob. This optimizer uses the JIT profiling log to determine the optimal power limit. + + ## Usage with distributed data parallelism + + The global power limit optimizer expects one process to control each GPU used for training. + For instance, `torchrun` will automatically spawn one process for each GPU on the node. + Correspondingly, the [`ZeusMonitor`][zeus.monitor.energy.ZeusMonitor] instance passed in + should be monitoring **one GPU**: the one being managed by the current process. The index of + this GPU would typically match the local rank of the process. In the case of PyTorch, users would have + called `torch.cuda.set_device` early on, so `torch.cuda.current_device` will give you the GPU index. + `GlobalPowerLimitOptimizer` will internally do an AllReduce across all GPUs to aggregate + time and energy measurements, and then select the globally optimal power limit. + + + ```python + monitor = ZeusMonitor(gpu_indices=[local_rank]) # pass in local rank to gpu_indices. + plo = GlobalPowerLimitOptimizer(monitor) + ``` """ def __init__( @@ -255,9 +273,19 @@ def __init__( # Setup logging. self.logger = get_logger(type(self).__name__) + gpus = get_gpus(ensure_homogeneous=True) + + # Warn if distributed training is enabled with multiple GPUs monitored. + if is_distributed() and len(monitor.gpu_indices) > 1: + self.logger.warning( + "Distributed training is enabled with %d GPUs monitored. " + "For distributed training, it is recommended to monitor only one GPU per `ZeusMonitor` instance " + "since `GlobalPowerLimitOptimizer` performs an all-reduce operation internally over all devices.", + len(monitor.gpu_indices), + ) + # Set the range of power limits to explore. # Assert that supported power limits ranges are uniform across GPUs. - gpus = get_gpus(ensure_homogeneous=True) pls = [] for index in monitor.gpu_indices: pls.append(gpus.getPowerManagementLimitConstraints(index)) @@ -387,11 +415,16 @@ def on_step_begin(self) -> None: "Finished profiling for power limit %d W.", self.state.current_power_limit // 1000, ) + self.measurements.append( PowerLimitMeasurement( power_limit=self.state.current_power_limit // 1000, - energy=measurement.total_energy, - time=measurement.time, + energy=sum( + all_reduce( + list(measurement.gpu_energy.values()), operation="sum" + ) + ), + time=max(all_reduce([measurement.time], operation="max")), ) ) # If we're done profiling all power limits, compute the optimal diff --git a/zeus/utils/framework.py b/zeus/utils/framework.py index b0550154..6b9da07a 100644 --- a/zeus/utils/framework.py +++ b/zeus/utils/framework.py @@ -102,3 +102,58 @@ def sync_execution( return raise RuntimeError("No framework is available.") + + +def all_reduce( + object: list[int] | list[float], operation: Literal["sum", "max"] +) -> list[int] | list[float]: + """Reduce objects from all replicas through the specified operation. + + If the current execution is not distributed, the object is returned as is. + """ + if torch_is_available(ensure_cuda=False): + torch = MODULE_CACHE["torch"] + + # if torch.distributed is not available or not initialized, return the object as is + if ( + not torch.distributed.is_available() + or not torch.distributed.is_initialized() + ): + return object + + # wrap object in a tensor + tensor = torch.Tensor(object).cuda() + + # determine operation + if operation == "sum": + torch_op = torch.distributed.ReduceOp.SUM + elif operation == "max": + torch_op = torch.distributed.ReduceOp.MAX + else: + raise ValueError(f"all_reduce unsupported operation: {operation}") + + torch.distributed.all_reduce(tensor, op=torch_op) + return tensor.cpu().tolist() + + if jax_is_available(): + # Check if not distributed + jax = MODULE_CACHE["jax"] + # if jax is not distributed, return the object as is + if jax.process_count() == 1: + return object + + # TODO: Implement JAX distributed all-reduce logic. + raise NotImplementedError("JAX distributed is not supported yet.") + + raise RuntimeError("No framework is available.") + + +def is_distributed() -> bool: + """Check if the current execution is distributed across multiple devices.""" + if torch_is_available(ensure_cuda=False): + torch = MODULE_CACHE["torch"] + return torch.distributed.is_available() and torch.distributed.is_initialized() + if jax_is_available(): + jax = MODULE_CACHE["jax"] + return jax.device_count() > 1 + raise RuntimeError("No framework is available.")