Skip to content

Commit

Permalink
[QNN] Concat - Refactoring to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Aug 22, 2019
1 parent d3eb9cb commit 41c82e4
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 125 deletions.
28 changes: 28 additions & 0 deletions include/tvm/relay/qnn/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,34 @@ struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
}
};

/*! \brief Attributes used in QNN concatenate operators */
struct QnnConcatenateAttrs : public tvm::AttrsNode<QnnConcatenateAttrs> {
Array<tvm::Expr> input_scales;
Array<tvm::Expr> input_zero_points;
double output_scale;
int32_t output_zero_point;
int axis;

TVM_DECLARE_ATTRS(QnnConcatenateAttrs, "relay.attrs.QnnConcatenateAttrs") {
TVM_ATTR_FIELD(input_scales)
.describe("The list of scales of input quantized tensors.");

TVM_ATTR_FIELD(input_zero_points)
.describe("The list of zero points of input quantized tensors.");

TVM_ATTR_FIELD(output_zero_point)
.describe("The zero_point for the output tensor.");

TVM_ATTR_FIELD(output_scale)
.describe("The scale for the output tensor.");

TVM_ATTR_FIELD(axis)
.describe("The axis at which the input arrays are concatenated."
"Should lie in range `[-ndim, ndim)`.")
.set_default(0);
}
}; // struct QnnConcatenateAttrs

} // namespace qnn
} // namespace relay
} // namespace tvm
Expand Down
55 changes: 15 additions & 40 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
"""QNN dialect operators."""

from __future__ import absolute_import as _abs
from tvm import relay
from tvm.expr import FloatImm, IntImm
from tvm.relay.expr import Tuple
from . import _make

