diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 6b2e073822f11..fce7b78d2d0b8 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -22,10 +22,43 @@ from tvm import relay from .. import op as reg +################################################# +# Register the functions for different operators. +################################################# + # Registering QNN Conv2D legalization function. @reg.register_qnn_legalize("qnn.conv2d") def legalize_qnn_conv2d(attrs, inputs, types): - """Legalizes QNN conv2d op. + return qnn_conv2d_legalize(attrs, inputs, types) + +# Registering QNN dense legalization function. +@reg.register_qnn_legalize("qnn.dense") +def legalize_qnn_dense(attrs, inputs, types): + return qnn_dense_legalize(attrs, inputs, types) + +# Default to None. If overridden by target, this will not be run. +# Generic QNN Conv2D legalization function. +@tvm.target.generic_func +def qnn_conv2d_legalize(attrs, inputs, types): + """Default legalization is None.""" + return None + +# Generic QNN Conv2D legalization function. +@tvm.target.generic_func +def qnn_dense_legalize(attrs, inputs, types): + """Default legalization is None.""" + return None + +################### +# Helper functions. +################### + +# Helper function for lowering in the abscence of fast Int8 arithmetic units. +def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): + """ Converts QNN operators into a sequence of Relay operators that are friendly to HW that do + not have fast Int8 arithmetic. For example, for ARM, LLVM utilizes the assembly instructions + much more efficiently if the convolution or dense operator input datatypes are int16 instead of + int8. More details are present at https://github.com/apache/incubator-tvm/pull/4277. Parameters ---------- @@ -41,19 +74,27 @@ def legalize_qnn_conv2d(attrs, inputs, types): result : tvm.relay.Expr The legalized expr """ - return qnn_conv2d_legalize(attrs, inputs, types) -# Generic QNN Conv2D legalization function. -@tvm.target.generic_func -def qnn_conv2d_legalize(attrs, inputs, types): - """Default legalization is None.""" - return None + # Collect the input exprs. + data, kernel = inputs -# Intel x86 QNN Conv2D legalization function. -@qnn_conv2d_legalize.register('cpu') -def _qnn_conv2d_legalize(attrs, inputs, types): - """Legalizes QNN conv2d op. VNNI supports u8 x i8 fast conv/MM. If the dtypes are already good, - we dont transform. Else, we shift the tensor values and zero points to change the dtype. + input_zp = attrs['input_zero_point'] + kernel_zp = attrs['kernel_zero_point'] + + shift_data = relay.subtract(relay.cast(data, dtype='int16'), + relay.const(input_zp, 'int16')) + shift_kernel = relay.subtract(relay.cast(kernel, dtype='int16'), + relay.const(kernel_zp, 'int16')) + new_attrs = {k : attrs[k] for k in attrs.keys()} + del new_attrs['kernel_zero_point'] + del new_attrs['input_zero_point'] + return relay_op(shift_data, shift_kernel, **new_attrs) + +# Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting. +def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op): + """Legalizes QNN conv2d/dense op for Intel HW. VNNI supports u8 x i8 fast conv/MM. If the dtypes + are already good, we dont transform. Else, we shift the tensor values and zero points to change + the dtype. Converting from int8 to uint8 can be done in following manner. @@ -95,14 +136,6 @@ def _shift(data, out_dtype): data_modified = relay.cast(data_modified, out_dtype) return data_modified - def _is_int8_hw_support(target): - """ - Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake - and above. - """ - supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'} - return supported_arches.intersection(set(target.options)) - # Collect the dtypes. data_dtype = types[0].dtype kernel_dtype = types[1].dtype @@ -110,11 +143,6 @@ def _is_int8_hw_support(target): # Collect the input exprs. data, kernel = inputs - # The VNNI transformations are applicable only Skylake and above.g - target = tvm.target.current_target(allow_none=False) - if not _is_int8_hw_support(target): - return None - # VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied. if data_dtype == 'uint8' and kernel_dtype == 'int8': return None @@ -137,4 +165,120 @@ def _is_int8_hw_support(target): new_attrs = {k : attrs[k] for k in attrs.keys()} new_attrs['input_zero_point'] = input_zp new_attrs['kernel_zero_point'] = kernel_zp - return relay.qnn.op.conv2d(data, kernel, **new_attrs) + return relay_op(data, kernel, **new_attrs) + +# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting. +def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op): + """ Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However, + many devices like ARM prefer the datatypes to be same for the HW units. This helper transforms + conv2d/dense such that both the dtypes are same. + + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + + def _shift(data, out_dtype): + """Shifts (add/subtracts) the qnn tensor with +/-128)""" + if out_dtype == 'uint8': + shift = 128 + elif out_dtype == 'int8': + shift = -128 + else: + raise ValueError("Unsupport out dtype.") + data_modified = relay.cast(data, 'int32') + data_modified = relay.add(data_modified, relay.const(shift, 'int32')) + data_modified = relay.cast(data_modified, out_dtype) + return data_modified + + # Collect the dtypes. + data_dtype = types[0].dtype + kernel_dtype = types[1].dtype + + # Collect the input exprs. + data, kernel = inputs + + # VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied. + if data_dtype == kernel_dtype: + return None + + assert 'int8' in data_dtype and 'int8' in kernel_dtype, \ + "Qnn Conv2D only accepts uint8 or int8 inputs" + + # Shift input if necessary. + input_zp = attrs['input_zero_point'] + data = _shift(data, kernel_dtype) + if data_dtype == 'int8': + input_zp = input_zp + 128 + elif data_dtype == 'uint8': + input_zp = input_zp - 128 + else: + raise RuntimeError("Qnn Conv2D only accepts uint8 or int8 inputs") + + # Call qnn.conv2d with modified inputs and zero points. + new_attrs = {k : attrs[k] for k in attrs.keys()} + new_attrs['input_zero_point'] = input_zp + return relay_op(data, kernel, **new_attrs) + +def is_fast_int8_hw_present(): + """ + Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake + and above. + """ + target = tvm.target.current_target(allow_none=False) + intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'} + is_present_intel = intel_supported_arches.intersection(set(target.options)) + + arm_supported_attr = '+v8.2a,+dotprod' + is_present_arm = False + for opt in target.options: + if arm_supported_attr in opt: + is_present_arm = True + + return is_present_intel or is_present_arm + +######################## +# ARM CPU legalizations. +######################## + +@qnn_conv2d_legalize.register('arm_cpu') +def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types): + # ARM likes the dtypes to be same. + if is_fast_int8_hw_present(): + return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) + return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d) + +@qnn_dense_legalize.register('arm_cpu') +def _qnn_dense_legalize_arm_cpu(attrs, inputs, types): + # ARM likes the dtypes to be same. + if is_fast_int8_hw_present(): + return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense) + return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense) + +########################## +# Intel CPU legalizations. +########################## + +@qnn_conv2d_legalize.register('cpu') +def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types): + # The VNNI transformations prefer uint8 x int8 datatypes. + if is_fast_int8_hw_present(): + return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.conv2d) + return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d) + +@qnn_dense_legalize.register('cpu') +def _qnn_dense_legalize_intel_cpy(attrs, inputs, types): + # The VNNI transformations prefer uint8 x int8 datatypes. + if is_fast_int8_hw_present(): + return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense) + return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense) diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index 55c1fa6d3187b..8ace7bc745eae 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -23,6 +23,14 @@ from tvm.relay.qnn.op import register_qnn_legalize from tvm.relay import transform, analysis +def alpha_equal(x, y): + """ + Wrapper around alpha equality which ensures that + the hash function respects equality. + """ + x = x['main'] + y = y['main'] + return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y) def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] @@ -82,11 +90,11 @@ def expected(): b = run_opt_pass(expected(), transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + def test_qnn_legalize_qnn_conv2d(): - data_shape = (1, 64, 256, 256) - kernel_shape = (128, 64, 3, 3) - for dtype in ['uint8', 'int8']: - data_dtype = kernel_dtype = dtype + def _get_mod(data_dtype, kernel_dtype): + data_shape = (1, 64, 256, 256) + kernel_shape = (128, 64, 3, 3) data = relay.var("data", shape=data_shape, dtype=data_dtype) kernel = relay.var("kernel", shape=kernel_shape, @@ -104,12 +112,145 @@ def test_qnn_legalize_qnn_conv2d(): mod = relay.Function(relay.analysis.free_vars(func), func) mod = relay.Module.from_expr(mod) + return mod + + # Check uint8 x uint8 and int8 x int8 transformation + for dtype in ('uint8', 'int8'): + mod = _get_mod(dtype, dtype) + ############################################################# + # Check transformations for platforms with fast Int8 support. + ############################################################# + # Check that Intel VNNI gets picked up. with tvm.target.create('llvm -mcpu=skylake-avx512'): - mod = relay.qnn.transform.Legalize()(mod) + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert 'cast' in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext() + + # Since same dtype, there should not be any transformation + with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert alpha_equal(mod, legalized_mod) + + ################################################################ + # Check transformations for platforms without fast Int8 support. + ################################################################ + # Older Intel versions. + with tvm.target.create('llvm'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext() + + # Older ARM vesions. + with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext() + + # Check uint8 x int8 transformation + mod = _get_mod('uint8', 'int8') + ############################################################# + # Check transformations for platforms with fast Int8 support. + ############################################################# + # Check no transformation for Intel VNNI. + with tvm.target.create('llvm -mcpu=skylake-avx512'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert alpha_equal(mod, legalized_mod) + + # ARM - so check that transformation has happened. + with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert 'cast' in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext() + + ################################################################ + # Check transformations for platforms without fast Int8 support. + ################################################################ + # Older Intel versions. + with tvm.target.create('llvm'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext() + + # Older ARM vesions. + with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext() + + +def test_qnn_legalize_qnn_dense(): + def _get_mod(data_dtype, kernel_dtype): + data_shape = (10, 3) + kernel_shape = (20, 3) + data = relay.var("data", shape=data_shape, + dtype=data_dtype) + kernel = relay.var("kernel", shape=kernel_shape, + dtype=kernel_dtype) + func = relay.qnn.op.dense( + data, kernel, + input_zero_point=1, + kernel_zero_point=1, + out_dtype='int32') + + mod = relay.Function(relay.analysis.free_vars(func), func) + mod = relay.Module.from_expr(mod) + return mod + + # Check uint8 x uint8 and int8 x int8 transformation + for dtype in ('uint8', 'int8'): + mod = _get_mod(dtype, dtype) + + ############################################################# + # Check transformations for platforms with fast Int8 support. + ############################################################# + # Check that Intel VNNI gets picked up. + with tvm.target.create('llvm -mcpu=skylake-avx512'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert 'cast' in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext() + + # Since same dtype, there should not be any transformation + with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert alpha_equal(mod, legalized_mod) + + ################################################################ + # Check transformations for platforms without fast Int8 support. + ################################################################ + # Older Intel versions. + with tvm.target.create('llvm'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext() + + # Older ARM vesions. + with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext() + + # Check uint8 x int8 transformation + mod = _get_mod('uint8', 'int8') + ############################################################# + # Check transformations for platforms with fast Int8 support. + ############################################################# + # Check no transformation for Intel VNNI. + with tvm.target.create('llvm -mcpu=skylake-avx512'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert alpha_equal(mod, legalized_mod) + + # ARM - so check that transformation has happened. + with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert 'cast' in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext() + + ################################################################ + # Check transformations for platforms without fast Int8 support. + ################################################################ + # Older Intel versions. + with tvm.target.create('llvm'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext() + + # Older ARM vesions. + with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'): + legalized_mod = relay.qnn.transform.Legalize()(mod) + assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext() - assert 'cast' in mod.astext() if __name__ == "__main__": test_qnn_legalize() test_qnn_legalize_qnn_conv2d() + test_qnn_legalize_qnn_dense()