From 90b8f46dd2b3d349a39ddc47ecdccd4e1724d198 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sun, 10 Jul 2022 13:15:16 +0200 Subject: [PATCH 01/14] Enable loading ckpt for t0 finetuning --- finetune_t0_non_causal_decoder.py | 2 +- megatron/checkpointing.py | 2 +- megatron/model/gpt_model.py | 3 ++- megatron/utils.py | 30 +++++++++++++++++++++++------- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/finetune_t0_non_causal_decoder.py b/finetune_t0_non_causal_decoder.py index 14650a6e5..44845baf2 100644 --- a/finetune_t0_non_causal_decoder.py +++ b/finetune_t0_non_causal_decoder.py @@ -89,7 +89,7 @@ def get_batch_pipe(data): attention_mask = get_packed_attention_mask( # Run non-causal decoder is_causal=False, - causal_mask=~(causal_mask.bool()), + causal_mask=~(causal_mask.bool()), # Turn back into tril being ones decoder_is_inputs=decoder_is_inputs.bool(), segment_ids=segment_ids.long(), ) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index d9a30f468..4ed6bcc73 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -273,7 +273,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True if args.deepspeed: load_optimizer_states = False if args.no_load_optim else True - loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states) + loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states, load_lr_scheduler_states=load_optimizer_states) if loaded_dir is None: print_rank_0('WARNING: could not find the metadata file {} '.format( load_dir)) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index a9e3e2604..6896df7f4 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -252,7 +252,8 @@ def _to_float16(inputs): args.num_layers), layer_number=layer_idx, # TODO: Change naming of class from GPT to something that encapsulate prefix lm. - self_attn_mask_type=attn_mask_type)) + self_attn_mask_type=attn_mask_type) + ) # Undo data format change def undo(x): diff --git a/megatron/utils.py b/megatron/utils.py index 6f3a0fa41..f160fc8da 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -261,11 +261,18 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode - segment_ids: torch.IntTensor [batch_size, sequence_length] Returns: - attention_mask: torch.BoolTensor [batch_size, 1, sequence_length, sequence_length] + + Input example for the mask examples: + att_mask_batch = 1 + seq_length = 7 + decoder_is_inputs = torch.tensor([[1, 1, 1, 2, 2, 2, 0]]) + segment_ids = torch.tensor([[1, 1, 0, 1, 1, 0, 0]]) + causal_mask = torch.tril(torch.ones(att_mask_batch, seq_length, seq_length)).view(att_mask_batch, 1, seq_length, seq_length) """ """Causal Inputs Mask: - mask = [[[[1, 1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], + mask = [[[[1, 1, 0, 1, 1, 0, 0], + [1, 1, 0, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0, 0], @@ -299,7 +306,7 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode [0, 0, 0, 1, 1, 1, 0], [0, 0, 0, 1, 1, 1, 0], [0, 0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0]]]] + [0, 0, 0, 0, 0, 0, 1]]]] """ segment_mask = segment_ids[:, None, :, None] == segment_ids[:, None, None, :] @@ -311,13 +318,22 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode [0, 0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 1, 1, 0], [0, 0, 0, 0, 0, 0, 0]]]] + + If is_causal=True: + mask = [[[[1, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]] + """ - attention_mask = causal_inputs_mask * padding_mask * segment_mask - # Convert attention mask to binary: - attention_mask = (attention_mask < 0.5) + attention_mask = causal_inputs_mask * padding_mask * segment_mask - return attention_mask + # True for places we do not want to attend to + return ~attention_mask def param_size(parameter): return parameter.ds_numel if hasattr(parameter, 'ds_id') else parameter.nelement() From abdd703028a9a125dd8d525cbf8e866e40b5e1a4 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Mon, 11 Jul 2022 12:40:45 +0200 Subject: [PATCH 02/14] Swap decoder_is_inputs & segment_ids --- megatron/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/utils.py b/megatron/utils.py index f160fc8da..2efa760fc 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -265,8 +265,8 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode Input example for the mask examples: att_mask_batch = 1 seq_length = 7 - decoder_is_inputs = torch.tensor([[1, 1, 1, 2, 2, 2, 0]]) - segment_ids = torch.tensor([[1, 1, 0, 1, 1, 0, 0]]) + decoder_is_inputs = torch.tensor([[1, 1, 0, 1, 1, 0, 0]]) + segment_ids = torch.tensor([[1, 1, 1, 2, 2, 2, 0]]) causal_mask = torch.tril(torch.ones(att_mask_batch, seq_length, seq_length)).view(att_mask_batch, 1, seq_length, seq_length) """ From 0fcb19c12eb22a3dd56d3bbba91d9cef27eb11ce Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Mon, 11 Jul 2022 14:38:36 +0200 Subject: [PATCH 03/14] Add prepend-space arg --- tools/preprocess_data.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 953378680..d3d23186f 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -82,6 +82,8 @@ def encode(self, json_line): ids = {} for key in self.args.json_keys: text = data[key] + if self.args.prepend_space: + text = " " + text doc_ids = [] for sentence in Encoder.splitter.tokenize(text): sentence_ids = Encoder.tokenizer.tokenize(sentence) @@ -117,6 +119,8 @@ def get_args(): help='Path to the BPE merge file (if necessary).') group.add_argument('--append-eod', action='store_true', help='Append an token to the end of a document.') + group.add_argument('--prepend-space', action='store_true', + help='Prepends a space to the beginning of a document') group.add_argument("--tokenizer-name-or-path", type=str, default=None, help="Name or path of the huggingface tokenizer.") group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, From 63daa46f8f4274c3a4a12c8f19341ab28cca134e Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Mon, 11 Jul 2022 15:02:51 +0200 Subject: [PATCH 04/14] Update tools/preprocess_data.py Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com> --- tools/preprocess_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index d3d23186f..bd8cbdb13 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -83,7 +83,7 @@ def encode(self, json_line): for key in self.args.json_keys: text = data[key] if self.args.prepend_space: - text = " " + text + text = f" {text}" doc_ids = [] for sentence in Encoder.splitter.tokenize(text): sentence_ids = Encoder.tokenizer.tokenize(sentence) From 89460c0a3200e3d849096d02088a37087e546307 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Mon, 11 Jul 2022 16:33:59 +0200 Subject: [PATCH 05/14] Add helpers & set is_causal to true --- finetune_t0_non_causal_decoder.py | 32 +++++++++++++++++++++++++++++-- megatron/model/gpt_model.py | 12 ++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/finetune_t0_non_causal_decoder.py b/finetune_t0_non_causal_decoder.py index 44845baf2..3d60a2748 100644 --- a/finetune_t0_non_causal_decoder.py +++ b/finetune_t0_non_causal_decoder.py @@ -47,6 +47,31 @@ def model_provider(pre_process=True, post_process=True): see_memory_usage(f"After Building Model", force=True) return model +def visualize_model_inputs(tokens, attention_mask, labels, loss_mask): + tokenizer = get_tokenizer() + import os + if os.path.exists("batchoutput.json"): + return + out = { + "tokens": tokens[0,:].tolist(), + "detokens": tokenizer.detokenize(tokens[0,:].tolist()), + "labels": labels[0,:].tolist(), + "attention_mask": attention_mask[0,:].tolist(), + "loss_mask": loss_mask[0,:].tolist(), + } + import json + with open('batchoutput.json', 'w') as fp: + json.dump(out, fp) + + #if os.path.exists("batchoutput.txt"): + # return + #with open("batchoutput.txt", "w", encoding="UTF-8") as f: + # batch_log_string = f"TOKENS\n{tokens[0,:].tolist()}\n\nDETOKENS\n{[tokenizer.detokenize(tokens[0,:])]}\n\n" + # batch_log_string += f"LABELS\n{labels[0,:].tolist()}\n\nttention_mask\n{attention_mask[0,:].tolist()}\n\n" + # batch_log_string += f"LABELS\n{loss_mask[0,:].tolist()}" + # print(batch_log_string) + # f.write(batch_log_string) + def get_batch_pipe(data): """ Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion @@ -83,17 +108,20 @@ def get_batch_pipe(data): ) # Only compute loss over causal target tokens, i.e. ignore input_tokens & padding loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:] - loss_on_non_pad_only = (tokens != tokenizer.pad) + loss_on_non_pad_only = (labels != tokenizer.pad) loss_mask *= loss_on_targets_only * loss_on_non_pad_only attention_mask = get_packed_attention_mask( # Run non-causal decoder - is_causal=False, + is_causal=True, causal_mask=~(causal_mask.bool()), # Turn back into tril being ones decoder_is_inputs=decoder_is_inputs.bool(), segment_ids=segment_ids.long(), ) + # Helper script + # visualize_model_inputs(tokens, attention_mask, labels, loss_mask) + if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.") diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 6896df7f4..85a046dc2 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -158,6 +158,15 @@ def load_state_dict(self, state_dict, strict=True): state_dict = state_dict[self._language_model_key] self.language_model.load_state_dict(state_dict, strict=strict) +def visualize_outputs(losses): + import os + if os.path.exists("losses.txt"): + return + with open("losses.txt", "w", encoding="UTF-8") as f: + batch_log_string = f"LOSESS\n{losses[0,:].tolist()}" + print(batch_log_string) + f.write(batch_log_string) + def get_cross_entropy(is_prefix: bool): def CrossEntropy(output, labels): @@ -167,6 +176,9 @@ def CrossEntropy(output, labels): losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) + # Helper script + # visualize_outputs(losses) + if is_prefix: micro_batch_size, sequence_length = loss_mask.shape average_tokens_per_sample: torch.Tensor From a55d2fb5d64a27b11fe3a062de7c041efcf56c43 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Mon, 11 Jul 2022 18:05:57 +0200 Subject: [PATCH 06/14] JSON helper scripts --- finetune_t0_non_causal_decoder.py | 10 +--------- megatron/model/gpt_model.py | 13 +++++++------ 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/finetune_t0_non_causal_decoder.py b/finetune_t0_non_causal_decoder.py index 3d60a2748..e2df88996 100644 --- a/finetune_t0_non_causal_decoder.py +++ b/finetune_t0_non_causal_decoder.py @@ -54,7 +54,7 @@ def visualize_model_inputs(tokens, attention_mask, labels, loss_mask): return out = { "tokens": tokens[0,:].tolist(), - "detokens": tokenizer.detokenize(tokens[0,:].tolist()), + #"detokens": tokenizer.detokenize(tokens[0,:].tolist()), "labels": labels[0,:].tolist(), "attention_mask": attention_mask[0,:].tolist(), "loss_mask": loss_mask[0,:].tolist(), @@ -63,14 +63,6 @@ def visualize_model_inputs(tokens, attention_mask, labels, loss_mask): with open('batchoutput.json', 'w') as fp: json.dump(out, fp) - #if os.path.exists("batchoutput.txt"): - # return - #with open("batchoutput.txt", "w", encoding="UTF-8") as f: - # batch_log_string = f"TOKENS\n{tokens[0,:].tolist()}\n\nDETOKENS\n{[tokenizer.detokenize(tokens[0,:])]}\n\n" - # batch_log_string += f"LABELS\n{labels[0,:].tolist()}\n\nttention_mask\n{attention_mask[0,:].tolist()}\n\n" - # batch_log_string += f"LABELS\n{loss_mask[0,:].tolist()}" - # print(batch_log_string) - # f.write(batch_log_string) def get_batch_pipe(data): """ diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 85a046dc2..4a46814a2 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -160,13 +160,14 @@ def load_state_dict(self, state_dict, strict=True): def visualize_outputs(losses): import os - if os.path.exists("losses.txt"): + if os.path.exists("losses.json"): return - with open("losses.txt", "w", encoding="UTF-8") as f: - batch_log_string = f"LOSESS\n{losses[0,:].tolist()}" - print(batch_log_string) - f.write(batch_log_string) - + out = { + "losses": losses[0,:].tolist(), + } + import json + with open('losses.json', 'w') as fp: + json.dump(out, fp) def get_cross_entropy(is_prefix: bool): def CrossEntropy(output, labels): From 2dfe5d11627e554e11933592c5c78c5428414aef Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Mon, 11 Jul 2022 18:06:27 +0200 Subject: [PATCH 07/14] Remove unnec imports --- finetune_t0_non_causal_decoder.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/finetune_t0_non_causal_decoder.py b/finetune_t0_non_causal_decoder.py index e2df88996..54eeb98c2 100644 --- a/finetune_t0_non_causal_decoder.py +++ b/finetune_t0_non_causal_decoder.py @@ -48,13 +48,11 @@ def model_provider(pre_process=True, post_process=True): return model def visualize_model_inputs(tokens, attention_mask, labels, loss_mask): - tokenizer = get_tokenizer() import os if os.path.exists("batchoutput.json"): return out = { "tokens": tokens[0,:].tolist(), - #"detokens": tokenizer.detokenize(tokens[0,:].tolist()), "labels": labels[0,:].tolist(), "attention_mask": attention_mask[0,:].tolist(), "loss_mask": loss_mask[0,:].tolist(), From ca740f1ee30d85b92894f951565b1d7fcea5d08a Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Tue, 12 Jul 2022 10:53:56 +0200 Subject: [PATCH 08/14] Remove helper scripts --- finetune_t0_non_causal_decoder.py | 17 ----------------- megatron/model/gpt_model.py | 13 ------------- 2 files changed, 30 deletions(-) diff --git a/finetune_t0_non_causal_decoder.py b/finetune_t0_non_causal_decoder.py index 54eeb98c2..709304c60 100644 --- a/finetune_t0_non_causal_decoder.py +++ b/finetune_t0_non_causal_decoder.py @@ -47,20 +47,6 @@ def model_provider(pre_process=True, post_process=True): see_memory_usage(f"After Building Model", force=True) return model -def visualize_model_inputs(tokens, attention_mask, labels, loss_mask): - import os - if os.path.exists("batchoutput.json"): - return - out = { - "tokens": tokens[0,:].tolist(), - "labels": labels[0,:].tolist(), - "attention_mask": attention_mask[0,:].tolist(), - "loss_mask": loss_mask[0,:].tolist(), - } - import json - with open('batchoutput.json', 'w') as fp: - json.dump(out, fp) - def get_batch_pipe(data): """ @@ -109,9 +95,6 @@ def get_batch_pipe(data): segment_ids=segment_ids.long(), ) - # Helper script - # visualize_model_inputs(tokens, attention_mask, labels, loss_mask) - if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.") diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 4a46814a2..6896df7f4 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -158,16 +158,6 @@ def load_state_dict(self, state_dict, strict=True): state_dict = state_dict[self._language_model_key] self.language_model.load_state_dict(state_dict, strict=strict) -def visualize_outputs(losses): - import os - if os.path.exists("losses.json"): - return - out = { - "losses": losses[0,:].tolist(), - } - import json - with open('losses.json', 'w') as fp: - json.dump(out, fp) def get_cross_entropy(is_prefix: bool): def CrossEntropy(output, labels): @@ -177,9 +167,6 @@ def CrossEntropy(output, labels): losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) - # Helper script - # visualize_outputs(losses) - if is_prefix: micro_batch_size, sequence_length = loss_mask.shape average_tokens_per_sample: torch.Tensor From cb0313ba3e3c2de0b51a1e4ddba317ceb3b0815a Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Wed, 13 Jul 2022 21:16:57 +0200 Subject: [PATCH 09/14] Avoid loading module when not loading optim --- megatron/checkpointing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 4ed6bcc73..5bd6a9623 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -272,8 +272,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True load_dir = getattr(args, load_arg) if args.deepspeed: - load_optimizer_states = False if args.no_load_optim else True - loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states, load_lr_scheduler_states=load_optimizer_states) + load_optimizer_states = not args.no_load_optim + loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_module_only=not load_optimizer_states, load_optimizer_states=load_optimizer_states, load_lr_scheduler_states=load_optimizer_states) if loaded_dir is None: print_rank_0('WARNING: could not find the metadata file {} '.format( load_dir)) From b62dcafc67c67168447defe1feb6c63f2d6fbb10 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Wed, 13 Jul 2022 21:21:56 +0200 Subject: [PATCH 10/14] Allow not using torch distributed --- megatron/data/decoder_packed_mtf_dataset.py | 23 +++++++++++---------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/megatron/data/decoder_packed_mtf_dataset.py b/megatron/data/decoder_packed_mtf_dataset.py index 4edf14207..0ef812544 100644 --- a/megatron/data/decoder_packed_mtf_dataset.py +++ b/megatron/data/decoder_packed_mtf_dataset.py @@ -358,7 +358,7 @@ def pack_samples(self, items): decoder_tokens[cur_len: cur_len + input_token_len] = token_dict["input_tokens"] decoder_tokens[cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"] decoder_segment_ids[cur_len: cur_len + total_len] = item_num - decoder_is_inputs[cur_len: cur_len + input_token_len] = 1 # inputs + decoder_is_inputs[cur_len: cur_len + input_token_len] = True # inputs # targets are already 0 at init, no need to update `decoder_is_inputs` item_num += 1 @@ -399,7 +399,7 @@ def _build_index_mappings( shuffle_idx_filename = _filename + '_decoder_packed_shuffle_idx.npy' # Build the indexed mapping if not exist. - if torch.distributed.get_rank() == 0: + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: if (not os.path.isfile(sample_idx_filename)) or \ (not os.path.isfile(shuffle_idx_filename)): @@ -437,15 +437,16 @@ def _build_index_mappings( print_rank_0(' > elasped time to build and save shuffle-idx and sample-idx mapping' ' (seconds): {:4f}'.format(time.time() - start_time)) - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) - torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) - assert counts[0].item() == ( - torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + if torch.distributed.is_initialized(): + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) # Load mappings. start_time = time.time() From b15ca2d5a036417435af08348e5327fec5b98009 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 16 Jul 2022 18:18:09 +0200 Subject: [PATCH 11/14] Add prefixlm arg --- finetune_t0_non_causal_decoder.py => finetune_t0.py | 2 +- megatron/arguments.py | 1 + tests/test_dataloaders.py | 6 +++--- tests/test_model.py | 4 ++-- tests/test_training.py | 2 +- 5 files changed, 8 insertions(+), 7 deletions(-) rename finetune_t0_non_causal_decoder.py => finetune_t0.py (99%) diff --git a/finetune_t0_non_causal_decoder.py b/finetune_t0.py similarity index 99% rename from finetune_t0_non_causal_decoder.py rename to finetune_t0.py index 709304c60..e4af9b91f 100644 --- a/finetune_t0_non_causal_decoder.py +++ b/finetune_t0.py @@ -89,7 +89,7 @@ def get_batch_pipe(data): attention_mask = get_packed_attention_mask( # Run non-causal decoder - is_causal=True, + is_causal=not(args.prefixlm), causal_mask=~(causal_mask.bool()), # Turn back into tril being ones decoder_is_inputs=decoder_is_inputs.bool(), segment_ids=segment_ids.long(), diff --git a/megatron/arguments.py b/megatron/arguments.py index cf48d0213..4b63112e9 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -933,6 +933,7 @@ def __call__(self, parser, args, values, option_string=None): 'This is mostly used for prefix_lm training') group.add_argument("--noise-density", type=float, default=None, help="Span corruption noise density") group.add_argument("--mean-noise-span-length", type=int, default=None, help="Span corruption mean noise span length") + group.add_argument("--prefixlm", action='store_true', help="Whether to train a PrefixLM - To be used with finetune t0") return parser diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index 30ec1328f..ce56a23dd 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -7,7 +7,7 @@ import deepspeed import torch -import finetune_t0_non_causal_decoder +import finetune_t0 from megatron import global_vars, get_tokenizer, initialize_megatron, get_args from megatron.data import mlm_dataset, mtf_dataset, decoder_packed_mtf_dataset from megatron.data.data_samplers import build_pretraining_data_loader @@ -241,7 +241,7 @@ def test_decoder_packed_mtf_dataloader(self): last_padding_size = len([None for segment_id in items["decoder_segment_ids"][micro_batch_size - 1] if segment_id == 0]) - def test_finetune_t0_non_causal_decoder_get_batch_pipe(self): + def test_finetune_t0_get_batch_pipe(self): command_args = get_default_args() command_args["--position-embedding-type"] = "alibi" @@ -263,7 +263,7 @@ def test_finetune_t0_non_causal_decoder_get_batch_pipe(self): special_tokens_ids={tokenizer.pad} ) - (tokens, position_ids, attention_mask), (labels, loss_mask) = finetune_t0_non_causal_decoder.get_batch_pipe(data) + (tokens, position_ids, attention_mask), (labels, loss_mask) = finetune_t0.get_batch_pipe(data) tokens = tokens.cpu() position_ids = position_ids.cpu() diff --git a/tests/test_model.py b/tests/test_model.py index 390e8664d..0e6a95554 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -20,7 +20,7 @@ from megatron.training import setup_model_and_optimizer import pretrain_gpt import pretrain_prefix_lm -import finetune_t0_non_causal_decoder +import finetune_t0 def get_default_args(test_file_dir: str): @@ -456,7 +456,7 @@ def test_non_causal_decoder_model_with_packed_input_passed_with_attention_mask_i vocab_size=args.padded_vocab_size, special_tokens_ids={tokenizer.pad} ) - model, _, _ = setup_model_and_optimizer(finetune_t0_non_causal_decoder.model_provider) + model, _, _ = setup_model_and_optimizer(finetune_t0.model_provider) model = model[0] model._config.train_micro_batch_size_per_gpu = args.micro_batch_size model.set_train_batch_size(args.micro_batch_size) diff --git a/tests/test_training.py b/tests/test_training.py index 2a031cc2d..686f09ed5 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -557,7 +557,7 @@ def test_training_t0(self): --deepspeed-activation-checkpointing """.split() - script = [f"{self.src_dir}/finetune_t0_non_causal_decoder.py"] + script = [f"{self.src_dir}/finetune_t0.py"] launcher = get_launcher(num_gpus) cmd = launcher + script + args + ds_args From dc8d0abb2d98faf70d823f675062c83be6933899 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Thu, 28 Jul 2022 11:52:36 +0200 Subject: [PATCH 12/14] Add bos option --- megatron/tokenizer/tokenizer.py | 3 ++- tools/preprocess_data.py | 9 +++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 09304b1dd..931ae9f4e 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -380,7 +380,8 @@ def mask(self): @property def bos(self): - raise NotImplementedError("Missing ") + candidate = self.tokenizer.bos_token_id + return self._check_token_candidate(candidate) @property def eos(self): diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index bd8cbdb13..f086a689b 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -89,8 +89,11 @@ def encode(self, json_line): sentence_ids = Encoder.tokenizer.tokenize(sentence) if len(sentence_ids) > 0: doc_ids.append(sentence_ids) - if len(doc_ids) > 0 and self.args.append_eod: - doc_ids[-1].append(Encoder.tokenizer.eod) + if len(doc_ids) > 0: + if self.args.append_eod: + doc_ids[-1].append(Encoder.tokenizer.eod) + elif self.args.append_bos: + doc_ids[-1].append(Encoder.tokenizer.bos) ids[key] = doc_ids return ids, len(json_line) @@ -119,6 +122,8 @@ def get_args(): help='Path to the BPE merge file (if necessary).') group.add_argument('--append-eod', action='store_true', help='Append an token to the end of a document.') + group.add_argument('--append-bos', action='store_true', + help='Append a bos token to the end of a document.') group.add_argument('--prepend-space', action='store_true', help='Prepends a space to the beginning of a document') group.add_argument("--tokenizer-name-or-path", type=str, default=None, From 2699721690da3dd71b83a588c8397355b496cb45 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Mon, 15 Aug 2022 20:11:51 +0200 Subject: [PATCH 13/14] Add reset-progress key --- megatron/arguments.py | 2 ++ megatron/checkpointing.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 7a90499dc..6622df924 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -664,6 +664,8 @@ def _add_checkpointing_args(parser): help='Do not load optimizer when loading checkpoint.') group.add_argument('--no-load-rng', action='store_true', default=None, help='Do not load rng state when loading checkpoint.') + group.add_argument('--reset-progress', action='store_true', default=None, + help='Reset iteration to 0 & do not load args.') group.add_argument('--finetune', action='store_true', help='Load model for finetuning. Do not load optimizer ' 'or rng state from checkpoint and set iteration to 0. ' diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 28ebbae1c..ad63213b4 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -342,7 +342,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True set_checkpoint_version(state_dict.get('checkpoint_version', 0)) # Set iteration. - if args.finetune or release: + if args.finetune or release or args.reset_progress: iteration = 0 else: try: @@ -361,7 +361,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True # Check arguments. assert args.consumed_train_samples == 0 assert args.consumed_valid_samples == 0 - if 'args' in state_dict: + if 'args' in state_dict and not args.reset_progress: checkpoint_args = state_dict['args'] if not args.universal_checkpoint: check_checkpoint_args(checkpoint_args) @@ -480,4 +480,4 @@ def _checkpoint_info(): return { "padded_vocab_size": args.padded_vocab_size, "original_vocab_size": tokenizer.vocab_size, - } \ No newline at end of file + } From 1e77844c26c2e488f57e125917ba666d5afcf9c1 Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Thu, 3 Nov 2022 18:38:03 +0100 Subject: [PATCH 14/14] Add option to normalize loss per target (#326) * Tmp lossseq * Efficient loss normalization * Reuse variable * Simplify division * Add norm_target_loss arg * Clarify loss on targets & remove kwarg * Loss mask is already float * Move norm to batch pipe * Reshape loss mask * Move view --- finetune_t0.py | 12 ++++++++++++ megatron/arguments.py | 2 ++ megatron/model/gpt_model.py | 5 ++++- 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/finetune_t0.py b/finetune_t0.py index e4af9b91f..a6fbd3a78 100644 --- a/finetune_t0.py +++ b/finetune_t0.py @@ -48,6 +48,14 @@ def model_provider(pre_process=True, post_process=True): return model +def fast_normalize(loss_mask: torch.Tensor): + """ + Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3] + """ + _, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True) + counts = torch.gather(dim=0, index=inverse_indices, input=counts) + return loss_mask / counts + def get_batch_pipe(data): """ Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion @@ -95,6 +103,10 @@ def get_batch_pipe(data): segment_ids=segment_ids.long(), ) + if args.norm_target_loss: + loss_mask = loss_mask.view(-1) + loss_mask = fast_normalize(loss_mask) + if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.") diff --git a/megatron/arguments.py b/megatron/arguments.py index 6622df924..a51ac6a33 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -930,6 +930,8 @@ def __call__(self, parser, args, values, option_string=None): help='Mask loss for the end of document tokens.') group.add_argument('--loss-on-targets-only', action='store_true', help='Mask loss on input sequence.') + group.add_argument('--norm-target-loss', action='store_true', + help='Normalize the loss per target. Used for multi-task finetuning with packing.') group.add_argument('--reweight-loss-based-on-position-frequency', action="store_true", help='Some objectives require us to sample loss_mask. This might introduce bias towards ' 'specific positions. This option tries to un-bias the loss by reweighting loss on specific ' diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 6896df7f4..3b8cbb960 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -15,7 +15,6 @@ """GPT-2 model.""" -from functools import partial import torch from megatron import get_args @@ -186,6 +185,10 @@ def CrossEntropy(output, labels): else: average_tokens_per_sample = sequence_length expected_number_of_tokens = average_tokens_per_sample * micro_batch_size + elif args.norm_target_loss: + expected_num_of_target_seqs = loss_mask.sum() + loss = torch.sum(losses.view(-1) * loss_mask) / expected_num_of_target_seqs + return loss else: expected_number_of_tokens = loss_mask.sum()