Skip to content

Commit

Permalink
adapt for resnet (#44685)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoyinglia authored Jul 28, 2022
1 parent a9f76d0 commit 2cec4c8
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,7 +1396,7 @@ def complete_update_annotation(self, serial_main_program):
op_dist_attr.set_output_dims_mapping(
input_var.name, [-1])
else:
assert "Moment" in input_name
assert "Moment" in input_name or "Velocity" in input_name
input_var_attr.dims_mapping = ref_dims_mapping
op_dist_attr.set_input_dims_mapping(
input_var.name, ref_dims_mapping)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,10 @@ def backward(ctx, *args, **kwargs):
rank_id = _get_corresponding_rank(
ctx, process_mesh, rank_id)

# NOTE: consider that the variable's shape is None
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(
var_dim_mapping) > 0 else -1
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
group_ranks = _get_comm_group(process_mesh.processes,
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,14 +1058,16 @@ def set_grad_var_shape(program, dist_context):
"dropout_grad", "tanh_grad", "slice", "assign",
"matmul_v2_triple_grad", "elementwise_add_triple_grad",
"fill_constant", "sqrt_grad",
"fused_softmax_mask_upper_triangle_grad"
"fused_softmax_mask_upper_triangle_grad",
"flatten_contiguous_range_grad", "relu_grad"
]
forward_list = [
"reshape2", "softmax_with_cross_entropy", "transpose2",
"softmax", "cross_entropy2", "dropout", "tanh",
["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad",
"elementwise_add_grad_grad", "shape", "sqrt",
"fused_softmax_mask_upper_triangle"
"fused_softmax_mask_upper_triangle", "flatten_contiguous_range",
"relu"
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
Expand Down

0 comments on commit 2cec4c8

Please sign in to comment.