Skip to content

Commit

Permalink
Tensor-Parallelism general support (#1512)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
3 people authored Nov 12, 2021
1 parent b16dd94 commit 9ce00a2
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 24 deletions.
9 changes: 7 additions & 2 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def init_inference(model,
dtype=None,
injection_policy=None,
replace_method='auto',
quantization_setting=None):
quantization_setting=None,
replace_with_kernel_inject=False,
return_tuple=True):
"""Initialize the DeepSpeed InferenceEngine.
Arguments:
Expand Down Expand Up @@ -267,6 +269,7 @@ def init_inference(model,
of groups used in quantization. A tuple is passed in if we want to mention that there is extra-grouping
for the MLP part of a Transformer layer (e.g. (True, 8) shows we quantize the model using 8 groups for
all the network except the MLP part that we use 8 extra grouping).
replace_with_kernel_inject: If set we inject kernel as we initialize the inference-engine
Returns:
A deepspeed.InferenceEngine wrapped model.
Expand All @@ -286,7 +289,9 @@ def init_inference(model,
checkpoint,
dtype,
injection_policy,
return_tuple,
replace_method,
quantization_setting)
quantization_setting,
replace_with_kernel_inject)

return engine
20 changes: 13 additions & 7 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(self,
injection_dict=None,
return_tuple=True,
replace_method='auto',
quantization_setting=None):
quantization_setting=None,
replace_with_kernel_inject=False):
"""
Args:
model: torch.nn.Module
Expand Down Expand Up @@ -74,15 +75,17 @@ def __init__(self,
self.mp_group = self.mpu.get_model_parallel_group()
elif self.mp_world_size > 1:
self._create_model_parallel_group()

# apply injection policy
if self.injection_dict:
for client_module, injection_policy in self.injection_dict.items():
self._apply_injection_policy(client_module,
injection_policy,
return_tuple)
elif replace_method == "auto":
self._apply_injection_policy()
return_tuple,
replace_with_kernel_inject)
elif replace_method == 'auto':
self._apply_injection_policy(
return_tuple=return_tuple,
replace_with_kernel_inject=replace_with_kernel_inject)

device = torch.cuda.current_device()
logger.info(f"Place model to device: {device}")
Expand Down Expand Up @@ -152,7 +155,9 @@ def _validate_args(self, mpu):
def _apply_injection_policy(self,
client_module=None,
injection_policy=None,
return_tuple=True):
return_tuple=True,
replace_with_kernel_inject=False):

replace_transformer_layer(client_module,
self.module,
policy=injection_policy,
Expand All @@ -166,7 +171,8 @@ def _apply_injection_policy(self,
quantize_settings=(self.quantization_scales,
self.quantize_merge_count,
self.mlp_extra_grouping,
self.quantize_groups))
self.quantize_groups),
replace_with_kernel_inject=replace_with_kernel_inject)

def _load_checkpoint(self, load_dir, load_module_strict=True):
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir)
Expand Down
170 changes: 156 additions & 14 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,40 @@
import torch
import deepspeed
import deepspeed.ops.transformer as transformer_inference
from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy
from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy
from .replace_policy import replace_policies
from ..constants import INFERENCE_GENERIC_MODE, INFERENCE_SPECIALIZED_MODE
from ..runtime.weight_quantizer import WeightQuantization
from torch import nn


class LinearAllreduce(nn.Module):
def __init__(self, weight, bias=None, mp_group=None):
super(LinearAllreduce, self).__init__()
self.weight = weight
self.bias = bias
self.mp_group = mp_group

def forward(self, input):
output = torch.matmul(input, self.weight)
if self.mp_group is not None:
torch.distributed.all_reduce(output, group=self.mp_group)
if self.bias is not None:
output += self.bias
return output


class LinearLayer(nn.Module):
def __init__(self, weight, bias=None):
super(LinearLayer, self).__init__()
self.weight = weight
self.bias = bias

def forward(self, input):
output = torch.matmul(input, self.weight)
if self.bias is not None:
output += self.bias
return output


