Skip to content

Commit

Permalink
Feature/image saving callback (#407)
Browse files Browse the repository at this point in the history
* refactor to do image saving in callback

* update segmentation config

* add gan

* add gan

* add instance seg

* update configs

* add model configs

* add n_postprocess to mask head
;

* remove save input from res blocks head

* simplify image saving;

* update vicreg and mae heads

* remove multimae head

* update plugin exp.

* precommit

* add warning

* precommit

* add ostats

* add callback to init

* update ostats

---------

Co-authored-by: Benjamin Morris <[email protected]>
Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
3 people authored Aug 1, 2024
1 parent c65dbbf commit bd15888
Show file tree
Hide file tree
Showing 30 changed files with 370 additions and 150 deletions.
15 changes: 15 additions & 0 deletions configs/experiment/im2im/gan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,18 @@ callbacks:

early_stopping:
monitor: val/loss/generator_loss

# prediction
# saving:
# _target_: cyto_dl.callbacks.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
15 changes: 15 additions & 0 deletions configs/experiment/im2im/gan_superres.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,18 @@ callbacks:

early_stopping:
monitor: val/loss/generator_loss

# prediction
# saving:
# _target_: cyto_dl.callbacks.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
15 changes: 15 additions & 0 deletions configs/experiment/im2im/instance_seg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,18 @@ data:
# patch_shape: [64, 64]
# 3D
patch_shape: [16, 32, 32]

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.ImageSaver
# save_dir: ${paths.output_dir}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
16 changes: 16 additions & 0 deletions configs/experiment/im2im/labelfree.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,19 @@ data:
# patch_shape: [64, 64]
# 3D
patch_shape: [16, 32, 32]

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
16 changes: 16 additions & 0 deletions configs/experiment/im2im/mae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,19 @@ data:
# patch_shape: [16, 16]
# 3D
patch_shape: [16, 16, 16]

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
16 changes: 16 additions & 0 deletions configs/experiment/im2im/segmentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,19 @@ data:
# patch_shape: [64, 64]
# 3D
patch_shape: [16, 32, 32]

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
16 changes: 16 additions & 0 deletions configs/experiment/im2im/segmentation_plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,19 @@ model:
_aux:
filters: MUST_OVERRIDE
overlap: 0

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
16 changes: 16 additions & 0 deletions configs/experiment/im2im/segmentation_superres.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,19 @@ data:
# patch_shape: [64, 64]
# 3D
patch_shape: [16, 32, 32]

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
16 changes: 16 additions & 0 deletions configs/experiment/im2im/vit_segmentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,19 @@ data:
# patch_shape: [16, 16]
# 3D
patch_shape: [16, 16, 16]

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
1 change: 0 additions & 1 deletion configs/model/im2im/gan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ _aux:
scales: 1
reconstruction_loss:
_target_: torch.nn.MSELoss
save_input: True
postprocess:
input:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
Expand Down
1 change: 0 additions & 1 deletion configs/model/im2im/gan_superres.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ _aux:
scales: 1
reconstruction_loss:
_target_: torch.nn.MSELoss
save_input: True
postprocess:
input:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
Expand Down
1 change: 0 additions & 1 deletion configs/model/im2im/instance_seg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ _aux:
loss:
_target_: cyto_dl.models.im2im.utils.InstanceSegLoss
dim: ${spatial_dims}
save_input: True
postprocess:
input:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
Expand Down
1 change: 0 additions & 1 deletion configs/model/im2im/labelfree.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,3 @@ _aux:
prediction:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
rescale_dtype: numpy.uint8
save_input: True
1 change: 0 additions & 1 deletion configs/model/im2im/segmentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,3 @@ _aux:
activation:
_target_: torch.nn.Sigmoid
rescale_dtype: numpy.uint8
save_input: True
2 changes: 0 additions & 2 deletions configs/model/im2im/segmentation_plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ task_heads:
_target_: cyto_dl.models.im2im.utils.postprocessing.AutoThreshold
method: "threshold_otsu"

save_input: True

optimizer:
generator:
_partial_: True
Expand Down
3 changes: 1 addition & 2 deletions configs/model/im2im/segmentation_superres.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_target_: cyto_dl.models.im2im.MultiTaskIm2Im

save_images_every_n_epochs: 1
save_images_every_n_epochs: 10
save_dir: ${paths.output_dir}

x_key: ${source_col}
Expand Down Expand Up @@ -54,7 +54,6 @@ _aux:
activation:
_target_: torch.nn.Sigmoid
rescale_dtype: numpy.uint8
save_input: True
in_channels: 1
out_channels: 1
upsample_ratio: 4
Expand Down
1 change: 0 additions & 1 deletion configs/model/im2im/vit_segmentation_decoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ _aux:
loss:
_target_: cyto_dl.models.im2im.utils.InstanceSegLoss
dim: ${spatial_dims}
save_input: True
postprocess:
input:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
Expand Down
1 change: 1 addition & 0 deletions cyto_dl/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# from .callback_utils import GetKLDRanks
# from .latent_walk import LatentWalk

from .image_saver import ImageSaver
from .layer_freeze import LayerFreeze

# raise NotImplementedError("TODO: refactor callbacks")
Expand Down
80 changes: 80 additions & 0 deletions cyto_dl/callbacks/image_saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from pathlib import Path
from typing import List, Union

from bioio.writers import OmeTiffWriter
from lightning.pytorch.callbacks import Callback

VALID_STAGES = ("train", "val", "test", "predict")


class ImageSaver(Callback):
def __init__(
self,
save_dir: Union[str, Path],
save_every_n_epochs: int = 1,
stages: List[str] = ["train", "val"],
save_input: bool = False,
):
"""Callback for saving images after postprocessing by eads.
Parameters
----------
save_dir: Union[str, Path]
Directory to save images
save_every_n_epochs:int=1
Frequency to save images
stages:List[str]=["train", "val"]
Stages to save images
save_input:bool =False
Whether to save input images
"""
self.save_dir = Path(save_dir)
for stage in stages:
assert stage in VALID_STAGES, f"Invalid stage {stage}, must be one of {VALID_STAGES}"
self.save_every_n_epochs = save_every_n_epochs
self.stages = stages
self.save_keys = ["pred", "target"]
if save_input:
self.save_keys.append("input")

def _save(self, fn, data):
fn.parent.mkdir(exist_ok=True, parents=True)
OmeTiffWriter.save(uri=fn, data=data)

def on_predict_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
):
if "predict" in self.stages:
io_map, outputs = outputs
if outputs is None:
# image has already been saved
return
for i, head_io_map in enumerate(io_map.values()):
for k, save_path in head_io_map.items():
self._save(save_path, outputs[k]["pred"][i])

