diff --git a/paddle/phi/infermeta/spmd_rules/reshape.cc b/paddle/phi/infermeta/spmd_rules/reshape.cc index 9ca886f0dc637..20e1120aa6fda 100644 --- a/paddle/phi/infermeta/spmd_rules/reshape.cc +++ b/paddle/phi/infermeta/spmd_rules/reshape.cc @@ -125,6 +125,17 @@ std::vector> MakeReshapeDimTrans( for (auto in_dim : src_dims) { if (src_shape[in_dim] > 1) { input_dims.emplace_back(std::make_shared(in_dim)); + } else if (src_shape[in_dim] == 1 && s == 1 && t == 1) { + // NOTE: for the case like: + // shape: [1, 512, 4096] --> [1, 2, 256, 4096], + // input dims_mapping: [0, 1, -1] + // expected output dims_mapping: [0, 1, -1, -1] (not [-1, 1, -1, + // -1]) + // In this case, the dim0 in target shape is 1 and it is from + // dim0 in source shape. make the dim0's transformation be InputDim + // rather than Singleton so that the sharding status can be + // propagated. + input_dims.emplace_back(std::make_shared(in_dim)); } } std::shared_ptr flatten = make_flatten(input_dims); diff --git a/test/auto_parallel/spmd_rules/test_reshape_rule.py b/test/auto_parallel/spmd_rules/test_reshape_rule.py index 80ec33aecfcdb..e70761e705cb0 100644 --- a/test/auto_parallel/spmd_rules/test_reshape_rule.py +++ b/test/auto_parallel/spmd_rules/test_reshape_rule.py @@ -291,6 +291,22 @@ def test_reshape_infer_forward(self): infered_output_dist_attrs[0].dims_mapping, [1, -1, 0, -1] ) + # shape: [1, 2048, 12288] --> [0, 0, 6, 2048] + # dims_mapping: [0, -1, 1] --> [0, -1, 1], [0, -1, 1, -1] + self.x_dist_tensor_spec.shape = [1, 2048, 12288] + self.attrs["shape"] = [0, 0, 6, 2048] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, 1]) + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['shape'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, 1, -1] + ) + # shape: [6, 12, 48, 24] --> [3, 24, 6, -1, -1] # raise error self.attrs["shape"] = [3, 24, 6, -1, -1]