Skip to content

Commit

Permalink
Serialization of old API map in nGraph. (#7840)
Browse files Browse the repository at this point in the history
* Added serialization of old API map in ngraph.

* Changed order type to int64_t.

* Fixed uint64_t error, added comments.

* Apply suggestions from code review

Co-authored-by: Gleb Kazantaev <[email protected]>

* Added tests with undefined type and empty order.

* Added set, get and has methods.

* Fix in tests.

* Apply suggestions from code review

Co-authored-by: Ilya Churaev <[email protected]>

* Made inline methods, changed to shared_ptr.

* Small fix.

* Moved methods to header file.

* Small fix.

Co-authored-by: Gleb Kazantaev <[email protected]>
Co-authored-by: Ilya Churaev <[email protected]>
  • Loading branch information
3 people authored Oct 6, 2021
1 parent e20cefb commit 623117f
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <transformations/rt_info/disable_constant_folding.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp>
#include <transformations/rt_info/nms_selected_indices.hpp>
#include <transformations/rt_info/old_api_map_attribute.hpp>
#include <transformations/rt_info/primitives_priority_attribute.hpp>
#include <transformations/rt_info/strides_property.hpp>

Expand All @@ -32,6 +33,7 @@ class TRANSFORMATIONS_API Attributes {
register_factory<DisableConstantFolding>();
register_factory<NmsSelectedIndices>();
register_factory<StridesPropagation>();
register_factory<OldApiMap>();
}

Variant * create_by_type_info(const ov::DiscreteTypeInfo & type_info) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

/**
* @brief Defines old API map attribute
* @file old_api_map_attribute.hpp
*/

#pragma once

#include <assert.h>

#include <functional>
#include <memory>
#include <ngraph/attribute_visitor.hpp>
#include <ngraph/node.hpp>
#include <ngraph/variant.hpp>
#include <openvino/core/rtti.hpp>
#include <set>
#include <string>
#include <transformations_visibility.hpp>
#include <utility>

namespace ov {

class OldApiMap;
/**
* @ingroup ie_runtime_attr_api
* @brief OldApiMapAttr class stores the value of OldApiMap class.
*
* OldApiMap stores the following information.
* Parameter:
* Order of the transpose which should be applied to Parameter with old API layout to
* obtain Parameter with new API layout.
* Element type of the Parameter in old API.
*
* Result:
* Order of the transpose which should be applied to Result with new API layout to
* obtain Result with old API layout.
*
*/
class TRANSFORMATIONS_API OldApiMapAttr {
private:
std::vector<uint64_t> m_order;
ngraph::element::Type m_legacy_type = ngraph::element::Type_t::undefined;

public:
friend class OldApiMap;

/**
* A default constructor
*/
OldApiMapAttr() = default;

/**
* @brief Constructs a new OldApiMapAttr object.
* @param[in] order Transpose order.
* @param[in] legacy_type Legacy type.
*/
explicit OldApiMapAttr(std::vector<uint64_t> order, const ngraph::element::Type& legacy_type)
: m_order(std::move(order)), m_legacy_type(legacy_type) {}

/**
* @brief Returns the transpose order that should be used for obtain a node with old API layout.
* @return transpose order.
*/
const std::vector<uint64_t> & get_order() const {
return m_order;
}

/**
* @brief Returns the legacy type of the node.
* @return legacy type.
*/
ngraph::element::Type get_type() const {
return m_legacy_type;
}
};

/**
* @ingroup ie_runtime_attr_api
* @brief OldApiMap class represents runtime info attribute that stores legacy type
* and order of the transpose that is required for obtaining IR in old API.
*/
class TRANSFORMATIONS_API OldApiMap : public VariantImpl<OldApiMapAttr> {
public:
OPENVINO_RTTI("old_api_map", "0");

/**
* A default constructor
*/
OldApiMap() = default;

/**
* Constructs a new OldApiMap object.
* @param[in] value The object that stores values of OldApiMap.
*/
OldApiMap(const value_type& value) : VariantImpl<value_type>(value) {}

bool is_copyable() const override {
return false;
}

bool visit_attributes(AttributeVisitor& visitor) override;
};

inline bool has_old_api_map(const std::shared_ptr<Node>& node) {
const auto& rt_map = node->get_rt_info();
return rt_map.count(OldApiMap::get_type_info_static());
}

inline OldApiMap get_old_api_map(const std::shared_ptr<Node>& node) {
const auto& rt_map = node->get_rt_info();
const auto& var = rt_map.at(OldApiMap::get_type_info_static());
return ngraph::as_type_ptr<OldApiMap>(var)->get();
}

inline void set_old_api_map(std::shared_ptr<Node>& node, const OldApiMap& old_api_map) {
auto& rt_map = node->get_rt_info();
rt_map[OldApiMap::get_type_info_static()] = std::make_shared<OldApiMap>(old_api_map);
}

} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/rt_info/old_api_map_attribute.hpp"

using namespace ov;

bool OldApiMap::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("order", m_value.m_order);
visitor.on_attribute("element_type", m_value.m_legacy_type);
return true;
}

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <inference_engine.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp>
#include <transformations/rt_info/old_api_map_attribute.hpp>

