-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] relay onnx frontend bug when [A, B, M, N] * [1, B, N, K] #9911
Conversation
Fix BUG and add test case for batch_matmul [A, B, M, N] * [1, B, N, K] |
Will take a look later today |
python/tvm/relay/frontend/onnx.py
Outdated
[ | ||
out_batch, | ||
_op.strided_slice( | ||
a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use a_rank
instead of infer_shape(a_shape)[0]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just please address nits and rebase
python/tvm/relay/frontend/onnx.py
Outdated
[ | ||
out_batch, | ||
_op.strided_slice( | ||
b_shape, [infer_shape(b_shape)[0] - 2], [infer_shape(b_shape)[0]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, use b_rank
hi @AndrewZhaoLuo , I have fixed the comment. |
…e#9911) * [Fix] relay onnx frontend bug when [A, B, M, N] * [1, B, N, K] * fix line Co-authored-by: tomoyazhang <[email protected]>
…e#9911) * [Fix] relay onnx frontend bug when [A, B, M, N] * [1, B, N, K] * fix line Co-authored-by: tomoyazhang <[email protected]>
Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.