Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MXNet pre-quantized BERT #6039

Merged
merged 7 commits into from
Jul 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/tvm/relay/qnn/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
}
};

/*! \brief Attribute for dequantize operator */
struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
int axis;

TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
TVM_ATTR_FIELD(axis)
.describe(
"The channel axis for channel wise dequantization. Default value is -1,"
"which corresponds to the last axis.")
.set_default(-1);
}
};

} // namespace qnn
} // namespace relay
} // namespace tvm
Expand Down
165 changes: 130 additions & 35 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1944,18 +1944,27 @@ def _qnn_batch_norm(inputs, attrs):

def _qnn_fully_connected(inputs, attrs, subgraphs, params):

def _get_input_scale_zp(_data, _inputs, _has_bias):
def _get_input_scale_zp(_data_dtype, _inputs, _has_bias):
data_min_idx, data_max_idx = (3, 4) if _has_bias else (2, 3)
data_min, data_max = _inputs[data_min_idx], _inputs[data_max_idx]
data_dtype = _infer_type(_data).checked_type.dtype
_data_scale = get_mkldnn_uint8_scale(data_min, data_max) \
if data_dtype == 'uint8' \
if _data_dtype == 'uint8' \
else get_mkldnn_int8_scale(data_min, data_max)
_data_zp = 0
return _data_scale, _data_zp

def _get_kernel_scale_zp(_kernel, _inputs, _has_bias):
def _get_kernel_scale_zp_tensor_quantized(_kernel, _inputs, _has_bias):
kernel_dtype = _infer_type(_kernel).checked_type.dtype

if kernel_dtype != "int8":
raise tvm.error.OpNotImplemented(\
"Tensor wise quantized expects weights in int8 data type")

if isinstance(_kernel, tvm.relay.Call) and _kernel.op.name == "qnn.quantize":
_kernel_scale = _kernel.args[1].data.asnumpy()
_kernel_zp = _kernel.args[2].data.asnumpy()
return _kernel_scale, _kernel_zp

kernel_min_idx, kernel_max_idx = (5, 6) if _has_bias else (4, 5)
kernel_min_name = _get_name(_inputs[kernel_min_idx])
kernel_min = params[kernel_min_name].asnumpy()[0]
Expand All @@ -1967,7 +1976,34 @@ def _get_kernel_scale_zp(_kernel, _inputs, _has_bias):
_kernel_zp = 0
return _kernel_scale, _kernel_zp

def _get_kernel_scale_zp_channel_quantized(_kernel, _bias, _data_scale):
kernel_dtype = _infer_type(_kernel).checked_type.dtype
if kernel_dtype != "float32":
raise tvm.error.OpNotImplemented(\
"Channel wise quantized expects weights in float32 data type")

# Get the FP32 values, calculate min/max and then channel quantize them
np_kernel = _infer_value(_kernel, params).asnumpy()
kernel_channel_min = np.amin(np_kernel, axis=(1, ))
kernel_channel_max = np.amax(np_kernel, axis=(1, ))

np_bias = None
if _bias is not None:
np_bias = _infer_value(_bias, params).asnumpy()
return quantize_conv_weights_bias_channel_mkldnn_from_var(_kernel,
np_bias,
kernel_channel_min,
kernel_channel_max,
_data_scale)

def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
_bias = _inputs[2]
if isinstance(_bias, tvm.relay.Call) and _bias.op.name == "qnn.quantize":
_bias_scale = _bias.args[1].data.asnumpy()
_bias_requantize_scale = _bias_scale/(_data_scale * _kernel_scale)
_bias_requantize_scale = _expr.const(_bias_requantize_scale, dtype="float32")
return _bias_requantize_scale

bias_min_name = _get_name(_inputs[7])
bias_min = params[bias_min_name].asnumpy()[0]
bias_max_name = _get_name(_inputs[8])
Expand All @@ -1987,39 +2023,95 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
return res
else:
has_bias = not subgraph_dense_attrs.get_bool("no_bias", False)
# input
data = inputs[0]
data_scale, data_zp = _get_input_scale_zp(data, inputs, has_bias)
# kernel
kernel = inputs[1]
kernel_scale, kernel_zp = _get_kernel_scale_zp(kernel, inputs, has_bias)
units = subgraph_dense_attrs.get_int("num_hidden")
is_flatten = subgraph_dense_attrs.get_bool("flatten", True)
enable_float_output = attrs.get_bool('enable_float_output', False)
is_channel_quantized = attrs.get_bool('channel_wise_quantize', False)

