From bd15888249c935380563e10ad54bca299320d72d Mon Sep 17 00:00:00 2001 From: benjijamorris <54606172+benjijamorris@users.noreply.github.com> Date: Thu, 1 Aug 2024 12:59:34 -0700 Subject: [PATCH] Feature/image saving callback (#407) * refactor to do image saving in callback * update segmentation config * add gan * add gan * add instance seg * update configs * add model configs * add n_postprocess to mask head ; * remove save input from res blocks head * simplify image saving; * update vicreg and mae heads * remove multimae head * update plugin exp. * precommit * add warning * precommit * add ostats * add callback to init * update ostats --------- Co-authored-by: Benjamin Morris Co-authored-by: Benjamin Morris --- configs/experiment/im2im/gan.yaml | 15 +++ configs/experiment/im2im/gan_superres.yaml | 15 +++ configs/experiment/im2im/instance_seg.yaml | 15 +++ configs/experiment/im2im/labelfree.yaml | 16 +++ configs/experiment/im2im/mae.yaml | 16 +++ configs/experiment/im2im/segmentation.yaml | 16 +++ .../experiment/im2im/segmentation_plugin.yaml | 16 +++ .../im2im/segmentation_superres.yaml | 16 +++ .../experiment/im2im/vit_segmentation.yaml | 16 +++ configs/model/im2im/gan.yaml | 1 - configs/model/im2im/gan_superres.yaml | 1 - configs/model/im2im/instance_seg.yaml | 1 - configs/model/im2im/labelfree.yaml | 1 - configs/model/im2im/segmentation.yaml | 1 - configs/model/im2im/segmentation_plugin.yaml | 2 - .../model/im2im/segmentation_superres.yaml | 3 +- .../model/im2im/vit_segmentation_decoder.yaml | 1 - cyto_dl/callbacks/__init__.py | 1 + cyto_dl/callbacks/image_saver.py | 80 +++++++++++++ cyto_dl/callbacks/outlier_detection.py | 2 +- cyto_dl/models/im2im/gan.py | 25 ++-- cyto_dl/models/im2im/multi_task.py | 112 +++++++++++------- cyto_dl/nn/head/base_head.py | 68 ++++------- cyto_dl/nn/head/gan_head.py | 24 ++-- cyto_dl/nn/head/gan_head_superres.py | 6 - cyto_dl/nn/head/mae_head.py | 11 +- cyto_dl/nn/head/mask_head.py | 25 ++-- cyto_dl/nn/head/res_blocks_head.py | 7 +- cyto_dl/nn/head/vic_reg.py | 4 +- pyproject.toml | 3 +- 30 files changed, 370 insertions(+), 150 deletions(-) create mode 100644 cyto_dl/callbacks/image_saver.py diff --git a/configs/experiment/im2im/gan.yaml b/configs/experiment/im2im/gan.yaml index f4c65eebd..58e7d3072 100644 --- a/configs/experiment/im2im/gan.yaml +++ b/configs/experiment/im2im/gan.yaml @@ -42,3 +42,18 @@ callbacks: early_stopping: monitor: val/loss/generator_loss + + # prediction + # saving: + # _target_: cyto_dl.callbacks.ImageSaver + # save_dir: ${paths.output_dir} + # save_every_n_epochs: ${model.save_images_every_n_epochs} + # stages: ["predict"] + # save_input: False + # training + saving: + _target_: cyto_dl.callbacks.ImageSaver + save_dir: ${paths.output_dir} + save_every_n_epochs: ${model.save_images_every_n_epochs} + stages: ["train", "test", "val"] + save_input: True diff --git a/configs/experiment/im2im/gan_superres.yaml b/configs/experiment/im2im/gan_superres.yaml index 0fa6e0adb..db8a3e78e 100644 --- a/configs/experiment/im2im/gan_superres.yaml +++ b/configs/experiment/im2im/gan_superres.yaml @@ -42,3 +42,18 @@ callbacks: early_stopping: monitor: val/loss/generator_loss + + # prediction + # saving: + # _target_: cyto_dl.callbacks.ImageSaver + # save_dir: ${paths.output_dir} + # save_every_n_epochs: ${model.save_images_every_n_epochs} + # stages: ["predict"] + # save_input: False + # training + saving: + _target_: cyto_dl.callbacks.ImageSaver + save_dir: ${paths.output_dir} + save_every_n_epochs: ${model.save_images_every_n_epochs} + stages: ["train", "test", "val"] + save_input: True diff --git a/configs/experiment/im2im/instance_seg.yaml b/configs/experiment/im2im/instance_seg.yaml index 82f4bdc01..8c8b95029 100644 --- a/configs/experiment/im2im/instance_seg.yaml +++ b/configs/experiment/im2im/instance_seg.yaml @@ -37,3 +37,18 @@ data: # patch_shape: [64, 64] # 3D patch_shape: [16, 32, 32] + +callbacks: + # prediction + # saving: + # _target_: cyto_dl.callbacks.ImageSaver + # save_dir: ${paths.output_dir} + # stages: ["predict"] + # save_input: False + # training + saving: + _target_: cyto_dl.callbacks.ImageSaver + save_dir: ${paths.output_dir} + save_every_n_epochs: ${model.save_images_every_n_epochs} + stages: ["train", "test", "val"] + save_input: True diff --git a/configs/experiment/im2im/labelfree.yaml b/configs/experiment/im2im/labelfree.yaml index e018a4f57..5e505edb6 100644 --- a/configs/experiment/im2im/labelfree.yaml +++ b/configs/experiment/im2im/labelfree.yaml @@ -35,3 +35,19 @@ data: # patch_shape: [64, 64] # 3D patch_shape: [16, 32, 32] + +callbacks: + # prediction + # saving: + # _target_: cyto_dl.callbacks.ImageSaver + # save_dir: ${paths.output_dir} + # save_every_n_epochs: ${model.save_images_every_n_epochs} + # stages: ["predict"] + # save_input: False + # training + saving: + _target_: cyto_dl.callbacks.ImageSaver + save_dir: ${paths.output_dir} + save_every_n_epochs: ${model.save_images_every_n_epochs} + stages: ["train", "test", "val"] + save_input: True diff --git a/configs/experiment/im2im/mae.yaml b/configs/experiment/im2im/mae.yaml index 459335748..4c65417e7 100644 --- a/configs/experiment/im2im/mae.yaml +++ b/configs/experiment/im2im/mae.yaml @@ -35,3 +35,19 @@ data: # patch_shape: [16, 16] # 3D patch_shape: [16, 16, 16] + +callbacks: + # prediction + # saving: + # _target_: cyto_dl.callbacks.ImageSaver + # save_dir: ${paths.output_dir} + # save_every_n_epochs: ${model.save_images_every_n_epochs} + # stages: ["predict"] + # save_input: False + # training + saving: + _target_: cyto_dl.callbacks.ImageSaver + save_dir: ${paths.output_dir} + save_every_n_epochs: ${model.save_images_every_n_epochs} + stages: ["train", "test", "val"] + save_input: True diff --git a/configs/experiment/im2im/segmentation.yaml b/configs/experiment/im2im/segmentation.yaml index 68946ffdb..64c4c3fc5 100644 --- a/configs/experiment/im2im/segmentation.yaml +++ b/configs/experiment/im2im/segmentation.yaml @@ -34,3 +34,19 @@ data: # patch_shape: [64, 64] # 3D patch_shape: [16, 32, 32] + +callbacks: + # prediction + # saving: + # _target_: cyto_dl.callbacks.ImageSaver + # save_dir: ${paths.output_dir} + # save_every_n_epochs: ${model.save_images_every_n_epochs} + # stages: ["predict"] + # save_input: False + # training + saving: + _target_: cyto_dl.callbacks.ImageSaver + save_dir: ${paths.output_dir} + save_every_n_epochs: ${model.save_images_every_n_epochs} + stages: ["train", "test", "val"] + save_input: True diff --git a/configs/experiment/im2im/segmentation_plugin.yaml b/configs/experiment/im2im/segmentation_plugin.yaml index 1e8cb38c3..f8c5a3ac5 100644 --- a/configs/experiment/im2im/segmentation_plugin.yaml +++ b/configs/experiment/im2im/segmentation_plugin.yaml @@ -61,3 +61,19 @@ model: _aux: filters: MUST_OVERRIDE overlap: 0 + +callbacks: + # prediction + # saving: + # _target_: cyto_dl.callbacks.ImageSaver + # save_dir: ${paths.output_dir} + # save_every_n_epochs: ${model.save_images_every_n_epochs} + # stages: ["predict"] + # save_input: False + # training + saving: + _target_: cyto_dl.callbacks.ImageSaver + save_dir: ${paths.output_dir} + save_every_n_epochs: ${model.save_images_every_n_epochs} + stages: ["train", "test", "val"] + save_input: True diff --git a/configs/experiment/im2im/segmentation_superres.yaml b/configs/experiment/im2im/segmentation_superres.yaml index 7db7dbba9..849ba0251 100644 --- a/configs/experiment/im2im/segmentation_superres.yaml +++ b/configs/experiment/im2im/segmentation_superres.yaml @@ -34,3 +34,19 @@ data: # patch_shape: [64, 64] # 3D patch_shape: [16, 32, 32] + +callbacks: + # prediction + # saving: + # _target_: cyto_dl.callbacks.ImageSaver + # save_dir: ${paths.output_dir} + # save_every_n_epochs: ${model.save_images_every_n_epochs} + # stages: ["predict"] + # save_input: False + # training + saving: + _target_: cyto_dl.callbacks.ImageSaver + save_dir: ${paths.output_dir} + save_every_n_epochs: ${model.save_images_every_n_epochs} + stages: ["train", "test", "val"] + save_input: True diff --git a/configs/experiment/im2im/vit_segmentation.yaml b/configs/experiment/im2im/vit_segmentation.yaml index 8c9c834b6..e0e2e4720 100644 --- a/configs/experiment/im2im/vit_segmentation.yaml +++ b/configs/experiment/im2im/vit_segmentation.yaml @@ -37,3 +37,19 @@ data: # patch_shape: [16, 16] # 3D patch_shape: [16, 16, 16] + +callbacks: + # prediction + # saving: + # _target_: cyto_dl.callbacks.ImageSaver + # save_dir: ${paths.output_dir} + # save_every_n_epochs: ${model.save_images_every_n_epochs} + # stages: ["predict"] + # save_input: False + # training + saving: + _target_: cyto_dl.callbacks.ImageSaver + save_dir: ${paths.output_dir} + save_every_n_epochs: ${model.save_images_every_n_epochs} + stages: ["train", "test", "val"] + save_input: True diff --git a/configs/model/im2im/gan.yaml b/configs/model/im2im/gan.yaml index 5e49ec94f..4b6afa7f8 100644 --- a/configs/model/im2im/gan.yaml +++ b/configs/model/im2im/gan.yaml @@ -66,7 +66,6 @@ _aux: scales: 1 reconstruction_loss: _target_: torch.nn.MSELoss - save_input: True postprocess: input: _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel diff --git a/configs/model/im2im/gan_superres.yaml b/configs/model/im2im/gan_superres.yaml index d391a491a..a7d454e58 100644 --- a/configs/model/im2im/gan_superres.yaml +++ b/configs/model/im2im/gan_superres.yaml @@ -66,7 +66,6 @@ _aux: scales: 1 reconstruction_loss: _target_: torch.nn.MSELoss - save_input: True postprocess: input: _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel diff --git a/configs/model/im2im/instance_seg.yaml b/configs/model/im2im/instance_seg.yaml index 31bc89fb4..48ad27a2e 100644 --- a/configs/model/im2im/instance_seg.yaml +++ b/configs/model/im2im/instance_seg.yaml @@ -44,7 +44,6 @@ _aux: loss: _target_: cyto_dl.models.im2im.utils.InstanceSegLoss dim: ${spatial_dims} - save_input: True postprocess: input: _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel diff --git a/configs/model/im2im/labelfree.yaml b/configs/model/im2im/labelfree.yaml index 1f0c15a8f..92655456f 100644 --- a/configs/model/im2im/labelfree.yaml +++ b/configs/model/im2im/labelfree.yaml @@ -51,4 +51,3 @@ _aux: prediction: _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel rescale_dtype: numpy.uint8 - save_input: True diff --git a/configs/model/im2im/segmentation.yaml b/configs/model/im2im/segmentation.yaml index 873bf6950..a3172567e 100644 --- a/configs/model/im2im/segmentation.yaml +++ b/configs/model/im2im/segmentation.yaml @@ -54,4 +54,3 @@ _aux: activation: _target_: torch.nn.Sigmoid rescale_dtype: numpy.uint8 - save_input: True diff --git a/configs/model/im2im/segmentation_plugin.yaml b/configs/model/im2im/segmentation_plugin.yaml index 306d02297..c6c35a2d2 100644 --- a/configs/model/im2im/segmentation_plugin.yaml +++ b/configs/model/im2im/segmentation_plugin.yaml @@ -32,8 +32,6 @@ task_heads: _target_: cyto_dl.models.im2im.utils.postprocessing.AutoThreshold method: "threshold_otsu" - save_input: True - optimizer: generator: _partial_: True diff --git a/configs/model/im2im/segmentation_superres.yaml b/configs/model/im2im/segmentation_superres.yaml index 5b175caee..6bdb3414f 100644 --- a/configs/model/im2im/segmentation_superres.yaml +++ b/configs/model/im2im/segmentation_superres.yaml @@ -1,6 +1,6 @@ _target_: cyto_dl.models.im2im.MultiTaskIm2Im -save_images_every_n_epochs: 1 +save_images_every_n_epochs: 10 save_dir: ${paths.output_dir} x_key: ${source_col} @@ -54,7 +54,6 @@ _aux: activation: _target_: torch.nn.Sigmoid rescale_dtype: numpy.uint8 - save_input: True in_channels: 1 out_channels: 1 upsample_ratio: 4 diff --git a/configs/model/im2im/vit_segmentation_decoder.yaml b/configs/model/im2im/vit_segmentation_decoder.yaml index 25366e638..62ce7e064 100644 --- a/configs/model/im2im/vit_segmentation_decoder.yaml +++ b/configs/model/im2im/vit_segmentation_decoder.yaml @@ -47,7 +47,6 @@ _aux: loss: _target_: cyto_dl.models.im2im.utils.InstanceSegLoss dim: ${spatial_dims} - save_input: True postprocess: input: _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel diff --git a/cyto_dl/callbacks/__init__.py b/cyto_dl/callbacks/__init__.py index c9dfca430..ba183d148 100644 --- a/cyto_dl/callbacks/__init__.py +++ b/cyto_dl/callbacks/__init__.py @@ -2,6 +2,7 @@ # from .callback_utils import GetKLDRanks # from .latent_walk import LatentWalk +from .image_saver import ImageSaver from .layer_freeze import LayerFreeze # raise NotImplementedError("TODO: refactor callbacks") diff --git a/cyto_dl/callbacks/image_saver.py b/cyto_dl/callbacks/image_saver.py new file mode 100644 index 000000000..4458016f9 --- /dev/null +++ b/cyto_dl/callbacks/image_saver.py @@ -0,0 +1,80 @@ +from pathlib import Path +from typing import List, Union + +from bioio.writers import OmeTiffWriter +from lightning.pytorch.callbacks import Callback + +VALID_STAGES = ("train", "val", "test", "predict") + + +class ImageSaver(Callback): + def __init__( + self, + save_dir: Union[str, Path], + save_every_n_epochs: int = 1, + stages: List[str] = ["train", "val"], + save_input: bool = False, + ): + """Callback for saving images after postprocessing by eads. + + Parameters + ---------- + save_dir: Union[str, Path] + Directory to save images + save_every_n_epochs:int=1 + Frequency to save images + stages:List[str]=["train", "val"] + Stages to save images + save_input:bool =False + Whether to save input images + """ + self.save_dir = Path(save_dir) + for stage in stages: + assert stage in VALID_STAGES, f"Invalid stage {stage}, must be one of {VALID_STAGES}" + self.save_every_n_epochs = save_every_n_epochs + self.stages = stages + self.save_keys = ["pred", "target"] + if save_input: + self.save_keys.append("input") + + def _save(self, fn, data): + fn.parent.mkdir(exist_ok=True, parents=True) + OmeTiffWriter.save(uri=fn, data=data) + + def on_predict_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0 + ): + if "predict" in self.stages: + io_map, outputs = outputs + if outputs is None: + # image has already been saved + return + for i, head_io_map in enumerate(io_map.values()): + for k, save_path in head_io_map.items(): + self._save(save_path, outputs[k]["pred"][i]) + + # train/test/val + def save(self, outputs, stage=None, step=None): + for k in self.save_keys: + for head in outputs[k]: + self._save( + self.save_dir / f"{stage}_images/{step}_{head}_{k}.tif", outputs[k][head] + ) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): + if "test" in self.stages: + # save all test outputs + self.save(outputs, "test", trainer.global_step) + + def _should_save(self, batch_idx, epoch): + return batch_idx == (epoch + 1) % self.save_every_n_epochs == 0 + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): + if "train" in self.stages and self._should_save(batch_idx, trainer.current_epoch): + self.save(outputs, "train", trainer.global_step) + + def on_validation_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0 + ): + if "val" in self.stages and self._should_save(batch_idx, trainer.current_epoch): + self.save(outputs, "val", trainer.global_step) diff --git a/cyto_dl/callbacks/outlier_detection.py b/cyto_dl/callbacks/outlier_detection.py index 789ca4090..9e66694f2 100644 --- a/cyto_dl/callbacks/outlier_detection.py +++ b/cyto_dl/callbacks/outlier_detection.py @@ -6,7 +6,7 @@ import pandas as pd from einops import reduce from lightning.pytorch.callbacks import Callback -from ostats import add_sample +from online_stats import add_sample from scipy.spatial.distance import mahalanobis diff --git a/cyto_dl/models/im2im/gan.py b/cyto_dl/models/im2im/gan.py index 0c8e72456..30d8a704a 100644 --- a/cyto_dl/models/im2im/gan.py +++ b/cyto_dl/models/im2im/gan.py @@ -141,8 +141,10 @@ def _extract_loss(self, outs, loss_type): def model_step(self, stage, batch, batch_idx): run_heads, _ = self._get_run_heads(batch, stage, batch_idx) + n_postprocess = self.get_n_postprocess_image(batch, batch_idx, stage) + batch = self._to_tensor(batch) - outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads) + outs = self.run_forward(batch, stage, n_postprocess, run_heads) loss_D = self._extract_loss(outs, "loss_D") loss_G = self._extract_loss(outs, "loss_G") @@ -157,17 +159,24 @@ def model_step(self, stage, batch, batch_idx): d_opt.zero_grad() self.manual_backward(loss_D["loss"]) d_opt.step() - loss_dict = {f"discriminator_{key}": loss for key, loss in loss_D.items()} - loss_dict.update({f"generator_{key}": loss for key, loss in loss_G.items()}) - loss_dict["loss"] = loss_dict["generator_loss"] + results = {f"discriminator_{key}": loss for key, loss in loss_D.items()} + results.update({f"generator_{key}": loss for key, loss in loss_G.items()}) + results["loss"] = results["generator_loss"] + + if n_postprocess > 0: + # add postprocessed images to return dict + for k in ("pred", "target", "input"): + results[k] = self.get_per_head(outs, k) - return loss_dict, None, None + self.compute_metrics(results, None, None, stage) + return results def predict_step(self, batch, batch_idx): stage = "predict" run_heads, io_map = self._get_run_heads(batch, stage, batch_idx) + outs = None if len(run_heads) > 0: + n_postprocess = self.get_n_postprocess_image(batch, batch_idx, stage) batch = self._to_tensor(batch) - save_image = self.should_save_image(batch_idx, stage) - self.run_forward(batch, stage, save_image, run_heads) - return io_map + outs = self.run_forward(batch, stage, n_postprocess, run_heads) + return io_map, outs diff --git a/cyto_dl/models/im2im/multi_task.py b/cyto_dl/models/im2im/multi_task.py index dabed0931..00b891f89 100644 --- a/cyto_dl/models/im2im/multi_task.py +++ b/cyto_dl/models/im2im/multi_task.py @@ -1,5 +1,5 @@ import sys -from contextlib import suppress +import warnings from pathlib import Path from typing import Dict, List, Union @@ -11,6 +11,8 @@ from cyto_dl.models.base_model import BaseModel +warnings.simplefilter("once", UserWarning) + class MultiTaskIm2Im(BaseModel): def __init__( @@ -102,19 +104,20 @@ def configure_optimizers(self): scheds.append(scheduler) return (opts, scheds) - def _train_forward(self, batch, stage, save_image, run_heads): + def _train_forward(self, batch, stage, n_postprocess, run_heads): """during training we are only dealing with patches,so we can calculate per-patch loss, metrics, postprocessing etc.""" z = self.backbone(batch[self.hparams.x_key]) return { - task: self.task_heads[task].run_head(z, batch, stage, save_image) for task in run_heads + task: self.task_heads[task].run_head(z, batch, stage, n_postprocess) + for task in run_heads } def forward(self, x, run_heads): z = self.backbone(x) return {task: self.task_heads[task](z) for task in run_heads} - def _inference_forward(self, batch, stage, save_image, run_heads): + def _inference_forward(self, batch, stage, n_postprocess, run_heads): """during inference, we need to calculate per-fov loss/metrics/postprocessing. To avoid storing and passing to each head the intermediate results of the backbone, we need @@ -133,23 +136,31 @@ def _inference_forward(self, batch, stage, save_image, run_heads): None, batch, stage, - save_image, + n_postprocess, run_forward=False, y_hat=raw_pred_images[head_name], ) for head_name in run_heads } - def run_forward(self, batch, stage, save_image, run_heads): + def run_forward(self, batch, stage, n_postprocess, run_heads): if stage in ("train", "val"): - return self._train_forward(batch, stage, save_image, run_heads) - return self._inference_forward(batch, stage, save_image, run_heads) + return self._train_forward(batch, stage, n_postprocess, run_heads) + return self._inference_forward(batch, stage, n_postprocess, run_heads) - def should_save_image(self, batch_idx, stage): - return stage in ("test", "predict") or ( - batch_idx < len(self.task_heads) # noqa: FURB124 - and (self.current_epoch + 1) % self.hparams.save_images_every_n_epochs == 0 - ) + def get_n_postprocess_image(self, batch, batch_idx, stage): + # save first batch every n epochs during train/val + if ( + stage in ("train", "val") + and batch_idx + == (self.current_epoch + 1) % self.hparams.save_images_every_n_epochs + == 0 + ): + return 1 + # postprocess all images in batch for predict/test + elif stage in ("predict", "test"): + return batch[self.hparams.x_key].shape[0] + return 0 def _sum_losses(self, losses): losses["loss"] = torch.sum(torch.stack(list(losses.values()))) @@ -181,20 +192,22 @@ def _get_run_heads(self, batch, stage, batch_idx): """Get heads that are either specified as inference heads and don't have outputs yet or all heads.""" run_heads = self.inference_heads - if stage in ("train", "val"): + if stage in ("train", "val", "test"): run_heads = [key for key in self.task_heads.keys() if key in batch] - - io_map = { - h: self.task_heads[h].generate_io_map( - batch[self.hparams.x_key].meta, stage, batch_idx, self.global_step + return run_heads, None + filenames = batch[self.hparams.x_key].meta.get("filename_or_obj", None) + if filenames is None: + warnings.warn( + 'Batch MetaTensors must have "filename_or_obj" to be saved out. Returning array prediction instead...', + UserWarning, ) - for h in run_heads - } + return run_heads, None - if stage == "predict": - # only run heads that don't have outputs yet for prediction - run_heads = self._get_unrun_heads(io_map) - io_map = self._combine_io_maps(io_map) + # IO_map is only generated for prediction + io_map = {h: self.task_heads[h].generate_io_map(filenames) for h in run_heads} + # only run heads that don't have outputs yet for prediction + run_heads = self._get_unrun_heads(io_map) + io_map = self._combine_io_maps(io_map) return run_heads, io_map @@ -205,32 +218,43 @@ def _to_tensor(self, batch): batch[k] = v.as_tensor() return batch + def get_per_head(self, outs, key): + return {head_name: head_result[key] for head_name, head_result in outs.items()} + def model_step(self, stage, batch, batch_idx): run_heads, _ = self._get_run_heads(batch, stage, batch_idx) + + n_postprocess = self.get_n_postprocess_image(batch, batch_idx, stage) batch = self._to_tensor(batch) - save_image = self.should_save_image(batch_idx, stage) - outs = self.run_forward(batch, stage, save_image, run_heads) - losses = {head_name: head_result["loss"] for head_name, head_result in outs.items()} - return self._sum_losses(losses), None, None + outs = self.run_forward(batch, stage, n_postprocess, run_heads) + # aggregate losses across heads + results = self.get_per_head(outs, "loss") + results = self._sum_losses(results) + + if n_postprocess > 0: + # add postprocessed images to return dict + for k in ("pred", "target", "input"): + results[k] = self.get_per_head(outs, k) + + self.compute_metrics(results, None, None, stage) + + return results + + def training_step(self, batch, batch_idx): + return self.model_step("train", batch, batch_idx) + + def validation_step(self, batch, batch_idx): + return self.model_step("val", batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self.model_step("test", batch, batch_idx) def predict_step(self, batch, batch_idx): stage = "predict" run_heads, io_map = self._get_run_heads(batch, stage, batch_idx) + outs = None if len(run_heads) > 0: + n_postprocess = self.get_n_postprocess_image(batch, batch_idx, stage) batch = self._to_tensor(batch) - save_image = self.should_save_image(batch_idx, stage) - self.run_forward(batch, stage, save_image, run_heads) - return io_map - - # utils for smartcache training - def on_train_start(self): - with suppress(AttributeError): - self.trainer.datamodule.train_dataloader().dataset.start() - - def on_train_epoch_end(self): - with suppress(AttributeError): - self.trainer.datamodule.train_dataloader().dataset.update_cache() - - def on_train_end(self, *args, **kwargs): - with suppress(AttributeError): - self.trainer.datamodule.train_dataloader().dataset.shutdown() + outs = self.run_forward(batch, stage, n_postprocess, run_heads) + return io_map, outs diff --git a/cyto_dl/nn/head/base_head.py b/cyto_dl/nn/head/base_head.py index 09265b157..8737f8fed 100644 --- a/cyto_dl/nn/head/base_head.py +++ b/cyto_dl/nn/head/base_head.py @@ -2,7 +2,6 @@ from pathlib import Path import torch -from bioio.writers import OmeTiffWriter from cyto_dl.models.im2im.utils.postprocessing import detach @@ -14,7 +13,6 @@ def __init__( self, loss, postprocess={"input": detach, "prediction": detach}, - save_input=False, ): """ Parameters @@ -23,15 +21,12 @@ def __init__( Loss function for task postprocess={"input": detach, "prediction": detach} Postprocessing for `input` and `predictions` of head - save_input=False - Whether to save out example input images during training """ super().__init__() self.loss = loss self.postprocess = postprocess self.model = torch.nn.Sequential(torch.nn.Identity()) - self.save_input = save_input def update_params(self, params): for k, v in params.items(): @@ -40,40 +35,23 @@ def update_params(self, params): def _calculate_loss(self, y_hat, y): return self.loss(y_hat, y) - def _postprocess(self, img, img_type): - return [self.postprocess[img_type](img[i]) for i in range(img.shape[0])] + def _postprocess(self, img, img_type, n_postprocess=1): + return [self.postprocess[img_type](img[i]) for i in range(n_postprocess)] - def generate_io_map(self, meta, stage, batch_idx, step): - """generates map between input files and output files for a head.""" - # filename is determined by step in training during train/val and by its source filename for prediction/testing - filename_map = {"input": meta.get("filename_or_obj", [batch_idx])} - if stage in ("train", "val", "test"): - out_paths = [Path(self.save_dir) / f"{stage}_images" / f"{step}_{self.head_name}.tif"] - else: - out_paths = [ - Path(self.save_dir) / self.head_name / f"{Path(fn).stem}.tif" - for fn in filename_map["input"] - ] - # create output directory if it doesn't exist - out_paths[0].parent.mkdir(exist_ok=True, parents=True) + def generate_io_map(self, input_filenames): + """generates map between input files and output files for a head. - filename_map["output"] = out_paths - self.filename_map = filename_map + Only used for prediction + """ + filename_map = {"input": input_filenames} + filename_map["output"] = [ + Path(self.save_dir) / self.head_name / f"{Path(fn).stem}.tif" + for fn in filename_map["input"] + ] + # create output directory if it doesn't exist + filename_map["output"][0].parent.mkdir(exist_ok=True, parents=True) return filename_map - def save_image(self, y_hat, batch, stage): - y_hat_out = self._postprocess(y_hat, img_type="prediction") - y_out = None - for i, out_path in enumerate(self.filename_map["output"]): - OmeTiffWriter.save(data=y_hat_out[i], uri=out_path) - if stage in ("train", "val"): - y_out = self._postprocess(batch[self.head_name], img_type="input") - OmeTiffWriter.save(data=y_out[i], uri=str(out_path).replace(".t", "_label.t")) - if self.save_input: - raw_out = self._postprocess(batch[self.x_key][i : i + 1], img_type="input") - OmeTiffWriter.save(data=raw_out, uri=str(out_path).replace(".t", "_input.t")) - return y_hat_out, y_out - def forward(self, x): return self.model(x) @@ -82,7 +60,7 @@ def run_head( backbone_features, batch, stage, - save_image, + n_postprocess=1, run_forward=True, y_hat=None, ): @@ -98,12 +76,18 @@ def run_head( if stage != "predict": loss = self._calculate_loss(y_hat, batch[self.head_name]) - y_hat_out, y_out = None, None - if save_image: - y_hat_out, y_out = self.save_image(y_hat, batch, stage) - + # no need to postprocess input and target during prediction return { "loss": loss, - "y_hat_out": y_hat_out, - "y_out": y_out, + "pred": self._postprocess(y_hat, img_type="prediction", n_postprocess=n_postprocess), + "target": self._postprocess( + batch[self.head_name], img_type="input", n_postprocess=n_postprocess + ) + if stage != "predict" + else None, + "input": self._postprocess( + batch[self.x_key], img_type="input", n_postprocess=n_postprocess + ) + if stage != "predict" + else None, } diff --git a/cyto_dl/nn/head/gan_head.py b/cyto_dl/nn/head/gan_head.py index 4043c254d..9072bb9bb 100644 --- a/cyto_dl/nn/head/gan_head.py +++ b/cyto_dl/nn/head/gan_head.py @@ -18,7 +18,6 @@ def __init__( reconstruction_loss=torch.nn.MSELoss(), reconstruction_loss_weight=100, postprocess={"input": detach, "prediction": detach}, - save_input=False, ): """ Parameters @@ -31,10 +30,8 @@ def __init__( Weighting of reconstruction loss postprocess={"input": detach, "prediction": detach} Postprocessing for `input` and `predictions` of head - save_input=False - Whether to save out example input images during training """ - super().__init__(None, postprocess, save_input) + super().__init__(None, postprocess) self.gan_loss = gan_loss self.reconstruction_loss = reconstruction_loss self.reconstruction_loss_weight = reconstruction_loss_weight @@ -61,7 +58,7 @@ def run_head( backbone_features, batch, stage, - save_image, + n_postprocess=1, discriminator=None, run_forward=True, y_hat=None, @@ -80,13 +77,18 @@ def run_head( ) loss_D, loss_G = self._calculate_loss(y_hat, batch, discriminator) - y_hat_out, y_out = None, None - if save_image: - y_hat_out, y_out = self.save_image(y_hat, batch, stage) - return { "loss_D": loss_D, "loss_G": loss_G, - "y_hat_out": y_hat_out, - "y_out": y_out, + "pred": self._postprocess(y_hat, img_type="prediction", n_postprocess=n_postprocess), + "target": self._postprocess( + batch[self.head_name], img_type="input", n_postprocess=n_postprocess + ) + if stage != "predict" + else None, + "input": self._postprocess( + batch[self.x_key], img_type="input", n_postprocess=n_postprocess + ) + if stage != "predict" + else None, } diff --git a/cyto_dl/nn/head/gan_head_superres.py b/cyto_dl/nn/head/gan_head_superres.py index d6d109d47..bd21ef45e 100644 --- a/cyto_dl/nn/head/gan_head_superres.py +++ b/cyto_dl/nn/head/gan_head_superres.py @@ -1,9 +1,7 @@ -import math from typing import Callable import numpy as np import torch -from monai.networks.blocks import DenseBlock, UnetOutBlock, UnetResBlock, UpSample from cyto_dl.models.im2im.utils.postprocessing import detach from cyto_dl.nn.losses import Pix2PixHD @@ -23,7 +21,6 @@ def __init__( reconstruction_loss=torch.nn.MSELoss(), reconstruction_loss_weight=100, postprocess={"input": detach, "prediction": detach}, - save_input=False, final_act: Callable = torch.nn.Identity(), resolution="lr", spatial_dims=3, @@ -45,8 +42,6 @@ def __init__( Weighting of reconstruction loss postprocess={"input": detach, "prediction": detach} Postprocessing for `input` and `predictions` of head - save_input=False - Whether to save out example input images during training """ ResBlocksHead.__init__( self, @@ -55,7 +50,6 @@ def __init__( out_channels=out_channels, final_act=final_act, postprocess=postprocess, - save_input=save_input, resolution=resolution, spatial_dims=spatial_dims, n_convs=n_convs, diff --git a/cyto_dl/nn/head/mae_head.py b/cyto_dl/nn/head/mae_head.py index 928cc2f9e..bda74983b 100644 --- a/cyto_dl/nn/head/mae_head.py +++ b/cyto_dl/nn/head/mae_head.py @@ -7,7 +7,7 @@ def run_head( backbone_features, batch, stage, - save_image, + n_postprocess=1, run_forward=True, y_hat=None, ): @@ -24,12 +24,9 @@ def run_head( else: loss = loss.mean() - y_hat_out, y_out = None, None - if save_image: - y_hat_out, y_out = self.save_image(y_hat, batch, stage) - return { "loss": loss, - "y_hat_out": y_hat_out, - "y_out": y_out, + "pred": y_hat, + "target": batch[self.head_name], + "input": batch[self.head_name], } diff --git a/cyto_dl/nn/head/mask_head.py b/cyto_dl/nn/head/mask_head.py index cbddcedbe..a255ec2d0 100644 --- a/cyto_dl/nn/head/mask_head.py +++ b/cyto_dl/nn/head/mask_head.py @@ -12,7 +12,6 @@ def __init__( loss, mask_key: str = "mask", postprocess={"input": detach, "prediction": detach}, - save_input=False, ): """ Parameters @@ -23,10 +22,8 @@ def __init__( Postprocessing for `input` and `predictions` of head calculate_metric=False Whether to calculate a metric during training. Not used by GAN head. - save_input=False - Whether to save out example input images during training """ - super().__init__(loss, postprocess=postprocess, save_input=save_input) + super().__init__(loss, postprocess=postprocess) self.mask_key = mask_key self.model = torch.nn.Sequential(torch.nn.Identity()) @@ -39,7 +36,7 @@ def run_head( backbone_features, batch, stage, - save_image, + n_postprocess, run_forward=True, y_hat=None, ): @@ -55,12 +52,18 @@ def run_head( if stage != "predict": loss = self._calculate_loss(y_hat, batch[self.head_name], batch[self.mask_key]) - y_hat_out, y_out = None, None - if save_image: - y_hat_out, y_out = self.save_image(y_hat, batch, stage) - + # no need to postprocess input and target during prediction return { "loss": loss, - "y_hat_out": y_hat_out, - "y_out": y_out, + "pred": self._postprocess(y_hat, img_type="prediction", n_postprocess=n_postprocess), + "target": self._postprocess( + batch[self.head_name], img_type="input", n_postprocess=n_postprocess + ) + if stage != "predict" + else None, + "input": self._postprocess( + batch[self.x_key], img_type="input", n_postprocess=n_postprocess + ) + if stage != "predict" + else None, } diff --git a/cyto_dl/nn/head/res_blocks_head.py b/cyto_dl/nn/head/res_blocks_head.py index 9c7de6082..787842cb7 100644 --- a/cyto_dl/nn/head/res_blocks_head.py +++ b/cyto_dl/nn/head/res_blocks_head.py @@ -1,7 +1,5 @@ -import math from typing import Callable -import numpy as np import torch from monai.networks.blocks import DenseBlock, UnetOutBlock, UnetResBlock, UpSample @@ -20,7 +18,6 @@ def __init__( out_channels: int, final_act: Callable = torch.nn.Identity(), postprocess={"input": detach, "prediction": detach}, - save_input=False, resolution="lr", spatial_dims=3, n_convs=1, @@ -41,8 +38,6 @@ def __init__( Final activation applied to logits postprocess={"input": detach, "prediction": detach} Postprocessing functions for ground truth and model predictions - save_input=False - Whether to save raw image examples during training resolution="lr" Resolution of output image. If `lr`, no upsampling is done. If `hr`, `upsample_method` and `upsample_ratio` are used to determine how to perform upsampling. @@ -61,7 +56,7 @@ def __init__( dense=False Whether to use dense connections between convolutional layers """ - super().__init__(loss, postprocess, save_input) + super().__init__(loss, postprocess) self.resolution = resolution conv_input_channels = in_channels diff --git a/cyto_dl/nn/head/vic_reg.py b/cyto_dl/nn/head/vic_reg.py index 8b65aa614..25ca5685f 100644 --- a/cyto_dl/nn/head/vic_reg.py +++ b/cyto_dl/nn/head/vic_reg.py @@ -1,7 +1,5 @@ from typing import List -from torch import nn - from cyto_dl.nn import MLP from cyto_dl.nn.head import BaseHead @@ -28,7 +26,7 @@ def run_head( backbone_features, batch, stage, - save_image=False, + n_postprocess=1, run_forward=True, y_hat=None, ): diff --git a/pyproject.toml b/pyproject.toml index 69783749c..95f251f88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,8 @@ dependencies = [ "bioio>=1.0.1", "bioio-czi", "bioio-ome-tiff", - "bioio-tifffile" + "bioio-tifffile", + "online-stats>=2023", ] requires-python = ">=3.9,<3.11"