Skip to content

Commit

Permalink
update to new monai (#297)
Browse files Browse the repository at this point in the history
* addchanneld->ensurechannelfirst

* update metatensor metadata handling

* add brackets

---------

Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
benjijamorris and Benjamin Morris authored Oct 19, 2023
1 parent b5d1fe2 commit e5c64e9
Show file tree
Hide file tree
Showing 18 changed files with 83 additions and 54 deletions.
12 changes: 8 additions & 4 deletions configs/data/im2im/gan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down Expand Up @@ -77,7 +78,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down Expand Up @@ -110,7 +112,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${source_col}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down Expand Up @@ -143,7 +146,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down
12 changes: 8 additions & 4 deletions configs/data/im2im/labelfree.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 5
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down Expand Up @@ -76,7 +77,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand All @@ -96,7 +98,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 5
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down Expand Up @@ -125,7 +128,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down
12 changes: 8 additions & 4 deletions configs/data/im2im/omnipose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down Expand Up @@ -79,7 +80,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand All @@ -103,7 +105,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 5
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down Expand Up @@ -132,7 +135,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down
12 changes: 8 additions & 4 deletions configs/data/im2im/segmentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down Expand Up @@ -81,7 +82,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand All @@ -107,7 +109,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 5
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${source_col}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down Expand Up @@ -136,7 +139,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down
12 changes: 8 additions & 4 deletions configs/data/im2im/skoots.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down Expand Up @@ -79,7 +80,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand All @@ -103,7 +105,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 5
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down Expand Up @@ -132,7 +135,8 @@ transforms:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: "ZYX"
C: 0
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: ${data.columns}
- _target_: monai.transforms.Zoomd
keys: ${data.columns}
Expand Down
3 changes: 2 additions & 1 deletion configs/data/private/variance_npm1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ transforms:
threshold: -2.
above: true
cval: -2.
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: image
- _target_: monai.transforms.ToTensord
keys: [image]
Expand Down
3 changes: 2 additions & 1 deletion configs/experiment/private/lattice_nuc_sdf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ data:
- _target_: monai.transforms.Transposed
keys: [image]
indices: [2, 1, 0]
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: image
- _target_: monai.transforms.ThresholdIntensityd
keys: [image]
Expand Down
3 changes: 2 additions & 1 deletion configs/experiment/private/npm1_classical_scale_inv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ data:
- _target_: cyto_dl.image.io.ReadNumpyFile
keys: [image]
remote: false
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: image
- _target_: monai.transforms.NormalizeIntensityd
keys: image
Expand Down
3 changes: 2 additions & 1 deletion configs/experiment/private/npm1_scale_inv_canon_so3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ data:
- _target_: cyto_dl.image.io.ReadNumpyFile
keys: [image]
remote: false
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: image
- _target_: monai.transforms.NormalizeIntensityd
keys: image
Expand Down
3 changes: 2 additions & 1 deletion configs/experiment/private/npm1_so2_scale_inv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ data:
- _target_: cyto_dl.image.io.ReadNumpyFile
keys: [image]
remote: false
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: image
# - _target_: monai.transforms.ThresholdIntensityd
# keys: image
Expand Down
3 changes: 2 additions & 1 deletion configs/experiment/private/npm1_so3_scale_inv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ data:
- _target_: cyto_dl.image.io.ReadNumpyFile
keys: [image]
remote: false
- _target_: monai.transforms.AddChanneld
- _target_: monai.transforms.EnsureChannelFirstd
channel_dim: "no_channel"
keys: image
- _target_: monai.transforms.NormalizeIntensityd
keys: image
Expand Down
2 changes: 1 addition & 1 deletion cyto_dl/callbacks/outlier_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def on_predict_epoch_start(self, trainer, pl_module):

def _inference_batch_end(self, batch):
if self._run:
batch_names = batch["raw_meta_dict"]["filename_or_obj"]
batch_names = batch["raw"].meta["filename_or_obj"]
# activations are saved per-patch
distances_per_image = len(self.mahalanobis_distances[self.layer_names[0]]) // len(
batch_names
Expand Down
18 changes: 11 additions & 7 deletions cyto_dl/datamodules/czi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pandas as pd
from aicsimageio.aics_image import AICSImage
from monai.data import DataLoader, Dataset
from monai.data import DataLoader, Dataset, MetaTensor
from monai.transforms import Compose, apply_transform
from omegaconf import ListConfig

Expand Down Expand Up @@ -129,13 +129,17 @@ def _transform(self, index: int):
img.set_scene(img_data.pop("scene"))
data_i = img.get_image_dask_data(**img_data).compute()
data_i = self._ensure_channel_first(data_i)
output_img = (
apply_transform(self.transform, data_i) if self.transform is not None else data_i
)

return {
self.out_key: apply_transform(self.transform, data_i)
if self.transform is not None
else data_i,
f"{self.out_key}_meta_dict": {
"filename_or_obj": original_path.replace(".", self._metadata_to_str(img_data))
},
self.out_key: MetaTensor(
output_img,
meta={
"filename_or_obj": original_path.replace(".", self._metadata_to_str(img_data))
},
)
}

def __len__(self):
Expand Down
4 changes: 2 additions & 2 deletions cyto_dl/image/io/aicsimage_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List

from aicsimageio import AICSImage
from monai.data import MetaTensor
from monai.transforms import Transform


Expand Down Expand Up @@ -52,7 +53,6 @@ def __call__(self, data):
img.set_scene(data[self.scene_key])
kwargs = {k: data[k] for k in self.kwargs_keys}
img = img.get_image_dask_data(**kwargs).compute()
data[self.out_key] = img
data[f"{self.out_key}_meta_dict"] = {"filename_or_obj": path, "kwargs": kwargs}
data[self.out_key] = MetaTensor(img, meta={"filename_or_obj": path, "kwargs": kwargs})

return data
2 changes: 2 additions & 0 deletions cyto_dl/models/im2im/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def _extract_loss(self, outs, loss_type):
return self._sum_losses(loss)

def model_step(self, stage, batch, batch_idx):
batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"]
# convert monai metatensors to tensors
for k, v in batch.items():
if isinstance(v, MetaTensor):
Expand Down Expand Up @@ -210,6 +211,7 @@ def model_step(self, stage, batch, batch_idx):
return loss_dict, None, None

def predict_step(self, batch, batch_idx):
batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"]
# convert monai metatensors to tensors
for k, v in batch.items():
if isinstance(v, MetaTensor):
Expand Down
13 changes: 5 additions & 8 deletions cyto_dl/models/im2im/multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,24 +168,21 @@ def _get_run_heads(self, batch, stage):

def model_step(self, stage, batch, batch_idx):
# convert monai metatensors to tensors
batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"]
for k, v in batch.items():
if isinstance(v, MetaTensor):
batch[k] = v.as_tensor()

run_heads = self._get_run_heads(batch, stage)
outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads)

