-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【BUAA】【Infer Symbolic Shape】Add moe, multiclass_nms3, shadow_feed_tensor operator for CINN compiler #67337
Conversation
Sf tensor
你的PR提交成功,感谢你对开源项目的贡献! |
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc
Outdated
Show resolved
Hide resolved
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc
Outdated
Show resolved
Hide resolved
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc
Outdated
Show resolved
Hide resolved
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc
Outdated
Show resolved
Hide resolved
已修改单测 |
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc
Outdated
Show resolved
Hide resolved
box_dims.size())); | ||
|
||
if (score_size == 3) { | ||
PADDLE_ENFORCE_EQ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个5个dim的约束需要修改为使用Addequalcstr()添加
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不包含||
的约束已进行修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
score_size和box_dims.size()均为int类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不包含
||
的约束已进行修改
剩下的约束目前shape dialect无法表示了,直接删掉或者写成TODO,不能直接使用静态值判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已完成修改
out_shape.emplace_back(box_dims[2] + 2); | ||
|
||
std::vector<symbol::DimExpr> index_shape; | ||
index_shape.emplace_back(next_symbol_out_and_index); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
怎么确定的out_shape 与 index_shape的第一维是相同的动态维度
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MultiClassNMSKernel out和index第一维总是相同的
index_shape.emplace_back(1); | ||
|
||
std::vector<symbol::DimExpr> nms_rois_num_shape; | ||
nms_rois_num_shape.emplace_back(next_symbol_nms_rois_num); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果能确定这个动态维度与上面两个无关,这里就可以不再使用中间变量了,直接emplace_back(infer_context->GetNextSymName());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已完成修改
@@ -138,6 +138,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter_) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Select) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShadowFeed) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShadowFeedTensors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个算子有点问题先不添加了
test/legacy_test/test_moe_op.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moe这个算子符号推导逻辑比较简单,而且注意到已有test_moe_api.py测试文件存在,建议删除这个单测文件
PR Category
CINN
PR Types
Others
Description
shadow_feed_tensor缺失单测,multiclass_nms3有单测但未开启check_pir测试,仿照fused_moe为moe添加一个单测