Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN] Concat - Refactoring to C++ #3819

Merged
merged 1 commit into from
Aug 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion include/tvm/relay/qnn/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
int32_t input_zero_point;
double input_scale;

TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero_point for the input tensor of this op.");

Expand All @@ -97,6 +97,34 @@ struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
}
};

/*! \brief Attributes used in QNN concatenate operator */
struct QnnConcatenateAttrs : public tvm::AttrsNode<QnnConcatenateAttrs> {
Array<tvm::Expr> input_scales;
Array<tvm::Expr> input_zero_points;
double output_scale;
int32_t output_zero_point;
int axis;

TVM_DECLARE_ATTRS(QnnConcatenateAttrs, "relay.attrs.QnnConcatenateAttrs") {
TVM_ATTR_FIELD(input_scales)
.describe("The list of scales of input quantized tensors.");

TVM_ATTR_FIELD(input_zero_points)
.describe("The list of zero points of input quantized tensors.");

TVM_ATTR_FIELD(output_zero_point)
.describe("The zero_point for the output tensor.");

TVM_ATTR_FIELD(output_scale)
.describe("The scale for the output tensor.");

TVM_ATTR_FIELD(axis)
.describe("The axis at which the input arrays are concatenated."
"Should lie in range `[-ndim, ndim)`.")
.set_default(0);
}
}; // struct QnnConcatenateAttrs

} // namespace qnn
} // namespace relay
} // namespace tvm
Expand Down
55 changes: 15 additions & 40 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
"""QNN dialect operators."""

from __future__ import absolute_import as _abs
from tvm import relay
from tvm.expr import FloatImm, IntImm
from tvm.relay.expr import Tuple
from . import _make

