Skip to content

Commit

Permalink
{QNN] Making scale/zero_points as expr instead of attrs.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Jan 3, 2020
1 parent 3f43bee commit dd63f3d
Show file tree
Hide file tree
Showing 27 changed files with 756 additions and 776 deletions.
176 changes: 0 additions & 176 deletions include/tvm/relay/qnn/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,10 @@ namespace qnn {

/*! \brief Attribute for requantize operator */
struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
double input_scale;
int32_t input_zero_point;
double output_scale;
int32_t output_zero_point;
std::string rounding;
DataType out_dtype;

TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
TVM_ATTR_FIELD(input_scale)
.describe("The scale of the input tensor.");
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(output_scale)
.describe("The scale of the output tensor.");
TVM_ATTR_FIELD(output_zero_point)
.describe("The zero point of the output tensor.");
TVM_ATTR_FIELD(rounding).set_default("UPWARD")
.describe("Defines the rounding direction when the value is midway between"
"two representable values. There are two supported modes - UPWARD"
Expand All @@ -67,175 +55,11 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {

/*! \brief Attribute for quantize operator */
struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
int32_t output_zero_point;
double output_scale;
DataType out_dtype;

TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
TVM_ATTR_FIELD(out_dtype)
.describe("Output data type, can be one of [int8 or uint8].");
TVM_ATTR_FIELD(output_zero_point)
.describe("The zero_point for the activation of this op.");
TVM_ATTR_FIELD(output_scale)
.describe("The scale for the activation of this op.");
}
};

/*! \brief Attribute for dequantize operator */
struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
int32_t input_zero_point;
double input_scale;

TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero_point for the input tensor of this op.");
TVM_ATTR_FIELD(input_scale)
.describe("The scale for the input tensor of this op.");
}
};

/*! \brief Attributes used in QNN concatenate operator */
struct QnnConcatenateAttrs : public tvm::AttrsNode<QnnConcatenateAttrs> {
Array<tvm::Expr> input_scales;
Array<tvm::Expr> input_zero_points;
double output_scale;
int32_t output_zero_point;
int axis;

TVM_DECLARE_ATTRS(QnnConcatenateAttrs, "relay.attrs.QnnConcatenateAttrs") {
TVM_ATTR_FIELD(input_scales)
.describe("The list of scales of input quantized tensors.");
TVM_ATTR_FIELD(input_zero_points)
.describe("The list of zero points of input quantized tensors.");
TVM_ATTR_FIELD(output_zero_point)
.describe("The zero_point for the output tensor.");
TVM_ATTR_FIELD(output_scale)
.describe("The scale for the output tensor.");
TVM_ATTR_FIELD(axis)
.describe("The axis at which the input arrays are concatenated."
"Should lie in range `[-ndim, ndim)`.")
.set_default(0);
}
}; // struct QnnConcatenateAttrs

