Skip to content

Commit

Permalink
add a case when the target shape and source shape are both 1 in resha…
Browse files Browse the repository at this point in the history
…pe spmd rule (#63681)
  • Loading branch information
pkuzyc authored Apr 19, 2024
1 parent e880d10 commit d725a5e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
11 changes: 11 additions & 0 deletions paddle/phi/infermeta/spmd_rules/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ std::vector<std::shared_ptr<DimTrans>> MakeReshapeDimTrans(
for (auto in_dim : src_dims) {
if (src_shape[in_dim] > 1) {
input_dims.emplace_back(std::make_shared<InputDim>(in_dim));
} else if (src_shape[in_dim] == 1 && s == 1 && t == 1) {
// NOTE: for the case like:
// shape: [1, 512, 4096] --> [1, 2, 256, 4096],
// input dims_mapping: [0, 1, -1]
// expected output dims_mapping: [0, 1, -1, -1] (not [-1, 1, -1,
// -1])
// In this case, the dim0 in target shape is 1 and it is from
// dim0 in source shape. make the dim0's transformation be InputDim
// rather than Singleton so that the sharding status can be
// propagated.
input_dims.emplace_back(std::make_shared<InputDim>(in_dim));
}
}
std::shared_ptr<DimTrans> flatten = make_flatten(input_dims);
Expand Down
16 changes: 16 additions & 0 deletions test/auto_parallel/spmd_rules/test_reshape_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,22 @@ def test_reshape_infer_forward(self):
infered_output_dist_attrs[0].dims_mapping, [1, -1, 0, -1]
)

# shape: [1, 2048, 12288] --> [0, 0, 6, 2048]
# dims_mapping: [0, -1, 1] --> [0, -1, 1], [0, -1, 1, -1]
self.x_dist_tensor_spec.shape = [1, 2048, 12288]
self.attrs["shape"] = [0, 0, 6, 2048]
self.x_dist_tensor_spec.set_dims_mapping([0, -1, 1])
result_dist_attrs = self.rule.infer_forward(
self.x_dist_tensor_spec, self.attrs['shape']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]

self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, 1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, -1, 1, -1]
)

# shape: [6, 12, 48, 24] --> [3, 24, 6, -1, -1]
# raise error
self.attrs["shape"] = [3, 24, 6, -1, -1]
Expand Down

0 comments on commit d725a5e

Please sign in to comment.