Skip to content

Commit

Permalink
Move tp_plan to config
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Oct 23, 2024
1 parent e60fb87 commit fd7f7c7
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 22 deletions.
12 changes: 12 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
translate_to_torch_parallel_style,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
Expand Down Expand Up @@ -4994,6 +4995,17 @@ def tensor_parallel(self, device_mesh):
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan:
logger.debug(
f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}"
)
# In model configs, we use a neutral type (string) to specify
# parallel styles, here we translate them into torch TP types.
# Using tree_map because `tp_plan` is a dict.
tp_plan = torch.utils._pytree.tree_map(
translate_to_torch_parallel_style,
tp_plan,
)
# Apply TP to current module.
torch.distributed.tensor.parallel.parallelize_module(
mod,
device_mesh=device_mesh,
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@ class LlamaConfig(PretrainedConfig):

model_type = "llama"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `LlamaModel`
_base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
25 changes: 3 additions & 22 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
)
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
Expand Down Expand Up @@ -231,12 +226,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):


class LlamaMLP(nn.Module):
_tp_plan = {
"gate_proj": ColwiseParallel(),
"up_proj": ColwiseParallel(),
"down_proj": RowwiseParallel(),
}

def __init__(self, config):
super().__init__()
self.config = config
Expand Down Expand Up @@ -267,13 +256,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

_tp_plan = {
"q_proj": ColwiseParallel(),
"k_proj": ColwiseParallel(),
"v_proj": ColwiseParallel(),
"o_proj": RowwiseParallel(),
}

def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
Expand Down Expand Up @@ -824,8 +806,9 @@ def __init__(self, config: LlamaConfig):
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False

self.gradient_checkpointing = False
self._tp_plan = config._base_model_tp_plan
# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -1081,9 +1064,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {
"lm_head": ColwiseParallel(output_layouts=Replicate()),
}
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config):
super().__init__(config)
Expand Down
26 changes: 26 additions & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from packaging import version
from safetensors.torch import storage_ptr, storage_size
from torch import nn
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
)

from .utils import is_torch_xla_available, logging

Expand Down Expand Up @@ -326,3 +331,24 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
else:
# Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
return torch.isin(elements, test_elements)


def translate_to_torch_parallel_style(style: str):
"""
In model configurations, we use a neutral type (string) to specify parallel
styles, here we translate them into torch.distributed tensor-parallel
types.
"""
if not isinstance(style, str):
raise ValueError(
f"Unsupported parallel style type {type(style)}, expected str"
)

if style == "colwise":
return ColwiseParallel()
elif style == "rowwise":
return RowwiseParallel()
elif style == "colwise_rep":
return ColwiseParallel(output_layouts=Replicate())
else:
raise ValueError(f"Unsupported parallel style value: {style}")

0 comments on commit fd7f7c7

Please sign in to comment.