From 2327bb9f6b7cd156a573a3ea50da075d1e07923a Mon Sep 17 00:00:00 2001
From: Ina Dobreva <55383260+inadob@users.noreply.github.com>
Date: Fri, 10 Jan 2020 22:57:16 +0000
Subject: [PATCH] [Relay][Frontend][TFlite] Add parses support for SLICE
 (#4502)

* [Relay][Frontend][TFlite] Add parses support for SLICE

* TFlite 1.13: convertor gives nonsense output when size[i]==-1
* TF parser: SLICE need fixing for size[i]==-1 -> gives wrong output
  bcs of indices

* Set end[i] = input_tensor_shape[i] as suggested in PR review

* Add another test to cover size=-1 case
---
 python/tvm/relay/frontend/tflite.py          | 30 ++++++++++++++++++++
 tests/python/frontend/tflite/test_forward.py | 21 ++++++++++++++
 2 files changed, 51 insertions(+)

diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index 5737eae17873..cb6dbea4f41e 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -103,6 +103,7 @@ def __init__(self, model, subgraph, exp_tab):
             'TANH':self.convert_tanh,
             'RELU':self.convert_relu,
             'SPLIT': self.convert_split,
+            'SLICE': self.convert_slice,
             'TRANSPOSE': self.convert_transpose,
             'CAST': self.convert_cast,
             'TILE': self.convert_tile,
@@ -1152,6 +1153,35 @@ def convert_split(self, op):
 
         return out
 
+    def convert_slice(self, op):
+        """Convert TFLite SLICE"""
+        try:
+            from tflite.Operator import Operator
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 3, "input tensors length should be == 3"
+        input_tensor = input_tensors[0]
+        in_expr = self.get_expr(input_tensor.tensor_idx)
+
+        begin = list(self.get_tensor_value(input_tensors[1]))
+        size = list(self.get_tensor_value(input_tensors[2]))
+        # strided_slice(Relay) needs the slice's end indices, not the size
+        end = size
+        input_tensor_shape = input_tensor.tensor.ShapeAsNumpy()
+        input_tensor_rank = len(input_tensor_shape)
+        for i in range(input_tensor_rank):
+            if size[i] == -1:
+                end[i] = input_tensor_shape[i]
+            else:
+                end[i] += begin[i]
+
+        out = _op.strided_slice(in_expr, begin, end)
+
+        return out
+
     def convert_transpose(self, op):
         """transpose implementation."""
         try:
diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py
index e6805a9eaddc..1478b2538bba 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -224,6 +224,26 @@ def test_forward_split():
     _test_split((1, 3, 6, 5), -2, 3, 'float32')
     _test_split((1, 3, 5, 6), -1, 3, 'float32')
 
+#######################################################################
+# slice
+# -----
+
+def _test_slice(data, begin, size):
+    """ One iteration of SLICE """
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+        out = array_ops.slice(in_data, begin, size)
+        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+def test_forward_slice():
+    """ SLICE """
+    _test_slice(np.arange(4, dtype=np.float32).reshape((4, )), begin=[0], size=[2])
+    _test_slice(np.arange(18, dtype=np.int32).reshape((3, 2, 3)), begin=[1, 0, 0], size=[1, 1, 3])
+    # tflite 1.13 outputs nonsense values if size[i] == -1
+    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+        _test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1])
+        _test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1])
+
 #######################################################################
 # transpose
 # ---------
@@ -1408,6 +1428,7 @@ def test_forward_mediapipe_hand_landmark():
     test_forward_reshape()
     test_all_resize()
     test_forward_squeeze()
+    test_forward_slice()
 
     # NN
     test_forward_convolution()