diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index e639e2296e2a3..fcc0f99ab6330 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -161,11 +161,11 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="depthwise_conv2d_nchw.x86") elif layout == "NHWC": assert kernel_layout == "HWOI" - logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.") + #logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.") strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), - wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc), - name="depthwise_conv2d_nhwc.generic") + wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc), + wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nhwc.arm_cpu") else: raise RuntimeError("Unsupported depthwise_conv2d layout {} for arm cpu". format(layout)) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 733d6e9448fd9..6ffcc675560d3 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -295,7 +295,6 @@ bool ReduceRel(const Array& types, int num_inputs, const Attrs& attrs, } Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, String op_name) { - std::cout << "making " << op_name << std::endl; auto attrs = make_object(); attrs->axis = std::move(axis); attrs->keepdims = keepdims; diff --git a/topi/python/topi/arm_cpu/depthwise_conv2d.py b/topi/python/topi/arm_cpu/depthwise_conv2d.py index 802b3df195303..f36525d6669e0 100644 --- a/topi/python/topi/arm_cpu/depthwise_conv2d.py +++ b/topi/python/topi/arm_cpu/depthwise_conv2d.py @@ -31,7 +31,6 @@ def depthwise_conv2d_nchw(_, data, kernel, strides, padding, dilation, out_dtype """Compute depthwise_conv2d with NCHW layout""" return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype) - @autotvm.register_topi_schedule("depthwise_conv2d_nchw.arm_cpu") def schedule_depthwise_conv2d_nchw(cfg, outs): """Schedule depthwise conv2d @@ -181,6 +180,130 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2) +@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu") +def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype): + """TOPI compute callback for depthwise_conv2d nhwc + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data : tvm.te.Tensor + 4-D with shape [batch, in_height, in_width, in_channel] + + kernel : tvm.te.Tensor + 4-D with shape [filter_height, filter_width, in_channel, channel_multiplier] + + strides : list of two ints + [stride_height, stride_width] + + padding : list of two ints + [pad_height, pad_width] + + dilation : list of two ints + [dilation_height, dilation_width] + + out_dtype: str + The output type. This is used for mixed precision. + + Returns + ------- + output : tvm.te.Tensor + 4-D with shape [batch, out_height, out_width, out_channel] + """ + + out_dtype = out_dtype or data.dtype + + N, IH, IW, IC = get_const_tuple(data.shape) + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape) + + dilated_kernel_h = (KH - 1) * dilation_h + 1 + dilated_kernel_w = (KW - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) + + OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 + OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 + + if pad_top or pad_left: + data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], + name="data_pad") + else: + data_pad = data + + output_shape = (N, OH, OW, IC*channel_multiplier) + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + + reduce_h = te.reduce_axis((0, KH), name='reduce_h') + reduce_w = te.reduce_axis((0, KW), name='reduce_w') + + out = te.compute(output_shape, lambda n, h, w, c: + te.sum(data_pad[n, + HSTR*h+dilation_h*reduce_h, + w*WSTR+reduce_w*dilation_w, + idxdiv(c, channel_multiplier)].astype(out_dtype) * + kernel[reduce_h, + reduce_w, + idxdiv(c, channel_multiplier), + idxmod(c, channel_multiplier)].astype(out_dtype), + axis=[reduce_h, reduce_w]), + name='depthwise_conv2d_nhwc_output') + + return out + +@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu") +def schedule_depthwise_conv2d_nhwc(_, outs): + """Create the schedule for depthwise_conv2d_nchw_spatial_pack""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + out = outs[0] + + def schedule_conv(conv): + n, w, h, c = conv.op.axis + r_h, r_w = conv.op.reduce_axis + co, ci = s[conv].split(c, 8) + wo, wi = s[conv].split(w, 2) + ho, hi = s[conv].split(h, 2) + + s[conv].reorder(n, wo, ho, co, wi, hi, r_h, r_w, ci) + s[conv].parallel(wo) + s[conv].vectorize(ci) + + def schedule_conv_out(out): + n, h, w, c = out.op.axis + co, ci = s[out].split(c, 8) + wo, wi = s[out].split(w, 2) + ho, hi = s[out].split(h, 2) + ci_outer, ci_inner = s[out].split(ci, 4) + s[out].reorder(n, wo, ho, co, wi, hi) + s[out].vectorize(ci_inner) + compute_at_axis = hi + s[out].parallel(wo) + return compute_at_axis + + def _callback(op): + if op.name == 'depthwise_conv2d_nhwc_output': + conv = op.output(0) + if conv != out: + compute_at_axis = schedule_conv_out(out) + schedule_conv(conv) + s[conv].compute_at(s[out], compute_at_axis) + else: + schedule_conv(out) + + traverse_inline(s, outs[0].op, _callback) + return s @autotvm.register_topi_schedule("depthwise_conv2d_nchw_spatial_pack.arm_cpu") def schedule_depthwise_conv2d_nchw_spatial_pack(cfg, outs): diff --git a/topi/tests/python/test_topi_depthwise_conv2d.py b/topi/tests/python/test_topi_depthwise_conv2d.py index 693348918d3ef..62611e4ebf2fb 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d.py +++ b/topi/tests/python/test_topi_depthwise_conv2d.py @@ -40,6 +40,7 @@ _depthwise_conv2d_nhwc_implement = { "generic": (topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc), + "arm_cpu": (topi.arm_cpu.compute_depthwise_conv2d_nhwc, topi.arm_cpu.schedule_depthwise_conv2d_nhwc), "gpu": (topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc), } @@ -177,6 +178,9 @@ def check_device(device): print("Running on target: %s" % device) fcompute, fschedule = topi.testing.dispatch(device, _depthwise_conv2d_nhwc_implement) + if device == "gpu" and dilation > 1: + # skip because it uses too large shared memory on cuda + return with tvm.target.create(device): # declare DepthwiseConv2d = fcompute(Input, Filter, @@ -385,8 +389,7 @@ def test_depthwise_conv2d(): depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "VALID") depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "VALID") # dilation = 2 - # disabled because it uses too large shared memory on cuda - # depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2) + depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2) # NCHW[x]c depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "SAME")