Skip to content

Commit

Permalink
Format and warning
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Oct 30, 2024
1 parent 79cc524 commit a2934b3
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
5 changes: 2 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5022,16 +5022,15 @@ def tensor_parallel(self, device_mesh):
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""

# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan:
logger.debug(
f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}"
)
logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}")
# In model configs, we use a neutral type (string) to specify
# parallel styles, here we translate them into torch TP types.
# Using tree_map because `tp_plan` is a dict.
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,8 @@ def __init__(self, config: LlamaConfig):

self.gradient_checkpointing = False
self._tp_plan = config._base_model_tp_plan
if config.pretraining_tp != 1:
logger.warn("`pretraining_tp` is deprecated, please use `tensor_parallel` method instead.")
# Initialize weights and apply final processing
self.post_init()

Expand Down
4 changes: 1 addition & 3 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,7 @@ def translate_to_torch_parallel_style(style: str):
types.
"""
if not isinstance(style, str):
raise ValueError(
f"Unsupported parallel style type {type(style)}, expected str"
)
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")

if style == "colwise":
return ColwiseParallel()
Expand Down

0 comments on commit a2934b3

Please sign in to comment.