Skip to content

Commit

Permalink
reduce num_iters, rename DetClassProbMapHook
Browse files Browse the repository at this point in the history
  • Loading branch information
negvet committed Mar 27, 2023
1 parent 8e6c279 commit 76a93f5
Show file tree
Hide file tree
Showing 11 changed files with 44 additions and 39 deletions.
4 changes: 2 additions & 2 deletions otx/algorithms/detection/adapters/mmdet/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
# SPDX-License-Identifier: Apache-2.0
#

from .det_class_probability_map_hook import DetClassProbabilityMap
from .det_class_probability_map_hook import DetClassProbabilityMapHook

__all__ = ["DetClassProbabilityMap"]
__all__ = ["DetClassProbabilityMapHook"]
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
# pylint: disable=too-many-locals


class DetClassProbabilityMap(BaseRecordingForwardHook):

class DetClassProbabilityMapHook(BaseRecordingForwardHook):
"""Saliency map hook for object detection models."""

def __init__(self, module: torch.nn.Module) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mmdet.models.detectors.atss import ATSS

from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMap,
DetClassProbabilityMapHook,
)
from otx.mpa.deploy.utils import is_mmdeploy_enabled
from otx.mpa.modules.hooks.recording_forward_hooks import FeatureVectorHook
Expand Down Expand Up @@ -97,7 +97,7 @@ def custom_atss__simple_test(ctx, self, img, img_metas, **kwargs):
if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetClassProbabilityMap(self).func(feature_map=cls_scores, cls_scores_provided=True)
saliency_map = DetClassProbabilityMapHook(self).func(feature_map=cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mmdet.models.detectors.single_stage import SingleStageDetector

from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMap,
DetClassProbabilityMapHook,
)
from otx.mpa.deploy.utils import is_mmdeploy_enabled
from otx.mpa.modules.hooks.recording_forward_hooks import FeatureVectorHook
Expand Down Expand Up @@ -154,7 +154,7 @@ def custom_single_stage_detector__simple_test(ctx, self, img, img_metas, **kwarg
if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetClassProbabilityMap(self).func(cls_scores, cls_scores_provided=True)
saliency_map = DetClassProbabilityMapHook(self).func(cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mmdet.models.detectors.yolox import YOLOX

from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMap,
DetClassProbabilityMapHook,
)
from otx.mpa.deploy.utils import is_mmdeploy_enabled
from otx.mpa.modules.hooks.recording_forward_hooks import FeatureVectorHook
Expand Down Expand Up @@ -135,7 +135,7 @@ def custom_yolox__simple_test(ctx, self, img, img_metas, **kwargs):
if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetClassProbabilityMap(self).func(cls_scores, cls_scores_provided=True)
saliency_map = DetClassProbabilityMapHook(self).func(cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results
Expand Down
4 changes: 2 additions & 2 deletions otx/mpa/det/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from otx.algorithms.detection.adapters.mmdet.datasets import ImageTilingDataset
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMap,
DetClassProbabilityMapHook,
)
from otx.mpa.modules.hooks.recording_forward_hooks import (
ActivationMapHook,
Expand All @@ -27,7 +27,7 @@

logger = get_logger()
EXPLAINER_HOOK_SELECTOR = {
"classwisesaliencymap": DetClassProbabilityMap,
"classwisesaliencymap": DetClassProbabilityMapHook,
"eigencam": EigenCamHook,
"activationmap": ActivationMapHook,
}
Expand Down
4 changes: 2 additions & 2 deletions otx/mpa/det/inferrer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from otx.algorithms.detection.adapters.mmdet.datasets import ImageTilingDataset
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMap,
DetClassProbabilityMapHook,
)
from otx.mpa.modules.hooks.recording_forward_hooks import (
ActivationMapHook,
Expand Down Expand Up @@ -179,7 +179,7 @@ def infer(self, cfg, model_builder=None, eval=False, dump_features=False, dump_s
if isinstance(raw_model, TwoStageDetector):
saliency_hook = ActivationMapHook(feature_model)
else:
saliency_hook = DetClassProbabilityMap(feature_model)
saliency_hook = DetClassProbabilityMapHook(feature_model)

eval_predictions = []
with FeatureVectorHook(feature_model) if dump_features else nullcontext() as feature_vector_hook:
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_api_xai_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class TestOVDetXAIAPI(DetectionTaskAPIBase):
def test_inference_xai(self):
with tempfile.TemporaryDirectory() as temp_dir:
hyper_parameters, model_template = self.setup_configurable_parameters(
DEFAULT_DET_TEMPLATE_DIR, num_iters=100
DEFAULT_DET_TEMPLATE_DIR, num_iters=15
)
detection_environment, dataset = self.init_environment(hyper_parameters, model_template, 10)

Expand Down
28 changes: 16 additions & 12 deletions tests/test_suite/run_test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def xfail_templates(templates, xfail_template_ids_reasons):
return xfailed_templates


def otx_explain_testing(template, root, otx_dir, args):
def otx_explain_testing(template, root, otx_dir, args, trained=False):
template_work_dir = get_template_dir(template, root)
if "RCNN" in template.model_template_id:
test_algorithm = "ActivationMap"
Expand Down Expand Up @@ -683,8 +683,9 @@ def otx_explain_testing(template, root, otx_dir, args):
]
check_run(command_line)
assert os.path.exists(output_dir)
assert len(os.listdir(output_dir)) > 0
assert all([fname.split(".")[-1] == "tiff" for fname in os.listdir(output_dir)])
if trained:
assert len(os.listdir(output_dir)) > 0
assert all([fname.split(".")[-1] == "tiff" for fname in os.listdir(output_dir)])


def otx_explain_testing_all_classes(template, root, otx_dir, args):
Expand Down Expand Up @@ -731,7 +732,7 @@ def otx_explain_testing_all_classes(template, root, otx_dir, args):
assert all([fname.split(".")[-1] == "tiff" for fname in os.listdir(output_dir)])


def otx_explain_testing_process_saliency_maps(template, root, otx_dir, args):
def otx_explain_testing_process_saliency_maps(template, root, otx_dir, args, trained=False):
template_work_dir = get_template_dir(template, root)
if "RCNN" in template.model_template_id:
test_algorithm = "ActivationMap"
Expand Down Expand Up @@ -765,11 +766,12 @@ def otx_explain_testing_process_saliency_maps(template, root, otx_dir, args):
]
check_run(command_line)
assert os.path.exists(output_dir)
assert len(os.listdir(output_dir)) > 0
assert all([fname.split(".")[-1] == "png" for fname in os.listdir(output_dir)])
if trained:
assert len(os.listdir(output_dir)) > 0
assert all([fname.split(".")[-1] == "png" for fname in os.listdir(output_dir)])


def otx_explain_openvino_testing(template, root, otx_dir, args):
def otx_explain_openvino_testing(template, root, otx_dir, args, trained=False):
template_work_dir = get_template_dir(template, root)
if "RCNN" in template.model_template_id:
test_algorithm = "ActivationMap"
Expand Down Expand Up @@ -803,8 +805,9 @@ def otx_explain_openvino_testing(template, root, otx_dir, args):
assert os.path.exists(f"{template_work_dir}/exported_{template.model_template_id}_w_features/openvino.xml")
check_run(command_line)
assert os.path.exists(output_dir)
assert len(os.listdir(output_dir)) > 0
assert all([fname.split(".")[-1] == "tiff" for fname in os.listdir(output_dir)])
if trained:
assert len(os.listdir(output_dir)) > 0
assert all([fname.split(".")[-1] == "tiff" for fname in os.listdir(output_dir)])


def otx_explain_all_classes_openvino_testing(template, root, otx_dir, args):
Expand Down Expand Up @@ -852,7 +855,7 @@ def otx_explain_all_classes_openvino_testing(template, root, otx_dir, args):
assert all([fname.split(".")[-1] == "tiff" for fname in os.listdir(output_dir)])


def otx_explain_process_saliency_maps_openvino_testing(template, root, otx_dir, args):
def otx_explain_process_saliency_maps_openvino_testing(template, root, otx_dir, args, trained=False):
template_work_dir = get_template_dir(template, root)
if "RCNN" in template.model_template_id:
test_algorithm = "ActivationMap"
Expand Down Expand Up @@ -887,8 +890,9 @@ def otx_explain_process_saliency_maps_openvino_testing(template, root, otx_dir,
assert os.path.exists(f"{template_work_dir}/exported_{template.model_template_id}_w_features/openvino.xml")
check_run(command_line)
assert os.path.exists(output_dir)
assert len(os.listdir(output_dir)) > 0
assert all([fname.split(".")[-1] == "png" for fname in os.listdir(output_dir)])
if trained:
assert len(os.listdir(output_dir)) > 0
assert all([fname.split(".")[-1] == "png" for fname in os.listdir(output_dir)])


def otx_find_testing():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest
import torch

from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
DetSaliencyMapHook,
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
)
from otx.algorithms.detection.adapters.mmdet.models.heads.custom_atss_head import (
CustomATSSHead,
Expand All @@ -19,8 +19,8 @@
from tests.test_suite.e2e_test_system import e2e_pytest_unit


class TestDetSaliencyMapHook:
"""Test class for DetSaliencyMapHook."""
class TestDetClassProbabilityMapHook:
"""Test class for DetClassProbabilityMapHook."""

@pytest.fixture(autouse=True)
def setup(self) -> None:
Expand All @@ -37,14 +37,14 @@ def forward(self, x):
return x

self.module = _MockModule()
self.hook = DetSaliencyMapHook(self.module)
self.hook = DetClassProbabilityMapHook(self.module)

@e2e_pytest_unit
def test_func(self, mocker) -> None:
"""Test func function."""

mocker.patch.object(
DetSaliencyMapHook, "_get_cls_scores_from_feature_map", return_value=[torch.randn(1, 3, 14, 14)]
DetClassProbabilityMapHook, "_get_cls_scores_from_feature_map", return_value=[torch.randn(1, 3, 14, 14)]
)
assert self.hook.func(torch.randn(1, 3, 14, 14)) is not None

Expand All @@ -53,14 +53,14 @@ def test_get_cls_scores_from_feature_map(self) -> None:
"""Test _get_cls_scores_from_feature_map function."""

self.module.bbox_head = CustomATSSHead(num_classes=3, in_channels=64)
self.hook = DetSaliencyMapHook(self.module)
self.hook = DetClassProbabilityMapHook(self.module)
assert self.hook._get_cls_scores_from_feature_map(torch.Tensor(1, 3, 64, 32, 32)) is not None
self.module.bbox_head = CustomYOLOXHead(num_classes=3, in_channels=64)
self.hook = DetSaliencyMapHook(self.module)
self.hook = DetClassProbabilityMapHook(self.module)
assert self.hook._get_cls_scores_from_feature_map(torch.Tensor(1, 3, 64, 32, 32)) is not None
self.module.bbox_head = CustomVFNetHead(num_classes=3, in_channels=64)
self.module.bbox_head.anchor_generator.num_base_anchors = 1
self.hook = DetSaliencyMapHook(self.module)
self.hook = DetClassProbabilityMapHook(self.module)
assert self.hook._get_cls_scores_from_feature_map(torch.Tensor(1, 3, 64, 32, 32)) is not None
self.module.bbox_head = CustomSSDHead(
anchor_generator=dict(
Expand All @@ -71,10 +71,10 @@ def test_get_cls_scores_from_feature_map(self) -> None:
),
act_cfg={},
)
self.hook = DetSaliencyMapHook(self.module)
self.hook = DetClassProbabilityMapHook(self.module)
assert self.hook._get_cls_scores_from_feature_map(torch.Tensor(1, 3, 512, 32, 32)) is not None
self.module.bbox_head = torch.nn.Module()
self.module.bbox_head.cls_out_channels = 3
self.hook = DetSaliencyMapHook(self.module)
self.hook = DetClassProbabilityMapHook(self.module)
with pytest.raises(NotImplementedError):
self.hook._get_cls_scores_from_feature_map(torch.Tensor(1, 3, 512, 32, 32))
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mmdet.models import build_detector

from otx.algorithms.classification.tasks import ClassificationInferenceTask # noqa
from otx.algorithms.detection.adapters.mmdet.hooks import DetClassProbabilityMap
from otx.algorithms.detection.adapters.mmdet.hooks import DetClassProbabilityMapHook
from otx.cli.registry import Registry
from otx.mpa.det.stage import DetectionStage # noqa
from otx.mpa.utils.config_utils import MPAConfig
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_saliency_map_det(self, template):
]
data = {"img_metas": [img_metas], "img": [img]}

with DetClassProbabilityMap(model) as det_hook:
with DetClassProbabilityMapHook(model) as det_hook:
with torch.no_grad():
_ = model(return_loss=False, rescale=True, **data)
saliency_maps = det_hook.records
Expand Down

0 comments on commit 76a93f5

Please sign in to comment.