From 83195a5b6230d8c0556593dd6fc05ee273fbf0dd Mon Sep 17 00:00:00 2001 From: Michal Piszczek Date: Thu, 14 May 2020 14:37:12 -0700 Subject: [PATCH 1/6] Add missing include so tflite RT builds --- src/runtime/contrib/tflite/tflite_runtime.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index f61f6ee37e0b..2e7fd42db5c2 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -26,6 +26,7 @@ #define TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_ #include +#include #include #include From 82f6bbbccc7f3a3dc06cb3c1b001a9c6f7c1f91a Mon Sep 17 00:00:00 2001 From: Michal Piszczek Date: Wed, 13 May 2020 14:00:11 -0700 Subject: [PATCH 2/6] Keep backing buffer alive for tflite models --- src/runtime/contrib/tflite/tflite_runtime.cc | 6 +++++- src/runtime/contrib/tflite/tflite_runtime.h | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 53d7754be946..a40fd04959f8 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -93,8 +93,12 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { void TFLiteRuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) { const char* buffer = tflite_model_bytes.c_str(); size_t buffer_size = tflite_model_bytes.size(); + // The buffer used to construct the model must be kept alive for + // dependent interpreters to be used. + flatBuffersBuffer_ = std::unique_ptr(new char[buffer_size]); + std::memcpy(flatBuffersBuffer_.get(), buffer, buffer_size); std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); + tflite::FlatBufferModel::BuildFromBuffer(flatBuffersBuffer_.get(), buffer_size); tflite::ops::builtin::BuiltinOpResolver resolver; // Build interpreter TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_); diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index 2e7fd42db5c2..f3e3bd90bba4 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -94,6 +94,8 @@ class TFLiteRuntime : public ModuleNode { */ NDArray GetOutput(int index) const; + // Buffer backing the interpreter's model + std::unique_ptr flatBuffersBuffer_; // TFLite interpreter std::unique_ptr interpreter_; // TVM context From 75543d87f6a076af606d51b370686bf817d612d9 Mon Sep 17 00:00:00 2001 From: Michal Piszczek Date: Thu, 14 May 2020 15:40:13 -0700 Subject: [PATCH 3/6] Re-enable tflite runtime unit tests with guard --- src/runtime/contrib/tflite/tflite_runtime.cc | 3 + src/runtime/module.cc | 2 + tests/python/contrib/test_tflite_runtime.py | 189 +++++++++++-------- tests/scripts/task_config_build_cpu.sh | 3 + 4 files changed, 115 insertions(+), 82 deletions(-) diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index a40fd04959f8..fa5bb66e9b5d 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -177,5 +177,8 @@ Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = TFLiteRuntimeCreate(args[0], args[1]); }); + +TVM_REGISTER_GLOBAL("target.runtime.tflite") +.set_body_typed(TFLiteRuntime); } // namespace runtime } // namespace tvm diff --git a/src/runtime/module.cc b/src/runtime/module.cc index be75ff265ccb..46ef6fab082b 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -129,6 +129,8 @@ bool RuntimeEnabled(const std::string& target) { f_name = "device_api.opencl"; } else if (target == "mtl" || target == "metal") { f_name = "device_api.metal"; + } else if (target == "tflite") { + f_name = "target.runtime.tflite"; } else if (target == "vulkan") { f_name = "device_api.vulkan"; } else if (target == "stackvm") { diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py index 8c883b031a89..91803d9232d2 100644 --- a/tests/python/contrib/test_tflite_runtime.py +++ b/tests/python/contrib/test_tflite_runtime.py @@ -14,92 +14,117 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + import tvm from tvm import te import numpy as np from tvm import rpc from tvm.contrib import util, tflite_runtime -# import tensorflow as tf -# import tflite_runtime.interpreter as tflite - - -def skipped_test_tflite_runtime(): - - def create_tflite_model(): - root = tf.Module() - root.const = tf.constant([1., 2.], tf.float32) - root.f = tf.function(lambda x: root.const * x) - - input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32) - concrete_func = root.f.get_concrete_function(input_signature) - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - tflite_model = converter.convert() - return tflite_model - - - def check_local(): - tflite_fname = "model.tflite" - tflite_model = create_tflite_model() - temp = util.tempdir() - tflite_model_path = temp.relpath(tflite_fname) - open(tflite_model_path, 'wb').write(tflite_model) - - # inference via tflite interpreter python apis - interpreter = tflite.Interpreter(model_path=tflite_model_path) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - input_shape = input_details[0]['shape'] - tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], tflite_input) - interpreter.invoke() - tflite_output = interpreter.get_tensor(output_details[0]['index']) - - # inference via tvm tflite runtime - with open(tflite_model_path, 'rb') as model_fin: - runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0)) - runtime.set_input(0, tvm.nd.array(tflite_input)) - runtime.invoke() - out = runtime.get_output(0) - np.testing.assert_equal(out.asnumpy(), tflite_output) - - - def check_remote(): - tflite_fname = "model.tflite" - tflite_model = create_tflite_model() - temp = util.tempdir() - tflite_model_path = temp.relpath(tflite_fname) - open(tflite_model_path, 'wb').write(tflite_model) - - # inference via tflite interpreter python apis - interpreter = tflite.Interpreter(model_path=tflite_model_path) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - input_shape = input_details[0]['shape'] - tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], tflite_input) - interpreter.invoke() - tflite_output = interpreter.get_tensor(output_details[0]['index']) - - # inference via remote tvm tflite runtime - server = rpc.Server("localhost") - remote = rpc.connect(server.host, server.port) - ctx = remote.cpu(0) - a = remote.upload(tflite_model_path) - - with open(tflite_model_path, 'rb') as model_fin: - runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) - runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) - runtime.invoke() - out = runtime.get_output(0) - np.testing.assert_equal(out.asnumpy(), tflite_output) - - check_local() - check_remote() + + +def _create_tflite_model(): + root = tf.Module() + root.const = tf.constant([1., 2.], tf.float32) + root.f = tf.function(lambda x: root.const * x) + + input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32) + concrete_func = root.f.get_concrete_function(input_signature) + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + tflite_model = converter.convert() + return tflite_model + + +@pytest.mark.skip('skip because accessing output tensor is flakey') +def test_local(): + if not tvm.runtime.enabled("tflite"): + print("skip because tflite runtime is not enabled...") + return + if not tvm.get_global_func("tvm.tflite_runtime.create", True): + print("skip because tflite runtime is not enabled...") + return + + try: + import tensorflow as tf + except ImportError: + print('skip because tensorflow not installed...') + return + + tflite_fname = "model.tflite" + tflite_model = _create_tflite_model() + temp = util.tempdir() + tflite_model_path = temp.relpath(tflite_fname) + open(tflite_model_path, 'wb').write(tflite_model) + + # inference via tflite interpreter python apis + interpreter = tf.lite.Interpreter(model_path=tflite_model_path) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_shape = input_details[0]['shape'] + tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], tflite_input) + interpreter.invoke() + tflite_output = interpreter.get_tensor(output_details[0]['index']) + + # inference via tvm tflite runtime + with open(tflite_model_path, 'rb') as model_fin: + runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0)) + runtime.set_input(0, tvm.nd.array(tflite_input)) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.asnumpy(), tflite_output) + + +def test_remote(): + if not tvm.runtime.enabled("tflite"): + print("skip because tflite runtime is not enabled...") + return + if not tvm.get_global_func("tvm.tflite_runtime.create", True): + print("skip because tflite runtime is not enabled...") + return + + try: + import tensorflow as tf + except ImportError: + print('skip because tensorflow not installed...') + return + + tflite_fname = "model.tflite" + tflite_model = _create_tflite_model() + temp = util.tempdir() + tflite_model_path = temp.relpath(tflite_fname) + open(tflite_model_path, 'wb').write(tflite_model) + + # inference via tflite interpreter python apis + interpreter = tf.lite.Interpreter(model_path=tflite_model_path) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_shape = input_details[0]['shape'] + tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], tflite_input) + interpreter.invoke() + tflite_output = interpreter.get_tensor(output_details[0]['index']) + + # inference via remote tvm tflite runtime + server = rpc.Server("localhost") + remote = rpc.connect(server.host, server.port) + ctx = remote.cpu(0) + a = remote.upload(tflite_model_path) + + with open(tflite_model_path, 'rb') as model_fin: + runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) + runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.asnumpy(), tflite_output) + + server.terminate() + if __name__ == "__main__": - # skipped_test_tflite_runtime() - pass + test_local() + test_remote() diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 9c1cf2870399..ce545bde6609 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -38,3 +38,6 @@ echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake +echo set\(USE_TFLITE ON\) >> config.cmake +echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake +echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake From 3255deb649b733c0dd80d6d88ffab0525161ed97 Mon Sep 17 00:00:00 2001 From: Michal Piszczek Date: Thu, 14 May 2020 15:57:42 -0700 Subject: [PATCH 4/6] Satisfy clang-format --- src/runtime/contrib/tflite/tflite_runtime.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index fa5bb66e9b5d..df3486d5e590 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -178,7 +178,6 @@ TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body([](TVMArgs args, TVMRe *rv = TFLiteRuntimeCreate(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("target.runtime.tflite") -.set_body_typed(TFLiteRuntime); +TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntime); } // namespace runtime } // namespace tvm From 87d1c27294f6304a9d2f9e282c7005b71af1493e Mon Sep 17 00:00:00 2001 From: Michal Piszczek Date: Thu, 14 May 2020 16:10:31 -0700 Subject: [PATCH 5/6] Fix value of guard global --- src/runtime/contrib/tflite/tflite_runtime.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index df3486d5e590..8b34e90312b0 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -178,6 +178,6 @@ TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body([](TVMArgs args, TVMRe *rv = TFLiteRuntimeCreate(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntime); +TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate); } // namespace runtime } // namespace tvm From a0e3a59f737fd21fe274db46da8e4c871a29c9a4 Mon Sep 17 00:00:00 2001 From: Michal Piszczek Date: Thu, 14 May 2020 16:26:29 -0700 Subject: [PATCH 6/6] Add tflite guard to test util --- tests/python/contrib/test_tflite_runtime.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py index 91803d9232d2..1b911b7eb632 100644 --- a/tests/python/contrib/test_tflite_runtime.py +++ b/tests/python/contrib/test_tflite_runtime.py @@ -24,6 +24,19 @@ def _create_tflite_model(): + if not tvm.runtime.enabled("tflite"): + print("skip because tflite runtime is not enabled...") + return + if not tvm.get_global_func("tvm.tflite_runtime.create", True): + print("skip because tflite runtime is not enabled...") + return + + try: + import tensorflow as tf + except ImportError: + print('skip because tensorflow not installed...') + return + root = tf.Module() root.const = tf.constant([1., 2.], tf.float32) root.f = tf.function(lambda x: root.const * x)