Skip to content

Commit

Permalink
protect tensor parallel usage (huggingface#34800)
Browse files Browse the repository at this point in the history
protect
  • Loading branch information
ArthurZucker authored Nov 19, 2024
1 parent eed11f3 commit dadb286
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
3 changes: 3 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
find_pruneable_heads_and_indices,
id_tensor_storage,
is_torch_greater_or_equal_than_1_13,
is_torch_greater_or_equal_than_2_4,
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
Expand Down Expand Up @@ -5005,6 +5006,8 @@ def tensor_parallel(self, device_mesh):
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""
if not is_torch_greater_or_equal_than_2_4:
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")

# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@
from packaging import version
from safetensors.torch import storage_ptr, storage_size
from torch import nn
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
)

from .utils import is_torch_xla_available, logging

Expand All @@ -44,6 +39,14 @@
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")


if is_torch_greater_or_equal_than_2_4:
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
)


def softmax_backward_data(parent, grad_output, output, dim, self):
"""
A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
Expand Down

0 comments on commit dadb286

Please sign in to comment.