Skip to content

Commit

Permalink
[QNN] Channel wise quantization - Quantize & Requantize (apache#4629)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and alexwong committed Feb 26, 2020
1 parent cc329d0 commit 4907ef9
Show file tree
Hide file tree
Showing 9 changed files with 413 additions and 44 deletions.
10 changes: 10 additions & 0 deletions include/tvm/relay/qnn/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,15 @@ namespace qnn {

/*! \brief Attribute for requantize operator */
struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
int axis;
std::string rounding;
DataType out_dtype;

TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
TVM_ATTR_FIELD(axis)
.describe("The output channel axis for channel wise quantization. Default value is -1,"
"which corresponds to the last axis.")
.set_default(-1);
TVM_ATTR_FIELD(rounding).set_default("UPWARD")
.describe("Defines the rounding direction when the value is midway between"
"two representable values. There are two supported modes - UPWARD"
Expand All @@ -56,10 +61,15 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
/*! \brief Attribute for quantize operator */
struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
DataType out_dtype;
int axis;

TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
TVM_ATTR_FIELD(out_dtype)
.describe("Output data type, can be one of [int8 or uint8].");
TVM_ATTR_FIELD(axis)
.describe("The output channel axis for channel wise quantization. Default value is -1,"
"which corresponds to the last axis.")
.set_default(-1);
}
};

Expand Down
9 changes: 9 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def requantize(data,
input_zero_point,
output_scale,
output_zero_point,
axis=-1,
rounding="UPWARD",
out_dtype="int8"):
r"""Requantized operator.
Expand Down Expand Up @@ -53,6 +54,9 @@ def requantize(data,
output_zero_point: tvm.relay.Expr
The zero point of the output tensor.
axis : int
The channel axis for quantization. Default value is -1 which corresponds to the last axis.
rounding : string, optional
Defines the rounding direction when the value is midway between two
representable values.
Expand All @@ -71,13 +75,15 @@ def requantize(data,
input_zero_point,
output_scale,
output_zero_point,
axis,
rounding,
out_dtype)


def quantize(data,
output_scale,
output_zero_point,
axis=-1,
out_dtype='int8'):
r""" Quantize op
This operator takes float32 as input and produces quantized int8 or unit8 as output.
Expand All @@ -95,6 +101,8 @@ def quantize(data,
The output zero_point.
output_scale : tvm.relay.Expr
The output scale.
axis : int
The channel axis for quantization. Default value is -1 which corresponds to the last axis.
out_dtype : str, optional
The data type of the input tensor. Can be [int8, uint8]
Returns
Expand All @@ -106,6 +114,7 @@ def quantize(data,
return _make.quantize(data,
output_scale,
output_zero_point,
axis,
out_dtype)


Expand Down
48 changes: 46 additions & 2 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/reduce.h>
#include <string>
#include <vector>
#include <utility>


Expand Down Expand Up @@ -221,29 +222,70 @@ inline bool IsScalar(const Expr& expr) {
return true;
}

/*!
* \brief Check if expr is a const scalar.
* \param expr The expr.
* \return True if const scalar.
*/
inline bool IsConstScalar(const Expr& expr) {
const auto* const_expr = expr.as<ConstantNode>();
if (const_expr) {
return const_expr->is_scalar();
}
return false;
}

/*!
* \brief Create a Constant with a scalar
*
* \param dtype The data type.
* \param value The value of the scalar.
* \return A Constant.
*/
template<typename T>
template <typename T>
inline Constant MakeConstantScalar(DataType dtype, T value) {
runtime::NDArray arr = runtime::NDArray::Empty({}, dtype, {kDLCPU, 0});
TVM_DTYPE_DISPATCH(dtype, DType, {
if (dtype == DataType::Float(16)) {
// convert to float16
// storage is uint16_t
*static_cast<DType*>(arr->data) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
} else {
*static_cast<DType*>(arr->data) = value;
}
})
return ConstantNode::make(arr);
}

/*!
* \brief Create a Constant with a tensor.
*
* \param dtype The data type.
* \param value The vector of the tensor values.
* \return A Constant.
*/
template <typename T>
static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> shape,
std::vector<T> value) {
runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0});
TVM_DTYPE_DISPATCH(dtype, DType, {
for (size_t i = 0; i < value.size(); i++) {
if (dtype == DataType::Float(16)) {
// convert to float16
// storage is uint16_t
// Similar handling as that in MakeConstantScalar
*(static_cast<DType*>(arr->data) + i) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
static_cast<float>(value[i]));
} else {
*(static_cast<DType*>(arr->data) + i) = value[i];
}
}
})
return ConstantNode::make(arr);
}

