From c46b56c5031dcbf37d4d4ace38dc2dc176cadfd6 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 6 Aug 2019 23:29:32 +0000 Subject: [PATCH] Move to Legalize API. --- python/tvm/relay/qnn/__init__.py | 1 - python/tvm/relay/qnn/_transform.py | 22 -- python/tvm/relay/qnn/transform.py | 33 --- src/relay/qnn/op/requantize.cc | 202 +++++++++++++++-- src/relay/qnn/pass/qnn_lower.cc | 253 ---------------------- tests/python/relay/test_qnn_requantize.py | 2 +- 6 files changed, 189 insertions(+), 324 deletions(-) delete mode 100644 python/tvm/relay/qnn/_transform.py delete mode 100644 python/tvm/relay/qnn/transform.py delete mode 100644 src/relay/qnn/pass/qnn_lower.cc diff --git a/python/tvm/relay/qnn/__init__.py b/python/tvm/relay/qnn/__init__.py index fa888d7ce7dd..a472109add39 100644 --- a/python/tvm/relay/qnn/__init__.py +++ b/python/tvm/relay/qnn/__init__.py @@ -18,4 +18,3 @@ """QNN dialect operators and IR passes.""" from __future__ import absolute_import as _abs from . import op -from . import transform diff --git a/python/tvm/relay/qnn/_transform.py b/python/tvm/relay/qnn/_transform.py deleted file mode 100644 index e2ff6f9ed652..000000000000 --- a/python/tvm/relay/qnn/_transform.py +++ /dev/null @@ -1,22 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -#pylint: disable=unused-argument -"""Internal module for quantization.""" -from __future__ import absolute_import -from tvm._ffi.function import _init_api - -_init_api("relay.qnn._transform", __name__) diff --git a/python/tvm/relay/qnn/transform.py b/python/tvm/relay/qnn/transform.py deleted file mode 100644 index 6ca456b4fb81..000000000000 --- a/python/tvm/relay/qnn/transform.py +++ /dev/null @@ -1,33 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name - -"""QNN Dialect transformation passes.""" -from __future__ import absolute_import - -from . import _transform - -def QnnLower(): - """ - Rewrites the high-level quantized ops into low-level exisiting Relay ops. - - Returns - ------- - Pass : tvm.relay.transform.Pass - The optmized pas. - """ - return _transform.QnnLower() diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 2e78d20721be..04f7e80d5c64 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -23,9 +23,10 @@ * \brief QNN requantize operator. */ -#include #include +#include #include +#include "../../pass/pattern_util.h" #include "../util.h" namespace tvm { @@ -34,6 +35,185 @@ namespace qnn { TVM_REGISTER_NODE_TYPE(RequantizeAttrs); +// Lowering of qnn.requantize op + +/* + * \brief Convert FP32 representation into fixed point representation. + * \param double_multplier The input FP32 number. + * \return The pair of multiplier and shift for fixed point representation. + * \note Converts a floating point number so that it can be represented by + * integers. The representation is + * float_number = (significand) * 2^(exponent) + * + * The significand is a number between 0.5 and 1. This is represented by + * an integer number. For example, if it is int32, then the decimal point + * exists between bit 31 and 30 from LSB (or between first and second bit + * from the left). + * + * Some examples are + * 0.25 = (0.5) * 2^(-1) + * 0.125 = (0.5) * 2^(-2) + * + * Credit to TFLite reference implementation. + */ +std::pair GetFixedPointMultiplierShift(double double_multiplier) { + int32_t significand, exponent; + if (double_multiplier == 0.) { + significand = 0; + exponent = 0; + return std::make_pair(significand, exponent); + } + + // Get the significand and exponent. + double significand_d = std::frexp(double_multiplier, &exponent); + + // Convert the double significand to int significand, i.e., convert into a + // integer where the decimal point is between bit 31 and 30. This is done by + // multiplying the double value with 2^31 and then casting to int. + significand_d = std::round(significand_d * (1ll << 31)); + auto significand_int64 = static_cast(significand_d); + CHECK_LE(significand_int64, (1ll << 31)); + if (significand_int64 == (1ll << 31)) { + significand_int64 /= 2; + ++exponent; + } + CHECK_LE(significand_int64, std::numeric_limits::max()); + significand = static_cast(significand_int64); + return std::make_pair(significand, exponent); +} + +/* + * \brief Lower requantize to a sequence of ops. + * \param input_tensor The input tensor to requantize op. + * \param param The requantize op attrs. + * \param input_shape The input tensor shape of the requantize op. + * \return The sequence of existing Relay ops. + * \note Requantization using only integer computation. Here, the computation is + * converted to a fixed point computation by computing output multiplier + * and shift. This is useful, if the target device does not support/have + * very expensive floating point computations. + * + * Original compuation is scale_fp32 * quantized_tensor. To convert into + * integer computation, the multiplication with fp32 scalar can be + * replaced by multiplication with an int value and then right shifting + * the result. This approximates the floating point computation with a + * fixed point computation. + * + * The whole computation this can be broken down into following steps + * 1) Calculate the integer multiplier and integer shift. + * 2) Subtract the input integer zero point. + * 3) Multiply the fixed point multiplier with quantized tensor. + * 4) Round the result. + * 5) Right shift the result. + * 6) Add the output zero point. + * 7) Cast to the out_dtype. + */ +Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, + const Array& input_shape) { + double double_multiplier = param->input_scale / param->output_scale; + + // Choose high precision datatype to be int64. This is for avoiding overflow + // in multiplication of two int32 values. + DataType hp_dtype = Int(64); + + // 1) Calculating the integer multiplier and integer shift + int32_t fixed_point_multiplier, shift; + std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier); + int left_shift = shift > 0 ? shift : 0; + int right_shift = shift > 0 ? 0 : -shift; + + // 2) Subtract the input_zero_point + auto tensor = Cast(input_tensor, hp_dtype); + if (param->input_zero_point != 0) { + auto input_zp = MakeConstantScalar(hp_dtype, param->input_zero_point); + tensor = Subtract(tensor, input_zp); + } + + // 3) Multiply the integer multiplier + if (left_shift != 0) { + tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift)); + } + // Perform the multiplication in higher precision. + // The scalar is a fixed point value of int32 where the decimal point is + // between bits 31 and 30. After multiplying with input_tensor, the result is + // in int64 where the decimal point is sitting between bits 31 and 30 (from + // the right, rightmost bit is bit 0). The computation is performed in higher + // precision to avoid overflow in multiplying two int32 values. + Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier); + auto multiplied_t = Multiply(tensor, scalar); + + // 4) Find the rounding scalar. This depends on where the final decimal point + // sits. As we will be right shifting the multiplied_t, we need to first + // calculate the total_right_shift. + int total_right_shift = right_shift + 31; + int64_t pos_rounding_value = (1ll << (total_right_shift - 1)); + + tensor = multiplied_t; + Expr round_scalar; + if (param->rounding == "UPWARD") { + round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value); + } else if (param->rounding == "TONEAREST") { + auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); + auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); + auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype); + auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); + + auto zero = MakeConstantScalar(hp_dtype, 0); + auto zero_t = Full(zero, input_shape, hp_dtype); + round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); + } + // Add the rounding scalar. + tensor = Add(tensor, round_scalar); + + // 5) Simply right shift the result to get the final output. + auto scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); + + // 6) Add the output zero point. + auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point); + auto shifted_int64_t = Add(output_zp, scaled_int64_t); + + // 7) Clip to the out_dtype min/max. + auto q_min = GetQmin(param->out_dtype); + auto q_max = GetQmax(param->out_dtype); + auto clipped_t = Clip(shifted_int64_t, q_min, q_max); + return Cast(clipped_t, param->out_dtype); +} + +/* + * \brief Forward rewrite the requantize op. + * \param ref_call The original call that will be lowered. + * \param new_args The new mutated args to the call node. + * \param ctx The node context. + * \return The sequence of Relay ops for requantize op. + * \note Lowering of the requantize operation. The requantize operator converts + * one quantized tensor to another quantized tensor. For the output + * tensor, we are provided with output scale and zero point. The + * computation looks like this + * + * Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) + */ +Expr RequantizeLegalize(const Attrs& attrs, const Array& new_args, + const Array& arg_types) { + CHECK_EQ(new_args.size(), 1); + auto& quantized_data = new_args[0]; + const auto* param = attrs.as(); + CHECK(param != nullptr); + + // Find input shape. + CHECK_EQ(arg_types.size(), 1); + auto input_dtype = arg_types[0]; + auto input_tensor_type = input_dtype.as(); + CHECK(input_tensor_type != nullptr) << "Type information missing." + << " Please run infer_type pass."; + Array input_shape = input_tensor_type->shape; + + // Check rounding validity. + CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST") + << "QNN requantize supports two rounding modes - UPWARD and " + << "TONEAREST"; + return RequantizeLower(quantized_data, param, input_shape); +} + /* * \brief Infer shape function of Requantize op. * \param types The types of input args. @@ -42,35 +222,28 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs); * \param reporter The type reporter that sets the dtype and shapes. * \return True if the infer shape succeeded. */ -bool RequantizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); const auto in_dtype = data->dtype; CHECK(in_dtype == Int(8) || in_dtype == UInt(8) || in_dtype == Int(32)) - << "Input type should be an integer but was " << in_dtype; + << "Input type should be an integer but was " << in_dtype; const Array oshape = data->shape; // assign output type const RequantizeAttrs* param = attrs.as(); auto out_dtype = param->out_dtype; CHECK(out_dtype == Int(8) || out_dtype == UInt(8) || out_dtype == Int(32)) - << "Output type should be an integer but was " << out_dtype; + << "Output type should be an integer but was " << out_dtype; reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype)); return true; } // Positional relay function to create qnn requantize operator // used by frontend FFI. -Expr MakeRequantize(Expr data, - double input_scale, - int32_t input_zero_point, - double output_scale, - int32_t output_zero_point, - std::string rounding, - DataType out_dtype) { +Expr MakeRequantize(Expr data, double input_scale, int32_t input_zero_point, double output_scale, + int32_t output_zero_point, std::string rounding, DataType out_dtype) { auto attrs = make_node(); attrs->input_scale = std::move(input_scale); attrs->input_zero_point = std::move(input_zero_point); @@ -95,7 +268,8 @@ Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) .set_num_inputs(1) .add_argument("data", "Tensor", "The quantized input tensor.") .set_support_level(11) -.add_type_rel("Requantize", RequantizeRel); +.add_type_rel("Requantize", RequantizeRel) +.set_attr("FTVMLegalize", RequantizeLegalize); TVM_REGISTER_API("relay.qnn.op._make.requantize") .set_body_typed(MakeRequantize); diff --git a/src/relay/qnn/pass/qnn_lower.cc b/src/relay/qnn/pass/qnn_lower.cc deleted file mode 100644 index 5ac6f4b7bda4..000000000000 --- a/src/relay/qnn/pass/qnn_lower.cc +++ /dev/null @@ -1,253 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file qnn_lower.cc - * \brief Lower qnn ops to a sequence of existing Relay ops. - */ - -#include -#include -#include -#include -#include "../util.h" -#include "../../pass/pattern_util.h" - -namespace tvm { -namespace relay { -namespace qnn { -/*! - * \brief namespace of qnn lower pass. - * - * Use namespace to reduce potential naming conflict. - */ -namespace qnn_lower { - -using runtime::TypedPackedFunc; - -// Lowering of qnn.requantize op - -/* - * \brief Convert FP32 representation into fixed point representation. - * \param double_multplier The input FP32 number. - * \return The pair of multiplier and shift for fixed point representation. - * \note Converts a floating point number so that it can be represented by - * integers. The representation is - * float_number = (significand) * 2^(exponent) - * - * The significand is a number between 0.5 and 1. This is represented by - * an integer number. For example, if it is int32, then the decimal point - * exists between bit 31 and 30 from LSB (or between first and second bit - * from the left). - * - * Some examples are - * 0.25 = (0.5) * 2^(-1) - * 0.125 = (0.5) * 2^(-2) - * - * Credit to TFLite reference implementation. - */ -std::pair GetFixedPointMultiplierShift( - double double_multiplier) { - int32_t significand, exponent; - if (double_multiplier == 0.) { - significand = 0; - exponent = 0; - return std::make_pair(significand, exponent); - } - - // Get the significand and exponent. - double significand_d = std::frexp(double_multiplier, &exponent); - - // Convert the double significand to int significand, i.e., convert into a - // integer where the decimal point is between bit 31 and 30. This is done by - // multiplying the double value with 2^31 and then casting to int. - significand_d = std::round(significand_d * (1ll << 31)); - auto significand_int64 = static_cast(significand_d); - CHECK_LE(significand_int64, (1ll << 31)); - if (significand_int64 == (1ll << 31)) { - significand_int64 /= 2; - ++exponent; - } - CHECK_LE(significand_int64, std::numeric_limits::max()); - significand = static_cast(significand_int64); - return std::make_pair(significand, exponent); -} - -/* - * \brief Lower requantize to a sequence of ops. - * \param input_tensor The input tensor to requantize op. - * \param param The requantize op attrs. - * \param out_shape The output shape of the requantize op. - * \return The sequence of existing Relay ops. - * \note Requantization using only integer computation. Here, the computation is - * converted to a fixed point computation by computing output multiplier - * and shift. This is useful, if the target device does not support/have - * very expensive floating point computations. - * - * Original compuation is scale_fp32 * quantized_tensor. To convert into - * integer computation, the multiplication with fp32 scalar can be - * replaced by multiplication with an int value and then right shifting - * the result. This approximates the floating point computation with a - * fixed point computation. - * - * The whole computation this can be broken down into following steps - * 1) Calculate the integer multiplier and integer shift. - * 2) Subtract the input integer zero point. - * 3) Multiply the fixed point multiplier with quantized tensor. - * 4) Round the result. - * 5) Right shift the result. - * 6) Add the output zero point. - * 7) Cast to the out_dtype. - */ -Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, - const Array& out_shape) { - double double_multiplier = param->input_scale/param->output_scale; - - // Choose high precision datatype to be int64. This is for avoiding overflow - // in multiplication of two int32 values. - DataType hp_dtype = Int(64); - - // 1) Calculating the integer multiplier and integer shift - int32_t fixed_point_multiplier, shift; - std::tie(fixed_point_multiplier, shift) = - GetFixedPointMultiplierShift(double_multiplier); - int left_shift = shift > 0 ? shift : 0; - int right_shift = shift > 0 ? 0 : -shift; - - // 2) Subtract the input_zero_point - auto tensor = Cast(input_tensor, hp_dtype); - if (param->input_zero_point != 0) { - auto input_zp = MakeConstantScalar(hp_dtype, param->input_zero_point); - tensor = Subtract(tensor, input_zp); - } - - // 3) Multiply the integer multiplier - if (left_shift != 0) { - tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift)); - } - // Perform the multiplication in higher precision. - // The scalar is a fixed point value of int32 where the decimal point is - // between bits 31 and 30. After multiplying with input_tensor, the result is - // in int64 where the decimal point is sitting between bits 31 and 30 (from - // the right, rightmost bit is bit 0). The computation is performed in higher - // precision to avoid overflow in multiplying two int32 values. - Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier); - auto multiplied_t = Multiply(tensor, scalar); - - // 4) Find the rounding scalar. This depends on where the final decimal point - // sits. As we will be right shifting the multiplied_t, we need to first - // calculate the total_right_shift. - int total_right_shift = right_shift + 31; - int64_t pos_rounding_value = (1ll << (total_right_shift - 1)); - - tensor = multiplied_t; - Expr round_scalar; - if (param->rounding == "UPWARD") { - round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value); - } else if (param->rounding == "TONEAREST") { - auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); - auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); - auto pos_rounder_t = Full(pos_rounder, out_shape, hp_dtype); - auto neg_rounder_t = Full(neg_rounder, out_shape, hp_dtype); - - auto zero = MakeConstantScalar(hp_dtype, 0); - auto zero_t = Full(zero, out_shape, hp_dtype); - round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, - neg_rounder_t); - } - // Add the rounding scalar. - tensor = Add(tensor, round_scalar); - - // 5) Simply right shift the result to get the final output. - auto scaled_int64_t = RightShift(tensor, - MakeConstantScalar(hp_dtype, total_right_shift)); - - // 6) Add the output zero point. - auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point); - auto shifted_int64_t = Add(output_zp, scaled_int64_t); - - // 7) Clip to the out_dtype min/max. - auto q_min = GetQmin(param->out_dtype); - auto q_max = GetQmax(param->out_dtype); - auto clipped_t = Clip(shifted_int64_t, q_min, q_max); - return Cast(clipped_t, param->out_dtype); -} - -/* - * \brief Forward rewrite the requantize op. - * \param ref_call The original call that will be lowered. - * \param new_args The new mutated args to the call node. - * \param ctx The node context. - * \return The sequence of Relay ops for requantize op. - * \note Lowering of the requantize operation. The requantize operator converts - * one quantized tensor to another quantized tensor. For the output - * tensor, we are provided with output scale and zero point. The - * computation looks like this - * - * Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) - */ -Expr RequantizeForwardRewrite(const Call& ref_call, - const Array& new_args, const NodeRef& ctx) { - CHECK_EQ(new_args.size(), 1); - Expr quantized_data = new_args[0]; - const auto* param = ref_call->attrs.as(); - CHECK(param != nullptr); - - // Find output shape. - auto ref_call_t = ref_call->checked_type(); - auto output_tt = ref_call_t.as(); - CHECK(output_tt != nullptr) << "Type information missing." - << " Please run infer_type pass."; - Array out_shape = output_tt->shape; - - // Check rounding validity. - CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST") - << "QNN requantize supports two rounding modes - UPWARD and " - << "TONEAREST"; - return RequantizeLower(quantized_data, param, out_shape); -} - -RELAY_REGISTER_OP("qnn.requantize") -.set_attr("FQnnForwardRewrite", RequantizeForwardRewrite); - -Expr QnnLower(const Expr& expr) { - return ForwardRewrite(expr, "FQnnForwardRewrite", nullptr, nullptr); -} -} // namespace qnn_lower - -namespace transform { -using namespace tvm::relay::transform; -Pass QnnLower() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast( - relay::qnn::qnn_lower::QnnLower(f)); - }; - return CreateFunctionPass(pass_func, 0, "QnnLower", - {ir::StringImm::make("InferType")}); -} - -TVM_REGISTER_API("relay.qnn._transform.QnnLower") -.set_body_typed(QnnLower); -} // namespace transform - -} // namespace qnn -} // namespace relay -} // namespace tvm diff --git a/tests/python/relay/test_qnn_requantize.py b/tests/python/relay/test_qnn_requantize.py index 3925c1e5d573..cd478fb5ba22 100644 --- a/tests/python/relay/test_qnn_requantize.py +++ b/tests/python/relay/test_qnn_requantize.py @@ -57,7 +57,7 @@ def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale, mod = relay.Function(relay.analysis.free_vars(mod), mod) mod = relay.Module.from_expr(mod) - mod = relay.qnn.transform.QnnLower()(mod) + mod = relay.transform.Legalize()(mod) return mod def same_scale_test():