Skip to content

Commit

Permalink
TensorFlow FrontEnd Refactoring (#9173)
Browse files Browse the repository at this point in the history
* Move tensorflow fe to openvino subfolder; renaming; refactoring

* codestyle

* delete redundant file

* fix Win build

* fix missprint
  • Loading branch information
itikhono authored Dec 21, 2021
1 parent 0100060 commit 9334f34
Show file tree
Hide file tree
Showing 95 changed files with 549 additions and 528 deletions.
26 changes: 13 additions & 13 deletions src/core/tests/frontend/tensorflow/transpose_sinking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
using namespace std;
using namespace ov;
using namespace opset8;
using namespace frontend::tf::pass;
using namespace frontend::tensorflow::pass;

template <class T>
int64_t count_ops_of_type(const shared_ptr<Model>& f) {
Expand All @@ -27,7 +27,7 @@ int64_t count_ops_of_type(const shared_ptr<Model>& f) {
}

TEST(TransposeSinkingTest, PassProperty) {
auto pass = std::make_shared<TransposeSinkingOVTF>();
auto pass = std::make_shared<TransposeSinking>();
ASSERT_TRUE(pass->get_property(ov::pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_FALSE(pass->get_property(ov::pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
Expand All @@ -51,7 +51,7 @@ TEST(TransposeSinkingTest, EdgeSplitting) {
size_t before_count = count_ops_of_type<Transpose>(func);

ov::pass::Manager pass_manager;
pass_manager.register_pass<TransposeSinkingOVTF>();
pass_manager.register_pass<TransposeSinking>();
pass_manager.run_passes(func);

ASSERT_EQ(before_count, 1);
Expand Down Expand Up @@ -107,7 +107,7 @@ TEST(TransposeSinkingTest, PoolAdd1) {

ov::pass::Manager pass_manager;
size_t before_count = count_ops_of_type<Transpose>(func);
pass_manager.register_pass<TransposeSinkingOVTF>();
pass_manager.register_pass<TransposeSinking>();
pass_manager.run_passes(func);

size_t after_count = count_ops_of_type<Transpose>(func);
Expand Down Expand Up @@ -159,7 +159,7 @@ TEST(TransposeSinkingTest, PoolAdd2) {

ov::pass::Manager pass_manager;
size_t before_count = count_ops_of_type<Transpose>(func); // 3
pass_manager.register_pass<TransposeSinkingOVTF>();
pass_manager.register_pass<TransposeSinking>();
pass_manager.run_passes(func);

size_t after_count = count_ops_of_type<Transpose>(func); // 4
Expand All @@ -171,7 +171,7 @@ TEST(TransposeSinkingTest, PoolAdd2) {
ASSERT_EQ(new_transpose->get_output_shape(0), (ngraph::Shape{1, 3, 3, 1}));
}

// Different rank constant input to Add1. After TransposeSinkingOVTF the const
// Different rank constant input to Add1. After TransposeSinking the const
// would need a Reshape to have the same order as the other input to
// Add1.
TEST(TransposeSinkingTest, PoolAdd3) {
Expand Down Expand Up @@ -203,7 +203,7 @@ TEST(TransposeSinkingTest, PoolAdd3) {

ov::pass::Manager pass_manager;
size_t before_count = count_ops_of_type<Transpose>(func);
pass_manager.register_pass<TransposeSinkingOVTF>();
pass_manager.register_pass<TransposeSinking>();
pass_manager.run_passes(func);

size_t after_count = count_ops_of_type<Transpose>(func);
Expand All @@ -228,7 +228,7 @@ TEST(TransposeSinkingTest, Concat) {
auto func = make_shared<ngraph::Function>(ngraph::OutputVector{c}, ngraph::ParameterVector{a, b});

ov::pass::Manager pass_manager;
pass_manager.register_pass<TransposeSinkingOVTF>();
pass_manager.register_pass<TransposeSinking>();
pass_manager.run_passes(func);

size_t transpose_count = count_ops_of_type<Transpose>(func);
Expand Down Expand Up @@ -260,7 +260,7 @@ TEST(TransposeSinkingTest, Concat_DummyShape) {
auto func = make_shared<ngraph::Function>(ngraph::OutputVector{out}, ngraph::ParameterVector{a1, a2, a3, a4});

ov::pass::Manager pass_manager;
pass_manager.register_pass<TransposeSinkingOVTF>();
pass_manager.register_pass<TransposeSinking>();
pass_manager.run_passes(func);

size_t transpose_count = count_ops_of_type<Transpose>(func); // 1
Expand Down Expand Up @@ -306,7 +306,7 @@ TEST(TransposeSinkingTest, Pad) {

ov::pass::Manager pass_manager;
size_t before_count = count_ops_of_type<Transpose>(func); // 2
pass_manager.register_pass<TransposeSinkingOVTF>();
pass_manager.register_pass<TransposeSinking>();
pass_manager.run_passes(func);

size_t after_count = count_ops_of_type<Transpose>(func); // 2
Expand Down Expand Up @@ -337,7 +337,7 @@ TEST(TransposeSinkingTest, SimpleUnary) {
size_t before_count = count_ops_of_type<Transpose>(func); // 2

ov::pass::Manager pass_manager;
pass_manager.register_pass<TransposeSinkingOVTF>();
pass_manager.register_pass<TransposeSinking>();
pass_manager.run_passes(func);

size_t after_count = count_ops_of_type<Transpose>(func); // 0
Expand Down Expand Up @@ -396,7 +396,7 @@ TEST(TransposeSinkingTest, MultiOutput) {

ov::pass::Manager pass_manager;
size_t before_count = count_ops_of_type<Transpose>(func); // 3
pass_manager.register_pass<TransposeSinkingOVTF>();
pass_manager.register_pass<TransposeSinking>();
pass_manager.run_passes(func);

size_t after_count = count_ops_of_type<Transpose>(func); // 4
Expand Down Expand Up @@ -503,7 +503,7 @@ TEST(TransposeSinkingTest, AlexnetPattern) {

ov::pass::Manager pass_manager;
size_t before_count = count_ops_of_type<Transpose>(func);
pass_manager.register_pass<TransposeSinkingOVTF>();
pass_manager.register_pass<TransposeSinking>();
pass_manager.run_passes(func);

size_t after_count = count_ops_of_type<Transpose>(func);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

#pragma once

#include <openvino/core/any.hpp>

#include "tensorflow_frontend/utility.hpp"
#include "openvino/core/any.hpp"
#include "openvino/frontend/tensorflow/visibility.hpp"

namespace ov {
namespace frontend {
namespace tensorflow {

class TF_API DecoderBase {
class TENSORFLOW_API DecoderBase {
public:
/// \brief Get attribute value by name and requested type
///
Expand Down Expand Up @@ -41,5 +41,7 @@ class TF_API DecoderBase {
/// \brief Destructor
virtual ~DecoderBase() = default;
};

} // namespace tensorflow
} // namespace frontend
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,24 @@
#include "openvino/frontend/extension/telemetry.hpp"
#include "openvino/frontend/frontend.hpp"
#include "openvino/frontend/input_model.hpp"
#include "tensorflow_frontend/utility.hpp"
#include "openvino/frontend/tensorflow/visibility.hpp"

namespace ov {
namespace frontend {
namespace tf {
namespace tensorflow {

class NodeContext;
}
} // namespace frontend
} // namespace ov

namespace ov {
namespace frontend {
class TF_API FrontEndTF : public ov::frontend::FrontEnd {
class TENSORFLOW_API FrontEnd : public ov::frontend::FrontEnd {
public:
using CreatorFunction = std::function<::ov::OutputVector(const ::ov::frontend::tf::NodeContext&)>;
using CreatorFunction = std::function<::ov::OutputVector(const ::ov::frontend::tensorflow::NodeContext&)>;
using TranslatorDictionaryType = std::map<const std::string, const CreatorFunction>;

private:
TranslatorDictionaryType m_op_translators;

public:
FrontEndTF();
FrontEnd();

/// \brief Completely convert the model
/// \return fully converted ov Model
Expand Down Expand Up @@ -86,5 +82,7 @@ class TF_API FrontEndTF : public ov::frontend::FrontEnd {
std::shared_ptr<TelemetryExtension> m_telemetry;
std::vector<std::shared_ptr<DecoderTransformationExtension>> m_transformation_extensions;
};

} // namespace tensorflow
} // namespace frontend
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
#pragma once

#include "openvino/core/any.hpp"
#include "tensorflow_frontend/decoder.hpp"
#include "tensorflow_frontend/utility.hpp"
#include "openvino/frontend/tensorflow/decoder.hpp"
#include "openvino/frontend/visibility.hpp"

namespace ov {
namespace frontend {
namespace tensorflow {

/// Abstract representation for an input model graph that gives nodes in topologically sorted order
class TF_API GraphIterator : ov::RuntimeAttribute {
class TENSORFLOW_API GraphIterator : ov::RuntimeAttribute {
public:
OPENVINO_RTTI("Variant::GraphIterator");

Expand All @@ -35,5 +37,7 @@ class TF_API GraphIterator : ov::RuntimeAttribute {
/// \brief Destructor
virtual ~GraphIterator() = default;
};

} // namespace tensorflow
} // namespace frontend
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
#include "openvino/frontend/exception.hpp"

#ifdef OPENVINO_STATIC_LIBRARY
# define TF_API
# define TF_C_API
# define TENSORFLOW_API
# define TENSORFLOW_C_API
#else
# ifdef ov_tensorflow_frontend_EXPORTS
# define TF_API OPENVINO_CORE_EXPORTS
# define TF_C_API OPENVINO_EXTERN_C OPENVINO_CORE_EXPORTS
# define TENSORFLOW_API OPENVINO_CORE_EXPORTS
# define TENSORFLOW_C_API OPENVINO_EXTERN_C OPENVINO_CORE_EXPORTS
# else
# define TF_API OPENVINO_CORE_IMPORTS
# define TF_C_API OPENVINO_EXTERN_C OPENVINO_CORE_IMPORTS
# define TENSORFLOW_API OPENVINO_CORE_IMPORTS
# define TENSORFLOW_C_API OPENVINO_EXTERN_C OPENVINO_CORE_IMPORTS
# endif // ov_tensorflow_frontend_EXPORTS
#endif // OPENVINO_STATIC_LIBRARY
20 changes: 10 additions & 10 deletions src/frontends/tensorflow/src/decoder_proto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace ov {
namespace frontend {
namespace tf {
namespace tensorflow {

namespace {
const std::map<::tensorflow::DataType, ov::element::Type>& TYPE_MAP() {
Expand All @@ -27,7 +27,7 @@ const std::map<::tensorflow::DataType, ov::element::Type>& TYPE_MAP() {
}
} // namespace

ov::Any DecoderTFProto::get_attribute(const std::string& name, const std::type_info& type_info) const {
ov::Any DecoderProto::get_attribute(const std::string& name, const std::type_info& type_info) const {
auto attrs = decode_attribute_helper(name);
if (attrs.empty()) {
return {};
Expand Down Expand Up @@ -85,13 +85,13 @@ ov::Any DecoderTFProto::get_attribute(const std::string& name, const std::type_i
return {};
}

size_t DecoderTFProto::get_input_size() const {
size_t DecoderProto::get_input_size() const {
return m_node_def->input_size();
}

void DecoderTFProto::get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index) const {
void DecoderProto::get_input_node(size_t input_port_idx,
std::string& producer_name,
size_t& producer_output_port_index) const {
// TODO: handle body graph nodes with a couple of columns
std::string producer_port_name = m_node_def->input(input_port_idx);
auto delim_pos = producer_port_name.find(':');
Expand All @@ -104,15 +104,15 @@ void DecoderTFProto::get_input_node(size_t input_port_idx,
producer_output_port_index = 0;
}

const std::string& DecoderTFProto::get_op_type() const {
const std::string& DecoderProto::get_op_type() const {
return m_node_def->op();
}

const std::string& DecoderTFProto::get_op_name() const {
const std::string& DecoderProto::get_op_name() const {
return m_node_def->name();
}

std::vector<::tensorflow::AttrValue> DecoderTFProto::decode_attribute_helper(const std::string& name) const {
std::vector<::tensorflow::AttrValue> DecoderProto::decode_attribute_helper(const std::string& name) const {
auto attr_map = m_node_def->attr();
FRONT_END_GENERAL_CHECK(attr_map.contains(name),
"An error occurred while parsing the ",
Expand All @@ -123,6 +123,6 @@ std::vector<::tensorflow::AttrValue> DecoderTFProto::decode_attribute_helper(con
auto value = m_node_def->attr().at(name);
return {value};
}
} // namespace tf
} // namespace tensorflow
} // namespace frontend
} // namespace ov
10 changes: 5 additions & 5 deletions src/frontends/tensorflow/src/decoder_proto.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@

#include "attr_value.pb.h"
#include "node_def.pb.h"
#include "tensorflow_frontend/decoder.hpp"
#include "openvino/frontend/tensorflow/decoder.hpp"
#include "types.pb.h"

namespace ov {
namespace frontend {
namespace tf {
namespace tensorflow {

class DecoderTFProto : public DecoderBase {
class DecoderProto : public ov::frontend::tensorflow::DecoderBase {
public:
explicit DecoderTFProto(const ::tensorflow::NodeDef* node_def) : m_node_def(node_def) {}
explicit DecoderProto(const ::tensorflow::NodeDef* node_def) : m_node_def(node_def) {}

ov::Any get_attribute(const std::string& name, const std::type_info& type_info) const override;

Expand All @@ -36,6 +36,6 @@ class DecoderTFProto : public DecoderBase {
std::vector<::tensorflow::AttrValue> decode_attribute_helper(const std::string& name) const;
const ::tensorflow::NodeDef* m_node_def;
};
} // namespace tf
} // namespace tensorflow
} // namespace frontend
} // namespace ov
6 changes: 3 additions & 3 deletions src/frontends/tensorflow/src/exceptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

namespace ov {
namespace frontend {
namespace tf {
std::string OpValidationFailureTF::get_error_msg_prefix_tf(const tf::NodeContext& node) {
namespace tensorflow {
std::string OpValidationFailure::get_error_msg_prefix_tf(const tensorflow::NodeContext& node) {
std::stringstream ss;
ss << "While validating node '" << node.get_op_type() << '\'';
return ss.str();
}
} // namespace tf
} // namespace tensorflow
} // namespace frontend
} // namespace ov
16 changes: 8 additions & 8 deletions src/frontends/tensorflow/src/exceptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@

namespace ov {
namespace frontend {
namespace tf {
namespace tensorflow {

class NodeContext;

class OpValidationFailureTF : public ov::frontend::OpValidationFailure {
class OpValidationFailure : public ov::frontend::OpValidationFailure {
public:
OpValidationFailureTF(const CheckLocInfo& check_loc_info, const NodeContext& node, const std::string& explanation)
: OpValidationFailure(check_loc_info, get_error_msg_prefix_tf(node), explanation) {}
OpValidationFailure(const CheckLocInfo& check_loc_info, const NodeContext& node, const std::string& explanation)
: ov::frontend::OpValidationFailure(check_loc_info, get_error_msg_prefix_tf(node), explanation) {}

private:
static std::string get_error_msg_prefix_tf(const NodeContext& node);
};
} // namespace tf
} // namespace tensorflow
} // namespace frontend

/// \brief Macro to check whether a boolean condition holds.
Expand All @@ -28,7 +28,7 @@ class OpValidationFailureTF : public ov::frontend::OpValidationFailure {
/// \param ... Additional error message info to be added to the error message via the `<<`
/// stream-insertion operator. Note that the expressions here will be evaluated lazily,
/// i.e., only if the `cond` evalutes to `false`.
/// \throws ::ov::OpValidationFailureTF if `cond` is false.
#define TF_OP_VALIDATION_CHECK(node_context, ...) \
OPENVINO_ASSERT_HELPER(::ov::frontend::tf::OpValidationFailureTF, (node_context), __VA_ARGS__)
/// \throws ::ov::OpValidationFailure if `cond` is false.
#define TENSORFLOW_OP_VALIDATION(node_context, ...) \
OPENVINO_ASSERT_HELPER(::ov::frontend::tensorflow::OpValidationFailure, (node_context), __VA_ARGS__)
} // namespace ov
Loading

0 comments on commit 9334f34

Please sign in to comment.