########################
# Get data, kernel, bias
########################
data, kernel = inputs[0], inputs[1]
bias = None
if has_bias:
bias = inputs[2]

##############################
# Handle for shape of data > 2
##############################
if is_flatten:
data = _op.nn.batch_flatten(data)
data_shape = _infer_type(data).checked_type.shape
if len(data_shape) > 2:
data = _op.nn.batch_flatten(data)
data = _op.reverse_reshape(data, [-1, 0])

###############################
# Get data scale and zero point
###############################
data_dtype = _infer_type(data).checked_type.dtype
data_scale, data_zp = _get_input_scale_zp(data_dtype, inputs, has_bias)

#################################
# Get weight scale and zero point
#################################
if is_channel_quantized:
kernel, kernel_scale, kernel_zp = _get_kernel_scale_zp_channel_quantized(kernel,
bias,
data_scale)
else:
kernel_scale, kernel_zp = _get_kernel_scale_zp_tensor_quantized(kernel, inputs,
has_bias)

################
# Call QNN dense
################
res = relay.qnn.op.dense(data,
kernel,
input_zero_point=relay.const(data_zp, 'int32'),
kernel_zero_point=relay.const(kernel_zp, 'int32'),
input_scale=relay.const(data_scale, 'float32'),
kernel_scale=relay.const(kernel_scale, 'float32'),
units=units)

#################
# Handle bias add
#################
if has_bias:
bias_data = inputs[2]
bias_requantize_scale = \
_get_bias_requantize_scale(inputs, data_scale, kernel_scale)
multiplied_bias = \
_op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale)
rounded_bias = _op.round(multiplied_bias)
clipped_bias = _op.clip(rounded_bias,
a_min=tvm.tir.op.min_value('int32').value,
a_max=tvm.tir.op.max_value('int32').value)
requantized_bias = _op.cast(clipped_bias, 'int32')
res = _op.nn.bias_add(res, requantized_bias, axis=-1)
enable_float_output = attrs.get_bool('enable_float_output', False)
out_dtype = 'uint8' if attrs.get_bool('with_relu', False) else 'int8'
input_scale = np.float32(data_scale * kernel_scale)
if not enable_float_output:
if is_channel_quantized:
bias_scale = data_scale * kernel_scale
int32_bias = quantize_conv_bias_mkldnn_from_var(bias, bias_scale)
res = _op.nn.bias_add(res, int32_bias, axis=-1)
else:
bias_data = inputs[2]
bias_requantize_scale = \
_get_bias_requantize_scale(inputs, data_scale, kernel_scale)
multiplied_bias = \
_op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale)
rounded_bias = _op.round(multiplied_bias)
clipped_bias = _op.clip(rounded_bias,
a_min=tvm.tir.op.min_value('int32').value,
a_max=tvm.tir.op.max_value('int32').value)
requantized_bias = _op.cast(clipped_bias, 'int32')
res = _op.nn.bias_add(res, requantized_bias, axis=-1)

##############################################
# Dequantize if float32 output else Requantize
##############################################
if enable_float_output:
output_scale = np.float32(data_scale * kernel_scale)
res = relay.qnn.op.dequantize(res,
relay.const(output_scale),
input_zero_point=relay.const(0, 'int32'),
axis=1)
if with_relu:
res = _op.nn.relu(res)
else:

