Skip to content

Commit

Permalink
Add tensor model parallel inference and training with GPT-Neo model
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunwoongko committed Aug 30, 2021
1 parent 4046e66 commit 5bf8655
Show file tree
Hide file tree
Showing 7 changed files with 1,803 additions and 1 deletion.
88 changes: 87 additions & 1 deletion src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...parallelization_utils import (
ColumnParallelLinear,
Layer,
ParallelizationMixin,
ParallelPolicy,
RowParallelLinear,
)
from ...utils import logging
from .configuration_gpt_neo import GPTNeoConfig

Expand Down Expand Up @@ -579,7 +586,7 @@ def forward(
return outputs # hidden_states, present, (attentions, cross_attentions)


class GPTNeoPreTrainedModel(PreTrainedModel):
class GPTNeoPreTrainedModel(PreTrainedModel, ParallelizationMixin):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
Expand Down Expand Up @@ -608,6 +615,24 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def parallelize(
self,
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
):
"""
Parallelize model by given model parallel sizes
Args:
tensor_model_parallel_size (int): tensor model parallel size
pipeline_model_parallel_size (int): pipeline model parallel size
"""
self._parallelize(
policies=[GPTNeoParallelPolicy],
tensor_model_parallel_size=tensor_model_parallel_size,
pipeline_model_parallel_size=pipeline_model_parallel_size,
)


GPT_NEO_START_DOCSTRING = r"""
Expand Down Expand Up @@ -1144,3 +1169,64 @@ def forward(
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


class GPTNeoParallelPolicy(ParallelPolicy):
@staticmethod
def replace_arguments(config, world_size):
return {
# 1. reduce hidden size
"attn.attention.embed_dim": config.hidden_size // world_size,
# 2. reduce number of heads
"attn.attention.num_heads": config.num_heads // world_size,
}

@staticmethod
def attn_qkv():
return [
Layer(
weight="attn.attention.q_proj.weight",
replace=ColumnParallelLinear,
),
Layer(
weight="attn.attention.k_proj.weight",
replace=ColumnParallelLinear,
),
Layer(
weight="attn.attention.v_proj.weight",
replace=ColumnParallelLinear,
),
]

@staticmethod
def attn_out():
return [
Layer(
weight="attn.attention.out_proj.weight",
replace=RowParallelLinear,
),
]

@staticmethod
def mlp_in():
return [
Layer(
weight="mlp.c_fc.weight",
bias="mlp.c_fc.bias",
replace=ColumnParallelLinear,
),
]

@staticmethod
def mlp_out():
return [
Layer(
weight="mlp.c_proj.weight",
bias="mlp.c_proj.bias",
replace=RowParallelLinear,
),
]

@staticmethod
def original_layer_class():
return GPTNeoBlock
Loading

0 comments on commit 5bf8655

Please sign in to comment.