-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【BUAA】【Infer Symbolic Shape】Add moe, multiclass_nms3, shadow_feed_tensor operator for CINN compiler #67337
【BUAA】【Infer Symbolic Shape】Add moe, multiclass_nms3, shadow_feed_tensor operator for CINN compiler #67337
Changes from 26 commits
0ecfd69
8a582f3
1c94ef1
d7cb873
2a41813
4752672
3913b9a
7d00b39
ce0552d
b180bb9
77fecae
43ad95d
fa9ad91
b7d57d0
60880cb
ff9ab66
4392539
fe37df4
6fef51b
8190ce6
e35dfb6
d9c191b
e62de69
5686d2b
f2cdd48
39d024b
bf708f0
7283fff
27af0b0
a2ae505
776a427
6b9205c
8adfc22
276e829
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1460,19 +1460,72 @@ bool RoiAlignOpInferSymbolicShape( | |
// return MergedMomentumOpInferSymbolicShape(op, infer_context); | ||
// } | ||
|
||
// bool MoeOpInferSymbolicShape(pir::Operation *op, | ||
// pir::InferSymbolicShapeContext *infer_context) | ||
// { | ||
// // pass | ||
// return true; | ||
// } | ||
bool MulticlassNms3OpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const auto &bboxes_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
const auto &scores_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(1)); | ||
|
||
// bool MulticlassNMS3OpInferSymbolicShape(pir::Operation *op, | ||
// pir::InferSymbolicShapeContext | ||
// *infer_context) { | ||
// // pass | ||
// return true; | ||
// } | ||
const std::vector<symbol::DimExpr> &box_dims = bboxes_shape_or_data.shape(); | ||
const std::vector<symbol::DimExpr> &score_dims = scores_shape_or_data.shape(); | ||
const size_t score_size = score_dims.size(); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
score_size == 2 || score_size == 3, | ||
true, | ||
common::errors::InvalidArgument( | ||
"The rank of Input(Scores) must be 2 or 3. But received rank = %d", | ||
score_size)); | ||
PADDLE_ENFORCE_EQ( | ||
box_dims.size(), | ||
3, | ||
common::errors::InvalidArgument( | ||
"The rank of Input(BBoxes) must be 3. But received rank = %d", | ||
box_dims.size())); | ||
|
||
if (score_size == 3) { | ||
PADDLE_ENFORCE_EQ( | ||
box_dims[2] == 4 || box_dims[2] == 8 || box_dims[2] == 16 || | ||
box_dims[2] == 24 || box_dims[2] == 32, | ||
true, | ||
common::errors::InvalidArgument("The last dimension of Input(BBoxes) " | ||
"must be 4 or 8 or 16 or 24 or 32")); | ||
infer_context->AddEqualCstr(box_dims[1], score_dims[2]); | ||
} else { | ||
infer_context->AddEqualCstr(box_dims[2], symbol::DimExpr(4)); | ||
infer_context->AddEqualCstr(box_dims[1], score_dims[1]); | ||
} | ||
|
||
const auto &next_symbol_out_and_index = infer_context->GetNextSymName(); | ||
const auto &next_symbol_nms_rois_num = infer_context->GetNextSymName(); | ||
|
||
std::vector<symbol::DimExpr> out_shape; | ||
out_shape.emplace_back(next_symbol_out_and_index); | ||
out_shape.emplace_back(box_dims[2] + 2); | ||
|
||
std::vector<symbol::DimExpr> index_shape; | ||
index_shape.emplace_back(next_symbol_out_and_index); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 怎么确定的out_shape 与 index_shape的第一维是相同的动态维度 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MultiClassNMSKernel out和index第一维总是相同的 |
||
index_shape.emplace_back(1); | ||
|
||
std::vector<symbol::DimExpr> nms_rois_num_shape; | ||
nms_rois_num_shape.emplace_back(next_symbol_nms_rois_num); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果能确定这个动态维度与上面两个无关,这里就可以不再使用中间变量了,直接emplace_back(infer_context->GetNextSymName()); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已完成修改 |
||
|
||
infer_context->SetShapeOrDataForValue( | ||
op->result(0), | ||
symbol::ShapeOrDataDimExprs{ | ||
symbol::TensorShapeOrDataDimExprs(out_shape)}); | ||
infer_context->SetShapeOrDataForValue( | ||
op->result(1), | ||
symbol::ShapeOrDataDimExprs{ | ||
symbol::TensorShapeOrDataDimExprs(index_shape)}); | ||
infer_context->SetShapeOrDataForValue( | ||
op->result(2), | ||
symbol::ShapeOrDataDimExprs{ | ||
symbol::TensorShapeOrDataDimExprs(nms_rois_num_shape)}); | ||
|
||
return true; | ||
} | ||
|
||
bool MeshgridOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
|
@@ -1532,8 +1585,7 @@ bool MovingAverageAbsMaxScale_OpInferSymbolicShape( | |
} | ||
|
||
// bool NceOpInferSymbolicShape(pir::Operation *op, | ||
// pir::InferSymbolicShapeContext *infer_context) | ||
// { | ||
// pir::InferSymbolicShapeContext *infer_context){ | ||
// // pass | ||
// return true; | ||
// } | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,6 +109,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logit_) | |
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsigmoid) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsigmoid_) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mish) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Moe) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Poisson) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pow) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pow_) | ||
|
@@ -138,6 +139,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter) | |
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter_) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Select) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShadowFeed) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShadowFeedTensors) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个算子有点问题先不添加了 |
||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShareData_) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sign) | ||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Silu) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个5个dim的约束需要修改为使用Addequalcstr()添加
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不包含
||
的约束已进行修改There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
score_size和box_dims.size()均为int类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
剩下的约束目前shape dialect无法表示了,直接删掉或者写成TODO,不能直接使用静态值判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已完成修改