Skip to content

Commit

Permalink
return correspondence between input and generated files (#292)
Browse files Browse the repository at this point in the history
Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
benjijamorris and Benjamin Morris authored Sep 25, 2023
1 parent b643714 commit 6ca7278
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 21 deletions.
7 changes: 3 additions & 4 deletions cyto_dl/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


@utils.task_wrapper
def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
def evaluate(cfg: DictConfig) -> Tuple[dict, dict, dict]:
"""Evaluates given checkpoint on a datamodule testset.
This method is wrapped in optional @task_wrapper decorator which applies extra utilities
Expand Down Expand Up @@ -84,11 +84,10 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:

log.info("Starting testing!")
method = trainer.test if cfg.get("test", False) else trainer.predict
method(model=model, dataloaders=data, ckpt_path=cfg.ckpt_path)

output = method(model=model, dataloaders=data, ckpt_path=cfg.ckpt_path)
metric_dict = trainer.callback_metrics

return metric_dict, object_dict
return metric_dict, object_dict, output


@hydra.main(
Expand Down
6 changes: 3 additions & 3 deletions cyto_dl/models/im2im/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _get_run_heads(self, batch, stage):
if stage not in ("test", "predict"):
run_heads = [key for key in self.task_heads.keys() if key in batch]
else:
run_heads = self.task_heads.keys()
run_heads = list(self.task_heads.keys())
return run_heads

def _extract_loss(self, outs, loss_type):
Expand Down Expand Up @@ -216,5 +216,5 @@ def predict_step(self, batch, batch_idx):
batch[k] = v.as_tensor()
stage = "predict"
run_heads = self._get_run_heads(batch, stage)
self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads)
return (None, None, None)
outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads)
return outs[run_heads[0]]["save_path"]
5 changes: 2 additions & 3 deletions cyto_dl/models/im2im/multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
self.backbone = backbone
self.task_heads = torch.nn.ModuleDict(task_heads)

self.inference_heads = inference_heads or self.task_heads.keys()
self.inference_heads = inference_heads or list(self.task_heads.keys())

for k, head in self.task_heads.items():
head.update_params({"head_name": k, "x_key": x_key, "save_dir": save_dir})
Expand Down Expand Up @@ -192,5 +192,4 @@ def predict_step(self, batch, batch_idx):
batch[k] = v.as_tensor()
run_heads = self._get_run_heads(batch, stage)
outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads)
preds = {head_name: head_result["y_hat_out"] for head_name, head_result in outs.items()}
return None, preds, None
return outs[run_heads[0]]["save_path"]
25 changes: 16 additions & 9 deletions cyto_dl/nn/head/base_head.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import math
from abc import ABC
from pathlib import Path

import numpy as np
import torch
from aicsimageio.writers import OmeTiffWriter
from monai.networks.blocks import Convolution, UnetOutBlock, UnetResBlock, UpSample

from cyto_dl.models.im2im.utils.postprocessing import detach

Expand Down Expand Up @@ -51,11 +48,13 @@ def _postprocess(self, img, img_type):
return [self.postprocess[img_type](img[i]) for i in range(img.shape[0])]

def _save(self, fn, img, stage):
out_path = Path(self.save_dir) / f"{stage}_images" / fn
OmeTiffWriter().save(
uri=Path(self.save_dir) / f"{stage}_images" / fn,
uri=out_path,
data=img.squeeze(),
dims_order="STCZYX"[-len(img.shape)],
)
return out_path

def _calculate_metric(self, y_hat, y):
raise NotImplementedError
Expand All @@ -70,6 +69,7 @@ def save_image(self, y_hat, batch, stage, global_step):
raw_out = self._postprocess(batch[self.x_key], img_type="input")
try:
metadata_filenames = batch[f"{self.x_key}_meta_dict"]["filename_or_obj"]
filename_map = {"input": metadata_filenames, "output": []}
metadata_filenames = [
f"{Path(fn).stem}_{self.head_name}.tif" for fn in metadata_filenames
]
Expand All @@ -84,13 +84,14 @@ def save_image(self, y_hat, batch, stage, global_step):
)
n_save = len(y_hat_out) if stage in ("test", "predict") else 1
for i in range(n_save):
self._save(save_name[i].replace(".tif", "_pred.tif"), y_hat_out[i], stage)
out_path = self._save(save_name[i].replace(".tif", "_pred.tif"), y_hat_out[i], stage)
filename_map["output"].append(out_path)
if stage in ("train", "val"):
self._save(save_name[i], y_out[i], stage)
if self.save_raw:
self._save(save_name[i].replace(".tif", "_raw.tif"), raw_out[i], stage)

return y_hat_out, y_out
return y_hat_out, y_out, filename_map

def forward(self, x):
return self.model(x)
Expand All @@ -117,11 +118,17 @@ def run_head(
if stage != "predict":
loss = self._calculate_loss(y_hat, batch[self.head_name])

y_hat_out, y_out = None, None
y_hat_out, y_out, out_paths = None, None, None
if save_image:
y_hat_out, y_out = self.save_image(y_hat, batch, stage, global_step)
y_hat_out, y_out, out_paths = self.save_image(y_hat, batch, stage, global_step)

metric = None
if self.calculate_metric and stage in ("val", "test"):
metric = self._calculate_metric(y_hat, batch[self.head_name])
return {"loss": loss, "metric": metric, "y_hat_out": y_hat_out, "y_out": y_out}
return {
"loss": loss,
"metric": metric,
"y_hat_out": y_hat_out,
"y_out": y_out,
"save_path": out_paths,
}
5 changes: 3 additions & 2 deletions cyto_dl/nn/head/gan_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def run_head(
)
loss_D, loss_G = self._calculate_loss(y_hat, batch, discriminator)

y_hat_out, y_out = None, None
y_hat_out, y_out, out_paths = None, None, None
if save_image:
y_hat_out, y_out = self.save_image(y_hat, batch, stage, global_step)
y_hat_out, y_out, out_paths = self.save_image(y_hat, batch, stage, global_step)

metric = None
if self.calculate_metric and stage in ("val", "test"):
Expand All @@ -95,4 +95,5 @@ def run_head(
"metric": metric,
"y_hat_out": y_hat_out,
"y_out": y_out,
"save_path": out_paths,
}

0 comments on commit 6ca7278

Please sign in to comment.