diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 7b4394e7facb..0ffd07e77d9e 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -298,6 +298,12 @@ def _convert_elemwise(self, relay_op, op): """Generic method to Convert TFLite elemwise""" try: from tflite.Operator import Operator + from tflite.AddOptions import AddOptions + from tflite.SubOptions import SubOptions + from tflite.MulOptions import MulOptions + from tflite.DivOptions import DivOptions + from tflite.BuiltinOptions import BuiltinOptions + from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -320,6 +326,26 @@ def _convert_elemwise(self, relay_op, op): rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), dtype=rhs_type_str) out = relay_op(lhs_expr, rhs_expr) + + # Options (fused_activation_function) + options = None + if op.BuiltinOptionsType() == BuiltinOptions.AddOptions: + options = AddOptions() + elif op.BuiltinOptionsType() == BuiltinOptions.SubOptions: + options = SubOptions() + elif op.BuiltinOptionsType() == BuiltinOptions.MulOptions: + options = MulOptions() + elif op.BuiltinOptionsType() == BuiltinOptions.DivOptions: + options = DivOptions() + + if options is not None: + op_options = op.BuiltinOptions() + options.Init(op_options.Bytes, op_options.Pos) + fused_activation_fn = options.FusedActivationFunction() + # if we have activation fn + if fused_activation_fn != ActivationFunctionType.NONE: + out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def convert_add(self, op): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 3b76fad1c073..795a08966e1d 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -21,6 +21,7 @@ This article is a test script to test TFLite operator with Relay. """ from __future__ import print_function +from functools import partial import numpy as np import tvm from tvm import relay @@ -146,6 +147,20 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors, tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) +def with_fused_activation_function(input_tensor, fn_name): + if fn_name is None or fn_name == "NONE": + return input_tensor + if fn_name == "RELU": + return nn_ops.relu(input_tensor) + if fn_name == "RELU6": + return nn_ops.relu6(input_tensor) + if fn_name == "RELU_N1_TO_1": + return math_ops.maximum(-1, math_ops.minimum(input_tensor, 1)) + if fn_name == "TANH": + return math_ops.tanh(input_tensor) + raise AssertionError("Unknown fused_activation_function {}".format(fn_name)) + + ####################################################################### # Pooling # ------- @@ -313,7 +328,7 @@ def test_forward_concatenation(): # Element-wise # --- -def _test_elemwise(math_op, data): +def _test_elemwise(math_op, data, fused_activation_function=None): """ One iteration of add """ assert len(data) == 2 @@ -323,12 +338,14 @@ def _test_elemwise(math_op, data): in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'), array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')] out = math_op(in_data[0], in_data[1]) + out = with_fused_activation_function(out, fused_activation_function) compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out]) # Test with tensor and constant with tf.Graph().as_default(): in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')] out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype)) + out = with_fused_activation_function(out, fused_activation_function) compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out]) @@ -336,31 +353,31 @@ def _test_elemwise(math_op, data): # Add # --- -def _test_add(data): +def _test_add(data, fused_activation_function=None): """ One iteration of add """ - return _test_elemwise(math_ops.add, data) + return _test_elemwise(math_ops.add, data, fused_activation_function) ####################################################################### # Subtract # -------- -def _test_sub(data): +def _test_sub(data, fused_activation_function=None): """ One iteration of subtract """ - return _test_elemwise(math_ops.subtract, data) + return _test_elemwise(math_ops.subtract, data, fused_activation_function) ####################################################################### # Mul # --- -def _test_mul(data): +def _test_mul(data, fused_activation_function=None): """ One iteration of mul """ - return _test_elemwise(math_ops.multiply, data) + return _test_elemwise(math_ops.multiply, data, fused_activation_function) ####################################################################### # Divide # ------ -def _test_div(data): +def _test_div(data, fused_activation_function=None): """ One iteration of divide """ - return _test_elemwise(math_ops.divide, data) + return _test_elemwise(math_ops.divide, data, fused_activation_function) ####################################################################### # Power # ----- @@ -386,17 +403,25 @@ def _test_minimum(data): def _test_forward_elemwise(testop): """ Elewise""" testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), - np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))]) + np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3))]) testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)), - np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))]) + np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3))]) testop([np.arange(3.0, dtype=np.float32).reshape((1, 3)), - np.arange(3.0, dtype=np.float32).reshape((1, 3))]) + np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3))]) def test_all_elemwise(): _test_forward_elemwise(_test_add) + _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU")) + _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU6")) _test_forward_elemwise(_test_sub) + _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU")) + _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU6")) _test_forward_elemwise(_test_mul) + _test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU")) + _test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU6")) _test_forward_elemwise(_test_div) + _test_forward_elemwise(partial(_test_div, fused_activation_function="RELU")) + _test_forward_elemwise(partial(_test_div, fused_activation_function="RELU6")) _test_forward_elemwise(_test_pow) _test_forward_elemwise(_test_maximum) _test_forward_elemwise(_test_minimum)