# train/test/val
def save(self, outputs, stage=None, step=None):
for k in self.save_keys:
for head in outputs[k]:
self._save(
self.save_dir / f"{stage}_images/{step}_{head}_{k}.tif", outputs[k][head]
)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
if "test" in self.stages:
# save all test outputs
self.save(outputs, "test", trainer.global_step)

def _should_save(self, batch_idx, epoch):
return batch_idx == (epoch + 1) % self.save_every_n_epochs == 0

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
if "train" in self.stages and self._should_save(batch_idx, trainer.current_epoch):
self.save(outputs, "train", trainer.global_step)

def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
):
if "val" in self.stages and self._should_save(batch_idx, trainer.current_epoch):
self.save(outputs, "val", trainer.global_step)
2 changes: 1 addition & 1 deletion cyto_dl/callbacks/outlier_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
from einops import reduce
from lightning.pytorch.callbacks import Callback
from ostats import add_sample
from online_stats import add_sample
from scipy.spatial.distance import mahalanobis


Expand Down
25 changes: 17 additions & 8 deletions cyto_dl/models/im2im/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,10 @@ def _extract_loss(self, outs, loss_type):

def model_step(self, stage, batch, batch_idx):
run_heads, _ = self._get_run_heads(batch, stage, batch_idx)
n_postprocess = self.get_n_postprocess_image(batch, batch_idx, stage)

batch = self._to_tensor(batch)
outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads)
outs = self.run_forward(batch, stage, n_postprocess, run_heads)

loss_D = self._extract_loss(outs, "loss_D")
loss_G = self._extract_loss(outs, "loss_G")
Expand All @@ -157,17 +159,24 @@ def model_step(self, stage, batch, batch_idx):
d_opt.zero_grad()
self.manual_backward(loss_D["loss"])
d_opt.step()
loss_dict = {f"discriminator_{key}": loss for key, loss in loss_D.items()}
loss_dict.update({f"generator_{key}": loss for key, loss in loss_G.items()})
loss_dict["loss"] = loss_dict["generator_loss"]
results = {f"discriminator_{key}": loss for key, loss in loss_D.items()}
results.update({f"generator_{key}": loss for key, loss in loss_G.items()})
results["loss"] = results["generator_loss"]

if n_postprocess > 0:
# add postprocessed images to return dict
for k in ("pred", "target", "input"):
results[k] = self.get_per_head(outs, k)

return loss_dict, None, None
self.compute_metrics(results, None, None, stage)
return results

def predict_step(self, batch, batch_idx):
stage = "predict"
run_heads, io_map = self._get_run_heads(batch, stage, batch_idx)
outs = None
if len(run_heads) > 0:
n_postprocess = self.get_n_postprocess_image(batch, batch_idx, stage)
batch = self._to_tensor(batch)
save_image = self.should_save_image(batch_idx, stage)
self.run_forward(batch, stage, save_image, run_heads)
return io_map
outs = self.run_forward(batch, stage, n_postprocess, run_heads)
return io_map, outs
Loading

0 comments on commit bd15888

Please sign in to comment.