Skip to content

Commit

Permalink
MXNet pre-quantized BERT
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Jul 11, 2020
1 parent c9c77c6 commit 9eb8b76
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 54 deletions.
13 changes: 13 additions & 0 deletions include/tvm/relay/qnn/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
}
};

/*! \brief Attribute for dequantize operator */
struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
int axis;

TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
TVM_ATTR_FIELD(axis)
.describe(
"The channel axis for channel wise dequantization. Default value is -1,"
"which corresponds to the last axis.")
.set_default(-1);
}
};

} // namespace qnn
} // namespace relay
} // namespace tvm
Expand Down
165 changes: 130 additions & 35 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1944,18 +1944,27 @@ def _qnn_batch_norm(inputs, attrs):

def _qnn_fully_connected(inputs, attrs, subgraphs, params):

def _get_input_scale_zp(_data, _inputs, _has_bias):
def _get_input_scale_zp(_data_dtype, _inputs, _has_bias):
data_min_idx, data_max_idx = (3, 4) if _has_bias else (2, 3)
data_min, data_max = _inputs[data_min_idx], _inputs[data_max_idx]
data_dtype = _infer_type(_data).checked_type.dtype
_data_scale = get_mkldnn_uint8_scale(data_min, data_max) \
if data_dtype == 'uint8' \
if _data_dtype == 'uint8' \
else get_mkldnn_int8_scale(data_min, data_max)
_data_zp = 0
return _data_scale, _data_zp

def _get_kernel_scale_zp(_kernel, _inputs, _has_bias):
def _get_kernel_scale_zp_tensor_quantized(_kernel, _inputs, _has_bias):
kernel_dtype = _infer_type(_kernel).checked_type.dtype

if kernel_dtype != "int8":
raise tvm.error.OpNotImplemented(\
"Tensor wise quantized expects weights in int8 data type")

if isinstance(_kernel, tvm.relay.Call) and _kernel.op.name == "qnn.quantize":
_kernel_scale = _kernel.args[1].data.asnumpy()
_kernel_zp = _kernel.args[2].data.asnumpy()
return _kernel_scale, _kernel_zp

kernel_min_idx, kernel_max_idx = (5, 6) if _has_bias else (4, 5)
kernel_min_name = _get_name(_inputs[kernel_min_idx])
kernel_min = params[kernel_min_name].asnumpy()[0]
Expand All @@ -1967,7 +1976,34 @@ def _get_kernel_scale_zp(_kernel, _inputs, _has_bias):
_kernel_zp = 0
return _kernel_scale, _kernel_zp

def _get_kernel_scale_zp_channel_quantized(_kernel, _bias, _data_scale):
kernel_dtype = _infer_type(_kernel).checked_type.dtype
if kernel_dtype != "float32":
raise tvm.error.OpNotImplemented(\
"Channel wise quantized expects weights in float32 data type")

# Get the FP32 values, calculate min/max and then channel quantize them
np_kernel = _infer_value(_kernel, params).asnumpy()
kernel_channel_min = np.amin(np_kernel, axis=(1, ))
kernel_channel_max = np.amax(np_kernel, axis=(1, ))

np_bias = None
if _bias is not None:
np_bias = _infer_value(_bias, params).asnumpy()
return quantize_conv_weights_bias_channel_mkldnn_from_var(_kernel,
np_bias,
kernel_channel_min,
kernel_channel_max,
_data_scale)

def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
_bias = _inputs[2]
if isinstance(_bias, tvm.relay.Call) and _bias.op.name == "qnn.quantize":
_bias_scale = _bias.args[1].data.asnumpy()
_bias_requantize_scale = _bias_scale/(_data_scale * _kernel_scale)
_bias_requantize_scale = _expr.const(_bias_requantize_scale, dtype="float32")
return _bias_requantize_scale