/*! \brief Attribute for QNN Conv2d operator */
struct QnnConv2DAttrs : public tvm::AttrsNode<QnnConv2DAttrs> {
// Traditional conv2d attributes.
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
DataType out_dtype;

// Quantization related attributes.
int32_t input_zero_point;
int32_t kernel_zero_point;
// The input tensor scale and kernel tensor scales are stored
// for easy access to this information.
double input_scale;
double kernel_scale;

TVM_DECLARE_ATTRS(QnnConv2DAttrs, "relay.attrs.QnnConv2DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
.describe("The number of output channels in the convolution."
" If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(data_layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(kernel_zero_point)
.describe("The zero point of the kernel tensor.");
TVM_ATTR_FIELD(input_scale)
.describe("The quantization scale for the input tensor.");
TVM_ATTR_FIELD(kernel_scale)
.describe("The quantization scale for the weight tensor.");
}
};

/*! \brief Attribute for QNN binary operator */
struct QnnBinaryOpAttrs : public tvm::AttrsNode<QnnBinaryOpAttrs> {
int32_t lhs_zero_point;
double lhs_scale;
int32_t rhs_zero_point;
double rhs_scale;
int32_t output_zero_point;
double output_scale;

TVM_DECLARE_ATTRS(QnnBinaryOpAttrs, "relay.attrs.QnnBinaryOpAttrs") {
TVM_ATTR_FIELD(lhs_zero_point)
.describe("The zero_point for the lhs input tensor of this op.");
TVM_ATTR_FIELD(lhs_scale)
.describe("The scale for the lhs input tensor of this op.");
TVM_ATTR_FIELD(rhs_zero_point)
.describe("The zero_point for the rhs input tensor of this op.");
TVM_ATTR_FIELD(rhs_scale)
.describe("The scale for the rhs input tensor of this op.");
TVM_ATTR_FIELD(output_zero_point)
.describe("The zero_point for the activation of this op.");
TVM_ATTR_FIELD(output_scale)
.describe("The scale for the activation of this op.");
}
};

/*! \brief Attributes for qnn dense operator */
struct QnnDenseAttrs : public tvm::AttrsNode<QnnDenseAttrs> {
IndexExpr units;
DataType out_dtype;
// Quantization related attributes.
int32_t input_zero_point;
int32_t kernel_zero_point;
double input_scale;
double kernel_scale;

TVM_DECLARE_ATTRS(QnnDenseAttrs, "relay.attrs.QnnDenseAttrs") {
TVM_ATTR_FIELD(units)
.describe("Number of hidden units of the dense transformation.");
TVM_ATTR_FIELD(out_dtype)
.describe("Output data type, set to explicit type under mixed precision setting");
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(kernel_zero_point)
.describe("The zero point of the kernel tensor.");
TVM_ATTR_FIELD(input_scale)
.describe("The input tensor scale.");
TVM_ATTR_FIELD(kernel_scale)
.describe("The kernel tensor scale.");
}
};

Expand Down
13 changes: 8 additions & 5 deletions python/tvm/relay/frontend/mxnet_qnn_op_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""

import numpy as np
from tvm import relay
from tvm.relay.qnn.op.qnn import dequantize

zero_centered_uint8_quantized_range = np.float32(255)
Expand Down Expand Up @@ -54,8 +55,8 @@ def _dequantize_zero_centered(data,

real_range = np.max([np.abs(np.float32(data_min)),
np.abs(np.float32(data_max))])
scale = np.divide(real_range, quantized_range)
zero_point = 0
scale = relay.const(np.divide(real_range, quantized_range), 'float32')
zero_point = relay.const(0, 'int32')
return dequantize(data, scale, zero_point)


Expand Down Expand Up @@ -186,9 +187,11 @@ def _dequantize_mxnet_min_max_uint8(data,
max_limit = np.float64(iinfo.max)
imin_range = np.float64(imin_range)
imax_range = np.float64(imax_range)
scale = np.divide((imax_range - imin_range),
(max_limit - min_limit))
zero_point = np.int(-1 * np.divide(imin_range, scale))
scale_val = np.divide((imax_range - imin_range),
(max_limit - min_limit))
zero_point_val = np.int(-1 * np.divide(imin_range, scale_val))
scale = relay.const(scale_val, 'float32')
zero_point = relay.const(zero_point_val, 'int32')
return dequantize(data, scale, zero_point)


Expand Down
50 changes: 34 additions & 16 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import math
import numpy as np
import tvm
from tvm import relay
from .. import analysis
from .. import expr as _expr
from .. import module as _module
from .. import op as _op
from .. import qnn as _qnn
from ..util import get_scalar_from_constant
from ... import nd as _nd
from .common import ExprTable
from .common import infer_shape as _infer_shape
Expand Down Expand Up @@ -177,8 +179,8 @@ def get_tensors(self, tensors_idx_list):
# Check that the scale and zero points are valid.
if scale != 0 or zero_point != 0:
qnn_params = dict()
qnn_params['scale'] = scale
qnn_params['zero_point'] = zero_point
qnn_params['scale'] = relay.const(scale, 'float32')
qnn_params['zero_point'] = relay.const(zero_point, 'int32')
return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params))
return return_list

Expand Down Expand Up @@ -225,8 +227,16 @@ def get_tensor_type_str(self, tensor_type):
.format(str(tensor_type)))

