Skip to content

Commit

Permalink
OVWeightUpdateCommand copies a constant (#2277)
Browse files Browse the repository at this point in the history
### Changes

- Update an implementation of `get_const_value()` method. This method
was initially implemented as a workaround in [PR
1654](#1654), but it is no
longer needed and can be replaced with `const_op.data`.
- Replace `const_op.get_data()` with `const_op.data` because the
`const_op.get_data()` legacy method and `.data` property should be used.
- Set the `shared_memory` parameter to `True` during constant creation
to avoid copying the constant.

### Reason for changes

- The `const_od.get_data()` method copies constant.
- Do not copy the constant value during constant creation in
`opset.constant(...)`.

### Related tickets

Ref: 122922

### Tests

N/A

---------

Co-authored-by: Nikita Malinin <[email protected]>
  • Loading branch information
andrey-churkin and KodiaqQ authored Nov 23, 2023
1 parent 5eee3bc commit 9479779
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 13 deletions.
9 changes: 6 additions & 3 deletions nncf/openvino/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,12 @@ def _set_const_value(node_with_const: ov.Node, const_port_id: int, const_value:
if const_node is None:
raise RuntimeError("Constant node was expected but could not find it.")

const_shape = const_node.get_data().shape
const_value = np.reshape(const_value, const_shape)
new_const_node = opset.constant(const_value, dtype=const_node.get_element_type())
const_shape = const_node.data.shape
const_dtype = const_node.data.dtype
const_value = np.reshape(const_value, const_shape).astype(const_dtype)

# TODO(andrey-churkin): Replace on opset13.constant() in a future release
new_const_node = ov.op.Constant(const_value, shared_memory=True)
new_const_node.set_friendly_name(const_node.get_friendly_name())
const_port.replace_source_output(new_const_node.output(0))

Expand Down
2 changes: 1 addition & 1 deletion nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_const_value(const_node: ov.Node) -> np.ndarray:
:param const_node: OpenVINO node.
:return: The constant value.
"""
return const_node.get_vector().reshape(const_node.get_output_shape(0))
return const_node.data


def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> np.ndarray:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def get_fq_nodes_stats_algo(model):
nodes = {}
for op in model.get_ops():
if op.get_type_name() == "FakeQuantize":
input_low = op.input_value(1).get_node().get_data()
input_high = op.input_value(2).get_node().get_data()
output_low = op.input_value(3).get_node().get_data()
output_high = op.input_value(4).get_node().get_data()
input_low = op.input_value(1).get_node().data
input_high = op.input_value(2).get_node().data
output_low = op.input_value(3).get_node().data
output_high = op.input_value(4).get_node().data

nodes[op.get_friendly_name()] = {
"input_low": input_low,
Expand Down
10 changes: 5 additions & 5 deletions tests/openvino/native/test_model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,13 @@ def check_inplace_op(target_node, ref_types, ref_vals, inplace_branches_num, out
if ref_val is not None:
const = get_prev_node(node, 1)
if ref_val == []:
assert const.get_data().shape == (0,)
assert const.data.shape == (0,)
elif not isinstance(ref_val, tuple):
assert const.get_data() == ref_val
assert const.data == ref_val
else:
res = np.equal(const.get_data(), np.array(ref_val))
res = np.equal(const.data, np.array(ref_val))
assert all(res)
assert const.get_data().shape == np.array(ref_val).shape
assert const.data.shape == np.array(ref_val).shape

nodes = get_next_nodes(node, 0)
assert len(nodes) == 1
Expand Down Expand Up @@ -296,7 +296,7 @@ def test_inplace_reduce_fn_dynamic_shapes(input_shape, raise_error):
op = fn(input_1, 0)
# check_const
ref_const = np.array([0, 1, 2, 3])
assert all(np.equal(get_prev_node(op, 1).get_data(), ref_const))
assert all(np.equal(get_prev_node(op, 1).data, ref_const))


@pytest.mark.parametrize("reduction_axes", [None, np.array([], dtype=np.int64)])
Expand Down

0 comments on commit 9479779

Please sign in to comment.