Skip to content

Commit

Permalink
Masked autoencoder pre-training for virtual staining models (#67)
Browse files Browse the repository at this point in the history
* refactor data loading into its own module

* update type annotations

* move the logging module out

* move old logging into utils

* rename tests to match module name

* bump torch

* draft fcmae encoder

* add stem to the encoder

* wip: masked stem layernorm

* wip: patchify masked features for linear

* use mlp from timm

* hack: POC training script for FCMAE

* fix mask for fitting

* remove training script

* default architecture

* fine-tuning options

* fix cli for finetuning

* draft combined data module

* fix import

* manual validation loss reduction

* update linting
new black version has different rules

* update development guide

* update type hints

* bump iohub

* draft ctmc v1 dataset

* update tests

* move test_data

* remove path conversion

* configurable normalizations (#68)

* inital commit adding the normalization.

* adding dataset_statistics to each fov to facilitate the configurable augmentations

* fix indentation

* ruff

* test preprocessing

* remove redundant field

* cleanup

---------

Co-authored-by: Ziwen Liu <[email protected]>

* fix ctmc dataloading

* add example ctmc v1 loading script

* changing the normalization and augmentations default from None to empty list.

* invert intensity transform

* concatenated data module

* subsample videos

* livecell dataset

* all sample fields are optional

* fix multi-dataloader validation

* lint

* fixing preprocessing for varying array shapes (i.e aics dataset)

* update loading scripts

* fix CombineMode

* compose normalizations for predict and test stages

* black

* fix normalization in example config

* fix collate when multi-sample transform is not used

* ddp caching fixes

* fix caching when using combined loader

* move log values to GPU before syncing
Lightning-AI/pytorch-lightning#18803

* removing normalize_source from configs.

* typing fixes

* fix test data path

* fix test dataset

* add docstring for ConcatDataModule

* format

---------

Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]>
  • Loading branch information
ziw-liu and edyoshikun authored Apr 8, 2024
1 parent 582a4c8 commit 0536d29
Show file tree
Hide file tree
Showing 40 changed files with 1,544 additions and 280 deletions.
14 changes: 13 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,19 @@ then make an editable installation with all the optional dependencies:
pip install -e ".[dev,visual,metrics]"
```

## Testing
## CI requirements

Lint with Ruff:

```sh
ruff check viscy
```

Format the code with Black:

```sh
black viscy
```

Run tests with `pytest`:

Expand Down
16 changes: 14 additions & 2 deletions examples/configs/fit_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ data:
batch_size: 32
num_workers: 16
yx_patch_size: [256, 256]
normalizations:
- class_path: viscy.transforms.NormalizeSampled
init_args:
keys: [source]
level: "fov_statistics"
subtrahend: "mean"
divisor: "std"
- class_path: viscy.transforms.NormalizeSampled
init_args:
keys: [target_1]
level: "fov_statistics"
subtrahend: "median"
divisor: "iqr"
augmentations:
- class_path: viscy.transforms.RandWeightedCropd
init_args:
Expand Down Expand Up @@ -74,5 +87,4 @@ data:
sigma_z: [0.25, 1.5]
sigma_y: [0.25, 1.5]
sigma_x: [0.25, 1.5]
caching: false
normalize_source: true
caching: false
1 change: 0 additions & 1 deletion examples/configs/predict_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ predict:
- 256
- 256
caching: false
normalize_source: false
predict_scale_source: null
return_predictions: false
ckpt_path: null
1 change: 0 additions & 1 deletion examples/configs/test_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ data:
- 256
- 256
caching: false
normalize_source: false
ground_truth_masks: null
ckpt_path: null
verbose: true
2 changes: 1 addition & 1 deletion examples/demo_dlmbl/debug_log_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard

# HCSDataModule makes it easy to load data during training.
from viscy.light.data import HCSDataModule
from viscy.data.hcs import HCSDataModule

# Trainer class and UNet.
from viscy.light.engine import VSUNet
Expand Down
2 changes: 1 addition & 1 deletion examples/demo_dlmbl/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard

# HCSDataModule makes it easy to load data during training.
from viscy.light.data import HCSDataModule
from viscy.data.hcs import HCSDataModule

# training augmentations
from viscy.transforms import (
Expand Down
21 changes: 13 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ requires-python = ">=3.10"
license = { file = "LICENSE" }
authors = [{ name = "CZ Biohub SF", email = "[email protected]" }]
dependencies = [
"iohub==0.1.0rc0",
"torch>=2.0.0",
"iohub==0.1.0",
"torch>=2.1.2",
"timm>=0.9.5",
"tensorboard>=2.13.0",
"lightning>=2.0.1",
Expand All @@ -30,7 +30,15 @@ metrics = [
"ptflops>=0.7",
]
visual = ["ipykernel", "graphviz", "torchview"]
dev = ["pytest", "pytest-cov", "hypothesis", "profilehooks", "onnxruntime"]
dev = [
"pytest",
"pytest-cov",
"hypothesis",
"ruff",
"black",
"profilehooks",
"onnxruntime",
]

[project.scripts]
viscy = "viscy.cli.cli:main"
Expand All @@ -39,12 +47,9 @@ viscy = "viscy.cli.cli:main"
write_to = "viscy/_version.py"

[tool.black]
src = ["viscy"]
line-length = 88

[tool.ruff]
src = ["viscy", "tests"]
extend-select = ["I001"]

[tool.ruff.isort]
known-first-party = ["viscy"]
lint.extend-select = ["I001"]
lint.isort.known-first-party = ["viscy"]
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def preprocessed_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path:
norm_meta = {channel: {"dataset_statistics": expected} for channel in channel_names}
with open_ome_zarr(dataset_path, mode="r+") as dataset:
dataset.zattrs["normalization"] = norm_meta
for _, fov in dataset.positions():
fov.zattrs["normalization"] = norm_meta
return dataset_path


Expand Down
Empty file added tests/data/__init__.py
Empty file.
105 changes: 105 additions & 0 deletions tests/data/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from pathlib import Path

from iohub import open_ome_zarr
from monai.transforms import RandSpatialCropSamplesd
from pytest import mark

from viscy.data.hcs import HCSDataModule
from viscy.light.trainer import VSTrainer


@mark.parametrize("default_channels", [True, False])
def test_preprocess(small_hcs_dataset: Path, default_channels: bool):
data_path = small_hcs_dataset
if default_channels:
channel_names = -1
else:
with open_ome_zarr(data_path) as dataset:
channel_names = dataset.channel_names
trainer = VSTrainer(accelerator="cpu")
trainer.preprocess(data_path, channel_names=channel_names, num_workers=2)
with open_ome_zarr(data_path) as dataset:
channel_names = dataset.channel_names
for channel in channel_names:
assert "dataset_statistics" in dataset.zattrs["normalization"][channel]
for _, fov in dataset.positions():
norm_metadata = fov.zattrs["normalization"]
for channel in channel_names:
assert channel in norm_metadata
assert "dataset_statistics" in norm_metadata[channel]
assert "fov_statistics" in norm_metadata[channel]


@mark.parametrize("multi_sample_augmentation", [True, False])
def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentation):
data_path = preprocessed_hcs_dataset
z_window_size = 5
channel_split = 2
split_ratio = 0.8
yx_patch_size = [128, 96]
batch_size = 4
with open_ome_zarr(data_path) as dataset:
channel_names = dataset.channel_names
if multi_sample_augmentation:
transforms = [
RandSpatialCropSamplesd(
keys=channel_names,
roi_size=[z_window_size, *yx_patch_size],
num_samples=2,
)
]
else:
transforms = []
dm = HCSDataModule(
data_path=data_path,
source_channel=channel_names[:channel_split],
target_channel=channel_names[channel_split:],
z_window_size=z_window_size,
batch_size=batch_size,
num_workers=0,
augmentations=transforms,
architecture="3D",
split_ratio=split_ratio,
yx_patch_size=yx_patch_size,
)
dm.setup(stage="fit")
for batch in dm.train_dataloader():
assert batch["source"].shape == (
batch_size,
channel_split,
z_window_size,
*yx_patch_size,
)
assert batch["target"].shape == (
batch_size,
len(channel_names) - channel_split,
z_window_size,
*yx_patch_size,
)


def test_datamodule_setup_predict(preprocessed_hcs_dataset):
data_path = preprocessed_hcs_dataset
z_window_size = 5
channel_split = 2
with open_ome_zarr(data_path) as dataset:
channel_names = dataset.channel_names
img = next(dataset.positions())[1][0]
total_p = len(list(dataset.positions()))
dm = HCSDataModule(
data_path=data_path,
source_channel=channel_names[:channel_split],
target_channel=channel_names[channel_split:],
z_window_size=z_window_size,
batch_size=2,
num_workers=0,
)
dm.setup(stage="predict")
dataset = dm.predict_dataset
assert len(dataset) == total_p * 2 * (img.slices - z_window_size + 1)
assert dataset[0]["source"].shape == (
channel_split,
z_window_size,
img.height,
img.width,
)
70 changes: 0 additions & 70 deletions tests/light/test_data.py

This file was deleted.

7 changes: 7 additions & 0 deletions tests/light/test_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from viscy.light.engine import FcmaeUNet


def test_fcmae_vsunet() -> None:
model = FcmaeUNet(
model_config=dict(in_channels=3, out_channels=1), fit_mask_ratio=0.6
)
Empty file added tests/unet/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 0536d29

Please sign in to comment.