From 934f5df5772003d256dfaa9ca73e09eb0827eafd Mon Sep 17 00:00:00 2001 From: Nikita Malinin Date: Tue, 23 Jul 2024 12:36:28 +0200 Subject: [PATCH] Fix shared_memory issue with cast (#2834) (#2835) (cherry picked from commit b328e4d0e0a6e3b471ed7ad155caf3848ab703d1) --- nncf/openvino/graph/model_transformer.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/nncf/openvino/graph/model_transformer.py b/nncf/openvino/graph/model_transformer.py index c073ce7605c..a080dec1118 100644 --- a/nncf/openvino/graph/model_transformer.py +++ b/nncf/openvino/graph/model_transformer.py @@ -67,6 +67,11 @@ def __init__(self, model: TModel, inplace: bool = False): (OVExtractIfBodyCommand, self._apply_extract_if_body_transformation), ] + @staticmethod + def _convert_to_dtype(value, dtype): + clip_data = np.clip(value, np.finfo(dtype).min, np.finfo(dtype).max) + return clip_data.astype(dtype) + @staticmethod def _get_name_to_node_mapping(model: ov.Model) -> Dict[str, ov.Node]: """ @@ -269,16 +274,22 @@ def _apply_convert_insertion_transformations( return model @staticmethod - def _create_constant(value: np.ndarray, dtype: ov.Type, name: str) -> ov.Node: + def _create_constant(value: np.ndarray, dtype: ov.Type, name: str, shared_memory: bool = False) -> ov.Node: """ Creates constant using opset. :param value: Numpy value. :param type: Constant type. :param name: Name for the constant. + :param shared_memory: Shared memory option. :return: ov.Node instance. """ - return opset.constant(value, dtype=dtype, name=name) + constant_value = value + # BF16 does not support type conversion in numpy + if dtype != ov.Type.bf16 and constant_value.dtype != dtype.to_dtype(): + constant_value = OVModelTransformer._convert_to_dtype(constant_value, dtype.to_dtype()) + + return opset.constant(constant_value, dtype=dtype, name=name, shared_memory=shared_memory) @staticmethod def _create_fake_quantize( @@ -511,7 +522,7 @@ def _set_const_value(node_with_const: ov.Node, const_port_id: int, const_value: # Shared memory does not work for BF16 precision shared_memory = False - new_const_node = opset.constant( + new_const_node = OVModelTransformer._create_constant( const_value, dtype=const_port.get_element_type(), name=const_node.get_friendly_name(),