Skip to content

Commit

Permalink
Merge branch '1-dev/mc_rise' into 'main'
Browse files Browse the repository at this point in the history
Implement MC-RISE classifier algorithm

Closes #1

See merge request jatic/kitware/xaitk-saliency!6
  • Loading branch information
Brandon RichardWebster committed Dec 12, 2024
2 parents fe960b3 + b22cfb8 commit 8441df4
Show file tree
Hide file tree
Showing 14 changed files with 941 additions and 2 deletions.
15 changes: 15 additions & 0 deletions docs/implementations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ trade-offs, or results implications.
Image Perturbation
------------------

Class: MCRISEGrid
---------------
.. autoclass:: xaitk_saliency.impls.perturb_image.mc_rise.MCRISEGrid
:members:

Class: RandomGrid
-----------------
.. autoclass:: xaitk_saliency.impls.perturb_image.random_grid.RandomGrid
Expand Down Expand Up @@ -43,6 +48,11 @@ Class: DRISEScoring
.. autoclass:: xaitk_saliency.impls.gen_detector_prop_sal.drise_scoring.DRISEScoring
:members:

Class: MCRISEScoring
------------------
.. autoclass:: xaitk_saliency.impls.gen_classifier_conf_sal.mc_rise_scoring.MCRISEScoring
:members:

Class: OcclusionScoring
-----------------------
.. autoclass:: xaitk_saliency.impls.gen_classifier_conf_sal.occlusion_scoring.OcclusionScoring
Expand Down Expand Up @@ -70,6 +80,11 @@ End-to-End Saliency Generation
Image Classification
--------------------

Class: MCRISEStack
~~~~~~~~~~~~~~~~
.. autoclass:: xaitk_saliency.impls.gen_image_classifier_blackbox_sal.mc_rise.MCRISEStack
:members:

Class: PerturbationOcclusion
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: xaitk_saliency.impls.gen_image_classifier_blackbox_sal.occlusion_based.PerturbationOcclusion
Expand Down
2 changes: 2 additions & 0 deletions docs/release_notes/pending_release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Code

* Updated hundreds of typing and linting needs.

* Added MC-RISE classification saliency algorithm.

Documentation

