From a38f61af45139f4e9651e33b6f124aa54b08d442 Mon Sep 17 00:00:00 2001 From: Mahesh Ambule <15611578+maheshambule@users.noreply.github.com> Date: Wed, 29 Apr 2020 11:05:54 +0530 Subject: [PATCH] [Frontend][TFLite] L2_POOL_2D operator (#5452) * TFLITE fill and splitv ops * l2_pool_2d op changes in comment * TFLite l2_pool_2d op added test case in main * TFLite L2_POOL_2D added check for quantized input --- python/tvm/relay/frontend/tflite.py | 15 +++++++++++ tests/python/frontend/tflite/test_forward.py | 26 ++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index bba7d3b1789a..2065d60a299e 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -87,6 +87,7 @@ def __init__(self, model, subgraph, exp_tab): 'GREATER': self.convert_greater, 'HARD_SWISH': self.convert_hard_swish, 'L2_NORMALIZATION': self.convert_l2_normalization, + 'L2_POOL_2D': self.convert_l2_pool2d, 'LESS_EQUAL': self.convert_less_equal, 'LESS': self.convert_less, 'LOCAL_RESPONSE_NORMALIZATION': self.convert_lrn, @@ -332,6 +333,10 @@ def convert_max_pool2d(self, op): """Convert TFLite max pool2d""" return self.convert_pool2d(op, "max") + def convert_l2_pool2d(self, op): + """Convert TFLite l2 pool2d""" + return self.convert_pool2d(op, "l2") + def convert_reshape(self, op): """Convert TFLite reshape""" try: @@ -1770,6 +1775,16 @@ def convert_pool2d(self, op, pool_type): assert self.has_same_qnn_params(input_tensor, output_tensor), \ "qnn.op.max_pool2d requires input and output qnn params to be same" out = _op.nn.max_pool2d(in_expr, **params) + elif pool_type == "l2": + # L2_POOL_2D is equivalent to square_root(avg_pool(square(in_data))) + # TFLite does not have support for quantised L2_POOL_2D op. + assert not input_tensor.qnn_params, \ + "As TFLite does not have support for quantized L2_POOL_2D, \ + Quantized input is not expected." + exp_type = self.get_tensor_type_str(output_tensor.tensor.Type()) + square_exp = _op.power(in_expr, relay.const(2, exp_type)) + avg_pool_exp = _op.nn.avg_pool2d(square_exp, **params) + out = _op.sqrt(avg_pool_exp) else: raise tvm.error.OpNotImplemented( 'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool')) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index b9602a50c3e1..eb65d82a6546 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -561,6 +561,31 @@ def test_forward_pooling(): strides=[2, 1]) +def _test_l2_pool2d(input_shape, ksize, strides, padding, data_format, fused_func_name=None): + x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1 + + with tf.Graph().as_default(): + in_data = tf.placeholder( + dtype=tf.float32, name="input", shape=input_shape) + out = tf.sqrt(tf.nn.avg_pool( + tf.square(in_data), ksize=ksize, strides=strides, + padding=padding, data_format=data_format)) + out = with_fused_activation_function(out, fused_func_name) + + compare_tflite_with_tvm(x, 'input', [in_data], [out]) + + +def test_forward_l2_pool2d(): + _test_l2_pool2d([1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], 'SAME', "NHWC", "RELU6") + _test_l2_pool2d([2, 9, 10, 2], [1, 1, 1, 1], [1, 1, 1, 1], 'SAME', "NHWC", "RELU6") + _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 1, 1], 'SAME', "NHWC") + _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 2, 1], 'SAME', "NHWC") + _test_l2_pool2d([1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], 'VALID', "NHWC", "RELU") + _test_l2_pool2d([2, 9, 10, 2], [1, 1, 1, 1], [1, 1, 1, 1], 'VALID', "NHWC") + _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 1, 1], 'VALID', "NHWC") + _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 2, 1], 'VALID', "NHWC", "RELU6") + + ####################################################################### # Convolution # ----------- @@ -1938,6 +1963,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_transpose_conv() test_forward_logistic() test_forward_pooling() + test_forward_l2_pool2d() test_forward_softmax() test_forward_tanh() test_forward_relu()