Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN] Requantize operator #3531

Merged
merged 37 commits into from
Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
847dd52
[Relay] [Quantization] WIP - Common files for the qauntization work.
Jul 8, 2019
ed11cd7
[Relay] [Quantization] WIP - Prototyping requantize op.
Jul 8, 2019
91b58a5
Requantize operator implementation.
anijain2305 Jul 10, 2019
13fcc70
Typo and lint fixes.
anijain2305 Jul 10, 2019
ac4dfdc
Doc fix.
anijain2305 Jul 10, 2019
01cad3a
Uncommenting the lint script (fixing mistake).
anijain2305 Jul 10, 2019
6405755
Modifying the unit tests.
anijain2305 Jul 10, 2019
7a49bee
Moving C++ files into src/relay/qnn
anijain2305 Jul 11, 2019
154e64f
Moving python files to python/tvm/relay/qnn. Some minor fixes.
anijain2305 Jul 11, 2019
324e75c
Moving the attrs.h inside the include directory.
anijain2305 Jul 11, 2019
ffec47f
Pushing files that I forgot earlier. Changing util location.
anijain2305 Jul 11, 2019
72436ff
Incorporating comments. API change. Lint fixes.
anijain2305 Jul 15, 2019
9a721ad
Modifying the GetFixedPointMultiplierShift API as per comments.
anijain2305 Jul 15, 2019
fb9cece
Forgot the dialect change.
anijain2305 Jul 15, 2019
be7101f
Changing rewrite to qnn_lower.
anijain2305 Jul 15, 2019
0a5642a
Renaming Quantize to Qnn for clarity.
anijain2305 Jul 15, 2019
a9c1ce0
Remove use_int_domain.
anijain2305 Jul 17, 2019
a0d0324
Incorportaing review comments.
anijain2305 Jul 19, 2019
513b544
Adding API doc for QNN dialect.
anijain2305 Jul 19, 2019
435ca27
Move the qnn_lower pass to transform namespace.
anijain2305 Jul 19, 2019
e4f6a4e
Moving from expr to module. Adding namespace in C++.
anijain2305 Jul 19, 2019
10a20d3
Minor sentence rewrites. Added qnn namespace.
anijain2305 Jul 19, 2019
927825d
Added the API doc.
anijain2305 Jul 19, 2019
48f5a52
Chanding default out_dtype to int8. Adding a test with in/out_dtype a…
anijain2305 Jul 19, 2019
1422f6d
Style fixes. Better error messages.
anijain2305 Jul 19, 2019
66a4d76
Adding documentation.
anijain2305 Jul 22, 2019
99483c2
More documentation fixes.
anijain2305 Jul 22, 2019
f8439e6
Adding out dtype check for requantize.
anijain2305 Jul 22, 2019
e756843
Adding corner case for FP32 to fixed point conversion.
anijain2305 Jul 22, 2019
5d7938f
Adding extra line.
anijain2305 Jul 22, 2019
10ce99d
Documentation fix.
anijain2305 Jul 22, 2019
f2e09d1
Adding static inline.
anijain2305 Jul 23, 2019
65c0b46
Incorporating jackwish comment. Removed idtype from requantize lowering.
anijain2305 Jul 24, 2019
8d2c3ad
Removing Quantize/Dequantize code. Restricting Requantize to (u)int8/…
anijain2305 Jul 26, 2019
2d15b54
Style fixes.
anijain2305 Jul 29, 2019
ff17a91
Fix the docs.
anijain2305 Aug 2, 2019
c46b56c
Move to Legalize API.
anijain2305 Aug 6, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,16 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.contrib.adaptive_avg_pool2d


**Level 11: Dialect Operators**

This level supports dialect operators.

.. autosummary::
:nosignatures:

tvm.relay.qnn.op.requantize


anijain2305 marked this conversation as resolved.
Show resolved Hide resolved
Level 1 Definitions
-------------------
.. autofunction:: tvm.relay.log
Expand Down Expand Up @@ -340,3 +350,8 @@ Level 10 Definitions
.. autofunction:: tvm.relay.nn.batch_matmul
.. autofunction:: tvm.relay.contrib.adaptive_max_pool2d
.. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d


Level 11 Definitions
--------------------
.. autofunction:: tvm.relay.qnn.op.requantize
71 changes: 71 additions & 0 deletions include/tvm/relay/qnn/attrs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/qnn/attrs.h
* \brief Auxiliary attributes for qnn operators.
*/
#ifndef TVM_RELAY_QNN_ATTRS_H_
#define TVM_RELAY_QNN_ATTRS_H_

