From 55026e4d1458024de88f1aba6e5fab4eeae283b0 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 27 Nov 2024 14:01:40 -0800 Subject: [PATCH] Explicit target shape argument in the HCS data module (#212) * explicit target shape argument in the HCS data module * update docstring * update test cases --- tests/data/test_data.py | 7 ++-- viscy/data/hcs.py | 75 ++++++++++++++++++++++++++--------------- 2 files changed, 51 insertions(+), 31 deletions(-) diff --git a/tests/data/test_data.py b/tests/data/test_data.py index a75f8da8..c71488c4 100644 --- a/tests/data/test_data.py +++ b/tests/data/test_data.py @@ -58,7 +58,7 @@ def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentatio batch_size=batch_size, num_workers=0, augmentations=transforms, - architecture="3D", + target_2d=False, split_ratio=split_ratio, yx_patch_size=yx_patch_size, ) @@ -78,9 +78,9 @@ def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentatio ) -def test_datamodule_setup_predict(preprocessed_hcs_dataset): +@mark.parametrize("z_window_size", [1, 5]) +def test_datamodule_setup_predict(preprocessed_hcs_dataset, z_window_size): data_path = preprocessed_hcs_dataset - z_window_size = 5 channel_split = 2 with open_ome_zarr(data_path) as dataset: channel_names = dataset.channel_names @@ -91,6 +91,7 @@ def test_datamodule_setup_predict(preprocessed_hcs_dataset): source_channel=channel_names[:channel_split], target_channel=channel_names[channel_split:], z_window_size=z_window_size, + target_2d=bool(z_window_size == 1), batch_size=2, num_workers=0, ) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index c4087941..88111cc3 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -274,33 +274,52 @@ def __getitem__(self, index: int) -> Sample: class HCSDataModule(LightningDataModule): - """Lightning data module for a preprocessed HCS NGFF Store. - - :param str data_path: path to the data store - :param str | Sequence[str] source_channel: name(s) of the source channel, - e.g. ``'Phase'`` - :param str | Sequence[str] target_channel: name(s) of the target channel, - e.g. ``['Nuclei', 'Membrane']`` - :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param float split_ratio: split ratio of the training subset in the fit stage, - e.g. 0.8 means a 80/20 split between training/validation, - by default 0.8 - :param int batch_size: batch size, defaults to 16 - :param int num_workers: number of data-loading workers, defaults to 8 - :param Literal["2D", "UNeXt2", "2.5D", "3D"] architecture: U-Net architecture, - defaults to "2.5D" - :param tuple[int, int] yx_patch_size: patch size in (Y, X), - defaults to (256, 256) - :param list[MapTransform] normalizations: MONAI dictionary transforms - applied to selected channels, defaults to [] (no normalization) - :param list[MapTransform] augmentations: MONAI dictionary transforms - applied to the training set, defaults to [] (no augmentation) - :param bool caching: whether to decompress all the images and cache the result, - will store in ``/tmp/$SLURM_JOB_ID/`` if available, - defaults to False - :param Path | None ground_truth_masks: path to the ground truth masks, + """ + Lightning data module for a preprocessed HCS NGFF Store. + + Parameters + ---------- + data_path : str + Path to the data store. + source_channel : str or Sequence[str] + Name(s) of the source channel, e.g. 'Phase'. + target_channel : str or Sequence[str] + Name(s) of the target channel, e.g. ['Nuclei', 'Membrane']. + z_window_size : int + Z window size of the 2.5D U-Net, 1 for 2D. + split_ratio : float, optional + Split ratio of the training subset in the fit stage, + e.g. 0.8 means an 80/20 split between training/validation, + by default 0.8. + batch_size : int, optional + Batch size, defaults to 16. + num_workers : int, optional + Number of data-loading workers, defaults to 8. + target_2d : bool, optional + Whether the target is 2D (e.g. in a 2.5D model), + defaults to False. + yx_patch_size : tuple[int, int], optional + Patch size in (Y, X), defaults to (256, 256). + normalizations : list of MapTransform, optional + MONAI dictionary transforms applied to selected channels, + defaults to ``[]`` (no normalization). + augmentations : list of MapTransform, optional + MONAI dictionary transforms applied to the training set, + defaults to ``[]`` (no augmentation). + caching : bool, optional + Whether to decompress all the images and cache the result, + will store in `/tmp/$SLURM_JOB_ID/` if available, + defaults to False. + ground_truth_masks : Path or None, optional + Path to the ground truth masks, used in the test stage to compute segmentation metrics, - defaults to None + defaults to None. + persistent_workers : bool, optional + Whether to keep the workers alive between fitting epochs, + defaults to False. + prefetch_factor : int or None, optional + Number of samples loaded in advance by each worker during fitting, + defaults to None (2 per PyTorch default). """ def __init__( @@ -312,7 +331,7 @@ def __init__( split_ratio: float = 0.8, batch_size: int = 16, num_workers: int = 8, - architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"] = "2.5D", + target_2d: bool = False, yx_patch_size: tuple[int, int] = (256, 256), normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], @@ -327,7 +346,7 @@ def __init__( self.target_channel = _ensure_channel_list(target_channel) self.batch_size = batch_size self.num_workers = num_workers - self.target_2d = False if architecture in ["UNeXt2", "3D", "fcmae"] else True + self.target_2d = target_2d self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size