Skip to content

Commit

Permalink
[Relay][Frontend][TFlite] Add parses support for UNPACK tflite operat…
Browse files Browse the repository at this point in the history
…or (apache#4447)

* use SPLIT & SQUEEZE = UNPACK as implemented in tensorflow parser
  Relay doesn't support UNPACK
* tflite 1.13: UNPACK doesn't work as exepcted -> copies the values from
  1st unpacked tensor to the other unpacks
* tflite 1.13: doesn't accept negative axis
  • Loading branch information
inadob authored and Xingyu Zhou committed Dec 13, 2019
1 parent 5c0b608 commit 4cb9b07
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
45 changes: 45 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(self, model, subgraph, exp_tab):
'FULLY_CONNECTED': self.convert_fully_connected,
'PAD': self.convert_pad,
'PACK': self.convert_pack,
'UNPACK': self.convert_unpack,
'LOGISTIC': self.convert_logistic,
'TANH':self.convert_tanh,
'RELU':self.convert_relu,
Expand Down Expand Up @@ -1239,6 +1240,50 @@ def convert_pack(self, op):
out = _op.concatenate(in_exprs_reshaped, pack_axis)
return out

def convert_unpack(self, op):
"""Convert TFLite unpack"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.UnpackOptions import UnpackOptions
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
assert op.BuiltinOptionsType() == BuiltinOptions.UnpackOptions
op_options = op.BuiltinOptions()
unpack_options = UnpackOptions()
unpack_options.Init(op_options.Bytes, op_options.Pos)
num_unpacks = unpack_options.Num()
unpack_axis = unpack_options.Axis()

# Relay doesn't support 'unpack' operator so we use 'split' & 'squeeze' instead.
# We have to do 'squeeze' along the split axis but Relay expects
# squeeze_axis to be either None or List.
squeeze_axis = None if unpack_axis == 0 else [unpack_axis]

# Relay doesn't like TupleWrapper of 1 element so we isolate the case of unpacking
# a tensor by an axis with len(axis) == 1. For reference see convert_split().
# Such unpacking will result in the same tensor so we omit 'split' and only squeeze
# along the axis of dim == 1.
if num_unpacks == 1:
squeezed = _op.squeeze(in_expr, axis=squeeze_axis)
if isinstance(squeezed, _expr.TupleWrapper):
squeezed = squeezed[0]
else:
splitted = _op.split(in_expr,
indices_or_sections=num_unpacks,
axis=unpack_axis)
squeezed = _expr.TupleWrapper(
_expr.Tuple([_op.squeeze(split_item, axis=squeeze_axis) \
for split_item in splitted]), len(splitted))

return squeezed

def convert_batch_to_space_nd(self, op):
"""batch_to_space_nd implementation."""
try:
Expand Down
22 changes: 22 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,27 @@ def test_forward_pack():
np.arange(6).reshape((2, 1, 1, 3))], 1)


#######################################################################
# Unpack
# ------

def _test_unpack(data, axis, num_unpacks):
""" One iteration of UNPACK """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = gen_array_ops.unpack(in_data, num=num_unpacks, axis=axis, name='unpack')
out_names = ['out_' + str(n) + ':0' for n in range(num_unpacks)]
compare_tflite_with_tvm([data], 'Placeholder:0', [in_data], out, out_names=out_names)

def test_forward_unpack():
""" UNPACK """
_test_unpack(np.array(np.random.uniform(0, 5, (3, 1)), dtype=np.int32), axis=1, num_unpacks=1)
_test_unpack(np.array(np.random.uniform(0, 5, (3, 4)), dtype=np.float32), axis=0, num_unpacks=3)
# tflite 1.13 doesn't accept negative axis
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
_test_unpack(np.array(np.random.uniform(0, 5, (3, 6)), dtype=np.int32), axis=-2, num_unpacks=3)
_test_unpack(np.array(np.random.uniform(0, 5, (2, 3, 4)), dtype=np.int32), axis=-3, num_unpacks=2)

#######################################################################
# Logistic
# --------
Expand Down Expand Up @@ -1280,6 +1301,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_concatenation()
test_forward_pad()
test_forward_pack()
test_forward_unpack()
test_forward_reshape()
test_all_resize()
test_forward_squeeze()
Expand Down

0 comments on commit 4cb9b07

Please sign in to comment.