diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index cb1b0854d27f..d2aca1890f85 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -517,6 +517,65 @@ def test_upsampling(): _test_upsampling("NHWC", "BILINEAR") +def test_conv2d_int8_intrinsics(): + def _compile(input_dtype, weight_dtype, output_dtype, target): + n, ic, h, w, oc, ch, cw = 1, 16, 224, 224, 32, 3, 3 + x = relay.var("x", relay.TensorType((n, ic, h, w), input_dtype)) + w = relay.var("w", relay.TensorType((oc, ic, ch, cw), weight_dtype)) + y = relay.nn.conv2d(x, w, + kernel_size=(ch, cw), + channels=oc, + padding=(1, 1), + dilation=(1, 1), + out_dtype=output_dtype) + func = relay.Function([x, w], y) + wdata = np.random.rand(oc, ic, ch, cw) * 10 + parameters = {"w": tvm.nd.array(wdata.astype(weight_dtype))} + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(func, target, params=parameters) + assembly = lib.get_source("asm") + return assembly + + # compile conv2d for x86 (skylake) and test assembly contains *pmadd* instructions + target = "llvm -mcpu=skylake-avx512" + name = "llvm.x86.avx512.pmaddubs.w.512" + llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name) + if llvm_id != 0: + # Intel Int8 instruction need uint8 data and int8 kernel + asm = _compile(input_dtype="uint8", + weight_dtype="int8", + output_dtype="int32", + target=target) + # Check that intrinisic is present in the assembly. + assert "pmaddubs" in asm + + # Ensure that code is generated when datatypes are not HW supported. + asm = _compile(input_dtype="int8", + weight_dtype="int8", + output_dtype="int32", + target=target) + # Check that intrinisic is not present in the assembly. + assert "pmaddubs" not in asm + + # Ensure that code is generated when datatypes are not HW supported. + asm = _compile(input_dtype="uint8", + weight_dtype="uint8", + output_dtype="int32", + target=target) + # Check that intrinisic is not present in the assembly. + assert "pmaddubs" not in asm + + # Check that a vectorized instruction is generated for older Intel + # generations, because we default to NCHWc layout. + target = "llvm -mcpu=core-avx2" + asm = _compile(input_dtype="int8", + weight_dtype="int8", + output_dtype="int32", + target=target) + # Check that vector int mult and add instructions are generated. + assert "vpmulld" in asm and "vpadd" in asm + + if __name__ == "__main__": test_pool2d() test_avg_pool2d_no_count_pad() @@ -532,3 +591,4 @@ def test_upsampling(): test_conv2d_run() test_batch_flatten() test_upsampling() + test_conv2d_int8_intrinsics() diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 57c1d20c422f..bc49ba27d6a9 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -391,10 +391,7 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn - if data.dtype == 'uint8': - oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape) - else: - oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) + oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn # output shape @@ -413,26 +410,6 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou kh = tvm.reduce_axis((0, kernel_height), name='kh') kw = tvm.reduce_axis((0, kernel_width), name='kw') - if data.dtype == 'uint8': - assert out_dtype == "int32", \ - "INT8 convolution requires input dtype = uint8 and output dtype=int32" - # Intel performs dot product of 2 "4" Int8 values - # Current implementation requires ic_bn to be a multiple of 4 - n_elems = 4 - assert ic_bn % n_elems == 0 - - ic_outer = tvm.reduce_axis((0, in_channel//ic_bn), name='ic_outer') - ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') - ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') - return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: - tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw, - ic_f_inner * n_elems + ic_s_inner] - .astype(out_dtype) * - kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner, - oc_block, ic_s_inner].astype(out_dtype), - axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), - name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") - # else: fp implementation return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw, ic%ic_bn].astype(out_dtype) * diff --git a/topi/python/topi/x86/check_targets.py b/topi/python/topi/x86/check_targets.py deleted file mode 100644 index 1b929c2ef752..000000000000 --- a/topi/python/topi/x86/check_targets.py +++ /dev/null @@ -1,28 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument -"""Checks different x86 targets for target specific schedules""" - -def check_skylake(target): - """ - Checks if the target is skylake - """ - - for opt in target.options: - if opt == '-mcpu=skylake-avx512': - return True - return False diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index d51945e14282..40aa09a6fb3d 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -37,6 +37,29 @@ logger = logging.getLogger('topi') +def _is_int8_hw_support(data_dtype, kernel_dtype, target): + """ + Checks to ensure that we can use Intel DLBoost instructions + 1) The datatypes are correct. + 2) LLVM version has support for the instructions. + 3) Target is skylake and above. + """ + # 1) Check datatypes + is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8' + + # 2) Check LLVM support + llvm_intrin_fast_int8 = "llvm.x86.avx512.pmaddubs.w.512" + llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(llvm_intrin_fast_int8) + is_llvm_support = llvm_id != 0 + + # 3) Check target + is_target_support = False + for opt in target.options: + if opt == '-mcpu=skylake-avx512': + is_target_support = True + + return is_dtype_support and is_llvm_support and is_target_support + def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, layout='NCHW'): """ @@ -68,7 +91,8 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): kh, kw, oc, _ = kshape elif pat.match(layout) is not None: n, ic_chunk, h, w, ic_bn = dshape - if data.dtype == 'uint8': + target = tvm.target.current_target(allow_none=False) + if _is_int8_hw_support(data.dtype, kernel.dtype, target): oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape ic = ic_chunk*ic_bn assert ic == k_ic*k_ic_f*kic_s @@ -276,7 +300,6 @@ def traverse(op): args = [s, cfg, data_vec, conv_out, outs[0]] if data.dtype == 'uint8': - # int8 conv kernel is 7-dim kh, kw, _, _, _ = get_const_tuple(kernel.shape) if kh == 1 and kw == 1: conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args) @@ -453,19 +476,42 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc) + dispatch_ctx.update(target, new_workload, cfg) else: - out_channel, _, kh, kw = get_const_tuple(kernel.shape) - # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) - new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) - - # Store altered operator's config - new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn), - dtype=kernel.dtype) - new_workload = autotvm.task.args_to_workload( - [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], - new_attrs['out_layout'], out_dtype], conv2d_NCHWc) - - dispatch_ctx.update(target, new_workload, cfg) + if _is_int8_hw_support(data.dtype, kernel.dtype, target): + # Convert kernel data layout from 4D to 7D + n_elems = 4 + out_channel, _, kh, kw = get_const_tuple(kernel.shape) + data_expr, kernel_expr = inputs + kernel_IHWO = F.transpose(kernel_expr, axes=(1, 2, 3, 0)) + kernel_IHWOo = F.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel//oc_bn, oc_bn)) + kernel_OHWoI = F.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0)) + kernel_OHWoIi = F.reshape(kernel_OHWoI, (out_channel//oc_bn, kh, kw, oc_bn, + in_channel//ic_bn, ic_bn)) + kernel_OHWoIie = F.reshape(kernel_OHWoIi, (out_channel//oc_bn, kh, kw, oc_bn, + in_channel//ic_bn, ic_bn//n_elems, n_elems)) + kernel_OIHWioe = F.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6)) + copy_inputs = [data_expr, kernel_OIHWioe] + # Store altered operator's config + new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn, + in_channel//ic_bn, ic_bn//n_elems, + n_elems)) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, + new_attrs[layout_name], new_attrs['out_layout'], out_dtype], + conv2d_NCHWc) + dispatch_ctx.update(target, new_workload, cfg) + else: + out_channel, _, kh, kw = get_const_tuple(kernel.shape) + # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) + new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) + # Store altered operator's config + new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, + kh, kw, ic_bn, oc_bn), dtype=kernel.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], + new_attrs['out_layout'], out_dtype], conv2d_NCHWc) + dispatch_ctx.update(target, new_workload, cfg) if is_depthwise: if F.__name__ == 'nnvm.symbol': @@ -505,7 +551,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn - if data.dtype == 'uint8': + target = tvm.target.current_target(allow_none=False) + if _is_int8_hw_support(data.dtype, kernel.dtype, target): oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \ get_const_tuple(kernel.shape) else: @@ -539,7 +586,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, kh = tvm.reduce_axis((0, kernel_height), name='kh') kw = tvm.reduce_axis((0, kernel_width), name='kw') - if data.dtype == 'uint8' and groups == 1: + if _is_int8_hw_support(data.dtype, kernel.dtype, target) and groups == 1: assert out_dtype == "int32", \ "INT8 convolution requires input dtype = uint8 and output dtype=int32" # Intel performs dot product of 2 "4" Int8 values @@ -559,7 +606,8 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, oc_block, ic_s_inner].astype(out_dtype), axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") - if data.dtype == 'uint8': + + if _is_int8_hw_support(data.dtype, kernel.dtype, target): # for int8 group conv support n_elems = 4 ic_chunk = in_channel//ic_bn @@ -615,7 +663,8 @@ def traverse(op): data = data_pad.op.input_tensors[0] args = [s, cfg, data_vec, conv_out, outs[0]] - if data.dtype == 'uint8': + target = tvm.target.current_target(allow_none=False) + if _is_int8_hw_support(data.dtype, kernel.dtype, target): # int8 conv kernel is 7-dim _, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape) if kh == 1 and kw == 1: diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index 256cea569c68..c486cbadef95 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -24,7 +24,6 @@ from ..nn.util import infer_pad, get_pad_tuple from ..util import get_const_tuple, simplify from .tensor_intrin import dot_16x1x16_int8_int8_int32 -from .check_targets import check_skylake from .util import get_fp32_len def _fallback_schedule(cfg, wkl): @@ -187,13 +186,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): More details - https://software.intel.com/en-us/articles/ lower-numerical-precision-deep-learning-inference-and-training """ - target = tvm.target.current_target(allow_none=False) - int32_lanes = -1 - if check_skylake(target): - int32_lanes = 16 - else: - return s - assert int32_lanes != -1 + int32_lanes = 16 oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1] _, _, _, _, ic_bn = get_const_tuple(data.shape) @@ -310,13 +303,11 @@ def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last): packing of weight to make the address access be friendly to int8 intrinsic """ - target = tvm.target.current_target(allow_none=False) - int32_lanes = -1 - if check_skylake(target): - int32_lanes = 16 - else: - return s - assert int32_lanes != -1 + # FIXME - https://github.com/dmlc/tvm/issues/3598 + # pylint: disable=unreachable + return s + + int32_lanes = 16 # assertion to fail the unhandled case _, _, _, ic_num = get_const_tuple(data.shape) diff --git a/topi/python/topi/x86/conv2d_avx_common.py b/topi/python/topi/x86/conv2d_avx_common.py index 44867c9e33d5..2088eb0d693d 100644 --- a/topi/python/topi/x86/conv2d_avx_common.py +++ b/topi/python/topi/x86/conv2d_avx_common.py @@ -23,7 +23,6 @@ from ..nn.util import infer_pad from ..util import get_const_tuple from .tensor_intrin import dot_16x1x16_int8_int8_int32 -from .check_targets import check_skylake from .util import get_fp32_len def _fallback_schedule(cfg, wkl): @@ -186,19 +185,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): More details - https://software.intel.com/en-us/articles/ lower-numerical-precision-deep-learning-inference-and-training """ - - # Currently INT8 operations are supported for only Skylake - # In future the _intrin_reduce4int8 will be updated for VNNI instructions - # In case of unsupported target, the schedule will go to the original - # compute - - target = tvm.target.current_target(allow_none=False) - int32_lanes = -1 - if check_skylake(target): - int32_lanes = 16 - else: - return s - assert int32_lanes != -1 + int32_lanes = 16 reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val _, _, _, _, ic_bn = get_const_tuple(data.shape) diff --git a/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py index 6ed1b4aabc16..2de814c33fd3 100644 --- a/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py +++ b/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py @@ -24,6 +24,7 @@ import topi.testing from tvm.contrib.pickle_memoize import memoize from topi.util import get_const_tuple +from nose.tools import nottest from common import get_all_backend @@ -97,15 +98,22 @@ def check_device(device): func(a, w, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3) - # for device in ["llvm -mcpu=skylake-avx512"]: - for device in ["llvm"]: + # for device in ["llvm"]: + for device in ["llvm -mcpu=skylake-avx512"]: with autotvm.tophub.context(device): # load tophub pre-tuned parameters check_device(device) - +@nottest def test_conv2d_NCHWc(): # ResNet50 workloads verify_group_conv2d_NCHWc_int8(1, 256, 32, 224, 64, 7, 2, 3) if __name__ == "__main__": - test_conv2d_NCHWc() + # The test requires Skylake and newer Intel machines to generate the correct + # instruction. This test directly calls the topi operator, requiring correct + # kernel shape. For older generation of Intel machines, the kernel needs to + # be 6D. This test tests 7D kernel, that can only work on Skylake+ machines. + # So, disabling the test. + + # test_conv2d_NCHWc() + pass