From 316e4b45ce04505ebc458fb0c93b68eaf766519a Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 12 Mar 2020 09:51:09 -0700 Subject: [PATCH] [Strategy] Support for Int8 schedules - CUDA/x86 (#5031) * [CUDA] Op strategy changes for Int8 schedules. * Applying Haichen's suggestions. * Make 4D output work for task extraction. * Make x86 work. * Fix lint. * Lint fixes. * Tests, comments, out channel a multiple of 4. * Topi test. Co-authored-by: Ubuntu --- python/tvm/relay/frontend/mxnet.py | 8 +- python/tvm/relay/op/strategy/cuda.py | 16 ++-- python/tvm/relay/qnn/op/legalizations.py | 14 ++++ tests/python/relay/test_pass_qnn_legalize.py | 14 ++++ topi/python/topi/cuda/conv2d_alter_op.py | 80 +++++++++++++++++++ topi/python/topi/cuda/conv2d_int8.py | 22 +++++- topi/python/topi/generic/conv2d.py | 66 +++++++++++----- topi/tests/python/test_topi_conv2d_int8.py | 81 ++++++++++++++++++++ 8 files changed, 273 insertions(+), 28 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index c2bfd751facb..ba93bb2e3b81 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1373,8 +1373,8 @@ def _get_sum(_res, _output_scale, out_dtype): # 3) Clip/cast to change the out dtype. _res = relay.clip(_res, - a_min=float(tvm.api.min_value(out_dtype).value), - a_max=float(tvm.api.max_value(out_dtype).value)) + a_min=float(tvm.tir.op.min_value(out_dtype).value), + a_max=float(tvm.tir.op.max_value(out_dtype).value)) _res = relay.cast(_res, out_dtype) return _res @@ -1647,8 +1647,8 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): _op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale) rounded_bias = _op.round(multiplied_bias) clipped_bias = _op.clip(rounded_bias, - a_min=tvm.api.min_value('int32').value, - a_max=tvm.api.max_value('int32').value) + a_min=tvm.tir.op.min_value('int32').value, + a_max=tvm.tir.op.max_value('int32').value) requantized_bias = _op.cast(clipped_bias, 'int32') res = _op.nn.bias_add(res, requantized_bias, axis=-1) enable_float_output = attrs.get_bool('enable_float_output', False) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index b1e77bdc05ed..e5eff1c6b790 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -85,12 +85,18 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): if groups == 1: if layout == "NCHW": - # TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8. assert kernel_layout == "OIHW" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nchw), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw), - name="conv2d_nchw.cuda") + if data.dtype in ('int8', 'uint8') and kernel.dtype in ('int8', 'uint8'): + assert data.dtype == kernel.dtype + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_nchw_int8), + wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_int8), + name="conv2d_nchw_int8.cuda") + else: + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_nchw), + wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw), + name="conv2d_nchw.cuda") _, _, kh, kw = get_const_tuple(kernel.shape) if 2 < kh < 8 and 2 < kw < 8 and kh == kw and stride_h == 1 and stride_w == 1 and \ dilation_h == 1 and dilation_w == 1: diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index ad71313fef52..f9874b78467e 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -264,3 +264,17 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types): if is_fast_int8_on_intel(): return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense) return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense) + +##################### +# CUDA legalizations. +##################### + +@qnn_conv2d_legalize.register('cuda') +def _qnn_conv2d_legalize_cuda(attrs, inputs, types): + # CUDA prefers the dtypes to be same. + return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) + +@qnn_dense_legalize.register('cuda') +def _qnn_dense_legalize_cuda(attrs, inputs, types): + # CUDA prefers the dtypes to be same. + return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense) diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index 7d3d9cc106c8..ed05096aec29 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -177,6 +177,13 @@ def _get_mod(data_dtype, kernel_dtype): legalized_mod = relay.qnn.transform.Legalize()(mod) assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext() + ########################################### + # Check transformations for CUDA platforms. + ########################################### + with tvm.target.create('cuda'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert 'cast' in legalized_mod.astext() and "qnn" in legalized_mod.astext() + def test_qnn_legalize_qnn_dense(): def _get_mod(data_dtype, kernel_dtype): @@ -257,6 +264,13 @@ def _get_mod(data_dtype, kernel_dtype): legalized_mod = relay.qnn.transform.Legalize()(mod) assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext() + ########################################### + # Check transformations for CUDA platforms. + ########################################### + with tvm.target.create('cuda'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert 'cast' in legalized_mod.astext() and "qnn" in legalized_mod.astext() + if __name__ == "__main__": test_qnn_legalize() diff --git a/topi/python/topi/cuda/conv2d_alter_op.py b/topi/python/topi/cuda/conv2d_alter_op.py index b59827136c70..8d9e86c192a0 100644 --- a/topi/python/topi/cuda/conv2d_alter_op.py +++ b/topi/python/topi/cuda/conv2d_alter_op.py @@ -26,6 +26,7 @@ from .. import nn from ..util import get_const_tuple from .conv2d_winograd import _infer_tile_size +from ..nn import conv2d_legalize logger = logging.getLogger('topi') @@ -135,3 +136,82 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): return relay.nn.conv2d(*inputs, **new_attrs) return None + +@conv2d_legalize.register("cuda") +def _conv2d_legalize(attrs, inputs, arg_types): + """Legalizes Conv2D op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + + # Dilation not supported yet. Return None if dilation is not (1, 1) + dilation = attrs.get_int_tuple("dilation") + if not (dilation[0] == 1 and dilation[1] == 1): + return None + + # No legalization for depthwise convolutions yet. + groups = attrs.get_int("groups") + if groups != 1: + return None + + # Collect the input tensors. + data_tensor, kernel_tensor = arg_types[0], arg_types[1] + data_dtype = data_tensor.dtype + + # Collect the output tensor. + output_tensor = arg_types[2] + + # Collect the input exprs. + data, kernel = inputs + + # Get the conv attrs + new_attrs = {k: attrs[k] for k in attrs.keys()} + + # Get data layout. Return None if not NCHW + data_layout = attrs['data_layout'] + kernel_layout = attrs['kernel_layout'] + + # Pad input and output channels to use int8 schedule. + if data_dtype in ['int8', 'uint8']: + if data_layout == 'NCHW' and kernel_layout == "OIHW": + oc_modified = False + in_channel = data_tensor.shape[1].value + out_channel = kernel_tensor.shape[0].value + + # Pad input channel + if in_channel % 4 != 0: + new_in_channel = ((in_channel + 4) // 4) * 4 + diff = new_in_channel - in_channel + pad_width = ((0, 0), (0, diff), (0, 0), (0, 0)) + data = relay.nn.pad(data, pad_width=pad_width) + kernel = relay.nn.pad(kernel, pad_width=pad_width) + + # Pad output channel + new_out_channel = out_channel + if out_channel % 4 != 0: + new_out_channel = ((out_channel + 4) // 4) * 4 + diff = new_out_channel - out_channel + kernel = relay.nn.pad(kernel, pad_width=((0, diff), (0, 0), (0, 0), (0, 0))) + oc_modified = True + + if oc_modified: + new_attrs['channels'] = new_out_channel + out = tvm.relay.nn.conv2d(data, kernel, **new_attrs) + original_out_shape = [x.value for x in output_tensor.shape] + out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape) + else: + out = relay.nn.conv2d(data, kernel, **new_attrs) + return out + return None diff --git a/topi/python/topi/cuda/conv2d_int8.py b/topi/python/topi/cuda/conv2d_int8.py index ad97fa68d6aa..bc8aa35cc5ac 100644 --- a/topi/python/topi/cuda/conv2d_int8.py +++ b/topi/python/topi/cuda/conv2d_int8.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name +# pylint: disable=no-value-for-parameter """Int8 conv2d in NCHWc layout""" import tvm from tvm import te @@ -23,10 +24,23 @@ from .injective import schedule_injective_from_existing from .tensor_intrin import dp4a from ..nn.pad import pad +from ..nn.conv2d import unpack_NCHWc_to_nchw from ..nn.util import get_pad_tuple from ..util import get_const_tuple, traverse_inline +def conv2d_nchw_int8(data, kernel, strides, padding, dilation, out_dtype='int32'): + """Compute conv2d internally using conv2d_nchwc layout for int8 dtype""" + assert data.dtype in ('int8', 'uint8') + assert kernel.dtype in ('int8', 'uint8') + assert data.dtype == kernel.dtype + packed_out = conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, "NCHW", out_dtype) + return unpack_NCHWc_to_nchw(packed_out, out_dtype) + +def schedule_conv2d_nchw_int8(outs): + """Create schedule for tensors""" + return schedule_conv2d_NCHWc_int8(outs) + @autotvm.register_topi_compute("conv2d_NCHWc_int8.cuda") def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_dtype): """Convolution operator in NCHW[x]c layout for int8. @@ -205,7 +219,13 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output): output = s.outputs[0].output(0) # tile and bind spatial axes - n, f, y, x, c = s[output].op.axis + if len(s[output].op.axis) == 5: + n, f, y, x, c = s[output].op.axis + else: + # For task extraction of auto-tuning, the expected output is 4D. Since auto-tuning tasks + # are created from scratch, therefore the real auto-tuning will still happen on 5D output. + n, f, y, x = s[output].op.axis + cfg.define_split("tile_n", cfg.axis(n), num_outputs=4) cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) cfg.define_split("tile_y", cfg.axis(y), num_outputs=4) diff --git a/topi/python/topi/generic/conv2d.py b/topi/python/topi/generic/conv2d.py index 69984a169ac6..2d9f78b645db 100644 --- a/topi/python/topi/generic/conv2d.py +++ b/topi/python/topi/generic/conv2d.py @@ -144,7 +144,8 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out, parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) s[data_vec].parallel(parallel_axis) - oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis + # conv2d_nchwc_int8 has 7D kernel + oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block) oc_bn = cfg["tile_oc"].size[-1] if oc_bn > 1: @@ -189,13 +190,26 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out, s[CC].unroll(oc_f_inner) if C != O: - batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - ow_chunk, ow_block = s[O].split(ow, factor=reg_n) - s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) - parallel_axis = s[O].fuse(batch, oc_chunk, oh) - s[C].compute_at(s[O], parallel_axis) - s[O].vectorize(oc_block) - s[O].parallel(parallel_axis) + out_ndim = len(s[O].op.axis) + if out_ndim == 5: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[O].fuse(batch, oc_chunk, oh) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + elif out_ndim == 4: + batch, oc, oh, ow = s[O].op.axis + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) + oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) + s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + parallel_axis = s[O].fuse(batch, oc_chunk, oh) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + else: + raise ValueError("Unsupported output ndim: %s" % out_ndim) return s @@ -234,7 +248,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out, parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih) s[data_vec].parallel(parallel_axis) - oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis + # Conv2d int8 schedule has 7D kernel + oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block) oc_bn = cfg["tile_oc"].size[-1] if oc_bn > 1: @@ -277,14 +292,29 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out, s[CC].unroll(oh_inner) if C != O: - batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) - ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) - s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) - - parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) - s[C].compute_at(s[O], parallel_axis) - s[O].vectorize(oc_block) - s[O].parallel(parallel_axis) + out_ndim = len(s[O].op.axis) + if out_ndim == 5: + batch, oc_chunk, oh, ow, oc_block = s[O].op.axis + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) + s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + + parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + elif out_ndim == 4: + batch, oc, oh, ow = s[O].op.axis + oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) + s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + + parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) + s[C].compute_at(s[O], parallel_axis) + s[O].vectorize(oc_block) + s[O].parallel(parallel_axis) + else: + raise ValueError("Unsupported output ndim: %s" % out_ndim) return s diff --git a/topi/tests/python/test_topi_conv2d_int8.py b/topi/tests/python/test_topi_conv2d_int8.py index d784e5cd3f86..06f930ccdbbd 100644 --- a/topi/tests/python/test_topi_conv2d_int8.py +++ b/topi/tests/python/test_topi_conv2d_int8.py @@ -108,6 +108,76 @@ def check_device(device): check_device(device) +def verify_conv2d_nchw_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False): + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + + in_height = in_width = in_size + + A = te.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='int8') + W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W', dtype='int8') + bias = te.placeholder((num_filter, 1, 1), name='bias', dtype='int8') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw") + def get_ref_data(): + a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype) + w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding).astype(dtype) + + if add_bias: + b_np = np.random.uniform(size=bias_shape).astype(dtype) + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version): + print("Skip because int8 intrinsics are not available") + return + + print("Running on target: %s" % device) + with tvm.target.create(device): + C = topi.cuda.conv2d_nchw_int8(A, W, (stride, stride), padding, (dilation, dilation), + dtype) + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = topi.cuda.schedule_conv2d_nchw_int8([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + if add_bias: + tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func(a, w, b, c) + else: + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func(a, w, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + for device in ["cuda"]: + check_device(device) + + def test_conv2d_nchw(): with Int8Fallback(): # ResNet18 workloads where channels in / out are multiple of oc_block_factor @@ -204,6 +274,17 @@ def test_conv2d_nchw(): verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True) verify_conv2d_NCHWc_int8(1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True) + # Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just + # performing basic testing - one test for all different scenarios - batch, dilation etc.. + verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1) + verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1, add_relu=True) + verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1, dilation=2) + verify_conv2d_nchw_int8(9, 64, 56, 64, 3, 1, 1) + verify_conv2d_nchw_int8(4, 4, 4, 4, 4, 4, 4) + verify_conv2d_nchw_int8(1, 32, 149, 32, 3, 1, 0) + verify_conv2d_nchw_int8(7, 32, 149, 32, 3, 1, 0) + verify_conv2d_nchw_int8(1, 32, 35, 64, 7, 2, (0, 0, 1, 1)) + if __name__ == "__main__": test_conv2d_nchw()