Skip to content

Commit

Permalink
Fix demos on other platforms (#95)
Browse files Browse the repository at this point in the history
* load statistics explicitly to avoid autocasting to float64
this allows the GPU transfer on MPS
see Lightning-AI/pytorch-lightning#16213

* disable multiprocssing by default to avoid fork/spawn difference

* remove unused patch size argument

* black
  • Loading branch information
ziw-liu authored Jun 25, 2024
1 parent a5b51c3 commit 2747e83
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 24 deletions.
28 changes: 18 additions & 10 deletions examples/demos/demo_vscyto2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,34 @@
"""

# %%
# Set download paths
root_dir = Path("")
# Download from
# https://public.czbiohub.org/comp.micro/viscy/datasets/testing/VSCyto2D/a549_hoechst_cellmask_test.zarr
input_data_path = "datasets/testing/VSCyto2D/a549_hoechst_cellmask_test.zarr"
input_data_path = (
root_dir / "datasets/testing/VSCyto2D/a549_hoechst_cellmask_test.zarr"
)
# Download from GitHub release page of v0.1.0
model_ckpt_path = "VisCy-0.1.0-VS-models/VSCyto2D/epoch=399-step=23200.ckpt"
model_ckpt_path = (
root_dir / "VisCy-0.1.0-VS-models/VSCyto2D/epoch=399-step=23200.ckpt"
)
# Zarr store to save the predictions
output_path = "./a549_prediction.zarr"
output_path = root_dir / "./a549_prediction.zarr"
# FOV of interest
fov = "0/0/0"

input_data_path = Path(input_data_path) / fov
input_data_path = input_data_path / fov

# %%
# Create the VSCyto2D network

# NOTE: Change the following parameters as needed.
BATCH_SIZE = 10
YX_PATCH_SIZE = (384, 384)
NUM_WORKERS = 8
# Reduce the batch size if encountering out-of-memory errors
BATCH_SIZE = 8
# NOTE: Set the number of workers to 0 for Windows and macOS
# since multiprocessing only works with a
# `if __name__ == '__main__':` guard.
# On Linux, set it to the number of CPU cores to maximize performance.
NUM_WORKERS = 0
phase_channel_name = "Phase3D"

# %%[markdown]
Expand All @@ -66,8 +75,7 @@
split_ratio=0.8,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
architecture="2D",
yx_patch_size=YX_PATCH_SIZE,
architecture="fcmae",
normalizations=[
NormalizeSampled(
[phase_channel_name],
Expand Down
10 changes: 6 additions & 4 deletions examples/demos/demo_vscyto3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,13 @@
# %%
# Create the VSCyto3D model

# NOTE: Change the following parameters as needed.
# Reduce the batch size if encountering out-of-memory errors
BATCH_SIZE = 2
YX_PATCH_SIZE = (384, 384)
NUM_WORKERS = 8
# NOTE: Set the number of workers to 0 for Windows and macOS
# since multiprocessing only works with a
# `if __name__ == '__main__':` guard.
# On Linux, set it to the number of CPU cores to maximize performance.
NUM_WORKERS = 0
phase_channel_name = "Phase3D"

# %%[markdown]
Expand All @@ -68,7 +71,6 @@
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
architecture="UNeXt2",
yx_patch_size=YX_PATCH_SIZE,
normalizations=[
NormalizeSampled(
[phase_channel_name],
Expand Down
10 changes: 6 additions & 4 deletions examples/demos/demo_vsneuromast.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@
# %%
# Create the VSNeuromast model

# NOTE: Change the following parameters as needed.
# Reduce the batch size if encountering out-of-memory errors
BATCH_SIZE = 2
YX_PATCH_SIZE = (384, 384)
NUM_WORKERS = 8
# NOTE: Set the number of workers to 0 for Windows and macOS
# since multiprocessing only works with a
# `if __name__ == '__main__':` guard.
# On Linux, set it to the number of CPU cores to maximize performance.
NUM_WORKERS = 0
phase_channel_name = "Phase3D"

# %%[markdown]
Expand All @@ -65,7 +68,6 @@
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
architecture="UNeXt2",
yx_patch_size=YX_PATCH_SIZE,
normalizations=[
NormalizeSampled(
[phase_channel_name],
Expand Down
23 changes: 21 additions & 2 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample:
return collated


def _read_norm_meta(fov: Position) -> NormMeta | None:
"""
Read normalization metadata from the FOV.
Convert to float32 tensors to avoid automatic casting to float64.
"""
norm_meta = fov.zattrs.get("normalization", None)
if norm_meta is None:
return None
for channel, channel_values in norm_meta.items():
for level, level_values in channel_values.items():
for stat, value in level_values.items():
norm_meta[channel][level][stat] = torch.tensor(
value, dtype=torch.float32
)
return norm_meta


class SlidingWindowDataset(Dataset):
"""Torch dataset where each element is a window of
(C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``.
Expand Down Expand Up @@ -124,7 +141,7 @@ def _get_windows(self) -> None:
w += ts * zs
self.window_keys.append(w)
self.window_arrays.append(img_arr)
self.window_norm_meta.append(fov.zattrs.get("normalization", None))
self.window_norm_meta.append(_read_norm_meta(fov))
self._max_window = w

def _find_window(self, index: int) -> tuple[ImageArray, int, NormMeta | None]:
Expand Down Expand Up @@ -162,7 +179,9 @@ def __len__(self) -> int:
return self._max_window

def _stack_channels(
self, sample_images: list[dict[str, Tensor]] | dict[str, Tensor], key: str
self,
sample_images: list[dict[str, Tensor]] | dict[str, Tensor],
key: str,
) -> Tensor | list[Tensor]:
"""Stack single-channel images into a multi-channel tensor."""
if not isinstance(sample_images, list):
Expand Down
8 changes: 4 additions & 4 deletions viscy/data/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@


class LevelNormStats(TypedDict):
mean: float
std: float
median: float
iqr: float
mean: Tensor
std: Tensor
median: Tensor
iqr: Tensor


class ChannelNormStats(TypedDict):
Expand Down

0 comments on commit 2747e83

Please sign in to comment.