#include <tvm/attrs.h>
#include <string>

namespace tvm {
namespace relay {
namespace qnn {

/*! \brief Attribute for requantize operator */
struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
double input_scale;
int32_t input_zero_point;
double output_scale;
int32_t output_zero_point;
std::string rounding;
DataType out_dtype;

TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
TVM_ATTR_FIELD(input_scale)
.describe("The scale of the input tensor.");
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(output_scale)
.describe("The scale of the output tensor.");
TVM_ATTR_FIELD(output_zero_point)
.describe("The zero point of the output tensor.");
TVM_ATTR_FIELD(rounding).set_default("TONEAREST")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please double check the choice of enums "TONEAREST", are there existing API choices that are similar? I do not have preference but would be great to survey the related APIs should it be "TO_NEAREST"? (The TF API convention seems to be TO_NEAREST)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had similar question while coding this. I used TONEAREST because the supporting document - https://www.gnu.org/software/libc/manual/html_node/Rounding.html
had FE_TONEAREST. To be consistent with that document, I kept TONEAREST. But, I like TO_NEAREST better as it aligns with the rest of the codebase.

.describe("Defines the rounding direction when the value is midway between"
"two representable values. There are two supported modes - UPWARD"
"or TONEAREST. Both modes behave exactly same except at the"
"midpoints between the two representable values. At the midpoint,"
"UPWARD rounds towards positive infinity (for example -1.5 will be"
"rounded to -1). TONEAREST is the standard rounding where the"
"value is rounded away from zero at midpoints (for example, -1.5"
"rounds to -2). More context can be found at following gblic manual"
"https://www.gnu.org/software/libc/manual/html_node/Rounding.html.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

} // namespace qnn
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_QNN_ATTRS_H_
3 changes: 3 additions & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
from . import backend
from . import quantize

# Dialects
from . import qnn

from .scope_builder import ScopeBuilder

# Span
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/qnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""QNN dialect operators and IR passes."""
from __future__ import absolute_import as _abs
from . import op
20 changes: 20 additions & 0 deletions python/tvm/relay/qnn/op/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Neural network related operators."""
from __future__ import absolute_import as _abs
from .qnn import *
anijain2305 marked this conversation as resolved.
Show resolved Hide resolved
20 changes: 20 additions & 0 deletions python/tvm/relay/qnn/op/_make.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
from ...._ffi.function import _init_api

_init_api("relay.qnn.op._make", __name__)
74 changes: 74 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=invalid-name
"""QNN dialect operators."""
anijain2305 marked this conversation as resolved.
Show resolved Hide resolved

from __future__ import absolute_import as _abs
from . import _make

def requantize(data,
input_scale,
input_zero_point,
output_scale,
output_zero_point,
rounding="TONEAREST",
out_dtype="int8"):
r"""Requantized operator.

The requantize operator converts one quantized tensor representation to
another quantized tensor representation. For the output tensor, we are
provided with output scale and zero point. The computation is as follows

Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)

Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
anijain2305 marked this conversation as resolved.
Show resolved Hide resolved

input_scale: float
The quantization scale for the input tensor.

input_zero_point: int
The zero point of the input tensor.

output_scale: float
The quantization scale for the output tensor.

output_zero_point: int
The zero point of the output tensor.

rounding : string, optional
Defines the rounding direction when the value is midway between two
representable values.

out_dtype : str, optional
Specifies the output data type.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""

return _make.requantize(data,
input_scale,
input_zero_point,
output_scale,
output_zero_point,
rounding,
out_dtype)
20 changes: 20 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,26 @@ inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, b
}


static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
static const Op& op = Op::Get("where");
return CallNode::make(op, {condition, x, y});
}

static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
static const Op& op = Op::Get("greater_equal");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}

static inline Expr Full(Expr fill_value,
Array<IndexExpr> shape,
DataType dtype) {
auto attrs = make_node<InitOpAttrs>();
attrs->shape = std::move(shape);
attrs->dtype = std::move(dtype);
static const Op& op = Op::Get("full");
return CallNode::make(op, {fill_value}, Attrs(attrs), {});
}

Expr MakeConcatenate(Expr data, int axis);

Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
Expand Down
Loading