diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 44badd0f412302..3884ee2be081ec 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -218,6 +218,8 @@
title: CPU inference
- local: perf_infer_gpu_one
title: GPU inference
+ - local: perf_infer_gpu_multi
+ title: Multi-GPU inference
title: Optimizing inference
- local: big_models
title: Instantiate a big model
diff --git a/docs/source/en/perf_infer_gpu_multi.md b/docs/source/en/perf_infer_gpu_multi.md
new file mode 100644
index 00000000000000..9975094411527a
--- /dev/null
+++ b/docs/source/en/perf_infer_gpu_multi.md
@@ -0,0 +1,68 @@
+
+
+# Multi-GPU inference
+
+Built-in Tensor Parallelism (TP) is now available with certain models using PyTorch. Tensor parallelism shards a model onto multiple GPUs, enabling larger model sizes, and parallelizes computations such as matrix multiplication.
+
+To enable tensor parallel, pass the argument `tp_plan="auto"` to [`~AutoModelForCausalLM.from_pretrained`]:
+
+```python
+import os
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
+
+# Initialize distributed
+rank = int(os.environ["RANK"])
+device = torch.device(f"cuda:{rank}")
+torch.distributed.init_process_group("nccl", device_id=device)
+
+# Retrieve tensor parallel model
+model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ tp_plan="auto",
+)
+
+# Prepare input tokens
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+prompt = "Can I help"
+inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
+
+# Distributed run
+outputs = model(inputs)
+```
+
+You can use `torchrun` to launch the above script with multiple processes, each mapping to a GPU:
+
+```
+torchrun --nproc-per-node 4 demo.py
+```
+
+PyTorch tensor parallel is currently supported for the following models:
+* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
+
+You can request to add tensor parallel support for another model by opening a GitHub Issue or Pull Request.
+
+### Expected speedups
+
+You can benefit from considerable speedups for inference, especially for inputs with large batch size or long sequences.
+
+For a single forward pass on [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) with a sequence length of 512 and various batch sizes, the expected speedup is as follows:
+
+
+
+
diff --git a/docs/source/en/performance.md b/docs/source/en/performance.md
index 94e756cf33ada6..b9176be04ec206 100644
--- a/docs/source/en/performance.md
+++ b/docs/source/en/performance.md
@@ -53,7 +53,7 @@ sections we go through the steps to run inference on CPU and single/multi-GPU se
* [Inference on a single CPU](perf_infer_cpu)
* [Inference on a single GPU](perf_infer_gpu_one)
-* [Multi-GPU inference](perf_infer_gpu_one)
+* [Multi-GPU inference](perf_infer_gpu_multi)
* [XLA Integration for TensorFlow Models](tf_xla)
diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py
index 60f9f34cf861c9..e49eab86b4e12f 100755
--- a/src/transformers/configuration_utils.py
+++ b/src/transformers/configuration_utils.py
@@ -71,6 +71,8 @@ class PretrainedConfig(PushToHubMixin):
outputs of the model during inference.
- **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
naming of attributes.
+ - **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
+ parallel plan applied to the sub-module when `model.tensor_parallel` is called.
Common attributes (present in all subclasses):
@@ -194,6 +196,7 @@ class PretrainedConfig(PushToHubMixin):
sub_configs: Dict[str, "PretrainedConfig"] = {}
is_composition: bool = False
attribute_map: Dict[str, str] = {}
+ base_model_tp_plan: Optional[Dict[str, Any]] = None
_auto_class: Optional[str] = None
def __setattr__(self, key, value):
@@ -848,6 +851,9 @@ def to_diff_dict(self) -> Dict[str, Any]:
if "_attn_implementation_internal" in serializable_config_dict:
del serializable_config_dict["_attn_implementation_internal"]
+ # Do not serialize `base_model_tp_plan` for now
+ if "base_model_tp_plan" in serializable_config_dict:
+ del serializable_config_dict["base_model_tp_plan"]
return serializable_config_dict
@@ -867,6 +873,9 @@ def to_dict(self) -> Dict[str, Any]:
del output["_commit_hash"]
if "_attn_implementation_internal" in output:
del output["_attn_implementation_internal"]
+ # Do not serialize `base_model_tp_plan` for now
+ if "base_model_tp_plan" in output:
+ del output["base_model_tp_plan"]
# Transformers version when serializing the model
output["transformers_version"] = __version__
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index dd067c99a0b97d..57532c0c711b85 100755
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -55,6 +55,7 @@
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
+ translate_to_torch_parallel_style,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
@@ -1326,6 +1327,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Has support for a `QuantoQuantizedCache` instance as `past_key_values`
_supports_quantized_cache = False
+ # A tensor parallel plan to be applied to the model when TP is enabled. For
+ # top-level models, this attribute is currently defined in respective model
+ # code. For base models, this attribute comes from
+ # `config.base_model_tp_plan` during `post_init`.
+ _tp_plan = None
+
@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
@@ -1370,6 +1377,9 @@ def post_init(self):
"""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
+ # If current model is a base model, attach `base_model_tp_plan` from config
+ if self.base_model is self:
+ self._tp_plan = self.config.base_model_tp_plan
def dequantize(self):
"""
@@ -3399,6 +3409,11 @@ def from_pretrained(
# Cache path to the GGUF file
gguf_path = None
+ tp_plan = kwargs.pop("tp_plan", None)
+ if tp_plan is not None and tp_plan != "auto":
+ # TODO: we can relax this check when we support taking tp_plan from a json file, for example.
+ raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
+
if is_fsdp_enabled():
low_cpu_mem_usage = True
@@ -4000,6 +4015,7 @@ def from_pretrained(
# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
+ tp_device = None
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
@@ -4012,6 +4028,16 @@ def from_pretrained(
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
init_contexts.append(init_empty_weights())
+ elif tp_plan is not None:
+ if not torch.distributed.is_initialized():
+ raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
+
+ # Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
+ device_type = torch._C._get_accelerator().type
+ device_module = torch.get_device_module(device_type)
+ # Get device with index assuming equal number of devices per host
+ tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
+ init_contexts.append(tp_device)
if is_deepspeed_zero3_enabled() and is_quantized:
init_contexts.append(set_quantized_state())
@@ -4145,32 +4171,38 @@ def from_pretrained(
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
- (
- model,
- missing_keys,
- unexpected_keys,
- mismatched_keys,
- offload_index,
- error_msgs,
- ) = cls._load_pretrained_model(
- model,
- state_dict,
- loaded_state_dict_keys, # XXX: rename?
- resolved_archive_file,
- pretrained_model_name_or_path,
- ignore_mismatched_sizes=ignore_mismatched_sizes,
- sharded_metadata=sharded_metadata,
- _fast_init=_fast_init,
- low_cpu_mem_usage=low_cpu_mem_usage,
- device_map=device_map,
- offload_folder=offload_folder,
- offload_state_dict=offload_state_dict,
- dtype=torch_dtype,
- hf_quantizer=hf_quantizer,
- keep_in_fp32_modules=keep_in_fp32_modules,
- gguf_path=gguf_path,
- weights_only=weights_only,
- )
+ load_contexts = []
+ # Make sure we load onto targeted device
+ if tp_device is not None:
+ load_contexts.append(tp_device)
+
+ with ContextManagers(load_contexts):
+ (
+ model,
+ missing_keys,
+ unexpected_keys,
+ mismatched_keys,
+ offload_index,
+ error_msgs,
+ ) = cls._load_pretrained_model(
+ model,
+ state_dict,
+ loaded_state_dict_keys, # XXX: rename?
+ resolved_archive_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ sharded_metadata=sharded_metadata,
+ _fast_init=_fast_init,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ device_map=device_map,
+ offload_folder=offload_folder,
+ offload_state_dict=offload_state_dict,
+ dtype=torch_dtype,
+ hf_quantizer=hf_quantizer,
+ keep_in_fp32_modules=keep_in_fp32_modules,
+ gguf_path=gguf_path,
+ weights_only=weights_only,
+ )
# make sure token embedding weights are still tied if needed
model.tie_weights()
@@ -4254,6 +4286,16 @@ def from_pretrained(
}
return model, loading_info
+ if tp_plan is not None:
+ assert tp_device is not None, "tp_device not set!"
+ if not model.supports_tp_plan:
+ raise NotImplementedError("This model does not have a tensor parallel plan.")
+ # Assuming sharding the model onto the world
+ world_size = torch.distributed.get_world_size()
+ device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
+ # Apply Tensor Parallelism
+ model.tensor_parallel(device_mesh)
+
return model
@classmethod
@@ -4943,6 +4985,54 @@ def _is_quantized_training_enabled(self):
return self.hf_quantizer.is_trainable
+ @property
+ def supports_tp_plan(self):
+ """
+ Returns whether the model has a tensor parallelism plan.
+ """
+ if self._tp_plan is not None:
+ return True
+ # Check if base model has a TP plan
+ if getattr(self.base_model, "_tp_plan", None) is not None:
+ return True
+ return False
+
+ def tensor_parallel(self, device_mesh):
+ """
+ Tensor parallelize the model across the given device mesh.
+
+ Args:
+ device_mesh (`torch.distributed.DeviceMesh`):
+ The device mesh to use for tensor parallelism.
+ """
+
+ # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
+ # No op if `_tp_plan` attribute does not exist under the module.
+ # This is a helper function to be used with `model.apply` to recursively
+ # parallelize a model.
+ def tplize(mod: torch.nn.Module) -> None:
+ tp_plan = getattr(mod, "_tp_plan", None)
+ if tp_plan is None:
+ return
+ logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}")
+ # In model configs, we use a neutral type (string) to specify
+ # parallel styles, here we translate them into torch TP types.
+ # Using tree_map because `tp_plan` is a dict.
+ tp_plan = torch.utils._pytree.tree_map(
+ translate_to_torch_parallel_style,
+ tp_plan,
+ )
+ # Apply TP to current module.
+ torch.distributed.tensor.parallel.parallelize_module(
+ mod,
+ device_mesh=device_mesh,
+ parallelize_plan=tp_plan,
+ )
+
+ # `apply` is a native method of `nn.Module` that recursively applies a
+ # function to every submodule.
+ self.apply(tplize)
+
@property
def loss_function(self):
if getattr(self.config, "loss_type", None) is not None:
diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py
index b215fb6561bf81..0261f997da1110 100644
--- a/src/transformers/models/cohere/modeling_cohere.py
+++ b/src/transformers/models/cohere/modeling_cohere.py
@@ -1068,7 +1068,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
+# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py
index fa3fadc4349a3c..6fead73eced704 100644
--- a/src/transformers/models/gemma/modeling_gemma.py
+++ b/src/transformers/models/gemma/modeling_gemma.py
@@ -720,7 +720,10 @@ def __init__(self, config: GemmaConfig):
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
self.gradient_checkpointing = False
+ if getattr(config, "pretraining_tp", 1) != 1:
+ logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing
self.post_init()
@@ -982,6 +985,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py
index 626e5537fc06bc..6a3d8f27fb177d 100644
--- a/src/transformers/models/gemma2/modeling_gemma2.py
+++ b/src/transformers/models/gemma2/modeling_gemma2.py
@@ -740,7 +740,10 @@ def __init__(self, config: Gemma2Config):
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
self.gradient_checkpointing = False
+ if getattr(config, "pretraining_tp", 1) != 1:
+ logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing
self.post_init()
@@ -961,6 +964,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py
index 248ec4021791b1..58a89d90b44ff5 100644
--- a/src/transformers/models/glm/modeling_glm.py
+++ b/src/transformers/models/glm/modeling_glm.py
@@ -708,6 +708,8 @@ def __init__(self, config: GlmConfig):
dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta
)
self.gradient_checkpointing = False
+ if getattr(config, "pretraining_tp", 1) != 1:
+ logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing
self.post_init()
@@ -967,6 +969,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config: GlmConfig):
super().__init__(config)
diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py
index a3667e06534564..98d5ecdd2a4fdb 100644
--- a/src/transformers/models/llama/configuration_llama.py
+++ b/src/transformers/models/llama/configuration_llama.py
@@ -141,6 +141,16 @@ class LlamaConfig(PretrainedConfig):
model_type = "llama"
keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `LlamaModel`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
def __init__(
self,
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 4d95f01849d678..679296648a9135 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -21,7 +21,6 @@
from typing import List, Optional, Tuple, Union
import torch
-import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
@@ -240,25 +239,7 @@ def __init__(self, config):
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
- if self.config.pretraining_tp > 1:
- slice = self.intermediate_size // self.config.pretraining_tp
- gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
- up_proj_slices = self.up_proj.weight.split(slice, dim=0)
- down_proj_slices = self.down_proj.weight.split(slice, dim=1)
-
- gate_proj = torch.cat(
- [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
- )
- up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
-
- intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
- down_proj = [
- F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
- ]
- down_proj = sum(down_proj)
- else:
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
-
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
@@ -320,31 +301,14 @@ def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
- if self.config.pretraining_tp > 1:
- key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
- query_slices = self.q_proj.weight.split(
- (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
- )
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
-
- query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
- query_states = torch.cat(query_states, dim=-1)
-
- key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
- key_states = torch.cat(key_states, dim=-1)
-
- value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
- value_states = torch.cat(value_states, dim=-1)
-
- else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
@@ -386,12 +350,7 @@ def forward(
attn_output = attn_output.reshape(bsz, q_len, -1)
- if self.config.pretraining_tp > 1:
- attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
- o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
- attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
- else:
- attn_output = self.o_proj(attn_output)
+ attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
@@ -564,9 +523,10 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
@@ -850,7 +810,10 @@ def __init__(self, config: LlamaConfig):
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
+
self.gradient_checkpointing = False
+ if getattr(config, "pretraining_tp", 1) != 1:
+ logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing
self.post_init()
@@ -1113,6 +1076,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
@@ -1211,13 +1175,8 @@ def forward(
)
hidden_states = outputs[0]
- if self.config.pretraining_tp > 1:
- lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
- logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
- logits = torch.cat(logits, dim=-1)
- else:
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py
index d4eb348260c1a4..8de6bc90ea3fec 100644
--- a/src/transformers/models/nemotron/modeling_nemotron.py
+++ b/src/transformers/models/nemotron/modeling_nemotron.py
@@ -980,7 +980,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
+# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py
index 60225d4759c6ab..d865c51e50578e 100644
--- a/src/transformers/models/olmo/modeling_olmo.py
+++ b/src/transformers/models/olmo/modeling_olmo.py
@@ -1020,7 +1020,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo
+# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo
class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
diff --git a/src/transformers/models/olmo_1124/modeling_olmo_1124.py b/src/transformers/models/olmo_1124/modeling_olmo_1124.py
index 52eb1f544bb484..5a9cca39b88570 100644
--- a/src/transformers/models/olmo_1124/modeling_olmo_1124.py
+++ b/src/transformers/models/olmo_1124/modeling_olmo_1124.py
@@ -971,6 +971,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
+# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO_1124,Llama->Olmo1124
class Olmo1124ForCausalLM(Olmo1124PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py
index cbb8db0f59dd02..47cb0964eca8b6 100644
--- a/src/transformers/models/olmoe/modeling_olmoe.py
+++ b/src/transformers/models/olmoe/modeling_olmoe.py
@@ -888,7 +888,7 @@ def _init_weights(self, module):
"The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
OLMOE_START_DOCSTRING,
)
-# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe
+# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe
class OlmoeModel(OlmoePreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoeDecoderLayer`]
diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
index d3164b17fe130c..2b3cf7eb0cb82e 100644
--- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
+++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
@@ -775,7 +775,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
return causal_mask
-# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma
+# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma
class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py
index a808f2cb63e861..a595f8bc9e1af6 100644
--- a/src/transformers/pytorch_utils.py
+++ b/src/transformers/pytorch_utils.py
@@ -20,6 +20,11 @@
from packaging import version
from safetensors.torch import storage_ptr, storage_size
from torch import nn
+from torch.distributed.tensor import Replicate
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ RowwiseParallel,
+)
from .utils import is_torch_xla_available, logging
@@ -329,3 +334,22 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
else:
# Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
return torch.isin(elements, test_elements)
+
+
+def translate_to_torch_parallel_style(style: str):
+ """
+ In model configurations, we use a neutral type (string) to specify parallel
+ styles, here we translate them into torch.distributed tensor-parallel
+ types.
+ """
+ if not isinstance(style, str):
+ raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
+
+ if style == "colwise":
+ return ColwiseParallel()
+ elif style == "rowwise":
+ return RowwiseParallel()
+ elif style == "colwise_rep":
+ return ColwiseParallel(output_layouts=Replicate())
+ else:
+ raise ValueError(f"Unsupported parallel style value: {style}")
diff --git a/tests/tp/test_tp.py b/tests/tp/test_tp.py
new file mode 100644
index 00000000000000..2139a648867b61
--- /dev/null
+++ b/tests/tp/test_tp.py
@@ -0,0 +1,91 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from transformers import is_torch_available
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import LlamaModel
+from transformers.testing_utils import (
+ TestCasePlus,
+ execute_subprocess_async,
+ get_torch_dist_unique_port,
+ require_torch_multi_gpu,
+)
+
+
+if is_torch_available():
+ import torch
+
+
+class TestTensorParallel(TestCasePlus):
+ @require_torch_multi_gpu
+ def test_tp(self):
+ distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
+ --master_port={get_torch_dist_unique_port()}
+ {self.test_file_dir}/test_tp.py
+ """.split()
+ output_dir = self.get_auto_remove_tmp_dir()
+ args = f"--output_dir {output_dir} --report_to none".split()
+ cmd = ["torchrun"] + distributed_args + args
+ print(cmd)
+ execute_subprocess_async(cmd, env=self.get_env())
+ # successful return here == success - any errors would have caused an error in the sub-call
+
+
+if __name__ == "__main__":
+ # The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
+ # CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tp/test_tp.py
+ # or
+ # PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
+
+ if not is_torch_available():
+ exit(0)
+
+ # Test settings
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
+ bs = 4
+ seqlen = 64
+
+ # Get distributed settings
+ rank = int(os.environ["RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+
+ # Initialize distributed
+ device = torch.device(f"cuda:{rank}")
+ torch.distributed.init_process_group("nccl", device_id=device)
+ device_mesh = torch.distributed.init_device_mesh("cuda", (world_size,))
+
+ # Get model config
+ config = LlamaConfig.from_pretrained(model_id)
+ # Shrink model size
+ config.num_hidden_layers //= 8
+ config.vocab_size //= 8
+
+ # Instantiate model
+ with device:
+ model = LlamaModel(config)
+
+ model.eval()
+
+ # Tensor Parallel
+ if world_size > 1:
+ model.tensor_parallel(device_mesh)
+
+ # Run model
+ inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
+ with torch.no_grad():
+ out = model(inputs)
+
+ assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size])