Skip to content

Commit

Permalink
Simplify Tensor Parallel implementation with PyTorch TP (huggingface#…
Browse files Browse the repository at this point in the history
…34184)

* Simplify Tensor Parallel implementation with PyTorch TP

* Move tp_plan to config

* Lint

* Format and warning

* Disable copy-from check

* Conditionally get attr from config

* make fix-copies

* Move base_model_tp_plan to PretrainedConfig

* Move TP into from_pretrained

* Add device context for load

* Do not serialize

* Move _tp_plan setting to post_init

* Add has_tp_plan

* Add test_tp

* Add 'Multi-gpu inference' doc

* Add backward support for device type identification

* Auto-detect accelerator

* supports_tp_plan

* copyright year

* Fix copy
  • Loading branch information
kwen2501 authored and BernardZach committed Dec 5, 2024
1 parent 6998ade commit 43d4c76
Show file tree
Hide file tree
Showing 18 changed files with 357 additions and 92 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@
title: CPU inference
- local: perf_infer_gpu_one
title: GPU inference
- local: perf_infer_gpu_multi
title: Multi-GPU inference
title: Optimizing inference
- local: big_models
title: Instantiate a big model
Expand Down
68 changes: 68 additions & 0 deletions docs/source/en/perf_infer_gpu_multi.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Multi-GPU inference

Built-in Tensor Parallelism (TP) is now available with certain models using PyTorch. Tensor parallelism shards a model onto multiple GPUs, enabling larger model sizes, and parallelizes computations such as matrix multiplication.

To enable tensor parallel, pass the argument `tp_plan="auto"` to [`~AutoModelForCausalLM.from_pretrained`]:

```python
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

# Initialize distributed
rank = int(os.environ["RANK"])
device = torch.device(f"cuda:{rank}")
torch.distributed.init_process_group("nccl", device_id=device)

# Retrieve tensor parallel model
model = AutoModelForCausalLM.from_pretrained(
model_id,
tp_plan="auto",
)

# Prepare input tokens
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# Distributed run
outputs = model(inputs)
```

You can use `torchrun` to launch the above script with multiple processes, each mapping to a GPU:

```
torchrun --nproc-per-node 4 demo.py
```

PyTorch tensor parallel is currently supported for the following models:
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)

You can request to add tensor parallel support for another model by opening a GitHub Issue or Pull Request.

### Expected speedups

You can benefit from considerable speedups for inference, especially for inputs with large batch size or long sequences.

For a single forward pass on [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) with a sequence length of 512 and various batch sizes, the expected speedup is as follows:

<div style="text-align: center">
<img src="huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Meta-Llama-3-8B-Instruct, seqlen = 512, python, w_ compile.png">
</div>
2 changes: 1 addition & 1 deletion docs/source/en/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ sections we go through the steps to run inference on CPU and single/multi-GPU se

* [Inference on a single CPU](perf_infer_cpu)
* [Inference on a single GPU](perf_infer_gpu_one)
* [Multi-GPU inference](perf_infer_gpu_one)
* [Multi-GPU inference](perf_infer_gpu_multi)
* [XLA Integration for TensorFlow Models](tf_xla)


Expand Down
9 changes: 9 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class PretrainedConfig(PushToHubMixin):
outputs of the model during inference.
- **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
naming of attributes.
- **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
parallel plan applied to the sub-module when `model.tensor_parallel` is called.
Common attributes (present in all subclasses):
Expand Down Expand Up @@ -194,6 +196,7 @@ class PretrainedConfig(PushToHubMixin):
sub_configs: Dict[str, "PretrainedConfig"] = {}
is_composition: bool = False
attribute_map: Dict[str, str] = {}
base_model_tp_plan: Optional[Dict[str, Any]] = None
_auto_class: Optional[str] = None

def __setattr__(self, key, value):
Expand Down Expand Up @@ -848,6 +851,9 @@ def to_diff_dict(self) -> Dict[str, Any]:

