diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 7f7ae3068e2a..caf8f92b1209 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -84,6 +84,7 @@ def __init__(self, model, subgraph, exp_tab): 'FULLY_CONNECTED': self.convert_fully_connected, 'GREATER_EQUAL': self.convert_greater_equal, 'GREATER': self.convert_greater, + 'HARD_SWISH': self.convert_hard_swish, 'L2_NORMALIZATION': self.convert_l2_normalization, 'LESS_EQUAL': self.convert_less_equal, 'LESS': self.convert_less, @@ -595,6 +596,42 @@ def convert_relu(self, op): return out + def convert_hard_swish(self, op): + """Convert TFLite Hard swish""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + assert isinstance(op, Operator) + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + + def _relu6(data): + return _op.tensor.clip(data, 0.0, 6.0) + + def _hard_swish(data): + return data * _relu6(data + relay.const(3.0)) / relay.const(6.0) + + # Dequantize if the input is quantized. + if input_tensor.qnn_params: + in_expr = self.dequantize(in_expr, input_tensor) + + # Perform hardswish + out = _hard_swish(in_expr) + + # Go back to integer dataype if the original operator was quantized. + if output_tensor.qnn_params: + out = self.quantize(out, output_tensor) + + return out + def convert_concatenation(self, op): """Convert TFLite concatenation""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 42726b7038d5..037e054013d8 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1625,6 +1625,24 @@ def test_forward_mobilenet_v2(): tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) +####################################################################### +# Mobilenet V3 +# ------------ + +def test_forward_mobilenet_v3(): + """Test the Mobilenet V3 TF Lite model.""" + # MobilenetV3 + tflite_model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_float.tgz", + "v3-large_224_1.0_float/v3-large_224_1.0_float.tflite") + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), + rtol=1e-5, atol=1e-5) + ####################################################################### # Inception # --------- @@ -1723,6 +1741,32 @@ def test_forward_qnn_mobilenet_v2_net(): tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) +####################################################################### +# Mobilenet V3 Quantized +# ---------------------- + +def test_forward_qnn_mobilenet_v3_net(): + """Test the Quantized TFLite Mobilenet V3 model.""" + # MobilenetV3 + tflite_model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_uint8.tgz", + "v3-large_224_1.0_uint8/v3-large_224_1.0_uint8.tflite") + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + # Test image. Checking the labels because the requantize implementation is different between + # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via + # labels. Also, giving a real image, instead of random inputs. + data = get_real_image(224, 224) + + tflite_output = run_tflite_graph(tflite_model_buf, data) + tflite_predictions = np.squeeze(tflite_output) + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') + tvm_predictions = np.squeeze(tvm_output) + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + ####################################################################### # SSD Mobilenet # ------------- @@ -1831,6 +1875,7 @@ def test_forward_mediapipe_hand_landmark(): # End to End test_forward_mobilenet_v1() test_forward_mobilenet_v2() + test_forward_mobilenet_v3() test_forward_inception_v3_net() test_forward_inception_v4_net() test_forward_ssd_mobilenet_v1() @@ -1840,3 +1885,4 @@ def test_forward_mediapipe_hand_landmark(): test_forward_qnn_inception_v1_net() test_forward_qnn_mobilenet_v1_net() test_forward_qnn_mobilenet_v2_net() + test_forward_qnn_mobilenet_v3_net()