diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index e38bd6bf82cf2..0e5ab4f5c5133 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -665,8 +665,13 @@ def _split_ops(block): return [fwd_prog, bwd_prog, opt_prog] -def _program_for_vpp(program, num_model_chunks, dist_context): - _insert_sync_for_fthenb_1f1b(program, dist_context) +def _program_for_vpp( + program, num_model_chunks, dist_context, enable_send_recv_overlap=False +): + if enable_send_recv_overlap: + _overlap_send_recv(program) + else: + _insert_sync_for_fthenb_1f1b(program, dist_context) oprole_type = {0: "forward", 1: "backward", 2: "optimizer"} diff --git a/python/paddle/distributed/passes/pipeline_scheduler_pass.py b/python/paddle/distributed/passes/pipeline_scheduler_pass.py index 414032290f381..3b7e76cc4d791 100644 --- a/python/paddle/distributed/passes/pipeline_scheduler_pass.py +++ b/python/paddle/distributed/passes/pipeline_scheduler_pass.py @@ -552,8 +552,9 @@ def _get_virtual_pp_rank(micro_step, forward): def _partial_programs(self, program): dist_context = self.get_attr("dist_context") num_model_chunks = self.get_attr("vpp_degree") + enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap") types, sub_program_list = _program_for_vpp( - program, num_model_chunks, dist_context + program, num_model_chunks, dist_context, enable_send_recv_overlap ) return types, sub_program_list diff --git a/test/auto_parallel/pipeline_scheduler_vpp_unittest.py b/test/auto_parallel/pipeline_scheduler_vpp_unittest.py index bed72232a05ca..c54359fb272c2 100644 --- a/test/auto_parallel/pipeline_scheduler_vpp_unittest.py +++ b/test/auto_parallel/pipeline_scheduler_vpp_unittest.py @@ -108,7 +108,7 @@ def loss_fn(pred, label): return loss -def apply_pass(schedule_mode, acc_step): +def apply_pass(schedule_mode, acc_step, enable_send_recv_overlap=False): strategy = auto.Strategy() strategy.auto_mode = "semi" strategy.reinit = True @@ -119,6 +119,7 @@ def apply_pass(schedule_mode, acc_step): pipeline.accumulate_steps = acc_step pipeline.vpp_degree = 2 pipeline.vpp_seg_method = "MyLinear" + pipeline.enable_send_recv_overlap = enable_send_recv_overlap return strategy @@ -158,10 +159,16 @@ def init(self, engine): place = paddle.base.CUDAPlace(ParallelEnv().dev_id) engine._executor = paddle.static.Executor(place) - def get_engine(self, schedule_mode, acc_step, manual=True): + def get_engine( + self, + schedule_mode, + acc_step, + manual=True, + enable_send_recv_overlap=False, + ): reset_prog() - strategy = apply_pass(schedule_mode, acc_step) + strategy = apply_pass(schedule_mode, acc_step, enable_send_recv_overlap) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) model = MLPLayer(manual=manual) @@ -206,6 +213,71 @@ def test_pp_pass(self): fw_chunk_ids = [] bw_chunk_ids = [] + for op in engine.main_program.global_block().ops: + if is_optimize_op(op): + break + + dist_op = engine.dist_context.get_dist_op_for_program(op) + if is_forward_op(op): + fw_chunk_ids.append(dist_op.dist_attr.chunk_id) + if is_backward_op(op): + bw_chunk_ids.append(dist_op.dist_attr.chunk_id) + + if paddle.distributed.get_rank() == 0: + self.assertEqual(sum(fw_chunk_ids), 9) + self.assertEqual(sum(bw_chunk_ids), 13) + else: + self.assertEqual(sum(fw_chunk_ids), 13) + self.assertEqual(sum(bw_chunk_ids), 19) + + # pp2-vpp-manual-overlap + engine = self.get_engine( + schedule_mode="VPP", + acc_step=4, + manual=True, + enable_send_recv_overlap=True, + ) + out_manual_overlap = engine.fit( + self.dataset, batch_size=self.batch_size, log_freq=1 + ) + assert engine._strategy.pipeline.schedule_mode == "VPP" + assert engine._strategy.pipeline.enable_send_recv_overlap is True + + fw_chunk_ids = [] + bw_chunk_ids = [] + for op in engine.main_program.global_block().ops: + if is_optimize_op(op): + break + + dist_op = engine.dist_context.get_dist_op_for_program(op) + if is_forward_op(op): + fw_chunk_ids.append(dist_op.dist_attr.chunk_id) + if is_backward_op(op): + bw_chunk_ids.append(dist_op.dist_attr.chunk_id) + + if paddle.distributed.get_rank() == 0: + self.assertEqual(sum(fw_chunk_ids), 8) + self.assertEqual(sum(bw_chunk_ids), 13) + else: + self.assertEqual(sum(fw_chunk_ids), 12) + self.assertEqual(sum(bw_chunk_ids), 19) + + # pp2-vpp-auto-overlap + engine = self.get_engine( + schedule_mode="VPP", + acc_step=4, + manual=False, + enable_send_recv_overlap=True, + ) + out_auto_overlap = engine.fit( + self.dataset, batch_size=self.batch_size, log_freq=1 + ) + assert engine._strategy.pipeline.schedule_mode == "VPP" + assert engine._strategy.pipeline.enable_send_recv_overlap is True + + fw_chunk_ids = [] + bw_chunk_ids = [] + for op in engine.main_program.global_block().ops: if is_optimize_op(op): break @@ -228,6 +300,14 @@ def test_pp_pass(self): np.mean(out_manual.history["loss"][0]), np.mean(out_auto.history["loss"][0]), ) + self.assertEqual( + np.mean(out_manual.history["loss"][0]), + np.mean(out_manual_overlap.history["loss"][0]), + ) + self.assertEqual( + np.mean(out_manual.history["loss"][0]), + np.mean(out_auto_overlap.history["loss"][0]), + ) if __name__ == "__main__":