Skip to content

Commit

Permalink
Explicit target shape argument in the HCS data module (#212)
Browse files Browse the repository at this point in the history
* explicit target shape argument in the HCS data module

* update docstring

* update test cases
  • Loading branch information
ziw-liu authored and edyoshikun committed Dec 19, 2024
1 parent a845847 commit a6e2318
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 31 deletions.
7 changes: 4 additions & 3 deletions tests/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentatio
batch_size=batch_size,
num_workers=0,
augmentations=transforms,
architecture="3D",
target_2d=False,
split_ratio=split_ratio,
yx_patch_size=yx_patch_size,
)
Expand All @@ -78,9 +78,9 @@ def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentatio
)


def test_datamodule_setup_predict(preprocessed_hcs_dataset):
@mark.parametrize("z_window_size", [1, 5])
def test_datamodule_setup_predict(preprocessed_hcs_dataset, z_window_size):
data_path = preprocessed_hcs_dataset
z_window_size = 5
channel_split = 2
with open_ome_zarr(data_path) as dataset:
channel_names = dataset.channel_names
Expand All @@ -91,6 +91,7 @@ def test_datamodule_setup_predict(preprocessed_hcs_dataset):
source_channel=channel_names[:channel_split],
target_channel=channel_names[channel_split:],
z_window_size=z_window_size,
target_2d=bool(z_window_size == 1),
batch_size=2,
num_workers=0,
)
Expand Down
75 changes: 47 additions & 28 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,33 +274,52 @@ def __getitem__(self, index: int) -> Sample:


class HCSDataModule(LightningDataModule):
"""Lightning data module for a preprocessed HCS NGFF Store.
:param str data_path: path to the data store
:param str | Sequence[str] source_channel: name(s) of the source channel,
e.g. ``'Phase'``
:param str | Sequence[str] target_channel: name(s) of the target channel,
e.g. ``['Nuclei', 'Membrane']``
:param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D
:param float split_ratio: split ratio of the training subset in the fit stage,
e.g. 0.8 means a 80/20 split between training/validation,
by default 0.8
:param int batch_size: batch size, defaults to 16
:param int num_workers: number of data-loading workers, defaults to 8
:param Literal["2D", "UNeXt2", "2.5D", "3D"] architecture: U-Net architecture,
defaults to "2.5D"
:param tuple[int, int] yx_patch_size: patch size in (Y, X),
defaults to (256, 256)
:param list[MapTransform] normalizations: MONAI dictionary transforms
applied to selected channels, defaults to [] (no normalization)
:param list[MapTransform] augmentations: MONAI dictionary transforms
applied to the training set, defaults to [] (no augmentation)
:param bool caching: whether to decompress all the images and cache the result,
will store in ``/tmp/$SLURM_JOB_ID/`` if available,
defaults to False
:param Path | None ground_truth_masks: path to the ground truth masks,
"""
Lightning data module for a preprocessed HCS NGFF Store.
Parameters
----------
data_path : str
Path to the data store.
source_channel : str or Sequence[str]
Name(s) of the source channel, e.g. 'Phase'.
target_channel : str or Sequence[str]
Name(s) of the target channel, e.g. ['Nuclei', 'Membrane'].
z_window_size : int
Z window size of the 2.5D U-Net, 1 for 2D.
split_ratio : float, optional
Split ratio of the training subset in the fit stage,
e.g. 0.8 means an 80/20 split between training/validation,
by default 0.8.
batch_size : int, optional
Batch size, defaults to 16.
num_workers : int, optional
Number of data-loading workers, defaults to 8.
target_2d : bool, optional
Whether the target is 2D (e.g. in a 2.5D model),
defaults to False.
yx_patch_size : tuple[int, int], optional
Patch size in (Y, X), defaults to (256, 256).
normalizations : list of MapTransform, optional
MONAI dictionary transforms applied to selected channels,
defaults to ``[]`` (no normalization).
augmentations : list of MapTransform, optional
MONAI dictionary transforms applied to the training set,
defaults to ``[]`` (no augmentation).
caching : bool, optional
Whether to decompress all the images and cache the result,
will store in `/tmp/$SLURM_JOB_ID/` if available,
defaults to False.
ground_truth_masks : Path or None, optional
Path to the ground truth masks,
used in the test stage to compute segmentation metrics,
defaults to None
defaults to None.
persistent_workers : bool, optional
Whether to keep the workers alive between fitting epochs,
defaults to False.
prefetch_factor : int or None, optional
Number of samples loaded in advance by each worker during fitting,
defaults to None (2 per PyTorch default).
"""

def __init__(
Expand All @@ -312,7 +331,7 @@ def __init__(
split_ratio: float = 0.8,
batch_size: int = 16,
num_workers: int = 8,
architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"] = "2.5D",
target_2d: bool = False,
yx_patch_size: tuple[int, int] = (256, 256),
normalizations: list[MapTransform] = [],
augmentations: list[MapTransform] = [],
Expand All @@ -327,7 +346,7 @@ def __init__(
self.target_channel = _ensure_channel_list(target_channel)
self.batch_size = batch_size
self.num_workers = num_workers
self.target_2d = False if architecture in ["UNeXt2", "3D", "fcmae"] else True
self.target_2d = target_2d
self.z_window_size = z_window_size
self.split_ratio = split_ratio
self.yx_patch_size = yx_patch_size
Expand Down

0 comments on commit a6e2318

Please sign in to comment.