Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

evaluation #121

Merged
merged 8 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions SIRF_data_preparation/evaluation_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import binary_erosion

import sirf.STIR as STIR
from petric import QualityMetrics
Expand All @@ -25,21 +24,6 @@ def get_metrics(qm: QualityMetrics, iters: Iterable[int], srcdir='.'):
list(qm.evaluate(STIR.ImageData(str(Path(srcdir) / f'iter_{i:04d}.hv'))).values()) for i in iters])


def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 1) -> int:
"""
Returns first index of `metrics` with value <= `thresh`.
The values must remain below the respective thresholds for at least `window` number of entries.
Otherwise raises IndexError.
"""
thr_arr = np.asanyarray(thresh)
assert metrics.ndim == 2
assert thr_arr.ndim == 1
assert metrics.shape[1] == thr_arr.shape[0]
passed = (metrics <= thr_arr[None]).all(axis=1)
res = binary_erosion(passed, structure=np.ones(window), origin=-(window // 2))
return np.where(res)[0][0]


def plot_metrics(iters: Iterable[int], m: np.ndarray, labels=None, suffix=""):
"""Make 2 subplots of metrics"""
if labels is None:
Expand Down
4 changes: 2 additions & 2 deletions SIRF_data_preparation/plot_BSREM_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from petric import OUTDIR, SRCDIR, QualityMetrics, get_data
from SIRF_data_preparation import data_QC
from SIRF_data_preparation.dataset_settings import get_settings
from SIRF_data_preparation.evaluation_utilities import get_metrics, pass_index, plot_metrics, read_objectives
from SIRF_data_preparation.evaluation_utilities import get_metrics, plot_metrics, read_objectives

if not all((SRCDIR.is_dir(), OUTDIR.is_dir())):
PETRICDIR = Path('~/devel/PETRIC').expanduser()
Expand Down Expand Up @@ -112,7 +112,7 @@
fig.savefig(outdir / f'{scanID}_metrics_BSREM.png')

# %%
idx = pass_index(m, numpy.array([.01, .01] + [.005 for i in range(len(data.voi_masks))]), 10)
idx = QualityMetrics.pass_index(m, numpy.array([.01, .01] + [.005 for i in range(len(data.voi_masks))]), 10)
iter = iters[idx]
print(iter)
image = STIR.ImageData(str(datadir / f"iter_{iter:04d}.hv"))
Expand Down
69 changes: 54 additions & 15 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
import csv
import logging
import os
import re
from dataclasses import dataclass
from pathlib import Path, PurePath
from time import time
from typing import Iterable

import numpy as np
from scipy.ndimage import binary_erosion
from skimage.metrics import mean_squared_error as mse
from tensorboardX import SummaryWriter

Expand All @@ -43,7 +46,7 @@ class Callback(cil_callbacks.Callback):
CIL Callback but with `self.skip_iteration` checking `min(self.interval, algo.update_objective_interval)`.
TODO: backport this class to CIL.
"""
def __init__(self, interval: int = 3, **kwargs):
def __init__(self, interval: int = 1, **kwargs):
super().__init__(**kwargs)
self.interval = interval

Expand Down Expand Up @@ -86,19 +89,20 @@ def __init__(self, transverse_slice=None, coronal_slice=None, sagittal_slice=Non
def __call__(self, algo: Algorithm):
if self.skip_iteration(algo):
return
t = getattr(self, '__time', None) or time()
t = self._time_
log.debug("logging iter %d...", algo.iteration)
# initialise `None` values
self.transverse_slice = algo.x.dimensions()[0] // 2 if self.transverse_slice is None else self.transverse_slice
self.coronal_slice = algo.x.dimensions()[1] // 2 if self.coronal_slice is None else self.coronal_slice
self.sagittal_slice = algo.x.dimensions()[2] // 2 if self.sagittal_slice is None else self.sagittal_slice
self.vmax = algo.x.max() if self.vmax is None else self.vmax

self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t)
if self.x_prev is not None:
normalised_change = (algo.x - self.x_prev).norm() / algo.x.norm()
self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t)
self.x_prev = algo.x.clone()
if log.getEffectiveLevel() <= logging.DEBUG:
self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t)
if self.x_prev is not None:
normalised_change = (algo.x - self.x_prev).norm() / algo.x.norm()
self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t)
casperdcl marked this conversation as resolved.
Show resolved Hide resolved
self.x_prev = algo.x.clone()
x_arr = algo.x.as_array()
self.tb.add_image("transverse", np.clip(x_arr[None, self.transverse_slice] / self.vmax, 0, 1), algo.iteration,
t)
Expand All @@ -110,21 +114,36 @@ def __call__(self, algo: Algorithm):

class QualityMetrics(ImageQualityCallback, Callback):
"""From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds"""
def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 3, **kwargs):
THRESHOLD = {"AEM_VOI": 0.005, "RMSE_whole_object": 0.01, "RMSE_background": 0.01}

def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1,
threshold_window: int = 10, **kwargs):
# TODO: drop multiple inheritance once `interval` included in CIL
Callback.__init__(self, interval=interval)
ImageQualityCallback.__init__(self, reference_image, **kwargs)
self.whole_object_indices = np.where(whole_object_mask.as_array())
self.background_indices = np.where(background_mask.as_array())
self.ref_im_arr = reference_image.as_array()
self.norm = self.ref_im_arr[self.background_indices].mean()
self.threshold_window = threshold_window
self.threshold_iters = 0

def __call__(self, algo: Algorithm):
if self.skip_iteration(algo):
return
t = getattr(self, '__time', None) or time()
for tag, value in self.evaluate(algo.x).items():
t = self._time_
# log metrics
metrics = self.evaluate(algo.x)
for tag, value in metrics.items():
self.tb_summary_writer.add_scalar(tag, value, algo.iteration, t)
# stop if `all(metrics < THRESHOLD)` for `threshold_window` iters
# NB: need to strip suffix from "AEM_VOI" tags
if all(value <= self.THRESHOLD[re.sub("^(AEM_VOI)_.*", r"\1", tag)] for tag, value in metrics.items()):
self.threshold_iters += 1
if self.threshold_iters >= self.threshold_window:
raise StopIteration
else:
self.threshold_iters = 0

def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]:
assert not any(self.filter.values()), "Filtering not implemented"
Expand All @@ -143,10 +162,25 @@ def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]:
def keys(self):
return ["RMSE_whole_object", "RMSE_background"] + [f"AEM_VOI_{name}" for name in sorted(self.voi_indices)]

@staticmethod
def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 10) -> int:
"""
Returns first index of `metrics` with value <= `thresh`.
The values must remain below the respective thresholds for at least `window` number of entries.
Otherwise raises IndexError.
"""
thr_arr = np.asanyarray(thresh)
assert metrics.ndim == 2
assert thr_arr.ndim == 1
assert metrics.shape[1] == thr_arr.shape[0]
passed = (metrics <= thr_arr[None]).all(axis=1)
res = binary_erosion(passed, structure=np.ones(window), origin=-(window // 2))
return np.where(res)[0][0]


class MetricsWithTimeout(cil_callbacks.Callback):
"""Stops the algorithm after `seconds`"""
def __init__(self, seconds=600, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, sagittal_slice=None,
def __init__(self, seconds=3600, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, sagittal_slice=None,
**kwargs):
super().__init__(**kwargs)
self._seconds = seconds
Expand All @@ -158,16 +192,18 @@ def __init__(self, seconds=600, outdir=OUTDIR, transverse_slice=None, coronal_sl
self.tb = tb_cbk.tb # convenient access to the underlying SummaryWriter
self.reset()

def reset(self, seconds=None):
self.limit = time() + (self._seconds if seconds is None else seconds)
def reset(self):
self.offset = 0
self.limit = (now := time()) + self._seconds
self.tb.add_scalar("reset", 0, -1, now) # for relative timing calculation

def __call__(self, algo: Algorithm):
if (now := time()) > self.limit + self.offset:
if (time_excluding_metrics := (now := time()) - self.offset) > self.limit:
log.warning("Timeout reached. Stopping algorithm.")
self.tb.add_scalar("reset", 0, algo.iteration, time_excluding_metrics)
raise StopIteration
for c in self.callbacks:
c.__time = now - self.offset # privately inject walltime-excluding-petric-callbacks
c._time_ = time_excluding_metrics
c(algo)
self.offset += time() - now

Expand Down Expand Up @@ -292,8 +328,11 @@ def get_image(fname):
from traceback import print_exc

from docopt import docopt
from tqdm.contrib.logging import logging_redirect_tqdm
args = docopt(__doc__)
logging.basicConfig(level=getattr(logging, args["--log"].upper()))
redir = logging_redirect_tqdm()
redir.__enter__()
from main import Submission, submission_callbacks
assert issubclass(Submission, Algorithm)
for srcdir, outdir, metrics in data_dirs_metrics:
Expand Down