diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 059e9eaafa568..2b8d3656831dd 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -26,8 +26,8 @@ logger = logging.getLogger('strategy') -_NCHWc_matcher = re.compile("^NCHW[-+]?[0-9]+c$") -_OIHWio_matcher = re.compile("^OIHW[-+]?[0-9]+i[-+]?[0-9]+o$") +_NCHWc_matcher = re.compile("^NCHW[0-9]+c$") +_OIHWio_matcher = re.compile("^OIHW[0-9]+i[-+]?[0-9]+o$") @schedule_injective.register("cpu") def schedule_injective_cpu(attrs, outs, target): @@ -88,13 +88,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): raise ValueError("dilation should be positive value") if groups == 1: - if layout.startswith("NCHW"): - if layout != "NCHW": - # check if layout is NCHWxc - assert _NCHWc_matcher.match(layout) - assert _OIHWio_matcher.match(kernel_layout) - else: - assert kernel_layout == "OIHW" + def add_implementation_nchw(): if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype): strategy.add_implementation( wrap_compute_conv2d(topi.x86.conv2d_nchw_int8), @@ -105,6 +99,12 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.x86.conv2d_nchw), wrap_topi_schedule(topi.x86.schedule_conv2d_nchw), name="conv2d_nchw.x86") + if layout == "NCHW": + assert kernel_layout == "OIHW" + add_implementation_nchw() + elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc + assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio + add_implementation_nchw() elif layout == "NHWC": assert kernel_layout == "HWIO" logger.warning("For x86 target, NCHW layout is recommended for conv2d.") @@ -122,14 +122,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): else: raise RuntimeError("Unsupported conv2d layout {} for x86".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): - if layout.startswith("NCHW"): - if layout != "NCHW": - # check if layout is NCHWxc - assert _NCHWc_matcher.match(layout) - assert _OIHWio_matcher.match(kernel_layout) - else: - assert kernel_layout == "OIHW" - channel_multiplier = get_const_tuple(inputs[1].shape)[1] + def add_implementation_depthwise_nchw(channel_multiplier): if channel_multiplier == 1 and dilation_h == 1 and dilation_w == 1: strategy.add_implementation( wrap_compute_conv2d(topi.x86.depthwise_conv2d_nchw), @@ -142,6 +135,15 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw), name="depthwise_conv2d_nchw.generic") + if layout == "NCHW": + assert kernel_layout == "OIHW" + channel_multiplier = get_const_tuple(inputs[1].shape)[1] + add_implementation_depthwise_nchw(channel_multiplier) + elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc + assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio + kernel_shape = get_const_tuple(inputs[1].shape) + channel_multiplier = kernel_shape[1] * kernel_shape[4] + add_implementation_depthwise_nchw(channel_multiplier) elif layout == "NHWC": assert kernel_layout == "HWOI" logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.") diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index c107c5d2543a9..8a6dfca6b0cce 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -39,7 +39,7 @@ namespace tvm { namespace relay { extern Expr MakeReshape(Expr data, - Array newshape); + Array newshape); template bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 95fa3599ff92b..f6ea23c84f0a8 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -318,7 +318,7 @@ static bool IsIntInArray(const Array& axis, int v) { } static Expr ReshapeToMatchAxis(Expr scale, const Array& shape, - const Array& axis) { + const Array& axis) { Array arr; for (size_t i = 0; i < shape.size(); i++) { if (IsIntInArray(axis, i)) { @@ -337,7 +337,7 @@ static Expr ReshapeToMatchAxis(Expr scale, const Array& shape, // if only one axis, use expand dim. Else, use reshape static Expr ReshapeOrExpandToMatchAxis(Expr scale, const Array& shape, - const Array& axis) { + const Array& axis) { if (axis.size() > 1) { return ReshapeToMatchAxis(scale, shape, axis); } else { @@ -407,8 +407,9 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes)); Expr scale = ReshapeOrExpandToMatchAxis( slhs->scale, tlhs->shape, slhs->axes); - if (!scale.defined()) + if (!scale.defined()) { return Expr(); + } Expr rhs = Divide(new_args[1], scale); rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args); rnode->scale = slhs->scale; @@ -418,8 +419,9 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes)); Expr scale = ReshapeOrExpandToMatchAxis( srhs->scale, trhs->shape, srhs->axes); - if (!scale.defined()) + if (!scale.defined()) { return Expr(); + } Expr lhs = Divide(new_args[0], scale); rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args); rnode->scale = srhs->scale; diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index b263c2383ca1b..63d8d9b13db87 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -32,8 +32,8 @@ logger = logging.getLogger('topi') -_NCHWc_matcher = re.compile("^NCHW[-+]?[0-9]+c$") -_OIHWio_matcher = re.compile("^OIHW[-+]?[0-9]+i[-+]?[0-9]+o$") +_NCHWc_matcher = re.compile("^NCHW[0-9]+c$") +_OIHWio_matcher = re.compile("^OIHW[0-9]+i[-+]?[0-9]+o$") @conv2d_alter_layout.register("cpu") def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):