diff --git a/training/DeepSpeed-Domino/README.md b/training/DeepSpeed-Domino/README.md new file mode 100644 index 000000000..3c1f4040b --- /dev/null +++ b/training/DeepSpeed-Domino/README.md @@ -0,0 +1,86 @@ +# Domino Example + +## Install Dependency Libraries +``` +pip install -r requirements.txt +``` + +## Prepare the Dataset +Follow the instructions from [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples_deepspeed/universal_checkpointing#download-and-pre-process-training-dataset) to prepare the training dataset. + +## Execute Domino Training + +To start training, adjust the following parameters in the script as needed: + +- **GPUS_PER_NODE**: Number of GPUs per node. +- **CHECKPOINT_PATH**: Path to the checkpoint, if applicable. +- **VOCAB_FILE**, **MERGE_FILE**, **DATA_PATH**: Paths to the dataset files. +- **--micro-batch-size**: Batch size per GPU. + +### Available Models and Scripts + +| Model | Script | +|------------|--------------------------| +| GPT-3 2.7B | `pretrain_gpt3_2.7b.sh` | +| GPT-3 6.7B | `pretrain_gpt3_6.7b.sh` | +| LLaMA 7B | `pretrain_llama_7b.sh` | +| LLaMA 13B | `pretrain_llama_13b.sh` | + +### Example + +To train the GPT-3 2.7B model, run the following command: + +```bash +bash pretrain_gpt3_2.7b.sh +``` + +The output should look like this: + +``` +training ... +iteration: 1 | loss: 11.318 | iteration time (ms): 2174.0469932556152 +iteration: 2 | loss: 11.307 | iteration time (ms): 1414.4024848937988 +iteration: 3 | loss: 11.323 | iteration time (ms): 1385.9455585479736 +iteration: 4 | loss: 11.310 | iteration time (ms): 1475.5175113677979 +iteration: 5 | loss: 11.306 | iteration time (ms): 1395.7207202911377 +iteration: 6 | loss: 11.315 | iteration time (ms): 1392.2104835510254 +iteration: 7 | loss: 11.314 | iteration time (ms): 1402.6703834533691 +iteration: 8 | loss: 11.309 | iteration time (ms): 1450.613260269165 +iteration: 9 | loss: 11.305 | iteration time (ms): 1473.1688499450684 +iteration: 10 | loss: 11.320 | iteration time (ms): 1398.4534740447998 +[2024-11-04 15:32:30,918] [INFO] [launch.py:351:main] Process 73015 exits successfully. +[2024-11-04 15:32:30,918] [INFO] [launch.py:351:main] Process 73017 exits successfully. +[2024-11-04 15:32:30,919] [INFO] [launch.py:351:main] Process 73014 exits successfully. +[2024-11-04 15:32:30,919] [INFO] [launch.py:351:main] Process 73016 exits successfully. +``` + +## Advanced Usage +You can compile Pytorch and Apex from source for better performance. + +### Compile PyTorch from Source +Compile PyTorch from source could enable JIT script. +``` +git clone -b v2.1.0 https://github.com/pytorch/pytorch.git +git submodule sync +git submodule update --init --recursive +conda install cmake ninja +pip install -r requirements.txt +conda install intel::mkl-static intel::mkl-include +conda install -c pytorch magma-cuda121 # or the magma-cuda* that matches your CUDA version from https://anaconda.org/pytorch/repo +export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} +python setup.py develop + +# Build torchvision +git clone https://github.com/pytorch/vision.git +python setup.py develop +``` + +## Build Apex +``` +git clone https://github.com/NVIDIA/apex +cd apex +# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key... +pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" --config-settings "--build-option=--fast_layer_norm" ./ +# otherwise +pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" --config-settings "--build-option=--fast_layer_norm" ./ +``` \ No newline at end of file diff --git a/training/DeepSpeed-Domino/domino/__init__.py b/training/DeepSpeed-Domino/domino/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/training/DeepSpeed-Domino/domino/arguments.py b/training/DeepSpeed-Domino/domino/arguments.py new file mode 100644 index 000000000..8bc59223a --- /dev/null +++ b/training/DeepSpeed-Domino/domino/arguments.py @@ -0,0 +1,400 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from arguments.py in Megatron-LM + +"""Domino arguments.""" + +import argparse +import os +import types +import math +import torch +import torch.nn.functional as F + +import dataclasses +from dataclasses import dataclass +from typing import Callable +from domino.timer import Timers +from megatron.tokenizer import build_tokenizer + + +_GLOBAL_ARGS = None +_GLOBAL_TOKENIZER = None +_GLOBAL_TIMERS = None + + +def get_args(): + """Return arguments.""" + return _GLOBAL_ARGS + + +def set_args(args): + global _GLOBAL_ARGS + _GLOBAL_ARGS = args + + +def build_tokenizer_g(args): + """Initialize tokenizer.""" + global _GLOBAL_TOKENIZER + _GLOBAL_TOKENIZER = build_tokenizer(args) + return _GLOBAL_TOKENIZER + + +def get_tokenizer(): + """Return tokenizer.""" + return _GLOBAL_TOKENIZER + + +def get_num_microbatches(): + return 1 + + +def init_method_normal(std_dev): + def initialize(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std_dev) + return initialize + + +def scaled_init_method_normal(std_dev, layer_count): + scaled_std_dev = std_dev / math.sqrt(2.0 * layer_count) + def initialize(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=scaled_std_dev) + return initialize + + +def get_timers(): + """Return timers.""" + return _GLOBAL_TIMERS + + +def set_timers(): + """Initialize timers.""" + global _GLOBAL_TIMERS + _GLOBAL_TIMERS = Timers(0, "maxmin") + + +def parse_args(): + """Parse all arguments.""" + parser = argparse.ArgumentParser(description='Domino Arguments', allow_abbrev=False) + parser.add_argument('--num-layers', type=int, default=None, + help='Number of transformer layers.') + parser.add_argument('--hidden-size', type=int, default=None, + help='Tansformer hidden size.') + parser.add_argument('--num-attention-heads', type=int, default=None, + help='Number of transformer attention heads.') + parser.add_argument('--ffn-hidden-size', type=int, default=None, + help='Transformer Feed-Forward Network hidden size. ' + 'This is set to 4*hidden-size if not provided') + parser.add_argument('--seq-length', type=int, default=None, + help='Maximum sequence length to process.') + parser.add_argument('--max-position-embeddings', type=int, default=None, + help='Maximum number of position embeddings to use. ' + 'This is the size of position embedding.') + parser.add_argument('--position-embedding-type', type=str, default='learned_absolute', + choices=['learned_absolute', 'rope'], + help='Position embedding type.') + parser.add_argument('--rotary-percent', type=float, default=1.0, + help='Percent of rotary dimension to use, default 100%') + parser.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None, + help='Sequence length interpolation factor for rotary embeddings.') + parser.add_argument('--hidden-dropout', type=float, default=0.1, + help='Dropout probability for hidden state transformer.') + parser.add_argument('--attention-dropout', type=float, default=0.1, + help='Post attention dropout probability.') + parser.add_argument('--no-masked-softmax-fusion', + action='store_false', + help='Disable fusion of query_key_value scaling, ' + 'masking, and softmax.', + dest='masked_softmax_fusion') + parser.add_argument('--tensor-model-parallel-size', type=int, default=1, + help='Degree of tensor model parallelism.') + parser.add_argument('--local_rank', type=int, default=None, + help='local rank passed from distributed launcher.') + parser.add_argument('--distributed-backend', default='nccl', + choices=['nccl', 'gloo'], + help='Which backend to use for distributed training.') + parser.add_argument('--seed', type=int, default=1234, + help='Random seed used for python, numpy, pytorch, and cuda.') + parser.add_argument('--train-iters', type=int, default=None, + help='Total number of iterations to train over all ' + 'training runs. Note that either train-iters or ' + 'train-samples should be provided.') + parser.add_argument('--micro-batch-size', type=int, default=None, + help='Batch size per model instance (local batch size). ' + 'Global batch size is local batch size times data ' + 'parallel size times number of micro batches.') + parser.add_argument('--global-batch-size', type=int, default=None, + help='Training batch size. If set, it should be a ' + 'multiple of micro-batch-size times data-parallel-size. ' + 'If this value is None, then ' + 'use micro-batch-size * data-parallel-size as the ' + 'global batch size. This choice will result in 1 for ' + 'number of micro-batches.') + parser.add_argument('--lr', type=float, default=None, + help='Initial learning rate. Depending on decay style ' + 'and initial warmup, the learing rate at each ' + 'iteration would be different.') + parser.add_argument('--min-lr', type=float, default=0.0, + help='Minumum value for learning rate. The scheduler' + 'clip values below this threshold.') + parser.add_argument('--lr-warmup-fraction', type=float, default=None, + help='fraction of lr-warmup-(iters/samples) to use ' + 'for warmup (as a float)') + parser.add_argument('--lr-decay-style', type=str, default='linear', + choices=['constant', 'linear', 'cosine', 'inverse-square-root'], + help='Learning rate decay function.') + parser.add_argument('--lr-decay-iters', type=int, default=None, + help='number of iterations to decay learning rate over,' + ' If None defaults to `--train-iters`') + parser.add_argument('--weight-decay', type=float, default=0.01, + help='Weight decay coefficient for L2 regularization.') + parser.add_argument('--clip-grad', type=float, default=1.0, + help='Gradient clipping based on global L2 norm.') + parser.add_argument('--data-path', nargs='*', default=None, + help='Path to the training dataset. Accepted format:' + '1) a single data path, 2) multiple datasets in the' + 'form: dataset1-weight dataset1-path dataset2-weight ' + 'dataset2-path ... It is used with --split when a ' + 'single dataset used for all three: train, valid ' + 'and test. It is exclusive to the other ' + '--*-data-path args') + parser.add_argument('--split', type=str, default='969, 30, 1', + help='Comma-separated list of proportions for training,' + ' validation, and test split. For example the split ' + '`90,5,5` will use 90%% of data for training, 5%% for ' + 'validation and 5%% for test.') + parser.add_argument('--vocab-file', type=str, default=None, + help='Path to the vocab file.') + parser.add_argument('--merge-file', type=str, default=None, + help='Path to the BPE merge file.') + parser.add_argument('--data-impl', type=str, default='infer', + choices=['mmap', 'infer'], + help='Implementation of indexed datasets.') + parser.add_argument('--fp16', action='store_true', + help='Run model in fp16 mode.') + parser.add_argument('--bf16', action='store_true', + help='Run model in bfloat16 mode.') + parser.add_argument('--tokenizer-type', type=str, + default='GPT2BPETokenizer', + choices=['BertWordPieceLowerCase', + 'BertWordPieceCase', + 'GPT2BPETokenizer', + 'SentencePieceTokenizer', + 'GPTSentencePieceTokenizer', + 'NullTokenizer'], + help='What type of tokenizer to use.') + parser.add_argument('--make-vocab-size-divisible-by', type=int, default=128, + help='Pad the vocab size to be divisible by this value.' + 'This is added for computational efficieny reasons.') + parser.add_argument('--llama-model', action='store_true', help='Use LLaMA model.') + parser.add_argument('--swiglu', action='store_true', + help='Use gated linear units and SiLU activation instead of default gelu') + parser.add_argument('--add-bias-linear', action='store_true', + help='Enable bias in the linear layers') + parser.add_argument('--normalization', default='LayerNorm', + choices=['LayerNorm', 'RMSNorm'], + help='Which normalization technique to use.', + dest='normalization') + parser.add_argument('--layernorm-epsilon', type=float, default=1e-5, + help='Layer norm epsilon.') + parser.add_argument('--eval-iters', type=int, default=100, + help='Number of iterations to run for evaluation' + 'validation/test for.') + parser.add_argument('--eval-interval', type=int, default=1000, + help='Interval between running evaluation on ' + 'validation set.') + parser.add_argument('--log-interval', type=int, default=100, + help='Report loss and timing interval.') + parser.add_argument('--save-interval', type=int, default=None, + help='Number of iterations between checkpoint saves.') + + args = parser.parse_args() + + args.rank = int(os.getenv('RANK', '0')) + args.world_size = int(os.getenv("WORLD_SIZE", '1')) + + if args.ffn_hidden_size is None: + args.ffn_hidden_size = 4 * args.hidden_size + if args.swiglu: + args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64 + + args.kv_channels = args.hidden_size // args.num_attention_heads + + args.perform_initialization = True + args.apply_residual_connection_post_layernorm = False + args.no_persist_layer_norm = False + + args.activation_func = F.gelu + args.add_bias_linear = True + args.gated_linear_unit = False + if args.swiglu: + args.activation_func = F.silu + args.gated_linear_unit = True + args.bias_gelu_fusion = False + + init_method_std = 0.02 + args.init_method = init_method_normal(init_method_std) + args.output_layer_init_method = scaled_init_method_normal( + init_method_std, args.num_layers) + + args.optimizer = 'adam' + args.adam_beta1 = 0.9 + args.adam_beta2 = 0.999 + args.adam_eps = 1e-8 + args.weight_decay = 0.01 + args.lr_warmup_init = 0.0 + args.lr_decay_style = 'cosine' + args.start_weight_decay = 0.1 + args.end_weight_decay = 0.1 + args.weight_decay_incr_style ='constant' + args.start_weight_decay = args.weight_decay + args.end_weight_decay = args.weight_decay + args.use_checkpoint_opt_param_scheduler = False + args.override_opt_param_scheduler = False + + args.mmap_warmup = False + + args.num_workers = 1 + args.dataloader_type = 'single' + args.train_data_path = None + args.valid_data_path = None + args.test_data_path = None + args.data_cache_path = None + args.train_samples = None + args.consumed_train_samples = 0 + args.consumed_valid_samples = 0 + args.decoder_seq_length = None + args.reset_position_ids = False + args.reset_attention_mask = False + args.eod_mask_loss = False + args.empty_unused_memory_level = 1 + args.tokenizer_type = 'GPT2BPETokenizer' + + args.loss_scale = 1024 + args.initial_loss_scale = 2**32 + args.min_loss_scale = 1.0 + args.loss_scale_window = 1000 + args.hysteresis = 2 + args.use_distributed_optimizer = False + args.log_num_zeros_in_grad = False + + args.rampup_batch_size = None + # Parameters dtype. + args.accumulate_allreduce_grads_in_fp32 = False + args.params_dtype = torch.float + if args.fp16: + args.params_dtype = torch.half + if args.bf16: + args.params_dtype = torch.bfloat16 + # bfloat16 requires gradient accumulation and all-reduce to + # be done in fp32. + if not args.accumulate_allreduce_grads_in_fp32: + args.accumulate_allreduce_grads_in_fp32 = True + if args.rank == 0: + print('accumulate and all-reduce gradients in fp32 for ' + 'bfloat16 data type.', flush=True) + + args.async_tensor_model_parallel_allreduce = True + args.gradient_accumulation_fusion = True + args.padded_vocab_size = 0 # tokenizer.py + args.model_type = 1 + args.data_parallel_size = 1 + args.DDP_impl = 'local' + args.use_contiguous_buffers_in_local_ddp = True + args.data_parallel_random_init = False + + return args + + +@dataclass +class TransformerConfig(): + """Configuration object for transformers. + """ + sequence_parallel: bool = False + llama_model: bool = False + apply_residual_connection_post_layernorm = False + no_persist_layer_norm = False + + # Initialization + perform_initialization: bool = True + use_cpu_initialization: bool = False + + # Training + fp16: bool = False + bf16: bool = False + params_dtype: torch.dtype = torch.float32 + timers: Callable = None + + # Optimizations + gradient_accumulation_fusion: bool = True + async_tensor_model_parallel_allreduce: bool = True + + # model architecture + num_layers: int = 0 + hidden_size: int = 0 + num_attention_heads: int = 0 + ffn_hidden_size: int = None + kv_channels: int = None + hidden_dropout: float = 0.1 + attention_dropout: float = 0.1 + layernorm_epsilon: float = 1e-5 + layernorm_zero_centered_gamma: bool = False + add_bias_linear: bool = True + swiglu = False + gated_linear_unit: bool = False + activation_func: Callable = F.gelu + bias_gelu_fusion = False + + # initialization + init_method: Callable = None + output_layer_init_method: Callable = None + init_method_std: float = 0.02 + + enable_autocast: bool = False + # autocast_dtype: torch.dtype = None + deallocate_pipeline_outputs: bool = False + no_sync_func: Callable = None + # grad_sync_func: Callable = None + # param_sync_func: Callable = None + + def __post_init__(self): + """ Python dataclass method that is used to modify attributes after initialization. + See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details. + """ + if self.fp16 and self.bf16: + raise ValueError( + f'Only one of self.fp16: {self.fp16} and self.bf16 {self.bf16} should be True.' + ) + + # if self.num_attention_heads % self.tensor_model_parallel_size != 0: + # raise ValueError( + # f"num_attention_heads ({self.num_attention_heads}) must be a multiple of " + # f"tensor_model_parallel_size ({self.tensor_model_parallel_size})." + # ) + + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + + if self.kv_channels is None: + self.kv_channels = self.hidden_size // self.num_attention_heads + + if self.init_method is None: + self.init_method = init_method_normal(self.init_method_std) + + if self.output_layer_init_method is None: + self.output_layer_init_method = scaled_init_method_normal( + self.init_method_std, self.num_layers + ) + +def core_transformer_config_from_args(args): + # Translate args to core transformer configuration + kw_args = {} + for f in dataclasses.fields(TransformerConfig): + if hasattr(args, f.name): + kw_args[f.name] = getattr(args, f.name) + + kw_args['hidden_size'] = args.hidden_size + kw_args['init_method'] = args.init_method + kw_args['output_layer_init_method'] = args.init_method + kw_args['params_dtype'] = args.params_dtype + + return TransformerConfig(**kw_args) diff --git a/training/DeepSpeed-Domino/domino/data/.DS_Store b/training/DeepSpeed-Domino/domino/data/.DS_Store new file mode 100644 index 000000000..d02a1acaa Binary files /dev/null and b/training/DeepSpeed-Domino/domino/data/.DS_Store differ diff --git a/training/DeepSpeed-Domino/domino/data/__init__.py b/training/DeepSpeed-Domino/domino/data/__init__.py new file mode 100644 index 000000000..cd5f898c6 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/data/__init__.py @@ -0,0 +1 @@ +from . import indexed_dataset diff --git a/training/DeepSpeed-Domino/domino/data/data_samplers.py b/training/DeepSpeed-Domino/domino/data/data_samplers.py new file mode 100644 index 000000000..d3df72f66 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/data/data_samplers.py @@ -0,0 +1,183 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import random +import torch +import numpy as np +from torch.utils.data import Dataset + +from domino.arguments import get_args +import domino.parallel_state as mpu + +def build_pretraining_data_loader(dataset, consumed_samples): + """Buld dataloader given an input dataset.""" + + if dataset is None: + return None + args = get_args() + + # Megatron sampler + if args.dataloader_type == 'single': + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size()) + elif args.dataloader_type == 'cyclic': + batch_sampler = MegatronPretrainingRandomSampler( + dataset, + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size(), + data_sharding=args.data_sharding) + else: + raise Exception('{} dataloader type is not supported.'.format( + args.dataloader_type)) + + # Torch dataloader. + return torch.utils.data.DataLoader(dataset, + batch_sampler=batch_sampler, + num_workers=args.num_workers, + pin_memory=True) + +class MegatronPretrainingSampler: + + def __init__(self, total_samples, consumed_samples, micro_batch_size, + data_parallel_rank, data_parallel_size, drop_last=True): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.micro_batch_times_data_parallel_size = \ + self.micro_batch_size * data_parallel_size + self.drop_last = drop_last + + # Sanity checks. + assert self.total_samples > 0, \ + 'no sample to consume: {}'.format(self.total_samples) + assert self.consumed_samples < self.total_samples, \ + 'no samples left to consume: {}, {}'.format(self.consumed_samples, + self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, \ + 'data_parallel_rank should be smaller than data size: {}, ' \ + '{}'.format(self.data_parallel_rank, data_parallel_size) + + def __len__(self): + return self.total_samples + + def get_start_end_idx(self): + start_idx = self.data_parallel_rank * self.micro_batch_size + end_idx = start_idx + self.micro_batch_size + return start_idx, end_idx + + def __iter__(self): + batch = [] + # Last batch will be dropped if drop_last is not set False + for idx in range(self.consumed_samples, self.total_samples): + batch.append(idx) + if len(batch) == self.micro_batch_times_data_parallel_size: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + + +class RandomSeedDataset(Dataset): + + def __init__(self, dataset): + args = get_args() + self.base_seed = args.seed + self.curr_seed = args.seed + self.dataset = dataset + + def __len__(self): + return len(self.dataset) + + def set_epoch(self, epoch): + self.curr_seed = self.base_seed + epoch + + def __getitem__(self, idx): + seed = idx + self.curr_seed + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + return self.dataset[idx] + + +class MegatronPretrainingRandomSampler: + + def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size, + data_parallel_rank, data_parallel_size, data_sharding): + # Keep a copy of input params for later use. + self.dataset = dataset + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.data_parallel_size = data_parallel_size + self.data_sharding = data_sharding + self.micro_batch_times_data_parallel_size = \ + self.micro_batch_size * data_parallel_size + self.last_batch_size = \ + self.total_samples % self.micro_batch_times_data_parallel_size + + # Sanity checks. + assert self.total_samples > 0, \ + 'no sample to consume: {}'.format(self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, \ + 'data_parallel_rank should be smaller than data size: {}, ' \ + '{}'.format(self.data_parallel_rank, data_parallel_size) + + def __len__(self): + return self.total_samples + + def __iter__(self): + active_total_samples = self.total_samples - self.last_batch_size + self.epoch = self.consumed_samples // active_total_samples + current_epoch_samples = self.consumed_samples % active_total_samples + assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 + + if isinstance(self.dataset, RandomSeedDataset): + self.dataset.set_epoch(self.epoch) + + # data sharding and random sampling + if self.data_sharding: + bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ + * self.micro_batch_size + bucket_offset = current_epoch_samples // self.data_parallel_size + start_idx = self.data_parallel_rank * bucket_size + + g = torch.Generator() + g.manual_seed(self.epoch) + random_idx = torch.randperm(bucket_size, generator=g).tolist() + idx_range = [start_idx + x for x in random_idx[bucket_offset:]] + else: + full_bucket_size = (self.total_samples // self.micro_batch_size) \ + * self.micro_batch_size + full_bucket_offset = current_epoch_samples + g = torch.Generator() + g.manual_seed(self.epoch) + idx_range_total = \ + torch.randperm(full_bucket_size, generator=g).tolist() + idx_range_active = idx_range_total[full_bucket_offset:] + idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size] + + batch = [] + # Last batch if not complete will be dropped. + for idx in idx_range: + batch.append(idx) + if len(batch) == self.micro_batch_size: + self.consumed_samples += self.micro_batch_times_data_parallel_size + yield batch + batch = [] diff --git a/training/DeepSpeed-Domino/domino/data/gpt_dataset.py b/training/DeepSpeed-Domino/domino/data/gpt_dataset.py new file mode 100644 index 000000000..da75d77cf --- /dev/null +++ b/training/DeepSpeed-Domino/domino/data/gpt_dataset.py @@ -0,0 +1,469 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import hashlib +import os +import time +import numpy as np +import torch + +from domino.utils import print_rank_0 +import domino.parallel_state as mpu +from domino.data.indexed_dataset import make_dataset as make_indexed_dataset + + +def get_train_valid_test_split_(splits_string, size): + """ Get dataset splits from comma or '/' separated string list.""" + + splits = [] + if splits_string.find(',') != -1: + splits = [float(s) for s in splits_string.split(',')] + elif splits_string.find('/') != -1: + splits = [float(s) for s in splits_string.split('/')] + else: + splits = [float(splits_string)] + while len(splits) < 3: + splits.append(0.) + splits = splits[:3] + splits_sum = sum(splits) + assert splits_sum > 0.0 + splits = [split / splits_sum for split in splits] + splits_index = [0] + for index, split in enumerate(splits): + splits_index.append(splits_index[index] + + int(round(split * float(size)))) + diff = splits_index[-1] - size + for index in range(1, len(splits_index)): + splits_index[index] -= diff + assert len(splits_index) == 4 + assert splits_index[-1] == size + return splits_index + + +def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + seq_length, seed, skip_warmup, + train_data_prefix=None, + valid_data_prefix=None, + test_data_prefix=None, + return_doc_ids=False, *, + data_cache_path=None): + """Build train, valid, and test datasets.""" + + # Single data path provided for train, valid & test + return _build_train_valid_test_datasets(data_prefix[0], + data_impl, splits_string, + train_valid_test_num_samples, + seq_length, seed, skip_warmup, + data_cache_path=data_cache_path) + + +def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + seq_length, seed, skip_warmup, + return_doc_ids=False, *, + data_cache_path=None): + """Build train, valid, and test datasets.""" + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + total_num_of_documents = indexed_dataset.sizes.shape[0] + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + print_rank_0(' > dataset split:') + + def print_split_stats(name, index): + print_rank_0(' {}:'.format(name)) + print_rank_0(' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[index], splits[index + 1], + splits[index + 1] - splits[index])) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + documents = np.arange(start=splits[index], stop=splits[index + 1], + step=1, dtype=np.int32) + dataset = GPTDataset(name, data_prefix, documents, indexed_dataset, + splits_string, + train_valid_test_num_samples[index], + seq_length, seed, + return_doc_ids, + data_cache_path=data_cache_path) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): + """Build indexed dataset.""" + print_rank_0(' > building dataset index ...') + + start_time = time.time() + indexed_dataset = make_indexed_dataset(data_prefix, + data_impl, + skip_warmup) + print_rank_0(' > finished creating indexed dataset in {:4f} ' + 'seconds'.format(time.time() - start_time)) + print_rank_0(' number of documents: {}'.format( + indexed_dataset.sizes.shape[0])) + + return indexed_dataset + + +class GPTDataset(torch.utils.data.Dataset): + + def __init__(self, name, data_prefix, documents, indexed_dataset, + splits_string, num_samples, seq_length, seed, + return_doc_ids=False, *, + data_cache_path=None): + + self.name = name + self.indexed_dataset = indexed_dataset + self.return_doc_ids = return_doc_ids + + # Checks + assert np.min(documents) >= 0 + assert np.max(documents) < indexed_dataset.sizes.shape[0] + + # Build index mappings. + self.doc_idx, self.sample_idx, self.shuffle_idx, self.desc, self.desc_hash = \ + _build_index_mappings(self.name, data_prefix, + documents, self.indexed_dataset.sizes, + splits_string, num_samples, seq_length, seed, + data_cache_path=data_cache_path) + + + def __len__(self): + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + return self.sample_idx.shape[0] - 1 + + def __getitem__(self, idx): + # Get the shuffled index. + idx = self.shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f = self.sample_idx[idx][0] + doc_index_l = self.sample_idx[idx + 1][0] + offset_f = self.sample_idx[idx][1] + offset_l = self.sample_idx[idx + 1][1] + # If we are within the same document, just extract the chunk. + doc_ids = [] + if doc_index_f == doc_index_l: + doc_ids.append(self.doc_idx[doc_index_f]) + sample = self.indexed_dataset.get(self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1) + else: + # Otherwise, get the rest of the initial document. + doc_ids.append(self.doc_idx[doc_index_f]) + sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], + offset=offset_f)] + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + doc_ids.append(self.doc_idx[i]) + sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) + # And finally add the relevant portion of last document. + doc_ids.append(self.doc_idx[doc_index_l]) + sample_list.append(self.indexed_dataset.get( + self.doc_idx[doc_index_l], + length=offset_l + 1)) + sample = np.concatenate(sample_list) + + if self.return_doc_ids: # for retro preprocessing + return {'text': np.array(sample, dtype=np.int64), + 'doc_ids': np.array(doc_ids, dtype=np.int64)} + else: + return {'text': np.array(sample, dtype=np.int64)} + + +def _build_index_mappings(name, data_prefix, documents, sizes, + splits_string, num_samples, seq_length, seed, + *, + data_cache_path): + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + desc = "GPT Dataset\n\n" + desc += f"Data prefix {data_prefix}\n" + desc += f"Dataset name {name}\n" + desc += f"Number of samples {num_samples}\n" + desc += f"Sequence length {seq_length}\n" + desc += f"Random seed {seed}\n" + desc += f"Split {splits_string}\n" + desc_hash = hashlib.md5(desc.encode('utf-8')).hexdigest() + desc_filename = desc_hash + ".dsc" + doc_idx_filename = desc_hash + '_doc_idx.npy' + sample_idx_filename = desc_hash + '_sample_idx.npy' + shuffle_idx_filename = desc_hash + '_shuffle_idx.npy' + + # Look for cache in main data dir first to avoid unnecessary + # duplication, then look in data-cache-path if specified, + # If nothing is found, use the last path looked in + build_indices = True + prefixes = [os.path.join(os.path.dirname(data_prefix), 'index-cache')] + if data_cache_path is not None: + prefixes.append(data_cache_path) + for prefix in prefixes: + idx_path = { + 'desc': os.path.join(prefix, desc_filename), + 'doc': os.path.join(prefix, doc_idx_filename), + 'sample': os.path.join(prefix, sample_idx_filename), + 'shuffle': os.path.join(prefix, shuffle_idx_filename) + } + for f in idx_path.values(): + if not os.path.isfile(f): + break + else: + # Found our files! + build_indices = False + break + data_cache_dir = os.path.dirname(idx_path['desc']) + data_cache_success = True + + # Build the indexed mapping if not exist. + if build_indices and torch.distributed.get_rank() == 0: + print_rank_0(' > WARNING: could not find index map files, building ' + 'the indices on rank 0 ...') + + # For the last epoch, decide whether include the entire epoch + # in the global shuffle or not. + + # If we need only one epoch, then separating last epoch does + # not mean anything. + if num_epochs == 1: + separate_last_epoch = False + print(' > only one epoch required, setting ' + 'separate_last_epoch to False', flush=True) + + else: + # Get the number of samples for the last epoch + num_samples_from_epochs_minus_one = ( + (num_epochs - 1) * tokens_per_epoch - 1) // seq_length + last_epoch_num_samples = num_samples - \ + num_samples_from_epochs_minus_one + assert last_epoch_num_samples >= 0, \ + 'last epoch number of samples should be non-negative.' + num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length + assert last_epoch_num_samples <= (num_samples_per_epoch + 1), \ + 'last epoch number of samples exceeded max value.' + # If we have less than 80% of the samples for the last epoch, + # seperate out the epoch and treat it differently. + # Note: the 80% number is just based on common sense and can + # be adjusted if needed. + separate_last_epoch = (last_epoch_num_samples < + int(0.80 * num_samples_per_epoch)) + if separate_last_epoch: + string = ' > last epoch number of samples ({}) is smaller '\ + 'than 80% of number of samples per epoch ({}), '\ + 'setting separate_last_epoch to True' + else: + string = ' > last epoch number of samples ({}) is larger '\ + 'than 80% of number of samples per epoch ({}), '\ + 'setting separate_last_epoch to False' + print(string.format(last_epoch_num_samples, + num_samples_per_epoch), flush=True) + + + try: + os.makedirs(data_cache_dir, exist_ok=True) + + # description + with open(idx_path['desc'], 'wt') as fd: + fd.write(desc) + + # doc-idx. + start_time = time.time() + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, + separate_last_epoch) + np.save(idx_path['doc'], doc_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save doc-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + # First compile and then import. + from megatron.data import helpers + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, + num_epochs, tokens_per_epoch) + np.save(idx_path['sample'], sample_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save sample-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + if separate_last_epoch: + num_samples_ = num_samples_from_epochs_minus_one + else: + num_samples_ = sample_idx.shape[0] - 1 + shuffle_idx = _build_shuffle_idx(num_samples_, + sample_idx.shape[0] - 1, np_rng) + np.save(idx_path['shuffle'], shuffle_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save shuffle-idx mapping' + ' (seconds): {:4f}'.format(time.time() - start_time)) + except OSError: + print(f'There was an error trying to create the data cache directory ({data_cache_dir})') + print('or a file in it. This defaults to a directory "index-cache" within the directory') + print('the data files are in and can be set with the --data-cache-path argument. Please') + print('ensure you have write access to this directory or specify one that you do have') + print('write access to.') + data_cache_success = False + + counts = torch.cuda.LongTensor([data_cache_success]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + if counts[0].item() != ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())): + print_rank_0("Data index creation unsuccessful, exiting.") + exit() + + # Load mappings. + start_time = time.time() + print_rank_0(f" > loading doc-idx mapping from {idx_path['doc']}") + doc_idx = np.load(idx_path['doc'], allow_pickle=True, mmap_mode='r') + + print_rank_0(f" > loading sample-idx mapping from {idx_path['sample']}") + sample_idx = np.load(idx_path['sample'], allow_pickle=True, mmap_mode='r') + + print_rank_0(f" > loading shuffle-idx mapping from {idx_path['shuffle']}") + shuffle_idx = np.load(idx_path['shuffle'], allow_pickle=True, mmap_mode='r') + + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + print_rank_0(' total number of samples: {}'.format( + sample_idx.shape[0])) + print_rank_0(' total number of epochs: {}'.format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx, desc, desc_hash + + +def _num_tokens(documents, sizes): + """Total number of tokens in the dataset.""" + return np.sum(sizes[documents]) + + +def _num_epochs(tokens_per_epoch, seq_length, num_samples): + """Based on number of samples and sequence lenght, calculate how many + epochs will be needed.""" + num_epochs = 0 + total_tokens = 0 + while True: + num_epochs += 1 + total_tokens += tokens_per_epoch + # -1 is because we need to retrieve seq_length + 1 token each time + # but the last token will overlap with the first token of the next + # sample except for the last sample. + if ((total_tokens - 1) // seq_length) >= num_samples: + return num_epochs + + +def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch): + """Build an array with length = number-of-epochs * number-of-dcuments. + Each index is mapped to a corresponding document.""" + if not separate_last_epoch or num_epochs == 1: + doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1] + doc_idx[:] = documents + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + np_rng.shuffle(doc_idx) + return doc_idx + + doc_idx_first = _build_doc_idx(documents, num_epochs-1, np_rng, False) + doc_idx_last = _build_doc_idx(documents, 1, np_rng, False) + return np.concatenate((doc_idx_first, doc_idx_last)) + + +# def _build_sample_idx(sizes, doc_idx, seq_length, +# num_epochs, tokens_per_epoch): +# """Sample index mapping is a 2D array with sizes +# [number-of-samples + 1, 2] where [..., 0] contains +# the index into `doc_idx` and [..., 1] is the +# starting offset in that document.""" + +# # Total number of samples. For -1 see comments in `_num_epochs`. +# num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length +# sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32) + +# # Index into sample_idx. +# sample_index = 0 +# # Index into doc_idx. +# doc_idx_index = 0 +# # Begining offset for each document. +# doc_offset = 0 +# # Start with first document and no offset. +# sample_idx[sample_index][0] = doc_idx_index +# sample_idx[sample_index][1] = doc_offset +# sample_index += 1 +# while sample_index <= num_samples: +# # Start with a fresh sequence. +# remaining_seq_length = seq_length + 1 +# while remaining_seq_length != 0: +# # Get the document length. +# doc_id = doc_idx[doc_idx_index] +# doc_length = sizes[doc_id] - doc_offset +# # And add it to the current sequence. +# remaining_seq_length -= doc_length +# # If we have more than a full sequence, adjust offset and set +# # remaining length to zero so we return from the while loop. +# # Note that -1 here is for the same reason we have -1 in +# # `_num_epochs` calculations. +# if remaining_seq_length <= 0: +# doc_offset += (remaining_seq_length + doc_length - 1) +# remaining_seq_length = 0 +# else: +# # Otherwise, start from the begining of the next document. +# doc_idx_index += 1 +# doc_offset = 0 +# # Record the sequence. +# sample_idx[sample_index][0] = doc_idx_index +# sample_idx[sample_index][1] = doc_offset +# sample_index += 1 + +# return sample_idx + + +def _build_shuffle_idx(num_samples, total_size, np_rng): + """Build the range [0, size) and shuffle.""" + print(' > building shuffle index with split [0, {}) and [{}, {}) ' + '...'.format(num_samples, num_samples, total_size), flush=True) + + dtype_ = np.uint32 + if total_size >= (np.iinfo(np.uint32).max - 1): + dtype_ = np.int64 + + shuffle_idx_first = np.arange(start=0, stop=num_samples, + step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_first) + if num_samples == total_size: + return shuffle_idx_first + + shuffle_idx_last = np.arange(start=num_samples, stop=total_size, + step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_last) + + return np.concatenate((shuffle_idx_first, shuffle_idx_last)) + diff --git a/training/DeepSpeed-Domino/domino/data/indexed_dataset.py b/training/DeepSpeed-Domino/domino/data/indexed_dataset.py new file mode 100644 index 000000000..e513e7f73 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/data/indexed_dataset.py @@ -0,0 +1,613 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +from functools import lru_cache +import os +import shutil +import struct +from itertools import accumulate + +import numpy as np +import torch +from domino.utils import print_rank_0 + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def get_available_dataset_impl(): + return ['lazy', 'cached', 'mmap'] + + +def infer_dataset_impl(path): + if IndexedDataset.exists(path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return 'cached' + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return 'mmap' + else: + return None + else: + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == 'mmap': + return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, skip_warmup=False, multimodal=False): + if not IndexedDataset.exists(path): + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + if impl == 'infer': + impl = infer_dataset_impl(path) + if impl == 'lazy' and IndexedDataset.exists(path): + return IndexedDataset(path) + elif impl == 'cached' and IndexedDataset.exists(path): + return IndexedCachedDataset(path) + elif impl == 'mmap' and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup, multimodal) + print(f"Unknown dataset implementation: {impl}") + return None + + +def dataset_exists(path, impl): + if impl == 'mmap': + return MMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float64, + 7: np.float32, + 8: np.uint16, +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + '.idx' + + +def data_file_path(prefix_path): + return prefix_path + '.bin' + + +def create_doc_idx(sizes): + doc_idx = [0] + for i, s in enumerate(sizes): + if s == 0: + doc_idx.append(i + 1) + return doc_idx + + +class IndexedDataset(torch.utils.data.Dataset): + """Loader for IndexedDataset""" + _HDR_MAGIC = b'TNTIDX\x00\x00' + + def __init__(self, path): + super().__init__() + self.path = path + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + 'Index file doesn\'t match expected format. ' + 'Make sure that --dataset-impl is configured properly.' + ) + version = f.read(8) + assert struct.unpack('= self._len: + raise IndexError('index out of range') + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return ( + os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + ) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + + def __init__(self, path): + super().__init__(path) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx: ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx: ptx + a.size]) + return a + elif isinstance(idx, slice): + # Hack just to make this work, can optimizer later if necessary + sents = [] + for i in range(*idx.indices(len(self))): + sents.append(self[i]) + return sents + + +class IndexedDatasetBuilder(object): + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float32: 4, + np.float64: 8, + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, 'wb') + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + self.doc_idx = [0] + + def add_item(self, tensor): + bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype)) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in tensor.size(): + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) + + def end_document(self): + self.doc_idx.append(len(self.sizes)) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + doc_offset = len(self.sizes) + + begin = self.data_offsets[-1] + for data_offset in index.data_offsets[1:]: + self.data_offsets.append(begin + data_offset) + self.sizes.extend(index.sizes) + + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + self.doc_idx.extend((doc_offset + index.doc_idx)[1:]) + + with open(data_file_path(another_file), 'rb') as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, 'wb') + index.write(b'TNTIDX\x00\x00') + index.write(struct.pack(' 0: + # Ensure that different pipeline MP stages get different seeds. + # seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank()) + seed = seed_ + # Ensure different data parallel ranks get different seeds + if data_parallel_random_init: + seed = seed + (10 * mpu.get_data_parallel_rank()) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.device_count() > 0: + model_parallel_cuda_manual_seed(seed) + + # Compile dependencies. + _compile_dependencies() + + # No continuation function + return None + + +def _compile_dependencies(): + + args = get_args() + + # ========================= + # Compile dataset C++ code. + # ========================= + # TODO: move this to ninja + if torch.distributed.get_rank() == 0: + start_time = time.time() + print("> compiling dataset index builder ...") + from megatron.data.dataset_utils import compile_helper + + compile_helper() + print( + ">>> done with dataset index builder. Compilation time: {:.3f} " + "seconds".format(time.time() - start_time), + flush=True, + ) + + # ================== + # Load fused kernels + # ================== + + # Custom kernel constraints check. + seq_len = args.seq_length + attn_batch_size = ( + args.num_attention_heads / args.tensor_model_parallel_size + ) * args.micro_batch_size + # Constraints on sequence length and attn_batch_size to enable warp based + # optimization and upper triangular optimization (for causal mask) + custom_kernel_constraint = ( + seq_len > 16 + and seq_len <= 16384 + and seq_len % 4 == 0 + and attn_batch_size % 4 == 0 + ) + # Print a warning. + if not ( + (args.fp16 or args.bf16) + and custom_kernel_constraint + and args.masked_softmax_fusion + ): + if args.rank == 0: + print( + "WARNING: constraints for invoking optimized" + " fused softmax kernel are not met. We default" + " back to unfused kernel invocations.", + flush=True, + ) + + # Always build on rank zero first. + if torch.distributed.get_rank() == 0: + start_time = time.time() + print("> compiling and loading fused kernels ...", flush=True) + fused_kernels.load(args) + torch.distributed.barrier() + else: + torch.distributed.barrier() + fused_kernels.load(args) + # Simple barrier to make sure all ranks have passed the + # compilation phase successfully before moving on to the + # rest of the program. We think this might ensure that + # the lock is released. + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print( + ">>> done with compiling and loading fused kernels. " + "Compilation time: {:.3f} seconds".format(time.time() - start_time), + flush=True, + ) + + +def set_jit_fusion_options(): + """Set PyTorch JIT layer fusion options.""" + # flags required to enable jit fusion kernels + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): + # nvfuser + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(True) + torch._C._debug_set_autodiff_subgraph_inlining(False) + else: + # legacy pytorch fuser + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + + _warmup_jit_function() + + +def _warmup_jit_function(): + """Compilie JIT functions before the main training steps""" + args = get_args() + if args.bf16: + dtype = torch.bfloat16 + elif args.fp16: + dtype = torch.float16 + else: + dtype = torch.float32 + + # Warmup fused bias+gelu + bias = torch.rand( + args.ffn_hidden_size // args.tensor_model_parallel_size, + dtype=dtype, + device="cuda", + ) + input = torch.rand( + ( + args.seq_length, + args.micro_batch_size, + args.ffn_hidden_size // args.tensor_model_parallel_size, + ), + dtype=dtype, + device="cuda", + ) + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for bias_grad, input_grad in zip([True, True], [False, True]): + bias.requires_grad, input.requires_grad = bias_grad, input_grad + for _ in range(5): + output = bias_gelu(bias, input) + del bias, input, output + + # Warmup fused bias+dropout+add + seq_length = args.seq_length + input = torch.rand( + (seq_length, args.micro_batch_size, args.hidden_size), + dtype=dtype, + device="cuda", + ) + residual = torch.rand( + (seq_length, args.micro_batch_size, args.hidden_size), + dtype=dtype, + device="cuda", + ) + bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as( + residual + ) + dropout_rate = 0.1 + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for input_grad, bias_grad, residual_grad in zip( + [False, True], [True, True], [True, True] + ): + input.requires_grad = input_grad + bias.requires_grad = bias_grad + residual.requires_grad = residual_grad + for _ in range(5): + output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate) + del bias, input, residual, output + torch.cuda.empty_cache() diff --git a/training/DeepSpeed-Domino/domino/language_model.py b/training/DeepSpeed-Domino/domino/language_model.py new file mode 100644 index 000000000..2cfb2f9fd --- /dev/null +++ b/training/DeepSpeed-Domino/domino/language_model.py @@ -0,0 +1,193 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from language_model.py in Megatron-LM + +import torch +from torch import einsum, nn +from domino.arguments import get_args +from domino.modules.enums import ModelType +import domino.parallel_state as mpu +from domino.modules.module import DominoModule +from domino.tensor_parallel.comm import GatherFromModelParallelRegion +from domino.tensor_parallel.partition import VocabParallelEmbedding, linear_with_grad_accumulation_and_async_allreduce +from domino.modules.fused_layer_norm import MixedFusedLayerNorm as fused_layer_norm +from domino.modules.fused_func import bias_dropout_add_fused_train, bias_dropout_add_fused_inference, apply_rotary_pos_emb +from domino.tensor_parallel.partition import _initialize_affine_weight_gpu, set_tensor_model_parallel_attributes +from domino.tensor_parallel.partition import ColumnParallelLinear, RowParallelLinearNoComm + +from deepspeed.runtime.domino.transformer import DominoTransformer + +def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, + bias=None): + """LM logits using word embedding weights.""" + args = get_args() + # Parallel logits. + if args.async_tensor_model_parallel_allreduce: + input_parallel = input_ + model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 + async_grad_allreduce = args.async_tensor_model_parallel_allreduce and model_parallel + + # Matrix multiply. + logits_parallel = linear_with_grad_accumulation_and_async_allreduce( + input=input_parallel, + weight=word_embeddings_weight, + bias=bias, + gradient_accumulation_fusion=args.gradient_accumulation_fusion, + async_grad_allreduce=async_grad_allreduce, + sequence_parallel=False) + # Gather if needed. + + if parallel_output: + return logits_parallel + + return GatherFromModelParallelRegion.apply(logits_parallel) + + +def get_language_model(config, num_tokentypes, + encoder_attn_mask_type, + pre_process=True, post_process=True): + language_model = TransformerLanguageModel( + config, + encoder_attn_mask_type, + num_tokentypes=num_tokentypes, + pre_process=pre_process, + post_process=post_process + ) + + return language_model + + +class Embedding(DominoModule): + def __init__(self, hidden_dim, vocab_size, max_seq_len, dropout_prob, config): + super(Embedding, self).__init__() + self.hidden_dim = hidden_dim + self.init_method = config.init_method + args = get_args() + self.word_embeddings = VocabParallelEmbedding( + vocab_size, self.hidden_dim, config=config, init_method=config.init_method + ) + self.use_position_embedding = args.position_embedding_type == 'learned_absolute' + if self.use_position_embedding: + self.position_embeddings = torch.nn.Embedding(max_seq_len, self.hidden_dim) + self.init_method(self.position_embeddings.weight) + self.embedding_dropout = torch.nn.Dropout(dropout_prob) + + def forward(self, input_ids, position_ids): + word_embeds = self.word_embeddings(input_ids) + if self.use_position_embedding: + pos_embeds = self.position_embeddings(position_ids) + combined_embeds = word_embeds + pos_embeds + else: + combined_embeds = word_embeds + + combined_embeds = combined_embeds.transpose(0, 1).contiguous() + combined_embeds = self.embedding_dropout(combined_embeds) + + return combined_embeds + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, seq_len_interpolation_factor=None): + super().__init__() + self.seq_len_interpolation_factor = seq_len_interpolation_factor + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + def forward(self, max_seq_len, offset=0): + seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset + if self.seq_len_interpolation_factor is not None: + seq = seq.type_as(self.inv_freq) + seq *= 1 / self.seq_len_interpolation_factor + freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq) + # first part even vector components, second part odd vector components, + # 2 * dim in dimension size + emb = torch.cat((freqs, freqs), dim=-1) + # emb [seq_length, .., dim] + return emb[:, None, None, :] + + # def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # state_dict.pop(f'{prefix}inv_freq', None) + # return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class TransformerLanguageModel(DominoModule): + def __init__(self, + config, + encoder_attn_mask_type, + num_tokentypes=0, + pre_process=True, + post_process=True): + + args = get_args() + super(TransformerLanguageModel, self).__init__(share_embeddings_and_output_weights=True) + + self.pre_process = pre_process + self.post_process = post_process + self.hidden_size = config.hidden_size + self.num_tokentypes = num_tokentypes + self.init_method = config.init_method + self.encoder_attn_mask_type = encoder_attn_mask_type + self.encoder_hidden_state = None + + if self.pre_process: + self.embedding = Embedding(self.hidden_size, + args.padded_vocab_size, + args.max_position_embeddings, + args.hidden_dropout, + config) + + self.use_rotary_position_embeddings = \ + args.position_embedding_type == 'rope' + if self.use_rotary_position_embeddings: + self.seq_length = args.seq_length + rotary_dim = args.hidden_size // args.num_attention_heads \ + if args.kv_channels is None else args.kv_channels + if args.rotary_percent < 1.0: + rotary_dim = int(rotary_dim * args.rotary_percent) + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim, + seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + ) + + self.encoder = DominoTransformer( + config, ModelType.encoder_or_decoder, mpu, + fused_layer_norm, _initialize_affine_weight_gpu, + ColumnParallelLinear, RowParallelLinearNoComm, apply_rotary_pos_emb, + bias_dropout_add_fused_train, bias_dropout_add_fused_inference, + self_attn_mask_type=self.encoder_attn_mask_type, + pre_process=self.pre_process, + post_process=self.post_process, + ) + + def set_input_tensor(self, input_tensor): + pass + # self.encoder.set_input_tensor(input_tensor[0]) + + def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, + inference_params=None): + + if self.pre_process: + encoder_input = self.embedding(enc_input_ids, enc_position_ids) + else: + encoder_input = None + + rotary_pos_emb = None + if self.use_rotary_position_embeddings: + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + rotary_pos_emb = ((rotary_pos_emb,) * 2) + + encoder_out_size = encoder_input.shape + p_batch_size = encoder_out_size[1] // 2 + dtype = encoder_input.dtype + encoder_output_t = torch.empty(encoder_out_size, dtype=dtype, device=torch.cuda.current_device()) + intra_partitions = 2 + encoder_inputs = torch.tensor_split(encoder_input, intra_partitions, dim=1) + encoder_outputs = self.encoder( + encoder_inputs, + enc_attn_mask, + rotary_pos_emb=rotary_pos_emb) + encoder_output_t[:, 0:p_batch_size, :] = encoder_outputs[0] + encoder_output_t[:, p_batch_size:2*p_batch_size, :] = encoder_outputs[1] + encoder_output = encoder_output_t + + return encoder_output + \ No newline at end of file diff --git a/training/DeepSpeed-Domino/domino/llama_model.py b/training/DeepSpeed-Domino/domino/llama_model.py new file mode 100644 index 000000000..3b929eeeb --- /dev/null +++ b/training/DeepSpeed-Domino/domino/llama_model.py @@ -0,0 +1,91 @@ +import torch +from domino.arguments import get_args +from domino.language_model import parallel_lm_logits +from domino.modules.enums import AttnMaskType +from domino.modules.module import DominoModule +from domino.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy +from domino.language_model import get_language_model + + +def post_language_model_processing(lm_output, labels, logit_weights, parallel_output): + output = parallel_lm_logits(lm_output, logit_weights, parallel_output) + labels = labels.transpose(0, 1).contiguous() + loss = vocab_parallel_cross_entropy(output.float(), labels) + loss = loss.transpose(0, 1).contiguous() + return loss + + +class LLaMAModel(DominoModule): + """LLaMA Language model.""" + + def __init__( + self, + config, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True, + ): + args = get_args() + super(LLaMAModel, self).__init__( + config=config, + share_embeddings_and_output_weights=True) + + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + self.padded_vocab_size = args.padded_vocab_size + self.language_model = get_language_model( + config=config, + num_tokentypes=num_tokentypes, + encoder_attn_mask_type=AttnMaskType.causal, + pre_process=self.pre_process, + post_process=self.post_process, + ) + self.initialize_word_embeddings() + self.lm_head = torch.nn.Linear( + args.hidden_size, args.padded_vocab_size, bias=False + ) + + def set_input_tensor(self, input_tensor): + self.language_model.set_input_tensor(input_tensor) + + def _causal_lm_process(self, lm_output, labels): + lm_output = lm_output.transpose(0, 1) + logits = self.lm_head(lm_output) + loss = None + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., :-1].contiguous() + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=0) + shift_logits = shift_logits.view(-1, self.padded_vocab_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + return loss + + def forward( + self, + input_ids, + position_ids, + attention_mask, + retriever_input_ids=None, + retriever_position_ids=None, + retriever_attn_mask=None, + labels=None, + inference_params=None, + ): + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + retriever_input_ids=retriever_input_ids, + retriever_position_ids=retriever_position_ids, + retriever_attn_mask=retriever_attn_mask, + inference_params=inference_params, + ) + + if self.post_process: + return self._causal_lm_process(lm_output=lm_output, labels=labels) + else: + return lm_output diff --git a/training/DeepSpeed-Domino/domino/modules/__init__.py b/training/DeepSpeed-Domino/domino/modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/training/DeepSpeed-Domino/domino/modules/distributed.py b/training/DeepSpeed-Domino/domino/modules/distributed.py new file mode 100644 index 000000000..4c69aa9e7 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/modules/distributed.py @@ -0,0 +1,117 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from distributed.py in Megatron-LM + +import math +import torch +import domino.parallel_state as mpu + + +class FlattenMemory: + + def __init__(self, numel, dtype): + self.numel = numel + data_parallel_world_size = mpu.get_data_parallel_world_size() + self.numel_padded = data_parallel_world_size * \ + int(math.ceil(numel / data_parallel_world_size)) + self.dtype = dtype + self.data = torch.zeros(self.numel_padded, + dtype=self.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + + + def get(self, shape, start_index): + end_index = start_index + shape.numel() + assert end_index <= self.numel + buffer_tensor = self.data[start_index:end_index] + buffer_tensor = buffer_tensor.view(shape) + return buffer_tensor + + +class DistributedDataParallel(torch.nn.Module): + + def __init__(self, module, + accumulate_allreduce_grads_in_fp32, + use_contiguous_buffers): + + super(DistributedDataParallel, self).__init__() + + self.module = module + self.accumulate_allreduce_grads_in_fp32 \ + = accumulate_allreduce_grads_in_fp32 + self.use_contiguous_buffers = use_contiguous_buffers + + if self.accumulate_allreduce_grads_in_fp32: + assert self.use_contiguous_buffers + + if not self.use_contiguous_buffers: + self._grad_buffers = None + self._grad_buffer_param_index_map = None + return + + self._grad_buffers = {} + self._grad_buffer_param_index_map = {} + + def _get_buffer_type(param): + return torch.float if \ + self.accumulate_allreduce_grads_in_fp32 else param.dtype + + type_num_elements = {} + for param in self.module.parameters(): + if param.requires_grad: + dtype = _get_buffer_type(param) + type_num_elements[dtype] = type_num_elements.get(dtype, 0) \ + + param.data.nelement() + + # Allocate the memory. + for dtype, num_elements in type_num_elements.items(): + self._grad_buffers[dtype] = FlattenMemory(num_elements, dtype) + + self.grad_accs = [] + for param in self.module.parameters(): + if param.requires_grad: + dtype = _get_buffer_type(param) + type_num_elements[dtype] -= param.data.nelement() + param.main_grad = self._grad_buffers[dtype].get( + param.data.shape, type_num_elements[dtype]) + if dtype not in self._grad_buffer_param_index_map: + self._grad_buffer_param_index_map[dtype] = {} + self._grad_buffer_param_index_map[dtype][param] = ( + type_num_elements[dtype], + type_num_elements[dtype] + param.data.nelement(), + ) + # Backward hook. + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + grad_acc.register_hook(self._make_param_hook(param)) + self.grad_accs.append(grad_acc) + + + def _make_param_hook(self, param): + """Create the all-reduce hook for backprop.""" + # Hook used for back-prop. + def param_hook(*unused): + # Add the gradient to the buffer. + if param.grad is not None: + # The gradient function of linear layers is fused with GEMMs + param.main_grad.add_(param.grad.data) + # Now we can deallocate grad memory. + param.grad = None + return param_hook + + + def zero_grad_buffer(self): + assert self._grad_buffers is not None, 'buffers are not initialized.' + for _, buffer_ in self._grad_buffers.items(): + buffer_.data.zero_() + + + def broadcast_params(self): + for param in self.module.parameters(): + torch.distributed.broadcast(param.data, + src=mpu.get_data_parallel_src_rank(), + group=mpu.get_data_parallel_group()) + + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) diff --git a/training/DeepSpeed-Domino/domino/modules/enums.py b/training/DeepSpeed-Domino/domino/modules/enums.py new file mode 100644 index 000000000..7ec422a86 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/modules/enums.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from enums.py in Megatron-LM + +import enum + +class LayerType(enum.Enum): + encoder = 1 + decoder = 2 + +class AttnType(enum.Enum): + self_attn = 1 + cross_attn = 2 + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + +class ModelType(enum.Enum): + encoder_or_decoder = 1 + encoder_and_decoder = 2 diff --git a/training/DeepSpeed-Domino/domino/modules/fused_bias_gelu.py b/training/DeepSpeed-Domino/domino/modules/fused_bias_gelu.py new file mode 100644 index 000000000..5c9f341b5 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/modules/fused_bias_gelu.py @@ -0,0 +1,44 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from fused_bias_gelu.py in Megatron-LM + +import torch + + +###### BIAS GELU FUSION/ NO AUTOGRAD ################ +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) + +@torch.jit.script +def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_back(g, bias, y): + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff*g + +class GeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(bias, input) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_back(grad_output, bias, input) + return tmp, tmp + +bias_gelu_impl = GeLUFunction.apply diff --git a/training/DeepSpeed-Domino/domino/modules/fused_func.py b/training/DeepSpeed-Domino/domino/modules/fused_func.py new file mode 100644 index 000000000..7a38dc855 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/modules/fused_func.py @@ -0,0 +1,96 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +from typing import Optional +import torch + + +class AddDropoutFuseFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input1, bias1, residual1, input2, bias2, residual2, prob, training): + if bias1 is not None and bias2 is not None: + output1, mask1, output2, mask2 = torch._C._nn.native_add_dropout_add_fuse( + input1, bias1, residual1, input2, bias2, residual2, prob, training + ) + else: + output1, mask1, output2, mask2 = torch._C._nn.native_add_dropout_fuse( + input1, residual1, input2, residual2, prob, training + ) + scale = 1.0 / (1.0 - prob) + ctx.save_for_backward(mask1, mask2) + ctx.scale = scale + ctx.with_bias = bias1 is not None and bias2 is not None + return output1, output2 + + @staticmethod + def backward(ctx, grad_output1, grad_output2): + (mask1, mask2) = ctx.saved_tensors + scale = ctx.scale + with_bias = ctx.with_bias + if with_bias: + grad_input1, grad_bias1, grad_residual1, grad_input2, grad_bias2, grad_residual2 = ( + torch._C._nn.native_add_dropout_add_fuse_2(grad_output1, mask1, grad_output2, mask2, scale) + ) + else: + grad_input1, grad_residual1, grad_input2, grad_residual2 = ( + torch._C._nn.native_add_dropout_fuse_2(grad_output1, mask1, grad_output2, mask2, scale) + ) + grad_bias1 = None + grad_bias2 = None + return grad_input1, grad_bias1, grad_residual1, grad_input2, grad_bias2, grad_residual2, None, None + + +class AddDropoutFuse(torch.nn.Module): + def __init__(self): + super(AddDropoutFuse, self).__init__() + + def forward(self, input1, bias1, residual1, input2, bias2, residual2, prob, training): + return AddDropoutFuseFunction.apply(input1, bias1, residual1, input2, bias2, residual2, prob, training) + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor + if bias is not None: + x = x + bias + out = torch.nn.functional.dropout(x, p=prob, training=training) + out = residual + out + return out + + +@torch.jit.script +def bias_dropout_add_fused_train(x: torch.Tensor, + bias: Optional[torch.Tensor], + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_fused_inference(x: torch.Tensor, + bias: Optional[torch.Tensor], + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, False) + + +def _rotate_half(x): + """ + change sign so the last dimension becomes [-odd, +even] + """ + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs): + """ + input tensor t is of shape [seq_length, ..., dim] + rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] + check https://kexue.fm/archives/8265 for detailed formulas + """ + rot_dim = freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin()) + return torch.cat((t, t_pass), dim=-1) \ No newline at end of file diff --git a/training/DeepSpeed-Domino/domino/modules/fused_layer_norm.py b/training/DeepSpeed-Domino/domino/modules/fused_layer_norm.py new file mode 100644 index 000000000..72e6463a9 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/modules/fused_layer_norm.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copied and modified from NVIDIA apex + +import numbers +import torch +from torch.nn.parameter import Parameter +from torch.nn import init +from domino.utils import make_viewless_tensor + +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNormFN + HAVE_PERSIST_LAYER_NORM = True +except: + HAVE_PERSIST_LAYER_NORM = False + +from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction + + +class MixedFusedLayerNorm(torch.nn.Module): + + def __init__(self, normalized_shape, eps=1e-5, + no_persist_layer_norm=True): + super(MixedFusedLayerNorm, self).__init__() + + persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096, + 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, + 24576, 25600, 30720, 32768, 40960, 49152, 65536] + if normalized_shape not in persist_ln_hidden_sizes or \ + not HAVE_PERSIST_LAYER_NORM: + no_persist_layer_norm = True + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.weight = Parameter(torch.Tensor(*normalized_shape)) + self.bias = Parameter(torch.Tensor(*normalized_shape)) + self.reset_parameters() + + self.no_persist_layer_norm = no_persist_layer_norm + + + def reset_parameters(self): + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input): + weight = self.weight + + if self.no_persist_layer_norm: + return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps) + else: + output = FastLayerNormFN.apply(input, weight, self.bias, self.eps) + + output = make_viewless_tensor(inp = output, + requires_grad = input.requires_grad, + keep_graph = True) + + return output diff --git a/training/DeepSpeed-Domino/domino/modules/module.py b/training/DeepSpeed-Domino/domino/modules/module.py new file mode 100644 index 000000000..b89bbc21f --- /dev/null +++ b/training/DeepSpeed-Domino/domino/modules/module.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from module.py in Megatron-LM + +import torch +from torch.autograd import Variable +from torch.nn.parameter import Parameter + +from domino.arguments import get_args +import domino.parallel_state as mpu + +_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) +_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) +_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor) + + +def param_is_not_shared(param): + return not hasattr(param, 'shared') or not param.shared + + +class DominoModule(torch.nn.Module): + """extensions of torch Module.""" + + def __init__(self, config=None, share_embeddings_and_output_weights=True): + super(DominoModule, self).__init__() + self.config = config + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + + def initialize_word_embeddings(self): + self.share_embeddings_and_output_weights = True + return + + def shared_embedding_or_output_weight(self): + if self.pre_process: + return self.language_model.embedding.word_embeddings.weight + else: + if not self.share_embeddings_and_output_weights: + raise Exception('shared_embedding_or_output_weight() called for last ' + 'stage, but share_embeddings_and_output_weights is false') + return self.word_embeddings.weight + + +def conversion_helper(val, conversion): + """Apply conversion to val. Recursively apply conversion if `val` + #is a nested tuple/list structure.""" + if not isinstance(val, (tuple, list)): + return conversion(val) + rtn = [conversion_helper(v, conversion) for v in val] + if isinstance(val, tuple): + rtn = tuple(rtn) + return rtn + + +def fp32_to_float16(val): + """Convert fp32 `val` to fp16/bf16""" + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, _FLOAT_TYPES): + val = val.half() + return val + return conversion_helper(val, half_conversion) + + +def float16_to_fp32(val): + """Convert fp16/bf16 `val` to fp32""" + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)): + val = val.float() + return val + return conversion_helper(val, float_conversion) + + +class Float16Module(torch.nn.Module): + + def __init__(self, module, args): + super(Float16Module, self).__init__() + self.add_module('module', module.half()) + + def set_input_tensor(self, input_tensor): + return self.module.set_input_tensor(input_tensor) + + def forward(self, *inputs, **kwargs): + if mpu.is_pipeline_first_stage(): + inputs = fp32_to_float16(inputs) + outputs = self.module(*inputs, **kwargs) + if mpu.is_pipeline_last_stage(): + outputs = float16_to_fp32(outputs) + return outputs + diff --git a/training/DeepSpeed-Domino/domino/modules/utils.py b/training/DeepSpeed-Domino/domino/modules/utils.py new file mode 100644 index 000000000..5c5938b90 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/modules/utils.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copied and modified from Megatron-LM + +import math +import torch +from domino.arguments import get_args + +def init_method_normal(std_dev): + def initialize(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std_dev) + return initialize + +def scaled_init_method_normal(std_dev, layer_count): + scaled_std_dev = std_dev / math.sqrt(2.0 * layer_count) + def initialize(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=scaled_std_dev) + return initialize + + +def get_linear_layer(input_dim, output_dim, init_method): + linear_layer = torch.nn.Linear(input_dim, output_dim) + if get_args().perform_initialization: + init_method(linear_layer.weight) + with torch.no_grad(): + linear_layer.bias.zero_() + return linear_layer + +def param_is_not_shared(param): + return not hasattr(param, 'shared') or not param.shared diff --git a/training/DeepSpeed-Domino/domino/optimizer/__init__.py b/training/DeepSpeed-Domino/domino/optimizer/__init__.py new file mode 100644 index 000000000..74a7e2c97 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/optimizer/__init__.py @@ -0,0 +1,144 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +from apex.optimizers import FusedAdam as Adam +from apex.optimizers import FusedSGD as SGD + +from domino.arguments import get_args + +from .distrib_optimizer import DistributedOptimizer +from .grad_scaler import ConstantGradScaler, DynamicGradScaler +from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer + + +def get_param_groups(modules, + no_weight_decay_cond, + scale_lr_cond, + lr_mult): + """creates param groups based on weight decay condition (regularized vs non regularized) + and learning rate scale condition (args.lr vs lr_mult * args.lr) + scale_lr_cond is used during finetuning where head of the network requires a scaled + version of the base learning rate. + """ + wd_no_scale_lr = [] + wd_scale_lr = [] + no_wd_no_scale_lr = [] + no_wd_scale_lr = [] + for module in modules: + for name, param in module.named_parameters(): + if not param.requires_grad: + continue + + if no_weight_decay_cond is not None: + no_wd = no_weight_decay_cond(name, param) + else: + # do not regularize biases nor Norm parameters + no_wd = name.endswith(".bias") or len(param.shape) == 1 + + if scale_lr_cond is not None: + scale_lr = scale_lr_cond(name, param) + else: + scale_lr = False + + if not no_wd and not scale_lr: + wd_no_scale_lr.append(param) + elif not no_wd and scale_lr: + wd_scale_lr.append(param) + elif no_wd and not scale_lr: + no_wd_no_scale_lr.append(param) + else: + no_wd_scale_lr.append(param) + + param_groups = [] + if len(wd_no_scale_lr): + param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0}) + if len(wd_scale_lr): + param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult}) + if len(no_wd_no_scale_lr): + param_groups.append({'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0}) + if len(no_wd_scale_lr): + param_groups.append({'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult}) + + return param_groups + +def get_megatron_optimizer(model, + no_weight_decay_cond=None, + scale_lr_cond=None, + lr_mult=1.0): + args = get_args() + + # Base optimizer. + param_groups = get_param_groups(model, + no_weight_decay_cond, + scale_lr_cond, + lr_mult) + + if args.optimizer == 'adam': + optimizer = Adam(param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_eps) + elif args.optimizer == 'sgd': + optimizer = SGD(param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + momentum=args.sgd_momentum) + else: + raise Exception('{} optimizer is not supported.'.format( + args.optimizer)) + + # Determine whether the params have main-grad field. + params_have_main_grad = False + if args.DDP_impl == 'local': + params_have_main_grad = True + + # Mixed precision optimizer. + # - Note: both the Float16Optimizer and the DistributedOptimizer inherit + # from the MixedPrecisionOptimizer, which manages any optimizer where + # the model params and main params are distinct. + if args.fp16 or args.bf16 or args.use_distributed_optimizer: + + # Grad scaler: + # if loss-scale is provided, instantiate the constant scaler. + # if we are using fp16 and loss-scale is not present, use a + # dynamic scaler. + # otherwise we are running in bf16 with no loss-scale so + # leave it as None. + grad_scaler = None + + # Constant loss scale. + if args.loss_scale: + grad_scaler = ConstantGradScaler(args.loss_scale) + + # Dynamic loss scale. + else: + if args.fp16: + grad_scaler = DynamicGradScaler( + initial_scale=args.initial_loss_scale, + min_scale=args.min_loss_scale, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=args.loss_scale_window, + hysteresis=args.hysteresis) + + # Megatron optimizer. + opt_ty = DistributedOptimizer \ + if args.use_distributed_optimizer else \ + Float16OptimizerWithFloat16Params + return opt_ty(optimizer, + args.clip_grad, + args.log_num_zeros_in_grad, + params_have_main_grad, + args.use_contiguous_buffers_in_local_ddp, + args.fp16, + args.bf16, + args.params_dtype, + grad_scaler, + model) + + # FP32. + return FP32Optimizer(optimizer, args.clip_grad, + args.log_num_zeros_in_grad, + params_have_main_grad, + args.use_contiguous_buffers_in_local_ddp, + model) diff --git a/training/DeepSpeed-Domino/domino/optimizer/clip_grads.py b/training/DeepSpeed-Domino/domino/optimizer/clip_grads.py new file mode 100644 index 000000000..092e34729 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/optimizer/clip_grads.py @@ -0,0 +1,134 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Gradient clipping.""" + +import torch +from torch import inf + +from apex.multi_tensor_apply import multi_tensor_applier +import amp_C + +from domino.modules.module import param_is_not_shared +from domino.tensor_parallel.partition import param_is_not_tensor_parallel_duplicate + +def clip_grad_norm_fp32(parameters, grads_for_norm, + max_norm, norm_type=2, + model_parallel_group=None): + """Clips gradient norm of an iterable of parameters whose gradients + are in fp32. + + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single + Tensor that will be used for calculating the grad norm. + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + model_parallel_group (group): given the nature of the distributed + optimizer, this is passed as an argument. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + if isinstance(grads_for_norm, torch.Tensor): + grads_for_norm = [grads_for_norm] + + # Grads. + grads = [] + for param in parameters: + if param.grad is not None: + assert param.grad.type() == 'torch.cuda.FloatTensor' + grads.append(param.grad.detach()) + + # Norm parameters. + max_norm = float(max_norm) + norm_type = float(norm_type) + total_norm = 0.0 + + # Calculate norm. + if norm_type == inf: + total_norm = max(grad.abs().max() for grad in grads_for_norm) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + # Take max across all model-parallel GPUs. + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.MAX, + group=model_parallel_group) + total_norm = total_norm_cuda[0].item() + + else: + if norm_type == 2.0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + # Use apex's multi-tensor applier for efficiency reasons. + # Multi-tensor applier takes a function and a list of list + # and performs the operation on that list all in one kernel. + if grads_for_norm: + grad_norm, _ = multi_tensor_applier( + amp_C.multi_tensor_l2norm, + dummy_overflow_buf, + [grads_for_norm], + False # no per-parameter norm + ) + else: + grad_norm = torch.cuda.FloatTensor([0]) + # Since we will be summing across data parallel groups, + # we need the pow(norm-type). + total_norm = grad_norm ** norm_type + + else: + for grad in grads_for_norm: + grad_norm = torch.norm(grad, norm_type) + total_norm += grad_norm ** norm_type + + # Sum across all model-parallel GPUs. + torch.distributed.all_reduce(total_norm, + op=torch.distributed.ReduceOp.SUM, + group=model_parallel_group) + total_norm = total_norm.item() ** (1.0 / norm_type) + + # Scale. + clip_coeff = max_norm / (total_norm + 1.0e-6) + if clip_coeff < 1.0: + dummy_overflow_buf = torch.cuda.IntTensor([0]) + multi_tensor_applier(amp_C.multi_tensor_scale, + dummy_overflow_buf, + [grads, grads], + clip_coeff) + + return total_norm + + +def count_zeros_fp32(parameters, model_parallel_group): + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + total_num_zeros = torch.cuda.FloatTensor([0.0]) + for param in parameters: + grad_not_none = param.grad is not None + is_not_shared = param_is_not_shared(param) + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + if grad_not_none and is_not_shared and is_not_tp_duplicate: + grad = param.grad.detach() + num_zeros = grad.numel() - torch.count_nonzero(grad) + total_num_zeros = num_zeros + total_num_zeros + + # Sum across all model-parallel GPUs. + torch.distributed.all_reduce(total_num_zeros, + op=torch.distributed.ReduceOp.SUM, + group=model_parallel_group) + + total_num_zeros = total_num_zeros.item() + + return total_num_zeros diff --git a/training/DeepSpeed-Domino/domino/optimizer/distrib_optimizer.py b/training/DeepSpeed-Domino/domino/optimizer/distrib_optimizer.py new file mode 100644 index 000000000..951b9a58e --- /dev/null +++ b/training/DeepSpeed-Domino/domino/optimizer/distrib_optimizer.py @@ -0,0 +1,1022 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""Megatron distributed optimizer.""" + + +from apex.optimizers import FusedAdam as Adam +import math +import torch + +from domino.utils import print_rank_0 +import domino.parallel_state as mpu +from domino.tensor_parallel.partition import copy_tensor_model_parallel_attributes +from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper + + +class Range: + """ + A range represents a start and end points for indexing a shard + from a full tensor. + """ + def __init__(self, start, end): + self.start = start + self.end = end + self.size = end - start + def normalize(self, start = 0): + return Range(start, start + self.size) + def __str__(self): + return "%d,%d [%d]" % (self.start, self.end, self.size) + def __len__(self): + return self.end - self.start + + +class DistributedOptimizer(MixedPrecisionOptimizer): + """Distributed optimizer, for all data types (fp16, bf16, and fp32). + + Arguments: + optimizer: base optimizer such as Adam or SGD + clip_grad: clip gradeints with this global L2 norm. Note + that clipping is ignored if clip_grad == 0 + log_num_zeros_in_grad: return number of zeros in the gradients. + params_have_main_grad: flag indicating if parameters have + a `main_grad` field. If this is set, we are assuming + that the model parameters are store in the `main_grad` + field instead of the typical `grad` field. This happens + for the DDP cases where there is a continuous buffer + holding the gradients. For example for bfloat16, we want + to do gradient accumulation and all-reduces in float32 + and as a result we store those gradients in the main_grad. + Note that main grad is not necessarily in float32. + use_contiguous_buffers_in_local_ddp: if true, the local DDP model + is using a contiguous buffer to hold the model grads. + fp16: if true, the model is running in fp16. + bf16: if true, the model is running in bfloat16. + grad_scaler: used for scaling gradients. Note that this can be + None. This case happens when `bf16 = True` and we don't + use any loss scale. Note that for `bf16 = True`, we can have + a constnat gradient scaler. Also for `bf16 = False`, we + always require a grad scaler. + models: list of models (i.e., the virtual pipelining models). This + is used by the distributed optimizer for mapping parameters. + """ + + @classmethod + def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range): + """ + Build mapping from param reference to grad buffer shard ranges. + + This method builds a mapping from parameter references to grad + buffer shard ranges, specific to each data-parallel (DP) rank's + set of 'owned' parameters. Each grad buffer (padded to be an even + multiple of DP-world-size) is conceptually divided into DP-world-size + contiguous regions, where each DP rank 'owns' a contiguous regions. + Ownership in this sense means DP rank is responsible for reducing + the relevant subset of grads, and updating the relevant subset of + params. + + This conceptual partitioning of the grad buffer does NOT respect + parameter boundaries, and as such it is assumed that each created + range references a shard (or subset) of the full parameter. It is + easiest to think of each DP rank as operating (i.e., reducing, + gathering) purely on views into the grad buffer, for all model-to- + main & main-to-model operations. + + This method creates three ranges: + - The param's range within the entire grad buffer (i.e., world index). + - The param's range within the DP rank's local view of the grad buffer. + - The param's range within itself (i.e., its shard). + """ + + # Param range map. + param_world_index_map = model._grad_buffer_param_index_map[dtype] + param_range_map = {} + for param, param_world_indexes in param_world_index_map.items(): + + # Param range. + param_world_start, param_world_end = param_world_indexes + param_local_start = max( + 0, + param_world_start - gbuf_world_range.start) + param_local_end = min( + gbuf_world_range.size, + param_world_end - gbuf_world_range.start) + + # Add param, if within local gbuf range. + if param_local_end > param_local_start: + param_local_range = Range(param_local_start, param_local_end) + param_world_range = param_local_range.normalize( + param_local_start + gbuf_world_range.start) + sub_param_start = max(0, gbuf_world_range.start-param_world_start) + sub_param_range = param_local_range.normalize(sub_param_start) + param_range_map[param] = { + "gbuf_world" : param_world_range, + "gbuf_local" : param_local_range, + "param" : sub_param_range, + } + + return param_range_map + + + @classmethod + def build_model_gbuf_range(cls, model, dtype): + """ + Build mapping between params and their grad buffers. + + This method does the initial setup for the method above. This setup + includes determining the shard ranges into the DDP's grad buffer for + each data-parallel (DP) rank. Each DP rank keeps range info for + all other DP ranks, for the purpose of creating args for + reduce-scatter and all-gather. + """ + + data_parallel_rank = mpu.get_data_parallel_rank() + data_parallel_world_size = mpu.get_data_parallel_world_size() + + # Grad buffer range. + grad_buffer = model._grad_buffers[dtype] + gbuf_size = grad_buffer.numel + max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size)) + + # All world ranges. (i.e., across all data parallel ranks) + gbuf_world_all_ranges = [] + for r in range(data_parallel_world_size): + gbuf_world_start = r * max_gbuf_range_size + gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size) + gbuf_world_range = Range(gbuf_world_start, gbuf_world_end) + gbuf_world_all_ranges.append(gbuf_world_range) + + # Local DP's ranges. + gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank] + gbuf_local_range = gbuf_world_range.normalize() + + # Get each param's ranges. + param_range_map = cls.build_model_gbuf_param_range_map(model, + dtype, + gbuf_world_range) + + # Group into dict. + data = { + "local" : gbuf_local_range, + "world" : gbuf_world_range, + "world_all" : gbuf_world_all_ranges, + "param_map" : param_range_map, + "max_range_size" : max_gbuf_range_size, + } + + return data + + + @classmethod + def build_model_gbuf_range_map(cls, model): + """ + Create param-to-grad-buffer mappings, for grad buffer data types + within a specific virtual model. + """ + return { + dtype : cls.build_model_gbuf_range(model, dtype) + for dtype in model._grad_buffers + } + + + @classmethod + def build_model_param_gbuf_map(cls, model_gbuf_ranges): + """ + Create a reverse of the model_gbuf_ranges, for referencing in + opposite direction. + """ + param_gbuf_map = {} + for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges): + for dtype, gbuf_range_map in model_gbuf_range_map.items(): + for param, param_range_map in gbuf_range_map["param_map"].items(): + param_gbuf_map[param] = (model_index, dtype) + return param_gbuf_map + + + @classmethod + def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges): + """ + Create optimizer groups. + + Given the set of parameter shard ranges that are owned by the current + data-parallel (DP) rank, gather the set of parameters that will be + used (in the method below) to create the current DP's optimizer + groups. + """ + + num_groups = len(param_groups) + + # Param group map. + # World param group map. + # - Store a mapping of for all parameters + # across all DP ranks. This is necessary because it is our first + # cross reference between the DDP mappings and the optimizer group + # parameters. This mapping only for use in the next step of building + # the local mapping over this DP rank's parameters. + world_param_group_map = {} + for group_index, group in enumerate(param_groups): + for param in group["params"]: + assert param.requires_grad + world_param_group_map[param] = group_index + + # Optimizer group ranges & param-group mapping. + # - Build a mapping from groups to their contained parameters, and also + # from parameters to their containing group index and order within + # the group. The group index and order are particularly important for + # saving and loading checkpoints. + local_param_group_map = {} + group_ranges = [ {"params": []} for _ in param_groups ] + for model_gbuf_range_map in model_gbuf_ranges: + for dtype, gbuf_range_map in model_gbuf_range_map.items(): + for param in gbuf_range_map["param_map"]: + group_index = world_param_group_map[param] + group_range = group_ranges[group_index] + group_range["params"].append(param) + local_param_group_map[param] = \ + (group_index, len(group_range["params"]) - 1) + + # Squeeze zero-size group ranges. + for group_index, group_range in enumerate(group_ranges): + group_range["orig_group"] = param_groups[group_index] + group_range["orig_group_idx"] = param_groups[group_index] + + return local_param_group_map, group_ranges + + + @classmethod + def build_model_and_main_param_groups(cls, + model_gbuf_ranges, + param_gbuf_map, + opt_group_ranges): + """ + Create main parameter groups needed for the optimizer step. + + These groups encompass both: 1) groups used by this class, for + reducing/gather, and 2) groups used by the inner optimizer for the + parameter update. Given that the conceptual grad buffer partitioning + (created in earlier method) doesn't respect parameter boundaries, + the optimizer operates on shards of the model parameters, rather than + the full parameters. + """ + + # Parameter groups: + # model_float16_groups: original float16 parameters + # model_fp32_groups: original fp32 parameters + # shard_float16_groups: shards of original float16 parameters + # shard_fp32_groups: shards of original fp32 parameters + # shard_fp32_from_float16_groups: fp32 copy of float16 parameters + model_float16_groups = [] + model_fp32_groups = [] + shard_float16_groups = [] + shard_fp32_groups = [] + shard_fp32_from_float16_groups = [] + + # Allocate (or slice) each group's param shard. + for group_index, group_range in enumerate(opt_group_ranges): + + # Params of this group. + model_float16_params_this_group = [] + model_fp32_params_this_group = [] + shard_float16_params_this_group = [] + shard_fp32_params_this_group = [] + shard_fp32_from_float16_params_this_group = [] + model_float16_groups.append(model_float16_params_this_group) + model_fp32_groups.append(model_fp32_params_this_group) + shard_float16_groups.append(shard_float16_params_this_group) + shard_fp32_groups.append(shard_fp32_params_this_group) + shard_fp32_from_float16_groups.append( + shard_fp32_from_float16_params_this_group) + + for model_param in group_range["params"]: + + assert model_param.requires_grad + + model_index, dtype = param_gbuf_map[model_param] + gbuf_range = model_gbuf_ranges[model_index][dtype] + param_range = gbuf_range["param_map"][model_param]["param"] + + # fp16, bf16 params. + if model_param.type() in ['torch.cuda.HalfTensor', + 'torch.cuda.BFloat16Tensor']: + + # Clone model -> main. + shard_model_param = model_param.detach().view(-1) \ + [param_range.start:param_range.end] + shard_main_param = shard_model_param.clone().float() + copy_tensor_model_parallel_attributes( + shard_model_param, model_param) + copy_tensor_model_parallel_attributes( + shard_main_param, model_param) + if hasattr(model_param, 'shared'): + shard_model_param.shared = model_param.shared + shard_main_param.shared = model_param.shared + + # Add to group. + model_float16_params_this_group.append(model_param) + shard_float16_params_this_group.append(shard_model_param) + shard_fp32_from_float16_params_this_group.append(shard_main_param) + + # fp32 params. + elif model_param.type() == 'torch.cuda.FloatTensor': + shard_model_param = model_param.view(-1) \ + [param_range.start:param_range.end] + model_fp32_params_this_group.append(model_param) + shard_fp32_params_this_group.append(shard_model_param) + copy_tensor_model_parallel_attributes( + shard_model_param, model_param) + if hasattr(model_param, 'shared'): + shard_model_param.shared = model_param.shared + + else: + raise TypeError('Wrapped parameters must be one of ' + 'torch.cuda.FloatTensor, ' + 'torch.cuda.HalfTensor, or ' + 'torch.cuda.BFloat16Tensor. ' + 'Received {}'.format(model_param.type())) + + # Update optimizer's params. + group_range["orig_group"]["params"] = [ + *shard_fp32_params_this_group, + *shard_fp32_from_float16_params_this_group, + ] + + return ( + model_float16_groups, + model_fp32_groups, + shard_float16_groups, + shard_fp32_groups, + shard_fp32_from_float16_groups, + ) + + + def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + fp16, bf16, params_dtype, grad_scaler, models): + """ + See top of class definition for argument descriptions. + + The steps in this method create the core mapping between DDP grad + buffers, parameters, and parameter shard ranges, that is needed for + converting between model param indexes and main parameter shard + indexes. This method also updates the optimizer parameter groups + with the newly created shards. + """ + + super().__init__( + optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + fp16, bf16, params_dtype, grad_scaler, models) + + # Verify that contiguous buffers are being used. + # - Note: this should already be checked in arguments.py. + assert use_contiguous_buffers_in_local_ddp + assert isinstance(optimizer, Adam), \ + "Only Adam currently supported, due to checkpointing requirements." + + # Model grad buffer ranges. + self.model_gbuf_ranges = [] + for model_index, model in enumerate(self.models): + self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model)) + self.model_param_gbuf_map = \ + self.build_model_param_gbuf_map(self.model_gbuf_ranges) + + # Optimizer ranges. + self.model_param_group_index_map, self.opt_group_ranges = \ + self.build_optimizer_group_ranges(self.optimizer.param_groups, + self.model_gbuf_ranges) + + # Allocate main param shards. + ( + self.model_float16_groups, + self.model_fp32_groups, + self.shard_float16_groups, + self.shard_fp32_groups, + self.shard_fp32_from_float16_groups, + ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges, + self.model_param_gbuf_map, + self.opt_group_ranges) + + # Initialize param buffers. + # - These are views on the DDP model's grad buffers, that share + # storage & have their own dtype. This is safe because the param + # dtype size is always <= grad dtype size. + self.param_buffers = [] + for model_index, model in enumerate(self.models): + current_param_buffers = {} + for dtype, grad_buffer in model._grad_buffers.items(): + + # Handle older/newer method for getting untyped storage. + try: + storage = grad_buffer.data.storage()._untyped() + except: + storage = grad_buffer.data.storage().untyped() + + # Typed param buffer. + param_buffer = torch.tensor( + storage, + dtype = params_dtype, + device = grad_buffer.data.device) + param_buffer = param_buffer[:grad_buffer.numel_padded] + current_param_buffers[dtype] = param_buffer + self.param_buffers.append(current_param_buffers) + + # Update optimizer groups. + # - Also, leverage state_dict() and load_state_dict() to + # recast preexisting per-param state tensors. + self.optimizer.param_groups = \ + [ g["orig_group"] for g in self.opt_group_ranges ] + self.optimizer.load_state_dict(self.optimizer.state_dict()) + + + def get_model_param_range_map(self, param): + """ + Given a model param, get the index sub-range of the param that this + data-parallel rank owns. + """ + model_index, dtype = self.model_param_gbuf_map[param] + gbuf_range_map = self.model_gbuf_ranges[model_index][dtype] + param_range_map = gbuf_range_map["param_map"][param] + return param_range_map + + + def get_model_parallel_group(self): + """ + With the distributed optimizer, the model parallel group is the + entire world. + """ + return None + + + def state_dict(self): + """ + The state dict contains all non-DP-rank-dependent (i.e., non-parameter- + related) optimizer variables. The returned state dict can be stored in + the standard model/RNG checkpoint file. The parameter and dependent + optimizer state (e.g., exp_avg, exp_avg_sq) are stored in a separate + checkpoint file by calling 'save_parameter_state()'. + """ + + state_dict = {} + + # Optimizer state (do not store parameter state here). + state_dict['optimizer'] = { + k : v + for k, v in self.optimizer.state_dict().items() + if k != "state" + } + for param_group in state_dict["optimizer"]["param_groups"]: + del param_group["params"] + + # Grad scaler state. + if self.grad_scaler: + state_dict['grad_scaler'] = self.grad_scaler.state_dict() + + return state_dict + + + def load_state_dict(self, state_dict): + """Load the state dict. + + As detailed in state_dict(), the state dict contains all non- + parameter-related variables. This method is notably longer than + state_dict(), because the Torch optimizers state has yet to be + allocated at this point, and so we must do a cross referencing between + the optimizers state (and the ordering it expects for parameter state) + and this DP rank's shards. The optimizer at this point does not contain + any tensor dimension information, so we must get these dimensions from + the DP shards mapped during DistributedOptimizer.__init__(). + + The tensor parameter state is loaded via load_parameter_state(), and + so this method also must populate the loaded state dict with dummy + tensor data (i.e., via torch.empty() below). This will be overwritten + during load_parameter_state(). + + ** Note: Torch optimizer's state structure. ** + The Torch optimizer stores its state in two levels. The top level is a + list of groups, where each group contains a list of integer indexes + (corresponding to parameters) that index into a master parameter list + that is shared by all groups. As such, three values are necessary for + maintaining this ordering: + + - group_index : The group to which a parameter belongs. + - group_order : The index of a parameter within its group. + - state_order : The index of a parameter within the shared parameter + list. + """ + + # Get the Torch optimizer's state dict. + # - This 'inner' optimizer at this point is unallocated, and only + # contains an integer odering of parameters within each group, and + # the ordering of parameters within its flattened parameter state + # list. + inner_state_dict = self.optimizer.state_dict() + state_dict_param_groups = [{ + **group, + "params" : list(inner_state_dict["param_groups"][idx]["params"]), + } for idx, group in enumerate(state_dict["optimizer"]["param_groups"])] + + # Allocate 'dummy' data for optimizer state (i.e., torch.empty() below) + # - Real data is overwritten during load_parameter_state(). + state_dict_state = [] + for gbuf_range_maps in self.model_gbuf_ranges: + for gbuf_range_map in gbuf_range_maps.values(): + for model_param, param_range_map in \ + gbuf_range_map["param_map"].items(): + + # Get parameter ordering information (see method docstring + # for details). + group_index, group_order = \ + self.model_param_group_index_map[model_param] + state_order = inner_state_dict["param_groups"] \ + [group_index]["params"][group_order] + + # Allocate dummy tensors. + numel = len(param_range_map["gbuf_world"]) + init_shard = lambda : torch.empty( + (numel,), + dtype=torch.float32, + device=torch.cuda.current_device()) + + state_dict_state.append((state_order, { + "exp_avg" : init_shard(), + "exp_avg_sq" : init_shard(), + })) + + # Sort by state order (see method docstring for details). + state_dict_state.sort(key = lambda s : s[0]) + state_dict_state = {s[0]:s[1] for s in state_dict_state} + + # Optimizer. + self.optimizer.load_state_dict({ + "state" : state_dict_state, + "param_groups" : state_dict_param_groups, + }) + + # Grad scaler. + if 'grad_scaler' not in state_dict: + if self.fp16: + print_rank_0('***WARNING*** found an old checkpoint, will not ' + 'load grad scaler ...') + else: + if self.grad_scaler: + self.grad_scaler.load_state_dict(state_dict['grad_scaler']) + else: + print_rank_0('***WARNING*** fould the grad scaler in the ' + 'checkpoint but it is None in the class. ' + 'Skipping loading grad scaler ...') + + + def save_parameter_state(self, filename): + """Save parameter state (i.e., parameter & optimizer tensors). + + This method performs three steps: + - For each DP rank, copy param & optimizer shards to contiguous CPU + buffers. (e.g., one buffer each for main_param, exp_avg, and + exp_avg_sq). + - Gather contiguous buffers on DP rank 0 and concatenate to world + buffers. + - Save world buffers to disk (i.e., distrib_opt.pt). + """ + + # Data parallelism variables. + data_parallel_world_size = mpu.get_data_parallel_world_size() + data_parallel_rank = mpu.get_data_parallel_rank() + data_parallel_group_gloo = mpu.get_data_parallel_group_gloo() + data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS) + + # Collect param states. + state = {} + for model_idx, gbuf_range_maps in enumerate(self.model_gbuf_ranges): + + # Iterate grad buffers (by data type). + dtype_state = {} + assert len(gbuf_range_maps) == 1, "single dtype supported, for now." + for dtype, gbuf_range_map in gbuf_range_maps.items(): + + # Compute local DP contiguous shard's size. + model = self.models[model_idx] + gbuf_world_numel = model._grad_buffers[dtype].numel_padded + gbuf_local_numel = int(gbuf_world_numel/data_parallel_world_size) + local_shards = {key:torch.empty((gbuf_local_numel,), + dtype=torch.float32, + device="cpu") + for key in ("param", "exp_avg", "exp_avg_sq")} + + # Build contiguous DP rank shards (for param + optim states). + for model_param, param_range_map in \ + gbuf_range_map["param_map"].items(): + + # Main param & optimizer states. + group_index, group_order = \ + self.model_param_group_index_map[model_param] + main_param = self.optimizer.param_groups \ + [group_index]["params"][group_order] + optim_state = self.optimizer.state[main_param] + + tensors = { + "param" : main_param, + **optim_state, + } + + # Copy states into contiguous shard. + gbuf_local_start = param_range_map["gbuf_local"].start + gbuf_local_end = param_range_map["gbuf_local"].end + for key in local_shards: + local_shards[key][gbuf_local_start:gbuf_local_end] \ + .data.copy_(tensors[key].detach().cpu()) + + # Gather contiguous shards on DP rank 0. + world_tensors = {} + for key, send_tensor in local_shards.items(): + + # Gather tensor list. + if data_parallel_rank == 0: + recv_tensors = [torch.empty((gbuf_local_numel,), + dtype=torch.float32, + device="cpu") + for _ in range(data_parallel_world_size)] + else: + recv_tensors = None + + # Gather. + torch.distributed.gather( + send_tensor, + recv_tensors, + data_parallel_global_ranks[0], + data_parallel_group_gloo, + ) + + # Concatenate. + if data_parallel_rank == 0: + world_tensors[key] = torch.cat(recv_tensors) + + # Collect world state. + dtype_state[dtype] = world_tensors + state[model_idx] = dtype_state + + # Save param state. + if data_parallel_rank == 0: + torch.save(state, filename) + + + def load_parameter_state(self, filename): + """Load parameter state (i.e., parameter & optimizer tensors). + + This method performs the reverse of save_parameter_state(): + - Load world buffers from disk (i.e., distrib_opt.pt). + - Scatter contiguous buffers from DP rank 0 to each DP rank (each DP + rank receives its relevant subset of the world buffers). + - For each DP rank, copy param & optimizer shards from contiguous CPU + buffers. (e.g., one buffer each for main_param, exp_avg, and + exp_avg_sq). + """ + + # Data parallelism variables. + data_parallel_world_size = mpu.get_data_parallel_world_size() + data_parallel_rank = mpu.get_data_parallel_rank() + data_parallel_group_gloo = mpu.get_data_parallel_group_gloo() + data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS) + + # Load on DP rank 0. + if data_parallel_rank == 0: + loaded_state = torch.load(filename) + + # Scatter tensors to all DP ranks. + for model_idx, gbuf_range_maps in enumerate(self.model_gbuf_ranges): + for dtype, gbuf_range_map in gbuf_range_maps.items(): + + # Compute local DP contiguous shard's size. + model = self.models[model_idx] + gbuf_world_numel = model._grad_buffers[dtype].numel_padded + gbuf_local_numel = int(gbuf_world_numel/data_parallel_world_size) + + # Contiguous local shards (received from DP rank 0). + local_shards = {key:torch.empty((gbuf_local_numel,), + dtype=torch.float32, + device="cpu") + for key in ("param", "exp_avg", "exp_avg_sq")} + + # Scatter local shards from DP rank 0. + for key, recv_tensor in local_shards.items(): + + # Scatter tensor list. + if data_parallel_rank == 0: + world_tensor = loaded_state[model_idx][dtype][key] + gbuf_start_idxs = \ + list(range(0, gbuf_world_numel, gbuf_local_numel)) + send_tensors = [world_tensor[i:(i+gbuf_local_numel)] + for i in gbuf_start_idxs] + else: + send_tensors = None + + # Scatter. + torch.distributed.scatter( + recv_tensor, + send_tensors, + data_parallel_global_ranks[0], + data_parallel_group_gloo, + ) + + # Copy local contiguous shards to param/optim shards. + for model_param, param_range_map in \ + gbuf_range_map["param_map"].items(): + + # Main param & optimizer states. + group_index, group_order = \ + self.model_param_group_index_map[model_param] + main_param = self.optimizer.param_groups \ + [group_index]["params"][group_order] + optim_state = self.optimizer.state[main_param] + + tensors = { + "param" : main_param, + **optim_state, + } + + # Copy states into contiguous shard. + gbuf_local_start = param_range_map["gbuf_local"].start + gbuf_local_end = param_range_map["gbuf_local"].end + for key in local_shards: + tensors[key].data.copy_( + local_shards[key][gbuf_local_start:gbuf_local_end]) + + + def zero_grad(self, set_to_none=True): + """ + Zero grads. + + We only need to zero the model related parameters, i.e., + model_float16_groups & model_fp32_groups. We additionally zero + the remaining groups as a memory optimization to reduce + fragmentation; in the case of set_to_none==True, the space + used by this field can be safely deallocated at this point. + """ + for groups in ( + self.model_float16_groups, + self.model_fp32_groups, + self.shard_float16_groups, # grad empty/unused here? + self.shard_fp32_groups, # throws grad-access warning + self.shard_fp32_from_float16_groups): + for group in groups: + _zero_grad_group_helper(group, set_to_none) + + + @staticmethod + def get_model_buffer_dp_views(model_buffers): + """ + Get shard views of each of the DDP's param/grad buffers. + + In this nested list, the top level is grouped by the virtual model + index and the buffer's data type. The sub-level is a list of + shards of that buffer, where each shard in the list represents + a contiguous view of the buffer, that is owned by a data-parallel + rank. The shard boundary does not respect parameter boundaries, and + so the elements of some parameters are split across data parallel + ranks. + + Additionally, return references to the entire buffers, for use + in _reduce_scatter_base and _all_gather_base. + """ + + data_parallel_world_size = mpu.get_data_parallel_world_size() + + # Buffer views. + view_items = [] + for model_index, buffers in enumerate(model_buffers): + for dtype, buf in buffers.items(): + + assert buf.numel() % data_parallel_world_size == 0 + shard_size = int(buf.numel() / data_parallel_world_size) + buf_views = [buf[(r*shard_size):((r+1)*shard_size)] + for r in range(data_parallel_world_size)] + view_items.append((model_index, dtype, buf, buf_views)) + + return view_items + + + def get_model_grad_buffer_dp_views(self): + return self.get_model_buffer_dp_views([ + {dtype : mem_buffer.data} + for model in self.models + for dtype, mem_buffer in model._grad_buffers.items()]) + + + def get_model_param_buffer_dp_views(self): + return self.get_model_buffer_dp_views(self.param_buffers) + + + def reduce_model_grads(self, args, timers): + """ + Reduce-scatter model grads. + + The DDP's grad buffer is used for the reduce-scatter, and thus no + tensors are dynamically allocated. + + Note: this is a different order of reduction, versus the non- + distributed optimizer, which reduces: 1) layernorm grads, 2) all + grads, 3) embedding grads. + """ + + # All-reduce layer-norm grads (for sequence parallelism). + timers('layernorm-grads-all-reduce', log_level=1).start( + barrier=args.barrier_with_L1_time) + self.allreduce_layernorm_grads(args) + timers('layernorm-grads-all-reduce').stop() + + # All-reduce embedding grads. + timers('embedding-grads-all-reduce', log_level=1).start( + barrier=args.barrier_with_L1_time) + self.allreduce_embedding_grads(args) + timers('embedding-grads-all-reduce').stop() + + # Reduce-scatter setup. + timers('grads-reduce-scatter', log_level=1).start( + barrier=args.barrier_with_L1_time) + data_parallel_rank = mpu.get_data_parallel_rank() + data_parallel_world_size = mpu.get_data_parallel_world_size() + data_parallel_group = mpu.get_data_parallel_group() + + # Scale grad buffers by '1 / data_parallel_world_size'. + for model in self.models: + for dtype, gbuf in model._grad_buffers.items(): + gbuf.data /= data_parallel_world_size + + # Reduce-scatter all grads. + gbuf_view_items = self.get_model_grad_buffer_dp_views() + for index, (model_index, dtype, gbuf, gbuf_views) \ + in enumerate(gbuf_view_items): + + torch.distributed._reduce_scatter_base( + gbuf_views[data_parallel_rank], + gbuf, + group = data_parallel_group, + ) + + timers('grads-reduce-scatter').stop() + + + + def gather_model_params(self, args, timers): + """ + All-gather updated model params. + + The DDP's param buffer is used for the all-gather, and thus no + tensors are dynamically allocated. After the all-gather, the params + can be copied from the param buffer to the param. + """ + + timers('params-all-gather', log_level=1).start( + barrier=args.barrier_with_L1_time) + + data_parallel_rank = mpu.get_data_parallel_rank() + data_parallel_group = mpu.get_data_parallel_group() + + # All-gather updated main params. + # - All param buffer views are guaranteed to have the same num elements + # across all data parallel ranks, due to grad buffer padding that is + # done in distributed.py, and extended to the param buffers. Thus, + # all sub-views will have consistent start/end indexes across data + # parallel ranks. + pbuf_view_items = self.get_model_param_buffer_dp_views() + for index, (model_index, dtype, pbuf, pbuf_views) \ + in enumerate(pbuf_view_items): + + torch.distributed._all_gather_base( + pbuf, + pbuf_views[data_parallel_rank], + group = data_parallel_group, + ) + + # Copy from param buffer to each param. + for model_id, model in enumerate(self.models): + for dtype, param_map in model._grad_buffer_param_index_map.items(): + for param, (buf_start, buf_end) in param_map.items(): + param_buf = self.param_buffers[model_id][dtype] + param_buf_shard = param_buf[buf_start:buf_end] + param.view(-1).detach().copy_(param_buf_shard) + + timers('params-all-gather').stop() + + + def _collect_main_grad_data_for_unscaling(self): + """ + Note: this should be equivalent to the float-16 optimizer's method, + but writtent differently, so the two should be combined. + """ + return [ + param.grad.data + for group in self.optimizer.param_groups + for param in group["params"] + ] + + + def _get_model_and_main_params_data_float16(self): + """ + Get aligned list of model and main params. + """ + model_data = [] + main_data = [] + for model_group, main_group in zip(self.shard_float16_groups, + self.shard_fp32_from_float16_groups): + for model_param, main_param in zip(model_group, main_group): + model_data.append(model_param.data) + main_data.append(main_param.data) + return model_data, main_data + + + def _copy_model_grads_to_main_grads(self): + """ + Copy model grads to main grads. + + Since this step follows a reduce-scatter through the DDP's grad + buffer, this method is responsible for copying the updated grads + from the grad buffer to the main shard's grad field. + """ + + # Utility method for copying group grads. + def copy_group_grads(model_groups, shard_main_groups): + for model_group, shard_main_group in zip(model_groups, + shard_main_groups): + for model_param, shard_main_param in zip(model_group, + shard_main_group): + + param_range_map = self.get_model_param_range_map(model_param) + param_range = param_range_map["param"] + assert param_range.size == shard_main_param.nelement() + + model_grad = model_param.main_grad + shard_model_grad = model_grad.view(-1) \ + [param_range.start:param_range.end] + shard_main_param.grad = shard_model_grad.float() + + # Copy model groups to shard groups. + copy_group_grads(self.model_float16_groups, + self.shard_fp32_from_float16_groups) + copy_group_grads(self.model_fp32_groups, + self.shard_fp32_groups) + + + def _copy_main_params_to_model_params(self): + """ + Copy main params to model params. + + Since this step is followed by an all-gather through the DDP's grad + buffer, this method is responsible for copying the updated params + from the main shards into the correct position in the grad buffer. + """ + + # Utility method for copying group params. + def copy_group_params(shard_main_groups, model_groups): + for shard_main_group, model_group in zip(shard_main_groups, + model_groups): + for shard_main_param, model_param in zip(shard_main_group, + model_group): + + param_range_map = self.get_model_param_range_map(model_param) + world_range = param_range_map["gbuf_world"] + + assert world_range.size == shard_main_param.nelement() + + model_id, dtype = self.model_param_gbuf_map[model_param] + model_param_buffer = self.param_buffers[model_id][dtype] + + shard_model_param = model_param_buffer.view(-1) \ + [world_range.start:world_range.end] + + shard_model_param.data.copy_(shard_main_param) + + # Copy shard groups to model groups. + copy_group_params(self.shard_fp32_from_float16_groups, + self.model_float16_groups) + copy_group_params(self.shard_fp32_groups, + self.model_fp32_groups) + + + def _copy_model_params_to_main_params(self): + """ + Copy model params to main params. + + During finetuning, this method is used to reload the main params from + the model params. This copy does not make use of the grad buffer as + an intermediary. + """ + + # Utility method for copying group params. + def copy_group_params(model_groups, shard_main_groups): + for model_group, shard_main_group in zip(model_groups, + shard_main_groups): + for model_param, shard_main_param in zip(model_group, + shard_main_group): + + param_range_map = self.get_model_param_range_map(model_param) + param_range = param_range_map["param"] + assert param_range.size == shard_main_param.nelement() + + shard_model_param = model_param.view(-1) \ + [param_range.start:param_range.end] + shard_main_param.data.copy_(shard_model_param) + + # Copy model groups to shard groups. + copy_group_params(self.model_float16_groups, + self.shard_fp32_from_float16_groups) + copy_group_params(self.model_fp32_groups, + self.shard_fp32_groups) diff --git a/training/DeepSpeed-Domino/domino/optimizer/grad_scaler.py b/training/DeepSpeed-Domino/domino/optimizer/grad_scaler.py new file mode 100644 index 000000000..66f7c907a --- /dev/null +++ b/training/DeepSpeed-Domino/domino/optimizer/grad_scaler.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Megatron grad scaler.""" + +from abc import ABC +from abc import abstractmethod + +import torch + + +class MegatronGradScaler(ABC): + + def __init__(self, initial_scale): + """Initialize scale value with the input initial scale.""" + assert initial_scale > 0.0 + self._scale = torch.cuda.FloatTensor([initial_scale]) + + @property + def scale(self): + return self._scale + + @property + def inv_scale(self): + return self._scale.double().reciprocal().float() + + @abstractmethod + def update(self, found_inf): + pass + + @abstractmethod + def state_dict(self): + pass + + @abstractmethod + def load_state_dict(self, state_dict): + pass + + + +class ConstantGradScaler(MegatronGradScaler): + + def update(self, found_inf): + pass + + def state_dict(self): + return dict() + + def load_state_dict(self, state_dict): + pass + + + +class DynamicGradScaler(MegatronGradScaler): + + def __init__(self, initial_scale, min_scale, + growth_factor, backoff_factor, + growth_interval, hysteresis): + """"Grad scaler with dynamic scale that gets adjusted + during training.""" + super(DynamicGradScaler, self).__init__(initial_scale) + + # Lower bound on the scale. + assert min_scale > 0.0 + assert min_scale <= initial_scale + self.min_scale = torch.cuda.FloatTensor([min_scale]) + # Growth and backoff factors for the scale. + assert growth_factor > 1.0 + self.growth_factor = torch.cuda.FloatTensor([growth_factor]) + assert backoff_factor < 1.0 + assert backoff_factor > 0.0 + self.backoff_factor = torch.cuda.FloatTensor([backoff_factor]) + # Interval over which if we don't see any inf/nan, + # we will scale the grad scale by the growth factor. + assert growth_interval > 0 + self.growth_interval = growth_interval + # Number of inf/nans we should see before scaling down + # the grad scale by the backoff factor. + assert hysteresis > 0 + self.hysteresis = hysteresis + + # Trackers. + self._growth_tracker = 0 + self._hysteresis_tracker = self.hysteresis + + + def update(self, found_inf): + + # If we have an inf/nan, growth tracker is set to 0 + # and hysterisis tracker is reduced by 1. + if found_inf: + self._growth_tracker = 0 + self._hysteresis_tracker -= 1 + # Now if we are out of hysteresis count, scale down the loss. + if self._hysteresis_tracker <= 0: + self._scale = torch.max(self._scale * self.backoff_factor, + self.min_scale) + else: + # If there is no nan/inf, increment the growth tracker. + self._growth_tracker += 1 + # If we have had enough consequitive intervals with no nan/inf: + if self._growth_tracker == self.growth_interval: + # Reset the tracker and hysteresis trackers, + self._growth_tracker = 0 + self._hysteresis_tracker = self.hysteresis + # and scale up the loss scale. + self._scale = self._scale * self.growth_factor + + + def state_dict(self): + state_dict = {} + state_dict['scale'] = self._scale + state_dict['growth_tracker'] = self._growth_tracker + state_dict['hysteresis_tracker'] = self._hysteresis_tracker + return state_dict + + + def load_state_dict(self, state_dict): + self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) + self._growth_tracker = state_dict['growth_tracker'] + self._hysteresis_tracker = state_dict['hysteresis_tracker'] diff --git a/training/DeepSpeed-Domino/domino/optimizer/optimizer.py b/training/DeepSpeed-Domino/domino/optimizer/optimizer.py new file mode 100644 index 000000000..1578dcb8c --- /dev/null +++ b/training/DeepSpeed-Domino/domino/optimizer/optimizer.py @@ -0,0 +1,774 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""Megatron optimizer.""" + +from abc import ABC +from abc import abstractmethod +from apex.multi_tensor_apply import multi_tensor_applier +import amp_C +import torch +from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from domino.utils import print_rank_0 +import domino.parallel_state as mpu +from domino.tensor_parallel.partition import copy_tensor_model_parallel_attributes, param_is_not_tensor_parallel_duplicate +from domino.modules.distributed import DistributedDataParallel as LocalDDP +from domino.modules.module import Float16Module, param_is_not_shared +from domino.utils import unwrap_model + +from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 + + + + +def _zero_grad_group_helper(group, set_to_none): + """Zero out the gradient for a group of parameters. + Note: copied from torch.optim.optimizer.""" + for param in group: + if param.grad is not None: + if set_to_none: + param.grad = None + else: + if param.grad.grad_fn is not None: + param.grad.detach_() + else: + param.grad.requires_grad_(False) + param.grad.zero_() + + +def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): + """Use multi-tensor-applier to copy values from one list to another. + We don't have a blfoat16 implementation so for now if the overflow_buf + is not provided, we default back to simple loop copy to be compatible + with bfloat16.""" + if overflow_buf: + overflow_buf.fill_(0) + # Scaling with factor `1.0` is equivalent to copy. + multi_tensor_applier(amp_C.multi_tensor_scale, + overflow_buf, + [this, that], + 1.0) + else: + for this_, that_ in zip(this, that): + that_.copy_(this_) + + + +class MegatronOptimizer(ABC): + + + def __init__(self, optimizer, clip_grad, + log_num_zeros_in_grad, + params_have_main_grad, + use_contiguous_buffers_in_local_ddp, + models): + + """Input optimizer is the base optimizer for example Adam.""" + self.optimizer = optimizer + assert self.optimizer, 'no optimizer is provided.' + # Set gradient clipping and logging params. + self.clip_grad = clip_grad + self.log_num_zeros_in_grad = log_num_zeros_in_grad + self.params_have_main_grad = params_have_main_grad + self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp + + # 'models' are retained for access to the contiguous grad buffers. + # (see distributed optimizer) + self.models = models + + if self.use_contiguous_buffers_in_local_ddp: + assert self.params_have_main_grad, \ + "use of contiguous buffer requires that params have main grad" + + + def get_parameters(self): + params = [] + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + params.append(param) + return params + + + def get_main_grads_for_grad_norm(self): + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + params = self.get_parameters() + grads_for_norm = [] + for param in params: + grad = param.grad + grad_not_none = grad is not None + is_not_shared = param_is_not_shared(param) + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + if grad_not_none and is_not_shared and is_not_tp_duplicate: + grads_for_norm.append(grad) + + return grads_for_norm + + + def get_model_parallel_group(self): + """Default returned here, but the distributed optimizer overrides this.""" + return mpu.get_model_parallel_group() + + + def clip_grad_norm(self, clip_grad): + params = self.get_parameters() + grads_for_norm = self.get_main_grads_for_grad_norm() + return clip_grad_norm_fp32( + params, grads_for_norm, clip_grad, + model_parallel_group=self.get_model_parallel_group()) + + + def count_zeros(self): + params = self.get_parameters() + return count_zeros_fp32(params, + model_parallel_group=self.get_model_parallel_group()) + + + @abstractmethod + def zero_grad(self, set_to_none=True): + pass + + + @abstractmethod + def get_loss_scale(self): + """The output should be a cuda tensor of size 1.""" + pass + + + def scale_loss(self, loss): + """Simple scaling.""" + return self.get_loss_scale() * loss + + + @abstractmethod + def reload_model_params(self): + """Refreshes any internal state from the current model parameters. + Call whenever the parameters are changed outside of the optimizer. + For example, when we load a model from a checkpoint without loading + the optimizer, the model parameters are updated but for fp16 optimizer + with main parameters, the main parameters need to also be updated.""" + pass + + + @abstractmethod + def state_dict(self): + pass + + + @abstractmethod + def load_state_dict(self, state_dict): + pass + + + # Promote state so it can be retrieved or set via + # "optimizer_instance.state" + def _get_state(self): + return self.optimizer.state + + def _set_state(self, value): + self.optimizer.state = value + + state = property(_get_state, _set_state) + + + # Promote param_groups so it can be retrieved or set via + # "optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self.optimizer.param_groups + + def _set_param_groups(self, value): + self.optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) + + + @abstractmethod + def step(self, args, timers): + pass + + + def gather_model_params(self, args, timers): + """ + For the case of a non-distributed-optimizer, there is nothing to + do here. + """ + pass + + + def allreduce_word_embedding_grads(self, args): + """ + All-reduce word embedding grads. + + Reduce grads across first and last stages to ensure that word_embeddings + parameters stay in sync. This should only run for models that support + pipelined model parallelism (BERT and GPT-2). + """ + + if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ + mpu.get_pipeline_model_parallel_world_size() > 1: + if mpu.is_pipeline_first_stage(ignore_virtual=True): + unwrapped_model = self.models[0] + elif mpu.is_pipeline_last_stage(ignore_virtual=True): + unwrapped_model = self.models[-1] + else: # We do not support the interleaved schedule for T5 yet. + unwrapped_model = self.models[0] + unwrapped_model = unwrap_model( + unwrapped_model, (torchDDP, LocalDDP, Float16Module)) + + if unwrapped_model.share_embeddings_and_output_weights: + weight = unwrapped_model.shared_embedding_or_output_weight() + if args.DDP_impl == 'local': + grad = weight.main_grad + else: + grad = weight.grad + torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) + + + def allreduce_position_embedding_grads(self, args): + """ + All-reduce position_embeddings grad across first (encoder) and + split (decoder) stages to ensure that position embeddings parameters + stay in sync. This should only run for T5 models with pipeline + parallelism. + """ + if mpu.is_rank_in_position_embedding_group() and \ + mpu.get_pipeline_model_parallel_world_size() > 1 and \ + args.pipeline_model_parallel_split_rank is not None: + unwrapped_model = self.models[0] + unwrapped_model = unwrap_model( + unwrapped_model, (torchDDP, LocalDDP, Float16Module)) + assert args.DDP_impl == 'local', \ + 'T5 model is only supported with local DDP mode' + grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad + torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group()) + + + def allreduce_embedding_grads(self, args): + """All-reduce both word and position embeddings.""" + self.allreduce_word_embedding_grads(args) + self.allreduce_position_embedding_grads(args) + + + def allreduce_layernorm_grads(self, args): + """All-reduce layernorm grads (for sequence parallelism).""" + + # All-reduce layernorm parameters across model parallel nodes + # when sequence parallelism is used + if mpu.get_tensor_model_parallel_world_size() > 1 and \ + args.sequence_parallel: + grads = [] + for model_module in self.models: + unwrapped_model = unwrap_model( + model_module, (torchDDP, LocalDDP, Float16Module)) + for param in unwrapped_model.parameters(): + if getattr(param, 'sequence_parallel', False): + grad = param.main_grad if args.DDP_impl == 'local' else param.grad + grads.append(grad.data) + coalesced = _flatten_dense_tensors(grads) + torch.distributed.all_reduce( + coalesced, group=mpu.get_tensor_model_parallel_group()) + for buf, synced in zip(grads, _unflatten_dense_tensors( + coalesced, grads)): + buf.copy_(synced) + + def reduce_model_grads(self, args, timers): + """All-reduce all grads, and all-reduce embeddings.""" + + # All-reduce layer-norm grads (for sequence parallelism). + timers('layernorm-grads-all-reduce', log_level=1).start( + barrier=args.barrier_with_L1_time) + self.allreduce_layernorm_grads(args) + timers('layernorm-grads-all-reduce').stop() + + # All-reduce if needed. + if args.DDP_impl == 'local': + timers('grads-all-reduce', log_level=1).start( + barrier=args.barrier_with_L1_time) + for model in self.models: + model.allreduce_gradients() + timers('grads-all-reduce').stop() + + # All-reduce embedding grads. + timers('embedding-grads-all-reduce', log_level=1).start( + barrier=args.barrier_with_L1_time) + self.allreduce_embedding_grads(args) + timers('embedding-grads-all-reduce').stop() + + +class MixedPrecisionOptimizer(MegatronOptimizer): + """Base class for both the float-16 and the distributed optimizer. + + Arguments: + optimizer: base optimizer such as Adam or SGD + clip_grad: clip gradeints with this global L2 norm. Note + that clipping is ignored if clip_grad == 0 + log_num_zeros_in_grad: return number of zeros in the gradients. + params_have_main_grad: flag indicating if parameters have + a `main_grad` field. If this is set, we are assuming + that the model parameters are store in the `main_grad` + field instead of the typical `grad` field. This happens + for the DDP cases where there is a continuous buffer + holding the gradients. For example for bfloat16, we want + to do gradient accumulation and all-reduces in float32 + and as a result we store those gradients in the main_grad. + Note that main grad is not necessarily in float32. + use_contiguous_buffers_in_local_ddp: if true, the local DDP model + is using a contiguous buffer to hold the model grads. + fp16: if true, the model is running in fp16. + bf16: if true, the model is running in bfloat16. + params_dtype: used by distributed optimizer. + grad_scaler: used for scaling gradients. Note that this can be + None. This case happens when `bf16 = True` and we don't + use any loss scale. Note that for `bf16 = True`, we can have + a constnat gradient scaler. Also for `bf16 = False`, we + always require a grad scaler. + models: list of models (i.e., the virtual pipelining models). This + is used by the distributed optimizer for mapping parameters. + """ + + def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + fp16, bf16, params_dtype, grad_scaler, + models): + + super().__init__( + optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + models) + + self.fp16 = fp16 + self.bf16 = bf16 + self.params_dtype = params_dtype + self.grad_scaler = grad_scaler + + # None grad scaler is only supported for bf16. + if self.grad_scaler is None: + assert not self.fp16, 'fp16 expects a grad scaler.' + + # Tensor used to determine if a nan/if has happend. + # Any non-zero value indicates inf/nan. + # Note that we keep this for the cases that grad scaler is none. + # We still record nan/inf if we have a bfloat16 with a grad scaler. + if self.grad_scaler: + self.found_inf = torch.cuda.FloatTensor([0.0]) + + # Dummy tensor needed for apex multi-apply tensor. + # For bfloat, we don't have multi-tensor apply and for now + # we set it to none so the multi-tensor apply gets ignored. + if bf16: + self._dummy_overflow_buf = None + else: + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + + # In case grad scaler is not passed, define the unity scale. + if self.grad_scaler is None: + self._scale_one = torch.cuda.FloatTensor([1.0]) + + + def get_loss_scale(self): + if self.grad_scaler is None: + return self._scale_one + return self.grad_scaler.scale + + + def reload_model_params(self): + self._copy_model_params_to_main_params() + + + def _unscale_main_grads_and_check_for_nan(self): + + # Collect main grads. + main_grads = self._collect_main_grad_data_for_unscaling() + + # Reset found inf. + self.found_inf.fill_(0.0) + + # Unscale and set found inf/nan + torch._amp_foreach_non_finite_check_and_unscale_( + main_grads, self.found_inf, self.grad_scaler.inv_scale) + + # Update across all model parallel instances. + torch.distributed.all_reduce(self.found_inf, + op=torch.distributed.ReduceOp.MAX, + group=self.get_model_parallel_group()) + + # Check for nan. + found_inf_flag = (self.found_inf.item() > 0) + + return found_inf_flag + + + @torch.no_grad() + def step(self, args, timers): + + # Copy gradients from model params to main params. + # timers('optimizer-copy-to-main-grad', log_level=1).start( + # barrier=args.barrier_with_L1_time) + self._copy_model_grads_to_main_grads() + # timers('optimizer-copy-to-main-grad').stop() + + # Do unscale, check for inf, and update grad scaler only for + # the case that grad scaler is provided. + if self.grad_scaler: + + # Unscale and check for inf/nan. + # timers('optimizer-unscale-and-check-inf', log_level=1).start( + # barrier=args.barrier_with_L1_time) + found_inf_flag = self._unscale_main_grads_and_check_for_nan() + # timers('optimizer-unscale-and-check-inf').stop() + + # We are done with scaling gradients + # so we can update the loss scale. + self.grad_scaler.update(found_inf_flag) + + # If we found inf/nan, skip the update. + if found_inf_flag: + return False, None, None + + # Clip the main gradients. + # timers('optimizer-clip-main-grad', log_level=1).start( + # barrier=args.barrier_with_L1_time) + grad_norm = None + if self.clip_grad > 0.0: + grad_norm = self.clip_grad_norm(self.clip_grad) + # timers('optimizer-clip-main-grad').stop() + + # Count the zeros in the grads. + # timers('optimizer-count-zeros', log_level=1).start( + # barrier=args.barrier_with_L1_time) + num_zeros_in_grad = self.count_zeros() if \ + self.log_num_zeros_in_grad else None + # timers('optimizer-count-zeros').stop() + + # Step the optimizer. + # timers('optimizer-inner-step', log_level=1).start( + # barrier=args.barrier_with_L1_time) + self.optimizer.step() + # timers('optimizer-inner-step').stop() + + # Update params from main params. + # timers('optimizer-copy-main-to-model-params', log_level=1).start( + # barrier=args.barrier_with_L1_time) + self._copy_main_params_to_model_params() + # timers('optimizer-copy-main-to-model-params').stop() + + # Successful update. + return True, grad_norm, num_zeros_in_grad + + +class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): + """Float16 optimizer for fp16 and bf16 data types. + + Arguments: + optimizer: base optimizer such as Adam or SGD + clip_grad: clip gradeints with this global L2 norm. Note + that clipping is ignored if clip_grad == 0 + log_num_zeros_in_grad: return number of zeros in the gradients. + params_have_main_grad: flag indicating if parameters have + a `main_grad` field. If this is set, we are assuming + that the model parameters are store in the `main_grad` + field instead of the typical `grad` field. This happens + for the DDP cases where there is a continuous buffer + holding the gradients. For example for bfloat16, we want + to do gradient accumulation and all-reduces in float32 + and as a result we store those gradients in the main_grad. + Note that main grad is not necessarily in float32. + use_contiguous_buffers_in_local_ddp: if true, the local DDP model + is using a contiguous buffer to hold the model grads. + fp16: if true, the model is running in fp16. + bf16: if true, the model is running in bfloat16. + grad_scaler: used for scaling gradients. Note that this can be + None. This case happens when `bf16 = True` and we don't + use any loss scale. Note that for `bf16 = True`, we can have + a constnat gradient scaler. Also for `bf16 = False`, we + always require a grad scaler. + models: list of models (i.e., the virtual pipelining models). This + is used by the distributed optimizer for mapping parameters. + """ + + def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + fp16, bf16, params_dtype, grad_scaler, models): + + super().__init__( + optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + fp16, bf16, params_dtype, grad_scaler, models) + + # ====================== + # main parameter stuff + # ====================== + + # Three groups of parameters: + # float16_groups: original float16 parameters + # fp32_from_float16_groups: fp32 copy of float16 parameters + # fp32_from_fp32_groups: original fp32 parameters + self.float16_groups = [] + self.fp32_from_float16_groups = [] + self.fp32_from_fp32_groups = [] + + # For all the groups in the original optimizer: + for param_group in self.optimizer.param_groups: + float16_params_this_group = [] + fp32_params_this_group = [] + fp32_from_float16_params_this_group = [] + # For all the parameters in this group: + for i, param in enumerate(param_group['params']): + if param.requires_grad: + + # float16 params: + if param.type() in ['torch.cuda.HalfTensor', + 'torch.cuda.BFloat16Tensor']: + float16_params_this_group.append(param) + # Create a copy + main_param = param.detach().clone().float() + # Copy tensor model parallel attributes. + copy_tensor_model_parallel_attributes(main_param, + param) + if hasattr(param, 'shared'): + main_param.shared = param.shared + # Replace the optimizer params with the new fp32 copy. + param_group['params'][i] = main_param + + fp32_from_float16_params_this_group.append(main_param) + # Reset existing state dict key to the new main param. + if param in self.optimizer.state: + self.optimizer.state[main_param] \ + = self.optimizer.state.pop(param) + # fp32 params. + elif param.type() == 'torch.cuda.FloatTensor': + fp32_params_this_group.append(param) + param_group['params'][i] = param + + else: + raise TypeError('Wrapped parameters must be one of ' + 'torch.cuda.FloatTensor, ' + 'torch.cuda.HalfTensor, or ' + 'torch.cuda.BFloat16Tensor. ' + 'Received {}'.format(param.type())) + + self.float16_groups.append(float16_params_this_group) + self.fp32_from_float16_groups.append( + fp32_from_float16_params_this_group) + self.fp32_from_fp32_groups.append(fp32_params_this_group) + + + def zero_grad(self, set_to_none=True): + """We only need to zero the model related parameters, i.e., + float16_groups & fp32_from_fp32_groups. We additionally zero + fp32_from_float16_groups as a memory optimization to reduce + fragmentation; in the case of set_to_none==True, the space + used by this field can be safely deallocated at this point.""" + for group in self.float16_groups: + _zero_grad_group_helper(group, set_to_none) + for group in self.fp32_from_float16_groups: + _zero_grad_group_helper(group, set_to_none) + for group in self.fp32_from_fp32_groups: + _zero_grad_group_helper(group, set_to_none) + + + def _collect_main_grad_data_for_unscaling(self): + + main_grads = [] + + # fp32 params from float16 ones. + for main_group in self.fp32_from_float16_groups: + for main_param in main_group: + if main_param.grad is not None: + main_grads.append(main_param.grad.data) + + # Append fp32 parameters. + for main_group in self.fp32_from_fp32_groups: + for main_param in main_group: + if main_param.grad is not None: + main_grads.append(main_param.grad.data) + + return main_grads + + + def _get_model_and_main_params_data_float16(self): + model_data = [] + main_data = [] + for model_group, main_group in zip(self.float16_groups, + self.fp32_from_float16_groups): + for model_param, main_param in zip(model_group, main_group): + model_data.append(model_param.data) + main_data.append(main_param.data) + return model_data, main_data + + + def _copy_model_grads_to_main_grads(self): + # This only needs to be done for the float16 group. + for model_group, main_group in zip(self.float16_groups, + self.fp32_from_float16_groups): + for model_param, main_param in zip(model_group, main_group): + if self.params_have_main_grad and hasattr(model_param, 'main_grad'): + main_param.grad = model_param.main_grad.float() + else: + if model_param.grad is not None: + main_param.grad = model_param.grad.float() + + # Safe to deallocate model's grad/main_grad after copying. + # (If using contiguous buffers, main_grad's memory should + # persist and therefore should not be deallocated.) + model_param.grad = None + if self.params_have_main_grad and \ + not self.use_contiguous_buffers_in_local_ddp: + model_param.main_grad = None + + # For fp32 grads, we need to reset the grads to main grad. + if self.params_have_main_grad: + for model_group in self.fp32_from_fp32_groups: + for model_param in model_group: + model_param.grad = model_param.main_grad + + # Safe to de-reference model's main_grad after copying. + # (If using contiguous buffers, main_grad's memory should + # persist and therefore should not be deallocated.) + if not self.use_contiguous_buffers_in_local_ddp: + model_param.main_grad = None + + + def _copy_main_params_to_model_params(self): + # Only needed for the float16 params. + model_data, main_data = self._get_model_and_main_params_data_float16() + _multi_tensor_copy_this_to_that(this=main_data, that=model_data, + overflow_buf=self._dummy_overflow_buf) + + + def _copy_model_params_to_main_params(self): + # Only needed for the float16 params. + model_data, main_data = self._get_model_and_main_params_data_float16() + _multi_tensor_copy_this_to_that(this=model_data, that=main_data, + overflow_buf=self._dummy_overflow_buf) + + + def state_dict(self): + state_dict = {} + state_dict['optimizer'] = self.optimizer.state_dict() + if self.grad_scaler: + state_dict['grad_scaler'] = self.grad_scaler.state_dict() + state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups + return state_dict + + + def load_state_dict(self, state_dict): + # Optimizer. + optimizer_key = 'optimizer' + if optimizer_key not in state_dict: + optimizer_key = 'optimizer_state_dict' + print_rank_0('***WARNING*** loading optimizer from ' + 'an old checkpoint ...') + self.optimizer.load_state_dict(state_dict[optimizer_key]) + + # Grad scaler. + if 'grad_scaler' not in state_dict: + if self.fp16: + print_rank_0('***WARNING*** found an old checkpoint, will not ' + 'load grad scaler ...') + else: + if self.grad_scaler: + self.grad_scaler.load_state_dict(state_dict['grad_scaler']) + else: + print_rank_0('***WARNING*** fould the grad scaler in the ' + 'checkpoint but it is None in the class. ' + 'Skipping loading grad scaler ...') + + # Copy data for the main params. + fp32_from_float16_params_key = 'fp32_from_fp16_params' + if fp32_from_float16_params_key not in state_dict: + fp32_from_float16_params_key = 'fp32_from_fp16' + for current_group, saved_group in zip( + self.fp32_from_float16_groups, + state_dict[fp32_from_float16_params_key]): + for current_param, saved_param in zip(current_group, saved_group): + current_param.data.copy_(saved_param.data) + + +class FP32Optimizer(MegatronOptimizer): + + def __init__(self, optimizer, clip_grad, + log_num_zeros_in_grad, + params_have_main_grad, + use_contiguous_buffers_in_local_ddp, + models): + + super(FP32Optimizer, self).__init__( + optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + models) + + self._scale = torch.cuda.FloatTensor([1.0]) + + + def zero_grad(self, set_to_none=True): + """Copied from torch.optim.optimizer""" + for group in self.optimizer.param_groups: + _zero_grad_group_helper(group['params'], set_to_none) + + + def get_loss_scale(self): + """FP32 optimizer does not do any scaling.""" + return self._scale + + + @torch.no_grad() + def step(self, args, timers): + """Clip gradients (if needed) and step the base optimizer. + Always return successful since there is no overflow.""" + + # Copy main_grads to grads. + timers('optimizer-copy-to-main-grad', log_level=1).start( + barrier=args.barrier_with_L1_time) + if self.params_have_main_grad: + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + param.grad = param.main_grad + + # Safe to de-reference model's main_grad after copying. + # (If using contiguous buffers, main_grad's memory should + # persist and therefore should not be deallocated.) + if not self.use_contiguous_buffers_in_local_ddp: + param.main_grad = None + timers('optimizer-copy-to-main-grad').stop() + + # Clip gradients. + timers('optimizer-clip-main-grad', log_level=1).start( + barrier=args.barrier_with_L1_time) + grad_norm = None + if self.clip_grad > 0.0: + grad_norm = self.clip_grad_norm(self.clip_grad) + timers('optimizer-clip-main-grad').stop() + + # count the zeros in the grads + timers('optimizer-count-zeros', log_level=1).start( + barrier=args.barrier_with_L1_time) + num_zeros_in_grad = self.count_zeros() if \ + self.log_num_zeros_in_grad else None + timers('optimizer-count-zeros').stop() + + # Update parameters. + timers('optimizer-inner-step', log_level=1).start( + barrier=args.barrier_with_L1_time) + self.optimizer.step() + timers('optimizer-inner-step').stop() + + # No overflow for FP32 optimizer. + return True, grad_norm, num_zeros_in_grad + + + def reload_model_params(self): + pass + + + def state_dict(self): + return self.optimizer.state_dict() + + + def load_state_dict(self, state_dict): + self.optimizer.load_state_dict(state_dict) diff --git a/training/DeepSpeed-Domino/domino/optimizer_param_scheduler.py b/training/DeepSpeed-Domino/domino/optimizer_param_scheduler.py new file mode 100644 index 000000000..1c9aab34c --- /dev/null +++ b/training/DeepSpeed-Domino/domino/optimizer_param_scheduler.py @@ -0,0 +1,234 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from optimizer_param_scheduler.py in Megatron-LM + + +import math +from domino.utils import print_rank_0 + +class OptimizerParamScheduler(object): + """Anneals learning rate and weight decay""" + + def __init__(self, optimizer, init_lr, max_lr, min_lr, + lr_warmup_steps, lr_decay_steps, lr_decay_style, + start_wd, end_wd, wd_incr_steps, wd_incr_style, + use_checkpoint_opt_param_scheduler=True, + override_opt_param_scheduler=False): + + # Class values. + self.optimizer = optimizer + + self.init_lr = init_lr + self.max_lr = float(max_lr) + self.min_lr = min_lr + assert self.min_lr >= 0.0 + assert self.max_lr >= self.min_lr + assert self.init_lr <= self.max_lr + + self.lr_warmup_steps = lr_warmup_steps + self.num_steps = 0 + self.lr_decay_steps = lr_decay_steps + assert self.lr_decay_steps > 0 + assert self.lr_warmup_steps < self.lr_decay_steps + + self.lr_decay_style = lr_decay_style + + self.start_wd = start_wd + self.end_wd = end_wd + assert self.start_wd >= 0.0 + assert self.end_wd >= self.start_wd + self.wd_incr_steps = wd_incr_steps + self.wd_incr_style = wd_incr_style + + self.override_opt_param_scheduler = override_opt_param_scheduler + self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler + if self.override_opt_param_scheduler: + assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\ + 'use-checkpoint are set.' + + # Set the learning rate + self.step(0) + print_rank_0('> learning rate decay style: {}'.format(self.lr_decay_style)) + + + def get_wd(self): + """ Weight decay incr functions""" + if self.num_steps > self.wd_incr_steps: + return self.end_wd + + if self.wd_incr_style == 'constant': + assert self.start_wd == self.end_wd + return self.end_wd + + incr_ratio = float(self.num_steps) / float(self.wd_incr_steps) + assert incr_ratio >= 0.0 + assert incr_ratio <= 1.0 + delta_wd = self.end_wd - self.start_wd + + if self.wd_incr_style == 'linear': + coeff = incr_ratio + elif self.wd_incr_style == 'cosine': + coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0) + else: + raise Exception('{} weight decay increment style is not supported.'.format( + self.wd_incr_style)) + + return self.start_wd + coeff * delta_wd + + + def get_lr(self): + """Learning rate decay functions from: + https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" + + # Use linear warmup for the initial part. + if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps: + return ( + self.init_lr + + ( + (self.max_lr - self.init_lr) + * float(self.num_steps) + / float(self.lr_warmup_steps) + ) + ) + + # If the learning rate is constant, just return the initial value. + if self.lr_decay_style == 'constant': + return self.max_lr + + # For any steps larger than `self.lr_decay_steps`, use `self.min_lr`. + if self.num_steps > self.lr_decay_steps: + return self.min_lr + + # If we are done with the warmup period, use the decay style. + if self.lr_decay_style == 'inverse-square-root': + warmup_steps = max(self.lr_warmup_steps, 1) + num_steps = max(self.num_steps, 1) + lr = self.max_lr * warmup_steps ** 0.5 / (num_steps ** 0.5) + return max(self.min_lr, lr) + + num_steps_ = self.num_steps - self.lr_warmup_steps + decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps + decay_ratio = float(num_steps_) / float(decay_steps_) + assert decay_ratio >= 0.0 + assert decay_ratio <= 1.0 + delta_lr = self.max_lr - self.min_lr + + if self.lr_decay_style == 'linear': + coeff = (1.0 - decay_ratio) + elif self.lr_decay_style == 'cosine': + coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + else: + raise Exception('{} decay style is not supported.'.format( + self.lr_decay_style)) + + return self.min_lr + coeff * delta_lr + + + def step(self, increment): + """Set lr for all parameters groups.""" + self.num_steps += increment + new_lr = self.get_lr() + new_wd = self.get_wd() + for group in self.optimizer.param_groups: + group['lr'] = new_lr * group.get('lr_mult', 1.0) + group['weight_decay'] = new_wd * group.get('wd_mult', 1.0) + + + def state_dict(self): + state_dict = { + 'max_lr': self.max_lr, + 'lr_warmup_steps': self.lr_warmup_steps, + 'num_steps': self.num_steps, + 'lr_decay_style': self.lr_decay_style, + 'lr_decay_steps': self.lr_decay_steps, + 'min_lr': self.min_lr, + 'start_wd': self.start_wd, + 'end_wd': self.end_wd, + 'wd_incr_style': self.wd_incr_style, + 'wd_incr_steps': self.wd_incr_steps + } + return state_dict + + + def _check_and_set(self, cls_value, sd_value, name): + """Auxiliary function for checking the values in the checkpoint and + setting them.""" + if self.override_opt_param_scheduler: + print_rank_0(' > overriding {} value to {}'.format(name, cls_value)) + return cls_value + + if not self.use_checkpoint_opt_param_scheduler: + assert cls_value == sd_value, \ + f'OptimizerParamScheduler: class input value {cls_value} and checkpoint' \ + f'value {sd_value} for {name} do not match' + print_rank_0(' > using checkpoint value {} for {}'.format(sd_value, + name)) + return sd_value + + + def load_state_dict(self, sd): + + if 'start_lr' in sd: + max_lr_ = sd['start_lr'] + else: + max_lr_ = sd['max_lr'] + self.max_lr = self._check_and_set(self.max_lr, max_lr_, + 'learning rate') + + self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], + 'minimum learning rate') + + if 'warmup_iter' in sd: + lr_warmup_steps_ = sd['warmup_iter'] + elif 'warmup_steps' in sd: + lr_warmup_steps_ = sd['warmup_steps'] + else: + lr_warmup_steps_ = sd['lr_warmup_steps'] + self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps, + lr_warmup_steps_, + 'warmup iterations') + + if 'end_iter' in sd: + lr_decay_steps_ = sd['end_iter'] + elif 'decay_steps' in sd: + lr_decay_steps_ = sd['decay_steps'] + else: + lr_decay_steps_ = sd['lr_decay_steps'] + self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_, + 'total number of iterations') + + if 'decay_style' in sd: + lr_decay_style_ = sd['decay_style'] + else: + lr_decay_style_ = sd['lr_decay_style'] + self.lr_decay_style = self._check_and_set(self.lr_decay_style, + lr_decay_style_, + 'learning rate decay style') + + if 'num_iters' in sd: + num_steps = sd['num_iters'] + else: + num_steps = sd['num_steps'] + self.step(increment=num_steps) + + + if 'start_wd' in sd: + self.start_wd = self._check_and_set(self.start_wd, + sd['start_wd'], + "start weight decay") + self.end_wd = self._check_and_set(self.end_wd, + sd['end_wd'], + "end weight decay") + self.wd_incr_steps = self._check_and_set(self.wd_incr_steps, + sd['wd_incr_steps'], + "total number of weight decay iterations") + self.wd_incr_style = self._check_and_set(self.wd_incr_style, + sd['wd_incr_style'], + "weight decay incr style") + + + + + + + + diff --git a/training/DeepSpeed-Domino/domino/parallel_state.py b/training/DeepSpeed-Domino/domino/parallel_state.py new file mode 100644 index 000000000..aceba20d0 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/parallel_state.py @@ -0,0 +1,260 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from parallel_state.py in Megatron-LM + +import operator +from functools import reduce +import torch + + +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations. + Caller should ensure that buffers of the same name + are not used concurrently.""" + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name): + required_len = reduce(operator.mul, tensor_shape, 1) + if ( + self.buffer.get((name, dtype), None) is None + or self.buffer[(name, dtype)].numel() < required_len + ): + self.buffer[(name, dtype)] = torch.empty( + required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False + ) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) + + +_TENSOR_MODEL_PARALLEL_GROUP = None +_PIPELINE_MODEL_PARALLEL_GROUP = None +_MODEL_PARALLEL_GROUP = None +_DATA_PARALLEL_GROUP = None + +_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_TENSOR_MODEL_PARALLEL_RANK = None +_MPU_PIPELINE_MODEL_PARALLEL_RANK = None + +_PIPELINE_GLOBAL_RANKS = None +_DATA_PARALLEL_GLOBAL_RANKS = None + +_GLOBAL_MEMORY_BUFFER = None + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1 +) -> None: + """Initialize model data parallel groups. + """ + + pipeline_model_parallel_size = 1 + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + data_parallel_size: int = world_size // ( + tensor_model_parallel_size * pipeline_model_parallel_size + ) + + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + num_data_parallel_groups: int = world_size // data_parallel_size + + rank = torch.distributed.get_rank() + + global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GLOBAL_RANKS + all_data_parallel_group_ranks = [] + for i in range(pipeline_model_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + end_rank = (i + 1) * num_pipeline_model_parallel_groups + for j in range(tensor_model_parallel_size): + ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) + all_data_parallel_group_ranks.append(list(ranks)) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _DATA_PARALLEL_GROUP = group + _DATA_PARALLEL_GLOBAL_RANKS = ranks + + # Build the model-parallel groups. + global _MODEL_PARALLEL_GROUP + for i in range(data_parallel_size): + ranks = [ + data_parallel_group_ranks[i] + for data_parallel_group_ranks in all_data_parallel_group_ranks + ] + group = torch.distributed.new_group(ranks) + if rank in ranks: + _MODEL_PARALLEL_GROUP = group + + # Build the tensor model-parallel groups. + global _TENSOR_MODEL_PARALLEL_GROUP + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _TENSOR_MODEL_PARALLEL_GROUP = group + + global _PIPELINE_MODEL_PARALLEL_GROUP + global _PIPELINE_GLOBAL_RANKS + for i in range(num_pipeline_model_parallel_groups): + ranks = range(i, world_size, num_pipeline_model_parallel_groups) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _PIPELINE_MODEL_PARALLEL_GROUP = group + _PIPELINE_GLOBAL_RANKS = ranks + + _set_global_memory_buffer() + + +def get_model_parallel_group(): + """Get the model parallel group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, 'model parallel group is not initialized' + return _MODEL_PARALLEL_GROUP + + +def get_tensor_model_parallel_group(check_initialized=True): + """Get the tensor model parallel group the caller rank belongs to.""" + if check_initialized: + assert ( + _TENSOR_MODEL_PARALLEL_GROUP is not None + ), 'tensor model parallel group is not initialized' + return _TENSOR_MODEL_PARALLEL_GROUP + + +def get_pipeline_model_parallel_group(): + """Get the pipeline model parallel group the caller rank belongs to.""" + assert ( + _PIPELINE_MODEL_PARALLEL_GROUP is not None + ), 'pipeline_model parallel group is not initialized' + return _PIPELINE_MODEL_PARALLEL_GROUP + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized' + return _DATA_PARALLEL_GROUP + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: + return _MPU_TENSOR_MODEL_PARALLEL_RANK + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + return 0 + # global _MPU_PIPELINE_MODEL_PARALLEL_RANK + # if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: + # return _MPU_PIPELINE_MODEL_PARALLEL_RANK + # return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) + + +def is_pipeline_first_stage(ignore_virtual=False): + """Return True if in the first pipeline model-parallel stage, False otherwise.""" + return get_pipeline_model_parallel_rank() == 0 + + +def is_pipeline_last_stage(ignore_virtual=False): + """Return True if in the last pipeline model-parallel stage, False otherwise.""" + # if not ignore_virtual: + # virtual_pipeline_model_parallel_world_size = ( + # get_virtual_pipeline_model_parallel_world_size() + # ) + # if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != ( + # virtual_pipeline_model_parallel_world_size - 1 + # ): + # return False + # return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1) + return True + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +def get_data_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the data parallel group.""" + assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized" + return _DATA_PARALLEL_GLOBAL_RANKS[0] + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return torch.distributed.get_world_size(group=get_data_parallel_group()) + else: + return 0 + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return torch.distributed.get_rank(group=get_data_parallel_group()) + else: + return 0 + + +def _set_global_memory_buffer(): + """Initialize global buffer""" + global _GLOBAL_MEMORY_BUFFER + assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized' + _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() + + +def get_global_memory_buffer(): + """Return the global GlobalMemoryBuffer object""" + assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized' + return _GLOBAL_MEMORY_BUFFER + + +def destroy_global_memory_buffer(): + """Sets the global memory buffer to None""" + global _GLOBAL_MEMORY_BUFFER + _GLOBAL_MEMORY_BUFFER = None + + +def destroy_model_parallel(): + """Set the groups to none.""" + global _MODEL_PARALLEL_GROUP + _MODEL_PARALLEL_GROUP = None + global _TENSOR_MODEL_PARALLEL_GROUP + _TENSOR_MODEL_PARALLEL_GROUP = None + global _PIPELINE_MODEL_PARALLEL_GROUP + _PIPELINE_MODEL_PARALLEL_GROUP = None + global _DATA_PARALLEL_GROUP + _DATA_PARALLEL_GROUP = None + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None + global _MPU_TENSOR_MODEL_PARALLEL_RANK + _MPU_TENSOR_MODEL_PARALLEL_RANK = None + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + _MPU_PIPELINE_MODEL_PARALLEL_RANK = None + global _GLOBAL_MEMORY_BUFFER + _GLOBAL_MEMORY_BUFFER = None diff --git a/training/DeepSpeed-Domino/domino/schedules.py b/training/DeepSpeed-Domino/domino/schedules.py new file mode 100644 index 000000000..aa8dff81a --- /dev/null +++ b/training/DeepSpeed-Domino/domino/schedules.py @@ -0,0 +1,354 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from schedules.py in Megatron-LM + +import contextlib +from typing import Iterator, List, Union +import torch +from torch.autograd.variable import Variable +from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP +import domino.parallel_state as parallel_state +from domino.modules.enums import ModelType +from domino.utils import get_attr_wrapped_model, get_model_config, get_model_type +from domino.utils import average_losses_across_data_parallel_group + +from megatron.core.pipeline_parallel import p2p_communication + + +def get_forward_backward_func(): + forward_backward_func = forward_backward_no_pipelining + return forward_backward_func + + +def custom_backward(output, grad_output): + '''Directly call C++ autograd engine. + + To make the 'deallocate_output_tensor' (above) optimization work, the C++ + autograd engine must be called directly, bypassing Pytorch's + torch.autograd.backward. Pytorch's 'backward' checks that the output and + grad have the same shape, while C++'s 'backward' does not. + ''' + + assert output.numel() == 1, "output should be pseudo-'freed' in schedule, to optimize memory" + assert isinstance(output, torch.Tensor), "output == '%s'." % type(output).__name__ + assert isinstance(grad_output, (torch.Tensor, type(None))), ( + "grad_output == '%s'." % type(grad_output).__name__ + ) + + # Handle scalar output + if grad_output is None: + assert output.numel() == 1, "implicit grad requires scalar output." + grad_output = torch.ones_like(output, memory_format=torch.preserve_format,) + + # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ] + Variable._execution_engine.run_backward( + tensors=(output,), + grad_tensors=(grad_output,), + keep_graph=False, + create_graph=False, + inputs=tuple(), + allow_unreachable=True, + accumulate_grad=True, + ) + + +def forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data=False, + checkpoint_activations_microbatch=None, +): + """Forward step for passed-in model. + + If first stage, input tensor is obtained from data_iterator, otherwise + passed-in input_tensor is used. + + Returns output tensor.""" + if config.timers is not None: + config.timers('forward-compute', log_level=2).start() + + unwrap_output_tensor = False + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + unwrap_output_tensor = True + + set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor") + set_input_tensor(input_tensor) + + if config.enable_autocast: + context_manager = torch.autocast("cuda", dtype=config.autocast_dtype) + else: + context_manager = contextlib.nullcontext() + with context_manager: + if checkpoint_activations_microbatch is None: + output_tensor, loss_func = forward_step_func(data_iterator, model) + else: + output_tensor, loss_func = forward_step_func( + data_iterator, model, checkpoint_activations_microbatch + ) + + if parallel_state.is_pipeline_last_stage(): + if not collect_non_loss_data: + output_tensor = loss_func(output_tensor) + loss = output_tensor + averaged_loss = average_losses_across_data_parallel_group([loss]) + loss_reduced = {"lm loss": averaged_loss[0]} + output_tensor = loss / num_microbatches + forward_data_store.append(loss_reduced) + else: + data = loss_func(output_tensor, non_loss_data=True) + forward_data_store.append(data) + + if config.timers is not None: + config.timers('forward-compute').stop() + + if unwrap_output_tensor: + return output_tensor + return [output_tensor] + + +def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config): + """Backward step through passed-in output tensor. + + If last stage, output_tensor_grad is None, otherwise gradient of loss + with respect to stage's output tensor. + + Returns gradient of loss with respect to input tensor (None if first + stage).""" + + # NOTE: This code currently can handle at most one skip connection. It + # needs to be modified slightly to support arbitrary numbers of skip + # connections. + + if config.timers is not None: + config.timers('backward-compute', log_level=2).start() + + # Retain the grad on the input_tensor. + unwrap_input_tensor_grad = False + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + unwrap_input_tensor_grad = True + for x in input_tensor: + if x is not None: + x.retain_grad() + + if not isinstance(output_tensor, list): + output_tensor = [output_tensor] + if not isinstance(output_tensor_grad, list): + output_tensor_grad = [output_tensor_grad] + + # Backward pass. + if output_tensor_grad[0] is None and config.grad_scale_func is not None: + output_tensor[0] = config.grad_scale_func(output_tensor[0]) + + if config.deallocate_pipeline_outputs: + custom_backward(output_tensor[0], output_tensor_grad[0]) + else: + torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) + + # Collect the grad of the input_tensor. + input_tensor_grad = [None] + if input_tensor is not None: + input_tensor_grad = [] + for x in input_tensor: + if x is None: + input_tensor_grad.append(None) + else: + input_tensor_grad.append(x.grad) + + # Handle single skip connection if it exists (encoder_hidden_state in + # model with encoder and decoder). + if ( + parallel_state.get_pipeline_model_parallel_world_size() > 1 + and parallel_state.is_pipeline_stage_after_split() + and model_type == ModelType.encoder_and_decoder + ): + if output_tensor_grad[1] is not None: + input_tensor_grad[-1].add_(output_tensor_grad[1]) + if unwrap_input_tensor_grad: + input_tensor_grad = input_tensor_grad[0] + + if config.timers is not None: + config.timers('backward-compute').stop() + + return input_tensor_grad + + +def forward_backward_no_pipelining( + *, + forward_step_func, + data_iterator: Union[Iterator, List[Iterator]], + model: Union[torch.nn.Module, List[torch.nn.Module]], + num_microbatches: int, + seq_length: int, # unused + micro_batch_size: int, # unused + decoder_seq_length: int = None, # unused + forward_only: bool = False, + collect_non_loss_data: bool = False, +): + """Run forward and backward passes with no pipeline parallelism + (no inter-stage communication). + + Returns dictionary with losses. + + + See get_forward_backward_func() for argument details + """ + + if isinstance(model, list): + assert len(model) == 1, "non-pipeline-parallel schedule does not support model chunking" + model = model[0] + if isinstance(data_iterator, list): + assert ( + len(data_iterator) == 1 + ), "non-pipeline-parallel schedule does not support model chunking" + data_iterator = data_iterator[0] + + config = get_model_config(model) + + no_sync_func = config.no_sync_func + if no_sync_func is None and isinstance(model, torchDDP): + no_sync_func = model.no_sync + if no_sync_func is None: + no_sync_func = contextlib.nullcontext + + model_type = get_model_type(model) + + forward_data_store = [] + input_tensor, output_tensor_grad = None, None + with no_sync_func(): + for i in range(num_microbatches - 1): + output_tensor = forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + ) + if not forward_only: + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + + # Run computation for last microbatch out of context handler (want to + # synchronize gradients). + output_tensor = forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + ) + + if not forward_only: + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + + return forward_data_store + + +def get_tensor_shapes( + *, + rank: int, + model_type: ModelType, + seq_length: int, + micro_batch_size: int, + decoder_seq_length: int, + config, +): + # Determine right tensor sizes (based on position of rank with respect to split + # rank) and model size. + + tensor_shapes = [] + + if config.sequence_parallel: + seq_length = seq_length // parallel_state.get_tensor_model_parallel_world_size() + if model_type == ModelType.encoder_and_decoder: + decoder_seq_length = ( + decoder_seq_length // parallel_state.get_tensor_model_parallel_world_size() + ) + + if model_type == ModelType.encoder_and_decoder: + if parallel_state.is_pipeline_stage_before_split(rank): + tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) + else: + tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size)) + tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) + else: + tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) + return tensor_shapes + + +def recv_forward(tensor_shapes, config): + input_tensors = [] + for tensor_shape in tensor_shapes: + if tensor_shape is None: + input_tensors.append(None) + else: + input_tensors.append(p2p_communication.recv_forward(tensor_shape, config)) + return input_tensors + + +def recv_backward(tensor_shapes, config): + output_tensor_grads = [] + for tensor_shape in tensor_shapes: + if tensor_shape is None: + output_tensor_grads.append(None) + else: + output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, config)) + return output_tensor_grads + + +def send_forward(output_tensors, tensor_shapes, config): + if not isinstance(output_tensors, list): + output_tensors = [output_tensors] + for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes): + if tensor_shape is None: + continue + p2p_communication.send_forward(output_tensor, config) + + +def send_backward(input_tensor_grads, tensor_shapes, config): + if not isinstance(input_tensor_grads, list): + input_tensor_grads = [input_tensor_grads] + for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes): + if tensor_shape is None: + continue + p2p_communication.send_backward(input_tensor_grad, config) + + +def send_forward_recv_backward(output_tensors, tensor_shapes, config): + if not isinstance(output_tensors, list): + output_tensors = [output_tensors] + output_tensor_grads = [] + for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes): + if tensor_shape is None: + output_tensor_grads.append(None) + continue + output_tensor_grad = p2p_communication.send_forward_recv_backward( + output_tensor, tensor_shape, config + ) + output_tensor_grads.append(output_tensor_grad) + return output_tensor_grads + + +def send_backward_recv_forward(input_tensor_grads, tensor_shapes, config): + if not isinstance(input_tensor_grads, list): + input_tensor_grads = [input_tensor_grads] + input_tensors = [] + for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes): + if tensor_shape is None: + input_tensors.append(None) + continue + input_tensor = p2p_communication.send_backward_recv_forward( + input_tensor_grad, tensor_shape, config + ) + input_tensors.append(input_tensor) + return input_tensors diff --git a/training/DeepSpeed-Domino/domino/tensor_parallel/__init__.py b/training/DeepSpeed-Domino/domino/tensor_parallel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/training/DeepSpeed-Domino/domino/tensor_parallel/comm.py b/training/DeepSpeed-Domino/domino/tensor_parallel/comm.py new file mode 100644 index 000000000..0f7f2d7ca --- /dev/null +++ b/training/DeepSpeed-Domino/domino/tensor_parallel/comm.py @@ -0,0 +1,83 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from comm.py in Megatron-LM + +import torch + +from domino.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +from .utils import split_tensor_along_last_dim + + +def reduce_tensor(tensor): + if get_tensor_model_parallel_world_size() == 1: + return tensor + + torch.distributed.all_reduce(tensor, group=get_tensor_model_parallel_group()) + return tensor + + +def split_tensor_last_dim(tensor): + world_size = get_tensor_model_parallel_world_size() + if world_size == 1: + return tensor + + tensor_splits = split_tensor_along_last_dim(tensor, world_size) + rank = get_tensor_model_parallel_rank() + return tensor_splits[rank].contiguous() + + +def gather_tensor_last_dim(tensor): + world_size = get_tensor_model_parallel_world_size() + if world_size == 1: + return tensor + + last_dim = tensor.dim() - 1 + rank = get_tensor_model_parallel_rank() + gathered_tensors = [torch.empty_like(tensor) for _ in range(world_size)] + gathered_tensors[rank] = tensor + torch.distributed.all_gather(gathered_tensors, tensor, group=get_tensor_model_parallel_group()) + return torch.cat(gathered_tensors, dim=last_dim).contiguous() + + +class CopyToModelParallelRegion(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return input_ + + @staticmethod + def backward(ctx, grad_output): + return reduce_tensor(grad_output) + + +class ReduceFromModelParallelRegion(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return reduce_tensor(input_) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class ScatterToModelParallelRegion(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return split_tensor_last_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return gather_tensor_last_dim(grad_output) + + +class GatherFromModelParallelRegion(torch.autograd.Function): + @staticmethod + def forward(ctx, input_): + return gather_tensor_last_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return split_tensor_last_dim(grad_output) diff --git a/training/DeepSpeed-Domino/domino/tensor_parallel/cross_entropy.py b/training/DeepSpeed-Domino/domino/tensor_parallel/cross_entropy.py new file mode 100644 index 000000000..a87c1f521 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/tensor_parallel/cross_entropy.py @@ -0,0 +1,65 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from cross_entropy.py in Megatron-LM + +import torch + +from domino.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +from .utils import VocabUtility + + +class _VocabParallelCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, target): + max_logits = torch.max(logits, dim=-1)[0] + torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()) + logits = logits - max_logits.unsqueeze(dim=-1) + + partition_vocab_size = logits.size()[-1] + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start, vocab_end = VocabUtility.vocab_range_from_per_partition_vocab_size(partition_vocab_size, rank, world_size) + + target_mask = (target < vocab_start) | (target >= vocab_end) + adjusted_target = target.clone() - vocab_start + adjusted_target[target_mask] = 0 + + logits_2d = logits.view(-1, partition_vocab_size) + adjusted_target_1d = adjusted_target.view(-1) + batch_indices = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[batch_indices, adjusted_target_1d].clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) + + exp_logits = torch.exp(logits) + sum_exp_logits = exp_logits.sum(dim=-1) + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) + + loss = torch.log(sum_exp_logits) - predicted_logits + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + ctx.save_for_backward(exp_logits, target_mask, adjusted_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + softmax, target_mask, adjusted_target_1d = ctx.saved_tensors + + grad_input = softmax.view(-1, softmax.size()[-1]) + batch_indices = torch.arange(start=0, end=grad_input.size()[0], device=grad_input.device) + softmax_update = 1.0 - target_mask.view(-1).float() + grad_input[batch_indices, adjusted_target_1d] -= softmax_update + grad_input = grad_input.view_as(softmax) + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None + + +def vocab_parallel_cross_entropy(vocab_parallel_logits, target): + return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) diff --git a/training/DeepSpeed-Domino/domino/tensor_parallel/data.py b/training/DeepSpeed-Domino/domino/tensor_parallel/data.py new file mode 100644 index 000000000..8cc529521 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/tensor_parallel/data.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from data.py in Megatron-LM + +import torch + +from domino.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_src_rank, +) + +_MAX_DATA_DIM = 5 + + +def _check_data_types(keys, data, target_dtype): + """Check that all the keys have the same target data type.""" + for key in keys: + assert data[key].dtype == target_dtype, ( + '{} has data type {} which ' + 'is different than {}'.format(key, data[key].dtype, target_dtype) + ) + + +def _build_key_size_numel_dictionaries(keys, data): + """Build the size on rank 0 and broadcast.""" + max_dim = _MAX_DATA_DIM + sizes = [0 for _ in range(max_dim) for _ in keys] + + # Pack the sizes on rank zero. + if get_tensor_model_parallel_rank() == 0: + offset = 0 + for key in keys: + assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' + size = data[key].size() + for i, s in enumerate(size): + sizes[i + offset] = s + offset += max_dim + + # Move to GPU and broadcast. + sizes_cuda = torch.cuda.LongTensor(sizes) + torch.distributed.broadcast( + sizes_cuda, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group() + ) + + # Move back to cpu and unpack. + sizes_cpu = sizes_cuda.cpu() + key_size = {} + key_numel = {} + total_numel = 0 + offset = 0 + for key in keys: + i = 0 + size = [] + numel = 1 + while sizes_cpu[offset + i] > 0: + this_size = sizes_cpu[offset + i] + size.append(this_size) + numel *= this_size + i += 1 + key_size[key] = size + key_numel[key] = numel + total_numel += numel + offset += max_dim + + return key_size, key_numel, total_numel + + +def broadcast_data(keys, data, datatype): + """Broadcast data from rank zero of each model parallel group to the + members of the same model parallel group. + + Arguments: + keys: list of keys in the data disctionary to be broadcasted + data: data dictionary of string keys and cpu tensor values. + datatype: torch data type of all tensors in data associated + with keys. + """ + # Build (key, size) and (key, number of elements) dictionaries along + # with the total number of elements on all ranks. + key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) + + # Pack on rank zero. + if get_tensor_model_parallel_rank() == 0: + # Check that all keys have the same data type. + _check_data_types(keys, data, datatype) + # Flatten the data associated with the keys + flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda() + else: + flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype) + + # Broadcast + torch.distributed.broadcast( + flatten_data, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group() + ) + + # Unpack + output = {} + offset = 0 + for key in keys: + size = key_size[key] + numel = key_numel[key] + output[key] = flatten_data.narrow(0, offset, numel).view(size) + offset += numel + + return output diff --git a/training/DeepSpeed-Domino/domino/tensor_parallel/partition.py b/training/DeepSpeed-Domino/domino/tensor_parallel/partition.py new file mode 100644 index 000000000..b5f3b148b --- /dev/null +++ b/training/DeepSpeed-Domino/domino/tensor_parallel/partition.py @@ -0,0 +1,756 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from partition.py in Megatron-LM + +import math +import os +import warnings +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +import torch.nn.init as init +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.parameter import Parameter + +from megatron.core.model_parallel_config import ModelParallelConfig +from domino.parallel_state import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, get_tensor_model_parallel_group +from domino.tensor_parallel.comm import ( + CopyToModelParallelRegion, + ReduceFromModelParallelRegion, + GatherFromModelParallelRegion, + ScatterToModelParallelRegion, +) + +# from .random import get_cuda_rng_tracker +from .utils import VocabUtility + +_grad_accum_fusion_available = True +try: + import fused_weight_gradient_mlp_cuda +except ImportError: + _grad_accum_fusion_available = False + +_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { + 'tensor_model_parallel': False, + 'partition_dim': -1, + 'partition_stride': 1, +} + + +def param_is_not_tensor_parallel_duplicate(param): + return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or ( + get_tensor_model_parallel_rank() == 0 + ) + + +def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): + # Make sure the attributes are not set. + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + assert not hasattr(tensor, attribute) + # Set the attributes. + setattr(tensor, 'tensor_model_parallel', is_parallel) + setattr(tensor, 'partition_dim', dim) + setattr(tensor, 'partition_stride', stride) + + +def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor): + def maybe_set(attribute, value): + if not hasattr(tensor, attribute): + setattr(tensor, attribute, value) + + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) + + +def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): + def maybe_copy(attribute): + if hasattr(source_tensor, attribute): + setattr(destination_tensor, attribute, getattr(source_tensor, attribute)) + + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + maybe_copy(attribute) + + +def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): + """Initialize affine weight for model parallel on GPU.""" + + set_tensor_model_parallel_attributes( + tensor=weight, is_parallel=True, dim=partition_dim, stride=stride + ) + + # with get_cuda_rng_tracker().fork(): + init_method(weight) + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + + Keyword Arguments: + config: A megatron.core.ModelParallelConfig object + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + *, + init_method: Callable, + config: ModelParallelConfig, + ): + super(VocabParallelEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + # Set the detauls for compatibility. + self.padding_idx = None + self.max_norm = None + self.norm_type = 2.0 + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() + # Divide the weight matrix along the vocaburaly dimension. + ( + self.vocab_start_index, + self.vocab_end_index, + ) = VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size + ) + self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index + + # Allocate weights and initialize. + if config.use_cpu_initialization: + self.weight = Parameter( + torch.empty( + self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype + ) + ) + if config.perform_initialization: + _initialize_affine_weight_cpu( + self.weight, + self.num_embeddings, + self.embedding_dim, + self.num_embeddings_per_partition, + 0, + init_method, + params_dtype=config.params_dtype, + ) + else: + self.weight = Parameter( + torch.empty( + self.num_embeddings_per_partition, + self.embedding_dim, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + ) + if config.perform_initialization: + _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1) + + def forward(self, input_): + if self.tensor_model_parallel_size > 1: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + output_parallel = F.embedding( + masked_input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + # Mask the output embedding. + if self.tensor_model_parallel_size > 1: + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = ReduceFromModelParallelRegion.apply(output_parallel) + return output + + +class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): + """See linear_with_grad_accumulation_and_async_allreduce""" + + @staticmethod + @custom_fwd + def forward( + ctx, + input, + weight, + bias, + gradient_accumulation_fusion, + async_grad_allreduce, + sequence_parallel, + ): + ctx.save_for_backward(input, weight) + ctx.use_bias = bias is not None + ctx.gradient_accumulation_fusion = gradient_accumulation_fusion + ctx.async_grad_allreduce = async_grad_allreduce + ctx.sequence_parallel = sequence_parallel + + if sequence_parallel: + world_size = get_tensor_model_parallel_world_size() + dim_size = list(input.size()) + dim_size[0] = dim_size[0] * world_size + + all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") + torch.distributed._all_gather_base( + all_gather_buffer, input, group=get_tensor_model_parallel_group() + ) + total_input = all_gather_buffer + else: + total_input = input + + output = torch.matmul(total_input, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + if ctx.sequence_parallel: + world_size = get_tensor_model_parallel_world_size() + dim_size = list(input.size()) + dim_size[0] = dim_size[0] * world_size + + all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") + handle = torch.distributed._all_gather_base( + all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True + ) + + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # gather is scheduled before the input gradient computation + total_input = all_gather_buffer + else: + total_input = input + grad_input = grad_output.matmul(weight) + + if ctx.sequence_parallel: + handle.wait() + + # Doing gather + slicing during the NeMo forward pass can make this tensor + # not be contiguous. PyTorch only checks if the tensor is contiguous, and only + # clones it if it's not contiguous: + # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761 + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.view( + grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2] + ) + total_input = total_input.view( + total_input.shape[0] * total_input.shape[1], total_input.shape[2] + ) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = torch.distributed.all_reduce( + grad_input, group=get_tensor_model_parallel_group(), async_op=True + ) + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # all-reduce is scheduled before the weight gradient computation + + if ctx.sequence_parallel: + assert not ctx.async_grad_allreduce + dim_size = list(input.size()) + sub_grad_input = torch.empty( + dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False + ) + # reduce_scatter + handle = torch.distributed._reduce_scatter_base( + sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True + ) + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # reduce scatter is scheduled before the weight gradient computation + + if ctx.gradient_accumulation_fusion: + if weight.main_grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32( + total_input, grad_output, weight.main_grad + ) + elif weight.main_grad.dtype in (torch.float16, torch.bfloat16): + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16( + total_input, grad_output, weight.main_grad + ) + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.sequence_parallel: + handle.wait() + return sub_grad_input, grad_weight, grad_bias, None, None, None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +def linear_with_grad_accumulation_and_async_allreduce( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + gradient_accumulation_fusion: bool, + async_grad_allreduce: bool, + sequence_parallel: bool, +) -> torch.Tensor: + """Linear layer execution with asynchronous communication and + gradient accumulation fusion in backprop. + + This has the option to accumulate the result of backprop + calculation into an existing gradient buffer, preventing the need + to do an additional addition kernel after the gradient + calculation. + + Additionally, the tensor parallel all reduce of the input + gradients can be done asynchronously with the calculation of + the weight gradients. + + In the case of sequence parallelism, the reduce scatter of the + input gradients is done asynchronously with the calcluation of the + weight gradients. + + Use of this module requires that the environment variable + CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective + operations, noted in the code, that should be scheduled before + compute kernels to overlap the communication with the computation, + which is necessary for a speedup but not for correctness so that + ordering isn't imposed by the scheduler. Setting + CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled + in the order they are called. + + Arguments: + + input (torch.Tensor required): input like torch.nn.functional.linear + + weight (torch.Tensor required): weight like torch.nn.functional.linear + + bias (torch.Tensor optional): bias like torch.nn.functional.linear + + gradient_accumulation_fusion (bool required): Perform the gradient + accumulation fusion, requires the custom CUDA extension + fused_weight_gradient_mlp_cuda module. To use + gradient_accumulation_fusion you must install APEX with + --cpp_ext and --cuda_ext. For example: "pip install + --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" + " Note that the extension requires CUDA>=11. Otherwise, you + must turn off gradient accumulation fusion." + + async_grad_allreduce (bool required): Do the allreduce of input + gradients asyncronously with the computation of weight + gradients. If sequence_parallel is True, this must be + False, as no all reduce is performed. + + sequence_parallel (bool required): Indicates that sequence + parallelism is used and thus in the forward pass the input is + all gathered, and the backward pass the input gradients are + reduce scattered. + """ + args = [ + input, + weight, + bias, + gradient_accumulation_fusion, + async_grad_allreduce, + sequence_parallel, + ] + + if not linear_with_grad_accumulation_and_async_allreduce.warned: + if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": + if sequence_parallel: + warnings.warn( + "When using sequence parallelism it is recommended to set the " + "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " + "maximum speedup" + ) + linear_with_grad_accumulation_and_async_allreduce.warned = True + + if async_grad_allreduce: + warnings.warn( + "When using async grad allreduce it is recommended to set the " + "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " + "maximum speedup" + ) + linear_with_grad_accumulation_and_async_allreduce.warned = True + + return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) + + +linear_with_grad_accumulation_and_async_allreduce.warned = False + + + +class ColumnParallelLinear(torch.nn.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + + Keyword Arguments + bias: If true, add bias + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: If True, do not add the bias term, instead + return it to be added by the caller. This + enables performance optimations where bias can + be fused with other elementwise operations. + + skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed + as a keyword argument `weight` during the forward pass. Note + that this does not affect bias, which will be allocated if + bias is True. Defaults to False. + + config: ModelParallelConfig object + + """ + + def __init__( + self, + input_size, + output_size, + *, + config: ModelParallelConfig, + init_method: Callable, + bias=True, + gather_output=False, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + skip_weight_param_allocation: bool = False, + ): + super(ColumnParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + world_size = get_tensor_model_parallel_world_size() + self.output_size_per_partition = output_size // world_size + self.skip_bias_add = skip_bias_add + self.config = config + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + # Initialize weight. + if not skip_weight_param_allocation: + if config.use_cpu_initialization: + self.weight = Parameter( + torch.empty( + self.output_size_per_partition, self.input_size, dtype=config.params_dtype + ) + ) + if config.perform_initialization: + self.master_weight = _initialize_affine_weight_cpu( + self.weight, + self.output_size, + self.input_size, + self.output_size_per_partition, + 0, + init_method, + stride=stride, + return_master_weight=keep_master_weight_for_test, + ) + else: + self.weight = Parameter( + torch.empty( + self.output_size_per_partition, + self.input_size, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + ) + if config.perform_initialization: + _initialize_affine_weight_gpu( + self.weight, init_method, partition_dim=0, stride=stride + ) + else: + self.weight = None + + if bias: + if config.use_cpu_initialization: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, dtype=config.params_dtype) + ) + else: + self.bias = Parameter( + torch.empty( + self.output_size_per_partition, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + ) + set_tensor_model_parallel_attributes(self.bias, True, 0, stride) + if config.perform_initialization: + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + self.async_tensor_model_parallel_allreduce = ( + config.async_tensor_model_parallel_allreduce and world_size > 1 + ) + + self.sequence_parallel = config.sequence_parallel + if self.sequence_parallel and world_size <= 1: + warnings.warn( + f"`sequence_parallel` is set to `True`, but tensor model parallel size is {world_size}. " + f"Disabling sequence parallel." + ) + self.sequence_parallel = False + + if config.gradient_accumulation_fusion and not _grad_accum_fusion_available: + raise RuntimeError( + "ColumnParallelLinear was called with gradient_accumulation_fusion set " + "to True but the custom CUDA extension fused_weight_gradient_mlp_cuda " + "module is not found. To use gradient_accumulation_fusion you must " + "install APEX with --cpp_ext and --cuda_ext. For example: " + "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" " + "Note that the extension requires CUDA>=11. Otherwise, you must turn off " + "gradient accumulation fusion." + ) + self.gradient_accumulation_fusion = config.gradient_accumulation_fusion + + if self.async_tensor_model_parallel_allreduce and self.sequence_parallel: + raise RuntimeError( + "`async_tensor_model_parallel_allreduce` and `sequence_parallel` " + "cannot be enabled at the same time." + ) + + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + + def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): + """Forward of ColumnParallelLinear + + Args: + input_: 3D tensor whose order of dimension is [sequence, batch, hidden] + + weight (optional): weight tensor to use, compulsory when + skip_weight_param_allocation is True. + + Returns: + - output + - bias + + """ + if weight is None: + if self.weight is None: + raise RuntimeError( + "weight was not supplied to ColumnParallelLinear forward pass " + "and skip_weight_param_allocation is True." + ) + weight = self.weight + else: + # Check the weight passed in is the correct shape + expected_shape = (self.output_size_per_partition, self.input_size) + if weight.shape != expected_shape: + raise RuntimeError( + f"supplied weight's shape is {tuple(weight.shape)}, " + f"not {expected_shape} as expected" + ) + + bias = self.bias if not self.skip_bias_add else None + + if self.async_tensor_model_parallel_allreduce or self.sequence_parallel: + input_parallel = input_ + else: + input_parallel = CopyToModelParallelRegion.apply(input_) + # Matrix multiply. + if not weight.requires_grad: + self._forward_impl = linear_with_frozen_weight + else: + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + output_parallel = self._forward_impl( + input=input_parallel, + weight=weight, + bias=bias, + gradient_accumulation_fusion=self.gradient_accumulation_fusion, + async_grad_allreduce=self.async_tensor_model_parallel_allreduce, + sequence_parallel=self.sequence_parallel, + ) + if self.gather_output: + # All-gather across the partitions. + assert not self.sequence_parallel + output = GatherFromModelParallelRegion.apply(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + + +class RowParallelLinearNoComm(torch.nn.Module): + """Linear layer with row parallelism without communication. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + + Keyword Arguments: + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: If True, do not add the bias term, instead + return it to be added by the caller. This + enables performance optimations where bias can + be fused with other elementwise operations. + config: ModelParallelConfig object + + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + bias: bool = True, + input_is_parallel: bool = False, + stride: int = 1, + keep_master_weight_for_test: bool = False, + skip_bias_add: bool = False, + ): + super(RowParallelLinearNoComm, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.input_is_parallel = input_is_parallel + # Divide the weight matrix along the last dimension. + world_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = input_size // world_size + self.skip_bias_add = skip_bias_add + self.config = config + self.gradient_accumulation_fusion = config.gradient_accumulation_fusion + self.sequence_parallel = config.sequence_parallel + if self.sequence_parallel and not self.input_is_parallel: + raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`") + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + # Initialize weight. + self.weight = Parameter( + torch.empty( + self.output_size, + self.input_size_per_partition, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + ) + if config.perform_initialization: + _initialize_affine_weight_gpu( + self.weight, init_method, partition_dim=1, stride=stride + ) + if bias: + if config.use_cpu_initialization: + self.bias = Parameter(torch.empty(self.output_size, dtype=config.params_dtype)) + else: + self.bias = Parameter( + torch.empty( + self.output_size, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + ) + setattr(self.bias, 'sequence_parallel', self.sequence_parallel) + + if config.perform_initialization: + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + + def forward(self, input_): + """Forward of RowParallelLinear + + Args: + input_: 3D tensor whose order of dimension is [sequence, batch, hidden] + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + assert not self.sequence_parallel + input_parallel = ScatterToModelParallelRegion.apply(input_) + # Matrix multiply. + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + output_parallel = self._forward_impl( + input=input_parallel, + weight=self.weight, + bias=None, + gradient_accumulation_fusion=self.gradient_accumulation_fusion, + async_grad_allreduce=False, + sequence_parallel=False, + ) + + # All-reduce across all the partitions. + # if self.sequence_parallel: + # output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) + # else: + # output_ = reduce_from_tensor_model_parallel_region(output_parallel) + # if not self.skip_bias_add: + # output = output_ + self.bias if self.bias is not None else output_ + # output_bias = None + # else: + # output = output_ + # output_bias = self.bias + + output = output_parallel + output_bias = self.bias + return output, output_bias \ No newline at end of file diff --git a/training/DeepSpeed-Domino/domino/tensor_parallel/random.py b/training/DeepSpeed-Domino/domino/tensor_parallel/random.py new file mode 100644 index 000000000..822794ed2 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/tensor_parallel/random.py @@ -0,0 +1,241 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from random.py in Megatron-LM + +import contextlib + +import torch +from torch import _C +from torch.cuda import _lazy_call +from torch.cuda import device as device_ctx_manager +from torch.utils.checkpoint import detach_variable + +from domino.parallel_state import get_tensor_model_parallel_rank +from domino.utils import safely_set_viewless_tensor_data +from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks + +# Default name for the model parallel rng tracker. +_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' + + +def _set_cuda_rng_state(new_state, device=-1): + """Sets the random number generator state of the current GPU. + + Argumentss: + new_state (torch.ByteTensor): The desired state + This function is adapted from PyTorch repo (torch.cuda.set_rng_state) + with a single change: the input state is not cloned. Cloning caused + major performance issues for +4 GPU cases. + """ + if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): + # older PyTorch + def cb(): + with device_ctx_manager(device): + _C._cuda_setRNGState(new_state) + + else: + # newer PyTorch + if device == -1: + device = torch.device('cuda') + elif isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device('cuda', device) + + def cb(): + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.set_state(new_state) + + _lazy_call(cb) + + +class CudaRNGStatesTracker: + """Tracker for the cuda RNG states. + + Using the `add` method, a cuda rng state is initialized based on + the input `seed` and is assigned to `name`. Later, by forking the + rng state, we can perform operations and return to our starting + cuda state. + """ + + def __init__(self): + # Map from a string name to the cuda rng state. + self.states_ = {} + # Seeds are just for book keeping and ensure no seed is set twice. + self.seeds_ = set() + + def reset(self): + """Set to the initial state (no tracker).""" + self.states_ = {} + self.seeds_ = set() + + def get_states(self): + """Get rng states. Copy the dictionary so we have direct + pointers to the states, not just a pointer to the dictionary.""" + states = {} + for name in self.states_: + states[name] = self.states_[name] + return states + + def set_states(self, states): + """Set the rng states. For efficiency purposes, we do not check + the size of seed for compatibility.""" + self.states_ = states + + def add(self, name, seed): + """Track the rng state.""" + # Check seed is not already used. + if seed in self.seeds_: + raise Exception('seed {} already exists'.format(seed)) + self.seeds_.add(seed) + # Check that state is not already defined. + if name in self.states_: + raise Exception('cuda rng state {} already exists'.format(name)) + # Get the current rng state. + orig_rng_state = torch.cuda.get_rng_state() + # Set the new state and store it. + torch.cuda.manual_seed(seed) + self.states_[name] = torch.cuda.get_rng_state() + # Reset rng state to what it was. + _set_cuda_rng_state(orig_rng_state) + + @contextlib.contextmanager + def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): + """Fork the cuda rng state, perform operations, and exit with + the original state.""" + # Check if we have added the state + if name not in self.states_: + raise Exception('cuda rng state {} is not added'.format(name)) + # Store current rng state. + orig_cuda_rng_state = torch.cuda.get_rng_state() + # Set rng state to the desired one + _set_cuda_rng_state(self.states_[name]) + # Do the stuff we wanted to do. + try: + yield + finally: + # Update the current rng state for later use. + self.states_[name] = torch.cuda.get_rng_state() + # And set the state to the original state we started with. + _set_cuda_rng_state(orig_cuda_rng_state) + + +# RNG tracker object. +_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + + +def get_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _CUDA_RNG_STATE_TRACKER + + +def model_parallel_cuda_manual_seed(seed): + """Initialize model parallel cuda seed. + + This function should be called after the model parallel is + initialized. Also, no torch.cuda.manual_seed should be called + after this function. Basically, this is replacement for that + function. + Two set of RNG states are tracked: + default state: This is for data parallelism and is the same among a + set of model parallel GPUs but different across + different model paralle groups. This is used for + example for dropout in the non-tensor-model-parallel regions. + tensor-model-parallel state: This state is different among a set of model + parallel GPUs, but the same across data parallel + groups. This is used for example for dropout in + model parallel regions. + """ + # 2718 is just for fun and any POSITIVE value will work. + offset = seed + 2718 + tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank() + # Data parallel gets the original seed. + data_parallel_seed = seed + + _CUDA_RNG_STATE_TRACKER.reset() + # Set the default state. + torch.cuda.manual_seed(data_parallel_seed) + # and model parallel state. + _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed) + + +class CheckpointFunction(torch.autograd.Function): + """This function is adapted from torch.utils.checkpoint with + two main changes: + 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` + 2) the states in the model parallel tracker are also properly + tracked/set/reset. + """ + + @staticmethod + def forward(ctx, run_function, distribute_saved_activations, *args): + ctx.run_function = run_function + ctx.distribute_saved_activations = distribute_saved_activations + + # Copy the rng states. + ctx.fwd_cpu_rng_state = torch.get_rng_state() + ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + with torch.no_grad(): + outputs = run_function(*args) + + # Divide hidden states across model parallel group and only keep + # the chunk corresponding to the current rank. + if distribute_saved_activations: + ctx.input_0_shape = args[0].data.shape + safely_set_viewless_tensor_data( + args[0], split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True) + ) + + # Store everything. + ctx.save_for_backward(*args) + + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad(), " + "please use .backward() if possible" + ) + inputs = ctx.saved_tensors + if ctx.distribute_saved_activations: + safely_set_viewless_tensor_data( + inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape) + ) + + # Store the current states. + bwd_cpu_rng_state = torch.get_rng_state() + bwd_cuda_rng_state = torch.cuda.get_rng_state() + bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + # Set the states to what it used to be before the forward pass. + torch.set_rng_state(ctx.fwd_cpu_rng_state) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) + + # Compute the forward pass. + detached_inputs = detach_variable(inputs) + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + + # Set the states back to what it was at the start of this function. + torch.set_rng_state(bwd_cpu_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + torch.autograd.backward(outputs, args) + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) + return (None, None) + grads + + +def checkpoint(function, distribute_saved_activations, *args): + """Checkpoint a model or part of the model. + This has been directly copied from torch.utils.checkpoint.""" + return CheckpointFunction.apply(function, distribute_saved_activations, *args) diff --git a/training/DeepSpeed-Domino/domino/tensor_parallel/utils.py b/training/DeepSpeed-Domino/domino/tensor_parallel/utils.py new file mode 100644 index 000000000..a4a604fa2 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/tensor_parallel/utils.py @@ -0,0 +1,112 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from utils.py in Megatron-LM + + +from typing import List, Sequence +import torch +from domino import parallel_state + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """ Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): + """ Break a tensor into equal 1D chunks across tensor parallel ranks. + + Returns a Tensor or View with this rank's portion of the data. + + Arguments: + tensor: The tensor to split + + Keyword Arguments: + new_buffer (bool): If True, returns a new Tensor. + If False, returns a view into the existing Tensor. + Default is False + + """ + partition_size = torch.numel(tensor) // parallel_state.get_tensor_model_parallel_world_size() + start_index = partition_size * parallel_state.get_tensor_model_parallel_rank() + end_index = start_index + partition_size + if new_buffer: + data = torch.empty( + partition_size, + dtype=tensor.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + data.copy_(tensor.view(-1)[start_index:end_index]) + else: + data = tensor.view(-1)[start_index:end_index] + return data + + +def gather_split_1d_tensor(tensor): + """ Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor + model parallel ranks. + + Returns a new Tensor with the gathered data. + + Arguments: + tensor: A Tensor or view of this rank's portion of the data. + """ + numel_gathered = torch.numel(tensor) * parallel_state.get_tensor_model_parallel_world_size() + gathered = torch.empty( + numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False + ) + # TODO: This API is experimental in pytorch (as of Feb 2022) and + # this might break in future pytorch releases. We chose this API + # as opposed to torch.distributed.all_gather for efficiency reasons. + # This API calls directly NCCL all-gather versus the former does + # internal copies and can potentially cause slow down. + torch.distributed._all_gather_base( + gathered, tensor, group=parallel_state.get_tensor_model_parallel_group() + ) + return gathered + + +class VocabUtility: + """ Split the vocabulary into `world_size` chunks and return the first + and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last) + + """ + + @staticmethod + def vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size: int, rank, world_size: int + ) -> Sequence[int]: + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + @staticmethod + def vocab_range_from_global_vocab_size( + global_vocab_size: int, rank: int, world_size: int + ) -> Sequence[int]: + per_partition_vocab_size = global_vocab_size // world_size + return VocabUtility.vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, world_size + ) diff --git a/training/DeepSpeed-Domino/domino/timer.py b/training/DeepSpeed-Domino/domino/timer.py new file mode 100644 index 000000000..a9478fa01 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/timer.py @@ -0,0 +1,304 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Megatron timers.""" + +from abc import ABC +from abc import abstractmethod +import time + +import torch + + + +class TimerBase(ABC): + + def __init__(self, name): + self.name = name + + @abstractmethod + def start(self, barrier=False): + pass + + @abstractmethod + def stop(self, barrier=False): + pass + + @abstractmethod + def reset(self): + pass + + @abstractmethod + def elapsed(self, reset=True, barrier=False): + pass + + + +class DummyTimer(TimerBase): + + def __init__(self): + super().__init__('dummy timer') + + def start(self, barrier=False): + return + + def stop(self, barrier=False): + return + + def reset(self): + return + + def elapsed(self, reset=True, barrier=False): + raise Exception('dummy timer should not be used to ' + 'calculate elapsed time') + + + +class Timer(TimerBase): + """ + Comment on using `barrier`: If this flag is passed, then all + the caller processes will wait till all reach the timing routine. + It is up to the user to make sure all the ranks in `barrier_group` + call it otherwise, it will result in a hang. + Comment on `barrier_group`: By default it is set to None which + in torch distributed land, it will result in the global communicator. + """ + + def __init__(self, name): + super().__init__(name) + self._elapsed = 0.0 + self._started = False + # Note that None will default to the global process group + self._barrier_group = None + self._start_time = time.time() + + + def set_barrier_group(self, barrier_group): + self._barrier_group = barrier_group + + + def start(self, barrier=False): + """Start the timer.""" + assert not self._started, 'timer has already been started' + if barrier: + torch.distributed.barrier(group=self._barrier_group) + torch.cuda.synchronize() + self._start_time = time.time() + self._started = True + + + def stop(self, barrier=False): + """Stop the timer.""" + assert self._started, 'timer is not started' + if barrier: + torch.distributed.barrier(group=self._barrier_group) + torch.cuda.synchronize() + self._elapsed += (time.time() - self._start_time) + self._started = False + + + def reset(self): + """Reset timer.""" + self._elapsed = 0.0 + self._started = False + + + def elapsed(self, reset=True, barrier=False): + """Calculate the elapsed time.""" + _started = self._started + # If the timing in progress, end it first. + if self._started: + self.stop(barrier=barrier) + # Get the elapsed time. + _elapsed = self._elapsed + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if _started: + self.start(barrier=barrier) + return _elapsed + + + +class Timers: + """Group of timers.""" + + def __init__(self, log_level, log_option): + self._log_level = log_level + self._log_option = log_option + self._timers = {} + self._log_levels = {} + self._dummy_timer = DummyTimer() + self._max_log_level = 2 + + + def __call__(self, name, log_level=None): + # If the timer has already been set, then check if the log-level + # is provided, it matches the one that the timer was created with. + if name in self._timers: + if log_level is not None: + assert log_level == self._log_levels[name], \ + 'input log level {} does not match already existing '\ + 'log level {} for {} timer'.format( + log_level, self._log_levels[name], name) + return self._timers[name] + # If timer does not exist and no log level is provided, + # set it to the max log level which is 2. + if log_level is None: + log_level = self._max_log_level + assert log_level <= self._max_log_level, \ + 'log level {} is larger than max supported log level {}'.format( + log_level, self._max_log_level) + # Now if the input log level is larger than the one set for + # the timers class, just ignore it and return a dummy timer. + if log_level > self._log_level: + return self._dummy_timer + # Otherwise, initalize the timer and set the level. + self._timers[name] = Timer(name) + self._log_levels[name] = log_level + return self._timers[name] + + + def _get_elapsed_time_all_ranks(self, names, reset, barrier): + """ + Assumptions: + - All the ranks call this function. + - `names` are identical on all ranks. + If the above assumptions are not met, calling this function will + result in hang. + Arguments: + - names: list of timer names + - reset: reset the timer after recording the elapsed time + - barrier: if set, do a global barrier before time measurments + """ + + # First make sure all the callers are in sync. + if barrier: + torch.distributed.barrier() + + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # Here we can use gather on the rank we want to print the + # timing, however, there is no gather_base support in + # pytorch yet. It is simpler to deal with a single tensor + # and since we are only gathering a small amount of data, + # it should be ok to use all-gather instead of gather. + rank_name_to_time = torch.zeros((world_size, len(names)), + dtype=torch.float, + device=torch.cuda.current_device()) + for i, name in enumerate(names): + if name in self._timers: + # Here we don't need to pass the barrier flag as all + # the processes are already in sync. This avoids the + # issue of different timers having different barrier + # groups inside their class. + rank_name_to_time[rank, i] = self._timers[name].elapsed( + reset=reset) + + # See the note above for why we are not using gather. + torch.distributed._all_gather_base(rank_name_to_time.view(-1), + rank_name_to_time[rank, :].view(-1)) + + return rank_name_to_time + + + def _get_global_min_max_time(self, names, reset, barrier, normalizer): + """Report only min and max times across all ranks.""" + + rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, + barrier) + name_to_min_max_time = {} + for i, name in enumerate(names): + rank_to_time = rank_name_to_time[:, i] + # filter out the ones we did not have any timings for + rank_to_time = rank_to_time[rank_to_time > 0.0] + # If the timer exists: + if rank_to_time.numel() > 0: + name_to_min_max_time[name] = ( + rank_to_time.min().item() / normalizer, + rank_to_time.max().item() / normalizer) + return name_to_min_max_time + + + def _get_global_min_max_time_string(self, names, reset, barrier, + normalizer, max_only): + name_to_min_max_time = self._get_global_min_max_time( + names, reset, barrier, normalizer) + if not name_to_min_max_time: + return None + output_string = '(min, max) time across ranks (ms):' + for name in name_to_min_max_time: + min_time, max_time = name_to_min_max_time[name] + if max_only: + output_string += '\n {}: {:.2f}'.format( + (name+' ').ljust(48, '.'), max_time) + else: + output_string += '\n {}: ({:.2f}, {:.2f})'.format( + (name+' ').ljust(48, '.'), min_time, max_time) + return output_string + + + def _get_all_ranks_time_string(self, names, reset, barrier, normalizer): + """Report times across all ranks.""" + rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, + barrier) + + output_string = 'times across ranks (ms):' + no_reported_timing = True + for i, name in enumerate(names): + not_yet_found = True + for rank in range(torch.distributed.get_world_size()): + if rank_name_to_time[rank, i] > 0: + no_reported_timing = False + if not_yet_found: + not_yet_found = False + output_string += '\n {}:'.format(name) + output_string += '\n rank {:2d}: {:.2f}'.format( + rank, rank_name_to_time[rank, i] / normalizer) + if no_reported_timing: + return None + return output_string + + + def log(self, names, rank=None, normalizer=1.0, reset=True, barrier=False): + """Log a group of timers.""" + + # Print. + assert normalizer > 0.0 + if self._log_option in ['max', 'minmax']: + max_only = False + if self._log_option == 'max': + max_only = True + output_string = self._get_global_min_max_time_string( + names, reset, barrier, normalizer/1000.0, max_only) + elif self._log_option == 'all': + output_string = self._get_all_ranks_time_string(names, + reset, barrier, + normalizer/1000.0) + else: + raise Exception('unknown timing log option {}'.format( + self._log_option)) + + # If no input rank is provided, log on last rank. + if rank is None: + rank = torch.distributed.get_world_size() - 1 + if rank == torch.distributed.get_rank() and output_string is not None: + print(output_string, flush=True) + + + def write(self, names, writer, iteration, normalizer=1.0, + reset=False, barrier=False): + """Write timers to a tensorboard writer + Note that we only report maximum time across ranks to tensorboard. + """ + # currently when using add_scalars, + # torch.utils.add_scalars makes each timer its own run, which + # polutes the runs list, so we just add each as a scalar + assert normalizer > 0.0 + name_to_min_max_time = self._get_global_min_max_time( + names, reset, barrier, normalizer) + if writer is not None: + for name in name_to_min_max_time: + _, max_time = name_to_min_max_time[name] + writer.add_scalar(name + '-time', max_time, iteration) diff --git a/training/DeepSpeed-Domino/domino/training.py b/training/DeepSpeed-Domino/domino/training.py new file mode 100644 index 000000000..59e253fcf --- /dev/null +++ b/training/DeepSpeed-Domino/domino/training.py @@ -0,0 +1,431 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from training.py in Megatron-LM + +import torch +from domino.arguments import get_args, get_tokenizer, get_num_microbatches, get_timers +from domino.utils import print_rank_0, get_model_config, get_ltor_masks_and_position_ids +import domino.parallel_state as mpu +from domino.tensor_parallel.partition import set_defaults_if_not_set_tensor_model_parallel_attributes +from domino.modules.enums import ModelType +from domino.schedules import get_forward_backward_func +from domino.data.data_samplers import build_pretraining_data_loader +from domino.modules.distributed import DistributedDataParallel as LocalDDP +from domino.modules.module import Float16Module +from domino.optimizer import get_megatron_optimizer +from domino.optimizer_param_scheduler import OptimizerParamScheduler +from domino.initialize import set_jit_fusion_options +from domino.tensor_parallel.data import broadcast_data + + +def is_rank_0(): + # if torch.cuda.current_device() == 0: + if torch.distributed.get_rank() == 0: + return True + + +def forward_step(data_iterator, model): + input_tokens, target_labels, loss_mask, attention_mask, position_ids = prepare_batch(data_iterator) + model_output = model(input_tokens, position_ids, attention_mask, labels=target_labels) + return model_output, lambda output: compute_loss(loss_mask, output) + + +def prepare_batch(data_iterator): + args = get_args() + tokenizer = get_tokenizer() + + data_keys = ['text'] + data_type = torch.int64 + + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + + broadcasted_data = broadcast_data(data_keys, data, data_type) + full_tokens = broadcasted_data['text'].long() + input_tokens = full_tokens[:, :-1].contiguous() + target_labels = full_tokens[:, 1:].contiguous() + + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + input_tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss + ) + + return input_tokens, target_labels, loss_mask, attention_mask, position_ids + + +def compute_loss(loss_mask, model_output): + flattened_output = model_output.view(-1).float() + flattened_loss_mask = loss_mask.view(-1).float() + loss = torch.sum(flattened_output * flattened_loss_mask) / flattened_loss_mask.sum() + return loss + + +def pretrain(base_model, train_ds, valid_ds, test_ds): + args = get_args() + + # Model, optimizer, and learning rate. + model, optimizer, opt_param_scheduler = setup_model_and_optimizer( + base_model, ModelType.encoder_or_decoder) + config = get_model_config(model) + + # Do not use virtual pipeline parallelism for data parallel + train_data_iterator, valid_data_iterator, test_data_iterator \ + = get_dataset_iterator(train_ds, valid_ds, test_ds) + + # Train and eval. + print_rank_0('training ...') + + if args.do_train and args.train_iters > 0: + train(forward_step, + model, optimizer, opt_param_scheduler, + train_data_iterator, valid_data_iterator, config) + + # if args.do_valid: + # total_loss_dict = evaluate(forward_step, valid_data_iterator, model, config, True) + # print_rank_0(total_loss_dict) + + # if args.do_test: + # total_loss_dict = evaluate(forward_step, test_data_iterator, model, config, True) + # print_rank_0(total_loss_dict) + + +def setup_model_and_optimizer(base_model, + model_type, + no_wd_decay_cond=None, + scale_lr_cond=None): + """Setup model and optimizer.""" + args = get_args() + + model = get_model(base_model, model_type) + + if isinstance(model, list): + models = model + else: + models = [model] + optimizer = get_megatron_optimizer(models, no_wd_decay_cond, scale_lr_cond) + opt_param_scheduler = get_optimizer_param_scheduler(optimizer) + + args.iteration = 0 + + return model, optimizer, opt_param_scheduler + + +def get_model(base_model, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True): + """Build the model.""" + args = get_args() + args.model_type = model_type + + # Build model. + model = base_model + model.model_type = model_type + + for param in model.parameters(): + set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + # Print number of parameters. + if mpu.get_data_parallel_rank() == 0: + print(' > number of parameters on (tensor, pipeline) ' + 'model parallel rank ({}, {}): {}'.format( + mpu.get_tensor_model_parallel_rank(), + mpu.get_pipeline_model_parallel_rank(), + sum([p.nelement() for p in model.parameters()])), flush=True) + + # GPU allocation. + model.cuda(torch.cuda.current_device()) + + # Fp16 conversion. + if args.fp16 or args.bf16: + model = Float16Module(model, args) + + if wrap_with_ddp: + if args.DDP_impl == 'local': + model = LocalDDP(model, + args.accumulate_allreduce_grads_in_fp32, + args.use_contiguous_buffers_in_local_ddp) + # broad cast params from data parallel src rank to other data parallel ranks + if args.data_parallel_random_init: + model.broadcast_params() + else: + raise NotImplementedError('Unknown DDP implementation specified: ' + '{}. Exiting.'.format(args.DDP_impl)) + return model + + +def get_optimizer_param_scheduler(optimizer): + """Build the learning rate scheduler.""" + args = get_args() + + # Iteration-based training. + # Remove sample-based training. + if args.train_iters: + if args.lr_decay_iters is None: + args.lr_decay_iters = args.train_iters + lr_decay_steps = args.lr_decay_iters * args.global_batch_size + wd_incr_steps = args.train_iters * args.global_batch_size + if args.lr_warmup_fraction is not None: + lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps + else: + lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size + else: + raise Exception( + 'either train-iters or train-samples should be provided.') + + opt_param_scheduler = OptimizerParamScheduler( + optimizer, + init_lr=args.lr_warmup_init, + max_lr=args.lr, + min_lr=args.min_lr, + lr_warmup_steps=lr_warmup_steps, + lr_decay_steps=lr_decay_steps, + lr_decay_style=args.lr_decay_style, + start_wd=args.start_weight_decay, + end_wd=args.end_weight_decay, + wd_incr_steps=wd_incr_steps, + wd_incr_style=args.weight_decay_incr_style, + use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler, + override_opt_param_scheduler=args.override_opt_param_scheduler) + + return opt_param_scheduler + + +def get_dataset_iterator(train_ds, valid_ds, test_ds): + """Build pretraining data iterators.""" + args = get_args() + + # Build loaders. + train_dataloader, valid_dataloader, test_dataloader = \ + get_data_loader(train_ds, valid_ds, test_ds) + + # Build iterators. + dl_type = args.dataloader_type + assert dl_type == 'single' + + if train_dataloader is not None: + train_data_iterator = iter(train_dataloader) + else: + train_data_iterator = None + + if valid_dataloader is not None: + valid_data_iterator = iter(valid_dataloader) + else: + valid_data_iterator = None + + if test_dataloader is not None: + test_data_iterator = iter(test_dataloader) + else: + test_data_iterator = None + + return train_data_iterator, valid_data_iterator, test_data_iterator + + +def get_data_loader(train_ds, valid_ds, test_ds): + """Build pretraining data loaders.""" + args = get_args() + + (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) + + print_rank_0('> building train, validation, and test datasets ...') + + # Backward compatibility, assume fixed batch size. + if args.iteration > 0 and args.consumed_train_samples == 0: + assert args.train_samples is None, \ + 'only backward compatiblity support for iteration-based training' + args.consumed_train_samples = args.iteration * args.global_batch_size + if args.iteration > 0 and args.consumed_valid_samples == 0: + if args.train_samples is None: + args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ + args.eval_iters * args.global_batch_size + + # Data loader only on rank 0 of each model parallel group. + if mpu.get_tensor_model_parallel_rank() == 0: + # Build dataloders. + train_dataloader = build_pretraining_data_loader( + train_ds, args.consumed_train_samples) + valid_dataloader = build_pretraining_data_loader( + valid_ds, args.consumed_valid_samples) + test_dataloader = build_pretraining_data_loader(test_ds, 0) + + # Flags to know if we need to do training/validation/testing. + do_train = train_dataloader is not None and args.train_iters > 0 + do_valid = valid_dataloader is not None and args.eval_iters > 0 + do_test = test_dataloader is not None and args.eval_iters > 0 + # Need to broadcast num_tokens and num_type_tokens. + flags = torch.cuda.LongTensor( + [int(do_train), int(do_valid), int(do_test)]) + else: + flags = torch.cuda.LongTensor([0, 0, 0]) + + # Broadcast num tokens. + torch.distributed.broadcast(flags, + mpu.get_tensor_model_parallel_src_rank(), + group=mpu.get_tensor_model_parallel_group()) + args.do_train = flags[0].item() + args.do_valid = flags[1].item() + args.do_test = flags[2].item() + + return train_dataloader, valid_dataloader, test_dataloader + + +def train(forward_step_func, model, optimizer, opt_param_scheduler, + train_data_iterator, valid_data_iterator, config): + """Train the model function.""" + args = get_args() + timers = get_timers() + + model.train() + + # Iterations. + iteration = args.iteration + + # Setup some training config params + config.grad_scale_func = optimizer.scale_loss + config.timers = None + + timers('ite-time', log_level=0).start(barrier=True) + while iteration < args.train_iters: + args.curr_iteration = iteration + loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ + train_step(forward_step_func, + train_data_iterator, + model, + optimizer, + opt_param_scheduler, + config) + + iteration += 1 + args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ + args.micro_batch_size * get_num_microbatches() + + ite_time = timers('ite-time').elapsed(barrier=True) + if iteration % args.log_interval == 0 and is_rank_0(): + loss = loss_dict['lm loss'].item() + print( 'iteration: {} | loss: {:.3f} | iteration time (ms): {} '.format(iteration, loss, ite_time*1000.0)) + # loss_scale = optimizer.cur_scale + # lr = optimizer.param_groups[0]['lr'] + # print( 'lr: {} loss scale: {:.1f} |'.format(lr, loss_scale))' + + return iteration + + +def train_step(forward_step_func, data_iterator, + model, optimizer, opt_param_scheduler, config): + """Single training step.""" + args = get_args() + + # Set grad to zero. + if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp: + model.zero_grad_buffer() + optimizer.zero_grad() + + forward_backward_func = get_forward_backward_func() + + losses_reduced = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=get_num_microbatches(), + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=False) + + timers = None + # reset timers if necessary + if config.timers is None: + config.timers = timers + + # Empty unused memory. + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers) + + # Update learning rate. + if update_successful: + increment = get_num_microbatches() * \ + args.micro_batch_size * \ + args.data_parallel_size + opt_param_scheduler.step(increment=increment) + skipped_iter = 0 + else: + skipped_iter = 1 + + # Empty unused memory. + if args.empty_unused_memory_level >= 2: + torch.cuda.empty_cache() + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Average loss across microbatches. + loss_reduced = {} + for key in losses_reduced[0]: + losses_reduced_for_key = [x[key] for x in losses_reduced] + loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key) + return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad + return {}, skipped_iter, grad_norm, num_zeros_in_grad + + +def evaluate(forward_step_func, + data_iterator, + model, + config, + verbose=False): + """Evaluation.""" + args = get_args() + + # Turn on evaluation mode which disables dropout. + model.eval() + + total_loss_dict = {} + + # make validation batch size independent from training batch size + eval_batch_size = args.global_batch_size + eval_num_microbatches = eval_batch_size // \ + (args.micro_batch_size * args.data_parallel_size) + + with torch.no_grad(): + iteration = 0 + if verbose: + print_rank_0( + f'Evaluating on {args.eval_iters * eval_batch_size} samples') + while iteration < args.eval_iters: + iteration += 1 + if verbose: + print_rank_0(f'Evaluating iter {iteration}/{args.eval_iters}') + + forward_backward_func = get_forward_backward_func() + # Don't care about timing during evaluation + config.timers = None + loss_dicts = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=eval_num_microbatches, + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=True) + + # Empty unused memory + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Reduce across processes. + for loss_dict in loss_dicts: + for key in loss_dict: + total_loss_dict[key] = total_loss_dict.get( + key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] + + args.consumed_valid_samples += eval_batch_size + + # Move model back to the train mode. + model.train() + + for key in total_loss_dict: + total_loss_dict[key] /= args.eval_iters * eval_num_microbatches + + return total_loss_dict diff --git a/training/DeepSpeed-Domino/domino/utils.py b/training/DeepSpeed-Domino/domino/utils.py new file mode 100644 index 000000000..af90f6684 --- /dev/null +++ b/training/DeepSpeed-Domino/domino/utils.py @@ -0,0 +1,242 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import sys +import torch +from torch.nn.parallel import DistributedDataParallel as torchDDP +from apex.multi_tensor_apply import multi_tensor_applier +import amp_C +from domino.arguments import get_args +import domino.parallel_state as mpu +from domino.tensor_parallel.partition import param_is_not_tensor_parallel_duplicate + + +def unwrap_model(model, module_instances=(torchDDP)): + return_list = True + if not isinstance(model, list): + model = [model] + return_list = False + unwrapped_model = [] + for model_module in model: + while isinstance(model_module, module_instances): + model_module = model_module.module + unwrapped_model.append(model_module) + if not return_list: + return unwrapped_model[0] + return unwrapped_model + + +def get_attr_wrapped_model(model, attr, allow_none=True): + """Get an attribute from a wrapped model""" + if isinstance(model, list): + raise RuntimeError("_get_attr_wrapped_model given a list of models") + + if allow_none: + def condition(model, attr): + return not hasattr(model, attr) + else: + def condition(model, attr): + return getattr(model, attr, None) is None + + while condition(model, attr): + if not hasattr(model, "module"): + raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}") + + model = model.module + return getattr(model, attr) + + +def get_model_type(model): + return get_attr_wrapped_model(model, 'model_type') + + +def get_model_config(model): + # args = get_args() + # if args.deepspeed: + # return get_attr_wrapped_model(model.module, 'config', allow_none=False) + return get_attr_wrapped_model(model, 'config', allow_none=False) + + +def param_is_not_shared(param): + return not hasattr(param, 'shared') or not param.shared + + +def calc_params_l2_norm(model): + """Calculate l2 norm of parameters """ + args = get_args() + if not isinstance(model, list): + model = [model] + # Remove duplicate params. + params_data = [] + for model_ in model: + for param in model_.parameters(): + is_not_shared = param_is_not_shared(param) + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + if is_not_shared and is_not_tp_duplicate: + if args.bf16: + params_data.append(param.data.float()) + else: + params_data.append(param.data) + # Calculate norm + dummy_overflow_buf = torch.cuda.IntTensor([0]) + norm, _ = multi_tensor_applier( + amp_C.multi_tensor_l2norm, + dummy_overflow_buf, + [params_data], + False # no per-parameter norm + ) + norm_2 = norm * norm + # Sum across all model-parallel GPUs. + torch.distributed.all_reduce(norm_2, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_model_parallel_group()) + return norm_2.item() ** 0.5 + + +def print_rank_0(message): + """If distributed is initialized, print only on rank 0.""" + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + print(message, flush=True) + else: + print(message, flush=True) + + +def get_ltor_masks_and_position_ids(data, + eod_token, + reset_position_ids, + reset_attention_mask, + eod_mask_loss): + """Build masks and position id for left to right model.""" + + # Extract batch size and sequence length. + micro_batch_size, seq_length = data.size() + + # Attention mask (lower triangular). + if reset_attention_mask: + att_mask_batch = micro_batch_size + else: + att_mask_batch = 1 + attention_mask = torch.tril(torch.ones( + (att_mask_batch, seq_length, seq_length), device=data.device)).view( + att_mask_batch, 1, seq_length, seq_length) + + # Loss mask. + loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) + if eod_mask_loss: + loss_mask[data == eod_token] = 0.0 + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, + device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data) + # We need to clone as the ids will be modifed based on batch index. + if reset_position_ids: + position_ids = position_ids.clone() + + if reset_position_ids or reset_attention_mask: + # Loop through the batches: + for b in range(micro_batch_size): + + # Find indecies where EOD token is. + eod_index = position_ids[b, data[b] == eod_token] + # Detach indecies from positions if going to modify positions. + if reset_position_ids: + eod_index = eod_index.clone() + + # Loop through EOD indecies: + prev_index = 0 + for j in range(eod_index.size()[0]): + i = eod_index[j] + # Mask attention loss. + if reset_attention_mask: + attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 + # Reset positions. + if reset_position_ids: + position_ids[b, (i + 1):] -= (i + 1 - prev_index) + prev_index = i + 1 + + # Convert attention mask to binary: + attention_mask = (attention_mask < 0.5) + + return attention_mask, loss_mask, position_ids + + +def average_losses_across_data_parallel_group(losses): + """Reduce a tensor of losses across all GPUs.""" + averaged_losses = torch.cat( + [loss.clone().detach().view(1) for loss in losses]) + torch.distributed.all_reduce(averaged_losses, + group=mpu.get_data_parallel_group()) + averaged_losses = averaged_losses / \ + torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) + + return averaged_losses + + +def _kernel_make_viewless_tensor(inp, requires_grad): + '''Make a viewless tensor. + + View tensors have the undesirable side-affect of retaining a reference + to the originally-viewed tensor, even after manually setting the '.data' + field. This method creates a new tensor that links to the old tensor's + data, without linking the viewed tensor, referenced via the '._base' + field. + ''' + out = torch.empty( + (1,), + dtype = inp.dtype, + device = inp.device, + requires_grad = requires_grad, + ) + out.data = inp.data + return out + +class MakeViewlessTensor(torch.autograd.Function): + ''' + Autograd function to make a viewless tensor. + + This function should be used in cases where the computation graph needs + to be propagated, but we only want a viewless tensor (e.g., + ParallelTransformer's hidden_states). Call this function by passing + 'keep_graph = True' to 'make_viewless_tensor()'. + ''' + @staticmethod + def forward(ctx, inp, requires_grad): + return _kernel_make_viewless_tensor(inp, requires_grad) + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + +def make_viewless_tensor(inp, requires_grad, keep_graph): + ''' + Entry-point for creating viewless tensors. + + This method should be used, rather than calling 'MakeViewlessTensor' + or '_kernel_make_viewless_tensor' directly. This method acts as a + switch for determining if an autograd function or a regular method + should be used to create the tensor. + ''' + + # return tensor as-is, if not a 'view' + if inp._base is None: + return inp + + # create viewless tensor + if keep_graph: + return MakeViewlessTensor.apply(inp, requires_grad) + else: + return _kernel_make_viewless_tensor(inp, requires_grad) + + +def safely_set_viewless_tensor_data(tensor, new_data_tensor): + '''Safely set tensor's '.data' field. + + Check first that the tensor is viewless (i.e., '._base' not set). If not, + raise an exception. + ''' + # assert_viewless_tensor( + # tensor, + # extra_msg="FYI, tensor._base has shape %s, and new_data_tensor has shape %s." + # % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape), + # ) + tensor.data = new_data_tensor diff --git a/training/DeepSpeed-Domino/megatron/__init__.py b/training/DeepSpeed-Domino/megatron/__init__.py new file mode 100644 index 000000000..aa99c0665 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import torch + +from .global_vars import get_args, get_retro_args +from .global_vars import get_current_global_batch_size +from .global_vars import get_num_microbatches +from .global_vars import get_signal_handler +from .global_vars import update_num_microbatches +from .global_vars import get_tokenizer +from .global_vars import get_tensorboard_writer +from .global_vars import get_adlr_autoresume +from .global_vars import get_timers +from .initialize import initialize_megatron + +from .utils import (print_rank_0, + is_last_rank, + print_rank_last) diff --git a/training/DeepSpeed-Domino/megatron/arguments.py b/training/DeepSpeed-Domino/megatron/arguments.py new file mode 100644 index 000000000..ba398aa44 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/arguments.py @@ -0,0 +1,1305 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Megatron arguments.""" + +import argparse +import dataclasses +import json +import os +import torch +import types + +import torch.nn.functional as F +from megatron.global_vars import set_retro_args, get_retro_args +# from tools.retro.utils import get_args_path as get_retro_args_path + +from megatron.core.transformer import TransformerConfig + +def parse_args(extra_args_provider=None, ignore_unknown_args=False): + """Parse all arguments.""" + parser = argparse.ArgumentParser(description='Megatron-LM Arguments', + allow_abbrev=False) + + # Standard arguments. + parser = _add_network_size_args(parser) + parser = _add_regularization_args(parser) + parser = _add_training_args(parser) + parser = _add_initialization_args(parser) + parser = _add_learning_rate_args(parser) + parser = _add_checkpointing_args(parser) + parser = _add_mixed_precision_args(parser) + parser = _add_distributed_args(parser) + parser = _add_validation_args(parser) + parser = _add_data_args(parser) + parser = _add_autoresume_args(parser) + parser = _add_biencoder_args(parser) + parser = _add_vision_args(parser) + parser = _add_logging_args(parser) + parser = _add_inference_args(parser) + parser = _add_transformer_engine_args(parser) + parser = _add_retro_args(parser) + + # Custom arguments. + if extra_args_provider is not None: + parser = extra_args_provider(parser) + + # Parse. + if ignore_unknown_args: + args, _ = parser.parse_known_args() + else: + args = parser.parse_args() + + # Args from environment + args.rank = int(os.getenv('RANK', '0')) + args.world_size = int(os.getenv("WORLD_SIZE", '1')) + + return args + +def validate_args(args, defaults={}): + # Tensor model parallel size. + args.tensor_model_parallel_size = min( + args.tensor_model_parallel_size, args.world_size) + assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\ + ' ({}) is not divisible by tensor model parallel size ({})'.format( + args.world_size, args.tensor_model_parallel_size) + # Pipeline model parallel size. + args.pipeline_model_parallel_size = min( + args.pipeline_model_parallel_size, + (args.world_size // args.tensor_model_parallel_size)) + args.transformer_pipeline_model_parallel_size = ( + args.pipeline_model_parallel_size - 1 + if args.standalone_embedding_stage else + args.pipeline_model_parallel_size + ) + # Checks. + model_parallel_size = args.pipeline_model_parallel_size * \ + args.tensor_model_parallel_size + assert args.world_size % model_parallel_size == 0, 'world size ({}) is not'\ + ' divisible by tensor parallel size ({}) times pipeline parallel ' \ + 'size ({})'.format(args.world_size, args.tensor_model_parallel_size, + args.pipeline_model_parallel_size) + args.data_parallel_size = args.world_size // model_parallel_size + if args.rank == 0: + print('using world size: {}, data-parallel-size: {}, ' + 'tensor-model-parallel size: {}, ' + 'pipeline-model-parallel size: {} '.format( + args.world_size, args.data_parallel_size, + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size), flush=True) + if args.pipeline_model_parallel_size > 1: + if args.pipeline_model_parallel_split_rank is not None: + assert args.pipeline_model_parallel_split_rank < \ + args.pipeline_model_parallel_size, 'split rank needs'\ + ' to be less than pipeline model parallel size ({})'.format( + args.pipeline_model_parallel_size) + + # Deprecated arguments + assert args.batch_size is None, '--batch-size argument is no longer ' \ + 'valid, use --micro-batch-size instead' + del args.batch_size + assert args.warmup is None, '--warmup argument is no longer valid, use ' \ + '--lr-warmup-fraction instead' + del args.warmup + assert args.model_parallel_size is None, '--model-parallel-size is no ' \ + 'longer valid, use --tensor-model-parallel-size instead' + del args.model_parallel_size + + if args.checkpoint_activations: + if args.rank == 0: + print('--checkpoint-activations is no longer valid, use --recompute-activations, ' + 'or, for more control, --recompute-granularity and --recompute-method.') + exit() + del args.checkpoint_activations + + if args.recompute_activations: + args.recompute_granularity = 'selective' + del args.recompute_activations + + # Set input defaults. + for key in defaults: + # For default to be valid, it should not be provided in the + # arguments that are passed to the program. We check this by + # ensuring the arg is set to None. + if getattr(args, key, None) is not None: + if args.rank == 0: + print('WARNING: overriding default arguments for {key}:{v} \ + with {key}:{v2}'.format(key=key, v=defaults[key], + v2=getattr(args, key)), + flush=True) + else: + setattr(args, key, defaults[key]) + + # Batch size. + assert args.micro_batch_size is not None + assert args.micro_batch_size > 0 + if args.global_batch_size is None: + args.global_batch_size = args.micro_batch_size * args.data_parallel_size + if args.rank == 0: + print('setting global batch size to {}'.format( + args.global_batch_size), flush=True) + assert args.global_batch_size > 0 + if args.num_layers_per_virtual_pipeline_stage is not None: + assert args.pipeline_model_parallel_size > 2, \ + 'pipeline-model-parallel size should be greater than 2 with ' \ + 'interleaved schedule' + assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \ + 'number of layers is not divisible by number of layers per virtual ' \ + 'pipeline stage' + args.virtual_pipeline_model_parallel_size = \ + (args.num_layers // args.transformer_pipeline_model_parallel_size) // \ + args.num_layers_per_virtual_pipeline_stage + else: + args.virtual_pipeline_model_parallel_size = None + + # Parameters dtype. + args.params_dtype = torch.float + if args.fp16: + assert not args.bf16 + args.params_dtype = torch.half + if args.bf16: + assert not args.fp16 + args.params_dtype = torch.bfloat16 + # bfloat16 requires gradient accumulation and all-reduce to + # be done in fp32. + if not args.accumulate_allreduce_grads_in_fp32: + args.accumulate_allreduce_grads_in_fp32 = True + if args.rank == 0: + print('accumulate and all-reduce gradients in fp32 for ' + 'bfloat16 data type.', flush=True) + + if args.rank == 0: + print('using {} for parameters ...'.format(args.params_dtype), + flush=True) + + # If we do accumulation and all-reduces in fp32, we need to have local DDP + # and we should make sure use-contiguous-buffers-in-local-ddp is not off. + if args.accumulate_allreduce_grads_in_fp32: + assert args.DDP_impl == 'local' + assert args.use_contiguous_buffers_in_local_ddp + + # If we use the distributed optimizer, we need to have local DDP + # and we should make sure use-contiguous-buffers-in-local-ddp is on. + if args.use_distributed_optimizer: + assert args.DDP_impl == 'local' + assert args.use_contiguous_buffers_in_local_ddp + + # For torch DDP, we do not use contiguous buffer + if args.DDP_impl == 'torch': + args.use_contiguous_buffers_in_local_ddp = False + + if args.dataloader_type is None: + args.dataloader_type = 'single' + + # Consumed tokens. + args.consumed_train_samples = 0 + args.consumed_valid_samples = 0 + + # Support for variable sequence lengths across batches/microbatches. + # set it if the dataloader supports generation of variable sequence lengths + # across batches/microbatches. Due to additional communication overhead + # during pipeline parallelism, it should not be set if sequence length + # is constant during training. + args.variable_seq_lengths = False + + # Iteration-based training. + if args.train_iters: + # If we use iteration-based training, make sure the + # sample-based options are off. + assert args.train_samples is None, \ + 'expected iteration-based training' + assert args.lr_decay_samples is None, \ + 'expected iteration-based learning rate decay' + assert args.lr_warmup_samples == 0, \ + 'expected iteration-based learning rate warmup' + assert args.rampup_batch_size is None, \ + 'expected no batch-size rampup for iteration-based training' + if args.lr_warmup_fraction is not None: + assert args.lr_warmup_iters == 0, \ + 'can only specify one of lr-warmup-fraction and lr-warmup-iters' + + # Sample-based training. + if args.train_samples: + # If we use sample-based training, make sure the + # iteration-based options are off. + assert args.train_iters is None, \ + 'expected sample-based training' + assert args.lr_decay_iters is None, \ + 'expected sample-based learning rate decay' + assert args.lr_warmup_iters == 0, \ + 'expected sample-based learnig rate warmup' + if args.lr_warmup_fraction is not None: + assert args.lr_warmup_samples == 0, \ + 'can only specify one of lr-warmup-fraction ' \ + 'and lr-warmup-samples' + + if args.num_layers is not None: + assert args.encoder_num_layers is None, \ + 'cannot have both num-layers and encoder-num-layers specified' + args.encoder_num_layers = args.num_layers + else: + assert args.encoder_num_layers is not None, \ + 'either num-layers or encoder-num-layers should be specified' + args.num_layers = args.encoder_num_layers + + # Check required arguments. + required_args = ['num_layers', 'hidden_size', 'num_attention_heads', + 'max_position_embeddings'] + for req_arg in required_args: + _check_arg_is_not_none(args, req_arg) + + # Checks. + if args.ffn_hidden_size is None: + args.ffn_hidden_size = 4 * args.hidden_size + + if args.swiglu: + # reduce the dimnesion for MLP since projections happens on + # two linear layers. this keeps the number of paramters in + # the same ballpark as the counterpart with 4*h size + # we keep it a multiple of 64, which means the actual tensor size + # will be a multiple of 64 / tp_size + args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64 + + if args.kv_channels is None: + assert args.hidden_size % args.num_attention_heads == 0 + args.kv_channels = args.hidden_size // args.num_attention_heads + + if args.seq_length is not None: + assert args.encoder_seq_length is None + args.encoder_seq_length = args.seq_length + else: + assert args.encoder_seq_length is not None + args.seq_length = args.encoder_seq_length + + if args.seq_length is not None: + assert args.max_position_embeddings >= args.seq_length + if args.decoder_seq_length is not None: + assert args.max_position_embeddings >= args.decoder_seq_length + if args.lr is not None: + assert args.min_lr <= args.lr + if args.save is not None: + assert args.save_interval is not None + # Mixed precision checks. + if args.fp16_lm_cross_entropy: + assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' + if args.fp32_residual_connection: + assert args.fp16 or args.bf16, \ + 'residual connection in fp32 only supported when using fp16 or bf16.' + + if args.weight_decay_incr_style == 'constant': + assert args.start_weight_decay is None + assert args.end_weight_decay is None + args.start_weight_decay = args.weight_decay + args.end_weight_decay = args.weight_decay + else: + assert args.start_weight_decay is not None + assert args.end_weight_decay is not None + + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + # Persistent fused layer norm. + if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11): + args.no_persist_layer_norm = True + if args.rank == 0: + print('Persistent fused layer norm kernel is supported from ' + 'pytorch v1.11 (nvidia pytorch container paired with v1.11). ' + 'Defaulting to no_persist_layer_norm=True') + + # Activation recomputing. + if args.distribute_saved_activations: + assert args.tensor_model_parallel_size > 1, 'can distribute ' \ + 'recomputed activations only across tensor model ' \ + 'parallel groups' + assert args.recompute_granularity == 'full', \ + 'distributed recompute activations is only '\ + 'application to full recompute granularity' + assert args.recompute_method is not None, \ + 'for distributed recompute activations to work you '\ + 'need to use a recompute method ' + assert (TORCH_MAJOR, TORCH_MINOR) >= (1, 10), \ + 'distributed recompute activations are supported for pytorch ' \ + 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \ + 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR) + + if args.recompute_granularity == 'selective': + assert args.recompute_method is None, \ + 'recompute method is not yet supported for ' \ + 'selective recomputing granularity' + + # disable sequence parallelism when tp=1 + # to avoid change in numerics when + # sequence_parallelism is enabled. + if args.tensor_model_parallel_size == 1: + args.sequence_parallel = False + + # disable async_tensor_model_parallel_allreduce when + # model parallel memory optimization is enabled + if args.sequence_parallel: + args.async_tensor_model_parallel_allreduce = False + + if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": + if args.sequence_parallel: + raise RuntimeError( + "Using sequence parallelism requires setting the environment variable " + "CUDA_DEVICE_MAX_CONNECTIONS to 1") + if args.async_tensor_model_parallel_allreduce: + raise RuntimeError( + "Using async gradient all reduce requires setting the environment " + "variable CUDA_DEVICE_MAX_CONNECTIONS to 1") + + # Disable bias gelu fusion if we are disabling bias altogether + if not args.add_bias_linear: + args.bias_gelu_fusion = False + + # Retro checks. + if args.retro_add_retriever: + + # Sequence parallelism unsupported. + assert not args.sequence_parallel, \ + "retro currently does not support sequence parallelism." + + # Pipeline parallelism unsupported. + assert args.pipeline_model_parallel_size == 1, \ + "retro currently does not support pipeline parallelism." + + # Load retro args. + # retro_args_path = get_retro_args_path(args.retro_workdir) + retro_args_path = None + assert os.path.exists(retro_args_path), "retro workdir missing args.json" + with open(retro_args_path) as f: + retro_args = types.SimpleNamespace(**json.load(f)) + retro_args.retro_return_doc_ids = args.retro_return_doc_ids + retro_args.retro_gpt_retrieved_length = \ + args.retro_num_retrieved_chunks * \ + retro_args.retro_gpt_chunk_length + set_retro_args(retro_args) + + # Legacy RoPE arguments + if args.use_rotary_position_embeddings: + args.position_embedding_type = 'rope' + + # Would just need to add 'NoPE' as a position_embedding_type to support this, but for now + # don't allow it to keep things simple + if not args.add_position_embedding and args.position_embedding_type != 'rope': + raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type') + + # Print arguments. + _print_args("arguments", args) + retro_args = get_retro_args() + if retro_args and args != retro_args: + _print_args("retro arguments", types.SimpleNamespace(**{k:v for k,v in vars(retro_args).items() if k.startswith("retro")}, rank=args.rank)) + + return args + + +def _print_args(title, args): + """Print arguments.""" + if args.rank == 0: + print(f'------------------------ {title} ------------------------', + flush=True) + str_list = [] + for arg in vars(args): + dots = '.' * (48 - len(arg)) + str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) + for arg in sorted(str_list, key=lambda x: x.lower()): + print(arg, flush=True) + print(f'-------------------- end of {title} ---------------------', + flush=True) + + +def _check_arg_is_not_none(args, arg): + assert getattr(args, arg) is not None, '{} argument is None'.format(arg) + +def core_transformer_config_from_args(args): + + # Translate args to core transformer configuration + kw_args = {} + for f in dataclasses.fields(TransformerConfig): + if hasattr(args, f.name): + kw_args[f.name] = getattr(args, f.name) + kw_args['persist_layer_norm'] = not args.no_persist_layer_norm + kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p + kw_args['deallocate_pipeline_outputs'] = True + kw_args['pipeline_dtype'] = args.params_dtype + kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm + if args.swiglu: + kw_args['activation_func'] = F.silu + kw_args['gated_linear_unit'] = True + kw_args['bias_gelu_fusion'] = False + if args.init_method_xavier_uniform: + kw_args['init_method'] = torch.nn.init.xavier_uniform_ + kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_ + if args.group_query_attention: + kw_args['num_query_groups'] = args.num_query_groups + else: + kw_args['num_query_groups'] = None + + return TransformerConfig(**kw_args) + +def _add_transformer_engine_args(parser): + group = parser.add_argument_group(title='Transformer-Engine') + + group.add_argument('--fp8-format', default=None, + choices=['e4m3', 'hybrid'], + help='Which fp8 format scheme to use for FP8 tensors in the forward and backward pass', + dest='fp8') + group.add_argument('--fp8-margin', type=int, default=0, + help='Scaling margin for fp8', + dest='fp8_margin') + group.add_argument('--fp8-interval', type=int, default=1, + help='Scaling update interval for fp8', + dest='fp8_interval') + group.add_argument('--fp8-amax-history-len', type=int, default=1, + help='Number of steps for which amax history is recorded per tensor', + dest='fp8_amax_history_len') + group.add_argument('--fp8-amax-compute-algo', default='most_recent', + choices=['most_recent', 'max'], + help='Algorithm for computing amax from history', + dest='fp8_amax_compute_algo') + group.add_argument('--no-fp8-wgrad', action='store_false', + help='Execute wgrad in higher precision even for FP8 runs', + dest='fp8_wgrad') + group.add_argument('--transformer-impl', default='local', + choices=['local', 'transformer_engine'], + help='Which Transformer implementation to use.', + dest='transformer_impl') + group.add_argument('--normalization', default='LayerNorm', + choices=['LayerNorm', 'RMSNorm'], + help='Which normalization technique to use.', + dest='normalization') + + return parser + +def _add_inference_args(parser): + group = parser.add_argument_group(title='inference') + + group.add_argument('--inference-batch-times-seqlen-threshold', + type=int, default=512, + help='During inference, if batch-size times ' + 'sequence-length is smaller than this threshold ' + 'then we will not use pipelining, otherwise we will.') + group.add_argument('--max-tokens-to-oom', + type=int, default=12000, + help='Maximum number of tokens during inference' + 'tokens here is # in prompt + # to generate' + 'Allows us to throw an error before OOM crashes server') + group.add_argument('--output-bert-embeddings', action='store_true', + help='Output Bert embeddings (via mean pooling) from ' + 'model, rather than its binary head output or entire ' + 'hidden batch.') + group.add_argument('--bert-embedder-type', default="megatron", + choices=["megatron", "huggingface"], + help='Select either Megatron or Huggingface as the ' + 'Bert embedder.') + + return parser + + +def _add_retro_args(parser): + group = parser.add_argument_group(title='retro') + + group.add_argument('--retro-workdir', default=None, + help='Retro working directory, which contains the ' + 'preprocessed data for for pretraining. This directory ' + 'is built during preprocessing (see ' + 'tools/retro/README.md), and contains subdirectories ' + 'for the chunk database and pretraining neighbors.') + group.add_argument('--retro-add-retriever', + action='store_true', default=False, + help='Add a retriever to the transformer, for use in ' + 'pretraining a Retro model.') + group.add_argument('--retro-cyclic-train-iters', type=int, default=None, + help='Set number of training iterations for cyclic ' + 'Retro training.') + group.add_argument('--retro-encoder-layers', type=int, default=2, + help='Number of layers to use for the retrieval ' + 'encoder.') + group.add_argument('--retro-encoder-hidden-dropout', + type=float, default=0.1, help='Hidden dropout for ' + 'retrieval encoder.') + group.add_argument('--retro-encoder-attention-dropout', + type=float, default=0.1, help='Attention dropout for ' + 'retrieval encoder.') + group.add_argument("--retro-num-neighbors", type=int, default=2, + help='Number of neighbors to retrieve during ' + 'pretraining.') + group.add_argument("--retro-num-retrieved-chunks", type=int, default=2, + help='Number of chunks to retrieve from the retrieval ' + 'database.') + group.add_argument("--retro-return-doc-ids", action="store_true", + help="Turn this on when preprocessing retro data.") + + # Enforce argument naming convention. + for action in group._group_actions: + prefix = action.dest.split("_")[0] + assert prefix == "retro", \ + "Retro args must be prefixed with '--retro-*', for consistent " \ + "styling. Please fix '%s'." % ", ".join(action.option_strings) + + return parser + + +def _add_network_size_args(parser): + group = parser.add_argument_group(title='network size') + + group.add_argument('--num-layers', type=int, default=None, + help='Number of transformer layers.') + group.add_argument('--encoder-num-layers', type=int, default=None, + help='Number of encoder transformer layers.') + group.add_argument('--decoder-num-layers', type=int, default=None, + help='Number of decoder transformer layers.') + group.add_argument('--hidden-size', type=int, default=None, + help='Tansformer hidden size.') + group.add_argument('--ffn-hidden-size', type=int, default=None, + help='Transformer Feed-Forward Network hidden size. ' + 'This is set to 4*hidden-size if not provided') + group.add_argument('--num-attention-heads', type=int, default=None, + help='Number of transformer attention heads.') + group.add_argument('--kv-channels', type=int, default=None, + help='Projection weights dimension in multi-head ' + 'attention. This is set to ' + ' args.hidden_size // args.num_attention_heads ' + 'if not provided.') + group.add_argument('--group-query-attention', action='store_true', + help='Use group-query attention.') + group.add_argument('--num-query-groups', type=int, default=1) + + group.add_argument('--max-position-embeddings', type=int, default=None, + help='Maximum number of position embeddings to use. ' + 'This is the size of position embedding.') + group.add_argument('--position-embedding-type', type=str, default='learned_absolute', + choices=['learned_absolute', 'rope'], + help='Position embedding type.') + group.add_argument('--use-rotary-position-embeddings', action='store_true', + help='Use rotary positional embeddings or not. ' + 'Deprecated: use --position-embedding-type') + group.add_argument('--rotary-percent', type=float, default=1.0, + help='Percent of rotary dimension to use, default 100%%') + group.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None, + help='Sequence length interpolation factor for rotary embeddings.') + group.add_argument('--no-position-embedding', + action='store_false', + help='Disable position embedding. Deprecated: use --position-embedding-type', + dest='add_position_embedding') + group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, + help='Pad the vocab size to be divisible by this value.' + 'This is added for computational efficieny reasons.') + group.add_argument('--layernorm-epsilon', type=float, default=1e-5, + help='Layer norm epsilon.') + group.add_argument('--apply-layernorm-1p', action='store_true', + help='Adjust LayerNorm weights such that they are centered ' + 'around zero. This improves numerical stability.') + group.add_argument('--apply-residual-connection-post-layernorm', + action='store_true', + help='If set, use original BERT residula connection ' + 'ordering.') + group.add_argument('--openai-gelu', action='store_true', + help='Use OpenAIs GeLU implementation. This option' + 'should not be used unless for backward compatibility' + 'reasons.') + group.add_argument('--squared-relu', action='store_true', + help='Use squared relu activation instead of default gelu') + group.add_argument('--swiglu', action='store_true', + help='Use gated linear units and SiLU activation instead of default gelu') + group.add_argument('--onnx-safe', type=bool, required=False, + help='Use workarounds for known problems with ' + 'Torch ONNX exporter') + group.add_argument('--bert-no-binary-head', action='store_false', + help='Disable BERT binary head.', + dest='bert_binary_head') + group.add_argument('--num-experts', type=int, default=None, + help='Number of Experts in Switch Transformer (None means no Switch)') + group.add_argument('--untie-embeddings-and-output-weights', action='store_true', + help='Untie embeddings and output weights.'), + group.add_argument('--embedding-weights-in-fp32', action='store_true', + help='Cast word embedding weights to fp32 before embedding fwd.'), + return parser + + +def _add_logging_args(parser): + group = parser.add_argument_group(title='logging') + + group.add_argument('--log-params-norm', action='store_true', + help='If set, calculate and log parameters norm.') + group.add_argument('--log-num-zeros-in-grad', action='store_true', + help='If set, calculate and log the number of zeros in gradient.') + group.add_argument('--timing-log-level', type=int, + default=0, choices=range(0,3), + help='Granularity level to measure and report timing. ' + ' 0: report only iteration time and make sure timing ' + ' does not introduce extra overhead.' + ' 1: report timing for operations that are executed ' + ' very limited times (basically once) during ' + ' each iteration (such as gradient all-reduce) ' + ' 2: report timing for operations that migh be ' + ' executed numerous times during each iteration. ' + 'Note that setting the level to 1 or 2 might ' + 'cause increase in iteration time.') + group.add_argument('--no-barrier-with-level-1-timing', action='store_false', + help='If not set, use barrier with level 1 time ' + 'measurements. Note that this is up to the user ' + 'to make sure calling barrier with their timers ' + 'will not result in hangs. This can happen if for ' + 'example the user adds a level 1 timer that is not ' + 'called by all ranks.', + dest='barrier_with_L1_time') + group.add_argument('--timing-log-option', type=str, default='minmax', + choices=['max', 'minmax', 'all'], + help='Options for logging timing:' + ' max: report the max timing across all ranks' + ' minmax: report min and max timings across all ranks' + ' all: report timings of all ranks.') + group.add_argument('--tensorboard-log-interval', type=int, default=1, + help='Report to tensorboard interval.') + group.add_argument('--tensorboard-queue-size', type=int, default=1000, + help='Size of the tensorboard queue for pending events ' + 'and summaries before one of the ‘add’ calls forces a ' + 'flush to disk.') + group.add_argument('--log-timers-to-tensorboard', action='store_true', + help='If set, write timers to tensorboard.') + group.add_argument('--log-batch-size-to-tensorboard', action='store_true', + help='If set, write batch-size to tensorboard.') + group.add_argument('--no-log-learnig-rate-to-tensorboard', + action='store_false', + help='Disable learning rate logging to tensorboard.', + dest='log_learning_rate_to_tensorboard') + group.add_argument('--no-log-loss-scale-to-tensorboard', + action='store_false', + help='Disable loss-scale logging to tensorboard.', + dest='log_loss_scale_to_tensorboard') + group.add_argument('--log-validation-ppl-to-tensorboard', + action='store_true', + help='If set, write validation perplexity to ' + 'tensorboard.') + group.add_argument('--log-memory-to-tensorboard', + action='store_true', + help='Enable memory logging to tensorboard.') + group.add_argument('--log-world-size-to-tensorboard', + action='store_true', + help='Enable world size logging to tensorboard.') + + return parser + + +def _add_regularization_args(parser): + group = parser.add_argument_group(title='regularization') + + group.add_argument('--attention-dropout', type=float, default=0.1, + help='Post attention dropout probability.') + group.add_argument('--hidden-dropout', type=float, default=0.1, + help='Dropout probability for hidden state transformer.') + group.add_argument('--weight-decay', type=float, default=0.01, + help='Weight decay coefficient for L2 regularization.') + group.add_argument('--start-weight-decay', type=float, + help='Initial weight decay coefficient for L2 regularization.') + group.add_argument('--end-weight-decay', type=float, + help='End of run weight decay coefficient for L2 regularization.') + group.add_argument('--weight-decay-incr-style', type=str, default='constant', + choices=['constant', 'linear', 'cosine'], + help='Weight decay increment function.') + group.add_argument('--clip-grad', type=float, default=1.0, + help='Gradient clipping based on global L2 norm.') + group.add_argument('--adam-beta1', type=float, default=0.9, + help='First coefficient for computing running averages ' + 'of gradient and its square') + group.add_argument('--adam-beta2', type=float, default=0.999, + help='Second coefficient for computing running averages ' + 'of gradient and its square') + group.add_argument('--adam-eps', type=float, default=1e-08, + help='Term added to the denominator to improve' + 'numerical stability') + group.add_argument('--sgd-momentum', type=float, default=0.9, + help='Momentum factor for sgd') + + return parser + + +def _add_training_args(parser): + group = parser.add_argument_group(title='training') + + group.add_argument('--micro-batch-size', type=int, default=None, + help='Batch size per model instance (local batch size). ' + 'Global batch size is local batch size times data ' + 'parallel size times number of micro batches.') + group.add_argument('--batch-size', type=int, default=None, + help='Old batch size parameter, do not use. ' + 'Use --micro-batch-size instead') + group.add_argument('--global-batch-size', type=int, default=None, + help='Training batch size. If set, it should be a ' + 'multiple of micro-batch-size times data-parallel-size. ' + 'If this value is None, then ' + 'use micro-batch-size * data-parallel-size as the ' + 'global batch size. This choice will result in 1 for ' + 'number of micro-batches.') + group.add_argument('--rampup-batch-size', nargs='*', default=None, + help='Batch size ramp up with the following values:' + ' --rampup-batch-size ' + ' ' + ' ' + 'For example:' + ' --rampup-batch-size 16 8 300000 \ ' + ' --global-batch-size 1024' + 'will start with global batch size 16 and over ' + ' (1024 - 16) / 8 = 126 intervals will increase' + 'the batch size linearly to 1024. In each interval' + 'we will use approximately 300000 / 126 = 2380 samples.') + group.add_argument('--recompute-activations', action='store_true', + help='recompute activation to allow for training ' + 'with larger models, sequences, and batch sizes.') + group.add_argument('--recompute-granularity', type=str, default=None, + choices=['full', 'selective'], + help='Checkpoint activations to allow for training ' + 'with larger models, sequences, and batch sizes. ' + 'It is supported at two granularities 1) full: ' + 'whole transformer layer is recomputed, ' + '2) selective: core attention part of the transformer ' + 'layer is recomputed.') + group.add_argument('--distribute-saved-activations', + action='store_true', + help='If set, distribute recomputed activations ' + 'across model parallel group.') + group.add_argument('--recompute-method', type=str, default=None, + choices=['uniform', 'block'], + help='1) uniform: uniformly divide the total number of ' + 'Transformer layers and recompute the input activation of ' + 'each divided chunk at specified granularity, ' + '2) recompute the input activations of only a set number of ' + 'individual Transformer layers per pipeline stage and do the ' + 'rest without any recomputing at specified granularity' + 'default) do not apply activations recompute to any layers') + group.add_argument('--recompute-num-layers', type=int, default=None, + help='1) uniform: the number of Transformer layers in each ' + 'uniformly divided recompute unit, ' + '2) block: the number of individual Transformer layers ' + 'to recompute within each pipeline stage.') + group.add_argument('--profile', action='store_true', + help='Enable nsys profiling. When using this option, nsys ' + 'options should be specified in commandline. An example ' + 'nsys commandline is `nsys profile -s none -t nvtx,cuda ' + '-o --force-overwrite true ' + '--capture-range=cudaProfilerApi ' + '--capture-range-end=stop`.') + group.add_argument('--profile-step-start', type=int, default=10, + help='Gloable step to start profiling.') + group.add_argument('--profile-step-end', type=int, default=12, + help='Gloable step to stop profiling.') + group.add_argument('--profile-ranks', nargs='+', type=int, default=[0], + help='Global ranks to profile.') + + + # deprecated + group.add_argument('--checkpoint-activations', action='store_true', + help='Checkpoint activation to allow for training ' + 'with larger models, sequences, and batch sizes.') + group.add_argument('--train-iters', type=int, default=None, + help='Total number of iterations to train over all ' + 'training runs. Note that either train-iters or ' + 'train-samples should be provided.') + group.add_argument('--train-samples', type=int, default=None, + help='Total number of samples to train over all ' + 'training runs. Note that either train-iters or ' + 'train-samples should be provided.') + group.add_argument('--log-interval', type=int, default=100, + help='Report loss and timing interval.') + group.add_argument('--exit-interval', type=int, default=None, + help='Exit the program after the iteration is divisible ' + 'by this value.') + group.add_argument('--exit-duration-in-mins', type=int, default=None, + help='Exit the program after this many minutes.') + group.add_argument('--exit-signal-handler', action='store_true', + help='Dynamically save the checkpoint and shutdown the ' + 'training if SIGTERM is received') + group.add_argument('--tensorboard-dir', type=str, default=None, + help='Write TensorBoard logs to this directory.') + group.add_argument('--no-masked-softmax-fusion', + action='store_false', + help='Disable fusion of query_key_value scaling, ' + 'masking, and softmax.', + dest='masked_softmax_fusion') + group.add_argument('--no-bias-gelu-fusion', action='store_false', + help='Disable bias and gelu fusion.', + dest='bias_gelu_fusion') + group.add_argument('--no-bias-dropout-fusion', action='store_false', + help='Disable bias and dropout fusion.', + dest='bias_dropout_fusion') + group.add_argument('--use-flash-attn', action='store_true', + help='use FlashAttention implementation of attention. ' + 'https://arxiv.org/abs/2205.14135') + group.add_argument('--disable-bias-linear', action='store_false', + help='Disable bias in the linear layers', + dest='add_bias_linear') + group.add_argument('--optimizer', type=str, default='adam', + choices=['adam', 'sgd'], + help='Optimizer function') + group.add_argument('--dataloader-type', type=str, default=None, + choices=['single', 'cyclic'], + help='Single pass vs multiple pass data loader') + group.add_argument('--no-async-tensor-model-parallel-allreduce', + action='store_false', + help='Disable asynchronous execution of ' + 'tensor-model-parallel all-reduce with weight ' + 'gradient compuation of a column-linear layer.', + dest='async_tensor_model_parallel_allreduce') + group.add_argument('--no-persist-layer-norm', action='store_true', + help='Disable using persistent fused layer norm kernel. ' + 'This kernel supports only a set of hidden sizes. Please ' + 'check persist_ln_hidden_sizes if your hidden ' + 'size is supported.') + group.add_argument('--sequence-parallel', action='store_true', + help='Enable sequence parallel optimization.') + group.add_argument('--no-gradient-accumulation-fusion', + action='store_false', + help='Disable fusing gradient accumulation to weight ' + 'gradient computation of linear layers', + dest='gradient_accumulation_fusion') + return parser + + +def _add_initialization_args(parser): + group = parser.add_argument_group(title='initialization') + + group.add_argument('--seed', type=int, default=1234, + help='Random seed used for python, numpy, ' + 'pytorch, and cuda.') + group.add_argument('--data-parallel-random-init', action='store_true', + help='Enable random initialization of params ' + 'across data parallel ranks') + group.add_argument('--init-method-std', type=float, default=0.02, + help='Standard deviation of the zero mean normal ' + 'distribution used for weight initialization.') + group.add_argument('--init-method-xavier-uniform', action='store_true', + help='Enable Xavier uniform parameter initialization') + + return parser + + +def _add_learning_rate_args(parser): + group = parser.add_argument_group(title='learning rate') + + group.add_argument('--lr', type=float, default=None, + help='Initial learning rate. Depending on decay style ' + 'and initial warmup, the learing rate at each ' + 'iteration would be different.') + group.add_argument('--lr-decay-style', type=str, default='linear', + choices=['constant', 'linear', 'cosine', 'inverse-square-root'], + help='Learning rate decay function.') + group.add_argument('--lr-decay-iters', type=int, default=None, + help='number of iterations to decay learning rate over,' + ' If None defaults to `--train-iters`') + group.add_argument('--lr-decay-samples', type=int, default=None, + help='number of samples to decay learning rate over,' + ' If None defaults to `--train-samples`') + group.add_argument('--lr-warmup-fraction', type=float, default=None, + help='fraction of lr-warmup-(iters/samples) to use ' + 'for warmup (as a float)') + group.add_argument('--lr-warmup-iters', type=int, default=0, + help='number of iterations to linearly warmup ' + 'learning rate over.') + group.add_argument('--lr-warmup-samples', type=int, default=0, + help='number of samples to linearly warmup ' + 'learning rate over.') + group.add_argument('--lr-warmup-init', type=float, default=0.0, + help='Initial value for learning rate warmup. The ' + 'scheduler starts warmup from this value.') + group.add_argument('--warmup', type=int, default=None, + help='Old lr warmup argument, do not use. Use one of the' + '--lr-warmup-* arguments above') + group.add_argument('--min-lr', type=float, default=0.0, + help='Minumum value for learning rate. The scheduler' + 'clip values below this threshold.') + group.add_argument('--override-opt_param-scheduler', action='store_true', + help='Reset the values of the scheduler (learning rate,' + 'warmup iterations, minimum learning rate, maximum ' + 'number of iterations, and decay style from input ' + 'arguments and ignore values from checkpoints. Note' + 'that all the above values will be reset.') + group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true', + help='Use checkpoint to set the values of the scheduler ' + '(learning rate, warmup iterations, minimum learning ' + 'rate, maximum number of iterations, and decay style ' + 'from checkpoint and ignore input arguments.') + + return parser + + +def _add_checkpointing_args(parser): + group = parser.add_argument_group(title='checkpointing') + + group.add_argument('--save', type=str, default=None, + help='Output directory to save checkpoints to.') + group.add_argument('--save-interval', type=int, default=None, + help='Number of iterations between checkpoint saves.') + group.add_argument('--no-save-optim', action='store_true', default=None, + help='Do not save current optimizer.') + group.add_argument('--no-save-rng', action='store_true', default=None, + help='Do not save current rng state.') + group.add_argument('--load', type=str, default=None, + help='Directory containing a model checkpoint.') + group.add_argument('--no-load-optim', action='store_true', default=None, + help='Do not load optimizer when loading checkpoint.') + group.add_argument('--no-load-rng', action='store_true', default=None, + help='Do not load rng state when loading checkpoint.') + group.add_argument('--finetune', action='store_true', + help='Load model for finetuning. Do not load optimizer ' + 'or rng state from checkpoint and set iteration to 0. ' + 'Assumed when loading a release checkpoint.') + group.add_argument('--no-initialization', action='store_false', + help='Do not perform initialization when building model, ' + 'can reduce startup time when definitely loading from a ' + 'checkpoint', + dest='perform_initialization') + group.add_argument('--use-checkpoint-args', action='store_true', + help='Override any command line arguments with arguments ' + 'from the checkpoint') + group.add_argument('--exit-on-missing-checkpoint', action='store_true', + help="If '--load' is set, but checkpoint is not found " + "(e.g., path typo), then exit instead of random " + "initialization.") + + return parser + + +def _add_mixed_precision_args(parser): + group = parser.add_argument_group(title='mixed precision') + + group.add_argument('--fp16', action='store_true', + help='Run model in fp16 mode.') + group.add_argument('--bf16', action='store_true', + help='Run model in bfloat16 mode.') + group.add_argument('--loss-scale', type=float, default=None, + help='Static loss scaling, positive power of 2 ' + 'values can improve fp16 convergence. If None, dynamic' + 'loss scaling is used.') + group.add_argument('--initial-loss-scale', type=float, default=2**32, + help='Initial loss-scale for dynamic loss scaling.') + group.add_argument('--min-loss-scale', type=float, default=1.0, + help='Minimum loss scale for dynamic loss scale.') + group.add_argument('--loss-scale-window', type=float, default=1000, + help='Window over which to raise/lower dynamic scale.') + group.add_argument('--hysteresis', type=int, default=2, + help='hysteresis for dynamic loss scaling') + group.add_argument('--fp32-residual-connection', action='store_true', + help='Move residual connections to fp32.') + group.add_argument('--no-query-key-layer-scaling', action='store_false', + help='Do not scale Q * K^T by 1 / layer-number.', + dest='apply_query_key_layer_scaling') + group.add_argument('--attention-softmax-in-fp32', action='store_true', + help='Run attention masking and softmax in fp32. ' + 'This flag is ignored unless ' + '--no-query-key-layer-scaling is specified.') + group.add_argument('--accumulate-allreduce-grads-in-fp32', + action='store_true', + help='Gradient accumulation and all-reduce in fp32.') + group.add_argument('--fp16-lm-cross-entropy', action='store_true', + help='Move the cross entropy unreduced loss calculation' + 'for lm head to fp16.') + + return parser + + +def _add_distributed_args(parser): + group = parser.add_argument_group(title='distributed') + + group.add_argument('--tensor-model-parallel-size', type=int, default=1, + help='Degree of tensor model parallelism.') + group.add_argument('--pipeline-model-parallel-size', type=int, default=1, + help='Degree of pipeline model parallelism.') + group.add_argument('--pipeline-model-parallel-split-rank', + type=int, default=None, + help='Rank where encoder and decoder should be split.') + group.add_argument('--model-parallel-size', type=int, default=None, + help='Old model parallel argument, do not use. Use ' + '--tensor-model-parallel-size instead.') + group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, + help='Number of layers per virtual pipeline stage') + group.add_argument('--overlap-p2p-communication', + action='store_true', + help='overlap pipeline parallel communication with forward and backward chunks', + dest='overlap_p2p_comm') + group.add_argument('--distributed-backend', default='nccl', + choices=['nccl', 'gloo'], + help='Which backend to use for distributed training.') + group.add_argument('--distributed-timeout-minutes', type=int, default=10, + help='Timeout minutes for torch.distributed.') + group.add_argument('--DDP-impl', default='local', + choices=['local', 'torch'], + help='which DistributedDataParallel implementation ' + 'to use.') + group.add_argument('--no-contiguous-buffers-in-local-ddp', + action='store_false', help='If set, dont use ' + 'contiguous buffer in local DDP.', + dest='use_contiguous_buffers_in_local_ddp') + group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', + help='Use scatter/gather to optimize communication of tensors in pipeline', + dest='scatter_gather_tensors_in_pipeline') + group.add_argument('--use-ring-exchange-p2p', action='store_true', + default=False, help='If set, use custom-built ring exchange ' + 'for p2p communications. Note that this option will require ' + 'a custom built image that support ring-exchange p2p.') + group.add_argument('--local_rank', type=int, default=None, + help='local rank passed from distributed launcher.') + group.add_argument('--lazy-mpu-init', type=bool, required=False, + help='If set to True, initialize_megatron() ' + 'skips DDP initialization and returns function to ' + 'complete it instead.Also turns on ' + '--use-cpu-initialization flag. This is for ' + 'external DDP manager.' ) + group.add_argument('--use-cpu-initialization', action='store_true', + default=None, help='If set, affine parallel weights ' + 'initialization uses CPU' ) + group.add_argument('--empty-unused-memory-level', default=0, type=int, + choices=[0, 1, 2], + help='Call torch.cuda.empty_cache() each iteration ' + '(training and eval), to reduce fragmentation.' + '0=off, 1=moderate, 2=aggressive.') + group.add_argument('--standalone-embedding-stage', action='store_true', + default=False, help='If set, *input* embedding layer ' + 'is placed on its own pipeline stage, without any ' + 'transformer layers. (For T5, this flag currently only ' + 'affects the encoder embedding.)') + group.add_argument('--use-distributed-optimizer', action='store_true', + help='Use distributed optimizer.') + + return parser + + +def _add_validation_args(parser): + group = parser.add_argument_group(title='validation') + + group.add_argument('--eval-iters', type=int, default=100, + help='Number of iterations to run for evaluation' + 'validation/test for.') + group.add_argument('--eval-interval', type=int, default=1000, + help='Interval between running evaluation on ' + 'validation set.') + group.add_argument('--skip-train', action='store_true', + default=False, help='If set, bypass the training loop, ' + 'optionally do evaluation for validation/test, and exit.') + + return parser + + +def _add_data_args(parser): + group = parser.add_argument_group(title='data and dataloader') + + group.add_argument('--data-path', nargs='*', default=None, + help='Path to the training dataset. Accepted format:' + '1) a single data path, 2) multiple datasets in the' + 'form: dataset1-weight dataset1-path dataset2-weight ' + 'dataset2-path ... It is used with --split when a ' + 'single dataset used for all three: train, valid ' + 'and test. It is exclusive to the other ' + '--*-data-path args') + group.add_argument('--split', type=str, default='969, 30, 1', + help='Comma-separated list of proportions for training,' + ' validation, and test split. For example the split ' + '`90,5,5` will use 90%% of data for training, 5%% for ' + 'validation and 5%% for test.') + group.add_argument('--train-data-path', nargs='*', default=None, + help='Path to the training dataset. Accepted format:' + '1) a single data path, 2) multiple datasets in the' + 'form: dataset1-weight dataset1-path dataset2-weight ' + 'dataset2-path ...') + group.add_argument('--valid-data-path', nargs='*', default=None, + help='Path to the validation dataset. Accepted format:' + '1) a single data path, 2) multiple datasets in the' + 'form: dataset1-weight dataset1-path dataset2-weight ' + 'dataset2-path ...') + group.add_argument('--test-data-path', nargs='*', default=None, + help='Path to the test dataset. Accepted format:' + '1) a single data path, 2) multiple datasets in the' + 'form: dataset1-weight dataset1-path dataset2-weight ' + 'dataset2-path ...') + group.add_argument('--data-cache-path', default=None, + help='Path to a directory to hold cached index files.') + + group.add_argument('--vocab-size', type=int, default=None, + help='Size of vocab before EOD or padding.') + group.add_argument('--vocab-file', type=str, default=None, + help='Path to the vocab file.') + group.add_argument('--merge-file', type=str, default=None, + help='Path to the BPE merge file.') + group.add_argument('--vocab-extra-ids', type=int, default=0, + help='Number of additional vocabulary tokens. ' + 'They are used for span masking in the T5 model') + group.add_argument('--seq-length', type=int, default=None, + help='Maximum sequence length to process.') + group.add_argument('--encoder-seq-length', type=int, default=None, + help='Maximum encoder sequence length to process.' + 'This should be exclusive of --seq-length') + group.add_argument('--decoder-seq-length', type=int, default=None, + help="Maximum decoder sequence length to process.") + group.add_argument('--retriever-seq-length', type=int, default=256, + help='Maximum sequence length for the biencoder model ' + 'for retriever') + group.add_argument('--sample-rate', type=float, default=1.0, + help='sample rate for training data. Supposed to be 0 ' + ' < sample_rate < 1') + group.add_argument('--mask-prob', type=float, default=0.15, + help='Probability of replacing a token with mask.') + group.add_argument('--short-seq-prob', type=float, default=0.1, + help='Probability of producing a short sequence.') + group.add_argument('--mmap-warmup', action='store_true', + help='Warm up mmap files.') + group.add_argument('--num-workers', type=int, default=2, + help="Dataloader number of workers.") + group.add_argument('--tokenizer-type', type=str, + default=None, + choices=['BertWordPieceLowerCase', + 'BertWordPieceCase', + 'GPT2BPETokenizer', + 'SentencePieceTokenizer', + 'GPTSentencePieceTokenizer', + 'NullTokenizer'], + help='What type of tokenizer to use.') + group.add_argument('--tokenizer-model', type=str, default=None, + help='Sentencepiece tokenizer model.') + group.add_argument('--data-impl', type=str, default='infer', + choices=['mmap', 'infer'], + help='Implementation of indexed datasets.') + group.add_argument('--reset-position-ids', action='store_true', + help='Reset posistion ids after end-of-document token.') + group.add_argument('--reset-attention-mask', action='store_true', + help='Reset self attention maske after ' + 'end-of-document token.') + group.add_argument('--eod-mask-loss', action='store_true', + help='Mask loss for the end of document tokens.') + + return parser + + +def _add_autoresume_args(parser): + group = parser.add_argument_group(title='autoresume') + + group.add_argument('--adlr-autoresume', action='store_true', + help='Enable autoresume on adlr cluster.') + group.add_argument('--adlr-autoresume-interval', type=int, default=1000, + help='Intervals over which check for autoresume' + 'termination signal') + + return parser + + +def _add_biencoder_args(parser): + group = parser.add_argument_group(title='biencoder') + + # network size + group.add_argument('--ict-head-size', type=int, default=None, + help='Size of block embeddings to be used in ICT and ' + 'REALM (paper default: 128)') + group.add_argument('--biencoder-projection-dim', type=int, default=0, + help='Size of projection head used in biencoder (paper' + ' default: 128)') + group.add_argument('--biencoder-shared-query-context-model', action='store_true', + help='Whether to share the parameters of the query ' + 'and context models or not') + + # checkpointing + group.add_argument('--ict-load', type=str, default=None, + help='Directory containing an ICTBertModel checkpoint') + group.add_argument('--bert-load', type=str, default=None, + help='Directory containing an BertModel checkpoint ' + '(needed to start ICT and REALM)') + + # data + group.add_argument('--titles-data-path', type=str, default=None, + help='Path to titles dataset used for ICT') + group.add_argument('--query-in-block-prob', type=float, default=0.1, + help='Probability of keeping query in block for ' + 'ICT dataset') + group.add_argument('--use-one-sent-docs', action='store_true', + help='Whether to use one sentence documents in ICT') + group.add_argument('--evidence-data-path', type=str, default=None, + help='Path to Wikipedia Evidence frm DPR paper') + + # training + group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int, + default=[], help="Which top-k accuracies to report " + "(e.g. '1 5 20')") + group.add_argument('--retriever-score-scaling', action='store_true', + help='Whether to scale retriever scores by inverse ' + 'square root of hidden size') + + # faiss index + group.add_argument('--block-data-path', type=str, default=None, + help='Where to save/load BlockData to/from') + group.add_argument('--embedding-path', type=str, default=None, + help='Where to save/load Open-Retrieval Embedding' + ' data to/from') + + # indexer + group.add_argument('--indexer-batch-size', type=int, default=128, + help='How large of batches to use when doing indexing ' + 'jobs') + group.add_argument('--indexer-log-interval', type=int, default=1000, + help='After how many batches should the indexer ' + 'report progress') + return parser + + +def _add_vision_args(parser): + group = parser.add_argument_group(title="vision") + + # general vision arguements + group.add_argument('--num-classes', type=int, default=1000, + help='num of classes in vision classificaiton task') + group.add_argument('--img-h', type=int, default=224, + help='Image height for vision classification task') + group.add_argument('--img-w', type=int, default=224, + help='Image height for vision classification task') + group.add_argument('--num-channels', type=int, default=3, + help='Number of channels in input image data') + group.add_argument('--patch-dim', type=int, default=16, + help='patch dimension') + group.add_argument('--classes-fraction', type=float, default=1.0, + help='training with fraction of classes.') + group.add_argument('--data-per-class-fraction', type=float, default=1.0, + help='training with fraction of data per class.') + group.add_argument('--no-data-sharding', action='store_false', + help='Disable data sharding.', + dest='data_sharding') + group.add_argument('--head-lr-mult', type=float, default=1.0, + help='learning rate multiplier for head during finetuning') + + # pretraining type and backbone selection` + group.add_argument('--vision-pretraining', action='store_true', + help='flag to indicate vision pretraining') + group.add_argument('--vision-pretraining-type', type=str, default='classify', + choices=['classify', 'inpaint', 'dino'], + help='pretraining objectives') + group.add_argument('--vision-backbone-type', type=str, default='vit', + choices=['vit', 'mit', 'swin'], + help='backbone types types') + group.add_argument('--swin-backbone-type', type=str, default='tiny', + choices=['tiny', 'base', 'h3'], + help='pretraining objectives') + + # inpainting arguments + group.add_argument('--mask-type', type=str, default='random', + choices=['random', 'row'], + help='mask types') + group.add_argument('--mask-factor', type=float, default=1.0, + help='mask size scaling parameter') + + # dino arguments + group.add_argument('--iter-per-epoch', type=int, default=1250, + help='iterations per epoch') + group.add_argument('--dino-local-img-size', type=int, default=96, + help='Image size for vision classification task') + group.add_argument('--dino-local-crops-number', type=int, default=10, + help='Number of local crops') + group.add_argument('--dino-head-hidden-size', type=int, default=2048, + help='Hidden dimension size in dino head') + group.add_argument('--dino-bottleneck-size', type=int, default=256, + help='Bottle neck dimension in dino head ') + group.add_argument('--dino-freeze-last-layer', type=float, default=1, + help='Freezing last layer weights') + group.add_argument('--dino-norm-last-layer', action='store_true', + help='Disable Norm in last layer.') + group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04, + help='warump teacher temperature') + group.add_argument('--dino-teacher-temp', type=float, default=0.07, + help='teacher temperature') + group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30, + help='warmup teacher temperaure epochs') + + return parser diff --git a/training/DeepSpeed-Domino/megatron/checkpointing.py b/training/DeepSpeed-Domino/megatron/checkpointing.py new file mode 100644 index 000000000..e88b58513 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/checkpointing.py @@ -0,0 +1,686 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""Input/output checkpointing.""" + +import os +import random +import sys +import numpy as np + +import torch + +from megatron import update_num_microbatches +from megatron.core import mpu, tensor_parallel +from .global_vars import get_args +from .utils import (unwrap_model, + print_rank_0) + + +_CHECKPOINT_VERSION = None + + +def set_checkpoint_version(value): + global _CHECKPOINT_VERSION + if _CHECKPOINT_VERSION is not None: + assert _CHECKPOINT_VERSION == value, \ + "checkpoint versions do not match" + _CHECKPOINT_VERSION = value + + +def get_checkpoint_version(): + global _CHECKPOINT_VERSION + return _CHECKPOINT_VERSION + + +def check_checkpoint_args(checkpoint_args): + """Ensure fixed arguments for a model are the same for the input + arguments and the one retrieved from checkpoint.""" + args = get_args() + + def _compare(arg_name, old_arg_name=None, default=None): + if old_arg_name is not None: + ckpt_arg_name = old_arg_name + else: + ckpt_arg_name = arg_name + if default is not None: + checkpoint_value = getattr(checkpoint_args, ckpt_arg_name, default) + else: + checkpoint_value = getattr(checkpoint_args, ckpt_arg_name) + args_value = getattr(args, arg_name) + error_message = '{} value from checkpoint ({}) is not equal to the ' \ + 'input argument value ({}).'.format( + arg_name, checkpoint_value, args_value) + assert checkpoint_value == args_value, error_message + + _compare('num_layers') + _compare('hidden_size') + _compare('num_attention_heads') + _compare('add_position_embedding', default=True) + if args.vocab_file: + _compare('max_position_embeddings') + _compare('make_vocab_size_divisible_by') + _compare('padded_vocab_size') + _compare('tokenizer_type') + if args.data_parallel_random_init: + _compare('data_parallel_random_init') + if get_checkpoint_version() < 3.0: + _compare('tensor_model_parallel_size', + old_arg_name='model_parallel_size') + if get_checkpoint_version() >= 3.0: + _compare('tensor_model_parallel_size') + _compare('pipeline_model_parallel_size') + + +def ensure_directory_exists(filename): + """Build filename's path if it does not already exists.""" + dirname = os.path.dirname(filename) + os.makedirs(dirname, exist_ok = True) + + +def get_checkpoint_name(checkpoints_path, iteration, release=False, + pipeline_parallel=None, + tensor_rank=None, pipeline_rank=None): + """Determine the directory name for this rank's checkpoint.""" + if release: + directory = 'release' + else: + directory = 'iter_{:07d}'.format(iteration) + + # Use both the tensor and pipeline MP rank. + if pipeline_parallel is None: + pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1) + if tensor_rank is None: + tensor_rank = mpu.get_tensor_model_parallel_rank() + if pipeline_rank is None: + pipeline_rank = mpu.get_pipeline_model_parallel_rank() + + # Use both the tensor and pipeline MP rank. If using the distributed + # optimizer, then the optimizer's path must additionally include the + # data parallel rank. + if not pipeline_parallel: + common_path = os.path.join(checkpoints_path, directory, + f'mp_rank_{tensor_rank:02d}') + else: + common_path = os.path.join(checkpoints_path, directory, + f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}') + + return os.path.join(common_path, "model_optim_rng.pt") + + +def get_distributed_optimizer_checkpoint_name(model_checkpoint_name): + return os.path.join(os.path.dirname(model_checkpoint_name), + "distrib_optim.pt") + + +def find_checkpoint_rank_0(checkpoints_path, iteration, release=False): + """Finds the checkpoint for rank 0 without knowing if we are using + pipeline parallelism or not. + + Since the checkpoint naming scheme changes if pipeline parallelism + is present, we need to look for both naming schemes if we don't + know if the checkpoint has pipeline parallelism. + """ + + # Look for checkpoint with no pipelining + filename = get_checkpoint_name(checkpoints_path, iteration, release, + pipeline_parallel=False, + tensor_rank=0, pipeline_rank=0) + if os.path.isfile(filename): + return filename + + # Look for checkpoint with pipelining + filename = get_checkpoint_name(checkpoints_path, iteration, release, + pipeline_parallel=True, + tensor_rank=0, pipeline_rank=0) + if os.path.isfile(filename): + return filename + + return None, None + + +def get_checkpoint_tracker_filename(checkpoints_path): + + """Tracker file rescords the latest chckpoint during + training to restart from.""" + return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') + + +def read_metadata(tracker_filename): + # Read the tracker file and either set the iteration or + # mark it as a release checkpoint. + iteration = 0 + release = False + with open(tracker_filename, 'r') as f: + metastring = f.read().strip() + try: + iteration = int(metastring) + except ValueError: + release = metastring == 'release' + if not release: + print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( + tracker_filename)) + sys.exit() + assert iteration > 0 or release, 'error parsing metadata file {}'.format( + tracker_filename) + + # Get the max iteration retrieved across the ranks. + if torch.distributed.is_initialized(): + iters_cuda = torch.cuda.LongTensor([iteration]) + torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX) + max_iter = iters_cuda[0].item() + + # We should now have all the same iteration. + # If not, print a warning and chose the maximum + # iteration across all ranks. + if iteration != max_iter: + rank = torch.distributed.get_rank() + print('WARNING: on rank {} found iteration {} in the ' + 'metadata while max iteration across the ranks ' + 'is {}, replacing it with max iteration.'.format( + rank, iteration, max_iter), flush=True) + else: + # When loading a checkpoint outside of training (for example, + # when editing it), we might not have torch distributed + # initialized, in this case, just assume we have the latest + max_iter = iteration + return max_iter, release + + +def get_rng_state(): + """ collect rng state across data parallel ranks """ + args = get_args() + rng_state = { + 'random_rng_state': random.getstate(), + 'np_rng_state': np.random.get_state(), + 'torch_rng_state': torch.get_rng_state(), + 'cuda_rng_state': torch.cuda.get_rng_state(), + 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()} + + rng_state_list = None + if torch.distributed.is_initialized() and \ + mpu.get_data_parallel_world_size() > 1 and \ + args.data_parallel_random_init: + rng_state_list = \ + [None for i in range(mpu.get_data_parallel_world_size())] + torch.distributed.all_gather_object( + rng_state_list, + rng_state, + group=mpu.get_data_parallel_group()) + else: + rng_state_list = [rng_state] + + return rng_state_list + + +def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): + """Save a model checkpoint.""" + args = get_args() + + # Only rank zero of the data parallel writes to the disk. + model = unwrap_model(model) + + print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( + iteration, args.save)) + + # Collect rng state across data parallel ranks. + rng_state = get_rng_state() + + # Checkpoint name. + checkpoint_name = get_checkpoint_name(args.save, iteration) + + # Save distributed optimizer's custom parameter state. + if args.use_distributed_optimizer: + optim_checkpoint_name = \ + get_distributed_optimizer_checkpoint_name(checkpoint_name) + ensure_directory_exists(optim_checkpoint_name) + optimizer.save_parameter_state(optim_checkpoint_name) + + # Collect args, model, RNG. + if not torch.distributed.is_initialized() \ + or mpu.get_data_parallel_rank() == 0: + + # Arguments, iteration, and model. + state_dict = {} + state_dict['args'] = args + state_dict['checkpoint_version'] = 3.0 + state_dict['iteration'] = iteration + if len(model) == 1: + state_dict['model'] = model[0].state_dict_for_save_checkpoint() + else: + for i in range(len(model)): + mpu.set_virtual_pipeline_model_parallel_rank(i) + state_dict['model%d' % i] = \ + model[i].state_dict_for_save_checkpoint() + + # Optimizer stuff. + if not args.no_save_optim: + if optimizer is not None: + state_dict['optimizer'] = optimizer.state_dict() + if opt_param_scheduler is not None: + state_dict['opt_param_scheduler'] = \ + opt_param_scheduler.state_dict() + + # RNG states. + if not args.no_save_rng: + state_dict["rng_state"] = rng_state + + # Save. + ensure_directory_exists(checkpoint_name) + torch.save(state_dict, checkpoint_name) + + # Wait so everyone is done (necessary) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}' \ + .format(iteration, args.save)) + + # And update the latest iteration + if not torch.distributed.is_initialized() \ + or torch.distributed.get_rank() == 0: + tracker_filename = get_checkpoint_tracker_filename(args.save) + with open(tracker_filename, 'w') as f: + f.write(str(iteration)) + + # Wait so everyone is done (not necessary) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + +def _transpose_first_dim(t, num_splits, num_splits_first, model): + input_shape = t.size() + # We use a self_attention module but the values extracted aren't + # specific to self attention so should work for cross attention as well + while hasattr(model, 'module'): + model = model.module + attention_module = model.language_model.encoder.layers[0].self_attention + hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head + num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition + if num_splits_first: + """[num_splits * np * hn, h] + -->(view) [num_splits, np, hn, h] + -->(tranpose) [np, num_splits, hn, h] + -->(view) [np * num_splits * hn, h] """ + + intermediate_shape = \ + (num_splits, num_attention_heads_per_partition, + hidden_size_per_attention_head) + input_shape[1:] + + t = t.view(*intermediate_shape) + t = t.transpose(0, 1).contiguous() + else: + """[np * hn * num_splits, h] + -->(view) [np, hn, num_splits, h] + -->(tranpose) [np, num_splits, hn, h] + -->(view) [np * num_splits * hn, h] """ + + intermediate_shape = \ + (num_attention_heads_per_partition, + hidden_size_per_attention_head, num_splits) +\ + input_shape[1:] + + t = t.view(*intermediate_shape) + t = t.transpose(1, 2).contiguous() + t = t.view(*input_shape) + + return t + + +def fix_query_key_value_ordering(model, checkpoint_version): + """Fix up query/key/value matrix ordering if checkpoint + version is smaller than 2.0 + """ + if checkpoint_version < 2.0: + if isinstance(model, list): + assert len(model)==1 + model = model[0] + for name, param in model.named_parameters(): + if name.endswith(('.query_key_value.weight', '.query_key_value.bias')): + if checkpoint_version == 0: + fixed_param = _transpose_first_dim(param.data, 3, True, model) + elif checkpoint_version == 1.0: + fixed_param = _transpose_first_dim(param.data, 3, False, model) + else: + print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") + sys.exit() + param.data.copy_(fixed_param) + if name.endswith(('.key_value.weight', '.key_value.bias')): + if checkpoint_version == 0: + fixed_param = _transpose_first_dim(param.data, 2, True, model) + elif checkpoint_version == 1.0: + fixed_param = _transpose_first_dim(param.data, 2, False, model) + else: + print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") + sys.exit() + param.data.copy_(fixed_param) + print_rank_0(" succesfully fixed query-key-values ordering for" + " checkpoint version {}".format(checkpoint_version)) + + +def _load_base_checkpoint(load_dir, rank0=False): + """ Load the base state_dict from the given directory + + If rank0 is true, just loads rank 0 checkpoint, ignoring arguments. + """ + + # Read the tracker file and set the iteration. + tracker_filename = get_checkpoint_tracker_filename(load_dir) + + # If no tracker file, return nothing + if not os.path.isfile(tracker_filename): + if not rank0: + print_rank_0('WARNING: could not find the metadata file {} '.format( + tracker_filename)) + print_rank_0(' will not load any checkpoints and will start from ' + 'random') + return None, "", False + + # Otherwise, read the tracker file and either set the iteration or + # mark it as a release checkpoint. + iteration, release = read_metadata(tracker_filename) + + # Checkpoint. + if rank0: + checkpoint_name = find_checkpoint_rank_0(load_dir, iteration, release) + else: + checkpoint_name = get_checkpoint_name(load_dir, iteration, release) + if release: + print_rank_0(f' loading release checkpoint from {load_dir}') + else: + print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}') + + # Load the checkpoint. + try: + state_dict = torch.load(checkpoint_name, map_location='cpu') + except ModuleNotFoundError: + from megatron.fp16_deprecated import loss_scaler + # For backward compatibility. + if not rank0: + print_rank_0(' > deserializing using the old code structure ...') + sys.modules['fp16.loss_scaler'] = sys.modules[ + 'megatron.fp16_deprecated.loss_scaler'] + sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ + 'megatron.fp16_deprecated.loss_scaler'] + state_dict = torch.load(checkpoint_name, map_location='cpu') + sys.modules.pop('fp16.loss_scaler', None) + sys.modules.pop('megatron.fp16.loss_scaler', None) + except BaseException as e: + print_rank_0('could not load the checkpoint') + print_rank_0(e) + sys.exit() + + return state_dict, checkpoint_name, release + + +def load_args_from_checkpoint(args, load_arg='load'): + """Set required arguments from the checkpoint specified in the + arguments. + + Will overwrite arguments that have a non-None default value, but + will leave any arguments that default to None as set. + + Returns the same args NameSpace with the new values added/updated. + + If no checkpoint is specified in args, or if the checkpoint is + there but invalid, the arguments will not be modified + + """ + load_dir = getattr(args, load_arg) + + if load_dir is None: + print_rank_0('No load directory specified, using provided arguments.') + return args + + state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=True) + + # Args. + if not state_dict: + print_rank_0('Checkpoint not found to provide arguments, using provided arguments.') + return args + + if 'args' not in state_dict: + print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.') + return args + + checkpoint_args = state_dict['args'] + checkpoint_version = state_dict.get('checkpoint_version', 0) + args.iteration = state_dict['iteration'] + + # One-off conversion for foundation models + if hasattr(checkpoint_args, 'disable_bias_linear'): + setattr(checkpoint_args, 'add_bias_linear', not getattr(checkpoint_args, 'disable_bias_linear')) + + def _set_arg(arg_name, old_arg_name=None, force=False): + if not force and getattr(args, arg_name, None) is not None: + return + + if old_arg_name is not None: + checkpoint_value = getattr(checkpoint_args, old_arg_name, None) + else: + checkpoint_value = getattr(checkpoint_args, arg_name, None) + + if checkpoint_value is not None: + print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint") + setattr(args, arg_name, checkpoint_value) + else: + print_rank_0(f"Checkpoint did not provide arguments {arg_name}") + + _set_arg('num_layers') + _set_arg('hidden_size') + _set_arg('ffn_hidden_size') + _set_arg('seq_length') + _set_arg('num_attention_heads') + _set_arg('kv_channels') + _set_arg('max_position_embeddings') + _set_arg('position_embedding_type', force=True) + _set_arg('add_position_embedding', force=True) + _set_arg('use_rotary_position_embeddings', force=True) + _set_arg('rotary_percent', force=True) + _set_arg('add_bias_linear', force=True) + _set_arg('swiglu', force=True) + _set_arg('untie_embeddings_and_output_weights', force=True) + _set_arg('apply_layernorm_1p', force=True) + _set_arg('tokenizer_type') + _set_arg('padded_vocab_size') + if checkpoint_version < 3.0: + _set_arg('tensor_model_parallel_size', + 'model_parallel_size') + else: + _set_arg('tensor_model_parallel_size', force=True) + _set_arg('pipeline_model_parallel_size', force=True) + _set_arg('virtual_pipeline_model_parallel_size', force=True) + _set_arg('num_layers_per_virtual_pipeline_stage') + return args, checkpoint_args + + +def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True): + """Load a model checkpoint and return the iteration. + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` of the checkpoint match the names of + parameters and buffers in model. + """ + args = get_args() + load_dir = getattr(args, load_arg) + + model = unwrap_model(model) + + state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=False) + + # Checkpoint not loaded. + if state_dict is None: + + # Conditionally exit at this point. + if args.exit_on_missing_checkpoint: + print_rank_0(">> '--exit-on-missing-checkpoint' set ... exiting. <<") + torch.distributed.barrier() + sys.exit() + + # Iteration defaults to 0. + return 0 + + # Set checkpoint version. + set_checkpoint_version(state_dict.get('checkpoint_version', 0)) + + # Set iteration. + if args.finetune or release: + iteration = 0 + else: + try: + iteration = state_dict['iteration'] + except KeyError: + try: # Backward compatible with older checkpoints + iteration = state_dict['total_iters'] + except KeyError: + print_rank_0('A metadata file exists but unable to load ' + 'iteration from checkpoint {}, exiting'.format( + checkpoint_name)) + sys.exit() + + # Check arguments. + assert args.consumed_train_samples == 0 + assert args.consumed_valid_samples == 0 + if 'args' in state_dict and not args.finetune: + checkpoint_args = state_dict['args'] + check_checkpoint_args(checkpoint_args) + args.consumed_train_samples = getattr(checkpoint_args, + 'consumed_train_samples', 0) + update_num_microbatches(consumed_samples=args.consumed_train_samples) + args.consumed_valid_samples = getattr(checkpoint_args, + 'consumed_valid_samples', 0) + else: + print_rank_0('could not find arguments in the checkpoint ...') + + # Model. + if len(model) == 1: + model[0].load_state_dict(state_dict['model'], strict=strict) + else: + for i in range(len(model)): + mpu.set_virtual_pipeline_model_parallel_rank(i) + model[i].load_state_dict(state_dict['model%d' % i], strict=strict) + + # Fix up query/key/value matrix ordering if needed. + checkpoint_version = get_checkpoint_version() + print_rank_0(f' checkpoint version {checkpoint_version}') + fix_query_key_value_ordering(model, checkpoint_version) + + # Optimizer. + if not release and not args.finetune and not args.no_load_optim: + try: + # Load state dict. + if optimizer is not None: + optimizer.load_state_dict(state_dict['optimizer']) + + # Load distributed optimizer's custom parameter state. + if args.use_distributed_optimizer: + tracker_filename = get_checkpoint_tracker_filename(load_dir) + iteration, release = read_metadata(tracker_filename) + model_checkpoint_name = \ + get_checkpoint_name(load_dir, iteration, release) + optim_checkpoint_name = \ + get_distributed_optimizer_checkpoint_name( + model_checkpoint_name) + optimizer.load_parameter_state(optim_checkpoint_name) + + # Load scheduler. + if opt_param_scheduler is not None: + if 'lr_scheduler' in state_dict: # backward compatbility + opt_param_scheduler.load_state_dict(state_dict['lr_scheduler']) + else: + opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler']) + except KeyError: + print_rank_0('Unable to load optimizer from checkpoint {}. ' + 'Specify --no-load-optim or --finetune to prevent ' + 'attempting to load the optimizer state, ' + 'exiting ...'.format(checkpoint_name)) + sys.exit() + else: + if (args.fp16 or args.bf16) and optimizer is not None: + optimizer.reload_model_params() + + # rng states. + if not release and not args.finetune and not args.no_load_rng: + try: + if 'rng_state' in state_dict: + # access rng_state for data parallel rank + if args.data_parallel_random_init: + + rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()] + else: + rng_state = state_dict['rng_state'][0] + random.setstate(rng_state['random_rng_state']) + np.random.set_state(rng_state['np_rng_state']) + torch.set_rng_state(rng_state['torch_rng_state']) + torch.cuda.set_rng_state(rng_state['cuda_rng_state']) + # Check for empty states array + if not rng_state['rng_tracker_states']: + raise KeyError + tensor_parallel.get_cuda_rng_tracker().set_states( + rng_state['rng_tracker_states']) + else: # backward compatability + random.setstate(state_dict['random_rng_state']) + np.random.set_state(state_dict['np_rng_state']) + torch.set_rng_state(state_dict['torch_rng_state']) + torch.cuda.set_rng_state(state_dict['cuda_rng_state']) + # Check for empty states array + if not state_dict['rng_tracker_states']: + raise KeyError + tensor_parallel.get_cuda_rng_tracker().set_states( + state_dict['rng_tracker_states']) + except KeyError: + print_rank_0('Unable to load rng state from checkpoint {}. ' + 'Specify --no-load-rng or --finetune to prevent ' + 'attempting to load the rng state, ' + 'exiting ...'.format(checkpoint_name)) + sys.exit() + + # Some utilities want to load a checkpoint without distributed being initialized + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + print_rank_0(f' successfully loaded checkpoint from {args.load} ' + f'at iteration {iteration}') + + return iteration + + +def load_biencoder_checkpoint(model, only_query_model=False, + only_context_model=False, custom_load_path=None): + """ + selectively load retrieval models for indexing/retrieving + from saved checkpoints + """ + + args = get_args() + + model = unwrap_model(model) + + load_path = custom_load_path if custom_load_path is not None else args.load + + tracker_filename = get_checkpoint_tracker_filename(load_path) + with open(tracker_filename, 'r') as f: + iteration = int(f.read().strip()) + + checkpoint_name = get_checkpoint_name(load_path, iteration, + args.use_distributed_optimizer, + release=False) + + if mpu.get_data_parallel_rank() == 0: + print('global rank {} is loading checkpoint {}'.format( + torch.distributed.get_rank(), checkpoint_name)) + + state_dict = torch.load(checkpoint_name, map_location='cpu') + ret_state_dict = state_dict['model'] + + if only_query_model: + ret_state_dict.pop('context_model') + if only_context_model: + ret_state_dict.pop('query_model') + + assert len(model) == 1 + model[0].load_state_dict(ret_state_dict) + torch.distributed.barrier() + + if mpu.get_data_parallel_rank() == 0: + print(' successfully loaded {}'.format(checkpoint_name)) + + return model diff --git a/training/DeepSpeed-Domino/megatron/core/README.md b/training/DeepSpeed-Domino/megatron/core/README.md new file mode 100644 index 000000000..0c8c61738 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/README.md @@ -0,0 +1 @@ +Megatron Core is a library for efficient and scalable training of transformer based models. diff --git a/training/DeepSpeed-Domino/megatron/core/__init__.py b/training/DeepSpeed-Domino/megatron/core/__init__.py new file mode 100644 index 000000000..25a663c0c --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/__init__.py @@ -0,0 +1,11 @@ +import megatron.core.parallel_state +import megatron.core.tensor_parallel +import megatron.core.utils + +from .inference_params import InferenceParams +from .model_parallel_config import ModelParallelConfig + +# Alias parallel_state as mpu, its legacy name +mpu = parallel_state + +__all__ = ["parallel_state", "tensor_parallel", "utils", "InferenceParams", "ModelParallelConfig"] diff --git a/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/__init__.py b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/__init__.py new file mode 100644 index 000000000..70bc6869b --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +from .core import check_is_distributed_checkpoint +from .mapping import LocalNonpersitentObject, ShardedTensor +from .serialization import load, load_common_state_dict, save diff --git a/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/core.py b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/core.py new file mode 100644 index 000000000..f20a0c3a2 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/core.py @@ -0,0 +1,41 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +import json +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Optional + +CONFIG_FNAME = 'metadata.json' + + +class CheckpointingException(Exception): + pass + + +@dataclass +class CheckpointingConfig: + """ Documents backends used in the checkpoint. """ + + sharded_backend: str + sharded_backend_version: int = 1 + common_backend: str = 'torch' + common_backend_version: int = 1 + + +def check_is_distributed_checkpoint(checkpoint_dir): + return maybe_load_config(checkpoint_dir) is not None + + +def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]: + config_path = Path(checkpoint_dir, CONFIG_FNAME) + if not config_path.exists(): + return None + with config_path.open() as f: + config_dict = json.load(f) + return CheckpointingConfig(**config_dict) + + +def save_config(config: CheckpointingConfig, checkpoint_dir: str): + config_path = Path(checkpoint_dir, CONFIG_FNAME) + with config_path.open('w') as f: + json.dump(asdict(config), f) diff --git a/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/dict_utils.py b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/dict_utils.py new file mode 100644 index 000000000..c6baf4f11 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/dict_utils.py @@ -0,0 +1,200 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Utilities for operating with dicts and lists. """ + +from collections import defaultdict +from typing import Any, Callable, Iterable, Optional, Tuple, Union + +import torch + + +def extract_matching_values( + x: Union[dict, list], predicate: Callable +) -> Tuple[Union[dict, list], Union[dict, list]]: + """ Return matching and nonmatching values. Keeps hierarchy. """ + if isinstance(x, dict): + matching_vals = {} + nonmatching_vals = {} + for k, v in x.items(): + if isinstance(v, (list, dict)): + match, nonmatch = extract_matching_values(v, predicate) + if match: + matching_vals[k] = match + if nonmatch or not v: + nonmatching_vals[k] = nonmatch + elif predicate(v): + matching_vals[k] = v + else: + nonmatching_vals[k] = v + else: + assert isinstance(x, list) + matching_vals = [] + nonmatching_vals = [] + for v in x: + if isinstance(v, (list, dict)) and v: + match, nonmatch = extract_matching_values(v, predicate) + if match: + matching_vals.append(match) + if nonmatch or not v: + nonmatching_vals.append(nonmatch) + elif predicate(v): + matching_vals.append(v) + else: + nonmatching_vals.append(v) + return matching_vals, nonmatching_vals + + +def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]: + mismatch = [] + if isinstance(x1, dict) and isinstance(x2, dict): + only_left = [prefix + (k,) for k in x1.keys() - x2.keys()] + only_right = [prefix + (k,) for k in x2.keys() - x1.keys()] + for k in x2.keys() & x1.keys(): + _left, _right, _mismatch = diff(x1[k], x2[k], prefix + (k,)) + only_left.extend(_left) + only_right.extend(_right) + mismatch.extend(_mismatch) + elif isinstance(x1, list) and isinstance(x2, list): + only_left = list(range(len(x1) - 1, len(x2) - 1, -1)) + only_right = list(range(len(x1) - 1, len(x2) - 1, -1)) + for i, (v1, v2) in enumerate(zip(x1, x2)): + _left, _right, _mismatch = diff(v1, v2, prefix + (i,)) + only_left.extend(_left) + only_right.extend(_right) + mismatch.extend(_mismatch) + else: + only_left = [] + only_right = [] + if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): + _is_mismatch = not torch.all(x1 == x2) + else: + try: + _is_mismatch = bool(x1 != x2) + except RuntimeError: + _is_mismatch = True + + if _is_mismatch: + mismatch.append((prefix, type(x1), type(x2))) + + return only_left, only_right, mismatch + + +def inspect_keys_types(d: dict, prefix: Tuple = (), indent: int = 4): + print_indent = lambda: print(' ' * indent * len(prefix), end='') + for k, v in d.items(): + if isinstance(v, dict): + print_indent() + print(f'> {k}:') + inspect_keys_types(v, prefix + (k,), indent) + else: + print_indent() + if isinstance(v, torch.Tensor): + print(f'> {k}: {type(v)} of shape {v.shape}') + else: + print(f'> {k}: {type(v)}') + + +def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4): + print_indent = lambda: print(' ' * indent * len(prefix), end='') + if isinstance(x, dict): + print() + for k, v in x.items(): + print_indent() + print(f'> {k}: ', end='') + inspect_types(v, prefix + (k,), indent) + elif isinstance(x, list): + print() + for i, v in enumerate(x): + print_indent() + print(f'- {i}: ', end='') + inspect_types(v, prefix + (i,), indent) + else: + if isinstance(x, torch.Tensor): + print(f'Tensor of shape {x.shape}') + else: + try: + x_str = str(x) + except: + x_str = '' + if len(x_str) > 30: + x_str = x_str[:30] + '... (truncated)' + print(f'[{type(x)}]: {x_str}') + + +def nested_values(x: Union[dict, list]): + x_iter = x.values() if isinstance(x, dict) else x + for v in x_iter: + if isinstance(v, (dict, list)): + yield from nested_values(v) + else: + yield v + + +def nested_items_iter(x: Union[dict, list]): + x_iter = x.items() if isinstance(x, dict) else enumerate(x) + for k, v in x_iter: + if isinstance(v, (dict, list)): + yield from nested_items_iter(v) + else: + yield x, k, v + + +def dict_map(f: Callable, d: dict): + for sub_d, k, v in nested_items_iter(d): + sub_d[k] = f(v) + + +def dict_map_with_key(f: Callable, d: dict): + for sub_d, k, v in nested_items_iter(d): + sub_d[k] = f(k, v) + + +def dict_list_map_inplace(f: Callable, x: Union[dict, list]): + if isinstance(x, dict): + for k, v in x.items(): + x[k] = dict_list_map_inplace(f, v) + elif isinstance(x, list): + x[:] = (dict_list_map_inplace(f, v) for v in x) + else: + return f(x) + return x + + +def dict_list_map_outplace(f: Callable, x: Union[dict, list]): + if isinstance(x, dict): + return {k: dict_list_map_outplace(f, v) for k, v in x.items()} + elif isinstance(x, list): + return [dict_list_map_outplace(f, v) for v in x] + else: + return f(x) + + +def merge(x1: dict, x2: dict): + if isinstance(x1, dict) and isinstance(x2, dict): + for k, v2 in x2.items(): + if k not in x1: + x1[k] = v2 + else: + x1[k] = merge(x1[k], v2) + elif isinstance(x1, list) and isinstance(x2, list): + if len(x1) != len(x2): + raise ValueError('Cannot merge two lists with different lengths') + for i, v2 in enumerate(x2): + x1[i] = merge(x1[i], v2) + else: + raise ValueError(f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}`') + return x1 + + +def map_reduce( + xs: Iterable, + key_fn: Callable = lambda x: x, + value_fn: Callable = lambda x: x, + reduce_fn: Callable = lambda x: x, +) -> dict: + res = defaultdict(list) + for x in xs: + res[key_fn(x)].append(value_fn(x)) + for k in res: + res[k] = reduce_fn(res[k]) + return dict(res) diff --git a/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/mapping.py b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/mapping.py new file mode 100644 index 000000000..ec71411d6 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/mapping.py @@ -0,0 +1,289 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Core library classes. """ + +from dataclasses import dataclass, replace +from itertools import chain +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import numpy as np +import torch + +from .core import CheckpointingException +from .dict_utils import dict_list_map_inplace, dict_list_map_outplace + +# These type definitions are just hints to differentiate a plain model state +# dict (StateDict) from a state dict with tensors replaced with ShardedTensors +# (ShardedStateDict). +StateDict = Dict[str, Any] +ShardedStateDict = Dict[str, Any] +ReplicaId = Union[int, Tuple[int, ...]] + + +@dataclass +class ShardedTensor: + """Represents a mapping between a local tensor and a global tensor. + + Global tensor is assumed to consist of many local tensors distributed + between different processes. + + Attributes: + key: unique identifier of a global tensor + data: local tensor data. Can be None only for consistency validation + dtype: tensor dtype + local_shape: local tensor shape + global_shape: global tensor shape + global_offset: offset of a local tensor in a global tensor, specified + in number of tensor elements + axis_fragmentations: global tensor fragmentation of each axis + replica_id: indicates given local tensor's replication wrt. local + tensors in different processes + prepend_axis_num: number of axes prepended to the local tensor + to reflect global tensor shape. + The behavior is similar to unsqueezing the local tensor. + allow_shape_mismatch: if True, during loading, the global shape of a + stored tensor does not have to match the expected global shape. + Useful for representing tensors with flexible shape, e.g. padded. + flattened_range: specifies a slice that should be applied to a flattened + tensor with `local_shape` in order to get the tensor stored as `data` + """ + + key: str + data: Optional[torch.Tensor] + dtype: torch.dtype + local_shape: Tuple[int, ...] + global_shape: Tuple[int, ...] + global_offset: Tuple[int, ...] + axis_fragmentations: Optional[Tuple[int, ...]] + replica_id: ReplicaId = 0 + prepend_axis_num: int = 0 + allow_shape_mismatch: bool = False + flattened_range: Optional[slice] = None + + def global_slice(self) -> Tuple[Union[int, slice], ...]: + assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num + return tuple( + chain( + (off for off in self.global_offset[: self.prepend_axis_num]), + ( + slice(off, off + sh) + for off, sh in zip( + self.global_offset[self.prepend_axis_num :], self.local_shape + ) + ), + ) + ) + + def global_coordinates(self) -> Tuple[np.ndarray, ...]: + if self.flattened_range is None: + raise CheckpointingException( + f'`global_coordinates` is undefined for' + f' {self.__class__.__name__} without `flattened_range`' + ) + + local_coords = self.local_coordinates() + assert len(local_coords) + self.prepend_axis_num == len(self.global_offset), ( + len(local_coords), + self, + ) + global_coords = tuple( + c + off + for c, off in zip((0,) * self.prepend_axis_num + local_coords, self.global_offset) + ) + return global_coords + + def local_coordinates(self) -> Tuple[np.ndarray, ...]: + if self.flattened_range is None: + raise CheckpointingException( + f'`local_coordinates` is undefined for' + f' {self.__class__.__name__} without `flattened_range`' + ) + + # TODO: np.unravel_index? + mask = np.zeros(np.product(self.local_shape), dtype=bool) + mask[self.flattened_range] = True + return np.nonzero(mask.reshape(self.local_shape)) + + def max_allowed_chunks(self) -> Tuple[int, ...]: + chunks = [] + for axis_sh, axis_fragm in zip(self.global_shape, self.axis_fragmentations): + if not self.allow_shape_mismatch and axis_sh % axis_fragm != 0: + raise CheckpointingException( + f'Axis shape ({axis_sh}) not divisible' f' by axis fragmentation ({axis_fragm}' + ) + axis_chunk_size = axis_sh // axis_fragm + chunks.append(axis_chunk_size) + return tuple(chunks) + + def without_data(self): + return replace(self, data=None) + + @classmethod + def from_rank_offsets( + cls, + key: str, + data: torch.Tensor, + *rank_offsets: Tuple[int, int, int], + replica_id: ReplicaId = 0, + prepend_axis_num: int = 0, + allow_shape_mismatch: bool = False, + ): + """Allows to construct the ShardedTensor given offset specified in process ranks. + Arguments: + key: unique key + data: local tensor data + rank_offsets: each tuple (axis, axis_rank_offset, axis_fragm) + says that if global tensor is divided into `axis_fragm` + fragment along `axis` axis, then local tensor data + corresponds to the `axis_rank_offset` chunk. + replica_id: see ShardedTensor + prepend_axis_num: see ShardedTensor + allow_shape_mismatch: see ShardedTensor + """ + global_offset = [0] * (data.ndim + prepend_axis_num) + global_shape = ([1] * prepend_axis_num) + list(data.shape) + axis_fragmentations = [1] * (data.ndim + prepend_axis_num) + _seen_axis = set() + for axis, axis_rank_offset, axis_fragm in rank_offsets: + assert axis >= 0 and axis_rank_offset >= 0 and axis_fragm >= 0, ( + axis, + axis_rank_offset, + axis_fragm, + ) + assert ( + axis_rank_offset < axis_fragm + ), 'Rank offset must be lower than axis fragmentation' + if axis in _seen_axis: + raise CheckpointingException('Duplicated axis specified') + _seen_axis.add(axis) + + local_axis_shape = 1 if axis < prepend_axis_num else data.shape[axis - prepend_axis_num] + global_shape[axis] = axis_fragm * local_axis_shape + global_offset[axis] = axis_rank_offset * local_axis_shape + axis_fragmentations[axis] = axis_fragm + + return cls( + key, + data, + data.dtype, + tuple(data.shape), + tuple(global_shape), + tuple(global_offset), + tuple(axis_fragmentations), + replica_id, + prepend_axis_num, + allow_shape_mismatch, + ) + + def __str__(self): + return f'{self.__class__.__name__}(key=\'{self.key}\')' + + +def is_main_replica(replica_id): + if isinstance(replica_id, int): + return replica_id == 0 + return all(r == 0 for r in replica_id) + + +class LocalNonpersitentObject: + """Object that should not be stored in a checkpoint, but restored locally. + + Wrapping any object inside the state dict with LocalNonpersitentObject + will result in: + - during saving, this object will *not* be stored in the checkpoint + - during loading, a local version of this object will be placed in a state dict + """ + + def __init__(self, obj): + self.obj = obj + + def unwrap(self): + return self.obj + + +@dataclass +class ShardedObject: + """Represents a mapping between a local object and a global object. + + Global object is assumed to consist of many local objects distributed + between different processes. + + NOTE: Contrary to ShardedTensor, it's impossible to change global object + sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor + with atomic arbitrary typed elements. + + Attributes: + key: unique identifier of a global tensor + data: local object data. Can be None only for consistency validation + global_shape: global object shape + global_offset: offset of a local object in a global object, specified + in number of shards + replica_id: indicates local object replication wrt. local + objects in different processes + """ + + key: str + data: object + global_shape: Tuple[int, ...] + global_offset: Tuple[int, ...] + replica_id: ReplicaId = 0 + + def without_data(self): + return replace(self, data=None) + + @property + def unique_key(self): + return f'{self.key}/shard_{".".join(map(str, self.global_offset))}_{".".join(map(str, self.global_shape))}' + + def __str__(self): + return f'{self.__class__.__name__}(key=\'{self.key}\')' + + +@dataclass +class ShardedTensorFactory: + """ Allows to apply transformations to tensors before/after serialization. + + The essence of those transformations is that they can be applied to + optimizer states the same way they are applied to the model params. + + Builder creates a sub-state-dict out of a tensor before saving, and merger + merges the corresponding state dict after loading. + """ + + key: str + data: torch.Tensor + build_fn: Callable[[str, torch.Tensor], ShardedStateDict] + merge_fn: Callable[[StateDict], torch.Tensor] + + def build(self): + return self.build_fn(self.key, self.data) + + +def apply_factories(sharded_state_dict: ShardedStateDict): + def apply(x): + if isinstance(x, ShardedTensorFactory): + x = x.build() + return x + + dict_list_map_inplace(apply, sharded_state_dict) + + +def apply_factory_merges(x1: StateDict, x2: ShardedStateDict): + if isinstance(x2, ShardedTensorFactory): + return x2.merge_fn(x1) + + # There rest is almost the same as the `merge` function from `dict_utils` + if isinstance(x1, dict) and isinstance(x2, dict): + for k, v2 in x2.items(): + if k not in x1: + raise ValueError('Different dict keys encountered in `apply_factory_merges`') + else: + x1[k] = apply_factory_merges(x1[k], v2) + elif isinstance(x1, list) and isinstance(x2, list): + if len(x1) != len(x2): + raise ValueError('Cannot merge two lists with different lengths') + for i, v2 in enumerate(x2): + x1[i] = apply_factory_merges(x1[i], v2) + else: + raise ValueError(f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}`') + return x1 diff --git a/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/optimizer.py b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/optimizer.py new file mode 100644 index 000000000..4ed35fa6d --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/optimizer.py @@ -0,0 +1,86 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Optimizer related helpers. """ + +import logging +from copy import deepcopy +from dataclasses import replace +from itertools import chain +from typing import Dict, Iterable, List, Union + +logger = logging.getLogger(__name__) + +import torch + +from .dict_utils import nested_values +from .mapping import ( + LocalNonpersitentObject, + ShardedStateDict, + ShardedTensor, + ShardedTensorFactory, + StateDict, +) +from .utils import extract_sharded_tensors, extract_sharded_tensors_and_factories + + +def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]: + param_mappings = {} + for i, param in enumerate(optim_params_iter): + if id(param) not in param_mappings: + param_mappings[id(param)] = i + return param_mappings + + +def get_param_id_to_sharded_param_map( + model_sharded_state_dict: ShardedStateDict, optim_params_iter: Iterable[torch.nn.Parameter] +) -> Dict[int, Union[ShardedTensor, ShardedTensorFactory]]: + model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict) + id_to_sharded_param_map = {} + param_to_id_map = get_optim_param_to_id_map(optim_params_iter) + for ten in nested_values(model_sharded_state_dict): + if id(ten.data) in param_to_id_map: + id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten + else: + logger.debug(f'{ten} is not tracked by the optimizer') + + if not id_to_sharded_param_map: + logger.warning( + "Sharded parameters mapping is empty. It means tensors in model state dict" + " do not correspond to tensors in optimizer parameters map." + " Make sure to call state_dict with `keep_vars=True`." + ) + return id_to_sharded_param_map + + +def make_sharded_optimizer_tensor( + model_param: Union[ShardedTensor, ShardedTensorFactory], optim_param: torch.Tensor, prefix: str +) -> Union[ShardedTensor, ShardedTensorFactory]: + if isinstance(model_param, ShardedTensorFactory): + return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param) + + assert ( + tuple(optim_param.shape) == model_param.local_shape + ), f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ({model_param.local_shape})' + return replace( + model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype + ) + + +def optim_state_to_sharding_state( + optim_state_dict: StateDict, id_to_sharded_param_map: Dict[int, ShardedTensor] +): + sharded_state = {} + for param_id, param_state in optim_state_dict['state'].items(): + sharded_state[param_id] = {} + for state_key, param in param_state.items(): + if param_id in id_to_sharded_param_map: + sharded_state[param_id][state_key] = make_sharded_optimizer_tensor( + id_to_sharded_param_map[param_id], param, prefix=f'optimizer.state.{state_key}' + ) + else: + raise ValueError(f'Param id {param_id} does not match any model sharded param') + + optim_state_dict['param_groups'] = deepcopy(optim_state_dict['param_groups']) + for group in optim_state_dict['param_groups']: + group['params'] = LocalNonpersitentObject(group['params']) + optim_state_dict['state'] = sharded_state diff --git a/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/serialization.py b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/serialization.py new file mode 100644 index 000000000..9827e9636 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/serialization.py @@ -0,0 +1,334 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +import logging +import os +from collections import Counter, defaultdict +from itertools import chain +from pathlib import Path +from typing import Iterable, List, Tuple, Union + +import numpy as np +import torch + +from .core import CheckpointingConfig, maybe_load_config, save_config +from .dict_utils import ( + dict_list_map_inplace, + diff, + extract_matching_values, + map_reduce, + merge, + nested_values, +) +from .mapping import ( + CheckpointingException, + ShardedObject, + ShardedStateDict, + ShardedTensor, + ShardedTensorFactory, + StateDict, + apply_factories, + apply_factory_merges, + is_main_replica, +) +from .strategies.base import ( + LoadCommonStrategy, + LoadShardedStrategy, + SaveCommonStrategy, + SaveShardedStrategy, + StrategyAction, + get_default_strategy, +) +from .utils import extract_sharded_tensors, extract_sharded_tensors_or_nonpersistent + +COMMON_STATE_FNAME = 'common.pt' + +logger = logging.getLogger(__name__) + + +def load( + sharded_state_dict: ShardedStateDict, + checkpoint_dir: str, + sharded_strategy: Union[LoadShardedStrategy, None] = None, + common_strategy: Union[LoadCommonStrategy, None] = None, +) -> StateDict: + """Loading entrypoint. + + Arguments: + sharded_state_dict: state dict of the existing model populated with + ShardedTensors. Used as a mapping to determine which parts of + global tensors stored in the checkpoint should be loaded. + checkpoint_dir: directory with the checkpoint + sharded_strategy: configures loading behavior for sharded tensors + common_strategy: configures loading behavior for common data + """ + if common_strategy is not None: + raise NotImplementedError('The only supported common strategy is torch') + + checkpoint_dir = Path(checkpoint_dir) + common_state_dict = load_common_state_dict(checkpoint_dir) + if not sharded_state_dict: + return common_state_dict + + sharded_objects, sharded_state_dict = load_sharded_objects(sharded_state_dict, checkpoint_dir) + merge(common_state_dict, sharded_objects) + + saved_config = maybe_load_config(checkpoint_dir) + if saved_config is None: + raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint') + + sh_ten_factories, _ = extract_matching_values( + sharded_state_dict, lambda x: isinstance(x, ShardedTensorFactory) + ) + apply_factories(sharded_state_dict) + sharded_state_dict, _ = extract_sharded_tensors_or_nonpersistent(sharded_state_dict) + sharded_state_dict, nonpersistent_state_dict = extract_sharded_tensors(sharded_state_dict) + dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict) + merge(common_state_dict, nonpersistent_state_dict) + + validate_sharding_integrity(nested_values(sharded_state_dict)) + + if sharded_strategy is None: + sharded_strategy = get_default_strategy( + StrategyAction.LOAD_SHARDED, + saved_config.sharded_backend, + saved_config.sharded_backend_version, + ) + else: + # TODO: implement consistency checks here + pass + loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir) + + loaded_state_dict = apply_factory_merges(loaded_state_dict, sh_ten_factories) + + merge(common_state_dict, loaded_state_dict) + return common_state_dict + + +# TODO: implement it as common torch strategy +def load_common_state_dict(checkpoint_dir: Path): + return torch.load(Path(checkpoint_dir) / COMMON_STATE_FNAME, map_location='cpu') + + +def load_sharded_objects(sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + sharded_objects, sharded_state_dict = extract_matching_values( + sharded_state_dict, lambda v: isinstance(v, ShardedObject) + ) + + def load_sharded_object(sh_obj: ShardedObject): + sh_obj.data = None + load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt') + loaded_obj = torch.load(load_path) + return loaded_obj + + return dict_list_map_inplace(load_sharded_object, sharded_objects), sharded_state_dict + + +def save( + sharded_state_dict: ShardedStateDict, + checkpoint_dir: str, + sharded_strategy: Union[SaveShardedStrategy, None] = None, + common_strategy: Union[SaveCommonStrategy, None] = None, +): + """Saving entrypoint. + + Extracts ShardedTensors from the given state dict. Rank 0 saves the + "regular" part of the checkpoint to common torch file. + The ShardedTensors are saved according to a strategy specified by the + config. + + Arguments: + sharded_state_dict: state dict of the populated with + ShardedTensors. Used as a mapping to determine how local tensors + should be saved as global tensors in the checkpoint. + checkpoint_dir: directory to save the checkpoint to + sharded_strategy: configures sharded tensors saving behavior and backend + common_strategy: configures common data saving behavior and backend + """ + checkpoint_dir = Path(checkpoint_dir) + + if torch.distributed.get_rank() == 0: + if not checkpoint_dir.exists(): + raise CheckpointingException( + f'Checkpoint destination directory does not exist: {checkpoint_dir}' + ) + + if next(checkpoint_dir.iterdir(), None) is not None: + raise CheckpointingException( + f'Checkpoint destination directory ({checkpoint_dir}) is not empty' + ) + + if common_strategy is not None: + raise NotImplementedError('The only supported common strategy is torch') + + if sharded_strategy is None: + sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, 'zarr', 1) + + apply_factories(sharded_state_dict) + sharded_state_dict, state_dict = extract_sharded_tensors_or_nonpersistent(sharded_state_dict) + sharded_state_dict, _ = extract_sharded_tensors(sharded_state_dict) + sharded_tensors = list(nested_values(sharded_state_dict)) + validate_sharding_integrity(sharded_tensors) + + _save_common_dict(state_dict, checkpoint_dir, True) + + sharded_strategy.save(sharded_tensors, checkpoint_dir) + save_config( + CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), checkpoint_dir + ) + + +# TODO: implement it as common torch strategy +def _save_common_dict( + state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False +): + common_state_dict = _extract_and_save_sharded_objects( + state_dict, checkpoint_dir, validate_consistency + ) + if torch.distributed.get_rank() == 0: + torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME) + if validate_consistency: + # TODO: implement checking consistency with rank 0 common dict on other ranks + pass + # torch.distributed.barrier() + # if not torch.distributed.get_rank() == 0: + # rank_0_state_dict = torch.load(checkpoint_dir / COMMON_STATE_FNAME) + # print(diff(common_state_dict, rank_0_state_dict)) + + +def _extract_and_save_sharded_objects( + state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False +): + sharded_objects, state_dict = extract_matching_values( + state_dict, lambda v: isinstance(v, ShardedObject) + ) + sharded_objects = list(nested_values(sharded_objects)) + if validate_consistency: + validate_objects_sharding_integrity(sharded_objects) + for sh_obj in sharded_objects: + if is_main_replica(sh_obj.replica_id): + save_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt') + os.makedirs(save_path.parent, exist_ok=True) + torch.save(sh_obj.data, save_path) + return state_dict + + +def validate_sharding_integrity(sharded_tensors: Iterable[ShardedTensor]): + sharding = [ten.without_data() for ten in sharded_tensors] + all_sharding = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(all_sharding, sharding) + if torch.distributed.get_rank() != 0: + return + + key_shardings = defaultdict(list) + for rank, rank_shardings in enumerate(all_sharding): + for sharding in rank_shardings: + key_shardings[sharding.key].append((rank, sharding)) + for key, shardings in key_shardings.items(): + _validate_sharding_for_key(shardings) + + +def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]): + some_rank_shard = rank_sharding[0][1] + global_shape = some_rank_shard.global_shape + local_shape = some_rank_shard.local_shape + dtype = some_rank_shard.dtype + has_flattened_range = some_rank_shard.flattened_range is not None + for rank, sharding in rank_sharding: + assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard) + assert sharding.global_shape == global_shape, ( + sharding.global_shape, + global_shape, + some_rank_shard, + ) + assert sharding.local_shape == local_shape, ( + sharding.local_shape, + local_shape, + some_rank_shard, + ) + assert (sharding.flattened_range is not None) == has_flattened_range, ( + (sharding.flattened_range is not None), + has_flattened_range, + some_rank_shard, + ) + + shard_access_cnt = _compute_shards_access(rank_sharding) + if has_flattened_range: + map_reduce( + rank_sharding, + lambda x: x[1].global_offset, + lambda x: x[1], + _validate_sharding_for_key_flattened, + ) + else: + if not torch.all(shard_access_cnt == 1): + logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}') + raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}') + + +def _compute_shards_access(rank_sharding): + def chunk_offset(sharding): + assert len(sharding.global_offset) == len(sharding.local_shape) + sharding.prepend_axis_num + return tuple( + chain( + (off for off in sharding.global_offset[: sharding.prepend_axis_num]), + ( + off // sh + for off, sh in zip( + sharding.global_offset[sharding.prepend_axis_num :], sharding.local_shape + ) + ), + ) + ) + + shard_access_cnt = torch.zeros( + rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device='cpu' + ) + for rank, sharding in rank_sharding: + if is_main_replica(sharding.replica_id): + shard_access_cnt[chunk_offset(sharding)] += 1 + # TODO: consider validating different replicas too + return shard_access_cnt + + +def _validate_sharding_for_key_flattened(tensors_by_shard): + all_slices = [] + local_shape = tensors_by_shard[0].local_shape + for sharding in tensors_by_shard: + assert sharding.local_shape == local_shape + sharding: ShardedTensor + if not is_main_replica(sharding.replica_id): + # TODO: this checks only saving (and loading replica_id=0) consistency + continue + + all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop)) + + starts, stops = map(np.asarray, zip(*sorted(all_slices))) + if ( + starts[0] != 0 + or stops[-1] != np.product(local_shape) + or not np.all(starts[1:] == stops[:-1]) + ): + logger.error( + f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}' + ) + raise CheckpointingException( + f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}' + ) + + +def validate_objects_sharding_integrity(sharded_objects: List[ShardedObject]): + """ Ensure uniqueness of saved objects. """ + local_sh_objs = [sh_obj.without_data() for sh_obj in sharded_objects] + all_sh_objs = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(all_sh_objs, local_sh_objs) + if torch.distributed.get_rank() != 0: + return + unique_keys = [ + sh_obj.unique_key + for sh_obj in chain.from_iterable(all_sh_objs) + if is_main_replica(sh_obj.replica_id) + ] + if len(unique_keys) != len(set(unique_keys)): + duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1} + logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}') + raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}') diff --git a/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/__init__.py b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/__init__.py new file mode 100644 index 000000000..7177d973c --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Various loading and saving strategies """ + +import logging + +logger = logging.getLogger(__name__) + +try: + import tensorstore + import zarr + + from .tensorstore import _import_trigger + from .zarr import _import_trigger +except ImportError: + logger.warning('Zarr-based strategies will not be registered because of missing packages') diff --git a/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/base.py b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/base.py new file mode 100644 index 000000000..92ded320f --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/base.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +from abc import ABC, abstractmethod +from collections import defaultdict +from enum import Enum +from pathlib import Path +from typing import Dict, List, Optional + +from ..mapping import CheckpointingException, ShardedStateDict, ShardedTensor, StateDict + + +class StrategyAction(Enum): + LOAD_COMMON = 'load_common' + LOAD_SHARDED = 'load_sharded' + SAVE_COMMON = 'save_common' + SAVE_SHARDED = 'save_sharded' + + +default_strategies = defaultdict(dict) + + +def get_default_strategy(action: StrategyAction, backend: str, version: int): + try: + return default_strategies[action.value][(backend, version)] + except KeyError as e: + raise CheckpointingException( + f'Cannot find default strategy for: {(action, backend, version)}' + ) from e + + +class LoadStrategyBase(ABC): + @abstractmethod + def check_backend_compatibility(self, loaded_version): + raise NotImplementedError + + @abstractmethod + def check_version_compatibility(self, loaded_version): + raise NotImplementedError + + +class SaveStrategyBase(ABC): + def __init__(self, backend: str, version: int): + self.backend = backend + self.version = version + + +class LoadCommonStrategy(LoadStrategyBase): + @abstractmethod + def load(self, checkpoint_dir: Path): + raise NotImplementedError + + +class LoadShardedStrategy(LoadStrategyBase): + @abstractmethod + def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + raise NotImplementedError + + +class SaveCommonStrategy(SaveStrategyBase): + @abstractmethod + def save(self, common_state_dict: StateDict, checkpoint_dir: Path): + raise NotImplementedError + + +class SaveShardedStrategy(SaveStrategyBase): + @abstractmethod + def save(self, sharded_tensors: List[ShardedTensor], checkpoint_dir: Path): + raise NotImplementedError diff --git a/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/tensorstore.py b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/tensorstore.py new file mode 100644 index 000000000..34355a0f4 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/tensorstore.py @@ -0,0 +1,110 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Strategies using TensorStore to load and save Zarr arrays. """ + +from functools import partial +from itertools import starmap +from pathlib import Path + +import tensorstore as ts +import torch + +from ..core import CheckpointingException +from ..dict_utils import dict_list_map_inplace +from ..mapping import ShardedStateDict, ShardedTensor +from .base import LoadShardedStrategy, StrategyAction, default_strategies +from .zarr import postprocess_numpy_array + +_import_trigger = None + + +class TensorStoreLoadShardedStrategy(LoadShardedStrategy): + def __init__(self, load_directly_on_device: bool = False): + super().__init__() + self.load_directly_on_device = load_directly_on_device + + def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + if torch.distributed.get_rank() == 0: + print(f'Loading distributed checkpoint with {self.__class__.__name__}') + if self.load_directly_on_device: + print(f'Loading distributed checkpoint directly on the GPU') + load_fn = partial( + _load_from_array, + checkpoint_dir=checkpoint_dir, + load_directly_on_device=self.load_directly_on_device, + ) + dict_list_map_inplace(load_fn, sharded_state_dict) + return sharded_state_dict + + def check_backend_compatibility(self, loaded_version): + pass # TODO + + def check_version_compatibility(self, loaded_version): + pass # TODO + + +def merge_global_slice_with_shape(global_slice, actual_shape, key): + def _merge_slice(dim_slice, dim_size): + if isinstance(dim_slice, slice): + assert ( + dim_slice.start < dim_size + ), f'Got empty slice for ShardedTensor {key} ({dim_slice}, {dim_size})' + if dim_slice.stop > dim_size: + dim_slice = slice(dim_slice.start, dim_size, dim_slice.step) + return dim_slice + + assert len(global_slice) == len(actual_shape), (global_slice, actual_shape, key) + return tuple(starmap(_merge_slice, zip(global_slice, actual_shape))) + + +def _load_from_array( + sharded_tensor: ShardedTensor, + checkpoint_dir: Path, + load_directly_on_device: bool = False, + apply_flattened_range: bool = True, +): + x = _load_regular_chunk(sharded_tensor, checkpoint_dir) + ten = postprocess_numpy_array(x, sharded_tensor, apply_flattened_range) + if load_directly_on_device: + sharded_tensor.data.data.copy_(ten) + return sharded_tensor.data + else: + return ten + + +def _load_regular_chunk(sharded_tensor: ShardedTensor, checkpoint_dir: Path): + assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor) + spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}} + spec['kvstore'] = { + 'driver': 'file', + 'path': str(checkpoint_dir / sharded_tensor.key), + } + try: + arr = ts.open(ts.Spec(spec), open=True).result() + except Exception as e: + raise CheckpointingException( + f'Array {checkpoint_dir / sharded_tensor.key} could not be loaded. Error: {e}' + ) from e + + if sharded_tensor.global_shape == arr.shape: + x = ( + arr[sharded_tensor.global_slice()].read().result() + ) # flattened tensors loading is delayed + elif sharded_tensor.allow_shape_mismatch: + global_slice = merge_global_slice_with_shape( + sharded_tensor.global_slice(), arr.shape, sharded_tensor.key + ) + x = arr[global_slice].read().result() # flattened tensors loading is delayed + else: + _msg = ( + f'Global shape mismatch for loaded ({arr.shape})' + f' and expected ({sharded_tensor.global_shape}) tensor' + f' for key {sharded_tensor.key}' + ) + raise CheckpointingException(_msg) + return x + + +default_strategies[StrategyAction.LOAD_SHARDED.value][ + ('zarr', 1) +] = TensorStoreLoadShardedStrategy() diff --git a/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/two_stage.py b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/two_stage.py new file mode 100644 index 000000000..f35fb0a69 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/two_stage.py @@ -0,0 +1,249 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" 2-stage checkpoint loading. """ +import os +import time +from collections import defaultdict +from dataclasses import dataclass +from functools import partial, wraps +from itertools import chain +from logging import DEBUG, INFO, StreamHandler, getLogger +from operator import attrgetter, itemgetter +from pathlib import Path +from typing import Iterable, List, NamedTuple, Optional, Tuple, Union + +import torch + +from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values +from ..mapping import ShardedStateDict, ShardedTensor, StateDict +from .base import LoadShardedStrategy +from .tensorstore import _load_from_array +from .zarr import flatten_range + +_import_trigger = None + + +timers = defaultdict(list) + +logger = getLogger(__name__) + + +def timed(verbose=True): + def timed_dec(fn): + name = fn.__name__ + + @wraps(fn) + def wrapped(*args, **kwargs): + if verbose: + logger.debug(f'{name} init') + start = time.time() + ret = fn(*args, **kwargs) + took = time.time() - start + if verbose: + logger.debug(f'{name} took {took}s') + timers[name].append(took) + return ret + + return wrapped + + return timed_dec + + +@dataclass +class _ShardedTensorMetadata: + global_rank: int + sharded_tensor_no_data: ShardedTensor + dist_group_rank: Tuple[int] # id of distributed group + dist_group_ranks: Tuple[int] # id of distributed group + data_size: Optional[int] = None # bytes + + +def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor): + return ( + sharded_tensor.key, + sharded_tensor.global_offset, + ) + + +class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy): + """ Loads one checkpoint replica from storage and broadcasts to other nodes. + + This strategy loads checkpoint from storage on minimal set of nodes + and distributes the checkpoint to other nodes with torch.distributed. + Loading is performed with tensorstore. + + Steps: + 0. (optional) create Gloo distributed groups + 1. Exchange ShardedTensors metadata between all nodes + 2. Align needed tensors within DP groups + 3. For each globally unique tensor: + a) on one of the ranks load it from storage to CPU and move to CUDA + b) allocate CUDA tensor on other ranks + c) broadcast within DP group + d) copy tensor content to the model param location + e) free tensor buffers from a) and b) + + Notes: + 1. Loading and broadcasting is done sequentially to avoid both host and device OOMs + 2. There is a lot of overlap potential between all three steps done for each tensor: + a) loading from storage to numpy + b) moving CPU tensors to CUDA + c) broadcast + + """ + + def __init__(self, data_parallel_group, cpu_transfer=True): + super().__init__() + + self.cpu_transfer = cpu_transfer + self.data_parallel_group_orig = data_parallel_group + self.data_parallel_group = None if cpu_transfer else data_parallel_group + self.dp_group_ranks = tuple( + sorted(torch.distributed.get_process_group_ranks(data_parallel_group)) + ) + self.dp_group_rank = torch.distributed.get_rank(self.data_parallel_group_orig) + self.global_rank = torch.distributed.get_rank() + + def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + self.maybe_init_gloo_group() + all_tensors_sorted = self._build_load_plan(sharded_state_dict) + self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir) + self.summarize_load_times() + return sharded_state_dict + + def summarize_load_times(self): + torch.distributed.barrier() + logger.info('Checkpoint loading finished. Summary:') + for key, times in sorted(timers.items()): + times_sum = sum(times) + max_times = torch.tensor([times_sum], device='cuda') + avg_times = torch.tensor([times_sum], device='cuda') + torch.distributed.all_reduce(max_times, op=torch.distributed.ReduceOp.MAX) + torch.distributed.all_reduce(avg_times, op=torch.distributed.ReduceOp.SUM) + avg_times /= torch.distributed.get_world_size() + if torch.distributed.get_rank() == 0: + logger.info(f'{key}: max {max_times[0]}, avg {avg_times[0]}') + + @timed(verbose=False) + def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata): + logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init') + ret = _load_from_array( + ten_meta.sharded_tensor_no_data, + checkpoint_dir, + load_directly_on_device=False, + apply_flattened_range=False, + ) + logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) DONE') + return ret + + @timed() + def maybe_init_gloo_group(self): + if not self.cpu_transfer: + return + all_groups = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(all_groups, self.dp_group_ranks) + all_groups = set(tuple(sorted(gr)) for gr in all_groups) + for group_ranks in sorted(all_groups): + gloo_pg = torch.distributed.new_group(ranks=group_ranks, backend='gloo') + if self.global_rank in group_ranks: + self.data_parallel_group = gloo_pg + assert self.dp_group_rank == torch.distributed.get_rank(self.data_parallel_group) + + def check_backend_compatibility(self, loaded_version): + pass # TODO + + def check_version_compatibility(self, loaded_version): + pass # TODO + + @timed() + def _build_load_plan( + self, sharded_state_dict: ShardedStateDict + ) -> List[_ShardedTensorMetadata]: + local_meta = [ + _ShardedTensorMetadata( + self.global_rank, + sharded_ten.without_data(), + self.dp_group_rank, + self.dp_group_ranks, + ) + for sharded_ten in nested_values(sharded_state_dict) + ] + all_meta = [None] * torch.distributed.get_world_size(group=self.data_parallel_group) + torch.distributed.all_gather_object(all_meta, local_meta, group=self.data_parallel_group) + all_meta = list(chain.from_iterable(all_meta)) + all_tensors_sorted = self.deduplicate_chunks(all_meta) + return all_tensors_sorted + + @timed() + def deduplicate_chunks(self, ten_metas: List[_ShardedTensorMetadata]): + """ Group tensors by chunk and then pick the tensor with the lowest rank. + + NOTE: with proper loading overlap, loading from randomized ranks + (instead of the smallest one) could be beneficial here. + """ + ten_metas = map_reduce( + ten_metas, + key_fn=lambda meta: sharded_tensor_chunk_id(meta.sharded_tensor_no_data), + reduce_fn=partial(min, key=attrgetter('dist_group_rank')), + ) + all_metas_sorted = list(map(itemgetter(1), sorted(ten_metas.items()))) + return all_metas_sorted + + @timed() + def _exchange_loaded_tensors( + self, ten_metas: List[_ShardedTensorMetadata], sharded_state_dict, checkpoint_dir + ): + logger.debug(f'_exchange_loaded_tensors, num ten_metas: {len(ten_metas)}') + for ten_meta in ten_metas: + + src_rank = torch.distributed.get_global_rank( + self.data_parallel_group, ten_meta.dist_group_rank + ) + + if self.dp_group_rank == ten_meta.dist_group_rank: + exchange_tensor = self.load_tensor_from_storage(checkpoint_dir, ten_meta) + if not self.cpu_transfer: + exchange_tensor = exchange_tensor.cuda() + else: + # TODO: for non-flattened ranges we could reuse the buffer from the start here + exchange_tensor = torch.empty( + ten_meta.sharded_tensor_no_data.local_shape, + device='cpu' if self.cpu_transfer else 'cuda', + dtype=ten_meta.sharded_tensor_no_data.dtype, + ) + + logger.debug( + f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})' + ) + torch.distributed.broadcast( + exchange_tensor, group=self.data_parallel_group, src=src_rank + ) + self._distribute_data_to_state_dict(ten_meta, exchange_tensor, sharded_state_dict) + logger.debug(f'exchange {ten_meta.sharded_tensor_no_data.key} done') + + # free buffer memory + exchange_tensor = None + + @timed(verbose=False) + def _distribute_data_to_state_dict( + self, + ten_meta: _ShardedTensorMetadata, + loaded_ten: torch.Tensor, + sharded_state_dict: ShardedStateDict, + ): + tensor_key = sharded_tensor_chunk_id(ten_meta.sharded_tensor_no_data) + + def _fill_in_data(t: Union[ShardedTensor, torch.Tensor]): + if not isinstance(t, ShardedTensor) or sharded_tensor_chunk_id(t) != tensor_key: + # already filled-in or key not matching + return t + sharded_tensor: ShardedTensor = t + x = loaded_ten + if sharded_tensor.flattened_range is not None: + x = flatten_range(sharded_tensor, x) + + # Reuse existing buffer + sharded_tensor.data.data.copy_(x) + return sharded_tensor.data + + dict_list_map_inplace(_fill_in_data, sharded_state_dict) diff --git a/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/zarr.py b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/zarr.py new file mode 100644 index 000000000..4c61f2d97 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/strategies/zarr.py @@ -0,0 +1,230 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +""" Strategies using Zarr as an underlying format. """ +import os +from functools import partial +from pathlib import Path +from typing import List + +import numpy as np +import torch +import zarr + +from ..core import CheckpointingException +from ..dict_utils import dict_list_map_inplace +from ..mapping import ShardedStateDict, ShardedTensor, is_main_replica +from .base import LoadShardedStrategy, SaveShardedStrategy, StrategyAction, default_strategies + +numpy_to_torch_dtype_dict = { + np.bool_: torch.bool, + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} + +torch_to_numpy_dtype_dict = {v: k for k, v in numpy_to_torch_dtype_dict.items()} + + +try: + import tensorstore + + HAS_BFLOAT16 = True + numpy_to_torch_dtype_dict[np.dtype('bfloat16')] = torch.bfloat16 + torch_to_numpy_dtype_dict[torch.bfloat16] = np.dtype('bfloat16') +except ImportError: + HAS_BFLOAT16 = False + +_import_trigger = None + + +class ZarrSaveShardedStrategy(SaveShardedStrategy): + def save(self, sharded_tensors: List[ShardedTensor], checkpoint_dir: Path): + arrays = _create_or_open_zarr_arrays(sharded_tensors, checkpoint_dir) + for ten, arr in zip(sharded_tensors, arrays): + _save_to_existing_array(ten, arr) + torch.distributed.barrier() + + +def _create_or_open_zarr_arrays( + sharded_tensors: List[ShardedTensor], checkpoint_dir: Path +) -> List[zarr.Array]: + arrays = [] + for ten in sharded_tensors: + if _should_create_array(ten): + _create_zarr_array(ten, checkpoint_dir) + # TODO: maybe reuse the opened arrays + + torch.distributed.barrier() + for ten in sharded_tensors: + # if is_main_replica(ten.replica_id) and set(ten.global_offset) == {0}: + # continue + open_kwargs = {} + if ten.flattened_range is not None: + open_kwargs['synchronizer'] = zarr.ProcessSynchronizer( + str(checkpoint_dir / f'{ten.key}.sync') + ) + arr = zarr.open(checkpoint_dir / ten.key, 'r+', **open_kwargs) + arrays.append(arr) + return arrays + + +def _should_create_array(ten: ShardedTensor): + return ( + is_main_replica(ten.replica_id) + and set(ten.global_offset) == {0} + and (ten.flattened_range is None or ten.flattened_range.start == 0) + ) + + +def _save_to_existing_array(sharded_tensor: ShardedTensor, arr: zarr.Array): + if not is_main_replica(sharded_tensor.replica_id): + return + x = sharded_tensor.data + x = x.detach().cpu() + torch.cuda.synchronize() + if x.dtype == torch.bfloat16: + x = x.float() + x = x.numpy() + x = x.astype('bfloat16') + else: + x = x.numpy() + + if sharded_tensor.flattened_range is None: + arr[sharded_tensor.global_slice()] = x + else: + arr.set_coordinate_selection(sharded_tensor.global_coordinates(), x) + + +def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path): + np_dtype = torch_to_numpy_dtype_dict[sharded_tensor.dtype] + try: + arr = zarr.create( + sharded_tensor.global_shape, + dtype=np_dtype, + store=checkpoint_dir / sharded_tensor.key, + chunks=sharded_tensor.max_allowed_chunks(), + compressor=None, + fill_value=None, + write_empty_chunks=True, + ) + except zarr.errors.ContainsArrayError as e: + raise CheckpointingException( + f'Array {checkpoint_dir / sharded_tensor.key} already exists' + ) from e + + if HAS_BFLOAT16 and np_dtype == np.dtype('bfloat16'): + arr._dtype = np_dtype + zarray = arr.store['.zarray'] + arr.store['.zarray'] = zarray.replace(b' exp_sh: + assert ( + False + ), f'Expected shape ({exp_sh}) smaller than actual ({x_sh}) for {repr(expected_sharded_ten)}' + else: + pad_args.extend((0, exp_sh - x_sh)) + # TODO: behavior control with envvar is for testing purposes only, remove it + if not int(os.environ.get('DIST_CKPT_PAD_REPLICATE', 0)): + return torch.nn.functional.pad(x, pad_args) + + # unsqueeze and squeeze to get shapes supported by cudnn + print(f'Replicating last row for {expected_sharded_ten.key}') + if x.dtype == torch.bfloat16: + return ( + torch.nn.functional.pad(x.float().unsqueeze(0), pad_args, mode='replicate') + .squeeze(0) + .bfloat16() + ) + return torch.nn.functional.pad(x.unsqueeze(0), pad_args, mode='replicate').squeeze(0) + + +# default_strategies[StrategyAction.LOAD_SHARDED.value][('zarr', 1)] = ZarrLoadShardedStrategy() +default_strategies[StrategyAction.SAVE_SHARDED.value][('zarr', 1)] = ZarrSaveShardedStrategy( + 'zarr', 1 +) diff --git a/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/utils.py b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/utils.py new file mode 100644 index 000000000..f7976f007 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/dist_checkpointing/utils.py @@ -0,0 +1,44 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + +from typing import Tuple + +from .dict_utils import dict_list_map_inplace, extract_matching_values +from .mapping import ( + LocalNonpersitentObject, + ShardedStateDict, + ShardedTensor, + ShardedTensorFactory, + StateDict, +) + + +def extract_sharded_tensors( + sharded_state_dict: ShardedStateDict, +) -> Tuple[ShardedStateDict, StateDict]: + return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedTensor)) + + +def extract_sharded_tensors_and_factories( + sharded_state_dict: ShardedStateDict, +) -> Tuple[ShardedStateDict, StateDict]: + return extract_matching_values( + sharded_state_dict, lambda v: isinstance(v, (ShardedTensor, ShardedTensorFactory)) + ) + + +def extract_sharded_tensors_or_nonpersistent( + sharded_state_dict: ShardedStateDict, +) -> Tuple[ShardedStateDict, StateDict]: + return extract_matching_values( + sharded_state_dict, + lambda v: isinstance(v, (ShardedTensor, LocalNonpersitentObject, ShardedTensorFactory)), + ) + + +def add_prefix_for_sharding(sharded_state_dict: ShardedStateDict, prefix: str): + def add_prefix(t): + if isinstance(t, ShardedTensor): + t.key = f'{prefix}.{t.key}' + return t + + dict_list_map_inplace(add_prefix, sharded_state_dict) diff --git a/training/DeepSpeed-Domino/megatron/core/enums.py b/training/DeepSpeed-Domino/megatron/core/enums.py new file mode 100644 index 000000000..46e7d3b76 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/enums.py @@ -0,0 +1,10 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import enum + + +class ModelType(enum.Enum): + encoder_or_decoder = 1 + encoder_and_decoder = 2 + retro_encoder = 3 + retro_decoder = 4 diff --git a/training/DeepSpeed-Domino/megatron/core/fusions/__init__.py b/training/DeepSpeed-Domino/megatron/core/fusions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/training/DeepSpeed-Domino/megatron/core/fusions/fused_bias_dropout.py b/training/DeepSpeed-Domino/megatron/core/fusions/fused_bias_dropout.py new file mode 100644 index 000000000..e0d6ffbda --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/fusions/fused_bias_dropout.py @@ -0,0 +1,60 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from typing import Optional, Tuple + +import torch + + +def _bias_dropout_add_func(x, bias, residual, prob, training): + # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor + # NOTE: Previously, the argument `bias` used to be passed as + # `bias.expand_as(residual)` when the `bias_dropout_func` is called from the + # transformer layer but broadcasting should automatically take care of that. + # Also, looking at broadcasting semantics, `expand_as` and broadcasting + # seem to be identical performance-wise (both just change the view). + + # If we want to train mixed precision, then the output of this function + # should be half precision. However, in AMP O1, the input (residual) is + # in fp32, and it will up-cast the result to fp32, causing pipeline parallel + # GPU communication to hang. Therefore, we need to cast residual to the same + # dtype as x. + residual = residual if residual.dtype == x.dtype else residual.to(x.dtype) + if bias is not None: + x = x + bias + out = torch.nn.functional.dropout(x, p=prob, training=training) + out = residual + out + return out + + +@torch.jit.script +def bias_dropout_add_fused_train( + x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float, +) -> torch.Tensor: + x, bias = x_with_bias # unpack + return _bias_dropout_add_func(x, bias, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_fused_inference( + x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float, +) -> torch.Tensor: + x, bias = x_with_bias # unpack + return _bias_dropout_add_func(x, bias, residual, prob, False) + + +def get_bias_dropout_add(training, fused): + def unfused_bias_dropout_add(x_with_bias, residual, prob): + x, bias = x_with_bias # unpack + return _bias_dropout_add_func(x, bias, residual, prob, training) + + if fused: + # jit scripting for a nn.module (with dropout) is not + # triggering the fusion kernel. For now, we use two + # different nn.functional routines to account for varying + # dropout semantics during training and inference phases. + if training: + return bias_dropout_add_fused_train + else: + return bias_dropout_add_fused_inference + else: + return unfused_bias_dropout_add diff --git a/training/DeepSpeed-Domino/megatron/core/fusions/fused_bias_gelu.py b/training/DeepSpeed-Domino/megatron/core/fusions/fused_bias_gelu.py new file mode 100644 index 000000000..9c791c180 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/fusions/fused_bias_gelu.py @@ -0,0 +1,48 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import torch + +###### BIAS GELU FUSION/ NO AUTOGRAD ################ +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) + + +@torch.jit.script +def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_back(g, bias, y): + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( + 1 + tanh_out + ) + return ff * g + + +class GeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(bias, input) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_back(grad_output, bias, input) + return tmp, tmp + + +bias_gelu_impl = GeLUFunction.apply diff --git a/training/DeepSpeed-Domino/megatron/core/fusions/fused_layer_norm.py b/training/DeepSpeed-Domino/megatron/core/fusions/fused_layer_norm.py new file mode 100644 index 000000000..e4f098424 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/fusions/fused_layer_norm.py @@ -0,0 +1,119 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import importlib +import numbers + +import torch +from torch.nn import init +from torch.nn.parameter import Parameter + +from megatron.core.utils import make_viewless_tensor + +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNormFN + + HAVE_PERSIST_LAYER_NORM = True +except: + HAVE_PERSIST_LAYER_NORM = False + +try: + from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction + + HAVE_FUSED_LAYER_NORM = True +except: + HAVE_FUSED_LAYER_NORM = False + + +class FusedLayerNorm(torch.nn.Module): + def __init__( + self, + hidden_size, + eps=1e-5, + persist_layer_norm=True, + sequence_parallel=False, + zero_centered_gamma=False, + ): + super().__init__() + + self.zero_centered_gamma = zero_centered_gamma + + # List of hiddens sizes supported in the persistent layer norm kernel + # If the hidden size is not supported, fall back to the non-persistent + # kernel. + persist_ln_hidden_sizes = [ + 1024, + 1536, + 2048, + 2304, + 3072, + 3840, + 4096, + 5120, + 6144, + 8192, + 10240, + 12288, + 12800, + 15360, + 16384, + 18432, + 20480, + 24576, + 25600, + 30720, + 32768, + 40960, + 49152, + 65536, + ] + if hidden_size not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM: + persist_layer_norm = False + + if not persist_layer_norm and not HAVE_FUSED_LAYER_NORM: + # TODO: Add pytorch only layer norm + raise ValueError(f'Apex must currently be installed to use megatron core.') + + if isinstance(hidden_size, numbers.Integral): + hidden_size = (hidden_size,) + self.hidden_size = torch.Size(hidden_size) + self.eps = eps + self.weight = Parameter(torch.Tensor(*hidden_size)) + self.bias = Parameter(torch.Tensor(*hidden_size)) + self.reset_parameters() + self.persist_layer_norm = persist_layer_norm + self.sequence_parallel = sequence_parallel + + # set sequence parallelism flag on weight and bias parameters + setattr(self.weight, 'sequence_parallel', self.sequence_parallel) + setattr(self.bias, 'sequence_parallel', self.sequence_parallel) + + def reset_parameters(self): + + if self.zero_centered_gamma: + init.zeros_(self.weight) + init.zeros_(self.bias) + else: + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input): + + weight = self.weight + 1 if self.zero_centered_gamma else self.weight + + if self.persist_layer_norm: + output = FastLayerNormFN.apply(input, weight, self.bias, self.eps) + + # Apex's fast layer norm function outputs a 'view' tensor (i.e., has + # a populated '_base' field). This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + output = make_viewless_tensor( + inp=output, requires_grad=input.requires_grad, keep_graph=True + ) + + else: + output = FusedLayerNormAffineFunction.apply( + input, weight, self.bias, self.hidden_size, self.eps + ) + + return output diff --git a/training/DeepSpeed-Domino/megatron/core/fusions/fused_softmax.py b/training/DeepSpeed-Domino/megatron/core/fusions/fused_softmax.py new file mode 100644 index 000000000..56eb2e801 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/fusions/fused_softmax.py @@ -0,0 +1,204 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + + +import torch +import torch.nn as nn + +from megatron.core.transformer.enums import AttnMaskType + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + import scaled_upper_triang_masked_softmax_cuda + + scale_t = torch.tensor([scale]) + softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0]) + + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import scaled_upper_triang_masked_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_upper_triang_masked_softmax_cuda.backward( + output_grads, softmax_results, scale_t[0] + ) + + return input_grads, None + + +class ScaledMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, mask, scale): + import scaled_masked_softmax_cuda + + scale_t = torch.tensor([scale]) + + softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import scaled_masked_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None + + +class ScaledSoftmax(torch.autograd.Function): + """ + Fused operation which performs following two operations in sequence + 1. Scale the tensor. + 2. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + import scaled_softmax_cuda + + scale_t = torch.tensor([scale]) + + softmax_results = scaled_softmax_cuda.forward(inputs, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import scaled_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None + + +class FusedScaleMaskSoftmax(nn.Module): + """ + fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: flag to indicate if input in fp16 data format. + input_in_bf16: flag to indicate if input in bf16 data format. + attn_mask_type: attention mask type (pad or causal) + scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion + mask_func: mask function to be applied. + softmax_in_fp32: if true, softmax in performed at fp32 precision. + scale: scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super(FusedScaleMaskSoftmax, self).__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + + assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and 16 < sk <= 4096 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and sk % 4 == 0 # sk must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 4096: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type == AttnMaskType.causal: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 + + if self.attn_mask_type == AttnMaskType.causal: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) + return probs.view(b, np, sq, sk) + else: + # input is 4D tensor (b, np, sq, sk) + if mask is not None: + return ScaledMaskedSoftmax.apply(input, mask, scale) + else: + return ScaledSoftmax.apply(input, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + @staticmethod + def get_batch_per_block(sq, sk, b, np): + import scaled_masked_softmax_cuda + + return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) diff --git a/training/DeepSpeed-Domino/megatron/core/inference_params.py b/training/DeepSpeed-Domino/megatron/core/inference_params.py new file mode 100644 index 000000000..287902460 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/inference_params.py @@ -0,0 +1,27 @@ +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + def __init__(self, max_batch_size, max_sequence_length): + self.max_sequence_length = max_sequence_length + self.max_batch_size = max_batch_size + self.sequence_len_offset = 0 + self.batch_size_offset = 0 + self.key_value_memory_dict = {} + + def swap_key_value_dict(self, batch_idx): + "swap between batches" + if len(self.key_value_memory_dict) == 0: + raise ValueError("should not swap when dict in empty") + + for layer_number in self.key_value_memory_dict.keys(): + inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number] + assert ( + len(batch_idx) == inference_key_memory.shape[1] + ) # make sure batch size is the same + new_inference_key_memory = inference_key_memory[:, batch_idx] + new_inference_value_memory = inference_value_memory[:, batch_idx] + self.key_value_memory_dict[layer_number] = ( + new_inference_key_memory, + new_inference_value_memory, + ) diff --git a/training/DeepSpeed-Domino/megatron/core/model_parallel_config.py b/training/DeepSpeed-Domino/megatron/core/model_parallel_config.py new file mode 100644 index 000000000..85d3c8e7b --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/model_parallel_config.py @@ -0,0 +1,167 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from dataclasses import dataclass +from typing import Callable, Optional + +import torch + + +@dataclass +class ModelParallelConfig: + """Base configuration for Megatron Core + + Model Parallelism + ----------------- + + tensor_model_parallel_size (int): Intra-layer model parallelism. Splits tensors across GPU ranks. Defaults to 1. + + pipeline_model_parallel_size (int): Inter-layer model parallelism. Splits transformer layers across GPU + ranks. Defaults to 1. + + virtual_pipeline_model_parallel_size (int): Interleaved pipeline parallelism is used to improve performance by + reducing the pipeline bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks. + The number of virtual blocks per pipeline model parallel rank is the virtual model parallel size. See Efficient + Large-Scale Language Model Training on GPU Clusters Using Megatron-LM: https://arxiv.org/pdf/2104.04473.pdf for + more details. Defaults to None. + + sequence_parallel (bool): Makes tensor parallelism more memory efficient for LLMs (20B+) by + parallelizing layer norms and dropout sequentially. See Reducing Activation Recomputation in Large Transformer + Models: https://arxiv.org/abs/2205.05198 for more details. Defaults to False. + + Initialization + -------------- + + perform_initialization (bool, default=True): If true, weights are initialized. This option can be useful when you + know you are going to load values from a checkpoint. + + use_cpu_initialization: (bool, default=False): When set to False, we initialize the weights directly on the GPU. + Transferring weights from CPU to GPU can take a significant amount of time for large models. Defaults to False. + + Training + -------- + + fp16 (bool): If true, train with fp16 mixed precision training. Defaults to False. + + bf16 (bool): If true, train with bf16 mixed precision training. Defaults to False. + + params_dtype (torch.dtype): dtype used when intializing the weights. Defaults to torch.float32 + + timers (optional, default=None): TODO + + Optimizations + ------------- + + gradient_accumulation_fusion (bool): If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA + extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with + --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" + ". Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion. + Defaults to False. + + async_tensor_model_parallel_allreduce (bool, default=True): If true, enables asynchronous execution of + tensor-model-parallel all-reduce with weight gradient compuation of a column-linear layer. Defaults to False. + + Pipeline Parallelism + -------------------- + + pipeline_dtype (required): dtype used in p2p communication, usually params_dtype + + grad_scale_func (optional, default=None): If using loss scaling, this function should take the loss and return the + scaled loss. If None, no function is called on the loss. + + enable_autocast (bool): If true runs the forward step function inside torch.autocast context. Default is False. + + autocast_dtype (torch.dtype): dtype to pass to torch.amp.autocast when enabled. Default is pipeline_dtype. + + variable_seq_lengths (bool, default=False): Support for variable sequence lengths across microbatches. Setting this + communicates the size of tensors during pipeline parallelism communication, because of this extra overhead it + should only be set if the sequence length varies by microbatch within a global batch. + + num_microbatches_with_partial_activation_checkpoints (int, default=None): If int, set the number of microbatches + where not all of the layers will be checkpointed and recomputed. The rest of the microbatches within the window + of maximum outstanding microbatches will recompute all layers (either full recompute or selective recompute). If + None, the checkpoint and recompute will be left up to the forward_step function. + + overlap_p2p_comm (bool, optional, default=False): When True some of the peer to peer communication for pipeline + parallelism will overlap with computation. Must be False if batch_p2p_comm is true. + + batch_p2p_comm (bool, default=True): Use batch_isend_irecv instead of individual isend/irecv calls. Must be False + if overlap_p2p_comm is True. + + batch_p2p_sync (bool, default=True): When using batch_isend_irecv, do a cuda.device.synchronize afterward to work + around a bug in older version of PyTorch. + + use_ring_exchange_p2p (bool, default = False): Use custom ring_exchange kernel instead of + torch.distributed.batch_isend_irecv(). Requires custom built torch with torch.distributed.ring_exchange. + + deallocate_pipeline_outputs (optional, default=False): If True, output data is deallocated after the tensor is sent + to the next pipeline stage. Helps with saving memory, does nothing when pipeline parallel is not used. + + no_sync_func (optional): Function that creates a context that suppresses asynchronous data-parallel + communication. If the model is an instance of torch.nn.DistributedDataParallel, the default is to use + torch.nn.DistributedDataParallel.no_sync. + + grad_sync_func (optional): Function that launches asynchronous gradient reductions (e.g. distributed optimizer + gradient reduce-scatters). The function should take one argument: an iterable of parameters whose gradients are + to be synchronized. + + param_sync_func (optional): Function that launches asynchronous parameter synchronizations (e.g. distributed + optimizer parameter all-gathers). The function should take one argument: an iterable of parameters to be + synchronized. + + """ + + # Model parallelism + tensor_model_parallel_size: int = 1 + pipeline_model_parallel_size: int = 1 + virtual_pipeline_model_parallel_size: Optional[int] = None + sequence_parallel: bool = False + + # Initialization + perform_initialization: bool = True + use_cpu_initialization: bool = False + + # Training + fp16: bool = False + bf16: bool = False + params_dtype: torch.dtype = torch.float32 + timers: Callable = None + + # Optimizations + gradient_accumulation_fusion: bool = False + async_tensor_model_parallel_allreduce: bool = False + + # Pipeline Parallel + pipeline_dtype: torch.dtype = None + grad_scale_func: Callable = None + enable_autocast: bool = False + autocast_dtype: torch.dtype = None + variable_seq_lengths: bool = False + num_microbatches_with_partial_activation_checkpoints: Optional[int] = None + overlap_p2p_comm: bool = False + batch_p2p_comm: bool = True + batch_p2p_sync: bool = True + use_ring_exchange_p2p: bool = False + deallocate_pipeline_outputs: bool = False + no_sync_func: Callable = None + grad_sync_func: Callable = None + param_sync_func: Callable = None + + def __post_init__(self): + """ Python dataclass method that is used to modify attributes after initialization. + See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details. + """ + if self.sequence_parallel: + if self.tensor_model_parallel_size <= 1: + raise ValueError("Can not use sequence paralllelism without tensor parallelism") + if self.async_tensor_model_parallel_allreduce: + # sequence_parallelism already does this async + self.async_tensor_model_parallel_allreduce = False + + if self.pipeline_model_parallel_size > 1: + if self.pipeline_dtype is None: + raise ValueError( + "When using pipeline parallelism, pipeline_dtype must be specified" + ) + + if self.autocast_dtype is None: + self.autocast_dtype = self.params_dtype diff --git a/training/DeepSpeed-Domino/megatron/core/models/__init__.py b/training/DeepSpeed-Domino/megatron/core/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/training/DeepSpeed-Domino/megatron/core/models/common/__init__.py b/training/DeepSpeed-Domino/megatron/core/models/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/training/DeepSpeed-Domino/megatron/core/models/common/rotary_pos_embedding.py b/training/DeepSpeed-Domino/megatron/core/models/common/rotary_pos_embedding.py new file mode 100644 index 000000000..291b10df7 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/models/common/rotary_pos_embedding.py @@ -0,0 +1,56 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import importlib.util + +import torch +from torch import einsum, nn + +__all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb'] + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, seq_len_interpolation_factor=None): + super().__init__() + self.seq_len_interpolation_factor = seq_len_interpolation_factor + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + def forward(self, max_seq_len, offset=0): + seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset + if self.seq_len_interpolation_factor is not None: + seq = seq.type_as(self.inv_freq) + seq *= 1 / self.seq_len_interpolation_factor + freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq) + # first part even vector components, second part odd vector components, + # 2 * dim in dimension size + emb = torch.cat((freqs, freqs), dim=-1) + # emb [seq_length, .., dim] + return emb[:, None, None, :] + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + state_dict.pop(f'{prefix}inv_freq', None) + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +def _rotate_half(x): + """ + change sign so the last dimension becomes [-odd, +even] + """ + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs): + """ + input tensor t is of shape [seq_length, ..., dim] + rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] + check https://kexue.fm/archives/8265 for detailed formulas + """ + rot_dim = freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin()) + return torch.cat((t, t_pass), dim=-1) diff --git a/training/DeepSpeed-Domino/megatron/core/models/gpt/__init__.py b/training/DeepSpeed-Domino/megatron/core/models/gpt/__init__.py new file mode 100644 index 000000000..2d5eb8674 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/models/gpt/__init__.py @@ -0,0 +1 @@ +from .gpt_model import GPTModel diff --git a/training/DeepSpeed-Domino/megatron/core/models/gpt/gpt_embedding.py b/training/DeepSpeed-Domino/megatron/core/models/gpt/gpt_embedding.py new file mode 100644 index 000000000..578ae803c --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/models/gpt/gpt_embedding.py @@ -0,0 +1,123 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import torch + +from megatron.core import tensor_parallel +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import ( + make_sharded_tensor_for_checkpoint, + make_tp_sharded_tensor_for_checkpoint, +) + + +class GPTEmbedding(MegatronModule): + """Language model embeddings. + + Arguments: + config (TransformerConfig): config object with all necessary configs for TransformerBlock + vocab_size (int): vocabulary size + max_sequence_length (int): maximum size of sequence. This + is used for positional embedding + add_position_embedding (bool): Add a position embedding. + embedding_dropout_prob float): dropout probability for embeddings + """ + + def __init__( + self, + config: TransformerConfig, + vocab_size: int, + max_sequence_length: int, + add_position_embedding: bool, + ): + super().__init__(config=config) + + self.config: TransformerConfig = config + self.vocab_size: int = vocab_size + self.max_sequence_length: int = max_sequence_length + self.add_position_embedding: bool = add_position_embedding + + # Word embeddings (parallel). + self.word_embeddings = tensor_parallel.VocabParallelEmbedding( + num_embeddings=self.vocab_size, + embedding_dim=self.config.hidden_size, + init_method=self.config.init_method, + config=self.config, + ) + + # Position embedding (serial). + if self.add_position_embedding: + self.position_embeddings = torch.nn.Embedding( + self.max_sequence_length, self.config.hidden_size + ) + + # Initialize the position embeddings. + if self.config.perform_initialization: + self.config.init_method(self.position_embeddings.weight) + + # Embeddings dropout + self.embedding_dropout = torch.nn.Dropout(self.config.hidden_dropout) + + def zero_parameters(self): + """Zero out all parameters in embedding.""" + self.word_embeddings.weight.data.fill_(0) + self.word_embeddings.weight.shared = True + self.position_embeddings.weight.data.fill_(0) + self.position_embeddings.weight.shared = True + + def forward(self, input_ids, position_ids): + # Embeddings. + word_embeddings = self.word_embeddings(input_ids) + if self.add_position_embedding: + position_embeddings = self.position_embeddings(position_ids) + embeddings = word_embeddings + position_embeddings + else: + embeddings = word_embeddings + + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + + # If the input flag for fp32 residual connection is set, convert for float. + if self.config.fp32_residual_connection: + embeddings = embeddings.float() + + # Dropout. + if self.config.sequence_parallel: + embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) + with tensor_parallel.get_cuda_rng_tracker().fork(): + embeddings = self.embedding_dropout(embeddings) + else: + embeddings = self.embedding_dropout(embeddings) + + return embeddings + + def sharded_state_dict(self, prefix=''): + + sharded_state_dict = {} + + word_embeddings_prefix = f'{prefix}word_embeddings.' + word_embeddings_state_dict = self.word_embeddings.state_dict( + prefix=word_embeddings_prefix, keep_vars=True + ) + + sharded_word_embeddings_key = f'{word_embeddings_prefix}weight' + sharded_word_embeddings_tensor = make_tp_sharded_tensor_for_checkpoint( + tensor=word_embeddings_state_dict[sharded_word_embeddings_key], + key=sharded_word_embeddings_key, + allow_shape_mismatch=True, + ) + sharded_state_dict[sharded_word_embeddings_key] = sharded_word_embeddings_tensor + + if self.add_position_embedding: + position_embeddings_prefix = f'{prefix}position_embeddings.' + position_embeddings_state_dict = self.position_embeddings.state_dict( + prefix=position_embeddings_prefix, keep_vars=True + ) + sharded_position_embeddings_key = f'{position_embeddings_prefix}weight' + sharded_position_embeddings_tensor = make_sharded_tensor_for_checkpoint( + tensor=position_embeddings_state_dict[sharded_position_embeddings_key], + key=sharded_position_embeddings_key, + ) + sharded_state_dict[sharded_position_embeddings_key] = sharded_position_embeddings_tensor + + return sharded_state_dict diff --git a/training/DeepSpeed-Domino/megatron/core/models/gpt/gpt_model.py b/training/DeepSpeed-Domino/megatron/core/models/gpt/gpt_model.py new file mode 100644 index 000000000..f1c304b7a --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/models/gpt/gpt_model.py @@ -0,0 +1,308 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import logging +from typing import Literal, Optional + +import torch +from torch import Tensor + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.models.common.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.gpt.gpt_embedding import GPTEmbedding +from megatron.core.transformer.enums import AttnMaskType, ModelType +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint + + +class GPTModel(MegatronModule): + """Transformer language model. + + Arguments: + config (TransformerConfig): transformer config + + vocab_size (int): vocabulary size + + max_sequence_length (int): maximum size of sequence. This is used for positional embedding + + pre_process (bool): Include embedding layer (used with pipeline parallelism) + post_process (bool): Include an output layer (used with pipeline parallelism) + + parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks + + share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are + shared. Defaults to False. + + position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope']. + Defaults is 'learned_absolute'. + + rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings. + Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'. + + seq_len_interpolation_factor (float): scale of linearly interpolating RoPE for longer sequences. + The value must be a float larger than 1.0. Defaults to None. + """ + + def __init__( + self, + config: TransformerConfig, + vocab_size: int, + max_sequence_length: int, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + share_embeddings_and_output_weights: bool = False, + position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute', + rotary_percent: float = 1.0, + seq_len_interpolation_factor: Optional[float] = None, + ): + super(GPTModel, self).__init__(config=config) + + self.config: TransformerConfig = config + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.position_embedding_type = position_embedding_type + + # megatron core pipelining currently depends on model type + # TODO: remove this dependency ? + self.model_type = ModelType.encoder_or_decoder + + # Embeddings. + if self.pre_process: + self.embedding = GPTEmbedding( + config=self.config, + vocab_size=self.vocab_size, + max_sequence_length=self.max_sequence_length, + add_position_embedding=(self.position_embedding_type == 'learned_absolute'), + ) + + # Rotary Position Embeddings + if self.position_embedding_type == 'rope': + rotary_dim = self.config.kv_channels + if rotary_percent < 1.0: + rotary_dim = int(rotary_dim * rotary_percent) + + self.rotary_pos_emb = RotaryEmbedding(rotary_dim, seq_len_interpolation_factor) + else: + self.rotary_pos_emb = None + + # Transformer. + self.decoder = TransformerBlock( + config=self.config, + self_attn_mask_type=AttnMaskType.causal, + pre_process=self.pre_process, + post_process=self.post_process, + ) + + # Output + if post_process: + self.output_layer = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + self.vocab_size, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=self.pre_process + and self.share_embeddings_and_output_weights, + ) + + if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process): + self.initialize_last_stage_with_word_embeddings() + + def set_input_tensor(self, input_tensor): + """ See megatron.model.transformer.set_input_tensor()""" + + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt' + self.decoder.set_input_tensor(input_tensor[0]) + + def forward( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_params=None, + ): + # If decoder_input is provided (not None), then input_ids and position_ids are ignored. + # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + + # Decoder embedding. + if decoder_input is not None: + pass + elif self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + else: + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor + decoder_input = None + + # Rotary positional embeddings + rotary_pos_emb = None + if self.rotary_pos_emb is not None: + if inference_params is not None: + rotary_seq_len = inference_params.max_sequence_length + else: + if self.decoder.input_tensor is not None: + rotary_seq_len = self.decoder.input_tensor.size(0) + else: + rotary_seq_len = decoder_input.size(0) + + # Decoder input is split along sequence dimension, but RoPE is applied in tensor parallel region + if self.config.sequence_parallel: + rotary_seq_len *= self.config.tensor_model_parallel_size + + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + # Run decoder. + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + ) + + if not self.post_process: + return hidden_states + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + logits, _ = self.output_layer(hidden_states, weight=output_weight) + + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + # [b s] => [s b] + labels = labels.transpose(0, 1).contiguous() + loss = tensor_parallel.vocab_parallel_cross_entropy(logits.float(), labels) + + # [s b] => [b, s] + loss = loss.transpose(0, 1).contiguous() + return loss + + def shared_embedding_or_output_weight(self): + if self.pre_process: + return self.embedding.word_embeddings.weight + elif self.post_process: + return self.output_layer.weight + return None + + def initialize_last_stage_with_word_embeddings(self): + + # This function just initializes the word embeddings in the final stage + # when we are using pipeline parallelism and sharing word + # embeddings. Nothing to do if we aren't sharing weights or aren't using + # pipeline parallelism. + if not self.share_embeddings_and_output_weights or (self.pre_process and self.post_process): + return + + if self.post_process and not self.pre_process: + assert not parallel_state.is_pipeline_first_stage() + # set word_embeddings weights to 0 here, then copy first + # stage's weights using all_reduce below. + self.output_layer.weight.data.fill_(0) + self.output_layer.weight.shared = True + + # Parameters are shared between the word embeddings layers, and the + # heads at the end of the model. In a pipelined setup with more than + # one stage, the initial embedding layer and the head are on different + # workers, so we do the following: + # 1. Create a second copy of word_embeddings on the last stage, with + # initial parameters of 0.0. + # 2. Do an all-reduce between the first and last stage to ensure that + # the two copies of word_embeddings start off with the same + # parameter values. + # 3. In the training loop, before an all-reduce between the grads of + # the two word_embeddings layers to ensure that every applied weight + # update is the same on both stages. + + # Ensure that first and last stages have the same initial parameter + # values. + if torch.distributed.is_initialized(): + if parallel_state.is_rank_in_embedding_group(): + weight = self.shared_embedding_or_output_weight() + torch.distributed.all_reduce( + weight.data, group=parallel_state.get_embedding_group() + ) + + elif not getattr(GPTModel, "embedding_warning_printed", False): + logging.getLogger(__name__).warning( + "Distributed processes aren't initialized, so the output layer " + "is not initialized with weights from the word embeddings. " + "If you are just manipulating a model this is fine, but " + "this needs to be handled manually. If you are training " + "something is definitely wrong." + ) + GPTModel.embedding_warning_printed = True + + def sharded_state_dict(self, prefix=''): + sharded_state_dict = {} + + if self.pre_process: + embedding_prefix = f'{prefix}embedding.' + embedding_sharded_state_dict = self.embedding.sharded_state_dict( + prefix=embedding_prefix + ) + sharded_state_dict.update(embedding_sharded_state_dict) + + decoder_prefix = f'{prefix}decoder.' + decoder_sharded_state_dict = self.decoder.sharded_state_dict(prefix=decoder_prefix) + sharded_state_dict.update(decoder_sharded_state_dict) + + if self.post_process: + output_layer_prefix = f'{prefix}output_layer.' + output_layer_key = f'{output_layer_prefix}weight' + if self.share_embeddings_and_output_weights: + if not self.pre_process: + # when sharing embeddings with last stage, we need to use the weights from the first stage + # on pipeline first rank, word embeddings are saved to {prefix}embedding.word_embeddings.weight + tensor = self.shared_embedding_or_output_weight() + first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight' + dp_rank = parallel_state.get_data_parallel_rank() + dp_size = parallel_state.get_data_parallel_world_size() + last_stage_word_emb_replica_id = ( + dp_rank + dp_size + ) # copy of first stage embedding + + sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint( + tensor=tensor, + key=first_stage_word_emb_key, + replica_id=last_stage_word_emb_replica_id, + allow_shape_mismatch=True, + ) + + sharded_state_dict[output_layer_key] = sharded_output_layer_tensor + + else: + output_layer_state_dict = self.output_layer.state_dict( + prefix=output_layer_prefix, keep_vars=True + ) + output_layer_tensor = output_layer_state_dict[output_layer_key] + # independent output layer + sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint( + tensor=output_layer_tensor, + key=output_layer_key, + replica_id=parallel_state.get_data_parallel_rank(), + allow_shape_mismatch=True, + ) + + sharded_state_dict[output_layer_key] = sharded_output_layer_tensor + + return sharded_state_dict diff --git a/training/DeepSpeed-Domino/megatron/core/package_info.py b/training/DeepSpeed-Domino/megatron/core/package_info.py new file mode 100644 index 000000000..47025af14 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/package_info.py @@ -0,0 +1,29 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + + +MAJOR = 0 +MINOR = 3 +PATCH = 0 +PRE_RELEASE = '' + +# Use the following formatting: (major, minor, patch, pre-release) +VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE) + +__shortversion__ = '.'.join(map(str, VERSION[:3])) +__version__ = '.'.join(map(str, VERSION[:3])) + ''.join(VERSION[3:]) + +__package_name__ = 'megatron_core' +__contact_names__ = 'NVIDIA' +__contact_emails__ = 'nemo-toolkit@nvidia.com' # use NeMo Email +__homepage__ = ( + 'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/' # use NeMo homepage +) +__repository_url__ = 'https://github.com/NVIDIA/Megatron-LM/megatron/core' +__download_url__ = 'https://github.com/NVIDIA/Megatron-LM/releases' +__description__ = ( + 'Megatron Core - a library for efficient and scalable training of transformer based models' +) +__license__ = 'BSD-3' +__keywords__ = ( + 'deep learning, machine learning, gpu, NLP, NLU, language, transformer, nvidia, pytorch, torch' +) diff --git a/training/DeepSpeed-Domino/megatron/core/parallel_state.py b/training/DeepSpeed-Domino/megatron/core/parallel_state.py new file mode 100644 index 000000000..52e7c46b9 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/parallel_state.py @@ -0,0 +1,651 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Model and data parallel groups.""" + +import os +from typing import Optional + +import torch + +from .utils import GlobalMemoryBuffer + +# Intra-layer model parallel group that the current rank belongs to. +_TENSOR_MODEL_PARALLEL_GROUP = None +# Inter-layer model parallel group that the current rank belongs to. +_PIPELINE_MODEL_PARALLEL_GROUP = None +# Model parallel group (both intra- and pipeline) that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +# Embedding group. +_EMBEDDING_GROUP = None +# Position embedding group. +_POSITION_EMBEDDING_GROUP = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None +_DATA_PARALLEL_GROUP_GLOO = None +# FP8 amax reduction group. +_AMAX_REDUCTION_GROUP = None + +_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None +_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None +_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None + +# These values enable us to change the mpu sizes on the fly. +_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_TENSOR_MODEL_PARALLEL_RANK = None +_MPU_PIPELINE_MODEL_PARALLEL_RANK = None + +# A list of ranks that have a copy of the embedding. +_EMBEDDING_GLOBAL_RANKS = None + +# A list of ranks that have a copy of the position embedding. +_POSITION_EMBEDDING_GLOBAL_RANKS = None + +# A list of global ranks for each pipeline group to ease calculation of the source +# rank when broadcasting from the first or last pipeline stage. +_PIPELINE_GLOBAL_RANKS = None + +# A list of global ranks for each data parallel group to ease calculation of the source +# rank when broadcasting weights from src to all other data parallel ranks +_DATA_PARALLEL_GLOBAL_RANKS = None + +# Memory buffers to avoid dynamic memory allocation +_GLOBAL_MEMORY_BUFFER = None + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + virtual_pipeline_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_split_rank: Optional[int] = None, + use_fp8: bool = False, + use_sharp: bool = False, +) -> None: + """Initialize model data parallel groups. + + Arguments: + tensor_model_parallel_size (int, default = 1): + The number of GPUs to split individual tensors across. + + pipeline_model_parallel_size (int, default = 1): + The number of tensor parallel GPU groups to split the + Transformer layers across. For example, if + tensor_model_parallel_size is 4 and + pipeline_model_parallel_size is 2, the model will be split + into 2 groups of 4 GPUs. + + virtual_pipeline_model_parallel_size (int, optional): + The number of stages that each pipeline group will have, + interleaving as necessary. If None, no interleaving is + performed. For example, if tensor_model_parallel_size is 1, + pipeline_model_parallel_size is 4, + virtual_pipeline_model_parallel_size is 2, and there are + 16 transformer layers in the model, the model will be + split into 8 stages with two layers each and each GPU + would get 2 stages as such (layer number starting with 1): + + GPU 0: [1, 2] [9, 10] + GPU 1: [3, 4] [11, 12] + GPU 2: [5, 6] [13, 14] + GPU 3: [7, 8] [15, 16] + + pipeline_model_parallel_split_rank (int, optional): + For models with both an encoder and decoder, the rank in + pipeline to switch between encoder and decoder (i.e. the + first rank of the decoder). This allows the user to set + the pipeline parallel size of the encoder and decoder + independently. For example, if + pipeline_model_parallel_size is 8 and + pipeline_model_parallel_split_rank is 3, then ranks 0-2 + will be the encoder and ranks 3-7 will be the decoder. + + use_fp8 (bool, default = False): + Construct GPU groups needed for FP8 training, namely for + amax reduction across the product of the data-parallel and + tensor-parallel groups. + + use_sharp (bool, default = False): + Set the use of SHARP for the collective communications of + data-parallel process groups. When `True`, run barrier + within each data-parallel process group, which specifies + the SHARP application target groups. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 8 tensor model-parallel groups, 4 pipeline model-parallel groups + and 8 data-parallel groups as: + 8 data_parallel groups: + [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] + 8 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 4 pipeline model-parallel groups: + [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + + if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by tensor_model_parallel_size " + f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})" + ) + + data_parallel_size: int = world_size // ( + tensor_model_parallel_size * pipeline_model_parallel_size + ) + + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + num_data_parallel_groups: int = world_size // data_parallel_size + + if virtual_pipeline_model_parallel_size is not None: + if not pipeline_model_parallel_size > 2: + raise RuntimeError( + "pipeline-model-parallel size should be greater than 2 with interleaved schedule" + ) + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 + _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size + + if pipeline_model_parallel_split_rank is not None: + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank + + rank = torch.distributed.get_rank() + + # Build the data-parallel groups. + global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GROUP_GLOO + global _DATA_PARALLEL_GLOBAL_RANKS + assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized' + all_data_parallel_group_ranks = [] + for i in range(pipeline_model_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + end_rank = (i + 1) * num_pipeline_model_parallel_groups + for j in range(tensor_model_parallel_size): + ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) + all_data_parallel_group_ranks.append(list(ranks)) + group = torch.distributed.new_group(ranks) + group_gloo = torch.distributed.new_group(ranks, backend="gloo") + if rank in ranks: + _DATA_PARALLEL_GROUP = group + _DATA_PARALLEL_GROUP_GLOO = group_gloo + _DATA_PARALLEL_GLOBAL_RANKS = ranks + + # Apply SHARP to DP process groups + if use_sharp: + if rank == 0: + print( + "The number of process groups to use SHARP with depends on the type " + "of the network switch. Nvidia QM1 switch supports SAHRP up to 8 " + "process groups and QM2 supports up to 256 process groups. We apply " + "SHARP to the communications of the data-parallel domain. If the " + "number of data-parallel process groups is larger than the max " + "process groups that the network switch supports, the communication " + "will fall back to non-SHARP operators. To enable SHARP, " + "`#SBATCH_NETWORK=sharp` should be set in the sbatch script." + ) + torch.distributed.barrier( + group=get_data_parallel_group(), device_ids=[torch.cuda.current_device()] + ) + # Set `NCCL_SHARP_DISABLE=1` to restrict SHARP application to DP process groups + os.environ["NCCL_SHARP_DISABLE"] = "1" + + # Build the model-parallel groups. + global _MODEL_PARALLEL_GROUP + assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized' + for i in range(data_parallel_size): + ranks = [ + data_parallel_group_ranks[i] + for data_parallel_group_ranks in all_data_parallel_group_ranks + ] + group = torch.distributed.new_group(ranks) + if rank in ranks: + _MODEL_PARALLEL_GROUP = group + + # Build the tensor model-parallel groups. + global _TENSOR_MODEL_PARALLEL_GROUP + assert ( + _TENSOR_MODEL_PARALLEL_GROUP is None + ), 'tensor model parallel group is already initialized' + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _TENSOR_MODEL_PARALLEL_GROUP = group + + # Build the pipeline model-parallel groups and embedding groups + # (first and last rank in each pipeline model-parallel group). + global _PIPELINE_MODEL_PARALLEL_GROUP + global _PIPELINE_GLOBAL_RANKS + assert ( + _PIPELINE_MODEL_PARALLEL_GROUP is None + ), 'pipeline model parallel group is already initialized' + global _EMBEDDING_GROUP + global _EMBEDDING_GLOBAL_RANKS + assert _EMBEDDING_GROUP is None, 'embedding group is already initialized' + global _POSITION_EMBEDDING_GROUP + global _POSITION_EMBEDDING_GLOBAL_RANKS + assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized' + for i in range(num_pipeline_model_parallel_groups): + ranks = range(i, world_size, num_pipeline_model_parallel_groups) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _PIPELINE_MODEL_PARALLEL_GROUP = group + _PIPELINE_GLOBAL_RANKS = ranks + # Setup embedding group (to exchange gradients between + # first and last stages). + if len(ranks) > 1: + embedding_ranks = [ranks[0], ranks[-1]] + position_embedding_ranks = [ranks[0]] + if pipeline_model_parallel_split_rank is not None: + if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks: + embedding_ranks = [ + ranks[0], + ranks[pipeline_model_parallel_split_rank], + ranks[-1], + ] + if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks: + position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]] + else: + embedding_ranks = ranks + position_embedding_ranks = ranks + + group = torch.distributed.new_group(embedding_ranks) + if rank in embedding_ranks: + _EMBEDDING_GROUP = group + if rank in ranks: + _EMBEDDING_GLOBAL_RANKS = embedding_ranks + + group = torch.distributed.new_group(position_embedding_ranks) + if rank in position_embedding_ranks: + _POSITION_EMBEDDING_GROUP = group + if rank in ranks: + _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks + + # Build the FP8 groups. + global _AMAX_REDUCTION_GROUP + assert _AMAX_REDUCTION_GROUP is None, 'FP8 amax reduction group is already initialized' + if use_fp8: + amax_group_size: int = tensor_model_parallel_size * data_parallel_size + num_amax_groups: int = world_size // amax_group_size + for i in range(num_amax_groups): + start_rank = i * amax_group_size + end_rank = (i + 1) * amax_group_size + ranks = range(start_rank, end_rank) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _AMAX_REDUCTION_GROUP = group + + # Initialize global memory buffer + # This isn't really "parallel state" but there isn't another good place to + # put this. If we end up with a more generic initialization of megatron-core + # we could stick it there + _set_global_memory_buffer() + + +def is_unitialized(): + """Useful for code segments that may be accessed with or without mpu initialization""" + return _DATA_PARALLEL_GROUP is None + + +def model_parallel_is_initialized(): + """Check if model and data parallel groups are initialized.""" + if ( + _TENSOR_MODEL_PARALLEL_GROUP is None + or _PIPELINE_MODEL_PARALLEL_GROUP is None + or _DATA_PARALLEL_GROUP is None + ): + return False + return True + + +def get_model_parallel_group(): + """Get the model parallel group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, 'model parallel group is not initialized' + return _MODEL_PARALLEL_GROUP + + +def get_tensor_model_parallel_group(check_initialized=True): + """Get the tensor model parallel group the caller rank belongs to.""" + if check_initialized: + assert ( + _TENSOR_MODEL_PARALLEL_GROUP is not None + ), 'tensor model parallel group is not initialized' + return _TENSOR_MODEL_PARALLEL_GROUP + + +def get_pipeline_model_parallel_group(): + """Get the pipeline model parallel group the caller rank belongs to.""" + assert ( + _PIPELINE_MODEL_PARALLEL_GROUP is not None + ), 'pipeline_model parallel group is not initialized' + return _PIPELINE_MODEL_PARALLEL_GROUP + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized' + return _DATA_PARALLEL_GROUP + + +def get_data_parallel_group_gloo(): + """Get the data parallel group-gloo the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP_GLOO is not None, 'data parallel group-gloo is not initialized' + return _DATA_PARALLEL_GROUP_GLOO + + +def get_embedding_group(): + """Get the embedding group the caller rank belongs to.""" + assert _EMBEDDING_GROUP is not None, 'embedding group is not initialized' + return _EMBEDDING_GROUP + + +def get_position_embedding_group(): + """Get the position embedding group the caller rank belongs to.""" + assert _POSITION_EMBEDDING_GROUP is not None, 'position embedding group is not initialized' + return _POSITION_EMBEDDING_GROUP + + +def get_amax_reduction_group(): + """Get the FP8 amax reduction group the caller rank belongs to.""" + assert _AMAX_REDUCTION_GROUP is not None, 'FP8 amax reduction group is not initialized' + return _AMAX_REDUCTION_GROUP + + +def set_tensor_model_parallel_world_size(world_size): + """Set the tensor model parallel size""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def set_pipeline_model_parallel_world_size(world_size): + """Set the pipeline model parallel size""" + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def set_virtual_pipeline_model_parallel_world_size(world_size): + """Set the pipeline model parallel size""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) + + +def set_tensor_model_parallel_rank(rank): + """Set tensor model parallel rank.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + _MPU_TENSOR_MODEL_PARALLEL_RANK = rank + + +def set_pipeline_model_parallel_rank(rank): + """Set pipeline model parallel rank.""" + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank + + +def set_pipeline_model_parallel_split_rank(rank): + """Set pipeline model parallel split rank.""" + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: + return _MPU_TENSOR_MODEL_PARALLEL_RANK + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: + return _MPU_PIPELINE_MODEL_PARALLEL_RANK + return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) + + +def get_pipeline_model_parallel_split_rank(): + """Return pipeline model parallel split rank.""" + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + + +def is_pipeline_first_stage(ignore_virtual=False): + """Return True if in the first pipeline model-parallel stage, False otherwise.""" + if not ignore_virtual: + if ( + get_virtual_pipeline_model_parallel_world_size() is not None + and get_virtual_pipeline_model_parallel_rank() != 0 + ): + return False + return get_pipeline_model_parallel_rank() == 0 + + +def is_pipeline_last_stage(ignore_virtual=False): + """Return True if in the last pipeline model-parallel stage, False otherwise.""" + if not ignore_virtual: + virtual_pipeline_model_parallel_world_size = ( + get_virtual_pipeline_model_parallel_world_size() + ) + if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != ( + virtual_pipeline_model_parallel_world_size - 1 + ): + return False + return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1) + + +def is_rank_in_embedding_group(ignore_virtual=False): + """Return true if current rank is in embedding group, False otherwise.""" + rank = torch.distributed.get_rank() + global _EMBEDDING_GLOBAL_RANKS + if ignore_virtual: + return rank in _EMBEDDING_GLOBAL_RANKS + if rank in _EMBEDDING_GLOBAL_RANKS: + if rank == _EMBEDDING_GLOBAL_RANKS[0]: + return is_pipeline_first_stage(ignore_virtual=False) + elif rank == _EMBEDDING_GLOBAL_RANKS[-1]: + return is_pipeline_last_stage(ignore_virtual=False) + else: + return True + return False + + +def is_rank_in_position_embedding_group(): + """Return true if current rank is in position embedding group, False otherwise.""" + rank = torch.distributed.get_rank() + global _POSITION_EMBEDDING_GLOBAL_RANKS + return rank in _POSITION_EMBEDDING_GLOBAL_RANKS + + +def is_pipeline_stage_before_split(rank=None): + """Return True if pipeline stage executes encoder block for a model + with both encoder and decoder.""" + if get_pipeline_model_parallel_world_size() == 1: + return True + if rank is None: + rank = get_pipeline_model_parallel_rank() + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: + return True + if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: + return True + return False + + +def is_pipeline_stage_after_split(rank=None): + """Return True if pipeline stage executes decoder block for a model + with both encoder and decoder.""" + if get_pipeline_model_parallel_world_size() == 1: + return True + if rank is None: + rank = get_pipeline_model_parallel_rank() + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: + return True + if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: + return True + return False + + +def is_pipeline_stage_at_split(): + """Return true if pipeline stage executes decoder block and next + stage executes encoder block for a model with both encoder and + decoder.""" + rank = get_pipeline_model_parallel_rank() + return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1) + + +def get_virtual_pipeline_model_parallel_rank(): + """Return the virtual pipeline-parallel rank.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + + +def set_virtual_pipeline_model_parallel_rank(rank): + """Set the virtual pipeline-parallel rank.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank + + +def get_virtual_pipeline_model_parallel_world_size(): + """Return the virtual pipeline-parallel world size.""" + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +def get_data_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the data parallel group.""" + assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized" + return _DATA_PARALLEL_GLOBAL_RANKS[0] + + +def get_pipeline_model_parallel_first_rank(): + """Return the global rank of the first process in the pipeline for the + current tensor parallel group""" + assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + return _PIPELINE_GLOBAL_RANKS[0] + + +def get_pipeline_model_parallel_last_rank(): + """Return the global rank of the last process in the pipeline for the + current tensor parallel group""" + assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + last_rank_local = get_pipeline_model_parallel_world_size() - 1 + return _PIPELINE_GLOBAL_RANKS[last_rank_local] + + +def get_pipeline_model_parallel_next_rank(): + """Return the global rank that follows the caller in the pipeline""" + assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + + +def get_pipeline_model_parallel_prev_rank(): + """Return the global rank that preceeds the caller in the pipeline""" + assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return torch.distributed.get_world_size(group=get_data_parallel_group()) + else: + return 0 + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return torch.distributed.get_rank(group=get_data_parallel_group()) + else: + return 0 + + +def _set_global_memory_buffer(): + """Initialize global buffer""" + global _GLOBAL_MEMORY_BUFFER + assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized' + _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() + + +def get_global_memory_buffer(): + """Return the global GlobalMemoryBuffer object""" + assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized' + return _GLOBAL_MEMORY_BUFFER + + +def destroy_global_memory_buffer(): + """Sets the global memory buffer to None""" + global _GLOBAL_MEMORY_BUFFER + _GLOBAL_MEMORY_BUFFER = None + + +def destroy_model_parallel(): + """Set the groups to none.""" + global _MODEL_PARALLEL_GROUP + _MODEL_PARALLEL_GROUP = None + global _TENSOR_MODEL_PARALLEL_GROUP + _TENSOR_MODEL_PARALLEL_GROUP = None + global _PIPELINE_MODEL_PARALLEL_GROUP + _PIPELINE_MODEL_PARALLEL_GROUP = None + global _DATA_PARALLEL_GROUP + _DATA_PARALLEL_GROUP = None + global _EMBEDDING_GROUP + _EMBEDDING_GROUP = None + global _POSITION_EMBEDDING_GROUP + _POSITION_EMBEDDING_GROUP = None + global _AMAX_REDUCTION_GROUP + _AMAX_REDUCTION_GROUP = None + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK + _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None + global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None + global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE + _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None + global _MPU_TENSOR_MODEL_PARALLEL_RANK + _MPU_TENSOR_MODEL_PARALLEL_RANK = None + global _MPU_PIPELINE_MODEL_PARALLEL_RANK + _MPU_PIPELINE_MODEL_PARALLEL_RANK = None + global _GLOBAL_MEMORY_BUFFER + _GLOBAL_MEMORY_BUFFER = None diff --git a/training/DeepSpeed-Domino/megatron/core/pipeline_parallel/__init__.py b/training/DeepSpeed-Domino/megatron/core/pipeline_parallel/__init__.py new file mode 100644 index 000000000..00cd1ff38 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/pipeline_parallel/__init__.py @@ -0,0 +1 @@ +from .schedules import get_forward_backward_func diff --git a/training/DeepSpeed-Domino/megatron/core/pipeline_parallel/p2p_communication.py b/training/DeepSpeed-Domino/megatron/core/pipeline_parallel/p2p_communication.py new file mode 100644 index 000000000..29ee34df8 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/pipeline_parallel/p2p_communication.py @@ -0,0 +1,571 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import operator +from functools import reduce +from typing import Callable, List, Optional, Tuple, Union + +import torch + +from megatron import core +from megatron.core import ModelParallelConfig +from megatron.core.parallel_state import ( + get_pipeline_model_parallel_group, + get_pipeline_model_parallel_next_rank, + get_pipeline_model_parallel_prev_rank, + get_pipeline_model_parallel_rank, +) + +# Types +Shape = Union[List[int], torch.Size] + + +def _communicate_shapes(tensor_send_next, tensor_send_prev, recv_prev, recv_next, config): + """Communicate tensor shapes between stages. Used to communicate + tensor shapes before the actual tensor communication happens. + This is required when the sequence lengths across micro batches + are not uniform. + + Takes the following arguments: + tensor_send_next: tensor to send to next rank (no tensor sent if + set to None). + tensor_send_prev: tensor to send to prev rank (no tensor sent if + set to None). + recv_prev: boolean for whether tensor should be received from + previous rank. + recv_next: boolean for whether tensor should be received from + next rank. + Returns: + (recv_prev_shape, recv_next_shape) + """ + + recv_prev_shape_tensor = None + recv_next_shape_tensor = None + send_prev_shape_tensor = None + send_next_shape_tensor = None + if recv_prev: + recv_prev_shape_tensor = torch.empty( + (3), device=torch.cuda.current_device(), dtype=torch.int64 + ) + if recv_next: + recv_next_shape_tensor = torch.empty( + (3), device=torch.cuda.current_device(), dtype=torch.int64 + ) + if tensor_send_prev is not None: + send_prev_shape_tensor = torch.tensor( + tensor_send_prev.size(), device=torch.cuda.current_device(), dtype=torch.int64 + ) + if tensor_send_next is not None: + send_next_shape_tensor = torch.tensor( + tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64 + ) + + if config.use_ring_exchange_p2p: + torch.distributed.ring_exchange( + tensor_send_prev=send_prev_shape_tensor, + tensor_recv_prev=recv_prev_shape_tensor, + tensor_send_next=send_next_shape_tensor, + tensor_recv_next=recv_next_shape_tensor, + group=get_pipeline_model_parallel_group(), + ) + else: + ops = [] + if send_prev_shape_tensor is not None: + send_prev_op = torch.distributed.P2POp( + torch.distributed.isend, + send_prev_shape_tensor, + get_pipeline_model_parallel_prev_rank(), + ) + ops.append(send_prev_op) + if recv_prev_shape_tensor is not None: + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_prev_shape_tensor, + get_pipeline_model_parallel_prev_rank(), + ) + ops.append(recv_prev_op) + if send_next_shape_tensor is not None: + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, + send_next_shape_tensor, + get_pipeline_model_parallel_next_rank(), + ) + ops.append(send_next_op) + if recv_next_shape_tensor is not None: + recv_next_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_next_shape_tensor, + get_pipeline_model_parallel_next_rank(), + ) + ops.append(recv_next_op) + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # To protect against race condition when using batch_isend_irecv(). + # should take this out once the bug with batch_isend_irecv is resolved. + torch.cuda.synchronize() + + recv_prev_shape = [0, 0, 0] + if recv_prev_shape_tensor is not None: + recv_prev_shape = recv_prev_shape_tensor.tolist() + + recv_next_shape = [0, 0, 0] + if recv_next_shape_tensor is not None: + recv_next_shape = recv_next_shape_tensor.tolist() + + return recv_prev_shape, recv_next_shape + + +def _batched_p2p_ops( + *, + tensor_send_prev: Optional[torch.Tensor], + tensor_recv_prev: Optional[torch.Tensor], + tensor_send_next: Optional[torch.Tensor], + tensor_recv_next: Optional[torch.Tensor], + group: torch.distributed.ProcessGroup +): + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( + torch.distributed.isend, + tensor_send_prev, + get_pipeline_model_parallel_prev_rank(), + group, + ) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_recv_prev, + get_pipeline_model_parallel_prev_rank(), + group, + ) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, + tensor_send_next, + get_pipeline_model_parallel_next_rank(), + group, + ) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_recv_next, + get_pipeline_model_parallel_next_rank(), + group, + ) + ops.append(recv_next_op) + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + else: + reqs = [] + return reqs + + +def _p2p_ops( + *, + tensor_send_prev: Optional[torch.Tensor], + tensor_recv_prev: Optional[torch.Tensor], + tensor_send_next: Optional[torch.Tensor], + tensor_recv_next: Optional[torch.Tensor], + group: torch.distributed.ProcessGroup +): + reqs = [] + rank = get_pipeline_model_parallel_rank() + if get_pipeline_model_parallel_rank() % 2 == 0: + if tensor_send_next is not None: + send_next_req = torch.distributed.isend( + tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(send_next_req) + + if tensor_recv_prev is not None: + recv_prev_req = torch.distributed.irecv( + tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(recv_prev_req) + + if tensor_send_prev is not None: + send_prev_req = torch.distributed.isend( + tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(send_prev_req) + + if tensor_recv_next is not None: + recv_next_req = torch.distributed.irecv( + tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(recv_next_req) + + else: + if tensor_recv_prev is not None: + recv_prev_req = torch.distributed.irecv( + tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(recv_prev_req) + + if tensor_send_next is not None: + send_next_req = torch.distributed.isend( + tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(send_next_req) + + if tensor_recv_next is not None: + recv_next_req = torch.distributed.irecv( + tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(recv_next_req) + + if tensor_send_prev is not None: + send_prev_req = torch.distributed.isend( + tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(send_prev_req) + return reqs + + +def _communicate( + *, + tensor_send_next: Optional[torch.Tensor], + tensor_send_prev: Optional[torch.Tensor], + recv_prev: bool, + recv_next: bool, + tensor_shape: Shape, + config: ModelParallelConfig, + wait_on_reqs: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + """Communicate tensors between stages. Used as helper method in other + communication methods that are used in megatron/schedules.py. + + Arguments: + tensor_send_next (torch.Tensor, optional): + Tensor to send to next rank (no tensor sent if None) + + tensor_send_prev (torch.Tensor, optional): + Tensor to send to prev rank (no tensor sent if None) + + recv_prev (boolean, required): + whether tensor should be received from previous rank. + + recv_next (boolean, required): + whether tensor should be received from next rank. + + tensor_shape (List[int] or torch.Size, required): + shape of tensor to receive (this method assumes that all + tensors sent and received in a single function call are + the same shape). + + wait_on_reqs (boolean, optional, default=False): + For non-batched p2p communication, wait on each request + before returning. + + Returns: + tuple containing + + - tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise. + - tensor_recv_next: torch.Tensor if recv_next is True, None otherwise. + + """ + + # Create placeholder tensors for receive in forward and backward directions + # if needed. + tensor_recv_prev = None + tensor_recv_next = None + + if not config.variable_seq_lengths: + recv_prev_shape = tensor_shape + recv_next_shape = tensor_shape + else: + recv_prev_shape, recv_next_shape = _communicate_shapes( + tensor_send_next, tensor_send_prev, recv_prev, recv_next, config + ) + + if recv_prev: + if config.pipeline_dtype is None: + raise RuntimeError("pipeline_dtype must be provided if recv_prev is True") + if tensor_shape is None: + raise RuntimeError( + "tensor_shape must be specified if recv_prev is True. " + "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)" + ) + tensor_recv_prev = torch.empty( + recv_prev_shape, + requires_grad=True, + device=torch.cuda.current_device(), + dtype=config.pipeline_dtype, + ) + if recv_next: + if config.pipeline_dtype is None: + raise RuntimeError("dtype must be provided if recv_next is True") + if tensor_shape is None: + raise RuntimeError( + "tensor_shape must be specified if recv_next is True. " + "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)" + ) + tensor_recv_next = torch.empty( + recv_next_shape, + requires_grad=True, + device=torch.cuda.current_device(), + dtype=config.pipeline_dtype, + ) + + # Send tensors in both the forward and backward directions as appropriate. + if config.use_ring_exchange_p2p: + + def _ring_exchange_wrapper(**kwargs): + torch.distributed.ring_exchange(**kwargs) + return [] + + p2p_func = _ring_exchange_wrapper + elif config.batch_p2p_comm: + assert wait_on_reqs + p2p_func = _batched_p2p_ops + else: + p2p_func = _p2p_ops + + reqs = p2p_func( + tensor_send_prev=tensor_send_prev, + tensor_recv_prev=tensor_recv_prev, + tensor_send_next=tensor_send_next, + tensor_recv_next=tensor_recv_next, + group=get_pipeline_model_parallel_group(), + ) + + if wait_on_reqs and len(reqs) > 0: + for req in reqs: + req.wait() + reqs = None + + if config.batch_p2p_comm and config.batch_p2p_sync: + # To protect against race condition when using batch_isend_irecv(). + # User should assert that we have a modern enough PyTorch to not need this + torch.cuda.synchronize() + + return tensor_recv_prev, tensor_recv_next, reqs + + +def recv_forward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor: + """ Receive tensor from previous rank in pipeline (forward receive). + + + See _communicate for argument details. + """ + + if core.parallel_state.is_pipeline_first_stage(): + input_tensor = None + else: + if config.timers is not None: + config.timers('forward-recv', log_level=2).start() + input_tensor, _, _ = _communicate( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=True, + recv_next=False, + tensor_shape=tensor_shape, + config=config, + ) + if config.timers is not None: + config.timers('forward-recv').stop() + return input_tensor + + +def recv_backward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor: + """Receive tensor from next rank in pipeline (backward receive). + + See _communicate for argument details. + """ + if core.parallel_state.is_pipeline_last_stage(): + output_tensor_grad = None + else: + if config.timers is not None: + config.timers('backward-recv', log_level=2).start() + _, output_tensor_grad, _ = _communicate( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=False, + recv_next=True, + tensor_shape=tensor_shape, + config=config, + ) + if config.timers is not None: + config.timers('backward-recv').stop() + return output_tensor_grad + + +def send_forward(output_tensor: torch.Tensor, config: ModelParallelConfig) -> None: + """Send tensor to next rank in pipeline (forward send). + + See _communicate for argument details. + """ + + if not core.parallel_state.is_pipeline_last_stage(): + if config.timers is not None: + config.timers('forward-send', log_level=2).start() + _communicate( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=False, + tensor_shape=None, + config=config, + ) + if config.timers is not None: + config.timers('forward-send').stop() + + +def send_backward(input_tensor_grad: torch.Tensor, config: ModelParallelConfig) -> None: + """Send tensor to previous rank in pipeline (backward send). + + See _communicate for argument details. + """ + if not core.parallel_state.is_pipeline_first_stage(): + if config.timers is not None: + config.timers('backward-send', log_level=2).start() + _communicate( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=False, + recv_next=False, + tensor_shape=None, + config=config, + ) + if config.timers is not None: + config.timers('backward-send').stop() + + +def send_forward_recv_backward( + output_tensor: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig +) -> torch.Tensor: + """Batched send and recv with next rank in pipeline. + + See _communicate for argument details. + """ + if core.parallel_state.is_pipeline_last_stage(): + output_tensor_grad = None + else: + if config.timers is not None: + config.timers('forward-send-backward-recv', log_level=2).start() + _, output_tensor_grad, _ = _communicate( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=True, + tensor_shape=tensor_shape, + config=config, + ) + if config.timers is not None: + config.timers('forward-send-backward-recv').stop() + return output_tensor_grad + + +def send_backward_recv_forward( + input_tensor_grad: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig +) -> torch.Tensor: + """Batched send and recv with previous rank in pipeline. + + See _communicate for argument details. + """ + if core.parallel_state.is_pipeline_first_stage(): + input_tensor = None + else: + if config.timers is not None: + config.timers('backward-send-forward-recv', log_level=2).start() + input_tensor, _, _ = _communicate( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=True, + recv_next=False, + tensor_shape=tensor_shape, + config=config, + ) + if config.timers is not None: + config.timers('backward-send-forward-recv').stop() + return input_tensor + + +def send_forward_recv_forward( + output_tensor: torch.Tensor, + recv_prev: bool, + tensor_shape: Shape, + config: ModelParallelConfig, + overlap_p2p_comm: bool = False, +) -> torch.Tensor: + """Batched recv from previous rank and send to next rank in pipeline. + + See _communicate for argument details. + """ + if config.timers is not None: + config.timers('forward-send-forward-recv', log_level=2).start() + input_tensor, _, wait_handles = _communicate( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=recv_prev, + recv_next=False, + tensor_shape=tensor_shape, + wait_on_reqs=(not overlap_p2p_comm), + config=config, + ) + if config.timers is not None: + config.timers('forward-send-forward-recv').stop() + if overlap_p2p_comm: + return input_tensor, wait_handles + return input_tensor + + +def send_backward_recv_backward( + input_tensor_grad: torch.Tensor, + recv_next: bool, + tensor_shape: Shape, + config: ModelParallelConfig, + overlap_p2p_comm: bool = False, +) -> torch.Tensor: + """Batched recv from next rank and send to previous rank in pipeline. + + See _communicate for argument details. + """ + if config.timers is not None: + config.timers('backward-send-backward-recv', log_level=2).start() + _, output_tensor_grad, wait_handles = _communicate( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=False, + recv_next=recv_next, + tensor_shape=tensor_shape, + wait_on_reqs=(not overlap_p2p_comm), + config=config, + ) + if config.timers is not None: + config.timers('backward-send-backward-recv').stop() + if overlap_p2p_comm: + return output_tensor_grad, wait_handles + return output_tensor_grad + + +def send_forward_backward_recv_forward_backward( + output_tensor: torch.Tensor, + input_tensor_grad: torch.Tensor, + recv_prev: bool, + recv_next: bool, + tensor_shape: Shape, + config: ModelParallelConfig, +) -> torch.Tensor: + """Batched send and recv with previous and next ranks in pipeline. + + See _communicate for argument details. + """ + if config.timers is not None: + config.timers('forward-backward-send-forward-backward-recv', log_level=2).start() + input_tensor, output_tensor_grad, _ = _communicate( + tensor_send_next=output_tensor, + tensor_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + ) + if config.timers is not None: + config.timers('forward-backward-send-forward-backward-recv').stop() + return input_tensor, output_tensor_grad diff --git a/training/DeepSpeed-Domino/megatron/core/pipeline_parallel/schedules.py b/training/DeepSpeed-Domino/megatron/core/pipeline_parallel/schedules.py new file mode 100644 index 000000000..6eeb15b5c --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/pipeline_parallel/schedules.py @@ -0,0 +1,1254 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import contextlib +from typing import Callable, Iterator, List, Optional, Union + +import torch +from torch.autograd.variable import Variable +from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP + +from megatron import core +from megatron.core import parallel_state +from megatron.core.enums import ModelType +from megatron.core.pipeline_parallel import p2p_communication +from megatron.core.utils import get_attr_wrapped_model, get_model_config, get_model_type + +# Types +Shape = Union[List[int], torch.Size] + + +def get_forward_backward_func(): + """Retrieves the appropriate forward_backward function given the + configuration of parallel_state. + + Returns a function that will perform all of the forward and + backward passes of the model given the pipeline model parallel + world size and virtual pipeline model parallel world size in the + global parallel_state. + + Note that if using sequence parallelism, the sequence length component of + the tensor shape is updated to original_sequence_length / + tensor_model_parallel_world_size. + + The function returned takes the following arguments: + + forward_step_func (required): A function that takes a data + iterator and a model as its arguments and return the model's + forward output and the loss function. The loss function should + take one torch.Tensor and return a torch.Tensor of loss and a + dictionary of string -> torch.Tensor. + + A third argument, checkpoint_activations_microbatch, indicates + that the activations for this microbatch should be + checkpointed. A None value for this argument indicates that + the default from the configuration should be used. This is + used when the + num_microbatches_with_partial_activation_checkpoints is used. + + For example: + + def loss_func(loss_mask, output_tensor): + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {'lm loss': averaged_loss[0]} + + def forward_step(data_iterator, model): + data, loss_mask = next(data_iterator) + output = model(data) + return output, partial(loss_func, loss_mask) + + + forward_backward_func(forward_step_func=forward_step, ...) + + + data_iterator (required): an iterator over the data, will be + passed as is to forward_step_func. Expected to be a list of + iterators in the case of interleaved pipeline parallelism. + + model (required): the actual model. Expected to be a list of modules in the case of interleaved + pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule. + + num_microbatches (int, required): + The number of microbatches to go through + + seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack + transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths + in the config is True. Otherwise, each microbatch in the current global batch size must use + this sequence length. + + micro_batch_size (int, required): The number of sequences in a microbatch. + + decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack + transformer. This is ignored for a single-stack transformer. + + forward_only (optional, default = False): Perform only the forward step + + collect_non_loss_data (optional, bool, default=False): TODO + + """ + pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() + if pipeline_model_parallel_size > 1: + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + forward_backward_func = forward_backward_pipelining_with_interleaving + else: + forward_backward_func = forward_backward_pipelining_without_interleaving + else: + forward_backward_func = forward_backward_no_pipelining + return forward_backward_func + + +def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): + '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. + + This method should be called right after the output tensor has been + sent to the next pipeline stage. At this point, the output tensor is + only useful for its '.grad_fn' field, and not its '.data'. + ''' + if (out is None) or (not deallocate_pipeline_outputs): + return + assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ + assert out._base is None, "counter-productive to free a view of another tensor." + out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) + + +def custom_backward(output, grad_output): + '''Directly call C++ autograd engine. + + To make the 'deallocate_output_tensor' (above) optimization work, the C++ + autograd engine must be called directly, bypassing Pytorch's + torch.autograd.backward. Pytorch's 'backward' checks that the output and + grad have the same shape, while C++'s 'backward' does not. + ''' + + assert output.numel() == 1, "output should be pseudo-'freed' in schedule, to optimize memory" + assert isinstance(output, torch.Tensor), "output == '%s'." % type(output).__name__ + assert isinstance(grad_output, (torch.Tensor, type(None))), ( + "grad_output == '%s'." % type(grad_output).__name__ + ) + + # Handle scalar output + if grad_output is None: + assert output.numel() == 1, "implicit grad requires scalar output." + grad_output = torch.ones_like(output, memory_format=torch.preserve_format,) + + # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ] + Variable._execution_engine.run_backward( + tensors=(output,), + grad_tensors=(grad_output,), + keep_graph=False, + create_graph=False, + inputs=tuple(), + allow_unreachable=True, + accumulate_grad=True, + ) + + +def forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data=False, + checkpoint_activations_microbatch=None, +): + """Forward step for passed-in model. + + If first stage, input tensor is obtained from data_iterator, otherwise + passed-in input_tensor is used. + + Returns output tensor.""" + if config.timers is not None: + config.timers('forward-compute', log_level=2).start() + + unwrap_output_tensor = False + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + unwrap_output_tensor = True + + set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor") + set_input_tensor(input_tensor) + + if config.enable_autocast: + context_manager = torch.autocast("cuda", dtype=config.autocast_dtype) + else: + context_manager = contextlib.nullcontext() + with context_manager: + if checkpoint_activations_microbatch is None: + output_tensor, loss_func = forward_step_func(data_iterator, model) + else: + output_tensor, loss_func = forward_step_func( + data_iterator, model, checkpoint_activations_microbatch + ) + + if parallel_state.is_pipeline_last_stage(): + if not collect_non_loss_data: + output_tensor = loss_func(output_tensor) + loss, loss_reduced = output_tensor + output_tensor = loss / num_microbatches + forward_data_store.append(loss_reduced) + else: + data = loss_func(output_tensor, non_loss_data=True) + forward_data_store.append(data) + + if config.timers is not None: + config.timers('forward-compute').stop() + + # If T5 model (or other model with encoder and decoder) + # and in decoder stack, then send encoder_hidden_state + # downstream as well. + model_type = get_model_type(model) + if ( + parallel_state.is_pipeline_stage_after_split() + and model_type == ModelType.encoder_and_decoder + ): + return [output_tensor, input_tensor[-1]] + if unwrap_output_tensor: + return output_tensor + return [output_tensor] + + +def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config): + """Backward step through passed-in output tensor. + + If last stage, output_tensor_grad is None, otherwise gradient of loss + with respect to stage's output tensor. + + Returns gradient of loss with respect to input tensor (None if first + stage).""" + + # NOTE: This code currently can handle at most one skip connection. It + # needs to be modified slightly to support arbitrary numbers of skip + # connections. + + if config.timers is not None: + config.timers('backward-compute', log_level=2).start() + + # Retain the grad on the input_tensor. + unwrap_input_tensor_grad = False + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + unwrap_input_tensor_grad = True + for x in input_tensor: + if x is not None: + x.retain_grad() + + if not isinstance(output_tensor, list): + output_tensor = [output_tensor] + if not isinstance(output_tensor_grad, list): + output_tensor_grad = [output_tensor_grad] + + # Backward pass. + if output_tensor_grad[0] is None and config.grad_scale_func is not None: + output_tensor[0] = config.grad_scale_func(output_tensor[0]) + + if config.deallocate_pipeline_outputs: + custom_backward(output_tensor[0], output_tensor_grad[0]) + else: + torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) + + # Collect the grad of the input_tensor. + input_tensor_grad = [None] + if input_tensor is not None: + input_tensor_grad = [] + for x in input_tensor: + if x is None: + input_tensor_grad.append(None) + else: + input_tensor_grad.append(x.grad) + + # Handle single skip connection if it exists (encoder_hidden_state in + # model with encoder and decoder). + if ( + parallel_state.get_pipeline_model_parallel_world_size() > 1 + and parallel_state.is_pipeline_stage_after_split() + and model_type == ModelType.encoder_and_decoder + ): + if output_tensor_grad[1] is not None: + input_tensor_grad[-1].add_(output_tensor_grad[1]) + if unwrap_input_tensor_grad: + input_tensor_grad = input_tensor_grad[0] + + if config.timers is not None: + config.timers('backward-compute').stop() + + return input_tensor_grad + + +def forward_backward_no_pipelining( + *, + forward_step_func, + data_iterator: Union[Iterator, List[Iterator]], + model: Union[torch.nn.Module, List[torch.nn.Module]], + num_microbatches: int, + seq_length: int, # unused + micro_batch_size: int, # unused + decoder_seq_length: int = None, # unused + forward_only: bool = False, + collect_non_loss_data: bool = False, +): + """Run forward and backward passes with no pipeline parallelism + (no inter-stage communication). + + Returns dictionary with losses. + + + See get_forward_backward_func() for argument details + """ + + if isinstance(model, list): + assert len(model) == 1, "non-pipeline-parallel schedule does not support model chunking" + model = model[0] + if isinstance(data_iterator, list): + assert ( + len(data_iterator) == 1 + ), "non-pipeline-parallel schedule does not support model chunking" + data_iterator = data_iterator[0] + + config = get_model_config(model) + + no_sync_func = config.no_sync_func + if no_sync_func is None and isinstance(model, torchDDP): + no_sync_func = model.no_sync + if no_sync_func is None: + no_sync_func = contextlib.nullcontext + + model_type = get_model_type(model) + + forward_data_store = [] + input_tensor, output_tensor_grad = None, None + with no_sync_func(): + for i in range(num_microbatches - 1): + output_tensor = forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + ) + if not forward_only: + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + + # Run computation for last microbatch out of context handler (want to + # synchronize gradients). + output_tensor = forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + ) + + if not forward_only: + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + + return forward_data_store + + +def forward_backward_pipelining_with_interleaving( + *, + forward_step_func, + data_iterator: Union[Iterator, List[Iterator]], + model: Union[torch.nn.Module, List[torch.nn.Module]], + num_microbatches: int, + seq_length: int, + micro_batch_size: int, + decoder_seq_length: int = None, + forward_only: bool = False, + collect_non_loss_data: bool = False, +): + """Run interleaved 1F1B schedule (model split into model chunks), with + communication between pipeline stages as needed. + + Returns dictionary with losses if the last stage, empty dict otherwise.""" + assert isinstance(model, list), "interleaved pipeline parallelism expected model chunking" + assert all(isinstance(chunk, torch.nn.Module) for chunk in model), "invalid model chunking" + assert isinstance( + data_iterator, list + ), "interleaved pipeline parallelism expected each model chunk to have a data iterator" + + config = get_model_config(model[0]) + if config.overlap_p2p_comm and config.batch_p2p_comm: + raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") + + # Disable async grad reductions + no_sync_func = config.no_sync_func + if no_sync_func is None and all(isinstance(chunk, torchDDP) for chunk in model): + + def multi_no_sync(): + stack = contextlib.ExitStack() + for chunk in model: + stack.enter_context(chunk.no_sync()) + return stack + + no_sync_func = multi_no_sync + if no_sync_func is None: + no_sync_func = contextlib.nullcontext + no_sync_context = None + + def disable_grad_sync(): + """Disable asynchronous grad reductions""" + nonlocal no_sync_context + if no_sync_context is None: + no_sync_context = no_sync_func() + no_sync_context.__enter__() + + def enable_grad_sync(): + """Enable asynchronous grad reductions""" + nonlocal no_sync_context + if no_sync_context is not None: + no_sync_context.__exit__(None, None, None) + no_sync_context = None + + disable_grad_sync() + + # Model chunk IDs with synchronized grads + synchronized_model_chunks = set() + + input_tensors = [[] for _ in range(len(model))] + output_tensors = [[] for _ in range(len(model))] + forward_data_store = [] + if not forward_only: + output_tensor_grads = [[] for _ in range(len(model))] + + pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() + pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank() + + if num_microbatches % pipeline_parallel_size != 0: + msg = f'number of microbatches ({num_microbatches}) is not divisible by ' + msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) ' + msg += 'when using interleaved schedule' + raise RuntimeError(msg) + + model_type = get_model_type(model[0]) + if model_type == ModelType.encoder_and_decoder: + raise RuntimeError("Interleaving is not supported with an encoder and decoder model.") + + if decoder_seq_length is not None and decoder_seq_length != seq_length: + raise RuntimeError( + "Interleaving is not supported with a different decoder sequence length." + ) + + tensor_shape = [seq_length, micro_batch_size, config.hidden_size] + if config.sequence_parallel: + tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size() + + # Compute number of warmup and remaining microbatches. + num_model_chunks = len(model) + total_num_microbatches = num_microbatches * num_model_chunks + all_warmup_microbatches = False + if forward_only: + num_warmup_microbatches = total_num_microbatches + else: + # Run all forward passes and then all backward passes if number of + # microbatches is just the number of pipeline stages. + # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on + # all workers, followed by more microbatches after depending on + # stage ID (more forward passes for earlier stages, later stages can + # immediately start with 1F1B). + if num_microbatches == pipeline_parallel_size: + num_warmup_microbatches = total_num_microbatches + all_warmup_microbatches = True + else: + num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 + num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size + num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches) + num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches + + # Checkpoint the activations of partial Transformer layers in a number of micro-batches + # within the maximum outstanding micro-batch backpropagations. + # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints' + # checkpoint partial Transformer layers (or skip checkpointing) and + # the rest of micro-batches within a window of micro-batches checkpoint + # all Transformer layers. The window of micro-batches is set by the maximum + # outstanding backpropagations and becomes smaller at later pipeline stages. + # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf + max_outstanding_backprops = None + if config.num_microbatches_with_partial_activation_checkpoints is not None: + max_outstanding_backprops = num_warmup_microbatches + 1 + + # Synchronize params for first two model chunks + if config.param_sync_func is not None: + config.param_sync_func(model[0].parameters()) + config.param_sync_func(model[1].parameters()) + + def get_model_chunk_id(microbatch_id, forward): + """Helper method to get the model chunk ID given the iteration number.""" + microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) + model_chunk_id = microbatch_id_in_group // pipeline_parallel_size + if not forward: + model_chunk_id = num_model_chunks - model_chunk_id - 1 + return model_chunk_id + + def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: + """Check if an iteration is the first for a model chunk.""" + microbatch_group_size = pipeline_parallel_size * num_model_chunks + num_microbatch_groups = total_num_microbatches // microbatch_group_size + microbatch_group_id = microbatch_id // microbatch_group_size + microbatch_id_in_group = microbatch_id % microbatch_group_size + if microbatch_group_id == 0: + return microbatch_id_in_group % pipeline_parallel_size == 0 + else: + return False + + def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool: + """Check if an iteration is the last for a model chunk.""" + microbatch_group_size = pipeline_parallel_size * num_model_chunks + num_microbatch_groups = total_num_microbatches // microbatch_group_size + microbatch_group_id = microbatch_id // microbatch_group_size + microbatch_id_in_group = microbatch_id % microbatch_group_size + if microbatch_group_id == num_microbatch_groups - 1: + return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1 + else: + return False + + def forward_step_helper(microbatch_id, checkpoint_activations_microbatch): + """Helper method to run forward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + forward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) + parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) + + # launch param synchronization for next model chunk + # Note: Asynchronous communication tends to slow down compute. + # To reduce idling from mismatched microbatch times, we launch + # asynchronous communication at the same time across the + # pipeline-parallel group. + if config.param_sync_func is not None: + param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank + if ( + param_sync_microbatch_id < total_num_microbatches + and is_first_microbatch_for_model_chunk(param_sync_microbatch_id) + ): + param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1 + if 1 < param_sync_chunk_id < num_model_chunks: + config.param_sync_func(model[param_sync_chunk_id].parameters()) + + # forward step + if parallel_state.is_pipeline_first_stage(): + if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]): + input_tensors[model_chunk_id].append(None) + input_tensor = input_tensors[model_chunk_id][-1] + output_tensor = forward_step( + forward_step_func, + data_iterator[model_chunk_id], + model[model_chunk_id], + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + checkpoint_activations_microbatch, + ) + output_tensors[model_chunk_id].append(output_tensor) + + # if forward-only, no need to save tensors for a backward pass + if forward_only: + input_tensors[model_chunk_id].pop() + output_tensors[model_chunk_id].pop() + + return output_tensor + + def backward_step_helper(microbatch_id): + """Helper method to run backward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + backward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) + parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) + + # launch grad synchronization (default) + if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id): + enable_grad_sync() + synchronized_model_chunks.add(model_chunk_id) + + if parallel_state.is_pipeline_last_stage(): + if len(output_tensor_grads[model_chunk_id]) == 0: + output_tensor_grads[model_chunk_id].append(None) + input_tensor = input_tensors[model_chunk_id].pop(0) + output_tensor = output_tensors[model_chunk_id].pop(0) + output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) + input_tensor_grad = backward_step( + input_tensor, output_tensor, output_tensor_grad, model_type, config + ) + + # launch grad synchronization (custom grad sync) + # Note: Asynchronous communication tends to slow down compute. + # To reduce idling from mismatched microbatch times, we launch + # asynchronous communication at the same time across the + # pipeline-parallel group. + if config.grad_sync_func is not None: + grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank + if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk( + grad_sync_microbatch_id + ): + grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False) + enable_grad_sync() + config.grad_sync_func(model[grad_sync_chunk_id].parameters()) + synchronized_model_chunks.add(grad_sync_chunk_id) + disable_grad_sync() + + return input_tensor_grad + + # Run warmup forward passes. + parallel_state.set_virtual_pipeline_model_parallel_rank(0) + input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config)) + + fwd_wait_handles = None + bwd_wait_handles = None + + for k in range(num_warmup_microbatches): + + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() + + # Decide to checkpoint all layers' activations of the current micro-batch + if max_outstanding_backprops is not None: + checkpoint_activations_microbatch = ( + k % max_outstanding_backprops + >= config.num_microbatches_with_partial_activation_checkpoints + ) + else: + checkpoint_activations_microbatch = None + + output_tensor = forward_step_helper(k, checkpoint_activations_microbatch) + + # Determine if tensor should be received from previous stage. + next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) + recv_prev = True + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + if next_forward_model_chunk_id == 0: + recv_prev = False + if k == (total_num_microbatches - 1): + recv_prev = False + + # Don't send tensor downstream if on last stage. + if parallel_state.is_pipeline_last_stage(): + output_tensor = None + + # Send and receive tensors as appropriate (send tensors computed + # in this iteration; receive tensors for next iteration). + if not config.overlap_p2p_comm: + if ( + k == (num_warmup_microbatches - 1) + and not forward_only + and not all_warmup_microbatches + ): + input_tensor_grad = None + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + ( + input_tensor, + output_tensor_grad, + ) = p2p_communication.send_forward_backward_recv_forward_backward( + output_tensor, + input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + ) + output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) + else: + input_tensor = p2p_communication.send_forward_recv_forward( + output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config + ) + input_tensors[next_forward_model_chunk_id].append(input_tensor) + else: + input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward( + output_tensor, + recv_prev=recv_prev, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + + if ( + k == (num_warmup_microbatches - 1) + and not forward_only + and not all_warmup_microbatches + ): + input_tensor_grad = None + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + + ( + output_tensor_grad, + bwd_wait_handles, + ) = p2p_communication.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + + output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) + input_tensors[next_forward_model_chunk_id].append(input_tensor) + + deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) + + # Run 1F1B in steady state. + for k in range(num_microbatches_remaining): + # Forward pass. + forward_k = k + num_warmup_microbatches + + # Decide to checkpoint all layers' activations of the current micro-batch + if max_outstanding_backprops is not None: + checkpoint_activations_microbatch = ( + forward_k % max_outstanding_backprops + >= config.num_microbatches_with_partial_activation_checkpoints + ) + else: + checkpoint_activations_microbatch = None + + if config.overlap_p2p_comm: + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() + + deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) + + output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch) + + # Determine if current stage has anything to send in either direction, + # otherwise set tensor to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) + + # Last virtual stage no activation tensor to send + if parallel_state.is_pipeline_last_stage(): + output_tensor = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id( + forward_k - (pipeline_parallel_size - 1), forward=True + ) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) + + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False + + # Send activation tensor to the next stage and receive activation tensor from the + # previous stage + input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward( + output_tensor, + recv_prev=recv_prev, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + # assert fwd_wait_handles is not None + + if bwd_wait_handles is not None: + for req in bwd_wait_handles: + req.wait() + + # Backward pass. + backward_k = k + input_tensor_grad = backward_step_helper(backward_k) + + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) + + # First virtual stage no activation gradient tensor to send + if parallel_state.is_pipeline_first_stage(): + input_tensor_grad = None + + # Determine if the current virtual stage has an activation gradient tensor to receive + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False + ) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) + + output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + + else: # no p2p overlap + output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch) + + # Backward pass. + backward_k = k + input_tensor_grad = backward_step_helper(backward_k) + + # Send output_tensor and input_tensor_grad, receive input_tensor + # and output_tensor_grad. + + # Determine if current stage has anything to send in either direction, + # otherwise set tensor to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) + if parallel_state.is_pipeline_last_stage(): + output_tensor = None + + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) + if parallel_state.is_pipeline_first_stage(): + input_tensor_grad = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id( + forward_k - (pipeline_parallel_size - 1), forward=True + ) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) + + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False + ) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) + + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False + + # Communicate tensors. + ( + input_tensor, + output_tensor_grad, + ) = p2p_communication.send_forward_backward_recv_forward_backward( + output_tensor, + input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + ) + deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) + + # Put input_tensor and output_tensor_grad in data structures in the + # right location. + if recv_prev: + input_tensors[next_forward_model_chunk_id].append(input_tensor) + if recv_next: + output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) + + deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) + + # Run cooldown backward passes (flush out pipeline). + if not forward_only: + if config.overlap_p2p_comm and bwd_wait_handles is not None: + for wait_handle in bwd_wait_handles: + wait_handle.wait() + + if all_warmup_microbatches: + output_tensor_grads[num_model_chunks - 1].append( + p2p_communication.recv_backward(tensor_shape, config=config) + ) + for k in range(num_microbatches_remaining, total_num_microbatches): + input_tensor_grad = backward_step_helper(k) + next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + if next_backward_model_chunk_id == (num_model_chunks - 1): + recv_next = False + if k == (total_num_microbatches - 1): + recv_next = False + output_tensor_grads[next_backward_model_chunk_id].append( + p2p_communication.send_backward_recv_backward( + input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config + ) + ) + + # Launch any remaining grad reductions + enable_grad_sync() + if config.grad_sync_func is not None: + params = [] + for model_chunk_id in range(num_model_chunks): + if model_chunk_id not in synchronized_model_chunks: + params.extend(model[model_chunk_id].parameters()) + synchronized_model_chunks.add(model_chunk_id) + if params: + config.grad_sync_func(params) + + return forward_data_store + + +def get_tensor_shapes( + *, + rank: int, + model_type: ModelType, + seq_length: int, + micro_batch_size: int, + decoder_seq_length: int, + config, +): + # Determine right tensor sizes (based on position of rank with respect to split + # rank) and model size. + # Send two tensors if model is T5 and rank is in decoder stage: + # first tensor is decoder (pre-transpose), + # second tensor is encoder (post-transpose). + # If model is T5 and rank is at the boundary: + # send one tensor (post-transpose from encoder). + # Otherwise, send one tensor (pre-transpose). + tensor_shapes = [] + + if config.sequence_parallel: + seq_length = seq_length // parallel_state.get_tensor_model_parallel_world_size() + if model_type == ModelType.encoder_and_decoder: + decoder_seq_length = ( + decoder_seq_length // parallel_state.get_tensor_model_parallel_world_size() + ) + + if model_type == ModelType.encoder_and_decoder: + if parallel_state.is_pipeline_stage_before_split(rank): + tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) + else: + tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size)) + tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) + else: + tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) + return tensor_shapes + + +def recv_forward(tensor_shapes, config): + input_tensors = [] + for tensor_shape in tensor_shapes: + if tensor_shape is None: + input_tensors.append(None) + else: + input_tensors.append(p2p_communication.recv_forward(tensor_shape, config)) + return input_tensors + + +def recv_backward(tensor_shapes, config): + output_tensor_grads = [] + for tensor_shape in tensor_shapes: + if tensor_shape is None: + output_tensor_grads.append(None) + else: + output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, config)) + return output_tensor_grads + + +def send_forward(output_tensors, tensor_shapes, config): + if not isinstance(output_tensors, list): + output_tensors = [output_tensors] + for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes): + if tensor_shape is None: + continue + p2p_communication.send_forward(output_tensor, config) + + +def send_backward(input_tensor_grads, tensor_shapes, config): + if not isinstance(input_tensor_grads, list): + input_tensor_grads = [input_tensor_grads] + for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes): + if tensor_shape is None: + continue + p2p_communication.send_backward(input_tensor_grad, config) + + +def send_forward_recv_backward(output_tensors, tensor_shapes, config): + if not isinstance(output_tensors, list): + output_tensors = [output_tensors] + output_tensor_grads = [] + for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes): + if tensor_shape is None: + output_tensor_grads.append(None) + continue + output_tensor_grad = p2p_communication.send_forward_recv_backward( + output_tensor, tensor_shape, config + ) + output_tensor_grads.append(output_tensor_grad) + return output_tensor_grads + + +def send_backward_recv_forward(input_tensor_grads, tensor_shapes, config): + if not isinstance(input_tensor_grads, list): + input_tensor_grads = [input_tensor_grads] + input_tensors = [] + for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes): + if tensor_shape is None: + input_tensors.append(None) + continue + input_tensor = p2p_communication.send_backward_recv_forward( + input_tensor_grad, tensor_shape, config + ) + input_tensors.append(input_tensor) + return input_tensors + + +def forward_backward_pipelining_without_interleaving( + *, + forward_step_func, + data_iterator: Union[Iterator, List[Iterator]], + model: Union[torch.nn.Module, List[torch.nn.Module]], + num_microbatches: int, + seq_length: int, + micro_batch_size: int, + decoder_seq_length: int = None, + forward_only: bool = False, + collect_non_loss_data: bool = False, +): + """Run non-interleaved 1F1B schedule, with communication between pipeline + stages. + + Returns dictionary with losses if the last stage, empty dict otherwise.""" + + if isinstance(model, list): + assert ( + len(model) == 1 + ), "non-interleaved pipeline parallelism does not support model chunking" + model = model[0] + if isinstance(data_iterator, list): + assert ( + len(data_iterator) == 1 + ), "non-pipeline-parallel schedule does not support model chunking" + data_iterator = data_iterator[0] + + config = get_model_config(model) + if config.overlap_p2p_comm: + raise ValueError( + "Non-interleaved pipeline parallelism does not support overlapping p2p communication" + ) + + # Disable async grad reductions + no_sync_func = config.no_sync_func + if no_sync_func is None and isinstance(model, torchDDP): + no_sync_func = model.no_sync + if no_sync_func is None: + no_sync_func = contextlib.nullcontext + no_sync_context = None + + def disable_grad_sync(): + """Disable asynchronous grad reductions""" + nonlocal no_sync_context + if no_sync_context is None: + no_sync_context = no_sync_func() + no_sync_context.__enter__() + + def enable_grad_sync(): + """Enable asynchronous grad reductions""" + nonlocal no_sync_context + if no_sync_context is not None: + no_sync_context.__exit__(None, None, None) + no_sync_context = None + + disable_grad_sync() + + # Compute number of warmup microbatches. + num_warmup_microbatches = ( + parallel_state.get_pipeline_model_parallel_world_size() + - parallel_state.get_pipeline_model_parallel_rank() + - 1 + ) + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + num_microbatches_remaining = num_microbatches - num_warmup_microbatches + + # Checkpoint the activations of partial Transformer layers in a number of micro-batches + # within the maximum outstanding micro-batch backpropagations. + # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints' + # checkpoint partial Transformer layers (or skip checkpointing) and + # the rest of micro-batches within a window of micro-batches checkpoint + # all Transformer layers. The window of micro-batches is set by the maximum + # outstanding backpropagations and becomes smaller at later pipeline stages. + # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf + max_outstanding_backprops = None + if config.num_microbatches_with_partial_activation_checkpoints is not None: + max_outstanding_backprops = num_warmup_microbatches + 1 + + model_type = get_model_type(model) + + rank = parallel_state.get_pipeline_model_parallel_rank() + recv_tensor_shapes = get_tensor_shapes( + rank=rank - 1, + model_type=model_type, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + decoder_seq_length=decoder_seq_length, + config=config, + ) + send_tensor_shapes = get_tensor_shapes( + rank=rank, + model_type=model_type, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + decoder_seq_length=decoder_seq_length, + config=config, + ) + + # Input, output tensors only need to be saved when doing backward passes + input_tensors = None + output_tensors = None + if not forward_only: + input_tensors = [] + output_tensors = [] + forward_data_store = [] + + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + # Decide to checkpoint all layers' activations of the current micro-batch + if max_outstanding_backprops is not None: + checkpoint_activations_microbatch = ( + i % max_outstanding_backprops + >= config.num_microbatches_with_partial_activation_checkpoints + ) + else: + checkpoint_activations_microbatch = None + + input_tensor = recv_forward(recv_tensor_shapes, config) + output_tensor = forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + checkpoint_activations_microbatch, + ) + send_forward(output_tensor, send_tensor_shapes, config) + + if not forward_only: + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) + deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) + + # Before running 1F1B, need to receive first forward tensor. + # If all microbatches are run in warmup / cooldown phase, then no need to + # receive this tensor here. + if num_microbatches_remaining > 0: + input_tensor = recv_forward(recv_tensor_shapes, config) + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + last_iteration = i == (num_microbatches_remaining - 1) + + # Decide to checkpoint all layers' activations of the current micro-batch + if max_outstanding_backprops is not None: + checkpoint_activations_microbatch = ( + (i + num_warmup_microbatches) % max_outstanding_backprops + ) >= config.num_microbatches_with_partial_activation_checkpoints + else: + checkpoint_activations_microbatch = None + + output_tensor = forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + checkpoint_activations_microbatch, + ) + + if forward_only: + send_forward(output_tensor, send_tensor_shapes, config) + + if not last_iteration: + input_tensor = recv_forward(recv_tensor_shapes, config) + + else: + output_tensor_grad = send_forward_recv_backward( + output_tensor, send_tensor_shapes, config + ) + + # Add input_tensor and output_tensor to end of list. + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) + deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) + + # Pop input_tensor and output_tensor from the start of the list for + # the backward pass. + input_tensor = input_tensors.pop(0) + output_tensor = output_tensors.pop(0) + + input_tensor_grad = backward_step( + input_tensor, output_tensor, output_tensor_grad, model_type, config + ) + + if last_iteration: + input_tensor = None + send_backward(input_tensor_grad, recv_tensor_shapes, config) + else: + input_tensor = send_backward_recv_forward( + input_tensor_grad, recv_tensor_shapes, config + ) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_warmup_microbatches): + + # Enable async grad reduction in the last backward pass + # Note: If grad sync function is provided, only enable + # async grad reduction in first pipeline stage. Other + # pipeline stages do grad reduction during pipeline + # bubble. + if i == num_warmup_microbatches - 1: + if config.grad_sync_func is None or rank == 0: + enable_grad_sync() + + input_tensor = input_tensors.pop(0) + output_tensor = output_tensors.pop(0) + + output_tensor_grad = recv_backward(send_tensor_shapes, config) + + input_tensor_grad = backward_step( + input_tensor, output_tensor, output_tensor_grad, model_type, config + ) + + send_backward(input_tensor_grad, recv_tensor_shapes, config) + + # Launch any remaining grad reductions + if no_sync_context is not None: + enable_grad_sync() + if config.grad_sync_func is not None: + config.grad_sync_func(model.parameters()) + + return forward_data_store diff --git a/training/DeepSpeed-Domino/megatron/core/requirements.txt b/training/DeepSpeed-Domino/megatron/core/requirements.txt new file mode 100644 index 000000000..08ed5eeb4 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/requirements.txt @@ -0,0 +1 @@ +torch \ No newline at end of file diff --git a/training/DeepSpeed-Domino/megatron/core/tensor_parallel/__init__.py b/training/DeepSpeed-Domino/megatron/core/tensor_parallel/__init__.py new file mode 100644 index 000000000..dabda5213 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/tensor_parallel/__init__.py @@ -0,0 +1,56 @@ +from .cross_entropy import vocab_parallel_cross_entropy +from .data import broadcast_data +from .layers import ( + ColumnParallelLinear, + RowParallelLinear, + VocabParallelEmbedding, + copy_tensor_model_parallel_attributes, + linear_with_grad_accumulation_and_async_allreduce, + param_is_not_tensor_parallel_duplicate, + set_defaults_if_not_set_tensor_model_parallel_attributes, + set_tensor_model_parallel_attributes, +) +from .mappings import ( + copy_to_tensor_model_parallel_region, + gather_from_sequence_parallel_region, + gather_from_tensor_model_parallel_region, + scatter_to_sequence_parallel_region, + scatter_to_tensor_model_parallel_region, +) +from .random import checkpoint, get_cuda_rng_tracker, model_parallel_cuda_manual_seed +from .utils import ( + gather_split_1d_tensor, + split_tensor_along_last_dim, + split_tensor_into_1d_equal_chunks, +) + +__all__ = [ + # cross_entropy.py + "vocab_parallel_cross_entropy", + # data.py + "broadcast_data", + # layers.py + "ColumnParallelLinear", + "RowParallelLinear", + "VocabParallelEmbedding", + "set_tensor_model_parallel_attributes", + "set_defaults_if_not_set_tensor_model_parallel_attributes", + "copy_tensor_model_parallel_attributes", + "param_is_not_tensor_parallel_duplicate", + "linear_with_grad_accumulation_and_async_allreduce", + # mappings.py + "copy_to_tensor_model_parallel_region", + "gather_from_tensor_model_parallel_region", + "gather_from_sequence_parallel_region", + # "reduce_from_tensor_model_parallel_region", + "scatter_to_tensor_model_parallel_region", + "scatter_to_sequence_parallel_region", + # random.py + "checkpoint", + "get_cuda_rng_tracker", + "model_parallel_cuda_manual_seed", + # utils.py + "split_tensor_along_last_dim", + "split_tensor_into_1d_equal_chunks", + "gather_split_1d_tensor", +] diff --git a/training/DeepSpeed-Domino/megatron/core/tensor_parallel/cross_entropy.py b/training/DeepSpeed-Domino/megatron/core/tensor_parallel/cross_entropy.py new file mode 100644 index 000000000..645fd1ea0 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/tensor_parallel/cross_entropy.py @@ -0,0 +1,142 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import torch + +from megatron.core.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +from .utils import VocabUtility + + +class _VocabParallelCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): + + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group() + ) + # Subtract the maximum value. + vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) + + # Get the partition's vocab indecies + get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size()[-1] + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + torch.distributed.all_reduce( + predicted_logits, + op=torch.distributed.ReduceOp.SUM, + group=get_tensor_model_parallel_group(), + ) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = vocab_parallel_logits + torch.exp(vocab_parallel_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + torch.distributed.all_reduce( + sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=get_tensor_model_parallel_group(), + ) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + + # Normalize and optionally smooth logits + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + vocab_size = exp_logits.size(-1) + if label_smoothing > 0: + """ + We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth. + = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt}) + = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i + = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i + = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i + = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K + From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py + """ + assert 1.0 > label_smoothing > 0.0 + smoothing = label_smoothing * vocab_size / (vocab_size - 1) + + # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs. + log_probs = torch.log(exp_logits) + mean_log_probs = log_probs.mean(dim=-1) + loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs + + ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size + + # Store softmax, target-mask and masked-target for backward pass. + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size + + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + + softmax_update = 1.0 - target_mask.view(-1).float() + + if label_smoothing > 0: + smoothing = label_smoothing * vocab_size / (vocab_size - 1) + grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update + average_grad = 1 / vocab_size + grad_2d[arange_1d, :] -= smoothing * average_grad + else: + grad_2d[arange_1d, masked_target_1d] -= softmax_update + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None, None + + +def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): + """ + Performs cross entropy loss when logits are split across tensor parallel ranks + + Arguments: + vocab_parallel_logits: logits split across tensor parallel ranks + dimension is [sequence_length, batch_size, hidden_size] + + target: correct vocab ids of dimseion [sequence_length, micro_batch_size] + + lobal_smoothing: smoothing factor, must be in range [0.0, 1.0) + default is no smoothing (=0.0) + """ + return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing) diff --git a/training/DeepSpeed-Domino/megatron/core/tensor_parallel/data.py b/training/DeepSpeed-Domino/megatron/core/tensor_parallel/data.py new file mode 100644 index 000000000..45c4fe7eb --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/tensor_parallel/data.py @@ -0,0 +1,104 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import torch + +from megatron.core.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_src_rank, +) + +_MAX_DATA_DIM = 5 + + +def _check_data_types(keys, data, target_dtype): + """Check that all the keys have the same target data type.""" + for key in keys: + assert data[key].dtype == target_dtype, ( + '{} has data type {} which ' + 'is different than {}'.format(key, data[key].dtype, target_dtype) + ) + + +def _build_key_size_numel_dictionaries(keys, data): + """Build the size on rank 0 and broadcast.""" + max_dim = _MAX_DATA_DIM + sizes = [0 for _ in range(max_dim) for _ in keys] + + # Pack the sizes on rank zero. + if get_tensor_model_parallel_rank() == 0: + offset = 0 + for key in keys: + assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' + size = data[key].size() + for i, s in enumerate(size): + sizes[i + offset] = s + offset += max_dim + + # Move to GPU and broadcast. + sizes_cuda = torch.cuda.LongTensor(sizes) + torch.distributed.broadcast( + sizes_cuda, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group() + ) + + # Move back to cpu and unpack. + sizes_cpu = sizes_cuda.cpu() + key_size = {} + key_numel = {} + total_numel = 0 + offset = 0 + for key in keys: + i = 0 + size = [] + numel = 1 + while sizes_cpu[offset + i] > 0: + this_size = sizes_cpu[offset + i] + size.append(this_size) + numel *= this_size + i += 1 + key_size[key] = size + key_numel[key] = numel + total_numel += numel + offset += max_dim + + return key_size, key_numel, total_numel + + +def broadcast_data(keys, data, datatype): + """Broadcast data from rank zero of each model parallel group to the + members of the same model parallel group. + + Arguments: + keys: list of keys in the data disctionary to be broadcasted + data: data dictionary of string keys and cpu tensor values. + datatype: torch data type of all tensors in data associated + with keys. + """ + # Build (key, size) and (key, number of elements) dictionaries along + # with the total number of elements on all ranks. + key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) + + # Pack on rank zero. + if get_tensor_model_parallel_rank() == 0: + # Check that all keys have the same data type. + _check_data_types(keys, data, datatype) + # Flatten the data associated with the keys + flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda() + else: + flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype) + + # Broadcast + torch.distributed.broadcast( + flatten_data, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group() + ) + + # Unpack + output = {} + offset = 0 + for key in keys: + size = key_size[key] + numel = key_numel[key] + output[key] = flatten_data.narrow(0, offset, numel).view(size) + offset += numel + + return output diff --git a/training/DeepSpeed-Domino/megatron/core/tensor_parallel/layers.py b/training/DeepSpeed-Domino/megatron/core/tensor_parallel/layers.py new file mode 100644 index 000000000..834f821e1 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/tensor_parallel/layers.py @@ -0,0 +1,899 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch + +import math +import os +import warnings +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +import torch.nn.init as init +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.parameter import Parameter + +from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.parallel_state import ( + get_global_memory_buffer, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +from .mappings import ( + copy_to_tensor_model_parallel_region, + gather_from_sequence_parallel_region, + gather_from_tensor_model_parallel_region, + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, + scatter_to_tensor_model_parallel_region, +) +from .random import get_cuda_rng_tracker +from .utils import VocabUtility, divide, split_tensor_along_last_dim + +_grad_accum_fusion_available = True +try: + import fused_weight_gradient_mlp_cuda +except ImportError: + _grad_accum_fusion_available = False + +_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { + 'tensor_model_parallel': False, + 'partition_dim': -1, + 'partition_stride': 1, +} + + +def param_is_not_tensor_parallel_duplicate(param): + return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or ( + get_tensor_model_parallel_rank() == 0 + ) + + +def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): + # Make sure the attributes are not set. + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + assert not hasattr(tensor, attribute) + # Set the attributes. + setattr(tensor, 'tensor_model_parallel', is_parallel) + setattr(tensor, 'partition_dim', dim) + setattr(tensor, 'partition_stride', stride) + + +def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor): + def maybe_set(attribute, value): + if not hasattr(tensor, attribute): + setattr(tensor, attribute, value) + + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) + + +def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): + def maybe_copy(attribute): + if hasattr(source_tensor, attribute): + setattr(destination_tensor, attribute, getattr(source_tensor, attribute)) + + for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: + maybe_copy(attribute) + + +def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): + """Initialize affine weight for model parallel on GPU.""" + + set_tensor_model_parallel_attributes( + tensor=weight, is_parallel=True, dim=partition_dim, stride=stride + ) + + with get_cuda_rng_tracker().fork(): + init_method(weight) + + +def _initialize_affine_weight_cpu( + weight, + output_size, + input_size, + per_partition_size, + partition_dim, + init_method, + stride=1, + return_master_weight=False, + *, + params_dtype=torch.float32, +): + """Initialize affine weight for model parallel. + + Build the master weight on all processes and scatter + the relevant chunk.""" + + set_tensor_model_parallel_attributes( + tensor=weight, is_parallel=True, dim=partition_dim, stride=stride + ) + + # Initialize master weight + master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False) + init_method(master_weight) + master_weight = master_weight.to(dtype=params_dtype) + + # Split and copy + per_partition_per_stride_size = divide(per_partition_size, stride) + weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + my_weight_list = weight_list[rank::world_size] + + with torch.no_grad(): + torch.cat(my_weight_list, dim=partition_dim, out=weight) + if return_master_weight: + return master_weight + return None + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + + Keyword Arguments: + config: A megatron.core.ModelParallelConfig object + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + *, + init_method: Callable, + config: ModelParallelConfig, + ): + super(VocabParallelEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + # Set the detauls for compatibility. + self.padding_idx = None + self.max_norm = None + self.norm_type = 2.0 + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() + # Divide the weight matrix along the vocaburaly dimension. + ( + self.vocab_start_index, + self.vocab_end_index, + ) = VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size + ) + self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index + + # Allocate weights and initialize. + if config.use_cpu_initialization: + self.weight = Parameter( + torch.empty( + self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype + ) + ) + if config.perform_initialization: + _initialize_affine_weight_cpu( + self.weight, + self.num_embeddings, + self.embedding_dim, + self.num_embeddings_per_partition, + 0, + init_method, + params_dtype=config.params_dtype, + ) + else: + self.weight = Parameter( + torch.empty( + self.num_embeddings_per_partition, + self.embedding_dim, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + ) + if config.perform_initialization: + _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1) + + def forward(self, input_): + if self.tensor_model_parallel_size > 1: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + output_parallel = F.embedding( + masked_input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + # Mask the output embedding. + if self.tensor_model_parallel_size > 1: + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = reduce_from_tensor_model_parallel_region(output_parallel) + return output + + +class LinearWithFrozenWeight(torch.autograd.Function): + """Linear operator that does not calculate gradient for weight. + This op and LinearWithGradAccumulationAndAsyncCommunication performs + mathematically-identical forward and DGRAD. + + Conceptually this op is the same as torch.nn.functional.linear with + weight.requires_grad==False, but in experiments they are not identical + mathematically. """ + + @staticmethod + @custom_fwd + def forward( + ctx, input, weight, bias, + ): + ctx.save_for_backward(weight) + output = torch.matmul(input, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + (weight,) = ctx.saved_tensors + grad_input = grad_output.matmul(weight) + return grad_input, None, None + + +def linear_with_frozen_weight( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + gradient_accumulation_fusion: bool, + async_grad_allreduce: bool, + sequence_parallel: bool, +) -> torch.Tensor: + """Linear layer execution with weight.requires_grad == False. + + This function handles linear layers with weight frozen (untrainable). + In the forward, it only saves weight and does not save input activations. + In the backward, it does not perform weight gradient calculation, or + weight gradient allreduce. + + Arguments: + + input (torch.Tensor required): input like torch.nn.functional.linear + + weight (torch.Tensor required): weight like torch.nn.functional.linear + + bias (torch.Tensor optional): bias like torch.nn.functional.linear + + gradient_accumulation_fusion (bool required): dummy argument, used to + keep the API unified between all forward implementation functions. + + async_grad_allreduce (bool required): dummy argument, used to + keep the API unified between all forward implementation functions. + + sequence_parallel (bool required): Indicates that sequence + parallelism is used and thus in the forward pass the input is + all gathered, and the backward pass the input gradients are + reduce scattered. + """ + + if sequence_parallel: + input = gather_from_sequence_parallel_region(input, tensor_parallel_output_grad=True) + else: + input = input + + args = [ + input, + weight, + bias, + ] + + return LinearWithFrozenWeight.apply(*args) + + +class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): + """See linear_with_grad_accumulation_and_async_allreduce""" + + @staticmethod + @custom_fwd + def forward( + ctx, + input, + weight, + bias, + gradient_accumulation_fusion, + async_grad_allreduce, + sequence_parallel, + ): + ctx.save_for_backward(input, weight) + ctx.use_bias = bias is not None + ctx.gradient_accumulation_fusion = gradient_accumulation_fusion + ctx.async_grad_allreduce = async_grad_allreduce + ctx.sequence_parallel = sequence_parallel + + if sequence_parallel: + world_size = get_tensor_model_parallel_world_size() + dim_size = list(input.size()) + dim_size[0] = dim_size[0] * world_size + + all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") + torch.distributed._all_gather_base( + all_gather_buffer, input, group=get_tensor_model_parallel_group() + ) + total_input = all_gather_buffer + else: + total_input = input + + output = torch.matmul(total_input, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + if ctx.sequence_parallel: + world_size = get_tensor_model_parallel_world_size() + dim_size = list(input.size()) + dim_size[0] = dim_size[0] * world_size + + all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") + handle = torch.distributed._all_gather_base( + all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True + ) + + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # gather is scheduled before the input gradient computation + total_input = all_gather_buffer + else: + total_input = input + grad_input = grad_output.matmul(weight) + + if ctx.sequence_parallel: + handle.wait() + + # Doing gather + slicing during the NeMo forward pass can make this tensor + # not be contiguous. PyTorch only checks if the tensor is contiguous, and only + # clones it if it's not contiguous: + # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761 + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.view( + grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2] + ) + total_input = total_input.view( + total_input.shape[0] * total_input.shape[1], total_input.shape[2] + ) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = torch.distributed.all_reduce( + grad_input, group=get_tensor_model_parallel_group(), async_op=True + ) + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # all-reduce is scheduled before the weight gradient computation + + if ctx.sequence_parallel: + assert not ctx.async_grad_allreduce + dim_size = list(input.size()) + sub_grad_input = torch.empty( + dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False + ) + # reduce_scatter + handle = torch.distributed._reduce_scatter_base( + sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True + ) + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # reduce scatter is scheduled before the weight gradient computation + + if ctx.gradient_accumulation_fusion: + if weight.main_grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32( + total_input, grad_output, weight.main_grad + ) + elif weight.main_grad.dtype in (torch.float16, torch.bfloat16): + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16( + total_input, grad_output, weight.main_grad + ) + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.sequence_parallel: + handle.wait() + return sub_grad_input, grad_weight, grad_bias, None, None, None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +def linear_with_grad_accumulation_and_async_allreduce( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + gradient_accumulation_fusion: bool, + async_grad_allreduce: bool, + sequence_parallel: bool, +) -> torch.Tensor: + """Linear layer execution with asynchronous communication and + gradient accumulation fusion in backprop. + + This has the option to accumulate the result of backprop + calculation into an existing gradient buffer, preventing the need + to do an additional addition kernel after the gradient + calculation. + + Additionally, the tensor parallel all reduce of the input + gradients can be done asynchronously with the calculation of + the weight gradients. + + In the case of sequence parallelism, the reduce scatter of the + input gradients is done asynchronously with the calcluation of the + weight gradients. + + Use of this module requires that the environment variable + CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective + operations, noted in the code, that should be scheduled before + compute kernels to overlap the communication with the computation, + which is necessary for a speedup but not for correctness so that + ordering isn't imposed by the scheduler. Setting + CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled + in the order they are called. + + Arguments: + + input (torch.Tensor required): input like torch.nn.functional.linear + + weight (torch.Tensor required): weight like torch.nn.functional.linear + + bias (torch.Tensor optional): bias like torch.nn.functional.linear + + gradient_accumulation_fusion (bool required): Perform the gradient + accumulation fusion, requires the custom CUDA extension + fused_weight_gradient_mlp_cuda module. To use + gradient_accumulation_fusion you must install APEX with + --cpp_ext and --cuda_ext. For example: "pip install + --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" + " Note that the extension requires CUDA>=11. Otherwise, you + must turn off gradient accumulation fusion." + + async_grad_allreduce (bool required): Do the allreduce of input + gradients asyncronously with the computation of weight + gradients. If sequence_parallel is True, this must be + False, as no all reduce is performed. + + sequence_parallel (bool required): Indicates that sequence + parallelism is used and thus in the forward pass the input is + all gathered, and the backward pass the input gradients are + reduce scattered. + """ + args = [ + input, + weight, + bias, + gradient_accumulation_fusion, + async_grad_allreduce, + sequence_parallel, + ] + + if not linear_with_grad_accumulation_and_async_allreduce.warned: + if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": + if sequence_parallel: + warnings.warn( + "When using sequence parallelism it is recommended to set the " + "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " + "maximum speedup" + ) + linear_with_grad_accumulation_and_async_allreduce.warned = True + + if async_grad_allreduce: + warnings.warn( + "When using async grad allreduce it is recommended to set the " + "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " + "maximum speedup" + ) + linear_with_grad_accumulation_and_async_allreduce.warned = True + + return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) + + +linear_with_grad_accumulation_and_async_allreduce.warned = False + + +class ColumnParallelLinear(torch.nn.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + + Keyword Arguments + bias: If true, add bias + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: If True, do not add the bias term, instead + return it to be added by the caller. This + enables performance optimations where bias can + be fused with other elementwise operations. + + skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed + as a keyword argument `weight` during the forward pass. Note + that this does not affect bias, which will be allocated if + bias is True. Defaults to False. + + config: ModelParallelConfig object + + """ + + def __init__( + self, + input_size, + output_size, + *, + config: ModelParallelConfig, + init_method: Callable, + bias=True, + gather_output=False, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + skip_weight_param_allocation: bool = False, + ): + super(ColumnParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + world_size = get_tensor_model_parallel_world_size() + self.output_size_per_partition = divide(output_size, world_size) + self.skip_bias_add = skip_bias_add + self.config = config + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + # Initialize weight. + if not skip_weight_param_allocation: + if config.use_cpu_initialization: + self.weight = Parameter( + torch.empty( + self.output_size_per_partition, self.input_size, dtype=config.params_dtype + ) + ) + if config.perform_initialization: + self.master_weight = _initialize_affine_weight_cpu( + self.weight, + self.output_size, + self.input_size, + self.output_size_per_partition, + 0, + init_method, + stride=stride, + return_master_weight=keep_master_weight_for_test, + ) + else: + self.weight = Parameter( + torch.empty( + self.output_size_per_partition, + self.input_size, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + ) + if config.perform_initialization: + _initialize_affine_weight_gpu( + self.weight, init_method, partition_dim=0, stride=stride + ) + else: + self.weight = None + + if bias: + if config.use_cpu_initialization: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, dtype=config.params_dtype) + ) + else: + self.bias = Parameter( + torch.empty( + self.output_size_per_partition, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + ) + set_tensor_model_parallel_attributes(self.bias, True, 0, stride) + if config.perform_initialization: + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + self.async_tensor_model_parallel_allreduce = ( + config.async_tensor_model_parallel_allreduce and world_size > 1 + ) + + self.sequence_parallel = config.sequence_parallel + if self.sequence_parallel and world_size <= 1: + warnings.warn( + f"`sequence_parallel` is set to `True`, but tensor model parallel size is {world_size}. " + f"Disabling sequence parallel." + ) + self.sequence_parallel = False + + if config.gradient_accumulation_fusion and not _grad_accum_fusion_available: + raise RuntimeError( + "ColumnParallelLinear was called with gradient_accumulation_fusion set " + "to True but the custom CUDA extension fused_weight_gradient_mlp_cuda " + "module is not found. To use gradient_accumulation_fusion you must " + "install APEX with --cpp_ext and --cuda_ext. For example: " + "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" " + "Note that the extension requires CUDA>=11. Otherwise, you must turn off " + "gradient accumulation fusion." + ) + self.gradient_accumulation_fusion = config.gradient_accumulation_fusion + + if self.async_tensor_model_parallel_allreduce and self.sequence_parallel: + raise RuntimeError( + "`async_tensor_model_parallel_allreduce` and `sequence_parallel` " + "cannot be enabled at the same time." + ) + + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + + def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): + """Forward of ColumnParallelLinear + + Args: + input_: 3D tensor whose order of dimension is [sequence, batch, hidden] + + weight (optional): weight tensor to use, compulsory when + skip_weight_param_allocation is True. + + Returns: + - output + - bias + + """ + if weight is None: + if self.weight is None: + raise RuntimeError( + "weight was not supplied to ColumnParallelLinear forward pass " + "and skip_weight_param_allocation is True." + ) + weight = self.weight + else: + # Check the weight passed in is the correct shape + expected_shape = (self.output_size_per_partition, self.input_size) + if weight.shape != expected_shape: + raise RuntimeError( + f"supplied weight's shape is {tuple(weight.shape)}, " + f"not {expected_shape} as expected" + ) + + bias = self.bias if not self.skip_bias_add else None + + if self.async_tensor_model_parallel_allreduce or self.sequence_parallel: + input_parallel = input_ + else: + input_parallel = copy_to_tensor_model_parallel_region(input_) + # Matrix multiply. + if not weight.requires_grad: + self._forward_impl = linear_with_frozen_weight + else: + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + output_parallel = self._forward_impl( + input=input_parallel, + weight=weight, + bias=bias, + gradient_accumulation_fusion=self.gradient_accumulation_fusion, + async_grad_allreduce=self.async_tensor_model_parallel_allreduce, + sequence_parallel=self.sequence_parallel, + ) + if self.gather_output: + # All-gather across the partitions. + assert not self.sequence_parallel + output = gather_from_tensor_model_parallel_region(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class RowParallelLinear(torch.nn.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + + Keyword Arguments: + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: If True, do not add the bias term, instead + return it to be added by the caller. This + enables performance optimations where bias can + be fused with other elementwise operations. + config: ModelParallelConfig object + + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + bias: bool = True, + input_is_parallel: bool = False, + stride: int = 1, + keep_master_weight_for_test: bool = False, + skip_bias_add: bool = False, + ): + super(RowParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.input_is_parallel = input_is_parallel + # Divide the weight matrix along the last dimension. + world_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, world_size) + self.skip_bias_add = skip_bias_add + self.config = config + self.gradient_accumulation_fusion = config.gradient_accumulation_fusion + self.sequence_parallel = config.sequence_parallel + if self.sequence_parallel and not self.input_is_parallel: + raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`") + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + # Initialize weight. + if config.use_cpu_initialization: + self.weight = Parameter( + torch.empty( + self.output_size, self.input_size_per_partition, dtype=config.params_dtype + ) + ) + if config.perform_initialization: + self.master_weight = _initialize_affine_weight_cpu( + self.weight, + self.output_size, + self.input_size, + self.input_size_per_partition, + 1, + init_method, + stride=stride, + return_master_weight=keep_master_weight_for_test, + params_dtype=config.params_dtype, + ) + else: + self.weight = Parameter( + torch.empty( + self.output_size, + self.input_size_per_partition, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + ) + if config.perform_initialization: + _initialize_affine_weight_gpu( + self.weight, init_method, partition_dim=1, stride=stride + ) + if bias: + if config.use_cpu_initialization: + self.bias = Parameter(torch.empty(self.output_size, dtype=config.params_dtype)) + else: + self.bias = Parameter( + torch.empty( + self.output_size, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + ) + setattr(self.bias, 'sequence_parallel', self.sequence_parallel) + + if config.perform_initialization: + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + + def forward(self, input_): + """Forward of RowParallelLinear + + Args: + input_: 3D tensor whose order of dimension is [sequence, batch, hidden] + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + assert not self.sequence_parallel + input_parallel = scatter_to_tensor_model_parallel_region(input_) + # Matrix multiply. + if not self.weight.requires_grad: + self._forward_impl = linear_with_frozen_weight + else: + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + output_parallel = self._forward_impl( + input=input_parallel, + weight=self.weight, + bias=None, + gradient_accumulation_fusion=self.gradient_accumulation_fusion, + async_grad_allreduce=False, + sequence_parallel=False, + ) + + # All-reduce across all the partitions. + if self.sequence_parallel: + output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) + else: + output_ = reduce_from_tensor_model_parallel_region(output_parallel) + if not self.skip_bias_add: + output = output_ + self.bias if self.bias is not None else output_ + output_bias = None + else: + output = output_ + output_bias = self.bias + return output, output_bias diff --git a/training/DeepSpeed-Domino/megatron/core/tensor_parallel/mappings.py b/training/DeepSpeed-Domino/megatron/core/tensor_parallel/mappings.py new file mode 100644 index 000000000..9f753e732 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/tensor_parallel/mappings.py @@ -0,0 +1,283 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import torch + +from megatron.core.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +from .utils import split_tensor_along_last_dim + + +def _reduce(input_): + """All-reduce the input tensor across model parallel group.""" + + # Bypass the function if we are using only 1 GPU. + if get_tensor_model_parallel_world_size() == 1: + return input_ + + # All-reduce. + torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) + + return input_ + + +def _split_along_last_dim(input_): + """Split the tensor along its last dimension and keep the + corresponding slice.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along last dimension. + input_list = split_tensor_along_last_dim(input_, world_size) + + # Note: torch.split does not create contiguous tensors by default. + rank = get_tensor_model_parallel_rank() + output = input_list[rank].contiguous() + + return output + + +def _split_along_first_dim(input_): + """Split the tensor along its first dimension and keep the + corresponding slice.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along first dimension. + dim_size = input_.size()[0] + assert ( + dim_size % world_size == 0 + ), "First dimension of the tensor should be divisible by tensor parallel size" + local_dim_size = dim_size // world_size + rank = get_tensor_model_parallel_rank() + dim_offset = rank * local_dim_size + + output = input_[dim_offset : dim_offset + local_dim_size].contiguous() + + return output + + +def _gather_along_last_dim(input_): + """Gather tensors and concatinate along the last dimension.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Size and dimension. + last_dim = input_.dim() - 1 + rank = get_tensor_model_parallel_rank() + + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=last_dim).contiguous() + + return output + + +def _gather_along_first_dim(input_): + """Gather tensors and concatinate along the first dimension.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) + torch.distributed._all_gather_base( + output, input_.contiguous(), group=get_tensor_model_parallel_group() + ) + + return output + + +def _reduce_scatter_along_first_dim(input_): + """Reduce-scatter the input tensor across model parallel group.""" + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + assert ( + dim_size[0] % world_size == 0 + ), "First dimension of the tensor should be divisible by tensor parallel size" + + dim_size[0] = dim_size[0] // world_size + + output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) + torch.distributed._reduce_scatter_base( + output, input_.contiguous(), group=get_tensor_model_parallel_group() + ) + return output + + +class _CopyToModelParallelRegion(torch.autograd.Function): + """Pass the input to the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_): + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output) + + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """All-reduce the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _reduce(input_) + + @staticmethod + def forward(ctx, input_): + return _reduce(input_) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class _ScatterToModelParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return _split_along_last_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _split_along_last_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_last_dim(grad_output) + + +class _GatherFromModelParallelRegion(torch.autograd.Function): + """Gather the input from model parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_): + return _gather_along_last_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _gather_along_last_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _split_along_last_dim(grad_output) + + +class _ScatterToSequenceParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return _split_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _split_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + +class _GatherFromSequenceParallelRegion(torch.autograd.Function): + """Gather the input from sequence parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_, tensor_parallel_output_grad=True): + return _gather_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_, tensor_parallel_output_grad=True): + ctx.tensor_parallel_output_grad = tensor_parallel_output_grad + return _gather_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + tensor_parallel_output_grad = ctx.tensor_parallel_output_grad + + # If the computation graph after the gather operation is + # in the tensor parallel mode, output gradients need to reduce + # scattered and whereas if the computation is duplicated, + # output gradients need to be scattered. + if tensor_parallel_output_grad: + return _reduce_scatter_along_first_dim(grad_output), None + else: + return _split_along_first_dim(grad_output), None + + +class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): + """Reduce scatter the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + +# ----------------- +# Helper functions. +# ----------------- + + +def copy_to_tensor_model_parallel_region(input_): + return _CopyToModelParallelRegion.apply(input_) + + +def reduce_from_tensor_model_parallel_region(input_): + return _ReduceFromModelParallelRegion.apply(input_) + + +def scatter_to_tensor_model_parallel_region(input_): + return _ScatterToModelParallelRegion.apply(input_) + + +def gather_from_tensor_model_parallel_region(input_): + return _GatherFromModelParallelRegion.apply(input_) + + +def scatter_to_sequence_parallel_region(input_): + return _ScatterToSequenceParallelRegion.apply(input_) + + +def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): + return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) + + +def reduce_scatter_to_sequence_parallel_region(input_): + return _ReduceScatterToSequenceParallelRegion.apply(input_) diff --git a/training/DeepSpeed-Domino/megatron/core/tensor_parallel/random.py b/training/DeepSpeed-Domino/megatron/core/tensor_parallel/random.py new file mode 100644 index 000000000..c3a9e2b72 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/tensor_parallel/random.py @@ -0,0 +1,249 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch + +import contextlib + +import torch +from torch import _C +from torch.cuda import _lazy_call +from torch.cuda import device as device_ctx_manager +from torch.utils.checkpoint import detach_variable + +from megatron.core.parallel_state import ( + get_data_parallel_rank, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from megatron.core.utils import safely_set_viewless_tensor_data + +from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks + +# Default name for the model parallel rng tracker. +_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' + + +def _set_cuda_rng_state(new_state, device=-1): + """Sets the random number generator state of the current GPU. + + Argumentss: + new_state (torch.ByteTensor): The desired state + This function is adapted from PyTorch repo (torch.cuda.set_rng_state) + with a single change: the input state is not cloned. Cloning caused + major performance issues for +4 GPU cases. + """ + if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): + # older PyTorch + def cb(): + with device_ctx_manager(device): + _C._cuda_setRNGState(new_state) + + else: + # newer PyTorch + if device == -1: + device = torch.device('cuda') + elif isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device('cuda', device) + + def cb(): + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.set_state(new_state) + + _lazy_call(cb) + + +class CudaRNGStatesTracker: + """Tracker for the cuda RNG states. + + Using the `add` method, a cuda rng state is initialized based on + the input `seed` and is assigned to `name`. Later, by forking the + rng state, we can perform operations and return to our starting + cuda state. + """ + + def __init__(self): + # Map from a string name to the cuda rng state. + self.states_ = {} + # Seeds are just for book keeping and ensure no seed is set twice. + self.seeds_ = set() + + def reset(self): + """Set to the initial state (no tracker).""" + self.states_ = {} + self.seeds_ = set() + + def get_states(self): + """Get rng states. Copy the dictionary so we have direct + pointers to the states, not just a pointer to the dictionary.""" + states = {} + for name in self.states_: + states[name] = self.states_[name] + return states + + def set_states(self, states): + """Set the rng states. For efficiency purposes, we do not check + the size of seed for compatibility.""" + self.states_ = states + + def add(self, name, seed): + """Track the rng state.""" + # Check seed is not already used. + if seed in self.seeds_: + raise Exception('seed {} already exists'.format(seed)) + self.seeds_.add(seed) + # Check that state is not already defined. + if name in self.states_: + raise Exception('cuda rng state {} already exists'.format(name)) + # Get the current rng state. + orig_rng_state = torch.cuda.get_rng_state() + # Set the new state and store it. + torch.cuda.manual_seed(seed) + self.states_[name] = torch.cuda.get_rng_state() + # Reset rng state to what it was. + _set_cuda_rng_state(orig_rng_state) + + @contextlib.contextmanager + def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): + """Fork the cuda rng state, perform operations, and exit with + the original state.""" + # Check if we have added the state + if name not in self.states_: + raise Exception('cuda rng state {} is not added'.format(name)) + # Store current rng state. + orig_cuda_rng_state = torch.cuda.get_rng_state() + # Set rng state to the desired one + _set_cuda_rng_state(self.states_[name]) + # Do the stuff we wanted to do. + try: + yield + finally: + # Update the current rng state for later use. + self.states_[name] = torch.cuda.get_rng_state() + # And set the state to the original state we started with. + _set_cuda_rng_state(orig_cuda_rng_state) + + +# RNG tracker object. +_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + + +def get_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _CUDA_RNG_STATE_TRACKER + + +def model_parallel_cuda_manual_seed(seed): + """Initialize model parallel cuda seed. + + This function should be called after the model parallel is + initialized. Also, no torch.cuda.manual_seed should be called + after this function. Basically, this is replacement for that + function. + Two set of RNG states are tracked: + default state: This is for data parallelism and is the same among a + set of model parallel GPUs but different across + different model paralle groups. This is used for + example for dropout in the non-tensor-model-parallel regions. + tensor-model-parallel state: This state is different among a set of model + parallel GPUs, but the same across data parallel + groups. This is used for example for dropout in + model parallel regions. + """ + # 2718 is just for fun and any POSITIVE value will work. + offset = seed + 2718 + tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank() + # Data parallel gets the original seed. + data_parallel_seed = seed + + _CUDA_RNG_STATE_TRACKER.reset() + # Set the default state. + torch.cuda.manual_seed(data_parallel_seed) + # and model parallel state. + _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed) + + +class CheckpointFunction(torch.autograd.Function): + """This function is adapted from torch.utils.checkpoint with + two main changes: + 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` + 2) the states in the model parallel tracker are also properly + tracked/set/reset. + """ + + @staticmethod + def forward(ctx, run_function, distribute_saved_activations, *args): + ctx.run_function = run_function + ctx.distribute_saved_activations = distribute_saved_activations + + # Copy the rng states. + ctx.fwd_cpu_rng_state = torch.get_rng_state() + ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + with torch.no_grad(): + outputs = run_function(*args) + + # Divide hidden states across model parallel group and only keep + # the chunk corresponding to the current rank. + if distribute_saved_activations: + ctx.input_0_shape = args[0].data.shape + safely_set_viewless_tensor_data( + args[0], split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True) + ) + + # Store everything. + ctx.save_for_backward(*args) + + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad(), " + "please use .backward() if possible" + ) + inputs = ctx.saved_tensors + if ctx.distribute_saved_activations: + safely_set_viewless_tensor_data( + inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape) + ) + + # Store the current states. + bwd_cpu_rng_state = torch.get_rng_state() + bwd_cuda_rng_state = torch.cuda.get_rng_state() + bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + # Set the states to what it used to be before the forward pass. + torch.set_rng_state(ctx.fwd_cpu_rng_state) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) + + # Compute the forward pass. + detached_inputs = detach_variable(inputs) + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + + # Set the states back to what it was at the start of this function. + torch.set_rng_state(bwd_cpu_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + torch.autograd.backward(outputs, args) + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) + return (None, None) + grads + + +def checkpoint(function, distribute_saved_activations, *args): + """Checkpoint a model or part of the model. + This has been directly copied from torch.utils.checkpoint.""" + return CheckpointFunction.apply(function, distribute_saved_activations, *args) diff --git a/training/DeepSpeed-Domino/megatron/core/tensor_parallel/utils.py b/training/DeepSpeed-Domino/megatron/core/tensor_parallel/utils.py new file mode 100644 index 000000000..a79ae1e87 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/tensor_parallel/utils.py @@ -0,0 +1,113 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +from typing import List, Sequence + +import torch + +from megatron.core import parallel_state +from megatron.core.utils import divide + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """ Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): + """ Break a tensor into equal 1D chunks across tensor parallel ranks. + + Returns a Tensor or View with this rank's portion of the data. + + Arguments: + tensor: The tensor to split + + Keyword Arguments: + new_buffer (bool): If True, returns a new Tensor. + If False, returns a view into the existing Tensor. + Default is False + + """ + partition_size = torch.numel(tensor) // parallel_state.get_tensor_model_parallel_world_size() + start_index = partition_size * parallel_state.get_tensor_model_parallel_rank() + end_index = start_index + partition_size + if new_buffer: + data = torch.empty( + partition_size, + dtype=tensor.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + data.copy_(tensor.view(-1)[start_index:end_index]) + else: + data = tensor.view(-1)[start_index:end_index] + return data + + +def gather_split_1d_tensor(tensor): + """ Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor + model parallel ranks. + + Returns a new Tensor with the gathered data. + + Arguments: + tensor: A Tensor or view of this rank's portion of the data. + """ + numel_gathered = torch.numel(tensor) * parallel_state.get_tensor_model_parallel_world_size() + gathered = torch.empty( + numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False + ) + # TODO: This API is experimental in pytorch (as of Feb 2022) and + # this might break in future pytorch releases. We chose this API + # as opposed to torch.distributed.all_gather for efficiency reasons. + # This API calls directly NCCL all-gather versus the former does + # internal copies and can potentially cause slow down. + torch.distributed._all_gather_base( + gathered, tensor, group=parallel_state.get_tensor_model_parallel_group() + ) + return gathered + + +class VocabUtility: + """ Split the vocabulary into `world_size` chunks and return the first + and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last) + + """ + + @staticmethod + def vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size: int, rank, world_size: int + ) -> Sequence[int]: + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + @staticmethod + def vocab_range_from_global_vocab_size( + global_vocab_size: int, rank: int, world_size: int + ) -> Sequence[int]: + per_partition_vocab_size = divide(global_vocab_size, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, world_size + ) diff --git a/training/DeepSpeed-Domino/megatron/core/transformer/__init__.py b/training/DeepSpeed-Domino/megatron/core/transformer/__init__.py new file mode 100644 index 000000000..c4ae4739d --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/transformer/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from .transformer_config import TransformerConfig diff --git a/training/DeepSpeed-Domino/megatron/core/transformer/attention.py b/training/DeepSpeed-Domino/megatron/core/transformer/attention.py new file mode 100644 index 000000000..913c7c051 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/transformer/attention.py @@ -0,0 +1,368 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from abc import ABC, abstractmethod + +import torch + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.models.common.rotary_pos_embedding import apply_rotary_pos_emb +from megatron.core.transformer.custom_layers.transformer_engine import ( + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TERowParallelLinear, +) +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import divide + +from .enums import AttnMaskType +from .transformer_config import TransformerConfig + + +class Attention(MegatronModule, ABC): + """Attention layer abstract class. + + This layer only contains common modules required for the "self attn" and + "cross attn" specializations. + """ + + def __init__( + self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding, + ): + super().__init__(config=config) + + self.config = config + self.layer_number = layer_number + self.attn_mask_type = attn_mask_type + + # For normal attention without groups, num_query_groups == num_attention_heads, + # so these two will be the same + self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads + self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups + + # Per attention head and per partition values. + world_size = parallel_state.get_tensor_model_parallel_world_size() + self.hidden_size_per_attention_head = divide( + self.query_projection_size, self.config.num_attention_heads + ) + self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) + self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) + + self.dot_product_attention = TEDotProductAttention( + config=self.config, layer_number=self.layer_number, attn_mask_type=self.attn_mask_type + ) + + self.checkpoint_dot_product_attention = self.config.recompute_granularity == 'selective' + + # Output. + self.linear_proj = TERowParallelLinear( + self.query_projection_size, + self.config.hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + skip_bias_add=True, + ) + + def _checkpointed_attention_forward( + self, query, key, value, attention_mask, rotary_pos_emb=None + ): + """Forward method with selective activation checkpointing.""" + + def custom_forward(*inputs): + query = inputs[0] + key = inputs[1] + value = inputs[2] + attention_mask = inputs[3] + output_ = self.dot_product_attention(query, key, value, attention_mask) + return output_ + + hidden_states = tensor_parallel.checkpoint( + custom_forward, False, query, key, value, attention_mask, rotary_pos_emb + ) + + return hidden_states + + def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype): + """Allocate memory to store kv cache during inference.""" + + return torch.empty( + inference_max_sequence_length, + batch_size, + self.num_query_groups_per_partition, + self.hidden_size_per_attention_head, + dtype=dtype, + device=torch.cuda.current_device(), + ) + + def _adjust_key_value_for_inference(self, inference_params, key, value, rotary_pos_emb): + """ + Saves the generated key and value tensors to the end of the buffers in inference_params. + Returns the full size keys and values from the provided inference_params, as well as + adjusted rotary_pos_emb. + + Returns a tuple: (key, value, rotary_pos_emb) + + """ + if inference_params is None: + return key, value, rotary_pos_emb + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + is_first_step = False + if self.layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_length = inference_params.max_sequence_length + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory( + inf_max_seq_length, inf_max_batch_size, key.dtype + ) + inference_value_memory = self._allocate_memory( + inf_max_seq_length, inf_max_batch_size, value.dtype + ) + inference_params.key_value_memory_dict[self.layer_number] = ( + inference_key_memory, + inference_value_memory, + ) + is_first_step = True + else: + # Get the pre-allocated buffers for this layer + inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[ + self.layer_number + ] + + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key.size(1) + assert batch_end <= inference_key_memory.size(1) + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key.size(0) + assert sequence_end <= inference_key_memory.size(0) + # Copy key and values. + inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key + inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value + key = inference_key_memory[:sequence_end, batch_start:batch_end, ...] + value = inference_value_memory[:sequence_end, batch_start:batch_end, ...] + + # adjust the key rotary positional embedding + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + # need to cross check this condition during inference + # if not set_inference_key_value_memory: + if not is_first_step: + # In inference, we compute one token at a time. + # Select the correct positional embedding + # (only the last token in the sequence) + q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end] + else: + # In the first forward pass of inference, + # we use the entire provided prefix. + # q_pos_emb here has the rope embeddings of the entire + # prefix + to-be-generated output so + # we slice to just the prefix. + q_pos_emb = q_pos_emb[:sequence_end, :, :, :] + k_pos_emb = k_pos_emb[:sequence_end, :, :, :] + rotary_pos_emb = (q_pos_emb, k_pos_emb) + + return key, value, rotary_pos_emb + + @abstractmethod + def get_query_key_value_tensors(self, hidden_states, key_value_states): + """ + This method needs to be implemented based on whether the derived class + is "self-attn" or "cross-attn". + """ + + def forward( + self, + hidden_states, + attention_mask, + key_value_states=None, + inference_params=None, + rotary_pos_emb=None, + ): + # hidden_states: [sq, b, h] + + # For self attention we just duplicate the rotary_pos_emb if it isn't already + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + # ===================== + # Query, Key, and Value + # ===================== + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + + # =================================================== + # Adjust key, value, and rotary_pos_emb for inference + # =================================================== + key, value, rotary_pos_emb = self._adjust_key_value_for_inference( + inference_params, key, value, rotary_pos_emb + ) + + # ================================================ + # relative positional embedding (rotary embedding) + # ================================================ + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + query = apply_rotary_pos_emb(query, q_pos_emb) + key = apply_rotary_pos_emb(key, k_pos_emb) + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + + # ================================== + # core attention computation + # ================================== + + # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn] + # This is a noop for normal attention where ng == np. When using group query attention this + # creates a view that has the keys and values virtually repeated along their dimension to + # match the number of queries. + if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1: + key = key.repeat_interleave( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2 + ) + value = value.repeat_interleave( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2 + ) + + if self.checkpoint_dot_product_attention: + core_attn_out = self._checkpointed_attention_forward(query, key, value, attention_mask) + else: + core_attn_out = self.dot_product_attention(query, key, value, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.linear_proj(core_attn_out) + + return output, bias + + +class SelfAttention(Attention): + """Self-attention layer class + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__( + self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding + ): + super().__init__(config=config, layer_number=layer_number, attn_mask_type=attn_mask_type) + + self.linear_qkv = TELayerNormColumnParallelLinear( + self.config.hidden_size, + self.query_projection_size + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + bias=self.config.add_bias_linear, + skip_bias_add=False, + ) + + def get_query_key_value_tensors(self, hidden_states, key_value_states=None): + """ + Derives `query`, `key` and `value` tensors from `hidden_states`. + """ + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + mixed_qkv, _ = self.linear_qkv(hidden_states) + + # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] + new_tensor_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head + ), + ) + mixed_qkv = mixed_qkv.view(*new_tensor_shape) + + # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = torch.split( + mixed_qkv, + [ + ( + self.num_attention_heads_per_partition + // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + ], + dim=3, + ) + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + + return query, key, value + + +class CrossAttention(Attention): + """Cross-attention layer class + + Cross-attention layer takes input with size [s, b, h] and context with size + [s, b, h] and returns output of the same size. + """ + + def __init__( + self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding + ): + super().__init__(config=config, layer_number=layer_number, attn_mask_type=attn_mask_type) + + if self.config.num_query_groups != self.config.num_attention_heads: + raise ValueError( + f"Group query attention is not currently supported in cross attention." + ) + assert self.query_projection_size == self.kv_projection_size + + self.linear_q = TELayerNormColumnParallelLinear( + self.config.hidden_size, + self.query_projection_size, + config=self.config, + init_method=self.config.init_method, + bias=self.config.add_bias_linear, + skip_bias_add=False, + ) + + self.linear_kv = TELayerNormColumnParallelLinear( + self.config.hidden_size, + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + bias=self.config.add_bias_linear, + skip_bias_add=False, + ) + + def get_query_key_value_tensors(self, hidden_states, key_value_states): + """ + Derives `query` tensor from `hidden_states`, and `key`/`value` tensors + from `key_value_states`. + """ + # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + mixed_kv, _ = self.linear_kv(key_value_states) + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv.size()[:-1] + ( + self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head, + ) + mixed_kv = mixed_kv.view(*new_tensor_shape) + + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2) + + # Attention head [sq, b, h] --> [sq, b, hp] + query, _ = self.linear_q(hidden_states) + + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + query = query.view(*new_tensor_shape) + + return query, key, value diff --git a/training/DeepSpeed-Domino/megatron/core/transformer/custom_layers/__init__.py b/training/DeepSpeed-Domino/megatron/core/transformer/custom_layers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/training/DeepSpeed-Domino/megatron/core/transformer/custom_layers/transformer_engine.py b/training/DeepSpeed-Domino/megatron/core/transformer/custom_layers/transformer_engine.py new file mode 100644 index 000000000..0b66009cd --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/transformer/custom_layers/transformer_engine.py @@ -0,0 +1,249 @@ +from importlib.metadata import version +from typing import Callable + +import torch +import transformer_engine as te +from pkg_resources import packaging + +from megatron.core.parallel_state import get_tensor_model_parallel_group +from megatron.core.tensor_parallel import get_cuda_rng_tracker +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.transformer_config import TransformerConfig + + +def _get_extra_te_kwargs(config: TransformerConfig): + extra_transformer_engine_kwargs = {} + from importlib.metadata import version + + from pkg_resources import packaging + + te_version = packaging.version.Version(version("transformer-engine")) + if te_version >= packaging.version.Version("0.12.0"): + if config.use_cpu_initialization: + extra_transformer_engine_kwargs["device"] = 'cpu' + else: + extra_transformer_engine_kwargs["device"] = torch.cuda.current_device() + return extra_transformer_engine_kwargs + + +class TENorm: + """ + A conditional wrapper to initialize an instance of Transformer-Engine's + `LayerNorm` or `RMSNorm` based on input + """ + + def __new__( + cls, + config: TransformerConfig, + hidden_size: int, + eps: float = 1e-5, + sequence_parallel: bool = False, + normalization="LayerNorm", + **kwargs + ): + zero_centered_gamma = kwargs.get('zero_centered_gamma', False) + if normalization == "LayerNorm": + instance = te.pytorch.LayerNorm( + hidden_size=hidden_size, + eps=eps, + sequence_parallel=sequence_parallel, + zero_centered_gamma=zero_centered_gamma, + **_get_extra_te_kwargs(config), + ) + elif normalization == "RMSNorm": + assert hasattr( + te.pytorch, "RMSNorm" + ), "Transformer-Engine >= v0.11 required to use this feature" + instance = te.pytorch.RMSNorm( + hidden_size=hidden_size, + eps=eps, + sequence_parallel=sequence_parallel, + zero_centered_gamma=zero_centered_gamma, + **_get_extra_te_kwargs(config), + ) + else: + raise Exception('Only LayerNorm and RMSNorm are curently supported') + + return instance + + +class TELinear(te.pytorch.Linear): + """ + Wrapper for the Transformer-Engine's `Linear` layer. + + Note that if Megatron's parallel_state has not been initialized + yet, the tp_group passed to TE will be None and must be set later + via set_tensor_parallel_group(). + """ + + def __init__( + self, + input_size: int, + output_size: int, + config: TransformerConfig, + parallel_mode: str, + init_method: Callable, + *, + bias: bool = True, + skip_bias_add: bool = False, + **kwargs + ): + self.config = config + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + + super().__init__( + in_features=input_size, + out_features=output_size, + sequence_parallel=self.config.sequence_parallel, + fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, + tp_group=get_tensor_model_parallel_group(check_initialized=False), + tp_size=self.config.tensor_model_parallel_size, + get_rng_state_tracker=get_cuda_rng_tracker, + init_method=init_method, + params_dtype=self.config.params_dtype, + parallel_mode=parallel_mode, + bias=bias, + return_bias=self.te_return_bias, + **_get_extra_te_kwargs(config), + **kwargs, + ) + + def forward(self, x): + out = super().forward(x) + + # TE only returns a tuple when return_bias is True, otherwise + # it returns a single Tensor, we always want to return two + # values regardless of the arguments. + if self.te_return_bias: + return out + return out, None + + +class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): + """ + Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines + layernorm and linear layers + """ + + def __init__( + self, + input_size: int, + output_size: int, + config: TransformerConfig, + init_method: Callable, + bias: bool, + skip_bias_add: bool, + **kwargs + ): + self.config = config + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + + # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` + te_version = packaging.version.Version(version("transformer-engine")) + if te_version >= packaging.version.Version("0.11.0"): + kwargs["normalization"] = self.config.normalization + + super().__init__( + in_features=input_size, + out_features=output_size, + bias=bias, + sequence_parallel=self.config.sequence_parallel, + fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, + tp_group=get_tensor_model_parallel_group(check_initialized=False), + tp_size=self.config.tensor_model_parallel_size, + get_rng_state_tracker=get_cuda_rng_tracker, + init_method=init_method, + params_dtype=self.config.params_dtype, + parallel_mode="column", + return_bias=self.te_return_bias, + zero_centered_gamma=self.config.layernorm_zero_centered_gamma, + **_get_extra_te_kwargs(config), + **kwargs, + ) + + def forward(self, x): + out = super().forward(x) + + # TE only returns a tuple when return_bias is True, otherwise + # it returns a single Tensor, we always want to return two + # values regardless of the arguments. + if self.te_return_bias: + return out + return out, None + + +class TEColumnParallelLinear(TELinear): + """ + Wrapper for the Transformer-Engine's `Linear` layer but specialized similar + to megatron's `ColumnParallelLinear` layer. + """ + + def __init__(self, input_size: int, output_size: int, config: TransformerConfig, **kwargs): + self.config = config + super().__init__( + input_size=input_size, + output_size=output_size, + config=self.config, + parallel_mode="column", + **kwargs, + ) + + +class TERowParallelLinear(TELinear): + """ + Wrapper for the Transformer-Engine's `Linear` layer but specialized similar + to megatron's `RowParallelLinear` layer. + """ + + def __init__(self, input_size: int, output_size: int, config: TransformerConfig, **kwargs): + self.config = config + super().__init__( + input_size=input_size, + output_size=output_size, + config=self.config, + parallel_mode="row", + **kwargs, + ) + + +class TEDotProductAttention(te.pytorch.DotProductAttention): + """ + Wrapper for the Transformer-Engine's `DotProductAttention` layer that also + has "flash attention" enabled. + + Note that if Megatron's parallel_state has not been initialized + yet, the tp_group passed to TE will be None and must be set later + via set_tensor_parallel_group(). + """ + + def __init__( + self, + config: TransformerConfig, + layer_number: int = 1, + attn_mask_type: AttnMaskType = AttnMaskType.padding, + **kwargs + ): + self.config = config + super().__init__( + num_attention_heads=self.config.num_attention_heads, + kv_channels=self.config.kv_channels, + attention_dropout=self.config.attention_dropout, + layer_number=layer_number, + attn_mask_type=attn_mask_type.name, + sequence_parallel=self.config.sequence_parallel, + tp_size=self.config.tensor_model_parallel_size, + get_rng_state_tracker=get_cuda_rng_tracker, + tp_group=get_tensor_model_parallel_group(check_initialized=False), + **kwargs, + ) diff --git a/training/DeepSpeed-Domino/megatron/core/transformer/dot_product_attention.py b/training/DeepSpeed-Domino/megatron/core/transformer/dot_product_attention.py new file mode 100644 index 000000000..d99adb4c3 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/transformer/dot_product_attention.py @@ -0,0 +1,165 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + + +import math + +import torch +from torch import Tensor + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import attention_mask_func +from megatron.core.utils import divide + + +class DotProductAttention(MegatronModule): + """ + Region where selective activation recomputation is applied. + This region is memory intensive but less compute intensive which + makes activation checkpointing more efficient for LLMs (20B+). + See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + + We use the following notation: + h: hidden size + n: number of attention heads + p: number of tensor model parallel partitions + b: batch size + s: sequence length + """ + + def __init__( + self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding + ): + super().__init__(config=config) + + self.config: TransformerConfig = config + + self.layer_number = max(1, layer_number) + self.attn_mask_type = attn_mask_type + + projection_size = self.config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + world_size = parallel_state.get_tensor_model_parallel_world_size() + self.hidden_size_per_partition = divide(projection_size, world_size) + self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = divide(config.num_attention_heads, world_size) + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.config.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + input_in_fp16=self.config.fp16, + input_in_bf16=self.config.bf16, + attn_mask_type=self.attn_mask_type, + scaled_masked_softmax_fusion=self.config.masked_softmax_fusion, + mask_func=attention_mask_func, + softmax_in_fp32=self.config.attention_softmax_in_fp32, + scale=coeff, + ) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout(self.config.attention_dropout) + + def forward( + self, query_layer: Tensor, key_layer: Tensor, value_layer: Tensor, attention_mask: Tensor + ): + + # =================================== + # Raw attention scores. [b, n/p, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) + + # [sq, b, np, hn] -> [sq, b * np, hn] + # This will be a simple view when doing normal attention, but in group query attention + # the key and value tensors are repeated to match the queries so you can't use simple strides + # to extract the queries. + query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor( + (output_size[0] * output_size[1], output_size[2], output_size[3]), + query_layer.dtype, + "mpu", + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + if not self.config.sequence_parallel: + with tensor_parallel.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) + + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer diff --git a/training/DeepSpeed-Domino/megatron/core/transformer/enums.py b/training/DeepSpeed-Domino/megatron/core/transformer/enums.py new file mode 100644 index 000000000..3583daa17 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/transformer/enums.py @@ -0,0 +1,25 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import enum + + +# can we get rid of this? +# it's being used in pipeline schedules +class ModelType(enum.Enum): + encoder_or_decoder = 1 + encoder_and_decoder = 2 + + +# class LayerType(enum.Enum): +# encoder = 1 +# decoder = 2 + + +class AttnType(enum.Enum): + self_attn = 1 + cross_attn = 2 + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 diff --git a/training/DeepSpeed-Domino/megatron/core/transformer/identity_op.py b/training/DeepSpeed-Domino/megatron/core/transformer/identity_op.py new file mode 100644 index 000000000..79dcddc1f --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/transformer/identity_op.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +import torch + + +class IdentityOp(torch.nn.Module): + """ + This is a placeholder for IdentityOp (NoOp) + """ + + def __init__(self, *args, **kwargs): + super(IdentityOp, self).__init__() + + def forward(self, x, *args, **kwargs): + return x diff --git a/training/DeepSpeed-Domino/megatron/core/transformer/mlp.py b/training/DeepSpeed-Domino/megatron/core/transformer/mlp.py new file mode 100644 index 000000000..16696ceaf --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/transformer/mlp.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import torch +import torch.nn.functional as F + +from megatron.core import tensor_parallel +from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl +from megatron.core.transformer.custom_layers.transformer_engine import ( + TELayerNormColumnParallelLinear, + TERowParallelLinear, +) +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig + + +class MLP(MegatronModule): + """ + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + + + Returns an output and a bias to be added to the output. + If config.add_bias_linear is False, the bias returned is None. + + We use the following notation: + h: hidden size + p: number of tensor model parallel partitions + b: batch size + s: sequence length + """ + + def __init__(self, config: TransformerConfig): + super().__init__(config=config) + + self.config: TransformerConfig = config + + # If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf + ffn_hidden_size = self.config.ffn_hidden_size + if self.config.gated_linear_unit: + ffn_hidden_size *= 2 + + self.linear_fc1 = TELayerNormColumnParallelLinear( + self.config.hidden_size, + ffn_hidden_size, + config=self.config, + init_method=self.config.init_method, + bias=self.config.add_bias_linear, + skip_bias_add=True, + ) + + if self.config.gated_linear_unit: + + def glu(x): + x = torch.chunk(x, 2, dim=-1) + return self.config.activation_func(x[0]) * x[1] + + self.activation_func = glu + else: + self.activation_func = self.config.activation_func + + self.linear_fc2 = TERowParallelLinear( + self.config.ffn_hidden_size, + self.config.hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + skip_bias_add=True, + ) + + def forward(self, hidden_states): + + # [s, b, 4 * h/p] + intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states) + + if self.config.bias_gelu_fusion: + assert self.config.add_bias_linear is True + assert self.activation_func == F.gelu + intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) + else: + if bias_parallel is not None: + intermediate_parallel = intermediate_parallel + bias_parallel + intermediate_parallel = self.activation_func(intermediate_parallel) + + # [s, b, h] + output, output_bias = self.linear_fc2(intermediate_parallel) + return output, output_bias diff --git a/training/DeepSpeed-Domino/megatron/core/transformer/module.py b/training/DeepSpeed-Domino/megatron/core/transformer/module.py new file mode 100644 index 000000000..c0f08fe11 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/transformer/module.py @@ -0,0 +1,132 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""Megatron Module""" + +import torch +from torch.autograd import Variable +from torch.nn.parameter import Parameter + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.transformer.transformer_config import TransformerConfig + +_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) +_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) +_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor) + + +def param_is_not_shared(param): + return not hasattr(param, 'shared') or not param.shared + + +class MegatronModule(torch.nn.Module): + """Megatron specific extensions of torch Module with support + for pipelining.""" + + # def __init__(self, config: TransformerConfig, share_word_embeddings=True): + def __init__(self, config: TransformerConfig): + super().__init__() + self.config = config + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """Use this function to override the state dict for + saving checkpoints. + """ + + return self.state_dict(prefix=prefix, keep_vars=keep_vars) + + def sharded_state_dict(self, prefix=''): + """ Override sharded_state_dict when using distributed checkpointing. + keep_vars must always be set to True so that optimizer states + can be sharded. + """ + return self.state_dict(prefix=prefix, keep_vars=True) + + +def conversion_helper(val, conversion): + """Apply conversion to val. Recursively apply conversion if `val` + #is a nested tuple/list structure.""" + if not isinstance(val, (tuple, list)): + return conversion(val) + rtn = [conversion_helper(v, conversion) for v in val] + if isinstance(val, tuple): + rtn = tuple(rtn) + return rtn + + +def fp32_to_float16(val, float16_convertor): + """Convert fp32 `val` to fp16/bf16""" + + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, _FLOAT_TYPES): + val = float16_convertor(val) + return val + + return conversion_helper(val, half_conversion) + + +def float16_to_fp32(val): + """Convert fp16/bf16 `val` to fp32""" + + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)): + val = val.float() + return val + + return conversion_helper(val, float_conversion) + + +class Float16Module(MegatronModule): + def __init__(self, config: TransformerConfig, module: torch.nn.Module): + super(Float16Module, self).__init__(config) + self.config = config + self.fp16 = config.fp16 + self.bf16 = config.bf16 + + if self.fp16: + self.add_module('module', module.half()) + + def float16_convertor(val): + return val.half() + + elif self.bf16: + self.add_module('module', module.bfloat16()) + + def float16_convertor(val): + return val.bfloat16() + + else: + raise Exception('Either config.fp16 or config.bf16 should be True.') + + self.float16_convertor = float16_convertor + + def set_input_tensor(self, input_tensor): + return self.module.set_input_tensor(input_tensor) + + def forward(self, *inputs, **kwargs): + if parallel_state.is_pipeline_first_stage(): + inputs = fp32_to_float16(inputs, self.float16_convertor) + outputs = self.module(*inputs, **kwargs) + if parallel_state.is_pipeline_last_stage(): + outputs = float16_to_fp32(outputs) + return outputs + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """ Retrieve state_dict from the module being wrapped.""" + return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars) + + def sharded_state_dict(self, prefix=''): + """ Retrieve sharded_state_dict from the module being wrapped. + """ + return self.module.sharded_state_dict(prefix=prefix) + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) diff --git a/training/DeepSpeed-Domino/megatron/core/transformer/transformer_block.py b/training/DeepSpeed-Domino/megatron/core/transformer/transformer_block.py new file mode 100644 index 000000000..2b9ba7908 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/transformer/transformer_block.py @@ -0,0 +1,286 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import re +from contextlib import nullcontext + +import torch + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.fusions.fused_layer_norm import FusedLayerNorm +from megatron.core.transformer.custom_layers.transformer_engine import TENorm +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer +from megatron.core.utils import make_sharded_tensor_for_checkpoint, make_viewless_tensor + + +class TransformerBlock(MegatronModule): + """Transformer class.""" + + def __init__( + self, + config: TransformerConfig, + self_attn_mask_type=AttnMaskType.padding, + post_layer_norm=True, + pre_process=True, + post_process=True, + ): + super().__init__(config=config) + + self.config: TransformerConfig = config + + self.self_attn_mask_type = self_attn_mask_type + self.post_layer_norm = post_layer_norm + self.pre_process = pre_process + self.post_process = post_process + + # required for pipeline parallel schedules + self.input_tensor = None + + self.checkpoint_core_attention = self.config.recompute_granularity == 'selective' + + self.num_layers_per_pipeline_rank = ( + self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size() + ) + + self._build_layers() + + def _build_layers(self): + # Transformer layers. + # @jcasper can we improve how we deal with layer_number? + # currently it's only used in CoreAttention? + # if self.apply_query_key_layer_scaling: + # coeff = self.layer_number + # self.norm_factor *= coeff + def build_layer(layer_number): + layer = TransformerLayer( + config=self.config, + layer_number=layer_number, + self_attn_mask_type=self.self_attn_mask_type, + ) + return layer + + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + # Interleaved pipeline parallelism: + # Number of layers in each model chunk is the number of layers in the stage, + # divided by the number of model chunks in a stage. + # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0] [2] [4] [6] + # Stage 1: [1] [3] [5] [7] + # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0, 1] [4, 5] + # Stage 1: [2, 3] [6, 7] + + vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + + num_layers_per_virtual_rank = self.num_layers_per_pipeline_rank // vp_size + + num_layers_to_build = num_layers_per_virtual_rank + + else: + # Non-interleaved pipeline parallelism: + # Each stage gets a contiguous set of layers. + + num_layers_to_build = self.num_layers_per_pipeline_rank + + # offset is implicit in TransformerLayer + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(num_layers_to_build)]) + + # # TODO: add back standalone_embedding_stage + # if self.num_layers == 0: + # # When a standalone embedding stage is used (e.g., + # # args.standalone_embedding_stage == True), virtual pipeline ranks + # # on pipeline rank 0 will have zero transformer layers assigned to + # # them. This results in the model's input and output tensors to be + # # the same, which will cause failure for certain output tensor + # # optimizations (e.g., pipeline output deallocation). To remedy + # # this, we assign a 'no-op' layer on these ranks, which will + # # disconnect the input tensor from the output tensor. + # self.num_layers = 1 + # self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)]) + # else: + # self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)]) + + if self.post_process and self.post_layer_norm: + # Final layer norm before output. + self.final_layernorm = TENorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + persist_layer_norm=self.config.persist_layer_norm, + sequence_parallel=self.config.sequence_parallel, + zero_centered_gamma=self.config.layernorm_zero_centered_gamma, + normalization=self.config.normalization, + ) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def _checkpointed_forward(self, hidden_states, attention_mask, rotary_pos_emb): + """Forward method with activation checkpointing.""" + + def custom(start, end): + def custom_forward(*args, **kwargs): + x_, *args = args + for index in range(start, end): + layer = self._get_layer(index) + x_ = layer(x_, *args, **kwargs) + return x_ + + return custom_forward + + if self.config.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + l = 0 + while l < self.num_layers_per_pipeline_rank: + hidden_states = tensor_parallel.checkpoint( + custom(l, l + self.config.recompute_num_layers), + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + rotary_pos_emb, + ) + + l += self.config.recompute_num_layers + + elif self.config.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + for l in range(self.num_layers_per_pipeline_rank): + if l < self.config.recompute_num_layers: + hidden_states = tensor_parallel.checkpoint( + custom(l, l + 1), + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + rotary_pos_emb, + ) + else: + hidden_states = custom(l, l + 1)(hidden_states, attention_mask, rotary_pos_emb) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward(self, hidden_states, attention_mask, inference_params=None, rotary_pos_emb=None): + # hidden_states (float): [s, b, h] + # attention_mask (bool): [1, 1, s, s] + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True, + ) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + if self.config.fp8: + import transformer_engine # To keep out TE dependency when not training in fp8 + + if self.config.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif self.config.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + + fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + margin=self.config.fp8_margin, + interval=self.config.fp8_interval, + fp8_format=fp8_format, + amax_compute_algo=self.config.fp8_amax_compute_algo, + amax_history_len=self.config.fp8_amax_history_len, + override_linear_precision=(False, False, not self.config.fp8_wgrad), + ) + fp8_group = None + if parallel_state.model_parallel_is_initialized(): + fp8_group = parallel_state.get_amax_reduction_group() + fp8_context = transformer_engine.pytorch.fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group + ) + else: + fp8_context = nullcontext() + + with rng_context and fp8_context: + # Forward pass. + if self.config.recompute_granularity == 'full': + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + ) + else: + for layer in self.layers: + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + inference_params=inference_params, + ) + + # Final layer norm. + if self.post_process and self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + def sharded_state_dict(self, prefix=''): + + sharded_state_dict = {} + + layer_prefix = f'{prefix}layers.' + for layer in self.layers: + sharded_state_dict.update(layer.sharded_state_dict(prefix=layer_prefix)) + + if self.post_process and self.post_layer_norm: + state_dict = self.state_dict(keep_vars=True) + + tensor = state_dict['final_layernorm.weight'] + layer_name = f'{prefix}final_layernorm.weight' + sharded_state_dict[layer_name] = make_sharded_tensor_for_checkpoint(tensor, layer_name) + + # RMSNorm doesn't have bias. + if 'final_layernorm.bias' in state_dict.keys(): + tensor = state_dict['final_layernorm.bias'] + layer_name = f'{prefix}final_layernorm.bias' + sharded_state_dict[layer_name] = make_sharded_tensor_for_checkpoint( + tensor, layer_name + ) + + return sharded_state_dict diff --git a/training/DeepSpeed-Domino/megatron/core/transformer/transformer_config.py b/training/DeepSpeed-Domino/megatron/core/transformer/transformer_config.py new file mode 100644 index 000000000..532c89b00 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/transformer/transformer_config.py @@ -0,0 +1,273 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from dataclasses import dataclass +from typing import Callable + +import torch +import torch.nn.functional as F + +from megatron.core import ModelParallelConfig +from megatron.core.utils import init_method_normal, scaled_init_method_normal + + +@dataclass +class TransformerConfig(ModelParallelConfig): + """Configuration object for megatron-core transformers. + + Attributes: + + # model architecture + num_layers (int): Number of transformer layers in a transformer block. + hidden_size (int): Transformer hidden size. + ffn_hidden_size (int): Transformer Feed-Forward Network hidden size. + This is set to 4*hidden_size if not provided. Defaults to None.') + num_attention_heads (int): Number of transformer attention heads. + kv_channels (int): Projection weights dimension in multi-head attention. + This is set to hidden_size // num_attention_heads if not provided. + Defaults to None. + num_query_groups (int): Number of query groups for group query attention. If None, normal attention is used. + + hidden_dropout (float): Dropout probability for transformer hidden state. Defaults to 0.1. + attention_dropout (float): Post attention dropout probability. Defaults to 0.1. + fp32_residual_connection (bool): If true, move residual connections to fp32. + apply_residual_connection_post_layernorm (bool): If true, uses the original BERT residule connection ordering. + Defaults to False. + layernorm_epsilon (float): Layernorm epsilon. Defaults to 1e-5. + + layernorm_zero_centered_gamma (bool): if set to 'True', the LayerNorm is adjusted to center the gamma values + around 0. This improves numerical stability. Defaults to False. + + add_bias_linear (bool): Include a bias term in all linear layers (QKV projections, after core attention, and two + in MLP layer). Default is True. + + gated_linear_unit (bool): Use a gated linear unit for the first linear layer in the MLP. Defaults to False. + + activation_func (Callable): Activation function to use for the non-linearity in the MLP. Defaults to F.gelu. + + # initialization + init_method (Callable): Method to initialize weights. Note that bias is always set to + zero. Should be a function that takes a single Tensor and + initializes it. Defaults to + megatron.core.utils.init_method_normal(init_method_std) which is + torch.nn.init.normal_ with mean=0.0 and std=init_method_Std. + + output_layer_init_method (Callable): Method to initialize weights of the output layer of + both attention and MLP blocks. Defaults to + megatron.core.utils.scaled_init_method_normal(init_method_std) + which is torch.nn.init.normal_ with mean=0.0 and + std=init_method_std / math.sqrt(2.0 * num_layers). + + init_method_std (float): Standard deviation of the zero mean normal for the default + initialization method, not used if init_method and + output_layer_init_method are provided. Defaults to 0.02. + + # mixed-precision + apply_query_key_layer_scaling (bool): If true, scale Q * K^T by 1 / layer-number. Defaults to True. + attention_softmax_in_fp32 (bool): If true, run attention masking and softmax in fp32. + This should be true if apply_query_key_layer_scaling is true. + + # fusion + bias_gelu_fustion (bool): If true, fuses bias and gelu. Defaults to False. + masked_softmax_fusion (bool): If true, uses softmax fusion. + persist_layer_norm (bool): If true, uses the persistent fused layer norm kernel. + This kernel only supports a fixed set of hidden sizes. + Defaults to False. + bias_dropout_fusion (bool): If true, uses bias dropout fusion. + + # activation recomputation + + recompute_granularity (str): megatron-core supports 'selective' activation checkpointing where only the memory + intensive part of attention is checkpointed. These memory intensive activations + are also less compute intensive which makes activation checkpointing more efficient + for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer + Models: https://arxiv.org/abs/2205.05198 for more details. 'full' will checkpoint + the entire transformer layer. Must be 'selective' or 'full'. 'selective' always uses all layers. + Defaults to None. + + recompute_method (str): uniform will uniformly divide the total number of transformer layers in a transformer + block and recompute the input activation of each divided chunk at the specified + granularity. block will recompute the input activations for only a set number of + transformer layers per pipeline stage. The rest of the layers in the pipeline stage + will not have any activations recomputed. Must be 'uniform' or 'block'. Defaults to + None. + + recompute_num_layers (int): When recompute_method is uniform, recompute_num_layers is the number of transformer + layers in each uniformly divided recompute unit. When recompute_method is block, + recompute_num_layers is the number of transformer layers to recompute within each + pipeline stage. Must be None for 'selective' activation checkpointing. Defaults to None. + + distribute_saved_activations (bool): If true, distribute recomputed activations across the model parallel + group. Defaults to None. + + # fp8 related (via Transformer Engine). For detailed info, refer the the Transformer Engine docs at + # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html + + fp8 (str): If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined choices: (1) 'e4m3' + uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8 activation and weight tensors and + e5m2 for all FP8 output activation gradient tensors. Defaults to None. + + fp8_margin (int): Margin for the scaling factor computation. + + fp8_interval (int): Controls how often the scaling factor is recomputed. + + fp8_amax_history_len (int): The length of the amax history window used for scaling factor computation. + + fp8_amax_compute_algo (str): Algorithm used for choosing the `amax` value for the scaling factor computation. + There are 2 predefined choices: `max` chooses the largest `amax` in the history + window, while `most_recent` always chooses the most recently seen value. + + fp8_wgrad (bool): When set to False, override FP8 config options and do the wgrad computation in higher precision. + Defaults to True. + + # Experimental + normalization (str): Swtich b/w `LayerNorm` and `RMSNorm` as normalization layers. For now, these are primarily + used by Transformer-Engine's layers like `LayerNormLinear`. Default value is `LayerNorm`. + + + """ + + # model architecture + num_layers: int = 0 + hidden_size: int = 0 + num_attention_heads: int = 0 + num_query_groups: int = None + + ffn_hidden_size: int = None + kv_channels: int = None + hidden_dropout: float = 0.1 + attention_dropout: float = 0.1 + fp32_residual_connection: bool = False + # @jcasper should we keep this option? + apply_residual_connection_post_layernorm: bool = False + layernorm_epsilon: float = 1e-5 + layernorm_zero_centered_gamma: bool = False + add_bias_linear: bool = True + gated_linear_unit: bool = False + activation_func: Callable = F.gelu + + # initialization + init_method: Callable = None + output_layer_init_method: Callable = None + init_method_std: float = 0.02 + + # mixed-precision + apply_query_key_layer_scaling: bool = True + attention_softmax_in_fp32: bool = True + + # communication + + # fusion + bias_gelu_fusion: bool = False # TODO: this should be bias_activation_fusion ? + masked_softmax_fusion: bool = False + persist_layer_norm: bool = False + bias_dropout_fusion: bool = False # TODO: this should be bias_dropout_add_fusion? + + # activation recomputation + recompute_granularity: str = None + recompute_method: str = None + recompute_num_layers: int = None + distribute_saved_activations: bool = None + + # fp8 related + fp8: str = None + fp8_margin: int = 0 + fp8_interval: int = 1 + fp8_amax_history_len: int = 1 + fp8_amax_compute_algo: str = "most_recent" + fp8_wgrad: bool = True + + # experimental section (TODO: move to apt. section above once stable) + normalization: bool = "LayerNorm" # alt value supported by TE: "RMSNorm" + + def __post_init__(self): + """ Python dataclass method that is used to modify attributes after initialization. + See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details. + """ + super().__post_init__() + if self.fp16 and self.bf16: + raise ValueError( + f'Only one of self.fp16: {self.fp16} and self.bf16 {self.bf16} should be True.' + ) + + if self.num_attention_heads % self.tensor_model_parallel_size != 0: + raise ValueError( + f"num_attention_heads ({self.num_attention_heads}) must be a multiple of " + f"tensor_model_parallel_size ({self.tensor_model_parallel_size})." + ) + + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + + if self.kv_channels is None: + self.kv_channels = self.hidden_size // self.num_attention_heads + + if self.num_query_groups is None: + self.num_query_groups = self.num_attention_heads + + if self.num_query_groups % self.tensor_model_parallel_size != 0: + raise ValueError( + f"num_query_groups ({self.num_query_groups}) must be a multiple of " + f"tensor_model_parallel_size ({self.tensor_model_parallel_size})." + ) + + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + + if self.recompute_granularity is not None: + if not self.recompute_granularity in ['full', 'selective']: + raise ValueError( + f'When using recompute_granuarlity: {self.recompute_granularity} must be "full" or "selective".' + ) + + if self.recompute_method is not None: + if not self.recompute_method in ['block', 'uniform']: + raise ValueError( + f'recompute_method: {self.recompute_method} must be "block" or "uniform".' + ) + elif self.recompute_granularity != 'selective': + raise ValueError( + f'Using recompute_granularity: {self.recompute_granularity} so recompute_method must be "block" or "uniform"' + ) + + if self.recompute_granularity != 'selective' and self.recompute_num_layers is None: + raise ValueError( + f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be between ' + f'1 and num_layers_per_pipeline_rank: {self.num_layers // self.pipeline_model_parallel_size}' + ) + elif ( + self.recompute_granularity == 'selective' and self.recompute_num_layers is not None + ): + raise ValueError( + f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be None.' + ) + + if self.distribute_saved_activations and self.sequence_parallel: + raise ValueError( + f'distribute_saved_activations: {self.distribute_saved_activations} must be false when sequence parallel is enabled: {self.sequence_parallel}' + ) + + if self.virtual_pipeline_model_parallel_size is not None: + if not self.num_layers % self.virtual_pipeline_model_parallel_size == 0: + raise ValueError( + f'num_layers: {self.num_layers} must be divisible by virtual_model_parallel_size {self.virtual_pipeline_model_parallel_size}' + ) + + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + + if self.bias_gelu_fusion: + if not self.add_bias_linear: + raise ValueError( + "When bias_gelu_fusion is True, add_bias_linear must also be True." + ) + + if self.activation_func != F.gelu: + raise ValueError(f'When bias_gelu_fusion is True, activation_func must be F.gelu.') + + if self.init_method is None: + self.init_method = init_method_normal(self.init_method_std) + + if self.output_layer_init_method is None: + self.output_layer_init_method = scaled_init_method_normal( + self.init_method_std, self.num_layers + ) diff --git a/training/DeepSpeed-Domino/megatron/core/transformer/transformer_layer.py b/training/DeepSpeed-Domino/megatron/core/transformer/transformer_layer.py new file mode 100644 index 000000000..73b9aadc6 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/transformer/transformer_layer.py @@ -0,0 +1,270 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import re +from functools import partial + +import torch + +from megatron.core import parallel_state +from megatron.core.dist_checkpointing.mapping import ( + ShardedObject, + ShardedTensor, + ShardedTensorFactory, +) +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.transformer.attention import SelfAttention +from megatron.core.transformer.custom_layers.transformer_engine import TENorm +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import make_viewless_tensor + + +class TransformerLayer(MegatronModule): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__( + self, + config: TransformerConfig, + layer_number: int = 1, + self_attn_mask_type=AttnMaskType.padding, + ): + super().__init__(config=config) + self.config: TransformerConfig = config + + self.layer_number = layer_number + self._get_layer_offset() + + self.self_attn_mask_type = self_attn_mask_type + + # Layernorm on the input data. + # TODO: add pytorch only layernorm + self.input_layernorm = IdentityOp( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + persist_layer_norm=self.config.persist_layer_norm, + sequence_parallel=self.config.sequence_parallel, + zero_centered_gamma=self.config.layernorm_zero_centered_gamma, + normalization=self.config.normalization, + ) + + # Self attention. + self.self_attention = SelfAttention( + config=self.config, layer_number=layer_number, attn_mask_type=self_attn_mask_type, + ) + + # Layernorm on the attention output + self.post_self_attn_layernorm = IdentityOp( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + persist_layer_norm=self.config.persist_layer_norm, + sequence_parallel=self.config.sequence_parallel, + zero_centered_gamma=self.config.layernorm_zero_centered_gamma, + normalization=self.config.normalization, + ) + + # MLP + self.mlp = MLP(config=self.config) + + # @jcasper how should we handle nvfuser? + # Set bias+dropout+add fusion grad_enable execution handler. + # TORCH_MAJOR = int(torch.__version__.split('.')[0]) + # TORCH_MINOR = int(torch.__version__.split('.')[1]) + # use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) + # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad + self.bias_dropout_add_exec_handler = torch.enable_grad + + def _get_layer_offset(self): + + pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() + + num_layers_per_pipeline_rank = ( + self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size() + ) + + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() + vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + + total_num_layers = self.config.num_layers + num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size + total_virtual_chunks = total_num_layers // vp_size + offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank) + + else: + # Each stage gets a contiguous set of layers. + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + offset = pipeline_rank * num_layers_per_pipeline_rank + else: + offset = 0 + + return offset + + def forward( + self, + hidden_states, + attention_mask, + encoder_output=None, + enc_dec_attn_mask=None, + inference_params=None, + rotary_pos_emb=None, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output_with_bias = self.self_attention( + layernorm_output, + attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + ) + + # Residual connection. + if self.config.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + bias_dropout_add_func = get_bias_dropout_add(self.training, self.config.bias_dropout_fusion) + + # bias_dropout_add fusion returning fp32 instead of bf16 + with self.bias_dropout_add_exec_handler(): + layernorm_input = bias_dropout_add_func( + attention_output_with_bias, residual, self.config.hidden_dropout + ) + + # Layer norm post the self attention. + layernorm_output = self.post_self_attn_layernorm(layernorm_input) + + # MLP. + mlp_output_with_bias = self.mlp(layernorm_output) + + # Second residual connection. + if self.config.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + with self.bias_dropout_add_exec_handler(): + output = bias_dropout_add_func( + mlp_output_with_bias, residual, self.config.hidden_dropout + ) + + # Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. + output = make_viewless_tensor( + inp=output, requires_grad=output.requires_grad, keep_graph=True + ) + + return output + + def sharded_state_dict(self, prefix=''): + + # state_dict = self.state_dict(prefix=prefix, keep_vars=True) + state_dict = self.state_dict(keep_vars=True) + + tensor_parallel_layers_axis_map = { + 'self_attention.linear_qkv.weight': 0, + 'self_attention.linear_qkv.bias': 0, + 'self_attention.linear_proj.weight': 1, + 'mlp.linear_fc1.weight': 0, + 'mlp.linear_fc1.bias': 0, + 'mlp.linear_fc2.weight': 1, + } + + offset = self._get_layer_offset() + num_layers = self.config.num_layers + + sharded_state_dict = {} + + for layer_name in state_dict.keys(): + tensor = state_dict[layer_name] + global_layer_offset = self.layer_number - 1 # self.layer_number starts at 1 + layer_key = f'{prefix}{global_layer_offset - offset}.{layer_name}' # module list index in TransformerBlock + sharded_offsets = [(0, global_layer_offset, num_layers)] # PP sharding + + # TODO: move it to MLP after merging the "sharded_state_dict modularization" MR + is_glu_weight = ( + layer_name == 'mlp.linear_fc1.weight' and self.mlp.config.gated_linear_unit + ) + + if layer_name in tensor_parallel_layers_axis_map: + tp_axis = tensor_parallel_layers_axis_map[layer_name] + # TP sharding + if not is_glu_weight: + sharded_offsets.append( + [ + tp_axis + 1, # +1 for PP dimension + parallel_state.get_tensor_model_parallel_rank(), + parallel_state.get_tensor_model_parallel_world_size(), + ] + ) + replica_id = parallel_state.get_data_parallel_rank() + else: + replica_id = ( + parallel_state.get_data_parallel_rank() + * parallel_state.get_data_parallel_world_size() + + parallel_state.get_tensor_model_parallel_rank() + ) + + if layer_name.endswith('._extra_state'): + sharded_state_dict[layer_key] = ShardedObject( + f'{prefix}{layer_name}', + tensor, + (num_layers,), + (global_layer_offset,), + replica_id, + ) + elif is_glu_weight: + # We must split the tensor into 2 parts, each sharded separately. + # This requires a ShardedTensorFactory which `chunk`s during saving + # and `cat`s during loading + assert tp_axis == 0, f'TP axis for GLU weight should be 0, got: {tp_axis}' + tp_rank = parallel_state.get_tensor_model_parallel_rank() + tp_size = parallel_state.get_tensor_model_parallel_world_size() + + sh_ten_builder = partial( + ShardedTensor.from_rank_offsets, replica_id=replica_id, prepend_axis_num=1 + ) # for PP sharding + + # NOTE: passing `tp_axis` as argument due to late binding in closures + def sh_ten_build_fn(key: str, t: torch.Tensor, tp_axis=tp_axis): + offset_w = (tp_axis + 1, tp_rank, tp_size * 2) + offset_v = (tp_axis + 1, tp_size + tp_rank, tp_size * 2) + with torch.no_grad(): + tensor_w, tensor_v = torch.chunk(t, 2, dim=tp_axis) + return [ + sh_ten_builder(key, tensor_w, *sharded_offsets, offset_w), + sh_ten_builder(key, tensor_v, *sharded_offsets, offset_v), + ] + + def sh_ten_merge_fn(sub_state_dict): + with torch.no_grad(): + return torch.cat(sub_state_dict) + + sharded_state_dict[layer_key] = ShardedTensorFactory( + f'{prefix}{layer_name}', tensor, sh_ten_build_fn, sh_ten_merge_fn + ) + else: + sharded_state_dict[layer_key] = ShardedTensor.from_rank_offsets( + f'{prefix}{layer_name}', + tensor, + *sharded_offsets, + replica_id=replica_id, + prepend_axis_num=1, # for PP sharding + ) + + return sharded_state_dict diff --git a/training/DeepSpeed-Domino/megatron/core/transformer/utils.py b/training/DeepSpeed-Domino/megatron/core/transformer/utils.py new file mode 100644 index 000000000..c3740f848 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/transformer/utils.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Utilities for transformer layers.""" + +import torch + +from megatron import get_args + + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + +def get_linear_layer(rows, columns, init_method): + """Simple linear layer with weight initialization.""" + layer = torch.nn.Linear(rows, columns) + if get_args().perform_initialization: + init_method(layer.weight) + with torch.no_grad(): + layer.bias.zero_() + return layer + + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) + + +def openai_gelu(x): + return gelu_impl(x) + + +# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter +@torch.jit.script +def erf_gelu(x): + return ( + x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype)) + ) diff --git a/training/DeepSpeed-Domino/megatron/core/utils.py b/training/DeepSpeed-Domino/megatron/core/utils.py new file mode 100644 index 000000000..7a0bc385c --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/core/utils.py @@ -0,0 +1,207 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""Utility functions used throughout Megatron core""" +import math +import operator +from functools import reduce + +import torch + +from megatron.core import parallel_state +from megatron.core.dist_checkpointing.mapping import ShardedTensor + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def get_attr_wrapped_model(model, attr, allow_none=True): + """Get an attribute from a wrapped model""" + if isinstance(model, list): + raise RuntimeError("_get_attr_wrapped_model given a list of models") + + if allow_none: + + def condition(model, attr): + return not hasattr(model, attr) + + else: + + def condition(model, attr): + return getattr(model, attr, None) is None + + while condition(model, attr): + if not hasattr(model, "module"): + raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}") + + model = model.module + return getattr(model, attr) + + +def get_model_type(model): + return get_attr_wrapped_model(model, 'model_type') + + +def get_model_config(model): + return get_attr_wrapped_model(model, 'config', allow_none=False) + + +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations. + Caller should ensure that buffers of the same name + are not used concurrently.""" + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name): + required_len = reduce(operator.mul, tensor_shape, 1) + if ( + self.buffer.get((name, dtype), None) is None + or self.buffer[(name, dtype)].numel() < required_len + ): + self.buffer[(name, dtype)] = torch.empty( + required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False + ) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) + + +def _kernel_make_viewless_tensor(inp, requires_grad): + '''Make a viewless tensor. + + View tensors have the undesirable side-affect of retaining a reference + to the originally-viewed tensor, even after manually setting the '.data' + field. This method creates a new tensor that links to the old tensor's + data, without linking the viewed tensor, referenced via the '._base' + field. + ''' + out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=requires_grad,) + out.data = inp.data + return out + + +class MakeViewlessTensor(torch.autograd.Function): + ''' + Autograd function to make a viewless tensor. + + This function should be used in cases where the computation graph needs + to be propagated, but we only want a viewless tensor (e.g., + ParallelTransformer's hidden_states). Call this function by passing + 'keep_graph = True' to 'make_viewless_tensor()'. + ''' + + @staticmethod + def forward(ctx, inp, requires_grad): + return _kernel_make_viewless_tensor(inp, requires_grad) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +def make_viewless_tensor(inp, requires_grad, keep_graph): + ''' + Entry-point for creating viewless tensors. + + This method should be used, rather than calling 'MakeViewlessTensor' + or '_kernel_make_viewless_tensor' directly. This method acts as a + switch for determining if an autograd function or a regular method + should be used to create the tensor. + ''' + + # return tensor as-is, if not a 'view' + if inp._base is None: + return inp + + # create viewless tensor + if keep_graph: + return MakeViewlessTensor.apply(inp, requires_grad) + else: + return _kernel_make_viewless_tensor(inp, requires_grad) + + +def assert_viewless_tensor(tensor, extra_msg=None): + '''Assert that a tensor is not a view (i.e., its '._base' field is + not set).''' + if isinstance(tensor, list): + [assert_viewless_tensor(t) for t in tensor] + return tensor + if not isinstance(tensor, torch.Tensor): + return tensor + assert tensor._base is None, ( + "Ensure tensor._base is None before setting tensor.data or storing " + "tensor to memory buffer. Otherwise, a memory leak will occur (and " + "likely accumulate over iterations). %s" + ) % extra_msg + return tensor + + +def safely_set_viewless_tensor_data(tensor, new_data_tensor): + '''Safely set tensor's '.data' field. + + Check first that the tensor is viewless (i.e., '._base' not set). If not, + raise an exception. + ''' + assert_viewless_tensor( + tensor, + extra_msg="FYI, tensor._base has shape %s, and new_data_tensor has shape %s." + % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape), + ) + tensor.data = new_data_tensor + + +def init_method_normal(sigma): + """Init method based on N(0, sigma).""" + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + + return init_ + + +def scaled_init_method_normal(sigma, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = sigma / math.sqrt(2.0 * num_layers) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +def make_tp_sharded_tensor_for_checkpoint(tensor, key, tp_axis=0, replica_id=None, **kwargs): + """ Helper for instantiating a ShardedTensor where the `tp_axis` dimension is sharded across TP group. """ + + return ShardedTensor.from_rank_offsets( + key, + tensor, + ( + tp_axis, + parallel_state.get_tensor_model_parallel_rank(), + parallel_state.get_tensor_model_parallel_world_size(), + ), + replica_id=parallel_state.get_data_parallel_rank() if replica_id is None else replica_id, + **kwargs, + ) + + +def make_sharded_tensor_for_checkpoint(tensor, key, **kwargs): + """ Helper for instantiating a non-sharded ShardedTensor (replicated across TP and DP group). """ + + return ShardedTensor.from_rank_offsets( + key, + tensor, + replica_id=parallel_state.get_data_parallel_rank() + * parallel_state.get_data_parallel_world_size() + + parallel_state.get_tensor_model_parallel_rank(), + **kwargs, + ) diff --git a/training/DeepSpeed-Domino/megatron/data/Makefile b/training/DeepSpeed-Domino/megatron/data/Makefile new file mode 100644 index 000000000..8f9db7686 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/Makefile @@ -0,0 +1,9 @@ +CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color +CPPFLAGS += $(shell python3 -m pybind11 --includes) +LIBNAME = helpers +LIBEXT = $(shell python3-config --extension-suffix) + +default: $(LIBNAME)$(LIBEXT) + +%$(LIBEXT): %.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ diff --git a/training/DeepSpeed-Domino/megatron/data/__init__.py b/training/DeepSpeed-Domino/megatron/data/__init__.py new file mode 100644 index 000000000..cd5f898c6 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/__init__.py @@ -0,0 +1 @@ +from . import indexed_dataset diff --git a/training/DeepSpeed-Domino/megatron/data/autoaugment.py b/training/DeepSpeed-Domino/megatron/data/autoaugment.py new file mode 100644 index 000000000..585a4fa6a --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/autoaugment.py @@ -0,0 +1,320 @@ +"""AutoAugment data augmentation policy for ImageNet. + +-- Begin license text. + +MIT License + +Copyright (c) 2018 Philip Popien + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +-- End license text. + +Code adapted from https://github.com/DeepVoltaire/AutoAugment. + +This module implements the fixed AutoAugment data augmentation policy for ImageNet provided in +Appendix A, Table 9 of reference [1]. It does not include any of the search code for augmentation +policies. + +Reference: +[1] https://arxiv.org/abs/1805.09501 +""" + +import random + +import numpy as np +from PIL import Image +from PIL import ImageEnhance +from PIL import ImageOps + +_MAX_LEVEL = 10 # Maximum integer strength of an augmentation, if applicable. + + +class ImageNetPolicy: + """Definition of an ImageNetPolicy. + + Implements a fixed AutoAugment data augmentation policy targeted at + ImageNet training by randomly applying at runtime one of the 25 pre-defined + data augmentation sub-policies provided in Reference [1]. + + Usage example as a Pytorch Transform: + >>> transform=transforms.Compose([transforms.Resize(256), + >>> ImageNetPolicy(), + >>> transforms.ToTensor()]) + """ + + def __init__(self, fillcolor=(128, 128, 128)): + """Initialize an ImageNetPolicy. + + Args: + fillcolor (tuple): RGB color components of the color to be used for + filling when needed (default: (128, 128, 128), which + corresponds to gray). + """ + # Instantiate a list of sub-policies. + # Each entry of the list is a SubPolicy which consists of + # two augmentation operations, + # each of those parametrized as operation, probability, magnitude. + # Those two operations are applied sequentially on the image upon call. + self.policies = [ + SubPolicy("posterize", 0.4, 8, "rotate", 0.6, 9, fillcolor), + SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor), + SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor), + SubPolicy("posterize", 0.6, 7, "posterize", 0.6, 6, fillcolor), + SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor), + SubPolicy("equalize", 0.4, 4, "rotate", 0.8, 8, fillcolor), + SubPolicy("solarize", 0.6, 3, "equalize", 0.6, 7, fillcolor), + SubPolicy("posterize", 0.8, 5, "equalize", 1.0, 2, fillcolor), + SubPolicy("rotate", 0.2, 3, "solarize", 0.6, 8, fillcolor), + SubPolicy("equalize", 0.6, 8, "posterize", 0.4, 6, fillcolor), + SubPolicy("rotate", 0.8, 8, "color", 0.4, 0, fillcolor), + SubPolicy("rotate", 0.4, 9, "equalize", 0.6, 2, fillcolor), + SubPolicy("equalize", 0.0, 7, "equalize", 0.8, 8, fillcolor), + SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor), + SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor), + SubPolicy("rotate", 0.8, 8, "color", 1.0, 2, fillcolor), + SubPolicy("color", 0.8, 8, "solarize", 0.8, 7, fillcolor), + SubPolicy("sharpness", 0.4, 7, "invert", 0.6, 8, fillcolor), + SubPolicy("shearX", 0.6, 5, "equalize", 1.0, 9, fillcolor), + SubPolicy("color", 0.4, 0, "equalize", 0.6, 3, fillcolor), + SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor), + SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor), + SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor), + SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor), + SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor), + ] + + def __call__(self, img): + """Define call method for ImageNetPolicy class.""" + policy_idx = random.randint(0, len(self.policies) - 1) + return self.policies[policy_idx](img) + + def __repr__(self): + """Define repr method for ImageNetPolicy class.""" + return "ImageNetPolicy" + + +class SubPolicy: + """Definition of a SubPolicy. + + A SubPolicy consists of two augmentation operations, + each of those parametrized as operation, probability, magnitude. + The two operations are applied sequentially on the image upon call. + """ + + def __init__( + self, + operation1, + probability1, + magnitude_idx1, + operation2, + probability2, + magnitude_idx2, + fillcolor, + ): + """Initialize a SubPolicy. + + Args: + operation1 (str): Key specifying the first augmentation operation. + There are fourteen key values altogether (see supported_ops below + listing supported operations). probability1 (float): Probability + within [0., 1.] of applying the first augmentation operation. + magnitude_idx1 (int): Integer specifiying the strength of the first + operation as an index further used to derive the magnitude from a + range of possible values. + operation2 (str): Key specifying the second augmentation operation. + probability2 (float): Probability within [0., 1.] of applying the + second augmentation operation. + magnitude_idx2 (int): Integer specifiying the strength of the + second operation as an index further used to derive the magnitude + from a range of possible values. + fillcolor (tuple): RGB color components of the color to be used for + filling. + Returns: + """ + # List of supported operations for operation1 and operation2. + supported_ops = [ + "shearX", + "shearY", + "translateX", + "translateY", + "rotate", + "color", + "posterize", + "solarize", + "contrast", + "sharpness", + "brightness", + "autocontrast", + "equalize", + "invert", + ] + assert (operation1 in supported_ops) and ( + operation2 in supported_ops + ), "SubPolicy:one of oper1 or oper2 refers to an unsupported operation." + + assert ( + 0.0 <= probability1 <= 1.0 and 0.0 <= probability2 <= 1.0 + ), "SubPolicy: prob1 and prob2 should be within [0., 1.]." + + assert ( + isinstance(magnitude_idx1, int) and 0 <= magnitude_idx1 <= 10 + ), "SubPolicy: idx1 should be specified as an integer within [0, 10]." + + assert ( + isinstance(magnitude_idx2, int) and 0 <= magnitude_idx2 <= 10 + ), "SubPolicy: idx2 should be specified as an integer within [0, 10]." + + # Define a dictionary where each key refers to a specific type of + # augmentation and the corresponding value is a range of ten possible + # magnitude values for that augmentation. + num_levels = _MAX_LEVEL + 1 + ranges = { + "shearX": np.linspace(0, 0.3, num_levels), + "shearY": np.linspace(0, 0.3, num_levels), + "translateX": np.linspace(0, 150 / 331, num_levels), + "translateY": np.linspace(0, 150 / 331, num_levels), + "rotate": np.linspace(0, 30, num_levels), + "color": np.linspace(0.0, 0.9, num_levels), + "posterize": np.round(np.linspace(8, 4, num_levels), 0).astype( + np.int + ), + "solarize": np.linspace(256, 0, num_levels), # range [0, 256] + "contrast": np.linspace(0.0, 0.9, num_levels), + "sharpness": np.linspace(0.0, 0.9, num_levels), + "brightness": np.linspace(0.0, 0.9, num_levels), + "autocontrast": [0] + * num_levels, # This augmentation doesn't use magnitude parameter. + "equalize": [0] + * num_levels, # This augmentation doesn't use magnitude parameter. + "invert": [0] + * num_levels, # This augmentation doesn't use magnitude parameter. + } + + def rotate_with_fill(img, magnitude): + """Define rotation transformation with fill. + + The input image is first rotated, then it is blended together with + a gray mask of the same size. Note that fillcolor as defined + elsewhere in this module doesn't apply here. + + Args: + magnitude (float): rotation angle in degrees. + Returns: + rotated_filled (PIL Image): rotated image with gray filling for + disoccluded areas unveiled by the rotation. + """ + rotated = img.convert("RGBA").rotate(magnitude) + rotated_filled = Image.composite( + rotated, Image.new("RGBA", rotated.size, (128,) * 4), rotated + ) + return rotated_filled.convert(img.mode) + + # Define a dictionary of augmentation functions where each key refers + # to a specific type of augmentation and the corresponding value defines + # the augmentation itself using a lambda function. + # pylint: disable=unnecessary-lambda + func_dict = { + "shearX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, + fillcolor=fillcolor, + ), + "shearY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), + Image.BICUBIC, + fillcolor=fillcolor, + ), + "translateX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + ( + 1, + 0, + magnitude * img.size[0] * random.choice([-1, 1]), + 0, + 1, + 0, + ), + fillcolor=fillcolor, + ), + "translateY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + ( + 1, + 0, + 0, + 0, + 1, + magnitude * img.size[1] * random.choice([-1, 1]), + ), + fillcolor=fillcolor, + ), + "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( + 1 + magnitude * random.choice([-1, 1]) + ), + "posterize": lambda img, magnitude: ImageOps.posterize( + img, magnitude + ), + "solarize": lambda img, magnitude: ImageOps.solarize( + img, magnitude + ), + "contrast": lambda img, magnitude: ImageEnhance.Contrast( + img + ).enhance(1 + magnitude * random.choice([-1, 1])), + "sharpness": lambda img, magnitude: ImageEnhance.Sharpness( + img + ).enhance(1 + magnitude * random.choice([-1, 1])), + "brightness": lambda img, magnitude: ImageEnhance.Brightness( + img + ).enhance(1 + magnitude * random.choice([-1, 1])), + "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: ImageOps.equalize(img), + "invert": lambda img, magnitude: ImageOps.invert(img), + } + + # Store probability, function and magnitude of the first augmentation + # for the sub-policy. + self.probability1 = probability1 + self.operation1 = func_dict[operation1] + self.magnitude1 = ranges[operation1][magnitude_idx1] + + # Store probability, function and magnitude of the second augmentation + # for the sub-policy. + self.probability2 = probability2 + self.operation2 = func_dict[operation2] + self.magnitude2 = ranges[operation2][magnitude_idx2] + + def __call__(self, img): + """Define call method for SubPolicy class.""" + # Randomly apply operation 1. + if random.random() < self.probability1: + img = self.operation1(img, self.magnitude1) + + # Randomly apply operation 2. + if random.random() < self.probability2: + img = self.operation2(img, self.magnitude2) + + return img diff --git a/training/DeepSpeed-Domino/megatron/data/bert_dataset.py b/training/DeepSpeed-Domino/megatron/data/bert_dataset.py new file mode 100644 index 000000000..036e6bccc --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/bert_dataset.py @@ -0,0 +1,183 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""BERT Style dataset.""" + +import numpy as np +import torch + +from megatron import ( + get_args, + get_tokenizer, + mpu, + print_rank_0 +) +from megatron.data.dataset_utils import ( + get_samples_mapping, + get_a_and_b_segments, + truncate_segments, + create_tokens_and_tokentypes, + create_masked_lm_predictions +) + +class BertDataset(torch.utils.data.Dataset): + + def __init__(self, name, indexed_dataset, data_prefix, + num_epochs, max_num_samples, masked_lm_prob, + max_seq_length, short_seq_prob, seed, binary_head): + + # Params to store. + self.name = name + self.seed = seed + self.masked_lm_prob = masked_lm_prob + self.max_seq_length = max_seq_length + self.binary_head = binary_head + + # Dataset. + self.indexed_dataset = indexed_dataset + + # Build the samples mapping. + self.samples_mapping = get_samples_mapping(self.indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + self.max_seq_length - 3, # account for added tokens + short_seq_prob, + self.seed, + self.name, + self.binary_head) + + # Vocab stuff. + tokenizer = get_tokenizer() + self.vocab_id_list = list(tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_dict = tokenizer.inv_vocab + self.cls_id = tokenizer.cls + self.sep_id = tokenizer.sep + self.mask_id = tokenizer.mask + self.pad_id = tokenizer.pad + + def __len__(self): + return self.samples_mapping.shape[0] + + def __getitem__(self, idx): + start_idx, end_idx, seq_length = self.samples_mapping[idx] + sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)] + # Note that this rng state should be numpy and not python since + # python randint is inclusive whereas the numpy one is exclusive. + # We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1 + np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32)) + return build_training_sample(sample, seq_length, + self.max_seq_length, # needed for padding + self.vocab_id_list, + self.vocab_id_to_token_dict, + self.cls_id, self.sep_id, + self.mask_id, self.pad_id, + self.masked_lm_prob, np_rng, + self.binary_head) + + + + +def build_training_sample(sample, + target_seq_length, max_seq_length, + vocab_id_list, vocab_id_to_token_dict, + cls_id, sep_id, mask_id, pad_id, + masked_lm_prob, np_rng, binary_head): + """Biuld training sample. + + Arguments: + sample: A list of sentences in which each sentence is a list token ids. + target_seq_length: Desired sequence length. + max_seq_length: Maximum length of the sequence. All values are padded to + this length. + vocab_id_list: List of vocabulary ids. Used to pick a random id. + vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. + cls_id: Start of example id. + sep_id: Separator id. + mask_id: Mask token id. + pad_id: Padding token id. + masked_lm_prob: Probability to mask tokens. + np_rng: Random number genenrator. Note that this rng state should be + numpy and not python since python randint is inclusive for + the opper bound whereas the numpy one is exclusive. + """ + + if binary_head: + # We assume that we have at least two sentences in the sample + assert len(sample) > 1 + assert target_seq_length <= max_seq_length + + # Divide sample into two segments (A and B). + if binary_head: + tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, + np_rng) + else: + tokens_a = [] + for j in range(len(sample)): + tokens_a.extend(sample[j]) + tokens_b = [] + is_next_random = False + + # Truncate to `target_sequence_length`. + max_num_tokens = target_seq_length + truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), + len(tokens_b), max_num_tokens, np_rng) + + # Build tokens and toketypes. + tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, + cls_id, sep_id) + + # Masking. + max_predictions_per_seq = masked_lm_prob * max_num_tokens + (tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions( + tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, + cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng) + + # Padding. + tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ + = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, + masked_labels, pad_id, max_seq_length) + + train_sample = { + 'text': tokens_np, + 'types': tokentypes_np, + 'labels': labels_np, + 'is_random': int(is_next_random), + 'loss_mask': loss_mask_np, + 'padding_mask': padding_mask_np, + 'truncated': int(truncated)} + return train_sample + + +def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, + masked_labels, pad_id, max_seq_length): + """Pad sequences and convert them to numpy.""" + + # Some checks. + num_tokens = len(tokens) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0, \ + f"num_tokens ({num_tokens}) is greater than " \ + "max_seq_length ({max_seq_length})." + assert len(tokentypes) == num_tokens + assert len(masked_positions) == len(masked_labels) + + # Tokens and token types. + filler = [pad_id] * padding_length + tokens_np = np.array(tokens + filler, dtype=np.int64) + tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) + + # Padding mask. + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, + dtype=np.int64) + + # Lables and loss mask. + labels = [-1] * max_seq_length + loss_mask = [0] * max_seq_length + for i in range(len(masked_positions)): + assert masked_positions[i] < num_tokens + labels[masked_positions[i]] = masked_labels[i] + loss_mask[masked_positions[i]] = 1 + labels_np = np.array(labels, dtype=np.int64) + loss_mask_np = np.array(loss_mask, dtype=np.int64) + + return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np diff --git a/training/DeepSpeed-Domino/megatron/data/biencoder_dataset_utils.py b/training/DeepSpeed-Domino/megatron/data/biencoder_dataset_utils.py new file mode 100644 index 000000000..c08f06792 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/biencoder_dataset_utils.py @@ -0,0 +1,209 @@ +import os +import time + +import numpy as np +import torch + +from megatron import get_args, get_tokenizer, print_rank_0 +from megatron.core import mpu, tensor_parallel +from megatron.data.dataset_utils import create_masked_lm_predictions, \ + pad_and_convert_to_numpy +from megatron.data.data_samplers import MegatronPretrainingSampler + +def make_attention_mask(source_block, target_block): + """ + Returns a 2-dimensional (2-D) attention mask + :param source_block: 1-D array + :param target_block: 1-D array + """ + mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) + mask = mask.astype(np.int64) + # (source_length, target_length) + return mask + +def get_one_epoch_dataloader(dataset, micro_batch_size=None): + """Specifically one epoch to be used in an indexing job.""" + args = get_args() + + if micro_batch_size is None: + micro_batch_size = args.micro_batch_size + num_workers = args.num_workers + + # Use megatron's sampler with consumed samples set to 0 as + # this is only for evaluation and don't intend to resume half way. + # Also, set the drop last to false as don't intend to remove + # the last batch + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=0, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size(), + drop_last=False) + + return torch.utils.data.DataLoader(dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + pin_memory=True) + + +def get_ict_batch(data_iterator): + # Items and their type. + keys = ['query_tokens', 'query_mask', + 'context_tokens', 'context_mask', 'block_data'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is None: + data = None + else: + data = next(data_iterator) + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + query_tokens = data_b['query_tokens'].long() + query_mask = data_b['query_mask'] < 0.5 + context_tokens = data_b['context_tokens'].long() + context_mask = data_b['context_mask'] < 0.5 + block_indices = data_b['block_data'].long() + + return query_tokens, query_mask,\ + context_tokens, context_mask, block_indices + + +def join_str_list(str_list): + """Join a list of strings, handling spaces appropriately""" + result = "" + for s in str_list: + if s.startswith("##"): + result += s[2:] + else: + result += " " + s + return result + + +class BlockSampleData(object): + """A struct for fully describing a fixed-size block of data as used in REALM + + :param start_idx: for first sentence of the block + :param end_idx: for last sentence of the block (may be partially truncated in sample construction) + :param doc_idx: the index of the document from which the block comes in the original indexed dataset + :param block_idx: a unique integer identifier given to every block. + """ + def __init__(self, start_idx, end_idx, doc_idx, block_idx): + self.start_idx = start_idx + self.end_idx = end_idx + self.doc_idx = doc_idx + self.block_idx = block_idx + + def as_array(self): + return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64) + + def as_tuple(self): + return self.start_idx, self.end_idx, self.doc_idx, self.block_idx + + +class BlockSamplesMapping(object): + def __init__(self, mapping_array): + # make sure that the array is compatible with BlockSampleData + assert mapping_array.shape[1] == 4 + self.mapping_array = mapping_array + + def __len__(self): + return self.mapping_array.shape[0] + + def __getitem__(self, idx): + """Get the data associated with an indexed sample.""" + sample_data = BlockSampleData(*self.mapping_array[idx]) + return sample_data + + +def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs, + max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False): + """Get samples mapping for a dataset over fixed size blocks. This function also requires + a dataset of the titles for the source documents since their lengths must be taken into account. + + :return: samples_mapping (BlockSamplesMapping) + """ + + if not num_epochs: + if not max_num_samples: + raise ValueError("Need to specify either max_num_samples " + "or num_epochs") + num_epochs = np.iinfo(np.int32).max - 1 + if not max_num_samples: + max_num_samples = np.iinfo(np.int64).max - 1 + + # Filename of the index mapping + indexmap_filename = data_prefix + indexmap_filename += '_{}_indexmap'.format(name) + if num_epochs != (np.iinfo(np.int32).max - 1): + indexmap_filename += '_{}ep'.format(num_epochs) + if max_num_samples != (np.iinfo(np.int64).max - 1): + indexmap_filename += '_{}mns'.format(max_num_samples) + indexmap_filename += '_{}msl'.format(max_seq_length) + indexmap_filename += '_{}s'.format(seed) + if use_one_sent_docs: + indexmap_filename += '_1sentok' + indexmap_filename += '.npy' + + # Build the indexed mapping if not exist. + if mpu.get_data_parallel_rank() == 0 and \ + not os.path.isfile(indexmap_filename): + print(' > WARNING: could not find index map file {}, building ' + 'the indices on rank 0 ...'.format(indexmap_filename)) + + # Make sure the types match the helpers input types. + assert block_dataset.doc_idx.dtype == np.int64 + assert block_dataset.sizes.dtype == np.int32 + + # Build samples mapping + verbose = torch.distributed.get_rank() == 0 + start_time = time.time() + print_rank_0(' > building samples index mapping for {} ...'.format( + name)) + + from megatron.data import helpers + mapping_array = helpers.build_blocks_mapping( + block_dataset.doc_idx, + block_dataset.sizes, + title_dataset.sizes, + num_epochs, + max_num_samples, + max_seq_length - 3, # account for added tokens + seed, + verbose, + use_one_sent_docs) + + + print_rank_0(' > done building samples index mapping') + np.save(indexmap_filename, mapping_array, allow_pickle=True) + print_rank_0(' > saved the index mapping in {}'.format( + indexmap_filename)) + # Make sure all the ranks have built the mapping + print_rank_0(' > elapsed time to build and save samples mapping ' + '(seconds): {:4f}'.format( + time.time() - start_time)) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + assert counts[0].item() == torch.distributed.get_world_size( + group=mpu.get_data_parallel_group()) + + # Load indexed dataset. + print_rank_0(' > loading indexed mapping from {}'.format( + indexmap_filename)) + start_time = time.time() + + mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') + samples_mapping = BlockSamplesMapping(mapping_array) + + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + print_rank_0(' total number of samples: {}'.format( + mapping_array.shape[0])) + + return samples_mapping diff --git a/training/DeepSpeed-Domino/megatron/data/blendable_dataset.py b/training/DeepSpeed-Domino/megatron/data/blendable_dataset.py new file mode 100644 index 000000000..8ff5ce3da --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/blendable_dataset.py @@ -0,0 +1,124 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Blendable dataset.""" + +import hashlib +import os +import time + +import numpy as np +import torch + +from megatron import print_rank_0 +from megatron.core import mpu + +class BlendableDataset(torch.utils.data.Dataset): + + + def __init__(self, datasets, weights, size, *, + data_cache_path=None): + + self.datasets = datasets + num_datasets = len(datasets) + assert num_datasets == len(weights) + + self.size = size + + # Normalize weights. + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + + # Build indicies. + def _build_indices(): + start_time = time.time() + assert num_datasets < 255 + dataset_index = np.zeros(self.size, dtype=np.uint8) + dataset_sample_index = np.zeros(self.size, dtype=np.int64) + + from megatron.data import helpers + helpers.build_blending_indices(dataset_index, dataset_sample_index, + weights, num_datasets, self.size, + torch.distributed.get_rank() == 0) + print_rank_0('> elapsed time for building blendable dataset indices: ' + '{:.2f} (sec)'.format(time.time() - start_time)) + return dataset_index, dataset_sample_index + + desc = "Blendable dataset\n\n" + desc += "Datasets:\n" + for dataset in datasets: + desc += dataset.desc + "\n\n" + desc += f"Weights: {weights}\n" + desc += f"Size: {size}\n" + self.desc = desc + + if data_cache_path: + desc_hash = hashlib.md5(desc.encode('utf-8')).hexdigest() + desc_path = os.path.join(data_cache_path, desc_hash + ".dsc") + index_path = os.path.join(data_cache_path, desc_hash + "_index.npy") + sample_index_path = os.path.join(data_cache_path, desc_hash + "_sample_index.npy") + cache_hit = os.path.isfile(index_path) and os.path.isfile(sample_index_path) + cache_success = True + if torch.distributed.get_rank() == 0 and not cache_hit: + print(' > WARNING: could not find index map files for blendable' + ' dataset, building indices on rank 0 ...', flush=True) + dataset_index, dataset_sample_index = _build_indices() + try: + os.makedirs(os.path.dirname(index_path), exist_ok=True) + with open(desc_path, 'wt') as fd: + fd.write(desc) + np.save(index_path, dataset_index, allow_pickle=True) + np.save(sample_index_path, dataset_sample_index, + allow_pickle=True) + except OSError: + print(f'There was an error trying to create the data cache directory ({data_cache_path})') + print('or a file in it. This is set with the --data-cache-path argument. Please') + print('ensure you have write access to this directory or specify one that you do have') + print('write access to.') + cache_success = False + + + counts = torch.cuda.LongTensor([cache_success]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + if counts[0].item() != ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())): + print_rank_0("Data index creation unsuccessful, exiting.") + exit() + + # Load on all ranks. + print_rank_0(f'> loading blendable dataset index: {index_path}') + self.dataset_index = np.load(index_path, allow_pickle=True, mmap_mode='r') + assert self.dataset_index.size == self.size + + print_rank_0(f'> loading blendable dataset sample index: {sample_index_path}') + self.dataset_sample_index = np.load(sample_index_path, allow_pickle=True, mmap_mode='r') + assert self.dataset_sample_index.size == self.size + else: + self.dataset_index, self.dataset_sample_index = _build_indices() + + + # Check size + _ = self.__getitem__(self.size - 1) + try: + _ = self.__getitem__(self.size) + raise RuntimeError('BlendedDataset size is improperly bounded') + except IndexError: + pass + print_rank_0('> size of blendable dataset: ' + '{} samples'.format(self.size)) + + + def __len__(self): + return self.size + + + def __getitem__(self, idx): + dataset_idx = self.dataset_index[idx] + sample_idx = self.dataset_sample_index[idx] + return { + "dataset_idx" : dataset_idx, + **self.datasets[dataset_idx][sample_idx], + } diff --git a/training/DeepSpeed-Domino/megatron/data/data_samplers.py b/training/DeepSpeed-Domino/megatron/data/data_samplers.py new file mode 100644 index 000000000..8dec2c192 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/data_samplers.py @@ -0,0 +1,186 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Dataloaders.""" + + +import random +import torch +import numpy as np +from torch.utils.data import Dataset +from megatron import get_args +from megatron.core import mpu + + +def build_pretraining_data_loader(dataset, consumed_samples): + """Buld dataloader given an input dataset.""" + + if dataset is None: + return None + args = get_args() + + # Megatron sampler + if args.dataloader_type == 'single': + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size()) + elif args.dataloader_type == 'cyclic': + batch_sampler = MegatronPretrainingRandomSampler( + dataset, + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size(), + data_sharding=args.data_sharding) + else: + raise Exception('{} dataloader type is not supported.'.format( + args.dataloader_type)) + + # Torch dataloader. + return torch.utils.data.DataLoader(dataset, + batch_sampler=batch_sampler, + num_workers=args.num_workers, + pin_memory=True) + +class MegatronPretrainingSampler: + + def __init__(self, total_samples, consumed_samples, micro_batch_size, + data_parallel_rank, data_parallel_size, drop_last=True): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.micro_batch_times_data_parallel_size = \ + self.micro_batch_size * data_parallel_size + self.drop_last = drop_last + + # Sanity checks. + assert self.total_samples > 0, \ + 'no sample to consume: {}'.format(self.total_samples) + assert self.consumed_samples < self.total_samples, \ + 'no samples left to consume: {}, {}'.format(self.consumed_samples, + self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, \ + 'data_parallel_rank should be smaller than data size: {}, ' \ + '{}'.format(self.data_parallel_rank, data_parallel_size) + + def __len__(self): + return self.total_samples + + def get_start_end_idx(self): + start_idx = self.data_parallel_rank * self.micro_batch_size + end_idx = start_idx + self.micro_batch_size + return start_idx, end_idx + + def __iter__(self): + batch = [] + # Last batch will be dropped if drop_last is not set False + for idx in range(self.consumed_samples, self.total_samples): + batch.append(idx) + if len(batch) == self.micro_batch_times_data_parallel_size: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + + +class RandomSeedDataset(Dataset): + + def __init__(self, dataset): + args = get_args() + self.base_seed = args.seed + self.curr_seed = args.seed + self.dataset = dataset + + def __len__(self): + return len(self.dataset) + + def set_epoch(self, epoch): + self.curr_seed = self.base_seed + epoch + + def __getitem__(self, idx): + seed = idx + self.curr_seed + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + return self.dataset[idx] + + +class MegatronPretrainingRandomSampler: + + def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size, + data_parallel_rank, data_parallel_size, data_sharding): + # Keep a copy of input params for later use. + self.dataset = dataset + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.micro_batch_size = micro_batch_size + self.data_parallel_rank = data_parallel_rank + self.data_parallel_size = data_parallel_size + self.data_sharding = data_sharding + self.micro_batch_times_data_parallel_size = \ + self.micro_batch_size * data_parallel_size + self.last_batch_size = \ + self.total_samples % self.micro_batch_times_data_parallel_size + + # Sanity checks. + assert self.total_samples > 0, \ + 'no sample to consume: {}'.format(self.total_samples) + assert self.micro_batch_size > 0 + assert data_parallel_size > 0 + assert self.data_parallel_rank < data_parallel_size, \ + 'data_parallel_rank should be smaller than data size: {}, ' \ + '{}'.format(self.data_parallel_rank, data_parallel_size) + + def __len__(self): + return self.total_samples + + def __iter__(self): + active_total_samples = self.total_samples - self.last_batch_size + self.epoch = self.consumed_samples // active_total_samples + current_epoch_samples = self.consumed_samples % active_total_samples + assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 + + if isinstance(self.dataset, RandomSeedDataset): + self.dataset.set_epoch(self.epoch) + + # data sharding and random sampling + if self.data_sharding: + bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ + * self.micro_batch_size + bucket_offset = current_epoch_samples // self.data_parallel_size + start_idx = self.data_parallel_rank * bucket_size + + g = torch.Generator() + g.manual_seed(self.epoch) + random_idx = torch.randperm(bucket_size, generator=g).tolist() + idx_range = [start_idx + x for x in random_idx[bucket_offset:]] + else: + full_bucket_size = (self.total_samples // self.micro_batch_size) \ + * self.micro_batch_size + full_bucket_offset = current_epoch_samples + g = torch.Generator() + g.manual_seed(self.epoch) + idx_range_total = \ + torch.randperm(full_bucket_size, generator=g).tolist() + idx_range_active = idx_range_total[full_bucket_offset:] + idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size] + + batch = [] + # Last batch if not complete will be dropped. + for idx in idx_range: + batch.append(idx) + if len(batch) == self.micro_batch_size: + self.consumed_samples += self.micro_batch_times_data_parallel_size + yield batch + batch = [] diff --git a/training/DeepSpeed-Domino/megatron/data/dataset_utils.py b/training/DeepSpeed-Domino/megatron/data/dataset_utils.py new file mode 100644 index 000000000..571d3141e --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/dataset_utils.py @@ -0,0 +1,806 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors, and NVIDIA. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Most of the code here has been copied from: +# https://github.com/google-research/albert/blob/master/create_pretraining_data.py +# with some modifications. + +import math +import os +import time +import collections + +import numpy as np +import torch + +from megatron import ( + get_args, + print_rank_0 +) +from megatron.core import mpu +from megatron.data.blendable_dataset import BlendableDataset +from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset + +DSET_TYPE_BERT = 'standard_bert' +DSET_TYPE_ICT = 'ict' +DSET_TYPE_T5 = 't5' +DSET_TYPE_MULTIMODAL = 'multimodal' + +DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5, DSET_TYPE_MULTIMODAL] + + +def get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples): + + # The data prefix should be in the format of: + # weight-1, data-prefix-1, weight-2, data-prefix-2, .. + assert len(data_prefix) % 2 == 0 + num_datasets = len(data_prefix) // 2 + weights = [0]*num_datasets + prefixes = [0]*num_datasets + for i in range(num_datasets): + weights[i] = float(data_prefix[2*i]) + prefixes[i] = (data_prefix[2*i+1]).strip() + # Normalize weights + weight_sum = 0.0 + for weight in weights: + weight_sum += weight + assert weight_sum > 0.0 + weights = [weight / weight_sum for weight in weights] + + # Add 0.5% (the 1.005 factor) so in case the bleding dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + if isinstance(train_valid_test_num_samples, list): + datasets_train_valid_test_num_samples = [] + for weight in weights: + datasets_train_valid_test_num_samples.append( + [int(math.ceil(val * weight * 1.005)) + for val in train_valid_test_num_samples]) + else: + # Used when separate dataset files are provided for train, + # valid and test + datasets_train_valid_test_num_samples = [ + int(math.ceil(train_valid_test_num_samples * weight * 1.005)) + for weight in weights] + + return prefixes, weights, datasets_train_valid_test_num_samples + + +def compile_helper(): + """Compile helper function ar runtime. Make sure this + is invoked on a single process.""" + import os + import subprocess + path = os.path.abspath(os.path.dirname(__file__)) + ret = subprocess.run(['make', '-C', path]) + if ret.returncode != 0: + print("Making C++ dataset helpers module failed, exiting.") + import sys + sys.exit(1) + + +def get_a_and_b_segments(sample, np_rng): + """Divide sample into a and b segments.""" + + # Number of sentences in the sample. + n_sentences = len(sample) + # Make sure we always have two sentences. + assert n_sentences > 1, 'make sure each sample has at least two sentences.' + + # First part: + # `a_end` is how many sentences go into the `A`. + a_end = 1 + if n_sentences >= 3: + # Note that randin in numpy is exclusive. + a_end = np_rng.randint(1, n_sentences) + tokens_a = [] + for j in range(a_end): + tokens_a.extend(sample[j]) + + # Second part: + tokens_b = [] + for j in range(a_end, n_sentences): + tokens_b.extend(sample[j]) + + # Random next: + is_next_random = False + if np_rng.random() < 0.5: + is_next_random = True + tokens_a, tokens_b = tokens_b, tokens_a + + return tokens_a, tokens_b, is_next_random + + +def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): + """Truncates a pair of sequences to a maximum sequence length.""" + #print(len_a, len_b, max_num_tokens) + assert len_a > 0 + if len_a + len_b <= max_num_tokens: + return False + while len_a + len_b > max_num_tokens: + if len_a > len_b: + len_a -= 1 + tokens = tokens_a + else: + len_b -= 1 + tokens = tokens_b + if np_rng.random() < 0.5: + del tokens[0] + else: + tokens.pop() + return True + + +def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): + """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" + + tokens = [] + tokentypes = [] + # [CLS]. + tokens.append(cls_id) + tokentypes.append(0) + # Segment A. + for token in tokens_a: + tokens.append(token) + tokentypes.append(0) + # [SEP]. + tokens.append(sep_id) + tokentypes.append(0) + # Segment B. + for token in tokens_b: + tokens.append(token) + tokentypes.append(1) + if tokens_b: + # [SEP]. + tokens.append(sep_id) + tokentypes.append(1) + + return tokens, tokentypes + + +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", + ["index", "label"]) + + +def is_start_piece(piece): + """Check if the current word piece is the starting piece (BERT).""" + # When a word has been split into + # WordPieces, the first token does not have any marker and any subsequence + # tokens are prefixed with ##. So whenever we see the ## token, we + # append it to the previous set of word indexes. + return not piece.startswith("##") + + +def create_masked_lm_predictions(tokens, + vocab_id_list, vocab_id_to_token_dict, + masked_lm_prob, + cls_id, sep_id, mask_id, + max_predictions_per_seq, + np_rng, + max_ngrams=3, + do_whole_word_mask=True, + favor_longer_ngram=False, + do_permutation=False, + geometric_dist=False, + masking_style="bert"): + """Creates the predictions for the masked LM objective. + Note: Tokens here are vocab ids and not text tokens.""" + + cand_indexes = [] + # Note(mingdachen): We create a list for recording if the piece is + # the starting piece of current token, where 1 means true, so that + # on-the-fly whole word masking is possible. + token_boundary = [0] * len(tokens) + + for (i, token) in enumerate(tokens): + if token == cls_id or token == sep_id: + token_boundary[i] = 1 + continue + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if (do_whole_word_mask and len(cand_indexes) >= 1 and + not is_start_piece(vocab_id_to_token_dict[token])): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + if is_start_piece(vocab_id_to_token_dict[token]): + token_boundary[i] = 1 + + output_tokens = list(tokens) + + masked_lm_positions = [] + masked_lm_labels = [] + + if masked_lm_prob == 0: + return (output_tokens, masked_lm_positions, + masked_lm_labels, token_boundary) + + num_to_predict = min(max_predictions_per_seq, + max(1, int(round(len(tokens) * masked_lm_prob)))) + + ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) + if not geometric_dist: + # Note(mingdachen): + # By default, we set the probilities to favor shorter ngram sequences. + pvals = 1. / np.arange(1, max_ngrams + 1) + pvals /= pvals.sum(keepdims=True) + if favor_longer_ngram: + pvals = pvals[::-1] + + ngram_indexes = [] + for idx in range(len(cand_indexes)): + ngram_index = [] + for n in ngrams: + ngram_index.append(cand_indexes[idx:idx + n]) + ngram_indexes.append(ngram_index) + + np_rng.shuffle(ngram_indexes) + + (masked_lms, masked_spans) = ([], []) + covered_indexes = set() + for cand_index_set in ngram_indexes: + if len(masked_lms) >= num_to_predict: + break + if not cand_index_set: + continue + # Note(mingdachen): + # Skip current piece if they are covered in lm masking or previous ngrams. + for index_set in cand_index_set[0]: + for index in index_set: + if index in covered_indexes: + continue + + if not geometric_dist: + n = np_rng.choice(ngrams[:len(cand_index_set)], + p=pvals[:len(cand_index_set)] / + pvals[:len(cand_index_set)].sum(keepdims=True)) + else: + # Sampling "n" from the geometric distribution and clipping it to + # the max_ngrams. Using p=0.2 default from the SpanBERT paper + # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1) + n = min(np_rng.geometric(0.2), max_ngrams) + + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # Note(mingdachen): + # Repeatedly looking for a candidate that does not exceed the + # maximum number of predictions by trying shorter ngrams. + while len(masked_lms) + len(index_set) > num_to_predict: + if n == 0: + break + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + masked_token = None + if masking_style == "bert": + # 80% of the time, replace with [MASK] + if np_rng.random() < 0.8: + masked_token = mask_id + else: + # 10% of the time, keep original + if np_rng.random() < 0.5: + masked_token = tokens[index] + # 10% of the time, replace with random word + else: + masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] + elif masking_style == "t5": + masked_token = mask_id + else: + raise ValueError("invalid value of masking style") + + output_tokens[index] = masked_token + masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) + + masked_spans.append(MaskedLmInstance( + index=index_set, + label=[tokens[index] for index in index_set])) + + assert len(masked_lms) <= num_to_predict + np_rng.shuffle(ngram_indexes) + + select_indexes = set() + if do_permutation: + for cand_index_set in ngram_indexes: + if len(select_indexes) >= num_to_predict: + break + if not cand_index_set: + continue + # Note(mingdachen): + # Skip current piece if they are covered in lm masking or previous ngrams. + for index_set in cand_index_set[0]: + for index in index_set: + if index in covered_indexes or index in select_indexes: + continue + + n = np.random.choice(ngrams[:len(cand_index_set)], + p=pvals[:len(cand_index_set)] / + pvals[:len(cand_index_set)].sum(keepdims=True)) + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + + while len(select_indexes) + len(index_set) > num_to_predict: + if n == 0: + break + index_set = sum(cand_index_set[n - 1], []) + n -= 1 + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(select_indexes) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes or index in select_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + select_indexes.add(index) + assert len(select_indexes) <= num_to_predict + + select_indexes = sorted(select_indexes) + permute_indexes = list(select_indexes) + np_rng.shuffle(permute_indexes) + orig_token = list(output_tokens) + + for src_i, tgt_i in zip(select_indexes, permute_indexes): + output_tokens[src_i] = orig_token[tgt_i] + masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i])) + + masked_lms = sorted(masked_lms, key=lambda x: x.index) + # Sort the spans by the index of the first span + masked_spans = sorted(masked_spans, key=lambda x: x.index[0]) + + for p in masked_lms: + masked_lm_positions.append(p.index) + masked_lm_labels.append(p.label) + return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans) + + +def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, + masked_labels, pad_id, max_seq_length): + """Pad sequences and convert them to numpy.""" + + # Some checks. + num_tokens = len(tokens) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0 + assert len(tokentypes) == num_tokens + assert len(masked_positions) == len(masked_labels) + + # Tokens and token types. + filler = [pad_id] * padding_length + tokens_np = np.array(tokens + filler, dtype=np.int64) + tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) + + # Padding mask. + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, + dtype=np.int64) + + # Lables and loss mask. + labels = [-1] * max_seq_length + loss_mask = [0] * max_seq_length + for i in range(len(masked_positions)): + assert masked_positions[i] < num_tokens + labels[masked_positions[i]] = masked_labels[i] + loss_mask[masked_positions[i]] = 1 + labels_np = np.array(labels, dtype=np.int64) + loss_mask_np = np.array(loss_mask, dtype=np.int64) + + return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np + + +def build_train_valid_test_datasets_with_prefixes(data_impl, + train_valid_test_num_samples, + max_seq_length, + seed, + skip_warmup, + train_data_prefix=None, + valid_data_prefix=None, + test_data_prefix=None, + binary_head=False, + max_seq_length_dec=None, + dataset_type='standard_bert'): + print_rank_0("Separate data paths provided for train, valid & test.") + + train_dataset, valid_dataset, test_dataset = None, None, None + # Single dataset. + if train_data_prefix is not None: + train_dataset = build_dataset("train", train_data_prefix, data_impl, + train_valid_test_num_samples[0], + max_seq_length, seed, skip_warmup, + binary_head, max_seq_length_dec, + dataset_type=dataset_type) + + if valid_data_prefix is not None: + valid_dataset = build_dataset("valid", valid_data_prefix, data_impl, + train_valid_test_num_samples[1], + max_seq_length, seed, False, + binary_head, max_seq_length_dec, + dataset_type=dataset_type) + + if test_data_prefix is not None: + test_dataset = build_dataset("test", test_data_prefix, data_impl, + train_valid_test_num_samples[2], + max_seq_length, seed, False, + binary_head, max_seq_length_dec, + dataset_type=dataset_type) + + return (train_dataset, valid_dataset, test_dataset) + + +def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, seed, + skip_warmup, binary_head=False, + max_seq_length_dec=None, + dataset_type='standard_bert'): + + if len(data_prefix) == 1: + return _build_train_valid_test_datasets(data_prefix[0], + data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, seed, + skip_warmup, + binary_head, + max_seq_length_dec, + dataset_type=dataset_type) + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + train_num_samples, valid_num_samples, test_num_samples = map( + sum, + zip(*datasets_train_valid_test_num_samples) + ) + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + prefixes[i], data_impl, splits_string, + datasets_train_valid_test_num_samples[i], + max_seq_length, seed, skip_warmup, binary_head, + max_seq_length_dec, dataset_type=dataset_type) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights, train_num_samples) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_num_samples) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights, test_num_samples) + + return (blending_train_dataset, blending_valid_dataset, + blending_test_dataset) + + +def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, seed, + skip_warmup, binary_head, + max_seq_length_dec, + dataset_type='standard_bert'): + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + dataset_type, + skip_warmup) + + # Get start and end indices of train/valid/train into doc-idx + # Note that doc-idx is desinged to be num-docs + 1 so we can + # easily iterate over it. + total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + print_rank_0(' > dataset split:') + + def print_split_stats(name, index): + print_rank_0(' {}:'.format(name)) + print_rank_0(' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[index], splits[index + 1], + splits[index + 1] - splits[index])) + start_index = indexed_dataset.doc_idx[splits[index]] + end_index = indexed_dataset.doc_idx[splits[index + 1]] + print_rank_0(' sentence indices in [{}, {}) total of {} ' + 'sentences'.format(start_index, end_index, + end_index - start_index)) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_split_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + # Get the pointer to the original doc-idx so we can set it later. + doc_idx_ptr = indexed_dataset.get_doc_idx() + # Slice the doc-idx + start_index = splits[index] + # Add +1 so we can index into the dataset to get the upper bound. + end_index = splits[index + 1] + 1 + # New doc_idx view. + indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) + + dataset = build_dataset( + name, data_prefix, data_impl, + train_valid_test_num_samples[index], max_seq_length, + seed, skip_warmup, binary_head, max_seq_length_dec, + dataset_type, indexed_dataset) + + # Set the original pointer so dataset remains the main dataset. + indexed_dataset.set_doc_idx(doc_idx_ptr) + # Checks. + assert indexed_dataset.doc_idx[0] == 0 + assert indexed_dataset.doc_idx.shape[0] == \ + (total_num_of_documents + 1) + return dataset + + train_dataset = build_split_dataset(0, 'train') + valid_dataset = build_split_dataset(1, 'valid') + test_dataset = build_split_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +def build_dataset(name, data_prefix, data_impl, max_num_samples, + max_seq_length, seed, skip_warmup, binary_head, + max_seq_length_dec, dataset_type='standard_bert', + indexed_dataset=None): + + from megatron.data.bert_dataset import BertDataset + from megatron.data.ict_dataset import ICTDataset + from megatron.data.t5_dataset import T5Dataset + from megatron.data.multimodal_dataset import MultiModalDataset + + if dataset_type not in DSET_TYPES: + raise ValueError("Invalid dataset_type: ", dataset_type) + + if indexed_dataset is None: + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + dataset_type, + skip_warmup) + + kwargs = dict( + name=name, + data_prefix=data_prefix, + num_epochs=None, + max_num_samples=max_num_samples, + max_seq_length=max_seq_length, + seed=seed, + ) + + if dataset_type == DSET_TYPE_ICT: + args = get_args() + + title_dataset = get_indexed_dataset_( + args.titles_data_path, + data_impl, + dataset_type, + skip_warmup) + + dataset = ICTDataset( + block_dataset=indexed_dataset, + title_dataset=title_dataset, + query_in_block_prob=args.query_in_block_prob, + use_one_sent_docs=args.use_one_sent_docs, + binary_head=binary_head, + **kwargs + ) + elif dataset_type == DSET_TYPE_T5: + args = get_args() + dataset = T5Dataset( + indexed_dataset=indexed_dataset, + masked_lm_prob=args.mask_prob, + max_seq_length_dec=max_seq_length_dec, + short_seq_prob=args.short_seq_prob, + **kwargs + ) + elif dataset_type == DSET_TYPE_BERT: + args = get_args() + dataset = BertDataset( + indexed_dataset=indexed_dataset, + masked_lm_prob=args.mask_prob, + short_seq_prob=args.short_seq_prob, + binary_head=binary_head, + **kwargs + ) + elif dataset_type == DSET_TYPE_MULTIMODAL: + args = get_args() + dataset = MultiModalDataset( + name=name, + data_prefix=data_prefix, + indexed_dataset=indexed_dataset, + num_samples=max_num_samples, + seq_length=max_seq_length, + seed=seed, + img_h=args.img_h, + img_w=args.img_w, + ) + else: + raise NotImplementedError("Dataset type not fully implemented.") + + return dataset + + +def get_indexed_dataset_(data_prefix, data_impl, dataset_type, skip_warmup): + + print_rank_0(' > building dataset index ...') + + start_time = time.time() + multimodal = dataset_type == DSET_TYPE_MULTIMODAL + indexed_dataset = make_indexed_dataset(data_prefix, + data_impl, + skip_warmup, + multimodal) + assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] + print_rank_0(' > finished creating indexed dataset in {:4f} ' + 'seconds'.format(time.time() - start_time)) + + print_rank_0(' > indexed dataset stats:') + print_rank_0(' number of documents: {}'.format( + indexed_dataset.doc_idx.shape[0] - 1)) + print_rank_0(' number of sentences: {}'.format( + indexed_dataset.sizes.shape[0])) + + return indexed_dataset + + +def get_train_valid_test_split_(splits_string, size): + """ Get dataset splits from comma or '/' separated string list.""" + + splits = [] + if splits_string.find(',') != -1: + splits = [float(s) for s in splits_string.split(',')] + elif splits_string.find('/') != -1: + splits = [float(s) for s in splits_string.split('/')] + else: + splits = [float(splits_string)] + while len(splits) < 3: + splits.append(0.) + splits = splits[:3] + splits_sum = sum(splits) + assert splits_sum > 0.0 + splits = [split / splits_sum for split in splits] + splits_index = [0] + for index, split in enumerate(splits): + splits_index.append(splits_index[index] + + int(round(split * float(size)))) + diff = splits_index[-1] - size + for index in range(1, len(splits_index)): + splits_index[index] -= diff + assert len(splits_index) == 4 + assert splits_index[-1] == size + return splits_index + +def get_samples_mapping(indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + max_seq_length, + short_seq_prob, + seed, + name, + binary_head): + """Get a list that maps a sample index to a starting sentence index, end sentence index, and length""" + + if not num_epochs: + if not max_num_samples: + raise ValueError("Need to specify either max_num_samples " + "or num_epochs") + num_epochs = np.iinfo(np.int32).max - 1 + if not max_num_samples: + max_num_samples = np.iinfo(np.int64).max - 1 + + # Filename of the index mapping + indexmap_filename = data_prefix + indexmap_filename += '_{}_indexmap'.format(name) + if num_epochs != (np.iinfo(np.int32).max - 1): + indexmap_filename += '_{}ep'.format(num_epochs) + if max_num_samples != (np.iinfo(np.int64).max - 1): + indexmap_filename += '_{}mns'.format(max_num_samples) + indexmap_filename += '_{}msl'.format(max_seq_length) + indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) + indexmap_filename += '_{}s'.format(seed) + indexmap_filename += '.npy' + + # Build the indexed mapping if not exist. + if torch.distributed.get_rank() == 0 and \ + not os.path.isfile(indexmap_filename): + print(' > WARNING: could not find index map file {}, building ' + 'the indices on rank 0 ...'.format(indexmap_filename)) + + # Make sure the types match the helpers input types. + assert indexed_dataset.doc_idx.dtype == np.int64 + assert indexed_dataset.sizes.dtype == np.int32 + + # Build samples mapping + verbose = torch.distributed.get_rank() == 0 + start_time = time.time() + print_rank_0(' > building samples index mapping for {} ...'.format( + name)) + # First compile and then import. + from megatron.data import helpers + samples_mapping = helpers.build_mapping( + indexed_dataset.doc_idx, + indexed_dataset.sizes, + num_epochs, + max_num_samples, + max_seq_length, + short_seq_prob, + seed, + verbose, + 2 if binary_head else 1) + print_rank_0(' > done building samples index maping') + np.save(indexmap_filename, samples_mapping, allow_pickle=True) + print_rank_0(' > saved the index mapping in {}'.format( + indexmap_filename)) + # Make sure all the ranks have built the mapping + print_rank_0(' > elasped time to build and save samples mapping ' + '(seconds): {:4f}'.format( + time.time() - start_time)) + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + + # Load indexed dataset. + print_rank_0(' > loading indexed mapping from {}'.format( + indexmap_filename)) + start_time = time.time() + samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + print_rank_0(' total number of samples: {}'.format( + samples_mapping.shape[0])) + + return samples_mapping diff --git a/training/DeepSpeed-Domino/megatron/data/gpt_dataset.py b/training/DeepSpeed-Domino/megatron/data/gpt_dataset.py new file mode 100644 index 000000000..088748bc9 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/gpt_dataset.py @@ -0,0 +1,590 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""GPT style dataset.""" + +import hashlib +import os +import time + +import numpy as np +import torch + +from megatron import print_rank_0 +from megatron.core import mpu +from megatron.data.blendable_dataset import BlendableDataset +from megatron.data.dataset_utils import get_datasets_weights_and_num_samples +from megatron.data.dataset_utils import get_train_valid_test_split_ +from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset + + +def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + seq_length, seed, skip_warmup, + train_data_prefix=None, + valid_data_prefix=None, + test_data_prefix=None, + return_doc_ids=False, *, + data_cache_path=None): + """Build train, valid, and test datasets.""" + + if data_prefix: + print_rank_0("Single data path provided for train, valid & test") + + # Single dataset. + if len(data_prefix) == 1: + return _build_train_valid_test_datasets(data_prefix[0], + data_impl, splits_string, + train_valid_test_num_samples, + seq_length, seed, skip_warmup, + data_cache_path=data_cache_path) + + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + train_num_samples, valid_num_samples, test_num_samples = map( + sum, + zip(*datasets_train_valid_test_num_samples) + ) + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + prefixes[i], data_impl, splits_string, + datasets_train_valid_test_num_samples[i], + seq_length, seed, skip_warmup, + return_doc_ids, + data_cache_path=data_cache_path) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights, train_num_samples, + data_cache_path=data_cache_path) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_num_samples, + data_cache_path=data_cache_path) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights, test_num_samples, + data_cache_path=data_cache_path) + + return (blending_train_dataset, blending_valid_dataset, + blending_test_dataset) + + else: + print_rank_0("Separate data paths provided for train, valid & test. Split string will be ignored.") + + train_dataset, valid_dataset, test_dataset = None, None, None + # Single dataset. + if train_data_prefix is not None: + train_dataset = build_dataset("train", train_data_prefix, data_impl, + splits_string, + train_valid_test_num_samples[0], + seq_length, seed, skip_warmup, + data_cache_path=data_cache_path) + + if valid_data_prefix is not None: + valid_dataset = build_dataset("valid", valid_data_prefix, data_impl, + splits_string, + train_valid_test_num_samples[1], + seq_length, seed, False, + data_cache_path=data_cache_path) + + + if test_data_prefix is not None: + test_dataset = build_dataset("test", test_data_prefix, data_impl, + splits_string, + train_valid_test_num_samples[2], + seq_length, seed, False, + data_cache_path=data_cache_path) + + return (train_dataset, valid_dataset, test_dataset) + + +def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + seq_length, seed, skip_warmup, + return_doc_ids=False, *, + data_cache_path=None): + """Build train, valid, and test datasets.""" + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + total_num_of_documents = indexed_dataset.sizes.shape[0] + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + print_rank_0(' > dataset split:') + + def print_split_stats(name, index): + print_rank_0(' {}:'.format(name)) + print_rank_0(' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[index], splits[index + 1], + splits[index + 1] - splits[index])) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + documents = np.arange(start=splits[index], stop=splits[index + 1], + step=1, dtype=np.int32) + dataset = GPTDataset(name, data_prefix, documents, indexed_dataset, + splits_string, + train_valid_test_num_samples[index], + seq_length, seed, + return_doc_ids, + data_cache_path=data_cache_path) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +def build_dataset(dataset_name, data_prefix, data_impl, + splits_string, num_samples, + seq_length, seed, skip_warmup, + *, + data_cache_path=None): + dataset = None + if len(data_prefix) == 1: + dataset = _build_dataset(dataset_name, data_prefix[0], data_impl, + splits_string, num_samples, seq_length, + seed, skip_warmup, + data_cache_path=data_cache_path) + else: + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, num_samples) + prefixes, weights, dataset_num_samples = output + num_samples = sum(dataset_num_samples) + + # Build individual datasets. + datasets = [] + for i in range(len(prefixes)): + ds = _build_dataset(dataset_name, prefixes[i], data_impl, + splits_string, dataset_num_samples[i], + seq_length, seed, skip_warmup, + data_cache_path=data_cache_path) + if ds: + datasets.append(ds) + + if datasets: + dataset = BlendableDataset(datasets, weights, num_samples, + data_cache_path=data_cache_path) + + return dataset + + +def _build_dataset(dataset_name, data_prefix, data_impl, splits_string, + num_samples, seq_length, seed, skip_warmup, + *, + data_cache_path=None): + """ + Build dataset. This method is called when individual + train, valid, test datasets are provided + """ + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + total_num_of_documents = indexed_dataset.sizes.shape[0] + + print_rank_0(' {}:'.format(dataset_name)) + print_rank_0(' document indices in [0, {}) total of {} ' + 'documents'.format(total_num_of_documents, total_num_of_documents)) + + documents = np.arange(start=0, stop=total_num_of_documents, + step=1, dtype=np.int32) + + dataset = GPTDataset(dataset_name, data_prefix, documents, indexed_dataset, + splits_string, num_samples, seq_length, seed, + data_cache_path=data_cache_path) + + return dataset + + +def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): + """Build indexed dataset.""" + print_rank_0(' > building dataset index ...') + + start_time = time.time() + indexed_dataset = make_indexed_dataset(data_prefix, + data_impl, + skip_warmup) + print_rank_0(' > finished creating indexed dataset in {:4f} ' + 'seconds'.format(time.time() - start_time)) + print_rank_0(' number of documents: {}'.format( + indexed_dataset.sizes.shape[0])) + + return indexed_dataset + + +class GPTDataset(torch.utils.data.Dataset): + + def __init__(self, name, data_prefix, documents, indexed_dataset, + splits_string, num_samples, seq_length, seed, + return_doc_ids=False, *, + data_cache_path=None): + + self.name = name + self.indexed_dataset = indexed_dataset + self.return_doc_ids = return_doc_ids + + # Checks + assert np.min(documents) >= 0 + assert np.max(documents) < indexed_dataset.sizes.shape[0] + + # Build index mappings. + self.doc_idx, self.sample_idx, self.shuffle_idx, self.desc, self.desc_hash = \ + _build_index_mappings(self.name, data_prefix, + documents, self.indexed_dataset.sizes, + splits_string, num_samples, seq_length, seed, + data_cache_path=data_cache_path) + + + def __len__(self): + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + return self.sample_idx.shape[0] - 1 + + def __getitem__(self, idx): + # Get the shuffled index. + idx = self.shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f = self.sample_idx[idx][0] + doc_index_l = self.sample_idx[idx + 1][0] + offset_f = self.sample_idx[idx][1] + offset_l = self.sample_idx[idx + 1][1] + # If we are within the same document, just extract the chunk. + doc_ids = [] + if doc_index_f == doc_index_l: + doc_ids.append(self.doc_idx[doc_index_f]) + sample = self.indexed_dataset.get(self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1) + else: + # Otherwise, get the rest of the initial document. + doc_ids.append(self.doc_idx[doc_index_f]) + sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], + offset=offset_f)] + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + doc_ids.append(self.doc_idx[i]) + sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) + # And finally add the relevant portion of last document. + doc_ids.append(self.doc_idx[doc_index_l]) + sample_list.append(self.indexed_dataset.get( + self.doc_idx[doc_index_l], + length=offset_l + 1)) + sample = np.concatenate(sample_list) + + if self.return_doc_ids: # for retro preprocessing + return {'text': np.array(sample, dtype=np.int64), + 'doc_ids': np.array(doc_ids, dtype=np.int64)} + else: + return {'text': np.array(sample, dtype=np.int64)} + + +def _build_index_mappings(name, data_prefix, documents, sizes, + splits_string, num_samples, seq_length, seed, + *, + data_cache_path): + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + desc = "GPT Dataset\n\n" + desc += f"Data prefix {data_prefix}\n" + desc += f"Dataset name {name}\n" + desc += f"Number of samples {num_samples}\n" + desc += f"Sequence length {seq_length}\n" + desc += f"Random seed {seed}\n" + desc += f"Split {splits_string}\n" + desc_hash = hashlib.md5(desc.encode('utf-8')).hexdigest() + desc_filename = desc_hash + ".dsc" + doc_idx_filename = desc_hash + '_doc_idx.npy' + sample_idx_filename = desc_hash + '_sample_idx.npy' + shuffle_idx_filename = desc_hash + '_shuffle_idx.npy' + + # Look for cache in main data dir first to avoid unnecessary + # duplication, then look in data-cache-path if specified, + # If nothing is found, use the last path looked in + build_indices = True + prefixes = [os.path.join(os.path.dirname(data_prefix), 'index-cache')] + if data_cache_path is not None: + prefixes.append(data_cache_path) + for prefix in prefixes: + idx_path = { + 'desc': os.path.join(prefix, desc_filename), + 'doc': os.path.join(prefix, doc_idx_filename), + 'sample': os.path.join(prefix, sample_idx_filename), + 'shuffle': os.path.join(prefix, shuffle_idx_filename) + } + for f in idx_path.values(): + if not os.path.isfile(f): + break + else: + # Found our files! + build_indices = False + break + data_cache_dir = os.path.dirname(idx_path['desc']) + data_cache_success = True + + # Build the indexed mapping if not exist. + if build_indices and torch.distributed.get_rank() == 0: + print_rank_0(' > WARNING: could not find index map files, building ' + 'the indices on rank 0 ...') + + # For the last epoch, decide whether include the entire epoch + # in the global shuffle or not. + + # If we need only one epoch, then separating last epoch does + # not mean anything. + if num_epochs == 1: + separate_last_epoch = False + print(' > only one epoch required, setting ' + 'separate_last_epoch to False', flush=True) + + else: + # Get the number of samples for the last epoch + num_samples_from_epochs_minus_one = ( + (num_epochs - 1) * tokens_per_epoch - 1) // seq_length + last_epoch_num_samples = num_samples - \ + num_samples_from_epochs_minus_one + assert last_epoch_num_samples >= 0, \ + 'last epoch number of samples should be non-negative.' + num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length + assert last_epoch_num_samples <= (num_samples_per_epoch + 1), \ + 'last epoch number of samples exceeded max value.' + # If we have less than 80% of the samples for the last epoch, + # seperate out the epoch and treat it differently. + # Note: the 80% number is just based on common sense and can + # be adjusted if needed. + separate_last_epoch = (last_epoch_num_samples < + int(0.80 * num_samples_per_epoch)) + if separate_last_epoch: + string = ' > last epoch number of samples ({}) is smaller '\ + 'than 80% of number of samples per epoch ({}), '\ + 'setting separate_last_epoch to True' + else: + string = ' > last epoch number of samples ({}) is larger '\ + 'than 80% of number of samples per epoch ({}), '\ + 'setting separate_last_epoch to False' + print(string.format(last_epoch_num_samples, + num_samples_per_epoch), flush=True) + + + try: + os.makedirs(data_cache_dir, exist_ok=True) + + # description + with open(idx_path['desc'], 'wt') as fd: + fd.write(desc) + + # doc-idx. + start_time = time.time() + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, + separate_last_epoch) + np.save(idx_path['doc'], doc_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save doc-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + # First compile and then import. + from megatron.data import helpers + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, + num_epochs, tokens_per_epoch) + np.save(idx_path['sample'], sample_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save sample-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + if separate_last_epoch: + num_samples_ = num_samples_from_epochs_minus_one + else: + num_samples_ = sample_idx.shape[0] - 1 + shuffle_idx = _build_shuffle_idx(num_samples_, + sample_idx.shape[0] - 1, np_rng) + np.save(idx_path['shuffle'], shuffle_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save shuffle-idx mapping' + ' (seconds): {:4f}'.format(time.time() - start_time)) + except OSError: + print(f'There was an error trying to create the data cache directory ({data_cache_dir})') + print('or a file in it. This defaults to a directory "index-cache" within the directory') + print('the data files are in and can be set with the --data-cache-path argument. Please') + print('ensure you have write access to this directory or specify one that you do have') + print('write access to.') + data_cache_success = False + + counts = torch.cuda.LongTensor([data_cache_success]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + if counts[0].item() != ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())): + print_rank_0("Data index creation unsuccessful, exiting.") + exit() + + # Load mappings. + start_time = time.time() + print_rank_0(f" > loading doc-idx mapping from {idx_path['doc']}") + doc_idx = np.load(idx_path['doc'], allow_pickle=True, mmap_mode='r') + + print_rank_0(f" > loading sample-idx mapping from {idx_path['sample']}") + sample_idx = np.load(idx_path['sample'], allow_pickle=True, mmap_mode='r') + + print_rank_0(f" > loading shuffle-idx mapping from {idx_path['shuffle']}") + shuffle_idx = np.load(idx_path['shuffle'], allow_pickle=True, mmap_mode='r') + + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + print_rank_0(' total number of samples: {}'.format( + sample_idx.shape[0])) + print_rank_0(' total number of epochs: {}'.format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx, desc, desc_hash + + +def _num_tokens(documents, sizes): + """Total number of tokens in the dataset.""" + return np.sum(sizes[documents]) + + +def _num_epochs(tokens_per_epoch, seq_length, num_samples): + """Based on number of samples and sequence lenght, calculate how many + epochs will be needed.""" + num_epochs = 0 + total_tokens = 0 + while True: + num_epochs += 1 + total_tokens += tokens_per_epoch + # -1 is because we need to retrieve seq_length + 1 token each time + # but the last token will overlap with the first token of the next + # sample except for the last sample. + if ((total_tokens - 1) // seq_length) >= num_samples: + return num_epochs + + +def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch): + """Build an array with length = number-of-epochs * number-of-dcuments. + Each index is mapped to a corresponding document.""" + if not separate_last_epoch or num_epochs == 1: + doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1] + doc_idx[:] = documents + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + np_rng.shuffle(doc_idx) + return doc_idx + + doc_idx_first = _build_doc_idx(documents, num_epochs-1, np_rng, False) + doc_idx_last = _build_doc_idx(documents, 1, np_rng, False) + return np.concatenate((doc_idx_first, doc_idx_last)) + + +def _build_sample_idx(sizes, doc_idx, seq_length, + num_epochs, tokens_per_epoch): + """Sample index mapping is a 2D array with sizes + [number-of-samples + 1, 2] where [..., 0] contains + the index into `doc_idx` and [..., 1] is the + starting offset in that document.""" + + # Total number of samples. For -1 see comments in `_num_epochs`. + num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length + sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32) + + # Index into sample_idx. + sample_index = 0 + # Index into doc_idx. + doc_idx_index = 0 + # Begining offset for each document. + doc_offset = 0 + # Start with first document and no offset. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + while sample_index <= num_samples: + # Start with a fresh sequence. + remaining_seq_length = seq_length + 1 + while remaining_seq_length != 0: + # Get the document length. + doc_id = doc_idx[doc_idx_index] + doc_length = sizes[doc_id] - doc_offset + # And add it to the current sequence. + remaining_seq_length -= doc_length + # If we have more than a full sequence, adjust offset and set + # remaining length to zero so we return from the while loop. + # Note that -1 here is for the same reason we have -1 in + # `_num_epochs` calculations. + if remaining_seq_length <= 0: + doc_offset += (remaining_seq_length + doc_length - 1) + remaining_seq_length = 0 + else: + # Otherwise, start from the begining of the next document. + doc_idx_index += 1 + doc_offset = 0 + # Record the sequence. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + + return sample_idx + + +def _build_shuffle_idx(num_samples, total_size, np_rng): + """Build the range [0, size) and shuffle.""" + print(' > building shuffle index with split [0, {}) and [{}, {}) ' + '...'.format(num_samples, num_samples, total_size), flush=True) + + dtype_ = np.uint32 + if total_size >= (np.iinfo(np.uint32).max - 1): + dtype_ = np.int64 + + shuffle_idx_first = np.arange(start=0, stop=num_samples, + step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_first) + if num_samples == total_size: + return shuffle_idx_first + + shuffle_idx_last = np.arange(start=num_samples, stop=total_size, + step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_last) + + return np.concatenate((shuffle_idx_first, shuffle_idx_last)) + diff --git a/training/DeepSpeed-Domino/megatron/data/helpers.cpp b/training/DeepSpeed-Domino/megatron/data/helpers.cpp new file mode 100644 index 000000000..09f5f9762 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/helpers.cpp @@ -0,0 +1,701 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +/* Helper methods for fast index mapping builds */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +using namespace std; + +const int32_t LONG_SENTENCE_LEN = 512; + + +void build_blending_indices(py::array_t& dataset_index, + py::array_t& dataset_sample_index, + const py::array_t& weights, + const int32_t num_datasets, + const int64_t size, const bool verbose) { + /* Given multiple datasets and a weighting array, build samples + such that it follows those wieghts.*/ + + if (verbose) { + std::cout << "> building indices for blendable datasets ..." << std::endl; + } + + // Get the pointer access without the checks. + auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); + auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); + auto weights_ptr = weights.unchecked<1>(); + + // Initialize buffer for number of samples used for each dataset. + int64_t current_samples[num_datasets]; + for(int64_t i = 0; i < num_datasets; ++i) { + current_samples[i] = 0; + } + + // For each sample: + for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { + + // Determine where the max error in sampling is happening. + auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); + int64_t max_error_index = 0; + double max_error = weights_ptr[0] * sample_idx_double - + static_cast(current_samples[0]); + for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { + double error = weights_ptr[dataset_idx] * sample_idx_double - + static_cast(current_samples[dataset_idx]); + if (error > max_error) { + max_error = error; + max_error_index = dataset_idx; + } + } + + // Populate the indices. + dataset_index_ptr[sample_idx] = static_cast(max_error_index); + dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; + + // Update the total samples. + current_samples[max_error_index] += 1; + + } + + // print info + if (verbose) { + std::cout << " > sample ratios:" << std::endl; + for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { + auto ratio = static_cast(current_samples[dataset_idx]) / + static_cast(size); + std::cout << " dataset " << dataset_idx << ", input: " << + weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; + } + } + +} + + +py::array build_sample_idx(const py::array_t& sizes_, + const py::array_t& doc_idx_, + const int32_t seq_length, + const int32_t num_epochs, + const int64_t tokens_per_epoch) { + /* Sample index (sample_idx) is used for gpt2 like dataset for which + the documents are flattened and the samples are built based on this + 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] + where [..., 0] contains the index into `doc_idx` and [..., 1] is the + starting offset in that document.*/ + + // Consistency checks. + assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); + + // Remove bound checks. + auto sizes = sizes_.unchecked<1>(); + auto doc_idx = doc_idx_.unchecked<1>(); + + // Mapping and it's length (1D). + int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; + int32_t* sample_idx = new int32_t[2*(num_samples+1)]; + + cout << " using:" << endl << std::flush; + cout << " number of documents: " << + doc_idx_.shape(0) / num_epochs << endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " sequence length: " << seq_length << + endl << std::flush; + cout << " total number of samples: " << num_samples << + endl << std::flush; + + // Index into sample_idx. + int64_t sample_index = 0; + // Index into doc_idx. + int64_t doc_idx_index = 0; + // Begining offset for each document. + int32_t doc_offset = 0; + // Start with first document and no offset. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + + while (sample_index <= num_samples) { + // Start with a fresh sequence. + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { + // Get the document length. + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the begining of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + } + + // Method to deallocate memory. + py::capsule free_when_done(sample_idx, [](void *mem_) { + int32_t *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(int32_t); + return py::array(std::vector{num_samples+1, 2}, // shape + {2*byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references + +} + + +inline int32_t get_target_sample_len(const int32_t short_seq_ratio, + const int32_t max_length, + std::mt19937& rand32_gen) { + /* Training sample length. */ + if (short_seq_ratio == 0) { + return max_length; + } + const auto random_number = rand32_gen(); + if ((random_number % short_seq_ratio) == 0) { + return 2 + random_number % (max_length - 1); + } + return max_length; +} + + +template +py::array build_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const double short_seq_prob, + const int32_t seed, + const bool verbose, + const int32_t min_num_sent) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(short_seq_prob >= 0.0); + assert(short_seq_prob <= 1.0); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + + // For efficiency, convert probability to ratio. Note: rand() generates int. + int32_t short_seq_ratio = 0; + if (short_seq_prob > 0) { + short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); + } + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " short sequence probability: " << short_seq_prob << + endl << std::flush; + cout << " short sequence ration (1/prob): " << short_seq_ratio << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and it's length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the seed so both iterations produce the same results. + std::mt19937 rand32_gen(seed); + + // Set the flag on second iteration. + second = (iteration == 1); + + // Counters: + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + + // Current map index. + uint64_t map_index = 0; + + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent > 1) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + + // If we have more than two sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + auto target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and if not only one sentence is left in the document. + // and if we have at least two sentneces. + // and if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent > 1) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Check for overflow. + if ((3 * map_index + 2) > + std::numeric_limits::max()) { + cout << "number of samples exceeded maximum " + << "allowed by type int64: " + << std::numeric_limits::max() + << endl; + throw std::overflow_error("Number of samples"); + } + + // Populate the map. + if (second) { + const auto map_index_0 = 3 * map_index; + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(target_seq_len); + } + + // Update indices / counters. + ++map_index; + prev_start_index = sent_index + 1; + target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + seq_len = 0; + num_sent = 0; + } + + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[3*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 3 * i; + const auto j0 = 3 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 3}, // shape + {3*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + + +py::array build_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const double short_seq_prob, + const int seed, + const bool verbose, + const int32_t min_num_sent) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } +} + +template +py::array build_blocks_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const int32_t seed, + const bool verbose, + const bool use_one_sent_blocks) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + auto titles_sizes = titles_sizes_.unchecked<1>(); + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and its length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Acceptable number of sentences per block. + int min_num_sent = 2; + if (use_one_sent_blocks) { + min_num_sent = 1; + } + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the flag on second iteration. + second = (iteration == 1); + + // Current map index. + uint64_t map_index = 0; + + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + const auto target_seq_len = max_seq_length - titles_sizes[doc]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent >= min_num_sent) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + // If we have enough sentences and no long sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and there are an acceptable number of sentences left + // and if we have at least the minimum number of sentences. + // or if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent >= min_num_sent) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Populate the map. + if (second) { + const auto map_index_0 = 4 * map_index; + // Each sample has 4 items: the starting sentence index, ending sentence index, + // the index of the document from which the block comes (used for fetching titles) + // and the unique id of the block (used for creating block indexes) + + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(doc); + maps[map_index_0 + 3] = static_cast(block_id); + } + + // Update indices / counters. + ++map_index; + ++block_id; + prev_start_index = sent_index + 1; + seq_len = 0; + num_sent = 0; + } + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[4*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 4 * i; + const auto j0 = 4 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + swap(maps[i0 + 3], maps[j0 + 3]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 4}, // shape + {4*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + +py::array build_blocks_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const int seed, + const bool verbose, + const bool use_one_sent_blocks) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } +} + +PYBIND11_MODULE(helpers, m) { + m.def("build_mapping", &build_mapping); + m.def("build_blocks_mapping", &build_blocks_mapping); + m.def("build_sample_idx", &build_sample_idx); + m.def("build_blending_indices", &build_blending_indices); +} diff --git a/training/DeepSpeed-Domino/megatron/data/ict_dataset.py b/training/DeepSpeed-Domino/megatron/data/ict_dataset.py new file mode 100644 index 000000000..6dac35ff9 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/ict_dataset.py @@ -0,0 +1,156 @@ +import itertools +import random + +import numpy as np +from torch.utils.data import Dataset + +from megatron import get_tokenizer +from megatron import get_args +from megatron.data.dataset_utils import get_indexed_dataset_ +from megatron.data.realm_dataset_utils import get_block_samples_mapping + +def make_attention_mask(source_block, target_block): + """ + Returns a 2-dimensional (2-D) attention mask + :param source_block: 1-D array + :param target_block: 1-D array + """ + mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) + mask = mask.astype(np.int64) + # (source_length, target_length) + return mask + +def get_ict_dataset(use_titles=True, query_in_block_prob=1): + """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) + rather than for training, since it is only built with a single epoch sample mapping. + """ + args = get_args() + block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) + titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) + + kwargs = dict( + name='full', + block_dataset=block_dataset, + title_dataset=titles_dataset, + data_prefix=args.data_path, + num_epochs=1, + max_num_samples=None, + max_seq_length=args.seq_length, + seed=1, + query_in_block_prob=query_in_block_prob, + use_titles=use_titles, + use_one_sent_docs=args.use_one_sent_docs + ) + dataset = ICTDataset(**kwargs) + return dataset + + +class ICTDataset(Dataset): + """Dataset containing sentences and their blocks for an inverse cloze task.""" + def __init__(self, name, block_dataset, title_dataset, data_prefix, + num_epochs, max_num_samples, max_seq_length, query_in_block_prob, + seed, use_titles=True, use_one_sent_docs=False, binary_head=False): + self.name = name + self.seed = seed + self.max_seq_length = max_seq_length + self.query_in_block_prob = query_in_block_prob + self.block_dataset = block_dataset + self.title_dataset = title_dataset + self.rng = random.Random(self.seed) + self.use_titles = use_titles + self.use_one_sent_docs = use_one_sent_docs + + self.samples_mapping = get_block_samples_mapping( + block_dataset, title_dataset, data_prefix, num_epochs, + max_num_samples, max_seq_length, seed, name, use_one_sent_docs) + self.tokenizer = get_tokenizer() + self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_list = self.tokenizer.inv_vocab + self.cls_id = self.tokenizer.cls + self.sep_id = self.tokenizer.sep + self.mask_id = self.tokenizer.mask + self.pad_id = self.tokenizer.pad + + def __len__(self): + return len(self.samples_mapping) + + def __getitem__(self, idx): + """Get an ICT example of a pseudo-query and the block of text from which it was extracted""" + sample_data = self.samples_mapping[idx] + start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple() + + if self.use_titles: + title = self.title_dataset[int(doc_idx)] + title_pad_offset = 3 + len(title) + else: + title = None + title_pad_offset = 2 + block = [self.block_dataset[i] for i in range(start_idx, end_idx)] + assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1 + + # randint() is inclusive for Python rng + rand_sent_idx = self.rng.randint(0, len(block) - 1) + + # keep the query in the context query_in_block_prob fraction of the time. + if self.rng.random() < self.query_in_block_prob: + query = block[rand_sent_idx].copy() + else: + query = block.pop(rand_sent_idx) + + # still need to truncate because blocks are concluded when + # the sentence lengths have exceeded max_seq_length. + query = query[:self.max_seq_length - 2] + block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] + + query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) + context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title) + + query_mask = make_attention_mask(query_tokens, query_tokens) + context_mask = make_attention_mask(context_tokens, context_tokens) + + block_data = sample_data.as_array() + + sample = { + 'query_tokens': query_tokens, + 'query_mask': query_mask, + 'query_pad_mask': query_pad_mask, + 'context_tokens': context_tokens, + 'context_mask': context_mask, + 'context_pad_mask': context_pad_mask, + 'block_data': block_data, + } + + return sample + + def get_block(self, start_idx, end_idx, doc_idx): + """Get the IDs for an evidence block plus the title of the corresponding document""" + block = [self.block_dataset[i] for i in range(start_idx, end_idx)] + title = self.title_dataset[int(doc_idx)] + + block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] + block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) + + return block_tokens, block_pad_mask + + def get_null_block(self): + """Get empty block and title - used in REALM pretraining""" + block, title = [], [] + block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) + + return block_tokens, block_pad_mask + + def concat_and_pad_tokens(self, tokens, title=None): + """Concat with special tokens and pad sequence to self.max_seq_length""" + tokens = list(tokens) + if title is None: + tokens = [self.cls_id] + tokens + [self.sep_id] + else: + title = list(title) + tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id] + assert len(tokens) <= self.max_seq_length + + num_pad = self.max_seq_length - len(tokens) + pad_mask = [1] * len(tokens) + [0] * num_pad + tokens += [self.pad_id] * num_pad + + return np.array(tokens), np.array(pad_mask) diff --git a/training/DeepSpeed-Domino/megatron/data/image_folder.py b/training/DeepSpeed-Domino/megatron/data/image_folder.py new file mode 100644 index 000000000..de15b29bf --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/image_folder.py @@ -0,0 +1,302 @@ +# BSD 3-Clause License +# +# Copyright (c) Soumith Chintala 2016, +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# code taken from +# https://github.com/pytorch/vision/blob/main/torchvision/datasets/folder.py +# added support for classes_fraction and data_per_class_fraction + +from torchvision.datasets import VisionDataset +from PIL import Image + +import os +import os.path +from typing import Any, Callable, cast, Dict, List, Optional, Tuple +import numpy as np + +def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: + """Checks if a file is an allowed extension. + Args: + filename (string): path to a file + extensions (tuple of strings): extensions to consider (lowercase) + Returns: + bool: True if the filename ends with one of given extensions + """ + return filename.lower().endswith(extensions) + + +def is_image_file(filename: str) -> bool: + """Checks if a file is an allowed image extension. + Args: + filename (string): path to a file + Returns: + bool: True if the filename ends with a known image extension + """ + return has_file_allowed_extension(filename, IMG_EXTENSIONS) + + +def make_dataset( + directory: str, + class_to_idx: Dict[str, int], + data_per_class_fraction: float, + extensions: Optional[Tuple[str, ...]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, +) -> List[Tuple[str, int]]: + """Generates a list of samples of a form (path_to_sample, class). + Args: + directory (str): root dataset directory + class_to_idx (Dict[str, int]): dictionary mapping class name to class index + extensions (optional): A list of allowed extensions. + Either extensions or is_valid_file should be passed. Defaults to None. + is_valid_file (optional): A function that takes path of a file + and checks if the file is a valid file + (used to check of corrupt files) both extensions and + is_valid_file should not be passed. Defaults to None. + Raises: + ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. + Returns: + List[Tuple[str, int]]: samples of a form (path_to_sample, class) + """ + instances = [] + directory = os.path.expanduser(directory) + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + if both_none or both_something: + raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") + if extensions is not None: + def is_valid_file(x: str) -> bool: + return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) + is_valid_file = cast(Callable[[str], bool], is_valid_file) + for target_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[target_class] + target_dir = os.path.join(directory, target_class) + if not os.path.isdir(target_dir): + continue + local_instances = [] + for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + item = path, class_index + local_instances.append(item) + + instances.extend(local_instances[0:int(len(local_instances) * data_per_class_fraction)]) + + return instances + + +class DatasetFolder(VisionDataset): + """A generic data loader where the samples are arranged in this way: :: + root/class_x/xxx.ext + root/class_x/xxy.ext + root/class_x/[...]/xxz.ext + root/class_y/123.ext + root/class_y/nsdf3.ext + root/class_y/[...]/asd932_.ext + Args: + root (string): Root directory path. + loader (callable): A function to load a sample given its path. + extensions (tuple[string]): A list of allowed extensions. + both extensions and is_valid_file should not be passed. + transform (callable, optional): A function/transform that takes in + a sample and returns a transformed version. + E.g, ``transforms.RandomCrop`` for images. + target_transform (callable, optional): A function/transform that takes + in the target and transforms it. + is_valid_file (callable, optional): A function that takes path of a file + and check if the file is a valid file (used to check of corrupt files) + both extensions and is_valid_file should not be passed. + Attributes: + classes (list): List of the class names sorted alphabetically. + class_to_idx (dict): Dict with items (class_name, class_index). + samples (list): List of (sample path, class_index) tuples + targets (list): The class_index value for each image in the dataset + """ + + def __init__( + self, + root: str, + loader: Callable[[str], Any], + extensions: Optional[Tuple[str, ...]] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + classes_fraction=1.0, + data_per_class_fraction=1.0, + is_valid_file: Optional[Callable[[str], bool]] = None, + ) -> None: + super(DatasetFolder, self).__init__(root, transform=transform, + target_transform=target_transform) + self.classes_fraction = classes_fraction + self.data_per_class_fraction = data_per_class_fraction + classes, class_to_idx = self._find_classes(self.root) + samples = self.make_dataset(self.root, + class_to_idx, + self.data_per_class_fraction, + extensions, + is_valid_file) + if len(samples) == 0: + msg = "Found 0 files in subfolders of: {}\n".format(self.root) + if extensions is not None: + msg += "Supported extensions are: {}".format(",".join(extensions)) + raise RuntimeError(msg) + + self.loader = loader + self.extensions = extensions + self.total = len(samples) + self.classes = classes + self.class_to_idx = class_to_idx + self.samples = samples + self.targets = [s[1] for s in samples] + + @staticmethod + def make_dataset( + directory: str, + class_to_idx: Dict[str, int], + data_per_class_fraction: float, + extensions: Optional[Tuple[str, ...]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + ) -> List[Tuple[str, int]]: + return make_dataset(directory, + class_to_idx, + data_per_class_fraction, + extensions=extensions, + is_valid_file=is_valid_file) + + def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: + """ + Finds the class folders in a dataset. + Args: + dir (string): Root directory path. + Returns: + tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. + Ensures: + No class is a subdirectory of another. + """ + all_classes = [d.name for d in os.scandir(dir) if d.is_dir()] + classes = all_classes[0:int(len(all_classes) * self.classes_fraction)] + classes.sort() + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + return classes, class_to_idx + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + curr_index = index + for x in range(self.total): + try: + path, target = self.samples[curr_index] + sample = self.loader(path) + break + except Exception as e: + curr_index = np.random.randint(0, self.total) + + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target + + def __len__(self) -> int: + return len(self.samples) + + +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') + + +def pil_loader(path: str) -> Image.Image: + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + + +# TODO: specify the return type +def accimage_loader(path: str) -> Any: + import accimage + try: + return accimage.Image(path) + except IOError: + # Potentially a decoding problem, fall back to PIL.Image + return pil_loader(path) + + +def default_loader(path: str) -> Any: + from torchvision import get_image_backend + if get_image_backend() == 'accimage': + return accimage_loader(path) + else: + return pil_loader(path) + + +class ImageFolder(DatasetFolder): + """A generic data loader where the images are arranged in this way: :: + root/dog/xxx.png + root/dog/xxy.png + root/dog/[...]/xxz.png + root/cat/123.png + root/cat/nsdf3.png + root/cat/[...]/asd932_.png + Args: + root (string): Root directory path. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + is_valid_file (callable, optional): A function that takes path of an Image file + and check if the file is a valid file (used to check of corrupt files) + Attributes: + classes (list): List of the class names sorted alphabetically. + class_to_idx (dict): Dict with items (class_name, class_index). + imgs (list): List of (image path, class_index) tuples + """ + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + classes_fraction=1.0, + data_per_class_fraction=1.0, + loader: Callable[[str], Any] = default_loader, + is_valid_file: Optional[Callable[[str], bool]] = None, + ): + super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, + transform=transform, + target_transform=target_transform, + classes_fraction=classes_fraction, + data_per_class_fraction=data_per_class_fraction, + is_valid_file=is_valid_file) + self.imgs = self.samples + diff --git a/training/DeepSpeed-Domino/megatron/data/indexed_dataset.py b/training/DeepSpeed-Domino/megatron/data/indexed_dataset.py new file mode 100644 index 000000000..05ef5c4b2 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/indexed_dataset.py @@ -0,0 +1,623 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# copied from fairseq/fairseq/data/indexed_dataset.py +# Removed IndexedRawTextDataset since it relied on Fairseq dictionary +# other slight modifications to remove fairseq dependencies +# Added document index to index file and made it accessible. +# An empty sentence no longer separates documents. + +from functools import lru_cache +import os +import shutil +import struct +from itertools import accumulate + +import numpy as np +import torch +from megatron import print_rank_0 + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def get_available_dataset_impl(): + return ['lazy', 'cached', 'mmap'] + + +def infer_dataset_impl(path): + if IndexedDataset.exists(path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return 'cached' + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return 'mmap' + else: + return None + else: + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == 'mmap': + return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, skip_warmup=False, multimodal=False): + if not IndexedDataset.exists(path): + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + if impl == 'infer': + impl = infer_dataset_impl(path) + if impl == 'lazy' and IndexedDataset.exists(path): + return IndexedDataset(path) + elif impl == 'cached' and IndexedDataset.exists(path): + return IndexedCachedDataset(path) + elif impl == 'mmap' and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup, multimodal) + print(f"Unknown dataset implementation: {impl}") + return None + + +def dataset_exists(path, impl): + if impl == 'mmap': + return MMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float64, + 7: np.float32, + 8: np.uint16, +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + '.idx' + + +def data_file_path(prefix_path): + return prefix_path + '.bin' + + +def create_doc_idx(sizes): + doc_idx = [0] + for i, s in enumerate(sizes): + if s == 0: + doc_idx.append(i + 1) + return doc_idx + + +class IndexedDataset(torch.utils.data.Dataset): + """Loader for IndexedDataset""" + _HDR_MAGIC = b'TNTIDX\x00\x00' + + def __init__(self, path): + super().__init__() + self.path = path + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + 'Index file doesn\'t match expected format. ' + 'Make sure that --dataset-impl is configured properly.' + ) + version = f.read(8) + assert struct.unpack('= self._len: + raise IndexError('index out of range') + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return ( + os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + ) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + + def __init__(self, path): + super().__init__(path) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx: ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx: ptx + a.size]) + return a + elif isinstance(idx, slice): + # Hack just to make this work, can optimizer later if necessary + sents = [] + for i in range(*idx.indices(len(self))): + sents.append(self[i]) + return sents + + +class IndexedDatasetBuilder(object): + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float32: 4, + np.float64: 8, + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, 'wb') + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + self.doc_idx = [0] + + def add_item(self, tensor): + bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype)) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in tensor.size(): + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) + + def end_document(self): + self.doc_idx.append(len(self.sizes)) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + doc_offset = len(self.sizes) + + begin = self.data_offsets[-1] + for data_offset in index.data_offsets[1:]: + self.data_offsets.append(begin + data_offset) + self.sizes.extend(index.sizes) + + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + self.doc_idx.extend((doc_offset + index.doc_idx)[1:]) + + with open(data_file_path(another_file), 'rb') as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, 'wb') + index.write(b'TNTIDX\x00\x00') + index.write(struct.pack(' max_seq_length - 1: + enc_ids = enc_ids[0: max_seq_length - 1] + tokentypes_enc = tokentypes_enc[0: max_seq_length - 1] + + # [SEP]. + enc_ids.append(sep_id) + tokentypes_enc.append(0) + + num_tokens_enc = len(enc_ids) + # Padding. + padding_length = max_seq_length - len(enc_ids) + if padding_length > 0: + enc_ids.extend([pad_id] * padding_length) + tokentypes_enc.extend([pad_id] * padding_length) + + pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length) + pad_mask = np.array(pad_mask, dtype=np.int64) + + return enc_ids, tokentypes_enc, pad_mask + + +def build_sample(row_id, context_ids, context_types, context_pad_mask): + """Convert to numpy and return a sample consumed by the batch producer.""" + + context_ids = np.array(context_ids, dtype=np.int64) + context_types = np.array(context_types, dtype=np.int64) + context_mask = make_attention_mask(context_ids, context_ids) + + sample = ({ + 'row_id': row_id, + 'context': context_ids, + 'context_mask': context_mask, + 'context_types': context_types, + 'context_pad_mask': context_pad_mask + }) + return sample + + +class OpenRetrievalEvidenceDataset(ABC, Dataset): + """Open Retrieval Evidence dataset class.""" + + def __init__(self, task_name, dataset_name, datapath, tokenizer, + max_seq_length): + # Store inputs. + self.task_name = task_name + self.dataset_name = dataset_name + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + print_rank_0(' > building {} dataset for {}:'.format(self.task_name, + self.dataset_name)) + # Process the files. + print_rank_0(datapath) + self.samples, self.id2text = self.process_samples_from_single_path( + datapath) + + args = get_args() + if args.sample_rate < 1: # subsample + k = int(len(self.samples) * args.sample_rate) + self.samples = random.sample(self.samples, k) + + print_rank_0(' >> total number of samples: {}'.format( + len(self.samples))) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + row = self.samples[idx] + + context_ids, context_types, context_pad_mask = \ + build_tokens_types_paddings_from_text(row, self.tokenizer, + self.max_seq_length) + + sample = build_sample(row['doc_id'], + context_ids, + context_types, + context_pad_mask) + return sample + + @staticmethod + def process_samples_from_single_path(filename): + print_rank_0(' > Processing {} ...'.format(filename)) + total = 0 + + rows = [] + id2text = {} + + with open(filename) as tsvfile: + reader = csv.reader(tsvfile, delimiter='\t') + next(reader, None) # skip the headers + for row in reader: + # file format: doc_id, doc_text, title + doc_id = int(row[0]) + text = row[1] + title = row[2] + + rows.append({'doc_id': doc_id, + 'text': text, + 'title': title}) + + assert doc_id not in id2text + id2text[doc_id] = (text, title) + + total += 1 + if total % 100000 == 0: + print_rank_0(' > processed {} rows so far ...'.format( + total)) + + print_rank_0(' >> processed {} samples.'.format(len(rows))) + return rows, id2text diff --git a/training/DeepSpeed-Domino/megatron/data/realm_dataset_utils.py b/training/DeepSpeed-Domino/megatron/data/realm_dataset_utils.py new file mode 100644 index 000000000..21445573e --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/realm_dataset_utils.py @@ -0,0 +1,199 @@ +import os +import time + +import numpy as np +import torch + +from megatron import print_rank_0 +from megatron.core import mpu, tensor_parallel +from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy +from megatron import get_args, get_tokenizer, print_rank_0 + + +def get_one_epoch_dataloader(dataset, micro_batch_size=None): + """Specifically one epoch to be used in an indexing job.""" + args = get_args() + + world_size = mpu.get_data_parallel_world_size() + rank = mpu.get_data_parallel_rank() + if micro_batch_size is None: + micro_batch_size = args.micro_batch_size + global_batch_size = micro_batch_size * world_size + num_workers = args.num_workers + + sampler = torch.utils.data.SequentialSampler(dataset) + # importantly, drop_last must be False to get all the data. + assert False, 'DistributedBatchSampler deprecated, change the implementation' + from megatron.data.samplers import DistributedBatchSampler + batch_sampler = DistributedBatchSampler(sampler, + batch_size=global_batch_size, + drop_last=False, + rank=rank, + world_size=world_size) + + return torch.utils.data.DataLoader(dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + pin_memory=True) + + +def get_ict_batch(data_iterator): + # Items and their type. + keys = ['query_tokens', 'query_pad_mask', + 'block_tokens', 'block_pad_mask', 'block_data'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is None: + data = None + else: + data = next(data_iterator) + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + query_tokens = data_b['query_tokens'].long() + query_pad_mask = data_b['query_pad_mask'].long() + block_tokens = data_b['block_tokens'].long() + block_pad_mask = data_b['block_pad_mask'].long() + block_indices = data_b['block_data'].long() + + return query_tokens, query_pad_mask,\ + block_tokens, block_pad_mask, block_indices + + +def join_str_list(str_list): + """Join a list of strings, handling spaces appropriately""" + result = "" + for s in str_list: + if s.startswith("##"): + result += s[2:] + else: + result += " " + s + return result + + +class BlockSampleData(object): + """A struct for fully describing a fixed-size block of data as used in REALM + + :param start_idx: for first sentence of the block + :param end_idx: for last sentence of the block (may be partially truncated in sample construction) + :param doc_idx: the index of the document from which the block comes in the original indexed dataset + :param block_idx: a unique integer identifier given to every block. + """ + def __init__(self, start_idx, end_idx, doc_idx, block_idx): + self.start_idx = start_idx + self.end_idx = end_idx + self.doc_idx = doc_idx + self.block_idx = block_idx + + def as_array(self): + return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64) + + def as_tuple(self): + return self.start_idx, self.end_idx, self.doc_idx, self.block_idx + + +class BlockSamplesMapping(object): + def __init__(self, mapping_array): + # make sure that the array is compatible with BlockSampleData + assert mapping_array.shape[1] == 4 + self.mapping_array = mapping_array + + def __len__(self): + return self.mapping_array.shape[0] + + def __getitem__(self, idx): + """Get the data associated with an indexed sample.""" + sample_data = BlockSampleData(*self.mapping_array[idx]) + return sample_data + + +def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs, + max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False): + """Get samples mapping for a dataset over fixed size blocks. This function also requires + a dataset of the titles for the source documents since their lengths must be taken into account. + + :return: samples_mapping (BlockSamplesMapping) + """ + + if not num_epochs: + if not max_num_samples: + raise ValueError("Need to specify either max_num_samples " + "or num_epochs") + num_epochs = np.iinfo(np.int32).max - 1 + if not max_num_samples: + max_num_samples = np.iinfo(np.int64).max - 1 + + # Filename of the index mapping + indexmap_filename = data_prefix + indexmap_filename += '_{}_indexmap'.format(name) + if num_epochs != (np.iinfo(np.int32).max - 1): + indexmap_filename += '_{}ep'.format(num_epochs) + if max_num_samples != (np.iinfo(np.int64).max - 1): + indexmap_filename += '_{}mns'.format(max_num_samples) + indexmap_filename += '_{}msl'.format(max_seq_length) + indexmap_filename += '_{}s'.format(seed) + if use_one_sent_docs: + indexmap_filename += '_1sentok' + indexmap_filename += '.npy' + + # Build the indexed mapping if not exist. + if mpu.get_data_parallel_rank() == 0 and \ + not os.path.isfile(indexmap_filename): + print(' > WARNING: could not find index map file {}, building ' + 'the indices on rank 0 ...'.format(indexmap_filename)) + + # Make sure the types match the helpers input types. + assert block_dataset.doc_idx.dtype == np.int64 + assert block_dataset.sizes.dtype == np.int32 + + # Build samples mapping + verbose = torch.distributed.get_rank() == 0 + start_time = time.time() + print_rank_0(' > building samples index mapping for {} ...'.format( + name)) + + from megatron.data import helpers + mapping_array = helpers.build_blocks_mapping( + block_dataset.doc_idx, + block_dataset.sizes, + title_dataset.sizes, + num_epochs, + max_num_samples, + max_seq_length - 3, # account for added tokens + seed, + verbose, + use_one_sent_docs) + + + print_rank_0(' > done building samples index mapping') + np.save(indexmap_filename, mapping_array, allow_pickle=True) + print_rank_0(' > saved the index mapping in {}'.format( + indexmap_filename)) + # Make sure all the ranks have built the mapping + print_rank_0(' > elapsed time to build and save samples mapping ' + '(seconds): {:4f}'.format( + time.time() - start_time)) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + assert counts[0].item() == torch.distributed.get_world_size( + group=mpu.get_data_parallel_group()) + + # Load indexed dataset. + print_rank_0(' > loading indexed mapping from {}'.format( + indexmap_filename)) + start_time = time.time() + + mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') + samples_mapping = BlockSamplesMapping(mapping_array) + + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + print_rank_0(' total number of samples: {}'.format( + mapping_array.shape[0])) + + return samples_mapping diff --git a/training/DeepSpeed-Domino/megatron/data/realm_index.py b/training/DeepSpeed-Domino/megatron/data/realm_index.py new file mode 100644 index 000000000..1fa4a309e --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/realm_index.py @@ -0,0 +1,224 @@ +import itertools +import os +import pickle +import shutil + +import numpy as np +import torch + +from megatron import get_args +from megatron.core import mpu + + +def detach(tensor): + return tensor.detach().cpu().numpy() + + +class OpenRetreivalDataStore(object): + """ + Serializable data structure for holding data for blocks -- + embeddings and necessary metadata for Retriever + """ + def __init__(self, embedding_path=None, load_from_path=True, rank=None): + self.embed_data = dict() + if embedding_path is None: + args = get_args() + embedding_path = args.embedding_path + rank = args.rank + self.embedding_path = embedding_path + self.rank = rank + + if load_from_path: + self.load_from_file() + + block_data_name = os.path.splitext(self.embedding_path)[0] + self.temp_dir_name = block_data_name + '_tmp' + + def state(self): + return { + 'embed_data': self.embed_data, + } + + def clear(self): + """ + Clear the embedding data structures to save memory. + The metadata ends up getting used, and is also much smaller in + dimensionality so it isn't really worth clearing. + """ + self.embed_data = dict() + + def load_from_file(self): + """Populate members from instance saved to file""" + + if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: + print("\n> Unpickling BlockData", flush=True) + state_dict = pickle.load(open(self.embedding_path, 'rb')) + if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: + print(">> Finished unpickling BlockData\n", flush=True) + + self.embed_data = state_dict['embed_data'] + + def add_block_data(self, row_id, block_embeds, allow_overwrite=False): + """ + Add data for set of blocks + :param row_id: 1D array of unique int ids for the blocks + :param block_embeds: 2D array of embeddings of the blocks + In the case of retriever this will be [start_idx, end_idx, doc_idx] + """ + for idx, embed in zip(row_id, block_embeds): + if not allow_overwrite and idx in self.embed_data: + raise ValueError("Unexpectedly tried to overwrite block data") + + self.embed_data[idx] = np.float16(embed) + + def save_shard(self): + """ + Save the block data that was created this in this process + """ + if not os.path.isdir(self.temp_dir_name): + os.makedirs(self.temp_dir_name, exist_ok=True) + + # save the data for each shard + with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \ + as writer: + pickle.dump(self.state(), writer) + + def merge_shards_and_save(self): + #Combine all the shards made using save_shard + shard_names = os.listdir(self.temp_dir_name) + seen_own_shard = False + + for fname in os.listdir(self.temp_dir_name): + shard_rank = int(os.path.splitext(fname)[0]) + if shard_rank == self.rank: + seen_own_shard = True + continue + + with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f: + data = pickle.load(f) + old_size = len(self.embed_data) + shard_size = len(data['embed_data']) + + # add the shard's data and check to make sure there + # is no overlap + self.embed_data.update(data['embed_data']) + assert len(self.embed_data) == old_size + shard_size + + assert seen_own_shard + + # save the consolidated shards and remove temporary directory + with open(self.embedding_path, 'wb') as final_file: + pickle.dump(self.state(), final_file) + shutil.rmtree(self.temp_dir_name, ignore_errors=True) + + print("Finished merging {} shards for a total of {} embeds".format( + len(shard_names), len(self.embed_data)), flush=True) + + +class FaissMIPSIndex(object): + """ + Wrapper object for a BlockData which similarity search via FAISS under the hood + """ + def __init__(self, embed_size, embed_data=None, use_gpu=False): + self.embed_size = embed_size + self.embed_data = embed_data + self.use_gpu = use_gpu + + self.mips_index = None + self._set_mips_index() + + def _set_mips_index(self): + """ + Create a Faiss Flat index with inner product as the metric + to search against + """ + try: + import faiss + except ImportError: + raise Exception("Error: Please install faiss to use FaissMIPSIndex") + + if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: + print("\n> Building index", flush=True) + + cpu_index = faiss.IndexFlatIP(self.embed_size) + + if self.use_gpu: + # create resources and config for GpuIndex + config = faiss.GpuMultipleClonerOptions() + config.shard = True + config.useFloat16 = True + gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config) + self.mips_index = faiss.IndexIDMap(gpu_index) + if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: + print(">> Initialized index on GPU", flush=True) + else: + # CPU index supports IDs so wrap with IDMap + self.mips_index = faiss.IndexIDMap(cpu_index) + if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: + print(">> Initialized index on CPU", flush=True) + + # if we were constructed with a BlockData, then automatically load it + # when the FAISS structure is built + if self.embed_data is not None: + self.add_embed_data(self.embed_data) + + def reset_index(self): + """Delete existing index and create a new""" + del self.mips_index + + # reset the block data so that _set_block_index will reload it as well + if self.embed_data is not None: + embed_data_path = self.embed_data.embedding_path + del self.embed_data + self.embed_data = OpenRetreivalDataStore(embed_data_path) + + self._set_mips_index() + + def update_index(self): + """Delete existing index and create a new""" + del self.mips_index + + # reset the block data so that _set_mips_index will reload it as well + if self.embed_data is not None: + self.embed_data.load_from_file() + self._set_mips_index() + + def add_embed_data(self, all_embed_data): + """Add the embedding of each block to the underlying FAISS index""" + + # this assumes the embed_data is a dict : {int: np.array} + block_indices, block_embeds = zip(*all_embed_data.embed_data.items()) + + # the embeddings have to be entered in as float32 even though the math + # internally is done with float16. + embeds_arr = np.float32(np.array(block_embeds)) + indices_arr = np.array(block_indices) + + # we no longer need the embedding data since it's in the index now + all_embed_data.clear() + + self.mips_index.add_with_ids(embeds_arr, indices_arr) + + if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: + print(">>> Finished adding block data to index", flush=True) + + def search_mips_index(self, query_embeds, top_k, reconstruct=True): + """ + Get the top-k blocks by the index distance metric. + + :param reconstruct: if True: return a [num_queries x k x embed_dim] + array of blocks + if False: return [num_queries x k] array of + distances, and another for indices + """ + query_embeds = np.float32(detach(query_embeds)) + + if reconstruct: + # get the vectors themselves + top_k_block_embeds = self.mips_index.search_and_reconstruct(\ + query_embeds, top_k) + return top_k_block_embeds + else: + # get distances and indices of closest vectors + distances, block_indices = self.mips_index.search(query_embeds, top_k) + return distances, block_indices diff --git a/training/DeepSpeed-Domino/megatron/data/t5_dataset.py b/training/DeepSpeed-Domino/megatron/data/t5_dataset.py new file mode 100644 index 000000000..e60681490 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/t5_dataset.py @@ -0,0 +1,257 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""T5 Style dataset.""" + +import collections + +import numpy as np +import torch + +from megatron import get_tokenizer +from megatron.data.dataset_utils import ( + create_masked_lm_predictions, + get_samples_mapping +) + +class T5Dataset(torch.utils.data.Dataset): + + def __init__(self, name, indexed_dataset, data_prefix, + num_epochs, max_num_samples, masked_lm_prob, + max_seq_length, max_seq_length_dec, + short_seq_prob, seed): + + # Params to store. + self.name = name + self.seed = seed + self.masked_lm_prob = masked_lm_prob + self.max_seq_length = max_seq_length + self.max_seq_length_dec = max_seq_length_dec + + # Dataset. + self.indexed_dataset = indexed_dataset + + # Build the samples mapping. + self.samples_mapping = get_samples_mapping(self.indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + self.max_seq_length - 2, # account for added tokens + short_seq_prob, + self.seed, + self.name, + False) + + # Vocab stuff. + tokenizer = get_tokenizer() + self.vocab_id_list = list(tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_dict = tokenizer.inv_vocab + self.cls_id = tokenizer.cls + self.sep_id = tokenizer.sep + self.mask_id = tokenizer.mask + self.pad_id = tokenizer.pad + self.bos_id = tokenizer.bos_token_id + self.eos_id = tokenizer.eos_token_id + self.sentinel_tokens = tokenizer.additional_special_tokens_ids + assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script" + + def __len__(self): + return self.samples_mapping.shape[0] + + def __getitem__(self, idx): + + start_index, end_index, seq_length = self.samples_mapping[idx] + sample = [] + for index in range(start_index, end_index): + sample.append(self.indexed_dataset[index]) + # Note that this rng state should be numpy and not python since + # python randint is inclusive whereas the numpy one is exclusive. + np_rng = np.random.RandomState(seed=(self.seed + idx)) + return build_training_sample(sample, seq_length, + self.max_seq_length, # needed for padding + self.max_seq_length_dec, + self.vocab_id_list, + self.vocab_id_to_token_dict, + self.cls_id, self.sep_id, + self.mask_id, self.pad_id, + self.masked_lm_prob, np_rng, + self.bos_id, self.eos_id, + self.sentinel_tokens) + + +def build_training_sample(sample, target_seq_length, + max_seq_length, max_seq_length_dec, + vocab_id_list, vocab_id_to_token_dict, + cls_id, sep_id, mask_id, pad_id, + masked_lm_prob, np_rng, bos_id=None, + eos_id=None, sentinel_tokens=None): + """Build training sample. + + Arguments: + sample: A list of sentences in which each sentence is a list token ids. + target_seq_length: Desired sequence length. + max_seq_length: Maximum length of the sequence. All values are padded to + this length. + vocab_id_list: List of vocabulary ids. Used to pick a random id. + vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. + cls_id: Start of example id. + sep_id: Separator id. + mask_id: Mask token id. + pad_id: Padding token id. + masked_lm_prob: Probability to mask tokens. + np_rng: Random number genenrator. Note that this rng state should be + numpy and not python since python randint is inclusive for + the opper bound whereas the numpy one is exclusive. + bos_id: start of decoder example id + eos_id: end of generation id + sentinel_tokens: unique value to be substituted for every replaced span + """ + + assert target_seq_length <= max_seq_length + + # flatten sentences into one list + tokens = [token for sentence in sample for token in sentence] + + # Truncate to `target_sequence_length`. + max_num_tokens = target_seq_length + truncated = len(tokens) > max_num_tokens + tokens = tokens[:max_num_tokens] + + # Masking. + max_predictions_per_seq = masked_lm_prob * max_num_tokens + (tokens, masked_positions, masked_labels, _, masked_spans) = create_masked_lm_predictions( + tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, + cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng, + max_ngrams=10, geometric_dist=True, masking_style="t5") + + # Padding. + tokens_enc, tokens_dec_in, labels, enc_mask, \ + dec_mask, enc_dec_mask, loss_mask \ + = pad_and_convert_to_numpy(tokens, masked_positions, + masked_labels, pad_id, max_seq_length, + max_seq_length_dec, masked_spans, + bos_id, eos_id, sentinel_tokens) + + train_sample = { + 'text_enc': tokens_enc, + 'text_dec': tokens_dec_in, + 'labels': labels, + 'loss_mask': loss_mask, + 'truncated': int(truncated), + 'enc_mask': enc_mask, + 'dec_mask': dec_mask, + 'enc_dec_mask': enc_dec_mask, + } + return train_sample + + +def pad_and_convert_to_numpy(tokens, masked_positions, + masked_labels, pad_id, + max_seq_length, max_seq_length_dec, + masked_spans=None, bos_id=None, + eos_id=None, sentinel_tokens=None): + """Pad sequences and convert them to numpy.""" + + sentinel_tokens = collections.deque(sentinel_tokens) + t5_input = [] + (t5_decoder_in, t5_decoder_out) = ([bos_id], []) + (start_index, end_index) = (0, None) + for span in masked_spans: + flag = sentinel_tokens.popleft() + + # Append the same tokens in decoder input and output + t5_decoder_in.append(flag) + t5_decoder_in.extend(span.label) + t5_decoder_out.append(flag) + t5_decoder_out.extend(span.label) + + end_index = span.index[0] + t5_input.extend(tokens[start_index: end_index]) + t5_input.append(flag) + + # the next start index is the token after the last span token + start_index = span.index[-1] + 1 + + # Add token to the t5_decoder_out + t5_decoder_out.append(eos_id) + + # Add the remaining tokens to the t5 input + t5_input.extend(tokens[start_index:]) + + # assert (len(t5_input) - len(masked_spans)) + \ + # (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens) + + # Some checks. + + # Encoder-side padding mask. + num_tokens = len(t5_input) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0 + assert len(masked_positions) == len(masked_labels) + + # Tokens.. + filler = [pad_id] * padding_length + tokens_enc = np.array(t5_input + filler, dtype=np.int64) + + # Decoder-side padding mask. + num_tokens_dec = len(t5_decoder_in) + padding_length_dec = max_seq_length_dec - num_tokens_dec + assert padding_length_dec >= 0 + filler_dec = [pad_id] * padding_length_dec + tokens_dec_in = np.array(t5_decoder_in + filler_dec, dtype=np.int64) + + # Create attention masks + enc_mask = make_attention_mask(tokens_enc, tokens_enc) + enc_dec_mask = make_attention_mask(tokens_dec_in, tokens_enc) + dec_mask = make_attention_mask(tokens_dec_in, tokens_dec_in) + dec_mask = dec_mask * make_history_mask(tokens_dec_in) + + # Labels mask. + labels = t5_decoder_out + ([-1] * padding_length_dec) + labels = np.array(labels, dtype=np.int64) + + # Loss mask + loss_mask = ([1] * num_tokens_dec) + ([0] * padding_length_dec) + loss_mask = np.array(loss_mask, dtype=np.int64) + + return tokens_enc, tokens_dec_in, labels, enc_mask, \ + dec_mask, enc_dec_mask, loss_mask + + +def make_attention_mask(source_block, target_block): + """ + Returns a 2-dimensional (2-D) attention mask + :param source_block: 1-D array + :param target_block: 1-D array + """ + mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) + mask = mask.astype(np.int64) + # (source_length, target_length) + return mask + + +def make_attention_mask_3d(source_block, target_block): + """ + Returns a 3-dimensional (3-D) attention mask + :param source_block: 1-D array + :param target_block: 1-D array + """ + mask = (target_block[:, None, :] >= 1) * (source_block[:, :, None] >= 1) + # (batch, source_length, target_length) + # mask = mask.astype(np.int64) + return mask + + +def make_history_mask(block): + length = block.shape[0] + arange = np.arange(length) + history_mask = (arange[None, ] <= arange[:, None]) + history_mask = history_mask.astype(np.int64) + return history_mask + + +def make_history_mask_3d(block): + batch, length = block.shape + arange = torch.arange(length, device=block.device) + history_mask = (arange[None, ] <= arange[:, None])[None, ] + history_mask = history_mask.expand(batch, length, length) + return history_mask diff --git a/training/DeepSpeed-Domino/megatron/data/test/test_indexed_dataset.py b/training/DeepSpeed-Domino/megatron/data/test/test_indexed_dataset.py new file mode 100644 index 000000000..12fec8d81 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/test/test_indexed_dataset.py @@ -0,0 +1,125 @@ +# This file isn't really a formal automated test, it's just a place to +# put some code used during development and manual testing of +# indexed_dataset. + +from megatron.data import indexed_dataset +from megatron.tokenizer import build_tokenizer +import argparse +import os +import sys + +import torch + +script_dir = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.join(script_dir, "../../../")) + + +def test_indexed_dataset(args): + ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) + tokenizer = build_tokenizer(args) + print(len(ds.doc_idx)) + print(len(ds)) + print(ds.doc_idx[-1]) + if ds.supports_prefetch: + # just prefetch the whole thing in test (so assume it is small) + ds.prefetch(range(len(ds))) + if args.count > len(ds.doc_idx) - 1: + args.count = len(ds.doc_idx) - 1 + + for i in range(args.count): + start = ds.doc_idx[i] + end = ds.doc_idx[i + 1] + ids = ds[start:end] + print(f"Document {i}:") + print("--------------") + for s in ids: + assert len(s) > 0 + l = s.data.tolist() + text = tokenizer.detokenize(l) + print(text) + print("---") + + +def test_indexed_dataset_get(args): + ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) + tokenizer = build_tokenizer(args) + size = ds.sizes[0] + print(f"size: {size}") + full = ds.get(0) + print(full) + # print(tokenizer.detokenize(full.data.tolist())) + print("---") + end = ds.get(0, offset=size - 10) + print(end) + # print(tokenizer.detokenize(end.data.tolist())) + + start = ds.get(0, length=10) + print(start) + # print(tokenizer.detokenize(start.data.tolist())) + + part = ds.get(0, offset=2, length=8) + print(part) + # print(tokenizer.detokenize(part.data.tolist())) + +# def test_albert_dataset(args): +# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) +# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl) +# # ds = AlbertDataset(idataset, tokenizer) +# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl, +# args.epochs, args.max_num_samples, +# args.masked_lm_prob, args.seq_length, +# args.short_seq_prob, args.seed) +# truncated = 0 +# total = 0 +# for i, s in enumerate(ds): +# ids = s['text'] +# tokens = ds.tokenizer.convert_ids_to_tokens(ids) +# print(tokens) +# if i >= args.count-1: +# exit() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--data', type=str, help='prefix to data files') + parser.add_argument('--dataset-impl', type=str, default='infer', + choices=['lazy', 'cached', 'mmap', 'infer']) + parser.add_argument('--count', type=int, default=10, + help='Number of samples/documents to print') + + group = parser.add_argument_group(title='tokenizer') + group.add_argument('--tokenizer-type', type=str, required=True, + choices=['BertWordPieceLowerCase', + 'GPT2BPETokenizer'], + help='What type of tokenizer to use.') + group.add_argument('--vocab-file', type=str, default=None, + help='Path to the vocab file') + group.add_argument('--merge-file', type=str, default=None, + help='Path to the BPE merge file (if necessary).') + + parser.add_argument('--epochs', type=int, default=5, + help='Number of epochs to plan for') + parser.add_argument('--max-num-samples', type=int, default=None, + help='Maximum number of samples to plan for') + parser.add_argument('--masked-lm-prob', type=float, default=0.15, + help='probability of masking tokens') + parser.add_argument('--seq-length', type=int, default=512, + help='maximum sequence length') + parser.add_argument('--short-seq-prob', type=float, default=0.1, + help='probability of creating a short sequence') + parser.add_argument('--seed', type=int, default=1234, + help='random seed') + args = parser.parse_args() + args.rank = 0 + args.make_vocab_size_divisible_by = 128 + args.tensor_model_parallel_size = 1 + + if args.dataset_impl == "infer": + args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) + +# test_albert_dataset(args) + test_indexed_dataset_get(args) + + +if __name__ == "__main__": + main() diff --git a/training/DeepSpeed-Domino/megatron/data/test/test_preprocess_data.sh b/training/DeepSpeed-Domino/megatron/data/test/test_preprocess_data.sh new file mode 100755 index 000000000..d121c8595 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/test/test_preprocess_data.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +IMPL=cached +python ../preprocess_data.py \ + --input test_samples.json \ + --vocab vocab.txt \ + --dataset-impl ${IMPL} \ + --output-prefix test_samples_${IMPL} \ + --workers 1 \ + --log-interval 2 diff --git a/training/DeepSpeed-Domino/megatron/data/vit_dataset.py b/training/DeepSpeed-Domino/megatron/data/vit_dataset.py new file mode 100644 index 000000000..82391e915 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/data/vit_dataset.py @@ -0,0 +1,249 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import os +import random +import numpy as np +import torch +import torchvision.transforms as T +from torchvision import datasets +from megatron import get_args +from megatron.data.image_folder import ImageFolder +from megatron.data.autoaugment import ImageNetPolicy +from megatron.data.data_samplers import RandomSeedDataset +from PIL import Image, ImageFilter, ImageOps + + +class GaussianBlur(object): + """ + Apply Gaussian Blur to the PIL image. + """ + def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): + self.prob = p + self.radius_min = radius_min + self.radius_max = radius_max + + def __call__(self, img): + do_it = random.random() <= self.prob + if not do_it: + return img + + return img.filter( + ImageFilter.GaussianBlur( + radius=random.uniform(self.radius_min, self.radius_max) + ) + ) + + +class Solarization(object): + """ + Apply Solarization to the PIL image. + """ + def __init__(self, p): + self.p = p + + def __call__(self, img): + if random.random() < self.p: + return ImageOps.solarize(img) + else: + return img + + +class ClassificationTransform(): + def __init__(self, image_size, train=True): + args = get_args() + assert args.fp16 or args.bf16 + self.data_type = torch.half if args.fp16 else torch.bfloat16 + if train: + self.transform = T.Compose([ + T.RandomResizedCrop(image_size), + T.RandomHorizontalFlip(), + T.ColorJitter(0.4, 0.4, 0.4, 0.1), + ImageNetPolicy(), + T.ToTensor(), + T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + T.ConvertImageDtype(self.data_type) + ]) + else: + self.transform = T.Compose([ + T.Resize(image_size), + T.CenterCrop(image_size), + T.ToTensor(), + T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + T.ConvertImageDtype(self.data_type) + ]) + + def __call__(self, input): + output = self.transform(input) + return output + + +class InpaintingTransform(): + def __init__(self, image_size, train=True): + + args = get_args() + self.mask_factor = args.mask_factor + self.mask_type = args.mask_type + self.image_size = image_size + self.patch_size = args.patch_dim + self.mask_size = int(self.mask_factor*(image_size[0]/self.patch_size)*(image_size[1]/self.patch_size)) + self.train = train + assert args.fp16 or args.bf16 + self.data_type = torch.half if args.fp16 else torch.bfloat16 + + if self.train: + self.transform = T.Compose([ + T.RandomResizedCrop(self.image_size), + T.RandomHorizontalFlip(), + T.ColorJitter(0.4, 0.4, 0.4, 0.1), + ImageNetPolicy(), + T.ToTensor(), + T.ConvertImageDtype(self.data_type) + ]) + else: + self.transform = T.Compose([ + T.Resize(self.image_size, interpolation=2), + T.CenterCrop(self.image_size), + T.ToTensor(), + T.ConvertImageDtype(self.data_type) + ]) + + def gen_mask(self, image_size, mask_size, mask_type, patch_size): + # output: mask as a list with indices for missing patches + action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]] + assert image_size[0] == image_size[1] + img_size_patch = image_size[0] // patch_size + + # drop masked patches + mask = torch.zeros((image_size[0], image_size[1]), dtype=torch.float) + + if mask_type == 'random': + x = torch.randint(0, img_size_patch, ()) + y = torch.randint(0, img_size_patch, ()) + for i in range(mask_size): + r = torch.randint(0, len(action_list), ()) + x = torch.clamp(x + action_list[r][0], min=0, max=img_size_patch - 1) + y = torch.clamp(y + action_list[r][1], min=0, max=img_size_patch - 1) + x_offset = x * patch_size + y_offset = y * patch_size + mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1 + else: + assert mask_type == 'row' + count = 0 + for x in reversed(range(img_size_patch)): + for y in reversed(range(img_size_patch)): + if (count < mask_size): + count += 1 + x_offset = x * patch_size + y_offset = y * patch_size + mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1 + return mask + + def __call__(self, input): + trans_input = self.transform(input) + mask = self.gen_mask(self.image_size, self.mask_size, + self.mask_type, self.patch_size) + mask = mask.unsqueeze(dim=0) + return trans_input, mask + + +class DinoTransform(object): + def __init__(self, image_size, train=True): + args = get_args() + self.data_type = torch.half if args.fp16 else torch.bfloat16 + + flip_and_color_jitter = T.Compose([ + T.RandomHorizontalFlip(p=0.5), + T.RandomApply( + [T.ColorJitter(brightness=0.4, contrast=0.4, + saturation=0.2, hue=0.1)], + p=0.8 + ), + T.RandomGrayscale(p=0.2), + ]) + + if args.fp16 or args.bf16: + normalize = T.Compose([ + T.ToTensor(), + T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + T.ConvertImageDtype(self.data_type) + ]) + else: + normalize = T.Compose([ + T.ToTensor(), + T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + + # first global crop + scale_const = 0.4 + self.global_transform1 = T.Compose([ + T.RandomResizedCrop(image_size, + scale=(scale_const, 1), + interpolation=Image.BICUBIC), + flip_and_color_jitter, + GaussianBlur(1.0), + normalize + ]) + # second global crop + self.global_transform2 = T.Compose([ + T.RandomResizedCrop(image_size, + scale=(scale_const, 1), + interpolation=Image.BICUBIC), + flip_and_color_jitter, + GaussianBlur(0.1), + Solarization(0.2), + normalize + ]) + # transformation for the local small crops + self.local_crops_number = args.dino_local_crops_number + self.local_transform = T.Compose([ + T.RandomResizedCrop(args.dino_local_img_size, + scale=(0.05, scale_const), + interpolation=Image.BICUBIC), + flip_and_color_jitter, + GaussianBlur(p=0.5), + normalize + ]) + + def __call__(self, image): + crops = [] + crops.append(self.global_transform1(image)) + crops.append(self.global_transform2(image)) + for _ in range(self.local_crops_number): + crops.append(self.local_transform(image)) + return crops + + +def build_train_valid_datasets(data_path, image_size=224): + args = get_args() + + if args.vision_pretraining_type == 'classify': + train_transform = ClassificationTransform(image_size) + val_transform = ClassificationTransform(image_size, train=False) + elif args.vision_pretraining_type == 'inpaint': + train_transform = InpaintingTransform(image_size, train=False) + val_transform = InpaintingTransform(image_size, train=False) + elif args.vision_pretraining_type == 'dino': + train_transform = DinoTransform(image_size, train=True) + val_transform = ClassificationTransform(image_size, train=False) + else: + raise Exception('{} vit pretraining type is not supported.'.format( + args.vit_pretraining_type)) + + # training dataset + train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2] + train_data = ImageFolder( + root=train_data_path, + transform=train_transform, + classes_fraction=args.classes_fraction, + data_per_class_fraction=args.data_per_class_fraction + ) + train_data = RandomSeedDataset(train_data) + + # validation dataset + val_data_path = data_path[1] + val_data = ImageFolder( + root=val_data_path, + transform=val_transform + ) + val_data = RandomSeedDataset(val_data) + + return train_data, val_data diff --git a/training/DeepSpeed-Domino/megatron/dist_signal_handler.py b/training/DeepSpeed-Domino/megatron/dist_signal_handler.py new file mode 100644 index 000000000..a60204f00 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/dist_signal_handler.py @@ -0,0 +1,81 @@ +import signal + +import torch + + +def get_world_size(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + else: + world_size = 1 + return world_size + + +def get_device(local_rank=None): + backend = torch.distributed.get_backend() + if backend == 'nccl': + if local_rank is None: + device = torch.device('cuda') + else: + device = torch.device(f'cuda:{local_rank}') + elif backend == 'gloo': + device = torch.device('cpu') + else: + raise RuntimeError + return device + + +def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None): + if not torch.distributed.is_available() or \ + not torch.distributed.is_initialized(): + return [item] + + device = get_device(local_rank) + + if group is not None: + group_size = group.size() + else: + group_size = get_world_size() + + tensor = torch.tensor([item], device=device, dtype=dtype) + output_tensors = [ + torch.zeros(1, dtype=tensor.dtype, device=tensor.device) + for _ in range(group_size) + ] + torch.distributed.all_gather(output_tensors, tensor, group, async_op) + output = [elem.item() for elem in output_tensors] + return output + + +class DistributedSignalHandler: + def __init__(self, sig=signal.SIGTERM): + self.sig = sig + + def signals_received(self): + all_received = all_gather_item( + self._signal_received, dtype=torch.int32 + ) + return all_received + + def __enter__(self): + self._signal_received = False + self.released = False + self.original_handler = signal.getsignal(self.sig) + + def handler(signum, frame): + self._signal_received = True + + signal.signal(self.sig, handler) + + return self + + def __exit__(self, type, value, tb): + self.release() + + def release(self): + if self.released: + return False + + signal.signal(self.sig, self.original_handler) + self.released = True + return True diff --git a/training/DeepSpeed-Domino/megatron/fused_kernels/__init__.py b/training/DeepSpeed-Domino/megatron/fused_kernels/__init__.py new file mode 100644 index 000000000..87cceac3e --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/fused_kernels/__init__.py @@ -0,0 +1,75 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import os +import pathlib +import subprocess + +from torch.utils import cpp_extension + +# Setting this param to a list has a problem of generating different +# compilation commands (with diferent order of architectures) and +# leading to recompilation of fused kernels. Set it to empty string +# to avoid recompilation and assign arch flags explicity in +# extra_cuda_cflags below +os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + +def load(args): + + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( + cpp_extension.CUDA_HOME + ) + if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_80,code=sm_80') + if int(bare_metal_minor) >= 8: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_90,code=sm_90') + + # Build path + srcpath = pathlib.Path(__file__).parent.absolute() + buildpath = srcpath / "build" + _create_build_dir(buildpath) + + # Helper function to build the kernels. + def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=[ + "-O3", + ], + extra_cuda_cflags=[ + "-O3", + "-gencode", + "arch=compute_70,code=sm_70", + "--use_fast_math", + ] + + extra_cuda_flags + + cc_flag, + verbose=(args.rank == 0), + ) + + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") diff --git a/training/DeepSpeed-Domino/megatron/fused_kernels/compat.h b/training/DeepSpeed-Domino/megatron/fused_kernels/compat.h new file mode 100644 index 000000000..5495d7807 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/fused_kernels/compat.h @@ -0,0 +1,17 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + + + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/training/DeepSpeed-Domino/megatron/fused_kernels/tests/__init__.py b/training/DeepSpeed-Domino/megatron/fused_kernels/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/training/DeepSpeed-Domino/megatron/fused_kernels/tests/test_fused_kernels.py b/training/DeepSpeed-Domino/megatron/fused_kernels/tests/test_fused_kernels.py new file mode 100644 index 000000000..74024c502 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/fused_kernels/tests/test_fused_kernels.py @@ -0,0 +1,388 @@ +import math + +import torch +from torch.nn import LayerNorm + +from megatron.model.enums import AttnMaskType +from megatron.model.fused_layer_norm import MixedFusedLayerNorm +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from megatron.model.utils import attention_mask_func +from megatron.fused_kernels import load + +def test_load_fused_kernels(): + try: + import fused_layer_norm_cuda + import scaled_masked_softmax_cuda + import scaled_upper_triang_masked_softmax_cuda + import torch + + print("[Success] load_fused_kernels") + except ImportError as e: + print("[Fail] load_fused_kernels") + raise e + +def test_fused_softmax(): + bert = BertModel.from_pretrained("bert-base-cased").cuda().half() + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + embedding_output = bert.embeddings( + input_ids=tokens["input_ids"].cuda(), + position_ids=None, + token_type_ids=tokens["token_type_ids"].cuda(), + inputs_embeds=None, + past_key_values_length=0, + ) + + # (bsz, 1, 1, seq_len) + mask = bert.get_extended_attention_mask( + attention_mask=tokens["attention_mask"].cuda(), + input_shape=tokens["input_ids"].shape, + device=bert.device, + ) + # (bsz, 1, seq_len, seq_len) + mask = mask.repeat(1, 1, mask.size()[-1], 1) + + attention = bert.encoder.layer[0].attention.self + key_layer = attention.transpose_for_scores(attention.key(embedding_output)) + query_layer = attention.transpose_for_scores(attention.query(embedding_output)) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores /= math.sqrt(key_layer.size()[-1]) + + fused_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.padding, + scaled_masked_softmax_fusion=True, + ) + .cuda() + .half() + ) + + fused_softmax_output = fused_softmax( + attention_scores, + (mask != 0), + ) + + torch_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.padding, + scaled_masked_softmax_fusion=False, + ) + .cuda() + .half() + ) + + torch_softmax_output = torch_softmax( + attention_scores, + (mask != 0), + ) + + test_result = (fused_softmax_output - torch_softmax_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_fused_softmax" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_fused_softmax" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + + +def test_fused_upper_triangle_mask_softmax(): + gpt = GPT2Model.from_pretrained("gpt2").cuda().half() + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi" # 24 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + attention_mask = tokens["attention_mask"].cuda() + attention_mask = attention_mask.view(attention_mask.size(0), -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = (1.0 - attention_mask) * -10000.0 + attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1) + attn = gpt.h[0] + + hidden_states = gpt.wte(tokens["input_ids"].cuda()) + q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1) + q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim) + k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim) + attn_weights = torch.matmul(q, k.transpose(-1, -2)) + + sq, sk = q.size(-2), k.size(-2) + causal_mask = attn.attn.bias[:, :, sk - sq : sk, :sk].bool() + total_mask = ~(causal_mask & (attention_mask == 0)) + """ + tensor([[[[False, True, True, ..., True, True, True], + [False, False, True, ..., True, True, True], + [False, False, False, ..., True, True, True], + ..., + [False, False, False, ..., False, True, True], + [False, False, False, ..., False, False, True], + [False, False, False, ..., False, False, False]]] + """ + + fused_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=True, + ) + .cuda() + .half() + ) + + fused_softmax_output = fused_softmax( + attn_weights, + total_mask, + ) + + torch_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=False, + ) + .cuda() + .half() + ) + + torch_softmax_output = torch_softmax( + attn_weights, + total_mask, + ) + + test_result = (fused_softmax_output - torch_softmax_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_fused_upper_triangle_mask_softmax" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_fused_upper_triangle_mask_softmax" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + + +def test_layer_norm(): + bert = BertModel.from_pretrained("bert-base-cased").cuda().half() + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + # [bsz, seq_len, d_model] + embedding_output = ( + bert.embeddings( + input_ids=tokens["input_ids"].cuda(), + position_ids=None, + token_type_ids=tokens["token_type_ids"].cuda(), + inputs_embeds=None, + past_key_values_length=0, + ) + .cuda() + .half() + ) + + fused_layernorm_layer = ( + MixedFusedLayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half() + ) + + torch_layernorm_layer = ( + LayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half() + ) + + fused_output = fused_layernorm_layer(embedding_output) + torch_output = torch_layernorm_layer(embedding_output) + test_result = (fused_output - torch_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_layer_norm" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_output[-1][-1][:5].tolist()}" + f"\n > torch_values={torch_output[-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_layer_norm" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_output[-1][-1][:5].tolist()}" + ) + + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + +def forward_torch_softmax(input, mask, scale): + input = input * scale + mask_output = attention_mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + return probs + + +def test_masked_softmax_forward(): + import scaled_masked_softmax_cuda + + batch = 2 + attn = 16 + scale_t = torch.tensor([1.0]) + for qlen in [128, 256, 1024, 2048, 4096]: + for klen in [128, 256, 1024, 2048]: + inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') + masks = torch.randint(0, 2, (batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') + softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) + softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item()) + error = (softmax_results_torch - softmax_results).abs().max() + assert error < 1e-3 + +def test_masked_softmax_backward(): + import scaled_masked_softmax_cuda + + batch = 2 + attn = 16 + scale_t = torch.tensor([1.0]) + for qlen in [128, 256, 1024, 2048, 4096]: + for klen in [128, 256, 1024, 2048]: + inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') + backward = torch.rand_like(inputs, dtype=torch.float16, device='cuda:0') + masks = torch.randint(0, 2, (batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') + softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) + back_grad = scaled_masked_softmax_cuda.backward(backward, softmax_results, scale_t[0].item()) + + inputs.requires_grad = True + softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item()) + softmax_results_torch.backward(backward) + error = (back_grad - inputs.grad).abs().max() + assert error < 1e-3 + + +def test_allmasked_softmax_forward(): + import scaled_masked_softmax_cuda + + batch = 2 + attn = 16 + scale_t = torch.tensor([1.0]) + for qlen in [128, 256, 1024, 2048, 4096]: + for klen in [128, 256, 1024, 2048]: + inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') + masks = torch.ones((batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') + softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) + softmax_results_torch = torch.zeros_like(inputs) + error = (softmax_results_torch - softmax_results).abs().max() + assert error == 0.0 + + +def test_allmasked_softmax_backward(): + import scaled_masked_softmax_cuda + + batch = 2 + attn = 16 + scale_t = torch.tensor([1.0]) + for qlen in [128, 256, 1024, 2048, 4096]: + for klen in [128, 256, 1024, 2048]: + inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') + backward = torch.rand_like(inputs, dtype=torch.float16, device='cuda:0') + masks = torch.ones((batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') + softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) + back_grad = scaled_masked_softmax_cuda.backward(backward, softmax_results, scale_t[0].item()) + inputs.requires_grad = True + softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item()) + softmax_results_torch.backward(backward) + error = (back_grad - inputs.grad).abs().max() + assert error < 1e-3 + + +if __name__ == "__main__": + try: + from transformers import BertTokenizer, GPT2Tokenizer + from transformers.models.bert.modeling_bert import BertModel + from transformers.models.gpt2.modeling_gpt2 import GPT2Model + import transformers + + transformers.logging.set_verbosity( + transformers.logging.FATAL, + ) + + except: + print("\n[Fail] Please install `transformers` package to test fused kernels\n") + exit(-1) + + load() + test_masked_softmax_forward() + test_masked_softmax_backward() + test_allmasked_softmax_forward() + test_allmasked_softmax_backward() + test_load_fused_kernels() + test_fused_softmax() + test_fused_upper_triangle_mask_softmax() + test_layer_norm() diff --git a/training/DeepSpeed-Domino/megatron/fused_kernels/type_shim.h b/training/DeepSpeed-Domino/megatron/fused_kernels/type_shim.h new file mode 100644 index 000000000..d60a6f8c6 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/fused_kernels/type_shim.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + + +#include +#include "compat.h" + + +#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + +#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch(TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + diff --git a/training/DeepSpeed-Domino/megatron/global_vars.py b/training/DeepSpeed-Domino/megatron/global_vars.py new file mode 100644 index 000000000..4e0118e10 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/global_vars.py @@ -0,0 +1,191 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Megatron global variables.""" + +import os +import sys +import torch + +from megatron import dist_signal_handler +from megatron.tokenizer import build_tokenizer +from .microbatches import build_num_microbatches_calculator +from .timers import Timers + +_GLOBAL_ARGS = None +_GLOBAL_RETRO_ARGS = None +_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None +_GLOBAL_TOKENIZER = None +_GLOBAL_TENSORBOARD_WRITER = None +_GLOBAL_ADLR_AUTORESUME = None +_GLOBAL_TIMERS = None +_GLOBAL_SIGNAL_HANDLER = None + +def get_args(): + """Return arguments.""" + _ensure_var_is_initialized(_GLOBAL_ARGS, 'args') + return _GLOBAL_ARGS + + +def get_retro_args(): + """Return retro arguments.""" + return _GLOBAL_RETRO_ARGS + + +def get_num_microbatches(): + return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() + + +def get_current_global_batch_size(): + return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size() + + +def update_num_microbatches(consumed_samples, consistency_check=True): + _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, + consistency_check) + + +def get_tokenizer(): + """Return tokenizer.""" + _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer') + return _GLOBAL_TOKENIZER + + +def get_tensorboard_writer(): + """Return tensorboard writer. It can be None so no need + to check if it is initialized.""" + return _GLOBAL_TENSORBOARD_WRITER + + +def get_adlr_autoresume(): + """ADLR autoresume object. It can be None so no need + to check if it is initialized.""" + return _GLOBAL_ADLR_AUTORESUME + + +def get_timers(): + """Return timers.""" + _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') + return _GLOBAL_TIMERS + + +def get_signal_handler(): + _ensure_var_is_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler') + return _GLOBAL_SIGNAL_HANDLER + + +def _set_signal_handler(): + global _GLOBAL_SIGNAL_HANDLER + _ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler') + _GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__() + + + +def set_global_variables(args, build_tokenizer=True): + """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" + + assert args is not None + + _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args') + set_args(args) + + _build_num_microbatches_calculator(args) + if build_tokenizer: + _ = _build_tokenizer(args) + _set_tensorboard_writer(args) + _set_adlr_autoresume(args) + _set_timers(args) + + if args.exit_signal_handler: + _set_signal_handler() + + +def set_args(args): + global _GLOBAL_ARGS + _GLOBAL_ARGS = args + + +def set_retro_args(retro_args): + global _GLOBAL_RETRO_ARGS + _GLOBAL_RETRO_ARGS = retro_args + + +def _build_num_microbatches_calculator(args): + + global _GLOBAL_NUM_MICROBATCHES_CALCULATOR + _ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, + 'num microbatches calculator') + + _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator( + args) + + +def _build_tokenizer(args): + """Initialize tokenizer.""" + global _GLOBAL_TOKENIZER + _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer') + _GLOBAL_TOKENIZER = build_tokenizer(args) + return _GLOBAL_TOKENIZER + + +def rebuild_tokenizer(args): + global _GLOBAL_TOKENIZER + _GLOBAL_TOKENIZER = None + return _build_tokenizer(args) + + +def _set_tensorboard_writer(args): + """Set tensorboard writer.""" + global _GLOBAL_TENSORBOARD_WRITER + _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, + 'tensorboard writer') + + if hasattr(args, 'tensorboard_dir') and \ + args.tensorboard_dir and args.rank == (args.world_size - 1): + try: + from torch.utils.tensorboard import SummaryWriter + print('> setting tensorboard ...') + _GLOBAL_TENSORBOARD_WRITER = SummaryWriter( + log_dir=args.tensorboard_dir, + max_queue=args.tensorboard_queue_size) + except ModuleNotFoundError: + print('WARNING: TensorBoard writing requested but is not ' + 'available (are you using PyTorch 1.1.0 or later?), ' + 'no TensorBoard logs will be written.', flush=True) + + +def _set_adlr_autoresume(args): + """Initialize ADLR autoresume.""" + global _GLOBAL_ADLR_AUTORESUME + _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume') + + if args.adlr_autoresume: + if args.rank == 0: + print('enabling autoresume ...', flush=True) + sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.')) + try: + from userlib.auto_resume import AutoResume + except BaseException: + print('ADLR autoresume is not available, exiting ...') + sys.exit() + + _GLOBAL_ADLR_AUTORESUME = AutoResume + + +def _set_timers(args): + """Initialize timers.""" + global _GLOBAL_TIMERS + _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') + _GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option) + + +def _ensure_var_is_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is not None, '{} is not initialized.'.format(name) + + +def _ensure_var_is_not_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is None, '{} is already initialized.'.format(name) + + + diff --git a/training/DeepSpeed-Domino/megatron/indexer.py b/training/DeepSpeed-Domino/megatron/indexer.py new file mode 100644 index 000000000..45f530a7d --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/indexer.py @@ -0,0 +1,129 @@ +import sys +import time +import torch +import torch.distributed as dist + +from megatron import get_args, print_rank_0 +from megatron.core import mpu +from megatron.checkpointing import load_biencoder_checkpoint +from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset +from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch +from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader +from megatron.data.realm_index import detach, OpenRetreivalDataStore +from megatron.model.biencoder_model import get_model_provider +from megatron.training import get_model + + +class IndexBuilder(object): + """ + Object for taking one pass over a dataset and creating a BlockData of its + embeddings + """ + def __init__(self): + args = get_args() + self.model = None + self.dataloader = None + self.evidence_embedder_obj = None + self.biencoder_shared_query_context_model = \ + args.biencoder_shared_query_context_model + + # need to know whether we're using a REALM checkpoint (args.load) + # or ICT checkpoint + assert not (args.load and args.ict_load) + + self.log_interval = args.indexer_log_interval + self.batch_size = args.indexer_batch_size + + self.load_attributes() + self.is_main_builder = mpu.get_data_parallel_rank() == 0 + self.num_total_builders = mpu.get_data_parallel_world_size() + self.iteration = self.total_processed = 0 + + def load_attributes(self): + """ + Load the necessary attributes: model, dataloader and empty BlockData + """ + only_context_model = True + if self.biencoder_shared_query_context_model: + only_context_model = False + + model = get_model(get_model_provider(only_context_model=\ + only_context_model, biencoder_shared_query_context_model=\ + self.biencoder_shared_query_context_model)) + + self.model = load_biencoder_checkpoint(model, + only_context_model=only_context_model) + + assert len(self.model) == 1 + self.model[0].eval() + + self.dataset = get_open_retrieval_wiki_dataset() + self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \ + self.batch_size)) + + self.evidence_embedder_obj = OpenRetreivalDataStore( \ + load_from_path=False) + + def track_and_report_progress(self, batch_size): + """ + Utility function for tracking progress + """ + self.iteration += 1 + self.total_processed += batch_size * self.num_total_builders + if self.is_main_builder and self.iteration % self.log_interval == 0: + print('Batch {:10d} | Total {:10d}'.format(self.iteration, + self.total_processed), flush=True) + + def build_and_save_index(self): + """ + Goes through one epoch of the dataloader and adds all data to this + instance's BlockData. + + The copy of BlockData is saved as a shard, which when run in a + distributed setting will be consolidated by the rank 0 process + and saved as a final pickled BlockData. + """ + assert len(self.model) == 1 + unwrapped_model = self.model[0] + + while not hasattr(unwrapped_model, 'embed_text'): + unwrapped_model = unwrapped_model.module + + while True: + try: + # batch also has query_tokens and query_pad_data + row_id, context_tokens, context_mask, context_types, \ + context_pad_mask = get_open_retrieval_batch( \ + self.dataloader) + except (StopIteration, IndexError): + break + + # TODO: can we add with torch.no_grad() to reduce memory usage + # detach, separate fields and add to BlockData + assert context_mask.dtype == torch.bool + context_logits = unwrapped_model.embed_text( + unwrapped_model.context_model, context_tokens, context_mask, + context_types) + + context_logits = detach(context_logits) + row_id = detach(row_id) + + self.evidence_embedder_obj.add_block_data(row_id, context_logits) + self.track_and_report_progress(batch_size=len(row_id)) + + # This process signals to finalize its shard and then synchronize with + # the other processes + self.evidence_embedder_obj.save_shard() + torch.distributed.barrier() + del self.model + + # rank 0 process builds the final copy + if self.is_main_builder: + self.evidence_embedder_obj.merge_shards_and_save() + # make sure that every single piece of data was embedded + assert len(self.evidence_embedder_obj.embed_data) == \ + len(self.dataset) + self.evidence_embedder_obj.clear() + + # complete building the final copy + torch.distributed.barrier() diff --git a/training/DeepSpeed-Domino/megatron/initialize.py b/training/DeepSpeed-Domino/megatron/initialize.py new file mode 100644 index 000000000..367ba85cb --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/initialize.py @@ -0,0 +1,349 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Megatron initialization.""" + +import random +import os +import time + +import numpy as np +import torch +from datetime import timedelta + +from megatron import fused_kernels +from megatron import get_adlr_autoresume +from megatron import get_args +from megatron import get_tensorboard_writer +from megatron.core import mpu, tensor_parallel +from megatron.arguments import parse_args, validate_args +from megatron.checkpointing import load_args_from_checkpoint +from megatron.global_vars import set_global_variables +from megatron.model.transformer import bias_dropout_add_fused_train +from megatron.model.fused_bias_gelu import bias_gelu + + +def initialize_megatron( + extra_args_provider=None, + args_defaults={}, + ignore_unknown_args=False, + allow_no_cuda=False, +): + """Set global variables, initialize distributed, and + set autoresume and random seeds. + `allow_no_cuda` should not be set unless using megatron for cpu only + data processing. In general this arg should not be set unless you know + what you are doing. + Returns a function to finalize distributed env initialization + (optionally, only when args.lazy_mpu_init == True) + """ + if not allow_no_cuda: + # Make sure cuda is available. + assert torch.cuda.is_available(), "Megatron requires CUDA." + + # Parse arguments + args = parse_args(extra_args_provider, ignore_unknown_args) + + if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False): + assert args.load is not None, "--use-checkpoints-args requires --load argument" + load_args_from_checkpoint(args) + + validate_args(args, args_defaults) + + # set global args, build tokenizer, and set adlr-autoresume, + # tensorboard-writer, and timers. + set_global_variables(args) + + # torch.distributed initialization + def finish_mpu_init(): + args = get_args() + # Pytorch distributed. + _initialize_distributed() + + # Random seeds for reproducibility. + if args.rank == 0: + print("> setting random seeds to {} ...".format(args.seed)) + _set_random_seed(args.seed, args.data_parallel_random_init) + + args = get_args() + if args.lazy_mpu_init: + # TODO is this still a necessary option? + args.use_cpu_initialization = True + # delayed initialization of DDP-related stuff + # We only set basic DDP globals + mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) + # and return function for external DDP manager + # to call when it has DDP initialized + mpu.set_tensor_model_parallel_rank(args.rank) + return finish_mpu_init + else: + # Megatron's MPU is the master. Complete initialization right away. + finish_mpu_init() + + # Autoresume. + _init_autoresume() + + # Compile dependencies. + _compile_dependencies() + + # No continuation function + return None + + +def _compile_dependencies(): + + args = get_args() + + # ========================= + # Compile dataset C++ code. + # ========================= + # TODO: move this to ninja + if torch.distributed.get_rank() == 0: + start_time = time.time() + print("> compiling dataset index builder ...") + from megatron.data.dataset_utils import compile_helper + + compile_helper() + print( + ">>> done with dataset index builder. Compilation time: {:.3f} " + "seconds".format(time.time() - start_time), + flush=True, + ) + + # ================== + # Load fused kernels + # ================== + + # Custom kernel constraints check. + seq_len = args.seq_length + attn_batch_size = ( + args.num_attention_heads / args.tensor_model_parallel_size + ) * args.micro_batch_size + # Constraints on sequence length and attn_batch_size to enable warp based + # optimization and upper triangular optimization (for causal mask) + custom_kernel_constraint = ( + seq_len > 16 + and seq_len <= 16384 + and seq_len % 4 == 0 + and attn_batch_size % 4 == 0 + ) + # Print a warning. + if not ( + (args.fp16 or args.bf16) + and custom_kernel_constraint + and args.masked_softmax_fusion + ): + if args.rank == 0: + print( + "WARNING: constraints for invoking optimized" + " fused softmax kernel are not met. We default" + " back to unfused kernel invocations.", + flush=True, + ) + + # Always build on rank zero first. + if torch.distributed.get_rank() == 0: + start_time = time.time() + print("> compiling and loading fused kernels ...", flush=True) + fused_kernels.load(args) + torch.distributed.barrier() + else: + torch.distributed.barrier() + fused_kernels.load(args) + # Simple barrier to make sure all ranks have passed the + # compilation phase successfully before moving on to the + # rest of the program. We think this might ensure that + # the lock is released. + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print( + ">>> done with compiling and loading fused kernels. " + "Compilation time: {:.3f} seconds".format(time.time() - start_time), + flush=True, + ) + + +def _initialize_distributed(): + """Initialize torch.distributed and core model parallel.""" + args = get_args() + + device_count = torch.cuda.device_count() + if torch.distributed.is_initialized(): + + if args.rank == 0: + print( + "torch distributed is already initialized, " + "skipping initialization ...", + flush=True, + ) + args.rank = torch.distributed.get_rank() + args.world_size = torch.distributed.get_world_size() + + else: + + if args.rank == 0: + print("> initializing torch distributed ...", flush=True) + # Manually set the device ids. + if device_count > 0: + device = args.rank % device_count + if args.local_rank is not None: + assert ( + args.local_rank == device + ), "expected local-rank to be the same as rank % device-count." + else: + args.local_rank = device + torch.cuda.set_device(device) + # Call the init process + torch.distributed.init_process_group( + backend=args.distributed_backend, + world_size=args.world_size, + rank=args.rank, + timeout=timedelta(minutes=args.distributed_timeout_minutes), + ) + + # Set the tensor model-parallel, pipeline model-parallel, and + # data-parallel communicators. + if device_count > 0: + if mpu.model_parallel_is_initialized(): + print("model parallel is already initialized") + else: + mpu.initialize_model_parallel( + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, + args.virtual_pipeline_model_parallel_size, + args.pipeline_model_parallel_split_rank, + args.fp8 is not None, + ) + if args.rank == 0: + print( + f"> initialized tensor model parallel with size " + f"{mpu.get_tensor_model_parallel_world_size()}" + ) + print( + f"> initialized pipeline model parallel with size " + f"{mpu.get_pipeline_model_parallel_world_size()}" + ) + + +def _init_autoresume(): + """Set autoresume start time.""" + autoresume = get_adlr_autoresume() + if autoresume: + torch.distributed.barrier() + autoresume.init() + torch.distributed.barrier() + + +def _set_random_seed(seed_, data_parallel_random_init=False): + """Set random seed for reproducability.""" + if seed_ is not None and seed_ > 0: + # Ensure that different pipeline MP stages get different seeds. + seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank()) + # Ensure different data parallel ranks get different seeds + if data_parallel_random_init: + seed = seed + (10 * mpu.get_data_parallel_rank()) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.device_count() > 0: + tensor_parallel.model_parallel_cuda_manual_seed(seed) + else: + raise ValueError("Seed ({}) should be a positive integer.".format(seed)) + + +def write_args_to_tensorboard(): + """Write arguments to tensorboard.""" + args = get_args() + writer = get_tensorboard_writer() + if writer: + for arg in vars(args): + writer.add_text(arg, str(getattr(args, arg)), global_step=args.iteration) + + +def set_jit_fusion_options(): + """Set PyTorch JIT layer fusion options.""" + # flags required to enable jit fusion kernels + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): + # nvfuser + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(True) + torch._C._debug_set_autodiff_subgraph_inlining(False) + else: + # legacy pytorch fuser + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + + _warmup_jit_function() + + +def _warmup_jit_function(): + """Compilie JIT functions before the main training steps""" + args = get_args() + if args.bf16: + dtype = torch.bfloat16 + elif args.fp16: + dtype = torch.float16 + else: + dtype = torch.float32 + + # Warmup fused bias+gelu + bias = torch.rand( + args.ffn_hidden_size // args.tensor_model_parallel_size, + dtype=dtype, + device="cuda", + ) + input = torch.rand( + ( + args.seq_length, + args.micro_batch_size, + args.ffn_hidden_size // args.tensor_model_parallel_size, + ), + dtype=dtype, + device="cuda", + ) + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for bias_grad, input_grad in zip([True, True], [False, True]): + bias.requires_grad, input.requires_grad = bias_grad, input_grad + for _ in range(5): + output = bias_gelu(bias, input) + del bias, input, output + + # Warmup fused bias+dropout+add + if args.sequence_parallel: + seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() + else: + seq_length = args.seq_length + input = torch.rand( + (seq_length, args.micro_batch_size, args.hidden_size), + dtype=dtype, + device="cuda", + ) + residual = torch.rand( + (seq_length, args.micro_batch_size, args.hidden_size), + dtype=dtype, + device="cuda", + ) + bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as( + residual + ) + dropout_rate = 0.1 + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for input_grad, bias_grad, residual_grad in zip( + [False, True], [True, True], [True, True] + ): + input.requires_grad = input_grad + bias.requires_grad = bias_grad + residual.requires_grad = residual_grad + for _ in range(5): + output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate) + del bias, input, residual, output + torch.cuda.empty_cache() diff --git a/training/DeepSpeed-Domino/megatron/memory.py b/training/DeepSpeed-Domino/megatron/memory.py new file mode 100644 index 000000000..a5fef75ba --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/memory.py @@ -0,0 +1,132 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + + +import torch + + +# A dictionary of all the memory buffers allocated. +_MEM_BUFFS = dict() + + +def allocate_mem_buff(name, numel, dtype, track_usage): + """Allocate a memory buffer.""" + assert name not in _MEM_BUFFS, \ + 'memory buffer {} already allocated.'.format(name) + _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage) + return _MEM_BUFFS[name] + + +def get_mem_buff(name): + """Get the memory buffer.""" + return _MEM_BUFFS[name] + + +class MemoryBuffer: + """Contiguous memory buffer. + Allocate a contiguous memory of type `dtype` and size `numel`. It is + used to reduce memory fragmentation. + + Usage: After the allocation, the `_start` index is set tot the first + index of the memory. A memory chunk starting from `_start` index + can be `allocated` for an input tensor, with the elements of the + tensor being coppied. The buffer can be reused by resetting the + `_start` index. + + """ + def __init__(self, name, numel, dtype, track_usage): + if torch.distributed.get_rank() == 0: + element_size = torch.tensor([], dtype=dtype).element_size() + print('> building the {} memory buffer with {} num elements ' + 'and {} dtype ({:.1f} MB)...'.format( + name, numel, dtype, numel*element_size/1024/1024), + flush=True) + self.name = name + self.numel = numel + self.dtype = dtype + self.data = torch.empty(self.numel, + dtype=self.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + + # Index tracking the start of the free memory. + self._start = 0 + + # Values used for tracking usage. + self.track_usage = track_usage + if self.track_usage: + self.in_use_value = 0.0 + self.total_value = 0.0 + + + def reset(self): + """Reset the buffer start index to the beginning of the buffer.""" + self._start = 0 + + + def is_in_use(self): + """Whether the current buffer hold on to any memory.""" + return self._start > 0 + + + def numel_in_use(self): + """Return number of elements in use.""" + return self._start + + + def add(self, tensor): + """Allocate a chunk of memory from the buffer to tensor and copy + the values.""" + assert tensor.dtype == self.dtype, \ + 'Input tensor type {} different from buffer type {}'.format( + tensor.dtype, self.dtype) + # Number of elements of the input tensor. + tensor_numel = torch.numel(tensor) + new_start = self._start + tensor_numel + assert new_start <= self.numel, \ + 'Not enough memory left in the buffer ({} > {})'.format( + tensor_numel, self.numel - self._start) + # New tensor is a view into the memory. + new_tensor = self.data[self._start:new_start] + self._start = new_start + new_tensor = new_tensor.view(tensor.shape) + new_tensor.copy_(tensor) + # Return a pointer to the new tensor. + return new_tensor + + + def get_data(self): + """Return the data currently in use.""" + if self.track_usage: + self.in_use_value += float(self._start) + self.total_value += float(self.numel) + return self.data[:self._start] + + + def print_average_usage(self): + """Print memory usage average over time. We would like this value + to be as high as possible.""" + assert self.track_usage, 'You need to enable track usage.' + if torch.distributed.get_rank() == 0: + print(' > usage of {} memory buffer: {:.2f} %'.format( + self.name, self.in_use_value * 100.0 / self.total_value), + flush=True) + + + +class RingMemBuffer: + """A ring of memory buffers.""" + + def __init__(self, name, num_buffers, numel, dtype, track_usage): + self.num_buffers = num_buffers + self.buffers = [ + allocate_mem_buff(name+' {}'.format(i), numel, dtype, track_usage) + for i in range(num_buffers)] + self._index = -1 + + + def get_next_buffer(self): + self._index += 1 + self._index = self._index % self.num_buffers + buff = self.buffers[self._index] + assert not buff.is_in_use(), 'buffer is already in use.' + return buff diff --git a/training/DeepSpeed-Domino/megatron/microbatches.py b/training/DeepSpeed-Domino/megatron/microbatches.py new file mode 100644 index 000000000..6449d7479 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/microbatches.py @@ -0,0 +1,144 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Megatron number of micro-batches calculators.""" + +from abc import ABC +from abc import abstractmethod + + +def build_num_microbatches_calculator(args): + + # Constant num micro-batches. + if args.rampup_batch_size is None: + num_microbatches_calculator = ConstantNumMicroBatches( + args.global_batch_size, args.micro_batch_size, + args.data_parallel_size) + if args.rank == 0: + print('setting number of micro-batches to constant {}'.format( + num_microbatches_calculator.get()), flush=True) + + else: + assert len(args.rampup_batch_size) == 3, 'expected the following ' \ + 'format: --rampup-batch-size ' \ + ' ' + start_batch_size = int(args.rampup_batch_size[0]) + batch_size_increment = int(args.rampup_batch_size[1]) + ramup_samples = int(args.rampup_batch_size[2]) + if args.rank == 0: + print('will use batch size rampup starting from global batch ' + 'size {} to global batch size {} with batch size increments ' + '{} over {} samples.'.format(start_batch_size, + args.global_batch_size, + batch_size_increment, + ramup_samples), flush=True) + num_microbatches_calculator = RampupBatchsizeNumMicroBatches( + start_batch_size, batch_size_increment, ramup_samples, + args.global_batch_size, args.micro_batch_size, + args.data_parallel_size) + + return num_microbatches_calculator + + +class NumMicroBatchesCalculator(ABC): + + def __init__(self): + self.num_micro_batches = None + self.current_global_batch_size = None + + def get(self): + return self.num_micro_batches + + def get_current_global_batch_size(self): + return self.current_global_batch_size + + @abstractmethod + def update(self, consumed_samples, consistency_check): + pass + + +class ConstantNumMicroBatches(NumMicroBatchesCalculator): + + def __init__(self, global_batch_size, micro_batch_size, data_parallel_size): + micro_batch_times_data_parallel = micro_batch_size * \ + data_parallel_size + assert global_batch_size % micro_batch_times_data_parallel == 0, \ + 'global batch size ({}) is not divisible by micro batch size ({})' \ + ' times data parallel size ({})'.format(global_batch_size, + micro_batch_size, + data_parallel_size) + self.num_micro_batches = global_batch_size // \ + micro_batch_times_data_parallel + assert self.num_micro_batches >= 1 + self.current_global_batch_size = global_batch_size + + def update(self, consumed_samples, consistency_check): + pass + + +class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): + + def __init__(self, start_batch_size, batch_size_increment, ramup_samples, + global_batch_size, micro_batch_size, data_parallel_size): + """Batch size ramp up. + Over + steps = (global-batch-size - start-batch-size) / batch_size_increment + increment batch size from start-batch-size to global-batch-size using + rampup-samples / steps + samples. + Arguments: + start_batch_size: global batch size to start with + batch_size_increment: global batch size increments + ramup_samples: number of samples to use ramp up global + batch size from `start_batch_size` to `global_batch_size` + global_batch_size: global batch size post rampup + micro_batch_size: micro batch size + data_parallel_size: data parallel size. + """ + + self.micro_batch_size = micro_batch_size + self.data_parallel_size = data_parallel_size + self.micro_batch_times_data_parallel_size = self.micro_batch_size * \ + self.data_parallel_size + assert self.micro_batch_times_data_parallel_size > 0 + + assert start_batch_size > 0 + self.start_batch_size = start_batch_size + + assert global_batch_size > 0 + self.global_batch_size = global_batch_size + diff_batch_size = self.global_batch_size - self.start_batch_size + assert diff_batch_size >= 0 + assert batch_size_increment > 0 + self.batch_size_increment = batch_size_increment + assert diff_batch_size % batch_size_increment == 0, 'expected ' \ + 'global batch size interval ({}) to be divisible by global batch ' \ + 'size increment ({})'.format(diff_batch_size, batch_size_increment) + + num_increments = diff_batch_size // self.batch_size_increment + self.ramup_samples = ramup_samples + assert self.ramup_samples >= 0 + self.rampup_samples_per_increment = self.ramup_samples / num_increments + + # Initialize number of microbatches. + self.update(0, False) + + + def update(self, consumed_samples, consistency_check): + + if consumed_samples > self.ramup_samples: + self.current_global_batch_size = self.global_batch_size + else: + steps = int(consumed_samples / self.rampup_samples_per_increment) + self.current_global_batch_size = self.start_batch_size + \ + steps * self.batch_size_increment + assert self.current_global_batch_size <= self.global_batch_size + + if consistency_check: + assert self.current_global_batch_size % \ + self.micro_batch_times_data_parallel_size == 0, 'current global ' \ + 'batch size ({}) is not divisible by micro-batch-size ({}) times' \ + 'data parallel size ({})'.format(self.current_global_batch_size, + self.micro_batch_size, + self.data_parallel_size) + self.num_micro_batches = self.current_global_batch_size // \ + self.micro_batch_times_data_parallel_size diff --git a/training/DeepSpeed-Domino/megatron/model/__init__.py b/training/DeepSpeed-Domino/megatron/model/__init__.py new file mode 100644 index 000000000..f5025bf25 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm + +from .distributed import DistributedDataParallel +from .bert_model import BertModel +from .gpt_model import GPTModel +from .t5_model import T5Model +from .language_model import get_language_model +from .module import Float16Module diff --git a/training/DeepSpeed-Domino/megatron/model/bert_model.py b/training/DeepSpeed-Domino/megatron/model/bert_model.py new file mode 100644 index 000000000..0750d7e6c --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/bert_model.py @@ -0,0 +1,250 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""BERT model.""" + +import torch + +from megatron import get_args +from megatron.core import tensor_parallel +from megatron.model.enums import AttnMaskType +from megatron.model.language_model import parallel_lm_logits +from megatron.model.language_model import get_language_model +from megatron.model import LayerNorm +from megatron.model.utils import openai_gelu, erf_gelu +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from .module import MegatronModule + + +def bert_extended_attention_mask(attention_mask): + # We create a 3D attention mask from a 2D tensor mask. + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = attention_mask.unsqueeze(2) + # [b, s, s] + attention_mask_bss = attention_mask_b1s * attention_mask_bs1 + # [b, 1, s, s] + extended_attention_mask = attention_mask_bss.unsqueeze(1) + + # Convert attention mask to binary: + extended_attention_mask = (extended_attention_mask < 0.5) + + return extended_attention_mask + +def bert_position_ids(token_ids): + # Create position ids + seq_length = token_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, + device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids + + +class BertLMHead(MegatronModule): + """Masked LM head for Bert + + Arguments: + config: TransformerConfig object + mpu_vocab_size: model parallel size of vocabulary. + hidden_size: hidden size + parallel_output: whether output logits being distributed or not. + """ + + def __init__(self, mpu_vocab_size, hidden_size, config, parallel_output): + super().__init__(config=config) + + args = get_args() + self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) + tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) + self.parallel_output = parallel_output + + self.dense = get_linear_layer(hidden_size, hidden_size, config.init_method) + setattr(self.dense.weight, 'sequence_parallel', config.sequence_parallel) + setattr(self.dense.bias, 'sequence_parallel', config.sequence_parallel) + + self.layernorm = LayerNorm(hidden_size, + eps=config.layernorm_epsilon, + sequence_parallel=config.sequence_parallel) + self.gelu = torch.nn.functional.gelu + if args.openai_gelu: + self.gelu = openai_gelu + elif args.onnx_safe: + self.gelu = erf_gelu + + def forward(self, hidden_states, word_embeddings_weight): + hidden_states = self.dense(hidden_states) + hidden_states = self.gelu(hidden_states) + hidden_states = self.layernorm(hidden_states) + output = parallel_lm_logits(hidden_states, + word_embeddings_weight, + self.parallel_output, + bias=self.bias) + return output + + +def post_language_model_processing(lm_output, pooled_output, + lm_head, binary_head, + lm_labels, + logit_weights, + fp16_lm_cross_entropy): + # Output. + lm_logits = lm_head( + lm_output, logit_weights) + + binary_logits = None + if binary_head is not None: + binary_logits = binary_head(pooled_output) + + if lm_labels is None: + # [s b h] => [b s h] + return lm_logits.transpose(0,1).contiguous(), binary_logits + else: + # [b s] => [s b] + lm_labels = lm_labels.transpose(0,1).contiguous() + # lm_logits : [s, b, h] and lm_labels: [s, b] + if fp16_lm_cross_entropy: + assert lm_logits.dtype == torch.half + lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels) + else: + lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), + lm_labels) + # [s, b] => [b s] + lm_loss = lm_loss.transpose(0,1).contiguous() + return lm_loss, binary_logits + + +class BertModel(MegatronModule): + """Bert Language model.""" + + def __init__(self, + config, + num_tokentypes=2, + add_binary_head=True, + parallel_output=True, + pre_process=True, + post_process=True): + super().__init__(config=config) + args = get_args() + + # TODO this option is not yet implemented in BERT + assert args.untie_embeddings_and_output_weights is False + + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + self.add_binary_head = add_binary_head + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + + self.return_embeddings = args.output_bert_embeddings + if self.return_embeddings: + assert self.post_process and self.add_binary_head + + self.language_model, self._language_model_key = get_language_model( + config=config, + num_tokentypes=num_tokentypes, + add_pooler=self.add_binary_head, + encoder_attn_mask_type=AttnMaskType.padding, + pre_process=self.pre_process, + post_process=self.post_process) + + self.initialize_word_embeddings() + if self.post_process: + self.lm_head = BertLMHead(self.shared_embedding_or_output_weight().size(0), config.hidden_size, + config, parallel_output) + self._lm_head_key = 'lm_head' + self.binary_head = None + if self.add_binary_head: + self.binary_head = get_linear_layer(config.hidden_size, 2, + config.init_method) + self._binary_head_key = 'binary_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, bert_model_input, attention_mask, + tokentype_ids=None, lm_labels=None): + + extended_attention_mask = bert_extended_attention_mask(attention_mask) + input_ids = bert_model_input + position_ids = bert_position_ids(input_ids) + + lm_output = self.language_model( + input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids + ) + + if self.post_process and self.add_binary_head: + lm_output, pooled_output = lm_output + + # Return pooled output (e.g., when computing Bert embeddings). + if self.return_embeddings: + + # Sum attention mask. + embeddings = torch.transpose(lm_output, 0, 1) + masks = torch.sum(attention_mask, dim=1) + + # Collect masked embeddings. + output = torch.zeros( + size=(embeddings.shape[0], embeddings.shape[2]), + dtype=torch.float32, + device=torch.cuda.current_device()) + for i, (embedding, mask) in enumerate(zip(embeddings, masks)): + output[i, :] = torch.mean(embedding[1: mask - 1], dim=0) + + return output + + else: + pooled_output = None + + if self.post_process: + return post_language_model_processing(lm_output, pooled_output, + self.lm_head, self.binary_head, + lm_labels, + self.shared_embedding_or_output_weight(), + self.fp16_lm_cross_entropy) + else: + return lm_output + + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.post_process: + state_dict_[self._lm_head_key] \ + = self.lm_head.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.post_process and self.add_binary_head: + state_dict_[self._binary_head_key] \ + = self.binary_head.state_dict(prefix=prefix, keep_vars=keep_vars) + # Save word_embeddings. + if self.post_process and not self.pre_process: + state_dict_[self._word_embeddings_for_head_key] \ + = self.word_embeddings.state_dict(prefix=prefix, keep_vars=keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + if self.post_process: + self.lm_head.load_state_dict( + state_dict[self._lm_head_key], strict=strict) + if self.post_process and self.add_binary_head: + self.binary_head.load_state_dict( + state_dict[self._binary_head_key], strict=strict) + # Load word_embeddings. + if self.post_process and not self.pre_process: + self.word_embeddings.load_state_dict( + state_dict[self._word_embeddings_for_head_key], strict=strict) diff --git a/training/DeepSpeed-Domino/megatron/model/biencoder_model.py b/training/DeepSpeed-Domino/megatron/model/biencoder_model.py new file mode 100644 index 000000000..c910879dc --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/biencoder_model.py @@ -0,0 +1,328 @@ +import os +import torch +import sys + +from megatron import get_args, print_rank_0, get_tokenizer +from megatron.core import mpu +from megatron.checkpointing import fix_query_key_value_ordering +from megatron.checkpointing import get_checkpoint_tracker_filename +from megatron.checkpointing import get_checkpoint_name +from megatron.model.bert_model import bert_position_ids +from megatron.model.enums import AttnMaskType +from megatron.model.language_model import get_language_model +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from .module import MegatronModule + +def get_model_provider(only_query_model=False, only_context_model=False, + biencoder_shared_query_context_model=False): + + def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building Bienoder model ...') + model = biencoder_model_provider(only_query_model=only_query_model, + only_context_model = only_context_model, + biencoder_shared_query_context_model = \ + biencoder_shared_query_context_model, + pre_process=pre_process, post_process=post_process) + + return model + + return model_provider + + +def biencoder_model_provider(only_query_model=False, + only_context_model=False, + biencoder_shared_query_context_model=False, + pre_process=True, + post_process=True): + """Build the model.""" + + assert mpu.get_tensor_model_parallel_world_size() == 1 and \ + mpu.get_pipeline_model_parallel_world_size() == 1, \ + "Model parallel size > 1 not supported for ICT" + + print_rank_0('building BiEncoderModel...') + + # simpler to just keep using 2 tokentypes since + # the LM we initialize with has 2 tokentypes + model = BiEncoderModel( + num_tokentypes=2, + parallel_output=False, + only_query_model=only_query_model, + only_context_model=only_context_model, + biencoder_shared_query_context_model=\ + biencoder_shared_query_context_model, + pre_process=pre_process, + post_process=post_process) + + return model + + +class BiEncoderModel(MegatronModule): + """Bert-based module for Biencoder model.""" + + def __init__(self, + num_tokentypes=1, + parallel_output=True, + only_query_model=False, + only_context_model=False, + biencoder_shared_query_context_model=False, + pre_process=True, + post_process=True): + super(BiEncoderModel, self).__init__() + args = get_args() + + bert_kwargs = dict( + num_tokentypes=num_tokentypes, + parallel_output=parallel_output, + pre_process=pre_process, + post_process=post_process) + + self.biencoder_shared_query_context_model = \ + biencoder_shared_query_context_model + assert not (only_context_model and only_query_model) + self.use_context_model = not only_query_model + self.use_query_model = not only_context_model + self.biencoder_projection_dim = args.biencoder_projection_dim + + if self.biencoder_shared_query_context_model: + self.model = PretrainedBertModel(**bert_kwargs) + self._model_key = 'shared_model' + self.query_model, self.context_model = self.model, self.model + else: + if self.use_query_model: + # this model embeds (pseudo-)queries - Embed_input in the paper + self.query_model = PretrainedBertModel(**bert_kwargs) + self._query_key = 'query_model' + + if self.use_context_model: + # this model embeds evidence blocks - Embed_doc in the paper + self.context_model = PretrainedBertModel(**bert_kwargs) + self._context_key = 'context_model' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + # this is just a placeholder and will be needed when model + # parallelism will be used + # self.language_model.set_input_tensor(input_tensor) + return + + def forward(self, query_tokens, query_attention_mask, query_types, + context_tokens, context_attention_mask, context_types): + """Run a forward pass for each of the models and + return the respective embeddings.""" + + if self.use_query_model: + query_logits = self.embed_text(self.query_model, + query_tokens, + query_attention_mask, + query_types) + else: + raise ValueError("Cannot embed query without the query model.") + if self.use_context_model: + context_logits = self.embed_text(self.context_model, + context_tokens, + context_attention_mask, + context_types) + else: + raise ValueError("Cannot embed block without the block model.") + return query_logits, context_logits + + @staticmethod + def embed_text(model, tokens, attention_mask, token_types): + """Embed a batch of tokens using the model""" + logits = model(tokens, + attention_mask, + token_types) + return logits + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """Save dict with state dicts of each of the models.""" + state_dict_ = {} + if self.biencoder_shared_query_context_model: + state_dict_[self._model_key] = \ + self.model.state_dict_for_save_checkpoint( + prefix=prefix, keep_vars=keep_vars) + else: + if self.use_query_model: + state_dict_[self._query_key] = \ + self.query_model.state_dict_for_save_checkpoint( + prefix=prefix, keep_vars=keep_vars) + + if self.use_context_model: + state_dict_[self._context_key] = \ + self.context_model.state_dict_for_save_checkpoint( + prefix=prefix, keep_vars=keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Load the state dicts of each of the models""" + if self.biencoder_shared_query_context_model: + print_rank_0("Loading shared query-context model") + self.model.load_state_dict(state_dict[self._model_key], \ + strict=strict) + else: + if self.use_query_model: + print_rank_0("Loading query model") + self.query_model.load_state_dict( \ + state_dict[self._query_key], strict=strict) + + if self.use_context_model: + print_rank_0("Loading context model") + self.context_model.load_state_dict( \ + state_dict[self._context_key], strict=strict) + + def init_state_dict_from_bert(self): + """Initialize the state from a pretrained BERT model + on iteration zero of ICT pretraining""" + args = get_args() + + if args.bert_load is None: + print_rank_0("bert-load argument is None") + return + + tracker_filename = get_checkpoint_tracker_filename(args.bert_load) + if not os.path.isfile(tracker_filename): + raise FileNotFoundError("Could not find BERT checkpoint") + with open(tracker_filename, 'r') as f: + iteration = int(f.read().strip()) + assert iteration > 0 + + checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False) + if mpu.get_data_parallel_rank() == 0: + print('global rank {} is loading BERT checkpoint {}'.format( + torch.distributed.get_rank(), checkpoint_name)) + + # Load the checkpoint. + try: + state_dict = torch.load(checkpoint_name, map_location='cpu') + except ModuleNotFoundError: + from megatron.fp16_deprecated import loss_scaler + # For backward compatibility. + print_rank_0(' > deserializing using the old code structure ...') + sys.modules['fp16.loss_scaler'] = sys.modules[ + 'megatron.fp16_deprecated.loss_scaler'] + sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ + 'megatron.fp16_deprecated.loss_scaler'] + state_dict = torch.load(checkpoint_name, map_location='cpu') + sys.modules.pop('fp16.loss_scaler', None) + sys.modules.pop('megatron.fp16.loss_scaler', None) + except BaseException: + print_rank_0('could not load the BERT checkpoint') + sys.exit() + + checkpoint_version = state_dict.get('checkpoint_version', 0) + + # load the LM state dict into each model + model_dict = state_dict['model']['language_model'] + + if self.biencoder_shared_query_context_model: + self.model.language_model.load_state_dict(model_dict) + fix_query_key_value_ordering(self.model, checkpoint_version) + else: + if self.use_query_model: + self.query_model.language_model.load_state_dict(model_dict) + # give each model the same ict_head to begin with as well + if self.biencoder_projection_dim > 0: + query_proj_state_dict = \ + self.state_dict_for_save_checkpoint()\ + [self._query_key]['projection_enc'] + fix_query_key_value_ordering(self.query_model, checkpoint_version) + + if self.use_context_model: + self.context_model.language_model.load_state_dict(model_dict) + if self.query_model is not None and \ + self.biencoder_projection_dim > 0: + self.context_model.projection_enc.load_state_dict\ + (query_proj_state_dict) + fix_query_key_value_ordering(self.context_model, checkpoint_version) + + +class PretrainedBertModel(MegatronModule): + """BERT-based encoder for queries or contexts used for + learned information retrieval.""" + + def __init__(self, num_tokentypes=2, + parallel_output=True, pre_process=True, post_process=True): + super(PretrainedBertModel, self).__init__() + + args = get_args() + tokenizer = get_tokenizer() + self.pad_id = tokenizer.pad + self.biencoder_projection_dim = args.biencoder_projection_dim + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + init_method = init_method_normal(args.init_method_std) + scaled_init_method = scaled_init_method_normal( + args.init_method_std, args.num_layers) + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=False, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method, + pre_process=self.pre_process, + post_process=self.post_process) + + if args.biencoder_projection_dim > 0: + self.projection_enc = get_linear_layer(args.hidden_size, + args.biencoder_projection_dim, + init_method) + self._projection_enc_key = 'projection_enc' + + def forward(self, input_ids, attention_mask, tokentype_ids=None): + extended_attention_mask = attention_mask.unsqueeze(1) + #extended_attention_mask = bert_extended_attention_mask(attention_mask) + position_ids = bert_position_ids(input_ids) + + lm_output = self.language_model(input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids) + # This mask will be used in average-pooling and max-pooling + pool_mask = (input_ids == self.pad_id).unsqueeze(2) + + # Taking the representation of the [CLS] token of BERT + pooled_output = lm_output[0, :, :] + + # Converting to float16 dtype + pooled_output = pooled_output.to(lm_output.dtype) + + # Output. + if self.biencoder_projection_dim: + pooled_output = self.projection_enc(pooled_output) + + return pooled_output + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + prefix=prefix, keep_vars=keep_vars) + + if self.biencoder_projection_dim > 0: + state_dict_[self._projection_enc_key] = \ + self.projection_enc.state_dict(prefix=prefix, + keep_vars=keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + print_rank_0("loading pretrained weights") + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + + if self.biencoder_projection_dim > 0: + print_rank_0("loading projection head weights") + self.projection_enc.load_state_dict( + state_dict[self._projection_enc_key], strict=strict) diff --git a/training/DeepSpeed-Domino/megatron/model/classification.py b/training/DeepSpeed-Domino/megatron/model/classification.py new file mode 100644 index 000000000..bac50c54c --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/classification.py @@ -0,0 +1,101 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Classification model.""" + +import torch + +from megatron import get_args, print_rank_last +from megatron.model.enums import AttnMaskType +from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids +from megatron.model.language_model import get_language_model +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from .module import MegatronModule + + +class Classification(MegatronModule): + + def __init__(self, + config, + num_classes, + num_tokentypes=2, + pre_process=True, + post_process=True): + super().__init__(config=config, share_embeddings_and_output_weights=False) + args = get_args() + + self.num_classes = num_classes + self.pre_process = pre_process + self.post_process = post_process + + self.language_model, self._language_model_key = get_language_model( + config=config, + num_tokentypes=num_tokentypes, + add_pooler=True, + encoder_attn_mask_type=AttnMaskType.padding, + pre_process=self.pre_process, + post_process=self.post_process) + + # Multi-choice head. + if self.post_process: + self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) + self.classification_head = get_linear_layer(args.hidden_size, + self.num_classes, + init_method) + self._classification_head_key = 'classification_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, model_input, attention_mask, tokentype_ids=None): + + extended_attention_mask = bert_extended_attention_mask(attention_mask) + input_ids = model_input + position_ids = bert_position_ids(input_ids) + + lm_output = self.language_model( + input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids + ) + + if self.post_process: + _, pooled_output = lm_output + classification_output = self.classification_dropout(pooled_output) + classification_logits = self.classification_head(classification_output) + + # Reshape back to separate choices. + classification_logits = classification_logits.view(-1, self.num_classes) + + return classification_logits + return lm_output + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.post_process: + state_dict_[self._classification_head_key] \ + = self.classification_head.state_dict(prefix=prefix, keep_vars=keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + if self.post_process: + if self._classification_head_key in state_dict: + self.classification_head.load_state_dict( + state_dict[self._classification_head_key], strict=strict) + else: + print_rank_last('***WARNING*** could not find {} in the checkpoint, ' + 'initializing to random'.format( + self._classification_head_key)) diff --git a/training/DeepSpeed-Domino/megatron/model/distributed.py b/training/DeepSpeed-Domino/megatron/model/distributed.py new file mode 100644 index 000000000..87d5f258d --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/distributed.py @@ -0,0 +1,232 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from abc import ABC +from abc import abstractmethod +import math + +import torch +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from megatron import get_args +from megatron.core import mpu +from .module import MegatronModule + + +class MemoryBuffer: + + def __init__(self, numel, numel_padded, dtype): + self.numel = numel + self.numel_padded = numel_padded + self.dtype = dtype + self.data = torch.zeros(self.numel_padded, + dtype=self.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + + def zero(self): + """Reset the buffer to zero.""" + self.data.zero_() + + + def get(self, shape, start_index): + """Return a tensor with the input `shape` as a view into the + 1-D data starting at `start_index`.""" + end_index = start_index + shape.numel() + assert end_index <= self.numel, \ + 'requested tensor is out of the buffer range.' + buffer_tensor = self.data[start_index:end_index] + buffer_tensor = buffer_tensor.view(shape) + return buffer_tensor + + + +class DistributedDataParallelBase(MegatronModule, ABC): + """Abstract class for DDP.""" + + def __init__(self, module): + super(DistributedDataParallelBase, self).__init__() + # Keep a pointer to the model. + self.module = module + + + @abstractmethod + def allreduce_gradients(self): + pass + + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) + + + def state_dict(self, prefix='', keep_vars=False): + return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) + + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + return self.module.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) + + + +class DistributedDataParallel(DistributedDataParallelBase): + """DDP with contiguous buffers options to store and accumulate gradients. + This class: + - has the potential to reduce memory fragmentation. + - provides the option to do the gradient accumulation + in a type other than the params type (for example fp32) + + Arguments: + module: input model. + accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation + and the gradient all-reduce all in in float32. If this option is + true, we require `use_contiguous_buffers` to be true too. + use_contiguous_buffers: if true, use a contiguous buffer to store the + gradients. + """ + + def __init__(self, module, + accumulate_allreduce_grads_in_fp32, + use_contiguous_buffers): + + super(DistributedDataParallel, self).__init__(module) + + self.accumulate_allreduce_grads_in_fp32 \ + = accumulate_allreduce_grads_in_fp32 + self.use_contiguous_buffers = use_contiguous_buffers + # If we are using fp32-accumulate-allreduce explicitly + # this means we need main grads in a continous buffer. + if self.accumulate_allreduce_grads_in_fp32: + assert self.use_contiguous_buffers + + # =================================== + # Rest of this part applies only to + # the case we use continuous buffers. + # =================================== + self._grad_buffers = None + self._grad_buffer_param_index_map = None + if self.use_contiguous_buffers: + self._grad_buffers = {} + self._grad_buffer_param_index_map = {} + data_parallel_world_size = mpu.get_data_parallel_world_size() + + # Simple function to define buffer type. + def _get_buffer_type(param): + return torch.float if \ + self.accumulate_allreduce_grads_in_fp32 else param.dtype + + # First calculate total number of elements per type. + type_num_elements = {} + for param in self.module.parameters(): + if param.requires_grad: + dtype = _get_buffer_type(param) + type_num_elements[dtype] = type_num_elements.get(dtype, 0) \ + + param.data.nelement() + + # Allocate the buffer. + for dtype, num_elements in type_num_elements.items(): + + # If using distributed optimizer, pad memory buffer to be + # multiple of data_parallel_world_size. (This padding is done + # due to a constraint with the reduce_scatter op, which requires + # all tensors have equal size. See: optimizer.py.) + num_elements_padded = data_parallel_world_size * \ + int(math.ceil(num_elements / data_parallel_world_size)) + + # Allocate grad buffer. + self._grad_buffers[dtype] = MemoryBuffer(num_elements, + num_elements_padded, + dtype) + + # Assume the back prop order is reverse the params order, + # store the start index for the gradients. + for param in self.module.parameters(): + if param.requires_grad: + dtype = _get_buffer_type(param) + type_num_elements[dtype] -= param.data.nelement() + param.main_grad = self._grad_buffers[dtype].get( + param.data.shape, type_num_elements[dtype]) + if dtype not in self._grad_buffer_param_index_map: + self._grad_buffer_param_index_map[dtype] = {} + self._grad_buffer_param_index_map[dtype][param] = ( + type_num_elements[dtype], + type_num_elements[dtype] + param.data.nelement(), + ) + + # Backward hook. + # Accumalation function for the gradients. We need + # to store them so they don't go out of scope. + self.grad_accs = [] + # Loop over all the parameters in the model. + for param in self.module.parameters(): + if param.requires_grad: + # Expand so we get access to grad_fn. + param_tmp = param.expand_as(param) + # Get the gradient accumulator functtion. + grad_acc = param_tmp.grad_fn.next_functions[0][0] + grad_acc.register_hook(self._make_param_hook(param)) + self.grad_accs.append(grad_acc) + + + def _make_param_hook(self, param): + """Create the all-reduce hook for backprop.""" + # Hook used for back-prop. + def param_hook(*unused): + # Add the gradient to the buffer. + if param.grad is not None: + # The gradient function of linear layers is fused with GEMMs + param.main_grad.add_(param.grad.data) + # Now we can deallocate grad memory. + param.grad = None + return param_hook + + + def zero_grad_buffer(self): + """Set the grad buffer data to zero. Needs to be called at the + begining of each iteration.""" + assert self._grad_buffers is not None, 'buffers are not initialized.' + for _, buffer_ in self._grad_buffers.items(): + buffer_.zero() + + + def broadcast_params(self): + for param in self.module.parameters(): + torch.distributed.broadcast(param.data, + src=mpu.get_data_parallel_src_rank(), + group=mpu.get_data_parallel_group()) + + + def allreduce_gradients(self): + """Reduce gradients across data parallel ranks.""" + # If we have buffers, simply reduce the data in the buffer. + if self._grad_buffers is not None: + for _, buffer_ in self._grad_buffers.items(): + buffer_.data /= mpu.get_data_parallel_world_size() + torch.distributed.all_reduce( + buffer_.data, group=mpu.get_data_parallel_group()) + else: + # Otherwise, bucketize and all-reduce + buckets = {} + # Pack the buckets. + for param in self.module.parameters(): + if param.requires_grad and param.grad is not None: + tp = param.data.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + param.main_grad = param.grad + + # For each bucket, all-reduce and copy all-reduced grads. + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + coalesced /= mpu.get_data_parallel_world_size() + torch.distributed.all_reduce( + coalesced, group=mpu.get_data_parallel_group()) + for buf, synced in zip(grads, _unflatten_dense_tensors( + coalesced, grads)): + buf.copy_(synced) diff --git a/training/DeepSpeed-Domino/megatron/model/enums.py b/training/DeepSpeed-Domino/megatron/model/enums.py new file mode 100644 index 000000000..bc4e4aa29 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/enums.py @@ -0,0 +1,21 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import enum + +class LayerType(enum.Enum): + encoder = 1 + decoder = 2 + retro_encoder = 3 + retro_decoder = 4 + retro_decoder_with_retriever = 5 + +class AttnType(enum.Enum): + self_attn = 1 + cross_attn = 2 + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + +# For backward compatibility with old model checkpoints +from megatron.core.enums import ModelType diff --git a/training/DeepSpeed-Domino/megatron/model/fused_bias_gelu.py b/training/DeepSpeed-Domino/megatron/model/fused_bias_gelu.py new file mode 100644 index 000000000..29222db02 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/fused_bias_gelu.py @@ -0,0 +1,43 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import torch + + +###### BIAS GELU FUSION/ NO AUTOGRAD ################ +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) + +@torch.jit.script +def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_back(g, bias, y): + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff*g + +class GeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(bias, input) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_back(grad_output, bias, input) + return tmp, tmp + +bias_gelu_impl = GeLUFunction.apply diff --git a/training/DeepSpeed-Domino/megatron/model/fused_layer_norm.py b/training/DeepSpeed-Domino/megatron/model/fused_layer_norm.py new file mode 100644 index 000000000..fd8591e4a --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/fused_layer_norm.py @@ -0,0 +1,92 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""This code is copied fron NVIDIA apex: + https://github.com/NVIDIA/apex + with some changes. """ + +import numbers +import torch +from torch.nn.parameter import Parameter +from torch.nn import init +import importlib + +from megatron.core.utils import make_viewless_tensor + +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNormFN + HAVE_PERSIST_LAYER_NORM = True +except: + HAVE_PERSIST_LAYER_NORM = False + +from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction + + +global fused_layer_norm_cuda +fused_layer_norm_cuda = None + + +class MixedFusedLayerNorm(torch.nn.Module): + + def __init__(self, normalized_shape, eps=1e-5, + no_persist_layer_norm=True, + sequence_parallel=False, + apply_layernorm_1p=False): + super(MixedFusedLayerNorm, self).__init__() + + self.apply_layernorm_1p = apply_layernorm_1p + + global fused_layer_norm_cuda + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + + # List of hiddens sizes supported in the persistent layer norm kernel + # If the hidden size is not supported, fall back to the non-persistent + # kernel. + persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096, + 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, + 24576, 25600, 30720, 32768, 40960, 49152, 65536] + if normalized_shape not in persist_ln_hidden_sizes or \ + not HAVE_PERSIST_LAYER_NORM: + no_persist_layer_norm = True + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.weight = Parameter(torch.Tensor(*normalized_shape)) + self.bias = Parameter(torch.Tensor(*normalized_shape)) + self.reset_parameters() + self.no_persist_layer_norm = no_persist_layer_norm + self.sequence_parallel = sequence_parallel + + # set sequence parallelism flag on weight and bias parameters + setattr(self.weight, 'sequence_parallel', self.sequence_parallel) + setattr(self.bias, 'sequence_parallel', self.sequence_parallel) + + + def reset_parameters(self): + + if self.apply_layernorm_1p: + init.zeros_(self.weight) + init.zeros_(self.bias) + else: + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input): + + weight = self.weight + 1 if self.apply_layernorm_1p else self.weight + + if self.no_persist_layer_norm: + return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps) + else: + output = FastLayerNormFN.apply(input, weight, self.bias, self.eps) + + # Apex's fast layer norm function outputs a 'view' tensor (i.e., has + # a populated '_base' field). This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + output = make_viewless_tensor(inp = output, + requires_grad = input.requires_grad, + keep_graph = True) + + return output diff --git a/training/DeepSpeed-Domino/megatron/model/fused_softmax.py b/training/DeepSpeed-Domino/megatron/model/fused_softmax.py new file mode 100644 index 000000000..9bacf3374 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/fused_softmax.py @@ -0,0 +1,213 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + + +import torch +import torch.nn as nn +from megatron.model.enums import AttnMaskType + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + import scaled_upper_triang_masked_softmax_cuda + + scale_t = torch.tensor([scale]) + softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( + inputs, scale_t[0] + ) + + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import scaled_upper_triang_masked_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_upper_triang_masked_softmax_cuda.backward( + output_grads, softmax_results, scale_t[0] + ) + + return input_grads, None + + +class ScaledMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, mask, scale): + import scaled_masked_softmax_cuda + + scale_t = torch.tensor([scale]) + + softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import scaled_masked_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_masked_softmax_cuda.backward( + output_grads, softmax_results, scale_t[0] + ) + return input_grads, None, None + + +class ScaledSoftmax(torch.autograd.Function): + """ + Fused operation which performs following two operations in sequence + 1. Scale the tensor. + 2. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + import scaled_softmax_cuda + + scale_t = torch.tensor([scale]) + + softmax_results = scaled_softmax_cuda.forward( + inputs, scale_t[0] + ) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import scaled_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_softmax_cuda.backward( + output_grads, softmax_results, scale_t[0] + ) + return input_grads, None, None + + +class FusedScaleMaskSoftmax(nn.Module): + """ + fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: flag to indicate if input in fp16 data format. + input_in_bf16: flag to indicate if input in bf16 data format. + attn_mask_type: attention mask type (pad or causal) + scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion + mask_func: mask function to be applied. + softmax_in_fp32: if true, softmax in performed at fp32 precision. + scale: scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super(FusedScaleMaskSoftmax, self).__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + + assert ( + self.scale is None or softmax_in_fp32 + ), "softmax should be in fp32 when scaled" + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and 16 < sk <= 16384 # sk must be 16 ~ 16384 + and sq % 4 == 0 # sq must be divisor of 4 + and sk % 4 == 0 # sk must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 16384: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type == AttnMaskType.causal: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 + + if self.attn_mask_type == AttnMaskType.causal: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) + return probs.view(b, np, sq, sk) + else: + # input is 4D tensor (b, np, sq, sk) + if mask is not None: + return ScaledMaskedSoftmax.apply(input, mask, scale) + else: + return ScaledSoftmax.apply(input, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + @staticmethod + def get_batch_per_block(sq, sk, b, np): + import scaled_masked_softmax_cuda + + return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) diff --git a/training/DeepSpeed-Domino/megatron/model/gpt_model.py b/training/DeepSpeed-Domino/megatron/model/gpt_model.py new file mode 100644 index 000000000..dd47188da --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/gpt_model.py @@ -0,0 +1,122 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""GPT-2 model.""" + +import torch + +from megatron import get_args +from megatron.core import tensor_parallel +from .module import MegatronModule + +from .enums import AttnMaskType +from .language_model import parallel_lm_logits +from .language_model import get_language_model + + +def post_language_model_processing(lm_output, labels, logit_weights, + parallel_output, + fp16_lm_cross_entropy): + + # Output. Format [s b h] + output = parallel_lm_logits( + lm_output, + logit_weights, + parallel_output) + + if labels is None: + # [s b h] => [b s h] + return output.transpose(0,1).contiguous() + else: + # [b s] => [s b] + labels = labels.transpose(0,1).contiguous() + if fp16_lm_cross_entropy: + assert output.dtype == torch.half + loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) + else: + loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) + + # [s b] => [b, s] + loss = loss.transpose(0,1).contiguous() + return loss + + +class GPTModel(MegatronModule): + """GPT-2 Language model.""" + + def __init__(self, + config, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True): + args = get_args() + super().__init__(config=config, share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights) + + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights + + self.language_model, self._language_model_key = get_language_model( + config=config, + num_tokentypes=num_tokentypes, + add_pooler=False, + encoder_attn_mask_type=AttnMaskType.causal, + pre_process=self.pre_process, + post_process=self.post_process) + + if not args.untie_embeddings_and_output_weights: + self.initialize_word_embeddings() + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, input_ids, position_ids, attention_mask, + retriever_input_ids=None, + retriever_position_ids=None, + retriever_attn_mask=None, + labels=None, tokentype_ids=None, inference_params=None): + + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + retriever_input_ids=retriever_input_ids, + retriever_position_ids=retriever_position_ids, + retriever_attn_mask=retriever_attn_mask, + inference_params=inference_params) + + if self.post_process: + return post_language_model_processing( + lm_output, labels, + self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight(), + self.parallel_output, + self.fp16_lm_cross_entropy) + else: + return lm_output + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + prefix=prefix, keep_vars=keep_vars) + # Save word_embeddings. + if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights: + state_dict_[self._word_embeddings_for_head_key] \ + = self.word_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Load word_embeddings. + if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights: + self.word_embeddings.load_state_dict( + state_dict[self._word_embeddings_for_head_key], strict=strict) + if self._language_model_key in state_dict: + state_dict = state_dict[self._language_model_key] + self.language_model.load_state_dict(state_dict, strict=strict) diff --git a/training/DeepSpeed-Domino/megatron/model/language_model.py b/training/DeepSpeed-Domino/megatron/model/language_model.py new file mode 100644 index 000000000..85b5dc5cb --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/language_model.py @@ -0,0 +1,634 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""Transformer based language model.""" + +import torch +import torch.nn.functional as F + +from megatron import get_args +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.core.models.common.rotary_pos_embedding import RotaryEmbedding + +from .enums import AttnMaskType, LayerType +from .module import MegatronModule +from .transformer import ParallelTransformer +from .utils import get_linear_layer +from .utils import init_method_normal, scaled_init_method_normal + + +def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, + bias=None): + """LM logits using word embedding weights.""" + args = get_args() + # Parallel logits. + if args.async_tensor_model_parallel_allreduce or\ + args.sequence_parallel: + input_parallel = input_ + model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 + async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ + model_parallel and not args.sequence_parallel + else: + input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_) + async_grad_allreduce = False + + # Matrix multiply. + logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce( + input=input_parallel, + weight=word_embeddings_weight, + bias=bias, + gradient_accumulation_fusion=args.gradient_accumulation_fusion, + async_grad_allreduce=async_grad_allreduce, + sequence_parallel=args.sequence_parallel) + # Gather if needed. + + if parallel_output: + return logits_parallel + + return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel) + + +def get_language_model(config, num_tokentypes, add_pooler, + encoder_attn_mask_type, + add_encoder=True, + add_decoder=False, + decoder_attn_mask_type=AttnMaskType.causal, + pre_process=True, post_process=True): + """Build language model and return along with the key to save.""" + args = get_args() + if config.init_method is None: + config.init_method = init_method_normal(config.init_method_std) + + if config.output_layer_init_method is None: + config.output_layer_init_method = scaled_init_method_normal(config.init_method_std, + config.num_layers) + + # Language model. + language_model = TransformerLanguageModel( + config, + encoder_attn_mask_type, + num_tokentypes=num_tokentypes, + add_encoder=add_encoder, + add_decoder=add_decoder, + decoder_attn_mask_type=decoder_attn_mask_type, + add_pooler=add_pooler, + pre_process=pre_process, + post_process=post_process + ) + # key used for checkpoints. + language_model_key = 'language_model' + + return language_model, language_model_key + + +class Pooler(MegatronModule): + """Pooler layer. + + Pool hidden states of a specific token (for example start of the + sequence) and add a linear transformation followed by a tanh. + + Arguments: + hidden_size: hidden size + init_method: weight initialization method for the linear layer. + bias is set to zero. + """ + + def __init__(self, hidden_size, init_method): + super(Pooler, self).__init__() + args = get_args() + self.dense = get_linear_layer(hidden_size, hidden_size, init_method) + self.sequence_parallel = args.sequence_parallel + + + def forward(self, hidden_states, sequence_index=0): + # hidden_states: [s, b, h] + # sequence_index: index of the token to pool. + + # gather data along sequence dimensions + # same pooler is run on all tensor parallel nodes + if self.sequence_parallel: + hidden_states = tensor_parallel.gather_from_sequence_parallel_region( + hidden_states, + tensor_parallel_output_grad=False) + + pooled = hidden_states[sequence_index, :, :] + pooled = self.dense(pooled) + pooled = torch.tanh(pooled) + return pooled + + +class Embedding(MegatronModule): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + embedding_weights_in_fp32: casts word embedding weights to + fp32 before sampling. Required to + maintain reproducibility when + training in bf16. + """ + + def __init__(self, + hidden_size, + vocab_size, + max_sequence_length, + embedding_dropout_prob, + config, + num_tokentypes=0, + embedding_weights_in_fp32=False): + super(Embedding, self).__init__() + + self.hidden_size = hidden_size + self.init_method = config.init_method + self.num_tokentypes = num_tokentypes + + args = get_args() + + # Word embeddings (parallel). + self.embedding_weights_in_fp32 = embedding_weights_in_fp32 + self.params_dtype = args.params_dtype + self.word_embeddings = tensor_parallel.VocabParallelEmbedding( + vocab_size, self.hidden_size, config=config, init_method=config.init_method) + self._word_embeddings_key = 'word_embeddings' + + # Position embedding (serial). + self.add_position_embedding = args.position_embedding_type == 'learned_absolute' + if self.add_position_embedding: + self.position_embeddings = torch.nn.Embedding( + max_sequence_length, self.hidden_size) + self._position_embeddings_key = 'position_embeddings' + # Initialize the position embeddings. + if args.perform_initialization: + self.init_method(self.position_embeddings.weight) + + # Token type embedding. + # Add this as an optional field that can be added through + # method call so we can load a pretrain model without + # token types and add them as needed. + self._tokentype_embeddings_key = 'tokentype_embeddings' + if self.num_tokentypes > 0: + self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, + self.hidden_size) + # Initialize the token-type embeddings. + if args.perform_initialization: + self.init_method(self.tokentype_embeddings.weight) + else: + self.tokentype_embeddings = None + + self.fp32_residual_connection = args.fp32_residual_connection + self.sequence_parallel = args.sequence_parallel + # Embeddings dropout + self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) + + def zero_parameters(self): + """Zero out all parameters in embedding.""" + self.word_embeddings.weight.data.fill_(0) + self.word_embeddings.weight.shared = True + if self.add_position_embedding: + self.position_embeddings.weight.data.fill_(0) + self.position_embeddings.weight.shared = True + if self.num_tokentypes > 0: + self.tokentype_embeddings.weight.data.fill_(0) + self.tokentype_embeddings.weight.shared = True + + def add_tokentype_embeddings(self, num_tokentypes): + """Add token-type embedding. This function is provided so we can add + token-type embeddings in case the pretrained model does not have it. + This allows us to load the model normally and then add this embedding. + """ + if self.tokentype_embeddings is not None: + raise Exception('tokentype embeddings is already initialized') + if torch.distributed.get_rank() == 0: + print('adding embedding for {} tokentypes'.format(num_tokentypes), + flush=True) + self.num_tokentypes = num_tokentypes + self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, + self.hidden_size) + # Initialize the token-type embeddings. + args = get_args() + self.init_method(self.tokentype_embeddings.weight) + + def forward(self, input_ids, position_ids, tokentype_ids=None): + # Embeddings. + if self.embedding_weights_in_fp32: + self.word_embeddings = self.word_embeddings.to(torch.float32) + words_embeddings = self.word_embeddings(input_ids) + if self.embedding_weights_in_fp32: + words_embeddings = words_embeddings.to(self.params_dtype) + self.word_embeddings = self.word_embeddings.to(self.params_dtype) + if self.add_position_embedding: + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + else: + embeddings = words_embeddings + + if tokentype_ids is not None: + assert self.tokentype_embeddings is not None + embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) + else: + assert self.tokentype_embeddings is None + + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + + # Dropout. + if self.sequence_parallel: + embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) + with tensor_parallel.get_cuda_rng_tracker().fork(): + embeddings = self.embedding_dropout(embeddings) + else: + embeddings = self.embedding_dropout(embeddings) + + return embeddings + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._word_embeddings_key] \ + = self.word_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + if self.add_position_embedding: + state_dict_[self._position_embeddings_key] \ + = self.position_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + if self.num_tokentypes > 0: + state_dict_[self._tokentype_embeddings_key] \ + = self.tokentype_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Word embedding. + if self._word_embeddings_key in state_dict: + state_dict_ = state_dict[self._word_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'word_embeddings' in key: + state_dict_[key.split('word_embeddings.')[1]] \ + = state_dict[key] + self.word_embeddings.load_state_dict(state_dict_, strict=strict) + + # Position embedding. + if self.add_position_embedding: + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] \ + = state_dict[key] + self.position_embeddings.load_state_dict(state_dict_, strict=strict) + + # Tokentype embedding. + if self.num_tokentypes > 0: + state_dict_ = {} + if self._tokentype_embeddings_key in state_dict: + state_dict_ = state_dict[self._tokentype_embeddings_key] + else: + # for backward compatibility. + for key in state_dict.keys(): + if 'tokentype_embeddings' in key: + state_dict_[key.split('tokentype_embeddings.')[1]] \ + = state_dict[key] + if len(state_dict_.keys()) > 0: + self.tokentype_embeddings.load_state_dict(state_dict_, + strict=strict) + else: + print('***WARNING*** expected tokentype embeddings in the ' + 'checkpoint but could not find it', flush=True) + + +class TransformerLanguageModel(MegatronModule): + """Transformer language model. + + Arguments: + transformer_hparams: transformer hyperparameters + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, + config, + encoder_attn_mask_type, + num_tokentypes=0, + add_encoder=True, + add_decoder=False, + decoder_attn_mask_type=AttnMaskType.causal, + add_pooler=False, + pre_process=True, + post_process=True): + args = get_args() + # TODO: passing share_embeddings_and_output_weights=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5. + if args.untie_embeddings_and_output_weights: assert not add_decoder + super(TransformerLanguageModel, self).__init__(share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights) + + self.pre_process = pre_process + self.post_process = post_process + self.hidden_size = config.hidden_size + self.num_tokentypes = num_tokentypes + self.init_method = config.init_method + self.add_encoder = add_encoder + self.encoder_attn_mask_type = encoder_attn_mask_type + self.add_decoder = add_decoder + self.decoder_attn_mask_type = decoder_attn_mask_type + self.add_pooler = add_pooler + self.encoder_hidden_state = None + self.add_retriever = args.retro_add_retriever + self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights + + # Embeddings. + if self.pre_process: + self.embedding = Embedding(self.hidden_size, + args.padded_vocab_size, + args.max_position_embeddings, + args.hidden_dropout, + config, + self.num_tokentypes, + args.embedding_weights_in_fp32) + self._embedding_key = 'embedding' + + # Rotary positional embeddings + self.use_rotary_position_embeddings = \ + args.position_embedding_type == 'rope' + if self.use_rotary_position_embeddings: + self.seq_length = args.seq_length + rotary_dim = args.hidden_size // args.num_attention_heads \ + if args.kv_channels is None else args.kv_channels + + if args.rotary_percent < 1.0: + rotary_dim = int(rotary_dim * args.rotary_percent) + + # partial rotary embeddings, which is better than full rotary + # Wang and Komatsuzaki et al + # https://github.com/kingoflolz/mesh-transformer-jax/ + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim, + seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + ) + + # Encoder (usually set to True, False if part of an encoder-decoder + # architecture and in encoder-only stage). + if self.add_encoder: + self.encoder = ParallelTransformer( + config, + model_type=args.model_type if not args.retro_add_retriever \ + else ModelType.retro_decoder, + self_attn_mask_type=self.encoder_attn_mask_type, + pre_process=self.pre_process, + post_process=self.post_process, + ) + self._encoder_key = 'encoder' + else: + self.encoder = None + + # Decoder (usually set to False, True if part of an encoder-decoder + # architecture and in decoder-only stage). + if self.add_decoder: + self.decoder = ParallelTransformer( + config, + model_type=args.model_type, + layer_type=LayerType.decoder, + self_attn_mask_type=self.decoder_attn_mask_type, + pre_process=self.pre_process, + post_process=self.post_process) + self._decoder_key = 'decoder' + else: + self.decoder = None + + if self.post_process: + # Pooler. + if self.add_pooler: + self.pooler = Pooler(self.hidden_size, self.init_method) + self._pooler_key = 'pooler' + + if self.untie_embeddings_and_output_weights: + self.output_layer = tensor_parallel.ColumnParallelLinear( + args.hidden_size, + args.padded_vocab_size, + config=config, + init_method=self.init_method, + bias=False) # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias. + self._output_layer_key = 'output_layer' + + def set_input_tensor(self, input_tensor): + """ See megatron.model.transformer.set_input_tensor()""" + + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + if self.add_encoder and self.add_decoder: + assert len(input_tensor) == 1, \ + 'input_tensor should only be length 1 for stage with both encoder and decoder' + self.encoder.set_input_tensor(input_tensor[0]) + elif self.add_encoder: + assert len(input_tensor) == 1, \ + 'input_tensor should only be length 1 for stage with only encoder' + self.encoder.set_input_tensor(input_tensor[0]) + elif self.add_decoder: + if len(input_tensor) == 2: + self.decoder.set_input_tensor(input_tensor[0]) + self.encoder_hidden_state = input_tensor[1] + elif len(input_tensor) == 1: + self.decoder.set_input_tensor(None) + self.encoder_hidden_state = input_tensor[0] + else: + raise Exception('input_tensor must have either length 1 or 2') + else: + raise Exception('Stage must have at least either encoder or decoder') + + def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, + dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, + retriever_input_ids=None, + retriever_position_ids=None, + retriever_attn_mask=None, + enc_dec_attn_mask=None, tokentype_ids=None, + inference_params=None, + pooling_sequence_index=0, + enc_hidden_states=None, output_enc_hidden=False): + + # Encoder embedding. + if self.pre_process: + encoder_input = self.embedding(enc_input_ids, enc_position_ids, + tokentype_ids=tokentype_ids) + else: + encoder_input = None + + # Retriever embedding. + if self.add_retriever and self.pre_process: + retriever_input = self.embedding(retriever_input_ids, + retriever_position_ids, + tokentype_ids=tokentype_ids) + else: + retriever_input = None + + # Rotary positional embeddings + rotary_pos_emb = None + if self.use_rotary_position_embeddings: + if inference_params is not None: + rotary_pos_emb = \ + self.rotary_pos_emb(inference_params.max_sequence_length) + else: + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + + # Run encoder. + if enc_hidden_states is None: + if self.encoder is not None: + encoder_output = self.encoder( + encoder_input, + enc_attn_mask, + retriever_input=retriever_input, + retriever_attn_mask=retriever_attn_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb) + else: + encoder_output = self.encoder_hidden_state + else: + encoder_output = enc_hidden_states.to(encoder_input.dtype) + + if self.post_process: + if self.add_pooler: + pooled_output = self.pooler(encoder_output, + pooling_sequence_index) + + # output_enc_hidden refers to when we just need the encoder's + # output. For example, it is helpful to compute + # similarity between two sequences by average pooling + if not self.add_decoder or output_enc_hidden: + if self.add_pooler and self.post_process: + return encoder_output, pooled_output + else: + return encoder_output + + # Decoder embedding. + if self.pre_process: + decoder_input = self.embedding(dec_input_ids, + dec_position_ids) + else: + decoder_input = None + + # Run decoder. + decoder_output = self.decoder( + decoder_input, + dec_attn_mask, + encoder_output=encoder_output, + enc_dec_attn_mask=enc_dec_attn_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb) + + if self.add_pooler and self.post_process: + return decoder_output, encoder_output, pooled_output + else: + return decoder_output, encoder_output + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + if self.pre_process: + state_dict_[self._embedding_key] \ + = self.embedding.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.add_encoder: + state_dict_[self._encoder_key] \ + = self.encoder.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.post_process: + if self.add_pooler: + state_dict_[self._pooler_key] \ + = self.pooler.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.untie_embeddings_and_output_weights: + state_dict_[self._output_layer_key] \ + = self.output_layer.state_dict(prefix=prefix, keep_vars=keep_vars) + + if self.add_decoder: + state_dict_[self._decoder_key] \ + = self.decoder.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Embedding. + if self.pre_process: + if self._embedding_key in state_dict: + state_dict_ = state_dict[self._embedding_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if '_embeddings' in key: + state_dict_[key] = state_dict[key] + self.embedding.load_state_dict(state_dict_, strict=strict) + + # Encoder. + if self.add_encoder: + if self._encoder_key in state_dict: + state_dict_ = state_dict[self._encoder_key] + # For backward compatibility. + elif 'transformer' in state_dict: + state_dict_ = state_dict['transformer'] + else: + # For backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'transformer.' in key: + state_dict_[key.split('transformer.')[1]] = state_dict[key] + + # For backward compatibility. + state_dict_self_attention = {} + for key in state_dict_.keys(): + if '.attention.' in key: + state_dict_self_attention[key.replace(".attention.", + ".self_attention.")] = state_dict_[key] + else: + state_dict_self_attention[key] = state_dict_[key] + state_dict_ = state_dict_self_attention + + self.encoder.load_state_dict(state_dict_, strict=strict) + + # Pooler. + if self.post_process: + if self.add_pooler: + assert 'pooler' in state_dict, \ + 'could not find data for pooler in the checkpoint' + self.pooler.load_state_dict(state_dict[self._pooler_key], + strict=strict) + if self.untie_embeddings_and_output_weights: + assert 'output_layer' in state_dict, \ + 'could not find data for output_layer in the checkpoint' + self.output_layer.load_state_dict(state_dict[self._output_layer_key], + strict=strict) + # Decoder. + if self.add_decoder: + assert 'decoder' in state_dict, \ + 'could not find data for pooler in the checkpoint' + self.decoder.load_state_dict(state_dict[self._decoder_key], + strict=strict) diff --git a/training/DeepSpeed-Domino/megatron/model/module.py b/training/DeepSpeed-Domino/megatron/model/module.py new file mode 100644 index 000000000..c2887315a --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/module.py @@ -0,0 +1,197 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Megatron Module""" + +import torch +from torch.autograd import Variable +from torch.nn.parameter import Parameter + +from megatron import get_args +from megatron.core import mpu, tensor_parallel + + +_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) +_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) +_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor) + + + +def param_is_not_shared(param): + return not hasattr(param, 'shared') or not param.shared + + + +class MegatronModule(torch.nn.Module): + """Megatron specific extensions of torch Module with support + for pipelining.""" + + def __init__(self, config=None, share_embeddings_and_output_weights=True): + super(MegatronModule, self).__init__() + self.config = config + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """Use this function to override the state dict for + saving checkpoints.""" + return self.state_dict(prefix=prefix, keep_vars=keep_vars) + + + def shared_embedding_or_output_weight(self): + if self.pre_process: + return self.language_model.embedding.word_embeddings.weight + else: + if not self.share_embeddings_and_output_weights: + raise Exception('shared_embedding_or_output_weight() called for last ' + 'stage, but share_embeddings_and_output_weights is false') + return self.word_embeddings.weight + + + def initialize_word_embeddings(self): + args = get_args() + if not self.share_embeddings_and_output_weights: + raise Exception('initialize_word_embeddings() was called but ' + 'share_embeddings_and_output_weights is false') + + # This function just initializes the word embeddings in the final stage + # when we are using pipeline parallelism. Nothing to do if we aren't + # using pipeline parallelism. + if args.pipeline_model_parallel_size == 1: + return + + # Parameters are shared between the word embeddings layers, and the + # heads at the end of the model. In a pipelined setup with more than + # one stage, the initial embedding layer and the head are on different + # workers, so we do the following: + # 1. Create a second copy of word_embeddings on the last stage, with + # initial parameters of 0.0. + # 2. Do an all-reduce between the first and last stage to ensure that + # the two copies of word_embeddings start off with the same + # parameter values. + # 3. In the training loop, before an all-reduce between the grads of + # the two word_embeddings layers to ensure that every applied weight + # update is the same on both stages. + if mpu.is_pipeline_last_stage() and not self.pre_process: + assert not mpu.is_pipeline_first_stage() + self._word_embeddings_for_head_key = 'word_embeddings_for_head' + # set word_embeddings weights to 0 here, then copy first + # stage's weights using all_reduce below. + self.word_embeddings = tensor_parallel.VocabParallelEmbedding( + args.padded_vocab_size, self.config.hidden_size, + config=self.config, init_method=self.config.init_method) + self.word_embeddings.weight.data.fill_(0) + self.word_embeddings.weight.shared = True + + # Zero out initial weights for decoder embedding. + # NOTE: We don't currently support T5 with the interleaved schedule. + if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \ + self.pre_process: + self.language_model.embedding.zero_parameters() + + if not torch.distributed.is_initialized(): + if not getattr(MegatronModule, "embedding_warning_printed", False): + print("WARNING! Distributed processes aren't initialized, so " + "word embeddings in the last layer are not initialized. " + "If you are just manipulating a model this is fine, but " + "this needs to be handled manually. If you are training " + "something is definitely wrong.") + MegatronModule.embedding_warning_printed = True + return + + # Ensure that first and last stages have the same initial parameter + # values. + if mpu.is_rank_in_embedding_group(): + torch.distributed.all_reduce(self.shared_embedding_or_output_weight().data, + group=mpu.get_embedding_group()) + + # Ensure that encoder(first stage) and decoder(split stage) position + # embeddings have the same initial parameter values + # NOTE: We don't currently support T5 with the interleaved schedule. + if mpu.is_rank_in_position_embedding_group() and \ + args.pipeline_model_parallel_split_rank is not None: + # TODO: Support tokentype embedding. + self.language_model.embedding.cuda() + position_embeddings = self.language_model.embedding.position_embeddings + torch.distributed.all_reduce(position_embeddings.weight.data, + group=mpu.get_position_embedding_group()) + + +def conversion_helper(val, conversion): + """Apply conversion to val. Recursively apply conversion if `val` + #is a nested tuple/list structure.""" + if not isinstance(val, (tuple, list)): + return conversion(val) + rtn = [conversion_helper(v, conversion) for v in val] + if isinstance(val, tuple): + rtn = tuple(rtn) + return rtn + + +def fp32_to_float16(val, float16_convertor): + """Convert fp32 `val` to fp16/bf16""" + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, _FLOAT_TYPES): + val = float16_convertor(val) + return val + return conversion_helper(val, half_conversion) + + +def float16_to_fp32(val): + """Convert fp16/bf16 `val` to fp32""" + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)): + val = val.float() + return val + return conversion_helper(val, float_conversion) + + + +class Float16Module(MegatronModule): + + def __init__(self, module, args): + super(Float16Module, self).__init__() + + if args.fp16: + self.add_module('module', module.half()) + def float16_convertor(val): + return val.half() + elif args.bf16: + self.add_module('module', module.bfloat16()) + def float16_convertor(val): + return val.bfloat16() + else: + raise Exception('should not be here') + + self.float16_convertor = float16_convertor + + + def set_input_tensor(self, input_tensor): + return self.module.set_input_tensor(input_tensor) + + + def forward(self, *inputs, **kwargs): + if mpu.is_pipeline_first_stage(): + inputs = fp32_to_float16(inputs, self.float16_convertor) + outputs = self.module(*inputs, **kwargs) + if mpu.is_pipeline_last_stage(): + outputs = float16_to_fp32(outputs) + return outputs + + + def state_dict(self, prefix='', keep_vars=False): + return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) + + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + return self.module.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) diff --git a/training/DeepSpeed-Domino/megatron/model/multiple_choice.py b/training/DeepSpeed-Domino/megatron/model/multiple_choice.py new file mode 100644 index 000000000..41f8bb49f --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/multiple_choice.py @@ -0,0 +1,112 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Multiple choice model.""" + +import torch + +from megatron import get_args, print_rank_last +from megatron.model.enums import AttnMaskType +from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids +from megatron.model.language_model import get_language_model +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from .module import MegatronModule + + +class MultipleChoice(MegatronModule): + + def __init__(self, + config, + num_tokentypes=2, + pre_process=True, + post_process=True): + super(MultipleChoice, self).__init__(share_embeddings_and_output_weights=False) + args = get_args() + + self.pre_process = pre_process + self.post_process = post_process + + self.language_model, self._language_model_key = get_language_model( + config=config, + num_tokentypes=num_tokentypes, + add_pooler=True, + encoder_attn_mask_type=AttnMaskType.padding, + pre_process=self.pre_process, + post_process=self.post_process) + + # Multi-choice head. + if self.post_process: + self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout) + self.multichoice_head = get_linear_layer(args.hidden_size, 1, + init_method) + self._multichoice_head_key = 'multichoice_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, model_input, attention_mask, tokentype_ids=None): + + # [batch, choices, sequence] --> [batch * choices, sequence] --> + # transformer --> [batch, choices] --> softmax + + # Ensure the shape is [batch-size, choices, sequence] + assert len(attention_mask.shape) == 3 + num_choices = attention_mask.shape[1] + + # Reshape and treat choice dimension the same as batch. + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) + extended_attention_mask = bert_extended_attention_mask(attention_mask) + + input_ids = model_input + # Do the same as attention_mask for input_ids, tokentype_ids + assert len(input_ids.shape) == 3 + assert len(tokentype_ids.shape) == 3 + input_ids = input_ids.view(-1, input_ids.size(-1)) + tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1)) + position_ids = bert_position_ids(input_ids) + + lm_output = self.language_model( + input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids + ) + if self.post_process: + _, pooled_output = lm_output + multichoice_output = self.multichoice_dropout(pooled_output) + multichoice_logits = self.multichoice_head(multichoice_output) + + # Reshape back to separate choices. + multichoice_logits = multichoice_logits.view(-1, num_choices) + + return multichoice_logits + return lm_output + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.post_process: + state_dict_[self._multichoice_head_key] \ + = self.multichoice_head.state_dict(prefix=prefix, keep_vars=keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + if self.post_process: + if self._multichoice_head_key in state_dict: + self.multichoice_head.load_state_dict( + state_dict[self._multichoice_head_key], strict=strict) + else: + print_rank_last('***WARNING*** could not find {} in the checkpoint, ' + 'initializing to random'.format( + self._multichoice_head_key)) diff --git a/training/DeepSpeed-Domino/megatron/model/realm_model.py b/training/DeepSpeed-Domino/megatron/model/realm_model.py new file mode 100644 index 000000000..654f2992f --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/realm_model.py @@ -0,0 +1,204 @@ +import os +import torch + +from megatron import get_args, print_rank_0 +from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name +from megatron.model import BertModel +from .module import MegatronModule +from megatron.core import mpu +from megatron.model.enums import AttnMaskType +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.language_model import get_language_model +from megatron.model.utils import scaled_init_method_normal +from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids + + +def general_ict_model_provider(only_query_model=False, only_block_model=False): + """Build the model.""" + args = get_args() + assert args.ict_head_size is not None, \ + "Need to specify --ict-head-size to provide an ICTBertModel" + assert mpu.get_tensor_model_parallel_world_size() == 1 and mpu.get_pipeline_model_parallel_world_size() == 1, \ + "Model parallel size > 1 not supported for ICT" + + print_rank_0('building ICTBertModel...') + + # simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes + model = ICTBertModel( + ict_head_size=args.ict_head_size, + num_tokentypes=2, + parallel_output=True, + only_query_model=only_query_model, + only_block_model=only_block_model) + + return model + + +class ICTBertModel(MegatronModule): + """Bert-based module for Inverse Cloze task.""" + def __init__(self, + ict_head_size, + num_tokentypes=1, + parallel_output=True, + only_query_model=False, + only_block_model=False): + super(ICTBertModel, self).__init__() + bert_kwargs = dict( + ict_head_size=ict_head_size, + num_tokentypes=num_tokentypes, + parallel_output=parallel_output + ) + assert not (only_block_model and only_query_model) + self.use_block_model = not only_query_model + self.use_query_model = not only_block_model + + if self.use_query_model: + # this model embeds (pseudo-)queries - Embed_input in the paper + self.query_model = IREncoderBertModel(**bert_kwargs) + self._query_key = 'question_model' + + if self.use_block_model: + # this model embeds evidence blocks - Embed_doc in the paper + self.block_model = IREncoderBertModel(**bert_kwargs) + self._block_key = 'context_model' + + def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask): + """Run a forward pass for each of the models and return the respective embeddings.""" + query_logits = self.embed_query(query_tokens, query_attention_mask) + block_logits = self.embed_block(block_tokens, block_attention_mask) + return query_logits, block_logits + + def embed_query(self, query_tokens, query_attention_mask): + """Embed a batch of tokens using the query model""" + if self.use_query_model: + query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0) + query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types) + return query_ict_logits + else: + raise ValueError("Cannot embed query without query model.") + + def embed_block(self, block_tokens, block_attention_mask): + """Embed a batch of tokens using the block model""" + if self.use_block_model: + block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0) + block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types) + return block_ict_logits + else: + raise ValueError("Cannot embed block without block model.") + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """Save dict with state dicts of each of the models.""" + state_dict_ = {} + if self.use_query_model: + state_dict_[self._query_key] \ + = self.query_model.state_dict_for_save_checkpoint( + prefix=prefix, keep_vars=keep_vars) + + if self.use_block_model: + state_dict_[self._block_key] \ + = self.block_model.state_dict_for_save_checkpoint( + prefix=prefix, keep_vars=keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Load the state dicts of each of the models""" + if self.use_query_model: + print("Loading ICT query model", flush=True) + self.query_model.load_state_dict( + state_dict[self._query_key], strict=strict) + + if self.use_block_model: + print("Loading ICT block model", flush=True) + self.block_model.load_state_dict( + state_dict[self._block_key], strict=strict) + + def init_state_dict_from_bert(self): + """Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining""" + args = get_args() + tracker_filename = get_checkpoint_tracker_filename(args.bert_load) + if not os.path.isfile(tracker_filename): + raise FileNotFoundError("Could not find BERT load for ICT") + with open(tracker_filename, 'r') as f: + iteration = int(f.read().strip()) + assert iteration > 0 + + checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False) + if mpu.get_data_parallel_rank() == 0: + print('global rank {} is loading checkpoint {}'.format( + torch.distributed.get_rank(), checkpoint_name)) + + try: + state_dict = torch.load(checkpoint_name, map_location='cpu') + except BaseException: + raise ValueError("Could not load checkpoint") + + # load the LM state dict into each model + model_dict = state_dict['model']['language_model'] + self.query_model.language_model.load_state_dict(model_dict) + self.block_model.language_model.load_state_dict(model_dict) + + # give each model the same ict_head to begin with as well + query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head'] + self.block_model.ict_head.load_state_dict(query_ict_head_state_dict) + + +class IREncoderBertModel(MegatronModule): + """BERT-based encoder for queries or blocks used for learned information retrieval.""" + def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True): + super(IREncoderBertModel, self).__init__() + args = get_args() + + self.ict_head_size = ict_head_size + self.parallel_output = parallel_output + init_method = init_method_normal(args.init_method_std) + scaled_init_method = scaled_init_method_normal(args.init_method_std, + args.num_layers) + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=True, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method) + + self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method) + self._ict_head_key = 'ict_head' + + def forward(self, input_ids, attention_mask, tokentype_ids=None): + extended_attention_mask = bert_extended_attention_mask( + attention_mask, next(self.language_model.parameters()).dtype) + position_ids = bert_position_ids(input_ids) + + lm_output, pooled_output = self.language_model( + input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids) + + # Output. + ict_logits = self.ict_head(pooled_output) + return ict_logits, None + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + state_dict_[self._ict_head_key] \ + = self.ict_head.state_dict(prefix=prefix, + keep_vars=keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + self.ict_head.load_state_dict( + state_dict[self._ict_head_key], strict=strict) + + diff --git a/training/DeepSpeed-Domino/megatron/model/t5_model.py b/training/DeepSpeed-Domino/megatron/model/t5_model.py new file mode 100644 index 000000000..f9fabd340 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/t5_model.py @@ -0,0 +1,186 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""T5 model.""" + +import torch + +from megatron import get_args +from megatron.core import tensor_parallel +from megatron.model.enums import AttnMaskType +from megatron.model.language_model import parallel_lm_logits, get_language_model +from megatron.model import LayerNorm +from megatron.model.utils import ( + openai_gelu, + get_linear_layer +) +from .module import MegatronModule + + +def t5_extended_attention_mask(attention_mask_list): + + def attn_mask_postprocess(attn_mask): + # [b, 1, s, s] + extended_attention_mask = attn_mask.unsqueeze(1) + return extended_attention_mask + + return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list] + + +def t5_position_ids(token_ids): + # Create position ids + seq_length = token_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, + device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids + + +class T5LMHead(MegatronModule): + """Masked LM head for T5 + + Arguments: + mpu_vocab_size: model parallel size of vocabulary. + parallel_output: wether output logits being distributed or not. + """ + + def __init__(self, mpu_vocab_size, parallel_output): + super(T5LMHead, self).__init__() + + self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) + self.bias.model_parallel = True + self.bias.partition_dim = 0 + self.bias.stride = 1 + self.parallel_output = parallel_output + + def forward(self, hidden_states, word_embeddings_weight): + output = parallel_lm_logits(hidden_states, + word_embeddings_weight, + self.parallel_output, + bias=self.bias) + return output + + +class T5Model(MegatronModule): + """T5 Language model.""" + + def __init__(self, + config, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True, + add_encoder=True, + add_decoder=True): + super().__init__(config=config) + args = get_args() + + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder + + self.language_model, self._language_model_key = get_language_model( + config=config, + num_tokentypes=num_tokentypes, + add_pooler=False, + add_encoder=add_encoder, + add_decoder=add_decoder, + encoder_attn_mask_type=AttnMaskType.padding, + pre_process=self.pre_process, + post_process=self.post_process) + + self.initialize_word_embeddings() + + if self.post_process and self.add_decoder: + self.lm_head = T5LMHead( + self.shared_embedding_or_output_weight().size(0), + parallel_output) + self._lm_head_key = 'lm_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, encoder_input_ids, decoder_input_ids, encoder_attn_mask, + decoder_attn_mask, encoder_decoder_attn_mask, + tokentype_ids=None, lm_labels=None, enc_hidden_states=None): + + # Converting the attention masks to proper parameter settings + encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask( + [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask]) + + encoder_position_ids = t5_position_ids(encoder_input_ids) + decoder_position_ids = t5_position_ids(decoder_input_ids) + + lm_output = self.language_model(encoder_input_ids, + encoder_position_ids, + encoder_attn_mask, + decoder_input_ids, + decoder_position_ids, + decoder_attn_mask, + encoder_decoder_attn_mask, + tokentype_ids=tokentype_ids, + enc_hidden_states=enc_hidden_states) + + if self.post_process and self.add_decoder: + decoder_output, encoder_output = lm_output + # Output. [s, b, h] + lm_logits = self.lm_head(decoder_output, + self.shared_embedding_or_output_weight()) + + if lm_labels is None: + # [s b h] => [b s h] + return lm_logits.transpose(0,1).contiguous() + else: + # [b s] => [s b] + lm_labels = lm_labels.transpose(0,1).contiguous() + if self.fp16_lm_cross_entropy: + assert lm_logits.dtype == torch.half + lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels) + else: + lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), + lm_labels) + # [s b] => [b s] + lm_loss = lm_loss.transpose(0,1).contiguous() + return lm_loss + elif self.add_decoder and not self.add_encoder: + decoder_output, encoder_output = lm_output + return decoder_output + else: + encoder_output = lm_output + return encoder_output + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.post_process and self.add_decoder: + state_dict_[self._lm_head_key] \ + = self.lm_head.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + # Save word_embeddings. + if self.post_process and not self.pre_process and self.add_decoder: + state_dict_[self._word_embeddings_for_head_key] \ + = self.word_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + if self.post_process and self.add_decoder: + self.lm_head.load_state_dict(state_dict[self._lm_head_key], + strict=strict) + # Load word embeddings. + if self.post_process and not self.pre_process and self.add_decoder: + self.word_embeddings.load_state_dict( + state_dict[self._word_embeddings_for_head_key], strict=strict) diff --git a/training/DeepSpeed-Domino/megatron/model/transformer.py b/training/DeepSpeed-Domino/megatron/model/transformer.py new file mode 100644 index 000000000..a7898156f --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/transformer.py @@ -0,0 +1,1710 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +"""Transformer.""" +from contextlib import nullcontext +import math +import numpy as np +import torch +import torch.nn.functional as F +from typing import Optional + +from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches +from .module import MegatronModule +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.model import LayerNorm +from megatron.model.enums import AttnMaskType, LayerType, AttnType +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from megatron.model.fused_bias_gelu import bias_gelu_impl +from megatron.core.models.common.rotary_pos_embedding import apply_rotary_pos_emb +from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu + +try: + from einops import rearrange +except ImportError: + rearrange = None + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func +except ImportError: + try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func + except ImportError: + flash_attn_unpadded_func = None + +""" We use the following notation throughout this file: + h: hidden size + n: number of attention heads + p: number of model parallel partitions + np: n/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + l: number of layers + Transformer takes input of size [s, b, h] and returns a + tensor of the same size. We use the following arguments: + hyperparameters: transformer hyperparameters +""" + +class DropPath(MegatronModule): + """Drop paths (Stochastic Depth) per sample + (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=0.): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_state): + if self.drop_prob == 0. or not self.training: + return hidden_state + keep_prob = 1 - self.drop_prob + # work with diff dim tensors, not just 2D ConvNets + # hidden_state: [s, b, h] + shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2) + random_tensor = keep_prob + \ + torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device) + random_tensor.floor_() # binarize + output = hidden_state.div(keep_prob) * random_tensor + return output + +class ParallelMLP(MegatronModule): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config): + super(ParallelMLP, self).__init__() + args = get_args() + + self.add_bias = config.add_bias_linear + + ffn_hidden_size = config.ffn_hidden_size + if config.gated_linear_unit: + ffn_hidden_size *= 2 + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + ffn_hidden_size, + config=config, + init_method=config.init_method, + bias=self.add_bias, + gather_output=False, + skip_bias_add=True, + ) + + self.bias_gelu_fusion = False + self.activation_func = None + self.swiglu = args.swiglu + + if args.openai_gelu: + self.activation_func = openai_gelu + elif args.onnx_safe: + self.activation_func = erf_gelu + elif args.swiglu: + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + self.activation_func = swiglu + elif args.squared_relu: + def squared_relu(x): + return torch.pow(F.relu(x), 2) + self.activation_func = squared_relu + else: + self.bias_gelu_fusion = args.bias_gelu_fusion + self.activation_func = F.gelu + + # Project back to h. + self.dense_4h_to_h = tensor_parallel.RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + config=config, + init_method=config.output_layer_init_method, + bias=self.add_bias, + input_is_parallel=True + ) + + def forward(self, hidden_states): + + # [s, b, 4hp] + intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) + + if self.bias_gelu_fusion: + assert self.add_bias is True + assert self.activation_func == F.gelu + intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) + else: + if bias_parallel is not None: + intermediate_parallel = intermediate_parallel + bias_parallel + intermediate_parallel = self.activation_func(intermediate_parallel) + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias + +class SwitchMLP(MegatronModule): + """ + Routes input to one of N MLP "experts" + """ + def __init__(self, config): + super(SwitchMLP, self).__init__() + args = get_args() + self.router = torch.nn.Linear(config.hidden_size, args.num_experts) + self.experts = torch.nn.ModuleList() + for i in range(args.num_experts): + self.experts.append(ParallelMLP(config)) + + def forward(self, hidden_states): + # hidden_states: [s, b, h] + s = hidden_states.size(0) + b = hidden_states.size(1) + h = hidden_states.size(2) + route = self.router(hidden_states) + route = torch.nn.functional.softmax(route, dim=2) + max_prob, max_ind = torch.max(route, dim=2) + max_prob = torch.unsqueeze(max_prob, 2) # [s b 1] + + # TODO (rprenger) TODO this could be made easier to read + # Converting [s, b, h] to [s*b, h]. + # Each vector could be routed differently + hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h] + max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1] + max_ind = max_ind.view(-1) # [s*b] + + output_total = torch.empty_like(hidden_states) + output_bias_total = torch.empty_like(hidden_states) + #TODO (rprenger) This does each expert in serial, but it could be parallelized + + for expert_num, expert in enumerate(self.experts): + local_indices = (max_ind == expert_num).nonzero() + hidden = hidden_states[local_indices,:] + output, output_bias = expert(hidden) + if output_bias is not None: + output_bias = output_bias.expand_as(output) + output_bias_total[local_indices,:] = output_bias + output_total[local_indices,:] = output + + output_total = output_total*max_prob + output_total = output_total.view(s, b, h) + if output_bias is not None: + output_bias_total = output_bias_total*max_prob + output_bias_total = output_bias_total.view(s, b, h) + else: + output_bias_total = None + + return output_total, output_bias_total + + +class CoreAttention(MegatronModule): + + def __init__(self, layer_number, config, + attn_mask_type=AttnMaskType.padding): + super(CoreAttention, self).__init__() + self.fp16 = config.fp16 + self.bf16 = config.bf16 + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + self.attn_mask_type = attn_mask_type + self.sequence_parallel = config.sequence_parallel + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_tensor_model_parallel_world_size() + self.hidden_size_per_partition = core.utils.divide(projection_size, + world_size) + self.hidden_size_per_attention_head = core.utils.divide( + projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = core.utils.divide( + config.num_attention_heads, world_size) + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + self.fp16, self.bf16, + self.attn_mask_type, + config.masked_softmax_fusion, + attention_mask_func, + self.attention_softmax_in_fp32, + coeff) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, + value_layer, attention_mask): + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.reshape(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor( + (output_size[0]*output_size[1], output_size[2], output_size[3]), + query_layer.dtype, "mpu") + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, alpha=(1.0/self.norm_factor)) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + if not self.sequence_parallel: + with tensor_parallel.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), + output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class FlashSelfAttention(torch.nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, + device=None, dtype=None): + super().__init__() + assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, ' + 'e.g., with pip install flash-attn') + assert rearrange is not None, 'Please install einops first, e.g., with pip install einops' + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, q, k, v): + """Implements the multihead softmax attention. + Arguments + --------- + q, k, v: The tensor containing the query, key, and value. (B, S, H, D) + """ + + assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v))) + assert all((i.is_cuda for i in (q,k,v))) + + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = k.shape[1] + + q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, + device=q.device) + + if self.training: + # during training q,k,v always have same seqlen + assert seqlen_k == seqlen_q + + is_causal = self.causal + cu_seqlens_k = cu_seqlens_q + dropout_p = self.dropout_p + else: + # turn off FA causal mask after first inference autoregressive iteration + # only on first autoregressive step q,k,v have same seqlen + is_causal = seqlen_q == seqlen_k + cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, + device=q.device) + dropout_p = 0 + + output = flash_attn_unpadded_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, + dropout_p, + softmax_scale=self.softmax_scale, causal=is_causal + ) + + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + return output + + +class ParallelAttention(MegatronModule): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config, layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=AttnMaskType.padding): + super(ParallelAttention, self).__init__() + args = get_args() + self.layer_number = max(1, layer_number) + self.attention_type = attention_type + self.attn_mask_type = attn_mask_type + self.params_dtype = config.params_dtype + self.sequence_parallel = config.sequence_parallel + + self.group_query_attention = args.group_query_attention + self.num_query_groups = args.num_query_groups + + query_projection_size = config.kv_channels * config.num_attention_heads + if self.group_query_attention: + kv_projection_size = args.kv_channels * args.num_query_groups + else: + kv_projection_size = args.kv_channels * args.num_attention_heads + + self.use_flash_attn = args.use_flash_attn \ + and attention_type == AttnType.self_attn \ + and self.attn_mask_type == AttnMaskType.causal + if self.use_flash_attn: + if flash_attn_unpadded_func is None: + raise ImportError('FlashAttention is not installed, please install with ' + 'pip install flash-attn') + assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports ' + 'self-attention for now') + assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only ' + 'supports causal mask for now') + if rearrange is None: + raise ImportError('einops is not installed, please install with pip install einops') + + # Per attention head and per partition values. + world_size = mpu.get_tensor_model_parallel_world_size() + self.hidden_size_per_attention_head = core.utils.divide( + query_projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = core.utils.divide( + config.num_attention_heads, world_size) + + if self.group_query_attention: + if args.num_query_groups % world_size != 0: + raise NotImplementedError('Currently the num_query_groups should be ' + 'a multiple of the tensor parallel size') + self.num_query_groups_per_partition = core.utils.divide( + args.num_query_groups, world_size) + else: + self.num_query_groups_per_partition = self.num_attention_heads_per_partition + + # Strided linear layer. + if attention_type == AttnType.self_attn: + self.query_key_value = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + query_projection_size + 2 * kv_projection_size, + config=config, + init_method=config.init_method, + bias=args.add_bias_linear, + gather_output=False) + else: + assert attention_type == AttnType.cross_attn + + if self.group_query_attention: + raise NotImplementedError("Grouped query attention not implemented for cross-attention.") + assert query_projection_size == kv_projection_size + + self.query = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + query_projection_size, + config=config, + init_method=config.init_method, + bias=config.add_bias_linear, + gather_output=False) + + self.key_value = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + 2 * kv_projection_size, + config=config, + init_method=config.init_method, + bias=config.add_bias_linear, + gather_output=False) + + self.core_attention = CoreAttention(self.layer_number, config, + self.attn_mask_type) + self.checkpoint_core_attention = config.recompute_granularity == 'selective' + + if self.use_flash_attn: + self.core_attention_flash = FlashSelfAttention( + causal=True, attention_dropout=config.attention_dropout + ) + + # Output. + self.dense = tensor_parallel.RowParallelLinear( + query_projection_size, + config.hidden_size, + config=config, + init_method=config.output_layer_init_method, + bias=args.add_bias_linear, + input_is_parallel=True, + skip_bias_add=True) + + def _checkpointed_attention_forward(self, query_layer, key_layer, + value_layer, attention_mask, + rotary_pos_emb=None): + """Forward method with activation checkpointing.""" + def custom_forward(*inputs): + query_layer = inputs[0] + key_layer = inputs[1] + value_layer = inputs[2] + attention_mask = inputs[3] + output_ = self.core_attention(query_layer, key_layer, + value_layer, attention_mask) + return output_ + + q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \ + else rotary_pos_emb + + hidden_states = tensor_parallel.checkpoint( + custom_forward, + False, query_layer, key_layer, value_layer, attention_mask, + q_pos_emb, k_pos_emb) + + return hidden_states + + def _allocate_memory(self, inference_max_sequence_len, batch_size, num_attention_heads): + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=self.params_dtype, + device=torch.cuda.current_device()) + + def forward(self, hidden_states, attention_mask, + encoder_output=None, inference_params=None, + rotary_pos_emb=None): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + is_first_step = False + if inference_params: + if self.layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_len = inference_params.max_sequence_length + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size, + self.num_query_groups_per_partition) + inference_value_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size, + self.num_query_groups_per_partition) + + inference_params.key_value_memory_dict[self.layer_number] = ( + inference_key_memory, inference_value_memory) + is_first_step = True + else: + inference_key_memory, inference_value_memory = \ + inference_params.key_value_memory_dict[self.layer_number] + + # ===================== + # Query, Key, and Value + # ===================== + if self.attention_type == AttnType.self_attn: + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head + ), + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query_layer, + key_layer, + value_layer) = torch.split( + mixed_x_layer, + [ + ( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head + ], + dim=3) + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] - + query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) + else: + # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + mixed_kv_layer, _ = self.key_value(encoder_output) + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head) + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key_layer, + value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) + + # Attention head [sq, b, h] --> [sq, b, hp] + query_layer, _ = self.query(hidden_states) + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + query_layer = query_layer.view(*new_tensor_shape) + + # ================================== + # Adjust key and value for inference + # ================================== + + # duplicate the pos_emb for self attention + if rotary_pos_emb is not None: + if isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = rotary_pos_emb + else: + rotary_pos_emb = ((rotary_pos_emb,) * 2) + + if inference_params: + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + # Copy key and values. + inference_key_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = key_layer + inference_value_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = value_layer + key_layer = inference_key_memory[ + :sequence_end, batch_start:batch_end, ...] + value_layer = inference_value_memory[ + :sequence_end, batch_start:batch_end, ...] + + + # adjust the key rotary positional embedding + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + # need to cross check this condition during inference + # if not set_inference_key_value_memory: + if not is_first_step: + # In inference, we compute one token at a time. + # Select the correct positional embedding + # (only the last token in the sequence) + q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end] + else: + # In the first forward pass of inference, + # we use the entire provided prefix. + # q_pos_emb here has the rope embeddings of the entire + # prefix + to-be-generated output so + # we slice to just the prefix. + q_pos_emb = q_pos_emb[:sequence_end, :, :, :] + k_pos_emb = k_pos_emb[:sequence_end, :, :, :] + rotary_pos_emb = (q_pos_emb, k_pos_emb) + + # ================================== + # core attention computation + # ================================== + + # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn] + key_layer = key_layer.repeat_interleave( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition, + dim = 2 + ) + value_layer = value_layer.repeat_interleave( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition, + dim = 2 + ) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb) + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + + if not self.use_flash_attn: + if self.checkpoint_core_attention: + context_layer = self._checkpointed_attention_forward( + query_layer, key_layer, value_layer, attention_mask) + else: + context_layer = self.core_attention( + query_layer, key_layer, value_layer, attention_mask) + else: + q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() + for x in (query_layer, key_layer, value_layer)] + if not self.sequence_parallel: + with tensor_parallel.get_cuda_rng_tracker().fork(): + context_layer = self.core_attention_flash(q, k, v) + else: + context_layer = self.core_attention_flash(q, k, v) + context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.dense(context_layer) + + return output, bias + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor + if bias is not None: + x = x + bias + out = torch.nn.functional.dropout(x, p=prob, training=training) + out = residual + out + return out + + +def get_bias_dropout_add(training): + def _bias_dropout_add(x, bias, residual, prob): + return bias_dropout_add(x, bias, residual, prob, training) + return _bias_dropout_add + + +@torch.jit.script +def bias_dropout_add_fused_train(x: torch.Tensor, + bias: Optional[torch.Tensor], + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_fused_inference(x: torch.Tensor, + bias: Optional[torch.Tensor], + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, False) + + +class ParallelTransformerLayer(MegatronModule): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config, + layer_number, layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + drop_path_rate=0.): + # retriever=None): + args = get_args() + + super(ParallelTransformerLayer, self).__init__() + self.layer_number = layer_number + self.layer_type = layer_type + + self.apply_residual_connection_post_layernorm \ + = config.apply_residual_connection_post_layernorm + + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + + # Layernorm on the input data. + self.input_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=args.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel, + apply_layernorm_1p=args.apply_layernorm_1p) + + # Self attention. + self.self_attention = ParallelAttention( + config, + layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=self_attn_mask_type) + self.hidden_dropout = config.hidden_dropout + self.bias_dropout_fusion = config.bias_dropout_fusion + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=not config.persist_layer_norm, + sequence_parallel=config.sequence_parallel, + apply_layernorm_1p=args.apply_layernorm_1p) + + # Cross attention. + if self.layer_type in (LayerType.decoder, + LayerType.retro_decoder, + LayerType.retro_decoder_with_retriever, + LayerType.retro_encoder): + self.inter_attention = ParallelAttention( + config, + layer_number, + attention_type=AttnType.cross_attn) + # Layernorm on the attention output. + self.post_inter_attention_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=not config.persist_layer_norm, + sequence_parallel=config.sequence_parallel, + apply_layernorm_1p=args.apply_layernorm_1p) + + # MLP + if args.num_experts is not None: + self.mlp = SwitchMLP(config) + else: + self.mlp = ParallelMLP(config) + + # Set bias+dropout+add fusion grad_enable execution handler. + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) + self.bias_dropout_add_exec_handler = \ + nullcontext if use_nvfuser else torch.enable_grad + + if args.retro_add_retriever: + retro_args = get_retro_args() + self.retro_num_neighbors = args.retro_num_neighbors + self.retro_chunk_length = retro_args.retro_gpt_chunk_length + self.retro_retrieved_length = retro_args.retro_gpt_retrieved_length + + # Retriever (bi-directional transformer with cross attention) + if layer_type == LayerType.retro_decoder_with_retriever: + self.retriever = ParallelTransformer( + config=config, + model_type=ModelType.retro_encoder, + self_attn_mask_type=AttnMaskType.padding, + pre_process=True, + post_process=False, + ) + self._retriever_key = 'retriever' + else: + self.retriever = None + + def default_decoder_cross_attention(self, + encoder_output, + enc_dec_attn_mask, + layernorm_input, + layernorm_output, + bias_dropout_add_func): + '''Cross attention for a standard encoder-decoder model.''' + + # Attention. + attention_output, attention_bias = \ + self.inter_attention(layernorm_output, + enc_dec_attn_mask, + encoder_output=encoder_output) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + if attention_bias is not None: + attention_bias = attention_bias.expand_as(residual) + + # Bias-dropout-add. + with self.bias_dropout_add_exec_handler(): + layernorm_input = bias_dropout_add_func( + attention_output, + attention_bias, + residual, + self.hidden_dropout) + + # Layer norm. + layernorm_output = self.post_inter_attention_layernorm(layernorm_input) + + return layernorm_input, layernorm_output + + def retro_encoder_cross_attention(self, + retriever_output, + layernorm_input, + layernorm_output, + bias_dropout_add_func): + """Cross attention for Retro encoder. + + Notation: + ns : Sequence length. + bs : Batch size. + d : Hidden size. + l : Number of chunks per sample (i.e., seq_length/chunk_length). + k : Number of neighbors. + r : Number of retrieved tokens (neighbors + continuation). + """ + + ns, bs, d = layernorm_output.shape # [r, bs * l * k, d] + + # Divide sequence dimension into chunks. + chunked_outputs = layernorm_output.reshape(self.retro_retrieved_length, + -1, + self.retro_num_neighbors, + d) + chunked_outputs_before_layer_norm = \ + layernorm_input.reshape(self.retro_retrieved_length, -1, + self.retro_num_neighbors, d) # [r, bs*l, k, d] + + # Per-chunk attention. + layernorm_inputs = [] + layernorm_outputs = [] + for k in range(self.retro_num_neighbors): + + # Attention. + chunked_output = chunked_outputs[:,:,k].contiguous() + attention_output, attention_bias = \ + self.inter_attention( + chunked_output, # Q (neighbor embedding) + None, + encoder_output=retriever_output) # K, V (hidden act) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = chunked_output + else: + residual = chunked_outputs_before_layer_norm[:,:,k] + + # Re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + layernorm_input = bias_dropout_add_func( + attention_output, + None if attention_bias is None else attention_bias.expand_as(residual), + residual, + self.hidden_dropout) + layernorm_inputs.append(layernorm_input) + + # Layer norm. + layernorm_output = \ + self.post_inter_attention_layernorm(layernorm_input) + layernorm_outputs.append(layernorm_output) + + # Concatenate layer norms. + # layernorm_input : [r, k * bs * l, d] + # layernorm_output : [r, k * bs * l, d] + layernorm_input = \ + torch.stack(layernorm_inputs, dim=1).reshape(ns, bs, d) + layernorm_output = \ + torch.stack(layernorm_outputs, dim=1).reshape(ns, bs, d) + + return layernorm_input, layernorm_output + + def retro_decoder_cross_attention(self, + retriever_input, + retriever_output, + retriever_attn_mask, + layernorm_input, + layernorm_output, + inference_params, + bias_dropout_add_func): + """Cross attention for Retro decoder. + + Notation: + ns : Sequence length. + bs : Batch size. + d : Hidden size. + l : Number of chunks per sample (i.e., seq_length/chunk_length). + m : Number of tokens per chunk. + k : Number of neighbors. + r : Number of retrieved tokens (neighbors + continuation). + """ + + ns, bs, d = layernorm_output.shape + l = int(np.ceil(ns / self.retro_chunk_length)) + + # Retrieve neighbors. + if self.layer_type == LayerType.retro_decoder_with_retriever: + first_ns = ns % self.retro_chunk_length + if first_ns > 0: + raise Exception("test this case.") + first_chunk, rest_chunk = \ + layernorm_output[:first_ns], layernorm_output[first_ns:] + first_chunk = torch.nn.functional.pad( + first_chunk, + (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), + 'constant', + 0) + chunked_output = \ + torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d] + else: + chunked_output = layernorm_output # [l * m, bs, d] + chunked_output = chunked_output \ + .reshape(l, self.retro_chunk_length, bs, d) \ + .permute(1, 2, 0, 3) \ + .reshape(self.retro_chunk_length, bs * l, d) \ + .contiguous() + + # Get Encoder Output + retriever_output = self.retriever( + hidden_states=retriever_input, + attention_mask=retriever_attn_mask, + retriever_output=chunked_output, + retriever_attn_mask=retriever_attn_mask, + inference_params=inference_params) # [r, k * bs * l , d] + retriever_output = retriever_output.reshape( + self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d) # [r * k, bs * l, d] + + # Chunks. + pad = (ns - 1) % self.retro_chunk_length + attending_chunks = layernorm_output[pad:] + padded_chunks = torch.nn.functional.pad( + attending_chunks, + (0, 0, 0, 0, 0, self.retro_chunk_length - 1), + 'constant', 0) + padded_chunked_output = padded_chunks \ + .reshape(l, self.retro_chunk_length, bs, d) \ + .permute(1, 2, 0, 3) + padded_chunked_output = padded_chunked_output.reshape( + self.retro_chunk_length, bs * l, d).contiguous() + + # Encoder output. + attention_output, attention_bias = \ + self.inter_attention(padded_chunked_output, + None, + encoder_output=retriever_output) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + # Re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + layernorm_input = bias_dropout_add_func( + attention_output, + None if attention_bias is None else attention_bias.expand_as(attention_output), + torch.zeros_like(attention_output), + self.hidden_dropout) + layernorm_input = layernorm_input \ + .reshape(self.retro_chunk_length, bs, l, d) \ + .permute(2, 0, 1, 3) # [l, m, bs, d] + layernorm_input = layernorm_input.reshape(self.retro_chunk_length * l, bs, d) + layernorm_input = torch.nn.functional.pad( + layernorm_input, + (0, 0, 0, 0, pad, 0), + 'constant', 0)[:ns] # [ns, b, d] + layernorm_input = layernorm_input + residual + + # Layer norm post the decoder attention + layernorm_output = self.post_inter_attention_layernorm(layernorm_input) + + return retriever_output, layernorm_input, layernorm_output + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + retriever_input=None, + retriever_output=None, + retriever_attn_mask=None, + inference_params=None, + rotary_pos_emb=None): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Self attention. + attention_output, attention_bias = \ + self.self_attention( + layernorm_output, + attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + if self.drop_path is None: + # jit scripting for a nn.module (with dropout) is not + # trigerring the fusion kernel. For now, we use two + # different nn.functional routines to account for varying + # dropout semantics during training and inference phases. + if self.bias_dropout_fusion: + if self.training: + bias_dropout_add_func = bias_dropout_add_fused_train + else: + bias_dropout_add_func = bias_dropout_add_fused_inference + else: + bias_dropout_add_func = get_bias_dropout_add(self.training) + + if attention_bias is not None: + attention_bias = attention_bias.expand_as(residual) + with self.bias_dropout_add_exec_handler(): + layernorm_input = bias_dropout_add_func( + attention_output, + attention_bias, + residual, + self.hidden_dropout) + else: + out = torch.nn.functional.dropout(attention_output + attention_bias, + p=self.hidden_dropout, + training=self.training) + layernorm_input = residual + self.drop_path(out) + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # Cross attention. + if self.layer_type == LayerType.encoder: + pass + elif self.layer_type == LayerType.decoder: + layernorm_input, layernorm_output = \ + self.default_decoder_cross_attention( + encoder_output, + enc_dec_attn_mask, + layernorm_input, + layernorm_output, + bias_dropout_add_func) + elif self.layer_type == LayerType.retro_encoder: + layernorm_input, layernorm_output = \ + self.retro_encoder_cross_attention( + retriever_output, + layernorm_input, + layernorm_output, + bias_dropout_add_func) + elif self.layer_type in (LayerType.retro_decoder, + LayerType.retro_decoder_with_retriever): + retriever_output, layernorm_input, layernorm_output = \ + self.retro_decoder_cross_attention( + retriever_input, + retriever_output, + retriever_attn_mask, + layernorm_input, + layernorm_output, + inference_params, + bias_dropout_add_func) + else: + raise Exception("Unsupported layer type, '%s'." % + self.layer_type.name) + + # MLP. + mlp_output, mlp_bias = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + if self.drop_path is None: + if mlp_bias is not None: + mlp_bias = mlp_bias.expand_as(residual) + with self.bias_dropout_add_exec_handler(): + output = bias_dropout_add_func( + mlp_output, + mlp_bias, + residual, + self.hidden_dropout) + + # Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. + output = core.utils.make_viewless_tensor(inp = output, + requires_grad = output.requires_grad, + keep_graph = True) + + else: + if mlp_bias is not None: + mlp_output = mlp_output + mlp_bias + out = torch.nn.functional.dropout(mlp_output, + p=self.hidden_dropout, + training=self.training) + output = residual + self.drop_path(out) + + if self.layer_type == LayerType.retro_decoder_with_retriever: + return output, retriever_output + else: + return output + + +class NoopTransformerLayer(MegatronModule): + """A single 'no-op' transformer layer. + + The sole purpose of this layer is for when a standalone embedding layer + is used (i.e., args.standalone_embedding_stage == True). In this case, + zero transformer layers are assigned when pipeline rank == 0. Additionally, + when virtual pipeline rank >= 1, zero total model parameters are created + (virtual rank 0 contains the input embedding). This results in the model's + input and output tensors being the same, which causes an error when + performing certain memory optimiations on the output tensor (e.g., + deallocating it). Thus, this layer disconnects the input from the output + via a clone. Since ranks containing a no-op layer are generally under- + utilized (both compute and memory), there's no worry of any performance + degredation. + """ + + def __init__(self, layer_number): + super().__init__() + self.layer_number = layer_number + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + inference_params=None): + return hidden_states.clone() + + +def _get_num_layers(args, model_type, is_decoder=False): + """Compute the number of transformer layers resident on the current rank.""" + is_encoder_and_decoder_model = (model_type == ModelType.encoder_and_decoder) + if model_type == ModelType.retro_encoder: + num_layers = args.retro_encoder_layers + elif mpu.get_pipeline_model_parallel_world_size() > 1: + if is_encoder_and_decoder_model: + assert args.pipeline_model_parallel_split_rank is not None + + # When a standalone embedding stage is used, a rank is taken from + # the encoder's ranks, to be used for the encoder's embedding + # layer. This way, the rank referenced by the 'split rank' remains + # the same whether or not a standalone embedding stage is used. + num_ranks_in_encoder = ( + args.pipeline_model_parallel_split_rank - 1 + if args.standalone_embedding_stage else + args.pipeline_model_parallel_split_rank + ) + num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder + assert args.encoder_num_layers % num_ranks_in_encoder == 0, \ + 'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder) + assert args.decoder_num_layers % num_ranks_in_decoder == 0, \ + 'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder) + if mpu.is_pipeline_stage_before_split(): + num_layers = ( + 0 + if args.standalone_embedding_stage + and mpu.get_pipeline_model_parallel_rank() == 0 else + args.encoder_num_layers // num_ranks_in_encoder + ) + else: + num_layers = args.decoder_num_layers // num_ranks_in_decoder + else: + assert args.num_layers == args.encoder_num_layers + assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \ + 'num_layers must be divisible by transformer_pipeline_model_parallel_size' + + # When a standalone embedding stage is used, all transformer layers + # are divided among pipeline rank >= 1, while on pipeline rank 0, + # ranks either contain the input embedding layer (virtual pp rank 0), + # or no layers at all (virtual pp rank >= 1). + num_layers = ( + 0 + if args.standalone_embedding_stage + and mpu.get_pipeline_model_parallel_rank() == 0 else + args.num_layers // args.transformer_pipeline_model_parallel_size + ) + else: + if not is_decoder: + num_layers = args.encoder_num_layers + else: + num_layers = args.decoder_num_layers + return num_layers + + +def _get_layer_type(model_type, default_layer_type, retro_layer_numbers, + layer_number): + args = get_args() + if args.retro_add_retriever and layer_number in retro_layer_numbers: + if model_type == ModelType.retro_decoder: + return LayerType.retro_decoder_with_retriever \ + if layer_number == retro_layer_numbers[0] \ + else LayerType.retro_decoder + elif model_type == ModelType.retro_encoder: + return LayerType.retro_encoder + else: + raise Exception("Unsupported model type, '%s'." % model_type) + else: + return default_layer_type + + +class ParallelTransformer(MegatronModule): + """Transformer class.""" + + def __init__(self, config, + model_type, layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + post_layer_norm=True, + pre_process=True, + post_process=True, + drop_path_rate=0.0): + super(ParallelTransformer, self).__init__() + args = get_args() + + self.layer_type = layer_type + self.model_type = model_type + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = post_layer_norm + self.pre_process = pre_process + self.post_process = post_process + self.input_tensor = None + self.drop_path_rate = drop_path_rate + self.transformer_impl = args.transformer_impl + self.retro_add_retriever = args.retro_add_retriever + + # Store activation checkpoiting flag. + self.recompute_granularity = config.recompute_granularity + self.recompute_method = config.recompute_method + self.recompute_num_layers = config.recompute_num_layers + self.distribute_saved_activations = \ + config.distribute_saved_activations and not config.sequence_parallel + + self.sequence_parallel = config.sequence_parallel + + # Transformer Engine Init. + self.transformer_engine_v_0_10 = False + self.transformer_engine_v_0_11 = False + self.transformer_engine_v_0_8 = False + if self.transformer_impl == 'transformer_engine': + global transformer_engine + import transformer_engine + from importlib.metadata import version + from pkg_resources import packaging + + te_version = packaging.version.Version(version("transformer-engine")) + if te_version >= packaging.version.Version("0.8.0"): + self.transformer_engine_v_0_8 = True + if te_version >= packaging.version.Version("0.10.0"): + self.transformer_engine_v_0_10 = True + if te_version >= packaging.version.Version("0.11.0"): + self.transformer_engine_v_0_11 = True + + del version, packaging + + assert not args.squared_relu, "TransformerEngine does not support squared relu activation." + + self.use_fp8 = args.fp8 is not None + self.fp8_recipe = None + self.fp8_group = None + if self.use_fp8: + assert args.transformer_impl == 'transformer_engine', \ + 'transformer-engine required for fp8 training and inference' + self.fp8_group = mpu.get_amax_reduction_group() + if args.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif args.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.") + self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + margin=args.fp8_margin, + interval=args.fp8_interval, + fp8_format=fp8_format, + amax_history_len=args.fp8_amax_history_len, + amax_compute_algo=args.fp8_amax_compute_algo, + override_linear_precision=(False, False, not args.fp8_wgrad), + ) + + self.num_microbatches_in_previous_step = -1 + self.microbatch_count = 0 + self.checkpoint_core_attention = config.recompute_granularity == 'selective' + + # Number of layers. + self.num_layers = _get_num_layers(args, model_type, + layer_type==LayerType.decoder) + + self.drop_path_rates = [ + rate.item() for rate in + torch.linspace(0, self.drop_path_rate, config.num_layers)] + + self.retro_layer_numbers = None + if model_type == ModelType.retro_decoder: + retro_layer_start = 6 if config.num_layers <= 15 else 9 + self.retro_layer_numbers = \ + np.arange(retro_layer_start, args.num_layers + 1, 3).tolist() + if model_type == ModelType.retro_encoder: + self.retro_layer_numbers = [1] + + # Transformer layers. + if args.retro_add_retriever: + assert self.recompute_granularity != 'full', \ + "Full recompute not supported for Retro." + assert args.transformer_impl == 'local', \ + "Transformer engine does not support Retro layers." + def build_layer(layer_number): + if args.transformer_impl == 'local': + current_layer_type = _get_layer_type( + model_type, layer_type, self.retro_layer_numbers, + layer_number) + return ParallelTransformerLayer( + config, + layer_number, + layer_type=current_layer_type, + self_attn_mask_type=self_attn_mask_type, + drop_path_rate=self.drop_path_rates[layer_number - 1]) + else: + # This argument is only available from TE v0.10 onwards. + extra_transformer_engine_kwargs = {} + if self.transformer_engine_v_0_8: + extra_transformer_engine_kwargs["bias"] = args.add_bias_linear + if self.transformer_engine_v_0_10: + extra_transformer_engine_kwargs["activation"] = "swiglu" if args.swiglu else "gelu" + if self.transformer_engine_v_0_11: + extra_transformer_engine_kwargs["normalization"] = args.normalization + return transformer_engine.pytorch.TransformerLayer( + config.hidden_size, + config.ffn_hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.layernorm_epsilon, + hidden_dropout=config.hidden_dropout, + attention_dropout=config.attention_dropout, + init_method=config.init_method, + output_layer_init_method=config.output_layer_init_method, + layer_number=layer_number, + kv_channels=config.kv_channels, + self_attn_mask_type=self_attn_mask_type.name, + tp_group=mpu.get_tensor_model_parallel_group(), + get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker, + fuse_wgrad_accumulation=config.gradient_accumulation_fusion, + apply_query_key_layer_scaling=config.apply_query_key_layer_scaling, + attention_softmax_in_fp32=config.attention_softmax_in_fp32, + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + sequence_parallel=config.sequence_parallel, + params_dtype=config.params_dtype, + apply_residual_connection_post_layernorm=config.apply_residual_connection_post_layernorm, + output_layernorm=False, + layer_type="encoder", + drop_path_rate=self.drop_path_rates[layer_number - 1], + set_parallel_mode=True, + fuse_qkv_params=True, + **extra_transformer_engine_kwargs) + + if config.virtual_pipeline_model_parallel_size is not None: + assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \ + 'num_layers_per_stage must be divisible by ' \ + 'virtual_pipeline_model_parallel_size' + assert args.model_type != ModelType.encoder_and_decoder + # Number of layers in each model chunk is the number of layers in the stage, + # divided by the number of model chunks in a stage. + self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size + # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0] [2] [4] [6] + # Stage 1: [1] [3] [5] [7] + # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0, 1] [4, 5] + # Stage 1: [2, 3] [6, 7] + offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( + config.num_layers // config.virtual_pipeline_model_parallel_size) + \ + (mpu.get_pipeline_model_parallel_rank() * self.num_layers) + else: + # Each stage gets a contiguous set of layers. + if args.model_type == ModelType.encoder_and_decoder and \ + mpu.get_pipeline_model_parallel_world_size() > 1: + pipeline_rank = mpu.get_pipeline_model_parallel_rank() + if layer_type == LayerType.encoder: + offset = pipeline_rank * self.num_layers + else: + num_ranks_in_enc = args.pipeline_model_parallel_split_rank + offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers + else: + offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers + + if self.num_layers == 0: + # When a standalone embedding stage is used (e.g., + # args.standalone_embedding_stage == True), virtual pipeline ranks + # on pipeline rank 0 will have zero transformer layers assigned to + # them. This results in the model's input and output tensors to be + # the same, which will cause failure for certain output tensor + # optimizations (e.g., pipeline output deallocation). To remedy + # this, we assign a 'no-op' layer on these ranks, which will + # disconnect the input tensor from the output tensor. + self.num_layers = 1 + self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ]) + else: + self.layers = torch.nn.ModuleList( + [build_layer(i + 1 + offset) for i in range(self.num_layers)]) + + # Update dropout rate for Retro encoder. + if model_type == ModelType.retro_encoder: + for layer in self.layers: + if layer.self_attention.use_flash_attn: + layer.self_attention.core_attention_flash.dropout_p = \ + torch.nn.Dropout(args.retro_encoder_attention_dropout) + else: + layer.self_attention.core_attention.attention_dropout.p =\ + args.retro_encoder_attention_dropout + layer.hidden_dropout = args.retro_encoder_hidden_dropout + + if self.post_process and self.post_layer_norm: + # Final layer norm before output. + self.final_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=args.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel, + apply_layernorm_1p=args.apply_layernorm_1p) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def _checkpointed_forward(self, hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + rotary_pos_emb, is_first_microbatch): + """Forward method with activation checkpointing.""" + def custom(start, end): + def custom_forward(*args, **kwargs): + x_, *args = args + for index in range(start, end): + layer = self._get_layer(index) + x_ = layer(x_, *args, **kwargs) + return x_ + return custom_forward + + te_forward_kwargs = {} + if self.transformer_impl == 'transformer_engine': + te_forward_kwargs['is_first_microbatch'] = is_first_microbatch + if self.transformer_engine_v_0_10: + te_forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + + if self.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and + # checkpoint the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + l = 0 + while l < self.num_layers: + if self.transformer_impl == 'transformer_engine': + hidden_states = transformer_engine.pytorch.checkpoint( + custom(l, l + self.recompute_num_layers), + self.distribute_saved_activations, + tensor_parallel.get_cuda_rng_tracker, + mpu.get_tensor_model_parallel_group(), + hidden_states, attention_mask, encoder_output, + enc_dec_attn_mask, **te_forward_kwargs) + else: + hidden_states = tensor_parallel.checkpoint( + custom(l, l + self.recompute_num_layers), + self.distribute_saved_activations, + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb) + + l += self.recompute_num_layers + + elif self.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + for l in range(self.num_layers): + if l < self.recompute_num_layers: + if self.transformer_impl == 'transformer_engine': + hidden_states = transformer_engine.pytorch.checkpoint( + custom(l, l + 1), + self.distribute_saved_activations, + tensor_parallel.get_cuda_rng_tracker, + mpu.get_tensor_model_parallel_group(), + hidden_states, attention_mask, encoder_output, + enc_dec_attn_mask, **te_forward_kwargs) + else: + hidden_states = tensor_parallel.checkpoint( + custom(l, l + 1), + self.distribute_saved_activations, + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb) + else: + if self.transformer_impl == 'transformer_engine': + hidden_states = custom(l, l + 1)( + hidden_states, attention_mask, encoder_output, + enc_dec_attn_mask, **te_forward_kwargs) + else: + hidden_states = custom(l, l + 1)( + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + retriever_input=None, + retriever_output=None, + retriever_attn_mask=None, + inference_params=None, + rotary_pos_emb=None): + # hidden_states: [s, b, h] + + # Checks. + if inference_params: + assert self.recompute_granularity is None, \ + 'inference does not work with activation checkpointing' + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = core.utils.make_viewless_tensor( + hidden_states, + requires_grad=True, + keep_graph=True, + ) + + # RNG context. + if self.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # Forward layers. + with rng_context: + # The fp8_autocast context manager is a no-op when enabled=True + # The if...else serves to short circuit name resolution for fp8_autocast + with transformer_engine.pytorch.fp8_autocast( + enabled=self.use_fp8, + fp8_recipe=self.fp8_recipe, + fp8_group=self.fp8_group + ) if self.use_fp8 else nullcontext(): + # Determine if the current iteration is first microbatch + if self.num_microbatches_in_previous_step != get_num_microbatches(): + self.microbatch_count = 0 # Reset count on new batch size rampup interval + self.num_microbatches_in_previous_step = get_num_microbatches() + is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 + + # Forward pass. + if self.recompute_granularity == 'full': + hidden_states = self._checkpointed_forward(hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + rotary_pos_emb, + is_first_microbatch) + else: + forward_kwargs = { + 'encoder_output': encoder_output, + 'enc_dec_attn_mask': enc_dec_attn_mask, + 'inference_params': inference_params, + } + + if self.transformer_impl == 'transformer_engine': + forward_kwargs['is_first_microbatch'] = is_first_microbatch + forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention + if self.transformer_engine_v_0_10: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + else: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + forward_kwargs['retriever_input'] = retriever_input + forward_kwargs['retriever_output'] = retriever_output + forward_kwargs['retriever_attn_mask'] = retriever_attn_mask + + for index in range(self.num_layers): + layer = self._get_layer(index) + + hidden_states = layer( + hidden_states, + attention_mask, + **forward_kwargs) + + # First Retro decoder layer returns both hidden_states + # and retriever_output. Make retriever_output available + # to subsequence Retro layers. + if isinstance(hidden_states, tuple): + assert len(hidden_states) == 2 + hidden_states, retriever_output = hidden_states + forward_kwargs["retriever_output"] = retriever_output + + # Skip counter update for eval and activation checkpointing + if torch.is_grad_enabled() and self.training: + self.microbatch_count += 1 + + # Final layer norm. + if self.post_process and self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states diff --git a/training/DeepSpeed-Domino/megatron/model/utils.py b/training/DeepSpeed-Domino/megatron/model/utils.py new file mode 100644 index 000000000..cf3727c02 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/utils.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Utilities for models.""" + +import math + +import torch + +from megatron import get_args + +def init_method_normal(sigma): + """Init method based on N(0, sigma).""" + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + + return init_ + + +def scaled_init_method_normal(sigma, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = sigma / math.sqrt(2.0 * num_layers) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + +def get_linear_layer(rows, columns, init_method): + """Simple linear layer with weight initialization.""" + layer = torch.nn.Linear(rows, columns) + if get_args().perform_initialization: + init_method(layer.weight) + with torch.no_grad(): + layer.bias.zero_() + return layer + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * + (1.0 + 0.044715 * x * x))) +def openai_gelu(x): + return gelu_impl(x) + +#This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter +@torch.jit.script +def erf_gelu(x): + return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype)) diff --git a/training/DeepSpeed-Domino/megatron/model/vision/classification.py b/training/DeepSpeed-Domino/megatron/model/vision/classification.py new file mode 100644 index 000000000..4d1a4e902 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/vision/classification.py @@ -0,0 +1,85 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Vision Transformer(VIT) model.""" + +import torch +from torch.nn.init import trunc_normal_ +from megatron import get_args +from megatron.model.utils import get_linear_layer +from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead +from megatron.model.vision.mit_backbone import mit_b3_avg +from megatron.model.module import MegatronModule + +class VitClassificationModel(MegatronModule): + """Vision Transformer Model.""" + + def __init__(self, config, num_classes, finetune=False, + pre_process=True, post_process=True): + super(VitClassificationModel, self).__init__() + args = get_args() + + self.hidden_size = args.hidden_size + self.num_classes = num_classes + self.finetune = finetune + self.pre_process = pre_process + self.post_process = post_process + self.backbone = VitBackbone( + config=config, + pre_process=self.pre_process, + post_process=self.post_process, + single_token_output=True + ) + + if self.post_process: + if not self.finetune: + self.head = VitMlpHead(self.hidden_size, self.num_classes) + else: + self.head = get_linear_layer( + self.hidden_size, + self.num_classes, + torch.nn.init.zeros_ + ) + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.backbone.set_input_tensor(input_tensor) + + def forward(self, input): + hidden_states = self.backbone(input) + + if self.post_process: + hidden_states = self.head(hidden_states) + + return hidden_states + + +class MitClassificationModel(MegatronModule): + """Mix vision Transformer Model.""" + + def __init__(self, num_classes, + pre_process=True, post_process=True): + super(MitClassificationModel, self).__init__() + args = get_args() + + self.hidden_size = args.hidden_size + self.num_classes = num_classes + + self.backbone = mit_b3_avg() + self.head = torch.nn.Linear(512, num_classes) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, torch.nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, torch.nn.Linear) and m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + pass + + def forward(self, input): + hidden_states = self.backbone(input) + hidden_states = self.head(hidden_states) + + return hidden_states diff --git a/training/DeepSpeed-Domino/megatron/model/vision/dino.py b/training/DeepSpeed-Domino/megatron/model/vision/dino.py new file mode 100644 index 000000000..1c577d2e1 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/vision/dino.py @@ -0,0 +1,290 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +# copied from https://github.com/facebookresearch/dino/blob/main/main_dino.py +# reworked/refactored some parts to make it run in Megatron. +import math +import apex +import einops +import torch +import numpy as np +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ +from megatron import get_args, print_rank_0 +from megatron.model.utils import get_linear_layer +from megatron.model.vision.vit_backbone import VitBackbone +from megatron.model.module import MegatronModule +from megatron.model.vision.mit_backbone import mit_b5_avg +from megatron.model.vision.esvit_swin_backbone import get_swin + + +class DINOLoss(torch.nn.Module): + def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, + warmup_teacher_temp_epochs, nepochs, student_temp=0.1, + center_momentum=0.9): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.ncrops = ncrops + self.register_buffer("center", torch.zeros(1, out_dim)) + # we apply a warm up for the teacher temperature because + # a too high temperature makes the training instable at the beginning + self.teacher_temp_schedule = np.concatenate(( + np.linspace(warmup_teacher_temp, + teacher_temp, warmup_teacher_temp_epochs), + np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp + )) + self.teacher_temp = teacher_temp + + def forward(self, student_output, teacher_output, iteration): + """ + Cross-entropy between softmax outputs of the teacher + and student network. + """ + args = get_args() + student_out = student_output / self.student_temp + student_out = student_out.chunk(self.ncrops) + + epoch = iteration // args.iter_per_epoch + + # teacher centering and sharpening + temp = self.teacher_temp_schedule[epoch] + teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) + + teacher_out = teacher_out.detach().chunk(2) + + total_loss = 0 + n_loss_terms = 0 + for iq, q in enumerate(teacher_out): + for v in range(len(student_out)): + if v == iq: + # we skip cases where student and teacher operate on the same view + continue + loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) + total_loss += loss.mean() + n_loss_terms += 1 + total_loss /= n_loss_terms + self.update_center(teacher_output) + return total_loss + + @torch.no_grad() + def update_center(self, teacher_output): + """ + Update center used for teacher output. + """ + batch_center = torch.sum(teacher_output, dim=0, keepdim=True) + torch.distributed.all_reduce(batch_center) + batch_center = batch_center / (len(teacher_output) * torch.distributed.get_world_size()) + self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) + +class DINOHead(torch.nn.Module): + def __init__(self, in_dim, out_dim, norm_last_layer=True, nlayers=3): + super().__init__() + args = get_args() + hidden_dim = args.dino_head_hidden_size + bottleneck_dim = args.dino_bottleneck_size + nlayers = max(nlayers, 1) + if nlayers == 1: + self.mlp = torch.nn.Linear(in_dim, bottleneck_dim) + else: + layers = [torch.nn.Linear(in_dim, hidden_dim)] + layers.append(torch.nn.GELU()) + for _ in range(nlayers - 2): + layers.append(torch.nn.Linear(hidden_dim, hidden_dim)) + layers.append(torch.nn.GELU()) + layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim)) + self.mlp = torch.nn.Sequential(*layers) + self.apply(self._init_weights) + self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + if norm_last_layer: + self.last_layer.weight_g.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, torch.nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, torch.nn.Linear) and m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + x = torch.nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x + + +class MultiCropWrapper(MegatronModule): + + """ + Perform forward pass separately on each resolution input. + The inputs corresponding to a single resolution are clubbed and single + forward is run on the same resolution inputs. Hence we do several + forward passes = number of different resolutions used. We then + concatenate all the output features and run the head forward on these + concatenated features. + """ + def __init__(self, backbone, head): + super(MultiCropWrapper, self).__init__() + # disable layers dedicated to ImageNet labels classification + #backbone.fc, backbone.head = torch.nn.Identity(), torch.nn.Identity() + self.backbone = backbone + self.head = head + + def forward(self, x): + # convert to list + if not isinstance(x, list): + x = [x] + idx_crops = torch.cumsum(torch.unique_consecutive( + torch.tensor([inp.shape[-1] for inp in x]), + return_counts=True, + )[1], 0) + + start_idx = 0 + for end_idx in idx_crops: + _out = self.backbone(torch.cat(x[start_idx: end_idx])) + if start_idx == 0: + output = _out + else: + output = torch.cat((output, _out)) + start_idx = end_idx + # Run the head forward on the concatenated features. + if self.training: + return self.head(output) + else: + return output + + +def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, + warmup_epochs=0, start_warmup_value=0): + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * niter_per_ep + if warmup_epochs > 0: + warmup_schedule = \ + np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(epochs * niter_per_ep - warmup_iters) + schedule = final_value + 0.5 * (base_value - final_value) \ + * (1 + np.cos(np.pi * iters / len(iters))) + + schedule = np.concatenate((warmup_schedule, schedule)) + assert len(schedule) == epochs * niter_per_ep + return schedule + + +def get_student_backbone_and_num_features(config, pre_process=True, post_process=True): + args = get_args() + + if args.vision_backbone_type == 'vit': + student = VitBackbone(config, + pre_process=pre_process, + post_process=post_process, + drop_path_rate=0.1, + single_token_output=True) + num_features = args.hidden_size + elif args.vision_backbone_type == 'mit': + student = mit_b5_avg(drop_path_rate=0.1) + num_features = 512 + elif args.vision_backbone_type == 'swin': + student = get_swin() + num_features = student.num_features + else: + raise Exception('{} vision backbone is not supported.'.format( + args.vision_backbone_type)) + + return student, num_features + +def get_teacher_backbone_and_num_features(config, pre_process=True, post_process=True): + args = get_args() + + if args.vision_backbone_type == 'vit': + teacher = VitBackbone(config, + pre_process=pre_process, + post_process=post_process, + single_token_output=True) + num_features = args.hidden_size + elif args.vision_backbone_type == 'mit': + teacher = mit_b5_avg(drop_path_rate=0.0) + num_features = 512 + elif args.vision_backbone_type == 'swin': + teacher = get_swin(is_teacher=True) + num_features = teacher.num_features + else: + raise Exception('{} vision backbone is not supported.'.format( + args.vision_backbone_type)) + return teacher, num_features + + +class DINOPretrainModel(MegatronModule): + def __init__(self, config, pre_process=True, post_process=True): + super(DINOPretrainModel, self).__init__() + args = get_args() + self.out_dim = 65536 + + self.dino_loss = DINOLoss( + self.out_dim, + args.dino_local_crops_number + 2, + args.dino_warmup_teacher_temp, + args.dino_teacher_temp, + args.dino_warmup_teacher_temp_epochs, + 300, + ) + + self.pre_process = pre_process + self.post_process = post_process + self.momentum_teacher = 0.996 + + student_backbone, num_features = \ + get_student_backbone_and_num_features(config, pre_process, post_process) + + self.student = MultiCropWrapper( + student_backbone, + DINOHead(num_features, self.out_dim, + norm_last_layer=args.dino_norm_last_layer) + ) + + self.momentum_schedule = cosine_scheduler( + self.momentum_teacher, 1, + args.train_iters // args.iter_per_epoch, + args.iter_per_epoch + ) + + teacher_backbone, num_features = \ + get_teacher_backbone_and_num_features(config, pre_process, post_process) + self.teacher = MultiCropWrapper( + teacher_backbone, + DINOHead(num_features, self.out_dim) + ) + self.teacher.load_state_dict(self.student.state_dict()) + + for p in self.teacher.parameters(): + if hasattr(p, "requires_grad") and p.requires_grad is not None: + p.requires_grad = False + + def set_input_tensor(self, tensor): + pass + + def forward(self, input): + student_output = None + if self.training: + student_output = self.student(input) + teacher_output = self.teacher(input[:2]) + else: + teacher_output = self.teacher(input) + return student_output, teacher_output + + def cancel_gradients_last_layer(self, iteration): + args = get_args() + epoch = iteration // args.iter_per_epoch + if epoch < args.dino_freeze_last_layer: + for n, p in self.student.named_parameters(): + if "last_layer" in n: + p.grad = None + + def update_momentum(self, iteration): + with torch.no_grad(): + m = self.momentum_schedule[iteration] + for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()): + param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) + diff --git a/training/DeepSpeed-Domino/megatron/model/vision/esvit_swin_backbone.py b/training/DeepSpeed-Domino/megatron/model/vision/esvit_swin_backbone.py new file mode 100644 index 000000000..70aee3db4 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/vision/esvit_swin_backbone.py @@ -0,0 +1,849 @@ +# Copyright (c) 2021 Microsoft +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Modified by Chunyuan Li (chunyl@microsoft.com) +# Swin Transformer +# -------------------------------------------------------- + +import os +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +import torch.distributed as dist +from torch.nn.init import trunc_normal_ +from megatron.model.transformer import DropPath +from megatron import get_args +from megatron.model import LayerNorm +import numpy as np +from math import sqrt + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, drop=0.): + super(Mlp, self).__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super(WindowAttention, self).__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2 Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0).type(attn.type()) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn_out = attn + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn_out + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + @staticmethod + def compute_macs(module, input, output): + B, N, C = input[0].shape + + module.__flops__ += module.flops(N) * B + + +class SwinTransformerBlock(nn.Module): + r"""Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = input_resolution[0] + self.W = input_resolution[1] + + self.attn_mask_dict = {} + + + def create_attn_mask(self, H, W): + # calculate attention mask for SW-MSA + + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1)) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + + def forward(self, x): + B, L, C = x.shape + H = int(sqrt(L)) + W = H + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + + if H in self.attn_mask_dict.keys(): + attn_mask = self.attn_mask_dict[H] + else: + self.attn_mask_dict[H] = self.create_attn_mask(self.H, self.W).to(x.device) + attn_mask = self.attn_mask_dict[H] + + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows, attn = self.attn(x_windows, attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x, attn + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size} mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r"""Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + H = int(sqrt(L)) + W = H + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + x, _ = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def forward_with_features(self, x): + fea = [] + for blk in self.blocks: + x, _ = blk(x) + fea.append(x) + if self.downsample is not None: + x = self.downsample(x) + return x, fea + + def forward_with_attention(self, x): + attns = [] + for blk in self.blocks: + x, attn = blk(x) + attns.append(attn) + if self.downsample is not None: + x = self.downsample(x) + return x, attns + + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. + patch_size (int | tuple(int)): Patch size. + in_chans (int): Number of input channels. + num_classes (int): Number of classes for classification head. + embed_dim (int): Embedding dimension. + depths (tuple(int)): Depth of Swin Transformer layers. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): normalization layer. + ape (bool): If True, add absolute position embedding to the patch embedding. + patch_norm (bool): If True, add normalization after patch embedding. + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + # todo: to be implemented + return {'relative_position_bias_table'} + + def forward(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x_region = self.norm(x) # B L C + x = self.avgpool(x_region.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + + return x + + + def forward_feature_maps(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x_grid = self.norm(x) # B L C + x = self.avgpool(x_grid.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + + return x, x_grid + + + def forward_selfattention(self, x, n=1): + # n=1 return the last layer attn map; otherwise return attn maps in all layers + + + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + if n==1: + return self.forward_last_selfattention(x) + else: + return self.forward_all_selfattention(x) + + def forward_last_selfattention(self, x): + + for i, layer in enumerate(self.layers): + if i < len(self.layers) - 1: + x = layer(x) + else: + x, attns = layer.forward_with_attention(x) + return attns[-1] + + def forward_all_selfattention(self, x): + attn_out = [] + + for layer in self.layers: + x, attns = layer.forward_with_attention(x) + attn_out += attns + + return attn_out + + + def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False, depth=[]): + + num_blks = sum(depth) + start_idx = num_blks - n + + sum_cur = 0 + for i, d in enumerate(depth): + sum_cur_new = sum_cur + d + if start_idx >= sum_cur and start_idx < sum_cur_new: + start_stage = i + start_blk = start_idx - sum_cur + sum_cur = sum_cur_new + + + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + # we will return the averaged token features from the `n` last blocks + # note: there is no [CLS] token in Swin Transformer + output = [] + s = 0 + for i, layer in enumerate(self.layers): + x, fea = layer.forward_with_features(x) + + if i >= start_stage: + for x_ in fea[start_blk:]: + + if i == len(self.layers)-1: # use the norm in the last stage + x_ = self.norm(x_) + + x_avg = torch.flatten(self.avgpool(x_.transpose(1, 2)), 1) # B C + # print(f'Stage {i}, x_avg {x_avg.shape}') + output.append(x_avg) + + start_blk = 0 + + return torch.cat(output, dim=-1) + + + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + if dist.get_rank() == 0: + print(f"GFLOPs layer_{i}: {layer.flops() / 1e9}") + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + def init_weights(self, pretrained='', pretrained_layers=[], verbose=True): + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained, map_location='cpu') + logging.info(f'=> loading pretrained model {pretrained}') + model_dict = self.state_dict() + pretrained_dict = { + k: v for k, v in pretrained_dict.items() + if k in model_dict.keys() + } + need_init_state_dict = {} + for k, v in pretrained_dict.items(): + need_init = ( + k.split('.')[0] in pretrained_layers + or pretrained_layers[0] is '*' + or 'relative_position_index' not in k + or 'attn_mask' not in k + ) + + if need_init: + if verbose: + logging.info(f'=> init {k} from {pretrained}') + + if 'relative_position_bias_table' in k and v.size() != model_dict[k].size(): + relative_position_bias_table_pretrained = v + relative_position_bias_table_current = model_dict[k] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + if nH1 != nH2: + logging.info(f"Error in loading {k}, passing") + else: + if L1 != L2: + logging.info( + '=> load_pretrained: resized variant: {} to {}' + .format((L1, nH1), (L2, nH2)) + ) + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( + relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) + + if 'absolute_pos_embed' in k and v.size() != model_dict[k].size(): + absolute_pos_embed_pretrained = v + absolute_pos_embed_current = model_dict[k] + _, L1, C1 = absolute_pos_embed_pretrained.size() + _, L2, C2 = absolute_pos_embed_current.size() + if C1 != C1: + logging.info(f"Error in loading {k}, passing") + else: + if L1 != L2: + logging.info( + '=> load_pretrained: resized variant: {} to {}' + .format((1, L1, C1), (1, L2, C2)) + ) + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) + absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) + absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( + absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') + v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2) + + need_init_state_dict[k] = v + self.load_state_dict(need_init_state_dict, strict=False) + + def freeze_pretrained_layers(self, frozen_layers=[]): + for name, module in self.named_modules(): + if ( + name.split('.')[0] in frozen_layers + or '.'.join(name.split('.')[0:2]) in frozen_layers + or (len(frozen_layers) > 0 and frozen_layers[0] is '*') + ): + for _name, param in module.named_parameters(): + param.requires_grad = False + logging.info( + '=> set param {} requires grad to False' + .format(name) + ) + for name, param in self.named_parameters(): + if ( + name.split('.')[0] in frozen_layers + or (len(frozen_layers) > 0 and frozen_layers[0] is '*') + and param.requires_grad is True + ): + param.requires_grad = False + logging.info( + '=> set param {} requires grad to False' + .format(name) + ) + return self + + +def get_swin(is_teacher=False): + args = get_args() + + if args.swin_backbone_type == "tiny": + embed_dim = 96 + depths = [2, 2, 6, 2] + num_heads = [3, 6, 12, 24] + drop_path_rate = 0.1 + elif args.swin_backbone_type == 'h3': + embed_dim = 384 + depths = [2, 2, 18, 2] + num_heads = [6, 12, 24, 48] + drop_path_rate = 0.2 + else: + embed_dim = 128 + depths = [2, 2, 18, 2] + num_heads = [4, 8, 16, 32] + drop_path_rate = 0.2 + + swin = SwinTransformer( + img_size=224, + in_chans=3, + num_classes=1000, + patch_size=4, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=7, + mlp_ratio=4, + qkv_bias=True, + drop_rate=0, + attn_drop_rate=0, + drop_path_rate=(0.0 if is_teacher else drop_path_rate), + norm_layer=partial(LayerNorm, eps=1e-6), + ape=False, + patch_norm=True, + ) + + return swin + diff --git a/training/DeepSpeed-Domino/megatron/model/vision/inpainting.py b/training/DeepSpeed-Domino/megatron/model/vision/inpainting.py new file mode 100644 index 000000000..cda03315b --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/vision/inpainting.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +i +import math +import apex +import einops +import torch +import torch.nn.functional as F +from megatron import get_args, print_rank_0 +from megatron.model.utils import get_linear_layer +from megatron.model.vision.vit_backbone import VitBackbone +from megatron.model.module import MegatronModule +from megatron.model.vision.mit_backbone import mit_b3 +from megatron.model.vision.utils import resize_ + + +class VitInpaintingModel(MegatronModule): + + def __init__(self, config, pre_process=True, post_process=True): + super(VitInpaintingModel, self).__init__() + args = get_args() + + self.pre_process = pre_process + self.post_process = post_process + self.hidden_size = config.hidden_size + self.backbone = VitBackbone( + config=config, + pre_process=self.pre_process, + post_process=self.post_process, + class_token=False, + ) + self.patch_dim = args.patch_dim + self.img_h = args.img_h + self.img_w = args.img_w + self.seq_length = args.seq_length + # full mask + + if self.post_process: + self.linear_decoder = get_linear_layer( + self.hidden_size, + self.backbone.flatten_dim, + torch.nn.init.zeros_ + ) + + def set_input_tensor(self, input_tensor): + self.backbone.set_input_tensor(input_tensor) + + def forward(self, input): + + hidden_states = self.backbone(input) + + if not self.post_process: + return hidden_states + decoded_output = self.linear_decoder(hidden_states) + output = einops.rearrange( + decoded_output, + "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", + p1=self.patch_dim, + p2=self.patch_dim, + h=self.img_h//self.patch_dim, + w=self.img_w//self.patch_dim, + ) + + return output + + +class MLP(torch.nn.Module): + """ + Linear Embedding + """ + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = torch.nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class MitInpaintingModel(MegatronModule): + """Mix vision Transformer Model.""" + + def __init__(self, pre_process=True, post_process=True): + super(MitInpaintingModel, self).__init__() + self.pre_process = pre_process + self.post_process = post_process + + args = get_args() + self.patch_dim = args.patch_dim + self.img_h = args.img_h + self.img_w = args.img_w + self.flatten_dim = self.patch_dim * self.patch_dim * 3 + self.backbone = mit_b3() + + self.in_channels = [64, 128, 320, 512] + self.embedding_dim = 768 + + c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + + self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim) + self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim) + self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim) + self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim) + + self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4, self.embedding_dim, 1, 1, bias=False) + self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim) + self.dropout = torch.nn.Dropout2d(0.1) + + self.linear_pred = torch.nn.Conv2d(self.embedding_dim, self.flatten_dim, kernel_size=1) + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + pass + + def forward(self, input): + c1, c2, c3, c4 = self.backbone(input) + + n, _, h, w = c4.shape + _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) + _c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False) + + _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) + _c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False) + + _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) + _c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False) + + _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) + + _c = torch.cat([_c4, _c3, _c2, _c1], dim=1) + _c = self.conv_fuse(_c) + + x = self.norm(_c) + x = F.relu(x, inplace=True) + x = self.dropout(x) + + x = self.linear_pred(x) + + output = einops.rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.patch_dim, + p2=self.patch_dim, + h=self.img_h//self.patch_dim, + w=self.img_w//self.patch_dim, + ) + + return output diff --git a/training/DeepSpeed-Domino/megatron/model/vision/knn_monitor.py b/training/DeepSpeed-Domino/megatron/model/vision/knn_monitor.py new file mode 100644 index 000000000..a7d79854e --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/vision/knn_monitor.py @@ -0,0 +1,129 @@ +import torch.nn.functional as F +import torch +from megatron import print_rank_0, get_args +from megatron.core import mpu +from megatron.data.vit_dataset import ClassificationTransform +from megatron.data.image_folder import ImageFolder + +_FEATURE_BANK = None + + +def build_data_loader(dataset, drop_last=True, shuffle=False): + """Data loader. Note that batch-size is the local (per GPU) batch-size.""" + # Sampler. + args = get_args() + micro_batch_size = 16 + num_workers = args.num_workers + world_size = mpu.get_data_parallel_world_size() + rank = mpu.get_data_parallel_rank() + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, + drop_last=drop_last, shuffle=shuffle + ) + + # Data loader. Note that batch size is the per GPU batch size. + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=micro_batch_size, + sampler=sampler, + shuffle=False, + num_workers=num_workers, + drop_last=not drop_last, + pin_memory=True, + ) + return data_loader + + +def compute_feature_bank(model): + args = get_args() + global _FEATURE_BANK + feature_bank = [] + feature_label = [] + + train_ds = ImageFolder( + root=args.data_path[0], + transform=ClassificationTransform((args.img_h, args.img_w), train=False), + data_per_class_fraction=1.0 + ) + classes = len(train_ds.classes) + dataloader = build_data_loader(train_ds) + + for m in model: + m.eval() + + with torch.no_grad(): + for i, batch in enumerate(dataloader): + images = batch[0].cuda().contiguous() + labels = batch[1].cuda().contiguous() + student_feature, teacher_feature = model[0](images) + feature = F.normalize(teacher_feature.float(), dim=1) + feature_bank.append(feature) + feature_label.append(labels) + + for m in model: + m.train() + + # [N', D] + feature_bank = torch.cat(feature_bank, dim=0).contiguous() + feature_label = torch.cat(feature_label, dim=0).contiguous() + + feature_banks = [torch.zeros_like(feature_bank) + for i in range(mpu.get_data_parallel_world_size())] + torch.distributed.all_gather(feature_banks, + feature_bank, + group=mpu.get_data_parallel_group()) + + assert torch.all(torch.eq(feature_banks[mpu.get_data_parallel_rank()], + feature_bank)) + + feature_labels = [torch.zeros_like(feature_label) + for i in range(mpu.get_data_parallel_world_size())] + torch.distributed.all_gather(feature_labels, + feature_label, + group=mpu.get_data_parallel_group()) + + # [D, N] + feature_banks = torch.cat(feature_banks, dim=0).t().contiguous() + # [N] + feature_labels = torch.cat(feature_labels, dim=0).contiguous() + print_rank_0("feature_banks size is {}".format(feature_banks.size())) + print_rank_0("feature labels size is {}".format(feature_labels.size())) + + _FEATURE_BANK = (feature_banks, feature_labels, classes) + + +def get_feature_bank(): + global _FEATURE_BANK + assert _FEATURE_BANK is not None + return _FEATURE_BANK + + +# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978 +# implementation follows http://github.com/zhirongw/lemniscate.pytorch and +# https://github.com/leftthomas/SimCLR +def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): + # compute cos similarity between each feature vector and feature bank ---> [B, N] + sim_matrix = torch.mm(feature, feature_bank) + # [B, K] + sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) + # [B, K] + sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), + dim=-1, + index=sim_indices) + sim_weight = (sim_weight / knn_t).exp() + + # counts for each class + one_hot_label = torch.zeros(feature.size(0) * knn_k, + classes, + device=sim_labels.device) + # [B*K, C] + one_hot_label = one_hot_label.scatter(dim=-1, + index=sim_labels.view(-1, 1), + value=1.0) + # weighted score ---> [B, C] + pred_scores = torch.sum( + one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), + dim=1) + + pred_labels = pred_scores.argsort(dim=-1, descending=True) + return pred_labels diff --git a/training/DeepSpeed-Domino/megatron/model/vision/mit_backbone.py b/training/DeepSpeed-Domino/megatron/model/vision/mit_backbone.py new file mode 100644 index 000000000..6640b105d --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/vision/mit_backbone.py @@ -0,0 +1,415 @@ +# Copyright (c) 2023, NVIDIA Corporation. All rights reserved. + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from torch.nn.init import trunc_normal_ +from megatron.model.transformer import DropPath +from megatron.model import LayerNorm + + +class Mlp(nn.Module): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = LayerNorm(dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + return x + + +class OverlapPatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2)) + self.norm = LayerNorm(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + + return x, H, W + + +class MixVisionTransformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNorm, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], output_avg=False): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.output_avg = output_avg + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + def forward_features(self, x): + B = x.shape[0] + outs = [] + + # stage 1 + x, H, W = self.patch_embed1(x) + for i, blk in enumerate(self.block1): + x = blk(x, H, W) + x = self.norm1(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 2 + x, H, W = self.patch_embed2(x) + for i, blk in enumerate(self.block2): + x = blk(x, H, W) + x = self.norm2(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 3 + x, H, W = self.patch_embed3(x) + for i, blk in enumerate(self.block3): + x = blk(x, H, W) + x = self.norm3(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 4 + x, H, W = self.patch_embed4(x) + for i, blk in enumerate(self.block4): + x = blk(x, H, W) + x = self.norm4(x) + if not self.output_avg: + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + return outs + + def forward(self, x): + x = self.forward_features(x) + + if self.output_avg: + x = x[3].mean(dim=1) + + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + + return x + +class mit_b0(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b0, self).__init__( + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b1(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b1, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b2(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b2, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +class mit_b3(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b3, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + +class mit_b3_avg(MixVisionTransformer): + def __init__(self, drop_path_rate=0.1, **kwargs): + super(mit_b3_avg, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True) + +class mit_b4(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b4, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + +class mit_b5(MixVisionTransformer): + def __init__(self, **kwargs): + super(mit_b5, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + +class mit_b5_avg(MixVisionTransformer): + def __init__(self, drop_path_rate=0.1, **kwargs): + super(mit_b5_avg, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True) + diff --git a/training/DeepSpeed-Domino/megatron/model/vision/swin_backbone.py b/training/DeepSpeed-Domino/megatron/model/vision/swin_backbone.py new file mode 100644 index 000000000..9a622c707 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/vision/swin_backbone.py @@ -0,0 +1,625 @@ +# Copyright (c) 2021 Microsoft +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Swin Transformer +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from math import sqrt + +from megatron import get_args +from functools import partial + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = input_resolution[0] + self.W = input_resolution[1] + + self.attn_mask_dict = {} + + def create_attn_mask(self, H, W): + # calculate attention mask for SW-MSA + + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1)) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + + def forward(self, x): + B, L, C = x.shape + H = int(sqrt(L)) + W = H + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + x_b4_ds = x + if self.downsample is not None: + x = self.downsample(x) + return x_b4_ds, x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, + norm_layer=partial(nn.LayerNorm, eps=1e-6), ape=False, patch_norm=True, + use_checkpoint=False, output_avg=False, **kwargs): + super().__init__() + + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + self.img_size = to_2tuple(img_size) + self.patch_size = to_2tuple(patch_size) + self.output_avg = output_avg + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + h = self.img_size[0] // self.patch_size[0] + w = self.img_size[1] // self.patch_size[1] + outs = [] + + for i, layer in enumerate(self.layers): + px, x = layer(x) + b, n, c = px.shape + + if i != len(self.layers) - 1 or not self.output_avg: + px = px.permute(0, 2, 1).contiguous() + px = px.reshape(b, c, h, w) + # is this a fair assumption ?? i think it's baked into the architecture + h, w = h//2, w//2 + outs.append(px) + + if self.output_avg: + return outs[-1].mean(dim=1) + + return outs + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + +def get_swin(drop_path_rate=0.3, output_avg=False): + args = get_args() + + window_size = 7 + embed_dim = 128 + depths = [2, 2, 18, 2] + num_heads = [4, 8, 16, 32] + swin = SwinTransformer( + img_size=(args.img_h, args.img_w,), + in_chans=3, + patch_size=args.patch_dim, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + drop_path_rate=drop_path_rate, + output_avg=output_avg, + ) + + return swin + diff --git a/training/DeepSpeed-Domino/megatron/model/vision/utils.py b/training/DeepSpeed-Domino/megatron/model/vision/utils.py new file mode 100644 index 000000000..b4068912c --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/vision/utils.py @@ -0,0 +1,27 @@ +import warnings +import torch +import torch.nn.functional as F + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + if isinstance(size, torch.Size): + size = tuple(int(x) for x in size) + return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/training/DeepSpeed-Domino/megatron/model/vision/vit_backbone.py b/training/DeepSpeed-Domino/megatron/model/vision/vit_backbone.py new file mode 100644 index 000000000..1efef9c17 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/model/vision/vit_backbone.py @@ -0,0 +1,245 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Vision Transformer(VIT) model.""" + +import math +import einops +import torch +import apex +import torch.nn.functional as F +from megatron import get_args +from megatron.model.transformer import ParallelTransformer +from megatron.model.utils import ( + get_linear_layer, + init_method_normal, + scaled_init_method_normal, +) +from megatron.model.module import MegatronModule + +CLASS_TOKEN_LENGTH = 8 + +class VitMlpHead(MegatronModule): + """Pooler layer. + + Pool hidden states of a specific token (for example start of the + sequence) and add a linear transformation followed by a tanh. + + Arguments: + hidden_size: hidden size + init_method: weight initialization method for the linear layer. + bias is set to zero. + """ + + def __init__(self, hidden_size, num_classes): + super(VitMlpHead, self).__init__() + self.dense_in = torch.nn.Linear(hidden_size, hidden_size) + self.relu = torch.nn.ReLU() + self.dense_out = torch.nn.Linear(hidden_size, num_classes) + torch.nn.init.constant_(self.dense_out.bias, -10) + + def forward(self, hidden_states): + # hidden_states: [b, 1, h] + # sequence_index: index of the token to pool. + dense_in_result = self.dense_in(hidden_states) + tanh_result = torch.tanh(dense_in_result) + dense_out_result = self.dense_out(tanh_result) + return dense_out_result + + +def isPerfectSquare(x): + if(x >= 0): + sr = math.sqrt(x) + return (int(sr) * int(sr) == x) + return False + + +def twod_interpolate_position_embeddings_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + + args = get_args() + num_patches_per_dim_h = args.img_h // args.patch_dim + num_patches_per_dim_w = args.img_w // args.patch_dim + num_patches = num_patches_per_dim_h * num_patches_per_dim_w + hidden_size = args.hidden_size + + key = prefix + "weight" + + assert key in state_dict + if key in state_dict: + input_param = state_dict[key] + + input_seq_len = input_param.shape[0] + assert(isPerfectSquare(input_seq_len) or isPerfectSquare(input_seq_len - CLASS_TOKEN_LENGTH)) + input_has_class_token = not isPerfectSquare(input_seq_len) + num_tok_input = input_seq_len - CLASS_TOKEN_LENGTH if input_has_class_token else input_seq_len + num_tok_output = num_patches + output_has_class_token = args.class_token_present + + # update input_param and load it to state_dict[key] + if input_has_class_token: + input_param_tok = input_param[:CLASS_TOKEN_LENGTH, :] + input_param_grid = input_param[CLASS_TOKEN_LENGTH:, :] + else: + input_param_tok = torch.zeros(CLASS_TOKEN_LENGTH, hidden_size) + input_param_grid = input_param + + assert input_param.shape[1] == hidden_size + + if num_tok_input != num_tok_output: + + gs_input = int(math.sqrt(num_tok_input)) + gs_new = (num_patches_per_dim_h, num_patches_per_dim_w) + + input_param_grid = input_param_grid.transpose(0, 1).contiguous() + input_param_grid = input_param_grid.reshape( + (1, -1, gs_input, gs_input) + ) + input_param_grid = input_param_grid.float() + scale_factor = (gs_new[0] / gs_input, gs_new[1] / gs_input) + + input_param_grid = F.interpolate( + input_param_grid, scale_factor=scale_factor, mode="bilinear" + ) + + input_param_grid = input_param_grid.half() + input_param_grid = input_param_grid.reshape((-1, num_tok_output)) + input_param_grid = input_param_grid.transpose(0, 1).contiguous() + + assert input_param_grid.shape[1] == hidden_size + + input_param = input_param_grid + assert ( + input_param.shape[0] == num_tok_output + and input_param.shape[1] == hidden_size + ) + + if output_has_class_token: + input_param = torch.cat((input_param_tok, input_param), dim=0) + + state_dict[key] = input_param + + +class VitBackbone(MegatronModule): + """Vision Transformer Model.""" + + def __init__(self, + config, + pre_process=True, + post_process=True, + class_token=True, + single_token_output=False, + post_layer_norm=True, + drop_path_rate=0.0): + super(VitBackbone, self).__init__(share_embeddings_and_output_weights=False) + args = get_args() + + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + + self.pre_process = pre_process + self.post_process = post_process + self.class_token = class_token + self.post_layer_norm = post_layer_norm + self.hidden_size = args.hidden_size + self.patch_dim = args.patch_dim + self.img_h = args.img_h + self.img_w = args.img_w + self.micro_batch_size = args.micro_batch_size + self.single_token_output = single_token_output + self.drop_path_rate = drop_path_rate + + assert self.img_h % self.patch_dim == 0 + assert self.img_w % self.patch_dim == 0 + self.num_patches_per_dim_h = self.img_h // self.patch_dim + self.num_patches_per_dim_w = self.img_w // self.patch_dim + self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w + self.seq_length = self.num_patches + (CLASS_TOKEN_LENGTH if self.class_token else 0) + self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels + self.input_tensor = None + self.position_ids = None + + if self.pre_process: + # cls_token + if self.class_token: + self.cls_token = torch.nn.Parameter( + torch.randn(1, CLASS_TOKEN_LENGTH, self.hidden_size) + ) + torch.nn.init.zeros_(self.cls_token) + self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda() + + # Linear encoder + self.linear_encoder = torch.nn.Linear( + self.flatten_dim, self.hidden_size + ) + + # embedding + self.position_embeddings = torch.nn.Embedding( + self.seq_length, self.hidden_size + ) + init_method_normal(args.init_method_std)( + self.position_embeddings.weight + ) + + args.class_token_present = self.class_token + self.position_embeddings._register_load_state_dict_pre_hook( + twod_interpolate_position_embeddings_hook + ) + + self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout) + + # Transformer + self.transformer = ParallelTransformer( + config, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=self.post_layer_norm, + drop_path_rate=self.drop_path_rate + ) + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.transformer.set_input_tensor(input_tensor) + + def forward(self, input): + + if self.pre_process: + rearranged_input = einops.rearrange( + input, + "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", + p1=self.patch_dim, + p2=self.patch_dim, + ) + + assert rearranged_input.dtype == torch.half + encoder_output = self.linear_encoder(rearranged_input) + + concatenated_tokens = encoder_output + if self.class_token: + cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1) + concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1) + + token_embeddings = concatenated_tokens + \ + self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]]) + # [b, s, h] => [s, b, h] + token_embeddings = token_embeddings.transpose(0, 1).contiguous() + hidden_states = self.embedding_dropout(token_embeddings) + else: + hidden_states = input + + hidden_states = self.transformer(hidden_states, None) + + if self.post_process: + # [s b h] => [b s h] + if self.single_token_output: + hidden_states = hidden_states[0] + else: + hidden_states = hidden_states.transpose(0, 1).contiguous() + + return hidden_states + diff --git a/training/DeepSpeed-Domino/megatron/text_generation_server.py b/training/DeepSpeed-Domino/megatron/text_generation_server.py new file mode 100644 index 000000000..8bd6c26fc --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/text_generation_server.py @@ -0,0 +1,241 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import datetime +import torch +import json +import threading +from flask import Flask, request, jsonify, current_app +from flask_restful import Resource, Api +from megatron import get_args +from megatron.text_generation import generate_and_post_process +from megatron.text_generation import beam_search_and_post_process + + +GENERATE_NUM = 0 +BEAM_NUM = 1 +lock = threading.Lock() + +class MegatronGenerate(Resource): + def __init__(self, model): + self.model = model + + @staticmethod + def send_do_generate(): + choice = torch.cuda.LongTensor([GENERATE_NUM]) + torch.distributed.broadcast(choice, 0) + + @staticmethod + def send_do_beam_search(): + choice = torch.cuda.LongTensor([BEAM_NUM]) + torch.distributed.broadcast(choice, 0) + + def put(self): + args = get_args() + + if not "prompts" in request.get_json(): + return "prompts argument required", 400 + + if "max_len" in request.get_json(): + return "max_len is no longer used. Replace with tokens_to_generate", 400 + + if "sentences" in request.get_json(): + return "sentences is no longer used. Replace with prompts", 400 + + prompts = request.get_json()["prompts"] + if not isinstance(prompts, list): + return "prompts is not a list of strings", 400 + + if len(prompts) == 0: + return "prompts is empty", 400 + + if len(prompts) > 128: + return "Maximum number of prompts is 128", 400 + + tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow + if "tokens_to_generate" in request.get_json(): + tokens_to_generate = request.get_json()["tokens_to_generate"] + if not isinstance(tokens_to_generate, int): + return "tokens_to_generate must be an integer greater than 0" + if tokens_to_generate < 0: + return "tokens_to_generate must be an integer greater than or equal to 0" + + logprobs = False + if "logprobs" in request.get_json(): + logprobs = request.get_json()["logprobs"] + if not isinstance(logprobs, bool): + return "logprobs must be a boolean value" + + if tokens_to_generate == 0 and not logprobs: + return "tokens_to_generate=0 implies logprobs should be True" + + temperature = 1.0 + if "temperature" in request.get_json(): + temperature = request.get_json()["temperature"] + if not (type(temperature) == int or type(temperature) == float): + return "temperature must be a positive number less than or equal to 100.0" + if not (0.0 < temperature <= 100.0): + return "temperature must be a positive number less than or equal to 100.0" + + top_k = 0.0 + if "top_k" in request.get_json(): + top_k = request.get_json()["top_k"] + if not (type(top_k) == int): + return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000" + if not (0 <= top_k <= 1000): + return "top_k must be equal to or greater than 0 and less than or equal to 1000" + + top_p = 0.0 + if "top_p" in request.get_json(): + top_p = request.get_json()["top_p"] + if not (type(top_p) == float): + return "top_p must be a positive float less than or equal to 1.0" + if top_p > 0.0 and top_k > 0.0: + return "cannot set both top-k and top-p samplings." + if not (0 <= top_p <= 1.0): + return "top_p must be less than or equal to 1.0" + + top_p_decay = 0.0 + if "top_p_decay" in request.get_json(): + top_p_decay = request.get_json()["top_p_decay"] + if not (type(top_p_decay) == float): + return "top_p_decay must be a positive float less than or equal to 1.0" + if top_p == 0.0: + return "top_p_decay cannot be set without top_p" + if not (0 <= top_p_decay <= 1.0): + return "top_p_decay must be less than or equal to 1.0" + + top_p_bound = 0.0 + if "top_p_bound" in request.get_json(): + top_p_bound = request.get_json()["top_p_bound"] + if not (type(top_p_bound) == float): + return "top_p_bound must be a positive float less than or equal to top_p" + if top_p == 0.0: + return "top_p_bound cannot be set without top_p" + if not (0.0 < top_p_bound <= top_p): + return "top_p_bound must be greater than 0 and less than top_p" + + add_BOS = False + if "add_BOS" in request.get_json(): + add_BOS = request.get_json()["add_BOS"] + if not isinstance(add_BOS, bool): + return "add_BOS must be a boolean value" + + if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS: + return "Empty prompts require add_BOS=true" + + stop_on_double_eol = False + if "stop_on_double_eol" in request.get_json(): + stop_on_double_eol = request.get_json()["stop_on_double_eol"] + if not isinstance(stop_on_double_eol, bool): + return "stop_on_double_eol must be a boolean value" + + stop_on_eol = False + if "stop_on_eol" in request.get_json(): + stop_on_eol = request.get_json()["stop_on_eol"] + if not isinstance(stop_on_eol, bool): + return "stop_on_eol must be a boolean value" + + prevent_newline_after_colon = False + if "prevent_newline_after_colon" in request.get_json(): + prevent_newline_after_colon = request.get_json()["prevent_newline_after_colon"] + if not isinstance(prevent_newline_after_colon, bool): + return "prevent_newline_after_colon must be a boolean value" + + random_seed = -1 + if "random_seed" in request.get_json(): + random_seed = request.get_json()["random_seed"] + if not isinstance(random_seed, int): + return "random_seed must be integer" + if random_seed < 0: + return "random_seed must be a positive integer" + + no_log = False + if "no_log" in request.get_json(): + no_log = request.get_json()["no_log"] + if not isinstance(no_log, bool): + return "no_log must be a boolean value" + + beam_width = None + if "beam_width" in request.get_json(): + beam_width = request.get_json()["beam_width"] + if not isinstance(beam_width, int): + return "beam_width must be integer" + if beam_width < 1: + return "beam_width must be an integer > 1" + if len(prompts) > 1: + return "When doing beam_search, batch size must be 1" + + stop_token=50256 + if "stop_token" in request.get_json(): + stop_token = request.get_json()["stop_token"] + if not isinstance(stop_token, int): + return "stop_token must be an integer" + + length_penalty = 1 + if "length_penalty" in request.get_json(): + length_penalty = request.get_json()["length_penalty"] + if not isinstance(length_penalty, float): + return "length_penalty must be a float" + + with lock: # Need to get lock to keep multiple threads from hitting code + + if not no_log: + print("request IP: " + str(request.remote_addr)) + print(json.dumps(request.get_json()),flush=True) + print("start time: ", datetime.datetime.now()) + + try: + if beam_width is not None: + MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search + response, response_seg, response_scores = \ + beam_search_and_post_process( + self.model, + prompts=prompts, + tokens_to_generate=tokens_to_generate, + beam_size = beam_width, + add_BOS=add_BOS, + stop_token=stop_token, + num_return_gen=beam_width, # Returning whole beam + length_penalty=length_penalty, + prevent_newline_after_colon=prevent_newline_after_colon + ) + + return jsonify({"text": response, + "segments": response_seg, + "scores": response_scores}) + else: + MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate + response, response_seg, response_logprobs, _ = \ + generate_and_post_process( + self.model, + prompts=prompts, + tokens_to_generate=tokens_to_generate, + return_output_log_probs=logprobs, + top_k_sampling=top_k, + top_p_sampling=top_p, + top_p_decay=top_p_decay, + top_p_bound=top_p_bound, + temperature=temperature, + add_BOS=add_BOS, + use_eod_token_for_early_termination=True, + stop_on_double_eol=stop_on_double_eol, + stop_on_eol=stop_on_eol, + prevent_newline_after_colon=prevent_newline_after_colon, + random_seed=random_seed) + + return jsonify({"text": response, + "segments": response_seg, + "logprobs": response_logprobs}) + + except ValueError as ve: + return ve.args[0] + print("end time: ", datetime.datetime.now()) + + +class MegatronServer(object): + def __init__(self, model): + self.app = Flask(__name__, static_url_path='') + api = Api(self.app) + api.add_resource(MegatronGenerate, '/api', resource_class_args=[model]) + + def run(self, url, port): + self.app.run(url, threaded=True, debug=False, port=port) diff --git a/training/DeepSpeed-Domino/megatron/timers.py b/training/DeepSpeed-Domino/megatron/timers.py new file mode 100644 index 000000000..a9478fa01 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/timers.py @@ -0,0 +1,304 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Megatron timers.""" + +from abc import ABC +from abc import abstractmethod +import time + +import torch + + + +class TimerBase(ABC): + + def __init__(self, name): + self.name = name + + @abstractmethod + def start(self, barrier=False): + pass + + @abstractmethod + def stop(self, barrier=False): + pass + + @abstractmethod + def reset(self): + pass + + @abstractmethod + def elapsed(self, reset=True, barrier=False): + pass + + + +class DummyTimer(TimerBase): + + def __init__(self): + super().__init__('dummy timer') + + def start(self, barrier=False): + return + + def stop(self, barrier=False): + return + + def reset(self): + return + + def elapsed(self, reset=True, barrier=False): + raise Exception('dummy timer should not be used to ' + 'calculate elapsed time') + + + +class Timer(TimerBase): + """ + Comment on using `barrier`: If this flag is passed, then all + the caller processes will wait till all reach the timing routine. + It is up to the user to make sure all the ranks in `barrier_group` + call it otherwise, it will result in a hang. + Comment on `barrier_group`: By default it is set to None which + in torch distributed land, it will result in the global communicator. + """ + + def __init__(self, name): + super().__init__(name) + self._elapsed = 0.0 + self._started = False + # Note that None will default to the global process group + self._barrier_group = None + self._start_time = time.time() + + + def set_barrier_group(self, barrier_group): + self._barrier_group = barrier_group + + + def start(self, barrier=False): + """Start the timer.""" + assert not self._started, 'timer has already been started' + if barrier: + torch.distributed.barrier(group=self._barrier_group) + torch.cuda.synchronize() + self._start_time = time.time() + self._started = True + + + def stop(self, barrier=False): + """Stop the timer.""" + assert self._started, 'timer is not started' + if barrier: + torch.distributed.barrier(group=self._barrier_group) + torch.cuda.synchronize() + self._elapsed += (time.time() - self._start_time) + self._started = False + + + def reset(self): + """Reset timer.""" + self._elapsed = 0.0 + self._started = False + + + def elapsed(self, reset=True, barrier=False): + """Calculate the elapsed time.""" + _started = self._started + # If the timing in progress, end it first. + if self._started: + self.stop(barrier=barrier) + # Get the elapsed time. + _elapsed = self._elapsed + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if _started: + self.start(barrier=barrier) + return _elapsed + + + +class Timers: + """Group of timers.""" + + def __init__(self, log_level, log_option): + self._log_level = log_level + self._log_option = log_option + self._timers = {} + self._log_levels = {} + self._dummy_timer = DummyTimer() + self._max_log_level = 2 + + + def __call__(self, name, log_level=None): + # If the timer has already been set, then check if the log-level + # is provided, it matches the one that the timer was created with. + if name in self._timers: + if log_level is not None: + assert log_level == self._log_levels[name], \ + 'input log level {} does not match already existing '\ + 'log level {} for {} timer'.format( + log_level, self._log_levels[name], name) + return self._timers[name] + # If timer does not exist and no log level is provided, + # set it to the max log level which is 2. + if log_level is None: + log_level = self._max_log_level + assert log_level <= self._max_log_level, \ + 'log level {} is larger than max supported log level {}'.format( + log_level, self._max_log_level) + # Now if the input log level is larger than the one set for + # the timers class, just ignore it and return a dummy timer. + if log_level > self._log_level: + return self._dummy_timer + # Otherwise, initalize the timer and set the level. + self._timers[name] = Timer(name) + self._log_levels[name] = log_level + return self._timers[name] + + + def _get_elapsed_time_all_ranks(self, names, reset, barrier): + """ + Assumptions: + - All the ranks call this function. + - `names` are identical on all ranks. + If the above assumptions are not met, calling this function will + result in hang. + Arguments: + - names: list of timer names + - reset: reset the timer after recording the elapsed time + - barrier: if set, do a global barrier before time measurments + """ + + # First make sure all the callers are in sync. + if barrier: + torch.distributed.barrier() + + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # Here we can use gather on the rank we want to print the + # timing, however, there is no gather_base support in + # pytorch yet. It is simpler to deal with a single tensor + # and since we are only gathering a small amount of data, + # it should be ok to use all-gather instead of gather. + rank_name_to_time = torch.zeros((world_size, len(names)), + dtype=torch.float, + device=torch.cuda.current_device()) + for i, name in enumerate(names): + if name in self._timers: + # Here we don't need to pass the barrier flag as all + # the processes are already in sync. This avoids the + # issue of different timers having different barrier + # groups inside their class. + rank_name_to_time[rank, i] = self._timers[name].elapsed( + reset=reset) + + # See the note above for why we are not using gather. + torch.distributed._all_gather_base(rank_name_to_time.view(-1), + rank_name_to_time[rank, :].view(-1)) + + return rank_name_to_time + + + def _get_global_min_max_time(self, names, reset, barrier, normalizer): + """Report only min and max times across all ranks.""" + + rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, + barrier) + name_to_min_max_time = {} + for i, name in enumerate(names): + rank_to_time = rank_name_to_time[:, i] + # filter out the ones we did not have any timings for + rank_to_time = rank_to_time[rank_to_time > 0.0] + # If the timer exists: + if rank_to_time.numel() > 0: + name_to_min_max_time[name] = ( + rank_to_time.min().item() / normalizer, + rank_to_time.max().item() / normalizer) + return name_to_min_max_time + + + def _get_global_min_max_time_string(self, names, reset, barrier, + normalizer, max_only): + name_to_min_max_time = self._get_global_min_max_time( + names, reset, barrier, normalizer) + if not name_to_min_max_time: + return None + output_string = '(min, max) time across ranks (ms):' + for name in name_to_min_max_time: + min_time, max_time = name_to_min_max_time[name] + if max_only: + output_string += '\n {}: {:.2f}'.format( + (name+' ').ljust(48, '.'), max_time) + else: + output_string += '\n {}: ({:.2f}, {:.2f})'.format( + (name+' ').ljust(48, '.'), min_time, max_time) + return output_string + + + def _get_all_ranks_time_string(self, names, reset, barrier, normalizer): + """Report times across all ranks.""" + rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, + barrier) + + output_string = 'times across ranks (ms):' + no_reported_timing = True + for i, name in enumerate(names): + not_yet_found = True + for rank in range(torch.distributed.get_world_size()): + if rank_name_to_time[rank, i] > 0: + no_reported_timing = False + if not_yet_found: + not_yet_found = False + output_string += '\n {}:'.format(name) + output_string += '\n rank {:2d}: {:.2f}'.format( + rank, rank_name_to_time[rank, i] / normalizer) + if no_reported_timing: + return None + return output_string + + + def log(self, names, rank=None, normalizer=1.0, reset=True, barrier=False): + """Log a group of timers.""" + + # Print. + assert normalizer > 0.0 + if self._log_option in ['max', 'minmax']: + max_only = False + if self._log_option == 'max': + max_only = True + output_string = self._get_global_min_max_time_string( + names, reset, barrier, normalizer/1000.0, max_only) + elif self._log_option == 'all': + output_string = self._get_all_ranks_time_string(names, + reset, barrier, + normalizer/1000.0) + else: + raise Exception('unknown timing log option {}'.format( + self._log_option)) + + # If no input rank is provided, log on last rank. + if rank is None: + rank = torch.distributed.get_world_size() - 1 + if rank == torch.distributed.get_rank() and output_string is not None: + print(output_string, flush=True) + + + def write(self, names, writer, iteration, normalizer=1.0, + reset=False, barrier=False): + """Write timers to a tensorboard writer + Note that we only report maximum time across ranks to tensorboard. + """ + # currently when using add_scalars, + # torch.utils.add_scalars makes each timer its own run, which + # polutes the runs list, so we just add each as a scalar + assert normalizer > 0.0 + name_to_min_max_time = self._get_global_min_max_time( + names, reset, barrier, normalizer) + if writer is not None: + for name in name_to_min_max_time: + _, max_time = name_to_min_max_time[name] + writer.add_scalar(name + '-time', max_time, iteration) diff --git a/training/DeepSpeed-Domino/megatron/tokenizer/__init__.py b/training/DeepSpeed-Domino/megatron/tokenizer/__init__.py new file mode 100644 index 000000000..59ceb3386 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/tokenizer/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + + +from .tokenizer import build_tokenizer diff --git a/training/DeepSpeed-Domino/megatron/tokenizer/bert_tokenization.py b/training/DeepSpeed-Domino/megatron/tokenizer/bert_tokenization.py new file mode 100644 index 000000000..642041e77 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/tokenizer/bert_tokenization.py @@ -0,0 +1,431 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re +import unicodedata +import six + + +def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): + """Checks whether the casing config is consistent with the checkpoint name.""" + + # The casing has to be passed in by the user and there is no explicit check + # as to whether it matches the checkpoint. The casing information probably + # should have been stored in the bert_config.json file, but it's not, so + # we have to heuristically detect it to validate. + + if not init_checkpoint: + return + + m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) + if m is None: + return + + model_name = m.group(1) + + lower_models = [ + "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", + "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" + ] + + cased_models = [ + "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", + "multi_cased_L-12_H-768_A-12" + ] + + is_bad_config = False + if model_name in lower_models and not do_lower_case: + is_bad_config = True + actual_flag = "False" + case_name = "lowercased" + opposite_flag = "True" + + if model_name in cased_models and do_lower_case: + is_bad_config = True + actual_flag = "True" + case_name = "cased" + opposite_flag = "False" + + if is_bad_config: + raise ValueError( + "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " + "However, `%s` seems to be a %s model, so you " + "should pass in `--do_lower_case=%s` so that the fine-tuning matches " + "how the model was pre-training. If this error is wrong, please " + "just comment out this check." % (actual_flag, init_checkpoint, + model_name, case_name, opposite_flag)) + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + elif isinstance(text, unicode): + return text + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def printable_text(text): + """Returns text encoded in a way suitable for print or `tf.logging`.""" + + # These functions want `str` for both Python2 and Python3, but in one case + # it's a Unicode string and in the other it's a byte string. + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text + elif isinstance(text, unicode): + return text.encode("utf-8") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r", encoding = "utf-8") as reader: + while True: + token = convert_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def convert_by_vocab(vocab, items): + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + output.append(vocab[item]) + return output + + +def convert_tokens_to_ids(vocab, tokens): + return convert_by_vocab(vocab, tokens) + + +def convert_ids_to_tokens(inv_vocab, ids): + return convert_by_vocab(inv_vocab, ids) + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class FullTokenizer(object): + """Runs end-to-end tokenziation.""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def convert_tokens_to_ids(self, tokens): + return convert_by_vocab(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return convert_by_vocab(self.inv_vocab, ids) + + @staticmethod + def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): + """ Converts a sequence of tokens (string) in a single string. """ + + def clean_up_tokenization(out_string): + """ Clean up a list of simple English tokenization artifacts + like spaces before punctuations and abreviated forms. + """ + out_string = ( + out_string.replace(" .", ".") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ,", ",") + .replace(" ' ", "'") + .replace(" n't", "n't") + .replace(" 'm", "'m") + .replace(" 's", "'s") + .replace(" 've", "'ve") + .replace(" 're", "'re") + ) + return out_string + + text = ' '.join(tokens).replace(' ##', '').strip() + if clean_up_tokenization_spaces: + clean_text = clean_up_tokenization(text) + return clean_text + else: + return text + + def vocab_size(self): + return len(self.vocab) + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenziation.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer. + + Returns: + A list of wordpiece tokens. + """ + + text = convert_to_unicode(text) + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat in ("Cc", "Cf"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/training/DeepSpeed-Domino/megatron/tokenizer/gpt2_tokenization.py b/training/DeepSpeed-Domino/megatron/tokenizer/gpt2_tokenization.py new file mode 100644 index 000000000..3f37e4490 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/tokenizer/gpt2_tokenization.py @@ -0,0 +1,321 @@ +# coding=utf-8 +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for OpenAI GPT.""" + +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import sys +import json +import logging +import os +import regex as re +from io import open + +try: + from functools import lru_cache +except ImportError: + # Just a dummy decorator to get the checks to run on python2 + # because honestly I don't want to support a byte-level unicode BPE + # tokenizer on python 2 right now. + def lru_cache(): + return lambda func: func + + +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", +} +PRETRAINED_MERGES_ARCHIVE_MAP = { + 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", +} +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'gpt2': 1024, +} +VOCAB_NAME = 'vocab.json' +MERGES_NAME = 'merges.txt' +SPECIAL_TOKENS_NAME = 'special_tokens.txt' + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + _chr = unichr if sys.version_info[0] == 2 else chr + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \ + list(range(ord("®"), ord("ÿ") + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [_chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class GPT2Tokenizer(object): + """ + GPT-2 BPE tokenizer. Peculiarities: + - Byte-level BPE + """ + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] + special_tokens_file = None + else: + vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) + merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) + special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) + if not os.path.exists(special_tokens_file): + special_tokens_file = None + else: + logger.info("loading special tokens file {}".format(special_tokens_file)) + # redirect to the cache, if necessary + try: + from .file_utils import cached_path + resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) + resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find files {} and {} " + "at this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + pretrained_model_name_or_path, + vocab_file, merges_file)) + return None + if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: + logger.info("loading vocabulary file {}".format(vocab_file)) + logger.info("loading merges file {}".format(merges_file)) + else: + logger.info("loading vocabulary file {} from cache at {}".format( + vocab_file, resolved_vocab_file)) + logger.info("loading merges file {} from cache at {}".format( + merges_file, resolved_merges_file)) + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + if special_tokens_file and 'special_tokens' not in kwargs: + special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] + else: + special_tokens = kwargs.pop('special_tokens', []) + tokenizer = cls( + resolved_vocab_file, + resolved_merges_file, + special_tokens=special_tokens, + *inputs, + **kwargs) + return tokenizer + + def __init__(self, vocab_file, merges_file, errors='replace', + special_tokens=None, max_len=None): + self.max_len = max_len if max_len is not None else int(1e12) + self.encoder = json.load(open(vocab_file)) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_data] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + # Should haved added re.IGNORECASE so BPE merges can happen for + # capitalized versions of contractions + self.pat = re.compile( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + self.special_tokens = {} + self.special_tokens_decoder = {} + self.set_special_tokens(special_tokens) + + def __len__(self): + return len(self.encoder) + len(self.special_tokens) + + def set_special_tokens(self, special_tokens): + """ Add a list of additional tokens to the encoder. + The additional tokens are indexed starting from the last index of the + current vocabulary in the order of the `special_tokens` list. + """ + if not special_tokens: + self.special_tokens = {} + self.special_tokens_decoder = {} + return + self.special_tokens = dict((tok, len(self.encoder) + i) + for i, tok in enumerate(special_tokens)) + self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()} + logger.info("Special tokens {}".format(self.special_tokens)) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except BaseException: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def tokenize(self, text): + """ Tokenize a string. """ + bpe_tokens = [] + for token in re.findall(self.pat, text): + if sys.version_info[0] == 2: + token = ''.join(self.byte_encoder[ord(b)] for b in token) + else: + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def convert_tokens_to_ids(self, tokens): + """ Converts a sequence of tokens into ids using the vocab. """ + ids = [] + if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.encoder.get(tokens, 0) + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.encoder.get(token, 0)) + if len(ids) > self.max_len: + logger.warning( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this OpenAI GPT model ({} > {}). Running this" + " sequence through the model will result in indexing errors".format( + len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + """Converts a sequence of ids in BPE tokens using the vocab.""" + tokens = [] + for i in ids: + if i in self.special_tokens_decoder: + if not skip_special_tokens: + tokens.append(self.special_tokens_decoder[i]) + else: + tokens.append(self.decoder[i]) + return tokens + + def encode(self, text): + return self.convert_tokens_to_ids(self.tokenize(text)) + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) + return text + + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary and merge files to a directory.""" + if not os.path.isdir(vocab_path): + logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) + return + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + merge_file = os.path.join(vocab_path, MERGES_NAME) + special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) + + with open(vocab_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(self.encoder, ensure_ascii=False)) + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write(u'#version: 0.2\n') + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(merge_file)) + index = token_index + writer.write(' '.join(bpe_tokens) + u'\n') + index += 1 + + index = len(self.encoder) + with open(special_tokens_file, 'w', encoding='utf-8') as writer: + for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) + index = token_index + writer.write(token + u'\n') + index += 1 + + return vocab_file, merge_file, special_tokens_file diff --git a/training/DeepSpeed-Domino/megatron/tokenizer/tokenizer.py b/training/DeepSpeed-Domino/megatron/tokenizer/tokenizer.py new file mode 100644 index 000000000..79dab75a0 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/tokenizer/tokenizer.py @@ -0,0 +1,536 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Megatron tokenizers.""" + +from abc import ABC +from abc import abstractmethod + +from .bert_tokenization import FullTokenizer as FullBertTokenizer +from .gpt2_tokenization import GPT2Tokenizer + + +def build_tokenizer(args): + """Initialize tokenizer.""" + if args.rank == 0: + print('> building {} tokenizer ...'.format(args.tokenizer_type), + flush=True) + + # Select and instantiate the tokenizer. + if args.tokenizer_type == 'BertWordPieceLowerCase': + assert args.vocab_file is not None + tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, + lower_case=True, + vocab_extra_ids=args.vocab_extra_ids) + elif args.tokenizer_type == 'BertWordPieceCase': + assert args.vocab_file is not None + tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, + lower_case=False, + vocab_extra_ids=args.vocab_extra_ids) + elif args.tokenizer_type == 'GPT2BPETokenizer': + assert args.vocab_file is not None + assert args.merge_file is not None + tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) + elif args.tokenizer_type == 'SentencePieceTokenizer': + assert args.tokenizer_model is not None + tokenizer = _SentencePieceTokenizer(args.tokenizer_model, vocab_extra_ids=args.vocab_extra_ids) + elif args.tokenizer_type == 'GPTSentencePieceTokenizer': + assert args.tokenizer_model is not None + tokenizer = _GPTSentencePieceTokenizer(args.tokenizer_model) + elif args.tokenizer_type == 'NullTokenizer': + assert args.vocab_size is not None + tokenizer = _NullTokenizer(args.vocab_size) + else: + raise NotImplementedError('{} tokenizer is not ' + 'implemented.'.format(args.tokenizer_type)) + + # Add vocab size. + args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, + args) + + return tokenizer + + +def _vocab_size_with_padding(orig_vocab_size, args): + """Pad vocab size so it is divisible by model parallel size and + still having GPU friendly size.""" + + after = orig_vocab_size + multiple = args.make_vocab_size_divisible_by * \ + args.tensor_model_parallel_size + while (after % multiple) != 0: + after += 1 + if args.rank == 0: + print(' > padded vocab (size: {}) with {} dummy tokens ' + '(new size: {})'.format( + orig_vocab_size, after - orig_vocab_size, after), flush=True) + return after + + +class AbstractTokenizer(ABC): + """Abstract class for tokenizer.""" + + def __init__(self, name): + self.name = name + super().__init__() + + @property + @abstractmethod + def vocab_size(self): + pass + + @property + @abstractmethod + def vocab(self): + """Dictionary from vocab text token to id token.""" + pass + + @property + @abstractmethod + def inv_vocab(self): + """Dictionary from vocab id token to text token.""" + pass + + @abstractmethod + def tokenize(self, text): + pass + + def detokenize(self, token_ids): + raise NotImplementedError('detokenizer is not implemented for {} ' + 'tokenizer'.format(self.name)) + + @property + def cls(self): + raise NotImplementedError('CLS is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def sep(self): + raise NotImplementedError('SEP is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def pad(self): + raise NotImplementedError('PAD is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def eod(self): + raise NotImplementedError('EOD is not provided for {} ' + 'tokenizer'.format(self.name)) + + @property + def mask(self): + raise NotImplementedError('MASK is not provided for {} ' + 'tokenizer'.format(self.name)) + + +class _BertWordPieceTokenizer(AbstractTokenizer): + """Original BERT wordpiece tokenizer.""" + + def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): + if lower_case: + name = 'BERT Lower Case' + else: + name = 'BERT Upper Case' + super().__init__(name) + self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case) + self.cls_id = self.tokenizer.vocab['[CLS]'] + self.sep_id = self.tokenizer.vocab['[SEP]'] + self.pad_id = self.tokenizer.vocab['[PAD]'] + self.mask_id = self.tokenizer.vocab['[MASK]'] + self._additional_special_tokens = [] + + # (dsachan) Add BOS and EOS tokens + SPECIAL_TOKENS = {'eos_token': '[EOS]', + 'bos_token': '[BOS]'} + self._bos_token = '[BOS]' + self.add_token(self._bos_token) + self._bos_token_id = self.vocab.get(self._bos_token) + + self._eos_token = '[EOS]' + self.add_token(self._eos_token) + self._eos_token_id = self.vocab.get(self._eos_token) + + # (dsachan) Add additional special tokens + # These can be used as sentinel tokens in T5 model inputs + additional_special_tokens = [] + additional_special_tokens.extend( + ["".format(i) for i in range(vocab_extra_ids)]) + self.add_additional_special_tokens(additional_special_tokens) + + def add_token(self, token): + if token not in self.vocab: + self.inv_vocab[self.vocab_size] = token + # self.vocab_size comes from len(vocab) + # and it will increase as we add elements + self.vocab[token] = self.vocab_size + + def add_additional_special_tokens(self, tokens_list): + setattr(self, "additional_special_tokens", tokens_list) + for value in tokens_list: + self.add_token(value) + + @property + def vocab_size(self): + return self.tokenizer.vocab_size() + + @property + def vocab(self): + return self.tokenizer.vocab + + @property + def inv_vocab(self): + return self.tokenizer.inv_vocab + + def tokenize(self, text): + text_tokens = self.tokenizer.tokenize(text) + return self.tokenizer.convert_tokens_to_ids(text_tokens) + + def decode(self, ids): + tokens = self.tokenizer.convert_ids_to_tokens(ids) + return self.tokenizer.convert_tokens_to_string(tokens) + + def decode_token_ids(self, token_ids): + tokens = self.tokenizer.convert_ids_to_tokens(token_ids) + exclude_list = ['[PAD]', '[CLS]'] + non_pads = [t for t in tokens if t not in exclude_list] + + result = "" + for s in non_pads: + if s.startswith("##"): + result += s[2:] + else: + result += " " + s + + return result + + @property + def cls(self): + return self.cls_id + + @property + def sep(self): + return self.sep_id + + @property + def pad(self): + return self.pad_id + + @property + def mask(self): + return self.mask_id + + @property + def bos_token(self): + """ Beginning of sentence token id """ + return self._bos_token + + @property + def eos_token(self): + """ End of sentence token id """ + return self._eos_token + + @property + def additional_special_tokens(self): + """ All the additional special tokens you may want to use (list of strings).""" + return self._additional_special_tokens + + @property + def bos_token_id(self): + """ Id of the beginning of sentence token in the vocabulary.""" + return self._bos_token_id + + @property + def eos_token_id(self): + """ Id of the end of sentence token in the vocabulary.""" + return self._eos_token_id + + @property + def additional_special_tokens_ids(self): + """ Ids of all the additional special tokens in the vocabulary (list of integers).""" + return [self.vocab.get(token) for token in self._additional_special_tokens] + + @additional_special_tokens.setter + def additional_special_tokens(self, value): + self._additional_special_tokens = value + + +class _GPT2BPETokenizer(AbstractTokenizer): + """Original GPT2 BPE tokenizer.""" + + def __init__(self, vocab_file, merge_file): + name = 'GPT2 BPE' + super().__init__(name) + + self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace', + special_tokens=[], max_len=None) + self.eod_id = self.tokenizer.encoder['<|endoftext|>'] + + @property + def vocab_size(self): + return len(self.tokenizer.encoder) + + @property + def vocab(self): + return self.tokenizer.encoder + + @property + def inv_vocab(self): + return self.tokenizer.decoder + + def tokenize(self, text): + return self.tokenizer.encode(text) + + def detokenize(self, token_ids): + return self.tokenizer.decode(token_ids) + + @property + def eod(self): + return self.eod_id + + +class _SentencePieceTokenizer(AbstractTokenizer): + """SentencePieceTokenizer-Megatron wrapper""" + + def __init__(self, model_file, vocab_extra_ids=0): + name = 'SentencePieceTokenizer' + super().__init__(name) + + import sentencepiece + self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file) + self._initalize(vocab_extra_ids) + + def _populate_vocab(self): + self._vocab = {} + self._inv_vocab = {} + + for i in range(len(self.tokenizer)): + t = self.tokenizer.id_to_piece(i) + self._inv_vocab[i] = t + self._vocab[t] = i + + def _initalize(self, vocab_extra_ids): + self._populate_vocab() + self._special_tokens = {} + self._inv_special_tokens = {} + + self._t5_tokens = [] + + def _add_special_token(t): + if t not in self._vocab: + next_id = len(self._vocab) + self._vocab[t] = next_id + self._inv_vocab[next_id] = t + self._special_tokens[t] = self._vocab[t] + self._inv_special_tokens[self._vocab[t]] = t + + _add_special_token('') + self._cls_id = self._vocab[''] + _add_special_token('') + self._sep_id = self._vocab[''] + _add_special_token('') + self._eod_id = self._vocab[''] + _add_special_token('') + self._mask_id = self._vocab[''] + + pad_id = self.tokenizer.pad_id() + try: + pad_token = self.tokenizer.id_to_piece(pad_id) + except IndexError: + pad_token = '' + _add_special_token(pad_token) + self._pad_id = self._vocab[pad_token] + + bos_id = self.tokenizer.bos_id() + try: + bos_token = self.tokenizer.id_to_piece(bos_id) + except IndexError: + bos_token = '' + _add_special_token(bos_token) + self._bos_id = self._vocab[bos_token] + + eos_id = self.tokenizer.eos_id() + try: + eos_token = self.tokenizer.id_to_piece(eos_id) + except IndexError: + eos_token = '' + _add_special_token(eos_token) + self._eos_id = self._vocab[eos_token] + + for i in range(vocab_extra_ids): + t = "".format(i) + _add_special_token(t) + self._t5_tokens += [t] + + @property + def vocab_size(self): + return len(self._vocab) + + @property + def vocab(self): + return self._vocab + + @property + def inv_vocab(self): + return self._inv_vocab + + @property + def decoder(self): + return self._inv_vocab + + @property + def encoder(self): + return self._vocab + + # From: + # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89 + def tokenize(self, text): + ids = [] + idx = 0 + + while 1: + indices = {} + for token in self._special_tokens: + try: + indices[token] = text[idx:].index(token) + except ValueError: + continue + if len(indices) == 0: + break + + next_token = min(indices, key=indices.get) + next_idx = idx + indices[next_token] + + ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx])) + ids.append(self._special_tokens[next_token]) + idx = next_idx + len(next_token) + + ids.extend(self.tokenizer.encode_as_ids(text[idx:])) + return ids + + # From: + # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L125 + def detokenize(self, ids): + text = "" + last_i = 0 + + for i, id in enumerate(ids): + if id in self._inv_special_tokens: + text += self.tokenizer.decode_ids(ids[last_i:i]) + " " + text += self._inv_special_tokens[id] + " " + last_i = i + 1 + + text += self.tokenizer.decode_ids(ids[last_i:]) + return text + + @property + def cls(self): + return self._cls_id + + @property + def sep(self): + return self._sep_id + + @property + def pad(self): + return self._pad_id + + @property + def bos_token_id(self): + return self._bos_id + + @property + def bos(self): + return self._bos_id + + @property + def eod(self): + return self._eod_id + + @property + def eos_token_id(self): + return self._eos_id + + @property + def eos(self): + return self._eos_id + + @property + def mask(self): + return self._mask_id + + @property + def additional_special_tokens_ids(self): + return [self.vocab[k] for k in self._t5_tokens] + +class _GPTSentencePieceTokenizer(_SentencePieceTokenizer): + """SentencePieceTokenizer-Megatron wrapper""" + + def __init__(self, model_file,): + super().__init__(model_file, vocab_extra_ids=0) + + def _initalize(self, vocab_extra_ids): + self._populate_vocab() + + self._pad_id = self.tokenizer.pad_id() + self._bos_id = self.tokenizer.bos_id() + self._eos_id = self.tokenizer.eos_id() + + def tokenize(self, text): + return self.tokenizer.encode_as_ids(text) + + def detokenize(self, ids): + return self.tokenizer.decode_ids(ids) + + @property + def cls(self): + return -1 + + @property + def sep(self): + return -1 + + @property + def mask(self): + return -1 + + @property + def eod(self): + return self._eos_id + + @property + def additional_special_tokens_ids(self): + return None + +class _NullTokenizer: + def __init__(self, vocab_size): + vocab_size = int(vocab_size) + self._eos_id = vocab_size + self.vocab_size = vocab_size+1 + + def tokenize(self, text): + return [int(x) for x in text.split(' ')] + + def detokenize(self, ids): + text = [str(x) for x in ids] + return ' '.join(text) + + @property + def cls(self): + return -1 + + @property + def sep(self): + return -1 + + @property + def mask(self): + return -1 + + @property + def eod(self): + return self._eos_id + + @property + def additional_special_tokens_ids(self): + return None diff --git a/training/DeepSpeed-Domino/megatron/utils.py b/training/DeepSpeed-Domino/megatron/utils.py new file mode 100644 index 000000000..008f89fa8 --- /dev/null +++ b/training/DeepSpeed-Domino/megatron/utils.py @@ -0,0 +1,213 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""General utilities.""" + +import sys + +import torch +from torch.nn.parallel import DistributedDataParallel as torchDDP + +from apex.multi_tensor_apply import multi_tensor_applier +import amp_C + +from megatron import ( + get_args, + get_adlr_autoresume, +) +from megatron.core import mpu +from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate +from megatron.model.module import param_is_not_shared + + +def unwrap_model(model, module_instances=(torchDDP)): + return_list = True + if not isinstance(model, list): + model = [model] + return_list = False + unwrapped_model = [] + for model_module in model: + while isinstance(model_module, module_instances): + model_module = model_module.module + unwrapped_model.append(model_module) + if not return_list: + return unwrapped_model[0] + return unwrapped_model + + +def calc_params_l2_norm(model): + """Calculate l2 norm of parameters """ + args = get_args() + if not isinstance(model, list): + model = [model] + # Remove duplicate params. + params_data = [] + for model_ in model: + for param in model_.parameters(): + is_not_shared = param_is_not_shared(param) + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + if is_not_shared and is_not_tp_duplicate: + if args.bf16: + params_data.append(param.data.float()) + else: + params_data.append(param.data) + # Calculate norm + dummy_overflow_buf = torch.cuda.IntTensor([0]) + norm, _ = multi_tensor_applier( + amp_C.multi_tensor_l2norm, + dummy_overflow_buf, + [params_data], + False # no per-parameter norm + ) + norm_2 = norm * norm + # Sum across all model-parallel GPUs. + torch.distributed.all_reduce(norm_2, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_model_parallel_group()) + return norm_2.item() ** 0.5 + + +def average_losses_across_data_parallel_group(losses): + """Reduce a tensor of losses across all GPUs.""" + averaged_losses = torch.cat( + [loss.clone().detach().view(1) for loss in losses]) + torch.distributed.all_reduce(averaged_losses, + group=mpu.get_data_parallel_group()) + averaged_losses = averaged_losses / \ + torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) + + return averaged_losses + + +def report_memory(name): + """Simple GPU memory report.""" + mega_bytes = 1024.0 * 1024.0 + string = name + ' memory (MB)' + string += ' | allocated: {}'.format( + torch.cuda.memory_allocated() / mega_bytes) + string += ' | max allocated: {}'.format( + torch.cuda.max_memory_allocated() / mega_bytes) + string += ' | reserved: {}'.format( + torch.cuda.memory_reserved() / mega_bytes) + string += ' | max reserved: {}'.format( + torch.cuda.max_memory_reserved() / mega_bytes) + if mpu.get_data_parallel_rank() == 0: + print("[Rank {}] {}".format(torch.distributed.get_rank(), string), + flush=True) + + +def print_params_min_max_norm(optimizer, iteration): + """Print min, max, and norm of all parameters.""" + index = 0 + rank = torch.distributed.get_rank() + string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n' + optimizer_ = optimizer.optimizer + for param_group in optimizer_.param_groups: + for param in param_group['params']: + index += 1 + min_ = param.data.min() + max_ = param.data.max() + norm = torch.linalg.norm(param.data) + string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( + iteration, rank, index, int(param.tensor_model_parallel)) + string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) + print(string, flush=True) + + +def check_adlr_autoresume_termination(iteration, model, + optimizer, opt_param_scheduler): + """Check for autoresume signal and exit if it is received.""" + from megatron.checkpointing import save_checkpoint + + args = get_args() + autoresume = get_adlr_autoresume() + # Add barrier to ensure consistnecy. + torch.distributed.barrier() + if autoresume.termination_requested(): + if args.save: + save_checkpoint(iteration, model, optimizer, opt_param_scheduler) + print_rank_0(">>> autoresume termination request found!") + if torch.distributed.get_rank() == 0: + autoresume.request_resume() + print_rank_0(">>> training terminated. Returning") + sys.exit(0) + + +def get_ltor_masks_and_position_ids(data, + eod_token, + reset_position_ids, + reset_attention_mask, + eod_mask_loss): + """Build masks and position id for left to right model.""" + + # Extract batch size and sequence length. + micro_batch_size, seq_length = data.size() + + # Attention mask (lower triangular). + if reset_attention_mask: + att_mask_batch = micro_batch_size + else: + att_mask_batch = 1 + attention_mask = torch.tril(torch.ones( + (att_mask_batch, seq_length, seq_length), device=data.device)).view( + att_mask_batch, 1, seq_length, seq_length) + + # Loss mask. + loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) + if eod_mask_loss: + loss_mask[data == eod_token] = 0.0 + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, + device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data) + # We need to clone as the ids will be modifed based on batch index. + if reset_position_ids: + position_ids = position_ids.clone() + + if reset_position_ids or reset_attention_mask: + # Loop through the batches: + for b in range(micro_batch_size): + + # Find indecies where EOD token is. + eod_index = position_ids[b, data[b] == eod_token] + # Detach indecies from positions if going to modify positions. + if reset_position_ids: + eod_index = eod_index.clone() + + # Loop through EOD indecies: + prev_index = 0 + for j in range(eod_index.size()[0]): + i = eod_index[j] + # Mask attention loss. + if reset_attention_mask: + attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 + # Reset positions. + if reset_position_ids: + position_ids[b, (i + 1):] -= (i + 1 - prev_index) + prev_index = i + 1 + + # Convert attention mask to binary: + attention_mask = (attention_mask < 0.5) + + return attention_mask, loss_mask, position_ids + + +def print_rank_0(message): + """If distributed is initialized, print only on rank 0.""" + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + print(message, flush=True) + else: + print(message, flush=True) + +def is_last_rank(): + return torch.distributed.get_rank() == ( + torch.distributed.get_world_size() - 1) + +def print_rank_last(message): + """If distributed is initialized, print only on last rank.""" + if torch.distributed.is_initialized(): + if is_last_rank(): + print(message, flush=True) + else: + print(message, flush=True) diff --git a/training/DeepSpeed-Domino/pretrain_base.py b/training/DeepSpeed-Domino/pretrain_base.py new file mode 100644 index 000000000..800305555 --- /dev/null +++ b/training/DeepSpeed-Domino/pretrain_base.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from pretrain_base.py in Megatron-LM + +from functools import partial +import torch +from domino.arguments import get_args +from domino.arguments import get_tokenizer +from domino.utils import get_ltor_masks_and_position_ids +from domino.utils import average_losses_across_data_parallel_group +from domino.tensor_parallel.data import broadcast_data + +def forward_step(data_iterator, model): + """Forward step.""" + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + output_tensor = model(tokens, position_ids, attention_mask, labels=labels) + return output_tensor, partial(loss_func, loss_mask) + + +def get_batch(data_iterator): + """Generate a batch.""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss) + + return tokens, labels, loss_mask, attention_mask, position_ids + + +def loss_func(loss_mask, output_tensor): + """Loss function.""" + raw_loss = output_tensor.view(-1).float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(raw_loss * loss_mask) / loss_mask.sum() + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + return loss, {'lm loss': averaged_loss[0]} \ No newline at end of file diff --git a/training/DeepSpeed-Domino/pretrain_gpt.py b/training/DeepSpeed-Domino/pretrain_gpt.py new file mode 100644 index 000000000..7fc4650e1 --- /dev/null +++ b/training/DeepSpeed-Domino/pretrain_gpt.py @@ -0,0 +1,139 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from pretrain_gpt.py in Megatron-LM + +import time +import torch +from domino.utils import print_rank_0 +from domino.initialize import initialize_domino, set_jit_fusion_options +from domino.arguments import get_args, core_transformer_config_from_args +from domino.data.gpt_dataset import build_train_valid_test_datasets +from domino.training import pretrain +from domino.modules.module import DominoModule +from domino.modules.enums import AttnMaskType +from domino.language_model import parallel_lm_logits +from domino.language_model import get_language_model +from domino.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy + + +_TRAIN_START_TIME = time.time() + +def post_language_model_processing(lm_output, labels, logit_weights, parallel_output): + output = parallel_lm_logits(lm_output, logit_weights, parallel_output) + labels = labels.transpose(0, 1).contiguous() + loss = vocab_parallel_cross_entropy(output.float(), labels) + loss = loss.transpose(0, 1).contiguous() + return loss + + +class GPTModel(DominoModule): + def __init__( + self, + config, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True, + ): + super().__init__(config=config) + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + self.language_model = get_language_model( + config=config, + num_tokentypes=num_tokentypes, + encoder_attn_mask_type=AttnMaskType.causal, + pre_process=self.pre_process, + post_process=self.post_process, + ) + self.initialize_word_embeddings() + + def set_input_tensor(self, input_tensor): + self.language_model.set_input_tensor(input_tensor) + + def forward( + self, + input_ids, + position_ids, + attention_mask, + labels=None, + inference_params=None, + ): + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + inference_params=inference_params, + ) + + if self.post_process: + return post_language_model_processing( + lm_output, + labels, + self.shared_embedding_or_output_weight(), + self.parallel_output, + ) + else: + return lm_output + +def main(): + initialize_domino() + + # Set pytorch JIT layer fusion options and warmup JIT functions. + set_jit_fusion_options() + + # Adjust the startup time so it reflects the largest value. + # This will be closer to what scheduler will see (outside of + # image ... launches. + global _TRAIN_START_TIME + start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME]) + torch.distributed.all_reduce(start_time_tensor, + op=torch.distributed.ReduceOp.MIN) + _TRAIN_START_TIME = start_time_tensor.item() + print_rank_0('time to initialize megatron (seconds): {:.3f}'.format( + time.time() - _TRAIN_START_TIME)) + + print_rank_0('Building GPT model ...') + config = core_transformer_config_from_args(get_args()) + model = GPTModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True + ) + + args = get_args() + print_rank_0('Load GPT dataset ...') + # Number of train/valid/test samples. + if args.train_samples: + train_samples = args.train_samples + else: + train_samples = args.train_iters * args.global_batch_size + eval_iters = (args.train_iters // args.eval_interval + 1) * \ + args.eval_iters + test_iters = args.eval_iters + train_val_test_num_samples = [train_samples, + eval_iters * args.global_batch_size, + test_iters * args.global_batch_size] + print_rank_0(' > datasets target sizes (minimum size):') + print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) + print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) + print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) + + pretrain(model, train_ds, valid_ds, test_ds) + + +if __name__ == "__main__": + main() diff --git a/training/DeepSpeed-Domino/pretrain_gpt3_2.7b.sh b/training/DeepSpeed-Domino/pretrain_gpt3_2.7b.sh new file mode 100644 index 000000000..c22b6866e --- /dev/null +++ b/training/DeepSpeed-Domino/pretrain_gpt3_2.7b.sh @@ -0,0 +1,72 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from pretrain_gpt.sh in Megatron-LM + +#!/bin/bash --login + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +GPUS_PER_NODE=2 +MASTER_ADDR=localhost +MASTER_PORT=6001 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CHECKPOINT_PATH=checkpoint +rm -rf $CHECKPOINT_PATH/* +VOCAB_FILE="dataset/gpt2-vocab.json" +MERGE_FILE="dataset/gpt2-merges.txt" +DATA_PATH="dataset/my-gpt2_text_document" + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +export PYTHONPATH=$SCRIPT_DIR:$PYTHONPATH + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +GPT_ARGS=" + --num-layers 12 \ + --hidden-size 2560 \ + --num-attention-heads 32 \ + --seq-length 512 \ + --max-position-embeddings 512 \ + --micro-batch-size 64 \ + --global-batch-size 64 \ + --lr 0.00015 \ + --train-iters 10 \ + --lr-decay-iters 320000 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --lr-warmup-fraction .01 \ + --clip-grad 1.0 \ + --fp16 \ + --tensor-model-parallel-size $WORLD_SIZE \ + --seed 3407 +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --vocab-file $VOCAB_FILE \ + --merge-file $MERGE_FILE \ + --split 949,50,1 +" + +OUTPUT_ARGS=" + --log-interval 1 \ +" + +cmd="deepspeed --num_gpus $WORLD_SIZE \ + pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS + " + +echo $cmd +eval $cmd diff --git a/training/DeepSpeed-Domino/pretrain_gpt3_6.7b.sh b/training/DeepSpeed-Domino/pretrain_gpt3_6.7b.sh new file mode 100644 index 000000000..131411d2b --- /dev/null +++ b/training/DeepSpeed-Domino/pretrain_gpt3_6.7b.sh @@ -0,0 +1,71 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from pretrain_gpt.sh in Megatron-LM + +#!/bin/bash --login + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +GPUS_PER_NODE=2 +MASTER_ADDR=localhost +MASTER_PORT=6001 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CHECKPOINT_PATH=checkpoint +rm -rf $CHECKPOINT_PATH/* +VOCAB_FILE="dataset/gpt2-vocab.json" +MERGE_FILE="dataset/gpt2-merges.txt" +DATA_PATH="dataset/my-gpt2_text_document" + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +export PYTHONPATH=$SCRIPT_DIR:$PYTHONPATH + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +GPT_ARGS=" + --num-layers 32 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --micro-batch-size 8 \ + --global-batch-size 8 \ + --lr 0.00015 \ + --train-iters 80 \ + --lr-decay-iters 320000 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --lr-warmup-fraction .01 \ + --clip-grad 1.0 \ + --no-gradient-accumulation-fusion \ + --fp16 \ + --tensor-model-parallel-size $WORLD_SIZE +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --vocab-file $VOCAB_FILE \ + --merge-file $MERGE_FILE \ + --split 949,50,1 +" + +OUTPUT_ARGS=" + --log-interval 1 \ +" + +cmd="deepspeed --num_gpus $WORLD_SIZE \ + pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS + " +echo $cmd +eval $cmd diff --git a/training/DeepSpeed-Domino/pretrain_llama.py b/training/DeepSpeed-Domino/pretrain_llama.py new file mode 100644 index 000000000..25e8b01a0 --- /dev/null +++ b/training/DeepSpeed-Domino/pretrain_llama.py @@ -0,0 +1,160 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from pretrain_llama.py in Megatron-LM + +import time +import torch +from domino.utils import print_rank_0 +from domino.initialize import initialize_domino, set_jit_fusion_options +from domino.arguments import get_args, core_transformer_config_from_args +from domino.data.gpt_dataset import build_train_valid_test_datasets +from domino.training import pretrain +from domino.modules.module import DominoModule +from domino.modules.enums import AttnMaskType +from domino.language_model import parallel_lm_logits +from domino.language_model import get_language_model +from domino.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy + + +# The earliest we can measure the start time. +_TRAIN_START_TIME = time.time() + +def post_language_model_processing(lm_output, labels, logit_weights, parallel_output): + output = parallel_lm_logits(lm_output, logit_weights, parallel_output) + labels = labels.transpose(0, 1).contiguous() + loss = vocab_parallel_cross_entropy(output.float(), labels) + loss = loss.transpose(0, 1).contiguous() + return loss + + +class LLaMAModel(DominoModule): + """LLaMA Language model.""" + + def __init__( + self, + config, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True, + ): + args = get_args() + super(LLaMAModel, self).__init__( + config=config, + share_embeddings_and_output_weights=True) + + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + self.padded_vocab_size = args.padded_vocab_size + self.language_model = get_language_model( + config=config, + num_tokentypes=num_tokentypes, + encoder_attn_mask_type=AttnMaskType.causal, + pre_process=self.pre_process, + post_process=self.post_process, + ) + self.initialize_word_embeddings() + self.lm_head = torch.nn.Linear( + args.hidden_size, args.padded_vocab_size, bias=False + ) + + def set_input_tensor(self, input_tensor): + self.language_model.set_input_tensor(input_tensor) + + def _causal_lm_process(self, lm_output, labels): + lm_output = lm_output.transpose(0, 1) + logits = self.lm_head(lm_output) + loss = None + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., :-1].contiguous() + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=0) + shift_logits = shift_logits.view(-1, self.padded_vocab_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + return loss + + def forward( + self, + input_ids, + position_ids, + attention_mask, + labels=None, + inference_params=None, + ): + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + inference_params=inference_params, + ) + + if self.post_process: + return self._causal_lm_process(lm_output=lm_output, labels=labels) + else: + return lm_output + + +def main(): + initialize_domino() + + # Set pytorch JIT layer fusion options and warmup JIT functions. + set_jit_fusion_options() + + # Adjust the startup time so it reflects the largest value. + # This will be closer to what scheduler will see (outside of + # image ... launches. + global _TRAIN_START_TIME + start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME]) + torch.distributed.all_reduce(start_time_tensor, + op=torch.distributed.ReduceOp.MIN) + _TRAIN_START_TIME = start_time_tensor.item() + print_rank_0('time to initialize megatron (seconds): {:.3f}'.format( + time.time() - _TRAIN_START_TIME)) + + print_rank_0('Building LLaMA model ...') + config = core_transformer_config_from_args(get_args()) + model = LLaMAModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True + ) + + args = get_args() + print_rank_0('Load LLaMA dataset ...') + # Number of train/valid/test samples. + if args.train_samples: + train_samples = args.train_samples + else: + train_samples = args.train_iters * args.global_batch_size + eval_iters = (args.train_iters // args.eval_interval + 1) * \ + args.eval_iters + test_iters = args.eval_iters + train_val_test_num_samples = [train_samples, + eval_iters * args.global_batch_size, + test_iters * args.global_batch_size] + print_rank_0(' > datasets target sizes (minimum size):') + print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) + print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) + print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) + + pretrain(model, train_ds, valid_ds, test_ds) + + +if __name__ == "__main__": + main() diff --git a/training/DeepSpeed-Domino/pretrain_llama_13b.sh b/training/DeepSpeed-Domino/pretrain_llama_13b.sh new file mode 100644 index 000000000..1a438513a --- /dev/null +++ b/training/DeepSpeed-Domino/pretrain_llama_13b.sh @@ -0,0 +1,84 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from pretrain_llama.sh in Megatron-LM + +#!/bin/bash + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +GPUS_PER_NODE=2 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CHECKPOINT_PATH=checkpoint +rm -rf $CHECKPOINT_PATH/* +VOCAB_FILE="dataset/gpt2-vocab.json" +MERGE_FILE="dataset/gpt2-merges.txt" +DATA_PATH="dataset/my-gpt2_text_document" + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +export PYTHONPATH=$SCRIPT_DIR:$PYTHONPATH + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +LLAMA_ARGS=" + --llama-model \ + --num-layers 2 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --seq-length 1024 \ + --max-position-embeddings 1024 \ + --position-embedding-type rope \ + --swiglu \ + --ffn-hidden-size 11008 \ + --disable-bias-linear \ + --normalization RMSNorm \ + --layernorm-epsilon 1e-6 \ + --micro-batch-size 8 \ + --global-batch-size 8 \ + --lr 0.00015 \ + --train-iters 100 \ + --lr-decay-iters 320000 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --lr-warmup-fraction .01 \ + --clip-grad 1.0 \ + --no-gradient-accumulation-fusion \ + --fp16 \ + --tensor-model-parallel-size $WORLD_SIZE \ + --seed 3407 \ + --causal-lm +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --vocab-file $VOCAB_FILE \ + --merge-file $MERGE_FILE \ + --split 949,50,1 +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 1 +" + +cmd="deepspeed --num_gpus $WORLD_SIZE \ + pretrain_llama.py \ + $LLAMA_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS + " +echo $cmd +eval $cmd diff --git a/training/DeepSpeed-Domino/pretrain_llama_7b.sh b/training/DeepSpeed-Domino/pretrain_llama_7b.sh new file mode 100644 index 000000000..ddef81382 --- /dev/null +++ b/training/DeepSpeed-Domino/pretrain_llama_7b.sh @@ -0,0 +1,80 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# This file is adapted from pretrain_llama.sh in Megatron-LM + +#!/bin/bash + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +GPUS_PER_NODE=2 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CHECKPOINT_PATH=checkpoint +rm -rf $CHECKPOINT_PATH/* +VOCAB_FILE="dataset/gpt2-vocab.json" +MERGE_FILE="dataset/gpt2-merges.txt" +DATA_PATH="dataset/my-gpt2_text_document" + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +export PYTHONPATH=$SCRIPT_DIR:$PYTHONPATH + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +LLAMA_ARGS=" + --num-layers 32 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --seq-length 1024 \ + --max-position-embeddings 1024 \ + --position-embedding-type rope \ + --swiglu \ + --ffn-hidden-size 11008\ + --normalization RMSNorm \ + --layernorm-epsilon 1e-6 \ + --micro-batch-size 16 \ + --global-batch-size 16 \ + --lr 0.00015 \ + --train-iters 80 \ + --lr-decay-iters 320000 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --lr-warmup-fraction .01 \ + --clip-grad 1.0 \ + --fp16 \ + --tensor-model-parallel-size $WORLD_SIZE \ + --seed 3407 +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --vocab-file $VOCAB_FILE \ + --merge-file $MERGE_FILE \ + --split 949,50,1 +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 1 +" + +cmd="deepspeed --num_gpus $WORLD_SIZE \ + pretrain_llama.py \ + $LLAMA_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS + " +echo $cmd +eval $cmd diff --git a/training/DeepSpeed-Domino/requirements.txt b/training/DeepSpeed-Domino/requirements.txt new file mode 100644 index 000000000..a39083e71 --- /dev/null +++ b/training/DeepSpeed-Domino/requirements.txt @@ -0,0 +1,6 @@ +apex +deepspeed +nltk +pybind11 +transformers +