diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index 97532183f87a3..469db025ce650 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -20,6 +20,11 @@ #include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/dialect/shape/ir/shape_attribute.h" +// to make codes shorter +using ShapeOrData = symbol::ShapeOrDataDimExprs; +using TensorExprs = symbol::TensorShapeOrDataDimExprs; +using TensorListExprs = symbol::TensorListShapeOrDataDimExprs; + template struct AttributeTrait; @@ -78,9 +83,6 @@ bool SameOperandsAndResultShape( symbol::ShapeOrDataDimExprs operand_shape_or_data = shape_analysis->GetShapeOrDataForValue(operand_source); - op->set_attribute("symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), - operand_shape_or_data)); pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, operand_shape_or_data); return true; @@ -143,9 +145,7 @@ bool InferSymbolicShapeElementWiseBinary( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(shapes)}; shape_analysis->SetShapeOrDataForValue(res, shape_data); - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); + return true; } @@ -184,9 +184,6 @@ bool DataOpInferSymbolicShape(pir::Operation *op, symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(sym_dims)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); @@ -263,9 +260,7 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op, sym_shape, operand_shape_or_data.shape())}; shape_analysis->SetShapeOrDataForValue(res, shape_or_data); - op->set_attribute("symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), - shape_or_data)); + return true; } @@ -305,9 +300,6 @@ bool StackOpInferSymbolicShape(pir::Operation *op, symbol::ShapeOrDataDimExprs shape_data( symbol::TensorShapeOrDataDimExprs(out_dims, out_dims_data)); - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; @@ -368,9 +360,7 @@ bool ReduceInferDim(pir::Operation *op, pir::Value res = op->result(0); symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(shapes)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); + shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; } @@ -441,9 +431,6 @@ bool ReshapeOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_dims)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); pir::Value res0 = op->result(0); pir::Value res1 = op->result(1); @@ -476,10 +463,6 @@ bool FullIntArrayOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(shape, data)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; @@ -537,10 +520,6 @@ bool SliceOpInferSymbolicShape(pir::Operation *op, symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(sym_shape, out_data)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; } @@ -580,10 +559,6 @@ bool FullOpInferSymbolicShape(pir::Operation *op, symbol::TensorShapeOrDataDimExprs(sym_shape)}; shape_data.SetData(sym_data); - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; @@ -629,10 +604,6 @@ bool ConcatOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_dims)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); @@ -685,10 +656,6 @@ bool GatherNdOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(result_sym_dims)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); @@ -702,7 +669,7 @@ bool PowOpInferSymbolicShape(pir::Operation *op, bool Pow_OpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return PowOpInferSymbolicShape(op, shape_analysis); } @@ -812,10 +779,6 @@ bool SqueezeOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(output_shape_sym)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); @@ -891,10 +854,6 @@ bool UnsqueezeOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(result_sym_dims)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); @@ -949,10 +908,6 @@ bool TileOpInferSymbolicShape(pir::Operation *op, symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_shape)}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - pir::Value res = op->result(0); shape_analysis->SetShapeOrDataForValue(res, shape_data); @@ -962,7 +917,7 @@ bool TileOpInferSymbolicShape(pir::Operation *op, bool TransposeOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } bool Transpose_OpInferSymbolicShape( @@ -1095,35 +1050,132 @@ bool EmbeddingOpInferSymbolicShape( bool SparseWeightEmbeddingOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } bool ExpandOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } bool MatmulOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + // x_dims can't be const or ref here, in case to be broadcasted + std::vector x_dims = [&] { + std::vector dims; + const auto &x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + if (x_shape_or_data.data().has_value()) { + dims = x_shape_or_data.data().value(); + } else { + dims = x_shape_or_data.shape(); + } + return dims; + }(); + + // y_dims can't be const or ref here, in case to be broadcasted + std::vector y_dims = [&] { + std::vector dims; + const auto y_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + if (y_shape_or_data.data().has_value()) { + dims = y_shape_or_data.data().value(); + } else { + dims = y_shape_or_data.shape(); + } + return dims; + }(); + + size_t ndims_x = x_dims.size(); + size_t ndims_y = y_dims.size(); + + const bool x_broadcasted = [&] { + bool broadcasted = false; + if (ndims_x == 1) { + x_dims.insert(x_dims.begin(), 1); + ndims_x = 2; + broadcasted = true; + } + return broadcasted; + }(); + + const bool y_broadcasted = [&] { + bool broadcasted = false; + if (ndims_y == 1) { + y_dims.emplace_back(1); + ndims_y = 2; + broadcasted = true; + } + return broadcasted; + }(); + + std::vector out_dims; + if (ndims_x > ndims_y) { + out_dims.assign(x_dims.begin(), x_dims.end() - 2); + } else if (ndims_x < ndims_y) { + out_dims.assign(y_dims.begin(), y_dims.end() - 2); + } else { + symbol::DimExprBuilder builder{nullptr}; + for (size_t i = 0; i < ndims_x - 2; ++i) { + out_dims.emplace_back(builder.Broadcast(x_dims[i], y_dims[i])); + } + } + + symbol::DimExpr out_M = + op->attributes().at("transpose_x").dyn_cast().data() + ? x_dims[ndims_x - 1] + : x_dims[ndims_x - 2]; + symbol::DimExpr out_N = + op->attributes().at("transpose_y").dyn_cast().data() + ? y_dims[ndims_y - 2] + : y_dims[ndims_y - 1]; + if (!x_broadcasted) { + out_dims.emplace_back(out_M); + } + if (!y_broadcasted) { + out_dims.emplace_back(out_N); + } + + shape_analysis->SetShapeOrDataForValue(op->result(0), + ShapeOrData{TensorExprs(out_dims)}); + return true; } bool MaxOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); - return true; + bool keepdim = + op->attributes().at("keepdim").dyn_cast().data(); + + const std::vector axis = [&] { + pir::Operation *axis_gen_op = op->operand_source(1).defining_op(); + std::vector axis_vec; + if (axis_gen_op->isa()) { + axis_vec = GetVectorAttr( + axis_gen_op->dyn_cast(), "value"); + } else { + // TODO(lanxianghit): there's other source: pir::VectorType, + // paddle::dialect::DenseTensorType, but after PRIM, maybe always + // FullIntArrayOp, to be confirmed + PADDLE_THROW( + phi::errors::Unimplemented("MaxOpInferSymbolicShape: 'axis' only " + "support FullIntArrayOp's result now.")); + } + return axis_vec; + }(); + + bool reduce_all = axis.size() == 0 ? true : false; + + return ReduceInferDim(op, shape_analysis, axis, keepdim, reduce_all); } bool TrilOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } @@ -1135,7 +1187,7 @@ bool Tril_OpInferSymbolicShape(pir::Operation *op, bool WhereOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; } @@ -1189,10 +1241,6 @@ bool SliceOpInferSymbolicShape(pir::Operation *op, }; symbol::ShapeOrDataDimExprs shape_data{GetOutDimExprs()}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); return true; } @@ -1239,10 +1287,6 @@ bool ConcatOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(GetOutDimExprs())}; - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); return true; } @@ -1292,9 +1336,7 @@ bool ReshapeOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_dims)}; shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); - op->set_attribute( - "symbolic_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); + return true; } diff --git a/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py b/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py index d09ba04ff6576..e20f64b5ee508 100644 --- a/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py +++ b/test/ir/pir/cinn/symbolic/test_op_infer_sym_shape.py @@ -202,5 +202,151 @@ def test_eval_symbolic(self): return out +class MatmulNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x, y, trans_x, trans_y): + out = paddle.matmul(x, y, trans_x, trans_y) + + return out + + +class TestMatmulOpInferSymbolicShape(TestBase): + def prepare_data(self): + self.cases = [ + # [x, y, trans_x, trans_y] + [np.random.rand(1, 3), np.random.rand(3, 2), False, False], + # with broadcast + [np.random.rand(10), np.random.rand(10), False, False], # [] + [np.random.rand(10, 5), np.random.rand(5), False, False], # [10] + [ + np.random.rand(10, 5, 2), + np.random.rand(2), + False, + False, + ], # [10, 5] + [ + np.random.rand(10, 5, 2), + np.random.rand(10, 2, 5), + False, + False, + ], # [10, 5, 5] + [ + np.random.rand(10, 1, 5, 2), + np.random.rand(1, 3, 2, 5), + False, + False, + ], # [10, 3, 5, 5] + # with transpose + [np.random.rand(3, 5), np.random.rand(3, 2), True, False], # [5, 2] + [np.random.rand(3, 5), np.random.rand(4, 5), False, True], # [3, 4] + ] + + self.expected = [ + 'shape[S0, S3], data[NULL]', + # with broadcast + 'shape[], data[NULL]', + 'shape[S0], data[NULL]', + 'shape[S0, S1], data[NULL]', + 'shape[Broadcast(S0, S3), S1, S5], data[NULL]', + 'shape[Broadcast(S0, S4), Broadcast(S1, S5), S2, S7], data[NULL]', + # with transpose + 'shape[S1, S3], data[NULL]', + 'shape[S0, S2], data[NULL]', + ] + + def test_eval_symbolic(self): + net = MatmulNet() + + for i in range(len(self.cases)): + x, y, trans_x, trans_y = self.cases[i] + x_spec = InputSpec( + shape=[None for index in range(len(x.shape))], dtype='float32' + ) + y_spec = InputSpec( + shape=[None for index in range(len(y.shape))], dtype='float32' + ) + + input_spec = [x_spec, y_spec, trans_x, trans_y] + net = apply_to_static(net, False, input_spec) + net.eval() + + # check the infer result + sym_shape_str_list = get_sym_shape_str_for_op( + net, input_spec, 'pd_op.matmul' + ) + np.testing.assert_equal(len(sym_shape_str_list), 1) + np.testing.assert_equal( + sym_shape_str_list[0].find(self.expected[i]), + 0, + f'in case i = {i}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i])}', + ) + + return True + + +class MaxNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + out = paddle.max(x) + out = paddle.max(x, 0) + out = paddle.max(x, 1) + out = paddle.max(x, -1) + out = paddle.max(x, -2) + + # keepdim=True + # out = paddle.max(x, 0, True) + + return out + + +class TestMaxOpInferSymbolicShape(TestBase): + def prepare_data(self): + self.cases = [np.random.rand(2, 4)] + + self.expected = [ + [ + 'shape[], data[NULL]', + 'shape[S1], data[NULL]', + 'shape[S0], data[NULL]', + 'shape[S0], data[NULL]', + 'shape[S1], data[NULL]', + # 'shape[1, S1], data[NULL]', + ] + ] + + def test_eval_symbolic(self): + net = MaxNet() + + for i in range(len(self.cases)): + x = self.cases[i] + x_spec = InputSpec( + shape=[None for index in range(len(x.shape))], dtype='float32' + ) + + input_spec = [x_spec] + net = apply_to_static(net, False, input_spec) + net.eval() + + # check the infer result + sym_shape_str_list = get_sym_shape_str_for_op( + net, input_spec, 'pd_op.max' + ) + np.testing.assert_equal( + len(sym_shape_str_list), len(self.expected[i]) + ) + for j in range(len(sym_shape_str_list)): + np.testing.assert_equal( + sym_shape_str_list[j].find(self.expected[i][j]), + 0, + f'in case i,j = {i},{j}: output shape ({sym_shape_str_list[0]}) is not expected {(self.expected[i][j])}', + ) + + return True + + if __name__ == '__main__': unittest.main()