diff --git a/SIRF_data_preparation/evaluation_utilities.py b/SIRF_data_preparation/evaluation_utilities.py index 20a371e..791d629 100644 --- a/SIRF_data_preparation/evaluation_utilities.py +++ b/SIRF_data_preparation/evaluation_utilities.py @@ -5,7 +5,6 @@ import matplotlib.pyplot as plt import numpy as np -from scipy.ndimage import binary_erosion import sirf.STIR as STIR from petric import QualityMetrics @@ -25,21 +24,6 @@ def get_metrics(qm: QualityMetrics, iters: Iterable[int], srcdir='.'): list(qm.evaluate(STIR.ImageData(str(Path(srcdir) / f'iter_{i:04d}.hv'))).values()) for i in iters]) -def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 1) -> int: - """ - Returns first index of `metrics` with value <= `thresh`. - The values must remain below the respective thresholds for at least `window` number of entries. - Otherwise raises IndexError. - """ - thr_arr = np.asanyarray(thresh) - assert metrics.ndim == 2 - assert thr_arr.ndim == 1 - assert metrics.shape[1] == thr_arr.shape[0] - passed = (metrics <= thr_arr[None]).all(axis=1) - res = binary_erosion(passed, structure=np.ones(window), origin=-(window // 2)) - return np.where(res)[0][0] - - def plot_metrics(iters: Iterable[int], m: np.ndarray, labels=None, suffix=""): """Make 2 subplots of metrics""" if labels is None: diff --git a/SIRF_data_preparation/plot_BSREM_metrics.py b/SIRF_data_preparation/plot_BSREM_metrics.py index 39bb828..55ae8ff 100644 --- a/SIRF_data_preparation/plot_BSREM_metrics.py +++ b/SIRF_data_preparation/plot_BSREM_metrics.py @@ -11,7 +11,7 @@ from petric import OUTDIR, SRCDIR, QualityMetrics, get_data from SIRF_data_preparation import data_QC from SIRF_data_preparation.dataset_settings import get_settings -from SIRF_data_preparation.evaluation_utilities import get_metrics, pass_index, plot_metrics, read_objectives +from SIRF_data_preparation.evaluation_utilities import get_metrics, plot_metrics, read_objectives if not all((SRCDIR.is_dir(), OUTDIR.is_dir())): PETRICDIR = Path('~/devel/PETRIC').expanduser() @@ -112,7 +112,7 @@ fig.savefig(outdir / f'{scanID}_metrics_BSREM.png') # %% -idx = pass_index(m, numpy.array([.01, .01] + [.005 for i in range(len(data.voi_masks))]), 10) +idx = QualityMetrics.pass_index(m, numpy.array([.01, .01] + [.005 for i in range(len(data.voi_masks))]), 10) iter = iters[idx] print(iter) image = STIR.ImageData(str(datadir / f"iter_{iter:04d}.hv")) diff --git a/petric.py b/petric.py index 31e2d20..5b3cb23 100755 --- a/petric.py +++ b/petric.py @@ -17,11 +17,14 @@ import csv import logging import os +import re from dataclasses import dataclass from pathlib import Path, PurePath from time import time +from typing import Iterable import numpy as np +from scipy.ndimage import binary_erosion from skimage.metrics import mean_squared_error as mse from tensorboardX import SummaryWriter @@ -43,7 +46,7 @@ class Callback(cil_callbacks.Callback): CIL Callback but with `self.skip_iteration` checking `min(self.interval, algo.update_objective_interval)`. TODO: backport this class to CIL. """ - def __init__(self, interval: int = 3, **kwargs): + def __init__(self, interval: int = 1, **kwargs): super().__init__(**kwargs) self.interval = interval @@ -86,7 +89,7 @@ def __init__(self, transverse_slice=None, coronal_slice=None, sagittal_slice=Non def __call__(self, algo: Algorithm): if self.skip_iteration(algo): return - t = getattr(self, '__time', None) or time() + t = self._time_ log.debug("logging iter %d...", algo.iteration) # initialise `None` values self.transverse_slice = algo.x.dimensions()[0] // 2 if self.transverse_slice is None else self.transverse_slice @@ -94,11 +97,12 @@ def __call__(self, algo: Algorithm): self.sagittal_slice = algo.x.dimensions()[2] // 2 if self.sagittal_slice is None else self.sagittal_slice self.vmax = algo.x.max() if self.vmax is None else self.vmax - self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t) - if self.x_prev is not None: - normalised_change = (algo.x - self.x_prev).norm() / algo.x.norm() - self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t) - self.x_prev = algo.x.clone() + if log.getEffectiveLevel() <= logging.DEBUG: + self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t) + if self.x_prev is not None: + normalised_change = (algo.x - self.x_prev).norm() / algo.x.norm() + self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t) + self.x_prev = algo.x.clone() x_arr = algo.x.as_array() self.tb.add_image("transverse", np.clip(x_arr[None, self.transverse_slice] / self.vmax, 0, 1), algo.iteration, t) @@ -110,7 +114,10 @@ def __call__(self, algo: Algorithm): class QualityMetrics(ImageQualityCallback, Callback): """From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds""" - def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 3, **kwargs): + THRESHOLD = {"AEM_VOI": 0.005, "RMSE_whole_object": 0.01, "RMSE_background": 0.01} + + def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1, + threshold_window: int = 10, **kwargs): # TODO: drop multiple inheritance once `interval` included in CIL Callback.__init__(self, interval=interval) ImageQualityCallback.__init__(self, reference_image, **kwargs) @@ -118,13 +125,25 @@ def __init__(self, reference_image, whole_object_mask, background_mask, interval self.background_indices = np.where(background_mask.as_array()) self.ref_im_arr = reference_image.as_array() self.norm = self.ref_im_arr[self.background_indices].mean() + self.threshold_window = threshold_window + self.threshold_iters = 0 def __call__(self, algo: Algorithm): if self.skip_iteration(algo): return - t = getattr(self, '__time', None) or time() - for tag, value in self.evaluate(algo.x).items(): + t = self._time_ + # log metrics + metrics = self.evaluate(algo.x) + for tag, value in metrics.items(): self.tb_summary_writer.add_scalar(tag, value, algo.iteration, t) + # stop if `all(metrics < THRESHOLD)` for `threshold_window` iters + # NB: need to strip suffix from "AEM_VOI" tags + if all(value <= self.THRESHOLD[re.sub("^(AEM_VOI)_.*", r"\1", tag)] for tag, value in metrics.items()): + self.threshold_iters += 1 + if self.threshold_iters >= self.threshold_window: + raise StopIteration + else: + self.threshold_iters = 0 def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]: assert not any(self.filter.values()), "Filtering not implemented" @@ -143,10 +162,25 @@ def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]: def keys(self): return ["RMSE_whole_object", "RMSE_background"] + [f"AEM_VOI_{name}" for name in sorted(self.voi_indices)] + @staticmethod + def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 10) -> int: + """ + Returns first index of `metrics` with value <= `thresh`. + The values must remain below the respective thresholds for at least `window` number of entries. + Otherwise raises IndexError. + """ + thr_arr = np.asanyarray(thresh) + assert metrics.ndim == 2 + assert thr_arr.ndim == 1 + assert metrics.shape[1] == thr_arr.shape[0] + passed = (metrics <= thr_arr[None]).all(axis=1) + res = binary_erosion(passed, structure=np.ones(window), origin=-(window // 2)) + return np.where(res)[0][0] + class MetricsWithTimeout(cil_callbacks.Callback): """Stops the algorithm after `seconds`""" - def __init__(self, seconds=600, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, sagittal_slice=None, + def __init__(self, seconds=3600, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, sagittal_slice=None, **kwargs): super().__init__(**kwargs) self._seconds = seconds @@ -158,16 +192,18 @@ def __init__(self, seconds=600, outdir=OUTDIR, transverse_slice=None, coronal_sl self.tb = tb_cbk.tb # convenient access to the underlying SummaryWriter self.reset() - def reset(self, seconds=None): - self.limit = time() + (self._seconds if seconds is None else seconds) + def reset(self): self.offset = 0 + self.limit = (now := time()) + self._seconds + self.tb.add_scalar("reset", 0, -1, now) # for relative timing calculation def __call__(self, algo: Algorithm): - if (now := time()) > self.limit + self.offset: + if (time_excluding_metrics := (now := time()) - self.offset) > self.limit: log.warning("Timeout reached. Stopping algorithm.") + self.tb.add_scalar("reset", 0, algo.iteration, time_excluding_metrics) raise StopIteration for c in self.callbacks: - c.__time = now - self.offset # privately inject walltime-excluding-petric-callbacks + c._time_ = time_excluding_metrics c(algo) self.offset += time() - now @@ -292,8 +328,11 @@ def get_image(fname): from traceback import print_exc from docopt import docopt + from tqdm.contrib.logging import logging_redirect_tqdm args = docopt(__doc__) logging.basicConfig(level=getattr(logging, args["--log"].upper())) + redir = logging_redirect_tqdm() + redir.__enter__() from main import Submission, submission_callbacks assert issubclass(Submission, Algorithm) for srcdir, outdir, metrics in data_dirs_metrics: