From cfd77068ceb67540df8c04780ecfb10866fc7fb1 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Mon, 10 Jul 2023 10:53:18 +0200 Subject: [PATCH] Fix semantic segmentation soft prediction dtype (#2322) * Fix semantic segmentation soft prediction dtype * relax ref sal vals check --------- Co-authored-by: Songki Choi --- src/otx/algorithms/segmentation/adapters/openvino/task.py | 2 ++ src/otx/algorithms/segmentation/task.py | 2 ++ .../classification/test_xai_classification_validity.py | 2 +- tests/unit/algorithms/detection/test_xai_detection_validity.py | 2 +- 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/otx/algorithms/segmentation/adapters/openvino/task.py b/src/otx/algorithms/segmentation/adapters/openvino/task.py index 268eb2e933a..270216505c6 100644 --- a/src/otx/algorithms/segmentation/adapters/openvino/task.py +++ b/src/otx/algorithms/segmentation/adapters/openvino/task.py @@ -240,6 +240,8 @@ def add_prediction( current_label_soft_prediction = soft_prediction[:, :, label_index] if process_soft_prediction: current_label_soft_prediction = get_activation_map(current_label_soft_prediction) + else: + current_label_soft_prediction = (current_label_soft_prediction * 255).astype(np.uint8) result_media = ResultMediaEntity( name=label.name, type="soft_prediction", diff --git a/src/otx/algorithms/segmentation/task.py b/src/otx/algorithms/segmentation/task.py index a3a8923454b..08860a3e168 100644 --- a/src/otx/algorithms/segmentation/task.py +++ b/src/otx/algorithms/segmentation/task.py @@ -290,6 +290,8 @@ def _add_predictions_to_dataset(self, prediction_results, dataset, dump_soft_pre current_label_soft_prediction = soft_prediction[:, :, label_index] if process_soft_prediction: current_label_soft_prediction = get_activation_map(current_label_soft_prediction) + else: + current_label_soft_prediction = (current_label_soft_prediction * 255).astype(np.uint8) result_media = ResultMediaEntity( name=label.name, type="soft_prediction", diff --git a/tests/unit/algorithms/classification/test_xai_classification_validity.py b/tests/unit/algorithms/classification/test_xai_classification_validity.py index 8f15ede6133..1ec20d0c2c1 100644 --- a/tests/unit/algorithms/classification/test_xai_classification_validity.py +++ b/tests/unit/algorithms/classification/test_xai_classification_validity.py @@ -54,4 +54,4 @@ def test_saliency_map_cls(self, template): assert len(saliency_maps) == 2 assert saliency_maps[0].ndim == 3 assert saliency_maps[0].shape == (1000, 7, 7) - assert (saliency_maps[0][0][0] == self.ref_saliency_vals_cls[template.name]).all() + assert np.all(np.abs(saliency_maps[0][0][0] - self.ref_saliency_vals_cls[template.name]) <= 1) diff --git a/tests/unit/algorithms/detection/test_xai_detection_validity.py b/tests/unit/algorithms/detection/test_xai_detection_validity.py index 6ed47397992..3e903a0b5d6 100644 --- a/tests/unit/algorithms/detection/test_xai_detection_validity.py +++ b/tests/unit/algorithms/detection/test_xai_detection_validity.py @@ -80,7 +80,7 @@ def test_saliency_map_det(self, template): assert len(saliency_maps) == 2 assert saliency_maps[0].ndim == 3 assert saliency_maps[0].shape == self.ref_saliency_shapes[template.name] - assert (saliency_maps[0][0][0] == self.ref_saliency_vals_det[template.name]).all() + assert np.all(np.abs(saliency_maps[0][0][0] - self.ref_saliency_vals_det[template.name]) <= 1) @e2e_pytest_unit @pytest.mark.parametrize("template", templates_det, ids=templates_det_ids)