Skip to content

Commit

Permalink
[AutoParallel] Reconstruct sharding mesh dimension inference logic - …
Browse files Browse the repository at this point in the history
…Part2 add sharding_mesh_dimension param (PaddlePaddle#9382)

* add custom sharding_dim

* Update training_args.py

* Update auto_trainer.py

* Update auto_trainer.py
  • Loading branch information
AndSonder authored Nov 20, 2024
1 parent d74df4d commit bd1fd92
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
13 changes: 10 additions & 3 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,25 @@ def _wrap_for_dist_loader(self, train_dataloader):
def _wrap_for_auto(self, model, train_dataloader):
logger.info("Wrapping model for auto paralle")
dist_loader = self._wrap_for_dist_loader(train_dataloader)
sharding_parallel_mesh_dimension = self.args.sharding_parallel_mesh_dimension

if ShardingOption.SHARD_OP in self.args.sharding:
self.optimizer = dist.shard_optimizer(
self.optimizer, dist.ShardingStage1(), self.args.gradient_accumulation_steps
self.optimizer,
dist.ShardingStage1(sharding_mesh_dim=sharding_parallel_mesh_dimension),
self.args.gradient_accumulation_steps,
)
elif ShardingOption.SHARD_GRAD_OP in self.args.sharding:
self.optimizer = dist.shard_optimizer(
self.optimizer, dist.ShardingStage2(), self.args.gradient_accumulation_steps
self.optimizer,
dist.ShardingStage2(sharding_mesh_dim=sharding_parallel_mesh_dimension),
self.args.gradient_accumulation_steps,
)
elif ShardingOption.FULL_SHARD in self.args.sharding:
self.optimizer = dist.shard_optimizer(
self.optimizer, dist.ShardingStage3(), self.args.gradient_accumulation_steps
self.optimizer,
dist.ShardingStage3(sharding_mesh_dim=sharding_parallel_mesh_dimension),
self.args.gradient_accumulation_steps,
)
else:
self.optimizer = dist.shard_optimizer(self.optimizer, None, self.args.gradient_accumulation_steps)
Expand Down
12 changes: 12 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ class TrainingArguments:
Sharding parameter in certain cards group. For example, aussume we use 2 machines each with 8 cards,
then set sharding_parallel_degree=8, sharding will only communication inside machine.
default -1 means sharding parameters between all workers.
sharding_parallel_mesh_dimension (`str`, *optional*, defaults to `dp`)
Specifies the name of the dimension in a multi-dimensional parallelism mesh that is responsible for sharding.
default `dp` for default parallelism mesh.
tensor_parallel_degree (`int`, *optional*, defaults to `-1`)
Tensor parallelism is parallel technique proposed in (https://arxiv.org/pdf/2104.04473.pdf see 2.3 Tensor Model Parallelism).
This technique splits one transformer layer into multi-cards (For examples, tensor_parallel_degree=4, will split a layer to 4-parts)
Expand Down Expand Up @@ -562,6 +565,15 @@ class TrainingArguments:
)
},
)
sharding_parallel_mesh_dimension: str = field(
default="dp",
metadata={
"help": (
"Specifies the name of the dimension in a multi-dimensional parallelism mesh that is responsible for sharding. "
"default `dp` for default parallelism mesh. "
)
},
)
sharding_comm_buffer_size_MB: int = field(
default=-1,
metadata={
Expand Down

0 comments on commit bd1fd92

Please sign in to comment.