diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 81263912bc9b..e0cee9526d14 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -112,6 +112,23 @@ def conv2d(data, kernel, stride, padding, layout='NCHW', out_dtype=None): raise ValueError("not support this layout {} yet".format(layout)) +@tvm.target.generic_func +def conv2d_alter_layout(attrs, inputs, tinfos): + """Change Conv2D layout. + + Parameters + ---------- + attrs : nnvm.top.AttrDict + Attributes of current convolution + inputs : nnvm.symbol + Grouped input symbols + tinfos : list + Input shape and dtype + """ + # not to change by default + return None + + def _get_workload(data, kernel, stride, padding, out_dtype): """ Get the workload structure. """ _, CI, IH, IW = [x.value for x in data.shape] @@ -425,6 +442,44 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'): name="Conv2dOutput", tag="conv2d_nhwc") return Output +@tvm.target.generic_func +def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, out_dtype='float32'): + """Conv2D operator for nChw[x]c layout. + + Parameters + ---------- + data : tvm.Tensor + 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] + + kernel : tvm.Tensor + 6-D with shape + [num_filter_chunk, in_channel_chunk, filter_height, filter_width, + in_channel_block, num_filter_block] + + num_filter : int + number of filters, i.e., output channel size + + kernel_size : tuple of two ints + [kernel_height, kernel_width] + + stride : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of two ints + padding size, or [pad_height, pad_width] + + out_dtype : str + output data type + + Returns + ------- + output : tvm.Tensor + 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] + """ + # search platform specific declaration first + # default declaration + raise ValueError("missing register for topi.nn.conv2d_NCHWc") + # map from schedule type to declaration function _SCH_TO_DECL_FUNC = { SpatialPack: _spatial_pack, diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 7e9632ed9622..59495d75f42d 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -4,27 +4,50 @@ from .. import generic, tag from .. import nn from ..nn.util import infer_pad, infer_stride -from ..nn.conv2d import conv2d, _get_workload, _get_schedule, _WORKLOADS +from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, \ + _get_workload, _get_schedule, Workload from . import conv2d_avx_1x1, conv2d_avx_common from .conv2d_avx_common import AVXConvCommonFwd from .conv2d_avx_1x1 import AVXConv1x1Fwd -_AVX_SCH_TO_DECL_FUNC = { - AVXConvCommonFwd: conv2d_avx_common._declaration_conv, - AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv -} - -_AVX_SCH_TO_SCH_FUNC = { - AVXConvCommonFwd: conv2d_avx_common._schedule_conv, - AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv -} - @_get_schedule.register("cpu") def _get_schedule_conv(wkl): - if wkl not in _WORKLOADS: - raise ValueError("no schedule for such workload: {}".format(wkl)) - idx = _WORKLOADS.index(wkl) + _WORKLOADS_AVX = [ + # workloads of resnet18_v1 on imagenet + Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2), + Workload('float32', 'float32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + Workload('float32', 'float32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + Workload('float32', 'float32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + Workload('float32', 'float32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + Workload('float32', 'float32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + Workload('float32', 'float32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + Workload('float32', 'float32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), + # workloads of resnet34_v1 on imagenet, no extra workload required + # workloads of resnet50_v1 on imagenet + Workload('float32', 'float32', 56, 56, 64, 256, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 56, 56, 256, 64, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 56, 56, 256, 128, 1, 1, 0, 0, 2, 2), + Workload('float32', 'float32', 28, 28, 128, 512, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 56, 56, 256, 512, 1, 1, 0, 0, 2, 2), + Workload('float32', 'float32', 28, 28, 512, 128, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 28, 28, 512, 256, 1, 1, 0, 0, 2, 2), + Workload('float32', 'float32', 14, 14, 256, 1024, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 28, 28, 512, 1024, 1, 1, 0, 0, 2, 2), + Workload('float32', 'float32', 14, 14, 1024, 256, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 14, 14, 1024, 512, 1, 1, 0, 0, 2, 2), + Workload('float32', 'float32', 7, 7, 512, 2048, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2), + Workload('float32', 'float32', 7, 7, 2048, 512, 1, 1, 0, 0, 1, 1), + # workloads of resnet101_v1 on imagenet, no extra workload required + # workloads of resnet152_v1 on imagenet, no extra workload required + # workloads of resnet18_v2 on imagenet, no extra workload required + # workloads of resnet34_v2 on imagenet, no extra workload required + ] fp32_vec_len = 8 target = tvm.target.current_target(allow_none=False) @@ -32,43 +55,61 @@ def _get_schedule_conv(wkl): if opt == '-mcpu=skylake-avx512': fp32_vec_len = 16 - _SCHEDULES_AVX_NCHW = [ - # float32 resnet-18 + _SCHEDULES_AVX = [ + # workloads of resnet18_v1 on imagenet AVXConvCommonFwd(3, fp32_vec_len, 28, False), - AVXConvCommonFwd(16, fp32_vec_len, 28, False), - AVXConv1x1Fwd(16, fp32_vec_len, 1, 28), - AVXConvCommonFwd(16, fp32_vec_len, 28, False), - AVXConv1x1Fwd(16, fp32_vec_len, 1, 28), - AVXConvCommonFwd(16, fp32_vec_len, 28, False), - AVXConvCommonFwd(16, fp32_vec_len, 14, False), - AVXConv1x1Fwd(16, fp32_vec_len, 2, 14), - AVXConvCommonFwd(16, fp32_vec_len, 14, True), - AVXConvCommonFwd(16, 32, 7, True), - AVXConv1x1Fwd(16, fp32_vec_len, 1, 7), - AVXConvCommonFwd(16, fp32_vec_len, 7, True), - # float32 mobilenet - AVXConvCommonFwd(3, fp32_vec_len, 28, False), - AVXConv1x1Fwd(16, fp32_vec_len, 1, 28), - AVXConv1x1Fwd(16, fp32_vec_len, 1, 28), - AVXConv1x1Fwd(16, fp32_vec_len, 1, 28), - AVXConv1x1Fwd(16, fp32_vec_len, 1, 28), - AVXConv1x1Fwd(16, fp32_vec_len, 1, 28), - AVXConv1x1Fwd(16, fp32_vec_len, 2, 14), - AVXConv1x1Fwd(16, fp32_vec_len, 2, 14), - AVXConv1x1Fwd(16, fp32_vec_len, 1, 7), - AVXConv1x1Fwd(16, fp32_vec_len, 1, 7), + AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28), + AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28), + AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False), + AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, False), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), + AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, True), + AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 7), + AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True), + # workloads of resnet34_v1 on imagenet, no extra workload required + # workloads of resnet50_v1 on imagenet + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), + AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), + # workloads of resnet101_v1 on imagenet, no extra workload required + # workloads of resnet152_v1 on imagenet, no extra workload required + # workloads of resnet18_v2 on imagenet, no extra workload required + # workloads of resnet34_v2 on imagenet, no extra workload required ] - sch = _SCHEDULES_AVX_NCHW[idx] + if wkl not in _WORKLOADS_AVX: + if wkl.hkernel == 1 and wkl.wkernel == 1: + return conv2d_avx_1x1._get_default_schedule(wkl, fp32_vec_len) + return conv2d_avx_common._get_default_schedule(wkl, fp32_vec_len) + idx = _WORKLOADS_AVX.index(wkl) + sch = _SCHEDULES_AVX[idx] return sch @conv2d.register("cpu") def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): + _AVX_SCH_TO_DECL_FUNC = { + AVXConvCommonFwd: conv2d_avx_common._declaration_conv, + AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv + } out_dtype = data.dtype if out_dtype is None else out_dtype target = tvm.target.current_target(allow_none=False) wkl = _get_workload(data, kernel, stride, padding, out_dtype) - if wkl in _WORKLOADS and 'avx' in str(target) and layout == 'NCHW': + if 'avx' in str(target) and layout == 'NCHW': sch = _get_schedule(wkl) return _AVX_SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, layout, out_dtype) elif layout == 'NCHW': @@ -81,9 +122,63 @@ def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): raise ValueError("not support this layout {} yet".format(layout)) +@conv2d_alter_layout.register("cpu") +def _alter_conv2d_layout(attrs, inputs, tinfos): + import nnvm.symbol as sym + copy_inputs = [s for s in inputs] + new_attrs = {k : attrs[k] for k in attrs.keys()} + # only optimize for NCHW, groups=1 conv + if attrs['layout'] != 'NCHW' or attrs.get_int("groups") != 1: + return None + + data = tinfos[0] + kernel = tinfos[1] + + import ast + padding = ast.literal_eval(attrs['padding']) + stride = ast.literal_eval(attrs['strides']) + + wkl = _get_workload(data, kernel, stride, padding, data.dtype) + sch = _get_schedule_conv(wkl) + is_kernel_1x1 = isinstance(sch, AVXConv1x1Fwd) + ic_bn, oc_bn = sch.ic_bn, sch.oc_bn + + new_attrs['layout'] = 'NCHW%dc' % ic_bn + new_attrs['out_layout'] = 'NCHW%dc' % oc_bn + + if is_kernel_1x1: + # (oc, ic, h, w) -> (OC, IC, ic, oc, h, w) + new_attrs['kernel_layout'] = 'OI%di%doHW' % (ic_bn, oc_bn) + else: + # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) + new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) + + return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) + + +@conv2d_NCHWc.register("cpu") +def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, out_dtype): + _AVX_SCH_TO_DECL_FUNC = { + AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc, + AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc + } + n, ic_chunk, h, w, ic_block = [x.value for x in data.shape] + ic = ic_chunk * ic_block + kh, kw = kernel_size + wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=out_dtype), + tvm.placeholder((num_filter, ic, kh, kw), dtype=out_dtype), + stride, padding, out_dtype) + sch = _get_schedule(wkl) + return _AVX_SCH_TO_DECL_FUNC[type(sch)](wkl, sch, data, kernel) + + @generic.schedule_conv2d_nchw.register(["cpu"]) def schedule_conv2d(outs): """Create schedule for tensors""" + _AVX_SCH_TO_SCH_FUNC = { + AVXConvCommonFwd: conv2d_avx_common._schedule_conv, + AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv + } s = tvm.create_schedule([x.op for x in outs]) target = tvm.target.current_target(allow_none=False) @@ -213,3 +308,49 @@ def traverse(op): traverse(output_op) return s + + +@generic.schedule_conv2d_NCHWc.register(["cpu"]) +def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, outs): + """Create schedule for tensors""" + _AVX_SCH_TO_SCH_FUNC = { + AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc, + AVXConv1x1Fwd: conv2d_avx_1x1._schedule_conv_NCHWc + } + s = tvm.create_schedule([x.op for x in outs]) + + def traverse(op): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + for tensor in op.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + + if 'conv2d_NCHWc' in op.tag: + conv_out = op.output(0) + kernel = conv_out.op.input_tensors[1] + data_vec = conv_out.op.input_tensors[0] + data = data_vec.op.input_tensors[0] \ + if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \ + else data_vec + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + n, ic_chunk, h, w, ic_block = [x.value for x in data.shape] + ic = ic_chunk * ic_block + original_data = tvm.placeholder((n, ic, h, w), dtype=conv_out.dtype) + + kh, kw = kernel_size + original_kernel = tvm.placeholder((num_filter, ic, kh, kw), dtype=conv_out.dtype) + + wkl = _get_workload(original_data, original_kernel, stride, padding, conv_out.dtype) + sch = _get_schedule(wkl) + _AVX_SCH_TO_SCH_FUNC[type(sch)](s, wkl, sch, data_vec, + kernel, conv_out, outs[0]) + + traverse(outs[0].op) + return s diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index afd0be2e2ded..ecbae0cc3128 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -1,4 +1,4 @@ -# pylint: disable=invalid-name,unused-variable,invalid-name +# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name """1x1 Conv2D schedule on for Intel CPU""" from __future__ import absolute_import as _abs from collections import namedtuple @@ -11,6 +11,34 @@ AVXConv1x1Fwd = namedtuple('AVXConv1x1Fwd', ['ic_bn', 'oc_bn', 'oh_factor', 'ow_factor']) + +def _get_default_schedule(wkl, simd_width): + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 + out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 + + oc_bn = 1 + for bn in range(simd_width, 0, -1): + if wkl.out_filter % bn == 0: + oc_bn = bn + break + + ic_bn = 1 + for bn in range(oc_bn, 0, -1): + if wkl.in_filter % bn == 0: + ic_bn = bn + break + + for ow_factor in range(out_width, 0, -1): + if out_width % ow_factor == 0: + for oh_factor in range(out_height, 0, -1): + if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: + return AVXConv1x1Fwd(ic_bn, oc_bn, oh_factor, ow_factor) + + raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) + + def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): assert layout == 'NCHW', "only support NCHW convolution for AVX" wkl = _get_workload(data, kernel, stride, padding, out_dtype) @@ -124,3 +152,80 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou s[O].parallel(parallel_axis) return s + + +def _declaration_conv_NCHWc(wkl, sch, data, kernel): + out_dtype = wkl.out_dtype + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + + batch_size = data.shape[0] + out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 + out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 + + DOPAD = (HPAD != 0 and WPAD != 0) + if DOPAD: + data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") + else: + data_pad = data + + oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn) + ic = tvm.reduce_axis((0, wkl.in_filter), name='ic') + conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: + tvm.sum(data_pad[n, ic//sch.ic_bn, oh*HSTR, ow*WSTR, ic%sch.ic_bn] + .astype(out_dtype) * + kernel[oc_chunk, ic // sch.ic_bn, ic % sch.ic_bn, oc_block, 0, 0], + axis=[ic]), name='conv2d_NCHWc', tag='conv2d_NCHWc') + + return conv + + +def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last): + # schedule data + A = data + if isinstance(s[A].op, tvm.tensor.ComputeOp): + batch, ic_chunk, ih, iw, ic_block = s[A].op.axis + parallel_axis = s[A].fuse(ic_chunk, ih) + s[A].parallel(parallel_axis) + + C, O = conv_out, last + CC = s.cache_write(C, 'global') + + batch, oc_chunk, oh, ow, oc_block = s[C].op.axis + oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor) + ow_outer, ow_inner = s[C].split(ow, factor=sch.ow_factor) + s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + s[C].vectorize(oc_block) + + parallel_axis = s[C].fuse(oc_chunk, oh_outer) + s[CC].compute_at(s[C], parallel_axis) + if C == O: + s[C].parallel(parallel_axis) + + _, oc_chunk, oh, ow, oc_block = s[CC].op.axis + ic, = s[CC].op.reduce_axis + + ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn) + + oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor) + ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor) + + s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block) + s[CC].fuse(oc_chunk, oh_outer) + s[CC].vectorize(oc_block) + + s[CC].unroll(ow_inner) + s[CC].unroll(oh_inner) + + if C != O: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor) + s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + + parallel_axis = s[O].fuse(oc_chunk, oh_outer) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + + return s diff --git a/topi/python/topi/x86/conv2d_avx_common.py b/topi/python/topi/x86/conv2d_avx_common.py index 03c1021be409..063759998108 100644 --- a/topi/python/topi/x86/conv2d_avx_common.py +++ b/topi/python/topi/x86/conv2d_avx_common.py @@ -1,4 +1,4 @@ -# pylint: disable=invalid-name,unused-variable,invalid-name +# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name """Conv2D schedule on for Intel CPU""" from __future__ import absolute_import as _abs from collections import namedtuple @@ -11,6 +11,34 @@ AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw']) + +def _get_default_schedule(wkl, simd_width): + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 + out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 + + oc_bn = 1 + for bn in range(simd_width, 0, -1): + if wkl.out_filter % bn == 0: + oc_bn = bn + break + + ic_bn = 1 + for bn in range(oc_bn, 0, -1): + if wkl.in_filter % bn == 0: + ic_bn = bn + break + + reg_n = 1 + for n in range(31, 0, -1): + if out_width % n == 0: + reg_n = n + break + + return AVXConvCommonFwd(ic_bn, oc_bn, reg_n, False) + + def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): out_dtype = data.dtype if out_dtype is None else out_dtype assert layout == 'NCHW', "only support NCHW convolution for AVX" @@ -141,3 +169,83 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou s[O].parallel(parallel_axis) return s + + +def _declaration_conv_NCHWc(wkl, sch, data, kernel): + out_dtype = wkl.out_dtype + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + + batch_size = data.shape[0] + out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 + out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 + + # pack data + DOPAD = (HPAD != 0 and WPAD != 0) + if DOPAD: + data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") + else: + data_pad = data + + # convolution + oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn) + + ic = tvm.reduce_axis((0, wkl.in_filter), name='ic') + kh = tvm.reduce_axis((0, wkl.hkernel), name='kh') + kw = tvm.reduce_axis((0, wkl.wkernel), name='kw') + + conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: + tvm.sum(data_pad[n, ic//sch.ic_bn, oh*HSTR+kh, ow*WSTR+kw, ic%sch.ic_bn] + .astype(out_dtype) * + kernel[oc_chunk, ic//sch.ic_bn, kh, kw, ic%sch.ic_bn, oc_block], + axis=[ic, kh, kw]), name='conv2d_NCHWc', tag="conv2d_NCHWc") + + return conv + + +def _schedule_conv_NCHWc(s, wkl, sch, data, kernel, conv_out, last): + # schedule data + A = data + if isinstance(s[A].op, tvm.tensor.ComputeOp): + batch, ic_chunk, ih, iw, ic_block = s[A].op.axis + parallel_axis = s[A].fuse(ic_chunk, ih) + s[A].parallel(parallel_axis) + + # schedule 5-D NCHW[x]c conv + C, O = conv_out, last + CC = s.cache_write(C, 'global') + + _, oc_chunk, oh, ow, oc_block = s[C].op.axis + ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n) + s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[C].fuse(oc_chunk, oh) + s[C].vectorize(oc_block) + if C == O: + s[C].parallel(parallel_axis) + + s[CC].compute_at(s[C], ow_chunk) + _, oc_chunk, oh, ow, oc_block = s[CC].op.axis + ic, kh, kw = s[CC].op.reduce_axis + + ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n) + ic_chunk, ic_block = s[CC].split(ic, factor=sch.ic_bn) + + if sch.unroll_kw: + s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, ic_block, kw, ow_block, oc_block) + s[CC].unroll(kw) + else: + s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, kw, ic_block, ow_block, oc_block) + + s[CC].vectorize(oc_block) + s[CC].unroll(ow_block) + + if C != O: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[O].fuse(oc_chunk, oh) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + + return s