From 84b01c3488780e51684f3c33b3a8bfe73fbe69b0 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 11 Mar 2020 09:03:30 -0700 Subject: [PATCH] Conv3D ONNX support and conv3D_ncdhw x86 schedules (#4949) * Support 3d Convolution with the ONNX frontend * add unit tests for conv3d in onnx frontend respond to PR formatting requests add x86 schedules to conv3d ncdhw test fix a doc string format issue refactor for changed upsream API * first attempt at conv3d autotuning add default schedule for conv3d_ncdhw fill in autotvm integration add a fallback for invalid schedules fix fallback fix reduction order to get simd working correctly --- python/tvm/relay/frontend/onnx.py | 16 +- python/tvm/relay/op/strategy/x86.py | 7 +- tests/python/frontend/onnx/test_forward.py | 76 +++--- topi/python/topi/x86/conv3d.py | 244 +++++++++++++++++++- topi/tests/python/test_topi_conv3d_ncdhw.py | 1 + 5 files changed, 298 insertions(+), 46 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 38ead20d1c9d..7f417d39e9f8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -91,16 +91,18 @@ def get_numpy(tensor_proto): return to_array(tensor_proto) -def dimension_picker(prefix, surfix=''): +def dimension_picker(prefix, suffix=''): """Check that dimensions are supported.""" def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 1: - return prefix + '1d' + surfix + return prefix + '1d' + suffix if len(kernel) == 2: - return prefix + '2d' + surfix - msg = 'Only 1D and 2D kernels are supported for operator {}.' - op_name = prefix + '1d/2d' + return prefix + '2d' + suffix + if len(kernel) == 3: + return prefix + '3d' + suffix + msg = 'Only 1D, 2D, and 3D kernels are supported for operator {}.' + op_name = prefix + '1d/2d/3d' raise tvm.error.OpAttributeInvalid(msg.format(op_name)) return _impl @@ -155,11 +157,11 @@ def onnx_storage_order2layout(storage_order, dims=2): def dimension_constraint(): def _dim_check(attrs): - if len(attrs['kernel_shape']) == 2 or len(attrs['kernel_shape']) == 1: + if len(attrs['kernel_shape']) in [1, 2, 3]: return True return False - return _dim_check, "Only 1d and 2d kernel supported." + return _dim_check, "Only 1d, 2d and 3d kernel supported." class OnnxOpConverter(object): diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 2fadb7f08dcd..e35838c1c5e8 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -188,10 +188,9 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target): strategy = _op.OpStrategy() layout = attrs.data_layout if layout == "NCDHW": - logger.warning("conv3d with layout NCDHW is not optimized for x86.") - strategy.add_implementation(wrap_compute_conv3d(topi.nn.conv3d_ncdhw), - wrap_topi_schedule(topi.generic.schedule_conv3d_ncdhw), - name="conv3d_ncdhw.generic") + strategy.add_implementation(wrap_compute_conv3d(topi.x86.conv3d_ncdhw), + wrap_topi_schedule(topi.x86.schedule_conv3d_ncdhw), + name="conv3d_ncdhw.x86") elif layout == "NDHWC": strategy.add_implementation(wrap_compute_conv3d(topi.x86.conv3d_ndhwc), wrap_topi_schedule(topi.x86.schedule_conv3d_ndhwc), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 20d7003e1353..44696f599071 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1794,37 +1794,51 @@ def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilat def test_conv(): - # Convolution with padding - # Conv2D - verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), [1, 1, 1, 1], [3, 3], [1, 1], [1, 1]) - # Conv1D - verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 5), [1, 1], [3], [1], [1]) - - # Convolution without padding - # Conv2D - verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 3, 3), [0, 0, 0, 0], [3, 3], [1, 1], [1, 1]) - # Conv1D - verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 3), [0, 0], [3], [1], [1]) - - # Convolution with autopadding - verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), - None, [3, 3], [1, 1], [1, 1], - auto_pad="SAME_UPPER") - # Conv1D - verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 5), None, [3], [1], [1], auto_pad="SAME_UPPER") - - # Convolution with non uniform stride - verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 3, 3), - None, [3, 3], [2, 2], [1, 1], - auto_pad="SAME_UPPER") - # Conv1D - verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 3), None, [3], [2], [1], auto_pad="SAME_UPPER") - - # Convolution with dilation - verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), [2, 2, 2, 2], [3, 3], [1, 1], [2, 2]) - # Conv1D - verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 5), [2, 2], [3], [1], [2]) - + def repeat(N, D): + return tuple([N for _ in range(D)]) + for D in [1, 2, 3]: + # Convolution with padding + verify_conv((1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(1, D), + repeat(3, D), + repeat(1, D), + repeat(1, D)) + # Convolution without padding + verify_conv((1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + 2 * repeat(0, D), + repeat(3, D), + repeat(1, D), + repeat(1, D)) + # Convolution with autopadding + verify_conv((1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + None, + repeat(3, D), + repeat(1, D), + repeat(1, D), + auto_pad="SAME_UPPER") + # Convolution with non uniform stride + verify_conv((1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + None, + repeat(3, D), + repeat(2, D), + repeat(1, D), + auto_pad="SAME_UPPER") + # Convolution with dilation + verify_conv((1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(5, D), + 2 * repeat(2, D), + repeat(3, D), + repeat(1, D), + repeat(2, D)) def verify_convtranspose(x_shape, w_shape, y_shape, p): node = onnx.helper.make_node("ConvTranspose", diff --git a/topi/python/topi/x86/conv3d.py b/topi/python/topi/x86/conv3d.py index 989ec4cf4ffc..27f48f8dc69a 100644 --- a/topi/python/topi/x86/conv3d.py +++ b/topi/python/topi/x86/conv3d.py @@ -42,7 +42,6 @@ def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): ---------- input : tvm.te.Tensor 5-D input data with shapes: - [batch, in_channel, in_depth, in_height, in_width] for NCDHW layout [batch, in_depth, in_height, in_width, in_channel] for NDHWC layout filter : tvm.te.Tensor @@ -61,7 +60,6 @@ def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): ------- output : tvm.te.Tensor 5-D with shape [batch, out_depth, out_height, out_width, out_channel] for NDHWC layout - 5-D with shape [batch, out_channel, out_depth, out_height, out_width] for NCDHW layout """ layout = "NDHWC" out_dtype = data.dtype if out_dtype is None else out_dtype @@ -74,14 +72,53 @@ def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): return _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype) +@autotvm.register_topi_compute("conv3d_ncdhw.x86") +def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, out_dtype): + """3D convolution forward operator. + + Parameters + ---------- + input : tvm.Tensor + 5-D input data with shapes: + [batch, in_channel, in_depth, in_height, in_width] for NCDHW layout + + filter : tvm.Tensor + 5-D filter with shape [out_channels, in_channels, kernel_depth, kernel_height, kernel_width] + + strides : int or a list/tuple of three ints + stride size, or [stride_depth, stride_height, stride_width] + + padding : int or a list/tuple of three ints + padding size, or [pad_depth, pad_height, pad_width] + + dilation: int or a list/tuple of three ints + dilation size, or [dilation_depth, dilation_height, dilation_width] + + Returns + ------- + output : tvm.Tensor + 5-D with shape [batch, out_channel, out_depth, out_height, out_width] for NCDHW layout + """ + layout = "NCDHW" + out_dtype = data.dtype if out_dtype is None else out_dtype + strides = strides if isinstance(strides, (tuple, list)) else (strides, strides, strides) + dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation, dilation) + + _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout) + if cfg.is_fallback: + _get_default_config(cfg, data, kernel, strides, padding, out_dtype, layout) + return _conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, layout, out_dtype) + @autotvm.register_topi_schedule("conv3d_ndhwc.x86") def schedule_conv3d_ndhwc(cfg, outs): """TOPI schedule callback for conv3d + Parameters ---------- outs: Array of Tensor The computation graph description of conv3d in the format of an array of tensors. + Returns ------- s: Schedule @@ -111,6 +148,45 @@ def _traverse(op): traverse_inline(s, outs[0].op, _traverse) return s +@autotvm.register_topi_schedule("conv3d_ncdhw.x86") +def schedule_conv3d_ncdhw(cfg, outs): + """TOPI schedule callback for conv3d + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv3d + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for conv3d. + """ + s = te.create_schedule([x.op for x in outs]) + + def _traverse(op): + if 'conv3d_ncdhw' in op.tag: + output = op.output(0) + conv_out = op.input_tensors[0] + kernel_vec = conv_out.op.input_tensors[1] + kernel = kernel_vec.op.input_tensors[0] + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + data_vec = conv_out.op.input_tensors[0] + data = data_vec.op.input_tensors[0] + data_pad = None + if isinstance(data.op, tvm.te.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + kd, kh, kw, i, o = get_const_tuple(kernel.shape) + args = [s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, outs[0]] + _schedule_conv3d_ncdhw(*args) + + traverse_inline(s, outs[0].op, _traverse) + return s + def _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): out_dtype = data.dtype if out_dtype is None else out_dtype @@ -198,6 +274,93 @@ def _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): tag='conv3d_ndhwc') return conv_unpacked +def _conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): + out_dtype = data.dtype if out_dtype is None else out_dtype + + assert isinstance(dilation, int) or len(dilation) == 3 + if isinstance(dilation, int): + dilation_d, dilation_h, dilation_w = (dilation, dilation, dilation) + else: + dilation_d, dilation_h, dilation_w = dilation + + DSTR, HSTR, WSTR = strides + batch_size, in_channel, in_depth, in_height, in_width = get_const_tuple(data.shape) + num_filter, _, kernel_depth, kernel_height, kernel_width = get_const_tuple(kernel.shape) + + dilated_kernel_d = (kernel_depth - 1) * dilation_d + 1 + dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + + pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d( + padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w)) + + pad_d = pad_front + pad_back + pad_h = pad_top + pad_down + pad_w = pad_left + pad_right + + pad_depth = in_depth + pad_d + pad_height = in_height + pad_h + pad_width = in_width + pad_w + + out_depth = simplify((in_depth + pad_d - dilated_kernel_d) // DSTR + 1) + out_height = simplify((in_height + pad_h - dilated_kernel_h) // HSTR + 1) + out_width = simplify((in_width + pad_w - dilated_kernel_w) // WSTR + 1) + + # pack data + DOPAD = (pad_d != 0 or pad_h != 0 or pad_w != 0) + if DOPAD: + data_pad = pad(data, (0, 0, pad_front, pad_top, pad_left), + (0, 0, pad_back, pad_down, pad_right), name="data_pad") + else: + data_pad = data + + # fetch schedule + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + + shape = (batch_size, in_channel // ic_bn, pad_depth, pad_height, ic_bn, pad_width) + data_vec = te.compute(shape, + lambda n, C, d, h, c, w: data_pad[n, C * ic_bn + c, d, h, w], + name='data_vec') + + # pack kernel + shape = (num_filter//oc_bn, in_channel//ic_bn, + kernel_depth, kernel_height, kernel_width, ic_bn, oc_bn) + kernel_vec = te.compute(shape, + lambda CO, CI, d, h, w, ci, co: + kernel[CO * oc_bn + co, CI * ic_bn + ci, d, h, w], + name='kernel_vec') + + # convolution + oshape = (batch_size, num_filter//oc_bn, + out_depth, out_height, out_width, oc_bn) + unpack_shape = (batch_size, num_filter, out_depth, out_height, out_width) + + ic = te.reduce_axis((0, in_channel), name='ic') + kh = te.reduce_axis((0, kernel_height), name='kh') + kw = te.reduce_axis((0, kernel_width), name='kw') + kd = te.reduce_axis((0, kernel_depth), name='kd') + idxmod = tvm.tir.indexmod + idxdiv = tvm.tir.indexdiv + + conv = te.compute(oshape, lambda n, oc_chunk, od, oh, ow, oc_block: + te.sum(data_vec[n, + idxdiv(ic, ic_bn), + od*DSTR+kd*dilation_d, + oh*HSTR+kh*dilation_h, + idxmod(ic, ic_bn), + ow*WSTR+kw*dilation_w].astype(out_dtype) * + kernel_vec[oc_chunk, idxdiv(ic, ic_bn), kd, kh, kw, + idxmod(ic, ic_bn), + oc_block].astype(out_dtype), + axis=[ic, kd, kh, kw]), name='conv') + conv_unpacked = te.compute(unpack_shape, + lambda n, c, d, h, w: conv[n, idxdiv(c, oc_bn), + d, h, w, + idxmod(c, oc_bn)] + .astype(out_dtype), + name='output_unpack', + tag='conv3d_ncdhw') + return conv_unpacked def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): """Create schedule configuration from input arguments""" @@ -206,6 +369,9 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): if layout == 'NDHWC': n, d, h, w, ic = dshape kd, kh, kw, _, oc = kshape + elif layout == 'NCDHW': + n, ic, d, h, w = dshape + oc, _, kd, kh, kw = kshape else: raise ValueError("Not support this layout {} with " "schedule template.".format(layout)) @@ -227,7 +393,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, layout): """ Get default schedule config for the workload """ - if layout != 'NDHWC': + if layout not in ['NDHWC', 'NCDHW']: raise ValueError("Layout {} is not supported".format(layout)) static_data_shape = [] @@ -244,7 +410,7 @@ def _get_conv3d_workload(data, kernel, stride, padding, out_dtype, data_layout=' """ Get the workload structure. """ if data_layout == 'NCDHW': _, CI, ID, IH, IW = get_const_tuple(data.shape) - CIG, CO, KD, KH, KW = get_const_tuple(kernel.shape) + CO, CIG, KD, KH, KW = get_const_tuple(kernel.shape) elif data_layout == 'NDHWC': _, ID, IH, IW, CI = get_const_tuple(data.shape) KD, KH, KW, CIG, CO = get_const_tuple(kernel.shape) @@ -365,3 +531,73 @@ def _schedule_conv3d_ndhwc(s, cfg, data, data_pad, data_vec, kernel_vec, conv_ou s[O].vectorize(oc_block) s[O].parallel(parallel_axis) return s + +def _schedule_conv3d_ncdhw(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last): + # fetch schedule + ic_bn, oc_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], + cfg["tile_ow"].size[-1], cfg["unroll_kw"].val) + + # get padding size + padding = infer_pad3d(data, data_pad, "NCDHW") + DPAD, HPAD, WPAD = padding + DOPAD = (DPAD != 0 or HPAD != 0 or WPAD != 0) + + A, W = data, kernel_vec + A0, A1 = data_pad, data_vec + + # schedule data + if DOPAD: + s[A0].compute_inline() + batch, ic_chunk, idd, ih, ic_block, iw = s[A1].op.axis + parallel_axis = s[A1].fuse(batch, ic_chunk, idd, ih) + s[A1].parallel(parallel_axis) + + # schedule kernel pack + oc_chunk, ic_chunk, od, oh, ow, ic_block, oc_block = s[W].op.axis + s[W].reorder(oc_chunk, od, oh, ic_chunk, ow, ic_block, oc_block) + if oc_bn > 1: + s[W].vectorize(oc_block) + parallel_axis = s[W].fuse(oc_chunk, od, oh) + s[W].parallel(parallel_axis) + + # schedule conv + C, O0, O = conv_out, output, last + CC = s.cache_write(C, 'global') + + _, oc_chunk, od, oh, ow, oc_block = s[C].op.axis + ow_chunk, ow_block = s[C].split(ow, factor=reg_n) + s[C].reorder(oc_chunk, od, oh, ow_chunk, ow_block, oc_block) + s[C].fuse(oc_chunk, od, oh) + s[C].vectorize(oc_block) + + s[CC].compute_at(s[C], ow_chunk) + _, oc_chunk, od, oh, ow, oc_block = s[CC].op.axis + ic, kd, kh, kw = s[CC].op.reduce_axis + + ow_chunk, ow_block = s[CC].split(ow, factor=reg_n) + ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn) + + if unroll_kw: + s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kd, kh, ic_block, kw, ow_block, oc_block) + s[CC].unroll(kw) + else: + s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kd, kh, kw, ic_block, ow_block, oc_block) + + s[CC].fuse(oc_chunk, od, oh) + s[CC].vectorize(oc_block) + s[CC].unroll(ow_block) + + if O0 != O: + s[O0].compute_inline() + + # unpacking + batch, oc, od, oh, ow = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) + s[O].reorder(oc_chunk, od, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[O].fuse(batch, oc_chunk, od, oh) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + + return s diff --git a/topi/tests/python/test_topi_conv3d_ncdhw.py b/topi/tests/python/test_topi_conv3d_ncdhw.py index 33e791716e34..c3e01aeb7a64 100644 --- a/topi/tests/python/test_topi_conv3d_ncdhw.py +++ b/topi/tests/python/test_topi_conv3d_ncdhw.py @@ -30,6 +30,7 @@ _conv3d_ncdhw_implement = { "generic": (topi.nn.conv3d_ncdhw, topi.generic.schedule_conv3d_ncdhw), + "cpu": (topi.x86.conv3d_ncdhw, topi.x86.schedule_conv3d_ncdhw), "gpu": (topi.cuda.conv3d_ncdhw, topi.cuda.schedule_conv3d_ncdhw), }