From edf5c2ff3964c7842c34660e3e8877b74f1a33ae Mon Sep 17 00:00:00 2001 From: liuyuang Date: Thu, 22 Sep 2022 15:50:01 +0800 Subject: [PATCH 1/5] sync recv for 1f1b --- .../fleet/meta_parallel/pipeline_parallel.py | 25 +++++++---- .../pp_utils/p2p_communication.py | 42 ++++++++++++------- 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 02a1b421526df..0a24c76aedc63 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -126,7 +126,8 @@ def forward_backward_pipeline(self, data, scaler=None): output_buffers = [] for step_id in range(startup_steps): - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage(), + sync_recv=True) output_tensor = self._forward_step(input_tensor) p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) @@ -135,7 +136,8 @@ def forward_backward_pipeline(self, data, scaler=None): output_buffers.append(output_tensor) if steady_steps > 0: - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage(), + sync_recv=True) for i in range(steady_steps): last_iter = (i == (steady_steps - 1)) @@ -167,7 +169,7 @@ def forward_backward_pipeline(self, data, scaler=None): output_tensor = output_buffers.pop(0) output_tensor_grad = p2p.recv_backward( - self.is_pipeline_last_stage()) + self.is_pipeline_last_stage(), sync_recv=True) input_tensor_grad = self._backward_step(input_tensor, output_tensor, output_tensor_grad) @@ -237,7 +239,8 @@ def eval_batch(self, data, compute_loss=False): output_buffers = [] for step_id in range(startup_steps): - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage(), + sync_recv=True) output_tensor = self._forward_step(input_tensor) p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) @@ -246,7 +249,8 @@ def eval_batch(self, data, compute_loss=False): output_buffers.append(output_tensor) if steady_steps > 0: - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage(), + sync_recv=True) for i in range(steady_steps): last_iter = (i == (steady_steps - 1)) @@ -258,7 +262,8 @@ def eval_batch(self, data, compute_loss=False): output_buffers.append(output_tensor) if not last_iter: - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage(), + sync_recv=True) if self._compute_loss: self.train_loss = self._broadcast_final_loss() @@ -525,8 +530,9 @@ def interleave_pipeline(self, steady_steps = num_steps - startup_steps self.set_virtual_pipeline_rank(0) - self.input_tensors[0].append( - p2p.recv_forward(self.is_pipeline_first_stage())) + self.input_tensors[0].append(p2p.recv_forward( + self.is_pipeline_first_stage()), + sync_recv=False) # run startup steps for micro_step in range(startup_steps): @@ -647,7 +653,8 @@ def interleave_pipeline(self, if not forward_only: if all_startup_steps: self.output_tensor_grads[self.num_model_chunks - 1].append( - p2p.recv_backward(self.is_pipeline_last_stage())) + p2p.recv_backward(self.is_pipeline_last_stage(), + sync_recv=False)) for micro_step in range(steady_steps, num_steps): # cooldown loop diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 160c5f1511220..4de4c505941cc 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -216,8 +216,11 @@ def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - return group.process_group.recv_partial(tensor, src_rank_in_group, + task = group.process_group.recv_partial(tensor, src_rank_in_group, nranks, rank_id) + if use_calc_stream: + task.wait() + return task def recv_partial(tensor, @@ -275,7 +278,11 @@ def allgather_partial(tensor, nranks, rank_id) -def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): +def _p2p_helper(tensor_send_next, + tensor_send_prev, + recv_prev, + recv_next, + sync_recv=True): global _hcg tensor_recv_prev = None @@ -354,7 +361,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_prev_group, - use_calc_stream=True)) + use_calc_stream=sync_recv)) else: tasks.append( recv_partial(tensor_recv_prev, @@ -362,7 +369,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_prev_group, - use_calc_stream=True)) + use_calc_stream=sync_recv)) if tensor_send_next is not None: if isinstance(tensor_send_next, tuple): @@ -394,7 +401,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_next_group, - use_calc_stream=True)) + use_calc_stream=sync_recv)) else: tasks.append( @@ -403,10 +410,10 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_next_group, - use_calc_stream=True)) + use_calc_stream=sync_recv)) - if in_dygraph_mode(): - # wait isend/irecv tasks in eager dygraph mode with new comm library + if not sync_recv and in_dygraph_mode(): + # wait irecv tasks in eager dygraph mode with new comm library for task in tasks: assert task is not None task.wait() @@ -443,7 +450,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): return tensor_recv_prev, tensor_recv_next -def recv_forward(pp_first_stage): +def recv_forward(pp_first_stage, sync_recv=True): if pp_first_stage: input_tensor = None else: @@ -454,18 +461,20 @@ def recv_forward(pp_first_stage): input_tensor, _ = _p2p_helper(tensor_send_next=None, tensor_send_prev=None, recv_prev=True, - recv_next=False) + recv_next=False, + sync_recv=sync_recv) return input_tensor -def recv_backward(pp_last_stage): +def recv_backward(pp_last_stage, sync_recv=True): if pp_last_stage: output_tensor_grad = None else: _, output_tensor_grad = _p2p_helper(tensor_send_next=None, tensor_send_prev=None, recv_prev=False, - recv_next=True) + recv_next=True, + sync_recv=sync_recv) return output_tensor_grad @@ -527,7 +536,8 @@ def send_forward_backward_recv_forward_backward(output_tensor, tensor_send_next=output_tensor, tensor_send_prev=input_tensor_grad, recv_prev=recv_prev, - recv_next=recv_next) + recv_next=recv_next, + sync_recv=False) return input_tensor, output_tensor_grad @@ -544,7 +554,8 @@ def send_forward_recv_forward(output_tensor, recv_prev): input_tensor, _ = _p2p_helper(tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=recv_prev, - recv_next=False) + recv_next=False, + sync_recv=False) return input_tensor @@ -553,5 +564,6 @@ def send_backward_recv_backward(input_tensor_grad, recv_next): _, output_tensor_grad = _p2p_helper(tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=False, - recv_next=recv_next) + recv_next=recv_next, + sync_recv=False) return output_tensor_grad From 1398183b8e1d4fca91c0645283d632cb114042a0 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Thu, 22 Sep 2022 15:53:00 +0800 Subject: [PATCH 2/5] update --- .../fleet/meta_parallel/pipeline_parallel.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 0a24c76aedc63..ae7c29f2f09fc 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -126,8 +126,7 @@ def forward_backward_pipeline(self, data, scaler=None): output_buffers = [] for step_id in range(startup_steps): - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage(), - sync_recv=True) + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) output_tensor = self._forward_step(input_tensor) p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) @@ -136,8 +135,7 @@ def forward_backward_pipeline(self, data, scaler=None): output_buffers.append(output_tensor) if steady_steps > 0: - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage(), - sync_recv=True) + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) for i in range(steady_steps): last_iter = (i == (steady_steps - 1)) @@ -169,7 +167,7 @@ def forward_backward_pipeline(self, data, scaler=None): output_tensor = output_buffers.pop(0) output_tensor_grad = p2p.recv_backward( - self.is_pipeline_last_stage(), sync_recv=True) + self.is_pipeline_last_stage()) input_tensor_grad = self._backward_step(input_tensor, output_tensor, output_tensor_grad) @@ -239,8 +237,7 @@ def eval_batch(self, data, compute_loss=False): output_buffers = [] for step_id in range(startup_steps): - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage(), - sync_recv=True) + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) output_tensor = self._forward_step(input_tensor) p2p.send_forward(output_tensor, self.is_pipeline_last_stage()) @@ -249,8 +246,7 @@ def eval_batch(self, data, compute_loss=False): output_buffers.append(output_tensor) if steady_steps > 0: - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage(), - sync_recv=True) + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) for i in range(steady_steps): last_iter = (i == (steady_steps - 1)) @@ -262,8 +258,7 @@ def eval_batch(self, data, compute_loss=False): output_buffers.append(output_tensor) if not last_iter: - input_tensor = p2p.recv_forward(self.is_pipeline_first_stage(), - sync_recv=True) + input_tensor = p2p.recv_forward(self.is_pipeline_first_stage()) if self._compute_loss: self.train_loss = self._broadcast_final_loss() From 98229c6f4822086dfa7f9fb0b54001886f6c84af Mon Sep 17 00:00:00 2001 From: liuyuang Date: Thu, 22 Sep 2022 16:00:05 +0800 Subject: [PATCH 3/5] for general pp --- .../fleet/meta_parallel/pp_utils/p2p_communication.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 4de4c505941cc..a1004a872e47d 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -241,7 +241,7 @@ def recv_partial(tensor, return _partial_recv_op(tensor, group, use_calc_stream, ring_id, src_rank, nranks, rank_id) else: - if _in_legacy_dygraph(): + if _in_legacy_dygraph() or use_calc_stream: recv_op = paddle.distributed.recv elif in_dygraph_mode(): recv_op = paddle.distributed.irecv From 194b8ba574c95dbde7cd297058cc047597328dfc Mon Sep 17 00:00:00 2001 From: liuyuang Date: Thu, 22 Sep 2022 16:06:26 +0800 Subject: [PATCH 4/5] add assertion --- .../fleet/meta_parallel/pp_utils/p2p_communication.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index a1004a872e47d..7962e2dd4373e 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -207,6 +207,7 @@ def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, rank_id): src_rank_in_group = src if group is None else group.get_group_rank(src) if _in_legacy_dygraph(): + assert use_calc_stream return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', src_rank_in_group, 'num', From efb5bab424f0ad208e39dc5aaabc01d690c2554a Mon Sep 17 00:00:00 2001 From: liuyuang Date: Thu, 22 Sep 2022 17:00:44 +0800 Subject: [PATCH 5/5] bug fix --- .../distributed/fleet/meta_parallel/pipeline_parallel.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index ae7c29f2f09fc..56429b748064d 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -525,9 +525,8 @@ def interleave_pipeline(self, steady_steps = num_steps - startup_steps self.set_virtual_pipeline_rank(0) - self.input_tensors[0].append(p2p.recv_forward( - self.is_pipeline_first_stage()), - sync_recv=False) + self.input_tensors[0].append( + p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)) # run startup steps for micro_step in range(startup_steps):