Skip to content

Commit

Permalink
Fix segmentation ptq test
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Jun 12, 2023
1 parent 6ead50b commit e0564fd
Showing 1 changed file with 10 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import copy
import os
import pathlib

import numpy as np
import pytest
Expand Down Expand Up @@ -194,22 +195,23 @@ def test_deploy(self, otx_model):

@e2e_pytest_unit
def test_optimize(self, mocker, otx_model):
def patch_save_model(model, dir_path, model_name):
with open(f"{dir_path}/{model_name}.xml", "wb") as f:
def patch_save_model(model, output_xml):
with open(output_xml, "wb") as f:
f.write(b"foo")
with open(f"{dir_path}/{model_name}.bin", "wb") as f:
bin_path = pathlib.Path(output_xml).parent / pathlib.Path(str(pathlib.Path(output_xml).stem) + ".bin")
with open(bin_path, "wb") as f:
f.write(b"bar")

dataset = generate_otx_dataset()
output_model = copy.deepcopy(otx_model)
self.seg_ov_task.model.set_data("openvino.bin", b"foo")
self.seg_ov_task.model.set_data("openvino.xml", b"bar")
mocker.patch("otx.algorithms.segmentation.adapters.openvino.task.load_model", autospec=True)
mocker.patch("otx.algorithms.segmentation.adapters.openvino.task.create_pipeline", autospec=True)
mocker.patch("otx.algorithms.segmentation.adapters.openvino.task.save_model", new=patch_save_model)
spy_compress = mocker.spy(otx.algorithms.segmentation.adapters.openvino.task, "compress_model_weights")

mocker.patch("otx.algorithms.segmentation.adapters.openvino.task.ov.Core.read_model", autospec=True)
mocker.patch("otx.algorithms.segmentation.adapters.openvino.task.ov.serialize", new=patch_save_model)
fake_quantize = mocker.patch("otx.algorithms.segmentation.adapters.openvino.task.nncf.quantize", autospec=True)
self.seg_ov_task.optimize(OptimizationType.POT, dataset=dataset, output_model=output_model)

spy_compress.assert_called_once()
fake_quantize.assert_called_once()
assert self.seg_ov_task.model.get_data("openvino.bin")
assert self.seg_ov_task.model.get_data("openvino.xml")

0 comments on commit e0564fd

Please sign in to comment.