From a9be90c353c6176cff3f4a1a5e739584f60e9097 Mon Sep 17 00:00:00 2001 From: benjijamorris <54606172+benjijamorris@users.noreply.github.com> Date: Mon, 25 Sep 2023 13:53:07 -0700 Subject: [PATCH] Feature/update transforms (#293) * off by one * add non-dict resize * fix configs * delete resize --------- Co-authored-by: Benjamin Morris --- configs/data/im2im/gan.yaml | 35 +++----- configs/data/im2im/labelfree.yaml | 30 +++---- configs/data/im2im/omnipose.yaml | 30 +++---- configs/data/im2im/segmentation.yaml | 30 +++---- configs/data/im2im/skoots.yaml | 30 +++---- cyto_dl/image/transforms/__init__.py | 2 +- .../image/transforms/multiscale_cropper.py | 2 +- cyto_dl/image/transforms/resize.py | 82 ------------------- 8 files changed, 62 insertions(+), 179 deletions(-) delete mode 100644 cyto_dl/image/transforms/resize.py diff --git a/configs/data/im2im/gan.yaml b/configs/data/im2im/gan.yaml index 4f90dd1b5..ac10ff891 100644 --- a/configs/data/im2im/gan.yaml +++ b/configs/data/im2im/gan.yaml @@ -33,12 +33,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: monai.transforms.ToTensord keys: ${data.columns} # GANs use Tanh as final activation, target has to be in range [-1,1] @@ -82,12 +79,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: monai.transforms.ToTensord keys: ${data.columns} # GANs use Tanh as final activation, target has to be in range [-1,1] @@ -118,11 +112,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${source_col} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: monai.transforms.ToTensord keys: ${source_col} # input to synthetic image generation model is a semantic segmentation @@ -153,12 +145,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: monai.transforms.ToTensord keys: ${data.columns} # GANs use Tanh as final activation, target has to be in range [-1,1] diff --git a/configs/data/im2im/labelfree.yaml b/configs/data/im2im/labelfree.yaml index 6992c802f..922768a2f 100644 --- a/configs/data/im2im/labelfree.yaml +++ b/configs/data/im2im/labelfree.yaml @@ -31,12 +31,9 @@ transforms: C: 5 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: monai.transforms.ToTensord keys: ${data.columns} - _target_: monai.transforms.NormalizeIntensityd @@ -81,12 +78,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: monai.transforms.ToTensord keys: ${data.columns} - _target_: monai.transforms.NormalizeIntensityd @@ -104,6 +98,9 @@ transforms: C: 5 - _target_: monai.transforms.AddChanneld keys: ${data.columns} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: monai.transforms.ToTensord keys: ${source_col} - _target_: monai.transforms.NormalizeIntensityd @@ -130,12 +127,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: monai.transforms.ToTensord keys: ${data.columns} - _target_: monai.transforms.NormalizeIntensityd diff --git a/configs/data/im2im/omnipose.yaml b/configs/data/im2im/omnipose.yaml index 3e391401c..14f078937 100644 --- a/configs/data/im2im/omnipose.yaml +++ b/configs/data/im2im/omnipose.yaml @@ -31,12 +31,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: cyto_dl.models.im2im.utils.omnipose.OmniposePreprocessd label_keys: ${target_col} dim: ${spatial_dims} @@ -84,12 +81,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: cyto_dl.models.im2im.utils.omnipose.OmniposePreprocessd label_keys: ${target_col} dim: ${spatial_dims} @@ -111,6 +105,9 @@ transforms: C: 5 - _target_: monai.transforms.AddChanneld keys: ${data.columns} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: monai.transforms.ToTensord keys: ${source_col} - _target_: monai.transforms.NormalizeIntensityd @@ -137,12 +134,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: cyto_dl.models.im2im.utils.omnipose.OmniposePreprocessd label_keys: ${target_col} dim: ${spatial_dims} diff --git a/configs/data/im2im/segmentation.yaml b/configs/data/im2im/segmentation.yaml index 2015cf01c..6e60e9420 100644 --- a/configs/data/im2im/segmentation.yaml +++ b/configs/data/im2im/segmentation.yaml @@ -31,12 +31,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: monai.transforms.ToTensord keys: ${data.columns} - _target_: monai.transforms.NormalizeIntensityd @@ -86,12 +83,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: monai.transforms.ToTensord keys: ${data.columns} - _target_: monai.transforms.NormalizeIntensityd @@ -115,6 +109,9 @@ transforms: C: 5 - _target_: monai.transforms.AddChanneld keys: ${source_col} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: monai.transforms.ToTensord keys: ${source_col} - _target_: monai.transforms.NormalizeIntensityd @@ -141,12 +138,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: monai.transforms.ToTensord keys: ${data.columns} - _target_: monai.transforms.NormalizeIntensityd diff --git a/configs/data/im2im/skoots.yaml b/configs/data/im2im/skoots.yaml index 796968185..819661d31 100644 --- a/configs/data/im2im/skoots.yaml +++ b/configs/data/im2im/skoots.yaml @@ -31,12 +31,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: cyto_dl.models.im2im.utils.SkootsPreprocessd label_keys: ${target_col} dim: ${spatial_dims} @@ -84,12 +81,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: cyto_dl.models.im2im.utils.SkootsPreprocessd label_keys: ${target_col} dim: ${spatial_dims} @@ -111,6 +105,9 @@ transforms: C: 5 - _target_: monai.transforms.AddChanneld keys: ${data.columns} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: monai.transforms.ToTensord keys: ${source_col} - _target_: monai.transforms.NormalizeIntensityd @@ -137,12 +134,9 @@ transforms: C: 0 - _target_: monai.transforms.AddChanneld keys: ${data.columns} - - _target_: cyto_dl.image.transforms.Resized - keys: - - ${source_col} - - ${target_col} - scale_factor: 0.25 - spatial_dims: ${spatial_dims} + - _target_: monai.transforms.Zoomd + keys: ${data.columns} + zoom: 0.25 - _target_: cyto_dl.models.im2im.utils.SkootsPreprocessd label_keys: ${target_col} dim: ${spatial_dims} diff --git a/cyto_dl/image/transforms/__init__.py b/cyto_dl/image/transforms/__init__.py index d96597ebd..a5aa08a6f 100644 --- a/cyto_dl/image/transforms/__init__.py +++ b/cyto_dl/image/transforms/__init__.py @@ -2,7 +2,7 @@ from .contrastadjust import ContrastAdjustd from .multiscale_cropper import RandomMultiScaleCropd from .project import MaxProjectd -from .resize import Resized +from .resize import Resize, Resized from .save import Save, Saved try: diff --git a/cyto_dl/image/transforms/multiscale_cropper.py b/cyto_dl/image/transforms/multiscale_cropper.py index 7babed2bd..c25388475 100644 --- a/cyto_dl/image/transforms/multiscale_cropper.py +++ b/cyto_dl/image/transforms/multiscale_cropper.py @@ -109,7 +109,7 @@ def _get_max_start_indices(self, image_dict: Dict): max_start_indices = np.minimum(max_start_indices_img, max_start_indices) if np.any(max_start_indices < 0): raise ValueError(f"Crop size {roi_size} is too large for image size {shape}") - return max_start_indices + return max_start_indices + 1 # range doesn't include end def generate_slices(self, image_dict: Dict) -> Dict: """Generate dictionary of slices at all scales starting at random point.""" diff --git a/cyto_dl/image/transforms/resize.py b/cyto_dl/image/transforms/resize.py deleted file mode 100644 index 8175ac817..000000000 --- a/cyto_dl/image/transforms/resize.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import Sequence, Union - -import numpy as np -import torch -from monai.data.meta_tensor import MetaTensor -from monai.transforms import Transform - - -class Resized(Transform): - """Transform to resize image by`scale_factor`""" - - def __init__( - self, - keys: Sequence[str], - scale_factor: Union[float, Sequence[float]], - spatial_dims: int = 3, - mode: str = "nearest-exact", - align_corners: Union[bool, None] = None, - recompute_scale_factor: bool = False, - antialias: bool = False, - allow_missing_keys: bool = False, - ): - """ - Parameters - ---------- - key: str - name of images to resize - scale_factor: int - output size will be `img.shape*scale_factor` - spatial_dims: int - whether inputs are 2d or 3d - mode: - interpolation method. For more details see: - https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html # noqa - align_corners: - see https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html # noqa - recompute_scale_factor: - see https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html # noqa - antialias: - see https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html # noqa - """ - super().__init__() - assert spatial_dims in (2, 3), f"Patch must be 2D or 3D, got {spatial_dims}" - self.keys = keys - self.scale_factor = np.asarray(scale_factor) - assert self.scale_factor.size in ( - 1, - spatial_dims, - ), f"Scale factor must have length 1 or {spatial_dims}, got {len(self.scale_factor)}" - if self.scale_factor.size == 1: - self.scale_factor = np.tile(self.scale_factor, spatial_dims) - self.spatial_dims = spatial_dims - self.mode = mode - self.align_corners = align_corners - self.recompute_scale_factor = recompute_scale_factor - self.antialias = antialias - self.allow_missing_keys = allow_missing_keys - - def __call__(self, img): - for key in self.keys: - if key in img.keys(): - out_size = list( - map( - round, - np.asarray(img[key].shape[-self.spatial_dims :]) * self.scale_factor, - ) - ) - raw_img = img[key] - if len(raw_img.shape) != self.spatial_dims + 1: - raise ValueError("Images must have CZYX or CYX dimensions") - raw_img = raw_img.as_tensor() if isinstance(raw_img, MetaTensor) else raw_img - - img[key] = torch.nn.functional.interpolate( - input=raw_img.unsqueeze(0), - size=out_size, - mode=self.mode, - align_corners=self.align_corners, - antialias=self.antialias, - ).squeeze(0) - elif not self.allow_missing_keys: - raise KeyError(f"Key {key} not found in data. Available keys are {img.keys()}") - return img