/*!
* \brief Check if two expressions are equal scalars.
* \param a The expression to be checked.
Expand Down Expand Up @@ -523,6 +565,8 @@ static inline Expr Tile(Expr data, Array<Integer> reps) {
return CallNode::make(op, {data}, Attrs(attrs), {});
}

Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape);

Expr MakeConcatenate(Expr data, int axis);

Expr MakeRepeat(Expr data, int repeats, int axis);
Expand Down
50 changes: 41 additions & 9 deletions src/relay/qnn/op/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,18 @@ bool QuantizeRel(const Array<Type>& types,
CHECK(input_dtype == DataType::Float(32))
<< "Input type should be one of float32 but was " << input_dtype;

// Check the types of scale and zero points.
CHECK(IsScalarType(types[1], DataType::Float(32))); // output_scale
CHECK(IsScalarType(types[2], DataType::Int(32))); // output_zero_point

const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
int axis = quantize_attrs->axis;
axis = (axis == -1) ? data->shape.size() - 1: axis;
CHECK_LT(axis, static_cast<int>(data->shape.size()))
<< "axis " << quantize_attrs->axis << " is out of range";
CHECK_GE(axis, 0)
<< "axis " << quantize_attrs->axis << " is out of range";

// Check and assign types for scale and zero points.
AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale
AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point

const Array<tvm::Expr> oshape = data->shape;
const DataType out_dtype = quantize_attrs->out_dtype;
CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
Expand All @@ -60,8 +67,10 @@ bool QuantizeRel(const Array<Type>& types,
return true;
}

Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, DataType out_dtype) {
Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, int axis,
DataType out_dtype) {
auto attrs = make_object<QuantizeAttrs>();
attrs->axis = axis;
attrs->out_dtype = std::move(out_dtype);
// result_quantized_value = result_zero_point + result_real_value / result_scale.
// A more detailed explanation can be found here -
Expand All @@ -71,13 +80,29 @@ Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, DataType
}

Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale,
const Expr& output_zero_point, const QuantizeAttrs* attrs) {
const Expr& output_zero_point, const Array<IndexExpr>& input_shape,
const QuantizeAttrs* attrs) {
const auto out_dtype = attrs->out_dtype;
const auto axis = attrs->axis;

size_t n_dim = input_shape.size();

auto expanded_output_scale = output_scale;
if (!IsConstScalar(output_scale)) {
expanded_output_scale = ExpandBiasToMatchAxis(output_scale, n_dim, {axis});
}

auto expanded_output_zero_point = output_zero_point;
if (!IsConstScalar(output_zero_point)) {
expanded_output_zero_point = ExpandBiasToMatchAxis(output_zero_point, n_dim, {axis});
}

const int32_t min_val = GetQmin(out_dtype);
const int32_t max_val = GetQmax(out_dtype);
auto scale_data = Divide(input_tensor, output_scale);
auto scale_data = Divide(input_tensor, expanded_output_scale);
auto add_zero_point =
Cast(Round(Add(scale_data, Cast(output_zero_point, DataType::Float(32)))), DataType::Int(32));
Cast(Round(Add(scale_data, Cast(expanded_output_zero_point, DataType::Float(32)))),
DataType::Int(32));
auto clamped_output = Clip(add_zero_point, min_val, max_val);
auto clamp_out_dtype = Cast(clamped_output, out_dtype);
return clamp_out_dtype;
Expand All @@ -92,8 +117,15 @@ Expr QuantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
CHECK(quantize_attrs != nullptr);

// Find input shape.
CHECK_EQ(types.size(), 4);
return QuantizeLower(data, output_scale, output_zero_point, quantize_attrs);
auto in_type = types[0];
auto in_tensor_type = in_type.as<TensorTypeNode>();
CHECK(in_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = in_tensor_type->shape;

return QuantizeLower(data, output_scale, output_zero_point, input_shape, quantize_attrs);
}

RELAY_REGISTER_OP("qnn.quantize")
Expand Down
57 changes: 42 additions & 15 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
float input_scale_float = GetScalarFromConstant<float>(input_scale);
float output_scale_float = GetScalarFromConstant<float>(output_scale);
double double_multiplier =
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);

DataType hp_dtype = DataType::Int(64);

auto tensor = Cast(input_tensor, hp_dtype);
Expand All @@ -72,11 +67,34 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
tensor = Subtract(tensor, Cast(input_zero_point, hp_dtype));
}

// 2) If the input and output scales are same, we can skip the fixed point multiplication.
// 2) If the input and output scales are same, we can skip the fixed point multiplication. Check
// if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for
// the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input
// tensor. Depending on the quantization type, the fixed point multiplication routing is called.
auto scaled_int64_t = tensor;
if (!IsEqualScalar(input_scale, output_scale)) {
scaled_int64_t =
FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
float output_scale_float = GetScalarFromConstant<float>(output_scale);
if (IsConstScalar(input_scale)) {
// This is per-tensor quantization. Single scale.
float input_scale_float = GetScalarFromConstant<float>(input_scale);
double double_multiplier =
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
// Skip if input and output scales are same.
if (!IsEqualScalar(input_scale, output_scale)) {
scaled_int64_t =
FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
}
} else {
// This is per-channel (per=axis) quantization.
std::vector<double> double_multipliers;
auto input_axis_scales = GetFloatVectorFromConstant(input_scale);
for (auto input_axis_scale : input_axis_scales) {
double_multipliers.push_back(static_cast<double>(input_axis_scale) /
static_cast<double>(output_scale_float));
}
int axis = param->axis;
axis = (axis == -1) ? input_shape.size() - 1 : axis;
scaled_int64_t = FixedPointMultiplyPerChannel(scaled_int64_t, double_multipliers, input_shape,
axis, param->rounding);
}

// 3) Add the output zero point.
Expand Down Expand Up @@ -157,16 +175,24 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
in_dtype == DataType::Int(32))
<< "Input type should be one of [int8, uint8, int32] but was " << in_dtype;

// Check the types of scale and zero points.
CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale
CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
int axis = requantize_attrs->axis;
axis = (axis == -1) ? data->shape.size() - 1: axis;
CHECK_LT(axis, static_cast<int>(data->shape.size()))
<< "axis " << requantize_attrs->axis << " is out of range";
CHECK_GE(axis, 0)
<< "axis " << requantize_attrs->axis << " is out of range";

// Check and assign types for scale and zero points.
AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // input_scale
AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // input_zero_pt
// For now, requantize output tensor is limited to full tensor uniform quantization.
CHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point

const Array<tvm::Expr> oshape = data->shape;
// assign output type
const RequantizeAttrs* param = attrs.as<RequantizeAttrs>();
auto out_dtype = param->out_dtype;
auto out_dtype = requantize_attrs->out_dtype;
CHECK(out_dtype == DataType::Int(8) ||
out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(32))
Expand All @@ -178,8 +204,9 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// Positional relay function to create qnn requantize operator
// used by frontend FFI.
Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale,
Expr output_zero_point, std::string rounding, DataType out_dtype) {
Expr output_zero_point, int axis, std::string rounding, DataType out_dtype) {
auto attrs = make_object<RequantizeAttrs>();
attrs->axis = axis;
attrs->rounding = std::move(rounding);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("qnn.requantize");
Expand Down
Loading

0 comments on commit 4907ef9

Please sign in to comment.