-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI] Basic x86 schedules #775
Conversation
n, c, h, w = op.axis | ||
fused = s[op].fuse(n, c) | ||
s[op].parallel(fused) | ||
s[op].vectorize(w) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this for? shouldn't it be under if 'conv2d_nchw' in op.tag:
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is for parallelizing and vectorizing bias + batchnorm + relu, which comes right after conv2d.
To understand this, you need to compile your network (resnet18, say) with nnvm and dump lowered IR after operator fusion.
See this dmlc/nnvm#292
Also see this
https://github.com/dmlc/nnvm/issues/275
s[C].reorder(fused, rc, h, wo, ry, rx, wi) # move rc to outer loop | ||
s[C].unroll(rx) | ||
s[C].unroll(ry) | ||
s[C].vectorize(wi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this schedule on AWS C5 instance (skylake-avx512 4-cpu) which does not perform well, about 13 ms
for conv only.
My input is (1, 64, 56, 56)
and kernel size (64, 64, 3, 3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the final s[C].vectorize(wi) is not working for this input size (56), meaning there is no vector instruction generated.
For my use cases, all of input width and height is power of two. I wrote my schedules assuming such input.
|
||
|
||
@generic.schedule_conv2d_nhwc.register(["cpu"]) | ||
def schedule_conv2d_nhwc(outs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the same C5 instance, this one works pretty good, 1.8 ms
for (1, 56, 56, 64)
input and (3, 3, 64, 64)
kernel size.
tvm.schedule.AutoInlineInjective(s) | ||
if len(s[x].op.axis) == 4: | ||
n, c, _, _ = s[x].op.axis | ||
fused = s[x].fuse(n, c) # for nhwc layout, fuse n and h |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still confused with this. why not put it in schedule_conv2d
?
and the comment does not match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this schedule is not about conv2d. This schedule is used for multi threading injective ops such as pooling, upsampling, and softmax. All elemwise/broadcast ops are also multithreaded with this schedule.
the comment is actually correct, what I mean is s[x].op.axis depends on the input layout. With NCHW layout it will be (n, c, h, w), while with NHWC it will be (n, h, w, c). The comment says no matter what layout, the schedule fuse the first two axes.
@yzhliu this PR is not just about x86 conv schedules, but also contains some change I made to run a whole network run faster in conjunction with NNVM. I should also mention that my schedule is optimized for my use case (input size being power of two, all conv are 3x3, etc..). In particular, I didn't optimize my schedules for imagenet workload at all. Nonetheless, it should be better than the current x86 schedule (which is mostly empty). The idea is @yzhliu or some others from community will improve on it, to make it faster on a wide range of workloads, including imagenet. |
I am going to merge this given the solution is better than current x86 ones. We can followup with updates to improve it |
* add basic x86 schedules * parallelize & vectorize batchnorm + relu * fuse conv into bn + relu * move rc loop to outer * add nhwc conv * change weight layout to hwcf * conv + bn + relu fusion for nhwc conv * fix conv_nhwc schedule when no fusion * clean up default parallel schedules * simplify elemwise parallel * fix elemwise parallel for batch == 1 * update nhwc conv test * fix and add comment * fix lint * remove redundant import * remove default multithreading for some ops * remove default multithreading for global pool
* add basic x86 schedules * parallelize & vectorize batchnorm + relu * fuse conv into bn + relu * move rc loop to outer * add nhwc conv * change weight layout to hwcf * conv + bn + relu fusion for nhwc conv * fix conv_nhwc schedule when no fusion * clean up default parallel schedules * simplify elemwise parallel * fix elemwise parallel for batch == 1 * update nhwc conv test * fix and add comment * fix lint * remove redundant import * remove default multithreading for some ops * remove default multithreading for global pool
Contains the following:
cc @yidawang @yzhliu