if is_channel_quantized:
raise tvm.error.OpNotImplemented(\
"Channel wise quantized dense with non float output is not supported yet")
out_dtype = 'uint8' if attrs.get_bool('with_relu', False) else 'int8'
input_scale = np.float32(data_scale * kernel_scale)
min_output_range = attrs.get_float('min_calib_range')
max_output_range = attrs.get_float('max_calib_range')
output_scale = get_mkldnn_requantize_scale_outDtype(min_output_range,
Expand All @@ -2034,17 +2126,20 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
out_dtype=out_dtype)
if with_relu:
res = _op.nn.relu(res)
return res, min_output_range, max_output_range
else:
output_scale = np.float32(data_scale * kernel_scale)
res = relay.qnn.op.dequantize(res,
relay.const(output_scale, 'float32'),
input_zero_point=relay.const(0, 'int32'))
if with_relu:
res = _op.nn.relu(res)
return res


##############################
# Handle for shape of data > 2
##############################
if len(data_shape) > 2:
new_shape = data_shape[:-1]
new_shape.append(units)
res = _op.reshape(res, new_shape)

if enable_float_output:
return res
return res, min_output_range, max_output_range

def _mx_broadcast_to(inputs, attrs):
data = inputs[0]
tgt_shape = attrs.get_int_tuple("shape", [])
Expand Down
52 changes: 50 additions & 2 deletions python/tvm/relay/frontend/nnvm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
# pylint: disable=invalid-name, import-self, len-as-condition
"""Utility functions common to NNVM and MxNet conversion."""
import warnings
from ... import error
from ...tir.op import min_value
from .. import expr as _expr
from .. import op as _op
from .common import get_relay_op
from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape

def _warn_not_used(attr, op='nnvm'):
err = "{} is ignored in {}.".format(attr, op)
Expand Down Expand Up @@ -57,9 +60,54 @@ def _impl(inputs, attrs):
def _softmax_op(new_op):
"""softmax/log_softmax"""
def _impl(inputs, attrs, _dtype='float32'):
# TODO(@icemelon9): currently ignore the 2nd input to softmax for mxnet 1.6
# assert len(inputs) == 1
axis = attrs.get_int("axis", -1)
use_length = attrs.get_bool("use_length", False)
if use_length:
# The second arg is valid_length. We can use sequence mask to mask the input before
# computing softmax
assert len(inputs) == 2

data = inputs[0]
length = inputs[1]
data_shape = _infer_shape(data)
data_dtype = _infer_type(data).checked_type.dtype
length_shape = _infer_shape(length)

if axis < 0:
axis = len(data_shape) + axis

data_ndims = len(data_shape)
length_ndims = len(length_shape)

# Sequence_mask supports axis = 0 and 1 and requires data to be in specific format.
if axis == data_ndims - 1 and data_ndims > 2 and length_ndims == 2:
new_batch_size = 1
for dim in range(length_ndims):
assert data_shape[dim] == length_shape[dim]
new_batch_size *= data_shape[dim]

# Reshape the data and length to satisfy sequence mask
data = _op.reshape(data, newshape=(new_batch_size, -1))
length = _op.reshape(length, newshape=(new_batch_size))

# Input data is now 2D, we can set the axis = 1
axis = 1
elif data_ndims > 2:
raise error.OpNotImplemented(\
"Operator softmax with use_length=True is supported only for axis -1")

res = _op.sequence_mask(data=data,
valid_length=length,
mask_value=float(min_value(data_dtype).value),
axis=axis)

# Apply softmax
res = new_op(res, axis=axis)

# Reshape back to input data shape
if len(data_shape) > 2:
return _op.reshape(res, newshape=data_shape)
return res
return new_op(inputs[0], axis=axis)
return _impl

Expand Down
8 changes: 6 additions & 2 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def quantize(data,

def dequantize(data,
input_scale,
input_zero_point):
input_zero_point,
axis=-1):
r""" Dequantize op
This operator takes quantized int8 and unit8 as input and produces
dequantized float32 as output. The output shape is the same as input shape. The input
Expand All @@ -135,6 +136,8 @@ def dequantize(data,
The input zero_point.
input_scale : tvm.relay.Expr
The input scale.
axis : int
The channel axis for quantization. Default value is -1 which corresponds to the last axis.
Returns
-------
result : tvm.relay.Expr
Expand All @@ -143,7 +146,8 @@ def dequantize(data,

return _make.dequantize(data,
input_scale,
input_zero_point)
input_zero_point,
axis)


def concatenate(data,
Expand Down
Loading