Skip to content

Commit

Permalink
[PY API] Fix the preoblem that Node.get_attributes() cannot return al…
Browse files Browse the repository at this point in the history
…l attributes (openvinotoolkit#23530)

### Details:
- extend the `util::DictAttributeSerializer::on_adapter()` method,
making it compatible with `ov::PartialShape` and
`ov::op::util::Variable` types;
 - add extra tests to test the correctness of `Node.get_attributes()`

### Tickets:
 - openvinotoolkit#23455

---------

Co-authored-by: Jan Iwaszkiewicz <[email protected]>
  • Loading branch information
2 people authored and alvoron committed Apr 29, 2024
1 parent c04223f commit be4748c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,17 @@ void util::DictAttributeSerializer::on_adapter(const std::string& name, ov::Valu
if (m_attributes.contains(name)) {
OPENVINO_THROW("No AttributeVisitor support for accessing attribute named: ", name);
}

if (auto _adapter = dynamic_cast<ov::AttributeAdapter<std::shared_ptr<ov::op::util::Variable>>*>(&adapter)) {
m_attributes[name.c_str()] = _adapter->get()->get_info().variable_id;
} else if (auto _adapter = dynamic_cast<ov::AttributeAdapter<ov::PartialShape>*>(&adapter)) {
auto partial_shape = _adapter->get();
std::vector<ov::Dimension::value_type> shape;
for (const auto& dim : partial_shape) {
shape.push_back(dim.is_dynamic() ? -1 : dim.get_length());
}
m_attributes[name.c_str()] = shape;
}
}
void util::DictAttributeSerializer::on_adapter(const std::string& name, ov::ValueAccessor<bool>& adapter) {
m_attributes[name.c_str()] = adapter.get();
Expand Down
13 changes: 13 additions & 0 deletions src/bindings/python/tests/test_graph/test_create_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,11 +1183,15 @@ def test_read_value():
init_value = ov.parameter([2, 2], name="init_value", dtype=np.int32)

node = ov.read_value(init_value, "var_id_667", np.int32, [2, 2])
read_value_attributes = node.get_attributes()

assert node.get_type_name() == "ReadValue"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [2, 2]
assert node.get_output_element_type(0) == Type.i32
assert read_value_attributes["variable_type"] == "i32"
assert read_value_attributes["variable_id"] == "var_id_667"
assert read_value_attributes["variable_shape"] == [2, 2]


def test_read_value_dyn_variable_pshape():
Expand All @@ -1205,11 +1209,13 @@ def test_assign():
input_data = ov.parameter([5, 7], name="input_data", dtype=np.int32)
rv = ov.read_value(input_data, "var_id_667", np.int32, [5, 7])
node = ov.assign(rv, "var_id_667")
assign_attributes = node.get_attributes()

assert node.get_type_name() == "Assign"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [5, 7]
assert node.get_output_element_type(0) == Type.i32
assert assign_attributes["variable_id"] == "var_id_667"


def test_extract_image_patches():
Expand Down Expand Up @@ -2353,3 +2359,10 @@ def test_topk_opset11():
assert node.get_output_size() == 2
assert list(node.get_output_shape(0)) == [1, 3, 3]
assert list(node.get_output_shape(1)) == [1, 3, 3]


def test_parameter_get_attributes():
parameter = ov.parameter([2, 2], dtype=np.float32, name="InputData")
parameter_attributes = parameter.get_attributes()
assert parameter_attributes["element_type"] == "f32"
assert parameter_attributes["shape"] == [2, 2]

0 comments on commit be4748c

Please sign in to comment.