Skip to content

Commit

Permalink
[QNN] Add operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Aug 8, 2019
1 parent 3ac27fc commit 781590d
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
66 changes: 66 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import absolute_import as _abs
from . import _make
from tvm import relay

def requantize(data,
input_scale,
Expand Down Expand Up @@ -72,3 +73,68 @@ def requantize(data,
output_zero_point,
rounding,
out_dtype)


def add(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale,
output_zero_point, out_dtype):
"""Quantized addition with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data.
rhs : relay.Expr
The right hand side input data.
lhs_scale: float
The scale of the lhs expr.
lhs_zero_point: int
The zero point of lhs expr.
rhs_scale: float
The scale of the rhs expr.
rhs_zero_point: int
The zero point of rhs expr.
output_scale: float
The scale of the output expr.
output_zero_point: int
The zero point of output expr.
out_dtype : str
Specifies the output data type.
Returns
-------
result : relay.Expr
The computed result.
"""

# Since the input qnn params can be different than output qnn params, we first requantize the
# input tensors to the output qnn params. Then we call relay.add on the requantized inputs. TF
# follows similar handling for quantized operators.

requantized_lhs = lhs
if not(lhs_scale == output_scale and lhs_zero_point == output_zero_point):
requantized_lhs = requantize(data=lhs,
input_scale=lhs_scale,
input_zero_point=lhs_zero_point,
output_scale=output_scale,
output_zero_point=output_zero_point,
out_dtype=out_dtype)

requantized_rhs = rhs
if not(rhs_scale == output_scale and rhs_zero_point == output_zero_point):
requantized_rhs = requantize(data=rhs,
input_scale=rhs_scale,
input_zero_point=rhs_zero_point,
output_scale=output_scale,
output_zero_point=output_zero_point,
out_dtype=out_dtype)

return relay.add(requantized_lhs, requantized_rhs)
56 changes: 56 additions & 0 deletions tests/python/relay/test_qnn_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
import numpy as np
from tvm import relay
from tvm.contrib import graph_runtime
import topi.testing

def test_qnn_add():
data_dtype = 'int32'
axis = 0
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
y_scale = (62 + 64) / (np.power(2, 32) - 1.0)

x = relay.var("x", shape=(1, 64), dtype=data_dtype)
y = relay.var("y", shape=(1, 64), dtype=data_dtype)
z = relay.qnn.op.add(lhs=x, rhs=y,
lhs_scale=x_scale,
lhs_zero_point=0,
rhs_scale=y_scale,
rhs_zero_point=0,
output_scale=y_scale,
output_zero_point=1,
out_dtype=data_dtype)

func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func)
mod = relay.transform.Legalize()(mod)
func = mod["main"]

golden_output = np.add(x_data, y_data)
golden_output = np.add(2, golden_output)

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)

if __name__ == '__main__':
test_qnn_add()

0 comments on commit 781590d

Please sign in to comment.