From 52769c5ca81d7036671222e87bb581c6a57ca8c5 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sun, 11 Nov 2018 00:52:22 +0800 Subject: [PATCH] [TOPI][CUDA] int8 group conv2d (#2075) --- nnvm/python/nnvm/top/nn.py | 5 + python/tvm/autotvm/task/nnvm_integration.py | 14 +- topi/python/topi/cuda/__init__.py | 3 +- topi/python/topi/cuda/group_conv2d_nchw.py | 308 ++++++++++++++++++ topi/python/topi/generic/nn.py | 19 ++ topi/python/topi/nn/conv2d.py | 77 +++++ .../python/topi/testing/conv2d_nchw_python.py | 37 ++- topi/tests/python/common.py | 15 + topi/tests/python/test_topi_conv2d_int8.py | 13 +- topi/tests/python/test_topi_group_conv2d.py | 215 ++++++++++++ 10 files changed, 690 insertions(+), 16 deletions(-) create mode 100644 topi/python/topi/cuda/group_conv2d_nchw.py create mode 100644 topi/tests/python/test_topi_group_conv2d.py diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 688dcccab110..44552ff7dcd0 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -109,6 +109,9 @@ def compute_conv2d(attrs, inputs, _): groups == channels: out = topi.nn.depthwise_conv2d_nchw( inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) + elif layout == "NCHW": + out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups, + out_dtype=out_dtype) elif layout == "NHWC" and \ kernel_layout == "HWOI" and \ groups == get_const_int(inputs[0].shape[3]) and \ @@ -144,6 +147,8 @@ def schedule_conv2d(attrs, outs, target): return topi.generic.schedule_depthwise_conv2d_nchw(outs) elif groups == channels and layout == "NHWC" and kernel_layout == "HWOI": return topi.generic.schedule_depthwise_conv2d_nhwc(outs) + elif layout == "NCHW": + return topi.generic.schedule_group_conv2d_nchw(outs) else: raise ValueError("No compatible schedule") diff --git a/python/tvm/autotvm/task/nnvm_integration.py b/python/tvm/autotvm/task/nnvm_integration.py index 80b62229a34e..6a07194a594d 100644 --- a/python/tvm/autotvm/task/nnvm_integration.py +++ b/python/tvm/autotvm/task/nnvm_integration.py @@ -58,7 +58,8 @@ def __init__(self): # NOTE: To add more symbols, you only need to change the following lists # nnvm symbol -> topi compute self.symbol2topi = { - nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw], + nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw, + topi.nn.group_conv2d_nchw], nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], nnvm.sym.dense: [topi.nn.dense], } @@ -67,6 +68,7 @@ def __init__(self): self.topi_to_task = { topi.nn.conv2d: "topi_nn_conv2d", topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw", + topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw", topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw", topi.nn.dense: "topi_nn_dense", } @@ -76,6 +78,7 @@ def __init__(self): topi.generic.schedule_conv2d_nhwc], topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw, topi.generic.schedule_depthwise_conv2d_nhwc], + topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw], topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw], topi.nn.dense: [topi.generic.schedule_dense], } @@ -143,6 +146,15 @@ def _topi_nn_depthwise_conv2d_nchw(*args, **kwargs): s = topi.generic.schedule_depthwise_conv2d_nchw([C]) return s, [A, W, C] + @register("topi_nn_group_conv2d_nchw") + def _topi_nn_group_conv2d_nchw(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + C = topi.nn.group_conv2d_nchw(*args, **kwargs) + s = topi.generic.schedule_group_conv2d_nchw([C]) + return s, [A, W, C] + @register("topi_nn_conv2d_transpose_nchw") def _topi_nn_conv2d_transpose_nchw(*args, **kwargs): assert not kwargs, "Do not support kwargs in template function call" diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index e1db2c6fdf63..28d2eb258bea 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -2,10 +2,11 @@ """CUDA specific declaration and schedules.""" from __future__ import absolute_import as _abs -from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw +from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw, group_conv2d_nchw from .conv2d_hwcn import schedule_conv2d_hwcn from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc +from .group_conv2d_nchw import schedule_conv2d_nchw_cuda from .reduction import schedule_reduce from .softmax import schedule_softmax from .injective import schedule_injective, schedule_elemwise, schedule_broadcast diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py new file mode 100644 index 000000000000..739691131284 --- /dev/null +++ b/topi/python/topi/cuda/group_conv2d_nchw.py @@ -0,0 +1,308 @@ +# pylint: disable=invalid-name +"""The template for cuda group_conv2d_nchw""" +import tvm +from tvm import autotvm + +from .injective import _schedule_injective +from .tensor_intrin import dp4a +from ..nn.pad import pad +from ..nn.util import get_pad_tuple +from ..util import traverse_inline, get_const_tuple, get_const_int +from .. import nn, generic + + +@autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], ['direct', 'int8']) +def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, + out_dtype='float32'): + """Group convolution operator in NCHW layout. + + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] or + 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] + + kernel : tvm.Tensor + 4-D with shape [num_filter, in_channel // groups, filter_height, filter_width] or + 6-D with shape [num_filter_chunk, in_channel_chunk // groups, filter_height, + filter_width, num_filter_block, in_channel_block] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + dilation : int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + groups : int + number of groups + + out_dtype : str + The output type. This is used for mixed precision. + + Returns + ------- + Output : tvm.Tensor + 5-D with shape [batch, out_channel, out_height, out_width, out_channel_block] + """ + ic_block_factor = 4 + oc_block_factor = 4 + + pre_computed = len(kernel.shape) == 6 + if not pre_computed: + batch, channels, height, width = get_const_tuple(data.shape) + out_channels, in_channels, kernel_h, kernel_w = get_const_tuple( + kernel.shape) + + assert channels % groups == 0, "input channels must divide group size" + assert out_channels % groups == 0, "output channels must divide group size" + assert channels % ic_block_factor == 0, \ + "Number of input channels per group must divide {}".format(ic_block_factor) + assert out_channels % 4 == 0, \ + "Number of output channels per group must divide {}".format(oc_block_factor) + + packed_data = tvm.compute((batch, channels // ic_block_factor, height, width, + ic_block_factor), + lambda n, c, h, w, vc: data[n, c*ic_block_factor + vc, h, w], + name="packed_data") + packed_kernel = tvm.compute( + (out_channels // oc_block_factor, in_channels // ic_block_factor, kernel_h, kernel_w, + oc_block_factor, ic_block_factor), + lambda oc_chunk, ic_chunk, kh, kw, oc_block, ic_block: + kernel[oc_chunk * oc_block_factor + oc_block, + ic_chunk * ic_block_factor + ic_block, kh, kw], + name="packed_kernel") + else: + packed_data = data + packed_kernel = kernel + + batch, ic_chunk, in_height, in_width, _ = get_const_tuple( + packed_data.shape) + oc_chunk, _, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple( + packed_kernel.shape) + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (kernel_h, kernel_w)) + # compute graph + pad_before = [0, 0, pad_top, pad_left, 0] + pad_after = [0, 0, pad_down, pad_right, 0] + pad_data = pad(packed_data, pad_before, pad_after, name="pad_data") + + # compute the output shape + out_height = (in_height - (kernel_h - 1) * dilation_h - + 1 + pad_top + pad_down) // stride_h + 1 + out_width = (in_width - (kernel_w - 1) * dilation_w - + 1 + pad_left + pad_right) // stride_w + 1 + + oshape = (batch, oc_chunk, out_height, out_width, oc_block) + + icc = tvm.reduce_axis((0, ic_chunk // groups), name='ic_chunk') + icb = tvm.reduce_axis((0, ic_block_factor), name='ic_block') + kh = tvm.reduce_axis((0, kernel_h), name='kh') + kw = tvm.reduce_axis((0, kernel_w), name='kw') + + conv = tvm.compute(oshape, lambda n, occ, oh, ow, ocb: + tvm.sum(pad_data[n, occ//(oc_chunk//groups)*(ic_chunk//groups)+icc, + oh*stride_h+kh*dilation_h, ow*stride_w+kw*dilation_w, icb] + .astype('int32') * + packed_kernel[occ, icc, + kh, kw, ocb, icb] + .astype('int32'), + axis=[icc, kh, kw, icb])) + + output = tvm.compute(oshape, lambda *index: conv(*index).astype(out_dtype), + tag='group_conv2d_NCHWc_int8') + num_flop = batch * oc_chunk * oc_block * out_height * out_width * \ + ic_chunk * ic_block * kernel_h * kernel_w * 2 // groups + cfg.add_flop(num_flop) + + return output + + +_dp4a = dp4a('shared', 'shared', 'local') + + +def schedule_group_conv2d_NCHWc_int8(cfg, s, output): + """Schedule group conv2d int8 NCHWc template""" + workload = output.op.attrs["workload"] + groups = get_const_int(workload[6]) + + conv = output.op.input_tensors[0] + packed_data, packed_kernel = conv.op.input_tensors + + if isinstance(packed_data.op, tvm.tensor.ComputeOp) and "pad" in packed_data.op.tag: + pad_data = packed_data + packed_data = pad_data.op.input_tensors[0] + else: + pad_data = packed_data + + if autotvm.GLOBAL_SCOPE.in_tuning: + # skip this part during tuning to make records accurate + # this part will be pre-computed during NNVM's pre-compute optimization pass + s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region") + s[packed_kernel].pragma( + s[packed_kernel].op.axis[0], "debug_skip_region") + else: + if isinstance(packed_kernel.op, tvm.tensor.ComputeOp) and\ + packed_kernel.name == 'packed_kernel': + # data and kernel are not pre-computed, schedule layout transform here + _schedule_injective(packed_data.op, s) + _schedule_injective(packed_kernel.op, s) + + if pad_data != packed_data: + s[pad_data].compute_inline() + + # create cache stage + AA = s.cache_read(pad_data, 'shared', [conv]) + WW = s.cache_read(packed_kernel, 'shared', [conv]) + + s[conv].set_scope('local') + + # handle bias + if output.op not in s.outputs: + s[output].compute_inline() + output = s.outputs[0].output(0) + + oc_chunk = get_const_int(output.shape[1]) + # tile and bind spatial axes + n, f, y, x, c = s[output].op.axis + cfg.define_split("tile_n", n, num_outputs=4) + cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2) + cfg.define_split("tile_f", cfg.axis(oc_chunk // groups), num_outputs=4) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + + # this is the scope to attach global config inside this kernel + kernel_scope, n = s[output].split(n, nparts=1) + + g, f = s[output].split(f, nparts=groups) + s[output].bind(n, tvm.thread_axis('blockIdx.z')) + bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n) + bg, vg = cfg["tile_g"].apply(s, output, g) + bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + s[output].reorder(bn, bg, bf, by, bx, vn, vg, vf, vy, + vx, tn, tf, ty, tx, ni, fi, yi, xi) + s[output].bind(bn, tvm.thread_axis("blockIdx.z")) + s[output].bind(s[output].fuse(bg, bf), tvm.thread_axis("blockIdx.y")) + s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x")) + s[output].bind(vn, tvm.thread_axis("vthread")) + s[output].bind(vg, tvm.thread_axis("vthread")) + s[output].bind(vf, tvm.thread_axis("vthread")) + s[output].bind(vy, tvm.thread_axis("vthread")) + s[output].bind(vx, tvm.thread_axis("vthread")) + cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf + if cfg["fuse_yx"].val: + s[output].bind(tn, tvm.thread_axis("threadIdx.z")) + s[output].bind(tf, tvm.thread_axis("threadIdx.y")) + tyx = s[output].fuse(ty, tx) + s[output].bind(tyx, tvm.thread_axis("threadIdx.x")) + s[conv].compute_at(s[output], tyx) + + # number of threads + n_tz = cfg["tile_n"].size[2] + n_ty = cfg["tile_f"].size[2] + n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2] + else: + s[output].bind(tn, tvm.thread_axis("threadIdx.z")) + s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z")) + s[output].bind(ty, tvm.thread_axis("threadIdx.y")) + s[output].bind(tx, tvm.thread_axis("threadIdx.x")) + s[conv].compute_at(s[output], tx) + + # number of threads + n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2] + n_ty = cfg["tile_y"].size[2] + n_tx = cfg["tile_x"].size[2] + + # tile and bind reduction axes + n, f, y, x, c = s[conv].op.axis + rc, ry, rx, rc_block = s[conv].op.reduce_axis + cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2) + cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2) + cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2) + rco, rci = cfg['tile_rc'].apply(s, conv, rc) + ryo, ryi = cfg['tile_ry'].apply(s, conv, ry) + rxo, rxi = cfg['tile_rx'].apply(s, conv, rx) + + s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x, c, rc_block) + _, rc_block = s[conv].split(rc_block, factor=4) + s[conv].tensorize(rc_block, _dp4a) + + s[AA].compute_at(s[conv], rxo) + s[WW].compute_at(s[conv], rxo) + + # cooperative fetching + for load in [AA, WW]: + c = s[load].op.axis[-1] + c_outer, c = s[load].split(c, factor=4) + s[load].vectorize(c) + fused = s[load].op.axis[:-1] + [c_outer] + fused = s[load].fuse(*fused) + + fused, tx = s[load].split(fused, factor=n_tx) + fused, ty = s[load].split(fused, factor=n_ty) + fused, tz = s[load].split(fused, factor=n_tz) + s[load].bind(tz, tvm.thread_axis("threadIdx.z")) + s[load].bind(ty, tvm.thread_axis("threadIdx.y")) + s[load].bind(tx, tvm.thread_axis("threadIdx.x")) + + # double buffer + cfg.define_knob('AA_double_buffer', [0, 1]) + cfg.define_knob('WW_double_buffer', [0, 1]) + if cfg['AA_double_buffer'].val: + s[AA].double_buffer() + if cfg['WW_double_buffer'].val: + s[WW].double_buffer() + + # unroll + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + s[output].pragma(kernel_scope, 'auto_unroll_max_step', + cfg['auto_unroll_max_step'].val) + s[output].pragma(kernel_scope, 'unroll_explicit', False) + + return s + + +@autotvm.register_topi_schedule(generic.schedule_group_conv2d_nchw, + ["cuda", "gpu"], ["direct", "int8"]) +def schedule_conv2d_nchw_cuda(cfg, outs): + """TOPI schedule callback of group conv2d for cuda gpu + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + outs: Array of Tensor + The computation graph description of conv2d + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for group conv2d. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "group_conv2d_NCHWc_int8": + schedule_group_conv2d_NCHWc_int8(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index a48b85638fb1..0f4b51b81433 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -173,6 +173,25 @@ def schedule_depthwise_conv2d_nhwc(outs): """ return _default_schedule(outs, False) + +@tvm.target.generic_func +def schedule_group_conv2d_nchw(outs): + """Schedule for conv2d_nchw + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of group_conv2d_nchw + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + @tvm.target.generic_func def schedule_bitserial_conv2d_nchw(outs): """Schedule for bitserial_conv2d_nchw diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 2b88886524bd..d4b9393c19dd 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -403,3 +403,80 @@ def conv2d_winograd_without_weight_transform(input, filter, strides, padding, di 4-D with shape [batch, out_height, out_width, out_channel] """ raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform") + + +@tvm.target.generic_func +def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtype=None): + """Group convolution operator in NCHW layout. + + Parameters + ---------- + Input : tvm.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] + + Filter : tvm.Tensor + 4-D with shape [num_filter, in_channel // groups, filter_height, filter_width] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + dilation : int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + groups : int + number of groups + + out_dtype : str + The output type. This is used for mixed precision. + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + if out_dtype is None: + out_dtype = Input.dtype + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_channel, in_height, in_width = get_const_tuple(Input.shape) + num_filter, _, kernel_h, kernel_w = get_const_tuple(Filter.shape) + + assert in_channel % groups == 0, "input channels must divide group size" + assert num_filter % groups == 0, "output channels must divide group size" + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (kernel_h, kernel_w)) + # compute the output shape + out_channel = num_filter + out_height = simplify( + (in_height - (kernel_h - 1) * dilation_h - 1 + pad_top + pad_down) // stride_h + 1) + out_width = simplify( + (in_width - (kernel_w - 1) * dilation_w - 1 + pad_left + pad_right) // stride_w + 1) + # compute graph + pad_before = [0, 0, pad_top, pad_left] + pad_after = [0, 0, pad_down, pad_right] + temp = pad(Input, pad_before, pad_after, name="pad_temp") + rc = tvm.reduce_axis((0, in_channel // groups), name='rc') + ry = tvm.reduce_axis((0, kernel_h), name='ry') + rx = tvm.reduce_axis((0, kernel_w), name='rx') + return tvm.compute( + (batch, out_channel, out_height, out_width), + lambda nn, ff, yy, xx: tvm.sum( + temp[nn, ff // (num_filter//groups) * (in_channel//groups) + rc, + yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w].astype(out_dtype) * + Filter[ff, rc, ry, rx].astype(out_dtype), + axis=[rc, ry, rx]), tag="conv2d_nchw") diff --git a/topi/python/topi/testing/conv2d_nchw_python.py b/topi/python/topi/testing/conv2d_nchw_python.py index 4a40d02d215c..7d2aa0d0fedf 100644 --- a/topi/python/topi/testing/conv2d_nchw_python.py +++ b/topi/python/topi/testing/conv2d_nchw_python.py @@ -4,8 +4,8 @@ import scipy.signal -def conv2d_nchw_python(a_np, w_np, stride, padding): - """Convolution operator in HWCN layout. +def _conv2d_nchw_python(a_np, w_np, stride, padding): + """Convolution operator in NCHW layout. Parameters ---------- @@ -66,3 +66,36 @@ def conv2d_nchw_python(a_np, w_np, stride, padding): apad, np.rot90(np.rot90(w_np[f, c])), mode='valid') b_np[n, f] += out[::stride_h, ::stride_w] return b_np + + +def conv2d_nchw_python(a_np, w_np, stride, padding, groups=1): + """Convolution operator in NCHW layout. + + Parameters + ---------- + a_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + w_np : numpy.ndarray + 4-D with shape [num_filter, in_channel // groups, filter_height, filter_width] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or str or a list/tuple of two ints + Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width] + + groups : int + Number of groups + + Returns + ------- + b_np : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + a_slices = np.array_split(a_np, groups, axis=1) + w_slices = np.array_split(w_np, groups, axis=0) + b_slices = [_conv2d_nchw_python(a_slice, w_slice, stride, padding) + for a_slice, w_slice in zip(a_slices, w_slices)] + b_np = np.concatenate(b_slices, axis=1) + return b_np diff --git a/topi/tests/python/common.py b/topi/tests/python/common.py index 763db5f86be2..f34f3b331fd1 100644 --- a/topi/tests/python/common.py +++ b/topi/tests/python/common.py @@ -1,5 +1,9 @@ """Common utility for topi test""" +from tvm import autotvm +from tvm.autotvm.task.space import FallbackConfigEntity + + def get_all_backend(): """return all supported target @@ -10,3 +14,14 @@ def get_all_backend(): """ return ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx', 'llvm -device=arm_cpu', 'opencl -device=mali', 'aocl_sw_emu'] + + +class NCHWcInt8Fallback(autotvm.FallbackContext): + def _query_inside(self, target, workload): + key = (target, workload) + if key in self.memory: + return self.memory[key] + cfg = FallbackConfigEntity() + cfg.template_key = 'int8' + self.memory[key] = cfg + return cfg diff --git a/topi/tests/python/test_topi_conv2d_int8.py b/topi/tests/python/test_topi_conv2d_int8.py index fd5e91eed72d..272a72f82619 100644 --- a/topi/tests/python/test_topi_conv2d_int8.py +++ b/topi/tests/python/test_topi_conv2d_int8.py @@ -9,7 +9,7 @@ from tvm.contrib.pickle_memoize import memoize from topi.util import get_const_tuple -from common import get_all_backend +from common import get_all_backend, NCHWcInt8Fallback oc_block_factor = 4 @@ -88,17 +88,6 @@ def check_device(device): check_device(device) -class NCHWcInt8Fallback(autotvm.FallbackContext): - def _query_inside(self, target, workload): - key = (target, workload) - if key in self.memory: - return self.memory[key] - cfg = FallbackConfigEntity() - cfg.template_key = 'int8' - self.memory[key] = cfg - return cfg - - def test_conv2d_nchw(): with NCHWcInt8Fallback(): # ResNet18 workloads where channels in / out are multiple of oc_block_factor diff --git a/topi/tests/python/test_topi_group_conv2d.py b/topi/tests/python/test_topi_group_conv2d.py new file mode 100644 index 000000000000..c1ff656fcd93 --- /dev/null +++ b/topi/tests/python/test_topi_group_conv2d.py @@ -0,0 +1,215 @@ +"""Example code to do group convolution.""" + +import numpy as np +import tvm +from tvm import autotvm +from tvm.autotvm.task.space import FallbackConfigEntity +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + +from common import get_all_backend, NCHWcInt8Fallback + + +def verify_group_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups, add_bias=False, add_relu=False): + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d)" % + (batch, in_channel, in_size, num_filter, + kernel, stride, padding, dilation, groups)) + + in_height = in_width = in_size + + A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') + W = tvm.placeholder((num_filter, in_channel // groups, kernel, kernel), name='W') + bias = tvm.placeholder((num_filter, 1, 1), name='bias') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_nchw") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups).astype(dtype) + + if add_bias: + b_np = np.random.uniform(size=bias_shape).astype(dtype) + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version): + print("Skip because int8 intrinsics are not available") + return + + print("Running on target: %s" % device) + with tvm.target.create(device): + C = topi.nn.group_conv2d_nchw(A, W, stride, padding, dilation, groups, out_dtype=dtype) + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = topi.generic.schedule_group_conv2d_nchw([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + if add_bias: + func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" %\ + (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups)) + func(a, w, b, c) + else: + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" % \ + (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups)) + func(a, w, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + for device in ["llvm"]: + check_device(device) + + +oc_block_factor = 4 + + +def verify_group_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups, add_bias=False, add_relu=False): + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d)" % + (batch, in_channel, in_size, num_filter, + kernel, stride, padding, dilation, groups)) + + in_height = in_width = in_size + + A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='int8') + W = tvm.placeholder((num_filter, in_channel // groups, kernel, kernel), name='W', dtype='int8') + bias = tvm.placeholder((num_filter // oc_block_factor, 1, 1, oc_block_factor), name='bias', + dtype='int8') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_NCHWc_int8") + def get_ref_data(): + a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype) + w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups).astype(dtype) + + # convert to NCHWc + _, _, out_height, out_width = c_np.shape + c_np = c_np.reshape((batch, num_filter // oc_block_factor, oc_block_factor, \ + out_height, out_width)).transpose(0, 1, 3, 4, 2) + + if add_bias: + b_np = np.random.uniform(size=bias_shape).astype(dtype) + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version): + print("Skip because int8 intrinsics are not available") + return + + print("Running on target: %s" % device) + with tvm.target.create(device): + C = topi.nn.group_conv2d_nchw(A, W, stride, padding, dilation, groups, out_dtype=dtype) + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = topi.generic.schedule_group_conv2d_nchw([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + if add_bias: + func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" %\ + (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups)) + func(a, w, b, c) + else: + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" % \ + (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups)) + func(a, w, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + for device in ["cuda"]: + check_device(device) + + +def test_group_conv2d_nchw(): + # ResNeXt-50 workload + verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32) + verify_group_conv2d_nchw(1, 256, 56, 256, 3, 2, 1, 1, 32) + verify_group_conv2d_nchw(1, 256, 28, 256, 3, 1, 1, 1, 32) + verify_group_conv2d_nchw(1, 512, 28, 512, 3, 2, 1, 1, 32) + verify_group_conv2d_nchw(1, 512, 14, 512, 3, 1, 1, 1, 32) + verify_group_conv2d_nchw(1, 1024, 14, 1024, 3, 2, 1, 1, 32) + verify_group_conv2d_nchw(1, 1024, 7, 1024, 3, 1, 1, 1, 32) + + # bias, relu + verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True) + verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32, add_bias=True) + verify_group_conv2d_nchw(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True, + add_bias=True) + + # dilation + verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 2, 32) + + # batch size + verify_group_conv2d_nchw(2, 128, 56, 128, 3, 1, 1, 1, 32) + verify_group_conv2d_nchw(9, 128, 56, 128, 3, 1, 1, 1, 32) + + + +def test_group_conv2d_NCHWc_int8(): + with NCHWcInt8Fallback(): + # ResNeXt-50 workload + verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32) + verify_group_conv2d_NCHWc_int8(1, 256, 56, 256, 3, 2, 1, 1, 32) + verify_group_conv2d_NCHWc_int8(1, 256, 28, 256, 3, 1, 1, 1, 32) + verify_group_conv2d_NCHWc_int8(1, 512, 28, 512, 3, 2, 1, 1, 32) + verify_group_conv2d_NCHWc_int8(1, 512, 14, 512, 3, 1, 1, 1, 32) + verify_group_conv2d_NCHWc_int8(1, 1024, 14, 1024, 3, 2, 1, 1, 32) + verify_group_conv2d_NCHWc_int8(1, 1024, 7, 1024, 3, 1, 1, 1, 32) + + # bias, relu + verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True) + verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_bias=True) + verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True, + add_bias=True) + # dilation + verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 2, 32) + + # batch size + verify_group_conv2d_NCHWc_int8(2, 128, 56, 128, 3, 1, 1, 1, 32) + verify_group_conv2d_NCHWc_int8(9, 128, 56, 128, 3, 1, 1, 1, 32) + + +if __name__ == "__main__": + test_group_conv2d_nchw() + test_group_conv2d_NCHWc_int8()