From acb34566dae24c3f0d14938d0ffc1b4aa9ecb633 Mon Sep 17 00:00:00 2001 From: Emily Veenhuis Date: Mon, 25 Nov 2024 17:09:42 -0600 Subject: [PATCH 1/2] Add MC-RISE implementations --- docs/implementations.rst | 15 ++ docs/release_notes/pending_release.rst | 2 + pyproject.toml | 3 + .../mc_rise_scoring.py | 124 +++++++++++++ .../mc_rise.py | 175 ++++++++++++++++++ .../impls/perturb_image/mc_rise.py | 158 ++++++++++++++++ .../gen_image_classifier_blackbox_sal.py | 17 +- .../__snapshots__/test_mc_rise_scoring.ambr | 12 ++ .../test_mc_rise_scoring.py | 57 ++++++ .../__snapshots__/test_mc_rise.ambr | 46 +++++ .../test_mc_rise.py | 80 ++++++++ .../__snapshots__/test_mc_rise.ambr | 77 ++++++++ tests/impls/perturb_image/test_mc_rise.py | 128 +++++++++++++ tests/test_utils.py | 49 +++++ 14 files changed, 941 insertions(+), 2 deletions(-) create mode 100644 src/xaitk_saliency/impls/gen_classifier_conf_sal/mc_rise_scoring.py create mode 100644 src/xaitk_saliency/impls/gen_image_classifier_blackbox_sal/mc_rise.py create mode 100644 src/xaitk_saliency/impls/perturb_image/mc_rise.py create mode 100644 tests/impls/gen_classifier_conf_sal/__snapshots__/test_mc_rise_scoring.ambr create mode 100644 tests/impls/gen_classifier_conf_sal/test_mc_rise_scoring.py create mode 100644 tests/impls/gen_image_classifier_blackbox_sal/__snapshots__/test_mc_rise.ambr create mode 100644 tests/impls/gen_image_classifier_blackbox_sal/test_mc_rise.py create mode 100644 tests/impls/perturb_image/__snapshots__/test_mc_rise.ambr create mode 100644 tests/impls/perturb_image/test_mc_rise.py create mode 100644 tests/test_utils.py diff --git a/docs/implementations.rst b/docs/implementations.rst index 29a2d3a2..c58f2348 100644 --- a/docs/implementations.rst +++ b/docs/implementations.rst @@ -14,6 +14,11 @@ trade-offs, or results implications. Image Perturbation ------------------ +Class: MCRISEGrid +--------------- +.. autoclass:: xaitk_saliency.impls.perturb_image.mc_rise.MCRISEGrid + :members: + Class: RandomGrid ----------------- .. autoclass:: xaitk_saliency.impls.perturb_image.random_grid.RandomGrid @@ -43,6 +48,11 @@ Class: DRISEScoring .. autoclass:: xaitk_saliency.impls.gen_detector_prop_sal.drise_scoring.DRISEScoring :members: +Class: MCRISEScoring +------------------ +.. autoclass:: xaitk_saliency.impls.gen_classifier_conf_sal.mc_rise_scoring.MCRISEScoring + :members: + Class: OcclusionScoring ----------------------- .. autoclass:: xaitk_saliency.impls.gen_classifier_conf_sal.occlusion_scoring.OcclusionScoring @@ -70,6 +80,11 @@ End-to-End Saliency Generation Image Classification -------------------- +Class: MCRISEStack +~~~~~~~~~~~~~~~~ +.. autoclass:: xaitk_saliency.impls.gen_image_classifier_blackbox_sal.mc_rise.MCRISEStack + :members: + Class: PerturbationOcclusion ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: xaitk_saliency.impls.gen_image_classifier_blackbox_sal.occlusion_based.PerturbationOcclusion diff --git a/docs/release_notes/pending_release.rst b/docs/release_notes/pending_release.rst index 25fb4d4d..ef0f40bb 100644 --- a/docs/release_notes/pending_release.rst +++ b/docs/release_notes/pending_release.rst @@ -8,6 +8,8 @@ Code * Updated hundreds of typing and linting needs. +* Added MC-RISE classification saliency algorithm. + Documentation * Removed a deprecated badge from the README. diff --git a/pyproject.toml b/pyproject.toml index 86acd9d3..f7c8bdcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,13 +79,16 @@ sal-on-coco-dets= "xaitk_saliency.utils.bin.sal_on_coco_dets:sal_on_coco_dets" [tool.poetry.plugins."smqtk_plugins"] # Add implementation sub-module exposure here. "impls.perturb_image.sliding_window" = "xaitk_saliency.impls.perturb_image.sliding_window" +"impls.perturb_image.mc_rise" = "xaitk_saliency.impls.perturb_image.mc_rise" "impls.perturb_image.rise" = "xaitk_saliency.impls.perturb_image.rise" "impls.perturb_image.random_grid" = "xaitk_saliency.impls.perturb_image.random_grid" "impls.perturb_image.sliding_radial" = "xaitk_saliency.impls.perturb_image.sliding_radial" "impls.gen_classifier_conf_sal.occlusion_scoring" = "xaitk_saliency.impls.gen_classifier_conf_sal.occlusion_scoring" +"impls.gen_classifier_conf_sal.mc_rise_scoring" = "xaitk_saliency.impls.gen_classifier_conf_sal.mc_rise_scoring" "impls.gen_classifier_conf_sal.rise_scoring" = "xaitk_saliency.impls.gen_classifier_conf_sal.rise_scoring" "impls.gen_descriptor_sim_sal.similarity_scoring" = "xaitk_saliency.impls.gen_descriptor_sim_sal.similarity_scoring" "impls.gen_image_classifier_blackbox_sal.occlusion_based" = "xaitk_saliency.impls.gen_image_classifier_blackbox_sal.occlusion_based" +"impls.gen_image_classifier_blackbox_sal.mc_rise" = "xaitk_saliency.impls.gen_image_classifier_blackbox_sal.mc_rise" "impls.gen_image_classifier_blackbox_sal.rise" = "xaitk_saliency.impls.gen_image_classifier_blackbox_sal.rise" "impls.gen_detector_prop_sal.drise_scoring" = "xaitk_saliency.impls.gen_detector_prop_sal.drise_scoring" "impls.gen_object_detector_blackbox_sal.drise" = "xaitk_saliency.impls.gen_object_detector_blackbox_sal.drise" diff --git a/src/xaitk_saliency/impls/gen_classifier_conf_sal/mc_rise_scoring.py b/src/xaitk_saliency/impls/gen_classifier_conf_sal/mc_rise_scoring.py new file mode 100644 index 00000000..87e812e1 --- /dev/null +++ b/src/xaitk_saliency/impls/gen_classifier_conf_sal/mc_rise_scoring.py @@ -0,0 +1,124 @@ +"""Implementation of MC-RISE scorer""" + +from typing import Any + +import numpy as np +from sklearn.preprocessing import maxabs_scale +from typing_extensions import override + +from xaitk_saliency import GenerateClassifierConfidenceSaliency +from xaitk_saliency.utils.masking import weight_regions_by_scalar + + +class MCRISEScoring(GenerateClassifierConfidenceSaliency): + """ + Saliency map generation based on the MC-RISE implementation. + This version utilizes only the input perturbed image confidence predictions + and does not utilize reference image confidences. + This implementation also takes influence from debiased RISE and may take an + optional debias probability, `p1` (0 by default). + In the original paper this is paired with the same probability used in RISE + perturbation mask generation (see the `p1` parameter in + :class:`xaitk_saliency.impls.perturb_image.mc_rise.MCRISEGrid`). + + Based on Hatakeyama et. al: + https://openaccess.thecvf.com/content/ACCV2020/papers/Hatakeyama_Visualizing_Color-wise_Saliency_of_Black-Box_Image_Classification_Models_ACCV_2020_paper.pdf + """ + + def __init__( + self, + k: int, + p1: float = 0.0, + ) -> None: + """ + :param k: int + Number of colors to used during perturbation. + :param p1: float + Debias probability, typically paired with the same probability used in mask generation. + + :raises: ValueError + If p1 not in in [0, 1]. + :raises: ValueError + If k < 1. + """ + if p1 < 0 or p1 > 1: + raise ValueError(f"Input p1 value of {p1} is not within the expected [0,1] range.") + self.p1 = p1 + + if k < 1: + raise ValueError(f"Input k value of {k} is not within the expected >0 range.") + self.k = k + + @override + def generate( + self, + image_conf: np.ndarray, + perturbed_conf: np.ndarray, + perturbed_masks: np.ndarray, + ) -> np.ndarray: + """ + Warning: this implementation returns a different shape than typically expected by + this interface. Instead of `[nClasses x H x W]`, saliency maps of shape + `[kColors x nClasses x H x W]` are generated, one per color per class. + + :param image_conf: np.ndarray + Reference image predicted class-confidence vector, as a + `numpy.ndarray`, for all classes that require saliency map + generation. + This should have a shape `[nClasses]`, be float-typed and with + values in the [0,1] range. + :param perturbed_conf: np.ndarray + Perturbed image predicted class confidence matrix. + Classes represented in this matrix should be congruent to classes + represented in the `image_conf` vector. + This should have a shape `[nMasks x nClasses]`, be float-typed and + with values in the [0,1] range. + :param perturbed_masks: np.ndarray + Perturbation masks `numpy.ndarray` over the reference image. + This should be parallel in association to the classification + results input into the `perturbed_conf` parameter. + This should have a shape `[kColors x nMasks x H x W]`, and values in range + [0, 1], where a value closer to 1 indicate areas of the image that + are *unperturbed*. + + :return: np.ndarray + Generated visual saliency heatmap for each input class as a + float-type `numpy.ndarray` of shape `[kColors x nClasses x H x W]`. + + :raises: ValueError + If number of perturbations masks and respective confidence lengths do not match. + """ + if len(perturbed_conf) != perturbed_masks.shape[1]: + raise ValueError("Number of perturbation masks and respective confidence lengths do not match.") + + sal_maps = [] + # Compute unmasked regions + perturbed_masks = 1 - perturbed_masks + m0 = 1 - perturbed_masks.sum(axis=0) + for k_masks in perturbed_masks: + # Debias based on the MC-RISE paper + sal = weight_regions_by_scalar( + scalar_vec=perturbed_conf, + # Factoring out denominator from self.k * k_masks / self.p1 - m0 / (1 - self.p1) + # to avoid divide by zero. Only acts as a normalization factor + masks=self.k * k_masks * (1 - self.p1) - m0 * self.p1, + inv_masks=False, + normalize=False, + ) + + # Normalize final saliency map + sal = maxabs_scale(sal.reshape(sal.shape[0], -1), axis=1).reshape(sal.shape) + + # Ensure saliency map in range [-1, 1] + sal = np.clip(sal, -1, 1) + + sal_maps.append(sal) + + return np.asarray(sal_maps) + + @override + def get_config(self) -> dict[str, Any]: + return { + "p1": self.p1, + "k": self.k, + } diff --git a/src/xaitk_saliency/impls/gen_image_classifier_blackbox_sal/mc_rise.py b/src/xaitk_saliency/impls/gen_image_classifier_blackbox_sal/mc_rise.py new file mode 100644 index 00000000..e351a6c8 --- /dev/null +++ b/src/xaitk_saliency/impls/gen_image_classifier_blackbox_sal/mc_rise.py @@ -0,0 +1,175 @@ +"""Implementation of MC-RISE saliency stack""" + +import itertools +from collections.abc import Generator, Iterable, Sequence +from typing import Any + +import numpy as np +from smqtk_classifier import ClassifyImage +from smqtk_descriptors.utils import parallel_map +from typing_extensions import override + +from xaitk_saliency.impls.gen_classifier_conf_sal.mc_rise_scoring import MCRISEScoring +from xaitk_saliency.impls.perturb_image.mc_rise import MCRISEGrid +from xaitk_saliency.interfaces.gen_image_classifier_blackbox_sal import GenerateImageClassifierBlackboxSaliency + + +class MCRISEStack(GenerateImageClassifierBlackboxSaliency): + """ + Encapsulation of the perturbation-occlusion method using specifically the + MC-RISE implementations of the component algorithms. + + This more specifically encapsulates the MC-RISE method as presented + in their paper and code. See references in the :class:`MCRISEGrid` + and :class:`MCRISEScoring` documentation. + + This implementation shares the `p1` probability and 'k' number colors + with the internal `MCRISEScoring` instance use, to make use of the + debiasing described in the MC-RISE paper. Debiasing is always on. + """ + + def __init__( + self, + n: int, + s: int, + p1: float, + fill_colors: Sequence[Sequence[int]], + seed: int = None, + threads: int = 0, + ) -> None: + """ + :param n: int + Number of random masks used in the algorithm. E.g. 1000. + :param s: int + Spatial resolution of the small masking grid. E.g. 8. + Assumes square grid. + :param p1: float + Probability of the grid cell being set to 1 (otherwise 0). + This should be a float value in the [0, 1] range. E.g. 0.5. + :param fill_colors: Sequence[Sequence[int]] + The fill colors to be used when generating masks. + :param seed: Optional[int] + A seed to pass into the constructed random number generator to allow + for reproducibility + :param threads: int + The number of threads to utilize when generating masks. + If this is <=0 or None, no threading is used and processing + is performed in-line serially. + + :raises: ValueError + If no fill colors are provided. + :raises: ValueError + If provided fill colors have differing numbers of channels. + """ + if len(fill_colors) == 0: + raise ValueError("At least one fill color must be provided") + for fill_color in fill_colors: + if len(fill_color) != len(fill_colors[0]): + raise ValueError("All fill colors must have the same number of channels") + + self._perturber = MCRISEGrid(n=n, s=s, p1=p1, k=len(fill_colors), seed=seed, threads=threads) + self._generator = MCRISEScoring(k=len(fill_colors), p1=p1) + self._threads = threads + self._fill_colors = fill_colors + + @staticmethod + def _work_func(*, ref_image: np.ndarray, i_: int, m: np.ndarray, f: np.ndarray) -> np.ndarray: + s: tuple = (...,) + if ref_image.ndim > 2: + s = (..., None) # add channel axis for multiplication + + # Just the [H x W] component. + img_shape = ref_image.shape[:2] + + m_shape = m.shape + if m_shape != img_shape: + raise ValueError( + f"Input mask (position {i_}) did not the shape of the input image: {m_shape} != {img_shape}", + ) + img_m = np.empty_like(ref_image) + + np.add((m[s] * ref_image), f, out=img_m, casting="unsafe") + + return img_m + + @staticmethod + def _occlude_image_streaming( + *, + ref_image: np.ndarray, + masks: Iterable[np.ndarray], + fill_values: Iterable[np.ndarray], + threads: int = None, + ) -> Generator[np.ndarray, None, None]: + if threads is None or threads < 1: + for i, (mask, fill) in enumerate(zip(masks, fill_values)): + yield MCRISEStack._work_func(ref_image=ref_image, i_=i, m=mask, f=fill) + else: + yield from parallel_map( + MCRISEStack._work_func, + ref_image=ref_image, + i_=itertools.count(), + m=masks, + f=fill_values, + cores=threads, + use_multiprocessing=False, + ) + + @override + def _generate(self, ref_image: np.ndarray, blackbox: ClassifyImage) -> np.ndarray: + """ + Warning: this implementation returns a different shape than is typically expected by this interface. + Instead of returning `[nClasses x H x W]`, `[kColors x nClasses x H x W] saliency maps will be returned. + + :param ref_image: np.ndarray + Reference image over which visual saliency heatmaps will be generated. + :param blackbox: ClassifyImage + The black-box classifier handle to perform arbitrary operations on in order to deduce visual saliency. + + :return: np.ndarray + A number of visual saliency heatmaps equivalent in number to the quantity of class labels output + by the black-box classifier and configured number of colors. + """ + k_perturbation_masks = self._perturber(ref_image) + + # Collapse k individual colored masks into n multi-color masks + perturbation_masks = 1 - np.sum(1 - k_perturbation_masks, axis=0) + fill_values = np.zeros((*k_perturbation_masks.shape[1:], len(self._fill_colors[0]))) + for mask_idx, k_masks in enumerate(np.swapaxes(k_perturbation_masks, 0, 1)): + k_masks = 1 - np.repeat(k_masks[..., np.newaxis], len(self._fill_colors[0]), axis=3) + for fill_idx, fill_value in enumerate(self._fill_colors): + fill_values[mask_idx] += k_masks[fill_idx] * fill_value + fill_values = np.clip(fill_values, 0, 255) + + class_list = blackbox.get_labels() + # Input one thing so assume output of one thing. + ref_conf_dict = list(blackbox.classify_images([ref_image]))[0] + ref_conf_vec = np.asarray([ref_conf_dict[la] for la in class_list]) + pert_conf_mat = np.empty((perturbation_masks.shape[0], ref_conf_vec.shape[0]), dtype=ref_conf_vec.dtype) + + pert_conf_it = blackbox.classify_images( + MCRISEStack._occlude_image_streaming( + ref_image=ref_image, + masks=perturbation_masks, + fill_values=fill_values, + threads=self._threads, + ), + ) + for i, pc in enumerate(pert_conf_it): + pert_conf_mat[i] = [pc[la] for la in class_list] + + # Compose classification results into a matrix for the generator + # algorithm. + return self._generator( + image_conf=ref_conf_vec, + perturbed_conf=pert_conf_mat, + perturbed_masks=k_perturbation_masks, + ) + + @override + def get_config(self) -> dict[str, Any]: + # It turns out that our configuration here is nearly equivalent to that given + # and retrieved from the MCRISEGrid implementation + c = self._perturber.get_config() + del c["k"] + c["fill_colors"] = self._fill_colors + return c diff --git a/src/xaitk_saliency/impls/perturb_image/mc_rise.py b/src/xaitk_saliency/impls/perturb_image/mc_rise.py new file mode 100644 index 00000000..f63e4cdf --- /dev/null +++ b/src/xaitk_saliency/impls/perturb_image/mc_rise.py @@ -0,0 +1,158 @@ +"""Implementation of MC-RISE perturbation mask generation""" + +from typing import Any, Optional +from typing_extensions import override + +import numpy as np +from skimage.transform import resize +from smqtk_descriptors.utils import parallel_map + +from xaitk_saliency.interfaces.perturb_image import PerturbImage + + +class MCRISEGrid(PerturbImage): + """ + Based on Hatakeyama et. al: + https://openaccess.thecvf.com/content/ACCV2020/papers/Hatakeyama_Visualizing_Color-wise_Saliency_of_Black-Box_Image_Classification_Models_ACCV_2020_paper.pdf + """ + + def __init__( + self, + n: int, + s: int, + p1: float, + k: int, + seed: int = None, + threads: Optional[int] = 4, + ) -> None: + """ + Generate a set of random binary masks + + :param n: int + Number of random masks used in the algorithm. E.g. 1000. + :param s: int + Spatial resolution of the small masking grid. E.g. 8. + Assumes square grid. + :param p1: float + Probability of the grid cell being set to 1 (otherwise 0). + This should be a float value in the [0, 1] range. E.g. 0.5. + :param k: int + Number of colors to use. + :param seed: Optional[int] + A seed to pass into the constructed random number generator to allow + for reproducibility + :param threads: int + The number of threads to utilize when generating masks. + If this is <=0 or None, no threading is used and processing + is performed in-line serially. + + :raises: ValueError + If p1 not in [0, 1]. + :raises: ValueError + If k < 1. + """ + if p1 < 0 or p1 > 1: + raise ValueError(f"Input p1 value of {p1} is not within the expected [0,1] range.") + + if k <= 0: + raise ValueError(f"Input k value of {k} is not within the expected >0 range.") + + self.n = n + self.s = s + self.p1 = p1 + self.k = k + self.seed = seed + self.threads = threads + + # Generate a set of random grids of small resolution + rng = np.random.default_rng(seed) + grid: np.ndarray = rng.random((n, s, s)) < p1 + + indiv_color_masks = np.empty((self.k, self.n, self.s, self.s)) + for g_idx, g in enumerate(grid): + g_shape = g.shape + # Randomly choose fill colors for each pixel + fill_mask = rng.integers(0, k, size=g_shape) + for fill_idx in range(self.k): + # Keep the masked pixels where the current idx (color) was selected + indiv_color_mask = np.where(g == False, np.where(fill_mask == fill_idx, False, True), True) + indiv_color_masks[fill_idx][g_idx] = indiv_color_mask + + self.grid = indiv_color_masks + self.grid.astype("float32") + + @override + def perturb(self, ref_image: np.ndarray) -> np.ndarray: + """ + Warning: this implementation returns a different shape than typically expected by this interface. + Instead of `[nMasks x Height x Width]`, masks of shape `[kColors x nMasks x Height x Width]` + are returned. + + :param ref_image: np.ndarray + Reference image to generate perturbations from. + + :return: np.ndarray + Mask matrix with shape `[kColors x nMasks x Height x Width]`. + """ + input_size = np.shape(ref_image)[:2] + num_colors = self.k + num_masks = self.n + grid = self.grid + s = self.s + shift_rng = np.random.default_rng(self.seed) + # Shape format: [H x W], Inherits from `input_size` + cell_size = np.ceil(np.array(input_size) / s) + up_size = (s + 1) * cell_size + + masks = np.empty((num_colors, num_masks, *input_size), dtype=grid.dtype) + + # Expanding index accesses for repetition efficiency. + cell_h, cell_w = cell_size[:2] + input_h, input_w = input_size[:2] + + # flake8: noqa + def work_func(i_: int) -> np.ndarray: + # Random shifts + y = shift_rng.integers(0, cell_h) + x = shift_rng.integers(0, cell_w) + k_masks = np.empty((num_colors, input_h, input_w)) + for k_ in range(num_colors): + k_masks[k_] = resize( + grid[k_, i_, ...], + up_size, + order=1, + mode="reflect", + anti_aliasing=False, + )[ + y : y + input_h, + x : x + input_w, + ] + return k_masks + + threads = self.threads + if threads is None or threads < 1: + for i in range(num_masks): + masks[:, i, ...] = work_func(i) + else: + for i, m in enumerate( + parallel_map( + work_func, + range(num_masks), + cores=threads, + use_multiprocessing=False, + ), + ): + masks[:, i, ...] = m + + return masks + + @override + def get_config(self) -> dict[str, Any]: + return { + "n": self.n, + "s": self.s, + "p1": self.p1, + "k": self.k, + "seed": self.seed, + "threads": self.threads, + } diff --git a/src/xaitk_saliency/interfaces/gen_image_classifier_blackbox_sal.py b/src/xaitk_saliency/interfaces/gen_image_classifier_blackbox_sal.py index 30e768a2..222a04c1 100644 --- a/src/xaitk_saliency/interfaces/gen_image_classifier_blackbox_sal.py +++ b/src/xaitk_saliency/interfaces/gen_image_classifier_blackbox_sal.py @@ -1,3 +1,16 @@ +""" +Module for generating visual saliency heatmaps from image classifiers. + +This module provides the `GenerateImageClassifierBlackboxSaliency` class, which is used to generate per-class visual +saliency heatmaps for an image classifier black-box. The saliency maps indicate which regions of the image are most +influential in the classifier's decision-making process for each class label. + +The `GenerateImageClassifierBlackboxSaliency` class takes a reference image and a classifier that implements the +`smqtk_classifier.ClassifyImage` interface, and outputs saliency heatmaps for each predicted class. These heatmaps are +used to interpret the classifier's behavior, by showing how different parts of the image contribute to the classifier's +confidence in each class. +""" + import abc import numpy as np @@ -64,11 +77,11 @@ def generate(self, ref_image: np.ndarray, blackbox: ClassifyImage) -> np.ndarray raise ValueError(f"Input image matrix has an unexpected number of dimensions: {ref_image.ndim}") output = self._generate(ref_image, blackbox) # Check that the saliency heatmaps' shape matches the reference image. - if output.shape[1:] != ref_image.shape[:2]: + if output.shape[-2:] != ref_image.shape[:2]: raise ShapeMismatchError( f"Output saliency heatmaps did not have matching height and " f"width shape components: " - f"(ref) {ref_image.shape[:2]} != {output.shape[1:]} (output)", + f"(ref) {ref_image.shape[:2]} != {output.shape[-2:]} (output)", ) return output diff --git a/tests/impls/gen_classifier_conf_sal/__snapshots__/test_mc_rise_scoring.ambr b/tests/impls/gen_classifier_conf_sal/__snapshots__/test_mc_rise_scoring.ambr new file mode 100644 index 00000000..57278e6d --- /dev/null +++ b/tests/impls/gen_classifier_conf_sal/__snapshots__/test_mc_rise_scoring.ambr @@ -0,0 +1,12 @@ +# serializer version: 1 +# name: TestMCRiseScoring.test_2class_scoring + array([[[[0. , 0. , 0.33, 0.33, 0.66, 0.66], + [0. , 0. , 0.33, 0.33, 0.66, 0.66], + [0.33, 0.33, 0.66, 0.66, 1. , 1. ], + [0.33, 0.33, 0.66, 0.66, 1. , 1. ]], + + [[1. , 1. , 0.66, 0.66, 0.33, 0.33], + [1. , 1. , 0.66, 0.66, 0.33, 0.33], + [0.66, 0.66, 0.33, 0.33, 0. , 0. ], + [0.66, 0.66, 0.33, 0.33, 0. , 0. ]]]]) +# --- diff --git a/tests/impls/gen_classifier_conf_sal/test_mc_rise_scoring.py b/tests/impls/gen_classifier_conf_sal/test_mc_rise_scoring.py new file mode 100644 index 00000000..ba287251 --- /dev/null +++ b/tests/impls/gen_classifier_conf_sal/test_mc_rise_scoring.py @@ -0,0 +1,57 @@ +import numpy as np +import pytest +from smqtk_core.configuration import configuration_test_helper +from syrupy.assertion import SnapshotAssertion + +from tests import EXPECTED_MASKS_4x6 +from tests.test_utils import CustomFloatSnapshotExtension +from xaitk_saliency.impls.gen_classifier_conf_sal.mc_rise_scoring import MCRISEScoring + + +@pytest.fixture +def snapshot_custom(snapshot: SnapshotAssertion) -> SnapshotAssertion: + return snapshot.use_extension(lambda: CustomFloatSnapshotExtension()) + + +class TestMCRiseScoring: + def test_init_outofrange_config(self) -> None: + """Test catching an out of range config value.""" + with pytest.raises(ValueError, match=r"Input p1 value of -0\.3 is not within the expected \[0,1\] range\."): + MCRISEScoring(k=4, p1=-0.3) + + with pytest.raises(ValueError, match=r"Input p1 value of 5 is not within the expected \[0,1\] range\."): + MCRISEScoring(k=1, p1=5) + + with pytest.raises(ValueError, match=r"Input k value of 0 is not within the expected >0 range\."): + MCRISEScoring(k=0, p1=0.2) + + def test_configuration(self) -> None: + """Test configuration aspects.""" + inst = MCRISEScoring(k=2, p1=0.747) + for inst_i in configuration_test_helper(inst): + assert np.allclose(inst_i.p1, 0.747) + assert inst_i.k == 2 + + def test_bad_alignment(self) -> None: + """Test passing misaligned perturbed confidence vector and masks input.""" + test_confs = np.ones((3, 2)) + test_masks = np.ones((3, 4, 3, 3)) + + inst = MCRISEScoring(k=3) + + with pytest.raises( + ValueError, + match=r"Number of perturbation masks and respective confidence lengths do not match", + ): + inst.generate(test_confs[0], test_confs, test_masks) + + def test_2class_scoring(self, snapshot_custom: SnapshotAssertion) -> None: + """Test for expected output when given known input data.""" + test_ref_confs = np.array([0, 0.2]) + # Mock classification results for the test masked regions. + test_pert_confs = np.array([[0.00, 0.33, 0.66, 0.33, 0.66, 1.00], [1.00, 0.66, 0.33, 0.66, 0.33, 0.00]]).T + + inst = MCRISEScoring(k=1) + sal = inst.generate(test_ref_confs, test_pert_confs, np.asarray([EXPECTED_MASKS_4x6])) + + snapshot_custom.assert_match(sal) diff --git a/tests/impls/gen_image_classifier_blackbox_sal/__snapshots__/test_mc_rise.ambr b/tests/impls/gen_image_classifier_blackbox_sal/__snapshots__/test_mc_rise.ambr new file mode 100644 index 00000000..98d7652f --- /dev/null +++ b/tests/impls/gen_image_classifier_blackbox_sal/__snapshots__/test_mc_rise.ambr @@ -0,0 +1,46 @@ +# serializer version: 1 +# name: TestMCRise.test_generation_rgb + array([[[[ 0.12224591, 0.06112296, 0.00710732, ..., 0.62295665, + 0.55899076, 0.51847903], + [ 0.19189765, 0.15351812, 0.1108742 , ..., 0.51421464, + 0.40760483, 0.33297797], + [ 0.2594172 , 0.23525231, 0.20113717, ..., 0.45024876, + 0.29246624, 0.17093106], + ..., + [-0.01954513, -0.04371002, -0.10767591, ..., 0.24022743, + 0.20895522, 0.16915423], + [-0.15174129, -0.17022033, -0.20575693, ..., 0.32267235, + 0.33688699, 0.30845771], + [-0.28393746, -0.29673063, -0.30383795, ..., 0.32835821, + 0.35394456, 0.30277186]]], + + + [[[-0.32026144, -0.39787582, -0.31535948, ..., -0.0253268 , + -0.01552288, 0.02124183], + [-0.30882353, -0.37009804, -0.30392157, ..., -0.16503268, + -0.14542484, -0.08905229], + [-0.27287582, -0.31781046, -0.26633987, ..., -0.21160131, + -0.20179739, -0.1503268 ], + ..., + [ 0.28145425, 0.44240196, 0.5625 , ..., -0.38071895, + -0.35784314, -0.3374183 ], + [ 0.12949346, 0.24795752, 0.35171569, ..., -0.37418301, + -0.33823529, -0.31454248], + [-0.02246732, 0.05351307, 0.14093137, ..., -0.38970588, + -0.3504902 , -0.33333333]]], + + + [[[-0.14629714, -0.2217174 , -0.24988642, ..., 0.15038619, + -0.07314857, -0.17673785], + [-0.07360291, -0.15265788, -0.15538392, ..., 0.06860518, + -0.17673785, -0.25851886], + [ 0.13811904, 0.06633348, 0.10177192, ..., 0.16946842, + -0.13039527, -0.24488869], + ..., + [-0.0781463 , -0.13357565, -0.16901408, ..., -0.21626533, + -0.04906861, 0.11540209], + [-0.03452976, -0.07178555, -0.10358928, ..., -0.28532485, + -0.10358928, 0.06451613], + [ 0.00908678, -0.00999546, -0.03816447, ..., -0.37891867, + -0.19354839, -0.0327124 ]]]]) +# --- diff --git a/tests/impls/gen_image_classifier_blackbox_sal/test_mc_rise.py b/tests/impls/gen_image_classifier_blackbox_sal/test_mc_rise.py new file mode 100644 index 00000000..b96f9eee --- /dev/null +++ b/tests/impls/gen_image_classifier_blackbox_sal/test_mc_rise.py @@ -0,0 +1,80 @@ +from collections.abc import Hashable, Iterator, Sequence + +import numpy as np +import pytest +from smqtk_classifier import ClassifyImage +from smqtk_classifier.interfaces.classification_element import CLASSIFICATION_DICT_T +from smqtk_classifier.interfaces.classify_image import IMAGE_ITER_T +from smqtk_core.configuration import configuration_test_helper +from syrupy import SnapshotAssertion + +from tests.test_utils import CustomFloatSnapshotExtension +from xaitk_saliency.impls.gen_image_classifier_blackbox_sal.mc_rise import ( + MCRISEGrid, + MCRISEScoring, + MCRISEStack, +) + + +@pytest.fixture +def snapshot_custom(snapshot: SnapshotAssertion) -> SnapshotAssertion: + return snapshot.use_extension(lambda: CustomFloatSnapshotExtension()) + + +class TestMCRise: + def test_configuration(self) -> None: + """Test standard config things.""" + fill_colors = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 255]] + inst = MCRISEStack( + n=444, + s=33, + p1=0.22, + fill_colors=fill_colors, + seed=42, + threads=99, + ) + for inst_i in configuration_test_helper(inst): + inst_p = inst_i._perturber + inst_g = inst_i._generator + assert isinstance(inst_p, MCRISEGrid) + assert isinstance(inst_g, MCRISEScoring) + assert inst_p.n == 444 + assert inst_p.s == 33 + assert np.allclose(inst_p.p1, 0.22) + assert inst_p.k == len(fill_colors) + assert inst_p.seed == 42 + assert inst_p.threads == 99 + assert inst_g.k == len(fill_colors) + assert np.allclose(inst_g.p1, 0.22) + assert inst_i._threads == 99 + assert inst_i._fill_colors == fill_colors + + def test_generation_rgb(self, snapshot_custom: SnapshotAssertion) -> None: + """Test basic generation functionality with dummy image and blackbox""" + + class TestBlackBox(ClassifyImage): + """Dummy blackbox that yields a constant result.""" + + def get_labels(self) -> Sequence[Hashable]: + return [0] + + def classify_images(self, img_iter: IMAGE_ITER_T) -> Iterator[CLASSIFICATION_DICT_T]: + for _ in img_iter: + yield {0: 1.0} + + # Not implemented for stub + get_config = None # type: ignore + + test_image = np.full([32, 32, 3], fill_value=255, dtype=np.uint8) + test_bb = TestBlackBox() + + # The heatmap result of this is merely the sum of RISE mask generation + # normalized in the [-1,1] range as the generation stage does nothing + # given the constant blackbox response. + inst = MCRISEStack(n=5, s=8, p1=0.5, fill_colors=[[255, 255, 255], [255, 0, 0], [0, 0, 255]], seed=0) + # Results may be sensitive to changes in scikit-image. Version 0.19 + # introduces some changes to the resize function. Difference is + # expected to only be marginal. + res = inst.generate(test_image, test_bb) + + snapshot_custom.assert_match(res) diff --git a/tests/impls/perturb_image/__snapshots__/test_mc_rise.ambr b/tests/impls/perturb_image/__snapshots__/test_mc_rise.ambr new file mode 100644 index 00000000..0d244e82 --- /dev/null +++ b/tests/impls/perturb_image/__snapshots__/test_mc_rise.ambr @@ -0,0 +1,77 @@ +# serializer version: 1 +# name: TestMCRISEPerturbation.test_perturb_1channel + array([[[[0.98148148, 0.90740741, 0.83333333, 0.75925926, 0.68518519, + 0.72222222], + [1. , 1. , 1. , 1. , 1. , + 1. ], + [0.98148148, 0.90740741, 0.83333333, 0.75925926, 0.68518519, + 0.72222222], + [0.96296296, 0.81481481, 0.66666667, 0.51851852, 0.37037037, + 0.44444444]], + + [[0.83333333, 0.94444444, 0.72222222, 0.5 , 0.27777778, + 0.05555556], + [0.61111111, 0.64814815, 0.57407407, 0.5 , 0.42592593, + 0.35185185], + [0.38888889, 0.35185185, 0.42592593, 0.5 , 0.57407407, + 0.64814815], + [0.16666667, 0.05555556, 0.27777778, 0.5 , 0.72222222, + 0.94444444]]], + + + [[[0.05555556, 0.27777778, 0.5 , 0.72222222, 0.94444444, + 0.83333333], + [0.05555556, 0.27777778, 0.5 , 0.72222222, 0.94444444, + 0.83333333], + [0.05555556, 0.27777778, 0.5 , 0.72222222, 0.94444444, + 0.83333333], + [0.05555556, 0.27777778, 0.5 , 0.72222222, 0.94444444, + 0.83333333]], + + [[1. , 1. , 1. , 1. , 1. , + 1. ], + [0.94444444, 0.98148148, 0.90740741, 0.83333333, 0.75925926, + 0.68518519], + [0.88888889, 0.96296296, 0.81481481, 0.66666667, 0.51851852, + 0.37037037], + [0.83333333, 0.94444444, 0.72222222, 0.5 , 0.27777778, + 0.05555556]]]]) +# --- +# name: TestMCRISEPerturbation.test_perturb_3channel + array([[[[0.98148148, 0.90740741, 0.83333333, 0.75925926, 0.68518519, + 0.72222222], + [1. , 1. , 1. , 1. , 1. , + 1. ], + [0.98148148, 0.90740741, 0.83333333, 0.75925926, 0.68518519, + 0.72222222], + [0.96296296, 0.81481481, 0.66666667, 0.51851852, 0.37037037, + 0.44444444]], + + [[0.83333333, 0.94444444, 0.72222222, 0.5 , 0.27777778, + 0.05555556], + [0.61111111, 0.64814815, 0.57407407, 0.5 , 0.42592593, + 0.35185185], + [0.38888889, 0.35185185, 0.42592593, 0.5 , 0.57407407, + 0.64814815], + [0.16666667, 0.05555556, 0.27777778, 0.5 , 0.72222222, + 0.94444444]]], + + + [[[0.05555556, 0.27777778, 0.5 , 0.72222222, 0.94444444, + 0.83333333], + [0.05555556, 0.27777778, 0.5 , 0.72222222, 0.94444444, + 0.83333333], + [0.05555556, 0.27777778, 0.5 , 0.72222222, 0.94444444, + 0.83333333], + [0.05555556, 0.27777778, 0.5 , 0.72222222, 0.94444444, + 0.83333333]], + + [[1. , 1. , 1. , 1. , 1. , + 1. ], + [0.94444444, 0.98148148, 0.90740741, 0.83333333, 0.75925926, + 0.68518519], + [0.88888889, 0.96296296, 0.81481481, 0.66666667, 0.51851852, + 0.37037037], + [0.83333333, 0.94444444, 0.72222222, 0.5 , 0.27777778, + 0.05555556]]]]) +# --- diff --git a/tests/impls/perturb_image/test_mc_rise.py b/tests/impls/perturb_image/test_mc_rise.py new file mode 100644 index 00000000..f4f23778 --- /dev/null +++ b/tests/impls/perturb_image/test_mc_rise.py @@ -0,0 +1,128 @@ +import numpy as np +import pytest +from smqtk_core.configuration import configuration_test_helper +from syrupy.assertion import SnapshotAssertion + +from tests.test_utils import CustomFloatSnapshotExtension +from xaitk_saliency.impls.perturb_image.mc_rise import MCRISEGrid + + +@pytest.fixture +def snapshot_custom(snapshot: SnapshotAssertion) -> SnapshotAssertion: + return snapshot.use_extension(lambda: CustomFloatSnapshotExtension()) + + +class TestMCRISEPerturbation: + def test_init_valued(self) -> None: + """Test that constructor values pass.""" + ex_n = 1000 + ex_s = 8 + ex_p1 = 0.5 + k = 2 + impl = MCRISEGrid(n=ex_n, s=ex_s, p1=ex_p1, k=k) + assert impl.n == ex_n + assert impl.s == ex_s + assert np.allclose(impl.p1, ex_p1) + assert impl.k == k + + def test_init_outofrange_p1(self) -> None: + """Test catching an out of range p1 value.""" + with pytest.raises(ValueError, match=r"Input p1 value of -0\.3 is not within the expected \[0,1\] range\."): + MCRISEGrid(10, 8, p1=-0.3, k=1) + + with pytest.raises(ValueError, match=r"Input p1 value of 5 is not within the expected \[0,1\] range\."): + MCRISEGrid(10, 8, p1=5, k=1) + + with pytest.raises(ValueError, match=r"Input k value of -1 is not within the expected >0 range\."): + MCRISEGrid(10, 8, p1=0.2, k=-1) + + def test_standard_config(self) -> None: + ex_n = 1000 + ex_s = 8 + ex_p1 = 0.5 + k = 1 + impl = MCRISEGrid(n=ex_n, s=ex_s, p1=ex_p1, k=k) + for inst in configuration_test_helper(impl): + assert inst.n == ex_n + assert inst.s == ex_s + assert np.allclose(inst.p1, ex_p1) + assert inst.k == k + + def test_if_random(self) -> None: + """Test that the perturbations are randomized""" + impl1 = MCRISEGrid(n=1000, s=8, p1=0.5, k=2) + impl2 = MCRISEGrid(n=1000, s=8, p1=0.5, k=2) + assert not np.array_equal(impl1.grid, impl2.grid) + + def test_seed(self) -> None: + """Test that passing a seed generates equivalent masks""" + impl1 = MCRISEGrid(n=1000, s=8, p1=0.5, k=2, seed=42) + impl2 = MCRISEGrid(n=1000, s=8, p1=0.5, k=2, seed=42) + assert np.array_equal(impl1.grid, impl2.grid) + + def test_perturb_1channel(self, snapshot_custom: SnapshotAssertion) -> None: + """ + Test basic perturbation on a known image with even windowing + stride. + Input image mode should not impact the masks output. + """ + # Image is slightly wide + white_image = np.full((4, 6), fill_value=255, dtype=np.uint8) + + # Setting threads=0 for serialized processing for deterministic + # results. + impl = MCRISEGrid(n=2, s=2, p1=0.5, k=2, seed=42, threads=0) + actual_masks = impl.perturb(white_image) + + snapshot_custom.assert_match(actual_masks) + + def test_call_idempotency(self) -> None: + """ + Test that, at least when seeded and single-threaded, perturbation + generation is idempotent. + """ + # Image is slightly wide + white_image = np.full((4, 6), fill_value=255, dtype=np.uint8) + # Setting threads=0 for serialized processing for deterministic + # results. When greater than 1 idempotency cannot be guaranteed due to + # thread interleaving. + # Also of course seeding otherwise random will do its random things. + impl = MCRISEGrid(n=2, s=2, p1=0.5, k=3, seed=42, threads=0) + + actual_masks1 = impl.perturb(white_image) + actual_masks2 = impl.perturb(white_image) + + assert np.allclose( + actual_masks1, + actual_masks2, + ) + + def test_perturb_3channel(self, snapshot_custom: SnapshotAssertion) -> None: + """ + Test basic perturbation on a known image with even windowing + stride. + Input image mode should not impact the masks output. + """ + # Image is slightly wide + white_image = np.full((4, 6, 3), fill_value=255, dtype=np.uint8) + + # Setting threads=0 for serialized processing for deterministic + # results. + impl = MCRISEGrid(n=2, s=2, p1=0.5, k=2, seed=42, threads=0) + actual_masks = impl.perturb(white_image) + + snapshot_custom.assert_match(actual_masks) + + def test_multiple_image_sizes(self) -> None: + """ + Test that once we initialize a RISEPerturbation we can call it on + images of varying sizes + """ + impl = MCRISEGrid(n=2, s=2, p1=0.5, k=2, seed=42) + white_image_small = np.full((4, 6), fill_value=255, dtype=np.uint8) + white_image_large = np.full((41, 26), fill_value=255, dtype=np.uint8) + masks_small = impl.perturb(white_image_small) + assert len(masks_small) == 2 + assert masks_small.shape[2:] == white_image_small.shape + + masks_large = impl.perturb(white_image_large) + assert len(masks_large) == 2 + assert masks_large.shape[2:] == white_image_large.shape diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..6e538c11 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,49 @@ +import ast +import math +import re + +import numpy as np +from syrupy.extensions.amber import AmberSnapshotExtension + + +class CustomFloatSnapshotExtension(AmberSnapshotExtension): + def parse_snapshot_to_numpy_no_eval(self, snapshot: str) -> tuple[np.ndarray]: + # Remove metadata lines starting with `#` + snapshot = "\n".join(line for line in snapshot.splitlines() if not line.strip().startswith("#")) + + # Extract array strings using regex + array_pattern = r"array\((\[.*?\])\)" + matches = re.findall(array_pattern, snapshot, flags=re.S) + + # Parse each array string into a NumPy array + arrays = [] + for match in matches: + # Replace "..." with the repeating last row/column to avoid parsing errors + cleaned_array = match.replace("...,", "") + # Convert the array string into a NumPy array using `np.array` and `eval` + arrays.append( + np.array(ast.literal_eval(cleaned_array)), + ) # Use `eval` only for literals, not the whole snapshot + return tuple(arrays) + + def matches(self, *, serialized_data: str, snapshot_data: str) -> bool: + try: + # Convert serialized and snapshot data to floats and compare within tolerance + a = float(serialized_data) + b = float(snapshot_data) + return math.isclose(a, b, abs_tol=1e-4) + except ValueError: + # If conversion to float fails, fallback to default comparison + pass + try: + # Convert serialized and snapshot data to np arrays and compare within tolerance + a = self.parse_snapshot_to_numpy_no_eval(serialized_data) + b = self.parse_snapshot_to_numpy_no_eval(snapshot_data) + for array_a, array_b in zip(a, b): + if not all( + math.isclose(array_a[index], array_b[index], rel_tol=1e-4) for index in np.ndindex(array_a.shape) + ): + return False + return True + except ValueError: + return serialized_data == snapshot_data From 7681061466e5d91a72ea6eb35e46853c34223d28 Mon Sep 17 00:00:00 2001 From: Brandon RichardWebster Date: Thu, 12 Dec 2024 22:40:38 -0500 Subject: [PATCH 2/2] Updated a regext to be more secure. --- tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 6e538c11..af567929 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,7 +12,7 @@ def parse_snapshot_to_numpy_no_eval(self, snapshot: str) -> tuple[np.ndarray]: snapshot = "\n".join(line for line in snapshot.splitlines() if not line.strip().startswith("#")) # Extract array strings using regex - array_pattern = r"array\((\[.*?\])\)" + array_pattern = r"array\((\[[^\]]*\])\)" matches = re.findall(array_pattern, snapshot, flags=re.S) # Parse each array string into a NumPy array