From 9ce00a2171af9284b8c3238f2579a4d7d6a4e190 Mon Sep 17 00:00:00 2001 From: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Date: Thu, 11 Nov 2021 22:22:57 -0800 Subject: [PATCH] Tensor-Parallelism general support (#1512) Co-authored-by: Olatunji Ruwase Co-authored-by: Jeff Rasley --- deepspeed/__init__.py | 9 +- deepspeed/inference/engine.py | 20 ++- deepspeed/module_inject/replace_module.py | 170 ++++++++++++++++-- .../inference/transformer_inference.py | 2 +- 4 files changed, 177 insertions(+), 24 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 3eeb6d9dbd0e..0c0e61f0ec5d 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -237,7 +237,9 @@ def init_inference(model, dtype=None, injection_policy=None, replace_method='auto', - quantization_setting=None): + quantization_setting=None, + replace_with_kernel_inject=False, + return_tuple=True): """Initialize the DeepSpeed InferenceEngine. Arguments: @@ -267,6 +269,7 @@ def init_inference(model, of groups used in quantization. A tuple is passed in if we want to mention that there is extra-grouping for the MLP part of a Transformer layer (e.g. (True, 8) shows we quantize the model using 8 groups for all the network except the MLP part that we use 8 extra grouping). + replace_with_kernel_inject: If set we inject kernel as we initialize the inference-engine Returns: A deepspeed.InferenceEngine wrapped model. @@ -286,7 +289,9 @@ def init_inference(model, checkpoint, dtype, injection_policy, + return_tuple, replace_method, - quantization_setting) + quantization_setting, + replace_with_kernel_inject) return engine diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 57c8b67e4048..6873d8d6a19c 100644 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -25,7 +25,8 @@ def __init__(self, injection_dict=None, return_tuple=True, replace_method='auto', - quantization_setting=None): + quantization_setting=None, + replace_with_kernel_inject=False): """ Args: model: torch.nn.Module @@ -74,15 +75,17 @@ def __init__(self, self.mp_group = self.mpu.get_model_parallel_group() elif self.mp_world_size > 1: self._create_model_parallel_group() - # apply injection policy if self.injection_dict: for client_module, injection_policy in self.injection_dict.items(): self._apply_injection_policy(client_module, injection_policy, - return_tuple) - elif replace_method == "auto": - self._apply_injection_policy() + return_tuple, + replace_with_kernel_inject) + elif replace_method == 'auto': + self._apply_injection_policy( + return_tuple=return_tuple, + replace_with_kernel_inject=replace_with_kernel_inject) device = torch.cuda.current_device() logger.info(f"Place model to device: {device}") @@ -152,7 +155,9 @@ def _validate_args(self, mpu): def _apply_injection_policy(self, client_module=None, injection_policy=None, - return_tuple=True): + return_tuple=True, + replace_with_kernel_inject=False): + replace_transformer_layer(client_module, self.module, policy=injection_policy, @@ -166,7 +171,8 @@ def _apply_injection_policy(self, quantize_settings=(self.quantization_scales, self.quantize_merge_count, self.mlp_extra_grouping, - self.quantize_groups)) + self.quantize_groups), + replace_with_kernel_inject=replace_with_kernel_inject) def _load_checkpoint(self, load_dir, load_module_strict=True): sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 2cabc3ec105e..9e60e3583e38 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -2,10 +2,40 @@ import torch import deepspeed import deepspeed.ops.transformer as transformer_inference -from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy +from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy from .replace_policy import replace_policies from ..constants import INFERENCE_GENERIC_MODE, INFERENCE_SPECIALIZED_MODE from ..runtime.weight_quantizer import WeightQuantization +from torch import nn + + +class LinearAllreduce(nn.Module): + def __init__(self, weight, bias=None, mp_group=None): + super(LinearAllreduce, self).__init__() + self.weight = weight + self.bias = bias + self.mp_group = mp_group + + def forward(self, input): + output = torch.matmul(input, self.weight) + if self.mp_group is not None: + torch.distributed.all_reduce(output, group=self.mp_group) + if self.bias is not None: + output += self.bias + return output + + +class LinearLayer(nn.Module): + def __init__(self, weight, bias=None): + super(LinearLayer, self).__init__() + self.weight = weight + self.bias = bias + + def forward(self, input): + output = torch.matmul(input, self.weight) + if self.bias is not None: + output += self.bias + return output class ReplaceWithTensorSlicing: @@ -103,13 +133,17 @@ def replace_transformer_layer(orig_layer_impl, training=True, quantize=False, quantize_settings=None, - return_tuple=False): + return_tuple=True, + replace_with_kernel_inject=False, + linear_layer_setting=None): """ Replace bert-style transformer layers with DeepSpeed's transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, e.g., transformers.modeling_bert.BertLayer. model (torch.nn.Module): user's nn.module representing their model - policy: shows the policy for mapping from the orig_layer_impl to transformer parameters + policy: shows the policy for mapping from the orig_layer_impl to transformer parameters when + replace_with_kernel_inject is set, otherwise, it provides the names of two linear layers as + a tuple: (attention_output projection, transformer output projection) micro_batch_size (int): micro batch size per gpu used during training/eval config (dict): model config containing hidden size, attention heads, etc. seed (int): random seed value @@ -127,7 +161,12 @@ def replace_transformer_layer(orig_layer_impl, It includes (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups). return_tuple (bool): if set, transformer layer returns a tuple as the output. Note: this flag needs to be set for huggingface models. - + replace_with_kernel_inject (bool): injection_mode, if true, kernels will be add along with configuring + Tensor-Parallelism + linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers + and embedding layers + attention_params: (list of strings) [Optional]: shows the parameters in the attention part that needs to + be adjusted based on the model-parallelism Returns: Updated nn.module with replaced transformer layers """ @@ -299,6 +338,110 @@ def transpose(data): new_module.output_b.data = _4hh_b return new_module + def replace_wo_policy(module, all_reduce_linears): + def _replace(child, name, conv_linear_layer): + mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) + if name in all_reduce_linears: + new_weight = torch.empty( + (child.weight.shape[0] + if conv_linear_layer else child.weight.shape[1] // mp_size, + child.weight.shape[1] + if conv_linear_layer else child.weight.shape[0]), + device=child.weight.device, + dtype=torch.half if fp16 else torch.float) + if not conv_linear_layer: + child.weight.data.view(-1).copy_( + child.weight.data.transpose(-1, + -2).contiguous().view(-1)) + child.weight.data = child.weight.data.reshape( + child.weight.data.shape[-1], + child.weight.data.shape[-2]) + data = mp_replace.copy(new_weight, + child.weight.data).to(torch.cuda.current_device()) + return LinearAllreduce(data, child.bias if child.bias is None else \ + child.bias.to(torch.cuda.current_device()), mp_group) + else: + new_weight = torch.empty( + (child.weight.shape[0] // + mp_size if conv_linear_layer else child.weight.shape[1], + child.weight.shape[1] + if conv_linear_layer else child.weight.shape[0] // mp_size), + device=child.weight.device, + dtype=torch.half if fp16 else torch.float) + if not conv_linear_layer: + child.weight.data.view(-1).copy_( + child.weight.data.transpose(-1, + -2).contiguous().view(-1)) + child.weight.data = child.weight.data.reshape( + child.weight.data.shape[-1], + child.weight.data.shape[-2]) + data = mp_replace.copy(new_weight, child.weight.data) + new_bias = torch.empty((child.weight.shape[1] // mp_size), + device=child.weight.device, + dtype=torch.half if fp16 else torch.float) + bias_data = None if child.bias is None else mp_replace.copy( + new_bias, + child.bias.data).to(torch.cuda.current_device()) + return LinearLayer(data.to(torch.cuda.current_device()), bias_data) + + def _slice_embedding(child, name, conv_linear_layer): + mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) + new_weight = torch.empty((child.weight.shape[0], + child.weight.shape[1] // mp_size), + device=child.weight.device, + dtype=child.weight.dtype) + data = mp_replace.copy(new_weight, child.weight.data) + new_embedding = nn.Embedding(child.weight.shape[0], + child.weight.shape[1] // mp_size) + new_embedding.weight.data.copy_(data) + return new_embedding + + def update_mp_params(child): + if hasattr(child, 'n_heads'): + child.n_heads = child.n_heads // mp_size + if hasattr(child, 'inner_dim'): + child.inner_dim = child.inner_dim // mp_size + if hasattr(child, 'num_heads'): + child.num_heads = child.num_heads // mp_size + if hasattr(child, 'num_attention_heads'): + child.num_attention_heads = child.num_attention_heads // mp_size + if hasattr(child, 'all_head_size'): + child.all_head_size = child.all_head_size // mp_size + if hasattr(child, 'embed_dim'): + child.embed_dim = child.embed_dim // mp_size + + conv_linear_layer = False + if linear_layer_setting is not None: + linear_policies = {linear_layer_setting[0]: _replace} + if len(linear_layer_setting) == 2: + linear_policies.update({linear_layer_setting[1]: _slice_embedding}) + else: + if orig_layer_impl is HFGPT2LayerPolicy._orig_layer_class: + try: + import transformers + conv_linear_layer = True + linear_policies = {transformers.model_utils.Conv1D: _replace} + except ImportError: + linear_policies = {nn.Linear: _replace} + else: + linear_policies = {nn.Linear: _replace, nn.Embedding: _slice_embedding} + + def _replace_module(r_module, prev_name=''): + for name, child in r_module.named_children(): + if child.__class__ in linear_policies: + setattr( + r_module, + name, + linear_policies[child.__class__](child, + prev_name + '.' + name, + conv_linear_layer)) + else: + update_mp_params(child) + _replace_module(child, name) + return r_module + + return _replace_module(module) + def replace_fn(child, _policy, layer_id=0): if training: # copy relevant state from child -> new module @@ -306,11 +449,15 @@ def replace_fn(child, _policy, layer_id=0): else: # copy relevant state from child -> new module - new_module = replace_with_policy(child, - _policy, - inference=True, - preln=(policy is not HFBertLayerPolicy), - layer_id=layer_id) + if replace_with_kernel_inject: + new_module = replace_with_policy( + child, + _policy, + inference=True, + preln=(_policy is not HFBertLayerPolicy), + layer_id=layer_id) + else: + new_module = replace_wo_policy(child, _policy) return new_module @@ -327,7 +474,6 @@ def revert_transformer_layer(orig_layer_impl, model, config, preln=False): e.g., transformers.modeling_bert.BertLayer. model (torch.nn.Module): user's nn.module representing their model config (dict): model config containing hidden size, attention heads, etc. - Returns: Updated nn.module with original bert-style transformer layers """ @@ -396,7 +542,6 @@ def replace_module(model, orig_class, replace_fn, _replace_policy): orig_class (torch.nn.Module): the module to search for replace_fn (method): a method to convert instances of ``orig_class`` to the desired type and return a new instance. - Returns: A modified ``model``. """ @@ -422,20 +567,17 @@ def _replace_module(model, policies, layer_id=0): Arguments: model (torch.nn.Module): model to augment policies (dict): Mapping of source class to replacement function. - Returns: Modified ``model``. """ for name, child in model.named_children(): if child.__class__ in policies: - orig = repr(child) setattr( model, name, policies[child.__class__][0](child, policies[child.__class__][-1], layer_id)) - new = getattr(model, name) layer_id += 1 else: _, layer_id = _replace_module(child, policies, layer_id=layer_id) diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index 4f65c121bc8f..07a8debeb85d 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -620,6 +620,6 @@ def forward(self, output = (output, presents) if self.config.return_tuple: - return (output, ) + return output if type(output) is tuple else (output, ) else: return output