Skip to content

Commit

Permalink
Enhance XAI (CLI, ExplainParameters, test) (#1941)
Browse files Browse the repository at this point in the history
* xai_enhance

* Update IExplainTask signature

* pylint fix

* logger

* reduce num_iters, rename DetClassProbMapHook

* reply to comments

* rebase + docs

* rebase
  • Loading branch information
negvet authored Apr 5, 2023
1 parent 9d6edc7 commit 40094c2
Show file tree
Hide file tree
Showing 25 changed files with 810 additions and 344 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ All notable changes to this project will be documented in this file.

- Clean up and refactor the output of the OTX CLI (<https://github.com/openvinotoolkit/training_extensions/pull/1946>)
- Enhance DetCon logic and SupCon for semantic segmentation(<https://github.com/openvinotoolkit/training_extensions/pull/1958>)
- Extend OTX explain CLI (<https://github.com/openvinotoolkit/training_extensions/pull/1941>)

### Bug fixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ The command below will evaluate the trained model on the provided dataset:
Explanation
***********

``otx explain`` runs the explainable AI (XAI) algorithm of a model on the specific dataset. It helps explain the model's decision-making process in a way that is easily understood by humans.
``otx explain`` runs the explainable AI (XAI) algorithm on a specific model-dataset pair. It helps explain the model's decision-making process in a way that is easily understood by humans.

With the ``--help`` command, you can list additional information, such as its parameters common to all model templates:

Expand All @@ -422,8 +422,12 @@ With the ``--help`` command, you can list additional information, such as its pa
Load model weights from previously saved checkpoint.
--explain-algorithm EXPLAIN_ALGORITHM
Explain algorithm name, currently support ['activationmap', 'eigencam', 'classwisesaliencymap']. For Openvino task, default method will be selected.
--process-saliency-maps PROCESS_SALIENCY_MAPS
Processing of saliency map includes (1) resizing to input image resolution and (2) applying a colormap. Depending on the number of targets to explain, this might take significant time.
--explain-all-classes EXPLAIN_ALL_CLASSES
Provides explanations for all classes. Otherwise, explains only predicted classes. This feature is supported by algorithms that can generate explanations per each class.
--overlay-weight OVERLAY_WEIGHT
Weight of the saliency map when overlaying the saliency map.
Weight of the saliency map when overlaying the input image with saliency map.
The command below will generate saliency maps (heatmaps with red colored areas of focus) of the trained model on the provided dataset and save the resulting images to ``save-explanation-to`` path:
Expand Down
3 changes: 2 additions & 1 deletion otx/algorithms/classification/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from otx.algorithms.common.utils import embed_ir_model_data
from otx.algorithms.common.utils.logger import get_logger
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import (
InferenceParameters,
default_progress_callback,
Expand Down Expand Up @@ -162,7 +163,7 @@ def infer(
def explain(
self,
dataset: DatasetEntity,
explain_parameters: Optional[InferenceParameters] = None,
explain_parameters: Optional[ExplainParameters] = None,
) -> DatasetEntity:
"""Main explain function of OTX Classification Task."""
logger.info("called explain()")
Expand Down
9 changes: 8 additions & 1 deletion otx/algorithms/classification/tasks/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from otx.api.entities.annotation import AnnotationSceneEntity
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import (
InferenceParameters,
default_progress_callback,
Expand Down Expand Up @@ -254,7 +255,7 @@ def infer(
def explain(
self,
dataset: DatasetEntity,
explain_parameters: Optional[InferenceParameters] = None,
explain_parameters: Optional[ExplainParameters] = None,
) -> DatasetEntity:
"""Explain function of ClassificationOpenVINOTask."""

Expand All @@ -269,6 +270,12 @@ def explain(
dataset_size = len(dataset)
for i, dataset_item in enumerate(dataset, 1):
predicted_scene, _, saliency_map, _, _ = self.inferencer.predict(dataset_item.numpy)
if saliency_map is None:
raise RuntimeError(
"There is no Saliency Map in OpenVINO IR model output. "
"Please export model to OpenVINO IR with dump_features"
)

item_labels = predicted_scene.annotations[0].get_labels()
dataset_item.append_labels(item_labels)
add_saliency_maps_to_dataset_item(
Expand Down
3 changes: 2 additions & 1 deletion otx/algorithms/common/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from otx.algorithms.common.utils import UncopiableDefaultDict
from otx.algorithms.common.utils.logger import get_logger
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.label import LabelEntity
from otx.api.entities.metrics import MetricsGroup
Expand Down Expand Up @@ -208,7 +209,7 @@ def export(
def explain(
self,
dataset: DatasetEntity,
explain_parameters: Optional[InferenceParameters] = None,
explain_parameters: Optional[ExplainParameters] = None,
) -> DatasetEntity:
"""Main explain function of OTX Task."""
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/detection/adapters/mmdet/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
# SPDX-License-Identifier: Apache-2.0
#

from .det_saliency_map_hook import DetSaliencyMapHook
from .det_class_probability_map_hook import DetClassProbabilityMapHook

__all__ = ["DetSaliencyMapHook"]
__all__ = ["DetClassProbabilityMapHook"]
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# pylint: disable=too-many-locals


class DetSaliencyMapHook(BaseRecordingForwardHook):
class DetClassProbabilityMapHook(BaseRecordingForwardHook):
"""Saliency map hook for object detection models."""

def __init__(self, module: torch.nn.Module) -> None:
Expand Down Expand Up @@ -116,7 +116,7 @@ def forward_single(x, cls_convs, conv_cls):
else:
raise NotImplementedError(
"Not supported detection head provided. "
"DetSaliencyMapHook supports only the following single stage detectors: "
"DetClassProbabilityMap supports only the following single stage detectors: "
"YOLOXHead, ATSSHead, SSDHead, VFNetHead."
)
return cls_scores
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.common.utils.task_adapt import map_class_names
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
DetSaliencyMapHook,
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
)

from .l2sp_detector_mixin import L2SPDetectorMixin
Expand Down Expand Up @@ -99,7 +99,7 @@ def custom_atss__simple_test(ctx, self, img, img_metas, **kwargs):
if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetSaliencyMapHook(self).func(feature_map=cls_scores, cls_scores_provided=True)
saliency_map = DetClassProbabilityMapHook(self).func(feature_map=cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.common.utils.task_adapt import map_class_names
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
DetSaliencyMapHook,
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
)

from .l2sp_detector_mixin import L2SPDetectorMixin
Expand Down Expand Up @@ -157,7 +157,7 @@ def custom_single_stage_detector__simple_test(ctx, self, img, img_metas, **kwarg
if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetSaliencyMapHook(self).func(cls_scores, cls_scores_provided=True)
saliency_map = DetClassProbabilityMapHook(self).func(cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.common.utils.task_adapt import map_class_names
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
DetSaliencyMapHook,
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
)

from .l2sp_detector_mixin import L2SPDetectorMixin
Expand Down Expand Up @@ -137,7 +137,7 @@ def custom_yolox__simple_test(ctx, self, img, img_metas, **kwargs):
if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetSaliencyMapHook(self).func(cls_scores, cls_scores_provided=True)
saliency_map = DetClassProbabilityMapHook(self).func(cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results
Expand Down
11 changes: 6 additions & 5 deletions otx/algorithms/detection/adapters/mmdet/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
SemiSLDetectionConfigurer,
)
from otx.algorithms.detection.adapters.mmdet.datasets import ImageTilingDataset
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
DetSaliencyMapHook,
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
)
from otx.algorithms.detection.adapters.mmdet.utils.builder import build_detector
from otx.algorithms.detection.adapters.mmdet.utils.config_utils import (
Expand All @@ -74,6 +74,7 @@
from otx.api.configuration import cfg_helper
from otx.api.configuration.helper.utils import config_to_bytes, ids_to_strings
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.model import (
ModelEntity,
Expand Down Expand Up @@ -392,7 +393,7 @@ def hook(module, inp, outp):
if isinstance(raw_model, TwoStageDetector):
saliency_hook = ActivationMapHook(feature_model)
else:
saliency_hook = DetSaliencyMapHook(feature_model)
saliency_hook = DetClassProbabilityMapHook(feature_model)

if not dump_features:
feature_vector_hook: Union[nullcontext, BaseRecordingForwardHook] = nullcontext()
Expand Down Expand Up @@ -541,12 +542,12 @@ def export(
def explain(
self,
dataset: DatasetEntity,
explain_parameters: Optional[InferenceParameters] = None,
explain_parameters: Optional[ExplainParameters] = None,
) -> DatasetEntity:
"""Main explain function of MMDetectionTask."""

explainer_hook_selector = {
"classwisesaliencymap": DetSaliencyMapHook,
"classwisesaliencymap": DetClassProbabilityMapHook,
"eigencam": EigenCamHook,
"activationmap": ActivationMapHook,
}
Expand Down
8 changes: 7 additions & 1 deletion otx/algorithms/detection/adapters/openvino/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from otx.api.configuration.helper.utils import config_to_bytes
from otx.api.entities.annotation import AnnotationSceneEntity
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import (
InferenceParameters,
default_progress_callback,
Expand Down Expand Up @@ -434,7 +435,7 @@ def infer(
def explain(
self,
dataset: DatasetEntity,
explain_parameters: Optional[InferenceParameters] = None,
explain_parameters: Optional[ExplainParameters] = None,
) -> DatasetEntity:
"""Explain function of OpenVINODetectionTask."""
logger.info("Start OpenVINO explain")
Expand All @@ -453,6 +454,11 @@ def explain(
dataset_item.append_annotations(predicted_scene.annotations)
update_progress_callback(int(i / dataset_size * 100), None)
_, saliency_map = features
if saliency_map is None:
raise RuntimeError(
"There is no Saliency Map in OpenVINO IR model output. "
"Please export model to OpenVINO IR with dump_features"
)

labels = self.task_environment.get_labels().copy()
if saliency_map.shape[0] == len(labels) + 1:
Expand Down
3 changes: 2 additions & 1 deletion otx/algorithms/detection/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from otx.api.configuration.helper.utils import ids_to_strings
from otx.api.entities.annotation import Annotation
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.id import ID
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.label import Domain, LabelEntity
Expand Down Expand Up @@ -259,7 +260,7 @@ def export(
def explain(
self,
dataset: DatasetEntity,
explain_parameters: Optional[InferenceParameters] = None,
explain_parameters: Optional[ExplainParameters] = None,
) -> DatasetEntity:
"""Main explain function of OTX Task."""
raise NotImplementedError
Expand Down
33 changes: 33 additions & 0 deletions otx/api/entities/explain_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""This module define the Explain entity."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#


from dataclasses import dataclass
from typing import Any, Callable, Optional


# pylint: disable=unused-argument
def default_progress_callback(progress: int, score: Optional[float] = None):
"""This is the default progress callback for OptimizationParameters."""


@dataclass
class ExplainParameters:
"""Explain parameters.
Attributes:
explainer: Explain algorithm to be used in explanation mode.
Will be converted automatically to lowercase.
process_saliency_maps: Processing of saliency map includes (1) resize to input image resolution
and (2) apply a colormap.
explain_predicted_classes: Provides explanations only for predicted classes.
Otherwise, explain all classes.
"""

update_progress: Callable[[int, Optional[float]], Any] = default_progress_callback

explainer: str = ""
process_saliency_maps: bool = False
explain_predicted_classes: bool = True
1 change: 0 additions & 1 deletion otx/api/entities/inference_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class InferenceParameters:
is_evaluation: bool = False
update_progress: Callable[[int, Optional[float]], Any] = default_progress_callback

# TODO(negvet): use separate ExplainParameters dataclass for this
explainer: str = ""
process_saliency_maps: bool = False
explain_predicted_classes: bool = True
4 changes: 2 additions & 2 deletions otx/api/usecases/tasks/interfaces/explain_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import abc

from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.explain_parameters import ExplainParameters


class IExplainTask(metaclass=abc.ABCMeta):
Expand All @@ -18,7 +18,7 @@ class IExplainTask(metaclass=abc.ABCMeta):
def explain(
self,
dataset: DatasetEntity,
explain_parameters: InferenceParameters,
explain_parameters: ExplainParameters,
) -> DatasetEntity:
"""This is the method that is called upon explanation.
Expand Down
Loading

0 comments on commit 40094c2

Please sign in to comment.