def requantize(data,
Expand Down Expand Up @@ -134,6 +135,8 @@ def dequantize(data,
return _make.dequantize(data,
input_scale,
input_zero_point)


def concatenate(data,
input_scales,
input_zero_points,
Expand Down Expand Up @@ -169,42 +172,14 @@ def concatenate(data,
"""

data = list(data)
requantized_exprs = list(data)

# Find the dtype of the input expr. This is required for the requantize op. Since, this is
# concatenate op, the dtype of the input is same as dtype of the output.
mod = relay.Module.from_expr(data[0])
mod = relay.transform.InferType()(mod)
entry = mod["main"]
data0 = entry if isinstance(data[0], relay.Function) else entry.body
in_dtype = data0.checked_type.dtype

# First check if all the input qnn params match. If yes, we can call concatenate first, followed
# by a requantize.
if all(scale == input_scales[0] for scale in input_scales)\
and all(zero_point == input_zero_points[0] for zero_point in input_zero_points):
out = relay.concatenate(tuple(data), axis)
input_scale = input_scales[0]
input_zero_point = input_zero_points[0]
if input_scale != output_scale or input_zero_point != output_zero_point:
out = requantize(data=out,
input_scale=input_scales[0],
input_zero_point=input_zero_points[0],
output_scale=output_scale,
output_zero_point=output_zero_point,
out_dtype=in_dtype)
return out

# If the output qnn params do not match the input qnn params, we can call requantize on the
# input expr first, followed by a concatenate on the requantized input exprs.
for idx, quantized_expr in enumerate(data):
input_scale = input_scales[idx]
input_zero_point = input_zero_points[idx]
if input_scale != output_scale or input_zero_point != output_zero_point:
requantized_exprs[idx] = requantize(data=quantized_expr,
input_scale=input_scale,
input_zero_point=input_zero_point,
output_scale=output_scale,
output_zero_point=output_zero_point,
out_dtype=in_dtype)
return relay.concatenate(tuple(requantized_exprs), axis)
if not data:
raise ValueError("relay.concatenate requires data to be non-empty.")
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")

return _make.concatenate(Tuple(data),
[FloatImm("float64", x) for x in input_scales],
[IntImm("int32", x) for x in input_zero_points],
output_scale,
output_zero_point,
axis)
134 changes: 134 additions & 0 deletions src/relay/qnn/op/concatenate.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/qnn/op/concatenate.cc
* \brief QNN concatenate operator. It concatenates quantized input tensors along a given axis.
*/

#include <tvm/ir.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../op/tensor/transform.h"
#include "../../pass/pattern_util.h"
#include "../util.h"

namespace tvm {
namespace relay {
namespace qnn {

TVM_REGISTER_NODE_TYPE(QnnConcatenateAttrs);

Expr MakeQnnConcatenate(Expr data, Array<tvm::Expr> input_scales,
Array<tvm::Expr> input_zero_points, double output_scale,
int32_t output_zero_point, int axis) {
auto attrs = make_node<QnnConcatenateAttrs>();
attrs->input_scales = std::move(input_scales);
attrs->input_zero_points = std::move(input_zero_points);
attrs->output_scale = output_scale;
attrs->output_zero_point = output_zero_point;
attrs->axis = axis;
static const Op& op = Op::Get("qnn.concatenate");
return CallNode::make(op, {data}, Attrs(attrs), {});
}

/*
* \brief Canonicalizes the QNN concatenate op.
* \param attrs The QNN concatenate attrs.
* \param new_args The new mutated args to the call node.
* \param arg_types The types of input and output.
* \return The sequence of Relay ops for concatenate op.
*/
Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
// Get the attrs.
CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0];
const auto* concatenate_attrs = attrs.as<QnnConcatenateAttrs>();
CHECK(concatenate_attrs != nullptr);
auto input_scales = concatenate_attrs->input_scales;
auto input_zero_points = concatenate_attrs->input_zero_points;
auto output_scale = concatenate_attrs->output_scale;
auto output_zero_point = concatenate_attrs->output_zero_point;

// Get the input dtype and shape.
CHECK_GE(arg_types.size(), 1);
auto tuple_type = arg_types[0].as<TupleTypeNode>();
CHECK(tuple_type != nullptr);

// FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in
// the start, we can insert requantize at the end if and only if all the input tensors have same
// qnn params. This can be done in future.

// If the output qnn params do not match the input qnn params, we can call requantize on the input
// expr first, followed by a concatenate on the requantized input exprs.

auto tuple_data = data.as<TupleNode>();
CHECK(tuple_data != nullptr);

int idx = 0;
Array<Expr> requantized_exprs;
for (auto quantized_expr : tuple_data->fields) {
// Get the input scale for the idx quantized input tensor.
auto input_scale_expr = input_scales[idx].as<tvm::ir::FloatImm>();
CHECK(input_scale_expr != nullptr);
auto input_scale = input_scale_expr->value;

// Get the zero point for the idx quantized input tensor.
auto input_zero_point_expr = input_zero_points[idx].as<tvm::ir::IntImm>();
CHECK(input_zero_point_expr != nullptr);
auto input_zero_point = input_zero_point_expr->value;

// Check if output and input qnn params are same. If not, requantize.
if (input_scale != output_scale || input_zero_point != output_zero_point) {
// Get the input shape and dtype.
auto tensor_type = tuple_type->fields[idx].as<TensorTypeNode>();
auto input_dtype = tensor_type->dtype;
auto input_shape = tensor_type->shape;

// Requantize the input.
auto requantized_expr = Requantize(quantized_expr, input_shape, input_scale, input_zero_point,
output_scale, output_zero_point, input_dtype);
requantized_exprs.push_back(requantized_expr);
} else {
requantized_exprs.push_back(quantized_expr);
}
idx++;
}
return MakeConcatenate(TupleNode::make(requantized_exprs), concatenate_attrs->axis);
}

RELAY_REGISTER_OP("qnn.concatenate")
.describe(R"code(Concatenate the quantized input tensors along the given axis.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.QnnConcatenateAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The tensor to concatenate.")
.set_support_level(11)
.add_type_rel("QnnConcatenate", ConcatenateRel<QnnConcatenateAttrs>)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize);

TVM_REGISTER_API("relay.qnn.op._make.concatenate")
.set_body_typed(MakeQnnConcatenate);

} // namespace qnn
} // namespace relay
} // namespace tvm
19 changes: 19 additions & 0 deletions src/relay/qnn/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <tvm/expr.h>
#include <tvm/relay/expr.h>
#include <limits>
#include <string>
#include <utility>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -67,6 +69,23 @@ static inline const int32_t GetQmax(const DataType& dtype) {
}
}

Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype);

static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_shape,
double input_scale, int32_t input_zero_point, double output_scale,
int32_t output_zero_point, const DataType& out_dtype,
const std::string& rounding = "TONEAREST") {
auto attrs = make_node<RequantizeAttrs>();
attrs->input_scale = std::move(input_scale);
attrs->input_zero_point = std::move(input_zero_point);
attrs->output_scale = std::move(output_scale);
attrs->output_zero_point = std::move(output_zero_point);
attrs->rounding = std::move(rounding);
attrs->out_dtype = std::move(out_dtype);
return RequantizeLower(data, attrs.operator->(), input_shape, out_dtype);
anijain2305 marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace qnn
} // namespace relay
} // namespace tvm
Expand Down
4 changes: 0 additions & 4 deletions tests/python/relay/test_qnn_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def test_same_io_qnn_params():
axis=axis)

func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 0
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]
Expand Down Expand Up @@ -68,7 +67,6 @@ def test_different_io_qnn_params():
axis=axis)

func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 2
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]
Expand Down Expand Up @@ -97,7 +95,6 @@ def test_few_same_io_qnn_params():
axis=axis)

func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 1
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]
Expand Down Expand Up @@ -126,7 +123,6 @@ def test_same_i_qnn_params():
axis=axis)

func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 1
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]
Expand Down