Skip to content

Commit

Permalink
[QNN] Add qnn.rsqrt op (#9982)
Browse files Browse the repository at this point in the history
* Add qnn.rsqrt op

* Add comment
  • Loading branch information
sfvaroglu authored Jan 24, 2022
1 parent 65b4b09 commit 6f2b35f
Show file tree
Hide file tree
Showing 8 changed files with 369 additions and 0 deletions.
35 changes: 35 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,41 @@ def mul(
)


def rsqrt(x, scale, zero_point, output_scale, output_zero_point):
"""Quantized reciprocal square root.
Parameters
----------
x : relay.Expr
The quantized input tensor.
scale: relay.Expr
The scale of the quantized expr.
zero_point: relay.Expr
The zero point of quantized expr.
output_scale: relay.Expr
The scale of the output quantized expr.
output_zero_point: relay.Expr
The zero point of output quantized expr.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.rsqrt(
x,
scale,
zero_point,
output_scale,
output_zero_point,
)


def subtract(
lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point
):
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,22 @@ def global_avgpool2d(expr, type_map):
return [out, t]


@register_fake_quantization_to_integer("rsqrt")
def rsqrt(expr, type_map):
"""Rewrite a rsqrt op"""
arg = expr.args[0]
x_t = type_map[arg]
out_t = type_map[expr]
out = relay.qnn.op.rsqrt(
arg,
x_t.scale,
x_t.zero_point,
out_t.scale,
out_t.zero_point,
)
return [out, x_t]


@register_fake_quantization_to_integer("nn.bias_add")
def bias_add(expr, type_map):
"""Rewrite a bias_add op"""
Expand Down Expand Up @@ -394,6 +410,7 @@ def binary(expr, type_map):
out_t.scale,
out_t.zero_point,
)

return [out, out_t]

return register_fake_quantization_to_integer(op_name, binary)
Expand Down
53 changes: 53 additions & 0 deletions src/relay/qnn/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,59 @@ struct QnnBinaryOpArguments {
}
};

/*
* Number of inputs for the Qnn unary operators.
*/
static constexpr int kNumQnnUnaryOpInputs = 5;

/*
* Number of expected arg types.
*/
static constexpr int kNumQnnUnaryOpArgTypes = 6;

/*
* \brief Simple struct to organize the inputs to the Qnn
* unary operators. The main reason to have a struct
* is to be able to perform the common checks needed at a
* central location.
*/
struct QnnUnaryOpArguments {
Expr x;
Expr scale;
Expr zero_point;
Expr output_scale;
Expr output_zero_point;

explicit QnnUnaryOpArguments(const Array<Expr>& new_args) {
ICHECK_EQ(new_args.size(), kNumQnnUnaryOpInputs);
int idx = 0;
x = new_args[idx++];
scale = new_args[idx++];
zero_point = new_args[idx++];
output_scale = new_args[idx++];
output_zero_point = new_args[idx++];
ICHECK_EQ(idx, kNumQnnUnaryOpInputs);
}
};

/*
* \brief Simple structure to hold the input tensor's dtype
* and shape. This structure allows a common point to do
* all the validation checks for Qnn unary operators.
*/
struct QnnUnaryOpTensorType {
DataType dtype;
Array<PrimExpr> shape;

explicit QnnUnaryOpTensorType(const Array<tvm::relay::Type>& arg_types, const int32_t arg_idx) {
ICHECK_EQ(arg_types.size(), kNumQnnUnaryOpArgTypes);
auto tensor_type = arg_types[arg_idx].as<TensorTypeNode>();
ICHECK(tensor_type != nullptr);
dtype = tensor_type->dtype;
shape = tensor_type->shape;
}
};

/*
* \brief Simple structure to hold the input tensor's dtype
* and shape. This structure allows a common point to do
Expand Down
126 changes: 126 additions & 0 deletions src/relay/qnn/op/rsqrt.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/qnn/op/rsqrt.cc
* \brief QNN rsqrt operator.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>

#include "op_common.h"

namespace tvm {
namespace relay {
namespace qnn {

bool QnnRsqrtRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// Expected Types: data, scale, zero_point, output_scale, output_zero_point
ICHECK_EQ(types.size(), 6);
const auto* x = types[0].as<TensorTypeNode>();
if (x == nullptr) return false;
ICHECK(x->dtype == DataType::Int(8) || x->dtype == DataType::UInt(8))
<< "Expected quantized rsqrt type(int8, uint8) for input but was " << x->dtype;

// Check the types of scale and zero points.
for (size_t i = 1; i < 5; ++i) {
if (types[i].as<IncompleteTypeNode>()) {
return false;
}
}
ICHECK(IsScalarType(types[1], DataType::Float(32))); // scale
ICHECK(IsScalarType(types[2], DataType::Int(32))); // zero_point
ICHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
ICHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point

// Assign types for scale and zero points.
reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // scale
reporter->Assign(types[2], TensorType({}, DataType::Int(32))); // zero_point
reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // output_scale
reporter->Assign(types[4], TensorType({}, DataType::Int(32))); // output_zero_point

// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
// IdentityRel infer type function.
Array<Type> tensor_types = {types[0], types[5]};
return IdentityRel(tensor_types, 2, attrs, reporter);
}

// Positional relay function to create quantized rsqrt operator used by frontend FFI.
Expr MakeQuantizedRsqrt(Expr x, Expr scale, Expr zero_point, Expr output_scale,
Expr output_zero_point) {
static const Op& op = Op::Get("qnn.rsqrt");
return Call(op, {x, scale, zero_point, output_scale, output_zero_point}, Attrs(), {});
}

/*
* \brief Canonicalizes the QNN rsqrt op.
* \param attrs The empty attribute.
* \param new_args The new mutated args to the call node.
* \param arg_types The types of input and output.
* \return The sequence of Relay ops for add op.
*/
Expr QnnRsqrtCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
// At this time, due to the complexity of implementing this op in int8 or uint8,
// we dequantize the input, run the op in float, and then quantize the output (as below).
// This acts as a placeholder for future hardware enablement, where more hardware specific
// canonicalization can be provided.

// Get the args.
QnnUnaryOpArguments args(new_args);

// Get the input dtype and shape.
QnnUnaryOpTensorType input_type(arg_types, 0);

// Get the types for dequantize/quantize.
Array<tvm::relay::Type> types;
for (size_t i = 1; i < 5; ++i) {
types.push_back(arg_types[i]);
}

// Dequantize input.
auto dequantized_arg = Dequantize(args.x, args.scale, args.zero_point, types, -1);

// Compute Rsqrt(Q_x')
auto output = Rsqrt(dequantized_arg);

// Quantize output.
return Quantize(output, args.output_scale, args.output_zero_point, input_type.dtype, types, -1);
}

RELAY_REGISTER_OP("qnn.rsqrt")
.describe("Elementwise rsqrt for quantized tensors.")
.set_num_inputs(5)
.add_argument("data", "Quantized Tensor", "The input data.")
.add_argument("scale", "Tensor", "The quantization scale of the input tensor.")
.add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.")
.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
.add_argument("output_zero_point", "Tensor",
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("QRsqrt", QnnRsqrtRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnRsqrtCanonicalize);

TVM_REGISTER_GLOBAL("relay.qnn.op._make.rsqrt").set_body_typed(MakeQuantizedRsqrt);

} // namespace qnn
} // namespace relay
} // namespace tvm
27 changes: 27 additions & 0 deletions src/relay/qnn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,33 @@ static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_sh
attrs.operator->(), input_shape, attrs->out_dtype);
}

Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Array<tvm::relay::Type>& types,
const DequantizeAttrs* attrs);

static inline Expr Dequantize(const Expr& data, const Expr& input_scale,
const Expr& input_zero_point, const Array<tvm::relay::Type>& types,
const int& axis = -1) {
auto attrs = make_object<DequantizeAttrs>();
attrs->axis = std::move(axis);

return DequantizeLower(data, input_scale, input_zero_point, types, attrs.operator->());
}

Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale,
const Expr& output_zero_point, const Array<tvm::relay::Type>& types,
const QuantizeAttrs* attrs);

static inline Expr Quantize(const Expr& data, const Expr& output_scale,
const Expr& output_zero_point, const DataType& out_dtype,
const Array<tvm::relay::Type>& types, const int& axis = -1) {
auto attrs = make_object<QuantizeAttrs>();
attrs->axis = std::move(axis);
attrs->out_dtype = std::move(out_dtype);

return QuantizeLower(data, output_scale, output_zero_point, types, attrs.operator->());
}

static inline int64_t get_const_int(const tvm::PrimExpr& x) {
auto* value_ptr = tir::as_const_int(x);
ICHECK(value_ptr) << "Expr is not a constant int";
Expand Down
5 changes: 5 additions & 0 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,11 @@ inline Expr Sqrt(Expr x) {
return Call(op, {x}, Attrs(), {});
}

inline Expr Rsqrt(Expr x) {
static const Op& op = Op::Get("rsqrt");
return Call(op, {x}, Attrs(), {});
}

inline Expr Relu(Expr x) {
static const Op& op = Op::Get("nn.relu");
return Call(op, {x}, Attrs(), {});
Expand Down
93 changes: 93 additions & 0 deletions tests/python/relay/test_op_qnn_rsqrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
import numpy as np
from tvm import relay


def dequantize(data, scale, zp):
return scale * (np.asarray(data) - zp)


def generate_golden_output(dequantized_x, output_scale, output_zero_point):
rsqrt = 1 / np.sqrt(dequantized_x)
output = np.around(rsqrt / output_scale + output_zero_point)

q_min = np.iinfo(np.uint8).min
q_max = np.iinfo(np.uint8).max
return np.clip(output, q_min, q_max)


def test_saturation():
# Same params
data_dtype = "uint8"
scale = output_scale = 0.125
zero_point = output_zero_point = 0

x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.qnn.op.rsqrt(
x=x,
scale=relay.const(scale, "float32"),
zero_point=relay.const(zero_point, "int32"),
output_scale=relay.const(output_scale, "float32"),
output_zero_point=relay.const(output_zero_point, "int32"),
)

func = relay.Function([x], y)
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]

x_data = np.array((255, 133, 0, 9)).reshape((1, 4))
x_dequantized = dequantize(x_data, scale, zero_point)
golden_output = generate_golden_output(x_dequantized, output_scale, output_zero_point)

op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)(x_data)

np.testing.assert_equal(op_res.numpy(), np.uint8(golden_output))

# Different scale
scale = 0.125
output_scale = 0.25

y = relay.qnn.op.rsqrt(
x=x,
scale=relay.const(scale, "float32"),
zero_point=relay.const(zero_point, "int32"),
output_scale=relay.const(output_scale, "float32"),
output_zero_point=relay.const(output_zero_point, "int32"),
)

func = relay.Function([x], y)
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]

x_data = np.array((255, 133, 0, 9)).reshape((1, 4))
x_dequantized = dequantize(x_data, scale, zero_point)
golden_output = generate_golden_output(x_dequantized, output_scale, output_zero_point)

op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)(x_data)

np.testing.assert_equal(op_res.numpy(), golden_output)


if __name__ == "__main__":
test_saturation()
Loading

0 comments on commit 6f2b35f

Please sign in to comment.