From a31b4f698bf3f0f546d0c8174bf78b804c48f659 Mon Sep 17 00:00:00 2001 From: Ina Dobreva <55383260+inadob@users.noreply.github.com> Date: Fri, 21 Feb 2020 04:10:45 +0000 Subject: [PATCH] Fix tests for tflite unary elemwise operations (#4913) * add TFLite version check for 'ceil' and 'cos' * fix name check of test_op for positive inputs * add error message for operator not found in the installed fbs schema --- python/tvm/relay/frontend/tflite.py | 6 +++++- tests/python/frontend/tflite/test_forward.py | 15 ++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index e92e4cef205d0..dd3587125aec6 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -166,7 +166,11 @@ def get_op_code_str(self, op): op_code_list_idx = op.OpcodeIndex() op_code_id = self.model.OperatorCodes(op_code_list_idx).BuiltinCode() - op_code_str = self.builtin_op_code[op_code_id] + try: + op_code_str = self.builtin_op_code[op_code_id] + except KeyError: + raise NotImplementedError('TFLite operator with code ' + str(op_code_id) + \ + ' is not supported by this version of the fbs schema.') if op_code_id == BuiltinOperator.CUSTOM: # Custom operator custom_op_code_str = self.model.OperatorCodes(op_code_list_idx).CustomCode() diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index e88226c1b1255..427d4bfe28106 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -669,7 +669,7 @@ def _test_unary_elemwise(math_op, data): with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name='in') out = math_op(in_data) - compare_tflite_with_tvm(data, ['in:0'], in_data, [out]) + compare_tflite_with_tvm(data, ['in:0'], [in_data], [out]) ####################################################################### # Abs @@ -745,23 +745,24 @@ def _test_neg(data): def _test_forward_unary_elemwise(test_op): # functions that need positive input - if test_op in {'_test_log', '_test_sqrt', '_test_rsqrt'}: - test_op(np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))) - test_op(np.arange(6.0, dtype=np.int32).reshape((2, 1, 3))) + if test_op.__name__ in {'_test_log', '_test_sqrt', '_test_rsqrt'}: + test_op(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3))) else: - np.array(np.random.uniform(-5, 5, (3, 1)), dtype=np.int32) + test_op(np.random.uniform(-10, 10, (3, 2)).astype(np.float32)) def test_all_unary_elemwise(): _test_forward_unary_elemwise(_test_abs) - _test_forward_unary_elemwise(_test_ceil) _test_forward_unary_elemwise(_test_floor) _test_forward_unary_elemwise(_test_exp) _test_forward_unary_elemwise(_test_log) _test_forward_unary_elemwise(_test_sin) - _test_forward_unary_elemwise(_test_cos) _test_forward_unary_elemwise(_test_sqrt) _test_forward_unary_elemwise(_test_rsqrt) _test_forward_unary_elemwise(_test_neg) + # ceil and cos come with TFLite 1.14.0.post1 fbs schema + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + _test_forward_unary_elemwise(_test_ceil) + _test_forward_unary_elemwise(_test_cos) ####################################################################### # Element-wise