Skip to content

Commit

Permalink
[QNN] Add support for per channel weight scale in dense op (apache#4880)
Browse files Browse the repository at this point in the history
* add test case for per channel dense

* add unit arg in tflite frontend

* update qnn legalize test

* fix output dim index
  • Loading branch information
masahi authored and zhiics committed Mar 2, 2020
1 parent 8061a35 commit 2b3204c
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 55 deletions.
2 changes: 2 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,13 +982,15 @@ def convert_fully_connected(self, op):

weight_value = self.get_tensor_value(weight_tensor)
weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
weight_shape = _infer_shape(weight_expr)

if input_tensor.qnn_params:
out = _qnn.op.dense(in_expr, weight_expr,
input_zero_point=input_tensor.qnn_params['zero_point'],
kernel_zero_point=weight_tensor.qnn_params['zero_point'],
input_scale=input_tensor.qnn_params['scale'],
kernel_scale=weight_tensor.qnn_params['scale'],
units=weight_shape[0],
out_dtype='int32')
else:
out = _op.nn.dense(in_expr, weight_expr)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def dense(data,
kernel_zero_point,
input_scale,
kernel_scale,
units=None,
units,
out_dtype="int32"):
"""Qnn Dense operator.
Applies a quantized linear transformation
Expand All @@ -371,7 +371,7 @@ def dense(data,
stored for access to this during relay. This information is not
needed in the pass pipeline after qnn.conv2d is lowered to the
sequence of steps as in nn.conv2d. See also input_scale in Requantize.
units : int, optional
units : int
Number of hidden units of the dense transformation.
out_dtype : str, optional
Specifies the output data type for mixed precision dense can be int32 or int16.
Expand Down
2 changes: 1 addition & 1 deletion src/relay/qnn/op/dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point
CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale
CHECK(IsScalarType(types[5], DataType::Float(32))); // kernel_scale
AssignType(types[5], DataType::Float(32), param->units, reporter);

CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";

Expand Down
75 changes: 23 additions & 52 deletions tests/python/relay/test_op_qnn_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,52 +75,8 @@ def make_configuration(quantized_data,
return config


def make_uint_configuration(use_bias=False, requantize_output=False):
input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3)
input_zero_point, kernel_zero_point = 127, 127
input_scale = 0.5
kernel_scale = 0.5
output_scale = 1.0
in_dtype = 'uint8'
out_dtype = 'int32' if not requantize_output else 'uint8'
units = 3
quantized_data_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 109, 107,
129, 131, 133, 135, 137, 139, 141, 111, 145, 107]) \
.astype(in_dtype) \
.reshape(input_shape)
quantized_kernel_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 145, 147,
129, 131, 133, 135, 137, 139, 141, 143, 145, 147,
129, 131, 133, 135, 137, 139, 141, 143, 145, 147]) \
.astype(in_dtype) \
.reshape(kernel_shape)
bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None
requant_params = make_requantize_params(input_scale * kernel_scale, output_scale, 127, 'uint8') if requantize_output else None

if requantize_output:
assert use_bias
output = np.array([151, 152, 153, 185, 186, 187])
elif use_bias:
output = np.array([96, 100, 104, 232, 236, 240 ])
else:
output = np.array([92, 92, 92, 228, 228, 228 ])
output = output.astype(out_dtype).reshape(output_shape)
return make_configuration(quantized_data=quantized_data_np,
quantized_kernel=quantized_kernel_np,
dtype=in_dtype,
input_shape=input_shape,
kernel_shape=kernel_shape,
input_zero_point=input_zero_point,
kernel_zero_point=kernel_zero_point,
input_scale=input_scale,
kernel_scale= kernel_scale,
units=units,
output=output,
bias=bias,
requantize=requant_params)


def make_int_configuration(use_bias=False, requantize_output=False):
input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3)
def make_int_configuration(use_bias=False, requantize_output=False, per_channel=False):
input_shape, kernel_shape, output_shape = (2, 10), (3, 10), (2, 3)
input_zero_point, kernel_zero_point = -1, -1
in_dtype = 'int8'
out_dtype = 'int32' if not requantize_output else 'int8'
Expand All @@ -138,15 +94,22 @@ def make_int_configuration(use_bias=False, requantize_output=False):
kernel_scale = 0.5
output_scale = 1.0
bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None
requant_params = make_requantize_params(input_scale * kernel_scale, output_scale, -1, 'int8') if requantize_output else None

if requantize_output:
if per_channel:
assert use_bias and requantize_output
kernel_scale = np.array([0.5, 0.3, 0.4], dtype=np.float32)
output = np.array([23, 14, 20, 57, 34, 47])
elif requantize_output:
assert use_bias
output = np.array([23, 24, 25, 57, 58, 59])
elif use_bias:
output = np.array([96, 100, 104, 232, 236, 240 ])
output = np.array([96, 100, 104, 232, 236, 240])
else:
output = np.array([92, 92, 92, 228, 228, 228 ])
output = np.array([92, 92, 92, 228, 228, 228])

requant_params = make_requantize_params(input_scale * kernel_scale,
output_scale, -1, 'int8') if requantize_output else None

output = output.astype(out_dtype).reshape(output_shape)
return make_configuration(quantized_data=quantized_data_np,
quantized_kernel=quantized_kernel_np,
Expand Down Expand Up @@ -206,8 +169,8 @@ def qnn_dense_driver(test_configuration):
with relay.build_config(opt_level=2):
graph, lib, params = relay.build(mod, "llvm", params=None)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input(quantized_data_name,test_configuration[quantized_data_name])
mod.set_input(quantized_kernel_name,test_configuration[quantized_kernel_name])
mod.set_input(quantized_data_name, test_configuration[quantized_data_name])
mod.set_input(quantized_kernel_name, test_configuration[quantized_kernel_name])
if test_configuration[bias_name] is not None:
mod.set_input(bias_name, test_configuration[bias_name])
mod.set_input(**params)
Expand Down Expand Up @@ -241,7 +204,15 @@ def test_qnn_dense_with_requantized_output():
qnn_dense_driver(int8_requantized_output_with_bias_params)


def test_per_channel_weight_scale():
with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense):
config = make_int_configuration(use_bias=True, requantize_output=True,
per_channel=True)
qnn_dense_driver(config)


if __name__ == "__main__":
test_qnn_dense_without_bias()
test_qnn_dense_with_bias()
test_qnn_dense_with_requantized_output()
test_per_channel_weight_scale()
1 change: 1 addition & 0 deletions tests/python/relay/test_pass_qnn_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def _get_mod(data_dtype, kernel_dtype):
kernel_zero_point=relay.const(1, 'int32'),
input_scale=relay.const(1, 'float32'),
kernel_scale=relay.const(1, 'float32'),
units=kernel_shape[0],
out_dtype='int32')

mod = relay.Function(relay.analysis.free_vars(func), func)
Expand Down

0 comments on commit 2b3204c

Please sign in to comment.