Skip to content

Commit

Permalink
[QNN] Concat - Refactoring to C++ (apache#3819)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and wweic committed Sep 16, 2019
1 parent f1639a8 commit 11d62df
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 45 deletions.
30 changes: 29 additions & 1 deletion include/tvm/relay/qnn/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
int32_t input_zero_point;
double input_scale;

TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero_point for the input tensor of this op.");

Expand All @@ -97,6 +97,34 @@ struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
}
};

/*! \brief Attributes used in QNN concatenate operator */
struct QnnConcatenateAttrs : public tvm::AttrsNode<QnnConcatenateAttrs> {
Array<tvm::Expr> input_scales;
Array<tvm::Expr> input_zero_points;
double output_scale;
int32_t output_zero_point;
int axis;

TVM_DECLARE_ATTRS(QnnConcatenateAttrs, "relay.attrs.QnnConcatenateAttrs") {
TVM_ATTR_FIELD(input_scales)
.describe("The list of scales of input quantized tensors.");

TVM_ATTR_FIELD(input_zero_points)
.describe("The list of zero points of input quantized tensors.");

TVM_ATTR_FIELD(output_zero_point)
.describe("The zero_point for the output tensor.");

TVM_ATTR_FIELD(output_scale)
.describe("The scale for the output tensor.");

TVM_ATTR_FIELD(axis)
.describe("The axis at which the input arrays are concatenated."
"Should lie in range `[-ndim, ndim)`.")
.set_default(0);
}
}; // struct QnnConcatenateAttrs

} // namespace qnn
} // namespace relay
} // namespace tvm
Expand Down
55 changes: 15 additions & 40 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
"""QNN dialect operators."""

from __future__ import absolute_import as _abs
from tvm import relay
from tvm.expr import FloatImm, IntImm
from tvm.relay.expr import Tuple
from . import _make

