Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/image saving callback #407

Merged
merged 19 commits into from
Aug 1, 2024
15 changes: 15 additions & 0 deletions configs/experiment/im2im/gan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,18 @@ callbacks:

early_stopping:
monitor: val/loss/generator_loss

# prediction
# saving:
# _target_: cyto_dl.callbacks.image_saver.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.image_saver.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
15 changes: 15 additions & 0 deletions configs/experiment/im2im/gan_superres.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,18 @@ callbacks:

early_stopping:
monitor: val/loss/generator_loss

# prediction
# saving:
# _target_: cyto_dl.callbacks.image_saver.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.image_saver.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
15 changes: 15 additions & 0 deletions configs/experiment/im2im/instance_seg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,18 @@ data:
# patch_shape: [64, 64]
# 3D
patch_shape: [16, 32, 32]

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.image_saver.ImageSaver
# save_dir: ${paths.output_dir}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.image_saver.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
16 changes: 16 additions & 0 deletions configs/experiment/im2im/labelfree.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,19 @@ data:
# patch_shape: [64, 64]
# 3D
patch_shape: [16, 32, 32]

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.image_saver.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.image_saver.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
16 changes: 16 additions & 0 deletions configs/experiment/im2im/mae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,19 @@ data:
# patch_shape: [16, 16]
# 3D
patch_shape: [16, 16, 16]

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.image_saver.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.image_saver.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
16 changes: 16 additions & 0 deletions configs/experiment/im2im/segmentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,19 @@ data:
# patch_shape: [64, 64]
# 3D
patch_shape: [16, 32, 32]

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.image_saver.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the saving key here? Shouldnt callbacks just be a list of callbacks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they all have a key :

model_checkpoint:

_target_: cyto_dl.callbacks.image_saver.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
16 changes: 16 additions & 0 deletions configs/experiment/im2im/segmentation_plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,19 @@ model:
_aux:
filters: MUST_OVERRIDE
overlap: 0

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.image_saver.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.image_saver.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
16 changes: 16 additions & 0 deletions configs/experiment/im2im/segmentation_superres.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,19 @@ data:
# patch_shape: [64, 64]
# 3D
patch_shape: [16, 32, 32]

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.image_saver.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.image_saver.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
16 changes: 16 additions & 0 deletions configs/experiment/im2im/vit_segmentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,19 @@ data:
# patch_shape: [16, 16]
# 3D
patch_shape: [16, 16, 16]

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.image_saver.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.image_saver.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
save_input: True
1 change: 0 additions & 1 deletion configs/model/im2im/gan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ _aux:
scales: 1
reconstruction_loss:
_target_: torch.nn.MSELoss
save_input: True
postprocess:
input:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
Expand Down
1 change: 0 additions & 1 deletion configs/model/im2im/gan_superres.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ _aux:
scales: 1
reconstruction_loss:
_target_: torch.nn.MSELoss
save_input: True
postprocess:
input:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
Expand Down
1 change: 0 additions & 1 deletion configs/model/im2im/instance_seg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ _aux:
loss:
_target_: cyto_dl.models.im2im.utils.InstanceSegLoss
dim: ${spatial_dims}
save_input: True
postprocess:
input:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
Expand Down
1 change: 0 additions & 1 deletion configs/model/im2im/labelfree.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,3 @@ _aux:
prediction:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
rescale_dtype: numpy.uint8
save_input: True
1 change: 0 additions & 1 deletion configs/model/im2im/segmentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,3 @@ _aux:
activation:
_target_: torch.nn.Sigmoid
rescale_dtype: numpy.uint8
save_input: True
2 changes: 0 additions & 2 deletions configs/model/im2im/segmentation_plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ task_heads:
_target_: cyto_dl.models.im2im.utils.postprocessing.AutoThreshold
method: "threshold_otsu"

save_input: True

optimizer:
generator:
_partial_: True
Expand Down
3 changes: 1 addition & 2 deletions configs/model/im2im/segmentation_superres.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_target_: cyto_dl.models.im2im.MultiTaskIm2Im

save_images_every_n_epochs: 1
save_images_every_n_epochs: 10
save_dir: ${paths.output_dir}

x_key: ${source_col}
Expand Down Expand Up @@ -54,7 +54,6 @@ _aux:
activation:
_target_: torch.nn.Sigmoid
rescale_dtype: numpy.uint8
save_input: True
in_channels: 1
out_channels: 1
upsample_ratio: 4
Expand Down
1 change: 0 additions & 1 deletion configs/model/im2im/vit_segmentation_decoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ _aux:
loss:
_target_: cyto_dl.models.im2im.utils.InstanceSegLoss
dim: ${spatial_dims}
save_input: True
postprocess:
input:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
Expand Down
80 changes: 80 additions & 0 deletions cyto_dl/callbacks/image_saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from pathlib import Path
from typing import List, Union

