-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NNVM][CONVOLUTION] Group convolution generalization for NHWC #1232
Changes from 3 commits
c10865e
7f56f4a
e7eeb51
34fe3d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,9 +98,14 @@ def compute_conv2d(attrs, inputs, _): | |
if groups == 1: | ||
out = topi.nn.conv2d(inputs[0], kernel, strides, padding, layout) | ||
elif groups == get_const_int(inputs[0].shape[1]) and groups == channels: | ||
# NCHW | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. then this comment can be removed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handled |
||
out = topi.nn.depthwise_conv2d_nchw(inputs[0], kernel, strides, padding) | ||
elif groups == get_const_int(inputs[0].shape[3]) and groups == channels: | ||
# NHWC | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and this line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. handled. |
||
out = topi.nn.depthwise_conv2d_nhwc(inputs[0], kernel, strides, padding) | ||
else: | ||
raise ValueError("not support arbitrary group number for now") | ||
|
||
if attrs.get_bool("use_bias"): | ||
bias = inputs[2] | ||
expand_axis = 1 if layout == "NCHW" else 0 | ||
|
@@ -112,13 +117,19 @@ def compute_conv2d(attrs, inputs, _): | |
def schedule_conv2d(attrs, outs, target): | ||
"""Schedule definition of conv2d""" | ||
groups = attrs.get_int("groups") | ||
channels = attrs.get_int("channels") | ||
layout = attrs["layout"] | ||
with tvm.target.create(target): | ||
if groups == 1 and layout == "NCHW": | ||
return topi.generic.schedule_conv2d_nchw(outs) | ||
elif groups == 1 and layout == "NHWC": | ||
return topi.generic.schedule_conv2d_nhwc(outs) | ||
return topi.generic.schedule_depthwise_conv2d_nchw(outs) | ||
elif groups == channels and layout == "NCHW": | ||
return topi.generic.schedule_depthwise_conv2d_nchw(outs) | ||
elif groups == channels and layout == "NHWC": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. elif groups == channels and layout == "NHWC" and kernel_layout == "HWOI": There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
return topi.generic.schedule_depthwise_conv2d_nhwc(outs) | ||
else: | ||
raise ValueError("No compatible schedule") | ||
|
||
@reg.register_alter_op_layout("conv2d") | ||
def alter_conv2d_layout(attrs, inputs, tinfos): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add more checks here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done