From 3ffc8072466dbc9ad69b661a44afde47ad442d8a Mon Sep 17 00:00:00 2001 From: benjijamorris <54606172+benjijamorris@users.noreply.github.com> Date: Mon, 18 Mar 2024 14:06:34 -0700 Subject: [PATCH] add dask_load option (#353) Co-authored-by: Benjamin Morris --- cyto_dl/datamodules/czi.py | 9 ++++++++- cyto_dl/image/io/aicsimage_loader.py | 10 +++++++++- cyto_dl/image/io/monai_bio_reader.py | 14 ++++++++++++-- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/cyto_dl/datamodules/czi.py b/cyto_dl/datamodules/czi.py index 6c5a455a3..1d6071138 100644 --- a/cyto_dl/datamodules/czi.py +++ b/cyto_dl/datamodules/czi.py @@ -25,6 +25,7 @@ def __init__( time_stop_column: str = "stop", time_step_column: str = "step", transform: Optional[Callable] = None, + dask_load: bool = True, ): """ Parameters @@ -53,6 +54,8 @@ def __init__( If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted. transform: Optional[Callable] = None Callable to that accepts numpy array. For example, image normalization functions could be passed here. + dask_load: bool = True + Whether to use dask to load images. If False, full images are loaded into memory before extracting specified scenes/timepoints. """ super().__init__(None, transform) df = pd.read_csv(csv_path) @@ -66,6 +69,7 @@ def __init__( if spatial_dims not in (2, 3): raise ValueError(f"`spatial_dims` must be 2 or 3, got {spatial_dims}") self.spatial_dims = spatial_dims + self.dask_load = dask_load self.img_data = self.get_per_file_args(df) @@ -126,7 +130,10 @@ def _transform(self, index: int): original_path = img_data.pop("original_path") scene = img_data.pop("scene") img.set_scene(scene) - data_i = img.get_image_dask_data(**img_data).compute() + if self.dask_load: + data_i = img.get_image_dask_data(**img_data).compute() + else: + data_i = img.get_image_data(**img_data) img_data["scene"] = scene data_i = self._ensure_channel_first(data_i) output_img = ( diff --git a/cyto_dl/image/io/aicsimage_loader.py b/cyto_dl/image/io/aicsimage_loader.py index 07625ca73..281aee730 100644 --- a/cyto_dl/image/io/aicsimage_loader.py +++ b/cyto_dl/image/io/aicsimage_loader.py @@ -22,6 +22,7 @@ def __init__( out_key: str = "raw", allow_missing_keys=False, dtype: np.dtype = np.float16, + dask_load: bool = True, ): """ Parameters @@ -36,6 +37,8 @@ def __init__( Key for the output image allow_missing_keys : bool = False Whether to allow missing keys in the data dictionary + dask_load: bool = True + Whether to use dask to load images. If False, full images are loaded into memory before extracting specified scenes/timepoints. """ super().__init__() self.path_key = path_key @@ -44,6 +47,7 @@ def __init__( self.out_key = out_key self.scene_key = scene_key self.dtype = dtype + self.dask_load = dask_load def __call__(self, data): # copying prevents the dataset from being modified inplace - important when using partially cached datasets so that the memory use doesn't increase over time @@ -55,7 +59,11 @@ def __call__(self, data): if self.scene_key in data: img.set_scene(data[self.scene_key]) kwargs = {k: data[k] for k in self.kwargs_keys} - img = img.get_image_dask_data(**kwargs).compute().astype(self.dtype) + if self.dask_load: + img = img.get_image_dask_data(**kwargs).compute() + else: + img = img.get_image_data(**kwargs) + img = img.astype(self.dtype) data[self.out_key] = MetaTensor(img, meta={"filename_or_obj": path, "kwargs": kwargs}) return data diff --git a/cyto_dl/image/io/monai_bio_reader.py b/cyto_dl/image/io/monai_bio_reader.py index b70083fe7..31195a996 100644 --- a/cyto_dl/image/io/monai_bio_reader.py +++ b/cyto_dl/image/io/monai_bio_reader.py @@ -13,13 +13,20 @@ @require_pkg(pkg_name="aicsimageio") class MonaiBioReader(ImageReader): - def __init__(self, **reader_kwargs): + def __init__(self, dask_load: bool = True, **reader_kwargs): + """ + dask_load: bool = True + Whether to use dask to load images. If False, full images are loaded into memory before extracting specified scenes/timepoints. + reader_kwargs: Dict + Additional keyword arguments to pass to AICSImage.get_image_data or AICSImage.get_image_dask_data + """ super().__init__() self.reader_kwargs = { k: OmegaConf.to_container(v) if isinstance(v, Container) else v for k, v in reader_kwargs.items() if v is not None } + self.dask_load = dask_load def read(self, data: Union[Sequence[PathLike], PathLike]): filenames: Sequence[PathLike] = ensure_tuple(data) @@ -33,7 +40,10 @@ def get_data(self, img) -> Tuple[np.ndarray, Dict]: img_array: List[np.ndarray] = [] for img_obj in ensure_tuple(img): - data = img_obj.get_image_dask_data(**self.reader_kwargs).compute() + if self.dask_load: + data = img_obj.get_image_dask_data(**self.reader_kwargs).compute() + else: + data = img_obj.get_image_data(**self.reader_kwargs) img_array.append(data) return _stack_images(img_array, {}), {}