def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
return lhs_tensor.qnn_params['scale'] == rhs_tensor.qnn_params['scale'] and \
lhs_tensor.qnn_params['zero_point'] == rhs_tensor.qnn_params['zero_point']
lhs_scale = lhs_tensor.qnn_params['scale']
rhs_scale = rhs_tensor.qnn_params['scale']
lhs_zero_point = lhs_tensor.qnn_params['zero_point']
rhs_zero_point = rhs_tensor.qnn_params['zero_point']
lhs_scale_value = get_scalar_from_constant(lhs_scale)
rhs_scale_value = get_scalar_from_constant(rhs_scale)
lhs_zero_point_value = get_scalar_from_constant(lhs_zero_point)
rhs_zero_point_value = get_scalar_from_constant(rhs_zero_point)
return lhs_scale_value == rhs_scale_value and \
lhs_zero_point_value == rhs_zero_point_value

def is_quantized(self, op):
"""Check if an input tensor is quantized."""
Expand Down Expand Up @@ -748,13 +758,11 @@ def convert_fully_connected(self, op):
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)

if input_tensor.qnn_params:
input_scale = input_tensor.qnn_params['scale']
kernel_scale = weight_tensor.qnn_params['scale']
out = _qnn.op.dense(in_expr, weight_expr,
input_zero_point=input_tensor.qnn_params['zero_point'],
kernel_zero_point=weight_tensor.qnn_params['zero_point'],
input_scale=input_scale,
kernel_scale=kernel_scale,
input_scale=input_tensor.qnn_params['scale'],
kernel_scale=weight_tensor.qnn_params['scale'],
out_dtype='int32')
else:
out = _op.nn.dense(in_expr, weight_expr)
Expand All @@ -781,11 +789,16 @@ def convert_fully_connected(self, op):

# Finally if the dense is quantized. Add a requantize at the end.
if output_tensor.qnn_params:
input_scale = input_tensor.qnn_params['scale'] * weight_tensor.qnn_params['scale']
input_zero_point = 0
data_scale = input_tensor.qnn_params['scale']
weight_scale = weight_tensor.qnn_params['scale']
data_scale_val = get_scalar_from_constant(data_scale)
weight_scale_val = get_scalar_from_constant(weight_scale)
new_input_scale_val = data_scale_val * weight_scale_val
new_input_scale = relay.const(new_input_scale_val, 'float32')
new_input_zero_point = relay.const(0, 'int32')
out = _qnn.op.requantize(out,
input_scale=input_scale,
input_zero_point=input_zero_point,
input_scale=new_input_scale,
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)
Expand Down Expand Up @@ -987,11 +1000,16 @@ def convert_conv(self, op, conv_type):

# Finally if the conv is quantized. Add a requantize at the end.
if output_tensor.qnn_params:
input_scale = input_tensor.qnn_params['scale'] * weight_tensor.qnn_params['scale']
input_zero_point = 0
data_scale = input_tensor.qnn_params['scale']
weight_scale = weight_tensor.qnn_params['scale']
data_scale_val = get_scalar_from_constant(data_scale)
weight_scale_val = get_scalar_from_constant(weight_scale)
new_input_scale_val = data_scale_val * weight_scale_val
new_input_scale = relay.const(new_input_scale_val, 'float32')
new_input_zero_point = relay.const(0, 'int32')
out = _qnn.op.requantize(out,
input_scale=input_scale,
input_zero_point=input_zero_point,
input_scale=new_input_scale,
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/qnn/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,3 @@
from .qnn import *
from .op import register_qnn_legalize
from . import legalizations
from . import op_attrs
Loading

0 comments on commit dd63f3d

Please sign in to comment.