Skip to content

Commit

Permalink
[TFLITE]DepthToSpace and SpaceToDepth support (apache#5041)
Browse files Browse the repository at this point in the history
* [TFLITE]DepthToSpace and SpaceToDepth op parser support

* DepthToSpace and SpaceToDepth testcases

* Review comments fixed
  • Loading branch information
siju-samuel authored and zhiics committed Apr 17, 2020
1 parent 53a21a3 commit b353710
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
52 changes: 52 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(self, model, subgraph, exp_tab):
'CONCATENATION': self.convert_concatenation,
'CONV_2D': self.convert_conv2d,
'COS': self.convert_cos,
'DEPTH_TO_SPACE': self.convert_depth_to_space,
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
'DIV': self.convert_div,
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(self, model, subgraph, exp_tab):
'SLICE': self.convert_slice,
'SOFTMAX': self.convert_softmax,
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'SPACE_TO_DEPTH': self.convert_space_to_depth,
'SPLIT': self.convert_split,
'SQRT': self.convert_sqrt,
'SQUARE': self.convert_square,
Expand Down Expand Up @@ -1896,6 +1898,56 @@ def convert_space_to_batch_nd(self, op):

return reshaped_permuted_reshaped_padded

def convert_depth_to_space(self, op):
"""Convert TFLite DEPTH_TO_SPACE"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.DepthToSpaceOptions import DepthToSpaceOptions
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

assert op.BuiltinOptionsType() == BuiltinOptions.DepthToSpaceOptions
op_options = op.BuiltinOptions()
depth_to_space_options = DepthToSpaceOptions()
depth_to_space_options.Init(op_options.Bytes, op_options.Pos)
block_size = depth_to_space_options.BlockSize()
out = _op.nn.depth_to_space(in_expr, block_size, layout='NHWC')

return out

def convert_space_to_depth(self, op):
"""Convert TFLite SPACE_TO_DEPTH"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.SpaceToDepthOptions import SpaceToDepthOptions
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

assert op.BuiltinOptionsType() == BuiltinOptions.SpaceToDepthOptions
op_options = op.BuiltinOptions()
space_to_depth_options = SpaceToDepthOptions()
space_to_depth_options.Init(op_options.Bytes, op_options.Pos)
block_size = space_to_depth_options.BlockSize()
out = _op.nn.space_to_depth(in_expr, block_size, layout='NHWC')

return out

def convert_prelu(self, op):
"""Convert TFLite PReLU"""
try:
Expand Down
36 changes: 36 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,40 @@ def test_forward_prelu():
_test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"), np.full((3,), 0.2, dtype="float32"))
_test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"), np.full((1, 1, 3), 0.2, dtype="float32"))

#######################################################################
# DepthToSpace
# ------------

def _test_depthtospace(data, block_size):
""" One iteration of depth_to_space operation with given data and block size """

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = array_ops.depth_to_space(in_data, block_size)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_depthtospace():
# DEPTH_TO_SPACE comes with TFLite >= 1.15.0 fbs schema
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
_test_depthtospace(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
_test_depthtospace(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4)

#######################################################################
# SpaceToDepth
# ------------

def _test_spacetodepth(data, block_size):
""" One iteration of space_to_depth operation with given data and block size """

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = array_ops.space_to_depth(in_data, block_size)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_spacetodepth():
_test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
_test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4)

#######################################################################
# Fully Connected
# ---------------
Expand Down Expand Up @@ -1741,6 +1775,8 @@ def test_forward_mediapipe_hand_landmark():
test_all_resize()
test_forward_squeeze()
test_forward_slice()
test_forward_depthtospace()
test_forward_spacetodepth()

# NN
test_forward_convolution()
Expand Down

0 comments on commit b353710

Please sign in to comment.