Skip to content

Commit

Permalink
Bugfix/transforms (#384)
Browse files Browse the repository at this point in the history
* allow comma separated lists for args

* make torch/numpy agnostic

* convert keys to list

* lint

---------

Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
benjijamorris and Benjamin Morris authored Jul 2, 2024
1 parent 3cbdd91 commit 4b4e37a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
9 changes: 7 additions & 2 deletions cyto_dl/image/io/aicsimage_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
scene_key : str = "scene"
Key for the scene number
kwargs_keys : List = ["dimension_order_out", "C", "T"]
Keys for the kwargs to pass to BioImage.get_image_dask_data
Keys for the kwargs to pass to BioImage.get_image_dask_data. Values in the csv can be comma separated list.
out_key : str = "raw"
Key for the output image
allow_missing_keys : bool = False
Expand All @@ -49,6 +49,11 @@ def __init__(
self.dtype = dtype
self.dask_load = dask_load

def split_args(self, arg):
if "," in str(arg):
return list(map(int, arg.split(",")))
return arg

def __call__(self, data):
# copying prevents the dataset from being modified inplace - important when using partially cached datasets so that the memory use doesn't increase over time
data = data.copy()
Expand All @@ -58,7 +63,7 @@ def __call__(self, data):
img = BioImage(path)
if self.scene_key in data:
img.set_scene(data[self.scene_key])
kwargs = {k: data[k] for k in self.kwargs_keys}
kwargs = {k: self.split_args(data[k]) for k in self.kwargs_keys if k in data}
if self.dask_load:
img = img.get_image_dask_data(**kwargs).compute()
else:
Expand Down
2 changes: 1 addition & 1 deletion cyto_dl/image/transforms/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __call__(self, img):
low = percentile(img, low)
high = percentile(img, high)

return torch.clip(img, low, high)
return clip(img, low, high)


class Clipd(Transform):
Expand Down
3 changes: 2 additions & 1 deletion cyto_dl/image/transforms/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
from monai.transforms import Transform
from omegaconf import ListConfig


class MaxProjectd(Transform):
Expand All @@ -24,7 +25,7 @@ def __init__(
Whether to raise error if specified key is missing
"""
super().__init__()
self.keys = keys
self.keys = keys if isinstance(keys, (list, ListConfig)) else [keys]
self.allow_missing_keys = allow_missing_keys
self.projection_dim = projection_dim

Expand Down

0 comments on commit 4b4e37a

Please sign in to comment.