Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TFLite] Custom attribute reading and While operation support #17932

Merged
merged 8 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ namespace ov {
namespace frontend {
namespace tensorflow_lite {

using SubGraphFuncs = std::vector<std::function<std::shared_ptr<ov::Model>()>>;
using SubGraphFuncsPtr = std::shared_ptr<SubGraphFuncs>;
jane-intel marked this conversation as resolved.
Show resolved Hide resolved

/// Keep necessary data for a single node in the original FW graph to facilitate
/// conversion process in the rules code.
class TENSORFLOW_LITE_API NodeContext : public ov::frontend::NodeContext {
Expand All @@ -20,43 +23,79 @@ class TENSORFLOW_LITE_API NodeContext : public ov::frontend::NodeContext {
NodeContext(const std::shared_ptr<DecoderBase>& decoder, const OutputVector& inputs)
: ov::frontend::NodeContext(decoder->get_op_type()),
m_decoder(decoder),
m_inputs(inputs) {}
m_inputs(inputs),
m_subgraph_functions(nullptr) {}

NodeContext(const std::shared_ptr<DecoderBase>& decoder,
const OutputVector& inputs,
const SubGraphFuncsPtr& subgraph_functions)
: ov::frontend::NodeContext(decoder->get_op_type()),
m_decoder(decoder),
m_inputs(inputs),
m_subgraph_functions(subgraph_functions) {}

/// \brief Returns a number of inputs
size_t get_input_size() const override {
return m_inputs.size();
}

/// \brief Returns exactly one input with a given idx; throws if there is no inputs or
/// there are more than one input
Output<Node> get_input(int port_index) const override {
return m_inputs.at(port_index);
}

/// Detects if there is at least one input attached with a given name
bool has_input(const size_t& port_index) const {
return port_index < m_inputs.size();
}

Output<Node> get_input(int port_index) const override {
return m_inputs.at(port_index);
/// \brief Get a node name
const std::string& get_name() const override {
return m_decoder->get_op_name();
}

OutputVector get_inputs() const {
return m_inputs;
}

size_t get_input_size() const override {
return m_inputs.size();
/// \brief Returns node attribute by name as ov::Any.
ov::Any get_attribute_as_any(const std::string& name) const override {
return m_decoder->get_attribute(name);
}

/// \brief Get a node name
const std::string& get_name() const override {
return m_decoder->get_op_name();
/// \brief Returns the number of sub-graphs that can be enumerated with get_subgraph
size_t get_subgraph_size() const override {
if (!m_subgraph_functions)
return 0;
return m_subgraph_functions->size();
}

/// \brief Returns subgraph converted on demand by the first access
/// If there is no query for specific sub-graph it shouldn't be converted
/// idx should be in range 0..get_subgraph_size()-1
std::shared_ptr<Model> get_subgraph(int idx) const override {
FRONT_END_GENERAL_CHECK(m_subgraph_functions != nullptr,
"Requested subgraph while subgraphs are not configured");
int size = static_cast<int>(get_subgraph_size());
FRONT_END_GENERAL_CHECK(idx >= 0 && idx < size,
"Incorrect subgraph idx ",
idx,
". There are only ",
get_subgraph_size(),
"subgraphs currently");
return m_subgraph_functions->operator[](idx)();
}

/// \brief Get a decoder
std::shared_ptr<DecoderBase> get_decoder() const {
return m_decoder;
}

ov::Any get_attribute_as_any(const std::string& name) const override {
auto res = m_decoder->get_attribute(name);
return res;
}

private:
std::shared_ptr<DecoderBase> m_decoder;
const OutputVector& m_inputs;
SubGraphFuncsPtr m_subgraph_functions;
};

using CreatorFunction = std::function<ov::OutputVector(const ov::frontend::tensorflow_lite::NodeContext&)>;
Expand Down
40 changes: 37 additions & 3 deletions src/frontends/tensorflow_lite/src/decoder_flatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

#include "decoder_flatbuffer.h"

#ifdef FLATBUFFERS_LOCALE_INDEPENDENT
# undef FLATBUFFERS_LOCALE_INDEPENDENT
#endif
#define FLATBUFFERS_LOCALE_INDEPENDENT 0
#include "flatbuffers/flexbuffers.h"
#include "schema_generated.h"
#include "utils.hpp"

Expand Down Expand Up @@ -87,9 +92,38 @@ std::shared_ptr<ov::frontend::tensorflow_lite::TensorLitePlace> DecoderFlatBuffe
ov::frontend::tensorflow_lite::get_ov_type(tensor->type()),
names,
ov::frontend::tensorflow_lite::get_quantization(tensor->quantization()),
tensor_info.input_idx,
tensor_info.output_idx,
(tensor_info.buffer->data() ? tensor_info.buffer->data()->data() : nullptr));
(tensor_info.buffer && tensor_info.buffer->data() ? tensor_info.buffer->data()->data() : nullptr));
}

ov::Any get_value_as_ov_any(const flexbuffers::Reference& value) {
#define CASE_MACRO(fbt, as_stmt) \
case flexbuffers::fbt: \
return {value.as_stmt()};
switch (value.GetType()) {
CASE_MACRO(FBT_INT, AsInt32)
CASE_MACRO(FBT_INDIRECT_INT, AsInt32)
CASE_MACRO(FBT_UINT, AsUInt32)
CASE_MACRO(FBT_INDIRECT_UINT, AsUInt32)
CASE_MACRO(FBT_FLOAT, AsFloat)
CASE_MACRO(FBT_INDIRECT_FLOAT, AsFloat)
CASE_MACRO(FBT_STRING, AsString)
CASE_MACRO(FBT_BOOL, AsBool)
default:
return {};
}
return {};
}

ov::Any DecoderFlatBuffer::get_attribute(const std::string& name) const {
const auto opts = m_node_def->custom_options();
if (opts == nullptr)
return {};
const flexbuffers::Map& m = flexbuffers::GetRoot(opts->Data(), opts->size()).AsMap();
try {
return get_value_as_ov_any(m[name]);
} catch (...) {
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
return {};
}
}

} // namespace tensorflow_lite
Expand Down
26 changes: 21 additions & 5 deletions src/frontends/tensorflow_lite/src/decoder_flatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace tensorflow_lite {
class TensorLitePlace;
struct TensorInfo;


class DecoderFlatBuffer : public ov::frontend::DecoderBase {
public:
explicit DecoderFlatBuffer(const tflite::Operator* node_def,
Expand All @@ -43,9 +44,7 @@ class DecoderFlatBuffer : public ov::frontend::DecoderBase {
return (opts->*member)();
}

ov::Any get_attribute(const std::string& name) const override {
return {};
}
ov::Any get_attribute(const std::string& name) const override;

size_t get_input_size() const override;
size_t get_output_size() const;
Expand All @@ -68,15 +67,32 @@ class DecoderFlatBuffer : public ov::frontend::DecoderBase {
std::shared_ptr<ov::frontend::tensorflow_lite::TensorLitePlace> decode_output_tensor(size_t idx,
const ov::frontend::InputModel& model) const;

private:
protected:
std::shared_ptr<ov::frontend::tensorflow_lite::TensorLitePlace> decode_tensor(
const ov::frontend::tensorflow_lite::TensorInfo& tensor_info, const InputModel& model) const;
const ov::frontend::tensorflow_lite::TensorInfo& tensor_info, const ov::frontend::InputModel& model) const;

const tflite::Operator* m_node_def;
std::string m_type, m_name;
std::map<size_t, ov::frontend::tensorflow_lite::TensorInfo> m_input_info, m_output_info;
};

class DecoderFlatBufferTensors : public DecoderFlatBuffer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment about why we need this class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could handle this in the next PR.

public:
DecoderFlatBufferTensors(const TensorInfo &info, int64_t input_idx, int64_t output_idx) :
DecoderFlatBuffer(nullptr, "", "", {}, {}), m_info{info}, m_input_idx(input_idx), m_output_idx(output_idx) {};

std::shared_ptr<ov::frontend::tensorflow_lite::TensorLitePlace> decode_tensor(const ov::frontend::InputModel& model) const {
auto tensor = DecoderFlatBuffer::decode_tensor(m_info, model);
tensor->set_input_index(m_input_idx);
tensor->set_output_index(m_output_idx);
return tensor;
}

private:
TensorInfo m_info;
int64_t m_input_idx, m_output_idx;
};

} // namespace tensorflow_lite
} // namespace frontend
} // namespace ov
21 changes: 20 additions & 1 deletion src/frontends/tensorflow_lite/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,22 @@ void FrontEnd::translate_graph(const InputModel::Ptr& model,
const auto& model_lite = std::dynamic_pointer_cast<ov::frontend::tensorflow_lite::InputModel>(model);
FRONT_END_GENERAL_CHECK(model_lite, "nullptr for InputModel is given for translation into OV Model");

auto subgraphs_as_input_models = model_lite->get_subgraphs();
auto input_to_ov_model = [&](const std::shared_ptr<ov::frontend::tensorflow_lite::InputModel>& in_model) {
auto simple_lambda = [&]() -> std::shared_ptr<ov::Model> {
std::shared_ptr<ov::Model> model;
if (in_model)
translate_graph(in_model, fail_fast, no_conversion, model);
return model;
};
return simple_lambda;
};
auto submodel_translation_functions = std::make_shared<std::vector<std::function<std::shared_ptr<ov::Model>()>>>();
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
submodel_translation_functions->reserve(subgraphs_as_input_models.size());
for (const auto& subgraph : subgraphs_as_input_models) {
submodel_translation_functions->emplace_back(input_to_ov_model(subgraph));
}

const auto& translate_map =
no_conversion ? ov::frontend::tensorflow_lite::TranslatorDictionaryType{} : m_op_translators;

Expand Down Expand Up @@ -220,7 +236,7 @@ void FrontEnd::translate_graph(const InputModel::Ptr& model,
FRONT_END_OP_CONVERSION_CHECK(translate_map.count(decoder->get_op_type()),
"No translator found for " + decoder->get_op_type() + " node.");
auto op_fun = &(translate_map.at(decoder->get_op_type()));
ov::frontend::tensorflow_lite::NodeContext node_context(decoder, inputs);
ov::frontend::tensorflow_lite::NodeContext node_context(decoder, inputs, submodel_translation_functions);
ov_outputs = (*op_fun)(node_context);
} catch (...) {
if (fail_fast) {
Expand Down Expand Up @@ -250,6 +266,9 @@ void FrontEnd::translate_graph(const InputModel::Ptr& model,
tensor != nullptr,
"Inputs of ov::frontend::tensorflow_lite::InputModel must be TensorLitePlace instances");
const auto name = tensor->get_names()[0];
if (!all_tensor_values.count(name)) {
continue;
}
const auto& output_value = all_tensor_values[name];
const auto& result = std::make_shared<ov::opset1::Result>(output_value);
auto input = result->output(0);
Expand Down
132 changes: 80 additions & 52 deletions src/frontends/tensorflow_lite/src/graph_iterator_flatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,63 +25,91 @@ GraphIteratorFlatBuffer::GraphIteratorFlatBuffer(const std::string& path) {
model_file.close();

m_model = tflite::GetModel(m_data.data());
const auto subgraphs = m_model->subgraphs();
FRONT_END_GENERAL_CHECK(subgraphs->size() == 1,
"Number of sub-graphs in the model is ",
subgraphs->size(),
". Supported number of sub-graphs is 1.");
const auto graph = *subgraphs->begin();
const auto operators = graph->operators();
m_nodes = {operators->begin(), operators->end()};
auto sub_graphs = m_model->subgraphs();
m_subgraphs = {sub_graphs->begin(), sub_graphs->end()};
m_graph = m_subgraphs[0];
const auto operators = m_graph->operators();
auto operators_vec = std::vector<const tflite::Operator*>{operators->begin(), operators->end()};

m_nodes.assign(operators_vec.begin(), operators_vec.end());
auto outputs = m_graph->outputs();
auto inputs = m_graph->inputs();
m_nodes.insert(m_nodes.begin(), outputs->begin(), outputs->end());
m_nodes.insert(m_nodes.begin(), inputs->begin(), inputs->end());
}

size_t GraphIteratorFlatBuffer::get_subgraph_size() const {
return m_subgraphs.size();
}

std::shared_ptr<GraphIteratorFlatBuffer> GraphIteratorFlatBuffer::get_subgraph(const size_t& idx) const {
FRONT_END_GENERAL_CHECK(m_subgraphs.size() > idx, "There is no subgraph with idx ", idx);
auto iterator = std::make_shared<GraphIteratorFlatBuffer>();
iterator->node_index = 0;
iterator->m_model = m_model;
iterator->m_subgraphs = {}; // TODO: check if we need to pass all sub-graphs here (while in a while situation)
iterator->m_graph = m_subgraphs[idx];
const auto operators = iterator->m_graph->operators();
auto operators_vec = std::vector<const tflite::Operator*>{operators->begin(), operators->end()};
iterator->m_nodes.assign(operators_vec.begin(), operators_vec.end());
auto outputs = iterator->m_graph->outputs();
auto inputs = iterator->m_graph->inputs();
iterator->m_nodes.insert(iterator->m_nodes.begin(), outputs->begin(), outputs->end());
iterator->m_nodes.insert(iterator->m_nodes.begin(), inputs->begin(), inputs->end());
return iterator;
}

std::shared_ptr<DecoderFlatBuffer> GraphIteratorFlatBuffer::get_decoder() const {
auto inputs_vec = (*m_model->subgraphs()->begin())->inputs();
auto outputs_vec = (*m_model->subgraphs()->begin())->outputs();
auto inputs = std::set<int32_t>{inputs_vec->begin(), inputs_vec->end()};
auto outputs = std::set<int32_t>{outputs_vec->begin(), outputs_vec->end()};

auto buffers = m_model->buffers();
auto tensors = m_model->subgraphs()->begin()->tensors();

std::map<size_t, TensorInfo> input_info = {}, output_info = {};
size_t i = 0;
for (auto input : *m_nodes[node_index]->inputs()) {
if (input == -1) {
continue;
auto any_item = m_nodes[node_index];
bool is_op = any_item.is<const tflite::Operator*>();
FRONT_END_GENERAL_CHECK(is_op || any_item.is<int32_t>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please put a message about the error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will handle it in the next PR

auto tensors = m_graph->tensors();

if (is_op) {
auto node = m_nodes[node_index].as<const tflite::Operator*>();
auto buffers = m_model->buffers();

std::map<size_t, TensorInfo> input_info = {}, output_info = {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::map<size_t, TensorInfo> input_info = {}, output_info = {};
std::unordered_map<size_t, TensorInfo> input_info = {}, output_info = {};

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't a change of current PR. Getting decoder for operations was supported before. So I will change it in the next PR as it doesn't break the behavior

size_t i = 0;
for (auto input : *node->inputs()) {
if (input == -1)
continue;
auto buffer = (*buffers)[(*tensors)[input]->buffer()];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very complex construction:)
Can we have a check for each step: 1. input in tensors; 2. (*tensors)[input]->buffer() is valid; 3. (*tensors)[input]->buffer() in buffers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the asserts are in place by flatbuffers.

auto tensor = (*tensors)[input];
input_info[i++] = TensorInfo{tensor, buffer};
}
auto buffer = (*buffers)[(*tensors)[input]->buffer()];
auto is_input = inputs.find(input) != inputs.end();
int64_t input_idx =
!is_input ? -1 : std::find(inputs_vec->begin(), inputs_vec->end(), input) - inputs_vec->begin();
auto is_output = outputs.find(input) != outputs.end();
int64_t output_idx =
!is_output ? -1 : std::find(outputs_vec->begin(), outputs_vec->end(), input) - outputs_vec->begin();
input_info[i++] = TensorInfo{input_idx, output_idx, (*tensors)[input], buffer};
}
i = 0;
// If we have any m_nodes[node_index]->intermediates() than trigger internal smth? no
// put all the info in Decoder as a sub-graph!
i = 0;
for (auto output : *node->outputs()) {
auto buffer = (*buffers)[(*tensors)[output]->buffer()];
auto tensor = (*tensors)[output];
output_info[i++] = TensorInfo{tensor, buffer};
}
auto op_codes = m_model->operator_codes();
auto operator_code = (*op_codes)[node->opcode_index()];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the asserts are in place by flatbuffers.

std::string type;
if (operator_code->deprecated_builtin_code() <
tflite::BuiltinOperator::BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES) {
type = tflite::EnumNamesBuiltinOperator()[operator_code->deprecated_builtin_code()];
} else {
Comment on lines +90 to +93
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment about why deprecated_builtin_code() is used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was committed some time ago. Since this isn't critical for the actual code outcome -- I will add comment in the next PR

type = tflite::EnumNamesBuiltinOperator()[operator_code->builtin_code()];
}
if (type == "CUSTOM") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment what is CUSTOM type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it quite self-explanatory, however if there is a need to explain it even more -- I will add comment in the next PR

type = operator_code->custom_code()->str();
}
return std::make_shared<DecoderFlatBuffer>(node, type, std::to_string(node_index), input_info, output_info);
} else {
auto tensor_id = m_nodes[node_index].as<int32_t>();
auto tensor = (*tensors)[tensor_id];
auto info = TensorInfo{tensor, nullptr};
auto inputs = m_graph->inputs();
auto outputs = m_graph->outputs();

for (auto output : *m_nodes[node_index]->outputs()) {
auto buffer = (*buffers)[(*tensors)[output]->buffer()];
auto is_output = outputs.find(output) != outputs.end();
auto input_it = std::find(inputs->begin(), inputs->end(), tensor_id);
auto output_it = std::find(outputs->begin(), outputs->end(), tensor_id);
int64_t input_idx =
input_it == inputs->end() ? -1 : static_cast<int64_t>(std::distance(inputs->begin(), input_it));
int64_t output_idx =
!is_output ? -1 : std::find(outputs_vec->begin(), outputs_vec->end(), output) - outputs_vec->begin();
output_info[i++] = TensorInfo{-1, output_idx, (*tensors)[output], buffer};
}
auto op_codes = m_model->operator_codes();
auto operator_code = (*op_codes)[m_nodes[node_index]->opcode_index()];
std::string type;
if (operator_code->deprecated_builtin_code() <
tflite::BuiltinOperator::BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES) {
type = tflite::EnumNamesBuiltinOperator()[operator_code->deprecated_builtin_code()];
} else {
type = tflite::EnumNamesBuiltinOperator()[operator_code->builtin_code()];
output_it == outputs->end() ? -1 : static_cast<int64_t>(std::distance(outputs->begin(), output_it));
return std::make_shared<DecoderFlatBufferTensors>(info, input_idx, output_idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we check that input_idx/output_idx is not -1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we handle input and/or output tensors. we explicitly trigger them only for these types of tensors. so having -1 is okay

}
return std::make_shared<DecoderFlatBuffer>(m_nodes[node_index],
type,
std::to_string(node_index),
input_info,
output_info);
}
Loading