bias_min_name = _get_name(_inputs[7])
bias_min = params[bias_min_name].asnumpy()[0]
bias_max_name = _get_name(_inputs[8])
Expand All @@ -1987,39 +2023,95 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
return res
else:
has_bias = not subgraph_dense_attrs.get_bool("no_bias", False)
# input
data = inputs[0]
data_scale, data_zp = _get_input_scale_zp(data, inputs, has_bias)
# kernel
kernel = inputs[1]
kernel_scale, kernel_zp = _get_kernel_scale_zp(kernel, inputs, has_bias)
units = subgraph_dense_attrs.get_int("num_hidden")
is_flatten = subgraph_dense_attrs.get_bool("flatten", True)
enable_float_output = attrs.get_bool('enable_float_output', False)
is_channel_quantized = attrs.get_bool('channel_wise_quantize', False)

########################
# Get data, kernel, bias
########################
data, kernel = inputs[0], inputs[1]
bias = None
if has_bias:
bias = inputs[2]

##############################
# Handle for shape of data > 2
##############################
if is_flatten:
data = _op.nn.batch_flatten(data)
data_shape = _infer_type(data).checked_type.shape
if len(data_shape) > 2:
data = _op.nn.batch_flatten(data)
data = _op.reverse_reshape(data, [-1, 0])

###############################
# Get data scale and zero point
###############################
data_dtype = _infer_type(data).checked_type.dtype
data_scale, data_zp = _get_input_scale_zp(data_dtype, inputs, has_bias)

#################################
# Get weight scale and zero point
#################################
if is_channel_quantized:
kernel, kernel_scale, kernel_zp = _get_kernel_scale_zp_channel_quantized(kernel,
bias,
data_scale)
else:
kernel_scale, kernel_zp = _get_kernel_scale_zp_tensor_quantized(kernel, inputs,
has_bias)

################
# Call QNN dense
################
res = relay.qnn.op.dense(data,
kernel,
input_zero_point=relay.const(data_zp, 'int32'),
kernel_zero_point=relay.const(kernel_zp, 'int32'),
input_scale=relay.const(data_scale, 'float32'),
kernel_scale=relay.const(kernel_scale, 'float32'),
units=units)

#################
# Handle bias add
#################
if has_bias:
bias_data = inputs[2]
bias_requantize_scale = \
_get_bias_requantize_scale(inputs, data_scale, kernel_scale)
multiplied_bias = \
_op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale)
rounded_bias = _op.round(multiplied_bias)
clipped_bias = _op.clip(rounded_bias,
a_min=tvm.tir.op.min_value('int32').value,
a_max=tvm.tir.op.max_value('int32').value)
requantized_bias = _op.cast(clipped_bias, 'int32')
res = _op.nn.bias_add(res, requantized_bias, axis=-1)
enable_float_output = attrs.get_bool('enable_float_output', False)
out_dtype = 'uint8' if attrs.get_bool('with_relu', False) else 'int8'
input_scale = np.float32(data_scale * kernel_scale)
if not enable_float_output:
if is_channel_quantized:
bias_scale = data_scale * kernel_scale
int32_bias = quantize_conv_bias_mkldnn_from_var(bias, bias_scale)
res = _op.nn.bias_add(res, int32_bias, axis=-1)
else:
bias_data = inputs[2]
bias_requantize_scale = \
_get_bias_requantize_scale(inputs, data_scale, kernel_scale)
multiplied_bias = \
_op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale)
rounded_bias = _op.round(multiplied_bias)
clipped_bias = _op.clip(rounded_bias,
a_min=tvm.tir.op.min_value('int32').value,
a_max=tvm.tir.op.max_value('int32').value)
requantized_bias = _op.cast(clipped_bias, 'int32')
res = _op.nn.bias_add(res, requantized_bias, axis=-1)

##############################################
# Dequantize if float32 output else Requantize
##############################################
if enable_float_output:
output_scale = np.float32(data_scale * kernel_scale)
res = relay.qnn.op.dequantize(res,
relay.const(output_scale),
input_zero_point=relay.const(0, 'int32'),
axis=1)
if with_relu:
res = _op.nn.relu(res)
else:

