diff --git a/petric.py b/petric.py index debf617..5f51f83 100755 --- a/petric.py +++ b/petric.py @@ -27,6 +27,7 @@ from scipy.ndimage import binary_erosion from skimage.metrics import mean_squared_error as mse from tensorboardX import SummaryWriter +from tqdm.auto import tqdm import sirf.STIR as STIR from cil.optimisation.algorithms import Algorithm @@ -157,7 +158,8 @@ def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]: f"AEM_VOI_{voi_name}": np.abs(test_im_arr[voi_indices].mean() - self.ref_im_arr[voi_indices].mean()) / self.norm for voi_name, voi_indices in sorted(self.voi_indices.items())} - return {**whole, **local} + self._evaluate_cache = {**whole, **local} + return self._evaluate_cache def keys(self): return ["RMSE_whole_object", "RMSE_background"] + [f"AEM_VOI_{name}" for name in sorted(self.voi_indices)] @@ -181,11 +183,11 @@ def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 10) -> int: class MetricsWithTimeout(Callback): """Stops the algorithm after `seconds`""" def __init__(self, seconds=3600, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, sagittal_slice=None, - **kwargs): + tqdm_class=tqdm, **kwargs): super().__init__(**kwargs) self._seconds = seconds self.callbacks = [ - cil_callbacks.ProgressCallback(), + cil_callbacks.ProgressCallback(desc=f"{TEAM}/{VERSION}/{outdir.name}", tqdm_class=tqdm_class), SaveIters(outdir=outdir, **kwargs), (tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice, sagittal_slice=sagittal_slice, **kwargs))] @@ -205,6 +207,10 @@ def __call__(self, algo: Algorithm): for c in self.callbacks: c._time_ = time_excluding_metrics c(algo) + if isinstance(self.callbacks[-1], QualityMetrics) and isinstance(self.callbacks[0], + cil_callbacks.ProgressCallback): + self.callbacks[0].pbar.set_postfix( + RMSE_whole_object=self.callbacks[-1]._evaluate_cache['RMSE_whole_object'], refresh=False) self.offset += time() - now @staticmethod