From 897a0b63fd3df852d61290999c8fdd4021095588 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Tue, 11 Jun 2024 12:04:36 -0700 Subject: [PATCH] 2D FCMAE (#71) * refactor data loading into its own module * update type annotations * move the logging module out * move old logging into utils * rename tests to match module name * bump torch * draft fcmae encoder * add stem to the encoder * wip: masked stem layernorm * wip: patchify masked features for linear * use mlp from timm * hack: POC training script for FCMAE * fix mask for fitting * remove training script * default architecture * fine-tuning options * fix cli for finetuning * draft combined data module * fix import * manual validation loss reduction * update linting new black version has different rules * update development guide * update type hints * bump iohub * draft ctmc v1 dataset * update tests * move test_data * remove path conversion * configurable normalizations (#68) * inital commit adding the normalization. * adding dataset_statistics to each fov to facilitate the configurable augmentations * fix indentation * ruff * test preprocessing * remove redundant field * cleanup --------- Co-authored-by: Ziwen Liu * fix ctmc dataloading * add example ctmc v1 loading script * changing the normalization and augmentations default from None to empty list. * invert intensity transform * concatenated data module * subsample videos * livecell dataset * all sample fields are optional * fix multi-dataloader validation * lint * fixing preprocessing for varying array shapes (i.e aics dataset) * update loading scripts * fix CombineMode * always use untrainable head for FCMAE * move log values to GPU before syncing https://github.com/Lightning-AI/pytorch-lightning/issues/18803 * custom head * ddp caching fixes * fix caching when using combined loader * compose normalizations for predict and test stages * black * fix normalization in example config * fix normalization in example config * prefetch more in validation * fix collate when multi-sample transform is not used * ddp caching fixes * fix caching when using combined loader * typing fixes * fix test dataset * fix invert transform * add ddp prepare flag for combined data module * remove redundant operations * filter empty detections * pass trainer to underlying data modules in concatenated * hack: add test dataloader for LiveCell dataset * test datasets for livecell and ctmc * fix merge error * fix merge error * fix mAP default for over 100 detections * bump torchmetric * fix combined loader training for virtual staining task * fix non-combined data loader training * add fcmae to graph script * fix type hint * format * add back convolutiuon option for fcmae head --------- Co-authored-by: Eduardo Hirata-Miyasaki --- pyproject.toml | 2 +- tests/unet/test_fcmae.py | 21 +++++ viscy/data/combined.py | 3 +- viscy/data/ctmc_v1.py | 3 +- viscy/data/hcs.py | 2 - viscy/data/livecell.py | 101 ++++++++++++++++++++++--- viscy/evaluation/evaluation_metrics.py | 9 ++- viscy/light/engine.py | 65 ++++++++++------ viscy/scripts/network_diagram.py | 39 +++++++++- viscy/transforms.py | 9 ++- viscy/unet/networks/fcmae.py | 69 ++++++++++++++--- 11 files changed, 265 insertions(+), 58 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8f6978de..5f0a184f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dynamic = ["version"] metrics = [ "cellpose==2.1.0", "scikit-learn>=1.1.3", - "torchmetrics[detection]>=1.0.0", + "torchmetrics[detection]>=1.3.1", "ptflops>=0.7", ] visual = ["ipykernel", "graphviz", "torchview"] diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index 4ed441b4..f22efa4c 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -6,6 +6,7 @@ MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, MaskedMultiscaleEncoder, + PixelToVoxelShuffleHead, generate_mask, masked_patchify, masked_unpatchify, @@ -104,6 +105,13 @@ def test_masked_multiscale_encoder(): assert afeat.shape[2] == afeat.shape[3] == xy_size // stride +def test_pixel_to_voxel_shuffle_head(): + head = PixelToVoxelShuffleHead(240, 3, out_stack_depth=5, xy_scaling=4) + x = torch.rand(2, 240, 16, 16) + y = head(x) + assert y.shape == (2, 3, 5, 64, 64) + + def test_fcmae(): x = torch.rand(2, 3, 5, 128, 128) model = FullyConvolutionalMAE(3, 3) @@ -113,3 +121,16 @@ def test_fcmae(): y, m = model(x, mask_ratio=0.6) assert y.shape == x.shape assert m.shape == (2, 1, 128, 128) + + +def test_fcmae_head_conv(): + x = torch.rand(2, 3, 5, 128, 128) + model = FullyConvolutionalMAE( + 3, 3, head_conv=True, head_conv_expansion_ratio=4, head_conv_pool=True + ) + y, m = model(x) + assert y.shape == x.shape + assert m is None + y, m = model(x, mask_ratio=0.6) + assert y.shape == x.shape + assert m.shape == (2, 1, 128, 128) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index db3803a5..31ea9f6c 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -79,7 +79,6 @@ class ConcatDataModule(LightningDataModule): The concatenated data module will have the same batch size and number of workers as the first data module. Each element will be sampled uniformly regardless of their original data module. - :param Sequence[LightningDataModule] data_modules: data modules to concatenate """ @@ -93,9 +92,11 @@ def __init__(self, data_modules: Sequence[LightningDataModule]): raise ValueError("Inconsistent number of workers") if dm.batch_size != self.batch_size: raise ValueError("Inconsistent batch size") + self.prepare_data_per_node = True def prepare_data(self): for dm in self.data_modules: + dm.trainer = self.trainer dm.prepare_data() def setup(self, stage: Literal["fit", "validate", "test", "predict"]): diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index d666fdcb..727f241f 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -10,9 +10,8 @@ class CTMCv1ValidationDataset(SlidingWindowDataset): - subsample_rate: int = 30 - def __len__(self) -> int: + def __len__(self, subsample_rate: int = 30) -> int: # sample every 30th frame in the videos return super().__len__() // self.subsample_rate diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index f33b6121..77bcc1ed 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -191,8 +191,6 @@ def __getitem__(self, index: int) -> Sample: sample_images["norm_meta"] = norm_meta if self.transform: sample_images = self.transform(sample_images) - # if isinstance(sample_images, list): - # sample_images = sample_images[0] if "weight" in sample_images: del sample_images["weight"] sample = { diff --git a/viscy/data/livecell.py b/viscy/data/livecell.py index 5d83f099..bb8bb56c 100644 --- a/viscy/data/livecell.py +++ b/viscy/data/livecell.py @@ -3,9 +3,11 @@ import torch from lightning.pytorch import LightningDataModule -from monai.transforms import Compose, Transform +from monai.transforms import Compose, MapTransform +from pycocotools.coco import COCO from tifffile import imread from torch.utils.data import DataLoader, Dataset +from torchvision.ops import box_convert from viscy.data.typing import Sample @@ -15,10 +17,10 @@ class LiveCellDataset(Dataset): LiveCell dataset. :param list[Path] images: List of paths to single-page, single-channel TIFF files. - :param Transform | Compose transform: Transform to apply to the dataset + :param MapTransform | Compose transform: Transform to apply to the dataset """ - def __init__(self, images: list[Path], transform: Transform | Compose) -> None: + def __init__(self, images: list[Path], transform: MapTransform | Compose) -> None: self.images = images self.transform = transform @@ -32,14 +34,70 @@ def __getitem__(self, idx: int) -> Sample: return {"source": image, "target": image} +class LiveCellTestDataset(Dataset): + """ + LiveCell dataset. + + :param list[Path] images: List of paths to single-page, single-channel TIFF files. + :param MapTransform | Compose transform: Transform to apply to the dataset + """ + + def __init__( + self, + image_dir: Path, + transform: MapTransform | Compose, + annotations: Path, + load_target: bool = False, + load_labels: bool = False, + ) -> None: + self.image_dir = image_dir + self.transform = transform + self.coco = COCO(str(annotations)) + self.image_ids = list(self.coco.imgs.keys()) + self.load_target = load_target + self.load_labels = load_labels + + def __len__(self) -> int: + return len(self.image_ids) + + def __getitem__(self, idx: int) -> Sample: + image_id = self.image_ids[idx] + file_name = self.coco.imgs[image_id]["file_name"] + image_path = self.image_dir / file_name + image = imread(image_path)[None, None] + image = torch.from_numpy(image).to(torch.float32) + sample = Sample(source=image) + if self.load_target: + sample["target"] = image + if self.load_labels: + anns = self.coco.loadAnns(self.coco.getAnnIds(image_id)) or [] + boxes = [torch.tensor(ann["bbox"]).to(torch.float32) for ann in anns] + masks = [ + torch.from_numpy(self.coco.annToMask(ann)).to(torch.bool) + for ann in anns + ] + dets = { + "boxes": box_convert(torch.stack(boxes), in_fmt="xywh", out_fmt="xyxy"), + "labels": torch.zeros(len(anns)).to(torch.uint8), + "masks": torch.stack(masks), + } + sample["detections"] = dets + sample["file_name"] = file_name + self.transform(sample) + return sample + + class LiveCellDataModule(LightningDataModule): def __init__( self, - train_val_images: Path, - train_annotations: Path, - val_annotations: Path, - train_transforms: list[Transform], - val_transforms: list[Transform], + train_val_images: Path | None = None, + test_images: Path | None = None, + train_annotations: Path | None = None, + val_annotations: Path | None = None, + test_annotations: Path | None = None, + train_transforms: list[MapTransform] = [], + val_transforms: list[MapTransform] = [], + test_transforms: list[MapTransform] = [], batch_size: int = 16, num_workers: int = 8, ) -> None: @@ -47,21 +105,29 @@ def __init__( self.train_val_images = Path(train_val_images) if not self.train_val_images.is_dir(): raise NotADirectoryError(str(train_val_images)) + self.test_images = Path(test_images) + if not self.test_images.is_dir(): + raise NotADirectoryError(str(test_images)) self.train_annotations = Path(train_annotations) if not self.train_annotations.is_file(): raise FileNotFoundError(str(train_annotations)) self.val_annotations = Path(val_annotations) if not self.val_annotations.is_file(): raise FileNotFoundError(str(val_annotations)) + self.test_annotations = Path(test_annotations) + if not self.test_annotations.is_file(): + raise FileNotFoundError(str(test_annotations)) self.train_transforms = Compose(train_transforms) self.val_transforms = Compose(val_transforms) + self.test_transforms = Compose(test_transforms) self.batch_size = batch_size self.num_workers = num_workers def setup(self, stage: str) -> None: - if stage != "fit": - raise NotImplementedError("Only fit stage is supported") - self._setup_fit() + if stage == "fit": + self._setup_fit() + elif stage == "test": + self._setup_test() def _parse_image_names(self, annotations: Path) -> list[Path]: with open(annotations) as f: @@ -80,6 +146,14 @@ def _setup_fit(self) -> None: transform=self.val_transforms, ) + def _setup_test(self) -> None: + self.test_dataset = LiveCellTestDataset( + self.test_images, + transform=self.test_transforms, + annotations=self.test_annotations, + load_labels=True, + ) + def train_dataloader(self) -> DataLoader: return DataLoader( self.train_dataset, @@ -96,3 +170,8 @@ def val_dataloader(self) -> DataLoader: num_workers=self.num_workers, persistent_workers=bool(self.num_workers), ) + + def test_dataloader(self) -> DataLoader: + return DataLoader( + self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) diff --git a/viscy/evaluation/evaluation_metrics.py b/viscy/evaluation/evaluation_metrics.py index 921b0e4e..bb89858f 100644 --- a/viscy/evaluation/evaluation_metrics.py +++ b/viscy/evaluation/evaluation_metrics.py @@ -9,7 +9,7 @@ from monai.metrics.regression import compute_ssim_and_cs from scipy.optimize import linear_sum_assignment from skimage.measure import label, regionprops -from torchmetrics.detection import MeanAveragePrecision +from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchvision.ops import masks_to_boxes @@ -172,7 +172,12 @@ def mean_average_precision( :py:class:`torchmetrics.detection.MeanAveragePrecision` :return dict[str, torch.Tensor]: COCO-style metrics """ - map_metric = MeanAveragePrecision(box_format="xyxy", iou_type="segm", **kwargs) + defaults = dict( + iou_type="segm", box_format="xyxy", max_detection_thresholds=[1, 100, 10000] + ) + if not kwargs: + kwargs = {} + map_metric = MeanAveragePrecision(**(defaults | kwargs)) map_metric.update( [labels_to_detection(pred_labels)], [labels_to_detection(target_labels)] ) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 08b85319..ac15c208 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -146,6 +146,7 @@ def __init__( self.log_batches_per_epoch = log_batches_per_epoch self.log_samples_per_batch = log_samples_per_batch self.training_step_outputs = [] + self.validation_losses = [] self.validation_step_outputs = [] # required to log the graph if architecture == "2D": @@ -170,32 +171,49 @@ def __init__( def forward(self, x: Tensor) -> Tensor: return self.model(x) - def training_step(self, batch: Sample, batch_idx: int): - source = batch["source"] - target = batch["target"] - pred = self.forward(source) - loss = self.loss_function(pred, target) + def training_step(self, batch: Sample | Sequence[Sample], batch_idx: int): + losses = [] + batch_size = 0 + if not isinstance(batch, Sequence): + batch = [batch] + for b in batch: + source = b["source"] + target = b["target"] + pred = self.forward(source) + loss = self.loss_function(pred, target) + losses.append(loss) + batch_size += source.shape[0] + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target, pred)) + ) + loss_step = torch.stack(losses).mean() self.log( "loss/train", - loss, + loss_step.to(self.device), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, + batch_size=batch_size, ) - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - self._detach_sample((source, target, pred)) - ) - return loss + return loss_step def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source = batch["source"] - target = batch["target"] + source: Tensor = batch["source"] + target: Tensor = batch["target"] pred = self.forward(source) loss = self.loss_function(pred, target) - self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False) + if dataloader_idx + 1 > len(self.validation_losses): + self.validation_losses.append([]) + self.validation_losses[dataloader_idx].append(loss.detach()) + self.log( + f"loss/val/{dataloader_idx}", + loss.to(self.device), + sync_dist=True, + batch_size=source.shape[0], + ) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target, pred)) @@ -305,8 +323,16 @@ def on_train_epoch_end(self): self.training_step_outputs = [] def on_validation_epoch_end(self): + super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) self.validation_step_outputs = [] + # average within each dataloader + loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses] + self.log( + "loss/validate", + torch.tensor(loss_means).mean().to(self.device), + sync_dist=True, + ) def on_test_start(self): """Load CellPose model for segmentation.""" @@ -382,7 +408,6 @@ class FcmaeUNet(VSUNet): def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio - self.validation_losses = [] def forward(self, x: Tensor, mask_ratio: float = 0.0): return self.model(x, mask_ratio) @@ -434,13 +459,3 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 self.validation_step_outputs.extend( self._detach_sample((source, target * mask.unsqueeze(2), pred)) ) - - def on_validation_epoch_end(self): - super().on_validation_epoch_end() - # average within each dataloader - loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses] - self.log( - "loss/validate", - torch.tensor(loss_means).mean().to(self.device), - sync_dist=True, - ) diff --git a/viscy/scripts/network_diagram.py b/viscy/scripts/network_diagram.py index dc436cdf..bcc1714f 100644 --- a/viscy/scripts/network_diagram.py +++ b/viscy/scripts/network_diagram.py @@ -1,7 +1,7 @@ # %% from torchview import draw_graph -from viscy.light.engine import VSUNet +from viscy.light.engine import FcmaeUNet, VSUNet # %% 2D UNet model = VSUNet( @@ -94,3 +94,40 @@ graph22d # %% If you want to save the graphs as SVG files: # model_graph.visual_graph.render(format="svg") + +# %% +model = FcmaeUNet( + model_config=dict( + in_channels=1, + out_channels=1, + encoder_blocks=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + decoder_conv_blocks=1, + stem_kernel_size=(1, 2, 2), + in_stack_depth=1, + ), + fit_mask_ratio=0.5, + schedule="WarmupCosine", + lr=2e-4, + log_batches_per_epoch=2, + log_samples_per_batch=2, +) + +model_graph = draw_graph( + model, + (model.example_input_array), + graph_name="VSCyto2D", + roll=True, + depth=3, +) + +fcmae = model_graph.visual_graph +fcmae + +# %% + +model_graph.visual_graph.render( + format="svg", +) + +# %% diff --git a/viscy/transforms.py b/viscy/transforms.py index 88e7f738..3775d154 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -166,8 +166,13 @@ class RandInvertIntensityd(MapTransform, RandomizableTransform): Randomly invert the intensity of the image. """ - def __init__(self, keys: Union[str, Iterable[str]], prob: float = 0.1) -> None: - MapTransform.__init__(self, keys) + def __init__( + self, + keys: Union[str, Iterable[str]], + prob: float = 0.1, + allow_missing_keys: bool = False, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys=allow_missing_keys) RandomizableTransform.__init__(self, prob) def __call__(self, sample: Sample) -> Sample: diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 51345906..6c2f6f45 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -5,9 +5,11 @@ and timm's dense implementation of the encoder in ``timm.models.convnext`` """ +import math from typing import Sequence import torch +from monai.networks.blocks import UpSample from timm.models.convnext import ( Downsample, DropPath, @@ -18,7 +20,7 @@ ) from torch import BoolTensor, Size, Tensor, nn -from viscy.unet.networks.Unet22D import PixelToVoxelHead, Unet2dDecoder, UnsqueezeHead +from viscy.unet.networks.Unet22D import PixelToVoxelHead, Unet2dDecoder def _init_weights(module: nn.Module) -> None: @@ -337,7 +339,9 @@ def __init__( self.total_stride = stem_kernel_size[1] * 2 ** (len(self.stages) - 1) self.apply(_init_weights) - def forward(self, x: Tensor, mask_ratio: float = 0.0) -> list[Tensor]: + def forward( + self, x: Tensor, mask_ratio: float = 0.0 + ) -> tuple[list[Tensor], BoolTensor | None]: """ :param Tensor x: input tensor (BCDHW) :param float mask_ratio: ratio of the feature maps to mask, @@ -362,6 +366,35 @@ def forward(self, x: Tensor, mask_ratio: float = 0.0) -> list[Tensor]: return features, mask +class PixelToVoxelShuffleHead(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + out_stack_depth: int = 5, + xy_scaling: int = 4, + pool: bool = False, + ) -> None: + super().__init__() + self.out_channels = out_channels + self.out_stack_depth = out_stack_depth + self.upsample = UpSample( + spatial_dims=2, + in_channels=in_channels, + out_channels=out_stack_depth * out_channels, + scale_factor=xy_scaling, + mode="pixelshuffle", + pre_conv=None, + apply_pad_pool=pool, + ) + + def forward(self, x: Tensor) -> Tensor: + x = self.upsample(x) + b, _, h, w = x.shape + x = x.reshape(b, self.out_channels, self.out_stack_depth, h, w) + return x + + class FullyConvolutionalMAE(nn.Module): def __init__( self, @@ -370,11 +403,13 @@ def __init__( encoder_blocks: Sequence[int] = [3, 3, 9, 3], dims: Sequence[int] = [96, 192, 384, 768], encoder_drop_path_rate: float = 0.0, - head_expansion_ratio: int = 4, stem_kernel_size: Sequence[int] = (5, 4, 4), in_stack_depth: int = 5, decoder_conv_blocks: int = 1, pretraining: bool = True, + head_conv: bool = False, + head_conv_expansion_ratio: int = 4, + head_conv_pool: bool = True, ) -> None: super().__init__() self.encoder = MaskedMultiscaleEncoder( @@ -387,9 +422,14 @@ def __init__( ) decoder_channels = list(dims) decoder_channels.reverse() - decoder_channels[-1] = ( - (in_stack_depth + 2) * in_channels * 2**2 * head_expansion_ratio - ) + if head_conv: + decoder_channels[-1] = ( + (in_stack_depth + 2) * in_channels * 2**2 * head_conv_expansion_ratio + ) + else: + decoder_channels[-1] = ( + out_channels * in_stack_depth * stem_kernel_size[-1] ** 2 + ) self.decoder = Unet2dDecoder( decoder_channels, norm_name="instance", @@ -398,18 +438,25 @@ def __init__( strides=[2] * (len(dims) - 1) + [stem_kernel_size[-1]], upsample_pre_conv=None, ) - if in_stack_depth == 1: - self.head = UnsqueezeHead() - else: + if head_conv: self.head = PixelToVoxelHead( in_channels=decoder_channels[-1], out_channels=out_channels, out_stack_depth=in_stack_depth, - expansion_ratio=head_expansion_ratio, + expansion_ratio=head_conv_expansion_ratio, + pool=head_conv_pool, + ) + else: + self.head = PixelToVoxelShuffleHead( + in_channels=decoder_channels[-1], + out_channels=out_channels, + out_stack_depth=in_stack_depth, + xy_scaling=stem_kernel_size[-1], pool=True, ) self.out_stack_depth = in_stack_depth - self.num_blocks = 6 + # TODO: replace num_blocks with explicit strides for all models + self.num_blocks = len(dims) * int(math.log2(stem_kernel_size[-1])) self.pretraining = pretraining def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: