diff --git a/sample_workloads/lit-gpt-demo/LitGPT.Dockerfile b/sample_workloads/lit-gpt-demo/LitGPT.Dockerfile index 06a2c656..f2f22e31 100644 --- a/sample_workloads/lit-gpt-demo/LitGPT.Dockerfile +++ b/sample_workloads/lit-gpt-demo/LitGPT.Dockerfile @@ -18,6 +18,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.c && apt-get update -y && apt-get install google-cloud-cli -y COPY scripts /workspace/scripts +COPY utilities /workspace/pretrain/utilities COPY openwebtext_trainer.py /workspace/pretrain/ ENTRYPOINT ["/bin/bash", "/workspace/scripts/litgpt_container_entrypoint.sh"] diff --git a/sample_workloads/lit-gpt-demo/helm/templates/litgpt.yaml b/sample_workloads/lit-gpt-demo/helm/templates/litgpt.yaml index 49b5dcbd..d4607c9d 100644 --- a/sample_workloads/lit-gpt-demo/helm/templates/litgpt.yaml +++ b/sample_workloads/lit-gpt-demo/helm/templates/litgpt.yaml @@ -164,6 +164,8 @@ spec: value: "{{$root.Values.workload.warmupIters}}" - name: MAX_ITERS value: "{{$root.Values.workload.maxIters}}" + - name: COLLECT_NSYS_PROFILE + value: "{{$root.Values.workload.collectNsysProfile}}" - name: CLUSTER_TYPE value: GKE volumeMounts: diff --git a/sample_workloads/lit-gpt-demo/helm/values.yaml b/sample_workloads/lit-gpt-demo/helm/values.yaml index 24638314..bb93b740 100644 --- a/sample_workloads/lit-gpt-demo/helm/values.yaml +++ b/sample_workloads/lit-gpt-demo/helm/values.yaml @@ -19,4 +19,4 @@ workload: microBatchSize: 6 warmupIters: 10 maxIters: 1000 - \ No newline at end of file + collectNsysProfile: 'no' # Set to 'yes' for profiles \ No newline at end of file diff --git a/sample_workloads/lit-gpt-demo/openwebtext_trainer.py b/sample_workloads/lit-gpt-demo/openwebtext_trainer.py index def3f415..352560e8 100644 --- a/sample_workloads/lit-gpt-demo/openwebtext_trainer.py +++ b/sample_workloads/lit-gpt-demo/openwebtext_trainer.py @@ -14,10 +14,19 @@ from lightning.pytorch.strategies import FSDPStrategy, XLAStrategy from torch.utils.data import DataLoader, IterableDataset +import torch.multiprocessing as mp +import nvtx + # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) +mp.set_start_method("spawn", force=True) +import utilities.monitor_collectives + +utilities.monitor_collectives.shunt_torch_communication() + + from lit_gpt import Config from lit_gpt.model import GPT, Block from lit_gpt.speed_monitor import SpeedMonitorCallback, estimate_flops, measure_flops @@ -57,6 +66,8 @@ def __init__(self, config: Config) -> None: self.config = config self.module: Optional[torch.nn.Module] = None self.measured_flops: Optional[int] = None + self.nsys_profile_step_multiple = 5 + self.backward_nvtx_range = None def configure_model(self) -> None: self.module = GPT(self.config) @@ -66,9 +77,14 @@ def configure_optimizers(self) -> torch.optim.Optimizer: return torch.optim.AdamW( self.module.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False ) + + def on_train_epoch_start(self) -> None: + print("Resetting max memory allocation") + torch.cuda.reset_peak_memory_stats() def on_fit_start(self) -> None: trainer = self.trainer + with torch.device("meta"): meta_model = GPT(self.module.config) # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. @@ -88,7 +104,39 @@ def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: for optimizer in self.trainer.strategy.optimizers: for param_group in optimizer.param_groups: param_group["lr"] = lr + + global_batch_idx = batch_idx / gradient_accumulation_steps + if ( + global_batch_idx > 0 + and global_batch_idx % self.nsys_profile_step_multiple == 0 + ): + print(f"Starting Nsys profiling") + torch.cuda.cudart().cudaProfilerStart() + + + def on_train_batch_end( + self, outputs, batch: Any, batch_idx: int, unused: int = 0 + ) -> None: + global_batch_idx = batch_idx // gradient_accumulation_steps + global_batch_offset = batch_idx % gradient_accumulation_steps + is_last_microbatch = global_batch_offset == gradient_accumulation_steps - 1 + + if ( + global_batch_idx > 1 + and global_batch_idx % self.nsys_profile_step_multiple == 0 + and is_last_microbatch + ): + self.print(f"Stopping Nsys profiling") + torch.cuda.cudart().cudaProfilerStop() + if is_last_microbatch: + self.print(f"HEARTBEAT: {global_batch_idx=}, {batch_idx=}") + self.print( + f"Max memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB" + ) + sys.stdout.flush() + sys.stderr.flush() + @nvtx.annotate(color='green') def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: input_ids, targets = batch logits = self.module(input_ids) @@ -96,6 +144,17 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True) return loss + def on_before_backward(self, loss): + self.backward_nvtx_range = nvtx.start_range(message="backward", color="red") + + def on_after_backward(self): + if self.backward_nvtx_range: + nvtx.end_range(self.backward_nvtx_range) + + @nvtx.annotate(color='orange') + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): + optimizer.step(closure=optimizer_closure) + def validation_step(self, batch: Any, batch_idx: int) -> None: input_ids, targets = batch logits = self.module(input_ids) @@ -104,68 +163,70 @@ def validation_step(self, batch: Any, batch_idx: int) -> None: def main(devices: int = 1, precision: Optional[str] = None, tpu: bool = False) -> None: - precision = precision or get_default_supported_precision(training=True, tpu=tpu) - - if devices > 1: - if tpu: - # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. - devices = "auto" - strategy = XLAStrategy(sync_module_states=False) + cm = torch.autograd.profiler.emit_nvtx() + with cm: + precision = precision or get_default_supported_precision(training=True, tpu=tpu) + + if devices > 1: + if tpu: + # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. + devices = "auto" + strategy = XLAStrategy(sync_module_states=False) + else: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy={Block}, + # the argument is not available in the Trainer strategy, but it's the default anyways + # state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) else: - strategy = FSDPStrategy( - auto_wrap_policy={Block}, - activation_checkpointing_policy={Block}, - # the argument is not available in the Trainer strategy, but it's the default anyways - # state_dict_type="full", - limit_all_gathers=True, - cpu_offload=False, - ) - else: - strategy = "auto" - - logger = step_csv_logger(out_dir, name, cls=CSVLogger, flush_logs_every_n_steps=log_interval) - speed_monitor = SpeedMonitorCallback( - length_fn=lambda batch: batch[0].size(1), batch_size=micro_batch_size, window_size=10, time_unit="seconds" - ) - model_checkpoint = ModelCheckpoint(dirpath=out_dir, every_n_train_steps=save_interval, save_last=True, verbose=True) - trainer = L.Trainer( - devices=devices, - strategy=strategy, - precision=precision, - logger=logger, - callbacks=[speed_monitor, model_checkpoint], - max_steps=max_iters, - max_epochs=1, - limit_val_batches=eval_iters, - accumulate_grad_batches=gradient_accumulation_steps, - log_every_n_steps=log_interval, - val_check_interval=eval_interval, - num_nodes=num_nodes - ) - - L.seed_everything(1337, workers=True) # same seed for every process to init model (FSDP) - - trainer.print(hparams) - - if trainer.global_rank == 0: - out_dir.mkdir(parents=True, exist_ok=True) - - config = Config.from_name(model_name) - trainer.print(f"Loading model with {config.__dict__}") - t0 = time.perf_counter() - model = LightningGPTModule(config) - trainer.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") - - train_data = Dataset(str(data_dir / "train.bin"), config.block_size) - val_data = Dataset(str(data_dir / "val.bin"), config.block_size) - train_dataloader = DataLoader(train_data, batch_size=micro_batch_size, num_workers=2) - val_dataloader = DataLoader(val_data, batch_size=micro_batch_size, num_workers=2) - - t0 = time.perf_counter() - trainer.fit(model, train_dataloader, val_dataloader, ckpt_path="last") - trainer.print(f"Training time: {(time.perf_counter()-t0):.2f}s") - if trainer.strategy.root_device.type == "cuda": - trainer.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + strategy = "auto" + + logger = step_csv_logger(out_dir, name, cls=CSVLogger, flush_logs_every_n_steps=log_interval) + speed_monitor = SpeedMonitorCallback( + length_fn=lambda batch: batch[0].size(1), batch_size=micro_batch_size, window_size=10, time_unit="seconds" + ) + model_checkpoint = ModelCheckpoint(dirpath=out_dir, every_n_train_steps=save_interval, save_last=True, verbose=True) + trainer = L.Trainer( + devices=devices, + strategy=strategy, + precision=precision, + logger=logger, + callbacks=[speed_monitor, model_checkpoint], + max_steps=max_iters, + max_epochs=1, + limit_val_batches=eval_iters, + accumulate_grad_batches=gradient_accumulation_steps, + log_every_n_steps=log_interval, + val_check_interval=eval_interval, + num_nodes=num_nodes + ) + + L.seed_everything(1337, workers=True) # same seed for every process to init model (FSDP) + + trainer.print(hparams) + + if trainer.global_rank == 0: + out_dir.mkdir(parents=True, exist_ok=True) + + config = Config.from_name(model_name) + trainer.print(f"Loading model with {config.__dict__}") + t0 = time.perf_counter() + model = LightningGPTModule(config) + trainer.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") + + train_data = Dataset(str(data_dir / "train.bin"), config.block_size) + val_data = Dataset(str(data_dir / "val.bin"), config.block_size) + train_dataloader = DataLoader(train_data, batch_size=micro_batch_size, num_workers=2) + val_dataloader = DataLoader(val_data, batch_size=micro_batch_size, num_workers=2) + + t0 = time.perf_counter() + trainer.fit(model, train_dataloader, val_dataloader, ckpt_path="last") + trainer.print(f"Training time: {(time.perf_counter()-t0):.2f}s") + if trainer.strategy.root_device.type == "cuda": + trainer.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") class Dataset(IterableDataset): diff --git a/sample_workloads/lit-gpt-demo/scripts/litgpt_container_entrypoint.sh b/sample_workloads/lit-gpt-demo/scripts/litgpt_container_entrypoint.sh index 1999904f..906dc715 100644 --- a/sample_workloads/lit-gpt-demo/scripts/litgpt_container_entrypoint.sh +++ b/sample_workloads/lit-gpt-demo/scripts/litgpt_container_entrypoint.sh @@ -12,6 +12,7 @@ set -o pipefail : "${GCS_DATA_BUCKET:?Must set GCS_DATA_BUCKET}" : "${DATA_DIR:?Must set DATA_DIR}" : "${CLUSTER_TYPE:='GKE'}" +: "${COLLECT_NSYS_PROFILE:='no'}" export EXPERIMENT_LOCAL_DIR=/experiment/${EXPERIMENT_ROOT_DIR} @@ -33,6 +34,9 @@ export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) LOG_DIR=$EXPERIMENT_LOCAL_DIR/training_logs mkdir -p $LOG_DIR +PROFILING_DIR=$EXPERIMENT_LOCAL_DIR/nsys_profiles +mkdir -p $PROFILING_DIR + OUT_DIR=$EXPERIMENT_LOCAL_DIR/out mkdir -p $OUT_DIR @@ -145,7 +149,6 @@ fi PIDS=() - CPU_SETS=( "0-7,104-111" "8-15,112-119" "16-23,120-127" "24-31,128-135" "52-59,156-163" "60-67,164-171" "68-75,172-179" "76-83,180-187" ) for ((LOCAL_RANK=0; LOCAL_RANK <= $((GPUS_PER_NODE - 1)); LOCAL_RANK++)); do @@ -161,6 +164,10 @@ for ((LOCAL_RANK=0; LOCAL_RANK <= $((GPUS_PER_NODE - 1)); LOCAL_RANK++)); do fi CMD_PREFIX="numactl --membind=$MEMBIND_NUMA_NODE --physcpubind $CPUS" + if [[ "${COLLECT_NSYS_PROFILE:="no"}" == "yes" ]]; then + echo "Collecting nsys profile" + CMD_PREFIX="${CMD_PREFIX} nsys profile --sample=none --trace=cuda,nvtx -o $PROFILING_DIR/node_${NODE_RANK:?}_local_rank_${LOCAL_RANK} --capture-range=cudaProfilerApi --capture-range-end=repeat:${PROFILE_REPS:=5} --export sqlite " + fi RANK=$RANK LOCAL_RANK=$LOCAL_RANK \ $CMD_PREFIX \ diff --git a/sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py b/sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py new file mode 100644 index 00000000..b24e645d --- /dev/null +++ b/sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py @@ -0,0 +1,594 @@ +"""A utility to trace torch.distributed calls. + +Traces torch.distributed collectives before dispatch. In particular, logs the +collective kind (all_reduce, all_to_all, ..), message size (10 MB), and which +GPU devices are participating ([0, 1, 6, 7]). These are logged as NVTX markers +by NVIDIA Nsight, as well as printed to stdout. By default, we only log +cross-node collective communications. + +To assist with computing the effective bandwidth of a collective, a nominal +expression is provided in the doc string of each 'traced_'. This +also requires extracting the timings of the corresponding NCCL kernels. + +Typical usage example: + + import utilities.monitor_collectives + utilities.monitor_collectives.shunt_torch_communication() + +When running a workload, also define TORCH_DISTRIBUTED_TRACING to be one of +'ALL' or 'CROSSNODE'. See `should_rank_record_comm` for added details. +""" + + +import functools +import inspect +import io +import json +import os +import pickle +import sys +from datetime import datetime +import calendar +import uuid + +import nvtx +import torch.cuda +import torch.distributed + + +_TRACE_MODE = None + + +# Note: By default, we only target tracing *cross-node* communications. +# See 'should_rank_record_comm' +def shunt_torch_communication(): + _identify_trace_mode() + if _TRACE_MODE == 'none': + if int(os.environ.get("RANK", "0")) == 0: + print('Tracing torch.distributed collectives disabled.', flush=True) + return + + _shunt_torch_communication_objects() + _shunt_torch_communication_calls() + + if int(os.environ.get("RANK", "0")) == 0: + print('NVTX and print tracing of torch.distributed collectives enabled.', + flush=True) + print(f"{_GPU_SERIAL=}, {_VM_ID=}") + + if not _SHOULD_PRINT: + print('Collectives are traced but will not be printed to stdout', flush=True) + + +def _identify_trace_mode(): + global _TRACE_MODE + _TRACE_MODE = os.environ.get('TORCH_DISTRIBUTED_TRACING', 'CROSSNODE') + _TRACE_MODE = _TRACE_MODE.lower() + + global _SHOULD_PRINT + _SHOULD_PRINT = os.environ.get('TORCH_DISTRIBUTED_TRACING_PRINT', 'False') + _SHOULD_PRINT = _SHOULD_PRINT.lower() in ['true', '1', 't', 'y', 'yes'] + + global _GPU_SERIAL + _GPU_SERIAL = os.environ.get("GPU_SERIAL", "unknown") + global _VM_ID + _VM_ID = os.environ.get("VM_ID", "unknown") + + +# Each wrapper should match format 'traced_' +def _shunt_torch_communication_calls(): + """Replaces torch.distributed. with a traced version. + """ + target_collectives = [ + 'barrier', + 'broadcast_object_list', + 'broadcast', + 'gather', + 'scatter', + 'reduce', + 'reduce_scatter', + 'reduce_scatter_tensor', + 'all_reduce', + 'all_gather', + 'all_gather_into_tensor', + 'all_to_all', + 'all_to_all_single', + 'batch_isend_irecv', + 'isend', + 'irecv', + 'send', + 'recv', + ] + + this_module = sys.modules[__name__] + for collective in target_collectives: + original_fn = getattr(torch.distributed, collective) + replaced_fn = getattr(this_module, 'traced_' + collective) + setattr(torch.distributed, 'untraced_' + collective, original_fn) + setattr(torch.distributed, collective, replaced_fn) + + +def _shunt_torch_communication_objects(): + original_p2p = torch.distributed.P2POp + setattr(torch.distributed, 'UntracedP2POp', original_p2p) + setattr(torch.distributed, 'P2POp', _TracedP2POp) + + +# Each 'traced_' defines a 'message_size' to compute B/W. +# Ref https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_barrier(group=None, async_op=False, device_ids=None): + """Intercepts invocations of torch.distributed.barrier. + """ + if _should_rank_record_comm(group): + _emit_call_description('barrier', message_size=1, group=group) + + return torch.distributed.untraced_barrier(group, async_op, device_ids) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_broadcast_object_list(object_list, src=0, group=None, device=None): + """Intercepts invocations of torch.distributed.broadcast_object_list. + + Converts objects to tensor data using the pickle library. Then conducts a + torch.distributed.broadcast call. + """ + + if _should_rank_record_comm(group, root_rank=src): + message_size = 0 + for obj in object_list: + # Note: This computation is sadly redundant with underlying call :( + # For now we don't expect this invocation to be in critical path. + buf = io.BytesIO() + pickle.Pickler(buf).dump(obj) + message_size += buf.getbuffer().nbytes + _emit_call_description( + 'broadcast_object_list', message_size, group, root_rank=src) + + return torch.distributed.untraced_broadcast_object_list( + object_list, src, group, device) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_broadcast(tensor, src, group=None, async_op=False): + """Intercepts invocations of torch.distributed.broadcast. + + Calculate [Ring-B/W] = [Message Size]/[Kernel Time] for large [Message Size] + + https://images.nvidia.com/events/sc15/pdfs/NCCL-Woolley.pdf + """ + if _should_rank_record_comm(group, root_rank=src): + message_size = tensor.nelement() * tensor.element_size() + _emit_call_description('broadcast', message_size, group, root_rank=src) + + return torch.distributed.untraced_broadcast( + tensor, src, group, async_op) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_gather( + tensor, gather_list=None, dst=0, group=None, async_op=False): + """Intercepts invocations of torch.distributed.gather. + + Let T := sum([Receive Kernel Time from Rank i] for i != dst) + Calculate [P2P-B/W] = [Message Size]/T + + Each of (n-1) ranks sends a message to the root. + + Note that any correction factors for the bus bandwidth (e.g. [n-1]/n) depend + on the *definition* of 'Message Size'. In some cases, such as for 'gather', we + define 'Message Size' so as to omit the size of data that is already local + to the destination GPU for the 'gather' operation. In this case, no correction + factor is needed. In NCCL tests, they assume all ranks send equal sized + messages and include this size of data already resident on the destination + GPU. Thus, in there case you see a (n-1)/n correction factor on calculating + the bus bandwidth. In general, the goal of computing the bus bandwidth is + to compare data transfer rates on the bus relative to peak bus bandwidth. + See https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md. + + https://github.com/NVIDIA/nccl-tests/blob/1a5f551ffd6e/src/gather.cu#L54 + https://github.com/pytorch/pytorch/blob/bfd995f0d6bf/torch/csrc/cuda/nccl.cpp#L1040 + """ + if _should_rank_record_comm(group, root_rank=dst, is_ring=False): + message_size = functools.reduce( + lambda sz, x: sz + x.nelement() * x.element_size(), gather_list, 0) + message_size -= tensor.nelement() * tensor.element_size() + + _emit_call_description('gather', message_size, group, root_rank=dst) + + return torch.distributed.untraced_gather( + tensor, gather_list, dst, group, async_op) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_scatter( + tensor, scatter_list=None, src=0, group=None, async_op=False): + """Intercepts invocations of torch.distributed.scatter. + + Let T := sum([Send Kernel Time from Rank i] for i != src) + Calculate [P2P-B/W] = [Message Size]/T + + Each of (n-1) ranks receives a message from the root. + There is no (n-1)/n factor as we factor it in [Message Size]. + + https://github.com/NVIDIA/nccl-tests/blob/1a5f551ffd6e/src/scatter.cu#L50 + https://github.com/pytorch/pytorch/blob/bfd995f0d6bf/torch/csrc/cuda/nccl.cpp#L1089 + """ + if _should_rank_record_comm(group, root_rank=src, is_ring=False): + message_size = functools.reduce( + lambda sz, x: sz + x.nelement() * x.element_size(), scatter_list, 0) + message_size -= tensor.nelement() * tensor.element_size() + + _emit_call_description('scatter', message_size, group, root_rank=src) + + return torch.distributed.untraced_scatter( + tensor, scatter_list, src, group, async_op) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_reduce( + tensor, dst, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): + """Intercepts invocations of torch.distributed.reduce. + + Calculate [Ring-B/W] = [Message Size]/[Kernel Time] for large [Message Size] + Also see 'traced_broadcast' + """ + if _should_rank_record_comm(group, root_rank=dst): + message_size = tensor.nelement() * tensor.element_size() + _emit_call_description('reduce', message_size, group, root_rank=dst) + + return torch.distributed.untraced_reduce(tensor, dst, op, group, async_op) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_reduce_scatter( + output, + input_list, + op=torch.distributed.ReduceOp.SUM, + group=None, + async_op=False): + """Intercepts invocations of torch.distributed.reduce_scatter. + + Let n := [Group Size]. + Calculate [Ring-B/W] = (n-1)/n * [Message Size]/[Kernel Time] + Assumes equal tensor sizes. It's the same as first half of ring All-Reduce. + """ + if _should_rank_record_comm(group): + message_size = output.nelement() * output.element_size() + _emit_call_description('reduce_scatter', message_size, group) + + return torch.distributed.untraced_reduce_scatter( + output, input_list, op, group, async_op) + + +# pylint: disable=redefined-builtin,g-doc-args,g-doc-return-or-yield +def traced_reduce_scatter_tensor( + output, + input, + op=torch.distributed.ReduceOp.SUM, + group=None, + async_op=False): + """Intercepts invocations of torch.distributed.reduce_scatter_tensor. + + Similar to 'traced_reduce_scatter' + """ + + if _should_rank_record_comm(group): + message_size = output.nelement() * output.element_size() + _emit_call_description('reduce_scatter', message_size, group) + + return torch.distributed.untraced_reduce_scatter_tensor( + output, input, op, group, async_op) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_all_reduce( + tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): + """Intercepts invocations of torch.distributed.all_reduce. + + Let n := [Group Size] + Calculate [Ring-B/W] = 2(n-1)/n * [Message Size] / [Kernel Time] + + https://images.nvidia.com/events/sc15/pdfs/NCCL-Woolley.pdf + """ + if _should_rank_record_comm(group): + message_size = tensor.nelement() * tensor.element_size() + _emit_call_description('all_reduce', message_size, group) + + return torch.distributed.untraced_all_reduce( + tensor, op, group, async_op) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_all_gather(tensor_list, tensor, group=None, async_op=False): + """Intercepts invocations of torch.distributed.all_gather. + + Let n := [Group Size] + Calculate [Ring-B/W] = (n-1)/n * [Message Size] / [Kernel Time] + Assuming equal tensor sizes. + """ + if _should_rank_record_comm(group): + message_size = functools.reduce( + lambda size, x: size + x.nelement() * x.element_size(), tensor_list, 0) + _emit_call_description('all_gather', message_size, group) + + return torch.distributed.untraced_all_gather( + tensor_list, tensor, group, async_op) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_all_gather_into_tensor( + output_tensor, input_tensor, group=None, async_op=False): + """Intercepts invocations of torch.distributed.all_gather_into_tensor. + + Similar 'traced_all_gather' + """ + if _should_rank_record_comm(group): + message_size = output_tensor.nelement() * output_tensor.element_size() + _emit_call_description('all_gather', message_size, group) + + return torch.distributed.untraced_all_gather_into_tensor( + output_tensor, input_tensor, group, async_op) + + +# Note: The TCP Direct team intends to implement a custom version of AllToAll. +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_all_to_all( + output_tensor_list, input_tensor_list, group=None, async_op=False): + """Intercepts invocations of torch.distributed.all_to_all. + + Let S := sum([Message Size on Rank i] for i = 1..n) where n := [Group Size] + Let T := [End of last Receive last rank] - [Start of first Send first rank] + Calculate [Algo B/W] = S / T. + + There is no n/(n-1) correction factor as we factor it in [Message Size]. + + https://github.com/NVIDIA/nccl-tests/blob/1a5f551ffd6e/src/alltoall.cu#L57 + https://github.com/pytorch/pytorch/blob/bfd995f0d6bf/torch/csrc/cuda/nccl.cpp#L911 + """ + if _should_rank_record_comm(group): + message_size = functools.reduce( + lambda s, x: s + x.nelement() * x.element_size(), input_tensor_list, 0) + + # Omit bytes corresponding to send and receive on the same rank + self_tensor = input_tensor_list[torch.distributed.get_rank(group)] + message_size -= self_tensor.nelement() * self_tensor.element_size() + + _emit_call_description('all_to_all', message_size, group) + + return torch.distributed.untraced_all_to_all( + output_tensor_list, input_tensor_list, group, async_op) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield,redefined-builtin +def traced_all_to_all_single( + output, + input, + output_split_sizes=None, + input_split_sizes=None, + group=None, + async_op=False): + """Intercepts invocations of torch.distributed.all_to_all_single. + + Similar to 'traced_all_to_all' + """ + if _should_rank_record_comm(group): + self_rank = torch.distributed.get_rank(group) + + if input_split_sizes is not None: + self_slice = input_split_sizes[self_rank] + else: + self_slice = input.size(dim=0) / torch.distributed.get_world_size(group) + + slice_nelement = input.nelement() / input.size(dim=0) + message_size = input.nelement() * input.element_size() + message_size -= self_slice * slice_nelement * input.element_size() + + _emit_call_description('all_to_all_single', message_size, group) + + return torch.distributed.untraced_all_to_all_single( + output, input, output_split_sizes, input_split_sizes, group, async_op) + + +# Note: Each send and receive occurs on indepenent CUDA streams +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_batch_isend_irecv(p2p_op_list): + """Intercepts invocations of torch.distributed.batch_isend_irecv. + + Calculate [P2P-B/W] = [Message Size]/[Kernel Time] for each send and recv. + """ + correlation_id = str(uuid.uuid4()) + for p2p in p2p_op_list: + if _SHOULD_PRINT: + print(f"Num p2p ops in batch: {len(p2p_op_list)}") + if _should_rank_record_comm(p2p.group, peer_rank=p2p.peer, is_ring=False): + api = 'send' if p2p.op == torch.distributed.untraced_isend else 'recv' + + message_size = p2p.tensor.nelement() * p2p.tensor.element_size() + _emit_call_description(api, message_size, group=p2p.group, peer_rank=p2p.peer, correlation_id=correlation_id) + + return torch.distributed.untraced_batch_isend_irecv(p2p_op_list) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_isend(tensor, dst, group=None, tag=0): + """Intercepts invocations of torch.distributed.isend. + + Calculate [P2P-B/W] = [Message Size]/[Kernel Time] + """ + if _should_rank_record_comm(group, peer_rank=dst, is_ring=False): + message_size = tensor.nelement() * tensor.element_size() + _emit_call_description('send', message_size, group, dst) + + return torch.distributed.untraced_isend(tensor, dst, group, tag) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_irecv(tensor, src=None, group=None, tag=0): + """Intercepts invocations of torch.distributed.irecv. + """ + if _should_rank_record_comm(group, peer_rank=src, is_ring=False): + message_size = tensor.nelement() * tensor.element_size() + _emit_call_description('recv', message_size, group, src) + + return torch.distributed.untraced_irecv(tensor, src, group, tag) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_send(tensor, dst, group=None, tag=0): + """Intercepts invocations of torch.distributed.send. + """ + if _should_rank_record_comm(group, peer_rank=dst, is_ring=False): + message_size = tensor.nelement() * tensor.element_size() + _emit_call_description('send', message_size, group, dst) + + return torch.distributed.untraced_send(tensor, dst, group, tag) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def traced_recv(tensor, src=None, group=None, tag=0): + """Intercepts invocations of torch.distributed.recv. + """ + if _should_rank_record_comm(group, peer_rank=src, is_ring=False): + message_size = tensor.nelement() * tensor.element_size() + _emit_call_description('recv', message_size, group, src) + + return torch.distributed.untraced_recv(tensor, src, group, tag) + + +@functools.lru_cache(maxsize=None) +def _should_rank_record_comm( + group=None, peer_rank=None, root_rank=None, is_ring=True): + """Decides whether a given torch.distributed collective should be recorded. + + Args: + group: The torch process group (i.e. participating GPUs) in this collective. + peer_rank: In direct peer to peer operations, the global rank of the peer. + root_rank: The global rank of the root GPU, for collectives with a root. + as_ring: Whether the default NCCL implementation uses a ring algorithm. + Specifying 'peer_rank' and 'is_ring=True' are incompatible. + + Returns: + Whether to record a descriptive NVTX marker, and possibly print a log trace. + """ + if not _is_current_process_in_group(group): + return False + if _TRACE_MODE == 'crossnode' and not _is_crossnode_comm(group, peer_rank): + return False + if not is_ring and root_rank is not None: + return torch.distributed.get_rank() == root_rank + + return True + + +def _is_current_process_in_group(group=None): + return torch.distributed.get_rank(group) >= 0 + + +@functools.lru_cache(maxsize=None) +def _is_crossnode_comm(group=None, peer_rank=None): + """Whether this collective involves communication across nodes. + + Args: + group: The torch process group (i.e. participating GPUs) in this collective. + peer: In direct peer to peer operations, the global rank of the peer. + + Returns: + Whether this collective involves communications across nodes. + """ + count_per_node = torch.cuda.device_count() + + if peer_rank is not None: + this_node = int(torch.distributed.get_rank() / count_per_node) + peer_node = int(peer_rank / count_per_node) + return this_node != peer_node + else: + if group is not None: + ranks = torch.distributed.get_process_group_ranks(group=group) + else: + ranks = [*range(torch.distributed.get_world_size())] + + nodes = list(map(lambda rank: int(rank / count_per_node), ranks)) + return any([node != nodes[0] for node in nodes]) + + +def _emit_call_description( + name, message_size, group=None, peer_rank=None, root_rank=None, correlation_id=None): + call_description = _TorchDistributedCallDescriptor( + name, message_size, group, peer_rank, root_rank, correlation_id).to_json() + + nvtx.mark(call_description) + if _should_rank_print(group, peer_rank, root_rank): + print(call_description) + + +class _TorchDistributedCallDescriptor: + """Description of a torch.distributed comm call to be stored as NVTX marker. + """ + + def __init__( + self, name, message_size, group=None, peer_rank=None, root_rank=None, correlation_id=None): + self.name = name + self.rank = torch.distributed.get_rank() + self.source_line = _get_call_source_line() + self.message_size = message_size + self.device = torch.cuda.current_device() + self.timestamp = calendar.timegm(datetime.utcnow().utctimetuple()) + self.gpu_serial = _GPU_SERIAL + self.vm_id = _VM_ID + if group is not None: + self.group_ranks = torch.distributed.get_process_group_ranks(group=group) + if peer_rank is not None: + self.peer_rank = peer_rank + if root_rank is not None: + self.root_rank = root_rank + if correlation_id is not None: + self.correlation_id = correlation_id + + def to_json(self): + return json.dumps(self, default=lambda o: o.__dict__) + + +def _should_rank_print(group=None, peer_rank=None, root_rank=None): + if not _SHOULD_PRINT: + return False + if root_rank is not None: + leader = root_rank + elif group is not None: + leader = torch.distributed.get_global_rank(group, 0) + else: + leader = 0 + + return (peer_rank is not None) or torch.distributed.get_rank() == leader + + +# A fixed depth works for all cases here +def _get_call_source_line(depth=4): + caller = inspect.getframeinfo(inspect.stack()[depth][0]) + return '{}:{}'.format(caller.filename, caller.lineno) + + +# We need to un-hide the original type for 'batch_isend_irecv' due to type +# checks performed by torch.distributed. This is not an issue as by then we +# have already recorded the call. +class _TracedP2POp(torch.distributed.P2POp): + """Used to redirect torch.distributed.i{send,recv} on 'batch_isend_irecv'. + """ + + def __init__(self, op, tensor, peer, group=None, tag=0): + original_op = _get_original_p2p_op(op) + torch.distributed.UntracedP2POp.__init__( + self, original_op, tensor, peer, group, tag) + + def __new__(cls, op, tensor, peer, group=None, tag=0): + original_op = _get_original_p2p_op(op) + return torch.distributed.UntracedP2POp.__new__( + cls, original_op, tensor, peer, group, tag) + + +def _get_original_p2p_op(op): + if op == torch.distributed.isend: + return torch.distributed.untraced_isend + elif op == torch.distributed.irecv: + return torch.distributed.untraced_irecv \ No newline at end of file