Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN]shape inference for logsumexp logcumsumexp linspace logspace min poisson repeat_interleave topk uniform #62800

Merged
merged 14 commits into from
Mar 27, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,28 @@ bool FullWithTensorOpInferSymbolicShape(

bool LinspaceOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
const auto &num_shape_or_data =
shape_analysis->GetShapeOrDataForValue(op->operand_source(2));
const auto step = [&] {
symbol::DimExpr expr;
if (num_shape_or_data.data().has_value()) {
expr = num_shape_or_data.data().value()[0];
} else {
expr = num_shape_or_data.shape()[0];
}
return expr;
}();
const symbol::ShapeOrDataDimExprs &shape_data = [&] {
std::vector<symbol::DimExpr> out_dims{step};
return symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_dims)};
}();
shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data);
return true;
}
bool LogspaceOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
return LinspaceOpInferSymbolicShape(op, shape_analysis);
}

bool StackOpInferSymbolicShape(pir::Operation *op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,7 @@ bool TriuIndicesOpInferSymbolicShape(
}
bool UniformOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
return GaussianOpInferSymbolicShape(op, shape_analysis);
}

} // namespace paddle::dialect
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ OP_SAME_OPERANDS_AND_RESULT(LogicalNot_)
OP_SAME_OPERANDS_AND_RESULT(Logit)
OP_SAME_OPERANDS_AND_RESULT(Logit_)
OP_SAME_OPERANDS_AND_RESULT(Pow)
OP_SAME_OPERANDS_AND_RESULT(Poisson)
OP_SAME_OPERANDS_AND_RESULT(Pow_)
OP_SAME_OPERANDS_AND_RESULT(Print)
OP_SAME_OPERANDS_AND_RESULT(PutAlongAxis)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogicalNot)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogicalNot_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logit)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logit_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Poisson)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pow)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pow_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Print)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,16 @@ bool KthvalueOpInferSymbolicShape(

bool LogcumsumexpOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
// same as CumsumOpInferSymbolicShape
return CumsumOpInferSymbolicShape(op, shape_analysis);
}

bool LogsumexpOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
bool keepdim = GetBoolAttr(op, "keepdim");
std::vector<int64_t> axis = details::GetVectorAttr(op, "axis");
bool reduce_all = axis.size() == 0 ? true : false;
return details::ReduceInferDim(op, shape_analysis, axis, keepdim, reduce_all);
}

bool MaxOpInferSymbolicShape(pir::Operation *op,
Expand Down Expand Up @@ -325,9 +325,7 @@ bool MaxOpInferSymbolicShape(pir::Operation *op,

bool MinOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
return MaxOpInferSymbolicShape(op, shape_analysis);
}

bool PadOpInferSymbolicShape(pir::Operation *op,
Expand All @@ -337,13 +335,6 @@ bool PadOpInferSymbolicShape(pir::Operation *op,
return true;
}

bool PoissonOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
return true;
}

bool ProdOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
bool keepdim = GetBoolAttr(op, "keep_dim");
Expand All @@ -368,8 +359,44 @@ bool ProdOpInferSymbolicShape(pir::Operation *op,

bool RepeatInterleaveOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
pir::Value operand_source = op->operand_source(0);
const symbol::ShapeOrDataDimExprs &operand_shape_or_data =
shape_analysis->GetShapeOrDataForValue(operand_source);

const auto &attributes = op->attributes();
int repeats = attributes.at("repeats").dyn_cast<pir::Int32Attribute>().data();
// what should I do if axis is null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis (int,可选) – 指定对输入 x 进行运算的轴,若未指定,默认值为 None,使用输入 Tensor 的 flatten 形式。
应该是把 tensor 按照 1 维 vector进行处理、参考一下 infermeta ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

经过讨论确认、ir 中会插入 flatten op

int axis = attributes.at("axis").dyn_cast<pir::Int32Attribute>().data();

const std::vector<symbol::DimExpr> &in_dims_sym = [&] {
std::vector<symbol::DimExpr> dims;
if (operand_shape_or_data.data().has_value()) {
dims = operand_shape_or_data.data().value();
} else {
dims = operand_shape_or_data.shape();
}
return dims;
}();

int x_rank = in_dims_sym.size();

const auto &out_sym_shape = [&] {
std::vector<symbol::DimExpr> out_sym_shape;
for (int i = 0; i < x_rank; i++) {
if (i == axis) {
out_sym_shape.push_back(in_dims_sym[i] * repeats);
} else {
out_sym_shape.push_back(in_dims_sym[i]);
}
}
return out_sym_shape;
}();

shape_analysis->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_sym_shape)});

return true;
}

