Skip to content

Commit

Permalink
Add RESIZE operators to realy TFLite frontend (apache#3370)
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov authored and wweic committed Jun 27, 2019
1 parent d791086 commit e677d9f
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 2 deletions.
54 changes: 54 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(self, model, subgraph, exp_tab):
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
'AVERAGE_POOL_2D': self.convert_average_pool2d,
'RESHAPE': self.convert_reshape,
'RESIZE_BILINEAR': self.convert_resize_bilinear,
'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
'SOFTMAX': self.convert_softmax,
'SQUEEZE': self.convert_squeeze,
'MAX_POOL_2D': self.convert_max_pool2d,
Expand Down Expand Up @@ -225,6 +227,58 @@ def convert_reshape(self, op):

return out

def _convert_resize(self, method, op):
"""Generic method to Convert TFLite RESIZE operators"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.Operator import Operator
from tflite.ResizeBilinearOptions import ResizeBilinearOptions
# ResizeNearestNeighborOptions was added in tflite v1.13
tflite_ver = 1120
if 'ResizeNearestNeighborOptions' in dir(BuiltinOptions):
from tflite.ResizeNearestNeighborOptions import ResizeNearestNeighborOptions
tflite_ver = 1130
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

# images, 4-D Tensor with shape NHWC.
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

# size - 1-D int32 Tensor of 2 elements: new_height, new_width
target_size = tuple(self.get_tensor_value(input_tensors[1]))

# Options - align_corners (bool)
resize_options = None
align_corners = False
if method == "BILINEAR":
assert op.BuiltinOptionsType() == BuiltinOptions.ResizeBilinearOptions
resize_options = ResizeBilinearOptions()
elif tflite_ver >= 1130:
assert op.BuiltinOptionsType() == BuiltinOptions.ResizeNearestNeighborOptions
resize_options = ResizeNearestNeighborOptions()

if resize_options is not None:
op_options = op.BuiltinOptions()
resize_options.Init(op_options.Bytes, op_options.Pos)
align_corners = resize_options.AlignCorners()

# Use layout NHWC
out = _op.image.resize(in_expr, target_size, "NHWC", method, align_corners)
return out

def convert_resize_bilinear(self, op):
"""Convert TFLite RESIZE_BILINEAR"""
return self._convert_resize("BILINEAR", op)

def convert_resize_nearest_neighbor(self, op):
"""Convert TFLite RESIZE_NEAREST_NEIGHBOR"""
return self._convert_resize("NEAREST_NEIGHBOR", op)

def convert_logistic(self, op):
"""Convert TFLite LOGISTIC"""
try:
Expand Down
32 changes: 32 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,37 @@ def test_forward_reshape():
_test_reshape(np.arange(6), [-1])


#######################################################################
# Resize
# ------

def _test_resize(tf_resize_op, data, align_corners):
""" One iteration of Resize """

assert len(data) == 2

# Test with tensor and constant
with tf.Graph().as_default():
images_tensor = array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')
size = ops.convert_to_tensor(data[1], dtype=data[1].dtype)
out_tensor = tf_resize_op(images=images_tensor, size=size, align_corners=align_corners)
compare_tflite_with_tvm([data[0]], ['in:0'], [images_tensor], [out_tensor])


def test_all_resize():
""" Resize """
data = [np.random.rand(1, 16, 16, 3).astype("float32"), np.array([8, 8], dtype=np.int32)]
### RESIZE_BILINEAR
_test_resize(tf.image.resize_bilinear, data, align_corners=False)
_test_resize(tf.image.resize_bilinear, data, align_corners=True)
### RESIZE_NEAREST_NEIGHBOR (was added in v1.13)
# According to topi resize.h
# Align corners not supported for nearest neighbour
from tflite.BuiltinOperator import BuiltinOperator
if 'RESIZE_NEAREST_NEIGHBOR' in dir(BuiltinOperator()):
_test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False)


#######################################################################
# Concatenation
# -------------
Expand Down Expand Up @@ -651,6 +682,7 @@ def test_forward_ssd_mobilenet_v1():
test_forward_concatenation()
test_forward_pad()
test_forward_reshape()
test_all_resize()
test_forward_squeeze()

# NN
Expand Down
2 changes: 1 addition & 1 deletion topi/include/topi/image/resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ inline Tensor resize_bilinear(const Tensor& input,
* \param shape Output shape to resize to.
* \param layout input layout
* \param align_corners To preserve centers of 4 corner pixels
* \param mode Angorithm to use (NEAREST_NEIGHBOR / BILINEAR)
* \param mode Algorithm to use (NEAREST_NEIGHBOR / BILINEAR)
* \param name Name of the operation
* \param tag The tag to mark the operation
*
Expand Down
2 changes: 1 addition & 1 deletion topi/include/topi/nn/upsampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ using namespace topi::image;
* \param input The input tensor.
* \param shape Output shape to upsample.
* \param layout input layout
* \param mode Angorithm to use (NEAREST_NEIGHBOR / BILINEAR)
* \param mode Algorithm to use (NEAREST_NEIGHBOR / BILINEAR)
* \param name Name of the operation
* \param tag The tag to mark the operation
*
Expand Down

0 comments on commit e677d9f

Please sign in to comment.