Skip to content

Commit

Permalink
[QNN] Concatenate operator (apache#3730)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and wweic committed Sep 6, 2019
1 parent 19cfb1f commit 5ca146a
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 0 deletions.
73 changes: 73 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""QNN dialect operators."""

from __future__ import absolute_import as _abs
from tvm import relay
from . import _make

def requantize(data,
Expand Down Expand Up @@ -72,3 +73,75 @@ def requantize(data,
output_zero_point,
rounding,
out_dtype)

def concatenate(data,
input_scales,
input_zero_points,
output_scale,
output_zero_point,
axis):
"""Concatenate the quantized input tensors along the given axis.
Parameters
----------
data : Union(List[relay.Expr], Tuple[relay.Expr])
The list of quantized tensors.
input_scales : List[float32]
The list of scales of input quantized tensors.
input_zero_points : List[int32]
The list of zero points of input quantized tensors.
output_scale : float32
The scale of the output quantized tensor.
output_zero_point : int32
The zero point of the output quantized tensor.
axis : int
The axis along which the tensors are concatenated.
Returns
-------
result: relay.Expr
The concatenated quantized tensor.
"""

data = list(data)
requantized_exprs = list(data)

# Find the dtype of the input expr. This is required for the requantize op. Since, this is
# concatenate op, the dtype of the input is same as dtype of the output.
data0 = relay.transform.infer_type(data[0])
in_dtype = data0.checked_type.dtype

# First check if all the input qnn params match. If yes, we can call concatenate first, followed
# by a requantize.
if all(scale == input_scales[0] for scale in input_scales)\
and all(zero_point == input_zero_points[0] for zero_point in input_zero_points):
out = relay.concatenate(tuple(data), axis)
input_scale = input_scales[0]
input_zero_point = input_zero_points[0]
if input_scale != output_scale or input_zero_point != output_zero_point:
out = requantize(data=out,
input_scale=input_scales[0],
input_zero_point=input_zero_points[0],
output_scale=output_scale,
output_zero_point=output_zero_point,
out_dtype=in_dtype)
return out

# If the output qnn params do not match the input qnn params, we can call requantize on the
# input expr first, followed by a concatenate on the requantized input exprs.
for idx, quantized_expr in enumerate(data):
input_scale = input_scales[idx]
input_zero_point = input_zero_points[idx]
if input_scale != output_scale or input_zero_point != output_zero_point:
requantized_exprs[idx] = requantize(data=quantized_expr,
input_scale=input_scale,
input_zero_point=input_zero_point,
output_scale=output_scale,
output_zero_point=output_zero_point,
out_dtype=in_dtype)
return relay.concatenate(tuple(requantized_exprs), axis)
145 changes: 145 additions & 0 deletions tests/python/relay/test_qnn_concatenate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
import numpy as np
from tvm import relay
from tvm.contrib import graph_runtime
import topi.testing

def test_same_io_qnn_params():
data_dtype = 'int32'
axis = 0
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
y_scale = (62 + 64) / (np.power(2, 32) - 1.0)

x = relay.var("x", shape=(1, 64), dtype=data_dtype)
y = relay.var("y", shape=(1, 64), dtype=data_dtype)
z = relay.qnn.op.concatenate((x, y),
input_scales=[x_scale, y_scale],
input_zero_points=[0, 0],
output_scale=y_scale,
output_zero_point=0,
axis=axis)

func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 0
mod = relay.Module.from_expr(func)
mod = relay.transform.Legalize()(mod)
func = mod["main"]

golden_output = np.concatenate((x_data, y_data), axis=axis)

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)

def test_different_io_qnn_params():
data_dtype = 'int32'
axis = 0
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
y_scale = (62 + 64) / (np.power(2, 32) - 1.0)

x = relay.var("x", shape=(1, 64), dtype=data_dtype)
y = relay.var("y", shape=(1, 64), dtype=data_dtype)
z = relay.qnn.op.concatenate((x, y),
input_scales=[x_scale, y_scale],
input_zero_points=[3, 4],
output_scale=y_scale,
output_zero_point=1,
axis=axis)

func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 2
mod = relay.Module.from_expr(func)
mod = relay.transform.Legalize()(mod)
func = mod["main"]

golden_output = np.concatenate((x_data - 2, y_data - 3), axis=axis)

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)

def test_few_same_io_qnn_params():
data_dtype = 'int32'
axis = 0
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
y_scale = (62 + 64) / (np.power(2, 32) - 1.0)

x = relay.var("x", shape=(1, 64), dtype=data_dtype)
y = relay.var("y", shape=(1, 64), dtype=data_dtype)
z = relay.qnn.op.concatenate((x, y),
input_scales=[x_scale, y_scale],
input_zero_points=[0, 1],
output_scale=y_scale,
output_zero_point=1,
axis=axis)

func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 1
mod = relay.Module.from_expr(func)
mod = relay.transform.Legalize()(mod)
func = mod["main"]

golden_output = np.concatenate((x_data + 1, y_data), axis=axis)

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)

def test_same_i_qnn_params():
data_dtype = 'int32'
axis = 0
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
y_scale = (62 + 64) / (np.power(2, 32) - 1.0)

x = relay.var("x", shape=(1, 64), dtype=data_dtype)
y = relay.var("y", shape=(1, 64), dtype=data_dtype)
z = relay.qnn.op.concatenate((x, y),
input_scales=[x_scale, y_scale],
input_zero_points=[0, 0],
output_scale=y_scale,
output_zero_point=1,
axis=axis)

func = relay.Function([x, y], z)
assert func.astext().count('requantize') == 1
mod = relay.Module.from_expr(func)
mod = relay.transform.Legalize()(mod)
func = mod["main"]

golden_output = np.concatenate((x_data + 1, y_data + 1), axis=axis)

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)


if __name__ == '__main__':
test_same_io_qnn_params()
test_different_io_qnn_params()
test_few_same_io_qnn_params()
test_same_i_qnn_params()

0 comments on commit 5ca146a

Please sign in to comment.