diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 0ffd07e77d9e..acfa9148a756 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -60,6 +60,8 @@ def __init__(self, model, subgraph, exp_tab): 'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d, 'AVERAGE_POOL_2D': self.convert_average_pool2d, 'RESHAPE': self.convert_reshape, + 'RESIZE_BILINEAR': self.convert_resize_bilinear, + 'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor, 'SOFTMAX': self.convert_softmax, 'SQUEEZE': self.convert_squeeze, 'MAX_POOL_2D': self.convert_max_pool2d, @@ -225,6 +227,58 @@ def convert_reshape(self, op): return out + def _convert_resize(self, method, op): + """Generic method to Convert TFLite RESIZE operators""" + try: + from tflite.BuiltinOptions import BuiltinOptions + from tflite.Operator import Operator + from tflite.ResizeBilinearOptions import ResizeBilinearOptions + # ResizeNearestNeighborOptions was added in tflite v1.13 + tflite_ver = 1120 + if 'ResizeNearestNeighborOptions' in dir(BuiltinOptions): + from tflite.ResizeNearestNeighborOptions import ResizeNearestNeighborOptions + tflite_ver = 1130 + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + # images, 4-D Tensor with shape NHWC. + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + # size - 1-D int32 Tensor of 2 elements: new_height, new_width + target_size = tuple(self.get_tensor_value(input_tensors[1])) + + # Options - align_corners (bool) + resize_options = None + align_corners = False + if method == "BILINEAR": + assert op.BuiltinOptionsType() == BuiltinOptions.ResizeBilinearOptions + resize_options = ResizeBilinearOptions() + elif tflite_ver >= 1130: + assert op.BuiltinOptionsType() == BuiltinOptions.ResizeNearestNeighborOptions + resize_options = ResizeNearestNeighborOptions() + + if resize_options is not None: + op_options = op.BuiltinOptions() + resize_options.Init(op_options.Bytes, op_options.Pos) + align_corners = resize_options.AlignCorners() + + # Use layout NHWC + out = _op.image.resize(in_expr, target_size, "NHWC", method, align_corners) + return out + + def convert_resize_bilinear(self, op): + """Convert TFLite RESIZE_BILINEAR""" + return self._convert_resize("BILINEAR", op) + + def convert_resize_nearest_neighbor(self, op): + """Convert TFLite RESIZE_NEAREST_NEIGHBOR""" + return self._convert_resize("NEAREST_NEIGHBOR", op) + def convert_logistic(self, op): """Convert TFLite LOGISTIC""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 795a08966e1d..0a011a6e709b 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -289,6 +289,37 @@ def test_forward_reshape(): _test_reshape(np.arange(6), [-1]) +####################################################################### +# Resize +# ------ + +def _test_resize(tf_resize_op, data, align_corners): + """ One iteration of Resize """ + + assert len(data) == 2 + + # Test with tensor and constant + with tf.Graph().as_default(): + images_tensor = array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in') + size = ops.convert_to_tensor(data[1], dtype=data[1].dtype) + out_tensor = tf_resize_op(images=images_tensor, size=size, align_corners=align_corners) + compare_tflite_with_tvm([data[0]], ['in:0'], [images_tensor], [out_tensor]) + + +def test_all_resize(): + """ Resize """ + data = [np.random.rand(1, 16, 16, 3).astype("float32"), np.array([8, 8], dtype=np.int32)] + ### RESIZE_BILINEAR + _test_resize(tf.image.resize_bilinear, data, align_corners=False) + _test_resize(tf.image.resize_bilinear, data, align_corners=True) + ### RESIZE_NEAREST_NEIGHBOR (was added in v1.13) + # According to topi resize.h + # Align corners not supported for nearest neighbour + from tflite.BuiltinOperator import BuiltinOperator + if 'RESIZE_NEAREST_NEIGHBOR' in dir(BuiltinOperator()): + _test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False) + + ####################################################################### # Concatenation # ------------- @@ -651,6 +682,7 @@ def test_forward_ssd_mobilenet_v1(): test_forward_concatenation() test_forward_pad() test_forward_reshape() + test_all_resize() test_forward_squeeze() # NN diff --git a/topi/include/topi/image/resize.h b/topi/include/topi/image/resize.h index 287ff9406618..0119aed3aff0 100644 --- a/topi/include/topi/image/resize.h +++ b/topi/include/topi/image/resize.h @@ -384,7 +384,7 @@ inline Tensor resize_bilinear(const Tensor& input, * \param shape Output shape to resize to. * \param layout input layout * \param align_corners To preserve centers of 4 corner pixels -* \param mode Angorithm to use (NEAREST_NEIGHBOR / BILINEAR) +* \param mode Algorithm to use (NEAREST_NEIGHBOR / BILINEAR) * \param name Name of the operation * \param tag The tag to mark the operation * diff --git a/topi/include/topi/nn/upsampling.h b/topi/include/topi/nn/upsampling.h index eb75d5dd8d92..d7551afbe04f 100644 --- a/topi/include/topi/nn/upsampling.h +++ b/topi/include/topi/nn/upsampling.h @@ -43,7 +43,7 @@ using namespace topi::image; * \param input The input tensor. * \param shape Output shape to upsample. * \param layout input layout -* \param mode Angorithm to use (NEAREST_NEIGHBOR / BILINEAR) +* \param mode Algorithm to use (NEAREST_NEIGHBOR / BILINEAR) * \param name Name of the operation * \param tag The tag to mark the operation *