Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【BUAA】【Infer Symbolic Shape】Add moe, multiclass_nms3, shadow_feed_tensor operator for CINN compiler #67337

Merged
merged 34 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
0ecfd69
no test
Jeff114514 Jul 31, 2024
8a582f3
unchanged vector
Jeff114514 Aug 5, 2024
1c94ef1
Merge branch 'PaddlePaddle:develop' into sf_tensor
Jeff114514 Aug 6, 2024
d7cb873
moe
Jeff114514 Aug 6, 2024
2a41813
Update same_operands_result.cc
Jeff114514 Aug 6, 2024
4752672
Merge pull request #5 from Jeff114514/sf_tensor
Jeff114514 Aug 6, 2024
3913b9a
unchange
Jeff114514 Aug 6, 2024
7d00b39
name
Jeff114514 Aug 7, 2024
ce0552d
name
Jeff114514 Aug 7, 2024
b180bb9
test
Jeff114514 Aug 8, 2024
77fecae
Merge branch 'moe' into tmp
Jeff114514 Aug 10, 2024
43ad95d
Merge pull request #9 from Jeff114514/tmp
Jeff114514 Aug 10, 2024
fa9ad91
err
Jeff114514 Aug 11, 2024
b7d57d0
same
Jeff114514 Aug 11, 2024
60880cb
dyn_cast
Jeff114514 Aug 12, 2024
ff9ab66
new sym
Jeff114514 Aug 12, 2024
4392539
no moe optest
Jeff114514 Aug 12, 2024
fe37df4
new test
Jeff114514 Aug 12, 2024
6fef51b
mkddim
Jeff114514 Aug 12, 2024
8190ce6
no mkddim
Jeff114514 Aug 12, 2024
e35dfb6
no mkddim
Jeff114514 Aug 12, 2024
d9c191b
pass
Jeff114514 Aug 12, 2024
e62de69
remove unused
Jeff114514 Aug 13, 2024
5686d2b
fix
Jeff114514 Aug 14, 2024
f2cdd48
Merge branch 'develop' into moe
Jeff114514 Aug 14, 2024
39d024b
fix
Jeff114514 Aug 14, 2024
bf708f0
fix
Jeff114514 Aug 14, 2024
7283fff
fix
Jeff114514 Aug 14, 2024
27af0b0
fix
Jeff114514 Aug 14, 2024
a2ae505
fixx
Jeff114514 Aug 14, 2024
776a427
fixx
Jeff114514 Aug 14, 2024
6b9205c
fixx
Jeff114514 Aug 14, 2024
8adfc22
todo
Jeff114514 Aug 15, 2024
276e829
todo
Jeff114514 Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1460,19 +1460,72 @@ bool RoiAlignOpInferSymbolicShape(
// return MergedMomentumOpInferSymbolicShape(op, infer_context);
// }

// bool MoeOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context)
// {
// // pass
// return true;
// }
bool MulticlassNms3OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &bboxes_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &scores_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));

// bool MulticlassNMS3OpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
const std::vector<symbol::DimExpr> &box_dims = bboxes_shape_or_data.shape();
const std::vector<symbol::DimExpr> &score_dims = scores_shape_or_data.shape();
const size_t score_size = score_dims.size();

PADDLE_ENFORCE_EQ(
score_size == 2 || score_size == 3,
true,
common::errors::InvalidArgument(
"The rank of Input(Scores) must be 2 or 3. But received rank = %d",
score_size));
PADDLE_ENFORCE_EQ(
box_dims.size(),
3,
common::errors::InvalidArgument(
"The rank of Input(BBoxes) must be 3. But received rank = %d",
box_dims.size()));

if (score_size == 3) {
PADDLE_ENFORCE_EQ(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个5个dim的约束需要修改为使用Addequalcstr()添加

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不包含||的约束已进行修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

score_size和box_dims.size()均为int类型

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不包含||的约束已进行修改

剩下的约束目前shape dialect无法表示了,直接删掉或者写成TODO,不能直接使用静态值判断

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已完成修改

box_dims[2] == 4 || box_dims[2] == 8 || box_dims[2] == 16 ||
box_dims[2] == 24 || box_dims[2] == 32,
true,
common::errors::InvalidArgument("The last dimension of Input(BBoxes) "
"must be 4 or 8 or 16 or 24 or 32"));
infer_context->AddEqualCstr(box_dims[1], score_dims[2]);
} else {
infer_context->AddEqualCstr(box_dims[2], symbol::DimExpr(4));
infer_context->AddEqualCstr(box_dims[1], score_dims[1]);
}

