From 1c2f0a115675720292797f46ddb1302f589e34d8 Mon Sep 17 00:00:00 2001 From: cokedong <408244909@qq.com> Date: Fri, 10 Nov 2023 11:20:56 +0800 Subject: [PATCH] Add tgs metrics --- megatron/training.py | 5 +++++ megatron/utils.py | 3 --- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index 98df5bd2b4..961b14e619 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1032,6 +1032,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, samples_per_sec_per_replica = samples_per_sec / args.data_parallel_size tokens_per_sec = samples_per_sec * seq_len tokens_per_sec_per_replica = tokens_per_sec / args.data_parallel_size + tokens_per_gpu_per_second = tokens_per_sec / args.world_size + tokens_per_gpu_per_second_per_replica = tokens_per_gpu_per_second / args.data_parallel_size if wandb is not None and getattr(wandb, 'run', None) is not None: tput = { 'throughput/iteration-time': elapsed_time_per_iteration, # 1000 ms / s @@ -1039,6 +1041,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, 'throughput/samples_per_sec_per_replica': samples_per_sec_per_replica, 'throughput/tokens_per_sec': tokens_per_sec, 'throughput/tokens_per_sec_per_replica': tokens_per_sec_per_replica, + 'throughput/tokens_per_gpu_per_sec': tokens_per_gpu_per_second, + 'throughput/tokens_per_gpu_per_sec_per_replica': tokens_per_gpu_per_second_per_replica, 'throughput/tflops': tflops, 'throughput/approx_params_in_billions': approx_parameters_in_billions, 'throughput/elapsed_ms_per_iteration': elapsed_time_per_iteration, @@ -1088,6 +1092,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, log_string += ' number of nan iterations: {:3d} |'.format( total_loss_dict[nan_iters_key]) log_string += ' samples per second: {:.3f} |'.format(samples_per_sec) + log_string += ' tokens per gpu per second (tgs): {:.3f} |'.format(tokens_per_gpu_per_second) log_string += ' TFLOPs: {:.2f} |'.format(tflops) total_loss_dict[advanced_iters_key] = 0 total_loss_dict[skipped_iters_key] = 0 diff --git a/megatron/utils.py b/megatron/utils.py index ea8e7a003b..02bd158dbf 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -251,10 +251,7 @@ def get_parameters_in_billions(model): return approx_parameters_in_billions*gpus_per_model/(1e9) def throughput_calculator(model, args, iteration_time, total_iterations): - gpus_per_model = torch.distributed.get_world_size(group = mpu.get_model_parallel_group()) batch_size = args.micro_batch_size * get_num_microbatches() * args.data_parallel_size - samples_per_model = batch_size * args.seq_length - model_replica_count = torch.distributed.get_world_size() / gpus_per_model approx_parameters_in_billions = None if (model is None) else get_parameters_in_billions(model) elapsed_time_per_iter = iteration_time/total_iterations samples_per_second = batch_size / elapsed_time_per_iter