diff --git a/CMakeLists.txt b/CMakeLists.txt index bf18ffc9e856..bb21db26bad6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,6 +63,8 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF) tvm_option(USE_RANDOM "Build with random support" OFF) tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF) tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) +tvm_option(USE_TFLITE "Build with tflite support" OFF) +tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none) # include directories include_directories(${CMAKE_INCLUDE_PATH}) @@ -257,6 +259,7 @@ include(cmake/modules/contrib/MicroStandaloneRuntime.cmake) include(cmake/modules/contrib/Sort.cmake) include(cmake/modules/contrib/NNPack.cmake) include(cmake/modules/contrib/HybridDump.cmake) +include(cmake/modules/contrib/TFLite.cmake) if(NOT MSVC) include(CheckCXXCompilerFlag) diff --git a/cmake/config.cmake b/cmake/config.cmake index 1ef956c7ee18..25bf5516291b 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -145,6 +145,15 @@ set(USE_RANDOM OFF) # Whether use NNPack set(USE_NNPACK OFF) +# Possible values: +# - ON: enable tflite with cmake's find search +# - OFF: disable tflite +# - /path/to/libtensorflow-lite.a: use specific path to tensorflow lite library +set(USE_TFLITE OFF) + +# /path/to/tensorflow: tensorflow root path when use tflite library +set(USE_TENSORFLOW_PATH none) + # Whether use CuDNN set(USE_CUDNN OFF) diff --git a/cmake/modules/contrib/TFLite.cmake b/cmake/modules/contrib/TFLite.cmake new file mode 100644 index 000000000000..9074def9dc8e --- /dev/null +++ b/cmake/modules/contrib/TFLite.cmake @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(NOT USE_TFLITE STREQUAL "OFF") + message(STATUS "Build with contrib.tflite") + if (USE_TENSORFLOW_PATH STREQUAL "none") + set(USE_TENSORFLOW_PATH ${CMAKE_CURRENT_SOURCE_DIR}/tensorflow) + endif() + + file(GLOB TFLITE_CONTRIB_SRC src/runtime/contrib/tflite/*.cc) + list(APPEND RUNTIME_SRCS ${TFLITE_CONTRIB_SRC}) + include_directories(${USE_TENSORFLOW_PATH}) + + if (USE_TFLITE STREQUAL "ON") + set(USE_TFLITE ${USE_TENSORFLOW_PATH}/tensorflow/lite/tools/make/gen/*/lib) + endif() + find_library(TFLITE_CONTRIB_LIB libtensorflow-lite.a ${USE_TFLITE}) + + list(APPEND TVM_RUNTIME_LINKER_LIBS ${TFLITE_CONTRIB_LIB}) + list(APPEND TVM_RUNTIME_LINKER_LIBS rt dl flatbuffers) +endif() diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py new file mode 100644 index 000000000000..89a547f48f96 --- /dev/null +++ b/python/tvm/contrib/tflite_runtime.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TFLite runtime that load and run tflite models.""" +from .._ffi.function import get_global_func +from ..rpc import base as rpc_base + +def create(tflite_model_bytes, ctx): + """Create a runtime executor module given a tflite model and context. + Parameters + ---------- + tflite_model_byte : bytes + The tflite model to be deployed in bytes string format. + ctx : TVMContext + The context to deploy the module. It can be local or remote when there + is only one TVMContext. + Returns + ------- + tflite_runtime : TFLiteModule + Runtime tflite module that can be used to execute the tflite model. + """ + device_type = ctx.device_type + if device_type >= rpc_base.RPC_SESS_MASK: + fcreate = ctx._rpc_sess.get_function("tvm.tflite_runtime.create") + return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx)) + fcreate = get_global_func("tvm.tflite_runtime.create") + return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx)) + + +class TFLiteModule(object): + """Wrapper runtime module. + + This is a thin wrapper of the underlying TVM module. + you can also directly call set_input, run, and get_output + of underlying module functions + + Parameters + ---------- + module : Module + The interal tvm module that holds the actual tflite functions. + + Attributes + ---------- + module : Module + The interal tvm module that holds the actual tflite functions. + """ + + def __init__(self, module): + self.module = module + self._set_input = module["set_input"] + self._invoke = module["invoke"] + self._get_output = module["get_output"] + self._allocate_tensors = module["allocate_tensors"] + + def set_input(self, index, value): + """Set inputs to the module via kwargs + + Parameters + ---------- + key : int or str + The input key + + value : the input value. + The input key + + params : dict of str to NDArray + Additonal arguments + """ + self._set_input(index, value) + + def invoke(self): + """Invoke forward execution of the model + + Parameters + ---------- + input_dict: dict of str to NDArray + List of input values to be feed to + """ + self._invoke() + + def allocate_tensors(self): + """Allocate space for all tensors. + """ + self._allocate_tensors() + + + def get_output(self, index): + """Get index-th output to out + + Parameters + ---------- + index : int + The output index + """ + return self._get_output(index) diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc new file mode 100644 index 000000000000..a32669d5f635 --- /dev/null +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tflite_runtime.cc + */ +#include +#include +#include +#include +#include + + +#include "tflite_runtime.h" + +namespace tvm { +namespace runtime { + +#define TVM_DTYPE_DISPATCH(type, DType, ...) \ + if (type == Float(64)) { \ + typedef double DType; \ + {__VA_ARGS__} \ + } else if (type == Float(32)) { \ + typedef float DType; \ + {__VA_ARGS__} \ + } else if (type == Float(16)) { \ + typedef uint16_t DType; \ + {__VA_ARGS__} \ + } else if (type == Int(64)) { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } else if (type == Int(32)) { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } else if (type == Int(16)) { \ + typedef int16_t DType; \ + {__VA_ARGS__} \ + } else if (type == Int(8)) { \ + typedef int8_t DType; \ + {__VA_ARGS__} \ + } else if (type == UInt(64)) { \ + typedef uint64_t DType; \ + {__VA_ARGS__} \ + } else if (type == UInt(32)) { \ + typedef uint32_t DType; \ + {__VA_ARGS__} \ + } else if (type == UInt(16)) { \ + typedef uint16_t DType; \ + {__VA_ARGS__} \ + } else if (type == UInt(8)) { \ + typedef uint8_t DType; \ + {__VA_ARGS__} \ + } else { \ + LOG(FATAL) << "unknown data type " << type; \ + } + +DataType TfLiteDType2TVMDType(TfLiteType dtype) { + switch (dtype) { + case kTfLiteFloat32: + return Float(32); + case kTfLiteInt32: + return Int(32); + case kTfLiteInt64: + return Int(64); + case kTfLiteInt16: + return Int(16); + case kTfLiteInt8: + return Int(8); + case kTfLiteUInt8: + return UInt(8); + case kTfLiteFloat16: + return Float(16); + default: + LOG(FATAL) << "tflite data type not support yet: " << dtype; + return Float(32); + } +} + + +void TFLiteRuntime::Init(const std::string& tflite_model_bytes, + TVMContext ctx) { + const char* buffer = tflite_model_bytes.c_str(); + size_t buffer_size = tflite_model_bytes.size(); + std::unique_ptr model = + tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); + tflite::ops::builtin::BuiltinOpResolver resolver; + tflite::InterpreterBuilder(*model, resolver)(&interpreter_); + ctx_ = ctx; +} + +void TFLiteRuntime::AllocateTensors() { + interpreter_->AllocateTensors(); +} + +void TFLiteRuntime::Invoke() { + interpreter_->Invoke(); +} + +void TFLiteRuntime::SetInput(int index, DLTensor* data_in) { + DataType dtype(data_in->dtype); + TVM_DTYPE_DISPATCH(dtype, DType, { + DType* dest = interpreter_->typed_input_tensor(index); + DType* src = static_cast(data_in->data); + CHECK(data_in->strides == NULL); + int64_t size = 1; + for (int64_t i = 0; i < data_in->ndim; ++i) { + size *= data_in->shape[i]; + } + for (int64_t i = 0; i < size; ++i) { + dest[i] = src[i]; + } + }); +} + +NDArray TFLiteRuntime::GetOutput(int index) const { + TfLiteTensor* output = interpreter_->output_tensor(index); + DataType dtype = TfLiteDType2TVMDType(output->type); + TfLiteIntArray* dims = output->dims; + int64_t size = 1; + std::vector shape; + for (int i = 0; i < dims->size; ++i) { + shape.push_back(dims->data[i]); + size *= dims->data[i]; + } + NDArray ret = NDArray::Empty(shape, dtype, ctx_); + TVM_DTYPE_DISPATCH(dtype, DType, { + DType* dest = static_cast(ret->data); + DType* src = interpreter_->typed_output_tensor(index); + for (int64_t i = 0; i < size; ++i) { + dest[i] = src[i]; + } + }); + return ret; +} + +PackedFunc TFLiteRuntime::GetFunction( + const std::string& name, + const ObjectPtr& sptr_to_self) { + // Return member functions during query. + if (name == "set_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + int in_idx = args[0]; + CHECK_GE(in_idx, 0); + this->SetInput(in_idx, args[1]); + }); + } else if (name == "get_output") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetOutput(args[0]); + }); + } else if (name == "invoke") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + this->Invoke(); + }); + } else if (name == "allocate_tensors") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + this->AllocateTensors(); + }); + } else { + return PackedFunc(); + } +} + +Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, + TVMContext ctx) { + auto exec = make_object(); + exec->Init(tflite_model_bytes, ctx); + return Module(exec); +} + +TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create") + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = TFLiteRuntimeCreate(args[0], args[1]); + }); +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h new file mode 100644 index 000000000000..4b08b97b6865 --- /dev/null +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Tflite runtime that can run tflite model + * containing only tvm PackedFunc. + * \file tflite_runtime.h + */ +#ifndef TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_ +#define TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_ + +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace runtime { + + +/*! + * \brief Tflite runtime. + * + * This runtime can be acccesibly in various language via + * TVM runtime PackedFunc API. + */ +class TFLiteRuntime : public ModuleNode { + public: + /*! + * \brief Get member function to front-end + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding member function. + */ + virtual PackedFunc GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self); + + /*! + * \return The type key of the executor. + */ + const char* type_key() const final { + return "TFLiteRuntime"; + } + + /*! + * \brief Update allocations for all tenssors. This is relatively expensive. + */ + void AllocateTensors(); + /*! + * \brief Invoke the internal tflite interpreter and run the whole model in + * dependency order. + */ + void Invoke(); + + /*! + * \brief Initialize the tflite runtime with tflite model and context. + * \param tflite_model_bytes The tflite model. + * \param ctx The context where the tflite model will be executed on. + */ + void Init(const std::string& tflite_model_bytes, + TVMContext ctx); + + /*! + * \brief set index-th input to the model. + * \param index The input index. + * \param data_in The input data. + */ + void SetInput(int index, DLTensor* data_in); + /*! + * \brief Return NDArray for given input index. + * \param index The input index. + * + * \return NDArray corresponding to given input node index. + */ + NDArray GetInput(int index) const; + /*! + * \brief Return NDArray for given output index. + * \param index The output index. + * + * \return NDArray corresponding to given output node index. + */ + NDArray GetOutput(int index) const; + + private: + std::unique_ptr interpreter_; + TVMContext ctx_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_ diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py new file mode 100644 index 000000000000..e8bc66300e1a --- /dev/null +++ b/tests/python/contrib/test_tflite_runtime.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import numpy as np +from tvm import rpc +from tvm.contrib import util, tflite_runtime +# import tensorflow as tf +# import tflite_runtime.interpreter as tflite + + +def skipped_test_tflite_runtime(): + + def create_tflite_model(): + root = tf.Module() + root.const = tf.constant([1., 2.], tf.float32) + root.f = tf.function(lambda x: root.const * x) + + input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32) + concrete_func = root.f.get_concrete_function(input_signature) + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + tflite_model = converter.convert() + return tflite_model + + + def check_verify(): + tflite_fname = "model.tflite" + tflite_model = create_tflite_model() + temp = util.tempdir() + tflite_model_path = temp.relpath(tflite_fname) + print(tflite_model_path) + open(tflite_model_path, 'wb').write(tflite_model) + + # inference via tflite interpreter python apis + print('interpreter') + interpreter = tflite.Interpreter(model_path=tflite_model_path) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_shape = input_details[0]['shape'] + tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], tflite_input) + interpreter.invoke() + tflite_output = interpreter.get_tensor(output_details[0]['index']) + + print('tvm tflite runtime') + # inference via tvm tflite runtime + with open(tflite_model_path, 'rb') as model_fin: + runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0)) + runtime.allocate_tensors() + runtime.set_input(0, tvm.nd.array(tflite_input)) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.asnumpy(), tflite_output) + + + def check_remote(): + tflite_fname = "model.tflite" + tflite_model = create_tflite_model() + temp = util.tempdir() + tflite_model_path = temp.relpath(tflite_fname) + open(tflite_model_path, 'wb').write(tflite_model) + + # inference via tflite interpreter python apis + interpreter = tflite.Interpreter(model_path=tflite_model_path) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_shape = input_details[0]['shape'] + tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], tflite_input) + interpreter.invoke() + tflite_output = interpreter.get_tensor(output_details[0]['index']) + + # inference via remote tvm tflite runtime + server = rpc.Server("localhost") + remote = rpc.connect(server.host, server.port) + ctx = remote.cpu(0) + a = remote.upload(tflite_model_path) + + with open(tflite_model_path, 'rb') as model_fin: + runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) + runtime.allocate_tensors() + runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.asnumpy(), tflite_output) + + + check_verify() + check_remote() + +if __name__ == "__main__": + # skipped_test_tflite_runtime() + pass