Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tgs metrics #286

Merged
merged 1 commit into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,13 +1032,17 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
samples_per_sec_per_replica = samples_per_sec / args.data_parallel_size
tokens_per_sec = samples_per_sec * seq_len
tokens_per_sec_per_replica = tokens_per_sec / args.data_parallel_size
tokens_per_gpu_per_second = tokens_per_sec / args.world_size
tokens_per_gpu_per_second_per_replica = tokens_per_gpu_per_second / args.data_parallel_size
if wandb is not None and getattr(wandb, 'run', None) is not None:
tput = {
'throughput/iteration-time': elapsed_time_per_iteration, # 1000 ms / s
'throughput/samples_per_sec': samples_per_sec,
'throughput/samples_per_sec_per_replica': samples_per_sec_per_replica,
'throughput/tokens_per_sec': tokens_per_sec,
'throughput/tokens_per_sec_per_replica': tokens_per_sec_per_replica,
'throughput/tokens_per_gpu_per_sec': tokens_per_gpu_per_second,
'throughput/tokens_per_gpu_per_sec_per_replica': tokens_per_gpu_per_second_per_replica,
'throughput/tflops': tflops,
'throughput/approx_params_in_billions': approx_parameters_in_billions,
'throughput/elapsed_ms_per_iteration': elapsed_time_per_iteration,
Expand Down Expand Up @@ -1088,6 +1092,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string += ' number of nan iterations: {:3d} |'.format(
total_loss_dict[nan_iters_key])
log_string += ' samples per second: {:.3f} |'.format(samples_per_sec)
log_string += ' tokens per gpu per second (tgs): {:.3f} |'.format(tokens_per_gpu_per_second)
log_string += ' TFLOPs: {:.2f} |'.format(tflops)
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = 0
Expand Down
3 changes: 0 additions & 3 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,7 @@ def get_parameters_in_billions(model):
return approx_parameters_in_billions*gpus_per_model/(1e9)

def throughput_calculator(model, args, iteration_time, total_iterations):
gpus_per_model = torch.distributed.get_world_size(group = mpu.get_model_parallel_group())
batch_size = args.micro_batch_size * get_num_microbatches() * args.data_parallel_size
samples_per_model = batch_size * args.seq_length
model_replica_count = torch.distributed.get_world_size() / gpus_per_model
approx_parameters_in_billions = None if (model is None) else get_parameters_in_billions(model)
elapsed_time_per_iter = iteration_time/total_iterations
samples_per_second = batch_size / elapsed_time_per_iter
Expand Down