From 70f2b54849bf0c861e70beb283188ad1aeef1b16 Mon Sep 17 00:00:00 2001 From: Winters Montagne <118546135+WintersMontagne10335@users.noreply.github.com> Date: Fri, 19 Apr 2024 16:55:46 +0800 Subject: [PATCH] Add InferSymbolicShape for pd_op.nonzero (#62987) * add pd_op.nonzero * update * update * update * update * update * update --- .../infer_symbolic_shape/unary_infer_sym.cc | 23 ++++++++++ .../infer_symbolic_shape/unary_infer_sym.h | 1 + paddle/phi/api/yaml/ops.yaml | 1 + .../symbolic/test_infer_sym_shape_unary_op.py | 42 +++++++++++++++++++ 4 files changed, 67 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 4dab7e358f05e..b69727cb9d4f8 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -328,6 +328,29 @@ bool MinOpInferSymbolicShape(pir::Operation *op, return MaxOpInferSymbolicShape(op, shape_analysis); } +bool NonzeroOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + const auto &x_shape = x_shape_or_data.shape(); + int rank = x_shape.size(); + + PADDLE_ENFORCE_GE( + rank, + 1UL, + phi::errors::InvalidArgument( + "Input(x) should have number of dimension at least 1.")); + + std::string sym_name = shape_analysis->GetNextSymName(); + std::vector out_shape{symbol::DimExpr{sym_name}, + symbol::DimExpr{rank}}; + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_shape)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; +} + bool PadOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { // input(0): Tensor x diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 2b7cd2c3cf4f9..e52b9aabc1568 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -35,6 +35,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logcumsumexp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod) OP_DECLARE_INFER_SYMBOLIC_SHAPE(RepeatInterleave) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 8a1aa0e36e6e1..84194d1eeb8e6 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2122,6 +2122,7 @@ kernel : func : nonzero data_type: condition + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : npu_identity args : (Tensor x, int format = -1) diff --git a/test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py b/test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py index 954f195f52f47..7a3507d44bc20 100644 --- a/test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py +++ b/test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py @@ -443,6 +443,48 @@ def test_eval_symbolic(self): return True +class NonzeroNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + out_nonzero = paddle.nonzero(x) + return out_nonzero + + +class NonzeroOpInferSymbolicShapeTest(TestBase): + def prepare_data(self): + self.cases = [np.random.rand(4, 5, 6)] + # pdb.set_trace() + + for _ in range(np.random.randint(1, 10)): + self.cases[0][np.random.randint(0, 3)][np.random.randint(0, 4)][ + np.random.randint(0, 5) + ] = 0 + + self.expected = [ + 'shape[S3, 3], data[NULL]', + ] + + def test_eval_symbolic(self): + net = NonzeroNet() + + for i in range(len(self.cases)): + x = self.cases[i] + x_spec = InputSpec( + shape=[None for index in range(len(x.shape))], dtype='float32' + ) + + input_spec = [x_spec] + net = apply_to_static(net, False, input_spec) + net.eval() + + # check the infer result + check_infer_results(net, input_spec, 'pd_op.nonzero', self.expected) + + return True + + class PutAlongAxisNet(paddle.nn.Layer): def __init__(self): super().__init__()