-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] Bitserial ops #3844
[Relay] Bitserial ops #3844
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work @jwfromm ! I've left a few comments in there for you to address.
topi/python/topi/arm_cpu/conv2d.py
Outdated
@@ -119,7 +123,11 @@ def _callback(op): | |||
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: | |||
s[kernel].compute_inline() | |||
|
|||
_schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) | |||
# TODO: move to schedule_nhwc later | |||
if 'nhwc' in op.tag: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is in the schedule_conv2d_nchw_arm_cpu
function - should 'nhwc' ever be in op.tag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reverted changes to arm nhwc schedule in favor of jackwish version.
topi/python/topi/arm_cpu/conv2d.py
Outdated
@@ -810,25 +954,19 @@ def _conv2d_legalize(attrs, inputs, arg_types): | |||
if attrs['data_layout'] == 'NHWC': | |||
data, kernel = inputs | |||
if attrs['kernel_layout'] == 'HWIO': | |||
# Handle HWIO layout. This is common in TF graph. | |||
kernel = relay.transpose(kernel, axes=(3, 2, 0, 1)) | |||
# HWIO layout is expected for NHWC input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the assumption here is that we run bitserial_conv2d_legalize
beforehand?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although this specific check has now been reverted a similar legalize routine is still in the bitserial convolution. The NHWC computation only works when the kernel is in HWIO format so the legalize pass is just doing a conversion in case the kernel is in a different format. Since most of TVM uses a default OIHW format, this is a pretty handy check to have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the new changes! I added a few more comments to be addressed.
It looks like @jackwish has a more fleshed out version of the NHWC schedules so I think it makes sense to cut it from my PR and instead focus only on the bitserial ops. All additions to the arm nhwc conv2d have now been reverted. |
Agreed, thanks for the changes, it will help integrate jackwish' changes more easily in the future. |
Thank you @jwfromm and @tmoreau89 , I think that I borrow insights of this work :) |
* Added arm_cpu NHWC schedules. * Fixed kernel shape legalization. * Added bitserial ops to relay. * Snapshot and more missing files. * Added dense testing. * Added tests * Added ASF header to new files. * cc lint * Pylint change. * pylint fixes. * Change arm legalize test. * Added assert check to arm legalize. * Added better documentation, fixed some bad style * Reverted arm conv2d nhwc changes.
* Added arm_cpu NHWC schedules. * Fixed kernel shape legalization. * Added bitserial ops to relay. * Snapshot and more missing files. * Added dense testing. * Added tests * Added ASF header to new files. * cc lint * Pylint change. * pylint fixes. * Change arm legalize test. * Added assert check to arm legalize. * Added better documentation, fixed some bad style * Reverted arm conv2d nhwc changes.
* Added arm_cpu NHWC schedules. * Fixed kernel shape legalization. * Added bitserial ops to relay. * Snapshot and more missing files. * Added dense testing. * Added tests * Added ASF header to new files. * cc lint * Pylint change. * pylint fixes. * Change arm legalize test. * Added assert check to arm legalize. * Added better documentation, fixed some bad style * Reverted arm conv2d nhwc changes.
This PR adds relay operations for the bitserial operations conv2d, dense and bitpack. This addition allows relay frontends to leverage the already existing TOPI bitserial ops that enable very fast low-bit execution on arm CPU. There are currently some limitations in regards to automatic shape inference in large part due to the need for these ops to support optional prepacking of weight bits. For example, this makes it difficult to infer shape without explicitly knowing the number of channels ahead of time, so here we require that the channels attribute is always set.
Also included in this PR is a schedule for NHWC convolution on arm_cpu that yields good results. Because of this inclusion, I've accordingly changed the conv2d legalize routine on arm.