diff --git a/finetune_t0_non_causal_decoder.py b/finetune_t0.py similarity index 91% rename from finetune_t0_non_causal_decoder.py rename to finetune_t0.py index 14650a6e5..a6fbd3a78 100644 --- a/finetune_t0_non_causal_decoder.py +++ b/finetune_t0.py @@ -47,6 +47,15 @@ def model_provider(pre_process=True, post_process=True): see_memory_usage(f"After Building Model", force=True) return model + +def fast_normalize(loss_mask: torch.Tensor): + """ + Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3] + """ + _, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True) + counts = torch.gather(dim=0, index=inverse_indices, input=counts) + return loss_mask / counts + def get_batch_pipe(data): """ Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion @@ -83,17 +92,21 @@ def get_batch_pipe(data): ) # Only compute loss over causal target tokens, i.e. ignore input_tokens & padding loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:] - loss_on_non_pad_only = (tokens != tokenizer.pad) + loss_on_non_pad_only = (labels != tokenizer.pad) loss_mask *= loss_on_targets_only * loss_on_non_pad_only attention_mask = get_packed_attention_mask( # Run non-causal decoder - is_causal=False, - causal_mask=~(causal_mask.bool()), + is_causal=not(args.prefixlm), + causal_mask=~(causal_mask.bool()), # Turn back into tril being ones decoder_is_inputs=decoder_is_inputs.bool(), segment_ids=segment_ids.long(), ) + if args.norm_target_loss: + loss_mask = loss_mask.view(-1) + loss_mask = fast_normalize(loss_mask) + if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.") diff --git a/megatron/arguments.py b/megatron/arguments.py index c18235a78..a51ac6a33 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -664,6 +664,8 @@ def _add_checkpointing_args(parser): help='Do not load optimizer when loading checkpoint.') group.add_argument('--no-load-rng', action='store_true', default=None, help='Do not load rng state when loading checkpoint.') + group.add_argument('--reset-progress', action='store_true', default=None, + help='Reset iteration to 0 & do not load args.') group.add_argument('--finetune', action='store_true', help='Load model for finetuning. Do not load optimizer ' 'or rng state from checkpoint and set iteration to 0. ' @@ -928,6 +930,8 @@ def __call__(self, parser, args, values, option_string=None): help='Mask loss for the end of document tokens.') group.add_argument('--loss-on-targets-only', action='store_true', help='Mask loss on input sequence.') + group.add_argument('--norm-target-loss', action='store_true', + help='Normalize the loss per target. Used for multi-task finetuning with packing.') group.add_argument('--reweight-loss-based-on-position-frequency', action="store_true", help='Some objectives require us to sample loss_mask. This might introduce bias towards ' 'specific positions. This option tries to un-bias the loss by reweighting loss on specific ' @@ -935,6 +939,7 @@ def __call__(self, parser, args, values, option_string=None): 'This is mostly used for prefix_lm training') group.add_argument("--noise-density", type=float, default=None, help="Span corruption noise density") group.add_argument("--mean-noise-span-length", type=int, default=None, help="Span corruption mean noise span length") + group.add_argument("--prefixlm", action='store_true', help="Whether to train a PrefixLM - To be used with finetune t0") return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index dacbec7dc..ad63213b4 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -274,8 +274,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True load_dir = getattr(args, load_arg) if args.deepspeed: - load_optimizer_states = False if args.no_load_optim else True - loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states) + load_optimizer_states = not args.no_load_optim + loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_module_only=not load_optimizer_states, load_optimizer_states=load_optimizer_states, load_lr_scheduler_states=load_optimizer_states) if loaded_dir is None: print_rank_0('WARNING: could not find the metadata file {} '.format( load_dir)) @@ -342,7 +342,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True set_checkpoint_version(state_dict.get('checkpoint_version', 0)) # Set iteration. - if args.finetune or release: + if args.finetune or release or args.reset_progress: iteration = 0 else: try: @@ -361,7 +361,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True # Check arguments. assert args.consumed_train_samples == 0 assert args.consumed_valid_samples == 0 - if 'args' in state_dict: + if 'args' in state_dict and not args.reset_progress: checkpoint_args = state_dict['args'] if not args.universal_checkpoint: check_checkpoint_args(checkpoint_args) @@ -480,4 +480,4 @@ def _checkpoint_info(): return { "padded_vocab_size": args.padded_vocab_size, "original_vocab_size": tokenizer.vocab_size, - } \ No newline at end of file + } diff --git a/megatron/data/decoder_packed_mtf_dataset.py b/megatron/data/decoder_packed_mtf_dataset.py index 4edf14207..0ef812544 100644 --- a/megatron/data/decoder_packed_mtf_dataset.py +++ b/megatron/data/decoder_packed_mtf_dataset.py @@ -358,7 +358,7 @@ def pack_samples(self, items): decoder_tokens[cur_len: cur_len + input_token_len] = token_dict["input_tokens"] decoder_tokens[cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"] decoder_segment_ids[cur_len: cur_len + total_len] = item_num - decoder_is_inputs[cur_len: cur_len + input_token_len] = 1 # inputs + decoder_is_inputs[cur_len: cur_len + input_token_len] = True # inputs # targets are already 0 at init, no need to update `decoder_is_inputs` item_num += 1 @@ -399,7 +399,7 @@ def _build_index_mappings( shuffle_idx_filename = _filename + '_decoder_packed_shuffle_idx.npy' # Build the indexed mapping if not exist. - if torch.distributed.get_rank() == 0: + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: if (not os.path.isfile(sample_idx_filename)) or \ (not os.path.isfile(shuffle_idx_filename)): @@ -437,15 +437,16 @@ def _build_index_mappings( print_rank_0(' > elasped time to build and save shuffle-idx and sample-idx mapping' ' (seconds): {:4f}'.format(time.time() - start_time)) - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) - torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) - assert counts[0].item() == ( - torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + if torch.distributed.is_initialized(): + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) # Load mappings. start_time = time.time() diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index a9e3e2604..3b8cbb960 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -15,7 +15,6 @@ """GPT-2 model.""" -from functools import partial import torch from megatron import get_args @@ -186,6 +185,10 @@ def CrossEntropy(output, labels): else: average_tokens_per_sample = sequence_length expected_number_of_tokens = average_tokens_per_sample * micro_batch_size + elif args.norm_target_loss: + expected_num_of_target_seqs = loss_mask.sum() + loss = torch.sum(losses.view(-1) * loss_mask) / expected_num_of_target_seqs + return loss else: expected_number_of_tokens = loss_mask.sum() @@ -252,7 +255,8 @@ def _to_float16(inputs): args.num_layers), layer_number=layer_idx, # TODO: Change naming of class from GPT to something that encapsulate prefix lm. - self_attn_mask_type=attn_mask_type)) + self_attn_mask_type=attn_mask_type) + ) # Undo data format change def undo(x): diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 09304b1dd..931ae9f4e 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -380,7 +380,8 @@ def mask(self): @property def bos(self): - raise NotImplementedError("Missing ") + candidate = self.tokenizer.bos_token_id + return self._check_token_candidate(candidate) @property def eos(self): diff --git a/megatron/utils.py b/megatron/utils.py index 893f58dd2..fe8cc8260 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -261,11 +261,18 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode - segment_ids: torch.IntTensor [batch_size, sequence_length] Returns: - attention_mask: torch.BoolTensor [batch_size, 1, sequence_length, sequence_length] + + Input example for the mask examples: + att_mask_batch = 1 + seq_length = 7 + decoder_is_inputs = torch.tensor([[1, 1, 0, 1, 1, 0, 0]]) + segment_ids = torch.tensor([[1, 1, 1, 2, 2, 2, 0]]) + causal_mask = torch.tril(torch.ones(att_mask_batch, seq_length, seq_length)).view(att_mask_batch, 1, seq_length, seq_length) """ """Causal Inputs Mask: - mask = [[[[1, 1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], + mask = [[[[1, 1, 0, 1, 1, 0, 0], + [1, 1, 0, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0, 0], @@ -299,7 +306,7 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode [0, 0, 0, 1, 1, 1, 0], [0, 0, 0, 1, 1, 1, 0], [0, 0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0]]]] + [0, 0, 0, 0, 0, 0, 1]]]] """ segment_mask = segment_ids[:, None, :, None] == segment_ids[:, None, None, :] @@ -311,13 +318,22 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode [0, 0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 1, 1, 0], [0, 0, 0, 0, 0, 0, 0]]]] + + If is_causal=True: + mask = [[[[1, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]] + """ - attention_mask = causal_inputs_mask * padding_mask * segment_mask - # Convert attention mask to binary: - attention_mask = (attention_mask < 0.5) + attention_mask = causal_inputs_mask * padding_mask * segment_mask - return attention_mask + # True for places we do not want to attend to + return ~attention_mask def param_size(parameter): return parameter.ds_numel if hasattr(parameter, 'ds_id') else parameter.nelement() diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index 30ec1328f..ce56a23dd 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -7,7 +7,7 @@ import deepspeed import torch -import finetune_t0_non_causal_decoder +import finetune_t0 from megatron import global_vars, get_tokenizer, initialize_megatron, get_args from megatron.data import mlm_dataset, mtf_dataset, decoder_packed_mtf_dataset from megatron.data.data_samplers import build_pretraining_data_loader @@ -241,7 +241,7 @@ def test_decoder_packed_mtf_dataloader(self): last_padding_size = len([None for segment_id in items["decoder_segment_ids"][micro_batch_size - 1] if segment_id == 0]) - def test_finetune_t0_non_causal_decoder_get_batch_pipe(self): + def test_finetune_t0_get_batch_pipe(self): command_args = get_default_args() command_args["--position-embedding-type"] = "alibi" @@ -263,7 +263,7 @@ def test_finetune_t0_non_causal_decoder_get_batch_pipe(self): special_tokens_ids={tokenizer.pad} ) - (tokens, position_ids, attention_mask), (labels, loss_mask) = finetune_t0_non_causal_decoder.get_batch_pipe(data) + (tokens, position_ids, attention_mask), (labels, loss_mask) = finetune_t0.get_batch_pipe(data) tokens = tokens.cpu() position_ids = position_ids.cpu() diff --git a/tests/test_model.py b/tests/test_model.py index 390e8664d..0e6a95554 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -20,7 +20,7 @@ from megatron.training import setup_model_and_optimizer import pretrain_gpt import pretrain_prefix_lm -import finetune_t0_non_causal_decoder +import finetune_t0 def get_default_args(test_file_dir: str): @@ -456,7 +456,7 @@ def test_non_causal_decoder_model_with_packed_input_passed_with_attention_mask_i vocab_size=args.padded_vocab_size, special_tokens_ids={tokenizer.pad} ) - model, _, _ = setup_model_and_optimizer(finetune_t0_non_causal_decoder.model_provider) + model, _, _ = setup_model_and_optimizer(finetune_t0.model_provider) model = model[0] model._config.train_micro_batch_size_per_gpu = args.micro_batch_size model.set_train_batch_size(args.micro_batch_size) diff --git a/tests/test_training.py b/tests/test_training.py index 2a031cc2d..686f09ed5 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -557,7 +557,7 @@ def test_training_t0(self): --deepspeed-activation-checkpointing """.split() - script = [f"{self.src_dir}/finetune_t0_non_causal_decoder.py"] + script = [f"{self.src_dir}/finetune_t0.py"] launcher = get_launcher(num_gpus) cmd = launcher + script + args + ds_args diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 953378680..f086a689b 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -82,13 +82,18 @@ def encode(self, json_line): ids = {} for key in self.args.json_keys: text = data[key] + if self.args.prepend_space: + text = f" {text}" doc_ids = [] for sentence in Encoder.splitter.tokenize(text): sentence_ids = Encoder.tokenizer.tokenize(sentence) if len(sentence_ids) > 0: doc_ids.append(sentence_ids) - if len(doc_ids) > 0 and self.args.append_eod: - doc_ids[-1].append(Encoder.tokenizer.eod) + if len(doc_ids) > 0: + if self.args.append_eod: + doc_ids[-1].append(Encoder.tokenizer.eod) + elif self.args.append_bos: + doc_ids[-1].append(Encoder.tokenizer.bos) ids[key] = doc_ids return ids, len(json_line) @@ -117,6 +122,10 @@ def get_args(): help='Path to the BPE merge file (if necessary).') group.add_argument('--append-eod', action='store_true', help='Append an token to the end of a document.') + group.add_argument('--append-bos', action='store_true', + help='Append a bos token to the end of a document.') + group.add_argument('--prepend-space', action='store_true', + help='Prepends a space to the beginning of a document') group.add_argument("--tokenizer-name-or-path", type=str, default=None, help="Name or path of the huggingface tokenizer.") group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,