const auto &next_symbol_out_and_index = infer_context->GetNextSymName();
const auto &next_symbol_nms_rois_num = infer_context->GetNextSymName();

std::vector<symbol::DimExpr> out_shape;
out_shape.emplace_back(next_symbol_out_and_index);
out_shape.emplace_back(box_dims[2] + 2);

std::vector<symbol::DimExpr> index_shape;
index_shape.emplace_back(next_symbol_out_and_index);
Copy link
Contributor

@gongshaotian gongshaotian Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

怎么确定的out_shape 与 index_shape的第一维是相同的动态维度

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiClassNMSKernel out和index第一维总是相同的

index_shape.emplace_back(1);

std::vector<symbol::DimExpr> nms_rois_num_shape;
nms_rois_num_shape.emplace_back(next_symbol_nms_rois_num);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果能确定这个动态维度与上面两个无关,这里就可以不再使用中间变量了,直接emplace_back(infer_context->GetNextSymName());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已完成修改


infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(index_shape)});
infer_context->SetShapeOrDataForValue(
op->result(2),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(nms_rois_num_shape)});

return true;
}

bool MeshgridOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
Expand Down Expand Up @@ -1532,8 +1585,7 @@ bool MovingAverageAbsMaxScale_OpInferSymbolicShape(
}

// bool NceOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context)
// {
// pir::InferSymbolicShapeContext *infer_context){
// // pass
// return true;
// }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logspace)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MergedAdam_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MergedMomentum)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MergedMomentum_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Moe)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MulticlassNMS3)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MulticlassNms3)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MemoryEfficientAttention)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Meshgrid)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MovingAverageAbsMaxScale)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ OP_SAME_OPERANDS_AND_RESULT(Sigmoid)
OP_SAME_OPERANDS_AND_RESULT(Sigmoid_)
OP_SAME_OPERANDS_AND_RESULT(LeakyRelu)
OP_SAME_OPERANDS_AND_RESULT(LeakyRelu_)
OP_SAME_OPERANDS_AND_RESULT(Moe)
Jeff114514 marked this conversation as resolved.
Show resolved Hide resolved
OP_SAME_OPERANDS_AND_RESULT(ThresholdedRelu)
OP_SAME_OPERANDS_AND_RESULT(ThresholdedRelu_)
OP_SAME_OPERANDS_AND_RESULT(SquareSr)
Expand Down Expand Up @@ -242,6 +243,18 @@ bool ArgsortOpInferSymbolicShape(
return true;
}

bool ShadowFeedTensorsOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
pir::Value operand_source = op->operand_source(0);
const symbol::TensorListShapeOrDataDimExprs &shape_or_data_list =
infer_context->GetShapeOrDataForValue(operand_source)
.dyn_cast<symbol::TensorListShapeOrDataDimExprs>();
for (size_t i = 0; i < shape_or_data_list.size(); ++i) {
infer_context->SetShapeOrDataForValue(op->result(i), shape_or_data_list[i]);
}
return true;
}

} // namespace paddle::dialect

namespace cinn::dialect {} // namespace cinn::dialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logit_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsigmoid)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsigmoid_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mish)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Moe)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Poisson)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pow)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pow_)
Expand Down Expand Up @@ -138,6 +139,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Select)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShadowFeed)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShadowFeedTensors)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个算子有点问题先不添加了

OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShareData_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sign)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Silu)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/inconsistent/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@
kernel:
func: shadow_feed_tensors
param: [x]
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : share_data_
args : (Tensor x)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3283,6 +3283,7 @@
func : multiclass_nms3
data_type : scores
optional : rois_num, nms_rois_num
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : multinomial
args : (Tensor x, Scalar(int) num_samples = 1, bool replacement = false)
Expand Down Expand Up @@ -5111,6 +5112,7 @@
func: MoeInferMeta
kernel:
func: moe
interfaces: paddle::dialect::InferSymbolicShapeInterface

- op: number_count
args: (Tensor numbers, int upper_range)
Expand Down
Loading