Skip to content

Commit

Permalink
[Frontend][TFLite] support for FILL and SPLIT_V operators (apache#5330)
Browse files Browse the repository at this point in the history
* tflite spliv ops

* TFLITE fill and splitv ops

* TFLITE fill and splitv ops

* TFLITE fill and splitv ops

* remove unnecessary operator check
  • Loading branch information
maheshambule authored and Trevor Morris committed Jun 18, 2020
1 parent 220558c commit df198d5
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 7 deletions.
46 changes: 46 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self, model, subgraph, exp_tab):
'ELU': self.convert_elu,
'EQUAL': self.convert_equal,
'EXP': self.convert_exp,
'FILL': self.convert_fill,
'FLOOR_DIV': self.convert_floor_div,
'FLOOR_MOD': self.convert_floor_mod,
'FLOOR': self.convert_floor,
Expand Down Expand Up @@ -123,6 +124,7 @@ def __init__(self, model, subgraph, exp_tab):
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'SPACE_TO_DEPTH': self.convert_space_to_depth,
'SPLIT': self.convert_split,
'SPLIT_V': self.convert_split_v,
'SQRT': self.convert_sqrt,
'SQUARE': self.convert_square,
'SQUARED_DIFFERENCE': self.convert_squared_difference,
Expand Down Expand Up @@ -1212,6 +1214,21 @@ def convert_zeros_like(self, op):

return out

def convert_fill(self, op):
"""Convert TFLite FILL"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

if self.has_expr(input_tensors[0].tensor_idx):
raise tvm.error.OpNotImplemented("For dims parameter of Fill operator,"
" only constant values are supported.")

in_dims = list(self.get_tensor_value(input_tensors[0]))
in_value_expr = self.get_expr(input_tensors[1].tensor_idx)
out = _op.full(in_value_expr, in_dims)

return out

def _convert_reduce(self, relay_op, op):
"""Generic method to Convert TFLite MEAN operators"""
try:
Expand Down Expand Up @@ -1617,6 +1634,35 @@ def convert_split(self, op):

return out

def convert_split_v(self, op):
"""SPLIT_V implementation."""
input_tensors = self.get_input_tensors(op)

assert len(input_tensors) == 3, "input tensors length should be 3"

input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx
in_expr = self.get_expr(input_tensor_idx)

if self.has_expr(input_tensors[1].tensor_idx):
raise tvm.error.OpNotImplemented("For size_splits parameter of SPLIT_V operator, "
"only constant values are supported.")
size_splits = list(self.get_tensor_value(input_tensors[1]))
size_splits = tuple(np.cumsum(size_splits)[:-1])

axis_tensor = input_tensors[2]
split_axis = self.get_tensor_value(axis_tensor)

out = _op.split(in_expr, size_splits, axis=int(split_axis))
# Relay does not like a TupleWrapper of 1 element, further this
# only shows up with tf1.13 if we use a split with num_splits==1.
# In tf 1.14 this doesn't appear as it is automatically a reshape
# operation.
if isinstance(out, _expr.TupleWrapper) and out.size == 1:
out = out[0]

return out

def convert_slice(self, op):
"""Convert TFLite SLICE"""
input_tensors = self.get_input_tensors(op)
Expand Down
57 changes: 50 additions & 7 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,19 @@ def with_fused_activation_function(input_tensor, fn_name):
return math_ops.tanh(input_tensor)
raise AssertionError("Unknown fused_activation_function {}".format(fn_name))

def _test_split(in_shape, axis, num_Splits, dtype):
'''internal split tester taking as parameters in_shape, number of tensors to split into
and dtype (data type)'''

def _test_split(in_shape, axis, num_splits, dtype):
"""internal split tester taking as parameters in_shape, number of tensors to split into
and dtype (data type)"""

np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=in_shape, dtype=dtype)
out = array_ops.split(in_data, num_Splits, axis=axis)
out_names = ['out_' + str(n) + ':0' for n in range(num_Splits)]
compare_tflite_with_tvm([np_data], ['Placeholder:0'], [in_data], out,
in_data = array_ops.placeholder(shape=in_shape, dtype=dtype, name="in_data")
out = array_ops.split(in_data, num_splits, axis=axis)
num_splits = len(num_splits) if isinstance(num_splits, list) \
else num_splits
out_names = ['out_' + str(n) + ':0' for n in range(num_splits)]
compare_tflite_with_tvm([np_data], ['in_data'], [in_data], out,
out_names=out_names)

def test_forward_split():
Expand Down Expand Up @@ -252,6 +256,9 @@ def test_forward_split():
_test_split((1, 6, 3, 5), -3, 3, 'float32')
_test_split((1, 3, 6, 5), -2, 3, 'float32')
_test_split((1, 3, 5, 6), -1, 3, 'float32')
# size_splits split
_test_split((6,), 0, [1, 2, 3], 'float32')
_test_split((3, 6, 4), -2, [1, 4, 1], 'float32')

#######################################################################
# slice
Expand Down Expand Up @@ -1210,6 +1217,39 @@ def test_forward_zeros_like():
""" ZEROS LIKE """
_test_zeros_like(np.arange(6.0, dtype=np.float32).reshape((1, 6)))


#######################################################################
# Fill
# ----

def _test_fill(dims, value_data, value_dtype):
""" Use the fill op to create a tensor of value_data with constant dims."""

value_data = np.array(value_data, dtype=value_dtype)
# TF 1.13 TFLite convert method does not accept empty shapes
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
with tf.Graph().as_default():
value = array_ops.placeholder(dtype=value_dtype, name="value", shape=[])
out = tf.fill(dims, value)
compare_tflite_with_tvm([value_data], ["value"], [value], [out])

with tf.Graph().as_default():
input1 = array_ops.placeholder(dtype=value_dtype, name="input1", shape=dims)
# Fill op gets converted to static tensor during conversion
out = tf.fill(dims, value_data)
out1 = tf.add(out, input1)
input1_data = np.random.uniform(0, 5, size=dims).astype(value_dtype)
compare_tflite_with_tvm([input1_data], ["input1"], [input1], [out1])


def test_forward_fill():
""" Test FILL op """

_test_fill((1, 2, 2, 4), 5, "int32")
_test_fill((1, 2, 2, 4), 5, "float32")
_test_fill((5, ), 5, "int32")


#######################################################################
# Reduce
# ------
Expand Down Expand Up @@ -1961,6 +2001,9 @@ def test_forward_mediapipe_hand_landmark():
# Zeros Like
test_forward_zeros_like()

# Fill
test_forward_fill()

# Reduce
test_all_reduce()

Expand Down

0 comments on commit df198d5

Please sign in to comment.