Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Tensor Parallel implementation with PyTorch TP #34184

Merged
merged 26 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e60fb87
Simplify Tensor Parallel implementation with PyTorch TP
kwen2501 Oct 15, 2024
fd7f7c7
Move tp_plan to config
kwen2501 Oct 23, 2024
9224cab
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Oct 30, 2024
79cc524
Lint
kwen2501 Oct 30, 2024
a2934b3
Format and warning
kwen2501 Oct 30, 2024
a8fc418
Disable copy-from check
kwen2501 Oct 30, 2024
e84a388
Conditionally get attr from config
kwen2501 Oct 31, 2024
396d158
make fix-copies
kwen2501 Oct 31, 2024
7b346b5
Move base_model_tp_plan to PretrainedConfig
kwen2501 Oct 31, 2024
d60679b
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Oct 31, 2024
dda058a
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Nov 1, 2024
12fbbe7
Move TP into from_pretrained
kwen2501 Nov 7, 2024
02c8c39
Add device context for load
kwen2501 Nov 7, 2024
073c521
Do not serialize
kwen2501 Nov 7, 2024
db6e5ee
Move _tp_plan setting to post_init
kwen2501 Nov 7, 2024
5bb294e
Add has_tp_plan
kwen2501 Nov 14, 2024
290a7f1
Add test_tp
kwen2501 Nov 15, 2024
bd2e89c
Add 'Multi-gpu inference' doc
kwen2501 Nov 15, 2024
4892cef
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Nov 15, 2024
9648f31
Add backward support for device type identification
kwen2501 Nov 15, 2024
93ba283
Auto-detect accelerator
kwen2501 Nov 16, 2024
73524c9
supports_tp_plan
kwen2501 Nov 16, 2024
f312e55
copyright year
kwen2501 Nov 16, 2024
ca93bdb
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Nov 17, 2024
dc2672f
Merge branch 'main' into tp_llama
kwen2501 Nov 18, 2024
1e27d6f
Fix copy
kwen2501 Nov 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
translate_to_torch_parallel_style,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
Expand Down Expand Up @@ -4979,6 +4980,42 @@ def _is_quantized_training_enabled(self):

return self.hf_quantizer.is_trainable

def tensor_parallel(self, device_mesh):
kwen2501 marked this conversation as resolved.
Show resolved Hide resolved
"""
Tensor parallelize the model across the given device mesh.

Args:
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan:
logger.debug(
f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}"
)
# In model configs, we use a neutral type (string) to specify
# parallel styles, here we translate them into torch TP types.
# Using tree_map because `tp_plan` is a dict.
tp_plan = torch.utils._pytree.tree_map(
translate_to_torch_parallel_style,
tp_plan,
)
# Apply TP to current module.
torch.distributed.tensor.parallel.parallelize_module(
mod,
device_mesh=device_mesh,
parallelize_plan=tp_plan,
)

# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
self.apply(tplize)


PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None:
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@ class LlamaConfig(PretrainedConfig):

model_type = "llama"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `LlamaModel`
_base_model_tp_plan = {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to allow this for external use by removing _ so that we can allow users to define tp plan tweaks from config.json?

Given that, shall we as well allow for providing custom tp plan as input to LlamaConfig() that overrides the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good idea. We can make this public once we prove things work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, base_model_tp_plan should be supported as input to the PreTrainedConfig!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this variable base_model_tp_plan has to be added to PreTrainedConfig

class PretrainedConfig(PushToHubMixin):
with a default value as an empty dict {} which I believe is best possible default for any config sub class inheriting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kmehant @ArthurZucker for the suggestion. I moved base_model_tp_plan to PretrainedConfig in the latest commit.

"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
78 changes: 18 additions & 60 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,25 +237,7 @@ def __init__(self, config):
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)

gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj


Expand Down Expand Up @@ -317,31 +299,14 @@ def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)

key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)

value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down Expand Up @@ -383,12 +348,7 @@ def forward(

attn_output = attn_output.reshape(bsz, q_len, -1)

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None
Expand Down Expand Up @@ -559,9 +519,10 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down Expand Up @@ -845,8 +806,9 @@ def __init__(self, config: LlamaConfig):
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False

self.gradient_checkpointing = False
self._tp_plan = config._base_model_tp_plan
# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -1102,6 +1064,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
kwen2501 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, config):
super().__init__(config)
Expand Down Expand Up @@ -1198,13 +1161,8 @@ def forward(
)

hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

loss = None
if labels is not None:
Expand Down
26 changes: 26 additions & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from packaging import version
from safetensors.torch import storage_ptr, storage_size
from torch import nn
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
)

from .utils import is_torch_xla_available, logging

Expand Down Expand Up @@ -326,3 +331,24 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
else:
# Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
return torch.isin(elements, test_elements)


def translate_to_torch_parallel_style(style: str):
"""
In model configurations, we use a neutral type (string) to specify parallel
styles, here we translate them into torch.distributed tensor-parallel
types.
"""
if not isinstance(style, str):
raise ValueError(
f"Unsupported parallel style type {type(style)}, expected str"
)

if style == "colwise":
return ColwiseParallel()
elif style == "rowwise":
return RowwiseParallel()
elif style == "colwise_rep":
return ColwiseParallel(output_layouts=Replicate())
else:
raise ValueError(f"Unsupported parallel style value: {style}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A mapping from tp style to the correct function might be better.
We are also thinking of defining a TensorParallelConfig, your feedback is welcome here, as I don't know the variety of classes / args that might be used!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment! Indeed a mapping style would look better.
The only caveat is that the returned value here is an object rather than a constant or class (see the () behind ColwiseParallel). If we prepare a map, we may be always returning the same object -- the parallelize_module API may be able to support it I guess, I am just not sure if there is a contract guaranteeing that today.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pitching in :)

We should be able to use the same object since it applies required parallel operation to the module and returns a new copy - https://github.com/pytorch/pytorch/blob/86d4b7d60b264cae5a04a1b20719bcd7a5752a4c/torch/distributed/tensor/parallel/api.py#L95

Have also tested it empirically while benchmarking (#34194)

Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SOunds good!