Skip to content

Commit

Permalink
Merge pull request #121 from SyneRBI/evaluation
Browse files Browse the repository at this point in the history
evaluation
  • Loading branch information
casperdcl authored Oct 2, 2024
2 parents 3f706cc + 999c7c8 commit 8f25b7f
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 33 deletions.
16 changes: 0 additions & 16 deletions SIRF_data_preparation/evaluation_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import binary_erosion

import sirf.STIR as STIR
from petric import QualityMetrics
Expand All @@ -25,21 +24,6 @@ def get_metrics(qm: QualityMetrics, iters: Iterable[int], srcdir='.'):
list(qm.evaluate(STIR.ImageData(str(Path(srcdir) / f'iter_{i:04d}.hv'))).values()) for i in iters])


def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 1) -> int:
"""
Returns first index of `metrics` with value <= `thresh`.
The values must remain below the respective thresholds for at least `window` number of entries.
Otherwise raises IndexError.
"""
thr_arr = np.asanyarray(thresh)
assert metrics.ndim == 2
assert thr_arr.ndim == 1
assert metrics.shape[1] == thr_arr.shape[0]
passed = (metrics <= thr_arr[None]).all(axis=1)
res = binary_erosion(passed, structure=np.ones(window), origin=-(window // 2))
return np.where(res)[0][0]


def plot_metrics(iters: Iterable[int], m: np.ndarray, labels=None, suffix=""):
"""Make 2 subplots of metrics"""
if labels is None:
Expand Down
4 changes: 2 additions & 2 deletions SIRF_data_preparation/plot_BSREM_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from petric import OUTDIR, SRCDIR, QualityMetrics, get_data
from SIRF_data_preparation import data_QC
from SIRF_data_preparation.dataset_settings import get_settings
from SIRF_data_preparation.evaluation_utilities import get_metrics, pass_index, plot_metrics, read_objectives
from SIRF_data_preparation.evaluation_utilities import get_metrics, plot_metrics, read_objectives

if not all((SRCDIR.is_dir(), OUTDIR.is_dir())):
PETRICDIR = Path('~/devel/PETRIC').expanduser()
Expand Down Expand Up @@ -112,7 +112,7 @@
fig.savefig(outdir / f'{scanID}_metrics_BSREM.png')

# %%
idx = pass_index(m, numpy.array([.01, .01] + [.005 for i in range(len(data.voi_masks))]), 10)
idx = QualityMetrics.pass_index(m, numpy.array([.01, .01] + [.005 for i in range(len(data.voi_masks))]), 10)
iter = iters[idx]
print(iter)
image = STIR.ImageData(str(datadir / f"iter_{iter:04d}.hv"))
Expand Down
69 changes: 54 additions & 15 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
import csv
import logging
import os
import re
from dataclasses import dataclass
from pathlib import Path, PurePath
from time import time
from typing import Iterable

import numpy as np
from scipy.ndimage import binary_erosion
from skimage.metrics import mean_squared_error as mse
from tensorboardX import SummaryWriter

Expand All @@ -43,7 +46,7 @@ class Callback(cil_callbacks.Callback):
CIL Callback but with `self.skip_iteration` checking `min(self.interval, algo.update_objective_interval)`.
TODO: backport this class to CIL.
"""
def __init__(self, interval: int = 3, **kwargs):
def __init__(self, interval: int = 1, **kwargs):
super().__init__(**kwargs)
self.interval = interval

Expand Down Expand Up @@ -86,19 +89,20 @@ def __init__(self, transverse_slice=None, coronal_slice=None, sagittal_slice=Non
def __call__(self, algo: Algorithm):
if self.skip_iteration(algo):
return
t = getattr(self, '__time', None) or time()
t = self._time_
log.debug("logging iter %d...", algo.iteration)
# initialise `None` values
self.transverse_slice = algo.x.dimensions()[0] // 2 if self.transverse_slice is None else self.transverse_slice
self.coronal_slice = algo.x.dimensions()[1] // 2 if self.coronal_slice is None else self.coronal_slice
self.sagittal_slice = algo.x.dimensions()[2] // 2 if self.sagittal_slice is None else self.sagittal_slice
self.vmax = algo.x.max() if self.vmax is None else self.vmax

self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t)
if self.x_prev is not None:
normalised_change = (algo.x - self.x_prev).norm() / algo.x.norm()
self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t)
self.x_prev = algo.x.clone()
if log.getEffectiveLevel() <= logging.DEBUG:
self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t)
if self.x_prev is not None:
normalised_change = (algo.x - self.x_prev).norm() / algo.x.norm()
self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t)
self.x_prev = algo.x.clone()
x_arr = algo.x.as_array()
self.tb.add_image("transverse", np.clip(x_arr[None, self.transverse_slice] / self.vmax, 0, 1), algo.iteration,
t)
Expand All @@ -110,21 +114,36 @@ def __call__(self, algo: Algorithm):

class QualityMetrics(ImageQualityCallback, Callback):
"""From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds"""
def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 3, **kwargs):
THRESHOLD = {"AEM_VOI": 0.005, "RMSE_whole_object": 0.01, "RMSE_background": 0.01}

def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1,
threshold_window: int = 10, **kwargs):
# TODO: drop multiple inheritance once `interval` included in CIL
Callback.__init__(self, interval=interval)
ImageQualityCallback.__init__(self, reference_image, **kwargs)
self.whole_object_indices = np.where(whole_object_mask.as_array())
self.background_indices = np.where(background_mask.as_array())
self.ref_im_arr = reference_image.as_array()
self.norm = self.ref_im_arr[self.background_indices].mean()
self.threshold_window = threshold_window
self.threshold_iters = 0

def __call__(self, algo: Algorithm):
if self.skip_iteration(algo):
return
t = getattr(self, '__time', None) or time()
for tag, value in self.evaluate(algo.x).items():
t = self._time_
# log metrics
metrics = self.evaluate(algo.x)
for tag, value in metrics.items():
self.tb_summary_writer.add_scalar(tag, value, algo.iteration, t)
# stop if `all(metrics < THRESHOLD)` for `threshold_window` iters
# NB: need to strip suffix from "AEM_VOI" tags
if all(value <= self.THRESHOLD[re.sub("^(AEM_VOI)_.*", r"\1", tag)] for tag, value in metrics.items()):
self.threshold_iters += 1
if self.threshold_iters >= self.threshold_window:
raise StopIteration
else:
self.threshold_iters = 0

def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]:
assert not any(self.filter.values()), "Filtering not implemented"
Expand All @@ -143,10 +162,25 @@ def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]:
def keys(self):
return ["RMSE_whole_object", "RMSE_background"] + [f"AEM_VOI_{name}" for name in sorted(self.voi_indices)]

@staticmethod
def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 10) -> int:
"""
Returns first index of `metrics` with value <= `thresh`.
The values must remain below the respective thresholds for at least `window` number of entries.
Otherwise raises IndexError.
"""
thr_arr = np.asanyarray(thresh)
assert metrics.ndim == 2
assert thr_arr.ndim == 1
assert metrics.shape[1] == thr_arr.shape[0]
passed = (metrics <= thr_arr[None]).all(axis=1)
res = binary_erosion(passed, structure=np.ones(window), origin=-(window // 2))
return np.where(res)[0][0]


class MetricsWithTimeout(cil_callbacks.Callback):
"""Stops the algorithm after `seconds`"""
def __init__(self, seconds=600, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, sagittal_slice=None,
def __init__(self, seconds=3600, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, sagittal_slice=None,
**kwargs):
super().__init__(**kwargs)
self._seconds = seconds
Expand All @@ -158,16 +192,18 @@ def __init__(self, seconds=600, outdir=OUTDIR, transverse_slice=None, coronal_sl
self.tb = tb_cbk.tb # convenient access to the underlying SummaryWriter
self.reset()

def reset(self, seconds=None):
self.limit = time() + (self._seconds if seconds is None else seconds)
def reset(self):
self.offset = 0
self.limit = (now := time()) + self._seconds
self.tb.add_scalar("reset", 0, -1, now) # for relative timing calculation

def __call__(self, algo: Algorithm):
if (now := time()) > self.limit + self.offset:
if (time_excluding_metrics := (now := time()) - self.offset) > self.limit:
log.warning("Timeout reached. Stopping algorithm.")
self.tb.add_scalar("reset", 0, algo.iteration, time_excluding_metrics)
raise StopIteration
for c in self.callbacks:
c.__time = now - self.offset # privately inject walltime-excluding-petric-callbacks
c._time_ = time_excluding_metrics
c(algo)
self.offset += time() - now

Expand Down Expand Up @@ -292,8 +328,11 @@ def get_image(fname):
from traceback import print_exc

from docopt import docopt
from tqdm.contrib.logging import logging_redirect_tqdm
args = docopt(__doc__)
logging.basicConfig(level=getattr(logging, args["--log"].upper()))
redir = logging_redirect_tqdm()
redir.__enter__()
from main import Submission, submission_callbacks
assert issubclass(Submission, Algorithm)
for srcdir, outdir, metrics in data_dirs_metrics:
Expand Down

0 comments on commit 8f25b7f

Please sign in to comment.