From 4b299a193091a8db4635aa59acaed9bfea3670c2 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 21 Nov 2024 14:53:54 -0800 Subject: [PATCH 1/3] explicit target shape argument in the HCS data module --- viscy/data/hcs.py | 74 +++++++++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 28 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 8831bdd7..37148d69 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -274,33 +274,51 @@ def __getitem__(self, index: int) -> Sample: class HCSDataModule(LightningDataModule): - """Lightning data module for a preprocessed HCS NGFF Store. - - :param str data_path: path to the data store - :param str | Sequence[str] source_channel: name(s) of the source channel, - e.g. ``'Phase'`` - :param str | Sequence[str] target_channel: name(s) of the target channel, - e.g. ``['Nuclei', 'Membrane']`` - :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param float split_ratio: split ratio of the training subset in the fit stage, - e.g. 0.8 means a 80/20 split between training/validation, - by default 0.8 - :param int batch_size: batch size, defaults to 16 - :param int num_workers: number of data-loading workers, defaults to 8 - :param Literal["2D", "UNeXt2", "2.5D", "3D"] architecture: U-Net architecture, - defaults to "2.5D" - :param tuple[int, int] yx_patch_size: patch size in (Y, X), - defaults to (256, 256) - :param list[MapTransform] normalizations: MONAI dictionary transforms - applied to selected channels, defaults to [] (no normalization) - :param list[MapTransform] augmentations: MONAI dictionary transforms - applied to the training set, defaults to [] (no augmentation) - :param bool caching: whether to decompress all the images and cache the result, - will store in ``/tmp/$SLURM_JOB_ID/`` if available, - defaults to False - :param Path | None ground_truth_masks: path to the ground truth masks, + """ + Lightning data module for a preprocessed HCS NGFF Store. + + Parameters + ---------- + data_path : str + Path to the data store. + source_channel : str or Sequence[str] + Name(s) of the source channel, e.g. 'Phase'. + target_channel : str or Sequence[str] + Name(s) of the target channel, e.g. ['Nuclei', 'Membrane']. + z_window_size : int + Z window size of the 2.5D U-Net, 1 for 2D. + split_ratio : float, optional + Split ratio of the training subset in the fit stage, + e.g. 0.8 means an 80/20 split between training/validation, + by default 0.8. + batch_size : int, optional + Batch size, defaults to 16. + num_workers : int, optional + Number of data-loading workers, defaults to 8. + architecture : Literal["2D", "UNeXt2", "2.5D", "3D"], optional + U-Net architecture, defaults to "2.5D". + yx_patch_size : tuple[int, int], optional + Patch size in (Y, X), defaults to (256, 256). + normalizations : list of MapTransform, optional + MONAI dictionary transforms applied to selected channels, + defaults to ``[]`` (no normalization). + augmentations : list of MapTransform, optional + MONAI dictionary transforms applied to the training set, + defaults to ``[]`` (no augmentation). + caching : bool, optional + Whether to decompress all the images and cache the result, + will store in `/tmp/$SLURM_JOB_ID/` if available, + defaults to False. + ground_truth_masks : Path or None, optional + Path to the ground truth masks, used in the test stage to compute segmentation metrics, - defaults to None + defaults to None. + persistent_workers : bool, optional + Whether to keep the workers alive between fitting epochs, + defaults to False. + prefetch_factor : int or None, optional + Number of samples loaded in advance by each worker during fitting, + defaults to None (2 per PyTorch default). """ def __init__( @@ -312,7 +330,7 @@ def __init__( split_ratio: float = 0.8, batch_size: int = 16, num_workers: int = 8, - architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"] = "2.5D", + target_2d: bool = False, yx_patch_size: tuple[int, int] = (256, 256), normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], @@ -327,7 +345,7 @@ def __init__( self.target_channel = _ensure_channel_list(target_channel) self.batch_size = batch_size self.num_workers = num_workers - self.target_2d = False if architecture in ["UNeXt2", "3D", "fcmae"] else True + self.target_2d = target_2d self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size From dd28f33abb620b4ef9bc6ec9524a16bebcd1f41e Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 21 Nov 2024 14:58:02 -0800 Subject: [PATCH 2/3] update docstring --- viscy/data/hcs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 37148d69..7f736066 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -295,8 +295,9 @@ class HCSDataModule(LightningDataModule): Batch size, defaults to 16. num_workers : int, optional Number of data-loading workers, defaults to 8. - architecture : Literal["2D", "UNeXt2", "2.5D", "3D"], optional - U-Net architecture, defaults to "2.5D". + target_2d : bool, optional + Whether the target is 2D (e.g. in a 2.5D model), + defaults to False. yx_patch_size : tuple[int, int], optional Patch size in (Y, X), defaults to (256, 256). normalizations : list of MapTransform, optional From aff24961ea31fc638e794f8b276b5d9e61774059 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 21 Nov 2024 15:13:30 -0800 Subject: [PATCH 3/3] update test cases --- tests/data/test_data.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/data/test_data.py b/tests/data/test_data.py index a75f8da8..c71488c4 100644 --- a/tests/data/test_data.py +++ b/tests/data/test_data.py @@ -58,7 +58,7 @@ def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentatio batch_size=batch_size, num_workers=0, augmentations=transforms, - architecture="3D", + target_2d=False, split_ratio=split_ratio, yx_patch_size=yx_patch_size, ) @@ -78,9 +78,9 @@ def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentatio ) -def test_datamodule_setup_predict(preprocessed_hcs_dataset): +@mark.parametrize("z_window_size", [1, 5]) +def test_datamodule_setup_predict(preprocessed_hcs_dataset, z_window_size): data_path = preprocessed_hcs_dataset - z_window_size = 5 channel_split = 2 with open_ome_zarr(data_path) as dataset: channel_names = dataset.channel_names @@ -91,6 +91,7 @@ def test_datamodule_setup_predict(preprocessed_hcs_dataset): source_channel=channel_names[:channel_split], target_channel=channel_names[channel_split:], z_window_size=z_window_size, + target_2d=bool(z_window_size == 1), batch_size=2, num_workers=0, )