From cab0bc3bf6c57803c0d9bb6a58df8941dd9f10db Mon Sep 17 00:00:00 2001 From: gomida Date: Thu, 28 Mar 2019 20:48:46 +0900 Subject: [PATCH 1/5] Adding ADD operator to tflite frontend for compiling the MobileNetV2 --- python/tvm/relay/frontend/tflite.py | 22 +++++++++++++++++++- tests/python/frontend/tflite/test_forward.py | 22 ++++++++++++++++++-- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index d45bb33859b2..18ee215aa92a 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -45,7 +45,8 @@ def __init__(self, model, subgraph, exp_tab): 'SOFTMAX': self.convert_softmax, 'SQUEEZE': self.convert_squeeze, 'MAX_POOL_2D': self.convert_max_pool2d, - "CONCATENATION": self.convert_concatenation + 'CONCATENATION': self.convert_concatenation, + 'ADD': self.convert_add } def check_unsupported_ops(self): @@ -289,6 +290,25 @@ def convert_concatenation(self, op): out = self.convert_fused_activation_function(out, fused_activation_fn) return out + def convert_add(self, op): + """Convert TFLite add""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + lhs_tensor = input_tensors[0] + rhs_tensor = input_tensors[1] + lhs_expr = self.get_expr(lhs_tensor.tensor_idx) + rhs_expr = self.get_expr(rhs_tensor.tensor_idx) + out = _op.add(lhs_expr, rhs_expr) + + return out + def convert_squeeze(self, op): """Convert TFLite squeeze""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 3ccc895dc60e..0085379fe5cb 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -388,7 +388,7 @@ def test_forward_softmax(): # Mobilenet # --------- -def test_forward_mobilenet(): +def test_forward_mobilenet_v1(): '''test mobilenet v1 tflite model''' # MobilenetV1 temp = util.tempdir() @@ -405,6 +405,23 @@ def test_forward_mobilenet(): rtol=1e-5, atol=1e-5) temp.remove() +def test_forward_mobilenet_v2(): + '''test mobilenet v2 tflite model''' + # MobilenetV2 + temp = util.tempdir() + tflite_model_file = tf_testing.get_workload_official( + "http://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz", + "mobilenet_v2_1.0_224.tflite", temp) + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') + tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), + rtol=1e-5, atol=1e-5) + temp.remove() + ####################################################################### # Inception V3 # ------------ @@ -441,5 +458,6 @@ def test_forward_inception_v3_net(): test_forward_softmax() # End to End - test_forward_mobilenet() + test_forward_mobilenet_v1() + test_forward_mobilenet_v2() test_forward_inception_v3_net() From ce8a51a48cfb9505a77296548423654c77ee7753 Mon Sep 17 00:00:00 2001 From: gomida Date: Sun, 31 Mar 2019 11:20:13 +0900 Subject: [PATCH 2/5] * fix MobileNetV2 for updated get_workload_official * Change URL for MobileNetV2 --- tests/python/frontend/tflite/test_forward.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index f46867a886da..b562001c9d22 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -406,10 +406,9 @@ def test_forward_mobilenet_v1(): def test_forward_mobilenet_v2(): '''test mobilenet v2 tflite model''' # MobilenetV2 - temp = util.tempdir() tflite_model_file = tf_testing.get_workload_official( - "http://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz", - "mobilenet_v2_1.0_224.tflite", temp) + "http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz", + "mobilenet_v2_1.0_224.tflite") with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') @@ -418,7 +417,6 @@ def test_forward_mobilenet_v2(): tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) - temp.remove() ####################################################################### # Inception V3 From 3866bc5df8686cab1fff3d5bc145cb09a20a39e4 Mon Sep 17 00:00:00 2001 From: gomida Date: Sun, 31 Mar 2019 13:42:22 +0900 Subject: [PATCH 3/5] Modification for when the ADD operator receives a constant without being fused --- python/tvm/relay/frontend/common.py | 3 +++ python/tvm/relay/frontend/tflite.py | 33 ++++++++++++++++++++++++++--- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index d07f2af3e08b..abfb60e44ea7 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -258,6 +258,9 @@ def set_expr(self, name, expr): if name not in self.exprs: self.exprs[name] = expr + def has_expr(self, name): + return True if name in self.exprs else False + def set_padding(self, paddings): self.paddings = paddings self.in_padding = True diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 9287e1d7628e..9d616c9e4025 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -305,11 +305,35 @@ def convert_add(self, op): assert len(input_tensors) == 2, "input tensors length should be 2" lhs_tensor = input_tensors[0] - rhs_tensor = input_tensors[1] lhs_expr = self.get_expr(lhs_tensor.tensor_idx) - rhs_expr = self.get_expr(rhs_tensor.tensor_idx) - out = _op.add(lhs_expr, rhs_expr) + rhs_tensor = input_tensors[1] + if self.has_expr(rhs_tensor.tensor_idx): + # In most cases, we can assume that TOCO fuses ADD operators + # with constants - it means both will be tensors. + rhs_expr = self.get_expr(rhs_tensor.tensor_idx) + else: + # However, in some corner cases, the ADD operator is not fused, + # we can receive as constant. + rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type()) + rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), + dtype=rhs_type_str) + + # In this case, we have to be careful about formatting. + input_shape_length = len(rhs_tensor.tensor.ShapeAsNumpy()) + if input_shape_length in (1, 2): + pass + elif input_shape_length == 3: + # N H*W C to N C H*W + rhs_expr = _op.transpose(rhs_expr, axes=(0, 2, 1)) + elif input_shape_length == 4: + # N H W C to N C H W + rhs_expr = _op.transpose(rhs_expr, axes=(0, 3, 1, 2)) + else: + msg = 'Input shape length {} for operator ADD is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) + + out = _op.add(lhs_expr, rhs_expr) return out def convert_squeeze(self, op): @@ -574,6 +598,9 @@ def convert_pool2d(self, op, pool_type): def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) + def has_expr(self, input_tensor_idx): + return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) + def build_str_map(obj): """Build string map of TFLite enum int value From 48bd15547a20221ee1e076454e1d247dc20df2a3 Mon Sep 17 00:00:00 2001 From: gomida Date: Sun, 31 Mar 2019 21:34:45 +0900 Subject: [PATCH 4/5] Adding unit test for ADD operator --- tests/python/frontend/tflite/test_forward.py | 62 ++++++++++++++++++-- 1 file changed, 57 insertions(+), 5 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index b562001c9d22..fa66454897ec 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -11,6 +11,7 @@ from tvm.contrib import util import tensorflow as tf from tensorflow.python.framework import constant_op +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables @@ -99,7 +100,7 @@ def run_tflite_graph(tflite_model_buf, input_data): def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors, - output_tensors, output_need_transpose_nchw=False, + output_tensors, output_need_transpose=False, init_global_variables=False): """Generic function to generate and compare TFLite and TVM output""" tflite_in_data = convert_to_list(tflite_in_data) @@ -126,9 +127,19 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors, tvm_output = run_tvm_graph(tflite_model_buffer, tvm_in_data, in_node, target=device) for i in range(len(tflite_output)): - if output_need_transpose_nchw: + if output_need_transpose: + dim = len(tvm_output[i].shape) + if dim == 3: + # N C H*W to N H*W C + axes = (0, 2, 1) + elif dim == 4: + # N C H W to N H W C + axes = (0, 2, 3, 1) + else: + raise NotImplementedError("Not support input shape {} of transpose : ". + format(str(dim))) tvm.testing.assert_allclose(tflite_output[i], - np.transpose(tvm_output[i], axes=(0, 2, 3, 1)), + np.transpose(tvm_output[i], axes=axes), atol=1e-5, rtol=1e-5) else: tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], @@ -152,7 +163,7 @@ def _test_pooling_iteration(input_shape, **kwargs): out = nn_ops.pool(in_data, **kwargs) compare_tflite_with_tvm(x, tvm_data, 'Placeholder:0', [in_data], [out], - output_need_transpose_nchw=True) + output_need_transpose=True) def _test_pooling(input_shape, **kwargs): @@ -236,7 +247,7 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, # TFLite output is NHWC, TVM is NCHW, we need transpose compare_tflite_with_tvm(tflite_data_array, tvm_data_array, 'Placeholder:0', [in_data], [out], - output_need_transpose_nchw=True) + output_need_transpose=True) def test_forward_convolution(): @@ -330,6 +341,44 @@ def test_forward_concatenation(): np.arange(6).reshape((2, 1, 1, 3))], 1) +####################################################################### +# Add +# ------- + +def _test_add(data): + """ One iteration of add """ + + assert len(data) == 2 + need_transpose = False + if len(data[0].shape) == 1 or len(data[0].shape) == 2: + tvm_data = data + elif len(data[0].shape) == 3: + need_transpose = True + tvm_data = [np.transpose(d, axes=(0, 2, 1)) for d in data] + elif len(data[0].shape) == 4: + need_transpose = True + tvm_data = [np.transpose(d, axes=(0, 3, 1, 2)) for d in data] + else: + raise NotImplementedError("Not support input shape {} of add : ". + format(str(len(data.shape)))) + + with tf.Graph().as_default(): + in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'), + array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')] + out = math_ops.add(in_data[0], in_data[1]) + compare_tflite_with_tvm(data, tvm_data, ['in_0:0','in_1:0'], + in_data, [out], need_transpose) + +def test_forward_add(): + """ Add """ + _test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), + np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))]) + _test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)), + np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))]) + _test_add([np.arange(3.0, dtype=np.float32).reshape((1, 3)), + np.arange(3.0, dtype=np.float32).reshape((1, 3))]) + + ####################################################################### # Squeeze # ------- @@ -451,6 +500,9 @@ def test_forward_inception_v3_net(): test_forward_pooling() test_forward_softmax() + # Math + test_forward_add() + # End to End test_forward_mobilenet_v1() test_forward_mobilenet_v2() From 18025efc42f4e5a3624537adca98e14583e07d7d Mon Sep 17 00:00:00 2001 From: gomida Date: Mon, 1 Apr 2019 19:27:44 +0900 Subject: [PATCH 5/5] Adding ADD unit test for constant input --- tests/python/frontend/tflite/test_forward.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index fa66454897ec..9abbfad8e429 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -11,6 +11,7 @@ from tvm.contrib import util import tensorflow as tf from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import array_ops @@ -343,7 +344,7 @@ def test_forward_concatenation(): ####################################################################### # Add -# ------- +# --- def _test_add(data): """ One iteration of add """ @@ -362,6 +363,7 @@ def _test_add(data): raise NotImplementedError("Not support input shape {} of add : ". format(str(len(data.shape)))) + # Test with two tensors with tf.Graph().as_default(): in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'), array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')] @@ -369,6 +371,14 @@ def _test_add(data): compare_tflite_with_tvm(data, tvm_data, ['in_0:0','in_1:0'], in_data, [out], need_transpose) + # Test with tensor and constant + with tf.Graph().as_default(): + in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')] + out = math_ops.add(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype)) + compare_tflite_with_tvm([data[0]], [tvm_data[0]], ['in:0'], + in_data, [out], need_transpose) + + def test_forward_add(): """ Add """ _test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),