Expand Down Expand Up @@ -716,8 +743,45 @@ bool TileOpInferSymbolicShape(pir::Operation *op,

bool TopkOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
symbol::ShapeOrDataDimExprs x_shape_or_data =
shape_analysis->GetShapeOrDataForValue(op->operand_source(0));
symbol::ShapeOrDataDimExprs k_shape_or_data =
shape_analysis->GetShapeOrDataForValue(op->operand_source(1));
const auto &attributes = op->attributes();
int axis = attributes.at("axis").dyn_cast<pir::Int32Attribute>().data();
const std::vector<symbol::DimExpr> &in_dims_sym = [&] {
std::vector<symbol::DimExpr> dims;
if (x_shape_or_data.data().has_value()) {
dims = x_shape_or_data.data().value();
} else {
dims = x_shape_or_data.shape();
}
return dims;
}();

int x_rank = in_dims_sym.size();

int k = k_shape_or_data.data().value()[0].Get<int64_t>();

if (axis < 0) axis += x_rank;
const auto &out_sym_shape = [&] {
std::vector<symbol::DimExpr> out_sym_shape;
for (int i = 0; i < x_rank; ++i) {
if (i == axis) {
out_sym_shape.push_back(symbol::DimExpr(k));
} else {
out_sym_shape.push_back(in_dims_sym[i]);
}
}
return out_sym_shape;
}();

symbol::ShapeOrDataDimExprs shape_data{
symbol::TensorShapeOrDataDimExprs(out_sym_shape)};

shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data);
shape_analysis->SetShapeOrDataForValue(op->result(1), shape_data);

return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Poisson)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(RepeatInterleave)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape)
Expand Down
78 changes: 78 additions & 0 deletions test/ir/pir/cinn/symbolic/test_infer_sym_shape_multinary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,52 @@ def test_eval_symbolic(self):
return out


class LinspaceNet(paddle.nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
out = paddle.linspace(start=0, stop=5, num=10)
return out


class LinspaceOpInferSymbolicShapeTest(TestBase):
def prepare_data(self):
self.expected = ['shape[10], data[NULL]']

def test_eval_symbolic(self):
net = LinspaceNet()
x_spec = InputSpec(shape=[None, None, 2], dtype='float32')
input_spec = [x_spec]
net = apply_to_static(net, False, input_spec)
net.eval()
check_infer_results(net, input_spec, 'pd_op.linspace', self.expected)
return True


class LogspaceNet(paddle.nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
out = paddle.logspace(start=1, stop=5, num=10)
return out


class LogspaceOpInferSymbolicShapeTest(TestBase):
def prepare_data(self):
self.expected = ['shape[10], data[NULL]']

def test_eval_symbolic(self):
net = LogspaceNet()
x_spec = InputSpec(shape=[None, None, 2], dtype='float32')
input_spec = [x_spec]
net = apply_to_static(net, False, input_spec)
net.eval()
check_infer_results(net, input_spec, 'pd_op.logspace', self.expected)
return True


class SliceNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -185,6 +231,38 @@ def test_eval_symbolic(self):
return True


class PoissonNet(paddle.nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
out = paddle.poisson(x)

return out


class PoissonOpInferSymbolicShapeTest(TestBase):
def prepare_data(self):
self.cases = [np.random.rand(2, 3, 4)]
self.expected = ['shape[S0, S1, S2], data[NULL]']

def test_eval_symbolic(self):
net = PoissonNet()

for i in range(len(self.cases)):
x = self.cases[i]
x_spec = InputSpec(
shape=[None for index in range(len(x.shape))], dtype='float32'
)

input_spec = [x_spec]
net = apply_to_static(net, False, input_spec)
net.eval()
check_infer_results(net, input_spec, 'pd_op.poisson', self.expected)

return True


class TrilNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
Expand Down
56 changes: 56 additions & 0 deletions test/ir/pir/cinn/symbolic/test_infer_sym_shape_nullary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,61 @@ def test_eval_symbolic(self):
return True


class RepeatInterleaveNet(paddle.nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
out = paddle.repeat_interleave(x, 2, axis=0)
out = paddle.repeat_interleave(x, 2, axis=1)
out = paddle.repeat_interleave(x, 2, axis=-1)
out = paddle.repeat_interleave(x, 2, axis=-2)
return out


class RepeatInterleaveOpInferSymbolicShapeTest(TestBase):
def prepare_data(self):
self.expected = [
'shape[Mul(S0, 2), S1], data[NULL]',
'shape[S0, Mul(S1, 2)], data[NULL]',
'shape[S0, S1, Mul(S2, 2)], data[NULL]',
'shape[S0, Mul(S1, 2), S2], data[NULL]',
]

def test_eval_symbolic(self):
net = RepeatInterleaveNet()
x_spec = InputSpec(shape=[None, None, None], dtype='float32')
input_spec = [x_spec]
net = apply_to_static(net, False, input_spec)
net.eval()
check_infer_results(
net, input_spec, 'pd_op.repeat_interleave', self.expected
)
return True


class UniformNet(paddle.nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
out = paddle.tensor.random.uniform(shape=[12, 32], min=1.0, max=2.0)
return out


class UniformOpInferSymbolicShapeTest(TestBase):
def prepare_data(self):
self.expected = ['shape[12, 32], data[NULL]']

def test_eval_symbolic(self):
net = UniformNet()
x_spec = InputSpec(shape=[None, None, 2], dtype='float32')
input_spec = [x_spec]
net = apply_to_static(net, False, input_spec)
net.eval()
check_infer_results(net, input_spec, 'pd_op.uniform', self.expected)
return True


if __name__ == '__main__':
unittest.main()
Loading