From 8e4a1522a445dd4043e72954648d724b9c9b58ad Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Sun, 21 Nov 2021 21:33:11 +0000 Subject: [PATCH] address comments Change-Id: I6167cf73b2722902212717c5243cd19edc3489b7 --- python/tvm/relay/op/contrib/ethosu.py | 4 +- .../contrib/test_ethosu/test_codegen.py | 78 +++++++++++++++++++ .../contrib/test_ethosu/test_legalize.py | 1 + 3 files changed, 81 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 03d2fee36c1e..a2916e46dbb9 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -538,8 +538,8 @@ def is_valid(self): """ if np.dtype(self.ofm) == np.int32 and self.activation is not None: return False - # Due to identity operator requiring ifm != int32 for now - if np.dtype(self.ifm) == np.int32 and len(self.ifm.shape) < 4 or len(self.ifm2.shape) < 4: + # Due to identity operator requiring ofm != int32 for now + if np.dtype(self.ofm) == np.int32 and len(self.ofm.shape) < 4: return False if len(self.ifm.shape) > 4 or len(self.ifm2.shape) > 4: return False diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 6b27d468cb64..81bcbe6b7c5c 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -436,6 +436,84 @@ def representative_dataset(): infra.verify_source(compiled_models, accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape", + [ + ([4], [4]), + ([4], [1, 2, 3, 4]), + ([1, 4, 4], [4, 1]), + ], +) +def test_binary_add_with_non_4d_shapes( + accel_type, + ifm_shape, + ifm2_shape, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, lhs, rhs): + return tf.math.add(lhs, rhs) + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32), tf.TensorSpec(ifm2_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + data2 = np.random.rand(*tuple(ifm2_shape)) * 2 + yield [data.astype(np.float32), data2.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"ifm": ifm_shape, "ifm2": ifm2_shape}, + dtype_dict={"ifm": dtype, "ifm2": dtype}, + ) + mod = partition_for_ethosu(mod, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + output_tolerance=0, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + @pytest.mark.parametrize("accel_type", ACCEL_TYPES) def test_binary_add_from_constant_scalar(accel_type): dtype = "uint8" diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 7ab3962358fa..8612b90adbe3 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -565,6 +565,7 @@ def verify(ext_func): ([1, 2, 3, 4], [1, 2, 3, 4], False), ([1, 2, 3, 4], [1, 1, 3, 1], False), ([1, 1, 3, 1], [1, 2, 3, 4], True), + ([1, 4, 4], [4, 1], False), ([4], [4], False), ([4], [1, 2, 3, 4], True), ([1, 4, 4], [4, 1], False),