Skip to content

Commit

Permalink
Add Pack operator to TFLite (apache#3521)
Browse files Browse the repository at this point in the history
  • Loading branch information
tristan-arm authored and wweic committed Jul 11, 2019
1 parent d9aedcc commit 1a268ed
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
28 changes: 28 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, model, subgraph, exp_tab):
'REDUCE_PROD': self._convert_reduce_prod,
'FULLY_CONNECTED': self.convert_fully_connected,
'PAD': self.convert_pad,
'PACK': self.convert_pack,
'LOGISTIC': self.convert_logistic,
}

Expand Down Expand Up @@ -789,6 +790,33 @@ def convert_pad(self, op):
out = _op.nn.pad(in_expr, paddings)
return out

def convert_pack(self, op):
"""Convert TFLite pack"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.PackOptions import PackOptions
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) >= 1, "input tensors should greater than 1"
in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors should be 1"

assert op.BuiltinOptionsType() == BuiltinOptions.PackOptions
op_options = op.BuiltinOptions()
pack_options = PackOptions()
pack_options.Init(op_options.Bytes, op_options.Pos)
pack_axis = pack_options.Axis()

in_exprs_reshaped = [_op.expand_dims(i, axis=pack_axis, num_newaxis=1) for i in in_exprs]
out = _op.concatenate(in_exprs_reshaped, pack_axis)
return out

def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))

Expand Down
36 changes: 36 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,41 @@ def test_forward_pad():
np.array([[1, 1], [2, 2]], dtype=np.int32)])


#######################################################################
# Pack
# -------------

def _test_pack(data, axis):
""" One iteration of pack """

assert len(data) >= 1

with tf.Graph().as_default():
in_data = [
array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx))
for idx, tensor in enumerate(data)]
out = array_ops.pack(in_data, axis=axis)
name = ["in_{}:0".format(idx) for idx in range(len(data))]

compare_tflite_with_tvm(data, name, in_data, [out])


def test_forward_pack():
""" Pack """
_test_pack(
[np.arange(6).reshape((1, 2, 1, 3)),
np.arange(6).reshape((1, 2, 1, 3))], 1)

_test_pack(
[np.arange(6).reshape((3, 2)),
np.arange(6).reshape((3, 2))], 1)

_test_pack(
[np.arange(6).reshape((2, 1, 1, 3)),
np.arange(6).reshape((2, 1, 1, 3)),
np.arange(6).reshape((2, 1, 1, 3))], 1)


#######################################################################
# Logistic
# --------
Expand Down Expand Up @@ -750,6 +785,7 @@ def test_forward_ssd_mobilenet_v1():
# Transforms
test_forward_concatenation()
test_forward_pad()
test_forward_pack()
test_forward_reshape()
test_all_resize()
test_forward_squeeze()
Expand Down

0 comments on commit 1a268ed

Please sign in to comment.