From 2cec4c88ce54177685c00a3ec4a3398fc1a691e8 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Thu, 28 Jul 2022 14:01:15 +0800 Subject: [PATCH] adapt for resnet (#44685) --- python/paddle/distributed/auto_parallel/completion.py | 2 +- .../distributed/auto_parallel/operators/dist_default.py | 4 +++- python/paddle/distributed/auto_parallel/utils.py | 6 ++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index d18c05a058eea..4eb2f45cc1859 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -1396,7 +1396,7 @@ def complete_update_annotation(self, serial_main_program): op_dist_attr.set_output_dims_mapping( input_var.name, [-1]) else: - assert "Moment" in input_name + assert "Moment" in input_name or "Velocity" in input_name input_var_attr.dims_mapping = ref_dims_mapping op_dist_attr.set_input_dims_mapping( input_var.name, ref_dims_mapping) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 9b288d36e46eb..d0eba355e7bec 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -555,8 +555,10 @@ def backward(ctx, *args, **kwargs): rank_id = _get_corresponding_rank( ctx, process_mesh, rank_id) + # NOTE: consider that the variable's shape is None mesh_shape = process_mesh.topology - batch_size_axis = var_dim_mapping[0] + batch_size_axis = var_dim_mapping[0] if len( + var_dim_mapping) > 0 else -1 if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: need_gradient_allreduce = True group_ranks = _get_comm_group(process_mesh.processes, diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index c4f9ad8b6bc84..5d9499d9286f3 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1058,14 +1058,16 @@ def set_grad_var_shape(program, dist_context): "dropout_grad", "tanh_grad", "slice", "assign", "matmul_v2_triple_grad", "elementwise_add_triple_grad", "fill_constant", "sqrt_grad", - "fused_softmax_mask_upper_triangle_grad" + "fused_softmax_mask_upper_triangle_grad", + "flatten_contiguous_range_grad", "relu_grad" ] forward_list = [ "reshape2", "softmax_with_cross_entropy", "transpose2", "softmax", "cross_entropy2", "dropout", "tanh", ["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad", "elementwise_add_grad_grad", "shape", "sqrt", - "fused_softmax_mask_upper_triangle" + "fused_softmax_mask_upper_triangle", "flatten_contiguous_range", + "relu" ] if op.type in need_set_shape_list: for forward_op in block.ops: