Skip to content

Commit

Permalink
Add InferSymbolicShape for pd_op.nonzero (#62987)
Browse files Browse the repository at this point in the history
* add pd_op.nonzero

* update

* update

* update

* update

* update

* update
  • Loading branch information
WintersMontagne10335 authored Apr 19, 2024
1 parent 0ae676f commit 70f2b54
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,29 @@ bool MinOpInferSymbolicShape(pir::Operation *op,
return MaxOpInferSymbolicShape(op, shape_analysis);
}

bool NonzeroOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
const auto &x_shape_or_data =
shape_analysis->GetShapeOrDataForValue(op->operand_source(0));
const auto &x_shape = x_shape_or_data.shape();
int rank = x_shape.size();

PADDLE_ENFORCE_GE(
rank,
1UL,
phi::errors::InvalidArgument(
"Input(x) should have number of dimension at least 1."));

std::string sym_name = shape_analysis->GetNextSymName();
std::vector<symbol::DimExpr> out_shape{symbol::DimExpr{sym_name},
symbol::DimExpr{rank}};

symbol::ShapeOrDataDimExprs shape_data{
symbol::TensorShapeOrDataDimExprs(out_shape)};
shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data);
return true;
}

bool PadOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
// input(0): Tensor x
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logcumsumexp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(RepeatInterleave)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2122,6 +2122,7 @@
kernel :
func : nonzero
data_type: condition
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : npu_identity
args : (Tensor x, int format = -1)
Expand Down
42 changes: 42 additions & 0 deletions test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,48 @@ def test_eval_symbolic(self):
return True


class NonzeroNet(paddle.nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
out_nonzero = paddle.nonzero(x)
return out_nonzero


class NonzeroOpInferSymbolicShapeTest(TestBase):
def prepare_data(self):
self.cases = [np.random.rand(4, 5, 6)]
# pdb.set_trace()

for _ in range(np.random.randint(1, 10)):
self.cases[0][np.random.randint(0, 3)][np.random.randint(0, 4)][
np.random.randint(0, 5)
] = 0

self.expected = [
'shape[S3, 3], data[NULL]',
]

def test_eval_symbolic(self):
net = NonzeroNet()

for i in range(len(self.cases)):
x = self.cases[i]
x_spec = InputSpec(
shape=[None for index in range(len(x.shape))], dtype='float32'
)

input_spec = [x_spec]
net = apply_to_static(net, False, input_spec)
net.eval()

# check the infer result
check_infer_results(net, input_spec, 'pd_op.nonzero', self.expected)

return True


class PutAlongAxisNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 70f2b54

Please sign in to comment.