if stage != "predict":
losses = {head_name: head_result["loss"] for head_name, head_result in outs.items()}
losses = self._sum_losses(losses)
return losses, None, None

preds = {head_name: head_result["y_hat_out"] for head_name, head_result in outs.items()}

return None, preds, None
losses = {head_name: head_result["loss"] for head_name, head_result in outs.items()}
losses = self._sum_losses(losses)
return losses, None, None

def predict_step(self, batch, batch_idx):
stage = "predict"
batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"]
# convert monai metatensors to tensors
for k, v in batch.items():
if isinstance(v, MetaTensor):
Expand Down
6 changes: 2 additions & 4 deletions cyto_dl/nn/head/base_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,13 @@ def save_image(self, y_hat, batch, stage, global_step):
if self.save_raw:
raw_out = self._postprocess(batch[self.x_key], img_type="input")
try:
metadata_filenames = batch[f"{self.x_key}_meta_dict"]["filename_or_obj"]
metadata_filenames = batch["filenames"]
filename_map = {"input": metadata_filenames, "output": []}
metadata_filenames = [
f"{Path(fn).stem}_{self.head_name}.tif" for fn in metadata_filenames
]
except KeyError:
raise ValueError(
f"Please ensure your batches contain key `{self.x_key}_meta_dict['filename_or_obj']`"
)
raise ValueError("Please ensure your batches have key `filenames`")
save_name = (
[f"{global_step}_{self.head_name}.tif"]
if stage in ("train", "val")
Expand Down
Loading

0 comments on commit e5c64e9

Please sign in to comment.