diff --git a/megatron/training.py b/megatron/training.py index 22ab5f242..21ef13b94 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -380,7 +380,7 @@ def train_step(forward_step_func, data_iterator, assert isinstance(model[0], deepspeed.PipelineEngine), model loss = model[0].train_batch(data_iter=data_iterator) skipped_iter = 0 - grad_norm = 0. + grad_norm = model[0].get_global_grad_norm() num_zeros_in_grad = 0 return {'lm loss' : loss}, skipped_iter, grad_norm, num_zeros_in_grad diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 065f21213..609644d64 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -33,6 +33,7 @@ import os import subprocess + def model_provider(pre_process=True, post_process=True): """Build the model.""" @@ -41,9 +42,10 @@ def model_provider(pre_process=True, post_process=True): args = get_args() with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), - remote_device=None if args.remote_device=='none' else args.remote_device, - config=args.deepspeed_config, - enabled=args.zero_stage==3): + remote_device=None if args.remote_device == 'none' else args.remote_device, + config_dict_or_path=args.deepspeed_config, + enabled=args.zero_stage == 3, + mpu=mpu): if args.deepspeed: model = GPTModelPipe( num_tokentypes=0, @@ -112,6 +114,7 @@ def get_batch(data_iterator): return tokens, labels, loss_mask, attention_mask, position_ids + def get_batch_pipe(data): """Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`""" args = get_args() @@ -139,6 +142,7 @@ def get_batch_pipe(data): return (tokens, position_ids, attention_mask), (labels, loss_mask) + def loss_func(loss_mask, output_tensor): losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() @@ -185,10 +189,12 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): return train_ds, valid_ds, test_ds + def command_exists(cmd): result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) return result.wait() == 0 + def git_ds_info(): from deepspeed.env_report import main as ds_report ds_report()