From bd978816f71496dc012013cd197c9c717ffb3ef2 Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 12 Mar 2020 12:39:45 +0530 Subject: [PATCH] [TFLITE]Round op parsing support added (#5022) --- python/tvm/relay/frontend/tflite.py | 8 ++++++++ tests/python/frontend/tflite/test_forward.py | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index c8207130ff82..6faf8d9d44c6 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -109,6 +109,7 @@ def __init__(self, model, subgraph, exp_tab): 'RESHAPE': self.convert_reshape, 'RESIZE_BILINEAR': self.convert_resize_bilinear, 'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor, + 'ROUND': self.convert_round, 'RSQRT': self.convert_rsqrt, 'SIN': self.convert_sin, 'SLICE': self.convert_slice, @@ -676,6 +677,13 @@ def convert_floor(self, op): 'TFlite quantized FLOOR operator is not supported yet.') return self._convert_unary_elemwise(_op.floor, op) + def convert_round(self, op): + """Convert TFLite ROUND""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized ROUND operator is not supported yet.') + return self._convert_unary_elemwise(_op.round, op) + def convert_exp(self, op): """Convert TFLite EXP""" if self.is_quantized(op): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 78d6c3e72fe7..cde8a74405eb 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -694,6 +694,15 @@ def _test_ceil(data): def _test_floor(data): """ One iteration of floor """ return _test_unary_elemwise(math_ops.floor, data) + +####################################################################### +# Round +# ----- + +def _test_round(data): + """ One iteration of round """ + return _test_unary_elemwise(math_ops.round, data) + ####################################################################### # Exp # --- @@ -787,6 +796,7 @@ def test_all_unary_elemwise(): if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): _test_forward_unary_elemwise(_test_ceil) _test_forward_unary_elemwise(_test_cos) + _test_forward_unary_elemwise(_test_round) _test_forward_unary_elemwise(_test_tan) _test_forward_unary_elemwise(_test_elu)