diff --git a/finetune_t0_non_causal_decoder.py b/finetune_t0_non_causal_decoder.py new file mode 100644 index 000000000..0b649970f --- /dev/null +++ b/finetune_t0_non_causal_decoder.py @@ -0,0 +1,142 @@ +"""Multitask Finetuning T0""" + +from multiprocessing.sharedctypes import Value +import torch + +from megatron import get_args, get_tokenizer, print_rank_0, mpu +from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets +from megatron.enums import PositionEmbeddingType, AttnMaskType +from megatron.model import GPTModelPipe +from megatron.training import pretrain +from megatron.utils import get_ltor_masks_and_position_ids, get_packed_attention_mask + +import deepspeed +from deepspeed.runtime.utils import see_memory_usage + +try: + from torch.distributed.elastic.multiprocessing.errors import record +except ImportError: + # noop + def record(fn): + return fn + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0("building GPT model ...") + see_memory_usage(f"Before Building Model", force=True) + + args = get_args() + + with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), + remote_device=None if args.remote_device == "none" else args.remote_device, + config_dict_or_path=args.deepspeed_config, + enabled=args.zero_stage == 3, + mpu=mpu): + if args.deepspeed: + model = GPTModelPipe( + num_tokentypes=0, + parallel_output=True, + attn_mask_type=AttnMaskType.custom + ) + # This is a hack to give us a reference to get_batch_pipe from within training.py + # We need to call model.set_batch_fn after deepspeed.initialize + model._megatron_batch_fn = get_batch_pipe + else: + raise NotImplementedError("DeepSpeed is required for T0") + + see_memory_usage(f"After Building Model", force=True) + return model + +def get_batch_pipe(data): + """ + Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion + + data: + decoder_tokens = [[6, 7, 8, 3, 4, 5, 0]] + decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] + decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]] + """ + args = get_args() + tokenizer = get_tokenizer() + + # Broadcast data. + data_b = mpu.broadcast_data(["decoder_token_ids", "decoder_segment_ids"], data, torch.int64) + data_c = mpu.broadcast_data(["decoder_is_inputs"], data, torch.bool) + + # Unpack. + tokens_ = data_b["decoder_token_ids"].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + segment_ids = data_b["decoder_segment_ids"].long()[:, :-1] + decoder_is_inputs = data_c["decoder_is_inputs"][:, :-1] + + # Get the masks and position ids. + causal_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss, + prefix_indices=None, + loss_on_targets_only=False # This is done below + ) + # Only compute loss over causal target tokens, i.e. ignore input_tokens & padding + loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:] + loss_on_non_pad_only = (tokens != tokenizer.pad) + loss_mask *= loss_on_targets_only * loss_on_non_pad_only + + attention_mask = get_packed_attention_mask( + # Run non-causal decoder + is_causal=False, + causal_mask=~(causal_mask.bool()), + decoder_is_inputs=decoder_is_inputs.bool(), + segment_ids=segment_ids.long(), + ) + + if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: + raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.") + + return (tokens, position_ids, attention_mask), (labels, loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + train_ds, valid_ds, test_ds = None, None, None + + tokenizer = get_tokenizer() + + print_rank_0("> building train, validation, and test datasets for T0 ...") + # Option 1 of data loading using --data-path + if args.data_path: + # TODO: Not yet compatible with dataset weights (Will break at prefixes, weights = analyze_data_prefix(args.data_path)) + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string=args.split, + seq_length=args.seq_length + 1, + pad_token=tokenizer.pad, + eos_token=tokenizer.eos, + train_valid_test_num_samples=train_val_test_num_samples, + seed=args.seed, + skip_warmup=(not args.mmap_warmup) + ) + else: + raise NotImplementedError("No dataloading argument passed") + + print_rank_0("> finished creating T0 datasets ...") + return train_ds, valid_ds, test_ds + +@record +def main(): + pretrain( + train_valid_test_datasets_provider, + model_provider, + forward_step_func=None, + args_defaults={} + ) + +if __name__ == "__main__": + main() diff --git a/megatron/arguments.py b/megatron/arguments.py index c5e3faefd..302c40cd0 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -557,7 +557,7 @@ def _add_training_args(parser): 'please refer https://github.com/facebookresearch/bitsandbytes.', dest='use_bnb_optimizer') group.add_argument('--dataloader-type', type=str, default=None, - choices=['single', 'cyclic', 'decoder_packed'], + choices=['single', 'cyclic'], help='Single pass vs multiple pass data loader') group.add_argument('--cpu-optimizer', action='store_true', help='Run optimizer on CPU') diff --git a/megatron/data/data_samplers.py b/megatron/data/data_samplers.py index e95b1b41b..1e7e86347 100644 --- a/megatron/data/data_samplers.py +++ b/megatron/data/data_samplers.py @@ -15,77 +15,11 @@ """Dataloaders.""" -from functools import partial - -import numpy as np import torch -from megatron import get_args, get_tokenizer +from megatron import get_args from megatron import mpu -from megatron.data.mtf_dataset import MTFDataset - - -def pack_samples(items, max_seq_len: int, micro_batch_size: int, pad_token: int): - """ - Greedily packs samples. - - Items: - [ - { - 'input_tokens': array([6, 7]), - 'target_tokens': array([8]) - }, - { - 'input_tokens': array([3, 4]), - 'target_tokens': array([5]) - } - ] - - Output: - decoder_target_tokens = [[6, 7, 8, 3, 4, 5, ]]: Concatenation of tokens followed with padding tokens. - decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents. - decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]: `0` depicts inputs, `1` depicts target. - """ - - decoder_target_tokens = np.full((micro_batch_size, max_seq_len), pad_token) - decoder_segment_ids = np.zeros((micro_batch_size, max_seq_len)) - decoder_causal_attention = np.zeros((micro_batch_size, max_seq_len)) - - batch_num = 0 - # `0` is reserved for padding - item_num = 1 - cur_len = 0 - for token_dict in items: - input_token_len = len(token_dict["input_tokens"]) - target_token_len = len(token_dict["target_tokens"]) - total_len = input_token_len + target_token_len - if cur_len + total_len > max_seq_len: - len_diff = max_seq_len - cur_len - # Padding - if len_diff > 0: - decoder_target_tokens[batch_num][cur_len: max_seq_len] = pad_token - decoder_segment_ids[batch_num][cur_len: max_seq_len] = 0 - decoder_causal_attention[batch_num][cur_len: max_seq_len] = 0 - batch_num += 1 - assert batch_num < micro_batch_size - item_num = 1 - cur_len = 0 - - decoder_target_tokens[batch_num][cur_len: cur_len + input_token_len] = token_dict["input_tokens"] - decoder_target_tokens[batch_num][cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"] - decoder_segment_ids[batch_num][cur_len: cur_len + total_len] = item_num - decoder_causal_attention[batch_num][cur_len: cur_len + input_token_len] = 1 # input - decoder_causal_attention[batch_num][cur_len + input_token_len: cur_len + total_len] = 0 # target - - item_num += 1 - cur_len += total_len - assert cur_len < max_seq_len - - return { - "decoder_target_tokens": decoder_target_tokens, - "decoder_segment_ids": decoder_segment_ids, - "decoder_causal_attention": decoder_causal_attention, - } +from megatron.data.decoder_packed_mtf_dataset import DecoderPackedMTFDataset def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None): @@ -110,16 +44,6 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None): micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size()) - elif args.dataloader_type == 'decoder_packed': - assert isinstance(dataset, MTFDataset) - batch_sampler = MegatronDecoderPackedText2TextRandomSampler( - sequence_length=args.seq_length + 1, - dataset=dataset, - total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=args.micro_batch_size, - data_parallel_rank=mpu.get_data_parallel_rank(), - data_parallel_size=mpu.get_data_parallel_world_size()) else: raise Exception('{} dataloader type is not supported.'.format( args.dataloader_type)) @@ -127,24 +51,16 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None): if num_workers is None: num_workers = args.num_workers - collate_fn = None - if args.dataloader_type == 'decoder_packed': - assert isinstance(dataset, MTFDataset) - pad_token = get_tokenizer().pad - collate_fn = partial(pack_samples, max_seq_len=args.seq_length + 1, micro_batch_size=args.micro_batch_size, - pad_token=pad_token) - # Torch dataloader. return torch.utils.data.DataLoader( dataset, batch_sampler=batch_sampler, num_workers=num_workers, generator=torch.Generator().manual_seed(args.seed), - collate_fn=collate_fn, + collate_fn=None, pin_memory=True ) - class MegatronPretrainingSampler: def __init__(self, total_samples, consumed_samples, micro_batch_size, @@ -246,76 +162,3 @@ def __iter__(self): self.consumed_samples += self.micro_batch_times_data_parallel_size yield batch batch = [] - - -class MegatronDecoderPackedText2TextRandomSampler(object): - """ - Converts a two stream dataset with `input_tokens` and `target_tokens` and creates a batch that should be greedily - packed to be passed onto the decoder model. - - To be used with `pack_samples` as collate_fn - """ - - def __init__(self, sequence_length, dataset, total_samples, consumed_samples, micro_batch_size, - data_parallel_rank, data_parallel_size): - # Keep a copy of input params for later use. - self.dataset = dataset - self.sequence_length = sequence_length - self.total_samples = total_samples - self.consumed_samples = consumed_samples - self.micro_batch_size = micro_batch_size - self.data_parallel_rank = data_parallel_rank - self.data_parallel_size = data_parallel_size - self.micro_batch_times_data_parallel_size = \ - self.micro_batch_size * data_parallel_size - self.last_batch_size = \ - self.total_samples % self.micro_batch_times_data_parallel_size - - # Sanity checks. - assert self.total_samples > 0, \ - 'no sample to consume: {}'.format(self.total_samples) - assert self.micro_batch_size > 0 - assert data_parallel_size > 0 - assert self.data_parallel_rank < data_parallel_size, \ - 'data_parallel_rank should be smaller than data size: {}, ' \ - '{}'.format(self.data_parallel_rank, data_parallel_size) - - def __len__(self): - return self.total_samples - - def __iter__(self): - active_total_samples = self.total_samples - self.last_batch_size - self.epoch = self.consumed_samples // active_total_samples - current_epoch_samples = self.consumed_samples % active_total_samples - assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 - - # data sharding and random sampling - bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ - * self.micro_batch_size - bucket_offset = current_epoch_samples // self.data_parallel_size - start_idx = self.data_parallel_rank * bucket_size - - g = torch.Generator() - g.manual_seed(self.epoch) - - random_idx = torch.randperm(bucket_size, generator=g).tolist() - idx_range = [start_idx + x for x in random_idx[bucket_offset:]] - - batch = [] - batch_count = 0 - token_lens = 0 - # Last batch if not complete will be dropped. - for idx in idx_range: - tok_len = len(self.dataset[idx]['input_tokens']) + len(self.dataset[idx]['target_tokens']) - if token_lens + tok_len > self.sequence_length: - batch_count += 1 - token_lens = 0 - - if batch_count == self.micro_batch_size: - self.consumed_samples += self.micro_batch_times_data_parallel_size - yield batch - batch_count = 0 - batch = [] - else: - token_lens += tok_len - batch.append(idx) diff --git a/megatron/data/decoder_packed_mtf_dataset.py b/megatron/data/decoder_packed_mtf_dataset.py new file mode 100644 index 000000000..f504d7f91 --- /dev/null +++ b/megatron/data/decoder_packed_mtf_dataset.py @@ -0,0 +1,529 @@ +import os +import time + +import numpy as np +import torch + +from megatron import print_rank_0, mpu +from megatron.data.blendable_dataset import BlendableDataset +from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, get_split_by_range_, \ + get_train_valid_test_split_ +from megatron.data.mtf_dataset import MTFDataset +from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset + + +def build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + seq_length: int, + pad_token: int, + eos_token: int, + train_valid_test_num_samples, + seed, + skip_warmup +): + """Build train, valid, and test datasets.""" + + # Single dataset. + if len(data_prefix) == 1: + all_train_datasets, all_valid_datasets, all_test_datasets = _build_train_valid_test_datasets( + data_prefix=data_prefix[0], + data_impl=data_impl, + splits_string=splits_string, + seq_length=seq_length, + pad_token=pad_token, + eos_token=eos_token, + train_valid_test_num_samples=train_valid_test_num_samples, + seed=seed, + skip_warmup=skip_warmup + ) + # Blending dataset. + else: + + output = get_datasets_weights_and_num_samples(data_prefix=data_prefix, train_valid_test_num_samples=train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + data_prefix=prefixes[i], + data_impl=data_impl, + splits_string=splits_string, + seq_length=seq_length, + pad_token=pad_token, + eos_token=eos_token, + train_valid_test_num_samples=datasets_train_valid_test_num_samples[i], + seed=seed, + skip_warmup=skip_warmup + ) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + all_train_datasets = BlendableDataset(train_datasets, weights) \ + if train_datasets else None + all_valid_datasets = BlendableDataset(valid_datasets, weights) \ + if valid_datasets else None + all_test_datasets = BlendableDataset(test_datasets, weights) \ + if test_datasets else None + + return all_train_datasets, all_valid_datasets, all_test_datasets + + +def build_dataset_group( + dataset_group_name, + paths, + weights, + splits, + data_impl, + seq_length: int, + pad_token: int, + eos_token: int, + train_valid_test_num_samples, + seed, + skip_warmup, + train_valid_test +): + ''' + Build a single dataset group corresponding to Option 2 of data loading see arguments.py + a dataset group is passed in the following form + GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT2 START:END PATH2 + or alternatively + GIVEN_NAME PATH1 # for a single dataset to be used fully + ''' + + assert train_valid_test in ["train","valid","test"] + + # Single dataset. + if len(paths) == 1: + dataset = _build_single_datasets( + data_prefix=paths[0], + range_string=splits[0], + data_impl=data_impl, + seq_length=seq_length, + pad_token=pad_token, + eos_token=eos_token, + train_valid_test_num_samples=train_valid_test_num_samples, + seed=seed, + skip_warmup=skip_warmup, + dataset_group_name=dataset_group_name, + train_valid_test=train_valid_test + ) + return dataset + # Blending dataset. + else: + + data_prefix = [] + # data_prefix is of the shape: + # ["WEIGHT1", "PATH1", "WEIGHT2", "PATH2", "WEIGHT3", "PATH3"] + for w,p in zip(weights, paths): + data_prefix += [w,p] + + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + datasets = [] + for i in range(len(prefixes)): + ds = _build_single_datasets( + data_prefix=prefixes[i], + range_string=splits[i], + data_impl=data_impl, + seq_length=seq_length, + pad_token=pad_token, + eos_token=eos_token, + train_valid_test_num_samples=datasets_train_valid_test_num_samples[i], + seed=seed, + skip_warmup=skip_warmup, + dataset_group_name=dataset_group_name, + train_valid_test=train_valid_test + ) + + datasets.append(ds) + all_datasets = BlendableDataset(datasets, weights) + + return all_datasets + +def _build_single_datasets( + data_prefix, + range_string, + data_impl, + seq_length: int, + pad_token: int, + eos_token: int, + train_valid_test_num_samples, + seed, + skip_warmup, + dataset_group_name, + train_valid_test +): + """Build a single dataset""" + + assert train_valid_test in ["train","valid","test"] + index = ["train","valid","test"].index(train_valid_test) + + # Target indexed dataset. + target_indexed_dataset = get_indexed_dataset( + data_prefix=data_prefix, + is_input=False, + data_impl=data_impl, + skip_warmup=skip_warmup + ) + + total_num_of_documents = target_indexed_dataset.sizes.shape[0] + # this corresponds to option2 for data loading on the form + # WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT3 START:END PATH3 + # splits here is an array of size 2 [start_index, end_index] + splits = get_split_by_range_(range_string=range_string, size=total_num_of_documents) + + # Print stats about the splits. + print_rank_0(' > dataset split:') + + print_rank_0(' {}:'.format(dataset_group_name)) + print_rank_0(' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[0], splits[1], + splits[1] - splits[0])) + + def build_dataset(name): + dataset = None + if splits[1] > splits[0]: + documents = np.arange(start=splits[0], stop=splits[1], + step=1, dtype=np.int32) + dataset = DecoderPackedMTFDataset( + name=name, + data_prefix=data_prefix, + data_impl=data_impl, + skip_warmup=skip_warmup, + documents=documents, + seq_length=seq_length, + pad_token=pad_token, + eos_token=eos_token, + num_samples=train_valid_test_num_samples[index], + seed=seed + ) + return dataset + + dataset = build_dataset(dataset_group_name) + + return dataset + + +def _build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + seq_length: int, + pad_token: int, + eos_token: int, + train_valid_test_num_samples, + seed, + skip_warmup +): + """Build train, valid, and test datasets.""" + + # Target indexed dataset. + target_indexed_dataset = get_indexed_dataset(data_prefix, is_input=False, data_impl=data_impl, skip_warmup=skip_warmup) + + total_num_of_documents = target_indexed_dataset.sizes.shape[0] + # splits here is an array of size 4 [train_start_index, valid_start_index, test_start_index, test_end_index] + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + # Print stats about the splits. + print_rank_0(' > dataset split:') + + def print_split_stats(name, index): + print_rank_0(' {}:'.format(name)) + print_rank_0(' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[index], splits[index + 1], + splits[index + 1] - splits[index])) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + documents = np.arange(start=splits[index], stop=splits[index + 1], + step=1, dtype=np.int32) + dataset = DecoderPackedMTFDataset( + name=name, + data_prefix=data_prefix, + data_impl=data_impl, + skip_warmup=skip_warmup, + documents=documents, + seq_length=seq_length, + pad_token=pad_token, + eos_token=eos_token, + num_samples=train_valid_test_num_samples[index], + seed=seed + ) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +class DecoderPackedMTFDataset(torch.utils.data.Dataset): + + def __init__( + self, + name, + data_prefix, + data_impl, + skip_warmup, + documents, + num_samples, + seq_length: int, + pad_token: int, + eos_token: int, + seed, + ): + self.mtf_dataset = MTFDataset(name=name, data_prefix=data_prefix, data_impl=data_impl, skip_warmup=skip_warmup, documents=documents) + + self.pad_token = pad_token + self.seq_length = seq_length + + self.sample_index, self.shuffle_index = _build_index_mappings(name=name, data_prefix=data_prefix, nb_documents=len(documents), mtf_dataset=self.mtf_dataset, num_samples=num_samples, seq_length=seq_length, seed=seed) + + def __len__(self): + return len(self.sample_index) + + def __getitem__(self, idx): + # Get the shuffled index. + start, end = self.sample_index[idx] + mtf_samples_indices = self.shuffle_index[start: end] + # TODO @thomasw21 build a dataset that generates an entire batch instead of a row (allows for more optimization) + items = [self.mtf_dataset[sample_id] for sample_id in mtf_samples_indices] + + return self.pack_samples(items) + + def pack_samples(self, items): + """ + Greedily packs samples. + + Items: + [ + { + 'input_tokens': array([6, 7]), + 'target_tokens': array([8]) + }, + { + 'input_tokens': array([3, 4]), + 'target_tokens': array([5]) + } + ] + + Output: + decoder_tokens = [[6, 7, 8, 3, 4, 5, ]]: Concatenation of tokens followed with padding tokens. + decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents. + decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]]: `1` depicts inputs, `0` depicts target. + """ + + decoder_tokens = np.full((self.seq_length,), self.pad_token, dtype=np.int64) + decoder_segment_ids = np.zeros((self.seq_length,), dtype=np.int64) + decoder_is_inputs = np.full((self.seq_length,), False, dtype=bool) + + # `0` is reserved for padding + item_num = 1 + cur_len = 0 + + assert len(items) > 0 + + for token_dict in items: + input_token_len = len(token_dict["input_tokens"]) + target_token_len = len(token_dict["target_tokens"]) + + total_len = input_token_len + target_token_len + + if cur_len + total_len > self.seq_length: + # This should not happen at the indexing should only allow the correct number of items + raise ValueError(f"""Items to be packed do not fit inside a single sample. + current length: {cur_len} + input tokens length: {input_token_len} + target token length: {target_token_len} + expected sequence length: {self.seq_length} + """) + + decoder_tokens[cur_len: cur_len + input_token_len] = token_dict["input_tokens"] + decoder_tokens[cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"] + decoder_segment_ids[cur_len: cur_len + total_len] = item_num + decoder_is_inputs[cur_len: cur_len + input_token_len] = 1 # inputs + # targets are already 0 at init, no need to update `decoder_is_inputs` + + item_num += 1 + cur_len += total_len + assert cur_len <= self.seq_length + + return { + "decoder_token_ids": decoder_tokens, + "decoder_segment_ids": decoder_segment_ids, + "decoder_is_inputs": decoder_is_inputs, + } + + +def _build_index_mappings( + name, + data_prefix, + nb_documents, + mtf_dataset, + num_samples: int, + seq_length: int, + seed, +): + """ + - `shuffle_index` is [num_epoch * len(self.mtf)] + - `sample_index` is [num_sample, 2] (storing the start and end of the sample). We query the sample via `self.shuffle_index[start:end]` + + TODO @thomas21 Instead of loading individually samples, we save the packing one and for all + """ + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + _filename = data_prefix + _filename += '_{}_indexmap'.format(name) + _filename += '_{}ns'.format(num_samples) + _filename += '_{}s'.format(seed) + sample_idx_filename = _filename + '_decoder_packed_batch_idx.npy' + shuffle_idx_filename = _filename + '_decoder_packed_shuffle_idx.npy' + + # Build the indexed mapping if not exist. + if torch.distributed.get_rank() == 0: + if (not os.path.isfile(sample_idx_filename)) or \ + (not os.path.isfile(shuffle_idx_filename)): + + print_rank_0(' > WARNING: could not find index map files, building ' + 'the indices on rank 0 ...') + + # iteratively add the entire dataset for every epoch and see if it's enough given current packing strategy + start_time = time.time() + row_offset = 0 + old_sample_start = 0 + epoch = 0 + shuffle_idx = [] + sample_idx = [] + while len(sample_idx) <= num_samples: + new_document_ids = _build_shuffle_idx(nb_documents=nb_documents, np_rng=np_rng) + # Generate a shuffling of the entire dataset + shuffle_idx.append(new_document_ids) + # Packs them into a single sample + new_samples, row_offset, old_sample_start = _build_sample_idx( + mtf_dataset=mtf_dataset, + document_ids=new_document_ids, + seq_length=seq_length, + row_offset=row_offset, + old_sample_start=old_sample_start, + epoch=epoch + ) + sample_idx.extend(new_samples) + epoch += 1 + + shuffle_idx = np.concatenate(shuffle_idx, axis=0) + sample_idx = np.stack(sample_idx, axis=0) + + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save shuffle-idx and sample-idx mapping' + ' (seconds): {:4f}'.format(time.time() - start_time)) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + + # Load mappings. + start_time = time.time() + print_rank_0(' > loading doc-idx mapping from {}'.format( + sample_idx_filename)) + sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' > loading shuffle-idx mapping from {}'.format( + shuffle_idx_filename)) + shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + + return sample_idx, shuffle_idx + +def _build_sample_idx(mtf_dataset, document_ids, seq_length, row_offset, old_sample_start, epoch): + """Build start and off index of each `full` batch, return that list of batch + start of the unfinished batch""" + row_length = row_offset + + full_samples = [] + current_sample_start = old_sample_start + epoch_offset = epoch * len(document_ids) + + assert epoch_offset >= current_sample_start + for current_sample_end, document_id in enumerate(document_ids): + current_sample_end = epoch_offset + current_sample_end + sample_sizes = mtf_dataset.size(document_id) + + # TODO @thomasw21 figure out if we add tokens + tok_len = sample_sizes["input_tokens"] + sample_sizes["target_tokens"] + + row_length = row_length + tok_len + if row_length > seq_length: + # current sample can't be added and requires to be added in the next one + if current_sample_end > current_sample_start: + full_samples.append(np.asarray([current_sample_start, current_sample_end])) + current_sample_start = current_sample_end + row_length = tok_len + + if tok_len > seq_length: + # TODO @thomasw21 handle the case where a single sample cannot fit inside a row. We can + # - silently skip that value [currently implemented] + # - truncate to `seq_length`, and keep the right part + current_sample_start = current_sample_end + 1 # skipping + row_length = 0 + continue + + return full_samples, row_length, current_sample_start + +def _build_shuffle_idx(nb_documents: int, np_rng): + """Build the range [0, dataset_size) and shuffle.""" + dtype_ = np.int64 + + result = np.arange(start=0, stop=nb_documents, step=1, dtype=dtype_) + + # in-place shuffling + np_rng.shuffle(result) + + return result + + +def get_indexed_dataset(data_prefix: str, is_input: bool, data_impl: str, skip_warmup: bool): + if is_input: + field = "inputs" + else: + field = "targets" + + return get_indexed_dataset_(f"{data_prefix}_{field}_document", data_impl, skip_warmup) + + +def get_indexed_dataset_(path, data_impl, skip_warmup): + """Build indexed dataset.""" + print_rank_0(' > building dataset index ...') + start_time = time.time() + indexed_dataset = make_indexed_dataset(path, + data_impl, + skip_warmup) + print_rank_0(' > finished creating indexed dataset in {:4f} ' + 'seconds'.format(time.time() - start_time)) + print_rank_0(' number of documents: {}'.format( + indexed_dataset.sizes.shape[0])) + + return indexed_dataset diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index d92a0535b..d0d312544 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -573,6 +573,9 @@ def get(self, idx, offset=0, length=None): def sizes(self): return self._index.sizes + def size(self, index): + return self._index.sizes[index] + @property def doc_idx(self): return self._index.doc_idx diff --git a/megatron/data/mlm_dataset.py b/megatron/data/mlm_dataset.py index 4ac4624b1..dcc66d2c0 100644 --- a/megatron/data/mlm_dataset.py +++ b/megatron/data/mlm_dataset.py @@ -314,7 +314,7 @@ def __init__( indexed_dataset=self.indexed_dataset, num_samples=num_samples, # -1 because GPTDataset will return `seq_length + 1` sequences. - seq_length=number_of_raw_tokens - 1, + seq_length=self.number_of_raw_tokens - 1, seed=seed ) @@ -327,12 +327,8 @@ def __init__( assert len(self.sentinel_token_ids) >= self.num_noise_spans, "Not enough sentinel tokens, please add more" args = get_args() - if hasattr(args, "encoder_seq_length") and args.encoder_seq_length is not None: - # T5 style - assert self.inputs_length == args.encoder_seq_length - assert self.targets_length == args.decoder_seq_length + 1 - else: - assert self.inputs_length + self.targets_length == args.seq_length + # TODO @thomasw21 check once we merge t5 + assert self.inputs_length + self.targets_length == args.seq_length + 1 def __len__(self): return len(self._gpt_dataset) diff --git a/megatron/data/mtf_dataset.py b/megatron/data/mtf_dataset.py index 044a4ab3a..57f3a779b 100644 --- a/megatron/data/mtf_dataset.py +++ b/megatron/data/mtf_dataset.py @@ -15,242 +15,14 @@ """Multitask Finetune style dataset.""" -import os import time import numpy as np import torch -from megatron import mpu, print_rank_0 -from megatron.data.blendable_dataset import BlendableDataset -from megatron.data.dataset_utils import get_datasets_weights_and_num_samples -from megatron.data.dataset_utils import get_train_valid_test_split_, get_split_by_range_ +from megatron import print_rank_0 from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset - -def build_train_valid_test_datasets( - data_prefix, - data_impl, - splits_string, - train_valid_test_num_samples, - seed, - skip_warmup -): - """Build train, valid, and test datasets.""" - - # Single dataset. - if len(data_prefix) == 1: - all_train_datasets, all_valid_datasets, all_test_datasets = _build_train_valid_test_datasets( - data_prefix=data_prefix[0], - data_impl=data_impl, - splits_string=splits_string, - train_valid_test_num_samples=train_valid_test_num_samples, - seed=seed, - skip_warmup=skip_warmup - ) - # Blending dataset. - else: - - output = get_datasets_weights_and_num_samples(data_prefix=data_prefix, train_valid_test_num_samples=train_valid_test_num_samples) - prefixes, weights, datasets_train_valid_test_num_samples = output - - # Build individual datasets. - train_datasets = [] - valid_datasets = [] - test_datasets = [] - for i in range(len(prefixes)): - train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( - data_prefix=prefixes[i], - data_impl=data_impl, - splits_string=splits_string, - train_valid_test_num_samples=datasets_train_valid_test_num_samples[i], - seed=seed, - skip_warmup=skip_warmup - ) - if train_ds: - train_datasets.append(train_ds) - if valid_ds: - valid_datasets.append(valid_ds) - if test_ds: - test_datasets.append(test_ds) - - all_train_datasets = BlendableDataset(train_datasets, weights) \ - if train_datasets else None - all_valid_datasets = BlendableDataset(valid_datasets, weights) \ - if valid_datasets else None - all_test_datasets = BlendableDataset(test_datasets, weights) \ - if test_datasets else None - - return all_train_datasets, all_valid_datasets, all_test_datasets - - -def build_dataset_group(dataset_group_name, paths, weights, splits, data_impl, - train_valid_test_num_samples, - seed, skip_warmup, train_valid_test): - ''' - Build a single dataset group corresponding to Option 2 of data loading see arguments.py - a dataset group is passed in the following form - GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT2 START:END PATH2 - or alternatively - GIVEN_NAME PATH1 # for a single dataset to be used fully - ''' - - assert train_valid_test in ["train","valid","test"] - - # Single dataset. - if len(paths) == 1: - dataset = _build_single_datasets( - data_prefix=paths[0], - range_string=splits[0], - data_impl=data_impl, - train_valid_test_num_samples=train_valid_test_num_samples, - seed=seed, - skip_warmup=skip_warmup, - dataset_group_name=dataset_group_name, - train_valid_test=train_valid_test - ) - return dataset - # Blending dataset. - else: - - data_prefix = [] - # data_prefix is of the shape: - # ["WEIGHT1", "PATH1", "WEIGHT2", "PATH2", "WEIGHT3", "PATH3"] - for w,p in zip(weights, paths): - data_prefix += [w,p] - - output = get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples) - prefixes, weights, datasets_train_valid_test_num_samples = output - - # Build individual datasets. - datasets = [] - for i in range(len(prefixes)): - ds = _build_single_datasets( - data_prefix=prefixes[i], - range_string=splits[i], - data_impl=data_impl, - train_valid_test_num_samples=datasets_train_valid_test_num_samples[i], - seed=seed, - skip_warmup=skip_warmup, - dataset_group_name=dataset_group_name, - train_valid_test=train_valid_test - ) - - datasets.append(ds) - all_datasets = BlendableDataset(datasets, weights) - - return all_datasets - -def _build_single_datasets( - data_prefix, - range_string, - data_impl, - train_valid_test_num_samples, - seed, - skip_warmup, - dataset_group_name, - train_valid_test -): - """Build a single dataset""" - - assert train_valid_test in ["train","valid","test"] - index = ["train","valid","test"].index(train_valid_test) - - # Target indexed dataset. - target_indexed_dataset = get_indexed_dataset( - data_prefix=data_prefix, - is_input=False, - data_impl=data_impl, - skip_warmup=skip_warmup - ) - - total_num_of_documents = target_indexed_dataset.sizes.shape[0] - # this corresponds to option2 for data loading on the form - # WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT3 START:END PATH3 - # splits here is an array of size 2 [start_index, end_index] - splits = get_split_by_range_(range_string=range_string, size=total_num_of_documents) - - # Print stats about the splits. - print_rank_0(' > dataset split:') - - print_rank_0(' {}:'.format(dataset_group_name)) - print_rank_0(' document indices in [{}, {}) total of {} ' - 'documents'.format(splits[0], splits[1], - splits[1] - splits[0])) - - def build_dataset(name): - dataset = None - if splits[1] > splits[0]: - documents = np.arange(start=splits[0], stop=splits[1], - step=1, dtype=np.int32) - dataset = MTFDataset( - name=name, - data_prefix=data_prefix, - data_impl=data_impl, - skip_warmup=skip_warmup, - documents=documents, - num_samples=train_valid_test_num_samples[index], - seed=seed - ) - return dataset - - dataset = build_dataset(dataset_group_name) - - return dataset - - -def _build_train_valid_test_datasets( - data_prefix, - data_impl, - splits_string, - train_valid_test_num_samples, - seed, - skip_warmup -): - """Build train, valid, and test datasets.""" - - # Target indexed dataset. - target_indexed_dataset = get_indexed_dataset(data_prefix, is_input=False, data_impl=data_impl, skip_warmup=skip_warmup) - - total_num_of_documents = target_indexed_dataset.sizes.shape[0] - # splits here is an array of size 4 [train_start_index, valid_start_index, test_start_index, test_end_index] - splits = get_train_valid_test_split_(splits_string, total_num_of_documents) - # Print stats about the splits. - print_rank_0(' > dataset split:') - - def print_split_stats(name, index): - print_rank_0(' {}:'.format(name)) - print_rank_0(' document indices in [{}, {}) total of {} ' - 'documents'.format(splits[index], splits[index + 1], - splits[index + 1] - splits[index])) - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) - - def build_dataset(index, name): - dataset = None - if splits[index + 1] > splits[index]: - documents = np.arange(start=splits[index], stop=splits[index + 1], - step=1, dtype=np.int32) - dataset = MTFDataset( - name=name, - data_prefix=data_prefix, - data_impl=data_impl, - skip_warmup=skip_warmup, - documents=documents, - num_samples=train_valid_test_num_samples[index], - seed=seed - ) - return dataset - - train_dataset = build_dataset(0, 'train') - valid_dataset = build_dataset(1, 'valid') - test_dataset = build_dataset(2, 'test') - - return (train_dataset, valid_dataset, test_dataset) - - class MTFDataset(torch.utils.data.Dataset): def __init__( @@ -260,14 +32,9 @@ def __init__( data_impl, skip_warmup, documents, - num_samples, - seed, - impossible_token=-100, ): - # Params to store. self.name = name - self.impossible_token = impossible_token # Dataset. self.input_indexed_dataset = get_indexed_dataset(data_prefix, is_input=True, data_impl=data_impl, skip_warmup=skip_warmup) @@ -279,130 +46,26 @@ def __init__( assert np.max(documents) < self.target_indexed_dataset.sizes.shape[0] assert self.input_indexed_dataset.sizes.shape[0] == self.target_indexed_dataset.sizes.shape[0] - # Build index mappings. - self.doc_idx, self.shuffle_idx = _build_index_mappings( - name=self.name, - data_prefix=data_prefix, - documents=documents, - num_samples=num_samples, - seed=seed - ) - def __len__(self): - # -1 is due to data structure used to retieve the index: - # sample i --> [sample_idx[i], sample_idx[i+1]) - # return self.doc_idx.shape[0] - 1 - return len(self.doc_idx) + return len(self.input_indexed_dataset) def __getitem__(self, idx): - # Get the shuffled index. - idx = self.shuffle_idx[idx] - input_tokens = self.input_indexed_dataset.get(self.doc_idx[idx]) - target_tokens = self.target_indexed_dataset.get(self.doc_idx[idx]) + input_tokens = self.input_indexed_dataset.get(idx) + target_tokens = self.target_indexed_dataset.get(idx) + + assert len(input_tokens) > 0 + assert len(target_tokens) > 0 return { - 'input_tokens': np.array(input_tokens, dtype=np.int64), - 'target_tokens': np.array(target_tokens, dtype=np.int64), + 'input_tokens': input_tokens, + 'target_tokens': target_tokens, } - -def _build_index_mappings( - name, - data_prefix, - documents, - num_samples, - seed, -): - """Build doc-idx, sample-idx, and shuffle-idx. - doc-idx: is an array (ordered) of documents to be used in training. - shuffle-idx: maps an index into a random index into sample-idx. - """ - # rng state - np_rng = np.random.RandomState(seed=seed) - - # Filename of the index mappings. - _filename = data_prefix - _filename += '_{}_indexmap'.format(name) - _filename += '_{}ns'.format(num_samples) - _filename += '_{}s'.format(seed) - doc_idx_filename = _filename + '_doc_idx.npy' - shuffle_idx_filename = _filename + '_shuffle_idx.npy' - - # Build the indexed mapping if not exist. - if torch.distributed.get_rank() == 0: - if (not os.path.isfile(doc_idx_filename)) or \ - (not os.path.isfile(shuffle_idx_filename)): - - print_rank_0(' > WARNING: could not find index map files, building ' - 'the indices on rank 0 ...') - - # doc-idx. - start_time = time.time() - doc_idx = _build_doc_idx(documents, np_rng) - np.save(doc_idx_filename, doc_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save doc-idx mapping ' - '(seconds): {:4f}'.format(time.time() - start_time)) - shuffle_idx = _build_shuffle_idx(doc_idx.shape[0] - 1 , doc_idx.shape[0] - 1, np_rng) - np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save shuffle-idx mapping' - ' (seconds): {:4f}'.format(time.time() - start_time)) - - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) - torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) - assert counts[0].item() == ( - torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) - - # Load mappings. - start_time = time.time() - print_rank_0(' > loading doc-idx mapping from {}'.format( - doc_idx_filename)) - doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r') - print_rank_0(' > loading shuffle-idx mapping from {}'.format( - shuffle_idx_filename)) - shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r') - print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( - time.time() - start_time)) - - return doc_idx, shuffle_idx - - -def _build_doc_idx(documents, np_rng): - """Build an array with length = number-of-epochs * number-of-dcuments. - Each index is mapped to a corresponding document.""" - num_epochs = 1 - doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1] - doc_idx[:] = documents - doc_idx = doc_idx.reshape(-1) - doc_idx = doc_idx.astype(np.int32) - np_rng.shuffle(doc_idx) - return doc_idx - - -def _build_shuffle_idx(num_samples, total_size, np_rng): - """Build the range [0, size) and shuffle.""" - print(' > building shuffle index with split [0, {}) and [{}, {}) ' - '...'.format(num_samples, num_samples, total_size), flush=True) - - dtype_ = np.uint32 - if total_size >= (np.iinfo(np.uint32).max - 1): - dtype_ = np.int64 - - shuffle_idx_first = np.arange(start=0, stop=num_samples, - step=1, dtype=dtype_) - np_rng.shuffle(shuffle_idx_first) - if num_samples == total_size: - return shuffle_idx_first - - shuffle_idx_last = np.arange(start=num_samples, stop=total_size, - step=1, dtype=dtype_) - np_rng.shuffle(shuffle_idx_last) - - return np.concatenate((shuffle_idx_first, shuffle_idx_last)) + def size(self, index): + return { + 'input_tokens': self.input_indexed_dataset.size(index), + 'target_tokens': self.target_indexed_dataset.size(index), + } def get_indexed_dataset(data_prefix: str, is_input: bool, data_impl: str, skip_warmup: bool): if is_input: diff --git a/megatron/enums.py b/megatron/enums.py index d9050462a..90d00a071 100644 --- a/megatron/enums.py +++ b/megatron/enums.py @@ -25,8 +25,9 @@ class AttnType(enum.Enum): class AttnMaskType(enum.Enum): padding = 1 - causal = 2 + causal = 2 # Overrides `attention_mask` to be a lower triangular matrix prefix = 3 + custom = 4 # Forces one to pass an `attention_mask` that's 1 if we need to mask. Tensor that can be broadcast to [micro_batch_size, n_head, seq_length, seq_length] class PositionEmbeddingType(enum.Enum): rotary = 1 diff --git a/megatron/fused_kernels/scaled_masked_softmax.h b/megatron/fused_kernels/scaled_masked_softmax.h index e57fd04c6..013dd8366 100644 --- a/megatron/fused_kernels/scaled_masked_softmax.h +++ b/megatron/fused_kernels/scaled_masked_softmax.h @@ -47,6 +47,22 @@ __device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t * template <> __device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + + int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; @@ -94,16 +110,16 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { /* * Extended softmax (from native aten pytorch) with following additional features * 1) input scaling - */ + */ template __global__ void scaled_softmax_warp_forward( - output_t *dst, + output_t *dst, const input_t *src, - const acc_t scale, - int micro_batch_size, + const acc_t scale, + int micro_batch_size, int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_forward_kernel. constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; @@ -112,7 +128,7 @@ __global__ void scaled_softmax_warp_forward( constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) + // gridDim/blockIdx = (seq_len, attn_heads, batches) int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; // micro_batch_size might not be a multiple of WARP_BATCH. Check how @@ -192,10 +208,10 @@ __global__ void scaled_softmax_warp_forward( for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); + copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; - } + } } } } @@ -205,18 +221,18 @@ __global__ void scaled_softmax_warp_forward( * Extended softmax (from native aten pytorch) with following additional features * 1) input scaling * 2) Explicit masking - */ + */ template __global__ void scaled_masked_softmax_warp_forward( - output_t *dst, + output_t *dst, const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, + const uint8_t *mask, + const acc_t scale, + int micro_batch_size, int element_count, - int pad_batches) + int pad_batches) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_forward_kernel. constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; @@ -225,7 +241,7 @@ __global__ void scaled_masked_softmax_warp_forward( constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) + // gridDim/blockIdx = (seq_len, attn_heads, batches) int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; int pad_first_batch = 0; if (pad_batches != 1) { // bert style @@ -269,7 +285,7 @@ __global__ void scaled_masked_softmax_warp_forward( if (temp_mask[element] != 1) { elements[i][it + element] = (acc_t)temp_data[element] * scale; } else { - elements[i][it + element] = -10000.0; + elements[i][it + element] = -std::numeric_limits::infinity(); } } } else { @@ -298,7 +314,11 @@ __global__ void scaled_masked_softmax_warp_forward( for (int i = 0; i < WARP_BATCH; ++i) { #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); + if (elements[i][it] <= -std::numeric_limits::infinity()) { + elements[i][it] = 0.0f; + } else { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + } sum[i] += elements[i][it]; } } @@ -314,28 +334,32 @@ __global__ void scaled_masked_softmax_warp_forward( for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; + if (sum[i] == 0.0f) { + copy_zero_vector(dst + i * element_count + it * WARP_SIZE); + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; - } + } } } } template __global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, + output_t *gradInput, + input_t *grad, const input_t *output, - acc_t scale, - int micro_batch_size, + acc_t scale, + int micro_batch_size, int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; @@ -344,9 +368,9 @@ __global__ void scaled_masked_softmax_warp_backward( constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) + // gridDim/blockIdx = (seq_len, attn_heads, batches) int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - + // micro_batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = micro_batch_size - first_batch; @@ -386,10 +410,10 @@ __global__ void scaled_masked_softmax_warp_backward( for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; } - } + } } } - + acc_t sum[WARP_BATCH]; #pragma unroll for (int i = 0; i < WARP_BATCH; ++i) { @@ -417,7 +441,7 @@ __global__ void scaled_masked_softmax_warp_backward( out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); } copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); - } + } } } } @@ -439,11 +463,11 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att template void dispatch_scaled_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int query_seq_len, - int key_seq_len, + output_t *dst, + const input_t *src, + const input_t scale, + int query_seq_len, + int key_seq_len, int batches, int attn_heads) { @@ -531,12 +555,12 @@ void dispatch_scaled_softmax_forward( template void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, + output_t *dst, + const input_t *src, const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, + const input_t scale, + int query_seq_len, + int key_seq_len, int batches, int attn_heads, int pad_batches) @@ -625,12 +649,12 @@ void dispatch_scaled_masked_softmax_forward( template void dispatch_scaled_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, int batches, int attn_heads) { diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index a4a788586..e2983a75d 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -124,7 +124,6 @@ class FusedScaleMaskSoftmax(nn.Module): softmax_in_fp32: if true, softmax in performed at fp32 precision. scale: scaling factor used in input tensor scaling. """ - custom_kernel_friendly_attn_mask_type = [AttnMaskType.causal, AttnMaskType.padding] def __init__( self, @@ -189,6 +188,7 @@ def forward_fused_softmax(self, input, mask): if self.attn_mask_type == AttnMaskType.causal: assert sq == sk, "causal mask is only for self attention" + assert mask is None, "Mask is silently ignored due to the use of a custom kernel" # input is 3D tensor (attn_batches, sq, sk) input = input.view(-1, sq, sk) @@ -207,7 +207,14 @@ def forward_torch_softmax(self, input, mask): if self.scale is not None: input = input * self.scale + + if self.attn_mask_type == AttnMaskType.causal: + assert mask is None + mask = torch.ones_like(input, dtype=torch.bool) + mask = torch.triu(mask, diagonal=1, out=mask) + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) if self.input_in_float16 and self.softmax_in_fp32: diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 31d33a91b..dce77d23d 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -202,7 +202,7 @@ def __init__( self, num_tokentypes=0, parallel_output=True, - prefix_lm=False + attn_mask_type: AttnMaskType = AttnMaskType.causal ): args = get_args() self.parallel_output = parallel_output @@ -252,7 +252,7 @@ def _to_float16(inputs): args.num_layers), layer_number=layer_idx, # TODO: Change naming of class from GPT to something that encapsulate prefix lm. - self_attn_mask_type=AttnMaskType.prefix if prefix_lm else AttnMaskType.causal)) + self_attn_mask_type=attn_mask_type)) if not hasattr(args, 'attn_mask'): @@ -314,7 +314,7 @@ def _logits_helper(embedding, lm_output): partition_method = 'type:transformer' super().__init__(layers=self.specs, - loss_fn=get_cross_entropy(is_prefix=prefix_lm), + loss_fn=get_cross_entropy(is_prefix=attn_mask_type is AttnMaskType.prefix), topology=topo, activation_checkpoint_interval=interval, partition_method=partition_method) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 48401a9f1..03e6faaec 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -333,6 +333,7 @@ def forward(self, hidden_states, attention_mask, layer_past=None, if get_key_value: with torch.no_grad(): + # TODO @thomasw21 Handle case where `attention_mask` is None if layer_past is not None: attention_mask = attention_mask[ ..., @@ -633,17 +634,11 @@ class ParallelTransformerLayerPipe(ParallelTransformerLayer): 2) forward(input, **kwargs) -> output When the mask is static over all samples, it is advantageous to cache the mask and avoid communicating it. - - If no mask is provided, the module will query `self._args.attn_mask` - for the mask and only return `super().forward(...)` """ def forward(self, inputs, **kwargs): assert torch.is_tensor(inputs) or isinstance(inputs, tuple) if torch.is_tensor(inputs) or len(inputs) == 1: - # No attention mask forwarded, search for args.attn_mask - if not hasattr(self, '_args'): - self._args = get_args() - hidden_states, attention_mask = inputs, self._args.attn_mask + hidden_states, attention_mask = inputs, None return super().forward(hidden_states, attention_mask, **kwargs) elif len(inputs) == 2: # Attention mask is an activation. diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 8c3908a93..18f008de8 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -48,9 +48,9 @@ def attention_mask_func(attention_scores, attention_mask): if actual_seqlen != attention_mask_.size()[2]: # attention_mask has size [1, 1, seqlen, seqlen] attention_mask_ = attention_mask_[:, :, :actual_seqlen, :actual_seqlen].contiguous() - attention_scores.masked_fill_(attention_mask_, -10000.0) + attention_scores.masked_fill_(attention_mask_, torch.finfo(attention_scores.dtype).min) else: - attention_scores.masked_fill_(attention_mask, -10000.0) + attention_scores.masked_fill_(attention_mask, torch.finfo(attention_scores.dtype).min) return attention_scores diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index fcc3ed20d..09304b1dd 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -355,33 +355,46 @@ def detokenize(self, token_ids): @property def eod(self): - return self.tokenizer.eos_token_id + # TODO @thomasw21 might conflict with + return self.eos @property def cls(self): - return self.tokenizer.cls_token_id + candidate = self.tokenizer.cls_token_id + return self._check_token_candidate(candidate) @property def sep(self): - return self.tokenizer.sep_token_id + candidate = self.tokenizer.sep_token_id + return self._check_token_candidate(candidate) @property def pad(self): - return self.tokenizer.pad_token_id + candidate = self.tokenizer.pad_token_id + return self._check_token_candidate(candidate) @property def mask(self): - return self.tokenizer.mask_token_id + candidate = self.tokenizer.mask_token_id + return self._check_token_candidate(candidate) @property - def additional_special_tokens_ids(self): - """ All the additional special tokens you may want to use (list of strings).""" - return self.tokenizer.additional_special_tokens_ids + def bos(self): + raise NotImplementedError("Missing ") @property - def bos_token_id(self): - raise NotImplementedError("Missing ") + def eos(self): + # TODO @thomasw21 might conflict with the notion of + candidate = self.tokenizer.eos_token_id + return self._check_token_candidate(candidate) @property - def eos_token_id(self): - raise NotImplementedError("Missing ") + def additional_special_tokens_ids(self): + """ All the additional special tokens you may want to use (list of strings).""" + return self.tokenizer.additional_special_tokens_ids + + @staticmethod + def _check_token_candidate(candidate): + if candidate is None: + raise AttributeError("Token doesn't exist") + return candidate diff --git a/megatron/training.py b/megatron/training.py index bbf6623e3..2d45ca808 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1181,20 +1181,20 @@ def build_train_valid_test_data_iterators( assert dl_type in ['single', 'cyclic'] if train_dataloader is not None: - train_data_iterator = iter(train_dataloader) if dl_type == 'single' \ + train_data_iterator = iter(train_dataloader) if dl_type in ['single'] \ else iter(cyclic_iter(train_dataloader)) else: train_data_iterator = None if valid_dataloaders is not None: - valid_data_iterators = [iter(vdl) if dl_type == 'single' \ + valid_data_iterators = [iter(vdl) if dl_type in ['single'] \ else iter(cyclic_iter(valid_dataloaders)) for vdl in valid_dataloaders] else: valid_data_iterators = [None] * num_valid_ds if test_dataloaders is not None: - test_data_iterators = [iter(tdl) if dl_type == 'single' \ + test_data_iterators = [iter(tdl) if dl_type in ['single'] \ else iter(cyclic_iter(test_dataloaders)) for tdl in test_dataloaders] else: diff --git a/megatron/utils.py b/megatron/utils.py index 98d2f611c..6f3a0fa41 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -250,6 +250,75 @@ def get_ltor_masks_and_position_ids( return attention_mask, loss_mask, position_ids +def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decoder_is_inputs: torch.Tensor, segment_ids: torch.Tensor): + """ + Inspired by https://github.com/google-research/t5x/blob/7193407f98a8b18100b71a04ff777238be1682ca/t5x/examples/decoder_only/layers.py#L978 + + Arguments: + - is_causal: determines if the masking should be causal in the `inputs` part + - causal_mask: torch.BoolTensor [batch_size, sequence_length, sequence_length] + - decoder_is_inputs: torch.BoolTensor [batch_size, sequence_length] + - segment_ids: torch.IntTensor [batch_size, sequence_length] + Returns: + - attention_mask: torch.BoolTensor [batch_size, 1, sequence_length, sequence_length] + """ + + """Causal Inputs Mask: + mask = [[[[1, 1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1]]]] + """ + assert causal_mask.dtype == torch.bool + assert segment_ids.dtype == torch.long + if is_causal: + causal_inputs_mask = causal_mask + else: + assert decoder_is_inputs.dtype == torch.bool + inputs_mask = decoder_is_inputs[:, None, :, None] * decoder_is_inputs[:, None, None, :] + causal_inputs_mask = causal_mask + inputs_mask + + """Padding Mask: + mask = [[[[1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]] + """ + padding_mask = (segment_ids != 0)[:, None, :, None] * (segment_ids != 0)[:, None, None, :] + + """Segment Mask: + mask = [[[[1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]] + """ + segment_mask = segment_ids[:, None, :, None] == segment_ids[:, None, None, :] + + """Final Mask: + mask = [[[[1, 1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]] + """ + attention_mask = causal_inputs_mask * padding_mask * segment_mask + + # Convert attention mask to binary: + attention_mask = (attention_mask < 0.5) + + return attention_mask + def param_size(parameter): return parameter.ds_numel if hasattr(parameter, 'ds_id') else parameter.nelement() diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 04f1b3b57..fdd4d28be 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -23,6 +23,7 @@ from megatron import get_tokenizer from megatron import mpu from megatron.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group +from megatron.enums import AttnMaskType from megatron.model import GPTModel, GPTModelPipe from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices @@ -53,26 +54,13 @@ def model_provider(pre_process=True, post_process=True): enabled=args.zero_stage == 3, mpu=mpu): if args.deepspeed: - # Precompute the attention mask and store it in args. This avoids having to - # pipeline it as an activation during training. The mask is constant, and thus - # we can reuse it. - attention_mask = torch.tril(torch.ones( - (1, args.seq_length, args.seq_length), device=torch.cuda.current_device())).view( - 1, 1, args.seq_length, args.seq_length) - - # Convert attention mask to binary: - attention_mask = (attention_mask < 0.5) - if args.fp16: - attention_mask = attention_mask.half() - elif args.bf16: - attention_mask = attention_mask.bfloat16() - - # must be bool or the training crashes expecting bool, but getting Half - args.attn_mask = attention_mask.to(torch.bool) + # Hack @thomasw21 to get fused_softmax.forward_torch_softmax working + args.attn_mask = None model = GPTModelPipe( num_tokentypes=0, - parallel_output=True + parallel_output=True, + attn_mask_type=AttnMaskType.causal ) # This is a hack to give us a reference to get_batch_pipe from within training.py # We need to call model.set_batch_fn after deepspeed.initialize diff --git a/pretrain_prefix_lm.py b/pretrain_prefix_lm.py index 391186e75..c531db863 100644 --- a/pretrain_prefix_lm.py +++ b/pretrain_prefix_lm.py @@ -23,6 +23,7 @@ from megatron import get_tokenizer from megatron import mpu from megatron.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group +from megatron.enums import AttnMaskType from megatron.model import GPTModel, GPTModelPipe from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices, reweight_loss_mask_ @@ -49,7 +50,7 @@ def model_provider(pre_process=True, post_process=True): model = GPTModelPipe( num_tokentypes=0, parallel_output=True, - prefix_lm=True + attn_mask_type=AttnMaskType.prefix ) # This is a hack to give us a reference to get_batch_pipe from within training.py # We need to call model.set_batch_fn after deepspeed.initialize diff --git a/tests/data/gpt2/README.md b/tests/data/gpt2/README.md new file mode 100644 index 000000000..ad8eed839 --- /dev/null +++ b/tests/data/gpt2/README.md @@ -0,0 +1,3 @@ +Dataset used for testing. + +`ag_news_prompt*`: manually generated from dataset available at https://huggingface.co/datasets/TimeRobber/ag_news_classify_question_first_100 \ No newline at end of file diff --git a/tests/data/gpt2/ag_news_prompt_inputs_document.bin b/tests/data/gpt2/ag_news_prompt_inputs_document.bin index b786d6e41..4a7f085de 100644 Binary files a/tests/data/gpt2/ag_news_prompt_inputs_document.bin and b/tests/data/gpt2/ag_news_prompt_inputs_document.bin differ diff --git a/tests/data/gpt2/ag_news_prompt_inputs_document.idx b/tests/data/gpt2/ag_news_prompt_inputs_document.idx index 0d55a0f6f..8af1e3897 100644 Binary files a/tests/data/gpt2/ag_news_prompt_inputs_document.idx and b/tests/data/gpt2/ag_news_prompt_inputs_document.idx differ diff --git a/tests/data/gpt2/ag_news_prompt_targets_document.bin b/tests/data/gpt2/ag_news_prompt_targets_document.bin index 60646247e..ac2ba952c 100644 Binary files a/tests/data/gpt2/ag_news_prompt_targets_document.bin and b/tests/data/gpt2/ag_news_prompt_targets_document.bin differ diff --git a/tests/data/gpt2/ag_news_prompt_targets_document.idx b/tests/data/gpt2/ag_news_prompt_targets_document.idx index 29e152ec6..b0e7d3eae 100644 Binary files a/tests/data/gpt2/ag_news_prompt_targets_document.idx and b/tests/data/gpt2/ag_news_prompt_targets_document.idx differ diff --git a/tests/data/gpt2/generate_ag_news_mtf_dataset.sh b/tests/data/gpt2/generate_ag_news_mtf_dataset.sh new file mode 100644 index 000000000..e6ec6ef75 --- /dev/null +++ b/tests/data/gpt2/generate_ag_news_mtf_dataset.sh @@ -0,0 +1,22 @@ +python -c "from datasets import load_dataset; load_dataset('TimeRobber/ag_news_classify_question_first_100', split='train').to_json('ag_news_classify_question_first_100.jsonl')" + +python tools/preprocess_data.py \ + --input ag_news_classify_question_first_100.jsonl \ + --output-prefix tests/data/gpt2/ag_news_prompt \ + --dataset-impl mmap \ + --json-key targets \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path bigscience/tokenizer \ + --append-eod \ + --workers 8 + +python tools/preprocess_data.py \ + --input ag_news_classify_question_first_100.jsonl \ + --output-prefix tests/data/gpt2/ag_news_prompt \ + --dataset-impl mmap \ + --json-key inputs \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path bigscience/tokenizer \ + --workers 8 + +rm ag_news_classify_question_first_100.jsonl diff --git a/tests/ds_config_inference.json b/tests/ds_config_inference.json new file mode 100644 index 000000000..91314429e --- /dev/null +++ b/tests/ds_config_inference.json @@ -0,0 +1,15 @@ +{ + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": 16, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 12 + }, + "zero_allow_untested_optimizer": false, + "steps_per_print": 2000, + "wall_clock_breakdown": false +} diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py index b529e1fd3..30ec1328f 100644 --- a/tests/test_dataloaders.py +++ b/tests/test_dataloaders.py @@ -1,13 +1,17 @@ import itertools -import unittest +import os +import shutil +from typing import Set from unittest.mock import patch import deepspeed +import torch +import finetune_t0_non_causal_decoder from megatron import global_vars, get_tokenizer, initialize_megatron, get_args -from megatron.data import mlm_dataset, mtf_dataset +from megatron.data import mlm_dataset, mtf_dataset, decoder_packed_mtf_dataset from megatron.data.data_samplers import build_pretraining_data_loader -from megatron.testing_utils import TestCasePlus, flatten_arguments, mockenv_context +from megatron.testing_utils import TestCasePlus, flatten_arguments, mockenv_context, torch_assert_equal def get_default_args(): @@ -49,6 +53,42 @@ def get_default_args(): # DATA_ARGS } +def get_dummy_mtf_decoder_packed_data(micro_batch_size: int, seq_length: int, vocab_size: int, special_tokens_ids: Set[int]): + seq_length += 1 + + num_segments = torch.randint(1, 5, ()) + segment_ids = torch.zeros(micro_batch_size, seq_length, dtype=torch.long) + is_inputs = torch.zeros(micro_batch_size, seq_length, dtype=torch.bool) + for batch_id in range(micro_batch_size): + # - `*2`: Hack in order to two start_new_segements to be seperated with two tokens at least + # - `+1`: Hack in order the start_mew_segments not to be 0 + start_new_segments = torch.sort(torch.randperm((seq_length - 2) // 2, )[:num_segments]).values * 2 + 1 + segment_ids[batch_id, start_new_segments] = 1 + + end_inputs = [ + torch.randint(low=start_segment, high=end_segment, size=()) + for start_segment, end_segment in zip([0, *start_new_segments], [*start_new_segments, seq_length]) + ] + for end_input, start_segment in zip(end_inputs, [0, *start_new_segments]): + is_inputs[batch_id][start_segment: end_input + 1] = True + + segment_ids = torch.cumsum(segment_ids, dim=-1) + 1 + + tokens = torch.randint(high=vocab_size, size=(micro_batch_size, seq_length), dtype=torch.long) + flatten_token_view = tokens.view(-1,) + for token_id in range(len(flatten_token_view)): + token = flatten_token_view[token_id] + # While token is a special tokens we change that token + while token in special_tokens_ids: + flatten_token_view[token_id] = (token + 1) % vocab_size + token = flatten_token_view[token_id] + + return { + "decoder_token_ids": tokens, + "decoder_segment_ids": segment_ids, + "decoder_is_inputs": is_inputs + } + class TestDataLoading(TestCasePlus): def setUp(self) -> None: super().setUp() @@ -65,10 +105,29 @@ def setUp(self) -> None: MASTER_ADDR="localhost", MASTER_PORT="9994", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1" ) - @unittest.skip("broken test") + def copy_data_to_temp(self, root_dir, prefix): + """copy data to temp, and return paths to temp version""" + src_path = os.path.join(root_dir, prefix) + src_dirname = os.path.dirname(src_path) + + tmp_dir = self.get_auto_remove_tmp_dir() + dest_path = os.path.join(tmp_dir, prefix) + dest_dirname = os.path.dirname(dest_path) + os.makedirs(dest_dirname, exist_ok=True) + for folder in os.listdir(src_dirname): + src_folder = os.path.join(src_dirname, folder) + dest_folder = os.path.join(dest_dirname, folder) + if src_folder.startswith(src_path): + if os.path.isdir(src_folder): + shutil.copytree(src_folder, dest_folder) + else: + shutil.copy2(src_folder, dest_folder) + return dest_path + def test_mlm_dataset(self): command_args = get_default_args() - command_args["--data-path"] = f"{self.data_dir}/gpt2/meg-gpt2-openwebtext_text_document" + data_path = self.copy_data_to_temp(self.data_dir, "gpt2/meg-gpt2-openwebtext_text_document") + command_args["--data-path"] = data_path command_args["--noise-density"] = "0.15" command_args["--mean-noise-span-length"] = "3" command_args["--vocab-extra-ids"] = "100" @@ -110,10 +169,10 @@ def test_mlm_dataset(self): self.assertEqual(sample["input_tokens"][-1], tokenizer.sep) self.assertEqual(sample["target_tokens"][-1], tokenizer.sep) - def test_mtf_dataset(self): + def test_decoder_packed_mtf_dataloader(self): command_args = get_default_args() - command_args["--data-path"] = f"{self.data_dir}/gpt2/ag_news_prompt" - command_args["--dataloader-type"] = "decoder_packed" + data_path = self.copy_data_to_temp(self.data_dir, "gpt2/ag_news_prompt") + command_args["--data-path"] = data_path with patch('sys.argv', flatten_arguments(command_args)): with mockenv_context(**self.dist_env_1_gpu): @@ -121,57 +180,40 @@ def test_mtf_dataset(self): initialize_megatron() args = get_args() - train_val_test_num_samples = [ - args.train_iters * args.global_batch_size, - args.eval_iters * args.global_batch_size, - 0 - ] - train_ds, valid_ds, test_ds = mtf_dataset.build_train_valid_test_datasets( - data_prefix=args.data_path, - data_impl=args.data_impl, - splits_string=args.split, - # TODO @thomasw21 figure how that value works - train_valid_test_num_samples=train_val_test_num_samples, - seed=args.seed, - skip_warmup=(not args.mmap_warmup) - ) - - # TODO @thomasw21 make sure that input and target are aligned. - - - def test_mtf_packed_dataloader(self): - command_args = get_default_args() - command_args["--data-path"] = f"{self.data_dir}/gpt2/ag_news_prompt" - command_args["--dataloader-type"] = "decoder_packed" - - with patch('sys.argv', flatten_arguments(command_args)): - with mockenv_context(**self.dist_env_1_gpu): - deepspeed.init_distributed() - initialize_megatron() + tokenizer = get_tokenizer() + # Hack: `gpt2` doesn't have a padding token, so we override that value. + tokenizer.tokenizer.pad_token_id = tokenizer.tokenizer.eos_token_id - args = get_args() train_val_test_num_samples = [ args.train_iters * args.global_batch_size, args.eval_iters * args.global_batch_size, 0 ] - train_ds, valid_ds, test_ds = mtf_dataset.build_train_valid_test_datasets( + train_ds, valid_ds, test_ds = decoder_packed_mtf_dataset.build_train_valid_test_datasets( data_prefix=args.data_path, data_impl=args.data_impl, splits_string=args.split, # TODO @thomasw21 figure how that value works train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length + 1, + pad_token=tokenizer.pad, + eos_token=tokenizer.eos, seed=args.seed, skip_warmup=(not args.mmap_warmup) ) - batch_sampler = build_pretraining_data_loader( + batch_iterator = build_pretraining_data_loader( train_ds, consumed_samples=0, num_workers=4 ) last_padding_size = 0 - for i, items in enumerate(batch_sampler): - micro_batch_size, seq_length = items["decoder_target_tokens"].shape + for i, items in enumerate(batch_iterator): + micro_batch_size, seq_length = items["decoder_token_ids"].shape + + # Check dtypes + self.assertEqual(items["decoder_token_ids"].dtype, torch.int64) + self.assertEqual(items["decoder_segment_ids"].dtype, torch.int64) + self.assertEqual(items["decoder_is_inputs"].dtype, torch.bool) # `micro_batch_size` correspond to the one in argument self.assertEqual(micro_batch_size, args.micro_batch_size) @@ -184,7 +226,7 @@ def test_mtf_packed_dataloader(self): # `segment_ids` is [1,2,...] self.assertEqual(segment_ids[:-1], list(range(1, len(segment_ids)))) # `0` signify that the tokens are padding - self.assertIn(segment_ids[-1], [0, len(segment_ids) + 1]) + self.assertIn(segment_ids[-1], [0, len(segment_ids)]) original_samples_count += len([segment_id for segment_id in segment_ids if segment_id != 0]) # Test that we actually pack, ie we have more samples than the `batch_size` @@ -197,3 +239,47 @@ def test_mtf_packed_dataloader(self): # update `last_padding_size` last_padding_size = len([None for segment_id in items["decoder_segment_ids"][micro_batch_size - 1] if segment_id == 0]) + + + def test_finetune_t0_non_causal_decoder_get_batch_pipe(self): + command_args = get_default_args() + command_args["--position-embedding-type"] = "alibi" + + with patch('sys.argv', flatten_arguments(command_args)): + with mockenv_context(**self.dist_env_1_gpu): + deepspeed.init_distributed() + initialize_megatron() + + args = get_args() + tokenizer = get_tokenizer() + # Hack: `gpt2` doesn't have a padding token, so we override that value. + tokenizer.tokenizer.pad_token_id = tokenizer.tokenizer.eos_token_id + + # Dummy data + data = get_dummy_mtf_decoder_packed_data( + micro_batch_size=args.micro_batch_size, + seq_length=args.seq_length, + vocab_size=args.padded_vocab_size, + special_tokens_ids={tokenizer.pad} + ) + + (tokens, position_ids, attention_mask), (labels, loss_mask) = finetune_t0_non_causal_decoder.get_batch_pipe(data) + + tokens = tokens.cpu() + position_ids = position_ids.cpu() + attention_mask = attention_mask.cpu() + labels = labels.cpu() + loss_mask = loss_mask.cpu() + + self.assertEqual(loss_mask.dtype, torch.float) + torch_assert_equal(loss_mask.bool(), ~data["decoder_is_inputs"][:, 1:] * (data["decoder_token_ids"][:, :-1] != tokenizer.pad)) + torch_assert_equal(tokens, data["decoder_token_ids"][:, :-1]) + torch_assert_equal(labels, data["decoder_token_ids"][:, 1:]) + + for batch_id in range(args.micro_batch_size): + segment_cuts = torch.nonzero(data["decoder_segment_ids"][batch_id, 1:] - data["decoder_segment_ids"][batch_id, :-1]) + 1 + for segment_start, segment_end in zip([0, *segment_cuts], [*segment_cuts, args.seq_length]): + self.assertTrue(torch.all(attention_mask[batch_id, 0, segment_start: segment_end, :segment_start])) + self.assertTrue(torch.all(attention_mask[batch_id, 0, segment_start: segment_end, segment_end:])) + + # TODO @thomasw21 make sure that we reset `position_ids` diff --git a/tests/test_model.py b/tests/test_model.py index fa625d764..60c6305ec 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,5 @@ -import unittest from random import randint +from typing import Set from unittest.mock import patch import deepspeed @@ -11,12 +11,16 @@ from packaging import version from megatron import initialize_megatron, get_args, get_tokenizer, global_vars -from megatron.testing_utils import TestCasePlus, mockenv_context, flatten_arguments, torch_assert_equal, require_torch_bf16 +from megatron.model.fused_softmax import ScaledMaskedSoftmax +from megatron.testing_utils import TestCasePlus, mockenv_context, flatten_arguments, torch_assert_equal, \ + torch_assert_close, require_torch_bf16 from megatron.training import setup_model_and_optimizer -from pretrain_gpt import model_provider as gpt_model_provider, get_batch_pipe as get_gpt_batch_pipe -from pretrain_prefix_lm import model_provider as prefix_lm_model_provider, get_batch_pipe as get_prefix_lm_batch_pipe +import pretrain_gpt +import pretrain_prefix_lm +import finetune_t0_non_causal_decoder -def get_default_args(): + +def get_default_args(test_file_dir: str): """return a dictionary with key as argument name and value as additional arguments""" return { # GPT_ARGS @@ -25,8 +29,8 @@ def get_default_args(): "--num-attention-heads": "4", "--seq-length": "256", "--max-position-embeddings": "256", - "--micro-batch-size": "4", - "--global-batch-size": "8", + "--micro-batch-size": "2", + "--global-batch-size": "2", "--lr-decay-iters": "320000", "--lr-decay-style": "cosine", "--lr": "0.00015", @@ -41,6 +45,7 @@ def get_default_args(): "--clip-grad": "1.0", "--lr-warmup-fraction": ".01", "--fp16": "", + "--inference": "", "--attention-dropout": "0", "--hidden-dropout": "0", @@ -53,6 +58,11 @@ def get_default_args(): "--checkpoint-activations": "", # DATA_ARGS + + # DeepSpeed args + "--deepspeed": "", + "--deepspeed_config": f"{test_file_dir}/ds_config_inference.json", + "--zero-stage": "0", } @@ -61,6 +71,48 @@ def equal_vectors(tensor1, tensor2, dim=-1): return torch.linalg.norm(tensor1 - tensor2, dim=dim) == 0 +def iter_out_of_one(one): + return iter([one]) + + +def get_dummy_mtf_decoder_packed_data(micro_batch_size: int, seq_length: int, vocab_size: int, special_tokens_ids: Set[int]): + """Code from `tests/test_dataloaders.py""" + seq_length += 1 + + num_segments = torch.randint(1, 5, ()) + segment_ids = torch.zeros(micro_batch_size, seq_length, dtype=torch.long) + is_inputs = torch.zeros(micro_batch_size, seq_length, dtype=torch.bool) + for batch_id in range(micro_batch_size): + # - `*2`: Hack in order to two start_new_segements to be seperated with two tokens at least + # - `+1`: Hack in order the start_mew_segments not to be 0 + start_new_segments = torch.sort(torch.randperm((seq_length - 2) // 2, )[:num_segments]).values * 2 + 1 + segment_ids[batch_id, start_new_segments] = 1 + + end_inputs = [ + torch.randint(low=start_segment, high=end_segment - 1, size=()) + for start_segment, end_segment in zip([0, *start_new_segments], [*start_new_segments, seq_length]) + ] + for end_input, start_segment in zip(end_inputs, [0, *start_new_segments]): + is_inputs[batch_id][start_segment: end_input + 1] = True + + segment_ids = torch.cumsum(segment_ids, dim=-1) + 1 + + tokens = torch.randint(high=vocab_size, size=(micro_batch_size, seq_length), dtype=torch.long) + flatten_token_view = tokens.view(-1,) + for token_id in range(len(flatten_token_view)): + token = flatten_token_view[token_id] + # While token is a special tokens we change that token + while token in special_tokens_ids: + flatten_token_view[token_id] = (token + 1) % vocab_size + token = flatten_token_view[token_id] + + return { + "decoder_token_ids": tokens, + "decoder_segment_ids": segment_ids, + "decoder_is_inputs": is_inputs + } + + class MyTestCase(TestCasePlus): def setUp(self) -> None: super().setUp() @@ -79,7 +131,7 @@ def setUp(self) -> None: def test_gpt(self): """Test causal invariance, ie past token don't depend on future tokens.""" - command_args = get_default_args() + command_args = get_default_args(self.test_file_dir_str) with patch('sys.argv', flatten_arguments(command_args)): with mockenv_context(**self.dist_env_1_gpu): @@ -88,8 +140,10 @@ def test_gpt(self): args = get_args() tokenizer = get_tokenizer() - model, _, _ = setup_model_and_optimizer(gpt_model_provider) + model, _, _ = setup_model_and_optimizer(pretrain_gpt.model_provider) model = model[0] + model._config.train_micro_batch_size_per_gpu = args.micro_batch_size + model.set_train_batch_size(args.micro_batch_size) token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) @@ -97,18 +151,15 @@ def test_gpt(self): token_ids[token_ids == tokenizer.eod] += 1 token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size - # process batch - input_batch = get_gpt_batch_pipe({"text": token_ids})[0] - # get a modified version of the first batch, we change a specific index changed_index = randint(0, args.seq_length - 2) - input_token_ids_changed = input_batch[0].clone() + token_ids_changed = token_ids.clone() # We increment the token_id by one for that index in order to artificially change the sequence. - input_token_ids_changed[:, changed_index] = \ - (input_token_ids_changed[:,changed_index] + 1) % args.padded_vocab_size + token_ids_changed[:, changed_index] = \ + (token_ids_changed[:, changed_index] + 1) % args.padded_vocab_size - output = model(*input_batch) - output_changed = model(input_token_ids_changed, *input_batch[1:]) + output = model.eval_batch(iter_out_of_one({"text": token_ids}), compute_loss=False) + output_changed = model.eval_batch(iter_out_of_one({"text": token_ids_changed}), compute_loss=False) # All token in past should be unchanged torch_assert_equal(output[:, :changed_index], output_changed[:, :changed_index]) @@ -124,7 +175,7 @@ def test_prefix_lm_reset_attention_mask(self): - Target tokens depend on input tokens. - Input tokens depend on all other input tokens, but never target tokens. """ - command_args = get_default_args() + command_args = get_default_args(self.test_file_dir_str) command_args["--reset-attention-mask"] = "" command_args["--loss-on-targets-only"] = "" @@ -136,8 +187,12 @@ def test_prefix_lm_reset_attention_mask(self): args = get_args() tokenizer = get_tokenizer() - model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider) + model, _, _ = setup_model_and_optimizer(pretrain_prefix_lm.model_provider) model = model[0] + model._config.train_micro_batch_size_per_gpu = args.micro_batch_size + model.set_train_batch_size(args.micro_batch_size) + # we preprocess batch_fn manually + model.set_batch_fn(None) token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) @@ -146,7 +201,7 @@ def test_prefix_lm_reset_attention_mask(self): token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size # process batch to have non empty prefix - input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids}) + input_batch, (labels, loss_mask), prefix_indices = pretrain_prefix_lm.get_batch_pipe({"text": token_ids}) for batch_id in range(len(prefix_indices)): for id in prefix_indices[batch_id]: @@ -155,7 +210,7 @@ def test_prefix_lm_reset_attention_mask(self): # Make sure that the last prefix token predicts the first token. self.assertTrue(loss_mask[batch_id, id -1] == 1) - output = model(*input_batch) + output = model.eval_batch(iter_out_of_one((input_batch, (labels, loss_mask), prefix_indices)), compute_loss=False) ## --------------- CHANGE A TARGET TOKEN --------------------------- # get a modified version of the first batch @@ -170,7 +225,7 @@ def test_prefix_lm_reset_attention_mask(self): token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size # Test change - output_changed_target = model(token_ids_changed_target, *input_batch[1:]) + output_changed_target = model.eval_batch(iter_out_of_one(((token_ids_changed_target, *input_batch[1:]), (labels, loss_mask), prefix_indices)), compute_loss=False) # All token in past should be unchanged torch_assert_equal(output[0, :changed_target_index], output_changed_target[0, :changed_target_index]) @@ -195,7 +250,7 @@ def test_prefix_lm_reset_attention_mask(self): token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1 token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size - output_changed_input = model(token_ids_changed_input, *input_batch[1:]) + output_changed_input = model.eval_batch(iter_out_of_one(((token_ids_changed_input, *input_batch[1:]), (labels, loss_mask), prefix_indices)), compute_loss=False) # All tokens should be changed self.assertFalse( @@ -213,7 +268,7 @@ def test_prefix_lm_wo_reset_attention_mask(self): - Target tokens depend on input tokens. - Input tokens depend on all other input tokens, but never target tokens. """ - command_args = get_default_args() + command_args = get_default_args(self.test_file_dir_str) command_args["--loss-on-targets-only"] = "" @@ -223,11 +278,15 @@ def test_prefix_lm_wo_reset_attention_mask(self): initialize_megatron() args = get_args() - model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider) + model, _, _ = setup_model_and_optimizer(pretrain_prefix_lm.model_provider) model = model[0] + model._config.train_micro_batch_size_per_gpu = args.micro_batch_size + model.set_train_batch_size(args.micro_batch_size) + # we preprocess batch_fn manually + model.set_batch_fn(None) token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) - input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids}) + input_batch, (labels, loss_mask), prefix_indices = pretrain_prefix_lm.get_batch_pipe({"text": token_ids}) for batch_id in range(len(prefix_indices)): id = prefix_indices[batch_id] @@ -236,13 +295,13 @@ def test_prefix_lm_wo_reset_attention_mask(self): # Make sure that the last prefix token predicts the first token. self.assertTrue(loss_mask[batch_id, id -1] == 1) - model(*input_batch) + model.eval_batch(iter_out_of_one((input_batch, (labels, loss_mask), prefix_indices)), compute_loss=False) #TODO: Check all invariants def test_gpt_rotary_embeddings(self): """Test rotary embeddings""" - command_args = get_default_args() + command_args = get_default_args(self.test_file_dir_str) del command_args["--max-position-embeddings"] command_args["--position-embedding-type"] = "rotary" @@ -254,8 +313,10 @@ def test_gpt_rotary_embeddings(self): args = get_args() tokenizer = get_tokenizer() - model, _, _ = setup_model_and_optimizer(gpt_model_provider) + model, _, _ = setup_model_and_optimizer(pretrain_gpt.model_provider) model = model[0] + model._config.train_micro_batch_size_per_gpu = args.micro_batch_size + model.set_train_batch_size(args.micro_batch_size) token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) @@ -263,16 +324,13 @@ def test_gpt_rotary_embeddings(self): token_ids[token_ids == tokenizer.eod] += 1 token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size - # process batch - input_batch = get_gpt_batch_pipe({"text": token_ids})[0] - - model(*input_batch) + model.eval_batch(iter_out_of_one({"text": token_ids}), compute_loss=False) #TODO: Check all invariants @require_torch_bf16 def test_fused_layer_norm(self): - command_args = get_default_args() + command_args = get_default_args(self.test_file_dir_str) # Condition to use custom cuda kernel command_args["--bf16"] = "" @@ -308,6 +366,114 @@ def test_fused_layer_norm(self): torch_assert_equal(mfln_output, torch_layer_norm_output) + def test_fused_masked_softmax(self): + command_args = get_default_args(self.test_file_dir_str) + + with patch('sys.argv', flatten_arguments(command_args)): + with mockenv_context(**self.dist_env_1_gpu): + initialize_megatron() + args = get_args() + + dummy_input = torch.randn( + args.micro_batch_size, + args.num_attention_heads, + args.seq_length, + args.seq_length, + device="cuda", + dtype=args.params_dtype + ) + dummy_attention_mask = torch.randn( + args.micro_batch_size, + 1, # `args.num_attention_heads` not implemented in our cuda kernel + args.seq_length, + args.seq_length, + device="cuda", + dtype=args.params_dtype + ) < 0 + scale = torch.rand(()) + + fused_scaled_softmax = ScaledMaskedSoftmax + + fused_output = fused_scaled_softmax.apply(dummy_input, dummy_attention_mask, scale) + + # mimick the same via torch + output = scale * dummy_input + output = output.masked_fill(dummy_attention_mask, torch.finfo(args.params_dtype).min) + output = F.softmax(output, dim=-1) + + # Test that the nonzeros are the same with the mask + for i in range(args.num_attention_heads): + torch_assert_equal(torch.nonzero(fused_output[:, i]), torch.nonzero(~dummy_attention_mask[:, 0])) + # Cuda kernel produces slightly different results + torch_assert_close(fused_output, output) + + + def test_non_causal_decoder_model_with_packed_input_passed_with_attention_mask_is_not_causal_across_segments(self): + command_args = get_default_args(self.test_file_dir_str) + command_args["--position-embedding-type"] = "alibi" + + with patch('sys.argv', flatten_arguments(command_args)): + with mockenv_context(**self.dist_env_1_gpu): + deepspeed.init_distributed() + initialize_megatron() + + args = get_args() + tokenizer = get_tokenizer() + # Hack: `gpt2` doesn't have a padding token, so we override that value. + tokenizer.tokenizer.pad_token_id = tokenizer.tokenizer.eos_token_id + + data = get_dummy_mtf_decoder_packed_data( + micro_batch_size=args.micro_batch_size, + seq_length=args.seq_length, + vocab_size=args.padded_vocab_size, + special_tokens_ids={tokenizer.pad} + ) + model, _, _ = setup_model_and_optimizer(finetune_t0_non_causal_decoder.model_provider) + model = model[0] + model._config.train_micro_batch_size_per_gpu = args.micro_batch_size + model.set_train_batch_size(args.micro_batch_size) + + output = model.eval_batch(iter_out_of_one(data), compute_loss=False) + + ## --------------- CHANGE A TARGET TOKEN --------------------------- + # change the first token in the first batch to a random value + change_batch_id = 0 + change_token_id = 0 + token_ids_changed = data["decoder_token_ids"].clone() + # We increment the token id on the changed index. + token_ids_changed[change_batch_id, change_token_id] = (token_ids_changed[change_batch_id, change_token_id] + 1) % args.padded_vocab_size + while token_ids_changed[change_batch_id, change_token_id] in {tokenizer.eod, tokenizer.pad}: + token_ids_changed[change_batch_id, change_token_id] = (token_ids_changed[change_batch_id, change_token_id] + 1) % args.padded_vocab_size + + # Test change + output_changed_target = model.eval_batch(iter_out_of_one({**data, "decoder_token_ids": token_ids_changed}), compute_loss=False) + + first_segment_first_batch_id_end = (torch.nonzero(data["decoder_segment_ids"][change_batch_id, 1:] - data["decoder_segment_ids"][change_batch_id, :-1]) + 1)[0] + # Check that values changed in segment 1 of batch_id 0 + self.assertFalse(torch.any( + equal_vectors( + output[change_batch_id, change_token_id:first_segment_first_batch_id_end], + output_changed_target[change_batch_id, change_token_id:first_segment_first_batch_id_end] + ) + )) + # Check that values did not change in other segments of batch_id 0 + torch_assert_equal( + output[change_batch_id, first_segment_first_batch_id_end:], + output_changed_target[change_batch_id, first_segment_first_batch_id_end:] + ) + # Check that values did not change in other segments in other batches + non_change_ids = torch.arange(output.shape[0]) != change_batch_id + torch_assert_equal(output[non_change_ids], output_changed_target[non_change_ids]) + + ## --------------- CHANGE A TARGET TOKEN --------------------------- + # change the last token in the first batch to a pad + token_ids_changed_pad = data["decoder_token_ids"].clone() + segment_ids_changed_pad = data["decoder_segment_ids"].clone() + # We increment the token id on the changed index. + token_ids_changed_pad[change_batch_id, -1] = tokenizer.pad + segment_ids_changed_pad[change_batch_id, -1] = 0 + + # Test model handles padding correctly + output_changed_pad = model.eval_batch(iter_out_of_one({**data, "decoder_token_ids": token_ids_changed_pad, "decoder_segment_ids": segment_ids_changed_pad}), compute_loss=False) -if __name__ == '__main__': - unittest.main() + self.assertFalse(torch.any(torch.isnan(output_changed_pad))) diff --git a/tests/test_training.py b/tests/test_training.py index 79a43c6a2..6ba1ca534 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -18,6 +18,7 @@ import os import glob import re +import shutil import unittest from pathlib import Path from parameterized import parameterized @@ -80,8 +81,27 @@ def setUp(self): if os.path.exists(meg_lock_file_path): os.unlink(meg_lock_file_path) + def copy_data_to_temp(self, root_dir, prefix): + """copy data to temp, and return paths to temp version""" + src_path = os.path.join(root_dir, prefix) + src_dirname = os.path.dirname(src_path) + + tmp_dir = self.get_auto_remove_tmp_dir() + dest_path = os.path.join(tmp_dir, prefix) + dest_dirname = os.path.dirname(dest_path) + os.makedirs(dest_dirname, exist_ok=True) + for folder in os.listdir(src_dirname): + src_folder = os.path.join(src_dirname, folder) + dest_folder = os.path.join(dest_dirname, folder) + if src_folder.startswith(src_path): + if os.path.isdir(src_folder): + shutil.copytree(src_folder, dest_folder) + else: + shutil.copy2(src_folder, dest_folder) + return dest_path + def get_variation_config(self, variation, output_dir, n_samples=None): - data_dir = f"{self.data_dir}/gpt2" + data_dir = self.copy_data_to_temp(self.data_dir,"gpt2") pp_size, tp_size, dp_size = get_3d_dimensions() num_gpus = pp_size * tp_size * dp_size @@ -355,7 +375,8 @@ def test_training_all(self, variation): def test_training_prefix_lm_all(self, loss_on_targets_only, reweight_loss_based_on_position_frequency): # all in one test src_dir = self.src_dir - data_dir = f"{self.data_dir}/gpt2" + data_dir = self.copy_data_to_temp(self.data_dir,"gpt2") + output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) logs_dir = f"{output_dir}/logs" Path(logs_dir).mkdir(parents=True, exist_ok=True) @@ -469,10 +490,122 @@ def test_training_prefix_lm_all(self, loss_on_targets_only, reweight_loss_based_ tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*") self.assertEqual(len(tensorboard_files), 2, "tensorboard files") + def test_training_t0(self): + data_path = self.copy_data_to_temp(self.data_dir, "gpt2/ag_news_prompt") + output_dir = self.get_auto_remove_tmp_dir() + logs_dir = f"{output_dir}/logs" + Path(logs_dir).mkdir(parents=True, exist_ok=True) + + pp_size, tp_size, dp_size = get_3d_dimensions() + num_gpus = pp_size * tp_size * dp_size + + n_samples = 200 # about 37 iterations + exit_interval = 10 # some samples in the first half and then some more in the 2nd half after resume + + args = f""" + --tensor-model-parallel-size {tp_size} + --pipeline-model-parallel-size {pp_size} + --distributed-backend nccl + + --num-layers 2 + --hidden-size 64 + --num-attention-heads 2 + --seq-length 128 + --max-position-embeddings 1024 + --position-embedding-type alibi + --micro-batch-size 1 + --rampup-batch-size 2 2 {n_samples} + --global-batch-size 16 + --train-samples {n_samples} + + --optimizer adam + --adam-beta1 0.9 + --adam-beta2 0.95 + --adam-eps 1e-8 + --lr 1e-4 + --lr-warmup-samples 5 + --clip-grad 1.0 + --weight-decay 1e-1 + --fp16 + + --log-interval 5 + --save-interval 10 + --eval-interval 10 + --eval-iters 5 + --checkpoint-activations + --exit-interval {exit_interval} + --tokenizer-type PretrainedFromHF + --tokenizer-name-or-path bigscience/tokenizer + --log-path {logs_dir} + --save {output_dir}/checkpoints + --load {output_dir}/checkpoints + --data-path {data_path} + --split 90,10,0 + --tensorboard-dir {output_dir}/tensorboard + --tensorboard-queue-size 5 + --log-timers-to-tensorboard + --log-batch-size-to-tensorboard + --log-validation-ppl-to-tensorboard + + --log-level debug + """.split() + + ds_args = f""" + --deepspeed + --deepspeed_config {self.test_file_dir_str}/ds_config.json + --zero-stage 1 + --deepspeed-activation-checkpointing + """.split() + + script = [f"{self.src_dir}/finetune_t0_non_causal_decoder.py"] + launcher = get_launcher(num_gpus) + + cmd = launcher + script + args + ds_args + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + + # 1. test training from scratch (no checkpoint) + with CaptureStdout() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + + # test deepspeed is running + self.assertIn("DeepSpeed info", cs.out) + + # test reports + self.assertIn("consumed samples", cs.out) + + # test there should be no checkpoint this round + self.assertIn(f"Unable to find latest file at {output_dir}/checkpoints/latest", cs.out) + + # test checkpoint saving + self.assertIn("successfully saved checkpoint at iteration", cs.out) + + # test tensorboard + tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*") + self.assertEqual(len(tensorboard_files), 1, "tensorboard files") + + # 2. test training from checkpoint: resume + # now do it again, this time resuming from the checkpoint + with CaptureStdout() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + + # test checkpoint loading + self.assertIn(f"successfully loaded checkpoint from {output_dir}/checkpoints", cs.out) + + # test reports + self.assertIn("consumed samples", cs.out) + + # test checkpoint saving + self.assertIn("successfully saved checkpoint at iteration", cs.out) + + # test tensorboard (1 file from the first run, plus 1 now) + tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*") + self.assertEqual(len(tensorboard_files), 2, "tensorboard files") + @parameterized.expand(["gpt", "prefix", "no_eval"]) def test_mode2_dataloading(self, variation): src_dir = self.src_dir - data_dir = f"{self.data_dir}/gpt2" + data_dir = self.copy_data_to_temp(self.data_dir, "gpt2") output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) logs_dir = f"{output_dir}/logs" Path(logs_dir).mkdir(parents=True, exist_ok=True)