From 900fd49d3c85bc45a112e8ec2eb6ab25f22e0c4a Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Tue, 1 Oct 2024 20:24:42 +0100 Subject: [PATCH 1/8] per-iteration metrics --- petric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/petric.py b/petric.py index 31e2d20..8425671 100755 --- a/petric.py +++ b/petric.py @@ -43,7 +43,7 @@ class Callback(cil_callbacks.Callback): CIL Callback but with `self.skip_iteration` checking `min(self.interval, algo.update_objective_interval)`. TODO: backport this class to CIL. """ - def __init__(self, interval: int = 3, **kwargs): + def __init__(self, interval: int = 1, **kwargs): super().__init__(**kwargs) self.interval = interval @@ -110,7 +110,7 @@ def __call__(self, algo: Algorithm): class QualityMetrics(ImageQualityCallback, Callback): """From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds""" - def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 3, **kwargs): + def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1, **kwargs): # TODO: drop multiple inheritance once `interval` included in CIL Callback.__init__(self, interval=interval) ImageQualityCallback.__init__(self, reference_image, **kwargs) From bc3a25075a3e5f65f7e98b90c7e380e8df55fbb2 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Tue, 1 Oct 2024 20:25:23 +0100 Subject: [PATCH 2/8] better logging <=> tqdm interaction --- petric.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/petric.py b/petric.py index 8425671..b8e0fe9 100755 --- a/petric.py +++ b/petric.py @@ -292,8 +292,11 @@ def get_image(fname): from traceback import print_exc from docopt import docopt + from tqdm.contrib.logging import logging_redirect_tqdm args = docopt(__doc__) logging.basicConfig(level=getattr(logging, args["--log"].upper())) + redir = logging_redirect_tqdm() + redir.__enter__() from main import Submission, submission_callbacks assert issubclass(Submission, Algorithm) for srcdir, outdir, metrics in data_dirs_metrics: From 7b44d2c75dc306ee8a53bf2da97c1cc3d8df6586 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Tue, 1 Oct 2024 20:26:03 +0100 Subject: [PATCH 3/8] metrics: skip objective & normalised_change --- petric.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/petric.py b/petric.py index b8e0fe9..696973b 100755 --- a/petric.py +++ b/petric.py @@ -94,11 +94,12 @@ def __call__(self, algo: Algorithm): self.sagittal_slice = algo.x.dimensions()[2] // 2 if self.sagittal_slice is None else self.sagittal_slice self.vmax = algo.x.max() if self.vmax is None else self.vmax - self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t) - if self.x_prev is not None: - normalised_change = (algo.x - self.x_prev).norm() / algo.x.norm() - self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t) - self.x_prev = algo.x.clone() + if log.getEffectiveLevel() <= logging.DEBUG: + self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t) + if self.x_prev is not None: + normalised_change = (algo.x - self.x_prev).norm() / algo.x.norm() + self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t) + self.x_prev = algo.x.clone() x_arr = algo.x.as_array() self.tb.add_image("transverse", np.clip(x_arr[None, self.transverse_slice] / self.vmax, 0, 1), algo.iteration, t) From a70c78dd9c1e9821771d93a19b1648743660769b Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Tue, 1 Oct 2024 20:27:00 +0100 Subject: [PATCH 4/8] fix walltime override --- petric.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/petric.py b/petric.py index 696973b..ef963e1 100755 --- a/petric.py +++ b/petric.py @@ -86,7 +86,7 @@ def __init__(self, transverse_slice=None, coronal_slice=None, sagittal_slice=Non def __call__(self, algo: Algorithm): if self.skip_iteration(algo): return - t = getattr(self, '__time', None) or time() + t = self._time_ log.debug("logging iter %d...", algo.iteration) # initialise `None` values self.transverse_slice = algo.x.dimensions()[0] // 2 if self.transverse_slice is None else self.transverse_slice @@ -123,7 +123,7 @@ def __init__(self, reference_image, whole_object_mask, background_mask, interval def __call__(self, algo: Algorithm): if self.skip_iteration(algo): return - t = getattr(self, '__time', None) or time() + t = self._time_ for tag, value in self.evaluate(algo.x).items(): self.tb_summary_writer.add_scalar(tag, value, algo.iteration, t) @@ -159,16 +159,16 @@ def __init__(self, seconds=600, outdir=OUTDIR, transverse_slice=None, coronal_sl self.tb = tb_cbk.tb # convenient access to the underlying SummaryWriter self.reset() - def reset(self, seconds=None): - self.limit = time() + (self._seconds if seconds is None else seconds) + def reset(self): + self.limit = time() + self._seconds self.offset = 0 def __call__(self, algo: Algorithm): - if (now := time()) > self.limit + self.offset: + if (time_excluding_metrics := (now := time()) - self.offset) > self.limit: log.warning("Timeout reached. Stopping algorithm.") raise StopIteration for c in self.callbacks: - c.__time = now - self.offset # privately inject walltime-excluding-petric-callbacks + c._time_ = time_excluding_metrics c(algo) self.offset += time() - now From f543f9c1c17bc64692a8889eac752878052deed4 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Tue, 1 Oct 2024 20:41:08 +0100 Subject: [PATCH 5/8] 1h timeout --- petric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/petric.py b/petric.py index ef963e1..8226558 100755 --- a/petric.py +++ b/petric.py @@ -147,7 +147,7 @@ def keys(self): class MetricsWithTimeout(cil_callbacks.Callback): """Stops the algorithm after `seconds`""" - def __init__(self, seconds=600, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, sagittal_slice=None, + def __init__(self, seconds=3600, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, sagittal_slice=None, **kwargs): super().__init__(**kwargs) self._seconds = seconds From ff9ea8db97eb12feda3a804ba980fab9cca776fe Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Tue, 1 Oct 2024 20:56:56 +0100 Subject: [PATCH 6/8] early stopping on meeting threshold --- SIRF_data_preparation/evaluation_utilities.py | 2 +- petric.py | 20 +++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/SIRF_data_preparation/evaluation_utilities.py b/SIRF_data_preparation/evaluation_utilities.py index 20a371e..41f87d4 100644 --- a/SIRF_data_preparation/evaluation_utilities.py +++ b/SIRF_data_preparation/evaluation_utilities.py @@ -25,7 +25,7 @@ def get_metrics(qm: QualityMetrics, iters: Iterable[int], srcdir='.'): list(qm.evaluate(STIR.ImageData(str(Path(srcdir) / f'iter_{i:04d}.hv'))).values()) for i in iters]) -def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 1) -> int: +def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 10) -> int: """ Returns first index of `metrics` with value <= `thresh`. The values must remain below the respective thresholds for at least `window` number of entries. diff --git a/petric.py b/petric.py index 8226558..a3d4d65 100755 --- a/petric.py +++ b/petric.py @@ -17,6 +17,7 @@ import csv import logging import os +import re from dataclasses import dataclass from pathlib import Path, PurePath from time import time @@ -111,7 +112,10 @@ def __call__(self, algo: Algorithm): class QualityMetrics(ImageQualityCallback, Callback): """From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds""" - def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1, **kwargs): + THRESHOLD = {"AEM_VOI": 0.005, "RMSE_whole_object": 0.01, "RMSE_background": 0.01} + + def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1, + threshold_window: int = 10, **kwargs): # TODO: drop multiple inheritance once `interval` included in CIL Callback.__init__(self, interval=interval) ImageQualityCallback.__init__(self, reference_image, **kwargs) @@ -119,13 +123,25 @@ def __init__(self, reference_image, whole_object_mask, background_mask, interval self.background_indices = np.where(background_mask.as_array()) self.ref_im_arr = reference_image.as_array() self.norm = self.ref_im_arr[self.background_indices].mean() + self.threshold_window = threshold_window + self.threshold_iters = 0 def __call__(self, algo: Algorithm): if self.skip_iteration(algo): return t = self._time_ - for tag, value in self.evaluate(algo.x).items(): + # log metrics + metrics = self.evaluate(algo.x) + for tag, value in metrics.items(): self.tb_summary_writer.add_scalar(tag, value, algo.iteration, t) + # stop if `all(metrics < THRESHOLD)` for `threshold_window` iters + # NB: need to strip suffix from "AEM_VOI" tags + if all(value <= self.THRESHOLD[re.sub("^(AEM_VOI)_.*", r"\1", tag)] for tag, value in metrics.items()): + self.threshold_iters += 1 + if self.threshold_iters >= self.threshold_window: + raise StopIteration + else: + self.threshold_iters = 0 def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]: assert not any(self.filter.values()), "Filtering not implemented" From 4f6df44a9613b2254996031c85edc0cafd661ea0 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Tue, 1 Oct 2024 22:19:03 +0100 Subject: [PATCH 7/8] metrics: add relative timing calculation helper --- petric.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/petric.py b/petric.py index a3d4d65..8f3af63 100755 --- a/petric.py +++ b/petric.py @@ -176,12 +176,14 @@ def __init__(self, seconds=3600, outdir=OUTDIR, transverse_slice=None, coronal_s self.reset() def reset(self): - self.limit = time() + self._seconds self.offset = 0 + self.limit = (now := time()) + self._seconds + self.tb.add_scalar("reset", 0, -1, now) # for relative timing calculation def __call__(self, algo: Algorithm): if (time_excluding_metrics := (now := time()) - self.offset) > self.limit: log.warning("Timeout reached. Stopping algorithm.") + self.tb.add_scalar("reset", 0, algo.iteration, time_excluding_metrics) raise StopIteration for c in self.callbacks: c._time_ = time_excluding_metrics From 999c7c8c1bedfa23120007f622ae92b31ecf5b5e Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Tue, 1 Oct 2024 23:03:06 +0100 Subject: [PATCH 8/8] move pass_index --- SIRF_data_preparation/evaluation_utilities.py | 16 ---------------- SIRF_data_preparation/plot_BSREM_metrics.py | 4 ++-- petric.py | 17 +++++++++++++++++ 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/SIRF_data_preparation/evaluation_utilities.py b/SIRF_data_preparation/evaluation_utilities.py index 41f87d4..791d629 100644 --- a/SIRF_data_preparation/evaluation_utilities.py +++ b/SIRF_data_preparation/evaluation_utilities.py @@ -5,7 +5,6 @@ import matplotlib.pyplot as plt import numpy as np -from scipy.ndimage import binary_erosion import sirf.STIR as STIR from petric import QualityMetrics @@ -25,21 +24,6 @@ def get_metrics(qm: QualityMetrics, iters: Iterable[int], srcdir='.'): list(qm.evaluate(STIR.ImageData(str(Path(srcdir) / f'iter_{i:04d}.hv'))).values()) for i in iters]) -def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 10) -> int: - """ - Returns first index of `metrics` with value <= `thresh`. - The values must remain below the respective thresholds for at least `window` number of entries. - Otherwise raises IndexError. - """ - thr_arr = np.asanyarray(thresh) - assert metrics.ndim == 2 - assert thr_arr.ndim == 1 - assert metrics.shape[1] == thr_arr.shape[0] - passed = (metrics <= thr_arr[None]).all(axis=1) - res = binary_erosion(passed, structure=np.ones(window), origin=-(window // 2)) - return np.where(res)[0][0] - - def plot_metrics(iters: Iterable[int], m: np.ndarray, labels=None, suffix=""): """Make 2 subplots of metrics""" if labels is None: diff --git a/SIRF_data_preparation/plot_BSREM_metrics.py b/SIRF_data_preparation/plot_BSREM_metrics.py index 39bb828..55ae8ff 100644 --- a/SIRF_data_preparation/plot_BSREM_metrics.py +++ b/SIRF_data_preparation/plot_BSREM_metrics.py @@ -11,7 +11,7 @@ from petric import OUTDIR, SRCDIR, QualityMetrics, get_data from SIRF_data_preparation import data_QC from SIRF_data_preparation.dataset_settings import get_settings -from SIRF_data_preparation.evaluation_utilities import get_metrics, pass_index, plot_metrics, read_objectives +from SIRF_data_preparation.evaluation_utilities import get_metrics, plot_metrics, read_objectives if not all((SRCDIR.is_dir(), OUTDIR.is_dir())): PETRICDIR = Path('~/devel/PETRIC').expanduser() @@ -112,7 +112,7 @@ fig.savefig(outdir / f'{scanID}_metrics_BSREM.png') # %% -idx = pass_index(m, numpy.array([.01, .01] + [.005 for i in range(len(data.voi_masks))]), 10) +idx = QualityMetrics.pass_index(m, numpy.array([.01, .01] + [.005 for i in range(len(data.voi_masks))]), 10) iter = iters[idx] print(iter) image = STIR.ImageData(str(datadir / f"iter_{iter:04d}.hv")) diff --git a/petric.py b/petric.py index 8f3af63..5b3cb23 100755 --- a/petric.py +++ b/petric.py @@ -21,8 +21,10 @@ from dataclasses import dataclass from pathlib import Path, PurePath from time import time +from typing import Iterable import numpy as np +from scipy.ndimage import binary_erosion from skimage.metrics import mean_squared_error as mse from tensorboardX import SummaryWriter @@ -160,6 +162,21 @@ def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]: def keys(self): return ["RMSE_whole_object", "RMSE_background"] + [f"AEM_VOI_{name}" for name in sorted(self.voi_indices)] + @staticmethod + def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 10) -> int: + """ + Returns first index of `metrics` with value <= `thresh`. + The values must remain below the respective thresholds for at least `window` number of entries. + Otherwise raises IndexError. + """ + thr_arr = np.asanyarray(thresh) + assert metrics.ndim == 2 + assert thr_arr.ndim == 1 + assert metrics.shape[1] == thr_arr.shape[0] + passed = (metrics <= thr_arr[None]).all(axis=1) + res = binary_erosion(passed, structure=np.ones(window), origin=-(window // 2)) + return np.where(res)[0][0] + class MetricsWithTimeout(cil_callbacks.Callback): """Stops the algorithm after `seconds`"""