From 6f2b35feb59826b010b131722c5b7d47079df8a9 Mon Sep 17 00:00:00 2001 From: "Sevin F. Varoglu" Date: Mon, 24 Jan 2022 10:59:56 -0800 Subject: [PATCH] [QNN] Add qnn.rsqrt op (#9982) * Add qnn.rsqrt op * Add comment --- python/tvm/relay/qnn/op/qnn.py | 35 +++++ .../transform/fake_quantization_to_integer.py | 17 +++ src/relay/qnn/op/op_common.h | 53 ++++++++ src/relay/qnn/op/rsqrt.cc | 126 ++++++++++++++++++ src/relay/qnn/utils.h | 27 ++++ src/relay/transforms/pattern_utils.h | 5 + tests/python/relay/test_op_qnn_rsqrt.py | 93 +++++++++++++ .../test_pass_fake_quantization_to_integer.py | 13 ++ 8 files changed, 369 insertions(+) create mode 100644 src/relay/qnn/op/rsqrt.cc create mode 100644 tests/python/relay/test_op_qnn_rsqrt.py diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index b69afd69b8c5..7f707c093ff3 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -656,6 +656,41 @@ def mul( ) +def rsqrt(x, scale, zero_point, output_scale, output_zero_point): + """Quantized reciprocal square root. + + Parameters + ---------- + x : relay.Expr + The quantized input tensor. + + scale: relay.Expr + The scale of the quantized expr. + + zero_point: relay.Expr + The zero point of quantized expr. + + output_scale: relay.Expr + The scale of the output quantized expr. + + output_zero_point: relay.Expr + The zero point of output quantized expr. + + Returns + ------- + result : relay.Expr + The computed result. + + """ + return _make.rsqrt( + x, + scale, + zero_point, + output_scale, + output_zero_point, + ) + + def subtract( lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, output_zero_point ): diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 0aa0524ae5d5..db46c2cbfd58 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -126,6 +126,22 @@ def global_avgpool2d(expr, type_map): return [out, t] +@register_fake_quantization_to_integer("rsqrt") +def rsqrt(expr, type_map): + """Rewrite a rsqrt op""" + arg = expr.args[0] + x_t = type_map[arg] + out_t = type_map[expr] + out = relay.qnn.op.rsqrt( + arg, + x_t.scale, + x_t.zero_point, + out_t.scale, + out_t.zero_point, + ) + return [out, x_t] + + @register_fake_quantization_to_integer("nn.bias_add") def bias_add(expr, type_map): """Rewrite a bias_add op""" @@ -394,6 +410,7 @@ def binary(expr, type_map): out_t.scale, out_t.zero_point, ) + return [out, out_t] return register_fake_quantization_to_integer(op_name, binary) diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index 9957464c6e37..4df958603ff2 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -82,6 +82,59 @@ struct QnnBinaryOpArguments { } }; +/* + * Number of inputs for the Qnn unary operators. + */ +static constexpr int kNumQnnUnaryOpInputs = 5; + +/* + * Number of expected arg types. + */ +static constexpr int kNumQnnUnaryOpArgTypes = 6; + +/* + * \brief Simple struct to organize the inputs to the Qnn + * unary operators. The main reason to have a struct + * is to be able to perform the common checks needed at a + * central location. + */ +struct QnnUnaryOpArguments { + Expr x; + Expr scale; + Expr zero_point; + Expr output_scale; + Expr output_zero_point; + + explicit QnnUnaryOpArguments(const Array& new_args) { + ICHECK_EQ(new_args.size(), kNumQnnUnaryOpInputs); + int idx = 0; + x = new_args[idx++]; + scale = new_args[idx++]; + zero_point = new_args[idx++]; + output_scale = new_args[idx++]; + output_zero_point = new_args[idx++]; + ICHECK_EQ(idx, kNumQnnUnaryOpInputs); + } +}; + +/* + * \brief Simple structure to hold the input tensor's dtype + * and shape. This structure allows a common point to do + * all the validation checks for Qnn unary operators. + */ +struct QnnUnaryOpTensorType { + DataType dtype; + Array shape; + + explicit QnnUnaryOpTensorType(const Array& arg_types, const int32_t arg_idx) { + ICHECK_EQ(arg_types.size(), kNumQnnUnaryOpArgTypes); + auto tensor_type = arg_types[arg_idx].as(); + ICHECK(tensor_type != nullptr); + dtype = tensor_type->dtype; + shape = tensor_type->shape; + } +}; + /* * \brief Simple structure to hold the input tensor's dtype * and shape. This structure allows a common point to do diff --git a/src/relay/qnn/op/rsqrt.cc b/src/relay/qnn/op/rsqrt.cc new file mode 100644 index 000000000000..55814dff422b --- /dev/null +++ b/src/relay/qnn/op/rsqrt.cc @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/rsqrt.cc + * \brief QNN rsqrt operator. + */ +#include +#include + +#include "op_common.h" + +namespace tvm { +namespace relay { +namespace qnn { + +bool QnnRsqrtRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // Expected Types: data, scale, zero_point, output_scale, output_zero_point + ICHECK_EQ(types.size(), 6); + const auto* x = types[0].as(); + if (x == nullptr) return false; + ICHECK(x->dtype == DataType::Int(8) || x->dtype == DataType::UInt(8)) + << "Expected quantized rsqrt type(int8, uint8) for input but was " << x->dtype; + + // Check the types of scale and zero points. + for (size_t i = 1; i < 5; ++i) { + if (types[i].as()) { + return false; + } + } + ICHECK(IsScalarType(types[1], DataType::Float(32))); // scale + ICHECK(IsScalarType(types[2], DataType::Int(32))); // zero_point + ICHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale + ICHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point + + // Assign types for scale and zero points. + reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // scale + reporter->Assign(types[2], TensorType({}, DataType::Int(32))); // zero_point + reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // output_scale + reporter->Assign(types[4], TensorType({}, DataType::Int(32))); // output_zero_point + + // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay + // IdentityRel infer type function. + Array tensor_types = {types[0], types[5]}; + return IdentityRel(tensor_types, 2, attrs, reporter); +} + +// Positional relay function to create quantized rsqrt operator used by frontend FFI. +Expr MakeQuantizedRsqrt(Expr x, Expr scale, Expr zero_point, Expr output_scale, + Expr output_zero_point) { + static const Op& op = Op::Get("qnn.rsqrt"); + return Call(op, {x, scale, zero_point, output_scale, output_zero_point}, Attrs(), {}); +} + +/* + * \brief Canonicalizes the QNN rsqrt op. + * \param attrs The empty attribute. + * \param new_args The new mutated args to the call node. + * \param arg_types The types of input and output. + * \return The sequence of Relay ops for add op. + */ +Expr QnnRsqrtCanonicalize(const Attrs& attrs, const Array& new_args, + const Array& arg_types) { + // At this time, due to the complexity of implementing this op in int8 or uint8, + // we dequantize the input, run the op in float, and then quantize the output (as below). + // This acts as a placeholder for future hardware enablement, where more hardware specific + // canonicalization can be provided. + + // Get the args. + QnnUnaryOpArguments args(new_args); + + // Get the input dtype and shape. + QnnUnaryOpTensorType input_type(arg_types, 0); + + // Get the types for dequantize/quantize. + Array types; + for (size_t i = 1; i < 5; ++i) { + types.push_back(arg_types[i]); + } + + // Dequantize input. + auto dequantized_arg = Dequantize(args.x, args.scale, args.zero_point, types, -1); + + // Compute Rsqrt(Q_x') + auto output = Rsqrt(dequantized_arg); + + // Quantize output. + return Quantize(output, args.output_scale, args.output_zero_point, input_type.dtype, types, -1); +} + +RELAY_REGISTER_OP("qnn.rsqrt") + .describe("Elementwise rsqrt for quantized tensors.") + .set_num_inputs(5) + .add_argument("data", "Quantized Tensor", "The input data.") + .add_argument("scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") + .add_argument("output_zero_point", "Tensor", + "The quantization zero_point of the output tensor.") + .set_support_level(11) + .add_type_rel("QRsqrt", QnnRsqrtRel) + .set_attr("TNonComputational", true) + .set_attr("FTVMQnnCanonicalize", QnnRsqrtCanonicalize); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.rsqrt").set_body_typed(MakeQuantizedRsqrt); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index 79d5549d659a..c8f3524d51ea 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -109,6 +109,33 @@ static inline Expr Requantize(const Expr& data, const Array& input_sh attrs.operator->(), input_shape, attrs->out_dtype); } +Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale, + const Expr& input_zero_point, const Array& types, + const DequantizeAttrs* attrs); + +static inline Expr Dequantize(const Expr& data, const Expr& input_scale, + const Expr& input_zero_point, const Array& types, + const int& axis = -1) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + + return DequantizeLower(data, input_scale, input_zero_point, types, attrs.operator->()); +} + +Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale, + const Expr& output_zero_point, const Array& types, + const QuantizeAttrs* attrs); + +static inline Expr Quantize(const Expr& data, const Expr& output_scale, + const Expr& output_zero_point, const DataType& out_dtype, + const Array& types, const int& axis = -1) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + attrs->out_dtype = std::move(out_dtype); + + return QuantizeLower(data, output_scale, output_zero_point, types, attrs.operator->()); +} + static inline int64_t get_const_int(const tvm::PrimExpr& x) { auto* value_ptr = tir::as_const_int(x); ICHECK(value_ptr) << "Expr is not a constant int"; diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 69ad20a7ceaf..16a23a4ba699 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -550,6 +550,11 @@ inline Expr Sqrt(Expr x) { return Call(op, {x}, Attrs(), {}); } +inline Expr Rsqrt(Expr x) { + static const Op& op = Op::Get("rsqrt"); + return Call(op, {x}, Attrs(), {}); +} + inline Expr Relu(Expr x) { static const Op& op = Op::Get("nn.relu"); return Call(op, {x}, Attrs(), {}); diff --git a/tests/python/relay/test_op_qnn_rsqrt.py b/tests/python/relay/test_op_qnn_rsqrt.py new file mode 100644 index 000000000000..1eb9b64057ca --- /dev/null +++ b/tests/python/relay/test_op_qnn_rsqrt.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +from tvm import relay + + +def dequantize(data, scale, zp): + return scale * (np.asarray(data) - zp) + + +def generate_golden_output(dequantized_x, output_scale, output_zero_point): + rsqrt = 1 / np.sqrt(dequantized_x) + output = np.around(rsqrt / output_scale + output_zero_point) + + q_min = np.iinfo(np.uint8).min + q_max = np.iinfo(np.uint8).max + return np.clip(output, q_min, q_max) + + +def test_saturation(): + # Same params + data_dtype = "uint8" + scale = output_scale = 0.125 + zero_point = output_zero_point = 0 + + x = relay.var("x", shape=(1, 4), dtype=data_dtype) + y = relay.qnn.op.rsqrt( + x=x, + scale=relay.const(scale, "float32"), + zero_point=relay.const(zero_point, "int32"), + output_scale=relay.const(output_scale, "float32"), + output_zero_point=relay.const(output_zero_point, "int32"), + ) + + func = relay.Function([x], y) + mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) + mod = relay.qnn.transform.CanonicalizeOps()(mod) + func = mod["main"] + + x_data = np.array((255, 133, 0, 9)).reshape((1, 4)) + x_dequantized = dequantize(x_data, scale, zero_point) + golden_output = generate_golden_output(x_dequantized, output_scale, output_zero_point) + + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)(x_data) + + np.testing.assert_equal(op_res.numpy(), np.uint8(golden_output)) + + # Different scale + scale = 0.125 + output_scale = 0.25 + + y = relay.qnn.op.rsqrt( + x=x, + scale=relay.const(scale, "float32"), + zero_point=relay.const(zero_point, "int32"), + output_scale=relay.const(output_scale, "float32"), + output_zero_point=relay.const(output_zero_point, "int32"), + ) + + func = relay.Function([x], y) + mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) + mod = relay.qnn.transform.CanonicalizeOps()(mod) + func = mod["main"] + + x_data = np.array((255, 133, 0, 9)).reshape((1, 4)) + x_dequantized = dequantize(x_data, scale, zero_point) + golden_output = generate_golden_output(x_dequantized, output_scale, output_zero_point) + + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)(x_data) + + np.testing.assert_equal(op_res.numpy(), golden_output) + + +if __name__ == "__main__": + test_saturation() diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index b9e9c6692899..aee2741782fd 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -303,6 +303,19 @@ def test_fake_quantize_global_avg_pool(): compare_fq_to_int(op, [x_np], True) +def test_fake_quantize_rsqrt(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + zero = relay.const(0) + + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.rsqrt(x) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np], True) + + def test_fake_quantize_reshape(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")