diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 44badd0f412302..3884ee2be081ec 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -218,6 +218,8 @@ title: CPU inference - local: perf_infer_gpu_one title: GPU inference + - local: perf_infer_gpu_multi + title: Multi-GPU inference title: Optimizing inference - local: big_models title: Instantiate a big model diff --git a/docs/source/en/perf_infer_gpu_multi.md b/docs/source/en/perf_infer_gpu_multi.md new file mode 100644 index 00000000000000..9975094411527a --- /dev/null +++ b/docs/source/en/perf_infer_gpu_multi.md @@ -0,0 +1,68 @@ + + +# Multi-GPU inference + +Built-in Tensor Parallelism (TP) is now available with certain models using PyTorch. Tensor parallelism shards a model onto multiple GPUs, enabling larger model sizes, and parallelizes computations such as matrix multiplication. + +To enable tensor parallel, pass the argument `tp_plan="auto"` to [`~AutoModelForCausalLM.from_pretrained`]: + +```python +import os +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_id = "meta-llama/Meta-Llama-3-8B-Instruct" + +# Initialize distributed +rank = int(os.environ["RANK"]) +device = torch.device(f"cuda:{rank}") +torch.distributed.init_process_group("nccl", device_id=device) + +# Retrieve tensor parallel model +model = AutoModelForCausalLM.from_pretrained( + model_id, + tp_plan="auto", +) + +# Prepare input tokens +tokenizer = AutoTokenizer.from_pretrained(model_id) +prompt = "Can I help" +inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device) + +# Distributed run +outputs = model(inputs) +``` + +You can use `torchrun` to launch the above script with multiple processes, each mapping to a GPU: + +``` +torchrun --nproc-per-node 4 demo.py +``` + +PyTorch tensor parallel is currently supported for the following models: +* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) + +You can request to add tensor parallel support for another model by opening a GitHub Issue or Pull Request. + +### Expected speedups + +You can benefit from considerable speedups for inference, especially for inputs with large batch size or long sequences. + +For a single forward pass on [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) with a sequence length of 512 and various batch sizes, the expected speedup is as follows: + +
+ +
diff --git a/docs/source/en/performance.md b/docs/source/en/performance.md index 94e756cf33ada6..b9176be04ec206 100644 --- a/docs/source/en/performance.md +++ b/docs/source/en/performance.md @@ -53,7 +53,7 @@ sections we go through the steps to run inference on CPU and single/multi-GPU se * [Inference on a single CPU](perf_infer_cpu) * [Inference on a single GPU](perf_infer_gpu_one) -* [Multi-GPU inference](perf_infer_gpu_one) +* [Multi-GPU inference](perf_infer_gpu_multi) * [XLA Integration for TensorFlow Models](tf_xla) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 60f9f34cf861c9..e49eab86b4e12f 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -71,6 +71,8 @@ class PretrainedConfig(PushToHubMixin): outputs of the model during inference. - **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized naming of attributes. + - **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor + parallel plan applied to the sub-module when `model.tensor_parallel` is called. Common attributes (present in all subclasses): @@ -194,6 +196,7 @@ class PretrainedConfig(PushToHubMixin): sub_configs: Dict[str, "PretrainedConfig"] = {} is_composition: bool = False attribute_map: Dict[str, str] = {} + base_model_tp_plan: Optional[Dict[str, Any]] = None _auto_class: Optional[str] = None def __setattr__(self, key, value): @@ -848,6 +851,9 @@ def to_diff_dict(self) -> Dict[str, Any]: if "_attn_implementation_internal" in serializable_config_dict: del serializable_config_dict["_attn_implementation_internal"] + # Do not serialize `base_model_tp_plan` for now + if "base_model_tp_plan" in serializable_config_dict: + del serializable_config_dict["base_model_tp_plan"] return serializable_config_dict @@ -867,6 +873,9 @@ def to_dict(self) -> Dict[str, Any]: del output["_commit_hash"] if "_attn_implementation_internal" in output: del output["_attn_implementation_internal"] + # Do not serialize `base_model_tp_plan` for now + if "base_model_tp_plan" in output: + del output["base_model_tp_plan"] # Transformers version when serializing the model output["transformers_version"] = __version__ diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index dd067c99a0b97d..57532c0c711b85 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -55,6 +55,7 @@ prune_conv1d_layer, prune_layer, prune_linear_layer, + translate_to_torch_parallel_style, ) from .quantizers import AutoHfQuantizer, HfQuantizer from .quantizers.quantizers_utils import get_module_from_name @@ -1326,6 +1327,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Has support for a `QuantoQuantizedCache` instance as `past_key_values` _supports_quantized_cache = False + # A tensor parallel plan to be applied to the model when TP is enabled. For + # top-level models, this attribute is currently defined in respective model + # code. For base models, this attribute comes from + # `config.base_model_tp_plan` during `post_init`. + _tp_plan = None + @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: """ @@ -1370,6 +1377,9 @@ def post_init(self): """ self.init_weights() self._backward_compatibility_gradient_checkpointing() + # If current model is a base model, attach `base_model_tp_plan` from config + if self.base_model is self: + self._tp_plan = self.config.base_model_tp_plan def dequantize(self): """ @@ -3399,6 +3409,11 @@ def from_pretrained( # Cache path to the GGUF file gguf_path = None + tp_plan = kwargs.pop("tp_plan", None) + if tp_plan is not None and tp_plan != "auto": + # TODO: we can relax this check when we support taking tp_plan from a json file, for example. + raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.") + if is_fsdp_enabled(): low_cpu_mem_usage = True @@ -4000,6 +4015,7 @@ def from_pretrained( # Instantiate model. init_contexts = [no_init_weights(_enable=_fast_init)] + tp_device = None if is_deepspeed_zero3_enabled() and not is_quantized: import deepspeed @@ -4012,6 +4028,16 @@ def from_pretrained( f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" ) init_contexts.append(init_empty_weights()) + elif tp_plan is not None: + if not torch.distributed.is_initialized(): + raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.") + + # Detect the accelerator on the machine. If no accelerator is available, it returns CPU. + device_type = torch._C._get_accelerator().type + device_module = torch.get_device_module(device_type) + # Get device with index assuming equal number of devices per host + tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count()) + init_contexts.append(tp_device) if is_deepspeed_zero3_enabled() and is_quantized: init_contexts.append(set_quantized_state()) @@ -4145,32 +4171,38 @@ def from_pretrained( if dtype_orig is not None: torch.set_default_dtype(dtype_orig) - ( - model, - missing_keys, - unexpected_keys, - mismatched_keys, - offload_index, - error_msgs, - ) = cls._load_pretrained_model( - model, - state_dict, - loaded_state_dict_keys, # XXX: rename? - resolved_archive_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - sharded_metadata=sharded_metadata, - _fast_init=_fast_init, - low_cpu_mem_usage=low_cpu_mem_usage, - device_map=device_map, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - dtype=torch_dtype, - hf_quantizer=hf_quantizer, - keep_in_fp32_modules=keep_in_fp32_modules, - gguf_path=gguf_path, - weights_only=weights_only, - ) + load_contexts = [] + # Make sure we load onto targeted device + if tp_device is not None: + load_contexts.append(tp_device) + + with ContextManagers(load_contexts): + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + loaded_state_dict_keys, # XXX: rename? + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=low_cpu_mem_usage, + device_map=device_map, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + gguf_path=gguf_path, + weights_only=weights_only, + ) # make sure token embedding weights are still tied if needed model.tie_weights() @@ -4254,6 +4286,16 @@ def from_pretrained( } return model, loading_info + if tp_plan is not None: + assert tp_device is not None, "tp_device not set!" + if not model.supports_tp_plan: + raise NotImplementedError("This model does not have a tensor parallel plan.") + # Assuming sharding the model onto the world + world_size = torch.distributed.get_world_size() + device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,)) + # Apply Tensor Parallelism + model.tensor_parallel(device_mesh) + return model @classmethod @@ -4943,6 +4985,54 @@ def _is_quantized_training_enabled(self): return self.hf_quantizer.is_trainable + @property + def supports_tp_plan(self): + """ + Returns whether the model has a tensor parallelism plan. + """ + if self._tp_plan is not None: + return True + # Check if base model has a TP plan + if getattr(self.base_model, "_tp_plan", None) is not None: + return True + return False + + def tensor_parallel(self, device_mesh): + """ + Tensor parallelize the model across the given device mesh. + + Args: + device_mesh (`torch.distributed.DeviceMesh`): + The device mesh to use for tensor parallelism. + """ + + # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. + # No op if `_tp_plan` attribute does not exist under the module. + # This is a helper function to be used with `model.apply` to recursively + # parallelize a model. + def tplize(mod: torch.nn.Module) -> None: + tp_plan = getattr(mod, "_tp_plan", None) + if tp_plan is None: + return + logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}") + # In model configs, we use a neutral type (string) to specify + # parallel styles, here we translate them into torch TP types. + # Using tree_map because `tp_plan` is a dict. + tp_plan = torch.utils._pytree.tree_map( + translate_to_torch_parallel_style, + tp_plan, + ) + # Apply TP to current module. + torch.distributed.tensor.parallel.parallelize_module( + mod, + device_mesh=device_mesh, + parallelize_plan=tp_plan, + ) + + # `apply` is a native method of `nn.Module` that recursively applies a + # function to every submodule. + self.apply(tplize) + @property def loss_function(self): if getattr(self.config, "loss_type", None) is not None: diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index b215fb6561bf81..0261f997da1110 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1068,7 +1068,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere +# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index fa3fadc4349a3c..6fead73eced704 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -720,7 +720,10 @@ def __init__(self, config: GemmaConfig): [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -982,6 +985,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 626e5537fc06bc..6a3d8f27fb177d 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -740,7 +740,10 @@ def __init__(self, config: Gemma2Config): [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -961,6 +964,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 248ec4021791b1..58a89d90b44ff5 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -708,6 +708,8 @@ def __init__(self, config: GlmConfig): dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta ) self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -967,6 +969,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config: GlmConfig): super().__init__(config) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index a3667e06534564..98d5ecdd2a4fdb 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -141,6 +141,16 @@ class LlamaConfig(PretrainedConfig): model_type = "llama" keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `LlamaModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } def __init__( self, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4d95f01849d678..679296648a9135 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -21,7 +21,6 @@ from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -240,25 +239,7 @@ def __init__(self, config): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - if self.config.pretraining_tp > 1: - slice = self.intermediate_size // self.config.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 - ) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -320,31 +301,14 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( @@ -386,12 +350,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None @@ -564,9 +523,10 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( @@ -850,7 +810,10 @@ def __init__(self, config: LlamaConfig): ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -1113,6 +1076,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config): super().__init__(config) @@ -1211,13 +1175,8 @@ def forward( ) hidden_states = outputs[0] - if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index d4eb348260c1a4..8de6bc90ea3fec 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -980,7 +980,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 60225d4759c6ab..d865c51e50578e 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1020,7 +1020,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo +# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/olmo_1124/modeling_olmo_1124.py b/src/transformers/models/olmo_1124/modeling_olmo_1124.py index 52eb1f544bb484..5a9cca39b88570 100644 --- a/src/transformers/models/olmo_1124/modeling_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modeling_olmo_1124.py @@ -971,6 +971,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO_1124,Llama->Olmo1124 class Olmo1124ForCausalLM(Olmo1124PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index cbb8db0f59dd02..47cb0964eca8b6 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -888,7 +888,7 @@ def _init_weights(self, module): "The bare Olmoe Model outputting raw hidden-states without any specific head on top.", OLMOE_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe +# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe class OlmoeModel(OlmoePreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoeDecoderLayer`] diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index d3164b17fe130c..2b3cf7eb0cb82e 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -775,7 +775,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): return causal_mask -# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma +# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index a808f2cb63e861..a595f8bc9e1af6 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -20,6 +20,11 @@ from packaging import version from safetensors.torch import storage_ptr, storage_size from torch import nn +from torch.distributed.tensor import Replicate +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, +) from .utils import is_torch_xla_available, logging @@ -329,3 +334,22 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) else: # Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045 return torch.isin(elements, test_elements) + + +def translate_to_torch_parallel_style(style: str): + """ + In model configurations, we use a neutral type (string) to specify parallel + styles, here we translate them into torch.distributed tensor-parallel + types. + """ + if not isinstance(style, str): + raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") + + if style == "colwise": + return ColwiseParallel() + elif style == "rowwise": + return RowwiseParallel() + elif style == "colwise_rep": + return ColwiseParallel(output_layouts=Replicate()) + else: + raise ValueError(f"Unsupported parallel style value: {style}") diff --git a/tests/tp/test_tp.py b/tests/tp/test_tp.py new file mode 100644 index 00000000000000..2139a648867b61 --- /dev/null +++ b/tests/tp/test_tp.py @@ -0,0 +1,91 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from transformers import is_torch_available +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaModel +from transformers.testing_utils import ( + TestCasePlus, + execute_subprocess_async, + get_torch_dist_unique_port, + require_torch_multi_gpu, +) + + +if is_torch_available(): + import torch + + +class TestTensorParallel(TestCasePlus): + @require_torch_multi_gpu + def test_tp(self): + distributed_args = f"""--nproc_per_node={torch.cuda.device_count()} + --master_port={get_torch_dist_unique_port()} + {self.test_file_dir}/test_tp.py + """.split() + output_dir = self.get_auto_remove_tmp_dir() + args = f"--output_dir {output_dir} --report_to none".split() + cmd = ["torchrun"] + distributed_args + args + print(cmd) + execute_subprocess_async(cmd, env=self.get_env()) + # successful return here == success - any errors would have caused an error in the sub-call + + +if __name__ == "__main__": + # The script below is meant to be run under torch.distributed, on a machine with multiple GPUs: + # CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tp/test_tp.py + # or + # PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py + + if not is_torch_available(): + exit(0) + + # Test settings + model_id = "meta-llama/Meta-Llama-3-8B-Instruct" + bs = 4 + seqlen = 64 + + # Get distributed settings + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # Initialize distributed + device = torch.device(f"cuda:{rank}") + torch.distributed.init_process_group("nccl", device_id=device) + device_mesh = torch.distributed.init_device_mesh("cuda", (world_size,)) + + # Get model config + config = LlamaConfig.from_pretrained(model_id) + # Shrink model size + config.num_hidden_layers //= 8 + config.vocab_size //= 8 + + # Instantiate model + with device: + model = LlamaModel(config) + + model.eval() + + # Tensor Parallel + if world_size > 1: + model.tensor_parallel(device_mesh) + + # Run model + inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device) + with torch.no_grad(): + out = model(inputs) + + assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size])