Skip to content

Commit

Permalink
[Frontend][TFLite] L2_POOL_2D operator (#5452)
Browse files Browse the repository at this point in the history
* TFLITE fill and splitv ops

* l2_pool_2d op changes in comment

* TFLite l2_pool_2d op added test case in main

* TFLite L2_POOL_2D added check for quantized input
  • Loading branch information
maheshambule authored Apr 29, 2020
1 parent 2b32c95 commit a38f61a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
15 changes: 15 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self, model, subgraph, exp_tab):
'GREATER': self.convert_greater,
'HARD_SWISH': self.convert_hard_swish,
'L2_NORMALIZATION': self.convert_l2_normalization,
'L2_POOL_2D': self.convert_l2_pool2d,
'LESS_EQUAL': self.convert_less_equal,
'LESS': self.convert_less,
'LOCAL_RESPONSE_NORMALIZATION': self.convert_lrn,
Expand Down Expand Up @@ -332,6 +333,10 @@ def convert_max_pool2d(self, op):
"""Convert TFLite max pool2d"""
return self.convert_pool2d(op, "max")

def convert_l2_pool2d(self, op):
"""Convert TFLite l2 pool2d"""
return self.convert_pool2d(op, "l2")

def convert_reshape(self, op):
"""Convert TFLite reshape"""
try:
Expand Down Expand Up @@ -1770,6 +1775,16 @@ def convert_pool2d(self, op, pool_type):
assert self.has_same_qnn_params(input_tensor, output_tensor), \
"qnn.op.max_pool2d requires input and output qnn params to be same"
out = _op.nn.max_pool2d(in_expr, **params)
elif pool_type == "l2":
# L2_POOL_2D is equivalent to square_root(avg_pool(square(in_data)))
# TFLite does not have support for quantised L2_POOL_2D op.
assert not input_tensor.qnn_params, \
"As TFLite does not have support for quantized L2_POOL_2D, \
Quantized input is not expected."
exp_type = self.get_tensor_type_str(output_tensor.tensor.Type())
square_exp = _op.power(in_expr, relay.const(2, exp_type))
avg_pool_exp = _op.nn.avg_pool2d(square_exp, **params)
out = _op.sqrt(avg_pool_exp)
else:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool'))
Expand Down
26 changes: 26 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,31 @@ def test_forward_pooling():
strides=[2, 1])


def _test_l2_pool2d(input_shape, ksize, strides, padding, data_format, fused_func_name=None):
x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1

with tf.Graph().as_default():
in_data = tf.placeholder(
dtype=tf.float32, name="input", shape=input_shape)
out = tf.sqrt(tf.nn.avg_pool(
tf.square(in_data), ksize=ksize, strides=strides,
padding=padding, data_format=data_format))
out = with_fused_activation_function(out, fused_func_name)

compare_tflite_with_tvm(x, 'input', [in_data], [out])


def test_forward_l2_pool2d():
_test_l2_pool2d([1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], 'SAME', "NHWC", "RELU6")
_test_l2_pool2d([2, 9, 10, 2], [1, 1, 1, 1], [1, 1, 1, 1], 'SAME', "NHWC", "RELU6")
_test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 1, 1], 'SAME', "NHWC")
_test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 2, 1], 'SAME', "NHWC")
_test_l2_pool2d([1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], 'VALID', "NHWC", "RELU")
_test_l2_pool2d([2, 9, 10, 2], [1, 1, 1, 1], [1, 1, 1, 1], 'VALID', "NHWC")
_test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 1, 1], 'VALID', "NHWC")
_test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 2, 1], 'VALID', "NHWC", "RELU6")


#######################################################################
# Convolution
# -----------
Expand Down Expand Up @@ -1938,6 +1963,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_transpose_conv()
test_forward_logistic()
test_forward_pooling()
test_forward_l2_pool2d()
test_forward_softmax()
test_forward_tanh()
test_forward_relu()
Expand Down

0 comments on commit a38f61a

Please sign in to comment.