diff --git a/src/otx/algorithms/common/utils/ir.py b/src/otx/algorithms/common/utils/ir.py index e342c41f771..5ab7a98b8a0 100644 --- a/src/otx/algorithms/common/utils/ir.py +++ b/src/otx/algorithms/common/utils/ir.py @@ -1,14 +1,18 @@ """Collections of IR-related utils for common OTX algorithms.""" -# Copyright (C) 2022 Intel Corporation +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # from pathlib import Path from typing import Any, Dict, Tuple +from openvino import Type +from openvino.preprocess import PrePostProcessor from openvino.runtime import Core, save_model +from otx.algorithms.common.utils.utils import is_xpu_available + def check_if_quantized(model: Any) -> bool: """Checks if OpenVINO model is already quantized.""" @@ -32,6 +36,14 @@ def embed_ir_model_data(xml_file: str, data_items: Dict[Tuple[str, str], Any]) - for k, data in data_items.items(): model.set_rt_info(data, list(k)) + # workaround for CVS-138901 + if is_xpu_available(): + pre_post_processor = PrePostProcessor(model) + for output in model.outputs: + if "labels" in output.get_names() and output.get_element_type() == Type.f32: + pre_post_processor.output("labels").tensor().set_element_type(Type.i64) + model = pre_post_processor.build() + # workaround for CVS-110054 tmp_xml_path = Path(Path(xml_file).parent) / "tmp.xml" save_model(model, str(tmp_xml_path), compress_to_fp16=False) diff --git a/tests/unit/algorithms/detection/utils/test_detection_utils.py b/tests/unit/algorithms/detection/utils/test_detection_utils.py index 68836f27646..d67a3a02d91 100644 --- a/tests/unit/algorithms/detection/utils/test_detection_utils.py +++ b/tests/unit/algorithms/detection/utils/test_detection_utils.py @@ -2,12 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 # +import tempfile import addict +import pytest +from otx.algorithms.common.utils.ir import embed_ir_model_data +from otx.algorithms.common.utils.utils import is_xpu_available from otx.algorithms.detection.utils.utils import ( generate_label_schema, get_det_model_api_configuration, ) + +from openvino import Type +from openvino.preprocess import PrePostProcessor +import openvino.runtime as ov from otx.api.entities.model_template import TaskType, task_type_to_label_domain from tests.test_suite.e2e_test_system import e2e_pytest_unit @@ -42,3 +50,20 @@ def test_get_det_model_api_configuration(): assert model_api_cfg[("model_info", "iou_threshold")] == "0.4" assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "labels")].split()) assert len(label_schema.get_labels(include_empty=False)) == len(model_api_cfg[("model_info", "label_ids")].split()) + + +@e2e_pytest_unit +@pytest.mark.skipif(not is_xpu_available(), reason="This test is valid on XPU only") +def test_det_model_ir_patching(): + param_node = ov.op.Parameter(ov.Type.f32, ov.Shape([1])) + model = ov.Model(param_node, [param_node]) + model.outputs[0].tensor.set_names({"labels"}) + assert model.outputs[0].get_element_type() == Type.f32 + + with tempfile.TemporaryDirectory() as tmpdir: + model_path = tmpdir + "/model.xml" + ov.save_model(model, model_path) + embed_ir_model_data(model_path, {}) + core = ov.Core() + model_updated = core.read_model(model_path) + assert model_updated.outputs[0].get_element_type() == Type.i64