-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix vmlal.s16 code generation for int8 x int8 -> int32 #2748
Fix vmlal.s16 code generation for int8 x int8 -> int32 #2748
Conversation
src/pass/lower_intrin.cc
Outdated
@@ -50,7 +50,23 @@ class IntrinInjecter : public IRMutator { | |||
// on ARM. | |||
if (const Broadcast* bcast = e.as<Broadcast>()) { | |||
if (const Cast* cast = bcast->value.as<Cast>()) { | |||
if (cast->type.bits() == cast->value.type().bits() * 2) { | |||
auto shouldSwap = [&]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: code style, local variable need to be should_swap
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, my bad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm modulo nit
A[k, n].astype("int32") * B[k, n].astype("int32"), axis=[k]), name='C') | ||
s = tvm.create_schedule(C.op) | ||
s[C].vectorize(s[C].op.axis[0]) | ||
print(tvm.lower(s, [A, B, C], simple_mode=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should remove this print
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, my bad.
98b54d4
to
17112b5
Compare
Thanks for the comments, I've updated the patch with the review comments. |
Thanks @ajtulloch @FrozenGene , this is now merged! |
Thanks for merging @tqchen. |
The
IntrinInjecter::SwapBroadcastCast
function was limited to cases where the result bitwidth was exactly 2x the input bitwidth. This fails for cases of relevance such as a vectorized int8 x int8 -> int32 GEMM. This helps improve the code-generation somewhat by outputtingVMLAL.S16
instructions instead of aMOVL
+ VMLA.S32`.