diff --git a/configs/data/im2im/gan.yaml b/configs/data/im2im/gan.yaml index ac10ff891..3b4e7bde0 100644 --- a/configs/data/im2im/gan.yaml +++ b/configs/data/im2im/gan.yaml @@ -31,7 +31,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -77,7 +78,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -110,7 +112,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${source_col} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -143,7 +146,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} diff --git a/configs/data/im2im/labelfree.yaml b/configs/data/im2im/labelfree.yaml index 922768a2f..469603184 100644 --- a/configs/data/im2im/labelfree.yaml +++ b/configs/data/im2im/labelfree.yaml @@ -29,7 +29,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 5 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -76,7 +77,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -96,7 +98,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 5 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -125,7 +128,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} diff --git a/configs/data/im2im/omnipose.yaml b/configs/data/im2im/omnipose.yaml index 14f078937..8000e53df 100644 --- a/configs/data/im2im/omnipose.yaml +++ b/configs/data/im2im/omnipose.yaml @@ -29,7 +29,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -79,7 +80,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -103,7 +105,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 5 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -132,7 +135,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} diff --git a/configs/data/im2im/segmentation.yaml b/configs/data/im2im/segmentation.yaml index 6e60e9420..a430a68d7 100644 --- a/configs/data/im2im/segmentation.yaml +++ b/configs/data/im2im/segmentation.yaml @@ -29,7 +29,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -81,7 +82,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -107,7 +109,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 5 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${source_col} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -136,7 +139,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} diff --git a/configs/data/im2im/skoots.yaml b/configs/data/im2im/skoots.yaml index 819661d31..4b916b116 100644 --- a/configs/data/im2im/skoots.yaml +++ b/configs/data/im2im/skoots.yaml @@ -29,7 +29,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -79,7 +80,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -103,7 +105,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 5 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} @@ -132,7 +135,8 @@ transforms: - _target_: cyto_dl.image.io.MonaiBioReader dimension_order_out: "ZYX" C: 0 - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: ${data.columns} - _target_: monai.transforms.Zoomd keys: ${data.columns} diff --git a/configs/data/private/variance_npm1.yaml b/configs/data/private/variance_npm1.yaml index 432e2d7a6..8386544f2 100644 --- a/configs/data/private/variance_npm1.yaml +++ b/configs/data/private/variance_npm1.yaml @@ -31,7 +31,8 @@ transforms: threshold: -2. above: true cval: -2. - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: image - _target_: monai.transforms.ToTensord keys: [image] diff --git a/configs/experiment/private/lattice_nuc_sdf.yaml b/configs/experiment/private/lattice_nuc_sdf.yaml index cb0d26dc4..ae42a8226 100644 --- a/configs/experiment/private/lattice_nuc_sdf.yaml +++ b/configs/experiment/private/lattice_nuc_sdf.yaml @@ -33,7 +33,8 @@ data: - _target_: monai.transforms.Transposed keys: [image] indices: [2, 1, 0] - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: image - _target_: monai.transforms.ThresholdIntensityd keys: [image] diff --git a/configs/experiment/private/npm1_classical_scale_inv.yaml b/configs/experiment/private/npm1_classical_scale_inv.yaml index 9eb1933e9..327c6e03b 100644 --- a/configs/experiment/private/npm1_classical_scale_inv.yaml +++ b/configs/experiment/private/npm1_classical_scale_inv.yaml @@ -30,7 +30,8 @@ data: - _target_: cyto_dl.image.io.ReadNumpyFile keys: [image] remote: false - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: image - _target_: monai.transforms.NormalizeIntensityd keys: image diff --git a/configs/experiment/private/npm1_scale_inv_canon_so3.yaml b/configs/experiment/private/npm1_scale_inv_canon_so3.yaml index 34f8082ce..5b629e734 100644 --- a/configs/experiment/private/npm1_scale_inv_canon_so3.yaml +++ b/configs/experiment/private/npm1_scale_inv_canon_so3.yaml @@ -30,7 +30,8 @@ data: - _target_: cyto_dl.image.io.ReadNumpyFile keys: [image] remote: false - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: image - _target_: monai.transforms.NormalizeIntensityd keys: image diff --git a/configs/experiment/private/npm1_so2_scale_inv.yaml b/configs/experiment/private/npm1_so2_scale_inv.yaml index 310e64c17..2e80b216d 100644 --- a/configs/experiment/private/npm1_so2_scale_inv.yaml +++ b/configs/experiment/private/npm1_so2_scale_inv.yaml @@ -33,7 +33,8 @@ data: - _target_: cyto_dl.image.io.ReadNumpyFile keys: [image] remote: false - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: image # - _target_: monai.transforms.ThresholdIntensityd # keys: image diff --git a/configs/experiment/private/npm1_so3_scale_inv.yaml b/configs/experiment/private/npm1_so3_scale_inv.yaml index 237b3ac4c..1b1d5cae5 100644 --- a/configs/experiment/private/npm1_so3_scale_inv.yaml +++ b/configs/experiment/private/npm1_so3_scale_inv.yaml @@ -30,7 +30,8 @@ data: - _target_: cyto_dl.image.io.ReadNumpyFile keys: [image] remote: false - - _target_: monai.transforms.AddChanneld + - _target_: monai.transforms.EnsureChannelFirstd + channel_dim: "no_channel" keys: image - _target_: monai.transforms.NormalizeIntensityd keys: image diff --git a/cyto_dl/callbacks/outlier_detection.py b/cyto_dl/callbacks/outlier_detection.py index ef706ff0e..789ca4090 100644 --- a/cyto_dl/callbacks/outlier_detection.py +++ b/cyto_dl/callbacks/outlier_detection.py @@ -118,7 +118,7 @@ def on_predict_epoch_start(self, trainer, pl_module): def _inference_batch_end(self, batch): if self._run: - batch_names = batch["raw_meta_dict"]["filename_or_obj"] + batch_names = batch["raw"].meta["filename_or_obj"] # activations are saved per-patch distances_per_image = len(self.mahalanobis_distances[self.layer_names[0]]) // len( batch_names diff --git a/cyto_dl/datamodules/czi.py b/cyto_dl/datamodules/czi.py index 30ba4f9ed..dc1f582f2 100644 --- a/cyto_dl/datamodules/czi.py +++ b/cyto_dl/datamodules/czi.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd from aicsimageio.aics_image import AICSImage -from monai.data import DataLoader, Dataset +from monai.data import DataLoader, Dataset, MetaTensor from monai.transforms import Compose, apply_transform from omegaconf import ListConfig @@ -129,13 +129,17 @@ def _transform(self, index: int): img.set_scene(img_data.pop("scene")) data_i = img.get_image_dask_data(**img_data).compute() data_i = self._ensure_channel_first(data_i) + output_img = ( + apply_transform(self.transform, data_i) if self.transform is not None else data_i + ) + return { - self.out_key: apply_transform(self.transform, data_i) - if self.transform is not None - else data_i, - f"{self.out_key}_meta_dict": { - "filename_or_obj": original_path.replace(".", self._metadata_to_str(img_data)) - }, + self.out_key: MetaTensor( + output_img, + meta={ + "filename_or_obj": original_path.replace(".", self._metadata_to_str(img_data)) + }, + ) } def __len__(self): diff --git a/cyto_dl/image/io/aicsimage_loader.py b/cyto_dl/image/io/aicsimage_loader.py index 43ce6d0ac..eb40a1277 100644 --- a/cyto_dl/image/io/aicsimage_loader.py +++ b/cyto_dl/image/io/aicsimage_loader.py @@ -1,6 +1,7 @@ from typing import List from aicsimageio import AICSImage +from monai.data import MetaTensor from monai.transforms import Transform @@ -52,7 +53,6 @@ def __call__(self, data): img.set_scene(data[self.scene_key]) kwargs = {k: data[k] for k in self.kwargs_keys} img = img.get_image_dask_data(**kwargs).compute() - data[self.out_key] = img - data[f"{self.out_key}_meta_dict"] = {"filename_or_obj": path, "kwargs": kwargs} + data[self.out_key] = MetaTensor(img, meta={"filename_or_obj": path, "kwargs": kwargs}) return data diff --git a/cyto_dl/models/im2im/gan.py b/cyto_dl/models/im2im/gan.py index 12a2d5f94..b1e40b7a0 100644 --- a/cyto_dl/models/im2im/gan.py +++ b/cyto_dl/models/im2im/gan.py @@ -174,6 +174,7 @@ def _extract_loss(self, outs, loss_type): return self._sum_losses(loss) def model_step(self, stage, batch, batch_idx): + batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"] # convert monai metatensors to tensors for k, v in batch.items(): if isinstance(v, MetaTensor): @@ -210,6 +211,7 @@ def model_step(self, stage, batch, batch_idx): return loss_dict, None, None def predict_step(self, batch, batch_idx): + batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"] # convert monai metatensors to tensors for k, v in batch.items(): if isinstance(v, MetaTensor): diff --git a/cyto_dl/models/im2im/multi_task.py b/cyto_dl/models/im2im/multi_task.py index 88ff6ddc5..24a9180c3 100644 --- a/cyto_dl/models/im2im/multi_task.py +++ b/cyto_dl/models/im2im/multi_task.py @@ -168,6 +168,7 @@ def _get_run_heads(self, batch, stage): def model_step(self, stage, batch, batch_idx): # convert monai metatensors to tensors + batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"] for k, v in batch.items(): if isinstance(v, MetaTensor): batch[k] = v.as_tensor() @@ -175,17 +176,13 @@ def model_step(self, stage, batch, batch_idx): run_heads = self._get_run_heads(batch, stage) outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads) - if stage != "predict": - losses = {head_name: head_result["loss"] for head_name, head_result in outs.items()} - losses = self._sum_losses(losses) - return losses, None, None - - preds = {head_name: head_result["y_hat_out"] for head_name, head_result in outs.items()} - - return None, preds, None + losses = {head_name: head_result["loss"] for head_name, head_result in outs.items()} + losses = self._sum_losses(losses) + return losses, None, None def predict_step(self, batch, batch_idx): stage = "predict" + batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"] # convert monai metatensors to tensors for k, v in batch.items(): if isinstance(v, MetaTensor): diff --git a/cyto_dl/nn/head/base_head.py b/cyto_dl/nn/head/base_head.py index d65e7242c..179f5ad30 100644 --- a/cyto_dl/nn/head/base_head.py +++ b/cyto_dl/nn/head/base_head.py @@ -68,15 +68,13 @@ def save_image(self, y_hat, batch, stage, global_step): if self.save_raw: raw_out = self._postprocess(batch[self.x_key], img_type="input") try: - metadata_filenames = batch[f"{self.x_key}_meta_dict"]["filename_or_obj"] + metadata_filenames = batch["filenames"] filename_map = {"input": metadata_filenames, "output": []} metadata_filenames = [ f"{Path(fn).stem}_{self.head_name}.tif" for fn in metadata_filenames ] except KeyError: - raise ValueError( - f"Please ensure your batches contain key `{self.x_key}_meta_dict['filename_or_obj']`" - ) + raise ValueError("Please ensure your batches have key `filenames`") save_name = ( [f"{global_step}_{self.head_name}.tif"] if stage in ("train", "val") diff --git a/tests/test_alternating_batch_sampler.py b/tests/test_alternating_batch_sampler.py index 48acdb7e0..929590828 100644 --- a/tests/test_alternating_batch_sampler.py +++ b/tests/test_alternating_batch_sampler.py @@ -33,12 +33,14 @@ def __iter__(self): def test_alternating_batch_sampler(shuffle, sampler): root_dir = pyrootutils.find_root() transforms = monai.transforms.Compose( - RemoveNaNKeysd(), - monai.transforms.LoadImaged( - keys=["raw", "seg1", "seg2"], - reader=MonaiBioReader(dimension_order_out="CZYX"), - allow_missing_keys=True, - ), + [ + RemoveNaNKeysd(), + monai.transforms.LoadImaged( + keys=["raw", "seg1", "seg2"], + reader=MonaiBioReader(dimension_order_out="CZYX"), + allow_missing_keys=True, + ), + ] ) transform_dict = {key: transforms for key in ["train", "test", "val", "predict"]} data = make_multiple_dataframe_splits(root_dir / "tests" / "resources", transform_dict)