Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit target shape argument in the HCS data module #212

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tests/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentatio
batch_size=batch_size,
num_workers=0,
augmentations=transforms,
architecture="3D",
target_2d=False,
split_ratio=split_ratio,
yx_patch_size=yx_patch_size,
)
Expand All @@ -78,9 +78,9 @@ def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentatio
)


def test_datamodule_setup_predict(preprocessed_hcs_dataset):
@mark.parametrize("z_window_size", [1, 5])
def test_datamodule_setup_predict(preprocessed_hcs_dataset, z_window_size):
data_path = preprocessed_hcs_dataset
z_window_size = 5
channel_split = 2
with open_ome_zarr(data_path) as dataset:
channel_names = dataset.channel_names
Expand All @@ -91,6 +91,7 @@ def test_datamodule_setup_predict(preprocessed_hcs_dataset):
source_channel=channel_names[:channel_split],
target_channel=channel_names[channel_split:],
z_window_size=z_window_size,
target_2d=bool(z_window_size == 1),
batch_size=2,
num_workers=0,
)
Expand Down
75 changes: 47 additions & 28 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,33 +274,52 @@ def __getitem__(self, index: int) -> Sample:


class HCSDataModule(LightningDataModule):
"""Lightning data module for a preprocessed HCS NGFF Store.

:param str data_path: path to the data store
:param str | Sequence[str] source_channel: name(s) of the source channel,
e.g. ``'Phase'``
:param str | Sequence[str] target_channel: name(s) of the target channel,
e.g. ``['Nuclei', 'Membrane']``
:param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D
:param float split_ratio: split ratio of the training subset in the fit stage,
e.g. 0.8 means a 80/20 split between training/validation,
by default 0.8
:param int batch_size: batch size, defaults to 16
:param int num_workers: number of data-loading workers, defaults to 8
:param Literal["2D", "UNeXt2", "2.5D", "3D"] architecture: U-Net architecture,
defaults to "2.5D"
:param tuple[int, int] yx_patch_size: patch size in (Y, X),
defaults to (256, 256)
:param list[MapTransform] normalizations: MONAI dictionary transforms
applied to selected channels, defaults to [] (no normalization)
:param list[MapTransform] augmentations: MONAI dictionary transforms
applied to the training set, defaults to [] (no augmentation)
:param bool caching: whether to decompress all the images and cache the result,
will store in ``/tmp/$SLURM_JOB_ID/`` if available,
defaults to False
:param Path | None ground_truth_masks: path to the ground truth masks,
"""
Lightning data module for a preprocessed HCS NGFF Store.

Parameters
----------
data_path : str
Path to the data store.
source_channel : str or Sequence[str]
Name(s) of the source channel, e.g. 'Phase'.
target_channel : str or Sequence[str]
Name(s) of the target channel, e.g. ['Nuclei', 'Membrane'].
z_window_size : int
Z window size of the 2.5D U-Net, 1 for 2D.
split_ratio : float, optional
Split ratio of the training subset in the fit stage,
e.g. 0.8 means an 80/20 split between training/validation,
by default 0.8.
batch_size : int, optional
Batch size, defaults to 16.
num_workers : int, optional
Number of data-loading workers, defaults to 8.
target_2d : bool, optional
Whether the target is 2D (e.g. in a 2.5D model),
defaults to False.
yx_patch_size : tuple[int, int], optional
Patch size in (Y, X), defaults to (256, 256).
normalizations : list of MapTransform, optional
MONAI dictionary transforms applied to selected channels,
defaults to ``[]`` (no normalization).
augmentations : list of MapTransform, optional
MONAI dictionary transforms applied to the training set,
defaults to ``[]`` (no augmentation).
caching : bool, optional
Whether to decompress all the images and cache the result,
will store in `/tmp/$SLURM_JOB_ID/` if available,
defaults to False.
ground_truth_masks : Path or None, optional
Path to the ground truth masks,
used in the test stage to compute segmentation metrics,
defaults to None
defaults to None.
persistent_workers : bool, optional
Whether to keep the workers alive between fitting epochs,
defaults to False.
prefetch_factor : int or None, optional
Number of samples loaded in advance by each worker during fitting,
defaults to None (2 per PyTorch default).
"""

def __init__(
Expand All @@ -312,7 +331,7 @@ def __init__(
split_ratio: float = 0.8,
batch_size: int = 16,
num_workers: int = 8,
architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"] = "2.5D",
target_2d: bool = False,
yx_patch_size: tuple[int, int] = (256, 256),
normalizations: list[MapTransform] = [],
augmentations: list[MapTransform] = [],
Expand All @@ -327,7 +346,7 @@ def __init__(
self.target_channel = _ensure_channel_list(target_channel)
self.batch_size = batch_size
self.num_workers = num_workers
self.target_2d = False if architecture in ["UNeXt2", "3D", "fcmae"] else True
self.target_2d = target_2d
self.z_window_size = z_window_size
self.split_ratio = split_ratio
self.yx_patch_size = yx_patch_size
Expand Down
Loading