if "_attn_implementation_internal" in serializable_config_dict:
del serializable_config_dict["_attn_implementation_internal"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_tp_plan"]

return serializable_config_dict

Expand All @@ -867,6 +873,9 @@ def to_dict(self) -> Dict[str, Any]:
del output["_commit_hash"]
if "_attn_implementation_internal" in output:
del output["_attn_implementation_internal"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in output:
del output["base_model_tp_plan"]

# Transformers version when serializing the model
output["transformers_version"] = __version__
Expand Down
142 changes: 116 additions & 26 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
translate_to_torch_parallel_style,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
Expand Down Expand Up @@ -1326,6 +1327,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Has support for a `QuantoQuantizedCache` instance as `past_key_values`
_supports_quantized_cache = False

# A tensor parallel plan to be applied to the model when TP is enabled. For
# top-level models, this attribute is currently defined in respective model
# code. For base models, this attribute comes from
# `config.base_model_tp_plan` during `post_init`.
_tp_plan = None

@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
Expand Down Expand Up @@ -1370,6 +1377,9 @@ def post_init(self):
"""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
# If current model is a base model, attach `base_model_tp_plan` from config
if self.base_model is self:
self._tp_plan = self.config.base_model_tp_plan

def dequantize(self):
"""
Expand Down Expand Up @@ -3399,6 +3409,11 @@ def from_pretrained(
# Cache path to the GGUF file
gguf_path = None

tp_plan = kwargs.pop("tp_plan", None)
if tp_plan is not None and tp_plan != "auto":
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")

if is_fsdp_enabled():
low_cpu_mem_usage = True

Expand Down Expand Up @@ -4000,6 +4015,7 @@ def from_pretrained(

# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
tp_device = None

if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
Expand All @@ -4012,6 +4028,16 @@ def from_pretrained(
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
init_contexts.append(init_empty_weights())
elif tp_plan is not None:
if not torch.distributed.is_initialized():
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")

# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
# Get device with index assuming equal number of devices per host
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
init_contexts.append(tp_device)

if is_deepspeed_zero3_enabled() and is_quantized:
init_contexts.append(set_quantized_state())
Expand Down Expand Up @@ -4145,32 +4171,38 @@ def from_pretrained(
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
)
load_contexts = []
# Make sure we load onto targeted device
if tp_device is not None:
load_contexts.append(tp_device)

with ContextManagers(load_contexts):
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
)

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down Expand Up @@ -4254,6 +4286,16 @@ def from_pretrained(
}
return model, loading_info

if tp_plan is not None:
assert tp_device is not None, "tp_device not set!"
if not model.supports_tp_plan:
raise NotImplementedError("This model does not have a tensor parallel plan.")
# Assuming sharding the model onto the world
world_size = torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
# Apply Tensor Parallelism
model.tensor_parallel(device_mesh)

return model

@classmethod
Expand Down Expand Up @@ -4943,6 +4985,54 @@ def _is_quantized_training_enabled(self):

return self.hf_quantizer.is_trainable

@property
def supports_tp_plan(self):
"""
Returns whether the model has a tensor parallelism plan.
"""
if self._tp_plan is not None:
return True
# Check if base model has a TP plan
if getattr(self.base_model, "_tp_plan", None) is not None:
return True
return False

def tensor_parallel(self, device_mesh):
"""
Tensor parallelize the model across the given device mesh.
Args:
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""

# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan is None:
return
logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}")
# In model configs, we use a neutral type (string) to specify
# parallel styles, here we translate them into torch TP types.
# Using tree_map because `tp_plan` is a dict.
tp_plan = torch.utils._pytree.tree_map(
translate_to_torch_parallel_style,
tp_plan,
)
# Apply TP to current module.
torch.distributed.tensor.parallel.parallelize_module(
mod,
device_mesh=device_mesh,
parallelize_plan=tp_plan,
)

# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
self.apply(tplize)

@property
def loss_function(self):
if getattr(self.config, "loss_type", None) is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,10 @@ def __init__(self, config: GemmaConfig):
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -982,6 +985,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config):
super().__init__(config)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,10 @@ def __init__(self, config: Gemma2Config):
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -961,6 +964,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config):
super().__init__(config)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,8 @@ def __init__(self, config: GlmConfig):
dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta
)
self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -967,6 +969,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config: GlmConfig):
super().__init__(config)
Expand Down
Loading

0 comments on commit 43d4c76

Please sign in to comment.