-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensor-Parallelism general support #1512
Changes from all commits
f7ef4b5
dfb603f
c5ecf32
0d8dd45
9618523
7752ec9
aa9eeca
ebdeda9
ef1298b
edb157a
70a4d81
8d14751
0024abc
bd3b08f
4092412
8e71030
b4a0ad8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,40 @@ | |
import torch | ||
import deepspeed | ||
import deepspeed.ops.transformer as transformer_inference | ||
from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy | ||
from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy | ||
from .replace_policy import replace_policies | ||
from ..constants import INFERENCE_GENERIC_MODE, INFERENCE_SPECIALIZED_MODE | ||
from ..runtime.weight_quantizer import WeightQuantization | ||
from torch import nn | ||
|
||
|
||
class LinearAllreduce(nn.Module): | ||
def __init__(self, weight, bias=None, mp_group=None): | ||
super(LinearAllreduce, self).__init__() | ||
self.weight = weight | ||
self.bias = bias | ||
self.mp_group = mp_group | ||
|
||
def forward(self, input): | ||
output = torch.matmul(input, self.weight) | ||
if self.mp_group is not None: | ||
torch.distributed.all_reduce(output, group=self.mp_group) | ||
if self.bias is not None: | ||
output += self.bias | ||
return output | ||
|
||
|
||
class LinearLayer(nn.Module): | ||
def __init__(self, weight, bias=None): | ||
super(LinearLayer, self).__init__() | ||
self.weight = weight | ||
self.bias = bias | ||
|
||
def forward(self, input): | ||
output = torch.matmul(input, self.weight) | ||
if self.bias is not None: | ||
output += self.bias | ||
return output | ||
|
||
|
||
class ReplaceWithTensorSlicing: | ||
|
@@ -103,13 +133,17 @@ def replace_transformer_layer(orig_layer_impl, | |
training=True, | ||
quantize=False, | ||
quantize_settings=None, | ||
return_tuple=False): | ||
return_tuple=True, | ||
replace_with_kernel_inject=False, | ||
linear_layer_setting=None): | ||
""" Replace bert-style transformer layers with DeepSpeed's transformer layer | ||
Arguments: | ||
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, | ||
e.g., transformers.modeling_bert.BertLayer. | ||
model (torch.nn.Module): user's nn.module representing their model | ||
policy: shows the policy for mapping from the orig_layer_impl to transformer parameters | ||
policy: shows the policy for mapping from the orig_layer_impl to transformer parameters when | ||
replace_with_kernel_inject is set, otherwise, it provides the names of two linear layers as | ||
a tuple: (attention_output projection, transformer output projection) | ||
micro_batch_size (int): micro batch size per gpu used during training/eval | ||
config (dict): model config containing hidden size, attention heads, etc. | ||
seed (int): random seed value | ||
|
@@ -127,7 +161,12 @@ def replace_transformer_layer(orig_layer_impl, | |
It includes (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups). | ||
return_tuple (bool): if set, transformer layer returns a tuple as the output. | ||
Note: this flag needs to be set for huggingface models. | ||
|
||
replace_with_kernel_inject (bool): injection_mode, if true, kernels will be add along with configuring | ||
Tensor-Parallelism | ||
linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers | ||
and embedding layers | ||
attention_params: (list of strings) [Optional]: shows the parameters in the attention part that needs to | ||
be adjusted based on the model-parallelism | ||
Returns: | ||
Updated nn.module with replaced transformer layers | ||
""" | ||
|
@@ -299,18 +338,126 @@ def transpose(data): | |
new_module.output_b.data = _4hh_b | ||
return new_module | ||
|
||
def replace_wo_policy(module, all_reduce_linears): | ||
def _replace(child, name, conv_linear_layer): | ||
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) | ||
if name in all_reduce_linears: | ||
new_weight = torch.empty( | ||
(child.weight.shape[0] | ||
if conv_linear_layer else child.weight.shape[1] // mp_size, | ||
child.weight.shape[1] | ||
if conv_linear_layer else child.weight.shape[0]), | ||
device=child.weight.device, | ||
dtype=torch.half if fp16 else torch.float) | ||
if not conv_linear_layer: | ||
child.weight.data.view(-1).copy_( | ||
child.weight.data.transpose(-1, | ||
-2).contiguous().view(-1)) | ||
child.weight.data = child.weight.data.reshape( | ||
child.weight.data.shape[-1], | ||
child.weight.data.shape[-2]) | ||
data = mp_replace.copy(new_weight, | ||
child.weight.data).to(torch.cuda.current_device()) | ||
return LinearAllreduce(data, child.bias if child.bias is None else \ | ||
child.bias.to(torch.cuda.current_device()), mp_group) | ||
else: | ||
new_weight = torch.empty( | ||
(child.weight.shape[0] // | ||
mp_size if conv_linear_layer else child.weight.shape[1], | ||
child.weight.shape[1] | ||
if conv_linear_layer else child.weight.shape[0] // mp_size), | ||
device=child.weight.device, | ||
dtype=torch.half if fp16 else torch.float) | ||
if not conv_linear_layer: | ||
child.weight.data.view(-1).copy_( | ||
child.weight.data.transpose(-1, | ||
-2).contiguous().view(-1)) | ||
child.weight.data = child.weight.data.reshape( | ||
child.weight.data.shape[-1], | ||
child.weight.data.shape[-2]) | ||
data = mp_replace.copy(new_weight, child.weight.data) | ||
new_bias = torch.empty((child.weight.shape[1] // mp_size), | ||
device=child.weight.device, | ||
dtype=torch.half if fp16 else torch.float) | ||
bias_data = None if child.bias is None else mp_replace.copy( | ||
new_bias, | ||
child.bias.data).to(torch.cuda.current_device()) | ||
return LinearLayer(data.to(torch.cuda.current_device()), bias_data) | ||
|
||
def _slice_embedding(child, name, conv_linear_layer): | ||
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) | ||
new_weight = torch.empty((child.weight.shape[0], | ||
child.weight.shape[1] // mp_size), | ||
device=child.weight.device, | ||
dtype=child.weight.dtype) | ||
data = mp_replace.copy(new_weight, child.weight.data) | ||
new_embedding = nn.Embedding(child.weight.shape[0], | ||
child.weight.shape[1] // mp_size) | ||
new_embedding.weight.data.copy_(data) | ||
return new_embedding | ||
|
||
def update_mp_params(child): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @RezaYazdaniAminabadi @stas00 This code works well for a few cases, but I don't think it's a good structure to scale to 70 models. Is there any more efficient way? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @jaketae do you have any nice idea for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is true that this part needs some refactoring. Please let me know if you have some ideas There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I left a comment on the issue. please note. |
||
if hasattr(child, 'n_heads'): | ||
child.n_heads = child.n_heads // mp_size | ||
if hasattr(child, 'inner_dim'): | ||
child.inner_dim = child.inner_dim // mp_size | ||
if hasattr(child, 'num_heads'): | ||
child.num_heads = child.num_heads // mp_size | ||
if hasattr(child, 'num_attention_heads'): | ||
child.num_attention_heads = child.num_attention_heads // mp_size | ||
if hasattr(child, 'all_head_size'): | ||
child.all_head_size = child.all_head_size // mp_size | ||
if hasattr(child, 'embed_dim'): | ||
child.embed_dim = child.embed_dim // mp_size | ||
|
||
conv_linear_layer = False | ||
if linear_layer_setting is not None: | ||
linear_policies = {linear_layer_setting[0]: _replace} | ||
if len(linear_layer_setting) == 2: | ||
linear_policies.update({linear_layer_setting[1]: _slice_embedding}) | ||
else: | ||
if orig_layer_impl is HFGPT2LayerPolicy._orig_layer_class: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than doing this, how about loading that module object and checking that it is a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good point. Thanks @hyunwoongko |
||
try: | ||
import transformers | ||
conv_linear_layer = True | ||
linear_policies = {transformers.model_utils.Conv1D: _replace} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you set GPT2 not to slice embeddings? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any special reason? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please tell me if I misunderstood. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That embedding is not part of the layer, but a model. What I am slicing here is the transformer layers. Basically, that is just a small part of the model |
||
except ImportError: | ||
linear_policies = {nn.Linear: _replace} | ||
else: | ||
linear_policies = {nn.Linear: _replace, nn.Embedding: _slice_embedding} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do you differentiate between positional embedding and token embedding? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used |
||
|
||
def _replace_module(r_module, prev_name=''): | ||
for name, child in r_module.named_children(): | ||
if child.__class__ in linear_policies: | ||
setattr( | ||
r_module, | ||
name, | ||
linear_policies[child.__class__](child, | ||
prev_name + '.' + name, | ||
conv_linear_layer)) | ||
else: | ||
update_mp_params(child) | ||
_replace_module(child, name) | ||
return r_module | ||
|
||
return _replace_module(module) | ||
|
||
def replace_fn(child, _policy, layer_id=0): | ||
if training: | ||
# copy relevant state from child -> new module | ||
new_module = replace_with_policy(child, _policy, preln=preln) | ||
|
||
else: | ||
# copy relevant state from child -> new module | ||
new_module = replace_with_policy(child, | ||
_policy, | ||
inference=True, | ||
preln=(policy is not HFBertLayerPolicy), | ||
layer_id=layer_id) | ||
if replace_with_kernel_inject: | ||
new_module = replace_with_policy( | ||
child, | ||
_policy, | ||
inference=True, | ||
preln=(_policy is not HFBertLayerPolicy), | ||
layer_id=layer_id) | ||
else: | ||
new_module = replace_wo_policy(child, _policy) | ||
|
||
return new_module | ||
|
||
|
@@ -327,7 +474,6 @@ def revert_transformer_layer(orig_layer_impl, model, config, preln=False): | |
e.g., transformers.modeling_bert.BertLayer. | ||
model (torch.nn.Module): user's nn.module representing their model | ||
config (dict): model config containing hidden size, attention heads, etc. | ||
|
||
Returns: | ||
Updated nn.module with original bert-style transformer layers | ||
""" | ||
|
@@ -396,7 +542,6 @@ def replace_module(model, orig_class, replace_fn, _replace_policy): | |
orig_class (torch.nn.Module): the module to search for | ||
replace_fn (method): a method to convert instances of ``orig_class`` to the | ||
desired type and return a new instance. | ||
|
||
Returns: | ||
A modified ``model``. | ||
""" | ||
|
@@ -422,20 +567,17 @@ def _replace_module(model, policies, layer_id=0): | |
Arguments: | ||
model (torch.nn.Module): model to augment | ||
policies (dict): Mapping of source class to replacement function. | ||
|
||
Returns: | ||
Modified ``model``. | ||
""" | ||
for name, child in model.named_children(): | ||
if child.__class__ in policies: | ||
orig = repr(child) | ||
setattr( | ||
model, | ||
name, | ||
policies[child.__class__][0](child, | ||
policies[child.__class__][-1], | ||
layer_id)) | ||
new = getattr(model, name) | ||
layer_id += 1 | ||
else: | ||
_, layer_id = _replace_module(child, policies, layer_id=layer_id) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the strategy of this method usually to apply column slice, and if the name of specific layers are input, to apply row slice them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we automate this a bit more? It would be nice to have a strategy that doesn't require parameter names at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RezaYazdaniAminabadi You probably thought more than me when you are making this. I'm curious about your opinion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been thinking briefly, what about profiling strategy?
forward
function of each Linear or Conv1D module with theprofiling_forward
function. This function measures the time the layer has been forwarded. and addget_first_forwared_time
function. this function returns the time of first forward. if this time value is exist,profiling_forward
no longer measures time.get_first_forwared_time
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems more flexible to just use torch.fx than this. I'll start automate the whole process of tensor & pipeline parallelization using torch.fx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm missing the full context. Do you suggest to have a policy record for each model like in the example you have shown here:
#1512 (comment)
I'd help to see several full examples, then it's much easier to see how it can be integrated.
For example I started integrating Deepspeed-Inference huggingface/transformers#14426
after studying a few examples here: microsoft/DeepSpeedExamples#144
So I can see what's common, what's unique, which code sections are the driver and need to go into into the Trainer loop.
Monkey-see, monkey-do style is the easiest w/o needing to figure out all the low-level details.
Does it make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, please.
As I reported to you originally it didn't appear that different OSLO components can be integrated separately and require all other OSLO components to work.
So Deepspeed-Inference I can relatively easily integrate into the HF Trainer since it doesn't require me to use anything else other than wrapping the model. We just need to figure out a few MPU quirks. With OSLO I have no idea how to do it, because what I tried didn't work.
But let's not derail this PR and discuss OSLO either on OSLO or HF Transformers Issues. This PR is about Deepspeed-Inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has nothing to do with deepspeed, so let's talk about the transformers issue.
huggingface/transformers#13690
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is still not so easy to find which one should be using all-reduce, as it can be dependent on the architecture. But, I may miss something here. Maybe, we can have an offline chat about this? Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RezaYazdaniAminabadi Yes, offline chat would be better. When do you like it?