diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 2b8d3656831dd..fbc2ed24548b2 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -27,7 +27,7 @@ logger = logging.getLogger('strategy') _NCHWc_matcher = re.compile("^NCHW[0-9]+c$") -_OIHWio_matcher = re.compile("^OIHW[0-9]+i[-+]?[0-9]+o$") +_OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$") @schedule_injective.register("cpu") def schedule_injective_cpu(attrs, outs, target): @@ -88,7 +88,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): raise ValueError("dilation should be positive value") if groups == 1: - def add_implementation_nchw(): + if layout == "NCHW": + assert kernel_layout == "OIHW" if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype): strategy.add_implementation( wrap_compute_conv2d(topi.x86.conv2d_nchw_int8), @@ -99,12 +100,9 @@ def add_implementation_nchw(): wrap_compute_conv2d(topi.x86.conv2d_nchw), wrap_topi_schedule(topi.x86.schedule_conv2d_nchw), name="conv2d_nchw.x86") - if layout == "NCHW": - assert kernel_layout == "OIHW" - add_implementation_nchw() elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio - add_implementation_nchw() + return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWIO" logger.warning("For x86 target, NCHW layout is recommended for conv2d.") @@ -122,7 +120,9 @@ def add_implementation_nchw(): else: raise RuntimeError("Unsupported conv2d layout {} for x86".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): - def add_implementation_depthwise_nchw(channel_multiplier): + if layout == "NCHW": + assert kernel_layout == "OIHW" + channel_multiplier = get_const_tuple(inputs[1].shape)[1] if channel_multiplier == 1 and dilation_h == 1 and dilation_w == 1: strategy.add_implementation( wrap_compute_conv2d(topi.x86.depthwise_conv2d_nchw), @@ -135,15 +135,9 @@ def add_implementation_depthwise_nchw(channel_multiplier): wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw), name="depthwise_conv2d_nchw.generic") - if layout == "NCHW": - assert kernel_layout == "OIHW" - channel_multiplier = get_const_tuple(inputs[1].shape)[1] - add_implementation_depthwise_nchw(channel_multiplier) elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio - kernel_shape = get_const_tuple(inputs[1].shape) - channel_multiplier = kernel_shape[1] * kernel_shape[4] - add_implementation_depthwise_nchw(channel_multiplier) + return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWOI" logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.") diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index 63d8d9b13db87..d1c607f6a3e5d 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -33,7 +33,7 @@ logger = logging.getLogger('topi') _NCHWc_matcher = re.compile("^NCHW[0-9]+c$") -_OIHWio_matcher = re.compile("^OIHW[0-9]+i[-+]?[0-9]+o$") +_OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$") @conv2d_alter_layout.register("cpu") def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):