-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch '1-dev/mc_rise' into 'main'
Implement MC-RISE classifier algorithm Closes #1 See merge request jatic/kitware/xaitk-saliency!6
- Loading branch information
Showing
14 changed files
with
941 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
124 changes: 124 additions & 0 deletions
124
src/xaitk_saliency/impls/gen_classifier_conf_sal/mc_rise_scoring.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
"""Implementation of MC-RISE scorer""" | ||
|
||
from typing import Any | ||
|
||
import numpy as np | ||
from sklearn.preprocessing import maxabs_scale | ||
from typing_extensions import override | ||
|
||
from xaitk_saliency import GenerateClassifierConfidenceSaliency | ||
from xaitk_saliency.utils.masking import weight_regions_by_scalar | ||
|
||
|
||
class MCRISEScoring(GenerateClassifierConfidenceSaliency): | ||
""" | ||
Saliency map generation based on the MC-RISE implementation. | ||
This version utilizes only the input perturbed image confidence predictions | ||
and does not utilize reference image confidences. | ||
This implementation also takes influence from debiased RISE and may take an | ||
optional debias probability, `p1` (0 by default). | ||
In the original paper this is paired with the same probability used in RISE | ||
perturbation mask generation (see the `p1` parameter in | ||
:class:`xaitk_saliency.impls.perturb_image.mc_rise.MCRISEGrid`). | ||
Based on Hatakeyama et. al: | ||
https://openaccess.thecvf.com/content/ACCV2020/papers/Hatakeyama_Visualizing_Color-wise_Saliency_of_Black-Box_Image_Classification_Models_ACCV_2020_paper.pdf | ||
""" | ||
|
||
def __init__( | ||
self, | ||
k: int, | ||
p1: float = 0.0, | ||
) -> None: | ||
""" | ||
:param k: int | ||
Number of colors to used during perturbation. | ||
:param p1: float | ||
Debias probability, typically paired with the same probability used in mask generation. | ||
:raises: ValueError | ||
If p1 not in in [0, 1]. | ||
:raises: ValueError | ||
If k < 1. | ||
""" | ||
if p1 < 0 or p1 > 1: | ||
raise ValueError(f"Input p1 value of {p1} is not within the expected [0,1] range.") | ||
self.p1 = p1 | ||
|
||
if k < 1: | ||
raise ValueError(f"Input k value of {k} is not within the expected >0 range.") | ||
self.k = k | ||
|
||
@override | ||
def generate( | ||
self, | ||
image_conf: np.ndarray, | ||
perturbed_conf: np.ndarray, | ||
perturbed_masks: np.ndarray, | ||
) -> np.ndarray: | ||
""" | ||
Warning: this implementation returns a different shape than typically expected by | ||
this interface. Instead of `[nClasses x H x W]`, saliency maps of shape | ||
`[kColors x nClasses x H x W]` are generated, one per color per class. | ||
:param image_conf: np.ndarray | ||
Reference image predicted class-confidence vector, as a | ||
`numpy.ndarray`, for all classes that require saliency map | ||
generation. | ||
This should have a shape `[nClasses]`, be float-typed and with | ||
values in the [0,1] range. | ||
:param perturbed_conf: np.ndarray | ||
Perturbed image predicted class confidence matrix. | ||
Classes represented in this matrix should be congruent to classes | ||
represented in the `image_conf` vector. | ||
This should have a shape `[nMasks x nClasses]`, be float-typed and | ||
with values in the [0,1] range. | ||
:param perturbed_masks: np.ndarray | ||
Perturbation masks `numpy.ndarray` over the reference image. | ||
This should be parallel in association to the classification | ||
results input into the `perturbed_conf` parameter. | ||
This should have a shape `[kColors x nMasks x H x W]`, and values in range | ||
[0, 1], where a value closer to 1 indicate areas of the image that | ||
are *unperturbed*. | ||
:return: np.ndarray | ||
Generated visual saliency heatmap for each input class as a | ||
float-type `numpy.ndarray` of shape `[kColors x nClasses x H x W]`. | ||
:raises: ValueError | ||
If number of perturbations masks and respective confidence lengths do not match. | ||
""" | ||
if len(perturbed_conf) != perturbed_masks.shape[1]: | ||
raise ValueError("Number of perturbation masks and respective confidence lengths do not match.") | ||
|
||
sal_maps = [] | ||
# Compute unmasked regions | ||
perturbed_masks = 1 - perturbed_masks | ||
m0 = 1 - perturbed_masks.sum(axis=0) | ||
for k_masks in perturbed_masks: | ||
# Debias based on the MC-RISE paper | ||
sal = weight_regions_by_scalar( | ||
scalar_vec=perturbed_conf, | ||
# Factoring out denominator from self.k * k_masks / self.p1 - m0 / (1 - self.p1) | ||
# to avoid divide by zero. Only acts as a normalization factor | ||
masks=self.k * k_masks * (1 - self.p1) - m0 * self.p1, | ||
inv_masks=False, | ||
normalize=False, | ||
) | ||
|
||
# Normalize final saliency map | ||
sal = maxabs_scale(sal.reshape(sal.shape[0], -1), axis=1).reshape(sal.shape) | ||
|
||
# Ensure saliency map in range [-1, 1] | ||
sal = np.clip(sal, -1, 1) | ||
|
||
sal_maps.append(sal) | ||
|
||
return np.asarray(sal_maps) | ||
|
||
@override | ||
def get_config(self) -> dict[str, Any]: | ||
return { | ||
"p1": self.p1, | ||
"k": self.k, | ||
} |
175 changes: 175 additions & 0 deletions
175
src/xaitk_saliency/impls/gen_image_classifier_blackbox_sal/mc_rise.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
"""Implementation of MC-RISE saliency stack""" | ||
|
||
import itertools | ||
from collections.abc import Generator, Iterable, Sequence | ||
from typing import Any | ||
|
||
import numpy as np | ||
from smqtk_classifier import ClassifyImage | ||
from smqtk_descriptors.utils import parallel_map | ||
from typing_extensions import override | ||
|
||
from xaitk_saliency.impls.gen_classifier_conf_sal.mc_rise_scoring import MCRISEScoring | ||
from xaitk_saliency.impls.perturb_image.mc_rise import MCRISEGrid | ||
from xaitk_saliency.interfaces.gen_image_classifier_blackbox_sal import GenerateImageClassifierBlackboxSaliency | ||
|
||
|
||
class MCRISEStack(GenerateImageClassifierBlackboxSaliency): | ||
""" | ||
Encapsulation of the perturbation-occlusion method using specifically the | ||
MC-RISE implementations of the component algorithms. | ||
This more specifically encapsulates the MC-RISE method as presented | ||
in their paper and code. See references in the :class:`MCRISEGrid` | ||
and :class:`MCRISEScoring` documentation. | ||
This implementation shares the `p1` probability and 'k' number colors | ||
with the internal `MCRISEScoring` instance use, to make use of the | ||
debiasing described in the MC-RISE paper. Debiasing is always on. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n: int, | ||
s: int, | ||
p1: float, | ||
fill_colors: Sequence[Sequence[int]], | ||
seed: int = None, | ||
threads: int = 0, | ||
) -> None: | ||
""" | ||
:param n: int | ||
Number of random masks used in the algorithm. E.g. 1000. | ||
:param s: int | ||
Spatial resolution of the small masking grid. E.g. 8. | ||
Assumes square grid. | ||
:param p1: float | ||
Probability of the grid cell being set to 1 (otherwise 0). | ||
This should be a float value in the [0, 1] range. E.g. 0.5. | ||
:param fill_colors: Sequence[Sequence[int]] | ||
The fill colors to be used when generating masks. | ||
:param seed: Optional[int] | ||
A seed to pass into the constructed random number generator to allow | ||
for reproducibility | ||
:param threads: int | ||
The number of threads to utilize when generating masks. | ||
If this is <=0 or None, no threading is used and processing | ||
is performed in-line serially. | ||
:raises: ValueError | ||
If no fill colors are provided. | ||
:raises: ValueError | ||
If provided fill colors have differing numbers of channels. | ||
""" | ||
if len(fill_colors) == 0: | ||
raise ValueError("At least one fill color must be provided") | ||
for fill_color in fill_colors: | ||
if len(fill_color) != len(fill_colors[0]): | ||
raise ValueError("All fill colors must have the same number of channels") | ||
|
||
self._perturber = MCRISEGrid(n=n, s=s, p1=p1, k=len(fill_colors), seed=seed, threads=threads) | ||
self._generator = MCRISEScoring(k=len(fill_colors), p1=p1) | ||
self._threads = threads | ||
self._fill_colors = fill_colors | ||
|
||
@staticmethod | ||
def _work_func(*, ref_image: np.ndarray, i_: int, m: np.ndarray, f: np.ndarray) -> np.ndarray: | ||
s: tuple = (...,) | ||
if ref_image.ndim > 2: | ||
s = (..., None) # add channel axis for multiplication | ||
|
||
# Just the [H x W] component. | ||
img_shape = ref_image.shape[:2] | ||
|
||
m_shape = m.shape | ||
if m_shape != img_shape: | ||
raise ValueError( | ||
f"Input mask (position {i_}) did not the shape of the input image: {m_shape} != {img_shape}", | ||
) | ||
img_m = np.empty_like(ref_image) | ||
|
||
np.add((m[s] * ref_image), f, out=img_m, casting="unsafe") | ||
|
||
return img_m | ||
|
||
@staticmethod | ||
def _occlude_image_streaming( | ||
*, | ||
ref_image: np.ndarray, | ||
masks: Iterable[np.ndarray], | ||
fill_values: Iterable[np.ndarray], | ||
threads: int = None, | ||
) -> Generator[np.ndarray, None, None]: | ||
if threads is None or threads < 1: | ||
for i, (mask, fill) in enumerate(zip(masks, fill_values)): | ||
yield MCRISEStack._work_func(ref_image=ref_image, i_=i, m=mask, f=fill) | ||
else: | ||
yield from parallel_map( | ||
MCRISEStack._work_func, | ||
ref_image=ref_image, | ||
i_=itertools.count(), | ||
m=masks, | ||
f=fill_values, | ||
cores=threads, | ||
use_multiprocessing=False, | ||
) | ||
|
||
@override | ||
def _generate(self, ref_image: np.ndarray, blackbox: ClassifyImage) -> np.ndarray: | ||
""" | ||
Warning: this implementation returns a different shape than is typically expected by this interface. | ||
Instead of returning `[nClasses x H x W]`, `[kColors x nClasses x H x W] saliency maps will be returned. | ||
:param ref_image: np.ndarray | ||
Reference image over which visual saliency heatmaps will be generated. | ||
:param blackbox: ClassifyImage | ||
The black-box classifier handle to perform arbitrary operations on in order to deduce visual saliency. | ||
:return: np.ndarray | ||
A number of visual saliency heatmaps equivalent in number to the quantity of class labels output | ||
by the black-box classifier and configured number of colors. | ||
""" | ||
k_perturbation_masks = self._perturber(ref_image) | ||
|
||
# Collapse k individual colored masks into n multi-color masks | ||
perturbation_masks = 1 - np.sum(1 - k_perturbation_masks, axis=0) | ||
fill_values = np.zeros((*k_perturbation_masks.shape[1:], len(self._fill_colors[0]))) | ||
for mask_idx, k_masks in enumerate(np.swapaxes(k_perturbation_masks, 0, 1)): | ||
k_masks = 1 - np.repeat(k_masks[..., np.newaxis], len(self._fill_colors[0]), axis=3) | ||
for fill_idx, fill_value in enumerate(self._fill_colors): | ||
fill_values[mask_idx] += k_masks[fill_idx] * fill_value | ||
fill_values = np.clip(fill_values, 0, 255) | ||
|
||
class_list = blackbox.get_labels() | ||
# Input one thing so assume output of one thing. | ||
ref_conf_dict = list(blackbox.classify_images([ref_image]))[0] | ||
ref_conf_vec = np.asarray([ref_conf_dict[la] for la in class_list]) | ||
pert_conf_mat = np.empty((perturbation_masks.shape[0], ref_conf_vec.shape[0]), dtype=ref_conf_vec.dtype) | ||
|
||
pert_conf_it = blackbox.classify_images( | ||
MCRISEStack._occlude_image_streaming( | ||
ref_image=ref_image, | ||
masks=perturbation_masks, | ||
fill_values=fill_values, | ||
threads=self._threads, | ||
), | ||
) | ||
for i, pc in enumerate(pert_conf_it): | ||
pert_conf_mat[i] = [pc[la] for la in class_list] | ||
|
||
# Compose classification results into a matrix for the generator | ||
# algorithm. | ||
return self._generator( | ||
image_conf=ref_conf_vec, | ||
perturbed_conf=pert_conf_mat, | ||
perturbed_masks=k_perturbation_masks, | ||
) | ||
|
||
@override | ||
def get_config(self) -> dict[str, Any]: | ||
# It turns out that our configuration here is nearly equivalent to that given | ||
# and retrieved from the MCRISEGrid implementation | ||
c = self._perturber.get_config() | ||
del c["k"] | ||
c["fill_colors"] = self._fill_colors | ||
return c |
Oops, something went wrong.