From a36cdd6b6cd95b0a4c708661dc22782481e38e7e Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Mon, 27 Feb 2023 10:25:31 +0800 Subject: [PATCH] [AutoParallel] add dist_attr in data_parallel optimization (#49744) * fix dist_attr in data_parallel in optimization * fix grad_clip pass when pp2 * fix dist_attr --- .../distributed/auto_parallel/dist_op.py | 18 +- .../paddle/distributed/auto_parallel/utils.py | 33 ++- ...uto_parallel_data_parallel_optimization.py | 209 +++++++++--------- .../passes/auto_parallel_grad_clip.py | 2 + ...rallel_supplement_explicit_dependencies.py | 3 - 5 files changed, 148 insertions(+), 117 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index ac39d62a30d91..89fd71df129c1 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -221,7 +221,14 @@ def __str__(self): ) for arg_name in self.serial_op.desc.input_arg_names(): - dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name) + try: + dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name) + except IndexError: + raise IndexError( + "There is not input var '{}''s dist_attr in current op '{}'".format( + arg_name, self.serial_op.desc.type() + ) + ) if self.dist_attr.is_annotated_input_dims_mapping(arg_name): annotated_str = "annotated" else: @@ -238,7 +245,14 @@ def __str__(self): ) for arg_name in self.serial_op.desc.output_arg_names(): - dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name) + try: + dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name) + except IndexError: + raise IndexError( + "There is not output var '{}''s dist_attr in current op '{}'".format( + arg_name, self.serial_op.desc.type() + ) + ) if self.dist_attr.is_annotated_output_dims_mapping(arg_name): annotated_str = "annotated" else: diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 4139051800ebb..1792fd51f9323 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1426,9 +1426,6 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping( def naive_set_dist_op_attr_for_program_by_mesh( new_op, process_mesh, ctx, is_recompute=False ): - # hack to skip coalesce var for dist attr - if not is_recompute: - return assert process_mesh is not None new_op_dist_attr = OperatorDistAttr() @@ -2314,15 +2311,31 @@ def insert_dependencies_for_vars( }, outputs={"Out": post_vars}, ) - - # depend_op.desc.set_type("depend") depend_op._set_attr(OP_ROLE_KEY, oprole) - # depend_op.desc.set_input("Dep", [first_var.name]) - # self.desc.set_output(out_proto.name, out_arg_names) - naive_set_dist_op_attr_for_program_by_mesh( - depend_op, process_mesh, dist_context, is_recompute - ) + # TODO: condition can be removed when add correct dist_attr for coalesce vars and ops in sharding_pass + if is_recompute or process_mesh != [-1]: + depend_op_dist_attr = OperatorDistAttr() + depend_op_dist_attr.impl_idx = 0 + depend_op_dist_attr.impl_type = "default" + depend_op_dist_attr.process_mesh = process_mesh + depend_op_dist_attr.is_recompute = is_recompute + for input_varname in depend_op.desc.input_arg_names(): + var = block.var(input_varname) + mapping = dist_context.get_tensor_dist_attr_for_program( + var + ).dims_mapping + depend_op_dist_attr.set_input_dims_mapping(input_varname, mapping) + for output_varname in depend_op.desc.output_arg_names(): + var = block.var(output_varname) + mapping = dist_context.get_tensor_dist_attr_for_program( + var + ).dims_mapping + depend_op_dist_attr.set_output_dims_mapping(output_varname, mapping) + dist_context.set_op_dist_attr_for_program( + depend_op, depend_op_dist_attr + ) + if op_namescope is not None: depend_op._set_attr('op_namescope', "/{}".format(op_namescope)) diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index e9acb9074fee0..dbfaddf2917f9 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -15,10 +15,15 @@ from collections import OrderedDict import paddle +from paddle.distributed.auto_parallel.dist_attribute import ( + OperatorDistAttr, + TensorDistAttr, +) from paddle.distributed.auto_parallel.operators.common import ( is_data_parallel_reduce_op, is_data_parallel_scale_op, ) +from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.utils import ( find_higher_order_backward_op, get_var_numel, @@ -463,6 +468,21 @@ def _update_program(self, grad_groups): group.coalesce_var = group.gradients[0] continue + ref_process_mesh = set() + concated_shapes = [] + concated_ranks = [] + for grad_ in group.gradients: + grad_dist_attr = ( + self.dist_context.get_tensor_dist_attr_for_program(grad_) + ) + ref_process_mesh.update( + set(grad_dist_attr.process_mesh.process_ids) + ) + + shape = grad_.shape + concated_shapes.extend(shape) + concated_ranks.append(len(shape)) + # create coalesce tensor group.coalesce_var = block.create_var( name=unique_name.generate( @@ -473,6 +493,13 @@ def _update_program(self, grad_groups): stop_gradient=True, ) + tensor_dist_attr = TensorDistAttr() + tensor_dist_attr.process_mesh = ProcessMesh(list(ref_process_mesh)) + tensor_dist_attr.dims_mapping = [] + self.dist_context.set_tensor_dist_attr_for_program( + group.coalesce_var, tensor_dist_attr + ) + # update allreduce & scale op if group.scale_op_idx != -1: scale_op = block.ops[group.scale_op_idx] @@ -492,11 +519,27 @@ def _update_program(self, grad_groups): ), "should found c_allreduce_sum op but found {}".format( str(allreduce_op) ) - allreduce_op._rename_input( - allreduce_op.input_arg_names[0], group.coalesce_var.name + allreduce_op_dist_attr = ( + self.dist_context.get_op_dist_attr_for_program(allreduce_op) + ) + old_in_name = allreduce_op.input_arg_names[0] + new_in_name = group.coalesce_var.name + allreduce_op._rename_input(old_in_name, new_in_name) + input_dist_attr = allreduce_op_dist_attr.get_input_dist_attr( + old_in_name ) - allreduce_op._rename_output( - allreduce_op.output_arg_names[0], group.coalesce_var.name + allreduce_op_dist_attr.set_input_dist_attr( + new_in_name, input_dist_attr + ) + + old_out_name = allreduce_op.output_arg_names[0] + new_out_name = group.coalesce_var.name + allreduce_op._rename_output(old_out_name, new_out_name) + out_dist_attr = allreduce_op_dist_attr.get_output_dist_attr( + old_out_name + ) + allreduce_op_dist_attr.set_output_dist_attr( + new_out_name, out_dist_attr ) # remvoe un-used op @@ -512,15 +555,8 @@ def _update_program(self, grad_groups): block._remove_op(idx, False) # insert coalesce op - concated_shapes = [] - concated_ranks = [] - for grad_ in group.gradients: - shape = grad_.shape - concated_shapes.extend(shape) - concated_ranks.append(len(shape)) - grad_names = [grad.name for grad in group.gradients] - block._insert_op_without_sync( + coalesce_op = block._insert_op_without_sync( group.coalesce_op_idx, type="coalesce_tensor", inputs={"Input": grad_names}, @@ -538,8 +574,32 @@ def _update_program(self, grad_groups): }, ) + op_dist_attr = OperatorDistAttr() + op_dist_attr.impl_idx = 0 + op_dist_attr.impl_type = "default" + op_dist_attr.process_mesh = ProcessMesh(list(ref_process_mesh)) + for in_name in coalesce_op.input_arg_names: + in_var = block.var(in_name) + in_var_dist_attr = ( + self.dist_context.get_tensor_dist_attr_for_program(in_var) + ) + op_dist_attr.set_input_dims_mapping( + in_name, in_var_dist_attr.dims_mapping + ) + for out_name in coalesce_op.output_arg_names: + out_var = block.var(out_name) + out_var_dist_attr = ( + self.dist_context.get_tensor_dist_attr_for_program(out_var) + ) + op_dist_attr.set_output_dims_mapping( + out_name, out_var_dist_attr.dims_mapping + ) + + self.dist_context.set_op_dist_attr_for_program( + coalesce_op, op_dist_attr + ) + block._sync_with_cpp() - # TODO update dist attr def _add_dependencies(self, grad_groups): # NOTE Currently, auto_parallel need to adopt for two executors: Sequential executor (old exe) and Graph based @@ -551,22 +611,12 @@ def _add_dependencies(self, grad_groups): block = default_main_program().global_block() # Build maps - vars_to_coalesce_map = {} coalesce_to_vars_map = {} - for group in grad_groups: - grad_names = [] - coalesce_name = group.coalesce_var.name - for grad in group.gradients: - vars_to_coalesce_map[grad.name] = coalesce_name - grad_names.append(grad.name) - coalesce_to_vars_map[coalesce_name] = grad_names + coalesce_to_vars_map[group.coalesce_var.name] = group # analyze dependencies - # Record ONLY the last grad that generated before allreduce - # NOTE need to be update when we allow multiple calc stream for backward calc - not_sync_coalesces = [] - prior_allreduce_deps = {} + dep_map = {} for idx, op in reversed(list(enumerate(block.ops))): if is_forward_op(op): break @@ -575,86 +625,41 @@ def _add_dependencies(self, grad_groups): if is_data_parallel_reduce_op(op): coalesce_var_name = op.output_arg_names[0] - - # NOTE only add extra deps for fused tensor, other tensor rely on - # data flow analysis of executor. - if self.coalesce_prefix in coalesce_var_name: - prior_allreduce_deps[coalesce_var_name] = [ - idx, - None, - coalesce_var_name, - ] - not_sync_coalesces.append(coalesce_var_name) - continue - - for out_name in op.output_arg_names: - var_name = vars_to_coalesce_map.get(out_name, None) - if var_name in not_sync_coalesces: - prior_allreduce_deps[var_name][1] = out_name - not_sync_coalesces.remove(var_name) - assert ( - len(not_sync_coalesces) == 0 - ), "Unexpected: {} has NOT been add prior Dep before allreduce.".format( - not_sync_coalesces - ) - - # Record ONLY the first grad that used after allreduce - # NOTE need to be update when we allow multiple calc stream for backward calc - not_sync_coalesces = [] - post_allreduce_deps = {} - for idx, op in enumerate(block.ops): - if is_forward_op(op): - continue - - if is_data_parallel_reduce_op(op): - coalesce_var_name = op.input_arg_names[0] if self.coalesce_prefix in coalesce_var_name: - post_allreduce_deps[coalesce_var_name] = [ - None, - coalesce_var_name, - None, + group = coalesce_to_vars_map[coalesce_var_name] + dep_map[idx] = [ + ( + idx, + group.gradients[-1], + group.coalesce_var, + op.attr(OP_ROLE_KEY), + ) ] - not_sync_coalesces.append(coalesce_var_name) - continue - - for out_name in op.input_arg_names: - var_name = vars_to_coalesce_map.get(out_name, None) - if var_name in not_sync_coalesces: - post_allreduce_deps[var_name][0] = idx - post_allreduce_deps[var_name][2] = out_name - not_sync_coalesces.remove(var_name) - - assert ( - len(not_sync_coalesces) == 0 - ), "Unexpected: {} has NOT been add post Dep after allreduce.".format( - not_sync_coalesces - ) + dep_map[idx].append( + ( + idx + 1, + group.coalesce_var, + group.gradients, + op.attr(OP_ROLE_KEY), + ) + ) - # Update program IR insert dependencise op - dep_var_pairs = [] - for deps in [prior_allreduce_deps, post_allreduce_deps]: - for pair in deps.values(): - dep_var_pairs.append(pair) - - dep_var_pairs.sort(key=lambda x: x[0], reverse=True) - for idx, prior_name, post_name in dep_var_pairs: - prior_var = block.var(prior_name) - post_var = block.var(post_name) - depend_op = insert_dependencies_for_vars( - block, - idx, - prior_var, - post_var, - self.dist_context, - OpRole.Backward, - process_mesh=[ - -1 - ], # hack to avoid initialize the dist attr for coalesce var - is_recompute=False, - sync=False, - op_namescope="data_parallel_overlap_dep", - ) - depend_op.dist_attr.execution_stream = self.gradient_sync_stream + # insert dependency op + indice = sorted(list(dep_map.keys()), reverse=True) + for i in indice: + for idx, prior_vars, post_vars, op_role in dep_map[i][::-1]: + depend_op = insert_dependencies_for_vars( + block, + idx, + prior_vars, + post_vars, + self.dist_context, + op_role, + is_recompute=False, + sync=False, + op_namescope="data_parallel_overlap_dep", + ) + depend_op.dist_attr.execution_stream = self.gradient_sync_stream block._sync_with_cpp() # remove naive synchronization & assign allreduce stream diff --git a/python/paddle/distributed/passes/auto_parallel_grad_clip.py b/python/paddle/distributed/passes/auto_parallel_grad_clip.py index 25a768e94dfaa..5a0dd9c5e39e1 100644 --- a/python/paddle/distributed/passes/auto_parallel_grad_clip.py +++ b/python/paddle/distributed/passes/auto_parallel_grad_clip.py @@ -254,6 +254,8 @@ def _is_pure_data_parallel(self): "c_allreduce_sum", ] and not is_data_parallel_reduce_op(op): return False + if op.type in ["send_v2", "recv_v2"]: + return False return True diff --git a/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py b/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py index 07de0c1dcbfeb..7f9ed86b18b5d 100644 --- a/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py +++ b/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py @@ -150,9 +150,6 @@ def _apply_single_impl(self, main_program, startup_program, context): post_var, self._dist_context, OpRole.Optimize, - process_mesh=[ - -1 - ], # hack to avoid initialize the dist attr for coalesc var is_recompute=False, sync=False, op_namescope=op_namescope,