Skip to content

Commit

Permalink
[microsoft/Megatron-DeepSpeed sync] Commits including 2021-08-09 (#58)
Browse files Browse the repository at this point in the history
* Use new zero.Init() API (#10)

* query deepspeed global grad norm (#8)

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Shaden Smith <[email protected]>
  • Loading branch information
3 people authored Aug 10, 2021
1 parent effb2fb commit 3c9d748
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def train_step(forward_step_func, data_iterator,
assert isinstance(model[0], deepspeed.PipelineEngine), model
loss = model[0].train_batch(data_iter=data_iterator)
skipped_iter = 0
grad_norm = 0.
grad_norm = model[0].get_global_grad_norm()
num_zeros_in_grad = 0
return {'lm loss' : loss}, skipped_iter, grad_norm, num_zeros_in_grad

Expand Down
12 changes: 9 additions & 3 deletions pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import os
import subprocess


def model_provider(pre_process=True, post_process=True):
"""Build the model."""

Expand All @@ -41,9 +42,10 @@ def model_provider(pre_process=True, post_process=True):

args = get_args()
with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
remote_device=None if args.remote_device=='none' else args.remote_device,
config=args.deepspeed_config,
enabled=args.zero_stage==3):
remote_device=None if args.remote_device == 'none' else args.remote_device,
config_dict_or_path=args.deepspeed_config,
enabled=args.zero_stage == 3,
mpu=mpu):
if args.deepspeed:
model = GPTModelPipe(
num_tokentypes=0,
Expand Down Expand Up @@ -112,6 +114,7 @@ def get_batch(data_iterator):

return tokens, labels, loss_mask, attention_mask, position_ids


def get_batch_pipe(data):
"""Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`"""
args = get_args()
Expand Down Expand Up @@ -139,6 +142,7 @@ def get_batch_pipe(data):

return (tokens, position_ids, attention_mask), (labels, loss_mask)


def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
Expand Down Expand Up @@ -185,10 +189,12 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):

return train_ds, valid_ds, test_ds


def command_exists(cmd):
result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
return result.wait() == 0


def git_ds_info():
from deepspeed.env_report import main as ds_report
ds_report()
Expand Down

0 comments on commit 3c9d748

Please sign in to comment.