def requantize(data,
Expand Down Expand Up @@ -134,6 +135,8 @@ def dequantize(data,
return _make.dequantize(data,
input_scale,
input_zero_point)


def concatenate(data,
input_scales,
input_zero_points,
Expand Down Expand Up @@ -169,42 +172,14 @@ def concatenate(data,
"""

data = list(data)
requantized_exprs = list(data)

# Find the dtype of the input expr. This is required for the requantize op. Since, this is
# concatenate op, the dtype of the input is same as dtype of the output.
mod = relay.Module.from_expr(data[0])
mod = relay.transform.InferType()(mod)
entry = mod["main"]
data0 = entry if isinstance(data[0], relay.Function) else entry.body
in_dtype = data0.checked_type.dtype

# First check if all the input qnn params match. If yes, we can call concatenate first, followed
# by a requantize.
if all(scale == input_scales[0] for scale in input_scales)\
and all(zero_point == input_zero_points[0] for zero_point in input_zero_points):
out = relay.concatenate(tuple(data), axis)
input_scale = input_scales[0]
input_zero_point = input_zero_points[0]
if input_scale != output_scale or input_zero_point != output_zero_point:
out = requantize(data=out,
input_scale=input_scales[0],
input_zero_point=input_zero_points[0],
output_scale=output_scale,
output_zero_point=output_zero_point,
out_dtype=in_dtype)
return out

# If the output qnn params do not match the input qnn params, we can call requantize on the
# input expr first, followed by a concatenate on the requantized input exprs.
for idx, quantized_expr in enumerate(data):
input_scale = input_scales[idx]
input_zero_point = input_zero_points[idx]
if input_scale != output_scale or input_zero_point != output_zero_point:
requantized_exprs[idx] = requantize(data=quantized_expr,
input_scale=input_scale,
input_zero_point=input_zero_point,
output_scale=output_scale,
output_zero_point=output_zero_point,
out_dtype=in_dtype)
return relay.concatenate(tuple(requantized_exprs), axis)
if not data:
raise ValueError("relay.concatenate requires data to be non-empty.")
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")

return _make.concatenate(Tuple(data),
[FloatImm("float64", x) for x in input_scales],
[IntImm("int32", x) for x in input_zero_points],
output_scale,
output_zero_point,
axis)
134 changes: 134 additions & 0 deletions src/relay/qnn/op/concatenate.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/qnn/op/concatenate.cc
* \brief QNN concatenate operator. It concatenates quantized input tensors along a given axis.
*/

#include <tvm/ir.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../op/tensor/transform.h"
#include "../../pass/pattern_util.h"
#include "../util.h"

namespace tvm {
namespace relay {
namespace qnn {

TVM_REGISTER_NODE_TYPE(QnnConcatenateAttrs);

Expr MakeQnnConcatenate(Expr data, Array<tvm::Expr> input_scales,
Array<tvm::Expr> input_zero_points, double output_scale,
int32_t output_zero_point, int axis) {
auto attrs = make_node<QnnConcatenateAttrs>();
attrs->input_scales = std::move(input_scales);
attrs->input_zero_points = std::move(input_zero_points);
attrs->output_scale = output_scale;
attrs->output_zero_point = output_zero_point;
attrs->axis = axis;
static const Op& op = Op::Get("qnn.concatenate");
return CallNode::make(op, {data}, Attrs(attrs), {});
}

/*
* \brief Canonicalizes the QNN concatenate op.
* \param attrs The QNN concatenate attrs.
* \param new_args The new mutated args to the call node.
* \param arg_types The types of input and output.
* \return The sequence of Relay ops for concatenate op.
*/
Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
// Get the attrs.
CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0];
const auto* concatenate_attrs = attrs.as<QnnConcatenateAttrs>();
CHECK(concatenate_attrs != nullptr);
auto input_scales = concatenate_attrs->input_scales;
auto input_zero_points = concatenate_attrs->input_zero_points;
auto output_scale = concatenate_attrs->output_scale;
auto output_zero_point = concatenate_attrs->output_zero_point;

// Get the input dtype and shape.
CHECK_GE(arg_types.size(), 1);
auto tuple_type = arg_types[0].as<TupleTypeNode>();
CHECK(tuple_type != nullptr);

// FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in
// the start, we can insert requantize at the end if and only if all the input tensors have same
// qnn params. This can be done in future.

// If the output qnn params do not match the input qnn params, we can call requantize on the input
// expr first, followed by a concatenate on the requantized input exprs.

auto tuple_data = data.as<TupleNode>();
CHECK(tuple_data != nullptr);

int idx = 0;
Array<Expr> requantized_exprs;
for (auto quantized_expr : tuple_data->fields) {
// Get the input scale for the idx quantized input tensor.
auto input_scale_expr = input_scales[idx].as<tvm::ir::FloatImm>();
CHECK(input_scale_expr != nullptr);
auto input_scale = input_scale_expr->value;

// Get the zero point for the idx quantized input tensor.
auto input_zero_point_expr = input_zero_points[idx].as<tvm::ir::IntImm>();
CHECK(input_zero_point_expr != nullptr);
auto input_zero_point = input_zero_point_expr->value;

// Check if output and input qnn params are same. If not, requantize.
if (input_scale != output_scale || input_zero_point != output_zero_point) {
// Get the input shape and dtype.
auto tensor_type = tuple_type->fields[idx].as<TensorTypeNode>();
auto input_dtype = tensor_type->dtype;
auto input_shape = tensor_type->shape;

// Requantize the input.
auto requantized_expr = Requantize(quantized_expr, input_shape, input_scale, input_zero_point,
output_scale, output_zero_point, input_dtype);
requantized_exprs.push_back(requantized_expr);
} else {
requantized_exprs.push_back(quantized_expr);
}
idx++;
}
return MakeConcatenate(TupleNode::make(requantized_exprs), concatenate_attrs->axis);
}

RELAY_REGISTER_OP("qnn.concatenate")
.describe(R"code(Concatenate the quantized input tensors along the given axis.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.QnnConcatenateAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The tensor to concatenate.")
.set_support_level(11)
.add_type_rel("QnnConcatenate", ConcatenateRel<QnnConcatenateAttrs>)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize);

TVM_REGISTER_API("relay.qnn.op._make.concatenate")
.set_body_typed(MakeQnnConcatenate);

} // namespace qnn
} // namespace relay
} // namespace tvm
19 changes: 19 additions & 0 deletions src/relay/qnn/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <tvm/expr.h>
#include <tvm/relay/expr.h>
#include <limits>
#include <string>
#include <utility>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -67,6 +69,23 @@ static inline const int32_t GetQmax(const DataType& dtype) {
}
}

Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype);

static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_shape,
double input_scale, int32_t input_zero_point, double output_scale,
int32_t output_zero_point, const DataType& out_dtype,
const std::string& rounding = "TONEAREST") {
auto attrs = make_node<RequantizeAttrs>();
attrs->input_scale = std::move(input_scale);
attrs->input_zero_point = std::move(input_zero_point);
attrs->output_scale = std::move(output_scale);
attrs->output_zero_point = std::move(output_zero_point);
attrs->rounding = std::move(rounding);
attrs->out_dtype = std::move(out_dtype);
return RequantizeLower(data, attrs.operator->(), input_shape, out_dtype);
}

} // namespace qnn
} // namespace relay
} // namespace tvm
Expand Down
4 changes: 0 additions & 4 deletions tests/python/relay/test_qnn_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def test_same_io_qnn_params():
axis=axis)

func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 0
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]
Expand Down Expand Up @@ -68,7 +67,6 @@ def test_different_io_qnn_params():
axis=axis)

func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 2
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]
Expand Down Expand Up @@ -97,7 +95,6 @@ def test_few_same_io_qnn_params():
axis=axis)

func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 1
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]
Expand Down Expand Up @@ -126,7 +123,6 @@ def test_same_i_qnn_params():
axis=axis)

func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 1
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]
Expand Down

0 comments on commit 11d62df

Please sign in to comment.