Skip to content

Commit

Permalink
Keep backing buffer alive for tflite models
Browse files Browse the repository at this point in the history
  • Loading branch information
michalpiszczek committed May 13, 2020
1 parent 079978e commit 2a34bb2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/runtime/contrib/tflite/tflite_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,12 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) {
void TFLiteRuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) {
const char* buffer = tflite_model_bytes.c_str();
size_t buffer_size = tflite_model_bytes.size();
// The buffer used to construct the model must be kept alive for
// dependent interpreters to be used.
flatBuffersBuffer_ = std::unique_ptr<char[]>(new char[buffer_size]);
std::memcpy(flatBuffersBuffer_.get(), buffer, buffer_size);
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
tflite::FlatBufferModel::BuildFromBuffer(flatBuffersBuffer_.get(), buffer_size);
tflite::ops::builtin::BuiltinOpResolver resolver;
// Build interpreter
TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/contrib/tflite/tflite_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class TFLiteRuntime : public ModuleNode {
*/
NDArray GetOutput(int index) const;

// Buffer backing the interpreter's model
std::unique_ptr<char[]> flatBuffersBuffer_;
// TFLite interpreter
std::unique_ptr<tflite::Interpreter> interpreter_;
// TVM context
Expand Down

0 comments on commit 2a34bb2

Please sign in to comment.