Skip to content

Commit

Permalink
Fix shared_memory issue with cast (#2834) (#2835)
Browse files Browse the repository at this point in the history
(cherry picked from commit b328e4d)
  • Loading branch information
KodiaqQ authored Jul 23, 2024
1 parent dd0be43 commit 934f5df
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions nncf/openvino/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def __init__(self, model: TModel, inplace: bool = False):
(OVExtractIfBodyCommand, self._apply_extract_if_body_transformation),
]

@staticmethod
def _convert_to_dtype(value, dtype):
clip_data = np.clip(value, np.finfo(dtype).min, np.finfo(dtype).max)
return clip_data.astype(dtype)

@staticmethod
def _get_name_to_node_mapping(model: ov.Model) -> Dict[str, ov.Node]:
"""
Expand Down Expand Up @@ -269,16 +274,22 @@ def _apply_convert_insertion_transformations(
return model

@staticmethod
def _create_constant(value: np.ndarray, dtype: ov.Type, name: str) -> ov.Node:
def _create_constant(value: np.ndarray, dtype: ov.Type, name: str, shared_memory: bool = False) -> ov.Node:
"""
Creates constant using opset.
:param value: Numpy value.
:param type: Constant type.
:param name: Name for the constant.
:param shared_memory: Shared memory option.
:return: ov.Node instance.
"""
return opset.constant(value, dtype=dtype, name=name)
constant_value = value
# BF16 does not support type conversion in numpy
if dtype != ov.Type.bf16 and constant_value.dtype != dtype.to_dtype():
constant_value = OVModelTransformer._convert_to_dtype(constant_value, dtype.to_dtype())

return opset.constant(constant_value, dtype=dtype, name=name, shared_memory=shared_memory)

@staticmethod
def _create_fake_quantize(
Expand Down Expand Up @@ -511,7 +522,7 @@ def _set_const_value(node_with_const: ov.Node, const_port_id: int, const_value:
# Shared memory does not work for BF16 precision
shared_memory = False

new_const_node = opset.constant(
new_const_node = OVModelTransformer._create_constant(
const_value,
dtype=const_port.get_element_type(),
name=const_node.get_friendly_name(),
Expand Down

0 comments on commit 934f5df

Please sign in to comment.