diff --git a/examples/demos/demo_vscyto2d.py b/examples/demos/demo_vscyto2d.py index e154cb05..eb480e4e 100644 --- a/examples/demos/demo_vscyto2d.py +++ b/examples/demos/demo_vscyto2d.py @@ -27,25 +27,34 @@ """ # %% +# Set download paths +root_dir = Path("") # Download from # https://public.czbiohub.org/comp.micro/viscy/datasets/testing/VSCyto2D/a549_hoechst_cellmask_test.zarr -input_data_path = "datasets/testing/VSCyto2D/a549_hoechst_cellmask_test.zarr" +input_data_path = ( + root_dir / "datasets/testing/VSCyto2D/a549_hoechst_cellmask_test.zarr" +) # Download from GitHub release page of v0.1.0 -model_ckpt_path = "VisCy-0.1.0-VS-models/VSCyto2D/epoch=399-step=23200.ckpt" +model_ckpt_path = ( + root_dir / "VisCy-0.1.0-VS-models/VSCyto2D/epoch=399-step=23200.ckpt" +) # Zarr store to save the predictions -output_path = "./a549_prediction.zarr" +output_path = root_dir / "./a549_prediction.zarr" # FOV of interest fov = "0/0/0" -input_data_path = Path(input_data_path) / fov +input_data_path = input_data_path / fov # %% # Create the VSCyto2D network -# NOTE: Change the following parameters as needed. -BATCH_SIZE = 10 -YX_PATCH_SIZE = (384, 384) -NUM_WORKERS = 8 +# Reduce the batch size if encountering out-of-memory errors +BATCH_SIZE = 8 +# NOTE: Set the number of workers to 0 for Windows and macOS +# since multiprocessing only works with a +# `if __name__ == '__main__':` guard. +# On Linux, set it to the number of CPU cores to maximize performance. +NUM_WORKERS = 0 phase_channel_name = "Phase3D" # %%[markdown] @@ -66,8 +75,7 @@ split_ratio=0.8, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, - architecture="2D", - yx_patch_size=YX_PATCH_SIZE, + architecture="fcmae", normalizations=[ NormalizeSampled( [phase_channel_name], diff --git a/examples/demos/demo_vscyto3d.py b/examples/demos/demo_vscyto3d.py index c25752dd..0aa57e2d 100644 --- a/examples/demos/demo_vscyto3d.py +++ b/examples/demos/demo_vscyto3d.py @@ -43,10 +43,13 @@ # %% # Create the VSCyto3D model -# NOTE: Change the following parameters as needed. +# Reduce the batch size if encountering out-of-memory errors BATCH_SIZE = 2 -YX_PATCH_SIZE = (384, 384) -NUM_WORKERS = 8 +# NOTE: Set the number of workers to 0 for Windows and macOS +# since multiprocessing only works with a +# `if __name__ == '__main__':` guard. +# On Linux, set it to the number of CPU cores to maximize performance. +NUM_WORKERS = 0 phase_channel_name = "Phase3D" # %%[markdown] @@ -68,7 +71,6 @@ batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, architecture="UNeXt2", - yx_patch_size=YX_PATCH_SIZE, normalizations=[ NormalizeSampled( [phase_channel_name], diff --git a/examples/demos/demo_vsneuromast.py b/examples/demos/demo_vsneuromast.py index 1eb471d5..b5ec7c9d 100644 --- a/examples/demos/demo_vsneuromast.py +++ b/examples/demos/demo_vsneuromast.py @@ -42,10 +42,13 @@ # %% # Create the VSNeuromast model -# NOTE: Change the following parameters as needed. +# Reduce the batch size if encountering out-of-memory errors BATCH_SIZE = 2 -YX_PATCH_SIZE = (384, 384) -NUM_WORKERS = 8 +# NOTE: Set the number of workers to 0 for Windows and macOS +# since multiprocessing only works with a +# `if __name__ == '__main__':` guard. +# On Linux, set it to the number of CPU cores to maximize performance. +NUM_WORKERS = 0 phase_channel_name = "Phase3D" # %%[markdown] @@ -65,7 +68,6 @@ batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, architecture="UNeXt2", - yx_patch_size=YX_PATCH_SIZE, normalizations=[ NormalizeSampled( [phase_channel_name], diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index ffc827ed..b3e946b0 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -76,6 +76,23 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: return collated +def _read_norm_meta(fov: Position) -> NormMeta | None: + """ + Read normalization metadata from the FOV. + Convert to float32 tensors to avoid automatic casting to float64. + """ + norm_meta = fov.zattrs.get("normalization", None) + if norm_meta is None: + return None + for channel, channel_values in norm_meta.items(): + for level, level_values in channel_values.items(): + for stat, value in level_values.items(): + norm_meta[channel][level][stat] = torch.tensor( + value, dtype=torch.float32 + ) + return norm_meta + + class SlidingWindowDataset(Dataset): """Torch dataset where each element is a window of (C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``. @@ -124,7 +141,7 @@ def _get_windows(self) -> None: w += ts * zs self.window_keys.append(w) self.window_arrays.append(img_arr) - self.window_norm_meta.append(fov.zattrs.get("normalization", None)) + self.window_norm_meta.append(_read_norm_meta(fov)) self._max_window = w def _find_window(self, index: int) -> tuple[ImageArray, int, NormMeta | None]: @@ -162,7 +179,9 @@ def __len__(self) -> int: return self._max_window def _stack_channels( - self, sample_images: list[dict[str, Tensor]] | dict[str, Tensor], key: str + self, + sample_images: list[dict[str, Tensor]] | dict[str, Tensor], + key: str, ) -> Tensor | list[Tensor]: """Stack single-channel images into a multi-channel tensor.""" if not isinstance(sample_images, list): diff --git a/viscy/data/typing.py b/viscy/data/typing.py index 1eabba75..d02463d8 100644 --- a/viscy/data/typing.py +++ b/viscy/data/typing.py @@ -10,10 +10,10 @@ class LevelNormStats(TypedDict): - mean: float - std: float - median: float - iqr: float + mean: Tensor + std: Tensor + median: Tensor + iqr: Tensor class ChannelNormStats(TypedDict):