Skip to content

Commit

Permalink
2D FCMAE (#71)
Browse files Browse the repository at this point in the history
* refactor data loading into its own module

* update type annotations

* move the logging module out

* move old logging into utils

* rename tests to match module name

* bump torch

* draft fcmae encoder

* add stem to the encoder

* wip: masked stem layernorm

* wip: patchify masked features for linear

* use mlp from timm

* hack: POC training script for FCMAE

* fix mask for fitting

* remove training script

* default architecture

* fine-tuning options

* fix cli for finetuning

* draft combined data module

* fix import

* manual validation loss reduction

* update linting
new black version has different rules

* update development guide

* update type hints

* bump iohub

* draft ctmc v1 dataset

* update tests

* move test_data

* remove path conversion

* configurable normalizations (#68)

* inital commit adding the normalization.

* adding dataset_statistics to each fov to facilitate the configurable augmentations

* fix indentation

* ruff

* test preprocessing

* remove redundant field

* cleanup

---------

Co-authored-by: Ziwen Liu <[email protected]>

* fix ctmc dataloading

* add example ctmc v1 loading script

* changing the normalization and augmentations default from None to empty list.

* invert intensity transform

* concatenated data module

* subsample videos

* livecell dataset

* all sample fields are optional

* fix multi-dataloader validation

* lint

* fixing preprocessing for varying array shapes (i.e aics dataset)

* update loading scripts

* fix CombineMode

* always use untrainable head for FCMAE

* move log values to GPU before syncing
Lightning-AI/pytorch-lightning#18803

* custom head

* ddp caching fixes

* fix caching when using combined loader

* compose normalizations for predict and test stages

* black

* fix normalization in example config

* fix normalization in example config

* prefetch more in validation

* fix collate when multi-sample transform is not used

* ddp caching fixes

* fix caching when using combined loader

* typing fixes

* fix test dataset

* fix invert transform

* add ddp prepare flag for combined data module

* remove redundant operations

* filter empty detections

* pass trainer to underlying data modules in concatenated

* hack: add test dataloader for LiveCell dataset

* test datasets for livecell and ctmc

* fix merge error

* fix merge error

* fix mAP default for over 100 detections

* bump torchmetric

* fix combined loader training for virtual staining task

* fix non-combined data loader training

* add fcmae to graph script

* fix type hint

* format

* add back convolutiuon option for fcmae head

---------

Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]>
  • Loading branch information
ziw-liu and edyoshikun committed Jun 12, 2024
1 parent 57ca6c4 commit 111cb7d
Show file tree
Hide file tree
Showing 11 changed files with 265 additions and 58 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dynamic = ["version"]
metrics = [
"cellpose==2.1.0",
"scikit-learn>=1.1.3",
"torchmetrics[detection]>=1.0.0",
"torchmetrics[detection]>=1.3.1",
"ptflops>=0.7",
]
visual = ["ipykernel", "graphviz", "torchview"]
Expand Down
21 changes: 21 additions & 0 deletions tests/unet/test_fcmae.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MaskedConvNeXtV2Block,
MaskedConvNeXtV2Stage,
MaskedMultiscaleEncoder,
PixelToVoxelShuffleHead,
generate_mask,
masked_patchify,
masked_unpatchify,
Expand Down Expand Up @@ -104,6 +105,13 @@ def test_masked_multiscale_encoder():
assert afeat.shape[2] == afeat.shape[3] == xy_size // stride


def test_pixel_to_voxel_shuffle_head():
head = PixelToVoxelShuffleHead(240, 3, out_stack_depth=5, xy_scaling=4)
x = torch.rand(2, 240, 16, 16)
y = head(x)
assert y.shape == (2, 3, 5, 64, 64)


def test_fcmae():
x = torch.rand(2, 3, 5, 128, 128)
model = FullyConvolutionalMAE(3, 3)
Expand All @@ -113,3 +121,16 @@ def test_fcmae():
y, m = model(x, mask_ratio=0.6)
assert y.shape == x.shape
assert m.shape == (2, 1, 128, 128)


def test_fcmae_head_conv():
x = torch.rand(2, 3, 5, 128, 128)
model = FullyConvolutionalMAE(
3, 3, head_conv=True, head_conv_expansion_ratio=4, head_conv_pool=True
)
y, m = model(x)
assert y.shape == x.shape
assert m is None
y, m = model(x, mask_ratio=0.6)
assert y.shape == x.shape
assert m.shape == (2, 1, 128, 128)
3 changes: 2 additions & 1 deletion viscy/data/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class ConcatDataModule(LightningDataModule):
The concatenated data module will have the same
batch size and number of workers as the first data module.
Each element will be sampled uniformly regardless of their original data module.
:param Sequence[LightningDataModule] data_modules: data modules to concatenate
"""

Expand All @@ -93,9 +92,11 @@ def __init__(self, data_modules: Sequence[LightningDataModule]):
raise ValueError("Inconsistent number of workers")
if dm.batch_size != self.batch_size:
raise ValueError("Inconsistent batch size")
self.prepare_data_per_node = True

def prepare_data(self):
for dm in self.data_modules:
dm.trainer = self.trainer
dm.prepare_data()

def setup(self, stage: Literal["fit", "validate", "test", "predict"]):
Expand Down
3 changes: 1 addition & 2 deletions viscy/data/ctmc_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@


class CTMCv1ValidationDataset(SlidingWindowDataset):
subsample_rate: int = 30

def __len__(self) -> int:
def __len__(self, subsample_rate: int = 30) -> int:
# sample every 30th frame in the videos
return super().__len__() // self.subsample_rate

Expand Down
2 changes: 0 additions & 2 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ def __getitem__(self, index: int) -> Sample:
sample_images["norm_meta"] = norm_meta
if self.transform:
sample_images = self.transform(sample_images)
# if isinstance(sample_images, list):
# sample_images = sample_images[0]
if "weight" in sample_images:
del sample_images["weight"]
sample = {
Expand Down
101 changes: 90 additions & 11 deletions viscy/data/livecell.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

import torch
from lightning.pytorch import LightningDataModule
from monai.transforms import Compose, Transform
from monai.transforms import Compose, MapTransform
from pycocotools.coco import COCO
from tifffile import imread
from torch.utils.data import DataLoader, Dataset
from torchvision.ops import box_convert

from viscy.data.typing import Sample

Expand All @@ -15,10 +17,10 @@ class LiveCellDataset(Dataset):
LiveCell dataset.
:param list[Path] images: List of paths to single-page, single-channel TIFF files.
:param Transform | Compose transform: Transform to apply to the dataset
:param MapTransform | Compose transform: Transform to apply to the dataset
"""

def __init__(self, images: list[Path], transform: Transform | Compose) -> None:
def __init__(self, images: list[Path], transform: MapTransform | Compose) -> None:
self.images = images
self.transform = transform

Expand All @@ -32,36 +34,100 @@ def __getitem__(self, idx: int) -> Sample:
return {"source": image, "target": image}


class LiveCellTestDataset(Dataset):
"""
LiveCell dataset.
:param list[Path] images: List of paths to single-page, single-channel TIFF files.
:param MapTransform | Compose transform: Transform to apply to the dataset
"""

def __init__(
self,
image_dir: Path,
transform: MapTransform | Compose,
annotations: Path,
load_target: bool = False,
load_labels: bool = False,
) -> None:
self.image_dir = image_dir
self.transform = transform
self.coco = COCO(str(annotations))
self.image_ids = list(self.coco.imgs.keys())
self.load_target = load_target
self.load_labels = load_labels

def __len__(self) -> int:
return len(self.image_ids)

def __getitem__(self, idx: int) -> Sample:
image_id = self.image_ids[idx]
file_name = self.coco.imgs[image_id]["file_name"]
image_path = self.image_dir / file_name
image = imread(image_path)[None, None]
image = torch.from_numpy(image).to(torch.float32)
sample = Sample(source=image)
if self.load_target:
sample["target"] = image
if self.load_labels:
anns = self.coco.loadAnns(self.coco.getAnnIds(image_id)) or []
boxes = [torch.tensor(ann["bbox"]).to(torch.float32) for ann in anns]
masks = [
torch.from_numpy(self.coco.annToMask(ann)).to(torch.bool)
for ann in anns
]
dets = {
"boxes": box_convert(torch.stack(boxes), in_fmt="xywh", out_fmt="xyxy"),
"labels": torch.zeros(len(anns)).to(torch.uint8),
"masks": torch.stack(masks),
}
sample["detections"] = dets
sample["file_name"] = file_name
self.transform(sample)
return sample


class LiveCellDataModule(LightningDataModule):
def __init__(
self,
train_val_images: Path,
train_annotations: Path,
val_annotations: Path,
train_transforms: list[Transform],
val_transforms: list[Transform],
train_val_images: Path | None = None,
test_images: Path | None = None,
train_annotations: Path | None = None,
val_annotations: Path | None = None,
test_annotations: Path | None = None,
train_transforms: list[MapTransform] = [],
val_transforms: list[MapTransform] = [],
test_transforms: list[MapTransform] = [],
batch_size: int = 16,
num_workers: int = 8,
) -> None:
super().__init__()
self.train_val_images = Path(train_val_images)
if not self.train_val_images.is_dir():
raise NotADirectoryError(str(train_val_images))
self.test_images = Path(test_images)
if not self.test_images.is_dir():
raise NotADirectoryError(str(test_images))
self.train_annotations = Path(train_annotations)
if not self.train_annotations.is_file():
raise FileNotFoundError(str(train_annotations))
self.val_annotations = Path(val_annotations)
if not self.val_annotations.is_file():
raise FileNotFoundError(str(val_annotations))
self.test_annotations = Path(test_annotations)
if not self.test_annotations.is_file():
raise FileNotFoundError(str(test_annotations))
self.train_transforms = Compose(train_transforms)
self.val_transforms = Compose(val_transforms)
self.test_transforms = Compose(test_transforms)
self.batch_size = batch_size
self.num_workers = num_workers

def setup(self, stage: str) -> None:
if stage != "fit":
raise NotImplementedError("Only fit stage is supported")
self._setup_fit()
if stage == "fit":
self._setup_fit()
elif stage == "test":
self._setup_test()

def _parse_image_names(self, annotations: Path) -> list[Path]:
with open(annotations) as f:
Expand All @@ -80,6 +146,14 @@ def _setup_fit(self) -> None:
transform=self.val_transforms,
)

