Skip to content

Commit

Permalink
Fix action tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Jun 12, 2023
1 parent e0564fd commit b78e9da
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
26 changes: 13 additions & 13 deletions tests/unit/algorithms/action/adapters/openvino/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#

import os
import pathlib
from typing import Any, Dict

import numpy as np
Expand All @@ -24,6 +25,7 @@
AnnotationSceneEntity,
AnnotationSceneKind,
)
from otx.api.entities.dataset_item import DatasetItemEntity
from otx.api.entities.label import Domain
from otx.api.entities.label_schema import LabelGroup, LabelGroupType, LabelSchemaEntity
from otx.api.entities.model import (
Expand Down Expand Up @@ -337,31 +339,31 @@ class MockPipeline:
def run(self, model):
return model

def mock_save_model(model, tempdir, model_name):
def mock_save_model(model, output_xml):
"""Mock function for save_model function."""
with open(os.path.join(tempdir, "model.xml"), "wb") as f:
with open(output_xml, "wb") as f:
f.write(np.ndarray(1).tobytes())
with open(os.path.join(tempdir, "model.bin"), "wb") as f:
bin_path = pathlib.Path(output_xml).parent / pathlib.Path(str(pathlib.Path(output_xml).stem) + ".bin")
with open(bin_path, "wb") as f:
f.write(np.ndarray(1).tobytes())

mocker.patch("otx.algorithms.action.adapters.openvino.task.ov.Core.read_model", autospec=True)
mocker.patch("otx.algorithms.action.adapters.openvino.task.ov.serialize", new=mock_save_model)
fake_quantize = mocker.patch("otx.algorithms.action.adapters.openvino.task.nncf.quantize", autospec=True)

mocker.patch(
"otx.algorithms.action.adapters.openvino.task.get_ovdataloader", return_value=MockDataloader(self.dataset)
)
mocker.patch(
"otx.algorithms.action.adapters.openvino.task.DataLoaderWrapper", return_value=MockDataloader(self.dataset)
)
mocker.patch("otx.algorithms.action.adapters.openvino.task.load_model", return_value=self.model)
mocker.patch("otx.algorithms.action.adapters.openvino.task.get_nodes_by_type", return_value=False)
mocker.patch("otx.algorithms.action.adapters.openvino.task.IEEngine", return_value=True)
mocker.patch("otx.algorithms.action.adapters.openvino.task.create_pipeline", return_value=MockPipeline())
mocker.patch("otx.algorithms.action.adapters.openvino.task.compress_model_weights", return_value=True)
mocker.patch("otx.algorithms.action.adapters.openvino.task.save_model", side_effect=mock_save_model)
mocker.patch(
"otx.algorithms.action.adapters.openvino.task.ActionOpenVINOTask.load_inferencer",
return_value=MockOVInferencer(),
)
task = ActionOpenVINOTask(self.task_environment)
task.optimize(OptimizationType.POT, self.dataset, self.model, OptimizationParameters())
fake_quantize.assert_called_once()
assert self.model.get_data("openvino.xml") is not None
assert self.model.get_data("openvino.bin") is not None
assert self.model.model_format == ModelFormat.OPENVINO
Expand Down Expand Up @@ -390,7 +392,5 @@ def test_getitem(self) -> None:
"""Test __getitem__ function."""

out = self.dataloader[0]
assert out[0][0] == 0
assert isinstance(out[0][1], AnnotationSceneEntity)
assert len(out[1]) == 10
assert isinstance(out[2], dict)
assert isinstance(out[1], AnnotationSceneEntity)
assert len(out[0]) == 10
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def patch_save_model(model, output_xml):
self.cls_ov_task.model.set_data("openvino.xml", b"bar")
mocker.patch("otx.algorithms.classification.adapters.openvino.task.ov.Core.read_model", autospec=True)
mocker.patch("otx.algorithms.classification.adapters.openvino.task.ov.serialize", new=patch_save_model)
fake_quantize = mocker.patch("otx.algorithms.classification.adapters.openvino.task.nncf.quantize", autospec=True)
fake_quantize = mocker.patch(
"otx.algorithms.classification.adapters.openvino.task.nncf.quantize", autospec=True
)
self.cls_ov_task.optimize(OptimizationType.POT, dataset=self.dataset, output_model=output_model)

fake_quantize.assert_called_once()
Expand Down

0 comments on commit b78e9da

Please sign in to comment.