Skip to content

Commit

Permalink
[AutoParallel] 3D parallel on MLP with PIR (PaddlePaddle#64369)
Browse files Browse the repository at this point in the history
* pir 3D parallel on mlp

* fix bug in nd_mesh_reshard when [Partial(),Shard(0)]-->[Replicate(),Shard(0)]

* fix bug in 3D parallel unit test

* add _fetch_value test in ut
  • Loading branch information
pkuzyc authored and co63oc committed May 19, 2024
1 parent de53e6f commit 517a3af
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 73 deletions.
23 changes: 19 additions & 4 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,16 +2210,31 @@ def __call__(self, *args):
self.outs = outs

if self._mode == "predict":
if "outputs" in outs:
return outs["outputs"]
if "outputs" in self.outs:
return self.outs["outputs"]
else:
return None
else:
if "loss" in outs:
return outs["loss"]
if "loss" in self.outs:
return self.outs["loss"]
else:
return None

def _fetch_value(self, value, name=None):
"""
Get the value of the variable with the given name.
Args:
value (pir.Value): The pir Value to fetch.
name (str|None, optional): The user-defined name of
the fetched result. If None, the order of the Value
in the fetch list will be used. Default: None.
"""
self._engine._pir_fetch_values.append(value)
if name is None:
name = len(self._engine._pir_fetch_values) - 1
self._engine._pir_user_defined_fetch_names.append(name)

def state_dict(self, mode="all"):
"""
Get the state dict of model and optimizer.
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def __init__(
self._fwd_main_progs = {}
self._pir_dist_main_progs = {}
self._pir_dense_main_progs = {}
self._pir_fetch_values = []
self._pir_user_defined_fetch_names = []
self._orig_optimizer = copy.deepcopy(self._optimizer)

self._executor = None
Expand Down Expand Up @@ -1809,6 +1811,7 @@ def run(self, data=None, feed=None, fetch_list=None, mode=None):
fetch_names = []
else:
fetch_names = [loss_value]
fetch_names += self._pir_fetch_values

outs = self._executor.run(
self.main_program,
Expand All @@ -1821,8 +1824,12 @@ def run(self, data=None, feed=None, fetch_list=None, mode=None):
if self._in_pir_mode:
if no_fetch:
logs = {"outputs": None, "loss": None}
start_idx = 0
else:
logs = {"outputs": outs[0], "loss": outs[0]}
start_idx = 1
for i, name in enumerate(self._pir_user_defined_fetch_names):
logs[name] = outs[start_idx + i]
return logs

logs = self._prepare_logger(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
for partial_dim, partial_type in in_partial_status.items():
if (
partial_dim in out_partial_status
or ori_dst_dist_attr.dims_mapping[partial_dim] > -1
or partial_dim in ori_dst_dist_attr.dims_mapping
):
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
-1 not in dst_type.shape
), "dynamic shape is not supported by pir-auto parallel yet."
recv_value = paddle._C_ops.recv_v2(
dst_type.shape,
dst_type._local_shape,
dst_type.dtype,
src_local_rank,
comm_group.id,
Expand Down
71 changes: 68 additions & 3 deletions test/auto_parallel/hybrid_strategy/pir_reshard_nd_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def run_pp_to_rr_case(self):
)

new_ops = dist_program.global_block().ops
old_ops_name = [op.name() for op in main_program.global_block().ops]
new_ops_name = [op.name() for op in dist_program.global_block().ops]

rank_id = dist.get_rank()
Expand Down Expand Up @@ -150,7 +149,6 @@ def run_pr_to_rs_case(self):
)

new_ops = dist_program.global_block().ops
old_ops_name = [op.name() for op in main_program.global_block().ops]
new_ops_name = [op.name() for op in dist_program.global_block().ops]

rank_id = dist.get_rank()
Expand Down Expand Up @@ -204,7 +202,6 @@ def run_ss_to_ss_case(self):
)

new_ops = dist_program.global_block().ops
old_ops_name = [op.name() for op in main_program.global_block().ops]
new_ops_name = [op.name() for op in dist_program.global_block().ops]

all_gather_ops = []
Expand Down Expand Up @@ -274,11 +271,79 @@ def run_ss_to_ss_case(self):
tgt_out_value = (self._mesh.process_ids, [1, 0, -1], {})
self.validate(op, tgt_operand, tgt_result, tgt_in_value, tgt_out_value)

def run_ps_to_ps_case(self):
# [Partial(), Shard(0)] --> [Replicate(), Shard(1)]
# c_allreduce_sum + all_gather + slice
main_program, dist_program = self.create_program(
[self.BATCH_SIZE, self.SEQ_LEN, self.HIDDEN_SIZE],
[dist.Partial(dist.ReduceType.kRedSum), dist.Shard(0)],
[dist.Replicate(), dist.Shard(1)],
)

ops = dist_program.global_block().ops
op_names = [op.name() for op in ops]
assert "pd_op.c_allreduce_sum_" in op_names
assert "pd_op.c_allgather" in op_names
assert "pd_op.slice" in op_names

allreduce_sum_op = ops[op_names.index("pd_op.c_allreduce_sum_")]
allgather_op = ops[op_names.index("pd_op.c_allgather")]
slice_op = ops[op_names.index("pd_op.slice")]

# check the allreduce_sum
rank_id = dist.get_rank()
if rank_id in [0, 2]:
process_ids = [0, 2]
elif rank_id in [1, 3]:
process_ids = [1, 3]
tgt_operand = (process_ids, [-1, -1, -1], {0: dist.ReduceType.kRedSum})
tgt_result = (process_ids, [-1, -1, -1], {})
tgt_in_value = (
self._mesh.process_ids,
[1, -1, -1],
{0: dist.ReduceType.kRedSum},
)
tgt_out_value = (self._mesh.process_ids, [1, -1, -1], {})
self.validate(
allreduce_sum_op,
tgt_operand,
tgt_result,
tgt_in_value,
tgt_out_value,
)

# check the allgather
if rank_id in [0, 1]:
process_ids = [0, 1]
elif rank_id in [2, 3]:
process_ids = [2, 3]
tgt_operand = (process_ids, [0, -1, -1], {})
tgt_result = (process_ids, [-1, -1, -1], {})
tgt_in_value = (self._mesh.process_ids, [1, -1, -1], {})
tgt_out_value = (self._mesh.process_ids, [-1, -1, -1], {})
self.validate(
allgather_op, tgt_operand, tgt_result, tgt_in_value, tgt_out_value
)

# check the slice
if rank_id in [0, 1]:
process_ids = [0, 1]
elif rank_id in [2, 3]:
process_ids = [2, 3]
tgt_operand = (process_ids, [-1, -1, -1], {})
tgt_result = (process_ids, [-1, 0, -1], {})
tgt_in_value = (self._mesh.process_ids, [-1, -1, -1], {})
tgt_out_value = (self._mesh.process_ids, [-1, 1, -1], {})
self.validate(
slice_op, tgt_operand, tgt_result, tgt_in_value, tgt_out_value
)

def run_test_cases(self):
self.run_pp_to_rr_case()
self.run_pr_to_rs_case()
self.run_pr_to_ss_case()
self.run_ss_to_ss_case()
self.run_ps_to_ps_case()


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 517a3af

Please sign in to comment.