diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index e62334132ecc..21115d07241c 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -25,6 +25,10 @@ from .mxnet import from_mxnet from .mxnet_qnn_op_utils import dequantize_mxnet_min_max +from .mxnet_qnn_op_utils import quantize_mxnet_min_max +from .mxnet_qnn_op_utils import get_mkldnn_int8_scale +from .mxnet_qnn_op_utils import get_mkldnn_uint8_scale +from .mxnet_qnn_op_utils import quantize_conv_bias_mkldnn_from_var from .keras import from_keras from .onnx import from_onnx from .tflite import from_tflite diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 1f85277712aa..508439354fb1 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -14,12 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, import-self, len-as-condition, no-else-return +# pylint: disable=invalid-name, import-self, len-as-condition, no-else-return, too-many-lines """MXNet symbol frontend.""" from __future__ import absolute_import as _abs import json +import numpy as np import tvm +from tvm import relay from topi.util import get_const_tuple from .. import analysis from .. import expr as _expr @@ -30,11 +32,23 @@ from .common import StrAttrsDict from .common import infer_type as _infer_type +from .common import infer_shape as _infer_shape +from .common import infer_value as _infer_value +from .common import get_name as _get_name from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast from .nnvm_common import _clip, _transpose, _upsampling from .nnvm_common import _elemwise_sum, _reshape from .nnvm_common import _warn_not_used +from .mxnet_qnn_op_utils import quantize_mxnet_min_max, \ + quantize_conv_weights_bias_channel_mkldnn_from_var, \ + quantize_conv_bias_mkldnn_from_var, \ + get_conv_mkldnn_requantized_scale_outDtype, \ + dequantize_mxnet_min_max, \ + get_mkldnn_int8_scale, \ + get_mkldnn_uint8_scale, \ + get_mkldnn_requantize_scale_outDtype + __all__ = ['from_mxnet'] @@ -44,8 +58,9 @@ "relu" : _op.nn.relu } + def _mx_fully_connected(inputs, attrs): - import mxnet as mx + import mxnet as mx #pylint: disable=import-outside-toplevel units = attrs.get_int("num_hidden") use_bias = not attrs.get_bool("no_bias", False) try: @@ -158,19 +173,13 @@ def _mx_conv1d(inputs, attrs): return res -def _mx_conv2d(inputs, attrs): +def _get_mx_conv2d_attrs(attrs): kernel_size = attrs.get_int_tuple("kernel") - if len(kernel_size) != 2: - raise tvm.error.OpAttributeInvalid( - 'Non 1D or 2D kernels are not supported for operator Convolution') data_layout = attrs.get_str("layout", "NCHW") - channel_axis = _get_channel_axis(data_layout, "conv2d") - if "kernel_layout" in attrs.attrs: kernel_layout = attrs.get_str("kernel_layout") else: kernel_layout = "HWIO" if data_layout == "NHWC" else "OIHW" - new_attrs = {} new_attrs["channels"] = attrs.get_int("num_filter") new_attrs["kernel_size"] = kernel_size @@ -180,6 +189,17 @@ def _mx_conv2d(inputs, attrs): new_attrs["groups"] = attrs.get_int("num_group", 1) new_attrs["data_layout"] = data_layout new_attrs["kernel_layout"] = kernel_layout + return new_attrs + +def _mx_conv2d(inputs, attrs): + kernel_size = attrs.get_int_tuple("kernel") + data_layout = attrs.get_str("layout", "NCHW") + if len(kernel_size) != 2: + raise tvm.error.OpAttributeInvalid( + 'Only 2D kernels are supported for operator Convolution') + + new_attrs = _get_mx_conv2d_attrs(attrs) + channel_axis = _get_channel_axis(data_layout, "conv2d") use_bias = not attrs.get_bool("no_bias", False) res = _op.nn.conv2d(inputs[0], inputs[1], **new_attrs) if use_bias: @@ -676,7 +696,8 @@ def _mx_resize(inputs, attrs): if scale_width is not None: width = (scale_width * shape[3]).astype("int32") size = (height, width) - return _op.image.resize(inputs[0], size, coordinate_transformation_mode="align_corners") + return _op.image.resize(inputs[0], size, + coordinate_transformation_mode="align_corners") def _mx_roi_pooling(inputs, attrs): new_attrs = {} @@ -1033,6 +1054,7 @@ def _mx_contrib_fifo_buffer(inputs, attrs): new_attrs['axis'] = attrs.get_int('axis') return _op.nn.fifo_buffer(*inputs, **new_attrs) + def _mx_cond(inputs, attrs, subgraphs): assert len(subgraphs) == 3 cond_input_locs = json.loads(attrs.get_str("cond_input_locs")) @@ -1075,6 +1097,582 @@ def _mx_cond(inputs, attrs, subgraphs): return ret +def _qnn_contrib_concat(inputs, attrs): + axis = attrs.get_int("dim", 1) + num_args = attrs.get_int("num_args", -1) + assert num_args > 0 + + input_exprs = inputs[0:num_args] + + min_start_idx = num_args + max_start_idx = num_args + 1 + + mins = list() + maxs = list() + + for i in range(min_start_idx, len(inputs), 2): + mins.append(inputs[i]) + + for i in range(max_start_idx, len(inputs), 2): + maxs.append(inputs[i]) + + # Check if all the input tensors have same qnn params. + if len(set(mins)) == 1 and len(set(maxs)) == 1: + output_min = mins[0] + output_max = maxs[0] + concat = _op.concatenate(tuple(input_exprs), axis=axis) + return concat, output_min, output_max + else: + # Get all dtypes. Find input and output scales, call concatenate. + dtypes = [_infer_type(x).checked_type.dtype for x in input_exprs] + assert all([x == 'uint8' for x in dtypes]), \ + "Current suppor is limited to uint8 inputs only." + new_min = min(mins) + new_max = max(maxs) + assert new_min == 0 + + output_scale = get_mkldnn_uint8_scale(new_min, new_max) + min_max = zip(mins, maxs) + input_scales = [get_mkldnn_uint8_scale(x, y) for (x, y) in min_max] + input_zeros = [0] * len(input_scales) + output_zero = 0 + + input_scales_expr = [relay.const(x, 'float32') for x in input_scales] + input_zeros_expr = [relay.const(x, 'int32') for x in input_zeros] + + output_scale_expr = relay.const(output_scale, 'float32') + output_zero_expr = relay.const(output_zero, 'int32') + + res = relay.qnn.op.concatenate(input_exprs, input_scales_expr, input_zeros_expr, + output_scale_expr, output_zero_expr, axis=axis) + return res, new_min, new_max + + +def _qnn_quantize(inputs, attrs): + out_dtype = 'int8' + out_type = attrs.get_str('out_type') + if out_type == 'auto': + if attrs.has_attr('min_calib_range') and attrs.has_attr('max_calib_range'): + if attrs.get_float('min_calib_range') >= 0: + out_dtype = 'uint8' + else: + out_dtype = 'int8' + else: + out_dtype = out_type + if out_dtype not in {'int8', 'uint8'}: + raise ValueError('Unsupported out_dtype: %s' % out_dtype) + min_calib_range = attrs.get_float('min_calib_range', 0.0) + max_calib_range = attrs.get_float('max_calib_range', 0.0) + quantized_output, _, _ = \ + quantize_mxnet_min_max(inputs[0], + min_range=min_calib_range, + max_range=max_calib_range, + out_dtype=out_dtype) + return quantized_output, min_calib_range, max_calib_range + + +def _qnn_contrib_quantized_fifo_buffer(inputs, attrs, params): + data = inputs[0] + buffer = inputs[1] + min_calib_range = inputs[2] + max_calib_range = inputs[3] + data_dtype = _infer_type(data).checked_type.dtype + buffer_shape = _infer_shape(buffer) + buffer_name = _get_name(buffer) + params[buffer_name] = _nd.array(np.zeros(buffer_shape).astype(data_dtype)) + new_buffer = relay.var(buffer_name, relay.TensorType(buffer_shape, data_dtype)) + inputs[1] = new_buffer + res = _op.nn.fifo_buffer(data=data, buffer=new_buffer, axis=attrs.get_int('axis')) + return res, min_calib_range, max_calib_range + + +def _get_subgraph_op(subgraphs, op_name): + assert len(subgraphs) == 1, \ + "Subgraph should have 1 node but has {}".format(len(subgraphs)) + subgraph = subgraphs[0] + nodes = subgraph['nodes'] + assert nodes is not None + for node in nodes: + if node['op'] == op_name: + return node + raise ValueError("Op {} was not found in the subgraph".format(op_name)) + + +def _qnn_conv(inputs, attrs, subgraphs, params): + def _has_fused_activation(_attrs, _supported_activations): + has_fused_activation = False + if attrs.get_bool('with_act', False) or attrs.get_bool('with_postsum_act', False): + subgraph_activation_attrs = _get_subgraph_op(subgraphs, 'Activation')['attrs'] + act_type = subgraph_activation_attrs['act_type'] + if act_type not in _supported_activations: + raise ValueError('Fused activation {} is not supported at ' + 'this time'.format(act_type)) + has_fused_activation = True + return has_fused_activation + + def _get_data_scale_and_zp(_data, _inputs, + _data_min_idx, _data_max_idx): + """ Finds the Qnn params for the data expr. """ + data_min = _inputs[_data_min_idx] + data_max = _inputs[_data_max_idx] + data_dtype = _infer_type(_data).checked_type.dtype + assert data_dtype in {'int8', 'uint8'} + if data_min < 0.0: + assert data_dtype == 'int8', \ + "Expect int8 when data_min < 0.0, consider quantize model with int8." + _data_scale = get_mkldnn_uint8_scale(data_min, data_max)\ + if data_dtype == 'uint8' \ + else get_mkldnn_int8_scale(data_min, data_max) + _data_zero_point = 0 + return _data_scale, _data_zero_point + + def _get_bn_alpha_coeff(_bn_gamma_idx, _bn_beta_idx, + _bn_running_mean_idx, _bn_running_var_idx): + """ Extract the BN coeff. These will be use later for BN folding into convolution. """ + # Extract relevant attrs from bn. + bn_attrs = _get_subgraph_op(subgraphs, 'BatchNorm')['attrs'] + bn_epsilon_param = float(bn_attrs['eps']) + bn_scale_param = bn_attrs['fix_gamma'] == "False" + bn_center_param = True + + # Extract the relevant relay expressions. + bn_running_var = inputs[_bn_running_var_idx] + bn_gamma = inputs[_bn_gamma_idx] + bn_beta = inputs[_bn_beta_idx] + bn_running_mean = inputs[_bn_running_mean_idx] + + # Get coefficient to multiply to weights. + bn_epsilon = relay.const(bn_epsilon_param, "float32") + denominator = relay.sqrt(relay.add(bn_running_var, bn_epsilon)) + _bn_scale = relay.divide(relay.const(1.0, "float32"), denominator) + if bn_scale_param: + _bn_scale = relay.multiply(bn_gamma, _bn_scale) + + # Get the shift. + _bn_shift = relay.negative(relay.multiply(bn_running_mean, _bn_scale)) + if bn_center_param: + _bn_shift = relay.add(bn_beta, _bn_shift) + + return _bn_scale, _bn_shift + + def _fold_bn(_bn_scale, _bn_shift, _has_bias, _has_bn): + """ Fold BN into kernel and bias. Get new kernel and bias. """ + _kernel = inputs[1] + if _bn_scale: + assert attrs.get_bool('with_bn', False) + # Weights are on OIHW, and _bn_scale is in O. + exp_bn_scale = relay.expand_dims(_bn_scale, axis=1, num_newaxis=3) + _kernel = relay.multiply(exp_bn_scale, _kernel) + + _bias = None + if _has_bias: + _bias = inputs[2] + if _has_bn: + assert _bn_shift is not None + assert _bn_scale is not None + _bias = relay.add(relay.multiply(_bn_scale, _bias), _bn_shift) + elif _has_bn: + assert _bn_shift is not None + assert _bn_scale is not None + _bias = _bn_shift + return _kernel, _bias + + def _get_quantized_kernel(_kernel, _bias, _data_scale): + # For quantizing, we need min/max of kernel. So, we have to pre compute this expr. + np_kernel = _infer_value(_kernel, params).asnumpy() + kernel_channel_min = np.amin(np_kernel, axis=(1, 2, 3)) + kernel_channel_max = np.amax(np_kernel, axis=(1, 2, 3)) + + np_bias = None + if _bias is not None: + np_bias = _infer_value(_bias, params).asnumpy() + return quantize_conv_weights_bias_channel_mkldnn_from_var(_kernel, + np_bias, + kernel_channel_min, + kernel_channel_max, + _data_scale) + + def _get_qnn_conv2d(_data, _kernel, _data_zero_point, + _kernel_zero_point, _data_scale, + _kernel_vector_scale, _conv2d_attrs): + return relay.qnn.op.conv2d( + _data, + _kernel, + input_zero_point=relay.const(_data_zero_point, 'int32'), + kernel_zero_point=relay.const(_kernel_zero_point, 'int32'), + input_scale=relay.const(_data_scale, 'float32'), + kernel_scale=relay.const(_kernel_vector_scale), + channels=_conv2d_attrs['channels'], + groups=_conv2d_attrs['groups'], + kernel_size=_conv2d_attrs['kernel_size'], + strides=_conv2d_attrs['strides'], + dilation=_conv2d_attrs['dilation'], + padding=_conv2d_attrs['padding'], + data_layout=_conv2d_attrs['data_layout'], + kernel_layout=_conv2d_attrs['kernel_layout']) + + def _get_requantized_op(_res, _input_scale, _output_scale, _out_dtype): + # Requantize to get the output back + return relay.qnn.op.requantize( + _res, + input_scale=relay.const(_input_scale), + input_zero_point=relay.const(0, 'int32'), + output_scale=relay.const(_output_scale, 'float32'), + output_zero_point=relay.const(0, 'int32'), + axis=1, + out_dtype=_out_dtype) + + def _get_sum(_res, _output_scale, out_dtype): + """ Handles sum of the second quantized tensor. """ + # This is done in following steps + # 1) rhs is the add's second operand. First rhs will be requantized to output scale with + # dtype int32. The int32 dtype is to keep precision high before adding. + # 2) Call normal add + # 3) Depending on final out_dtype, clip and cast (basically requantize). + + _output_scale = relay.const(_output_scale, 'float32') + data_sum = inputs[-5] + data_sum_min = inputs[-2] + data_sum_max = inputs[-1] + + data_sum_dtype = _infer_type(data_sum).checked_type.dtype + data_sum_scale = \ + get_mkldnn_uint8_scale(data_sum_min, data_sum_max) if data_sum_dtype == 'uint8' \ + else get_mkldnn_int8_scale(data_sum_min, data_sum_max) + data_sum_scale = relay.const(data_sum_scale, 'float32') + zero_point = relay.const(0, 'int32') + + # Save one requantize if the previous expr already has a requantize node. This also improves + # little bit with accuracy. + if isinstance(data_sum, _expr.Call) and data_sum.op.name == "qnn.requantize": + prev_input, prev_scale, prev_zero_point = data_sum.args[0:3] + prev_axis = data_sum.attrs.axis + data_sum = relay.qnn.op.requantize(prev_input, + input_scale=prev_scale, + input_zero_point=prev_zero_point, + output_scale=_output_scale, + output_zero_point=zero_point, + axis=prev_axis, + out_dtype='int32') + else: + data_sum = relay.qnn.op.requantize(data_sum, + input_scale=data_sum_scale, + input_zero_point=zero_point, + output_scale=_output_scale, + output_zero_point=zero_point, + out_dtype='int32') + + # 2) Add two int32 tensors. + _res = relay.add(_res, data_sum) + + # 3) Clip/cast to change the out dtype. + _res = relay.clip(_res, + a_min=float(tvm.api.min_value(out_dtype).value), + a_max=float(tvm.api.max_value(out_dtype).value)) + _res = relay.cast(_res, out_dtype) + return _res + + def _parse(): + assert len(subgraphs) == 1 + subgraph_conv_attrs = StrAttrsDict(_get_subgraph_op(subgraphs, 'Convolution')['attrs']) + + is_quantized = attrs.get_bool('quantized', False) + if is_quantized: + # The MKLDNN has a quantized convolution subgraph. There are many different arguments + # that are taken into account to parse the subgraph. + # * no_bias + # * with_sum + # * with_bn + # * with_postsum_relu + # * with_act + # + # Note - Relu/clip handling is not required because output min/max take care of that. + # + # The parsing can be broken down into following steps + # 1) Get the input data scale and zero points. + # 2) Extract BN params. + # 3) Fold the BN params into kernel and bias. + # 4) Quantize the kernel. + # 4) Call QNN conv2d op. + # 5) Quantize bias and call bias_add. + # 6) Handle sum of quantized tensors if needed. Or just Requantize. + + has_bias = not subgraph_conv_attrs.get_bool("no_bias", False) + has_sum = attrs.get_bool('with_sum', False) + has_bn = attrs.get_bool('with_bn', False) + + ############################################### + # 1) Get the input data scale and zero point. + ############################################### + # Last 2 indexes are data min and max. If the conv has a sum, then last 2 indexes are + # for the second tensor. So, the data min max indexes are last 3 and 4 + data_min_idx = -1 + data_max_idx = -2 + if has_sum: + data_min_idx = -4 + data_max_idx = -3 + + data = inputs[0] + data_scale, data_zero_point = \ + _get_data_scale_and_zp(data, inputs, data_min_idx, data_max_idx) + + + ############################# + # 2) Extract the BN params. + ############################# + # Find the indexes to look at for BN. + bn_scale = bn_shift = None + if has_bn: + if has_bias: + bn_start_idx = 3 + else: + bn_start_idx = 2 + + bn_gamma_idx = bn_start_idx + bn_beta_idx = bn_start_idx + 1 + bn_running_mean_idx = bn_start_idx + 2 + bn_running_var_idx = bn_start_idx + 3 + + bn_scale, bn_shift = _get_bn_alpha_coeff(bn_gamma_idx, + bn_beta_idx, + bn_running_mean_idx, + bn_running_var_idx) + + ######################################## + # 3) Fold the BN into kernel and bias. + ######################################## + kernel, bias = _fold_bn(bn_scale, bn_shift, has_bias, has_bn) + + ####################################################################### + # 4) Fold BN params into kernel. Get quantized kernel and QNN params. + ####################################################################### + kernel, kernel_vector_scale, kernel_zero_point = _get_quantized_kernel(kernel, bias, + data_scale) + + ########################## + # 5) Call QNN conv2d op. + ########################## + conv2d_attrs = _get_mx_conv2d_attrs(subgraph_conv_attrs) + res = _get_qnn_conv2d(data, kernel, data_zero_point, kernel_zero_point, data_scale, + kernel_vector_scale, conv2d_attrs) + + ############################################### + # 6) Fold BN params into bias. Call bias_add. + ############################################### + if has_bias or has_bn: + bias_scale = data_scale * kernel_vector_scale + int32_bias = quantize_conv_bias_mkldnn_from_var(bias, bias_scale) + res = _op.nn.bias_add(res, int32_bias, axis=1) + + ##################################################################### + # 7) Handle sum of quantized tensors if needed. Or just Requantize. + ##################################################################### + min_output_range = attrs.get_float('min_calib_range') + max_output_range = attrs.get_float('max_calib_range') + output_scale, out_dtype = get_conv_mkldnn_requantized_scale_outDtype(min_output_range, + max_output_range) + + # QNN conv2d output scale is product of data_scale and kernel_vector_scale + input_scale = data_scale * kernel_vector_scale + if attrs.get_bool('with_sum', False): + # There is a second tensor that has to be added to the QNN conv2d output. Therefore, + # the QNN conv2d is first requantized to output scale with int32 precision. The + # second tensor will also be requantized to output scale with int32 precision, + # followed by an add operator. + res = _get_requantized_op(res, input_scale, output_scale, 'int32') + res = _get_sum(res, output_scale, out_dtype) + else: + # Get the requantized conv output + res = _get_requantized_op(res, input_scale, output_scale, out_dtype) + + return res, min_output_range, max_output_range + else: + res = _mx_conv(inputs, subgraph_conv_attrs) + has_fused_relu = _has_fused_activation(attrs, ['relu']) + if has_fused_relu: + res = _op.nn.relu(res) + return res + + return _parse() + + +def _qnn_flatten(inputs, attrs): + #pylint: disable=unused-argument + data = inputs[0] + output_min = inputs[1] + output_max = inputs[2] + output = _op.nn.batch_flatten(data) + return output, output_min, output_max + + +def _qnn_dequantize(inputs, attrs): + #pylint: disable=unused-argument + data = inputs[0] + input_min = inputs[1] + input_max = inputs[2] + in_dtype = _infer_type(data).checked_type.dtype + result = dequantize_mxnet_min_max(data, input_min, input_max, in_dtype) + return result + + +def _qnn_activation(inputs, attrs): + act_type = attrs.get_str("act_type") + assert len(inputs) == 3 + assert act_type == "relu", "Currently only relu is supported" + data = inputs[0] + range_min = inputs[1] + range_max = inputs[2] + res = _op.nn.relu(data) + return res, range_min, range_max + + +def _qnn_pooling(inputs, attrs): + input_min = inputs[1] + input_max = inputs[2] + data = inputs[0] + data_dtype = _infer_type(data).checked_type.dtype + pool_type = attrs.get_str("pool_type") + if data_dtype in ('int8', 'uint8') and pool_type != 'max': + data = _op.cast(data, 'int32') + res = _mx_pooling([data, input_min, input_max], attrs) + if data_dtype in ('int8', 'uint8') and pool_type != 'max': + res = _op.cast(res, data_dtype) + return res, input_min, input_max + + +def _qnn_batch_norm(inputs, attrs): + # Perform batch norm in FP32 + data = inputs[0] + + # Dequantize the data. + data_min_idx, data_max_idx = (-2, -1) + data_min, data_max = inputs[data_min_idx], inputs[data_max_idx] + data_dtype = _infer_type(data).checked_type.dtype + data_scale = get_mkldnn_uint8_scale(data_min, data_max) if data_dtype == 'uint8' \ + else get_mkldnn_int8_scale(data_min, data_max) + data_zp = 0 + data = relay.qnn.op.dequantize(data, + relay.const(data_scale, 'float32'), + relay.const(data_zp, 'int32')) + + # Run BN. The last 4 inputs are same as before. + new_inputs = [data, *inputs[1:5]] + res = _mx_batch_norm(new_inputs, attrs) + + # Quantize the result + min_output_range = attrs.get_float('min_calib_range') + max_output_range = attrs.get_float('max_calib_range') + output_scale, out_dtype = get_conv_mkldnn_requantized_scale_outDtype(min_output_range, + max_output_range) + res = relay.qnn.op.quantize(res[0], + relay.const(output_scale, 'float32'), + relay.const(0, 'int32'), + out_dtype=out_dtype) + return res, min_output_range, max_output_range + + +def _qnn_fully_connected(inputs, attrs, subgraphs, params): + + def _get_input_scale_zp(_data, _inputs, _has_bias): + data_min_idx, data_max_idx = (3, 4) if _has_bias else (2, 3) + data_min, data_max = _inputs[data_min_idx], _inputs[data_max_idx] + data_dtype = _infer_type(_data).checked_type.dtype + _data_scale = get_mkldnn_uint8_scale(data_min, data_max) \ + if data_dtype == 'uint8' \ + else get_mkldnn_int8_scale(data_min, data_max) + _data_zp = 0 + return _data_scale, _data_zp + + def _get_kernel_scale_zp(_kernel, _inputs, _has_bias): + kernel_dtype = _infer_type(_kernel).checked_type.dtype + kernel_min_idx, kernel_max_idx = (5, 6) if _has_bias else (4, 5) + kernel_min_name = _get_name(_inputs[kernel_min_idx]) + kernel_min = params[kernel_min_name].asnumpy()[0] + kernel_max_name = _get_name(_inputs[kernel_max_idx]) + kernel_max = params[kernel_max_name].asnumpy()[0] + _kernel_scale = get_mkldnn_uint8_scale(kernel_min, kernel_max) \ + if kernel_dtype == 'uint8' \ + else get_mkldnn_int8_scale(kernel_min, kernel_max) + _kernel_zp = 0 + return _kernel_scale, _kernel_zp + + def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): + bias_min_name = _get_name(_inputs[7]) + bias_min = params[bias_min_name].asnumpy()[0] + bias_max_name = _get_name(_inputs[8]) + bias_max = params[bias_max_name].asnumpy()[0] + bias_scale = get_mkldnn_int8_scale(bias_min, bias_max) + _bias_requantize_scale = bias_scale/(_data_scale * _kernel_scale) + _bias_requantize_scale = _expr.const(_bias_requantize_scale, dtype="float32") + return _bias_requantize_scale + + is_quantized = attrs.get_bool('quantized', False) + with_relu = attrs.get_bool('with_relu', False) + subgraph_dense_attrs = StrAttrsDict(_get_subgraph_op(subgraphs, "FullyConnected")['attrs']) + if not is_quantized: + res = _mx_fully_connected(inputs, subgraph_dense_attrs) + if with_relu: + res = _op.nn.relu(res) + return res + else: + has_bias = not subgraph_dense_attrs.get_bool("no_bias", False) + # input + data = inputs[0] + data_scale, data_zp = _get_input_scale_zp(data, inputs, has_bias) + # kernel + kernel = inputs[1] + kernel_scale, kernel_zp = _get_kernel_scale_zp(kernel, inputs, has_bias) + units = subgraph_dense_attrs.get_int("num_hidden") + data_shape = _infer_type(data).checked_type.shape + if len(data_shape) > 2: + data = _op.nn.batch_flatten(data) + res = relay.qnn.op.dense(data, + kernel, + input_zero_point=relay.const(data_zp, 'int32'), + kernel_zero_point=relay.const(kernel_zp, 'int32'), + input_scale=relay.const(data_scale, 'float32'), + kernel_scale=relay.const(kernel_scale, 'float32'), + units=units) + if has_bias: + bias_data = inputs[2] + bias_requantize_scale = \ + _get_bias_requantize_scale(inputs, data_scale, kernel_scale) + multiplied_bias = \ + _op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale) + rounded_bias = _op.round(multiplied_bias) + clipped_bias = _op.clip(rounded_bias, + a_min=tvm.api.min_value('int32').value, + a_max=tvm.api.max_value('int32').value) + requantized_bias = _op.cast(clipped_bias, 'int32') + res = _op.nn.bias_add(res, requantized_bias, axis=-1) + enable_float_output = attrs.get_bool('enable_float_output', False) + out_dtype = 'uint8' if attrs.get_bool('with_relu', False) else 'int8' + input_scale = np.float32(data_scale * kernel_scale) + if not enable_float_output: + min_output_range = attrs.get_float('min_calib_range') + max_output_range = attrs.get_float('max_calib_range') + output_scale = get_mkldnn_requantize_scale_outDtype(min_output_range, + max_output_range, + out_dtype) + res = relay.qnn.op.requantize( + res, + input_scale=relay.const(input_scale, 'float32'), + input_zero_point=relay.const(0, 'int32'), + output_scale=relay.const(output_scale, 'float32'), + output_zero_point=relay.const(0, 'int32'), + out_dtype=out_dtype) + if with_relu: + res = _op.nn.relu(res) + return res, min_output_range, max_output_range + else: + output_scale = np.float32(data_scale * kernel_scale) + res = relay.qnn.op.dequantize(res, + relay.const(output_scale, 'float32'), + input_zero_point=relay.const(0, 'int32')) + if with_relu: + res = _op.nn.relu(res) + return res + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free _identity_list = [ @@ -1249,14 +1847,44 @@ def _mx_cond(inputs, attrs, subgraphs): # TODO(tvm-tvm): support all operators. # # "broadcast_to", - "contrib_fifo_buffer" : _mx_contrib_fifo_buffer, + # "contrib_fifo_buffer": _mx_contrib_fifo_buffer, + "ring_buffer": _mx_contrib_fifo_buffer, + # Qnn ops + "_contrib_quantize_v2": _qnn_quantize, + "_contrib_quantized_concat" : _qnn_contrib_concat, + # "_contrib_quantized_fifo_buffer": _qnn_contrib_quantized_fifo_buffer, + "_contrib_quantized_ring_buffer": _qnn_contrib_quantized_fifo_buffer, + "_sg_mkldnn_conv": _qnn_conv, + "_contrib_quantized_flatten": _qnn_flatten, + "_contrib_dequantize": _qnn_dequantize, + "_contrib_quantized_act": _qnn_activation, + "_contrib_quantized_pooling": _qnn_pooling, + "_contrib_quantized_batch_norm" : _qnn_batch_norm, + "_sg_mkldnn_fully_connected": _qnn_fully_connected, } # set identity list -_convert_map.update({k : _rename(k) for k in _identity_list}) +_convert_map.update({k: _rename(k) for k in _identity_list}) + +_control_flow_ops = ['_cond', '_foreach', '_while_loop'] +_qnn_subgraph_ops = ['_sg_mkldnn_conv', '_sg_mkldnn_fully_connected'] +_subgraph_ops = _control_flow_ops + _qnn_subgraph_ops +_params_ops = ['_contrib_quantized_ring_buffer'] -def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None): +def _get_op_params(children, attrs, op_name, node, params): + op_params = [children, attrs] + if op_name in _subgraph_ops: + subgraphs = node['subgraphs'] + op_params.append(subgraphs) + if op_name in _qnn_subgraph_ops: + op_params.append(params) + if op_name in _params_ops: + op_params.append(params) + return op_params + + +def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None): #pylint: disable=unused-argument """Convert mxnet symbol to compatible relay Function. @@ -1314,11 +1942,9 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None): shape_idx += 1 node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)] elif op_name in _convert_map: - if op_name in ['_cond', '_foreach', '_while_loop']: - subgraphs = node['subgraphs'] - res = _convert_map[op_name](children, attrs, subgraphs) - else: - res = _convert_map[op_name](children, attrs) + op_params = _get_op_params(children, attrs, op_name, + node, params) + res = _convert_map[op_name](*op_params) if res is None: # defer conversion, used in RNN state initialization res = [node] @@ -1390,7 +2016,7 @@ def from_mxnet(symbol, The parameter dict to be used by nnvm """ try: - import mxnet as mx + import mxnet as mx #pylint: disable=import-outside-toplevel except ImportError as e: raise ImportError("{}. MXNet is required to parse symbols.".format(e)) @@ -1404,7 +2030,7 @@ def from_mxnet(symbol, for k, v in aux_params.items(): params[k] = _nd.array(v.asnumpy()) shape, dtype = _update_shape_dtype(shape, dtype, params) - func = _from_mxnet_impl(symbol, shape, dtype, mod) + func = _from_mxnet_impl(symbol, shape, dtype, params, mod) elif isinstance(symbol, mx.gluon.HybridBlock): if arg_params is not None or aux_params is not None: raise ValueError("arg_params and aux_params ae not used when importing HybridBlock") @@ -1418,7 +2044,7 @@ def from_mxnet(symbol, if isinstance(sym, (list, tuple)): sym = mx.sym.Group(sym) shape, dtype = _update_shape_dtype(shape, dtype, params) - func = _from_mxnet_impl(sym, shape, dtype, mod) + func = _from_mxnet_impl(sym, shape, dtype, params, mod) elif isinstance(symbol, mx.gluon.Block): raise NotImplementedError("Only Hybrid Blocks are supported now.") else: diff --git a/python/tvm/relay/frontend/mxnet_qnn_op_utils.py b/python/tvm/relay/frontend/mxnet_qnn_op_utils.py index 73d18a4f3394..a8836ff0270a 100644 --- a/python/tvm/relay/frontend/mxnet_qnn_op_utils.py +++ b/python/tvm/relay/frontend/mxnet_qnn_op_utils.py @@ -21,31 +21,73 @@ import numpy as np from tvm import relay -from tvm.relay.qnn.op.qnn import dequantize +from tvm.relay.qnn.op.qnn import quantize, dequantize -zero_centered_uint8_quantized_range = np.float32(255) -zero_centered_int8_quantized_range = np.float32(127) +# The below values are taken from - +# https://github.com/apache/incubator-mxnet/blob/master/src/operator/quantization/quantization_utils.h#L38-L39 +zero_centered_uint8_quantized_range = np.float32(255.5) +zero_centered_int8_quantized_range = np.float32(127.5) -def _dequantize_zero_centered(data, - data_min, - data_max, - quantized_range): - r"""Dequantizes the given data tensor by calculating the scale - using the MKLDNN formula `max(abs(data_min, data_max))/quantized_range`. +def _get_mkldnn_scale(data_min, + data_max, + quantized_range): + """Computes the scale as per MKLDNN specification mentioned here - + https://intel.github.io/mkl-dnn/ex_int8_simplenet.html + + Parameters + ---------- + data_min : float32 + A number representing the lower end of the tensor to be quantized. + data_max : float32 + A number representing the upper end of the tensor to be quantized. + quantized_range : float32 + 255 for uint8 and 127 for int8. This is the data type range. + + Returns + ------- + scale : A floating point number which acts as the scale for quantization. + """ + real_range = np.max([np.abs(np.float32(data_min)), + np.abs(np.float32(data_max))]) + scale = np.divide(quantized_range, real_range) + scale_inverse = np.divide(1.0, scale) + return scale_inverse + + +def _quantize_scale_with_zero_centered(data, + scale, + zero_point, + out_dtype): + quantized_output = quantize(data, + relay.const(scale, 'float32'), + relay.const(zero_point, 'int32'), + out_dtype=out_dtype) + return quantized_output, scale, zero_point + + +def _quantize_with_zero_centered(data, + data_min, + data_max, + quantized_range, + out_dtype): + """Quantizes the given data tensor by calculating the scale + using the MKLDNN formula `quantized_range / max(abs(data_min, data_max))`. Where quantized_range is 255 for uint8 and 127 for int8. The `data_min` and `data_max` are the min and max to use for the `data` tensor elements. Parameters ---------- data : tvm.relay.Expr - The input tensor to be quantized. Can be of type {int8 or uint8}. + The input tensor to be quantized. Can be of type float32. data_min : float The minimum to use data elements. data_max : float The maximum to use for data elements. quantized_range : float 255 for uint8 and 127 for int8. This is the data type range. + out_dtype : str + The output data type. Can be int8 or uint8 Returns ------- @@ -53,20 +95,23 @@ def _dequantize_zero_centered(data, The computed result. """ - real_range = np.max([np.abs(np.float32(data_min)), - np.abs(np.float32(data_max))]) - scale = relay.const(np.divide(real_range, quantized_range), 'float32') - zero_point = relay.const(0, 'int32') - return dequantize(data, scale, zero_point) - - -def _dequantize_mkldnn_min_max_int8(data, - imin_range, - imax_range): - r"""Dequantizes the given `data` in {int8 or uint8} and the given - min and max ranges and the output data type is `float32`. - The method of dequantizing is described here - https://tinyurl.com/y5k6fz5w. - We use our default quantize implementation from src/relay/qnn/op/dequantize.cc:67 + scale = _get_mkldnn_scale(data_min, + data_max, + quantized_range) + zero_point = 0 + return _quantize_scale_with_zero_centered(data, + scale, + zero_point, + out_dtype) + + +def _quantize_mkldnn_min_max_uint8(data, + data_min, + data_max): + """Quantizes the given `data` in float32 and the given + min and max ranges and the output data type is `uint8`. + The method of quantizing is described here - https://tinyurl.com/y5k6fz5w. + We use our default quantize implementation from src/relay/qnn/op/quantize.cc:72 but compute the `scale` and `zero_point` to fit our equation. Unlike in TFLite where we get the scale and zero_point from the model, MKLDNN stores the min and max from which we calculate the scale and zero_point. @@ -85,20 +130,20 @@ def _dequantize_mkldnn_min_max_int8(data, result : tvm.relay.Expr The computed result. """ - - return _dequantize_zero_centered(data, - data_min=imin_range, - data_max=imax_range, - quantized_range=zero_centered_int8_quantized_range) - - -def _dequantize_mkldnn_min_max_uint8(data, - imin_range, - imax_range): - r"""Dequantizes the given `data` in {int8 or uint8} and the given - min and max ranges and the output data type is `float32`. - The method of dequantize is described here - https://tinyurl.com/y5k6fz5w. - We use our default quantize implementation from src/relay/qnn/op/dequantize.cc:67 + return _quantize_with_zero_centered(data, + data_min, + data_max, + zero_centered_uint8_quantized_range, + 'uint8') + + +def _quantize_mkldnn_min_max_int8(data, + data_min, + data_max): + """Quantizes the given `data` in float32 and the given + min and max ranges and the output data type is `int8`. + The method of quantizing is described here - https://tinyurl.com/y5k6fz5w. + We use our default quantize implementation from src/relay/qnn/op/quantize.cc:72 but compute the `scale` and `zero_point` to fit our equation. Unlike in TFLite where we get the scale and zero_point from the model, MKLDNN stores the min and max from which we calculate the scale and zero_point. @@ -107,9 +152,9 @@ def _dequantize_mkldnn_min_max_uint8(data, ---------- data : tvm.relay.Expr The input tensor to be quantized. Can be of type float32. - imin_range : float + data_min : float The minimum to use data elements. - imax_range : float + data_max : float The maximum to use for data elements. Returns @@ -118,21 +163,235 @@ def _dequantize_mkldnn_min_max_uint8(data, The computed result. """ - return _dequantize_zero_centered(data, - data_min=imin_range, - data_max=imax_range, - quantized_range=zero_centered_uint8_quantized_range) + return _quantize_with_zero_centered(data, + data_min, + data_max, + zero_centered_int8_quantized_range, + 'int8') + + +def get_mkldnn_int8_scale(range_min, + range_max): + """Computes the quantization scale using MKLDNN specifications + with the given range. The output datatype of tensor to be quantized should be + int8. + + Parameters + ---------- + range_min : float32 + A number representing the lower end of the tensor to be quantized. + range_max : float32 + A number representing the upper end of the tensor to be quantized. + + Returns + ------- + scale : A float32 number which acts as the scale for quantization. + """ + + scale = _get_mkldnn_scale(range_min, + range_max, + zero_centered_int8_quantized_range) + return np.float32(scale) + + +def get_mkldnn_uint8_scale(range_min, + range_max): + """Computes the quantization scale using MKLDNN specifications + with the given range. The output datatype of tensor to be quantized should be + uint8. + + Parameters + ---------- + range_min : float32 + A number representing the lower end of the tensor to be quantized. + range_max : float32 + A number representing the upper end of the tensor to be quantized. + + Returns + ------- + scale : A float32 number which acts as the scale for quantization. + """ + + scale = _get_mkldnn_scale(range_min, + range_max, + zero_centered_uint8_quantized_range) + return np.float32(scale) + + +def quantize_conv_weights_bias_channel_mkldnn_from_var(weights_var, + bias, + min_vector_range, + max_vector_range, + data_scale): + """Helper method to quantize the convolution kernel in prequantized model + in MXNet with MKLDNN. The kernel is always quantized to int8 output datatype. + The inputs are the raw weights which are floating point numbers. The min and + max ranges are used from the weight itself. The name supplied is used to create + a tvm.relay.var with the given name. + + Parameters + ---------- + weights_var : tvm.relay.var + The float32 representation of the weights. + bias : np.array + The float32 np array for bias. + min_vector_range : array of float32 + A number representing the minimum of the weights per channel. + max_vector_range : array of float32 + A number representing the maximum of the weights per channel. + data_scale : float + The data scale value. + Returns + ------- + result : tvm.relay.expr + The quantized representation of the weights. + """ + + quantized_range = zero_centered_int8_quantized_range + real_vector_range = np.maximum(np.absolute(min_vector_range), + np.absolute(max_vector_range)) + # If real_vector_range is 0, then to avoid division by 0 in scaling, + # make real_vector INT32_max + vector_scale = np.where(real_vector_range == 0, + 1./float(np.iinfo(np.int32).max), + np.divide(real_vector_range, quantized_range)) + + # Handle bias impact on scales as done by MxNet-MKLDNN. + if bias is not None: + common = 2.0 * bias.astype('float32') * (1/data_scale) + vector_scale_min = np.where(bias > 0, + common/float(np.iinfo(np.int32).max), + common/float(np.iinfo(np.int32).min)) + vector_scale = np.maximum(vector_scale, vector_scale_min) + + zero_point = 0 + quantized_output = quantize(weights_var, + relay.const(vector_scale), + relay.const(zero_point, 'int32'), + axis=0, + out_dtype='int8') + return quantized_output, vector_scale, zero_point + + +def get_mkldnn_requantize_scale_outDtype(min_output_range, + max_output_range, + out_dtype): + quantized_out_range = zero_centered_int8_quantized_range if out_dtype == 'int8' \ + else zero_centered_uint8_quantized_range + out_range = np.max([np.abs(np.float32(min_output_range)), + np.abs(np.float32(max_output_range))]) + output_scale = quantized_out_range / out_range + requantize_scale = np.float32(1/output_scale) + return requantize_scale + + +def get_conv_mkldnn_requantized_scale_outDtype(min_output_range, max_output_range): + out_dtype = 'uint8' if min_output_range >= 0.0 else 'int8' + requantize_scale = get_mkldnn_requantize_scale_outDtype(min_output_range, + max_output_range, + out_dtype) + return requantize_scale, out_dtype + + +def quantize_conv_bias_mkldnn_from_var(bias_var, + bias_scale): + zero_point = 0 + quantized_bias = quantize(data=bias_var, + output_scale=relay.const(bias_scale), + output_zero_point=relay.const(zero_point, 'int32'), + axis=0, + out_dtype='int32') + + return quantized_bias + + +def quantize_mxnet_min_max(data, + min_range, + max_range, + out_dtype='int8'): + """Quantizes the given `data` in float32 and the given + min and max ranges and the output data type. + Only `int8` and `uint8` is supported as output data types. + The input data type is expected to be `float32`. + Mxnet has two different flavors for quantization 1) Default 2)MKLDNN. + To get the second one Mxnet must be built with MKLDNN during compile time. + Users can choose either of the implementation for TVM runtime. + The main difference between the two implementation is that MKLDNN is centered + around 0 and the default implementation for uint8 is not. + + Parameters + ---------- + data : tvm.relay.Expr + The input tensor to be quantized. Can be of type float32. + min_range : float + The minimum to use data elements. + max_range : float + The maximum to use for data elements. + out_dtype: str, optional + The output data type, can be 'int8' or 'uint8' -def _dequantize_mxnet_min_max_int8(data, - imin_range, - imax_range): - r"""Deuantizes the given `data` in {int8 or uint8} and the given + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + if out_dtype == 'uint8': + return _quantize_mkldnn_min_max_uint8(data, + min_range, + max_range) + elif out_dtype == 'int8': + return _quantize_mkldnn_min_max_int8(data, + min_range, + max_range) + else: + raise ValueError( + "Expected out_dtype to be int8 or uint8 but was %s" % out_dtype) + + +def _dequantize_zero_centered(data, + data_min, + data_max, + quantized_range): + """Dequantizes the given data tensor by calculating the scale + using the MKLDNN formula `max(abs(data_min, data_max))/quantized_range`. + Where quantized_range is 255 for uint8 and 127 for int8. The `data_min` + and `data_max` are the min and max to use for the `data` tensor elements. + + Parameters + ---------- + data : tvm.relay.Expr + The input tensor to be quantized. Can be of type {int8 or uint8}. + data_min : float + The minimum to use data elements. + data_max : float + The maximum to use for data elements. + quantized_range : float + 255 for uint8 and 127 for int8. This is the data type range. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + real_range = np.max([np.abs(np.float32(data_min)), + np.abs(np.float32(data_max))]) + scale = relay.const(np.divide(real_range, quantized_range), 'float32') + zero_point = relay.const(0, 'int32') + return dequantize(data, scale, zero_point) + + +def _dequantize_mkldnn_min_max_int8(data, + imin_range, + imax_range): + """Dequantizes the given `data` in {int8 or uint8} and the given min and max ranges and the output data type is `float32`. - The method of dequantization is described here - https://tinyurl.com/y4d7hrzf. - We use our default dequantize implementation from src/relay/qnn/op/dequantize.cc:67 + The method of dequantizing is described here - https://tinyurl.com/y5k6fz5w. + We use our default quantize implementation from src/relay/qnn/op/dequantize.cc:67 but compute the `scale` and `zero_point` to fit our equation. - Unlike in TFLite where we get the scale and zero_point from the model, Mxnet + Unlike in TFLite where we get the scale and zero_point from the model, MKLDNN stores the min and max from which we calculate the scale and zero_point. Parameters @@ -156,15 +415,15 @@ def _dequantize_mxnet_min_max_int8(data, quantized_range=zero_centered_int8_quantized_range) -def _dequantize_mxnet_min_max_uint8(data, - imin_range, - imax_range): - r"""Dequantizes the given `data` in {int8 or uint8} and the given +def _dequantize_mkldnn_min_max_uint8(data, + imin_range, + imax_range): + """Dequantizes the given `data` in {int8 or uint8} and the given min and max ranges and the output data type is `float32`. - The method of dequantizing is described here - https://tinyurl.com/y4d7hrzf. + The method of dequantize is described here - https://tinyurl.com/y5k6fz5w. We use our default quantize implementation from src/relay/qnn/op/dequantize.cc:67 but compute the `scale` and `zero_point` to fit our equation. - Unlike in TFLite where we get the scale and zero_point from the model, Mxnet + Unlike in TFLite where we get the scale and zero_point from the model, MKLDNN stores the min and max from which we calculate the scale and zero_point. Parameters @@ -182,25 +441,17 @@ def _dequantize_mxnet_min_max_uint8(data, The computed result. """ - iinfo = np.iinfo(np.uint8) - min_limit = np.float64(iinfo.min) - max_limit = np.float64(iinfo.max) - imin_range = np.float64(imin_range) - imax_range = np.float64(imax_range) - scale_val = np.divide((imax_range - imin_range), - (max_limit - min_limit)) - zero_point_val = np.int(-1 * np.divide(imin_range, scale_val)) - scale = relay.const(scale_val, 'float32') - zero_point = relay.const(zero_point_val, 'int32') - return dequantize(data, scale, zero_point) + return _dequantize_zero_centered(data, + data_min=imin_range, + data_max=imax_range, + quantized_range=zero_centered_uint8_quantized_range) def dequantize_mxnet_min_max(data, min_range, max_range, - in_dtype='int8', - use_mkldnn=False): - r"""Dequantizes the given `data` in {int8 or uint8} and the given + in_dtype='int8'): + """Dequantizes the given `data` in {int8 or uint8} and the given min and max ranges. The output data type is float32. Only `float32` is supported as output data types. The input data type is expected to be {int8 or uint8}. @@ -220,9 +471,6 @@ def dequantize_mxnet_min_max(data, The maximum to use for data elements for the output. in_dtype: str, optional The input data type, can be 'int8' or 'uint8' - use_mkldnn: bool, optional - If True then uses MKLDNN quantization implementation otherwise - will use default implementation. Returns ------- @@ -231,19 +479,13 @@ def dequantize_mxnet_min_max(data, """ if in_dtype == 'uint8': - if use_mkldnn: - return _dequantize_mkldnn_min_max_uint8(data, - min_range, - max_range) - else: - return _dequantize_mxnet_min_max_uint8(data, - min_range, - max_range) + return _dequantize_mkldnn_min_max_uint8(data, + min_range, + max_range) elif in_dtype == 'int8': - if use_mkldnn: - return _dequantize_mkldnn_min_max_int8(data, min_range, max_range) - else: - return _dequantize_mxnet_min_max_int8(data, min_range, max_range) + return _dequantize_mkldnn_min_max_int8(data, + min_range, + max_range) else: raise ValueError( "Expected out_dtype to be int8 or uint8 but was %s" % in_dtype) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 7381a0728567..8a6ceb81f263 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -988,4 +988,4 @@ def verify(a_np, b_np): test_forward_one_hot() test_forward_convolution() test_forward_deconvolution() - test_forward_cond() + test_forward_cond() \ No newline at end of file diff --git a/tests/python/frontend/mxnet/test_qnn_ops_utils.py b/tests/python/frontend/mxnet/test_qnn_ops_utils.py index 78c9692ea5b3..0c7374d4d8a7 100644 --- a/tests/python/frontend/mxnet/test_qnn_ops_utils.py +++ b/tests/python/frontend/mxnet/test_qnn_ops_utils.py @@ -21,21 +21,20 @@ from tvm.contrib import graph_runtime -def test_mxnet_dequantize_op(): +def test_mkldnn_dequantize(): - def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): + def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): shape = in_data.shape input_data = relay.var("input_data", shape=shape, dtype=in_dtype) min_range = quant_args['min_range'] max_range = quant_args['max_range'] - quantized_output = \ + dequantized_output = \ relay.frontend.dequantize_mxnet_min_max(input_data, min_range=min_range, max_range=max_range, in_dtype=in_dtype) - mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) + mod = relay.Function(relay.analysis.free_vars(dequantized_output), dequantized_output) mod = relay.Module.from_expr(mod) - mod = relay.qnn.transform.CanonicalizeOps()(mod) with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, "llvm", params=None) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) @@ -43,56 +42,55 @@ def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): rt_mod.set_input(**params) rt_mod.run() res = rt_mod.get_output(0).asnumpy() - assert np.allclose(res, verify_output_data, ) + assert np.allclose(res, verify_output_data) assert res.dtype == np.float32 def test_uint8_to_float32(): data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \ .astype('uint8') \ .reshape((2, 5)) - output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \ + output = np.array([0., 0.25048923, 0.50097847, 0.7514677, 1.0019569, 62.8728, 63.123287, + 63.373775, 63.624268, 63.874756]) \ .astype('float32') \ .reshape((2, 5)) quant_args = {"min_range": -63.5, "max_range": 64} - quantize_test_driver(in_dtype='uint8', - quant_args=quant_args, - in_data=data, - verify_output_data=output) + dequantize_test_driver(in_dtype='uint8', + quant_args=quant_args, + in_data=data, + verify_output_data=output) def test_int8_to_float32(): data = np.array([-126, -125, -124, -123, -122, 123, 124, 125, 126, 127]) \ .astype('int8') \ .reshape((2, 5)) - output = np.array([-63.496063, -62.992126, -62.48819, -61.984253, -61.480316, - 61.984253, 62.48819, 62.992126, 63.496063, 64.]) \ + output = np.array([-63.247063, -62.745102, -62.24314, -61.74118, -61.23922, + 61.74118, 62.24314, 62.745102, 63.247063, 63.749023]) \ .astype('float32') \ .reshape((2, 5)) - quant_args = {"min_range": -63.5, "max_range": 64} - quantize_test_driver(in_dtype='int8', - quant_args=quant_args, - in_data=data, - verify_output_data=output) + dequantize_args = {"min_range": -63.5, "max_range": 64} + dequantize_test_driver(in_dtype='int8', + quant_args=dequantize_args, + in_data=data, + verify_output_data=output) test_uint8_to_float32() test_int8_to_float32() -def test_mkldnn_dequantize_op(): +def test_mkldnn_quantize(): - def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): + def quantize_test_driver(out_dtype, quant_args, in_data, verify_output_data): shape = in_data.shape - input_data = relay.var("input_data", shape=shape, dtype=in_dtype) + input_data = relay.var("input_data", shape=shape, dtype='float32') min_range = quant_args['min_range'] max_range = quant_args['max_range'] - quantized_output = \ - relay.frontend.dequantize_mxnet_min_max(input_data, - min_range=min_range, - max_range=max_range, - in_dtype=in_dtype, - use_mkldnn=True) + quantized_output, _, _ = \ + relay.frontend.quantize_mxnet_min_max(input_data, + min_range=min_range, + max_range=max_range, + out_dtype=out_dtype) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) mod = relay.Module.from_expr(mod) - mod = relay.qnn.transform.CanonicalizeOps()(mod) with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, "llvm", params=None) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) @@ -100,43 +98,76 @@ def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): rt_mod.set_input(**params) rt_mod.run() res = rt_mod.get_output(0).asnumpy() - # print(res) - # np.testing.assert_equal(res, verify_output_data) - assert np.allclose(res, verify_output_data, ) - assert res.dtype == np.float32 + assert np.allclose(res, verify_output_data) + assert res.dtype == verify_output_data.dtype - def test_uint8_to_float32(): - data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \ - .astype('uint8') \ - .reshape((2, 5)) - output = np.array([0., 0.2509804, 0.5019608, 0.75294125, 1.0039216, - 62.996082, 63.247063, 63.498043, 63.749023, 64.]) \ + def test_float32_to_uint8(): + data = np.array([0., 0.25048923, 0.50097847, 0.7514677, 1.0019569, 62.8728, 63.123287, + 63.373775, 63.624268, 63.874756]) \ .astype('float32') \ .reshape((2, 5)) + output = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \ + .astype('uint8') \ + .reshape((2, 5)) + quant_args = {"min_range": -63.5, "max_range": 64} - quantize_test_driver(in_dtype='uint8', + quantize_test_driver(out_dtype='uint8', quant_args=quant_args, in_data=data, verify_output_data=output) - def test_int8_to_float32(): - data = np.array([-126, -125, -124, -123, -122, 123, 124, 125, 126, 127]) \ - .astype('int8') \ - .reshape((2, 5)) - output = np.array([-63.496063, -62.992126, -62.48819, -61.984253, -61.480316, - 61.984253, 62.48819, 62.992126, 63.496063, 64.]) \ + def test_float32_to_int8(): + data = np.array([-63.247063, -62.745102, -62.24314, -61.74118, -61.23922, + 61.74118, 62.24314, 62.745102, 63.247063, 63.749023]) \ .astype('float32') \ .reshape((2, 5)) + output = np.array([-126, -125, -124, -123, -122, 123, 124, 125, 126, 127]) \ + .astype('int8') \ + .reshape((2, 5)) + quant_args = {"min_range": -63.5, "max_range": 64} - quantize_test_driver(in_dtype='int8', + quantize_test_driver(out_dtype='int8', quant_args=quant_args, in_data=data, verify_output_data=output) - test_uint8_to_float32() - test_int8_to_float32() + test_float32_to_uint8() + test_float32_to_int8() + + +def test_get_mkldnn_int8_scale(): + range_min = -3.904039 + range_max = 3.904039 + expected = 0.03061991354976495 + output = relay.frontend.get_mkldnn_int8_scale(range_max=range_max, + range_min=range_min) + assert np.allclose(output, expected) + + +def test_get_mkldnn_uint8_scale(): + range_min = 0.0 + range_max = 55.77269 + expected = 0.21828841189047482 + output = relay.frontend.get_mkldnn_uint8_scale(range_max=range_max, + range_min=range_min) + assert np.allclose(output, expected) + + +def test_quantize_conv_bias_mkldnn_from_var(): + bias_var = relay.var('bias', shape=(3,), dtype='float32') + bias_scale = tvm.nd.array(np.array([0.5, 0.6, 0.7])) + output = relay.frontend.quantize_conv_bias_mkldnn_from_var(bias_var, bias_scale) + assert isinstance(output, tvm.relay.expr.Call) + attrs = output.attrs + assert attrs.axis == 0 + assert attrs.out_dtype == 'int32' + assert output.op.name == 'qnn.quantize' + assert output.args[1].data == bias_scale if __name__ == "__main__": - test_mxnet_dequantize_op() - test_mkldnn_dequantize_op() + test_mkldnn_dequantize() + test_mkldnn_quantize() + test_get_mkldnn_int8_scale() + test_get_mkldnn_uint8_scale() + test_quantize_conv_bias_mkldnn_from_var() \ No newline at end of file