if is_channel_quantized:
raise tvm.error.OpNotImplemented(\
"Channel wise quantized dense with non float output is not supported yet")
out_dtype = 'uint8' if attrs.get_bool('with_relu', False) else 'int8'
input_scale = np.float32(data_scale * kernel_scale)
min_output_range = attrs.get_float('min_calib_range')
max_output_range = attrs.get_float('max_calib_range')
output_scale = get_mkldnn_requantize_scale_outDtype(min_output_range,
Expand All @@ -2034,17 +2126,20 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
out_dtype=out_dtype)
if with_relu:
res = _op.nn.relu(res)
return res, min_output_range, max_output_range
else:
output_scale = np.float32(data_scale * kernel_scale)
res = relay.qnn.op.dequantize(res,
relay.const(output_scale, 'float32'),
input_zero_point=relay.const(0, 'int32'))
if with_relu:
res = _op.nn.relu(res)
return res


##############################
# Handle for shape of data > 2
##############################
if len(data_shape) > 2:
new_shape = data_shape[:-1]
new_shape.append(units)
res = _op.reshape(res, new_shape)

if enable_float_output:
return res
return res, min_output_range, max_output_range

def _mx_broadcast_to(inputs, attrs):
data = inputs[0]
tgt_shape = attrs.get_int_tuple("shape", [])
Expand Down
51 changes: 49 additions & 2 deletions python/tvm/relay/frontend/nnvm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
# pylint: disable=invalid-name, import-self, len-as-condition
"""Utility functions common to NNVM and MxNet conversion."""
import warnings
from ... import error
from ...tir.op import min_value
from .. import expr as _expr
from .. import op as _op
from .common import get_relay_op
from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape

def _warn_not_used(attr, op='nnvm'):
err = "{} is ignored in {}.".format(attr, op)
Expand Down Expand Up @@ -57,9 +60,53 @@ def _impl(inputs, attrs):
def _softmax_op(new_op):
"""softmax/log_softmax"""
def _impl(inputs, attrs, _dtype='float32'):
# TODO(@icemelon9): currently ignore the 2nd input to softmax for mxnet 1.6
# assert len(inputs) == 1
axis = attrs.get_int("axis", -1)
use_length = attrs.get_bool("use_length", False)
if use_length:
# The second arg is valid_length. We can use sequence mask to mask the input before
# computing softmax
assert len(inputs) == 2

data = inputs[0]
length = inputs[1]
data_shape = _infer_shape(data)
length_shape = _infer_shape(length)

if axis < 0:
axis = len(data_shape) + axis

data_ndims = len(data_shape)
length_ndims = len(length_shape)

# Sequence_mask supports axis = 0 and 1 and requires data to be in specific format.
if axis == data_ndims - 1 and data_ndims > 2 and length_ndims == 2:
new_batch_size = 1
for dim in range(length_ndims):
assert data_shape[dim] == length_shape[dim]
new_batch_size *= data_shape[dim]

# Reshape the data and length to satisfy sequence mask
data = _op.reshape(data, newshape=(new_batch_size, -1))
length = _op.reshape(length, newshape=(new_batch_size))

# Input data is now 2D, we can set the axis = 1
axis = 1
elif data_ndims > 2:
raise error.OpNotImplemented(\
"Operator softmax with use_length=True is supported only for axis -1")

res = _op.sequence_mask(data=data,
valid_length=length,
mask_value=float(min_value("float").value),
axis=axis)

# Apply softmax
res = new_op(res, axis=axis)

# Reshape back to input data shape
if len(data_shape) > 2:
return _op.reshape(res, newshape=data_shape)
return res
return new_op(inputs[0], axis=axis)
return _impl

Expand Down
8 changes: 6 additions & 2 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def quantize(data,

def dequantize(data,
input_scale,
input_zero_point):
input_zero_point,
axis=-1):
r""" Dequantize op
This operator takes quantized int8 and unit8 as input and produces
dequantized float32 as output. The output shape is the same as input shape. The input
Expand All @@ -135,6 +136,8 @@ def dequantize(data,
The input zero_point.
input_scale : tvm.relay.Expr
The input scale.
axis : int
The channel axis for quantization. Default value is -1 which corresponds to the last axis.
Returns
-------
result : tvm.relay.Expr
Expand All @@ -143,7 +146,8 @@ def dequantize(data,

return _make.dequantize(data,
input_scale,
input_zero_point)
input_zero_point,
axis)


def concatenate(data,
Expand Down
Loading

0 comments on commit 9eb8b76

Please sign in to comment.