Skip to content

Commit

Permalink
Dynamic attribute getters and setters. (openvinotoolkit#964)
Browse files Browse the repository at this point in the history
  • Loading branch information
arogowie-intel authored Jun 26, 2020
1 parent 5aa9ffb commit d0be6b1
Show file tree
Hide file tree
Showing 8 changed files with 865 additions and 220 deletions.
1 change: 1 addition & 0 deletions ngraph/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def cpp_flag(compiler):
"pyngraph/axis_vector.cpp",
"pyngraph/coordinate.cpp",
"pyngraph/coordinate_diff.cpp",
"pyngraph/dict_attribute_visitor.cpp",
"pyngraph/dimension.cpp",
"pyngraph/function.cpp",
"pyngraph/node.cpp",
Expand Down
89 changes: 89 additions & 0 deletions ngraph/python/src/ngraph/utils/node_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Any, Dict, List, Optional

from _pyngraph import NodeFactory as _NodeFactory
Expand All @@ -21,6 +22,8 @@ def create(
) -> Node:
"""Create node object from provided description.
The user does not have to provide all node's attributes, but only required ones.
:param op_type_name: The operator type name.
:param arguments: The operator arguments.
:param attributes: The operator attributes.
Expand All @@ -30,4 +33,90 @@ def create(
if attributes is None:
attributes = {}
node = self.factory.create(op_type_name, arguments, attributes)

# Currently we don't support any attribute getters & setters for TensorIterator node.
if node.get_type_name() == "TensorIterator":
return node

# Set getters and setters for each node's attribute.
# node.get_attribute_name()
# node.set_attribute_name()
# For compound (with more than one level of nesting) attributes of form ie.:
# node.class_member_name.some_metric.attr_name:
# node.get_some_metric_attr_name()
# node.set_some_metric_attr_name()
# Please see test_dyn_attributes.py for more usage examples.
all_attributes = node._get_attributes()
for attr_name in all_attributes.keys():
setattr(node,
self._normalize_attr_name_getter(attr_name),
partial(NodeFactory._get_node_attr_value, node, attr_name))
setattr(node,
self._normalize_attr_name_setter(attr_name),
partial(NodeFactory._set_node_attr_value, node, attr_name))

# Setup helper members for caching attribute values.
# The cache would be lazily populated at first access attempt.
setattr(node, "_attr_cache", dict())
setattr(node, "_attr_cache_valid", bool(False))

return node

@staticmethod
def _normalize_attr_name(attr_name: str, prefix: str) -> str:
"""Normalizes attribute name.
:param attr_name: The attribute name.
:param prefix: The prefix to attach to attribute name.
:returns: The modified attribute name.
"""
# Trim first part of the name if there is only one level of attribute hierarchy.
if attr_name.count(".") == 1:
attr_name = attr_name[attr_name.find(".") + 1:]
return prefix + attr_name.replace(".", "_")

@classmethod
def _normalize_attr_name_getter(cls, attr_name: str) -> str:
"""Normalizes atr name to be suitable for getter function name.
:param attr_name: The attribute name to normalize
:returns: The appropriate getter function name.
"""
return cls._normalize_attr_name(attr_name, "get_")

@classmethod
def _normalize_attr_name_setter(cls, attr_name: str) -> str:
"""Normalizes atr name to be suitable for setter function name.
:param attr_name: The attribute name to normalize
:returns: The appropriate setter function name.
"""
return cls._normalize_attr_name(attr_name, "set_")

@staticmethod
def _get_node_attr_value(node: Node, attr_name: str) -> Any:
"""Gets provided node attribute value.
:param node: The node we retrieve attribute value from.
:param attr_name: The attribute name.
:returns: The node attribute value.
"""
if not node._attr_cache_valid:
node._attr_cache = node._get_attributes()
node._attr_cache_valid = True
return node._attr_cache[attr_name]

@staticmethod
def _set_node_attr_value(node: Node, attr_name: str, value: Any) -> None:
"""Sets the node attribute value.
:param node: The node we change attribute value for.
:param attr_name: The attribute name.
:param value: The new attribute value.
"""
node._set_attribute(attr_name, value)
node._attr_cache[attr_name] = value
Loading

0 comments on commit d0be6b1

Please sign in to comment.