From d0be6b1d2f1630a804385262c7d197ffdf07c0f4 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 26 Jun 2020 16:35:00 +0200 Subject: [PATCH] Dynamic attribute getters and setters. (#964) --- ngraph/python/setup.py | 1 + .../python/src/ngraph/utils/node_factory.py | 89 +++++ .../src/pyngraph/dict_attribute_visitor.cpp | 351 ++++++++++++++++++ .../src/pyngraph/dict_attribute_visitor.hpp | 158 ++++++++ ngraph/python/src/pyngraph/node.cpp | 28 +- ngraph/python/src/pyngraph/node_factory.cpp | 215 +---------- ngraph/python/test/ngraph/test_create_op.py | 2 +- .../python/test/ngraph/test_dyn_attributes.py | 241 ++++++++++++ 8 files changed, 865 insertions(+), 220 deletions(-) create mode 100644 ngraph/python/src/pyngraph/dict_attribute_visitor.cpp create mode 100644 ngraph/python/src/pyngraph/dict_attribute_visitor.hpp create mode 100644 ngraph/python/test/ngraph/test_dyn_attributes.py diff --git a/ngraph/python/setup.py b/ngraph/python/setup.py index 288fe4d16229ec..8c0d5896f51915 100644 --- a/ngraph/python/setup.py +++ b/ngraph/python/setup.py @@ -182,6 +182,7 @@ def cpp_flag(compiler): "pyngraph/axis_vector.cpp", "pyngraph/coordinate.cpp", "pyngraph/coordinate_diff.cpp", + "pyngraph/dict_attribute_visitor.cpp", "pyngraph/dimension.cpp", "pyngraph/function.cpp", "pyngraph/node.cpp", diff --git a/ngraph/python/src/ngraph/utils/node_factory.py b/ngraph/python/src/ngraph/utils/node_factory.py index 70750ce16af292..d07ac3db6a5fb7 100644 --- a/ngraph/python/src/ngraph/utils/node_factory.py +++ b/ngraph/python/src/ngraph/utils/node_factory.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Any, Dict, List, Optional from _pyngraph import NodeFactory as _NodeFactory @@ -21,6 +22,8 @@ def create( ) -> Node: """Create node object from provided description. + The user does not have to provide all node's attributes, but only required ones. + :param op_type_name: The operator type name. :param arguments: The operator arguments. :param attributes: The operator attributes. @@ -30,4 +33,90 @@ def create( if attributes is None: attributes = {} node = self.factory.create(op_type_name, arguments, attributes) + + # Currently we don't support any attribute getters & setters for TensorIterator node. + if node.get_type_name() == "TensorIterator": + return node + + # Set getters and setters for each node's attribute. + # node.get_attribute_name() + # node.set_attribute_name() + # For compound (with more than one level of nesting) attributes of form ie.: + # node.class_member_name.some_metric.attr_name: + # node.get_some_metric_attr_name() + # node.set_some_metric_attr_name() + # Please see test_dyn_attributes.py for more usage examples. + all_attributes = node._get_attributes() + for attr_name in all_attributes.keys(): + setattr(node, + self._normalize_attr_name_getter(attr_name), + partial(NodeFactory._get_node_attr_value, node, attr_name)) + setattr(node, + self._normalize_attr_name_setter(attr_name), + partial(NodeFactory._set_node_attr_value, node, attr_name)) + + # Setup helper members for caching attribute values. + # The cache would be lazily populated at first access attempt. + setattr(node, "_attr_cache", dict()) + setattr(node, "_attr_cache_valid", bool(False)) + return node + + @staticmethod + def _normalize_attr_name(attr_name: str, prefix: str) -> str: + """Normalizes attribute name. + + :param attr_name: The attribute name. + :param prefix: The prefix to attach to attribute name. + + :returns: The modified attribute name. + """ + # Trim first part of the name if there is only one level of attribute hierarchy. + if attr_name.count(".") == 1: + attr_name = attr_name[attr_name.find(".") + 1:] + return prefix + attr_name.replace(".", "_") + + @classmethod + def _normalize_attr_name_getter(cls, attr_name: str) -> str: + """Normalizes atr name to be suitable for getter function name. + + :param attr_name: The attribute name to normalize + + :returns: The appropriate getter function name. + """ + return cls._normalize_attr_name(attr_name, "get_") + + @classmethod + def _normalize_attr_name_setter(cls, attr_name: str) -> str: + """Normalizes atr name to be suitable for setter function name. + + :param attr_name: The attribute name to normalize + + :returns: The appropriate setter function name. + """ + return cls._normalize_attr_name(attr_name, "set_") + + @staticmethod + def _get_node_attr_value(node: Node, attr_name: str) -> Any: + """Gets provided node attribute value. + + :param node: The node we retrieve attribute value from. + :param attr_name: The attribute name. + + :returns: The node attribute value. + """ + if not node._attr_cache_valid: + node._attr_cache = node._get_attributes() + node._attr_cache_valid = True + return node._attr_cache[attr_name] + + @staticmethod + def _set_node_attr_value(node: Node, attr_name: str, value: Any) -> None: + """Sets the node attribute value. + + :param node: The node we change attribute value for. + :param attr_name: The attribute name. + :param value: The new attribute value. + """ + node._set_attribute(attr_name, value) + node._attr_cache[attr_name] = value diff --git a/ngraph/python/src/pyngraph/dict_attribute_visitor.cpp b/ngraph/python/src/pyngraph/dict_attribute_visitor.cpp new file mode 100644 index 00000000000000..246ca8066c7e24 --- /dev/null +++ b/ngraph/python/src/pyngraph/dict_attribute_visitor.cpp @@ -0,0 +1,351 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +// These are not used here, but needed in order to not violate ODR, since +// these are included in other translation units, and specialize some types. +// Related: https://github.com/pybind/pybind11/issues/1055 +#include +#include + +#include "dict_attribute_visitor.hpp" + +namespace py = pybind11; + +util::DictAttributeDeserializer::DictAttributeDeserializer(const py::dict& attributes) + : m_attributes(attributes) +{ +} + +void util::DictAttributeDeserializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + if (m_attributes.contains(name)) + { + NGRAPH_CHECK(false, "No AttributeVisitor support for accessing attribute named: ", name); + } +} +void util::DictAttributeDeserializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast()); + } +} +void util::DictAttributeDeserializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast()); + } +} +void util::DictAttributeDeserializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast()); + } +} +void util::DictAttributeDeserializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast()); + } +} +void util::DictAttributeDeserializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast()); + } +} +void util::DictAttributeDeserializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast()); + } +} +void util::DictAttributeDeserializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast()); + } +} +void util::DictAttributeDeserializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast()); + } +} +void util::DictAttributeDeserializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast()); + } +} +void util::DictAttributeDeserializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast()); + } +} +void util::DictAttributeDeserializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast()); + } +} +void util::DictAttributeDeserializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast()); + } +} +void util::DictAttributeDeserializer::on_adapter( + const std::string& name, ngraph::ValueAccessor>& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast>()); + } +} +void util::DictAttributeDeserializer::on_adapter( + const std::string& name, ngraph::ValueAccessor>& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast>()); + } +} +void util::DictAttributeDeserializer::on_adapter( + const std::string& name, ngraph::ValueAccessor>& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast>()); + } +} +void util::DictAttributeDeserializer::on_adapter( + const std::string& name, ngraph::ValueAccessor>& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast>()); + } +} +void util::DictAttributeDeserializer::on_adapter( + const std::string& name, ngraph::ValueAccessor>& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast>()); + } +} +void util::DictAttributeDeserializer::on_adapter( + const std::string& name, ngraph::ValueAccessor>& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast>()); + } +} +void util::DictAttributeDeserializer::on_adapter( + const std::string& name, ngraph::ValueAccessor>& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast>()); + } +} +void util::DictAttributeDeserializer::on_adapter( + const std::string& name, ngraph::ValueAccessor>& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast>()); + } +} +void util::DictAttributeDeserializer::on_adapter( + const std::string& name, ngraph::ValueAccessor>& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast>()); + } +} +void util::DictAttributeDeserializer::on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast>()); + } +} +void util::DictAttributeDeserializer::on_adapter( + const std::string& name, ngraph::ValueAccessor>& adapter) +{ + if (m_attributes.contains(name)) + { + adapter.set(m_attributes[name.c_str()].cast>()); + } +} + +util::DictAttributeSerializer::DictAttributeSerializer(const std::shared_ptr& node) +{ + node->visit_attributes(*this); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + if (m_attributes.contains(name)) + { + NGRAPH_CHECK(false, "No AttributeVisitor support for accessing attribute named: ", name); + } +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter( + const std::string& name, ngraph::ValueAccessor>& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter( + const std::string& name, ngraph::ValueAccessor>& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter( + const std::string& name, ngraph::ValueAccessor>& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter( + const std::string& name, ngraph::ValueAccessor>& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} +void util::DictAttributeSerializer::on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) +{ + m_attributes[name.c_str()] = adapter.get(); +} diff --git a/ngraph/python/src/pyngraph/dict_attribute_visitor.hpp b/ngraph/python/src/pyngraph/dict_attribute_visitor.hpp new file mode 100644 index 00000000000000..21978cc8dfaff4 --- /dev/null +++ b/ngraph/python/src/pyngraph/dict_attribute_visitor.hpp @@ -0,0 +1,158 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include +#include +#include + +#include "ngraph/attribute_visitor.hpp" +#include "ngraph/node.hpp" + +#include + +namespace py = pybind11; + +namespace util +{ + class DictAttributeDeserializer : public ngraph::AttributeVisitor + { + public: + DictAttributeDeserializer(const py::dict& attributes); + + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + + protected: + const py::dict& m_attributes; + }; + + class DictAttributeSerializer : public ngraph::AttributeVisitor + { + public: + DictAttributeSerializer(const std::shared_ptr& node); + + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + virtual void on_adapter(const std::string& name, + ngraph::ValueAccessor>& adapter) override; + + template + T get_attribute(const std::string& name) + { + NGRAPH_CHECK(m_attributes.contains(name), + "Couldn't find attribute \"", + name, + "\" in serialized node attribute dictionary."); + return m_attributes[name.c_str()].cast(); + } + + py::dict get_attributes() const { return m_attributes; } + protected: + py::dict m_attributes; + }; +} diff --git a/ngraph/python/src/pyngraph/node.cpp b/ngraph/python/src/pyngraph/node.cpp index f00205f387e8fd..9db7b4de9fa506 100644 --- a/ngraph/python/src/pyngraph/node.cpp +++ b/ngraph/python/src/pyngraph/node.cpp @@ -14,21 +14,21 @@ // limitations under the License. //***************************************************************************** -#include #include -#include "ngraph/node.hpp" // ngraph::Node -#include "ngraph/op/add.hpp" // ngraph::op::Add -#include "ngraph/op/divide.hpp" // ngraph::op::Divide -#include "ngraph/op/multiply.hpp" // ngraph::op::Multiply -#include "ngraph/op/subtract.hpp" // ngraph::op::Subtract +#include "dict_attribute_visitor.hpp" +#include "ngraph/node.hpp" +#include "ngraph/op/add.hpp" +#include "ngraph/op/divide.hpp" +#include "ngraph/op/multiply.hpp" +#include "ngraph/op/subtract.hpp" #include "pyngraph/node.hpp" namespace py = pybind11; void regclass_pyngraph_Node(py::module m) { - py::class_> node(m, "Node"); + py::class_> node(m, "Node", py::dynamic_attr()); node.doc() = "ngraph.impl.Node wraps ngraph::Node"; node.def("__add__", [](const std::shared_ptr& a, const std::shared_ptr b) { @@ -79,4 +79,18 @@ void regclass_pyngraph_Node(py::module m) node.def("get_unique_name", &ngraph::Node::get_name); node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name); + node.def_property_readonly("shape", &ngraph::Node::get_shape); + + node.def("_get_attributes", [](const std::shared_ptr& self) { + util::DictAttributeSerializer dict_serializer(self); + return dict_serializer.get_attributes(); + }); + node.def( + "_set_attribute", + [](std::shared_ptr& self, const std::string& atr_name, py::object value) { + py::dict attr_dict; + attr_dict[atr_name.c_str()] = value; + util::DictAttributeDeserializer dict_deserializer(attr_dict); + self->visit_attributes(dict_deserializer); + }); } diff --git a/ngraph/python/src/pyngraph/node_factory.cpp b/ngraph/python/src/pyngraph/node_factory.cpp index ea54e00d50a5b6..d7cfca76d02885 100644 --- a/ngraph/python/src/pyngraph/node_factory.cpp +++ b/ngraph/python/src/pyngraph/node_factory.cpp @@ -26,14 +26,11 @@ #include #include -#include "ngraph/attribute_visitor.hpp" +#include "dict_attribute_visitor.hpp" #include "ngraph/check.hpp" -#include "ngraph/enum_names.hpp" #include "ngraph/except.hpp" #include "ngraph/node.hpp" -#include "ngraph/op/constant.hpp" #include "ngraph/opsets/opset.hpp" -#include "ngraph/util.hpp" #include "node_factory.hpp" #include "tensor_iterator_builder.hpp" @@ -41,212 +38,6 @@ namespace py = pybind11; namespace { - class DictAttributeDeserializer : public ngraph::AttributeVisitor - { - public: - DictAttributeDeserializer(const py::dict& attributes) - : m_attributes(attributes) - { - } - - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor& adapter) override - { - if (m_attributes.contains(name)) - { - NGRAPH_CHECK( - false, "No AttributeVisitor support for accessing attribute named: ", name); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor>& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast>()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor>& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast>()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor>& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast>()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor>& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast>()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor>& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast>()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor>& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast>()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor>& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast>()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor>& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast>()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor>& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast>()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor>& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast>()); - } - } - virtual void on_adapter(const std::string& name, - ngraph::ValueAccessor>& adapter) override - { - if (m_attributes.contains(name)) - { - adapter.set(m_attributes[name.c_str()].cast>()); - } - } - - protected: - const py::dict& m_attributes; - }; - class NodeFactory { public: @@ -270,12 +61,12 @@ namespace if (op_type_name == "TensorIterator") { - // TODO: how to differentiate opsets? + // XXX: How to differentiate opsets? return util::TensorIteratorBuilder(arguments, attributes) .configure(std::static_pointer_cast(op_node)); } - DictAttributeDeserializer visitor(attributes); + util::DictAttributeDeserializer visitor(attributes); op_node->set_arguments(arguments); op_node->visit_attributes(visitor); diff --git a/ngraph/python/test/ngraph/test_create_op.py b/ngraph/python/test/ngraph/test_create_op.py index c815eaf4bb9861..24b7ccab657a34 100644 --- a/ngraph/python/test/ngraph/test_create_op.py +++ b/ngraph/python/test/ngraph/test_create_op.py @@ -508,7 +508,7 @@ def test_roi_pooling(): node = ng.roi_pooling(inputs, coords, [6, 6], 0.0625, "Max") assert node.get_type_name() == "ROIPooling" - assert node.get_output_size() == 1 + assert node.get_output_size() == [6, 6] assert list(node.get_output_shape(0)) == [150, 3, 6, 6] assert node.get_output_element_type(0) == Type.f32 diff --git a/ngraph/python/test/ngraph/test_dyn_attributes.py b/ngraph/python/test/ngraph/test_dyn_attributes.py new file mode 100644 index 00000000000000..8b6fb8a303994a --- /dev/null +++ b/ngraph/python/test/ngraph/test_dyn_attributes.py @@ -0,0 +1,241 @@ +# ****************************************************************************** +# Copyright 2017-2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ****************************************************************************** +import numpy as np +import pytest + +import ngraph as ng + + +@pytest.fixture() +def _proposal_node(): + attributes = { + "attrs.base_size": np.uint16(1), + "attrs.pre_nms_topn": np.uint16(20), + "attrs.post_nms_topn": np.uint16(64), + "attrs.nms_thresh": np.float64(0.34), + "attrs.feat_stride": np.uint16(16), + "attrs.min_size": np.uint16(32), + "attrs.ratio": np.array([0.1, 1.5, 2.0, 2.5], dtype=np.float64), + "attrs.scale": np.array([2, 3, 3, 4], dtype=np.float64), + } + batch_size = 7 + + class_probs = ng.parameter([batch_size, 12, 34, 62], np.float64, "class_probs") + class_logits = ng.parameter([batch_size, 24, 34, 62], np.float64, "class_logits") + image_shape = ng.parameter([3], np.float64, "image_shape") + return ng.proposal(class_probs, class_logits, image_shape, attributes) + + +def test_dynamic_attributes_softmax(): + axis = 2 + data = ng.parameter([1, 2, 3, 4], np.float32, "data_in") + node = ng.softmax(data, axis) + + assert node.get_axis() == axis + node.set_axis(3) + assert node.get_axis() == 3 + + +@pytest.mark.parametrize( + "int_dtype, fp_dtype", + [ + (np.int8, np.float32), + (np.int16, np.float32), + (np.int32, np.float32), + (np.int64, np.float32), + (np.uint8, np.float32), + (np.uint16, np.float32), + (np.uint32, np.float32), + (np.uint64, np.float32), + (np.int32, np.float16), + (np.int32, np.float64), + ], +) +def test_dynamic_get_attribute_value(int_dtype, fp_dtype): + attributes = { + "attrs.num_classes": int_dtype(85), + "attrs.background_label_id": int_dtype(13), + "attrs.top_k": int_dtype(16), + "attrs.variance_encoded_in_target": True, + "attrs.keep_top_k": np.array([64, 32, 16, 8], dtype=int_dtype), + "attrs.code_type": "pytorch.some_parameter_name", + "attrs.share_location": False, + "attrs.nms_threshold": fp_dtype(0.645), + "attrs.confidence_threshold": fp_dtype(0.111), + "attrs.clip_after_nms": True, + "attrs.clip_before_nms": False, + "attrs.decrease_label_id": True, + "attrs.normalized": True, + "attrs.input_height": int_dtype(86), + "attrs.input_width": int_dtype(79), + "attrs.objectness_score": fp_dtype(0.77), + } + + box_logits = ng.parameter([4, 1, 5, 5], fp_dtype, "box_logits") + class_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "class_preds") + proposals = ng.parameter([2, 1, 4, 5], fp_dtype, "proposals") + aux_class_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "aux_class_preds") + aux_box_preds = ng.parameter([2, 1, 4, 5], fp_dtype, "aux_box_preds") + + node = ng.detection_output( + box_logits, class_preds, proposals, attributes, aux_class_preds, aux_box_preds + ) + + assert node.get_num_classes() == int_dtype(85) + assert node.get_background_label_id() == int_dtype(13) + assert node.get_top_k() == int_dtype(16) + assert node.get_variance_encoded_in_target() == True + assert np.all(np.equal(node.get_keep_top_k(), np.array([64, 32, 16, 8], dtype=int_dtype))) + assert node.get_code_type() == "pytorch.some_parameter_name" + assert node.get_share_location() == False + assert np.isclose(node.get_nms_threshold(), fp_dtype(0.645)) + assert np.isclose(node.get_confidence_threshold(), fp_dtype(0.111)) + assert node.get_clip_after_nms() == True + assert node.get_clip_before_nms() == False + assert node.get_decrease_label_id() == True + assert node.get_normalized() == True + assert node.get_input_height() == int_dtype(86) + assert node.get_input_width() == int_dtype(79) + assert np.isclose(node.get_objectness_score(), fp_dtype(0.77)) + assert node.get_num_classes() == int_dtype(85) + + +@pytest.mark.parametrize( + "int_dtype, fp_dtype", + [ + (np.uint8, np.float32), + (np.uint16, np.float32), + (np.uint32, np.float32), + (np.uint64, np.float32), + (np.uint32, np.float16), + (np.uint32, np.float64), + ], +) +def test_dynamic_set_attribute_value(int_dtype, fp_dtype): + attributes = { + "attrs.base_size": int_dtype(1), + "attrs.pre_nms_topn": int_dtype(20), + "attrs.post_nms_topn": int_dtype(64), + "attrs.nms_thresh": fp_dtype(0.34), + "attrs.feat_stride": int_dtype(16), + "attrs.min_size": int_dtype(32), + "attrs.ratio": np.array([0.1, 1.5, 2.0, 2.5], dtype=fp_dtype), + "attrs.scale": np.array([2, 3, 3, 4], dtype=fp_dtype), + } + batch_size = 7 + + class_probs = ng.parameter([batch_size, 12, 34, 62], fp_dtype, "class_probs") + class_logits = ng.parameter([batch_size, 24, 34, 62], fp_dtype, "class_logits") + image_shape = ng.parameter([3], fp_dtype, "image_shape") + node = ng.proposal(class_probs, class_logits, image_shape, attributes) + + node.set_base_size(int_dtype(15)) + node.set_pre_nms_topn(int_dtype(7)) + node.set_post_nms_topn(int_dtype(33)) + node.set_nms_thresh(fp_dtype(1.55)) + node.set_feat_stride(int_dtype(8)) + node.set_min_size(int_dtype(123)) + node.set_ratio(np.array([1.1, 2.5, 3.0, 4.5], dtype=fp_dtype)) + node.set_scale(np.array([2.1, 3.2, 3.3, 4.4], dtype=fp_dtype)) + node.set_clip_before_nms(True) + node.set_clip_after_nms(True) + node.set_normalize(True) + node.set_box_size_scale(fp_dtype(1.34)) + node.set_box_coordinate_scale(fp_dtype(0.88)) + node.set_framework("OpenVINO") + + assert node.get_base_size() == int_dtype(15) + assert node.get_pre_nms_topn() == int_dtype(7) + assert node.get_post_nms_topn() == int_dtype(33) + assert np.isclose(node.get_nms_thresh(), fp_dtype(1.55)) + assert node.get_feat_stride() == int_dtype(8) + assert node.get_min_size() == int_dtype(123) + assert np.allclose(node.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=fp_dtype)) + assert np.allclose(node.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=fp_dtype)) + assert node.get_clip_before_nms() == True + assert node.get_clip_after_nms() == True + assert node.get_normalize() == True + assert np.isclose(node.get_box_size_scale(), fp_dtype(1.34)) + assert np.isclose(node.get_box_coordinate_scale(), fp_dtype(0.88)) + assert node.get_framework() == "OpenVINO" + + +def test_dynamic_attr_cache(_proposal_node): + node = _proposal_node + + assert not node._attr_cache_valid + node.set_nms_thresh(1.3453678102) + assert not node._attr_cache_valid + assert np.isclose(node.get_nms_thresh(), np.float64(1.3453678102)) + assert node._attr_cache_valid + + +def test_dynamic_attr_transitivity(_proposal_node): + node = _proposal_node + node2 = node + + node.set_ratio(np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64)) + assert np.allclose(node.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64)) + assert np.allclose(node2.get_ratio(), np.array([1.1, 2.5, 3.0, 4.5], dtype=np.float64)) + + node2.set_scale(np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64)) + assert np.allclose(node2.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64)) + assert np.allclose(node.get_scale(), np.array([2.1, 3.2, 3.3, 4.4], dtype=np.float64)) + + +def test_dynamic_attributes_simple(): + batch_size = 1 + input_size = 16 + hidden_size = 128 + + X_shape = [batch_size, input_size] + H_t_shape = [batch_size, hidden_size] + W_shape = [3 * hidden_size, input_size] + R_shape = [3 * hidden_size, hidden_size] + B_shape = [4 * hidden_size] + + parameter_X = ng.parameter(X_shape, name="X", dtype=np.float32) + parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=np.float32) + parameter_W = ng.parameter(W_shape, name="W", dtype=np.float32) + parameter_R = ng.parameter(R_shape, name="R", dtype=np.float32) + parameter_B = ng.parameter(B_shape, name="B", dtype=np.float32) + + activations = ["tanh", "relu"] + activations_alpha = [1.0, 2.0] + activations_beta = [1.0, 2.0] + clip = 0.5 + linear_before_reset = True + + node = ng.gru_cell( + parameter_X, + parameter_H_t, + parameter_W, + parameter_R, + parameter_B, + hidden_size, + activations, + activations_alpha, + activations_beta, + clip, + linear_before_reset, + ) + + assert node.get_hidden_size() == hidden_size + assert all(map(lambda x, y: x == y, node.get_activations(), activations)) + assert all(np.equal(node.get_activations_alpha(), activations_alpha)) + assert all(np.equal(node.get_activations_beta(), activations_beta)) + assert node.get_linear_before_reset() == linear_before_reset + assert np.isclose(node.get_clip(), clip)