Skip to content

Commit

Permalink
3
Browse files Browse the repository at this point in the history
  • Loading branch information
nshchego committed Oct 30, 2024
1 parent d67237b commit c432ec7
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 155 deletions.
21 changes: 11 additions & 10 deletions src/core/include/openvino/core/attribute_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <string>
#include <unordered_map>
#include <utility>
#include <deque>

#include "openvino/core/type.hpp"

Expand Down Expand Up @@ -99,13 +100,13 @@ class OPENVINO_API AttributeVisitor {

/// The generic visitor. There must be a definition of AttributeAdapter<T> that can convert
/// to a ValueAccessor<U> for one of the on_adpater methods.
//template <typename AT>
//void on_attribute(const char* name, AT& value) {
// AttributeAdapter<AT> adapter(value);
// start_structure(name);
// on_adapter(get_name_with_context(), adapter);
// finish_structure();
//}
template <typename AT>
void on_attribute(const char* name, AT& value) {
AttributeAdapter<AT> adapter(value);
start_structure(name);
on_adapter(get_name_with_context(), adapter);
finish_structure();
}
template <typename AT>
void on_attribute(const std::string& name, AT& value) {
AttributeAdapter<AT> adapter(value);
Expand All @@ -114,14 +115,14 @@ class OPENVINO_API AttributeVisitor {
finish_structure();
}
/// \returns The nested context of visits
const std::vector<std::string>& get_context() const {
const std::deque<std::string>& get_context() const {
return m_context;
}
/// \returns context prepended to names
virtual std::string get_name_with_context();
/// \brief Start visiting a nested structure
virtual void start_structure(const std::string& name);
//virtual void start_structure(const char* name);
virtual void start_structure(const char* name);
/// \brief Finish visiting a nested structure
virtual std::string finish_structure();
using node_id_t = std::string;
Expand All @@ -137,7 +138,7 @@ class OPENVINO_API AttributeVisitor {
virtual node_id_t get_registered_node_id(const std::shared_ptr<Node>& node);

protected:
std::vector<std::string> m_context;
std::deque<std::string> m_context;
std::unordered_map<std::shared_ptr<Node>, node_id_t> m_node_id_map;
std::unordered_map<node_id_t, std::shared_ptr<Node>> m_id_node_map;
};
Expand Down
117 changes: 79 additions & 38 deletions src/core/src/attribute_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,71 @@
#include "openvino/core/model.hpp"
#include "openvino/core/node.hpp"

using namespace std;

void ov::AttributeVisitor::start_structure(const string& name) {
void ov::AttributeVisitor::start_structure(const std::string& name) {
m_context.push_back(name);
}

//void ov::AttributeVisitor::start_structure(const char* name) {
// m_context.push_back(name);
//}
void ov::AttributeVisitor::start_structure(const char* name) {
m_context.push_back(name);
}

string ov::AttributeVisitor::finish_structure() {
string result = m_context.back();
std::string ov::AttributeVisitor::finish_structure() {
std::string result = m_context.back();
m_context.pop_back();
return result;
}

string ov::AttributeVisitor::get_name_with_context() {
ostringstream result;
string sep = "";
std::string ov::AttributeVisitor::get_name_with_context() {
//std::stringstream result;
//result.iword(512);
//static const char sep = '.';
//for (const auto& c : m_context) {
// result << c << sep;
//}
//auto strt = result.str();
//strt.pop_back();
//return strt;


//std::cout << "get_name_with_context: \"" << strt << "\" tellp: " << result.tellp() <<
// std::endl;
//const std::streamsize size = static_cast<std::streamsize>(result.tellp());
//std::string res;
//res.reserve(size + 5);
//char tmp[5];
//char* tmp = &(res[0]);
//result.get(tmp, size);
//result.read(tmp, size);
//result.get(&(res[0]), size - 1);
//result.get(const_cast<char*>(res.data()), size - 1l);
//std::cout << "get_name_with_context: \"" << strt << "\" size: " << strt.size() << std::endl;
//printf("get_name_with_context: \"%s\"; tellp: \"%lu\"\n", result.str().data(), result.tellp());
//return strt;


//std::ostringstream result;
//std::string sep = "";
//for (const auto& c : m_context) {
// result << sep << c;
// sep = ".";
//}
//return result.str();


std::string result;
//std::cout << "result capacity: " << result.capacity() << std::endl;
//result.reserve(64);
static const char sep = '.';
for (const auto& c : m_context) {
result << sep << c;
sep = ".";
result.append(c).push_back(sep);
}
return result.str();
result.pop_back();
//if (result.size() >= 128) {
//std::cout << "result: \"" << result << "\"; size: " << result.size() << std::endl;
//}

return result;
}

void ov::AttributeVisitor::on_adapter(const std::string& name, VisitorAdapter& adapter) {
Expand All @@ -42,99 +83,99 @@ void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<voi
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<string>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<std::string>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
};

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<bool>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<bool>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
};

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<int8_t>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<int8_t>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<int16_t>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<int16_t>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<int32_t>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<int32_t>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<int64_t>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<uint8_t>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<uint8_t>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<uint16_t>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<uint16_t>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<uint32_t>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<uint32_t>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<uint64_t>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<uint64_t>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<float>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<float>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<double>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<double>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<std::vector<int8_t>>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<std::vector<int8_t>>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<std::vector<int16_t>>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<std::vector<int16_t>>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<std::vector<int32_t>>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<std::vector<int32_t>>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<std::vector<int64_t>>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<std::vector<int64_t>>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<std::vector<uint8_t>>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<std::vector<uint8_t>>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<std::vector<uint16_t>>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<std::vector<uint16_t>>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<std::vector<uint32_t>>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<std::vector<uint32_t>>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<std::vector<uint64_t>>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<std::vector<uint64_t>>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<std::vector<float>>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<std::vector<float>>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<std::vector<double>>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<std::vector<double>>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<std::vector<string>>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<std::vector<std::string>>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

void ov::AttributeVisitor::on_adapter(const string& name, ValueAccessor<std::shared_ptr<ov::Model>>& adapter) {
void ov::AttributeVisitor::on_adapter(const std::string& name, ValueAccessor<std::shared_ptr<ov::Model>>& adapter) {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}

Expand All @@ -150,7 +191,7 @@ void ov::AttributeVisitor::register_node(const std::shared_ptr<ov::Node>& node,

std::shared_ptr<ov::Node> ov::AttributeVisitor::get_registered_node(node_id_t id) {
auto it = m_id_node_map.find(id);
return it == m_id_node_map.end() ? shared_ptr<ov::Node>() : it->second;
return it == m_id_node_map.end() ? std::shared_ptr<ov::Node>() : it->second;
}

ov::AttributeVisitor::node_id_t ov::AttributeVisitor::get_registered_node_id(const std::shared_ptr<ov::Node>& node) {
Expand Down
48 changes: 24 additions & 24 deletions src/core/src/op/util/op_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,60 +26,60 @@
#include "openvino/op/xor.hpp"

bool ov::op::util::is_unary_elementwise_arithmetic(const ov::Node* node) {
return dynamic_cast<const ov::op::util::UnaryElementwiseArithmetic*>(node) != nullptr;
return ov::is_type<ov::op::util::UnaryElementwiseArithmetic>(node);
}

bool ov::op::util::is_binary_elementwise_arithmetic(const ov::Node* node) {
return dynamic_cast<const ov::op::util::BinaryElementwiseArithmetic*>(node) != nullptr;
return ov::is_type<ov::op::util::BinaryElementwiseArithmetic>(node);
}

bool ov::op::util::is_binary_elementwise_comparison(const ov::Node* node) {
return dynamic_cast<const ov::op::util::BinaryElementwiseComparison*>(node) != nullptr;
return ov::is_type<ov::op::util::BinaryElementwiseComparison>(node);
}

bool ov::op::util::is_binary_elementwise_logical(const ov::Node* node) {
return dynamic_cast<const ov::op::util::BinaryElementwiseLogical*>(node) != nullptr;
return ov::is_type<ov::op::util::BinaryElementwiseLogical>(node);
}

bool ov::op::util::supports_auto_broadcast(const ov::Node* node) {
return dynamic_cast<const ov::op::v1::Select*>(node) != nullptr ||
dynamic_cast<const ov::op::v0::SquaredDifference*>(node) != nullptr ||
dynamic_cast<const ov::op::util::BinaryElementwiseComparison*>(node) != nullptr ||
dynamic_cast<const ov::op::util::BinaryElementwiseLogical*>(node) != nullptr ||
dynamic_cast<const ov::op::util::BinaryElementwiseArithmetic*>(node) != nullptr;
return ov::is_type<ov::op::v1::Select>(node) ||
ov::is_type<ov::op::v0::SquaredDifference>(node) ||
ov::is_type<ov::op::util::BinaryElementwiseComparison>(node) ||
ov::is_type<ov::op::util::BinaryElementwiseLogical>(node) ||
ov::is_type<ov::op::util::BinaryElementwiseArithmetic>(node);
}

bool ov::op::util::is_op(const ov::Node* node) {
return dynamic_cast<const ov::op::Op*>(node) != nullptr;
return ov::is_type<ov::op::Op>(node);
}

bool ov::op::util::is_parameter(const ov::Node* node) {
return dynamic_cast<const ov::op::v0::Parameter*>(node) != nullptr;
return ov::is_type<ov::op::v0::Parameter>(node);
}

bool ov::op::util::is_output(const ov::Node* node) {
return dynamic_cast<const ov::op::v0::Result*>(node) != nullptr;
return ov::is_type<ov::op::v0::Result>(node);
}

bool ov::op::util::is_sink(const ov::Node* node) {
return dynamic_cast<const ov::op::Sink*>(node) != nullptr;
return ov::is_type<ov::op::Sink>(node);
}

bool ov::op::util::is_constant(const ov::Node* node) {
return dynamic_cast<const ov::op::v0::Constant*>(node) != nullptr;
return ov::is_type<ov::op::v0::Constant>(node);
}

bool ov::op::util::is_commutative(const ov::Node* node) {
return dynamic_cast<const ov::op::v1::Add*>(node) != nullptr ||
dynamic_cast<const ov::op::v1::Maximum*>(node) != nullptr ||
dynamic_cast<const ov::op::v1::Equal*>(node) != nullptr ||
dynamic_cast<const ov::op::v1::NotEqual*>(node) != nullptr ||
dynamic_cast<const ov::op::v1::LogicalAnd*>(node) != nullptr ||
dynamic_cast<const ov::op::v0::Xor*>(node) != nullptr ||
dynamic_cast<const ov::op::v1::LogicalXor*>(node) != nullptr ||
dynamic_cast<const ov::op::v1::Minimum*>(node) != nullptr ||
dynamic_cast<const ov::op::v1::Multiply*>(node) != nullptr ||
dynamic_cast<const ov::op::v1::LogicalOr*>(node) != nullptr;
return ov::is_type<ov::op::v1::Add>(node) ||
ov::is_type<ov::op::v1::Equal>(node) ||
ov::is_type<ov::op::v1::LogicalAnd>(node) ||
ov::is_type<ov::op::v1::LogicalOr>(node) ||
ov::is_type<ov::op::v1::LogicalXor>(node) ||
ov::is_type<ov::op::v1::Maximum>(node) ||
ov::is_type<ov::op::v1::Minimum>(node) ||
ov::is_type<ov::op::v1::Multiply>(node) ||
ov::is_type<ov::op::v1::NotEqual>(node) ||
ov::is_type<ov::op::v0::Xor>(node);
}

bool ov::op::util::is_unary_elementwise_arithmetic(const std::shared_ptr<ov::Node>& node) {
Expand Down
Loading

0 comments on commit c432ec7

Please sign in to comment.