using namespace ngraph;

Expand Down Expand Up @@ -205,6 +206,7 @@ TEST(RTInfoDeserialization, NodeV11) {
<data element_type="f32" shape="1,3,22,22"/>
<rt_info>
<attribute name="fused_names" version="0" value="in1"/>
<attribute name="old_api_map" version="0" order="0,2,3,1" element_type="f32"/>
</rt_info>
<output>
<port id="0" precision="FP32">
Expand Down Expand Up @@ -238,6 +240,9 @@ TEST(RTInfoDeserialization, NodeV11) {
</output>
</layer>
<layer name="output" type="Result" id="2" version="opset8">
<rt_info>
<attribute name="old_api_map" version="0" order="0,3,1,2" element_type="undefined"/>
</rt_info>
<input>
<port id="0" precision="FP32">
<dim>1</dim>
Expand Down Expand Up @@ -266,6 +271,15 @@ TEST(RTInfoDeserialization, NodeV11) {
ASSERT_EQ(fused_names_attr->get().getNames(), names);
};

auto check_old_api_map = [](const RTMap & info, const std::vector<uint64_t> & order, const ngraph::element::Type& type) {
const std::string & old_api_map_key = ov::OldApiMap::get_type_info_static();
ASSERT_TRUE(info.count(old_api_map_key));
auto old_api_map_attr = std::dynamic_pointer_cast<ov::OldApiMap>(info.at(old_api_map_key));
ASSERT_TRUE(old_api_map_attr);
auto old_api_map_attr_val = old_api_map_attr->get();
ASSERT_EQ(old_api_map_attr_val.get_order(), order);
ASSERT_EQ(old_api_map_attr_val.get_type(), type);
};
auto check_version = [](const std::shared_ptr<ov::Function>& f) {
auto& rt_info = f->get_rt_info();
ASSERT_TRUE(rt_info.count("version"));
Expand All @@ -277,8 +291,14 @@ TEST(RTInfoDeserialization, NodeV11) {

auto param = f->get_parameters()[0];
check_fused_names(param->get_rt_info(), "in1");
check_old_api_map(param->get_rt_info(),
std::vector<uint64_t>({0, 2, 3, 1}),
ngraph::element::Type_t::f32);

auto result = f->get_results()[0];
check_old_api_map(result->get_rt_info(),
std::vector<uint64_t>({0, 3, 1, 2}),
ngraph::element::Type_t::undefined);
auto round = result->get_input_node_ptr(0);
check_fused_names(round->get_rt_info(), "Round1,Round2");
}
Expand All @@ -289,6 +309,9 @@ TEST(RTInfoDeserialization, InputAndOutputV11) {
<layers>
<layer name="in1" type="Parameter" id="0" version="opset8">
<data element_type="f32" shape="1,3,22,22"/>
<rt_info>
<attribute name="old_api_map" version="0" order="" element_type="u8"/>
</rt_info>
<output>
<port id="0" precision="FP32">
<rt_info>
Expand Down Expand Up @@ -376,8 +399,21 @@ TEST(RTInfoDeserialization, InputAndOutputV11) {
ASSERT_EQ(fused_names_attr->get().getNames(), names);
};

auto check_old_api_map = [](const RTMap & info, const std::vector<uint64_t> & order, ngraph::element::Type type) {
const std::string & old_api_map_key = ov::OldApiMap::get_type_info_static();
ASSERT_TRUE(info.count(old_api_map_key));
auto old_api_map_attr = std::dynamic_pointer_cast<ov::OldApiMap>(info.at(old_api_map_key));
ASSERT_TRUE(old_api_map_attr);
auto old_api_map_attr_val = old_api_map_attr->get();
ASSERT_EQ(old_api_map_attr_val.get_order(), order);
ASSERT_EQ(old_api_map_attr_val.get_type(), type);
};

auto param = f->get_parameters()[0];
check_fused_names(param->output(0).get_rt_info(), "test1,test2");
check_old_api_map(param->get_rt_info(),
std::vector<uint64_t>({}),
ngraph::element::Type_t::u8);

auto result = f->get_results()[0];
check_fused_names(result->input(0).get_rt_info(), "test5,test6");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ TEST_F(RTInfoSerializationTest, all_attributes_latest) {
std::make_shared<VariantWrapper<ngraph::FusedNames>>(ngraph::FusedNames("add"));
info[ov::PrimitivesPriority::get_type_info_static()] =
std::make_shared<ov::PrimitivesPriority>("priority");
info[ov::OldApiMap::get_type_info_static()] = std::make_shared<ov::OldApiMap>(
ov::OldApiMapAttr(std::vector<uint64_t>{0, 2, 3, 1}, ngraph::element::Type_t::f32));
};

std::shared_ptr<ngraph::Function> function;
Expand Down Expand Up @@ -67,6 +69,14 @@ TEST_F(RTInfoSerializationTest, all_attributes_latest) {
auto primitives_priority_attr = std::dynamic_pointer_cast<ov::PrimitivesPriority>(info.at(pkey));
ASSERT_TRUE(primitives_priority_attr);
ASSERT_EQ(primitives_priority_attr->get(), "priority");

const std::string & old_api_map_key = ov::OldApiMap::get_type_info_static();
ASSERT_TRUE(info.count(old_api_map_key));
auto old_api_map_attr = std::dynamic_pointer_cast<ov::OldApiMap>(info.at(old_api_map_key));
ASSERT_TRUE(old_api_map_attr);
auto old_api_map_attr_val = old_api_map_attr->get();
ASSERT_EQ(old_api_map_attr_val.get_order(), std::vector<uint64_t>({0, 2, 3, 1}));
ASSERT_EQ(old_api_map_attr_val.get_type(), ngraph::element::Type_t::f32);
};

auto add = f->get_results()[0]->get_input_node_ptr(0);
Expand Down
10 changes: 10 additions & 0 deletions ngraph/frontend/ir/src/rt_info_deserializer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ class RTInfoDeserializer : public ngraph::AttributeVisitor {
adapter.set(value);
}

void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override {
check_attribute_name(name);
std::string val;
if (!getStrAttribute(m_node, name, val))
return;
std::vector<uint64_t> value;
str_to_container(val, value);
adapter.set(value);
}

void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<std::string>>& adapter) override {
check_attribute_name(name);
std::string val;
Expand Down

0 comments on commit 623117f

Please sign in to comment.