from bioio.writers import OmeTiffWriter
from lightning.pytorch.callbacks import Callback

VALID_STAGES = ("train", "val", "test", "predict")


class ImageSaver(Callback):
def __init__(
self,
save_dir: Union[str, Path],
save_every_n_epochs: int = 1,
stages: List[str] = ["train", "val"],
save_input: bool = False,
):
"""Callback for saving images after postprocessing by eads.

Parameters
----------
save_dir: Union[str, Path]
Directory to save images
save_every_n_epochs:int=1
Frequency to save images
stages:List[str]=["train", "val"]
Stages to save images
save_input:bool =False
Whether to save input images
"""
self.save_dir = Path(save_dir)
for stage in stages:
assert stage in VALID_STAGES, f"Invalid stage {stage}, must be one of {VALID_STAGES}"
self.save_every_n_epochs = save_every_n_epochs
self.stages = stages
self.save_keys = ["pred", "target"]
if save_input:
self.save_keys.append("input")

def _save(self, fn, data):
fn.parent.mkdir(exist_ok=True, parents=True)
OmeTiffWriter.save(uri=fn, data=data)

def on_predict_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
):
if "predict" in self.stages:
io_map, outputs = outputs
if outputs is None:
# image has already been saved
return
for i, head_io_map in enumerate(io_map.values()):
for k, save_path in head_io_map.items():
self._save(save_path, outputs[k]["pred"][i])

# train/test/val
def save(self, outputs, stage=None, step=None):
for k in self.save_keys:
for head in outputs[k]:
self._save(
self.save_dir / f"{stage}_images/{step}_{head}_{k}.tif", outputs[k][head]
)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
if "test" in self.stages:
# save all test outputs
self.save(outputs, "test", trainer.global_step)

def _should_save(self, batch_idx, epoch):
return batch_idx == (epoch + 1) % self.save_every_n_epochs == 0

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
if "train" in self.stages and self._should_save(batch_idx, trainer.current_epoch):
self.save(outputs, "train", trainer.global_step)

def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
):
if "val" in self.stages and self._should_save(batch_idx, trainer.current_epoch):
self.save(outputs, "val", trainer.global_step)
25 changes: 17 additions & 8 deletions cyto_dl/models/im2im/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,10 @@ def _extract_loss(self, outs, loss_type):

def model_step(self, stage, batch, batch_idx):
run_heads, _ = self._get_run_heads(batch, stage, batch_idx)
n_postprocess = self.get_n_postprocess_image(batch, batch_idx, stage)

batch = self._to_tensor(batch)
outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads)
outs = self.run_forward(batch, stage, n_postprocess, run_heads)

loss_D = self._extract_loss(outs, "loss_D")
loss_G = self._extract_loss(outs, "loss_G")
Expand All @@ -157,17 +159,24 @@ def model_step(self, stage, batch, batch_idx):
d_opt.zero_grad()
self.manual_backward(loss_D["loss"])
d_opt.step()
loss_dict = {f"discriminator_{key}": loss for key, loss in loss_D.items()}
loss_dict.update({f"generator_{key}": loss for key, loss in loss_G.items()})
loss_dict["loss"] = loss_dict["generator_loss"]
results = {f"discriminator_{key}": loss for key, loss in loss_D.items()}
results.update({f"generator_{key}": loss for key, loss in loss_G.items()})
results["loss"] = results["generator_loss"]

if n_postprocess > 0:
# add postprocessed images to return dict
for k in ("pred", "target", "input"):
results[k] = self.get_per_head(outs, k)

return loss_dict, None, None
self.compute_metrics(results, None, None, stage)
return results

def predict_step(self, batch, batch_idx):
stage = "predict"
run_heads, io_map = self._get_run_heads(batch, stage, batch_idx)
outs = None
if len(run_heads) > 0:
n_postprocess = self.get_n_postprocess_image(batch, batch_idx, stage)
batch = self._to_tensor(batch)
save_image = self.should_save_image(batch_idx, stage)
self.run_forward(batch, stage, save_image, run_heads)
return io_map
outs = self.run_forward(batch, stage, n_postprocess, run_heads)
return io_map, outs
Loading
Loading