* Removed a deprecated badge from the README.
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,16 @@ sal-on-coco-dets= "xaitk_saliency.utils.bin.sal_on_coco_dets:sal_on_coco_dets"
[tool.poetry.plugins."smqtk_plugins"]
# Add implementation sub-module exposure here.
"impls.perturb_image.sliding_window" = "xaitk_saliency.impls.perturb_image.sliding_window"
"impls.perturb_image.mc_rise" = "xaitk_saliency.impls.perturb_image.mc_rise"
"impls.perturb_image.rise" = "xaitk_saliency.impls.perturb_image.rise"
"impls.perturb_image.random_grid" = "xaitk_saliency.impls.perturb_image.random_grid"
"impls.perturb_image.sliding_radial" = "xaitk_saliency.impls.perturb_image.sliding_radial"
"impls.gen_classifier_conf_sal.occlusion_scoring" = "xaitk_saliency.impls.gen_classifier_conf_sal.occlusion_scoring"
"impls.gen_classifier_conf_sal.mc_rise_scoring" = "xaitk_saliency.impls.gen_classifier_conf_sal.mc_rise_scoring"
"impls.gen_classifier_conf_sal.rise_scoring" = "xaitk_saliency.impls.gen_classifier_conf_sal.rise_scoring"
"impls.gen_descriptor_sim_sal.similarity_scoring" = "xaitk_saliency.impls.gen_descriptor_sim_sal.similarity_scoring"
"impls.gen_image_classifier_blackbox_sal.occlusion_based" = "xaitk_saliency.impls.gen_image_classifier_blackbox_sal.occlusion_based"
"impls.gen_image_classifier_blackbox_sal.mc_rise" = "xaitk_saliency.impls.gen_image_classifier_blackbox_sal.mc_rise"
"impls.gen_image_classifier_blackbox_sal.rise" = "xaitk_saliency.impls.gen_image_classifier_blackbox_sal.rise"
"impls.gen_detector_prop_sal.drise_scoring" = "xaitk_saliency.impls.gen_detector_prop_sal.drise_scoring"
"impls.gen_object_detector_blackbox_sal.drise" = "xaitk_saliency.impls.gen_object_detector_blackbox_sal.drise"
Expand Down
124 changes: 124 additions & 0 deletions src/xaitk_saliency/impls/gen_classifier_conf_sal/mc_rise_scoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Implementation of MC-RISE scorer"""

from typing import Any

import numpy as np
from sklearn.preprocessing import maxabs_scale
from typing_extensions import override

from xaitk_saliency import GenerateClassifierConfidenceSaliency
from xaitk_saliency.utils.masking import weight_regions_by_scalar


class MCRISEScoring(GenerateClassifierConfidenceSaliency):
"""
Saliency map generation based on the MC-RISE implementation.
This version utilizes only the input perturbed image confidence predictions
and does not utilize reference image confidences.
This implementation also takes influence from debiased RISE and may take an
optional debias probability, `p1` (0 by default).
In the original paper this is paired with the same probability used in RISE
perturbation mask generation (see the `p1` parameter in
:class:`xaitk_saliency.impls.perturb_image.mc_rise.MCRISEGrid`).
Based on Hatakeyama et. al:
https://openaccess.thecvf.com/content/ACCV2020/papers/Hatakeyama_Visualizing_Color-wise_Saliency_of_Black-Box_Image_Classification_Models_ACCV_2020_paper.pdf
"""

def __init__(
self,
k: int,
p1: float = 0.0,
) -> None:
"""
:param k: int
Number of colors to used during perturbation.
:param p1: float
Debias probability, typically paired with the same probability used in mask generation.
:raises: ValueError
If p1 not in in [0, 1].
:raises: ValueError
If k < 1.
"""
if p1 < 0 or p1 > 1:
raise ValueError(f"Input p1 value of {p1} is not within the expected [0,1] range.")
self.p1 = p1

if k < 1:
raise ValueError(f"Input k value of {k} is not within the expected >0 range.")
self.k = k

@override
def generate(
self,
image_conf: np.ndarray,
perturbed_conf: np.ndarray,
perturbed_masks: np.ndarray,
) -> np.ndarray:
"""
Warning: this implementation returns a different shape than typically expected by
this interface. Instead of `[nClasses x H x W]`, saliency maps of shape
`[kColors x nClasses x H x W]` are generated, one per color per class.
:param image_conf: np.ndarray
Reference image predicted class-confidence vector, as a
`numpy.ndarray`, for all classes that require saliency map
generation.
This should have a shape `[nClasses]`, be float-typed and with
values in the [0,1] range.
:param perturbed_conf: np.ndarray
Perturbed image predicted class confidence matrix.
Classes represented in this matrix should be congruent to classes
represented in the `image_conf` vector.
This should have a shape `[nMasks x nClasses]`, be float-typed and
with values in the [0,1] range.
:param perturbed_masks: np.ndarray
Perturbation masks `numpy.ndarray` over the reference image.
This should be parallel in association to the classification
results input into the `perturbed_conf` parameter.
This should have a shape `[kColors x nMasks x H x W]`, and values in range
[0, 1], where a value closer to 1 indicate areas of the image that
are *unperturbed*.
:return: np.ndarray
Generated visual saliency heatmap for each input class as a
float-type `numpy.ndarray` of shape `[kColors x nClasses x H x W]`.
:raises: ValueError
If number of perturbations masks and respective confidence lengths do not match.
"""
if len(perturbed_conf) != perturbed_masks.shape[1]:
raise ValueError("Number of perturbation masks and respective confidence lengths do not match.")

sal_maps = []
# Compute unmasked regions
perturbed_masks = 1 - perturbed_masks
m0 = 1 - perturbed_masks.sum(axis=0)
for k_masks in perturbed_masks:
# Debias based on the MC-RISE paper
sal = weight_regions_by_scalar(
scalar_vec=perturbed_conf,
# Factoring out denominator from self.k * k_masks / self.p1 - m0 / (1 - self.p1)
# to avoid divide by zero. Only acts as a normalization factor
masks=self.k * k_masks * (1 - self.p1) - m0 * self.p1,
inv_masks=False,
normalize=False,
)

# Normalize final saliency map
sal = maxabs_scale(sal.reshape(sal.shape[0], -1), axis=1).reshape(sal.shape)

# Ensure saliency map in range [-1, 1]
sal = np.clip(sal, -1, 1)

sal_maps.append(sal)

return np.asarray(sal_maps)

@override
def get_config(self) -> dict[str, Any]:
return {
"p1": self.p1,
"k": self.k,
}
175 changes: 175 additions & 0 deletions src/xaitk_saliency/impls/gen_image_classifier_blackbox_sal/mc_rise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Implementation of MC-RISE saliency stack"""

