From b87a03cc0838e2b92ceccef3ac7f3a77287e4254 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Thu, 1 Feb 2024 13:52:17 -0600 Subject: [PATCH] Track additional metrics with W&B in `megatron/training.py` --- megatron/training.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index c7ac2574fb..29d2181d46 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1039,7 +1039,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, tokens_per_gpu_per_second = tokens_per_sec / args.world_size tokens_per_gpu_per_second_per_replica = tokens_per_gpu_per_second / args.data_parallel_size if wandb is not None and getattr(wandb, 'run', None) is not None: - tput = { + assert wandb.run is not None + wandb_metrics = { 'throughput/iteration-time': elapsed_time_per_iteration, # 1000 ms / s 'throughput/samples_per_sec': samples_per_sec, 'throughput/samples_per_sec_per_replica': samples_per_sec_per_replica, @@ -1050,8 +1051,13 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, 'throughput/tflops': tflops, 'throughput/approx_params_in_billions': approx_parameters_in_billions, 'throughput/elapsed_ms_per_iteration': elapsed_time_per_iteration, + 'throughput/iteration': iteration, } - wandb.run.log(tput) + if loss_dict is not None: + wandb_metrics |= { + f'loss/{k}': v for k, v in loss_dict.items() + } + wandb_metrics |= {'loss/iteration': iteration} if writer: if args.log_timers_to_tensorboard: writer.add_scalar('iteration-time/iteration-time', @@ -1060,6 +1066,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, elapsed_time_per_iteration, args.consumed_train_samples) writer.add_scalar('iteration-time/iteration-time vs tokens', elapsed_time_per_iteration, args.consumed_train_tokens) + if wandb is not None and getattr(wandb, 'run', None) is not None: + wandb_metrics |= { + 'iteration': iteration, + 'iteration_time': elapsed_time_per_iteration, + 'iteration_time_vs_tokens': ( + (elapsed_time_per_iteration + / args.consumed_train_tokens) + ), + 'iteration_time_vs_samples': ( + (elapsed_time_per_iteration + / args.consumed_train_samples), + ), + } + if wandb is not None and getattr(wandb, 'run', None) is not None: + wandb.log(wandb_metrics) log_string = ' iteration {:8d}/{:8d} |'.format( iteration, args.train_iters) log_string += ' consumed samples: {:12d} |'.format(