diff --git a/cyto_dl/eval.py b/cyto_dl/eval.py index 8020578e9..1fe1e92c2 100644 --- a/cyto_dl/eval.py +++ b/cyto_dl/eval.py @@ -19,7 +19,7 @@ @utils.task_wrapper -def evaluate(cfg: DictConfig) -> Tuple[dict, dict]: +def evaluate(cfg: DictConfig) -> Tuple[dict, dict, dict]: """Evaluates given checkpoint on a datamodule testset. This method is wrapped in optional @task_wrapper decorator which applies extra utilities @@ -84,11 +84,10 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]: log.info("Starting testing!") method = trainer.test if cfg.get("test", False) else trainer.predict - method(model=model, dataloaders=data, ckpt_path=cfg.ckpt_path) - + output = method(model=model, dataloaders=data, ckpt_path=cfg.ckpt_path) metric_dict = trainer.callback_metrics - return metric_dict, object_dict + return metric_dict, object_dict, output @hydra.main( diff --git a/cyto_dl/models/im2im/gan.py b/cyto_dl/models/im2im/gan.py index 355febc06..12a2d5f94 100644 --- a/cyto_dl/models/im2im/gan.py +++ b/cyto_dl/models/im2im/gan.py @@ -163,7 +163,7 @@ def _get_run_heads(self, batch, stage): if stage not in ("test", "predict"): run_heads = [key for key in self.task_heads.keys() if key in batch] else: - run_heads = self.task_heads.keys() + run_heads = list(self.task_heads.keys()) return run_heads def _extract_loss(self, outs, loss_type): @@ -216,5 +216,5 @@ def predict_step(self, batch, batch_idx): batch[k] = v.as_tensor() stage = "predict" run_heads = self._get_run_heads(batch, stage) - self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads) - return (None, None, None) + outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads) + return outs[run_heads[0]]["save_path"] diff --git a/cyto_dl/models/im2im/multi_task.py b/cyto_dl/models/im2im/multi_task.py index 71122011d..88ff6ddc5 100644 --- a/cyto_dl/models/im2im/multi_task.py +++ b/cyto_dl/models/im2im/multi_task.py @@ -77,7 +77,7 @@ def __init__( self.backbone = backbone self.task_heads = torch.nn.ModuleDict(task_heads) - self.inference_heads = inference_heads or self.task_heads.keys() + self.inference_heads = inference_heads or list(self.task_heads.keys()) for k, head in self.task_heads.items(): head.update_params({"head_name": k, "x_key": x_key, "save_dir": save_dir}) @@ -192,5 +192,4 @@ def predict_step(self, batch, batch_idx): batch[k] = v.as_tensor() run_heads = self._get_run_heads(batch, stage) outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads) - preds = {head_name: head_result["y_hat_out"] for head_name, head_result in outs.items()} - return None, preds, None + return outs[run_heads[0]]["save_path"] diff --git a/cyto_dl/nn/head/base_head.py b/cyto_dl/nn/head/base_head.py index aa16e08bb..d65e7242c 100644 --- a/cyto_dl/nn/head/base_head.py +++ b/cyto_dl/nn/head/base_head.py @@ -1,11 +1,8 @@ -import math from abc import ABC from pathlib import Path -import numpy as np import torch from aicsimageio.writers import OmeTiffWriter -from monai.networks.blocks import Convolution, UnetOutBlock, UnetResBlock, UpSample from cyto_dl.models.im2im.utils.postprocessing import detach @@ -51,11 +48,13 @@ def _postprocess(self, img, img_type): return [self.postprocess[img_type](img[i]) for i in range(img.shape[0])] def _save(self, fn, img, stage): + out_path = Path(self.save_dir) / f"{stage}_images" / fn OmeTiffWriter().save( - uri=Path(self.save_dir) / f"{stage}_images" / fn, + uri=out_path, data=img.squeeze(), dims_order="STCZYX"[-len(img.shape)], ) + return out_path def _calculate_metric(self, y_hat, y): raise NotImplementedError @@ -70,6 +69,7 @@ def save_image(self, y_hat, batch, stage, global_step): raw_out = self._postprocess(batch[self.x_key], img_type="input") try: metadata_filenames = batch[f"{self.x_key}_meta_dict"]["filename_or_obj"] + filename_map = {"input": metadata_filenames, "output": []} metadata_filenames = [ f"{Path(fn).stem}_{self.head_name}.tif" for fn in metadata_filenames ] @@ -84,13 +84,14 @@ def save_image(self, y_hat, batch, stage, global_step): ) n_save = len(y_hat_out) if stage in ("test", "predict") else 1 for i in range(n_save): - self._save(save_name[i].replace(".tif", "_pred.tif"), y_hat_out[i], stage) + out_path = self._save(save_name[i].replace(".tif", "_pred.tif"), y_hat_out[i], stage) + filename_map["output"].append(out_path) if stage in ("train", "val"): self._save(save_name[i], y_out[i], stage) if self.save_raw: self._save(save_name[i].replace(".tif", "_raw.tif"), raw_out[i], stage) - return y_hat_out, y_out + return y_hat_out, y_out, filename_map def forward(self, x): return self.model(x) @@ -117,11 +118,17 @@ def run_head( if stage != "predict": loss = self._calculate_loss(y_hat, batch[self.head_name]) - y_hat_out, y_out = None, None + y_hat_out, y_out, out_paths = None, None, None if save_image: - y_hat_out, y_out = self.save_image(y_hat, batch, stage, global_step) + y_hat_out, y_out, out_paths = self.save_image(y_hat, batch, stage, global_step) metric = None if self.calculate_metric and stage in ("val", "test"): metric = self._calculate_metric(y_hat, batch[self.head_name]) - return {"loss": loss, "metric": metric, "y_hat_out": y_hat_out, "y_out": y_out} + return { + "loss": loss, + "metric": metric, + "y_hat_out": y_hat_out, + "y_out": y_out, + "save_path": out_paths, + } diff --git a/cyto_dl/nn/head/gan_head.py b/cyto_dl/nn/head/gan_head.py index 4332fa7cf..c241a3b03 100644 --- a/cyto_dl/nn/head/gan_head.py +++ b/cyto_dl/nn/head/gan_head.py @@ -81,9 +81,9 @@ def run_head( ) loss_D, loss_G = self._calculate_loss(y_hat, batch, discriminator) - y_hat_out, y_out = None, None + y_hat_out, y_out, out_paths = None, None, None if save_image: - y_hat_out, y_out = self.save_image(y_hat, batch, stage, global_step) + y_hat_out, y_out, out_paths = self.save_image(y_hat, batch, stage, global_step) metric = None if self.calculate_metric and stage in ("val", "test"): @@ -95,4 +95,5 @@ def run_head( "metric": metric, "y_hat_out": y_hat_out, "y_out": y_out, + "save_path": out_paths, }