import itertools
from collections.abc import Generator, Iterable, Sequence
from typing import Any

import numpy as np
from smqtk_classifier import ClassifyImage
from smqtk_descriptors.utils import parallel_map
from typing_extensions import override

from xaitk_saliency.impls.gen_classifier_conf_sal.mc_rise_scoring import MCRISEScoring
from xaitk_saliency.impls.perturb_image.mc_rise import MCRISEGrid
from xaitk_saliency.interfaces.gen_image_classifier_blackbox_sal import GenerateImageClassifierBlackboxSaliency


class MCRISEStack(GenerateImageClassifierBlackboxSaliency):
"""
Encapsulation of the perturbation-occlusion method using specifically the
MC-RISE implementations of the component algorithms.
This more specifically encapsulates the MC-RISE method as presented
in their paper and code. See references in the :class:`MCRISEGrid`
and :class:`MCRISEScoring` documentation.
This implementation shares the `p1` probability and 'k' number colors
with the internal `MCRISEScoring` instance use, to make use of the
debiasing described in the MC-RISE paper. Debiasing is always on.
"""

def __init__(
self,
n: int,
s: int,
p1: float,
fill_colors: Sequence[Sequence[int]],
seed: int = None,
threads: int = 0,
) -> None:
"""
:param n: int
Number of random masks used in the algorithm. E.g. 1000.
:param s: int
Spatial resolution of the small masking grid. E.g. 8.
Assumes square grid.
:param p1: float
Probability of the grid cell being set to 1 (otherwise 0).
This should be a float value in the [0, 1] range. E.g. 0.5.
:param fill_colors: Sequence[Sequence[int]]
The fill colors to be used when generating masks.
:param seed: Optional[int]
A seed to pass into the constructed random number generator to allow
for reproducibility
:param threads: int
The number of threads to utilize when generating masks.
If this is <=0 or None, no threading is used and processing
is performed in-line serially.
:raises: ValueError
If no fill colors are provided.
:raises: ValueError
If provided fill colors have differing numbers of channels.
"""
if len(fill_colors) == 0:
raise ValueError("At least one fill color must be provided")

Check warning on line 65 in src/xaitk_saliency/impls/gen_image_classifier_blackbox_sal/mc_rise.py

View check run for this annotation

Codecov / codecov/patch

src/xaitk_saliency/impls/gen_image_classifier_blackbox_sal/mc_rise.py#L65

Added line #L65 was not covered by tests
for fill_color in fill_colors:
if len(fill_color) != len(fill_colors[0]):
raise ValueError("All fill colors must have the same number of channels")

Check warning on line 68 in src/xaitk_saliency/impls/gen_image_classifier_blackbox_sal/mc_rise.py

View check run for this annotation

Codecov / codecov/patch

src/xaitk_saliency/impls/gen_image_classifier_blackbox_sal/mc_rise.py#L68

Added line #L68 was not covered by tests

self._perturber = MCRISEGrid(n=n, s=s, p1=p1, k=len(fill_colors), seed=seed, threads=threads)
self._generator = MCRISEScoring(k=len(fill_colors), p1=p1)
self._threads = threads
self._fill_colors = fill_colors

@staticmethod
def _work_func(*, ref_image: np.ndarray, i_: int, m: np.ndarray, f: np.ndarray) -> np.ndarray:
s: tuple = (...,)
if ref_image.ndim > 2:
s = (..., None) # add channel axis for multiplication

# Just the [H x W] component.
img_shape = ref_image.shape[:2]