class ReplaceWithTensorSlicing:
Expand Down Expand Up @@ -103,13 +133,17 @@ def replace_transformer_layer(orig_layer_impl,
training=True,
quantize=False,
quantize_settings=None,
return_tuple=False):
return_tuple=True,
replace_with_kernel_inject=False,
linear_layer_setting=None):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
policy: shows the policy for mapping from the orig_layer_impl to transformer parameters
policy: shows the policy for mapping from the orig_layer_impl to transformer parameters when
replace_with_kernel_inject is set, otherwise, it provides the names of two linear layers as
a tuple: (attention_output projection, transformer output projection)
micro_batch_size (int): micro batch size per gpu used during training/eval
config (dict): model config containing hidden size, attention heads, etc.
seed (int): random seed value
Expand All @@ -127,7 +161,12 @@ def replace_transformer_layer(orig_layer_impl,
It includes (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups).
return_tuple (bool): if set, transformer layer returns a tuple as the output.
Note: this flag needs to be set for huggingface models.
replace_with_kernel_inject (bool): injection_mode, if true, kernels will be add along with configuring
Tensor-Parallelism
linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers
and embedding layers
attention_params: (list of strings) [Optional]: shows the parameters in the attention part that needs to
be adjusted based on the model-parallelism
Returns:
Updated nn.module with replaced transformer layers
"""
Expand Down Expand Up @@ -299,18 +338,126 @@ def transpose(data):
new_module.output_b.data = _4hh_b
return new_module

def replace_wo_policy(module, all_reduce_linears):
def _replace(child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
if name in all_reduce_linears:
new_weight = torch.empty(
(child.weight.shape[0]
if conv_linear_layer else child.weight.shape[1] // mp_size,
child.weight.shape[1]
if conv_linear_layer else child.weight.shape[0]),
device=child.weight.device,
dtype=torch.half if fp16 else torch.float)
if not conv_linear_layer:
child.weight.data.view(-1).copy_(
child.weight.data.transpose(-1,
-2).contiguous().view(-1))
child.weight.data = child.weight.data.reshape(
child.weight.data.shape[-1],
child.weight.data.shape[-2])
data = mp_replace.copy(new_weight,
child.weight.data).to(torch.cuda.current_device())
return LinearAllreduce(data, child.bias if child.bias is None else \
child.bias.to(torch.cuda.current_device()), mp_group)
else:
new_weight = torch.empty(
(child.weight.shape[0] //
mp_size if conv_linear_layer else child.weight.shape[1],
child.weight.shape[1]
if conv_linear_layer else child.weight.shape[0] // mp_size),
device=child.weight.device,
dtype=torch.half if fp16 else torch.float)
if not conv_linear_layer:
child.weight.data.view(-1).copy_(
child.weight.data.transpose(-1,
-2).contiguous().view(-1))
child.weight.data = child.weight.data.reshape(
child.weight.data.shape[-1],
child.weight.data.shape[-2])
data = mp_replace.copy(new_weight, child.weight.data)
new_bias = torch.empty((child.weight.shape[1] // mp_size),
device=child.weight.device,
dtype=torch.half if fp16 else torch.float)
bias_data = None if child.bias is None else mp_replace.copy(
new_bias,
child.bias.data).to(torch.cuda.current_device())
return LinearLayer(data.to(torch.cuda.current_device()), bias_data)

def _slice_embedding(child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
new_weight = torch.empty((child.weight.shape[0],
child.weight.shape[1] // mp_size),
device=child.weight.device,
dtype=child.weight.dtype)
data = mp_replace.copy(new_weight, child.weight.data)
new_embedding = nn.Embedding(child.weight.shape[0],
child.weight.shape[1] // mp_size)
new_embedding.weight.data.copy_(data)
return new_embedding

def update_mp_params(child):
if hasattr(child, 'n_heads'):
child.n_heads = child.n_heads // mp_size
if hasattr(child, 'inner_dim'):
child.inner_dim = child.inner_dim // mp_size
if hasattr(child, 'num_heads'):
child.num_heads = child.num_heads // mp_size
if hasattr(child, 'num_attention_heads'):
child.num_attention_heads = child.num_attention_heads // mp_size
if hasattr(child, 'all_head_size'):
child.all_head_size = child.all_head_size // mp_size
if hasattr(child, 'embed_dim'):
child.embed_dim = child.embed_dim // mp_size

conv_linear_layer = False
if linear_layer_setting is not None:
linear_policies = {linear_layer_setting[0]: _replace}
if len(linear_layer_setting) == 2:
linear_policies.update({linear_layer_setting[1]: _slice_embedding})
else:
if orig_layer_impl is HFGPT2LayerPolicy._orig_layer_class:
try:
import transformers
conv_linear_layer = True
linear_policies = {transformers.model_utils.Conv1D: _replace}
except ImportError:
linear_policies = {nn.Linear: _replace}
else:
linear_policies = {nn.Linear: _replace, nn.Embedding: _slice_embedding}

def _replace_module(r_module, prev_name=''):
for name, child in r_module.named_children():
if child.__class__ in linear_policies:
setattr(
r_module,
name,
linear_policies[child.__class__](child,
prev_name + '.' + name,
conv_linear_layer))
else:
update_mp_params(child)
_replace_module(child, name)
return r_module

return _replace_module(module)

def replace_fn(child, _policy, layer_id=0):
if training:
# copy relevant state from child -> new module
new_module = replace_with_policy(child, _policy, preln=preln)

else:
# copy relevant state from child -> new module
new_module = replace_with_policy(child,
_policy,
inference=True,
preln=(policy is not HFBertLayerPolicy),
layer_id=layer_id)
if replace_with_kernel_inject:
new_module = replace_with_policy(
child,
_policy,
inference=True,
preln=(_policy is not HFBertLayerPolicy),
layer_id=layer_id)
else:
new_module = replace_wo_policy(child, _policy)

return new_module

Expand All @@ -327,7 +474,6 @@ def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
config (dict): model config containing hidden size, attention heads, etc.
Returns:
Updated nn.module with original bert-style transformer layers
"""
Expand Down Expand Up @@ -396,7 +542,6 @@ def replace_module(model, orig_class, replace_fn, _replace_policy):
orig_class (torch.nn.Module): the module to search for
replace_fn (method): a method to convert instances of ``orig_class`` to the
desired type and return a new instance.
Returns:
A modified ``model``.
"""
Expand All @@ -422,20 +567,17 @@ def _replace_module(model, policies, layer_id=0):
Arguments:
model (torch.nn.Module): model to augment
policies (dict): Mapping of source class to replacement function.
Returns:
Modified ``model``.
"""
for name, child in model.named_children():
if child.__class__ in policies:
orig = repr(child)
setattr(
model,
name,
policies[child.__class__][0](child,
policies[child.__class__][-1],
layer_id))
new = getattr(model, name)
layer_id += 1
else:
_, layer_id = _replace_module(child, policies, layer_id=layer_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,6 @@ def forward(self,
output = (output, presents)

if self.config.return_tuple:
return (output, )
return output if type(output) is tuple else (output, )
else:
return output

0 comments on commit 9ce00a2

Please sign in to comment.