diff --git a/paddle/phi/infermeta/spmd_rules/fused_linear_param_grad_add.cc b/paddle/phi/infermeta/spmd_rules/fused_linear_param_grad_add.cc new file mode 100644 index 00000000000000..c2afdd86b57a76 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/fused_linear_param_grad_add.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/fused_linear_param_grad_add.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/infermeta/spmd_rules/matmul.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +SpmdInfo FusedLinearParamGradAddInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& dout, + const DistMetaTensor& dweight, + const DistMetaTensor& dbias, + bool multi_precision, + bool has_bias) { + auto dy_spmd_info = + MatmulInferSpmd(x, dout, /*trans_x=*/true, /*trans_y=*/false); + auto& x_dist_attr = PADDLE_GET_CONST(TensorDistAttr, dy_spmd_info.first[0]); + auto& dout_dist_attr = + PADDLE_GET_CONST(TensorDistAttr, dy_spmd_info.first[1]); + auto weight_grad_dist_attr = + PADDLE_GET_CONST(TensorDistAttr, dy_spmd_info.second[0]); + + weight_grad_dist_attr = ReduceGradBroadCastDims(2, weight_grad_dist_attr); + + TensorDistAttr dweight_dist_attr = dweight.dist_attr(); + auto dweight_shape = common::vectorize(dweight.dims()); + TensorDistAttr dbias_dist_attr = dbias.dist_attr(); + auto dbias_shape = common::vectorize(dbias.dims()); + + TensorDistAttr bias_grad_dist_attr; + if (has_bias) { + bias_grad_dist_attr = ReduceGradBroadCastDims(1, dout.dist_attr()); + } + + // check dweight and dweight_grad + if (!IsEmpty(dweight_shape)) { + PADDLE_ENFORCE_EQ(dweight_dist_attr, + weight_grad_dist_attr, + phi::errors::InvalidArgument( + "dweight_dist_attr [%s] and weight_grad_dist_attr " + "[%s] should be equal", + dweight_dist_attr.to_string(), + weight_grad_dist_attr.to_string())); + } + // check dbias and bias_grad + if (!IsEmpty(dbias_shape)) { + PADDLE_ENFORCE_EQ( + dbias_dist_attr, + bias_grad_dist_attr, + phi::errors::InvalidArgument( + "dbias_dist_attr [%s] and bias_grad_dist_attr [%s] should be equal", + dbias_dist_attr.to_string(), + bias_grad_dist_attr.to_string())); + } + + return {{x_dist_attr, dout_dist_attr, dweight_dist_attr, dbias_dist_attr}, + {weight_grad_dist_attr, bias_grad_dist_attr}}; +} + +SpmdInfo FusedLinearParamGradAddInferSpmdFakeReverse() { return SpmdInfo(); } + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/fused_linear_param_grad_add.h b/paddle/phi/infermeta/spmd_rules/fused_linear_param_grad_add.h new file mode 100644 index 00000000000000..794202598a2c30 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/fused_linear_param_grad_add.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { +SpmdInfo FusedLinearParamGradAddInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& dout, + const DistMetaTensor& dweight, + const DistMetaTensor& dbias, + bool multi_precision, + bool has_bias); + +SpmdInfo FusedLinearParamGradAddInferSpmdFakeReverse(); +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/matmul.cc b/paddle/phi/infermeta/spmd_rules/matmul.cc index 5cd895401dc96e..b9b15707737ea3 100644 --- a/paddle/phi/infermeta/spmd_rules/matmul.cc +++ b/paddle/phi/infermeta/spmd_rules/matmul.cc @@ -286,43 +286,6 @@ static bool DistAttrsAreBasicallyEqual( in_dist_attr.partial_status() == out_dist_attr.partial_status()); } -TensorDistAttr ReduceGradBroadCastDims(const TensorDistAttr& input, - const ArgDistAttr& grad) { - auto& grad_in = PADDLE_GET_CONST(TensorDistAttr, grad); - auto grad_dim = grad_in.dims_mapping().size(); - auto input_dim = input.dims_mapping().size(); - PADDLE_ENFORCE_GE( - grad_dim, - input_dim, - phi::errors::InvalidArgument("grad dim must ge than input dim, but we " - "got grad_dim [%d], input_dim[%d]", - grad_dim, - input_dim)); - if (grad_dim == input_dim) { - return grad_in; - } - size_t broadcast_dim = grad_dim - input_dim; - // gather partial status - auto partial_dims = grad_in.partial_dims(); - auto& grad_dims_mapping = grad_in.dims_mapping(); - auto dims_mapping = input.dims_mapping(); - for (size_t i = 0; i < grad_dim; ++i) { - auto mapping = grad_dims_mapping[i]; - if (i < broadcast_dim) { - if (mapping >= 0) { - partial_dims.insert(mapping); - } - } else { - dims_mapping[i - broadcast_dim] = mapping; - } - } - auto grad_out = CopyTensorDistAttrForOutput(input); - grad_out.set_dims_mapping(dims_mapping); - std::vector partial_status(partial_dims.begin(), partial_dims.end()); - grad_out.set_partial_status(partial_status); - return grad_out; -} - SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y, const DistMetaTensor& out_grad, diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index a75132ab30e0e7..1015f61802bc4a 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/flash_attention.h" #include "paddle/phi/infermeta/spmd_rules/flatten.h" #include "paddle/phi/infermeta/spmd_rules/full_like.h" +#include "paddle/phi/infermeta/spmd_rules/fused_linear_param_grad_add.h" #include "paddle/phi/infermeta/spmd_rules/fused_rope.h" #include "paddle/phi/infermeta/spmd_rules/layer_norm.h" #include "paddle/phi/infermeta/spmd_rules/matmul.h" @@ -618,10 +619,18 @@ PD_REGISTER_SPMD_RULE( cross_entropy_with_softmax, PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmd), PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmdReverse)); + PD_REGISTER_SPMD_RULE( softmax_with_cross_entropy, PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmd), PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmdReverse)); +// fused_linear_param_grad_add got no reverse infer spmd rule +PD_REGISTER_SPMD_RULE( + fused_linear_param_grad_add, + PD_INFER_SPMD(phi::distributed::FusedLinearParamGradAddInferSpmd), + PD_INFER_SPMD( + phi::distributed::FusedLinearParamGradAddInferSpmdFakeReverse)); + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index e3d079dc4c88c5..18096ea7e12c6e 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -554,5 +554,55 @@ void DebugInfoForInferSpmd(const std::string& rule_name, } } +TensorDistAttr ReduceGradBroadCastDims(const TensorDistAttr& input, + const ArgDistAttr& grad) { + const auto& grad_in = PADDLE_GET_CONST(TensorDistAttr, grad); + return ReduceGradBroadCastDims(input, grad_in); +} + +TensorDistAttr ReduceGradBroadCastDims(int64_t input_dims, + const TensorDistAttr& grad) { + TensorDistAttr input; + std::vector dim_mapping(input_dims, -1); + input.set_dims_mapping(dim_mapping); + return ReduceGradBroadCastDims(input, grad); +} + +TensorDistAttr ReduceGradBroadCastDims(const TensorDistAttr& input, + const TensorDistAttr& grad) { + auto grad_dim = grad.dims_mapping().size(); + auto input_dim = input.dims_mapping().size(); + PADDLE_ENFORCE_GE( + grad_dim, + input_dim, + phi::errors::InvalidArgument("grad dim must ge than input dim, but we " + "got grad_dim [%d], input_dim[%d]", + grad_dim, + input_dim)); + if (grad_dim == input_dim) { + return grad; + } + size_t broadcast_dim = grad_dim - input_dim; + // gather partial status + auto partial_dims = grad.partial_dims(); + auto& grad_dims_mapping = grad.dims_mapping(); + auto dims_mapping = input.dims_mapping(); + for (size_t i = 0; i < grad_dim; ++i) { + auto mapping = grad_dims_mapping[i]; + if (i < broadcast_dim) { + if (mapping >= 0) { + partial_dims.insert(mapping); + } + } else { + dims_mapping[i - broadcast_dim] = mapping; + } + } + auto grad_out = CopyTensorDistAttrForOutput(input); + grad_out.set_dims_mapping(dims_mapping); + std::vector partial_status(partial_dims.begin(), partial_dims.end()); + grad_out.set_partial_status(partial_status); + return grad_out; +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/utils.h b/paddle/phi/infermeta/spmd_rules/utils.h index c58e80ba02608f..a59e582f151ae3 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.h +++ b/paddle/phi/infermeta/spmd_rules/utils.h @@ -205,5 +205,14 @@ std::vector GetDimsMappingForAxes( void DebugInfoForInferSpmd(const std::string& rule_name, const SpmdInfo& infer_result); +TensorDistAttr ReduceGradBroadCastDims(const TensorDistAttr& input, + const ArgDistAttr& grad); + +TensorDistAttr ReduceGradBroadCastDims(const TensorDistAttr& input, + const TensorDistAttr& grad); + +TensorDistAttr ReduceGradBroadCastDims(int64_t input_dims, + const TensorDistAttr& grad); + } // namespace distributed } // namespace phi diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_for_fused_linear_param_grad_add.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_for_fused_linear_param_grad_add.py new file mode 100644 index 00000000000000..7edb7dc1335268 --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_for_fused_linear_param_grad_add.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle import _C_ops + + +class TestFusedParamGradAddForSemiAutoParallel: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + + def test_body(self): + x_shape = [4, 16, 32] + y_shape = [4, 16, 64] + + paddle.seed(self._seed) + np.random.seed(self._seed) + + x_np = np.random.random(size=x_shape).astype(self._dtype) + y_np = np.random.random(size=y_shape).astype(self._dtype) + + def run_acc_step(x, y): + weight_grad = None + bias_grad = None + for _ in range(2): + weight_grad, bias_grad = _C_ops.fused_linear_param_grad_add( + x, + y, + weight_grad, + bias_grad, + False, + True, + ) + return weight_grad, bias_grad + + x = paddle.to_tensor(x_np) + y = paddle.to_tensor(y_np) + x.stop_gradient = True + y.stop_gradient = True + + weight_grad, bias_grad = run_acc_step(x, y) + + # test mp col split + x_placements = [dist.Shard(0), dist.Replicate()] + y_placements = [dist.Shard(0), dist.Shard(2)] + + dist_x = dist.shard_tensor(x_np, self._mesh, x_placements) + dist_y = dist.shard_tensor(y_np, self._mesh, y_placements) + dist_x.stop_gradient = True + dist_y.stop_gradient = True + + weight_grad_dist, bias_grad_dist = run_acc_step(dist_x, dist_y) + self.check_tensor_eq(weight_grad, weight_grad_dist) + self.check_tensor_eq(bias_grad, bias_grad_dist) + + # test mp row split + x_placements = [dist.Shard(0), dist.Shard(2)] + y_placements = [dist.Shard(0), dist.Replicate()] + dist_x = dist.shard_tensor(x_np, self._mesh, x_placements) + dist_y = dist.shard_tensor(y_np, self._mesh, y_placements) + dist_x.stop_gradient = True + dist_y.stop_gradient = True + weight_grad_dist, bias_grad_dist = run_acc_step(dist_x, dist_y) + self.check_tensor_eq(weight_grad, weight_grad_dist) + self.check_tensor_eq(bias_grad, bias_grad_dist) + + def test_fused_linear_param_grad_add(self): + self.test_body() + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_fused_linear_param_grad_add() + + +if __name__ == '__main__': + TestFusedParamGradAddForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py index 8066573aff078b..3fb0885b671af3 100644 --- a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py +++ b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py @@ -42,6 +42,16 @@ def test_simple_net_hybrid_strategy(self): ) ckpt_path.cleanup() + def test_fused_linear_param_grad_add(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_fused_linear_param_grad_add.py", + user_defined_envs=envs, + ) + class TestSemiAutoParallelHybridStrategy(test_base.CommunicationTestDistBase): def setUp(self): diff --git a/test/auto_parallel/spmd_rules/CMakeLists.txt b/test/auto_parallel/spmd_rules/CMakeLists.txt index 8cfba30807d205..d8c99d33a189f9 100644 --- a/test/auto_parallel/spmd_rules/CMakeLists.txt +++ b/test/auto_parallel/spmd_rules/CMakeLists.txt @@ -27,6 +27,8 @@ if(WITH_DISTRIBUTE) py_test_modules(test_triu_rule MODULES test_triu_rule) py_test_modules(test_flash_attention_rule MODULES test_flash_attention_rule) py_test_modules(test_tile_rule MODULES test_tile_rule) + py_test_modules(test_fused_linear_param_grad_add_rule MODULES + test_fused_linear_param_grad_add_rule) # End of unittests WITH single card WITHOUT timeout endif() diff --git a/test/auto_parallel/spmd_rules/test_fused_linear_param_grad_add_rule.py b/test/auto_parallel/spmd_rules/test_fused_linear_param_grad_add_rule.py new file mode 100644 index 00000000000000..fc767b0d15c90f --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_fused_linear_param_grad_add_rule.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto +from paddle.framework import core + + +class TestFusedLinearParamGradAddSPMDRule(unittest.TestCase): + """ + Unit tests for split spmd rule. + """ + + def setUp(self): + self.process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]]) + + def build_inputs(self, dims_mapping, shape): + tensor_dist_attr = TensorDistAttr() + tensor_dist_attr.dims_mapping = dims_mapping + tensor_dist_attr.process_mesh = self.process_mesh + return DistTensorSpec(shape, tensor_dist_attr) + + def test_infer_forward(self): + rule = core.get_phi_spmd_rule("fused_linear_param_grad_add") + + # test mp split by col + input = self.build_inputs([0, -1, -1], [2, 512, 1024]) + out_grad = self.build_inputs([0, -1, 1], [2, 512, 2048]) + dweight = self.build_inputs([], []) + dbais = self.build_inputs([], []) + infered_dist_attrs = rule.infer_forward( + input, out_grad, dweight, dbais, 0, True + ) + self.assertEqual(infered_dist_attrs[1][0].dims_mapping, [-1, 1]) + self.assertEqual(infered_dist_attrs[1][1].dims_mapping, [1]) + + # test mp split by row + input = self.build_inputs([0, -1, 1], [2, 512, 1024]) + out_grad = self.build_inputs([0, -1, -1], [2, 512, 2048]) + dweight = self.build_inputs([], []) + dbais = self.build_inputs([], []) + infered_dist_attrs = rule.infer_forward( + input, out_grad, dweight, dbais, 0, True + ) + self.assertEqual(infered_dist_attrs[1][0].dims_mapping, [1, -1]) + self.assertEqual(infered_dist_attrs[1][1].dims_mapping, [-1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/cpp/auto_parallel/CMakeLists.txt b/test/cpp/auto_parallel/CMakeLists.txt index c5bdc7cb5e8506..d5bf3f1cf5d3a7 100644 --- a/test/cpp/auto_parallel/CMakeLists.txt +++ b/test/cpp/auto_parallel/CMakeLists.txt @@ -41,6 +41,14 @@ if(WITH_DISTRIBUTE) spmd_rules phi) + paddle_test( + fused_linear_param_grad_add_spmd_rule_test + SRCS + fused_linear_param_grad_add_spmd_rule_test.cc + DEPS + spmd_rule_test_util + spmd_rules + phi) endif() cc_test( diff --git a/test/cpp/auto_parallel/fused_linear_param_grad_add_spmd_rule_test.cc b/test/cpp/auto_parallel/fused_linear_param_grad_add_spmd_rule_test.cc new file mode 100644 index 00000000000000..109d183940dfcd --- /dev/null +++ b/test/cpp/auto_parallel/fused_linear_param_grad_add_spmd_rule_test.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "test/cpp/auto_parallel/spmd_rule_test_util.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +TEST(FusedLinearParamGradAddSPMDRule, Ctor) { + // build input data class + + std::vector mesh_shape = {2, 3}; + std::vector process_ids = {0, 1, 2, 3, 4, 5}; + std::vector dim_names = {"x", "y"}; + ProcessMesh process_mesh(mesh_shape, process_ids, dim_names); + + // b s h + std::vector x_shape = {2, 512, 1024}; + std::vector out_shape = {2, 512, 2048}; + std::vector weight_shape = {1024, 2048}; + std::vector bias_shape = {2048}; + + // test mp col split + { + TensorDistAttr x_dist_attr = TensorDistAttr(); + x_dist_attr.set_process_mesh(process_mesh); + x_dist_attr.set_dims_mapping(std::vector({0, -1, -1})); + x_dist_attr.set_dynamic_dims(std::vector({false, false, false})); + + TensorDistAttr out_dist_attr = TensorDistAttr(); + out_dist_attr.set_process_mesh(process_mesh); + out_dist_attr.set_dims_mapping(std::vector({0, -1, 1})); + out_dist_attr.set_dynamic_dims(std::vector({false, false, false})); + + phi::distributed::DistMetaTensor x(phi::make_ddim(x_shape), x_dist_attr); + phi::distributed::DistMetaTensor out(phi::make_ddim(out_shape), + out_dist_attr); + phi::distributed::DistMetaTensor dweight; + phi::distributed::DistMetaTensor dbias; + for (int i = 0; i < 3; i++) { + auto spmd_info = + FusedLinearParamGradAddInferSpmd(x, out, dweight, dbias, 0, true); + check_dim_mapping(spmd_info.second[0], {-1, 1}); + check_partial_dims(spmd_info.second[0], {0}); + check_dim_mapping(spmd_info.second[1], {1}); + check_partial_dims(spmd_info.second[1], {0}); + dweight = phi::distributed::DistMetaTensor( + phi::make_ddim(weight_shape), + PADDLE_GET_CONST(TensorDistAttr, spmd_info.second[0])); + dbias = phi::distributed::DistMetaTensor( + phi::make_ddim(bias_shape), + PADDLE_GET_CONST(TensorDistAttr, spmd_info.second[1])); + } + } + + // test mp row split + { + TensorDistAttr x_dist_attr = TensorDistAttr(); + x_dist_attr.set_process_mesh(process_mesh); + x_dist_attr.set_dims_mapping(std::vector({0, -1, 1})); + x_dist_attr.set_dynamic_dims(std::vector({false, false, false})); + + TensorDistAttr out_dist_attr = TensorDistAttr(); + out_dist_attr.set_process_mesh(process_mesh); + out_dist_attr.set_dims_mapping(std::vector({0, -1, -1})); + out_dist_attr.set_dynamic_dims(std::vector({false, false, false})); + + phi::distributed::DistMetaTensor x(phi::make_ddim(x_shape), x_dist_attr); + phi::distributed::DistMetaTensor out(phi::make_ddim(out_shape), + out_dist_attr); + phi::distributed::DistMetaTensor dweight; + phi::distributed::DistMetaTensor dbias; + for (int i = 0; i < 3; i++) { + auto spmd_info = + FusedLinearParamGradAddInferSpmd(x, out, dweight, dbias, 0, true); + check_dim_mapping(spmd_info.second[0], {1, -1}); + check_partial_dims(spmd_info.second[0], {0}); + check_dim_mapping(spmd_info.second[1], {-1}); + check_partial_dims(spmd_info.second[1], {0}); + dweight = phi::distributed::DistMetaTensor( + phi::make_ddim(weight_shape), + PADDLE_GET_CONST(TensorDistAttr, spmd_info.second[0])); + dbias = phi::distributed::DistMetaTensor( + phi::make_ddim(bias_shape), + PADDLE_GET_CONST(TensorDistAttr, spmd_info.second[1])); + } + } +} +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle