Skip to content

Commit

Permalink
[AutoParallel] Fit send_recv overlap for vpp (#61541)
Browse files Browse the repository at this point in the history
* fit send_recv overlap for vpp

* add test cases for using send_recv overlap
  • Loading branch information
AndSonder authored Feb 5, 2024
1 parent 9882b25 commit 74d517c
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 6 deletions.
9 changes: 7 additions & 2 deletions python/paddle/distributed/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,13 @@ def _split_ops(block):
return [fwd_prog, bwd_prog, opt_prog]


def _program_for_vpp(program, num_model_chunks, dist_context):
_insert_sync_for_fthenb_1f1b(program, dist_context)
def _program_for_vpp(
program, num_model_chunks, dist_context, enable_send_recv_overlap=False
):
if enable_send_recv_overlap:
_overlap_send_recv(program)
else:
_insert_sync_for_fthenb_1f1b(program, dist_context)

oprole_type = {0: "forward", 1: "backward", 2: "optimizer"}

Expand Down
3 changes: 2 additions & 1 deletion python/paddle/distributed/passes/pipeline_scheduler_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,9 @@ def _get_virtual_pp_rank(micro_step, forward):
def _partial_programs(self, program):
dist_context = self.get_attr("dist_context")
num_model_chunks = self.get_attr("vpp_degree")
enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap")
types, sub_program_list = _program_for_vpp(
program, num_model_chunks, dist_context
program, num_model_chunks, dist_context, enable_send_recv_overlap
)
return types, sub_program_list

Expand Down
86 changes: 83 additions & 3 deletions test/auto_parallel/pipeline_scheduler_vpp_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def loss_fn(pred, label):
return loss


def apply_pass(schedule_mode, acc_step):
def apply_pass(schedule_mode, acc_step, enable_send_recv_overlap=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
Expand All @@ -119,6 +119,7 @@ def apply_pass(schedule_mode, acc_step):
pipeline.accumulate_steps = acc_step
pipeline.vpp_degree = 2
pipeline.vpp_seg_method = "MyLinear"
pipeline.enable_send_recv_overlap = enable_send_recv_overlap

return strategy

Expand Down Expand Up @@ -158,10 +159,16 @@ def init(self, engine):
place = paddle.base.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)

def get_engine(self, schedule_mode, acc_step, manual=True):
def get_engine(
self,
schedule_mode,
acc_step,
manual=True,
enable_send_recv_overlap=False,
):
reset_prog()

strategy = apply_pass(schedule_mode, acc_step)
strategy = apply_pass(schedule_mode, acc_step, enable_send_recv_overlap)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model = MLPLayer(manual=manual)
Expand Down Expand Up @@ -206,6 +213,71 @@ def test_pp_pass(self):

fw_chunk_ids = []
bw_chunk_ids = []
for op in engine.main_program.global_block().ops:
if is_optimize_op(op):
break

dist_op = engine.dist_context.get_dist_op_for_program(op)
if is_forward_op(op):
fw_chunk_ids.append(dist_op.dist_attr.chunk_id)
if is_backward_op(op):
bw_chunk_ids.append(dist_op.dist_attr.chunk_id)

if paddle.distributed.get_rank() == 0:
self.assertEqual(sum(fw_chunk_ids), 9)
self.assertEqual(sum(bw_chunk_ids), 13)
else:
self.assertEqual(sum(fw_chunk_ids), 13)
self.assertEqual(sum(bw_chunk_ids), 19)

# pp2-vpp-manual-overlap
engine = self.get_engine(
schedule_mode="VPP",
acc_step=4,
manual=True,
enable_send_recv_overlap=True,
)
out_manual_overlap = engine.fit(
self.dataset, batch_size=self.batch_size, log_freq=1
)
assert engine._strategy.pipeline.schedule_mode == "VPP"
assert engine._strategy.pipeline.enable_send_recv_overlap is True

fw_chunk_ids = []
bw_chunk_ids = []
for op in engine.main_program.global_block().ops:
if is_optimize_op(op):
break

dist_op = engine.dist_context.get_dist_op_for_program(op)
if is_forward_op(op):
fw_chunk_ids.append(dist_op.dist_attr.chunk_id)
if is_backward_op(op):
bw_chunk_ids.append(dist_op.dist_attr.chunk_id)

if paddle.distributed.get_rank() == 0:
self.assertEqual(sum(fw_chunk_ids), 8)
self.assertEqual(sum(bw_chunk_ids), 13)
else:
self.assertEqual(sum(fw_chunk_ids), 12)
self.assertEqual(sum(bw_chunk_ids), 19)

# pp2-vpp-auto-overlap
engine = self.get_engine(
schedule_mode="VPP",
acc_step=4,
manual=False,
enable_send_recv_overlap=True,
)
out_auto_overlap = engine.fit(
self.dataset, batch_size=self.batch_size, log_freq=1
)
assert engine._strategy.pipeline.schedule_mode == "VPP"
assert engine._strategy.pipeline.enable_send_recv_overlap is True

fw_chunk_ids = []
bw_chunk_ids = []

for op in engine.main_program.global_block().ops:
if is_optimize_op(op):
break
Expand All @@ -228,6 +300,14 @@ def test_pp_pass(self):
np.mean(out_manual.history["loss"][0]),
np.mean(out_auto.history["loss"][0]),
)
self.assertEqual(
np.mean(out_manual.history["loss"][0]),
np.mean(out_manual_overlap.history["loss"][0]),
)
self.assertEqual(
np.mean(out_manual.history["loss"][0]),
np.mean(out_auto_overlap.history["loss"][0]),
)


if __name__ == "__main__":
Expand Down

0 comments on commit 74d517c

Please sign in to comment.