Skip to content

Commit

Permalink
Fix deepspeed prefix-lm (bigscience-workshop#107)
Browse files Browse the repository at this point in the history
* Fix pretrain prefix lm using deepspeed

* Fix: self._args to args

* First set attn_mask in model and then build model

* Fix: enforce that we pass down tuple instead of generator

* Attention mask does not need to be transposed

* BIGGEST HACK EVER

* Remove BIGGEST HACK

* Skip prefix test as PP>1 doesn't work yet on deepspeed

* Unskip prefix test

* Merge branch 'main' into thomas/fix_deepspeed_prefix
  • Loading branch information
thomasw21 authored Oct 10, 2021
1 parent b5098e6 commit da31db6
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 11 deletions.
16 changes: 14 additions & 2 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,17 @@ def _to_float16(inputs):
tied_weight_attr='word_embeddings_weight'))

if args.fp32_residual_connection:
self.specs.append(lambda x: x.transpose(0, 1).contiguous().float())
if hasattr(args, 'attn_mask'):
self.specs.append(lambda x: x.transpose(0, 1).contiguous().float())
else:
# EmbeddingPipe returns attention mask as well
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous().float(), *x[1:]))
else:
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
if hasattr(args, 'attn_mask'):
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
else:
# EmbeddingPipe returns attention mask as well
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:]))

for layer_idx in range(args.num_layers):
self.specs.append(
Expand All @@ -222,6 +230,10 @@ def _to_float16(inputs):
self_attn_mask_type=AttnMaskType.prefix if prefix_lm else AttnMaskType.causal))


if not hasattr(args, 'attn_mask'):
# We drop attention mask from the pipeline
self.specs.append(lambda x: x[0])

# Undo data format change
self.specs.append(lambda x: x.transpose(0, 1).contiguous())

Expand Down
1 change: 0 additions & 1 deletion megatron/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def forward(self, inputs, **kwargs):
if hasattr(self._args, 'attn_mask'):
return embeddings
else:
assert False
return embeddings, attention_mask


Expand Down
16 changes: 8 additions & 8 deletions pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,6 @@ def model_provider(pre_process=True, post_process=True):
enabled=args.zero_stage == 3,
mpu=mpu):
if args.deepspeed:
model = GPTModelPipe(
num_tokentypes=0,
parallel_output=True
)
# This is a hack to give us a reference to get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe

# Precompute the attention mask and store it in args. This avoids having to
# pipeline it as an activation during training. The mask is constant, and thus
# we can reuse it.
Expand All @@ -73,6 +65,14 @@ def model_provider(pre_process=True, post_process=True):
# must be bool or the training crashes expecting bool, but getting Half
args.attn_mask = attention_mask.to(torch.bool)
args.attn_mask_original = attention_mask.to(torch.bool)

model = GPTModelPipe(
num_tokentypes=0,
parallel_output=True
)
# This is a hack to give us a reference to get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe
else:
model = GPTModel(
num_tokentypes=0,
Expand Down
110 changes: 110 additions & 0 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,113 @@ def test_training_all(self, variation):
# test tensorboard (1 file from the first run, plus 1 now)
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
self.assertEqual(len(tensorboard_files), 2, "tensorboard files")

def test_training_prefix_lm_all(self):
# all in one test
src_dir = self.src_dir
data_dir = f"{self.data_dir}/gpt2"
output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False)

pp_size, tp_size, dp_size = get_3d_dimensions()
num_gpus = pp_size * tp_size * dp_size

n_samples = 200 # about 37 iterations
exit_interval = 20 # some samples in the first half and then some more in the 2nd half after resume
args = f"""
--tensor-model-parallel-size {tp_size}
--pipeline-model-parallel-size {pp_size}
--distributed-backend nccl
--num-layers 2
--hidden-size 64
--num-attention-heads 2
--seq-length 128
--max-position-embeddings 1024
--micro-batch-size 1
--rampup-batch-size 2 2 {n_samples}
--global-batch-size 16
--train-samples {n_samples}
--loss-on-targets-only
--optimizer adam
--adam-beta1 0.9
--adam-beta2 0.95
--adam-eps 1e-8
--lr 1e-4
--lr-warmup-samples 5
--clip-grad 1.0
--weight-decay 1e-1
--fp16
--log-interval 5
--save-interval 10
--eval-interval 10
--eval-iters 5
--checkpoint-activations
--glu-activation geglu
--exit-interval {exit_interval}
--merge-file {data_dir}/gpt2-tiny-merges.txt
--vocab-file {data_dir}/gpt2-tiny-vocab.json
--save {output_dir}/checkpoints
--load {output_dir}/checkpoints
--data-path {data_dir}/meg-gpt2-openwebtext_text_document
--codecarbon-dir {output_dir}/codecarbon
--tensorboard-dir {output_dir}/tensorboard
--tensorboard-queue-size 5
--log-timers-to-tensorboard
--log-batch-size-to-tensorboard
--log-validation-ppl-to-tensorboard
""".split()

ds_args = f"""
--deepspeed
--deepspeed_config {self.test_file_dir_str}/ds_config.json
--zero-stage 1
--deepspeed-activation-checkpointing
""".split()

script = [f"{src_dir}/pretrain_prefix_lm.py"]
launcher = get_launcher(num_gpus)

cmd = launcher + script + args + ds_args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die

# 1. test training from scratch (no checkpoint)
with CaptureStdout() as cs:
execute_subprocess_async(cmd, env=self.get_env())

# test deepspeed is running
self.assertIn("DeepSpeed info", cs.out)

# test reports
self.assertIn("consumed samples", cs.out)

# test there should be no checkpoint this round
self.assertIn(f"Unable to find latest file at {output_dir}/checkpoints/latest", cs.out)

# test checkpoint saving
self.assertIn("successfully saved checkpoint at iteration", cs.out)

# test tensorboard
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
self.assertEqual(len(tensorboard_files), 1, "tensorboard files")

# 2. test training from checkpoint: resume
# now do it again, this time resuming from the checkpoint
with CaptureStdout() as cs:
execute_subprocess_async(cmd, env=self.get_env())

# test checkpoint loading
self.assertIn(f"successfully loaded checkpoint from {output_dir}/checkpoints", cs.out)

# test reports
self.assertIn("consumed samples", cs.out)

# test checkpoint saving
self.assertIn("successfully saved checkpoint at iteration", cs.out)

# test tensorboard (1 file from the first run, plus 1 now)
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
self.assertEqual(len(tensorboard_files), 2, "tensorboard files")

0 comments on commit da31db6

Please sign in to comment.