Skip to content

Commit

Permalink
add dask_load option (#353)
Browse files Browse the repository at this point in the history
Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
benjijamorris and Benjamin Morris authored Mar 18, 2024
1 parent 61d2015 commit 3ffc807
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
9 changes: 8 additions & 1 deletion cyto_dl/datamodules/czi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
time_stop_column: str = "stop",
time_step_column: str = "step",
transform: Optional[Callable] = None,
dask_load: bool = True,
):
"""
Parameters
Expand Down Expand Up @@ -53,6 +54,8 @@ def __init__(
If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted.
transform: Optional[Callable] = None
Callable to that accepts numpy array. For example, image normalization functions could be passed here.
dask_load: bool = True
Whether to use dask to load images. If False, full images are loaded into memory before extracting specified scenes/timepoints.
"""
super().__init__(None, transform)
df = pd.read_csv(csv_path)
Expand All @@ -66,6 +69,7 @@ def __init__(
if spatial_dims not in (2, 3):
raise ValueError(f"`spatial_dims` must be 2 or 3, got {spatial_dims}")
self.spatial_dims = spatial_dims
self.dask_load = dask_load

self.img_data = self.get_per_file_args(df)

Expand Down Expand Up @@ -126,7 +130,10 @@ def _transform(self, index: int):
original_path = img_data.pop("original_path")
scene = img_data.pop("scene")
img.set_scene(scene)
data_i = img.get_image_dask_data(**img_data).compute()
if self.dask_load:
data_i = img.get_image_dask_data(**img_data).compute()
else:
data_i = img.get_image_data(**img_data)
img_data["scene"] = scene
data_i = self._ensure_channel_first(data_i)
output_img = (
Expand Down
10 changes: 9 additions & 1 deletion cyto_dl/image/io/aicsimage_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
out_key: str = "raw",
allow_missing_keys=False,
dtype: np.dtype = np.float16,
dask_load: bool = True,
):
"""
Parameters
Expand All @@ -36,6 +37,8 @@ def __init__(
Key for the output image
allow_missing_keys : bool = False
Whether to allow missing keys in the data dictionary
dask_load: bool = True
Whether to use dask to load images. If False, full images are loaded into memory before extracting specified scenes/timepoints.
"""
super().__init__()
self.path_key = path_key
Expand All @@ -44,6 +47,7 @@ def __init__(
self.out_key = out_key
self.scene_key = scene_key
self.dtype = dtype
self.dask_load = dask_load

def __call__(self, data):
# copying prevents the dataset from being modified inplace - important when using partially cached datasets so that the memory use doesn't increase over time
Expand All @@ -55,7 +59,11 @@ def __call__(self, data):
if self.scene_key in data:
img.set_scene(data[self.scene_key])
kwargs = {k: data[k] for k in self.kwargs_keys}
img = img.get_image_dask_data(**kwargs).compute().astype(self.dtype)
if self.dask_load:
img = img.get_image_dask_data(**kwargs).compute()
else:
img = img.get_image_data(**kwargs)
img = img.astype(self.dtype)
data[self.out_key] = MetaTensor(img, meta={"filename_or_obj": path, "kwargs": kwargs})

return data
14 changes: 12 additions & 2 deletions cyto_dl/image/io/monai_bio_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,20 @@

@require_pkg(pkg_name="aicsimageio")
class MonaiBioReader(ImageReader):
def __init__(self, **reader_kwargs):
def __init__(self, dask_load: bool = True, **reader_kwargs):
"""
dask_load: bool = True
Whether to use dask to load images. If False, full images are loaded into memory before extracting specified scenes/timepoints.
reader_kwargs: Dict
Additional keyword arguments to pass to AICSImage.get_image_data or AICSImage.get_image_dask_data
"""
super().__init__()
self.reader_kwargs = {
k: OmegaConf.to_container(v) if isinstance(v, Container) else v
for k, v in reader_kwargs.items()
if v is not None
}
self.dask_load = dask_load

def read(self, data: Union[Sequence[PathLike], PathLike]):
filenames: Sequence[PathLike] = ensure_tuple(data)
Expand All @@ -33,7 +40,10 @@ def get_data(self, img) -> Tuple[np.ndarray, Dict]:
img_array: List[np.ndarray] = []

for img_obj in ensure_tuple(img):
data = img_obj.get_image_dask_data(**self.reader_kwargs).compute()
if self.dask_load:
data = img_obj.get_image_dask_data(**self.reader_kwargs).compute()
else:
data = img_obj.get_image_data(**self.reader_kwargs)
img_array.append(data)

return _stack_images(img_array, {}), {}
Expand Down

0 comments on commit 3ffc807

Please sign in to comment.