-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QNN] Lowering for Depthwise Convolution. #4351
Conversation
9d45f03
to
4e2ec58
Compare
@FrozenGene @jackwish @tmoreau89 A gentle ping for review :) This PR is ready for review now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments :)
src/relay/qnn/op/convolution.cc
Outdated
if (param->kernel_layout == "OIHW") { | ||
out_channels = get_const_int(kernel_shape[0]); | ||
kernel_h = get_const_int(kernel_shape[2]); | ||
kernel_w = get_const_int(kernel_shape[3]); | ||
if (is_depthwise(param)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about using a variable to hold this?
* For depthwise, we can similarly unroll the computation. The intial compute is as follows | ||
* wehere cm = channel_multiplier | ||
* | ||
* Qc(n, oc, oh, ow) = Sigma(r, s) (Qw(oc/m, oc%/m, r, s) - zp_w) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggesting to rewrite related formulas with tex: block example, or inline style.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments. I will send a separate PR for better comments. This will have to change in multiple files I think.
CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation"; | ||
auto dilation_h = get_const_int(param->dilation[0]); | ||
auto dilation_w = get_const_int(param->dilation[1]); | ||
if (dilation_h != 1 || dilation_w != 1 || param->groups != 1) { | ||
if (dilation_h != 1 || dilation_w != 1 || (param->groups != 1 && !is_depthwise(param))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems is_depthwise
already checked groups
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to differentiate between grouped convolutions (old Alexnet) vs depthwise convolution. We want to fallback to simpler lowering if its grouped convolution, but not depthwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for this, I remembered that I have removed this comment after I took a clear look at this :(
auto term3 = Conv2DThirdTerm(weight, param, batch_size, out_channels); | ||
auto term4 = Conv2DFourthTerm(param, batch_size, in_channels, kernel_h, kernel_w); | ||
auto term3 = Conv2DThirdTerm(weight, param, out_channels); | ||
auto term4 = Conv2DFourthTerm(param, in_channels, kernel_h, kernel_w); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is interesting, why removing batch semantic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The batch_size was not needed at all for those 2 terms in the first place. I realized that while writing the lowering for Depthwise.
src/relay/qnn/op/convolution.cc
Outdated
// We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum. | ||
// Since, this is integer division (floor), we can first multiply the data by the pool_size and | ||
// then perform avg_pool2d. Reversing this causes inaccuracy due to floor division. | ||
auto scaled_hw_t2 = Multiply(casted_t2, MakeConstantScalar(Int(32), kernel_h * kernel_w)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is unneeded (like L215) when kernel_h * kernel_w == 1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice observation! Yes, that will save a few us :) I will make changes for the conv2d lowering as well.
@jackwish Thanks! Addressed your comments :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you @anijain2305
@zhiics Can you please review? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks @anijain2305 @jackwish |
Adding the QNN lowering sequence for Depthwise conv2d. This creates depthwise conv2d with (u)int8 inputs, opening up path to write schedule using Intel VNNI and ARM DOT instructions.
For older HW that do not have fast in8 support, this lowering will not be called (already merged upstream in a different PR).
@FrozenGene @jackwish @tmoreau89 @yzhliu @zhiics