Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QNN quantize and dequantize operators. #3745

Merged
merged 13 commits into from
Aug 16, 2019
Prev Previous commit
Next Next commit
addressing review comments.
shoubhikbhatti@gmail.com committed Aug 12, 2019
commit 9b86fe61d653f0f79a00109b7df0dd941f6ccb9e
42 changes: 27 additions & 15 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
@@ -74,48 +74,60 @@ def requantize(data,
out_dtype)


def quantize(input_data, output_zero_point, output_scale, out_dtype='int8'):
def quantize(input_data,
shoubhik marked this conversation as resolved.
Show resolved Hide resolved
output_scale,
output_zero_point,
out_dtype='int8'):
r""" Quantize op
This operator takes float32 as input and produces quantized int8 or unit8 as output.
The input tensor can be of any shape. The output shape is the same as input shape.
..math::
\mbox{out}[x] =
\mbox{clamp(round(input_tensor/output_scale) + output_zero_point);
out_dtype::min, out_dtype::max}
Parameters
Q_output = clamp(round(input_tensor/output_scale) + output_zero_point), out_dtype::min, out_dtype::max)
Parameters
----------
input_data : tvm.relay.Expr
The input tensor to be quantized. Can be of type float32.
output_zero_point :
output_zero_point : int
The output zero_point.
output_scale:
output_scale : float
The output scale.
input_dtype:
input_dtype : str, optional
The data type of the input tensor. Can be [int8, uint8]
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.quantize(input_data, output_zero_point, output_scale, out_dtype)

return _make.quantize(input_data,
output_scale,
output_zero_point,
out_dtype)

def dequantize(input_data, input_zero_point, input_scale):

def dequantize(input_data,
input_scale,
input_zero_point):
r""" Dequantize op
This operator takes quantized int8 and unit8 as input and produces
shoubhik marked this conversation as resolved.
Show resolved Hide resolved
dequantized float32 as output. The output shape is the same as input shape. The input
tensor can be of any shape.
Parameters
Parameters
----------
input_data : tvm.relay.Expr
The input tensor to be dequantized. Can be of type [int8, uint8].
input_zero_point :
input_zero_point : int
The output zero_point.
input_scale:
input_scale : float
The output scale.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.dequantize(input_data, input_zero_point, input_scale)

return _make.dequantize(input_data,
input_scale,
input_zero_point)
11 changes: 4 additions & 7 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
@@ -25,7 +25,6 @@
*/

#include <tvm/relay/analysis.h>
//#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../pass/pattern_util.h"
@@ -53,11 +52,13 @@ bool DequantizeRel(const Array<Type>& types,
}

Expr MakeDequantize(Expr data,
int32_t input_zero_point,
double input_scale) {
double input_scale,
int32_t input_zero_point) {
auto attrs = make_node<DequantizeAttrs>();
attrs->input_scale = input_scale;
attrs->input_zero_point = input_zero_point;
// real_value = scale * (quantized_value - zero_point)
// A more detailed explanation can be found here - https://github.com/google/gemmlowp/blob/master/doc/quantization.md
static const Op& op = Op::Get("qnn.dequantize");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
@@ -78,10 +79,6 @@ Expr DequantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args,
CHECK(param != nullptr);

CHECK_EQ(arg_types.size(), 1);
auto input_dtype = arg_types[0];
auto input_tensor_type = input_dtype.as<TensorTypeNode>();
CHECK(input_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
return DequantizeLower(data, param);
}

11 changes: 4 additions & 7 deletions src/relay/qnn/op/quantize_op.cc
Original file line number Diff line number Diff line change
@@ -25,7 +25,6 @@
*/

#include <tvm/relay/analysis.h>
//#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../pass/pattern_util.h"
@@ -57,13 +56,14 @@ bool QuantizeRel(const Array<Type>& types,
}

Expr MakeQuantize(Expr data,
int32_t output_zero_point,
double output_scale,
int32_t output_zero_point,
DataType out_dtype) {
auto attrs = make_node<QuantizeAttrs>();
attrs->output_scale = output_scale;
attrs->output_zero_point = output_zero_point;
attrs->out_dtype = std::move(out_dtype);
// quantized_output =
shoubhik marked this conversation as resolved.
Show resolved Hide resolved
static const Op& op = Op::Get("qnn.quantize");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
@@ -75,7 +75,8 @@ Expr QuantizeLower(const Expr& input_tensor, const QuantizeAttrs* param) {
const int32_t min_val = GetQmin(out_dtype);
const int32_t max_val = GetQmax(out_dtype);
auto scale_data = Cast(Round(Divide(input_tensor, scale)), Int(32));
// we are trying to do - std::min(std::max(unclamped, min_val), max_val);
// result_quantized_value = result_zero_point + result_real_value / result_scale.
// A more detailed explanation can be found here - https://github.com/google/gemmlowp/blob/master/doc/quantization.md
auto add_zero_point = Add(scale_data, output_zero_point);
auto clamped_output = Clip(add_zero_point, min_val, max_val);
auto clamp_out_dtype = Cast(clamped_output, out_dtype);
@@ -90,10 +91,6 @@ Expr QuantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args,
CHECK(param != nullptr);

CHECK_EQ(arg_types.size(), 1);
shoubhik marked this conversation as resolved.
Show resolved Hide resolved
auto input_dtype = arg_types[0];
auto input_tensor_type = input_dtype.as<TensorTypeNode>();
CHECK(input_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
return QuantizeLower(data, param);
}

2 changes: 1 addition & 1 deletion src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
@@ -235,7 +235,7 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const RequantizeAttrs* param = attrs.as<RequantizeAttrs>();
auto out_dtype = param->out_dtype;
CHECK(out_dtype == Int(8) || out_dtype == UInt(8) || out_dtype == Int(32))
<< "Output type should be one of [int8, uint8, int32] integer but was " << out_dtype;
<< "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype));
return true;
}
4 changes: 2 additions & 2 deletions tests/python/relay/test_qnn_dequantize.py
Original file line number Diff line number Diff line change
@@ -27,8 +27,8 @@ def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data):
input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
input_zero_point = quant_args['in_zero_point']
input_scale = quant_args['in_scale']
quantized_output = relay.qnn.op.dequantize(input_data, input_zero_point=input_zero_point,
input_scale=input_scale)
quantized_output = relay.qnn.op.dequantize(input_data, input_scale=input_scale,
input_zero_point=input_zero_point)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod)
mod = relay.transform.Legalize()(mod)
4 changes: 2 additions & 2 deletions tests/python/relay/test_qnn_quantize.py
Original file line number Diff line number Diff line change
@@ -33,8 +33,8 @@ def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output
input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
output_zero_point = quant_args['out_zero_point']
output_scale = quant_args['out_scale']
quantized_output = relay.qnn.op.quantize(input_data, output_zero_point=output_zero_point,
output_scale=output_scale, out_dtype=out_dtype)
quantized_output = relay.qnn.op.quantize(input_data, output_scale=output_scale,
output_zero_point=output_zero_point,out_dtype=out_dtype)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod)
mod = relay.transform.Legalize()(mod)