diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 1ebdeaa7d708a..358b493f12e1c 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -97,6 +97,34 @@ struct DequantizeAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in QNN concatenate operators */ +struct QnnConcatenateAttrs : public tvm::AttrsNode { + Array input_scales; + Array input_zero_points; + double output_scale; + int32_t output_zero_point; + int axis; + + TVM_DECLARE_ATTRS(QnnConcatenateAttrs, "relay.attrs.QnnConcatenateAttrs") { + TVM_ATTR_FIELD(input_scales) + .describe("The list of scales of input quantized tensors."); + + TVM_ATTR_FIELD(input_zero_points) + .describe("The list of zero points of input quantized tensors."); + + TVM_ATTR_FIELD(output_zero_point) + .describe("The zero_point for the output tensor."); + + TVM_ATTR_FIELD(output_scale) + .describe("The scale for the output tensor."); + + TVM_ATTR_FIELD(axis) + .describe("The axis at which the input arrays are concatenated." + "Should lie in range `[-ndim, ndim)`.") + .set_default(0); + } +}; // struct QnnConcatenateAttrs + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index b153cd58876a7..7eb0408785dcc 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -18,7 +18,8 @@ """QNN dialect operators.""" from __future__ import absolute_import as _abs -from tvm import relay +from tvm.expr import FloatImm, IntImm +from tvm.relay.expr import Tuple from . import _make def requantize(data, @@ -134,6 +135,8 @@ def dequantize(data, return _make.dequantize(data, input_scale, input_zero_point) + + def concatenate(data, input_scales, input_zero_points, @@ -169,42 +172,14 @@ def concatenate(data, """ data = list(data) - requantized_exprs = list(data) - - # Find the dtype of the input expr. This is required for the requantize op. Since, this is - # concatenate op, the dtype of the input is same as dtype of the output. - mod = relay.Module.from_expr(data[0]) - mod = relay.transform.InferType()(mod) - entry = mod["main"] - data0 = entry if isinstance(data[0], relay.Function) else entry.body - in_dtype = data0.checked_type.dtype - - # First check if all the input qnn params match. If yes, we can call concatenate first, followed - # by a requantize. - if all(scale == input_scales[0] for scale in input_scales)\ - and all(zero_point == input_zero_points[0] for zero_point in input_zero_points): - out = relay.concatenate(tuple(data), axis) - input_scale = input_scales[0] - input_zero_point = input_zero_points[0] - if input_scale != output_scale or input_zero_point != output_zero_point: - out = requantize(data=out, - input_scale=input_scales[0], - input_zero_point=input_zero_points[0], - output_scale=output_scale, - output_zero_point=output_zero_point, - out_dtype=in_dtype) - return out - - # If the output qnn params do not match the input qnn params, we can call requantize on the - # input expr first, followed by a concatenate on the requantized input exprs. - for idx, quantized_expr in enumerate(data): - input_scale = input_scales[idx] - input_zero_point = input_zero_points[idx] - if input_scale != output_scale or input_zero_point != output_zero_point: - requantized_exprs[idx] = requantize(data=quantized_expr, - input_scale=input_scale, - input_zero_point=input_zero_point, - output_scale=output_scale, - output_zero_point=output_zero_point, - out_dtype=in_dtype) - return relay.concatenate(tuple(requantized_exprs), axis) + if not data: + raise ValueError("relay.concatenate requires data to be non-empty.") + if not isinstance(axis, int): + raise ValueError("For now, we only support integer axis") + + return _make.concatenate(Tuple(data), + [FloatImm("float64", x) for x in input_scales], + [IntImm("int32", x) for x in input_zero_points], + output_scale, + output_zero_point, + axis) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 03a92b35d3969..e2b812d8f75bc 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -37,6 +37,7 @@ #include "../op_common.h" #include "../../../arithmetic/compute_expr.h" #include "../../pass/alter_op_layout.h" +#include "transform.h" namespace tvm { namespace relay { @@ -210,86 +211,6 @@ RELAY_REGISTER_OP("expand_dims") // relay.concatenate TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); -bool ConcatenateRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - // types: [data, result] - CHECK_EQ(types.size(), 2); - /* If we receive a tuple we can continue, if we receive - * anything but an incomplete type we should signal an - * error. - */ - const auto* tensor_tuple = types[0].as(); - if (tensor_tuple == nullptr) { - throw relay::Error( - RELAY_ERROR( - "concatenate requires a tuple of tensors as the first argument, found " - << PrettyPrint(types[0]))); - } else if (types[0].as() != nullptr) { - return false; - } - - const auto* param = attrs.as(); - if (tensor_tuple->fields[0].as()) { - return false; - } - const auto& first = Downcast(tensor_tuple->fields[0]); - // Sanity check: ndim and dtype. - const int ndim = static_cast(first->shape.size()); - const DataType dtype = first->dtype; - - for (const Type& ele : tensor_tuple->fields) { - if (ele.as()) { - return false; - } - - const auto& e = Downcast(ele); - - int e_ndim = static_cast(e->shape.size()); - const DataType& e_dtype = e->dtype; - if (e_ndim != ndim) { - throw relay::Error("relay.concatenate requires all tensors have the same ndim"); - } - if (e_dtype != dtype) { - throw relay::Error("relay.concatenate requires all tensors have the same dtype"); - } - } - // Sanity check: axis - int axis = param->axis; - if (!(-ndim <= axis && axis < ndim)) { - throw relay::Error(RELAY_ERROR( - "concatenate only accepts `axis` in [-ndim, ndim)" << - ", but got axis = " << axis << - ", and ndim = " << ndim)); - } - axis = axis < 0 ? ndim + axis : axis; - // Calculate shape - std::vector oshape(first->shape.begin(), first->shape.end()); - IndexExpr &concat_dim = oshape[axis]; - bool has_any = false; - if (concat_dim.as()) { - has_any = true; - } else { - for (int i = 1; i < static_cast(tensor_tuple->fields.size()); ++i) { - const auto& e = Downcast(tensor_tuple->fields[i]); - if (e->shape[axis].as()) { - has_any = true; - break; - } - concat_dim += e->shape[axis]; - } - } - - if (has_any) { - concat_dim = Any::make(); - } - - auto rtype = TensorTypeNode::make(oshape, dtype); - reporter->Assign(types[1], rtype); - return true; -} - Array ConcatenateCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, @@ -358,7 +279,7 @@ RELAY_REGISTER_OP("concatenate") .set_num_inputs(1) .add_argument("data", "Tensor", "The input list of tensors.") .set_support_level(1) -.add_type_rel("Concatenate", ConcatenateRel) +.add_type_rel("Concatenate", ConcatenateRel) .set_attr("FInferCorrectLayout", ConcatenateLayout) .set_attr("FTVMCompute", ConcatenateCompute) .set_attr("TOpPattern", kInjective); diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h new file mode 100644 index 0000000000000..64ff56f8032b0 --- /dev/null +++ b/src/relay/op/tensor/transform.h @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/op/tensor/transform.h + * \brief Tranform op attributes that can be shared among Relay and its dialects. + */ +#ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_ +#define TVM_RELAY_OP_TENSOR_TRANSFORM_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +template +bool ConcatenateRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, result] + CHECK_EQ(types.size(), 2); + /* If we receive a tuple we can continue, if we receive + * anything but an incomplete type we should signal an + * error. + */ + const auto* tensor_tuple = types[0].as(); + if (tensor_tuple == nullptr) { + throw relay::Error( + RELAY_ERROR( + "concatenate requires a tuple of tensors as the first argument, found " + << PrettyPrint(types[0]))); + } else if (types[0].as() != nullptr) { + return false; + } + + const auto* param = attrs.as(); + CHECK(param != nullptr); + if (tensor_tuple->fields[0].as()) { + return false; + } + const auto& first = Downcast(tensor_tuple->fields[0]); + // Sanity check: ndim and dtype. + const int ndim = static_cast(first->shape.size()); + const DataType dtype = first->dtype; + + for (const Type& ele : tensor_tuple->fields) { + if (ele.as()) { + return false; + } + + const auto& e = Downcast(ele); + + int e_ndim = static_cast(e->shape.size()); + const DataType& e_dtype = e->dtype; + if (e_ndim != ndim) { + throw relay::Error("relay.concatenate requires all tensors have the same ndim"); + } + if (e_dtype != dtype) { + throw relay::Error("relay.concatenate requires all tensors have the same dtype"); + } + } + // Sanity check: axis + int axis = param->axis; + if (!(-ndim <= axis && axis < ndim)) { + throw relay::Error(RELAY_ERROR( + "concatenate only accepts `axis` in [-ndim, ndim)" << + ", but got axis = " << axis << + ", and ndim = " << ndim)); + } + axis = axis < 0 ? ndim + axis : axis; + // Calculate shape + std::vector oshape(first->shape.begin(), first->shape.end()); + IndexExpr &concat_dim = oshape[axis]; + bool has_any = false; + if (concat_dim.as()) { + has_any = true; + } else { + for (int i = 1; i < static_cast(tensor_tuple->fields.size()); ++i) { + const auto& e = Downcast(tensor_tuple->fields[i]); + if (e->shape[axis].as()) { + has_any = true; + break; + } + concat_dim += e->shape[axis]; + } + } + + if (has_any) { + concat_dim = Any::make(); + } + + auto rtype = TensorTypeNode::make(oshape, dtype); + reporter->Assign(types[1], rtype); + return true; +} + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_OP_TENSOR_TRANSFORM_H_ diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 18e5df3e04dfd..069a34d3b15f5 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -415,6 +415,13 @@ static inline Expr Full(Expr fill_value, return CallNode::make(op, {fill_value}, Attrs(attrs), {}); } +static inline Expr Concatenate(Expr data, int axis) { + auto attrs = make_node(); + attrs->axis = axis; + static const Op& op = Op::Get("concatenate"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + Expr MakeConcatenate(Expr data, int axis); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc new file mode 100644 index 0000000000000..a40286af94c50 --- /dev/null +++ b/src/relay/qnn/op/concatenate.cc @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/qnn/op/concatenate.cc + * \brief QNN concatenate operator. It concatenates quantized input tensors along a given axis. + */ + +#include +#include +#include +#include +#include "../../op/tensor/transform.h" +#include "../../pass/pattern_util.h" +#include "../util.h" + +namespace tvm { +namespace relay { +namespace qnn { + +TVM_REGISTER_NODE_TYPE(QnnConcatenateAttrs); + +Expr MakeQnnConcatenate(Expr data, Array input_scales, + Array input_zero_points, double output_scale, + int32_t output_zero_point, int axis) { + auto attrs = make_node(); + attrs->input_scales = input_scales; + attrs->input_zero_points = input_zero_points; + attrs->output_scale = output_scale; + attrs->output_zero_point = output_zero_point; + attrs->axis = axis; + static const Op& op = Op::Get("qnn.concatenate"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +/* + * \brief Legalizes the QNN concatenate op. + * \param ref_call The original call that will be lowered. + * \param new_args The new mutated args to the call node. + * \param ctx The node context. + * \return The sequence of Relay ops for concatenate op. + */ +Expr ConcatenateLegalize(const Attrs& attrs, const Array& new_args, + const Array& arg_types) { + // Get the attrs. + CHECK_EQ(new_args.size(), 1); + auto& data = new_args[0]; + const auto* concatenate_attrs = attrs.as(); + CHECK(concatenate_attrs != nullptr); + auto input_scales = concatenate_attrs->input_scales; + auto input_zero_points = concatenate_attrs->input_zero_points; + auto output_scale = concatenate_attrs->output_scale; + auto output_zero_point = concatenate_attrs->output_zero_point; + + // Get the input dtype and shape. + CHECK_EQ(arg_types.size(), 1); + auto tuple_type = arg_types[0].as(); + CHECK(tuple_type != nullptr); + auto tensor_type = tuple_type->fields[0].as(); + auto input_dtype = tensor_type->dtype; + auto input_shape = tensor_type->shape; + + // FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in + // the start, we can insert requantize at the end if and only if all the input tensors have same + // qnn params. This can be done in future. + + // If the output qnn params do not match the input qnn params, we can call requantize on the input + // expr first, followed by a concatenate on the requantized input exprs. + + auto tuple_data = data.as(); + CHECK(tuple_data != nullptr); + + int idx = 0; + Array requantized_exprs; + for (auto quantized_expr : tuple_data->fields) { + // Get the input scale for the idx quantized input tensor. + auto input_scale_expr = input_scales[idx].as(); + CHECK(input_scale_expr != nullptr); + auto input_scale = input_scale_expr->value; + + // Get the zero point for the idx quantized input tensor. + auto input_zero_point_expr = input_zero_points[idx].as(); + CHECK(input_zero_point_expr != nullptr); + auto input_zero_point = input_zero_point_expr->value; + + // Check if output and input qnn params are same. If not, requantize. + if (input_scale != output_scale || input_zero_point != output_zero_point) { + // Create the requantize attrs for calling the requantize op. + auto requantize_attrs = make_node(); + requantize_attrs->input_scale = input_scale; + requantize_attrs->input_zero_point = input_zero_point; + requantize_attrs->output_scale = output_scale; + requantize_attrs->output_zero_point = output_zero_point; + requantize_attrs->rounding = "TONEAREST"; + requantize_attrs->out_dtype = input_dtype; + + auto requantized_expr = + RequantizeLower(quantized_expr, requantize_attrs.operator->(), input_shape); + requantized_exprs.push_back(requantized_expr); + } else { + requantized_exprs.push_back(quantized_expr); + } + idx++; + } + return Concatenate(TupleNode::make(requantized_exprs), concatenate_attrs->axis); +} + +RELAY_REGISTER_OP("qnn.concatenate") +.describe(R"code(Concatenate the quantized input tensors along the given axis. +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.QnnConcatenateAttrs") +.set_num_inputs(1) +.add_argument("data", "Tensor", "The tensor to concatenate.") +.set_support_level(11) +.add_type_rel("QnnConcatenate", ConcatenateRel) +.set_attr("FTVMLegalize", ConcatenateLegalize); + +TVM_REGISTER_API("relay.qnn.op._make.concatenate") +.set_body_typed(MakeQnnConcatenate); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 1ada7ecd070e8..b10707a88d36c 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -67,6 +67,9 @@ static inline const int32_t GetQmax(const DataType& dtype) { } } +Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, + const Array& input_shape); + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_qnn_concatenate.py b/tests/python/relay/test_qnn_concatenate.py index b0745cf251c48..0a595435d81ab 100644 --- a/tests/python/relay/test_qnn_concatenate.py +++ b/tests/python/relay/test_qnn_concatenate.py @@ -39,7 +39,6 @@ def test_same_io_qnn_params(): axis=axis) func = relay.Function([x, y], z) - assert func.astext().count('requantize') == 0 mod = relay.Module.from_expr(func) mod = relay.transform.Legalize()(mod) func = mod["main"] @@ -68,7 +67,6 @@ def test_different_io_qnn_params(): axis=axis) func = relay.Function([x, y], z) - assert func.astext().count('requantize') == 2 mod = relay.Module.from_expr(func) mod = relay.transform.Legalize()(mod) func = mod["main"] @@ -97,7 +95,6 @@ def test_few_same_io_qnn_params(): axis=axis) func = relay.Function([x, y], z) - assert func.astext().count('requantize') == 1 mod = relay.Module.from_expr(func) mod = relay.transform.Legalize()(mod) func = mod["main"] @@ -126,7 +123,6 @@ def test_same_i_qnn_params(): axis=axis) func = relay.Function([x, y], z) - assert func.astext().count('requantize') == 1 mod = relay.Module.from_expr(func) mod = relay.transform.Legalize()(mod) func = mod["main"]