m_shape = m.shape
if m_shape != img_shape:
raise ValueError(

Check warning on line 86 in src/xaitk_saliency/impls/gen_image_classifier_blackbox_sal/mc_rise.py

View check run for this annotation

Codecov / codecov/patch

src/xaitk_saliency/impls/gen_image_classifier_blackbox_sal/mc_rise.py#L86

Added line #L86 was not covered by tests
f"Input mask (position {i_}) did not the shape of the input image: {m_shape} != {img_shape}",
)
img_m = np.empty_like(ref_image)

np.add((m[s] * ref_image), f, out=img_m, casting="unsafe")

return img_m

@staticmethod
def _occlude_image_streaming(
*,
ref_image: np.ndarray,
masks: Iterable[np.ndarray],
fill_values: Iterable[np.ndarray],
threads: int = None,
) -> Generator[np.ndarray, None, None]:
if threads is None or threads < 1:
for i, (mask, fill) in enumerate(zip(masks, fill_values)):
yield MCRISEStack._work_func(ref_image=ref_image, i_=i, m=mask, f=fill)
else:
yield from parallel_map(

Check warning on line 107 in src/xaitk_saliency/impls/gen_image_classifier_blackbox_sal/mc_rise.py

View check run for this annotation

Codecov / codecov/patch

src/xaitk_saliency/impls/gen_image_classifier_blackbox_sal/mc_rise.py#L107

Added line #L107 was not covered by tests
MCRISEStack._work_func,
ref_image=ref_image,
i_=itertools.count(),
m=masks,
f=fill_values,
cores=threads,
use_multiprocessing=False,
)

@override
def _generate(self, ref_image: np.ndarray, blackbox: ClassifyImage) -> np.ndarray:
"""
Warning: this implementation returns a different shape than is typically expected by this interface.
Instead of returning `[nClasses x H x W]`, `[kColors x nClasses x H x W] saliency maps will be returned.
:param ref_image: np.ndarray
Reference image over which visual saliency heatmaps will be generated.
:param blackbox: ClassifyImage
The black-box classifier handle to perform arbitrary operations on in order to deduce visual saliency.
:return: np.ndarray
A number of visual saliency heatmaps equivalent in number to the quantity of class labels output
by the black-box classifier and configured number of colors.
"""
k_perturbation_masks = self._perturber(ref_image)

# Collapse k individual colored masks into n multi-color masks
perturbation_masks = 1 - np.sum(1 - k_perturbation_masks, axis=0)
fill_values = np.zeros((*k_perturbation_masks.shape[1:], len(self._fill_colors[0])))
for mask_idx, k_masks in enumerate(np.swapaxes(k_perturbation_masks, 0, 1)):
k_masks = 1 - np.repeat(k_masks[..., np.newaxis], len(self._fill_colors[0]), axis=3)
for fill_idx, fill_value in enumerate(self._fill_colors):
fill_values[mask_idx] += k_masks[fill_idx] * fill_value
fill_values = np.clip(fill_values, 0, 255)

class_list = blackbox.get_labels()
# Input one thing so assume output of one thing.
ref_conf_dict = list(blackbox.classify_images([ref_image]))[0]
ref_conf_vec = np.asarray([ref_conf_dict[la] for la in class_list])
pert_conf_mat = np.empty((perturbation_masks.shape[0], ref_conf_vec.shape[0]), dtype=ref_conf_vec.dtype)

pert_conf_it = blackbox.classify_images(
MCRISEStack._occlude_image_streaming(
ref_image=ref_image,
masks=perturbation_masks,
fill_values=fill_values,
threads=self._threads,
),
)
for i, pc in enumerate(pert_conf_it):
pert_conf_mat[i] = [pc[la] for la in class_list]

# Compose classification results into a matrix for the generator
# algorithm.
return self._generator(
image_conf=ref_conf_vec,
perturbed_conf=pert_conf_mat,
perturbed_masks=k_perturbation_masks,
)

@override
def get_config(self) -> dict[str, Any]:
# It turns out that our configuration here is nearly equivalent to that given
# and retrieved from the MCRISEGrid implementation
c = self._perturber.get_config()
del c["k"]
c["fill_colors"] = self._fill_colors
return c
Loading

0 comments on commit 8441df4

Please sign in to comment.