def requantize(data,
Expand Down Expand Up @@ -134,6 +135,8 @@ def dequantize(data,
return _make.dequantize(data,
input_scale,
input_zero_point)


def concatenate(data,
input_scales,
input_zero_points,
Expand Down Expand Up @@ -169,42 +172,14 @@ def concatenate(data,
"""

data = list(data)
requantized_exprs = list(data)

# Find the dtype of the input expr. This is required for the requantize op. Since, this is
# concatenate op, the dtype of the input is same as dtype of the output.
mod = relay.Module.from_expr(data[0])
mod = relay.transform.InferType()(mod)
entry = mod["main"]
data0 = entry if isinstance(data[0], relay.Function) else entry.body
in_dtype = data0.checked_type.dtype

# First check if all the input qnn params match. If yes, we can call concatenate first, followed
# by a requantize.
if all(scale == input_scales[0] for scale in input_scales)\
and all(zero_point == input_zero_points[0] for zero_point in input_zero_points):
out = relay.concatenate(tuple(data), axis)
input_scale = input_scales[0]
input_zero_point = input_zero_points[0]
if input_scale != output_scale or input_zero_point != output_zero_point:
out = requantize(data=out,
input_scale=input_scales[0],
input_zero_point=input_zero_points[0],
output_scale=output_scale,
output_zero_point=output_zero_point,
out_dtype=in_dtype)
return out

# If the output qnn params do not match the input qnn params, we can call requantize on the
# input expr first, followed by a concatenate on the requantized input exprs.
for idx, quantized_expr in enumerate(data):
input_scale = input_scales[idx]
input_zero_point = input_zero_points[idx]
if input_scale != output_scale or input_zero_point != output_zero_point:
requantized_exprs[idx] = requantize(data=quantized_expr,
input_scale=input_scale,
input_zero_point=input_zero_point,
output_scale=output_scale,
output_zero_point=output_zero_point,
out_dtype=in_dtype)
return relay.concatenate(tuple(requantized_exprs), axis)
if not data:
raise ValueError("relay.concatenate requires data to be non-empty.")
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")

return _make.concatenate(Tuple(data),
[FloatImm("float64", x) for x in input_scales],
[IntImm("int32", x) for x in input_zero_points],
output_scale,
output_zero_point,
axis)
83 changes: 2 additions & 81 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "../op_common.h"
#include "../../../arithmetic/compute_expr.h"
#include "../../pass/alter_op_layout.h"
#include "transform.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -210,86 +211,6 @@ RELAY_REGISTER_OP("expand_dims")
// relay.concatenate
TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);

bool ConcatenateRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, result]
CHECK_EQ(types.size(), 2);
/* If we receive a tuple we can continue, if we receive
* anything but an incomplete type we should signal an
* error.
*/
const auto* tensor_tuple = types[0].as<TupleTypeNode>();
if (tensor_tuple == nullptr) {
throw relay::Error(
RELAY_ERROR(
"concatenate requires a tuple of tensors as the first argument, found "
<< PrettyPrint(types[0])));
} else if (types[0].as<IncompleteTypeNode>() != nullptr) {
return false;
}

const auto* param = attrs.as<ConcatenateAttrs>();
if (tensor_tuple->fields[0].as<IncompleteTypeNode>()) {
return false;
}
const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
// Sanity check: ndim and dtype.
const int ndim = static_cast<int>(first->shape.size());
const DataType dtype = first->dtype;

for (const Type& ele : tensor_tuple->fields) {
if (ele.as<IncompleteTypeNode>()) {
return false;
}

const auto& e = Downcast<TensorType>(ele);

int e_ndim = static_cast<int>(e->shape.size());
const DataType& e_dtype = e->dtype;
if (e_ndim != ndim) {
throw relay::Error("relay.concatenate requires all tensors have the same ndim");
}
if (e_dtype != dtype) {
throw relay::Error("relay.concatenate requires all tensors have the same dtype");
}
}
// Sanity check: axis
int axis = param->axis;
if (!(-ndim <= axis && axis < ndim)) {
throw relay::Error(RELAY_ERROR(
"concatenate only accepts `axis` in [-ndim, ndim)" <<
", but got axis = " << axis <<
", and ndim = " << ndim));
}
axis = axis < 0 ? ndim + axis : axis;
// Calculate shape
std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end());
IndexExpr &concat_dim = oshape[axis];
bool has_any = false;
if (concat_dim.as<Any>()) {
has_any = true;
} else {
for (int i = 1; i < static_cast<int>(tensor_tuple->fields.size()); ++i) {
const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
if (e->shape[axis].as<Any>()) {
has_any = true;
break;
}
concat_dim += e->shape[axis];
}
}

if (has_any) {
concat_dim = Any::make();
}

auto rtype = TensorTypeNode::make(oshape, dtype);
reporter->Assign(types[1], rtype);
return true;
}

Array<Tensor> ConcatenateCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
Expand Down Expand Up @@ -358,7 +279,7 @@ RELAY_REGISTER_OP("concatenate")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input list of tensors.")
.set_support_level(1)
.add_type_rel("Concatenate", ConcatenateRel)
.add_type_rel("Concatenate", ConcatenateRel<ConcatenateAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConcatenateLayout)
.set_attr<FTVMCompute>("FTVMCompute", ConcatenateCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
Expand Down
122 changes: 122 additions & 0 deletions src/relay/op/tensor/transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/op/tensor/transform.h
* \brief Tranform op attributes that can be shared among Relay and its dialects.
*/
#ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_
#define TVM_RELAY_OP_TENSOR_TRANSFORM_H_

#include <algorithm>
#include <limits>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>

namespace tvm {
namespace relay {

template <typename AttrType>
bool ConcatenateRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, result]
CHECK_EQ(types.size(), 2);
/* If we receive a tuple we can continue, if we receive
* anything but an incomplete type we should signal an
* error.
*/
const auto* tensor_tuple = types[0].as<TupleTypeNode>();
if (tensor_tuple == nullptr) {
throw relay::Error(
RELAY_ERROR(
"concatenate requires a tuple of tensors as the first argument, found "
<< PrettyPrint(types[0])));
} else if (types[0].as<IncompleteTypeNode>() != nullptr) {
return false;
}

const auto* param = attrs.as<AttrType>();
CHECK(param != nullptr);
if (tensor_tuple->fields[0].as<IncompleteTypeNode>()) {
return false;
}
const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
// Sanity check: ndim and dtype.
const int ndim = static_cast<int>(first->shape.size());
const DataType dtype = first->dtype;

for (const Type& ele : tensor_tuple->fields) {
if (ele.as<IncompleteTypeNode>()) {
return false;
}

const auto& e = Downcast<TensorType>(ele);

int e_ndim = static_cast<int>(e->shape.size());
const DataType& e_dtype = e->dtype;
if (e_ndim != ndim) {
throw relay::Error("relay.concatenate requires all tensors have the same ndim");
}
if (e_dtype != dtype) {
throw relay::Error("relay.concatenate requires all tensors have the same dtype");
}
}
// Sanity check: axis
int axis = param->axis;
if (!(-ndim <= axis && axis < ndim)) {
throw relay::Error(RELAY_ERROR(
"concatenate only accepts `axis` in [-ndim, ndim)" <<
", but got axis = " << axis <<
", and ndim = " << ndim));
}
axis = axis < 0 ? ndim + axis : axis;
// Calculate shape
std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end());
IndexExpr &concat_dim = oshape[axis];
bool has_any = false;
if (concat_dim.as<Any>()) {
has_any = true;
} else {
for (int i = 1; i < static_cast<int>(tensor_tuple->fields.size()); ++i) {
const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
if (e->shape[axis].as<Any>()) {
has_any = true;
break;
}
concat_dim += e->shape[axis];
}
}

if (has_any) {
concat_dim = Any::make();
}

auto rtype = TensorTypeNode::make(oshape, dtype);
reporter->Assign(types[1], rtype);
return true;
}

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_TENSOR_TRANSFORM_H_
7 changes: 7 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,13 @@ static inline Expr Full(Expr fill_value,
return CallNode::make(op, {fill_value}, Attrs(attrs), {});
}

static inline Expr Concatenate(Expr data, int axis) {
auto attrs = make_node<ConcatenateAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("concatenate");
return CallNode::make(op, {data}, Attrs(attrs), {});
}

Expr MakeConcatenate(Expr data, int axis);

Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
Expand Down
Loading

0 comments on commit 41c82e4

Please sign in to comment.