From 0536d299d0cb0d724d1dfc33c5e5a82871997972 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Mon, 8 Apr 2024 09:22:07 -0700 Subject: [PATCH] Masked autoencoder pre-training for virtual staining models (#67) * refactor data loading into its own module * update type annotations * move the logging module out * move old logging into utils * rename tests to match module name * bump torch * draft fcmae encoder * add stem to the encoder * wip: masked stem layernorm * wip: patchify masked features for linear * use mlp from timm * hack: POC training script for FCMAE * fix mask for fitting * remove training script * default architecture * fine-tuning options * fix cli for finetuning * draft combined data module * fix import * manual validation loss reduction * update linting new black version has different rules * update development guide * update type hints * bump iohub * draft ctmc v1 dataset * update tests * move test_data * remove path conversion * configurable normalizations (#68) * inital commit adding the normalization. * adding dataset_statistics to each fov to facilitate the configurable augmentations * fix indentation * ruff * test preprocessing * remove redundant field * cleanup --------- Co-authored-by: Ziwen Liu * fix ctmc dataloading * add example ctmc v1 loading script * changing the normalization and augmentations default from None to empty list. * invert intensity transform * concatenated data module * subsample videos * livecell dataset * all sample fields are optional * fix multi-dataloader validation * lint * fixing preprocessing for varying array shapes (i.e aics dataset) * update loading scripts * fix CombineMode * compose normalizations for predict and test stages * black * fix normalization in example config * fix collate when multi-sample transform is not used * ddp caching fixes * fix caching when using combined loader * move log values to GPU before syncing https://github.com/Lightning-AI/pytorch-lightning/issues/18803 * removing normalize_source from configs. * typing fixes * fix test data path * fix test dataset * add docstring for ConcatDataModule * format --------- Co-authored-by: Eduardo Hirata-Miyasaki --- CONTRIBUTING.md | 14 +- examples/configs/fit_example.yml | 16 +- examples/configs/predict_example.yml | 1 - examples/configs/test_example.yml | 1 - examples/demo_dlmbl/debug_log_graph.py | 2 +- examples/demo_dlmbl/solution.py | 2 +- pyproject.toml | 21 +- tests/conftest.py | 2 + tests/data/__init__.py | 0 tests/data/test_data.py | 105 +++++ tests/light/test_data.py | 70 --- tests/light/test_engine.py | 7 + tests/unet/__init__.py | 0 .../networks/Unet25D_tests.py | 0 .../networks/Unet2D_tests.py | 0 .../networks/layers/ConvBlock2D_tests.py | 0 .../networks/layers/ConvBlock3D_tests.py | 0 tests/unet/test_fcmae.py | 115 +++++ viscy/cli/cli.py | 2 +- viscy/data/__init__.py | 0 viscy/data/combined.py | 134 ++++++ viscy/data/ctmc_v1.py | 96 ++++ viscy/{light/data.py => data/hcs.py} | 238 ++++------ viscy/data/livecell.py | 98 ++++ viscy/data/typing.py | 63 +++ viscy/evaluation/evaluation_metrics.py | 1 + viscy/light/engine.py | 120 ++++- viscy/light/predict_writer.py | 2 +- viscy/preprocessing/generate_masks.py | 1 + viscy/preprocessing/preprocessing.md | 16 +- viscy/scripts/load_ctmc_v1.py | 84 ++++ viscy/scripts/load_livecell.py | 85 ++++ viscy/scripts/profiling.py | 2 +- viscy/transforms.py | 64 +++ viscy/unet/networks/Unet21D.py | 30 +- viscy/unet/networks/fcmae.py | 422 ++++++++++++++++++ viscy/utils/image_utils.py | 4 +- viscy/{unet => }/utils/logging.py | 0 viscy/utils/meta_utils.py | 5 +- viscy/utils/normalize.py | 1 + 40 files changed, 1544 insertions(+), 280 deletions(-) create mode 100644 tests/data/__init__.py create mode 100644 tests/data/test_data.py delete mode 100644 tests/light/test_data.py create mode 100644 tests/light/test_engine.py create mode 100644 tests/unet/__init__.py rename tests/{torch_unet => unet}/networks/Unet25D_tests.py (100%) rename tests/{torch_unet => unet}/networks/Unet2D_tests.py (100%) rename tests/{torch_unet => unet}/networks/layers/ConvBlock2D_tests.py (100%) rename tests/{torch_unet => unet}/networks/layers/ConvBlock3D_tests.py (100%) create mode 100644 tests/unet/test_fcmae.py create mode 100644 viscy/data/__init__.py create mode 100644 viscy/data/combined.py create mode 100644 viscy/data/ctmc_v1.py rename viscy/{light/data.py => data/hcs.py} (75%) create mode 100644 viscy/data/livecell.py create mode 100644 viscy/data/typing.py create mode 100644 viscy/scripts/load_ctmc_v1.py create mode 100644 viscy/scripts/load_livecell.py create mode 100644 viscy/unet/networks/fcmae.py rename viscy/{unet => }/utils/logging.py (100%) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3b40b075..44db5bbc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -10,7 +10,19 @@ then make an editable installation with all the optional dependencies: pip install -e ".[dev,visual,metrics]" ``` -## Testing +## CI requirements + +Lint with Ruff: + +```sh +ruff check viscy +``` + +Format the code with Black: + +```sh +black viscy +``` Run tests with `pytest`: diff --git a/examples/configs/fit_example.yml b/examples/configs/fit_example.yml index 017c57f0..85ff4c67 100644 --- a/examples/configs/fit_example.yml +++ b/examples/configs/fit_example.yml @@ -37,6 +37,19 @@ data: batch_size: 32 num_workers: 16 yx_patch_size: [256, 256] + normalizations: + - class_path: viscy.transforms.NormalizeSampled + init_args: + keys: [source] + level: "fov_statistics" + subtrahend: "mean" + divisor: "std" + - class_path: viscy.transforms.NormalizeSampled + init_args: + keys: [target_1] + level: "fov_statistics" + subtrahend: "median" + divisor: "iqr" augmentations: - class_path: viscy.transforms.RandWeightedCropd init_args: @@ -74,5 +87,4 @@ data: sigma_z: [0.25, 1.5] sigma_y: [0.25, 1.5] sigma_x: [0.25, 1.5] - caching: false - normalize_source: true + caching: false \ No newline at end of file diff --git a/examples/configs/predict_example.yml b/examples/configs/predict_example.yml index 789613e6..b2556139 100644 --- a/examples/configs/predict_example.yml +++ b/examples/configs/predict_example.yml @@ -62,7 +62,6 @@ predict: - 256 - 256 caching: false - normalize_source: false predict_scale_source: null return_predictions: false ckpt_path: null diff --git a/examples/configs/test_example.yml b/examples/configs/test_example.yml index 2e750d73..6c7130a2 100644 --- a/examples/configs/test_example.yml +++ b/examples/configs/test_example.yml @@ -61,7 +61,6 @@ data: - 256 - 256 caching: false - normalize_source: false ground_truth_masks: null ckpt_path: null verbose: true diff --git a/examples/demo_dlmbl/debug_log_graph.py b/examples/demo_dlmbl/debug_log_graph.py index 1819b02f..ec987118 100644 --- a/examples/demo_dlmbl/debug_log_graph.py +++ b/examples/demo_dlmbl/debug_log_graph.py @@ -19,7 +19,7 @@ from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard # HCSDataModule makes it easy to load data during training. -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule # Trainer class and UNet. from viscy.light.engine import VSUNet diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 933f939d..2c81aa6f 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -83,7 +83,7 @@ from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard # HCSDataModule makes it easy to load data during training. -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule # training augmentations from viscy.transforms import ( diff --git a/pyproject.toml b/pyproject.toml index 8d60ee1d..8f6978de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,8 +10,8 @@ requires-python = ">=3.10" license = { file = "LICENSE" } authors = [{ name = "CZ Biohub SF", email = "compmicro@czbiohub.org" }] dependencies = [ - "iohub==0.1.0rc0", - "torch>=2.0.0", + "iohub==0.1.0", + "torch>=2.1.2", "timm>=0.9.5", "tensorboard>=2.13.0", "lightning>=2.0.1", @@ -30,7 +30,15 @@ metrics = [ "ptflops>=0.7", ] visual = ["ipykernel", "graphviz", "torchview"] -dev = ["pytest", "pytest-cov", "hypothesis", "profilehooks", "onnxruntime"] +dev = [ + "pytest", + "pytest-cov", + "hypothesis", + "ruff", + "black", + "profilehooks", + "onnxruntime", +] [project.scripts] viscy = "viscy.cli.cli:main" @@ -39,12 +47,9 @@ viscy = "viscy.cli.cli:main" write_to = "viscy/_version.py" [tool.black] -src = ["viscy"] line-length = 88 [tool.ruff] src = ["viscy", "tests"] -extend-select = ["I001"] - -[tool.ruff.isort] -known-first-party = ["viscy"] +lint.extend-select = ["I001"] +lint.isort.known-first-party = ["viscy"] diff --git a/tests/conftest.py b/tests/conftest.py index 9ad6630c..198e51ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,6 +36,8 @@ def preprocessed_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path: norm_meta = {channel: {"dataset_statistics": expected} for channel in channel_names} with open_ome_zarr(dataset_path, mode="r+") as dataset: dataset.zattrs["normalization"] = norm_meta + for _, fov in dataset.positions(): + fov.zattrs["normalization"] = norm_meta return dataset_path diff --git a/tests/data/__init__.py b/tests/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/data/test_data.py b/tests/data/test_data.py new file mode 100644 index 00000000..8eb06352 --- /dev/null +++ b/tests/data/test_data.py @@ -0,0 +1,105 @@ +from pathlib import Path + +from iohub import open_ome_zarr +from monai.transforms import RandSpatialCropSamplesd +from pytest import mark + +from viscy.data.hcs import HCSDataModule +from viscy.light.trainer import VSTrainer + + +@mark.parametrize("default_channels", [True, False]) +def test_preprocess(small_hcs_dataset: Path, default_channels: bool): + data_path = small_hcs_dataset + if default_channels: + channel_names = -1 + else: + with open_ome_zarr(data_path) as dataset: + channel_names = dataset.channel_names + trainer = VSTrainer(accelerator="cpu") + trainer.preprocess(data_path, channel_names=channel_names, num_workers=2) + with open_ome_zarr(data_path) as dataset: + channel_names = dataset.channel_names + for channel in channel_names: + assert "dataset_statistics" in dataset.zattrs["normalization"][channel] + for _, fov in dataset.positions(): + norm_metadata = fov.zattrs["normalization"] + for channel in channel_names: + assert channel in norm_metadata + assert "dataset_statistics" in norm_metadata[channel] + assert "fov_statistics" in norm_metadata[channel] + + +@mark.parametrize("multi_sample_augmentation", [True, False]) +def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentation): + data_path = preprocessed_hcs_dataset + z_window_size = 5 + channel_split = 2 + split_ratio = 0.8 + yx_patch_size = [128, 96] + batch_size = 4 + with open_ome_zarr(data_path) as dataset: + channel_names = dataset.channel_names + if multi_sample_augmentation: + transforms = [ + RandSpatialCropSamplesd( + keys=channel_names, + roi_size=[z_window_size, *yx_patch_size], + num_samples=2, + ) + ] + else: + transforms = [] + dm = HCSDataModule( + data_path=data_path, + source_channel=channel_names[:channel_split], + target_channel=channel_names[channel_split:], + z_window_size=z_window_size, + batch_size=batch_size, + num_workers=0, + augmentations=transforms, + architecture="3D", + split_ratio=split_ratio, + yx_patch_size=yx_patch_size, + ) + dm.setup(stage="fit") + for batch in dm.train_dataloader(): + assert batch["source"].shape == ( + batch_size, + channel_split, + z_window_size, + *yx_patch_size, + ) + assert batch["target"].shape == ( + batch_size, + len(channel_names) - channel_split, + z_window_size, + *yx_patch_size, + ) + + +def test_datamodule_setup_predict(preprocessed_hcs_dataset): + data_path = preprocessed_hcs_dataset + z_window_size = 5 + channel_split = 2 + with open_ome_zarr(data_path) as dataset: + channel_names = dataset.channel_names + img = next(dataset.positions())[1][0] + total_p = len(list(dataset.positions())) + dm = HCSDataModule( + data_path=data_path, + source_channel=channel_names[:channel_split], + target_channel=channel_names[channel_split:], + z_window_size=z_window_size, + batch_size=2, + num_workers=0, + ) + dm.setup(stage="predict") + dataset = dm.predict_dataset + assert len(dataset) == total_p * 2 * (img.slices - z_window_size + 1) + assert dataset[0]["source"].shape == ( + channel_split, + z_window_size, + img.height, + img.width, + ) diff --git a/tests/light/test_data.py b/tests/light/test_data.py deleted file mode 100644 index 263f8f90..00000000 --- a/tests/light/test_data.py +++ /dev/null @@ -1,70 +0,0 @@ -from pathlib import Path - -import torch -from iohub import open_ome_zarr -from pytest import mark - -from viscy.light.data import HCSDataModule -from viscy.light.trainer import VSTrainer - - -@mark.parametrize("default_channels", [True, False]) -def test_preprocess(small_hcs_dataset: Path, default_channels: bool): - data_path = small_hcs_dataset - if default_channels: - channel_names = -1 - else: - with open_ome_zarr(data_path) as dataset: - channel_names = dataset.channel_names - trainer = VSTrainer(accelerator="cpu") - trainer.preprocess(data_path, channel_names=channel_names, num_workers=2) - - -def test_datamodule_setup_predict(preprocessed_hcs_dataset): - data_path = preprocessed_hcs_dataset - z_window_size = 5 - channel_split = 2 - with open_ome_zarr(data_path) as dataset: - channel_names = dataset.channel_names - img = next(dataset.positions())[1][0] - total_p = len(list(dataset.positions())) - dm = HCSDataModule( - data_path=data_path, - source_channel=channel_names[:channel_split], - target_channel=channel_names[channel_split:], - z_window_size=z_window_size, - batch_size=2, - num_workers=0, - ) - dm.setup(stage="predict") - dataset = dm.predict_dataset - assert len(dataset) == total_p * 2 * (img.slices - z_window_size + 1) - assert dataset[0]["source"].shape == ( - channel_split, - z_window_size, - img.height, - img.width, - ) - - -def test_datamodule_predict_scales(preprocessed_hcs_dataset): - data_path = preprocessed_hcs_dataset - with open_ome_zarr(data_path) as dataset: - channel_names = dataset.channel_names - - def get_normalized_stack(predict_scale_source): - factor = 1 if predict_scale_source is None else predict_scale_source - dm = HCSDataModule( - data_path=data_path, - source_channel=channel_names[:2], - target_channel=channel_names[2:], - z_window_size=5, - batch_size=2, - num_workers=0, - predict_scale_source=predict_scale_source, - normalize_source=True, - ) - dm.setup(stage="predict") - return dm.predict_dataset[0]["source"] / factor - - assert torch.allclose(get_normalized_stack(None), get_normalized_stack(2)) diff --git a/tests/light/test_engine.py b/tests/light/test_engine.py new file mode 100644 index 00000000..9ce182f5 --- /dev/null +++ b/tests/light/test_engine.py @@ -0,0 +1,7 @@ +from viscy.light.engine import FcmaeUNet + + +def test_fcmae_vsunet() -> None: + model = FcmaeUNet( + model_config=dict(in_channels=3, out_channels=1), fit_mask_ratio=0.6 + ) diff --git a/tests/unet/__init__.py b/tests/unet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/torch_unet/networks/Unet25D_tests.py b/tests/unet/networks/Unet25D_tests.py similarity index 100% rename from tests/torch_unet/networks/Unet25D_tests.py rename to tests/unet/networks/Unet25D_tests.py diff --git a/tests/torch_unet/networks/Unet2D_tests.py b/tests/unet/networks/Unet2D_tests.py similarity index 100% rename from tests/torch_unet/networks/Unet2D_tests.py rename to tests/unet/networks/Unet2D_tests.py diff --git a/tests/torch_unet/networks/layers/ConvBlock2D_tests.py b/tests/unet/networks/layers/ConvBlock2D_tests.py similarity index 100% rename from tests/torch_unet/networks/layers/ConvBlock2D_tests.py rename to tests/unet/networks/layers/ConvBlock2D_tests.py diff --git a/tests/torch_unet/networks/layers/ConvBlock3D_tests.py b/tests/unet/networks/layers/ConvBlock3D_tests.py similarity index 100% rename from tests/torch_unet/networks/layers/ConvBlock3D_tests.py rename to tests/unet/networks/layers/ConvBlock3D_tests.py diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py new file mode 100644 index 00000000..4ed441b4 --- /dev/null +++ b/tests/unet/test_fcmae.py @@ -0,0 +1,115 @@ +import torch + +from viscy.unet.networks.fcmae import ( + FullyConvolutionalMAE, + MaskedAdaptiveProjection, + MaskedConvNeXtV2Block, + MaskedConvNeXtV2Stage, + MaskedMultiscaleEncoder, + generate_mask, + masked_patchify, + masked_unpatchify, + upsample_mask, +) + + +def test_generate_mask(): + w = 64 + s = 16 + m = 0.75 + mask = generate_mask((2, 3, w, w), stride=s, mask_ratio=m, device="cpu") + assert mask.shape == (2, 1, w // s, w // s) + assert mask.dtype == torch.bool + ratio = mask.sum((2, 3)) / mask.numel() * mask.shape[0] + assert torch.allclose(ratio, torch.ones_like(ratio) * m) + + +def test_masked_patchify(): + b, c, h, w = 2, 3, 4, 8 + x = torch.rand(b, c, h, w) + mask_ratio = 0.75 + mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio, device=x.device) + mask = upsample_mask(mask, x.shape) + feat = masked_patchify(x, ~mask) + assert feat.shape == (b, int(h * w * (1 - mask_ratio)), c) + + +def test_unmasked_patchify_roundtrip(): + x = torch.rand(2, 3, 4, 8) + y = masked_unpatchify(masked_patchify(x), out_shape=x.shape) + assert torch.allclose(x, y) + + +def test_masked_patchify_roundtrip(): + x = torch.rand(2, 3, 4, 8) + mask = generate_mask(x.shape, stride=2, mask_ratio=0.5, device=x.device) + mask = upsample_mask(mask, x.shape) + y = masked_unpatchify(masked_patchify(x, ~mask), out_shape=x.shape, unmasked=~mask) + assert torch.all((y == 0) ^ (x == y)) + assert torch.all((y == 0)[:, 0:1] == mask) + + +def test_masked_convnextv2_block() -> None: + x = torch.rand(2, 3, 4, 5) + mask = generate_mask(x.shape, stride=1, mask_ratio=0.5, device=x.device) + block = MaskedConvNeXtV2Block(3, 3 * 2) + unmasked_out = block(x) + assert len(unmasked_out.unique()) == x.numel() * 2 + all_unmasked = torch.ones_like(mask) + empty_masked_out = block(x, all_unmasked) + assert torch.allclose(unmasked_out, empty_masked_out) + block = MaskedConvNeXtV2Block(3, 3) + masked_out = block(x, mask) + assert len(masked_out.unique()) == mask.sum() * x.shape[1] + 1 + + +def test_masked_convnextv2_stage(): + x = torch.rand(2, 3, 16, 16) + mask = generate_mask(x.shape, stride=4, mask_ratio=0.5, device=x.device) + stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2) + out = stage(x) + assert out.shape == (2, 3, 8, 8) + masked_out = stage(x, mask) + assert not torch.allclose(masked_out, out) + + +def test_adaptive_projection(): + proj = MaskedAdaptiveProjection( + 3, 12, kernel_size_2d=4, kernel_depth=5, in_stack_depth=5 + ) + assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) + assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) + mask = generate_mask((1, 3, 5, 8, 8), stride=4, mask_ratio=0.6, device="cpu") + masked_out = proj(torch.rand(1, 3, 5, 16, 16), mask) + assert masked_out.shape == (1, 12, 4, 4) + proj = MaskedAdaptiveProjection( + 3, 12, kernel_size_2d=(2, 4), kernel_depth=5, in_stack_depth=15 + ) + assert proj(torch.rand(2, 3, 15, 6, 8)).shape == (2, 12, 3, 2) + + +def test_masked_multiscale_encoder(): + xy_size = 64 + dims = [12, 24, 48, 96] + x = torch.rand(2, 3, 5, xy_size, xy_size) + encoder = MaskedMultiscaleEncoder(3, dims=dims) + auto_masked_features, _ = encoder(x, mask_ratio=0.5) + target_shape = list(x.shape) + target_shape.pop(1) + assert len(auto_masked_features) == 4 + for i, (dim, afeat) in enumerate(zip(dims, auto_masked_features)): + assert afeat.shape[0] == x.shape[0] + assert afeat.shape[1] == dim + stride = 2 * 2 ** (i + 1) + assert afeat.shape[2] == afeat.shape[3] == xy_size // stride + + +def test_fcmae(): + x = torch.rand(2, 3, 5, 128, 128) + model = FullyConvolutionalMAE(3, 3) + y, m = model(x) + assert y.shape == x.shape + assert m is None + y, m = model(x, mask_ratio=0.6) + assert y.shape == x.shape + assert m.shape == (2, 1, 128, 128) diff --git a/viscy/cli/cli.py b/viscy/cli/cli.py index 0946bb0f..f9a55f12 100644 --- a/viscy/cli/cli.py +++ b/viscy/cli/cli.py @@ -9,7 +9,7 @@ from lightning.pytorch.cli import LightningCLI from lightning.pytorch.loggers import TensorBoardLogger -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule from viscy.light.engine import VSUNet from viscy.light.trainer import VSTrainer diff --git a/viscy/data/__init__.py b/viscy/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/viscy/data/combined.py b/viscy/data/combined.py new file mode 100644 index 00000000..db3803a5 --- /dev/null +++ b/viscy/data/combined.py @@ -0,0 +1,134 @@ +from enum import Enum +from typing import Literal, Sequence + +from lightning.pytorch import LightningDataModule +from lightning.pytorch.utilities.combined_loader import CombinedLoader +from torch.utils.data import ConcatDataset, DataLoader + +from viscy.data.hcs import _collate_samples + + +class CombineMode(Enum): + MIN_SIZE = "min_size" + MAX_SIZE_CYCLE = "max_size_cycle" + MAX_SIZE = "max_size" + SEQUENTIAL = "sequential" + + +class CombinedDataModule(LightningDataModule): + """Wrapper for combining multiple data modules. + For supported modes, see ``lightning.pytorch.utilities.combined_loader``. + + :param Sequence[LightningDataModule] data_modules: data modules to combine + :param str train_mode: mode in training stage, defaults to "max_size_cycle" + :param str val_mode: mode in validation stage, defaults to "sequential" + :param str test_mode: mode in testing stage, defaults to "sequential" + :param str predict_mode: mode in prediction stage, defaults to "sequential" + """ + + def __init__( + self, + data_modules: Sequence[LightningDataModule], + train_mode: CombineMode = CombineMode.MAX_SIZE_CYCLE, + val_mode: CombineMode = CombineMode.SEQUENTIAL, + test_mode: CombineMode = CombineMode.SEQUENTIAL, + predict_mode: CombineMode = CombineMode.SEQUENTIAL, + ): + super().__init__() + self.data_modules = data_modules + self.train_mode = CombineMode(train_mode).value + self.val_mode = CombineMode(val_mode).value + self.test_mode = CombineMode(test_mode).value + self.predict_mode = CombineMode(predict_mode).value + self.prepare_data_per_node = True + + def prepare_data(self): + for dm in self.data_modules: + dm.trainer = self.trainer + dm.prepare_data() + + def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + for dm in self.data_modules: + dm.setup(stage) + + def train_dataloader(self): + return CombinedLoader( + [dm.train_dataloader() for dm in self.data_modules], mode=self.train_mode + ) + + def val_dataloader(self): + return CombinedLoader( + [dm.val_dataloader() for dm in self.data_modules], mode=self.val_mode + ) + + def test_dataloader(self): + return CombinedLoader( + [dm.test_dataloader() for dm in self.data_modules], mode=self.test_mode + ) + + def predict_dataloader(self): + return CombinedLoader( + [dm.predict_dataloader() for dm in self.data_modules], + mode=self.predict_mode, + ) + + +class ConcatDataModule(LightningDataModule): + """ + Concatenate multiple data modules. + The concatenated data module will have the same + batch size and number of workers as the first data module. + Each element will be sampled uniformly regardless of their original data module. + + :param Sequence[LightningDataModule] data_modules: data modules to concatenate + """ + + def __init__(self, data_modules: Sequence[LightningDataModule]): + super().__init__() + self.data_modules = data_modules + self.num_workers = data_modules[0].num_workers + self.batch_size = data_modules[0].batch_size + for dm in data_modules: + if dm.num_workers != self.num_workers: + raise ValueError("Inconsistent number of workers") + if dm.batch_size != self.batch_size: + raise ValueError("Inconsistent batch size") + + def prepare_data(self): + for dm in self.data_modules: + dm.prepare_data() + + def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + self.train_patches_per_stack = 0 + for dm in self.data_modules: + dm.setup(stage) + if patches := getattr(dm, "train_patches_per_stack", 0): + if self.train_patches_per_stack == 0: + self.train_patches_per_stack = patches + elif self.train_patches_per_stack != patches: + raise ValueError("Inconsistent patches per stack") + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + self.train_dataset = ConcatDataset( + [dm.train_dataset for dm in self.data_modules] + ) + self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.data_modules]) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size // self.train_patches_per_stack, + num_workers=self.num_workers, + shuffle=True, + persistent_workers=bool(self.num_workers), + collate_fn=_collate_samples, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + persistent_workers=bool(self.num_workers), + ) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py new file mode 100644 index 00000000..d666fdcb --- /dev/null +++ b/viscy/data/ctmc_v1.py @@ -0,0 +1,96 @@ +from pathlib import Path + +from iohub.ngff import open_ome_zarr +from lightning.pytorch import LightningDataModule +from monai.transforms import Compose, MapTransform +from torch.utils.data import DataLoader + +from viscy.data.hcs import ChannelMap, SlidingWindowDataset +from viscy.data.typing import Sample + + +class CTMCv1ValidationDataset(SlidingWindowDataset): + subsample_rate: int = 30 + + def __len__(self) -> int: + # sample every 30th frame in the videos + return super().__len__() // self.subsample_rate + + def __getitem__(self, index: int) -> Sample: + index = index * self.subsample_rate + return super().__getitem__(index) + + +class CTMCv1DataModule(LightningDataModule): + """ + Autoregression data module for the CTMCv1 dataset. + Training and validation datasets are stored in separate HCS OME-Zarr stores. + + :param str | Path train_data_path: Path to the training dataset + :param str | Path val_data_path: Path to the validation dataset + :param list[MapTransform] train_transforms: List of transforms for training + :param list[MapTransform] val_transforms: List of transforms for validation + :param int batch_size: Batch size, defaults to 16 + :param int num_workers: Number of workers, defaults to 8 + :param str channel_name: Name of the DIC channel, defaults to "DIC" + """ + + def __init__( + self, + train_data_path: str | Path, + val_data_path: str | Path, + train_transforms: list[MapTransform], + val_transforms: list[MapTransform], + batch_size: int = 16, + num_workers: int = 8, + channel_name: str = "DIC", + ) -> None: + super().__init__() + self.train_data_path = train_data_path + self.val_data_path = val_data_path + self.train_transforms = train_transforms + self.val_transforms = val_transforms + self.channel_map = ChannelMap(source=[channel_name], target=[channel_name]) + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage: str) -> None: + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + self._setup_fit() + + def _setup_fit(self) -> None: + train_plate = open_ome_zarr(self.train_data_path) + val_plate = open_ome_zarr(self.val_data_path) + train_positions = [p for _, p in train_plate.positions()] + val_positions = [p for _, p in val_plate.positions()] + self.train_dataset = SlidingWindowDataset( + train_positions, + channels=self.channel_map, + z_window_size=1, + transform=Compose(self.train_transforms), + ) + self.val_dataset = CTMCv1ValidationDataset( + val_positions, + channels=self.channel_map, + z_window_size=1, + transform=Compose(self.val_transforms), + ) + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=False, + ) diff --git a/viscy/light/data.py b/viscy/data/hcs.py similarity index 75% rename from viscy/light/data.py rename to viscy/data/hcs.py index 01191db1..f33b6121 100644 --- a/viscy/light/data.py +++ b/viscy/data/hcs.py @@ -5,7 +5,7 @@ import tempfile from glob import glob from pathlib import Path -from typing import Callable, Iterable, Literal, Optional, Sequence, TypedDict, Union +from typing import Callable, Literal, Optional, Sequence, Union import numpy as np import torch @@ -18,15 +18,17 @@ from monai.transforms import ( CenterSpatialCropd, Compose, - InvertibleTransform, MapTransform, MultiSampleTrait, RandAffined, ) +from torch import Tensor from torch.utils.data import DataLoader, Dataset +from viscy.data.typing import ChannelMap, HCSStackIndex, NormMeta, Sample -def _ensure_channel_list(str_or_seq: Union[str, Sequence[str]]): + +def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]: """ Ensure channel argument is a list of strings. @@ -54,24 +56,6 @@ def _search_int_in_str(pattern: str, file_name: str) -> str: raise ValueError(f"Cannot find pattern {pattern} in {file_name}.") -class ChannelMap(TypedDict, total=False): - """Source and target channel names.""" - - source: Union[str, Sequence[str]] - # optional - target: Union[str, Sequence[str]] - - -class Sample(TypedDict, total=False): - """Image sample type for mini-batches.""" - - index: tuple[str, int, int] - # optional - source: Union[torch.Tensor, Sequence[torch.Tensor]] - target: Union[torch.Tensor, Sequence[torch.Tensor]] - labels: Union[torch.Tensor, Sequence[torch.Tensor]] - - def _collate_samples(batch: Sequence[Sample]) -> Sample: """Collate samples into a batch sample. @@ -80,46 +64,18 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: as is the case with ``train_patches_per_stack > 1``. :return Sample: Batch sample (dictionary of tensors) """ - elemment = batch[0] - collated = {} - for key in elemment.keys(): - data: list[list[torch.Tensor]] = [sample[key] for sample in batch] - collated[key] = collate_meta_tensor([im for imgs in data for im in imgs]) + collated: Sample = {} + for key in batch[0].keys(): + data = [] + for sample in batch: + if isinstance(sample[key], Sequence): + data.extend(sample[key]) + else: + data.append(sample[key]) + collated[key] = collate_meta_tensor(data) return collated -class NormalizeSampled(MapTransform, InvertibleTransform): - """Dictionary transform to only normalize target (fluorescence) channel. - - :param Union[str, Iterable[str]] keys: keys to normalize - :param dict[str, dict] norm_meta: Plate normalization metadata - written in preprocessing - """ - - def __init__( - self, keys: Union[str, Iterable[str]], norm_meta: dict[str, dict] - ) -> None: - if set(keys) > set(norm_meta.keys()): - raise KeyError(f"{keys} is not a subset of {norm_meta.keys()}") - super().__init__(keys, allow_missing_keys=False) - self.norm_meta = norm_meta - - def _stat(self, key: str) -> dict: - # FIXME: hard-coded key - return self.norm_meta[key]["dataset_statistics"] - - def __call__(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - d = dict(data) - for key in self.keys: - d[key] = (d[key] - self._stat(key)["median"]) / self._stat(key)["iqr"] - return d - - def inverse(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - d = dict(data) - for key in self.keys: - d[key] = (d[key] * self._stat(key)["iqr"]) + self._stat(key)["median"] - - class SlidingWindowDataset(Dataset): """Torch dataset where each element is a window of (C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``. @@ -128,7 +84,7 @@ class SlidingWindowDataset(Dataset): :param ChannelMap channels: source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] transform: + :param Callable[[dict[str, Tensor]], dict[str, Tensor]] | None transform: a callable that transforms data, defaults to None """ @@ -137,7 +93,7 @@ def __init__( positions: list[Position], channels: ChannelMap, z_window_size: int, - transform: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, + transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, ) -> None: super().__init__() self.positions = positions @@ -160,33 +116,36 @@ def _get_windows(self) -> None: w = 0 self.window_keys = [] self.window_arrays = [] + self.window_norm_meta: list[NormMeta | None] = [] for fov in self.positions: - img_arr = fov["0"] + img_arr: ImageArray = fov["0"] ts = img_arr.frames zs = img_arr.slices - self.z_window_size + 1 w += ts * zs self.window_keys.append(w) self.window_arrays.append(img_arr) + self.window_norm_meta.append(fov.zattrs.get("normalization", None)) self._max_window = w - def _find_window(self, index: int) -> tuple[int, int]: + def _find_window(self, index: int) -> tuple[ImageArray, int, NormMeta | None]: """Look up window given index.""" window_idx = sorted(self.window_keys + [index + 1]).index(index + 1) w = self.window_keys[window_idx] tz = index - self.window_keys[window_idx - 1] if window_idx > 0 else index - return self.window_arrays[self.window_keys.index(w)], tz + norm_meta = self.window_norm_meta[self.window_keys.index(w)] + return (self.window_arrays[self.window_keys.index(w)], tz, norm_meta) def _read_img_window( - self, img: ImageArray, ch_idx: list[str], tz: int - ) -> tuple[tuple[torch.Tensor], tuple[str, int, int]]: + self, img: ImageArray, ch_idx: list[int], tz: int + ) -> tuple[list[Tensor], HCSStackIndex]: """Read image window as tensor. :param ImageArray img: NGFF image array - :param list[int] channels: list of channel indices to read, + :param list[int] ch_idx: list of channel indices to read, output channel ordering will reflect the sequence :param int tz: window index within the FOV, counted Z-first - :return tuple[torch.Tensor], tuple[str, int, int]: - tuple of (C=1, Z, Y, X) image tensors, + :return list[Tensor], HCSStackIndex: + list of (C=1, Z, Y, X) image tensors, tuple of image name, time index, and Z index """ zs = img.shape[-3] - self.z_window_size + 1 @@ -203,8 +162,8 @@ def __len__(self) -> int: return self._max_window def _stack_channels( - self, sample_images: list[dict[str, torch.Tensor]], key: str - ) -> torch.Tensor: + self, sample_images: list[dict[str, Tensor]] | dict[str, Tensor], key: str + ) -> Tensor | list[Tensor]: """Stack single-channel images into a multi-channel tensor.""" if not isinstance(sample_images, list): return torch.stack([sample_images[ch][0] for ch in self.channels[key]]) @@ -215,7 +174,7 @@ def _stack_channels( ] def __getitem__(self, index: int) -> Sample: - img, tz = self._find_window(index) + img, tz, norm_meta = self._find_window(index) ch_names = self.channels["source"].copy() ch_idx = self.source_ch_idx.copy() if self.target_ch_idx is not None: @@ -228,6 +187,8 @@ def __getitem__(self, index: int) -> Sample: # since adding a reference to a tensor does not copy # maybe write a weight map in preprocessing to use more information? sample_images["weight"] = sample_images[self.channels["target"][0]] + if norm_meta is not None: + sample_images["norm_meta"] = norm_meta if self.transform: sample_images = self.transform(sample_images) # if isinstance(sample_images, list): @@ -237,15 +198,12 @@ def __getitem__(self, index: int) -> Sample: sample = { "index": sample_index, "source": self._stack_channels(sample_images, "source"), + "norm_meta": norm_meta, } if self.target_ch_idx is not None: sample["target"] = self._stack_channels(sample_images, "target") return sample - def __del__(self): - """Close the Zarr store when the dataset instance gets GC'ed.""" - self.positions[0].zgroup.store.close() - class MaskTestDataset(SlidingWindowDataset): """Torch dataset where each element is a window of @@ -258,7 +216,7 @@ class MaskTestDataset(SlidingWindowDataset): :param ChannelMap channels: source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] transform: + :param Callable[[dict[str, Tensor]], dict[str, Tensor]] transform: a callable that transforms data, defaults to None """ @@ -267,7 +225,7 @@ def __init__( positions: list[Position], channels: ChannelMap, z_window_size: int, - transform: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, + transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, ground_truth_masks: str = None, ) -> None: super().__init__(positions, channels, z_window_size, transform) @@ -311,18 +269,16 @@ class HCSDataModule(LightningDataModule): defaults to "2.5D" :param tuple[int, int] yx_patch_size: patch size in (Y, X), defaults to (256, 256) - :param Optional[list[MapTransform]] augmentations: MONAI dictionary transforms - applied to the training set, defaults to None (no augmentation) + :param list[MapTransform] normalizations: MONAI dictionary transforms + applied to selected channels, defaults to [] (no normalization) + :param list[MapTransform] augmentations: MONAI dictionary transforms + applied to the training set, defaults to [] (no augmentation) :param bool caching: whether to decompress all the images and cache the result, will store in ``/tmp/$SLURM_JOB_ID/`` if available, defaults to False - :param bool normalize_source: whether to normalize the source channel, - defaults to False :param Optional[Path] ground_truth_masks: path to the ground truth masks, used in the test stage to compute segmentation metrics, defaults to None - :param Optional[float] predict_scale_source: scale the source channel intensity, - defaults to None (no scaling) """ def __init__( @@ -334,13 +290,12 @@ def __init__( split_ratio: float = 0.8, batch_size: int = 16, num_workers: int = 8, - architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D"] = "2.5D", + architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D", yx_patch_size: tuple[int, int] = (256, 256), - augmentations: Optional[list[MapTransform]] = None, + normalizations: list[MapTransform] = [], + augmentations: list[MapTransform] = [], caching: bool = False, - normalize_source: bool = False, ground_truth_masks: Optional[Path] = None, - predict_scale_source: Optional[float] = None, ): super().__init__() self.data_path = Path(data_path) @@ -348,25 +303,32 @@ def __init__( self.target_channel = _ensure_channel_list(target_channel) self.batch_size = batch_size self.num_workers = num_workers - self.target_2d = False if architecture in ["2.2D", "3D"] else True + self.target_2d = False if architecture in ["2.2D", "3D", "fcmae"] else True self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size + self.normalizations = normalizations self.augmentations = augmentations self.caching = caching - self.normalize_source = normalize_source self.ground_truth_masks = ground_truth_masks - self.tmp_zarr = None - if predict_scale_source is not None: - if not normalize_source: - raise ValueError( - "Intensity scaling must be applied to normalized source channels." - ) - if predict_scale_source <= 0: - raise ValueError( - f"Intensity scaling {predict_scale_source} should be positive." - ) - self.predict_scale_source = predict_scale_source + self.prepare_data_per_node = True + + @property + def cache_path(self): + return Path( + tempfile.gettempdir(), + os.getenv("SLURM_JOB_ID", "viscy_cache"), + self.data_path.name, + ) + + def _data_log_path(self) -> Path: + log_dir = Path.cwd() + if self.trainer: + if self.trainer.logger: + if self.trainer.logger.log_dir: + log_dir = Path(self.trainer.logger.log_dir) + log_dir.mkdir(parents=True, exist_ok=True) + return log_dir / "data.log" def prepare_data(self): if not self.caching: @@ -378,20 +340,11 @@ def prepare_data(self): console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) logger.addHandler(console_handler) - os.mkdir(self.trainer.logger.log_dir) - file_handler = logging.FileHandler( - os.path.join(self.trainer.logger.log_dir, "data.log") - ) + file_handler = logging.FileHandler(self._data_log_path()) file_handler.setLevel(logging.DEBUG) logger.addHandler(file_handler) - # cache in temporary directory - self.tmp_zarr = os.path.join( - tempfile.gettempdir(), - os.getenv("SLURM_JOB_ID"), - os.path.basename(self.data_path), - ) - logger.info(f"Caching dataset at {self.tmp_zarr}.") - tmp_store = zarr.NestedDirectoryStore(self.tmp_zarr) + logger.info(f"Caching dataset at {self.cache_path}.") + tmp_store = zarr.NestedDirectoryStore(self.cache_path) with open_ome_zarr(self.data_path, mode="r") as lazy_plate: _, skipped, _ = zarr.copy( lazy_plate.zgroup, @@ -418,31 +371,22 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]): else: raise NotImplementedError(f"{stage} stage") - def _setup_eval(self, dataset_settings: dict) -> tuple[Plate, MapTransform]: - """Setup stages where the target is available (evaluating performance).""" - dataset_settings["channels"]["target"] = self.target_channel - data_path = self.tmp_zarr if self.tmp_zarr else self.data_path - plate = open_ome_zarr(data_path, mode="r") - # disable metadata tracking in MONAI for performance - set_track_meta(False) - # define training stage transforms - norm_keys = self.target_channel.copy() - if self.normalize_source: - norm_keys += self.source_channel - normalize_transform = NormalizeSampled( - norm_keys, - plate.zattrs["normalization"], - ) - return plate, normalize_transform - def _setup_fit(self, dataset_settings: dict): """Set up the training and validation datasets.""" - plate, normalize_transform = self._setup_eval(dataset_settings) + # Setup the transformations + # TODO: These have a fixed order for now... (normalization->augmentation->fit_transform) fit_transform = self._fit_transform() train_transform = Compose( - [normalize_transform] + self._train_transform() + fit_transform + self.normalizations + self._train_transform() + fit_transform ) - val_transform = Compose([normalize_transform] + fit_transform) + val_transform = Compose(self.normalizations + fit_transform) + + dataset_settings["channels"]["target"] = self.target_channel + data_path = self.cache_path if self.caching else self.data_path + plate = open_ome_zarr(data_path, mode="r") + + # disable metadata tracking in MONAI for performance + set_track_meta(False) # shuffle positions, randomness is handled globally positions = [pos for _, pos in plate.positions()] shuffled_indices = torch.randperm(len(positions)) @@ -464,25 +408,31 @@ def _setup_fit(self, dataset_settings: dict): **train_dataset_settings, ) self.val_dataset = SlidingWindowDataset( - positions[num_train_fovs:], transform=val_transform, **dataset_settings + positions[num_train_fovs:], + transform=val_transform, + **dataset_settings, ) def _setup_test(self, dataset_settings: dict): """Set up the test stage.""" if self.batch_size != 1: logging.warning(f"Ignoring batch size {self.batch_size} in test stage.") - plate, normalize_transform = self._setup_eval(dataset_settings) + + dataset_settings["channels"]["target"] = self.target_channel + data_path = self.cache_path if self.caching else self.data_path + plate = open_ome_zarr(data_path, mode="r") + test_transform = Compose(self.normalizations) if self.ground_truth_masks: self.test_dataset = MaskTestDataset( [p for _, p in plate.positions()], - transform=normalize_transform, + transform=test_transform, ground_truth_masks=self.ground_truth_masks, **dataset_settings, ) else: self.test_dataset = SlidingWindowDataset( [p for _, p in plate.positions()], - transform=normalize_transform, + transform=test_transform, **dataset_settings, ) @@ -505,16 +455,7 @@ def _setup_predict(self, dataset_settings: dict): positions = [plate[fov_name]] elif isinstance(dataset, Plate): positions = [p for _, p in dataset.positions()] - norm_meta = dataset.zattrs["normalization"].copy() - if self.predict_scale_source is not None: - for ch in self.source_channel: - # FIXME: hard-coded key - norm_meta[ch]["dataset_statistics"]["iqr"] /= self.predict_scale_source - predict_transform = ( - NormalizeSampled(self.source_channel, norm_meta) - if self.normalize_source - else None - ) + predict_transform = Compose(self.normalizations) self.predict_dataset = SlidingWindowDataset( positions=positions, transform=predict_transform, @@ -527,7 +468,7 @@ def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample if self.trainer: if self.trainer.predicting: predicting = True - if predicting or isinstance(batch, torch.Tensor): + if predicting or isinstance(batch, Tensor): # skipping example input array return batch if self.target_2d: @@ -543,7 +484,9 @@ def train_dataloader(self): num_workers=self.num_workers, shuffle=True, persistent_workers=bool(self.num_workers), + prefetch_factor=4 if self.num_workers else None, collate_fn=_collate_samples, + drop_last=True, ) def val_dataloader(self): @@ -552,6 +495,7 @@ def val_dataloader(self): batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, + prefetch_factor=4 if self.num_workers else None, persistent_workers=bool(self.num_workers), ) diff --git a/viscy/data/livecell.py b/viscy/data/livecell.py new file mode 100644 index 00000000..5d83f099 --- /dev/null +++ b/viscy/data/livecell.py @@ -0,0 +1,98 @@ +import json +from pathlib import Path + +import torch +from lightning.pytorch import LightningDataModule +from monai.transforms import Compose, Transform +from tifffile import imread +from torch.utils.data import DataLoader, Dataset + +from viscy.data.typing import Sample + + +class LiveCellDataset(Dataset): + """ + LiveCell dataset. + + :param list[Path] images: List of paths to single-page, single-channel TIFF files. + :param Transform | Compose transform: Transform to apply to the dataset + """ + + def __init__(self, images: list[Path], transform: Transform | Compose) -> None: + self.images = images + self.transform = transform + + def __len__(self) -> int: + return len(self.images) + + def __getitem__(self, idx: int) -> Sample: + image = imread(self.images[idx])[None, None] + image = torch.from_numpy(image).to(torch.float32) + image = self.transform(image) + return {"source": image, "target": image} + + +class LiveCellDataModule(LightningDataModule): + def __init__( + self, + train_val_images: Path, + train_annotations: Path, + val_annotations: Path, + train_transforms: list[Transform], + val_transforms: list[Transform], + batch_size: int = 16, + num_workers: int = 8, + ) -> None: + super().__init__() + self.train_val_images = Path(train_val_images) + if not self.train_val_images.is_dir(): + raise NotADirectoryError(str(train_val_images)) + self.train_annotations = Path(train_annotations) + if not self.train_annotations.is_file(): + raise FileNotFoundError(str(train_annotations)) + self.val_annotations = Path(val_annotations) + if not self.val_annotations.is_file(): + raise FileNotFoundError(str(val_annotations)) + self.train_transforms = Compose(train_transforms) + self.val_transforms = Compose(val_transforms) + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage: str) -> None: + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + self._setup_fit() + + def _parse_image_names(self, annotations: Path) -> list[Path]: + with open(annotations) as f: + images = [f["file_name"] for f in json.load(f)["images"]] + return sorted(images) + + def _setup_fit(self) -> None: + train_images = self._parse_image_names(self.train_annotations) + val_images = self._parse_image_names(self.val_annotations) + self.train_dataset = LiveCellDataset( + [self.train_val_images / f for f in train_images], + transform=self.train_transforms, + ) + self.val_dataset = LiveCellDataset( + [self.train_val_images / f for f in val_images], + transform=self.val_transforms, + ) + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + ) diff --git a/viscy/data/typing.py b/viscy/data/typing.py new file mode 100644 index 00000000..1eabba75 --- /dev/null +++ b/viscy/data/typing.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple, Sequence, TypedDict, TypeVar + +if TYPE_CHECKING: + from torch import Tensor + +T = TypeVar("T") +OneOrSeq = T | Sequence[T] + + +class LevelNormStats(TypedDict): + mean: float + std: float + median: float + iqr: float + + +class ChannelNormStats(TypedDict): + dataset_statistics: LevelNormStats + fov_statistics: LevelNormStats + + +NormMeta = dict[str, ChannelNormStats] + + +class HCSStackIndex(NamedTuple): + """HCS stack index.""" + + # name of the image array, e.g. "A/1/0/0" + image: str + time: int + z: int + + +class Sample(TypedDict, total=False): + """ + Image sample type for mini-batches. + All fields are optional. + """ + + index: HCSStackIndex + # Image data + source: OneOrSeq[Tensor] + target: OneOrSeq[Tensor] + weight: OneOrSeq[Tensor] + # Instance segmentation masks + labels: OneOrSeq[Tensor] + # None: not available + norm_meta: NormMeta + + +class _ChannelMap(TypedDict): + """Source channel names.""" + + source: OneOrSeq[str] + + +class ChannelMap(_ChannelMap, total=False): + """Source and target channel names.""" + + # TODO: use typing.NotRequired when upgrading to Python 3.11 + target: OneOrSeq[str] diff --git a/viscy/evaluation/evaluation_metrics.py b/viscy/evaluation/evaluation_metrics.py index 589370bd..fb83c06b 100644 --- a/viscy/evaluation/evaluation_metrics.py +++ b/viscy/evaluation/evaluation_metrics.py @@ -1,4 +1,5 @@ """Metrics for model evaluation""" + from typing import Sequence, Union from warnings import warn diff --git a/viscy/light/engine.py b/viscy/light/engine.py index b263998d..3dc92b74 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -10,7 +10,7 @@ from monai.optimizers import WarmupCosineSchedule from monai.transforms import DivisiblePad from skimage.exposure import rescale_intensity -from torch import nn +from torch import Tensor, nn from torch.nn import functional as F from torch.optim.lr_scheduler import ConstantLR from torchmetrics.functional import ( @@ -25,8 +25,9 @@ structural_similarity_index_measure, ) +from viscy.data.hcs import Sample from viscy.evaluation.evaluation_metrics import mean_average_precision, ms_ssim_25d -from viscy.light.data import Sample +from viscy.unet.networks.fcmae import FullyConvolutionalMAE from viscy.unet.networks.Unet2D import Unet2d from viscy.unet.networks.Unet21D import Unet21d from viscy.unet.networks.Unet25D import Unet25d @@ -43,6 +44,7 @@ # same class with out_stack_depth > 1 "2.2D": Unet21d, "2.5D": Unet25d, + "fcmae": FullyConvolutionalMAE, } @@ -117,11 +119,12 @@ class VSUNet(LightningModule): def __init__( self, - architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D"], + architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"], model_config: dict = {}, loss_function: Union[nn.Module, MixedLoss] = None, lr: float = 1e-3, schedule: Literal["WarmupCosine", "Constant"] = "Constant", + freeze_encoder: bool = False, ckpt_path: str = None, log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, @@ -162,13 +165,13 @@ def __init__( self.test_cellpose_model_path = test_cellpose_model_path self.test_cellpose_diameter = test_cellpose_diameter self.test_evaluate_cellpose = test_evaluate_cellpose - + self.freeze_encoder = freeze_encoder if ckpt_path is not None: self.load_state_dict( torch.load(ckpt_path)["state_dict"] ) # loading only weights - def forward(self, x) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: return self.model(x) def training_step(self, batch: Sample, batch_idx: int): @@ -191,12 +194,12 @@ def training_step(self, batch: Sample, batch_idx: int): ) return loss - def validation_step(self, batch: Sample, batch_idx: int): + def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source = batch["source"] target = batch["target"] pred = self.forward(source) loss = self.loss_function(pred, target) - self.log("loss/validate", loss, sync_dist=True) + self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target, pred)) @@ -233,7 +236,7 @@ def test_step(self, batch: Sample, batch_idx: int): else: self._log_segmentation_metrics(None, None) - def _log_regression_metrics(self, pred: torch.Tensor, target: torch.Tensor): + def _log_regression_metrics(self, pred: Tensor, target: Tensor): # paired image translation metrics self.log_dict( { @@ -256,7 +259,7 @@ def _log_regression_metrics(self, pred: torch.Tensor, target: torch.Tensor): on_epoch=True, ) - def _cellpose_predict(self, pred: torch.Tensor, name: str) -> torch.ShortTensor: + def _cellpose_predict(self, pred: Tensor, name: str) -> torch.ShortTensor: pred_labels_np = self.cellpose_model.eval( pred.cpu().numpy(), channels=[0, 0], diameter=self.test_cellpose_diameter )[0].astype(np.int16) @@ -275,19 +278,19 @@ def _log_segmentation_metrics( self.log_dict( { # semantic segmentation - "test_metrics/accuracy": accuracy( - pred_binary, target_binary, task="binary" - ) - if compute - else -1, - "test_metrics/dice": dice(pred_binary, target_binary) - if compute - else -1, - "test_metrics/jaccard": jaccard_index( - pred_binary, target_binary, task="binary" - ) - if compute - else -1, + "test_metrics/accuracy": ( + accuracy(pred_binary, target_binary, task="binary") + if compute + else -1 + ), + "test_metrics/dice": ( + dice(pred_binary, target_binary) if compute else -1 + ), + "test_metrics/jaccard": ( + jaccard_index(pred_binary, target_binary, task="binary") + if compute + else -1 + ), "test_metrics/mAP": coco_metrics["map"] if compute else -1, "test_metrics/mAP_50": coco_metrics["map_50"] if compute else -1, "test_metrics/mAP_75": coco_metrics["map_75"] if compute else -1, @@ -336,6 +339,9 @@ def on_predict_start(self): self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) def configure_optimizers(self): + if self.freeze_encoder: + self.model: FullyConvolutionalMAE + self.model.encoder.requires_grad_(False) optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr) if self.schedule == "WarmupCosine": scheduler = WarmupCosineSchedule( @@ -350,7 +356,7 @@ def configure_optimizers(self): ) return [optimizer], [scheduler] - def _detach_sample(self, imgs: Sequence[torch.Tensor]): + def _detach_sample(self, imgs: Sequence[Tensor]): num_samples = min(imgs[0].shape[0], self.log_samples_per_batch) return [ [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] @@ -374,3 +380,71 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" ) + + +class FcmaeUNet(VSUNet): + def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): + super().__init__(architecture="fcmae", **kwargs) + self.fit_mask_ratio = fit_mask_ratio + self.validation_losses = [] + + def forward(self, x: Tensor, mask_ratio: float = 0.0): + return self.model(x, mask_ratio) + + def forward_fit(self, batch: Sample) -> tuple[Tensor]: + source = batch["source"] + target = batch["target"] + pred, mask = self.forward(source, mask_ratio=self.fit_mask_ratio) + loss = F.mse_loss(pred, target, reduction="none") + loss = (loss.mean(2) * mask).sum() / mask.sum() + return source, target, pred, mask, loss + + def training_step(self, batch: Sequence[Sample], batch_idx: int): + losses = [] + batch_size = 0 + for b in batch: + source, target, pred, mask, loss = self.forward_fit(b) + losses.append(loss) + batch_size += source.shape[0] + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target * mask.unsqueeze(2), pred)) + ) + loss_step = torch.stack(losses).mean() + self.log( + "loss/train", + loss_step.to(self.device), + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + batch_size=batch_size, + ) + return loss_step + + def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source, target, pred, mask, loss = self.forward_fit(batch) + if dataloader_idx + 1 > len(self.validation_losses): + self.validation_losses.append([]) + self.validation_losses[dataloader_idx].append(loss.detach()) + self.log( + f"loss/val/{dataloader_idx}", + loss.to(self.device), + sync_dist=True, + batch_size=source.shape[0], + ) + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target * mask.unsqueeze(2), pred)) + ) + + def on_validation_epoch_end(self): + super().on_validation_epoch_end() + # average within each dataloader + loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses] + self.log( + "loss/validate", + torch.tensor(loss_means).mean().to(self.device), + sync_dist=True, + ) diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py index a6ae88cb..7a58009c 100644 --- a/viscy/light/predict_writer.py +++ b/viscy/light/predict_writer.py @@ -9,7 +9,7 @@ from lightning.pytorch.callbacks import BasePredictionWriter from numpy.typing import DTypeLike, NDArray -from viscy.light.data import HCSDataModule, Sample +from viscy.data.hcs import HCSDataModule, Sample __all__ = ["HCSPredictionWriter"] _logger = logging.getLogger("lightning.pytorch") diff --git a/viscy/preprocessing/generate_masks.py b/viscy/preprocessing/generate_masks.py index f88f8fbe..491bc406 100644 --- a/viscy/preprocessing/generate_masks.py +++ b/viscy/preprocessing/generate_masks.py @@ -1,4 +1,5 @@ """Generate masks from sum of flurophore channels""" + import iohub.ngff as ngff import viscy.utils.aux_utils as aux_utils diff --git a/viscy/preprocessing/preprocessing.md b/viscy/preprocessing/preprocessing.md index 76d508c5..809b456f 100644 --- a/viscy/preprocessing/preprocessing.md +++ b/viscy/preprocessing/preprocessing.md @@ -87,11 +87,17 @@ The statistics are added as dictionaries into the .zattrs file. An example of pl } ``` -FOV level statistics added to every position: +FOV level statistics added to every position as well as the dataset_statistics to read dataset statistics: ```json "normalization": { "Deconvolved-Nuc": { + "dataset_statistics": { + "iqr": 149.7620086669922, + "mean": 262.2070617675781, + "median": 65.5246353149414, + "std": 890.0471801757812 + }, "fov_statistics": { "iqr": 450.4745788574219, "mean": 486.3854064941406, @@ -99,7 +105,13 @@ FOV level statistics added to every position: "std": 976.02392578125 } }, - "Phase3D": { + "Phase3D": { + "dataset_statistics": { + "iqr": 0.0011349652777425945, + "mean": -1.9603044165705796e-06, + "median": 3.388232289580628e-05, + "std": 0.005480962339788675 + }, "fov_statistics": { "iqr": 0.006403466919437051, "mean": 0.0010083537781611085, diff --git a/viscy/scripts/load_ctmc_v1.py b/viscy/scripts/load_ctmc_v1.py new file mode 100644 index 00000000..d4326b81 --- /dev/null +++ b/viscy/scripts/load_ctmc_v1.py @@ -0,0 +1,84 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +from monai.transforms import ( + CenterSpatialCropd, + NormalizeIntensityd, + RandAdjustContrastd, + RandAffined, + RandFlipd, + RandGaussianNoised, + RandGaussianSmoothd, + RandScaleIntensityd, +) +from tqdm import tqdm + +from viscy.data.ctmc_v1 import CTMCv1DataModule + +# %% +channel = "DIC" +data_path = Path("/hpc/reference/imaging/ctmc") + +normalize_transform = NormalizeIntensityd(keys=[channel], channel_wise=True) +crop_transform = CenterSpatialCropd(keys=[channel], roi_size=[1, 224, 224]) + +data = CTMCv1DataModule( + train_data_path=data_path / "CTMCV1_test.zarr", + val_data_path=data_path / "CTMCV1_train.zarr", + train_transforms=[ + normalize_transform, + RandAffined( + keys=[channel], + rotate_range=[3.14, 0.0, 0.0], + scale_range=[0.0, [-0.6, 0.1], [-0.6, 0.1]], + prob=0.8, + padding_mode="zeros", + ), + RandFlipd(keys=[channel], prob=0.5, spatial_axis=(1, 2)), + RandAdjustContrastd(keys=[channel], prob=0.5, gamma=(0.8, 1.2)), + RandScaleIntensityd(keys=[channel], factors=0.3, prob=0.5), + RandGaussianNoised(keys=[channel], prob=0.5, mean=0.0, std=0.2), + RandGaussianSmoothd( + keys=[channel], + sigma_x=(0.05, 0.3), + sigma_y=(0.05, 0.3), + sigma_z=(0.05, 0.0), + prob=0.5, + ), + crop_transform, + ], + val_transforms=[normalize_transform, crop_transform], + batch_size=32, + num_workers=0, + channel_name=channel, +) + +# %% +data.setup("fit") +dmt = data.train_dataloader() +dmv = data.val_dataloader() + +# %% +for batch in tqdm(dmt): + img = batch["source"] + img[:, :, :, 32:64, 32:64] = 0 + f, ax = plt.subplots(5, 5, figsize=(15, 15)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + +# %% +for batch in tqdm(dmv): + img = batch["source"] + f, ax = plt.subplots(5, 5, figsize=(15, 15)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + + +# %% diff --git a/viscy/scripts/load_livecell.py b/viscy/scripts/load_livecell.py new file mode 100644 index 00000000..cfaf2dfe --- /dev/null +++ b/viscy/scripts/load_livecell.py @@ -0,0 +1,85 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +from monai.transforms import ( + CenterSpatialCrop, + NormalizeIntensity, + RandAdjustContrast, + RandAffine, + RandFlip, + RandGaussianNoise, + RandGaussianSmooth, + RandScaleIntensity, + RandSpatialCrop, +) +from tqdm import tqdm + +from viscy.data.livecell import LiveCellDataModule + +# %% +data_path = Path("/hpc/reference/imaging/livecell") + +normalize_transform = NormalizeIntensity(channel_wise=True) +crop_transform = CenterSpatialCrop(roi_size=[1, 224, 224]) + +data = LiveCellDataModule( + train_val_images=data_path / "images" / "livecell_train_val_images", + train_annotations=data_path + / "annotations" + / "livecell_coco_train_images_only.json", + val_annotations=data_path / "annotations" / "livecell_coco_val_images_only.json", + train_transforms=[ + normalize_transform, + RandSpatialCrop(roi_size=[1, 384, 384]), + RandAffine( + rotate_range=[3.14, 0.0, 0.0], + scale_range=[0.0, [-0.2, 0.8], [-0.2, 0.8]], + prob=0.8, + padding_mode="zeros", + ), + RandFlip(prob=0.5, spatial_axis=(1, 2)), + RandAdjustContrast(prob=0.5, gamma=(0.8, 1.2)), + RandScaleIntensity(factors=0.3, prob=0.5), + RandGaussianNoise(prob=0.5, mean=0.0, std=0.3), + RandGaussianSmooth( + sigma_x=(0.05, 0.3), + sigma_y=(0.05, 0.3), + sigma_z=(0.05, 0.0), + prob=0.5, + ), + crop_transform, + ], + val_transforms=[normalize_transform, crop_transform], + batch_size=16, + num_workers=0, +) + +# %% +data.setup("fit") +dmt = data.train_dataloader() +dmv = data.val_dataloader() + +# %% +for batch in tqdm(dmt): + img = batch["target"] + img[:, :, :, 32:64, 32:64] = 0 + f, ax = plt.subplots(4, 4, figsize=(15, 15)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + +# %% +for batch in tqdm(dmv): + img = batch["source"] + f, ax = plt.subplots(4, 4, figsize=(12, 12)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + + +# %% diff --git a/viscy/scripts/profiling.py b/viscy/scripts/profiling.py index 0c947f45..a0c3ca6d 100644 --- a/viscy/scripts/profiling.py +++ b/viscy/scripts/profiling.py @@ -2,7 +2,7 @@ from profilehooks import profile -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule dataset = "/path/to/dataset.zarr" diff --git a/viscy/transforms.py b/viscy/transforms.py index cb3d2622..88e7f738 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -3,13 +3,20 @@ from typing import Sequence, Union from monai.transforms import ( + MapTransform, RandAdjustContrastd, RandAffined, RandGaussianNoised, RandGaussianSmoothd, + RandomizableTransform, RandScaleIntensityd, RandWeightedCropd, ) +from monai.transforms.transform import Randomizable +from numpy.random.mtrand import RandomState as RandomState +from typing_extensions import Iterable, Literal + +from viscy.data.typing import Sample class RandWeightedCropd(RandWeightedCropd): @@ -118,3 +125,60 @@ def __init__( sigma_z=sigma_z, **kwargs, ) + + +class NormalizeSampled(MapTransform): + """ + Normalize the sample + :param Union[str, Iterable[str]] keys: keys to normalize + :param str fov: fov path with respect to Plate + :param str subtrahend: subtrahend for normalization, defaults to "mean" + :param str divisor: divisor for normalization, defaults to "std" + """ + + def __init__( + self, + keys: Union[str, Iterable[str]], + level: Literal["fov_statistics", "dataset_statistics"], + subtrahend="mean", + divisor="std", + ) -> None: + super().__init__(keys, allow_missing_keys=False) + self.subtrahend = subtrahend + self.divisor = divisor + self.level = level + + # TODO: need to implement the case where the preprocessing already exists + def __call__(self, sample: Sample) -> Sample: + for key in self.keys: + level_meta = sample["norm_meta"][key][self.level] + subtrahend_val = level_meta[self.subtrahend] + divisor_val = level_meta[self.divisor] + 1e-8 # avoid div by zero + sample[key] = (sample[key] - subtrahend_val) / divisor_val + return sample + + def _normalize(): + NotImplementedError("_normalization() not implemented") + + +class RandInvertIntensityd(MapTransform, RandomizableTransform): + """ + Randomly invert the intensity of the image. + """ + + def __init__(self, keys: Union[str, Iterable[str]], prob: float = 0.1) -> None: + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) + + def __call__(self, sample: Sample) -> Sample: + self.randomize(None) + for key in self.keys: + if key in sample: + sample[key] = -sample[key] + return sample + + def set_random_state( + self, seed: int | None = None, state: RandomState | None = None + ) -> Randomizable: + super().set_random_state(seed, state) + return self diff --git a/viscy/unet/networks/Unet21D.py b/viscy/unet/networks/Unet21D.py index 7c32e34b..c4320240 100644 --- a/viscy/unet/networks/Unet21D.py +++ b/viscy/unet/networks/Unet21D.py @@ -1,18 +1,18 @@ -from typing import Callable, Literal, Optional, Sequence, Union +from typing import Callable, Literal, Sequence import timm import torch from monai.networks.blocks import Convolution, ResidualUnit, UpSample from monai.networks.blocks.dynunet_block import get_conv_layer from monai.networks.utils import normal_init -from torch import nn +from torch import Tensor, nn def icnr_init( conv: nn.Module, upsample_factor: int, upsample_dims: int, - init=nn.init.kaiming_normal_, + init: Callable = nn.init.kaiming_normal_, ): """ ICNR initialization for 2D/3D kernels adapted from Aitken et al.,2017 , @@ -45,7 +45,7 @@ def _get_convnext_stage( in_channels: int, out_channels: int, depth: int, - upsample_factor: Optional[int] = None, + upsample_factor: int | None = None, ) -> nn.Module: stage = timm.models.convnext.ConvNeXtStage( in_chs=in_channels, @@ -83,7 +83,7 @@ def __init__( stride=kernel_size, ) - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.conv(x) b, c, d, h, w = x.shape # project Z/depth into channels @@ -101,7 +101,7 @@ def __init__( mode: Literal["deconv", "pixelshuffle"], conv_blocks: int, norm_name: str, - upsample_pre_conv: Optional[Union[Literal["default"], Callable]], + upsample_pre_conv: Literal["default"] | Callable | None, ) -> None: super().__init__() spatial_dims = 2 @@ -145,11 +145,11 @@ def __init__( upsample_factor=conv_weight_init_factor, ) - def forward(self, inp: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: + def forward(self, inp: Tensor, skip: Tensor) -> Tensor: """ - :param torch.Tensor inp: Low resolution features - :param torch.Tensor skip: High resolution skip connection features - :return torch.Tensor: High resolution features + :param Tensor inp: Low resolution features + :param Tensor skip: High resolution skip connection features + :return Tensor: High resolution features """ inp = self.upsample(inp) inp = torch.cat([inp, skip], dim=1) @@ -192,7 +192,7 @@ def __init__( self.out = nn.PixelShuffle(2) self.out_stack_depth = out_stack_depth - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: x = self.upsample(x) d = self.out_stack_depth + 2 b, c, h, w = x.shape @@ -209,7 +209,7 @@ class UnsqueezeHead(nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: x = x.unsqueeze(2) return x @@ -222,7 +222,7 @@ def __init__( mode: Literal["deconv", "pixelshuffle"], conv_blocks: int, strides: list[int], - upsample_pre_conv: Optional[Union[Literal["default"], Callable]], + upsample_pre_conv: Literal["default"] | Callable | None, ) -> None: super().__init__() self.decoder_stages = nn.ModuleList([]) @@ -240,7 +240,7 @@ def __init__( ) self.decoder_stages.append(stage) - def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor: + def forward(self, features: Sequence[Tensor]) -> Tensor: feat = features[0] # padding features.append(None) @@ -328,7 +328,7 @@ def num_blocks(self) -> int: """2-times downscaling factor of the smallest feature map""" return 6 - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: x = self.stem(x) x: list = self.encoder_stages(x) x.reverse() diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py new file mode 100644 index 00000000..97771365 --- /dev/null +++ b/viscy/unet/networks/fcmae.py @@ -0,0 +1,422 @@ +""" +Fully Convolutional Masked Autoencoder as described in ConvNeXt V2 +based on the official JAX example in +https://github.com/facebookresearch/ConvNeXt-V2/blob/main/TRAINING.md#implementing-fcmae-with-masked-convolution-in-jax +and timm's dense implementation of the encoder in ``timm.models.convnext`` +""" + +from typing import Sequence + +import torch +from timm.models.convnext import ( + Downsample, + DropPath, + GlobalResponseNormMlp, + LayerNorm2d, + create_conv2d, + trunc_normal_, +) +from torch import BoolTensor, Size, Tensor, nn + +from viscy.unet.networks.Unet21D import PixelToVoxelHead, Unet2dDecoder, UnsqueezeHead + + +def _init_weights(module: nn.Module) -> None: + """Initialize weights of the given module.""" + if isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + +def generate_mask( + target: Size, stride: int, mask_ratio: float, device: str +) -> BoolTensor: + """ + :param Size target: target shape + :param int stride: total stride + :param float mask_ratio: ratio of the pixels to mask + :return BoolTensor: boolean mask (B1HW) + """ + m_height = target[-2] // stride + m_width = target[-1] // stride + mask_numel = m_height * m_width + masked_elements = int(mask_numel * mask_ratio) + mask = torch.rand(target[0], mask_numel, device=device).argsort(1) < masked_elements + return mask.reshape(target[0], 1, m_height, m_width) + + +def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: + """ + :param BoolTensor mask: low-resolution boolean mask (B1HW) + :param Size target: target size (BCHW) + :return BoolTensor: upsampled boolean mask (B1HW) + """ + if target[-2:] != mask.shape[-2:]: + if not all(i % j == 0 for i, j in zip(target, mask.shape)): + raise ValueError( + f"feature map shape {target} must be divisible by " + f"mask shape {mask.shape}." + ) + mask = mask.repeat_interleave( + target[-2] // mask.shape[-2], dim=-2 + ).repeat_interleave(target[-1] // mask.shape[-1], dim=-1) + return mask + + +def masked_patchify(features: Tensor, unmasked: BoolTensor | None = None) -> Tensor: + """ + :param Tensor features: input image features (BCHW) + :param BoolTensor unmasked: boolean foreground mask (B1HW) + :return Tensor: masked channel-last features (BLC, L = H * W * mask_ratio) + """ + if unmasked is None: + return features.flatten(2).permute(0, 2, 1) + b, c = features.shape[:2] + # (B, C, H, W) -> (B, H, W, C) + features = features.permute(0, 2, 3, 1) + # (B, H, W, C) -> (B * L, C) -> (B, L, C) + features = features[unmasked[:, 0]].reshape(b, -1, c) + return features + + +def masked_unpatchify( + features: Tensor, out_shape: Size, unmasked: BoolTensor | None = None +) -> Tensor: + """ + :param Tensor features: dense channel-last features (BLC) + :param Size out_shape: output shape (BCHW) + :param BoolTensor | None unmasked: boolean foreground mask, defaults to None + :return Tensor: masked features (BCHW) + """ + if unmasked is None: + return features.permute(0, 2, 1).reshape(out_shape) + b, c, w, h = out_shape + out = torch.zeros((b, w, h, c), device=features.device, dtype=features.dtype) + # (B, L, C) -> (B * L, C) + features = features.reshape(-1, c) + out[unmasked[:, 0]] = features + # (B, H, W, C) -> (B, C, H, W) + return out.permute(0, 3, 1, 2) + + +class MaskedConvNeXtV2Block(nn.Module): + """Masked ConvNeXt V2 Block. + + :param int in_channels: input channels + :param int | None out_channels: output channels, defaults to None + :param int kernel_size: depth-wise convolution kernel size, defaults to 7 + :param int stride: downsample stride, defaults to 1 + :param int mlp_ratio: MLP expansion ratio, defaults to 4 + :param float drop_path: drop path rate, defaults to 0.0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + kernel_size: int = 7, + stride: int = 1, + mlp_ratio: int = 4, + drop_path: float = 0.0, + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + self.dwconv = create_conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + depthwise=True, + ) + self.layernorm = nn.LayerNorm(out_channels) + mid_channels = mlp_ratio * out_channels + self.mlp = GlobalResponseNormMlp( + in_features=out_channels, + hidden_features=mid_channels, + out_features=out_channels, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + if in_channels != out_channels or stride > 1: + self.shortcut = Downsample(in_channels, out_channels, stride=stride) + else: + self.shortcut = nn.Identity() + + def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: + """ + :param Tensor x: input tensor (BCHW) + :param BoolTensor | None unmasked: boolean foreground mask, defaults to None + :return Tensor: output tensor (BCHW) + """ + shortcut = self.shortcut(x) + if unmasked is not None: + x *= unmasked + x = self.dwconv(x) + if unmasked is not None: + x *= unmasked + out_shape = x.shape + x = masked_patchify(x, unmasked=unmasked) + x = self.layernorm(x) + x = self.mlp(x.unsqueeze(1)).squeeze(1) + x = masked_unpatchify(x, out_shape=out_shape, unmasked=unmasked) + x = self.drop_path(x) + shortcut + return x + + +class MaskedConvNeXtV2Stage(nn.Module): + """Masked ConvNeXt V2 Stage. + + :param int in_channels: input channels + :param int out_channels: output channels + :param int kernel_size: depth-wise convolution kernel size, defaults to 7 + :param int stride: downsampling factor of this stage, defaults to 2 + :param int num_blocks: number of residual blocks, defaults to 2 + :param Sequence[float] | None drop_path_rates: drop path rates of each block, + defaults to None + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 7, + stride: int = 2, + num_blocks: int = 2, + drop_path_rates: Sequence[float] | None = None, + ) -> None: + super().__init__() + if drop_path_rates is None: + drop_path_rates = [0.0] * num_blocks + elif len(drop_path_rates) != num_blocks: + raise ValueError( + "length of drop_path_rates must be equal to " + f"the number of blocks {num_blocks}, got {len(drop_path_rates)}." + ) + if in_channels != out_channels or stride > 1: + downsample_kernel_size = stride if stride > 1 else 1 + self.downsample = nn.Sequential( + LayerNorm2d(in_channels), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=downsample_kernel_size, + stride=stride, + padding=0, + ), + ) + in_channels = out_channels + else: + self.downsample = nn.Identity() + self.blocks = nn.ModuleList() + for i in range(num_blocks): + self.blocks.append( + MaskedConvNeXtV2Block( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + drop_path=drop_path_rates[i], + ) + ) + in_channels = out_channels + + def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: + """ + :param Tensor x: input tensor (BCHW) + :param BoolTensor | None unmasked: boolean foreground mask, defaults to None + :return Tensor: output tensor (BCHW) + """ + x = self.downsample(x) + if unmasked is not None: + unmasked = upsample_mask(unmasked, x.shape) + for block in self.blocks: + x = block(x, unmasked) + return x + + +class MaskedAdaptiveProjection(nn.Module): + """ + Masked patchifying layer for projecting 2D or 3D input into 2D feature maps. + + :param int in_channels: input channels + :param int out_channels: output channels + :param Sequence[int, int] | int kernel_size_2d: kernel width and height + :param int kernel_depth: kernel depth for 3D input + :param int in_stack_depth: input stack depth for 3D input + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size_2d: tuple[int, int] | int = 4, + kernel_depth: int = 5, + in_stack_depth: int = 5, + ) -> None: + super().__init__() + ratio = in_stack_depth // kernel_depth + if isinstance(kernel_size_2d, int): + kernel_size_2d = [kernel_size_2d] * 2 + kernel_size_3d = [kernel_depth, *kernel_size_2d] + self.conv3d = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels // ratio, + kernel_size=kernel_size_3d, + stride=kernel_size_3d, + ) + self.conv2d = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size_2d, + stride=kernel_size_2d, + ) + self.norm = nn.LayerNorm(out_channels) + + def forward(self, x: Tensor, unmasked: BoolTensor = None) -> Tensor: + """ + :param Tensor x: input tensor (BCDHW) + :param BoolTensor unmasked: boolean foreground mask (B1HW), defaults to None + :return Tensor: output tensor (BCHW) + """ + # no need to mask before convolutions since patches do not spill over + if x.shape[2] > 1: + x = self.conv3d(x) + b, c, d, h, w = x.shape + # project Z/depth into channels + # return a view when possible (contiguous) + x = x.reshape(b, c * d, h, w) + else: + x = self.conv2d(x.squeeze(2)) + out_shape = x.shape + if unmasked is not None: + unmasked = upsample_mask(unmasked, x.shape) + x = masked_patchify(x, unmasked=unmasked) + x = self.norm(x) + x = masked_unpatchify(x, out_shape=out_shape, unmasked=unmasked) + return x + + +class MaskedMultiscaleEncoder(nn.Module): + def __init__( + self, + in_channels: int, + stage_blocks: Sequence[int] = (3, 3, 9, 3), + dims: Sequence[int] = (96, 192, 384, 768), + drop_path_rate: float = 0.0, + stem_kernel_size: Sequence[int] = (5, 4, 4), + in_stack_depth: int = 5, + ) -> None: + super().__init__() + self.stem = MaskedAdaptiveProjection( + in_channels, + dims[0], + kernel_size_2d=stem_kernel_size[1:], + kernel_depth=stem_kernel_size[0], + in_stack_depth=in_stack_depth, + ) + self.stages = nn.ModuleList() + chs = [dims[0], *dims] + for i, num_blocks in enumerate(stage_blocks): + stride = 1 if i == 0 else 2 + self.stages.append( + MaskedConvNeXtV2Stage( + chs[i], + chs[i + 1], + kernel_size=7, + stride=stride, + num_blocks=num_blocks, + drop_path_rates=[drop_path_rate] * num_blocks, + ) + ) + self.total_stride = stem_kernel_size[1] * 2 ** (len(self.stages) - 1) + self.apply(_init_weights) + + def forward(self, x: Tensor, mask_ratio: float = 0.0) -> list[Tensor]: + """ + :param Tensor x: input tensor (BCDHW) + :param float mask_ratio: ratio of the feature maps to mask, + defaults to 0.0 (no masking) + :return list[Tensor]: output tensors (list of BCHW) + :return BoolTensor | None: boolean foreground mask, None if no masking + """ + if mask_ratio > 0.0: + mask = generate_mask( + x.shape, self.total_stride, mask_ratio, device=x.device + ) + b, c, d, h, w = x.shape + unmasked = ~mask + mask = upsample_mask(mask, (b, 1, h, w)) + else: + mask = unmasked = None + x = self.stem(x) + features = [] + for stage in self.stages: + x = stage(x, unmasked=unmasked) + features.append(x) + return features, mask + + +class FullyConvolutionalMAE(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + encoder_blocks: Sequence[int] = [3, 3, 9, 3], + dims: Sequence[int] = [96, 192, 384, 768], + encoder_drop_path_rate: float = 0.0, + head_expansion_ratio: int = 4, + stem_kernel_size: Sequence[int] = (5, 4, 4), + in_stack_depth: int = 5, + decoder_conv_blocks: int = 1, + pretraining: bool = True, + ) -> None: + super().__init__() + self.encoder = MaskedMultiscaleEncoder( + in_channels=in_channels, + stage_blocks=encoder_blocks, + dims=dims, + drop_path_rate=encoder_drop_path_rate, + stem_kernel_size=stem_kernel_size, + in_stack_depth=in_stack_depth, + ) + decoder_channels = list(dims) + decoder_channels.reverse() + decoder_channels[-1] = ( + (in_stack_depth + 2) * in_channels * 2**2 * head_expansion_ratio + ) + self.decoder = Unet2dDecoder( + decoder_channels, + norm_name="instance", + mode="pixelshuffle", + conv_blocks=decoder_conv_blocks, + strides=[2] * (len(dims) - 1) + [stem_kernel_size[-1]], + upsample_pre_conv=None, + ) + if in_stack_depth == 1: + self.head = UnsqueezeHead() + else: + self.head = PixelToVoxelHead( + in_channels=decoder_channels[-1], + out_channels=out_channels, + out_stack_depth=in_stack_depth, + expansion_ratio=head_expansion_ratio, + pool=True, + ) + self.out_stack_depth = in_stack_depth + self.num_blocks = 6 + self.pretraining = pretraining + + def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: + x, mask = self.encoder(x, mask_ratio=mask_ratio) + x.reverse() + x = self.decoder(x) + x = self.head(x) + if self.pretraining: + return x, mask + return x diff --git a/viscy/utils/image_utils.py b/viscy/utils/image_utils.py index f9020dc9..a9569116 100644 --- a/viscy/utils/image_utils.py +++ b/viscy/utils/image_utils.py @@ -21,9 +21,7 @@ def im_bit_convert(im, bit=16, norm=False, limit=[]): / (limit[1] - limit[0] + sys.float_info.epsilon) * (2**bit - 1) ) - im = np.clip( - im, 0, 2**bit - 1 - ) # clip the values to avoid wrap-around by np.astype + im = np.clip(im, 0, 2**bit - 1) # clip the values to avoid wrap-around by np.astype if bit == 8: im = im.astype(np.uint8, copy=False) # convert to 8 bit else: diff --git a/viscy/unet/utils/logging.py b/viscy/utils/logging.py similarity index 100% rename from viscy/unet/utils/logging.py rename to viscy/utils/logging.py diff --git a/viscy/utils/meta_utils.py b/viscy/utils/meta_utils.py index d644dadf..961b6696 100644 --- a/viscy/utils/meta_utils.py +++ b/viscy/utils/meta_utils.py @@ -104,8 +104,9 @@ def generate_normalization_metadata( positions, fov_sample_values = mp_utils.mp_sample_im_pixels( this_channels_args, num_workers ) - dataset_sample_values = np.stack(fov_sample_values, 0) - + dataset_sample_values = np.concatenate( + [arr.flatten() for arr in fov_sample_values] + ) fov_level_statistics = mp_utils.mp_get_val_stats(fov_sample_values, num_workers) dataset_level_statistics = mp_utils.get_val_stats(dataset_sample_values) diff --git a/viscy/utils/normalize.py b/viscy/utils/normalize.py index 93c11713..73753acb 100644 --- a/viscy/utils/normalize.py +++ b/viscy/utils/normalize.py @@ -1,4 +1,5 @@ """Image normalization related functions""" + import sys import numpy as np