Skip to content

Commit

Permalink
directly use HCHWc impl in conv2d_strategy_cpu || refine REGEX
Browse files Browse the repository at this point in the history
  • Loading branch information
Menooker committed May 5, 2020
1 parent da6ac91 commit 2d9a616
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 15 deletions.
22 changes: 8 additions & 14 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
logger = logging.getLogger('strategy')

_NCHWc_matcher = re.compile("^NCHW[0-9]+c$")
_OIHWio_matcher = re.compile("^OIHW[0-9]+i[-+]?[0-9]+o$")
_OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$")

@schedule_injective.register("cpu")
def schedule_injective_cpu(attrs, outs, target):
Expand Down Expand Up @@ -88,7 +88,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
raise ValueError("dilation should be positive value")

if groups == 1:
def add_implementation_nchw():
if layout == "NCHW":
assert kernel_layout == "OIHW"
if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype):
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_nchw_int8),
Expand All @@ -99,12 +100,9 @@ def add_implementation_nchw():
wrap_compute_conv2d(topi.x86.conv2d_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_nchw),
name="conv2d_nchw.x86")
if layout == "NCHW":
assert kernel_layout == "OIHW"
add_implementation_nchw()
elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc
assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio
add_implementation_nchw()
return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
logger.warning("For x86 target, NCHW layout is recommended for conv2d.")
Expand All @@ -122,7 +120,9 @@ def add_implementation_nchw():
else:
raise RuntimeError("Unsupported conv2d layout {} for x86".format(layout))
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
def add_implementation_depthwise_nchw(channel_multiplier):
if layout == "NCHW":
assert kernel_layout == "OIHW"
channel_multiplier = get_const_tuple(inputs[1].shape)[1]
if channel_multiplier == 1 and dilation_h == 1 and dilation_w == 1:
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.depthwise_conv2d_nchw),
Expand All @@ -135,15 +135,9 @@ def add_implementation_depthwise_nchw(channel_multiplier):
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.generic")
if layout == "NCHW":
assert kernel_layout == "OIHW"
channel_multiplier = get_const_tuple(inputs[1].shape)[1]
add_implementation_depthwise_nchw(channel_multiplier)
elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc
assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio
kernel_shape = get_const_tuple(inputs[1].shape)
channel_multiplier = kernel_shape[1] * kernel_shape[4]
add_implementation_depthwise_nchw(channel_multiplier)
return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.")
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/x86/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
logger = logging.getLogger('topi')

_NCHWc_matcher = re.compile("^NCHW[0-9]+c$")
_OIHWio_matcher = re.compile("^OIHW[0-9]+i[-+]?[0-9]+o$")
_OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$")

@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
Expand Down

0 comments on commit 2d9a616

Please sign in to comment.