def _setup_test(self) -> None:
self.test_dataset = LiveCellTestDataset(
self.test_images,
transform=self.test_transforms,
annotations=self.test_annotations,
load_labels=True,
)

def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
Expand All @@ -96,3 +170,8 @@ def val_dataloader(self) -> DataLoader:
num_workers=self.num_workers,
persistent_workers=bool(self.num_workers),
)

def test_dataloader(self) -> DataLoader:
return DataLoader(
self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers
)
9 changes: 7 additions & 2 deletions viscy/evaluation/evaluation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from monai.metrics.regression import compute_ssim_and_cs
from scipy.optimize import linear_sum_assignment
from skimage.measure import label, regionprops
from torchmetrics.detection import MeanAveragePrecision
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.ops import masks_to_boxes


Expand Down Expand Up @@ -172,7 +172,12 @@ def mean_average_precision(
:py:class:`torchmetrics.detection.MeanAveragePrecision`
:return dict[str, torch.Tensor]: COCO-style metrics
"""
map_metric = MeanAveragePrecision(box_format="xyxy", iou_type="segm", **kwargs)
defaults = dict(
iou_type="segm", box_format="xyxy", max_detection_thresholds=[1, 100, 10000]
)
if not kwargs:
kwargs = {}
map_metric = MeanAveragePrecision(**(defaults | kwargs))
map_metric.update(
[labels_to_detection(pred_labels)], [labels_to_detection(target_labels)]
)
Expand Down
65 changes: 40 additions & 25 deletions viscy/light/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(
self.log_batches_per_epoch = log_batches_per_epoch
self.log_samples_per_batch = log_samples_per_batch
self.training_step_outputs = []
self.validation_losses = []
self.validation_step_outputs = []
# required to log the graph
if architecture == "2D":
Expand All @@ -170,32 +171,49 @@ def __init__(
def forward(self, x: Tensor) -> Tensor:
return self.model(x)

def training_step(self, batch: Sample, batch_idx: int):
source = batch["source"]
target = batch["target"]
pred = self.forward(source)
loss = self.loss_function(pred, target)
def training_step(self, batch: Sample | Sequence[Sample], batch_idx: int):
losses = []
batch_size = 0
if not isinstance(batch, Sequence):
batch = [batch]
for b in batch:
source = b["source"]
target = b["target"]
pred = self.forward(source)
loss = self.loss_function(pred, target)
losses.append(loss)
batch_size += source.shape[0]
if batch_idx < self.log_batches_per_epoch:
self.training_step_outputs.extend(
self._detach_sample((source, target, pred))
)
loss_step = torch.stack(losses).mean()
self.log(
"loss/train",
loss,
loss_step.to(self.device),
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
batch_size=batch_size,
)
if batch_idx < self.log_batches_per_epoch:
self.training_step_outputs.extend(
self._detach_sample((source, target, pred))
)
return loss
return loss_step

def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0):
source = batch["source"]
target = batch["target"]
source: Tensor = batch["source"]
target: Tensor = batch["target"]
pred = self.forward(source)
loss = self.loss_function(pred, target)
self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False)
if dataloader_idx + 1 > len(self.validation_losses):
self.validation_losses.append([])
self.validation_losses[dataloader_idx].append(loss.detach())
self.log(
f"loss/val/{dataloader_idx}",
loss.to(self.device),
sync_dist=True,
batch_size=source.shape[0],
)
if batch_idx < self.log_batches_per_epoch:
self.validation_step_outputs.extend(
self._detach_sample((source, target, pred))
Expand Down Expand Up @@ -305,8 +323,16 @@ def on_train_epoch_end(self):
self.training_step_outputs = []

def on_validation_epoch_end(self):
super().on_validation_epoch_end()
self._log_samples("val_samples", self.validation_step_outputs)
self.validation_step_outputs = []
# average within each dataloader
loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses]
self.log(
"loss/validate",
torch.tensor(loss_means).mean().to(self.device),
sync_dist=True,
)

def on_test_start(self):
"""Load CellPose model for segmentation."""
Expand Down Expand Up @@ -382,7 +408,6 @@ class FcmaeUNet(VSUNet):
def __init__(self, fit_mask_ratio: float = 0.0, **kwargs):
super().__init__(architecture="fcmae", **kwargs)
self.fit_mask_ratio = fit_mask_ratio
self.validation_losses = []

def forward(self, x: Tensor, mask_ratio: float = 0.0):
return self.model(x, mask_ratio)
Expand Down Expand Up @@ -434,13 +459,3 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0
self.validation_step_outputs.extend(
self._detach_sample((source, target * mask.unsqueeze(2), pred))
)

def on_validation_epoch_end(self):
super().on_validation_epoch_end()
# average within each dataloader
loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses]
self.log(
"loss/validate",
torch.tensor(loss_means).mean().to(self.device),
sync_dist=True,
)
Loading

0 comments on commit 111cb7d

Please sign in to comment.