-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ARM] Fix int8 NCHWc compute and alter layout #10839
Conversation
@@ -364,7 +365,7 @@ def get_ref_data(): | |||
# ), | |||
] | |||
|
|||
# TODO(tvm-team): Properly run ARM code on CI aarch64 environment | |||
# TODO(tvm-team): Figure out ARM dot product availability on CI aarch64 environment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @Mousius @u99127, I'd love to test the dot-product schedule on the aarch64 CI, do you know if it is supposed? Automatic detection would require /proc/cpuinfo
etc as suggested by @u99127 in #10773 (comment), which I'd rather avoid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I know, the CI environment should be good to run the dot-product schedules, I can take a look at cpuinfo detection later 😸
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup I enabled the dot product test on CI, it seems to be working!
https://ci.tlcpack.ai/blue/rest/organizations/jenkins/pipelines/tvm/branches/PR-10839/runs/5/nodes/316/steps/542/log/?start=0
(Search Running on target: llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod
)
@@ -120,7 +120,7 @@ def _pack_data(cfg, data, kernel): | |||
kernel = te.compute( | |||
(oc_chunk, ic_chunk, kh, kw, ic_bn // n_elems, oc_bn, n_elems), | |||
lambda occ, icc, k_h, k_w, icbc, ocb, icbb: kernel[ | |||
occ * oc_bn + ocb, icc * ic_bn + icbc * ic_bn // n_elems + icbb, k_h, k_w | |||
occ * oc_bn + ocb, icc * ic_bn + icbc * n_elems + icbb, k_h, k_w |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @tkonolige please have a look at this change. Since test_topi_conv2d_int8.py
doesn't use the alter layout code (which had a bug), and _pack_data
is using n_elems = 4
, the reason aarch64 CI failed on test_topi_conv2d_int8.py
was probably due to this bug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this @masahi! After reading through the code again (that I wrote...), it is doing a 4x4 dot product, so n_elems should be 4.
32916b6
to
53ff53e
Compare
This PR fixes a bug in TE ARM int8 compute for NCHWc conv2d, introduced in apache#10310. The compute itself, not the schedule, is broken for the following reasons: * We are using `n_elems = 8` in https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L350. Thus, the innermost axis of the transformed kernel has extent 8: https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L375 * In the TE compute, we iterate over the innermost axis `ic_s_inner` of the kernel at https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L577. `ic_s_inner` has extent `n_elems` according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L566. `n_elems` is 4 by default according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L478 * The ARM code that calls this compute does not explicitly pass `n_elems`, according to https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_int8.py#L106-L108 * Thus, even though the innermost axis of the kernel has extent 8, the TE compute only loops over `n_elems = 4` of the input channel dimension. Initially, I tried to keep `n_elems = 8` in alter layout and fix the intrinsic definition. But `n_elems = 8` breaks tensorization pattern matching, since now the compute is doing 4x8 innermost loop but this intrinsic is supposed to do 4x4 dot product, see https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L467-L479. Setting `num_int8_elements = 8` there does fix the tensorize pattern matching, but the result was still incorrect. Rather than fixing the intrin implementation in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L492 to adapt for 4x8 dot product, I settled on setting `n_elems = 4` in alter layout. It turned out this change is enough to get the correct output. Moreover, `n_elems = 8` is simply wrong for the dot product path in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/conv2d_int8.py#L154-L155 which computes 4x4 dot product in one instruction. @tkonolige I suggest doing perf benchmark again, since the numbers in apache#10310 are invalid. cc @mbrookhart @Mousius @junrushao1994 @vinx13
This PR fixes a bug in TE ARM int8 compute for NCHWc conv2d, introduced in apache#10310. The compute itself, not the schedule, is broken for the following reasons: * We are using `n_elems = 8` in https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L350. Thus, the innermost axis of the transformed kernel has extent 8: https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L375 * In the TE compute, we iterate over the innermost axis `ic_s_inner` of the kernel at https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L577. `ic_s_inner` has extent `n_elems` according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L566. `n_elems` is 4 by default according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L478 * The ARM code that calls this compute does not explicitly pass `n_elems`, according to https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_int8.py#L106-L108 * Thus, even though the innermost axis of the kernel has extent 8, the TE compute only loops over `n_elems = 4` of the input channel dimension. Initially, I tried to keep `n_elems = 8` in alter layout and fix the intrinsic definition. But `n_elems = 8` breaks tensorization pattern matching, since now the compute is doing 4x8 innermost loop but this intrinsic is supposed to do 4x4 dot product, see https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L467-L479. Setting `num_int8_elements = 8` there does fix the tensorize pattern matching, but the result was still incorrect. Rather than fixing the intrin implementation in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L492 to adapt for 4x8 dot product, I settled on setting `n_elems = 4` in alter layout. It turned out this change is enough to get the correct output. Moreover, `n_elems = 8` is simply wrong for the dot product path in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/conv2d_int8.py#L154-L155 which computes 4x4 dot product in one instruction. @tkonolige I suggest doing perf benchmark again, since the numbers in apache#10310 are invalid. cc @mbrookhart @Mousius @junrushao1994 @vinx13
This PR fixes a bug in TE ARM int8 compute for NCHWc conv2d, introduced in #10310. The compute itself, not the schedule, is broken for the following reasons:
n_elems = 8
intvm/python/tvm/topi/arm_cpu/conv2d_alter_op.py
Line 350 in e9091d6
tvm/python/tvm/topi/arm_cpu/conv2d_alter_op.py
Line 375 in e9091d6
ic_s_inner
of the kernel attvm/python/tvm/topi/nn/conv2d.py
Line 577 in f6f252f
ic_s_inner
has extentn_elems
according totvm/python/tvm/topi/nn/conv2d.py
Line 566 in f6f252f
n_elems
is 4 by default according totvm/python/tvm/topi/nn/conv2d.py
Line 478 in f6f252f
n_elems
, according totvm/python/tvm/topi/arm_cpu/conv2d_int8.py
Lines 106 to 108 in e9091d6
n_elems = 4
of the input channel dimension.Initially, I tried to keep
n_elems = 8
in alter layout and fix the intrinsic definition. Butn_elems = 8
breaks tensorization pattern matching, since now the compute is doing 4x8 innermost loop but this intrinsic is supposed to do 4x4 dot product, seetvm/python/tvm/topi/arm_cpu/tensor_intrin.py
Lines 467 to 479 in 7896108
num_int8_elements = 8
there does fix the tensorize pattern matching, but the result was still incorrect.Rather than fixing the intrin implementation in
tvm/python/tvm/topi/arm_cpu/tensor_intrin.py
Line 492 in 7896108
n_elems = 4
in alter layout. It turned out this change is enough to get the correct output. Moreover,n_elems = 8
is simply wrong for the dot product path intvm/python/tvm/topi/arm_cpu/conv2d_int8.py
Lines 154 to 155 in 7896108
@tkonolige I suggest doing perf benchmark again, since the numbers in #10310 are invalid.
cc @mbrookhart @Mousius @junrushao1994 @vinx13