diff --git a/pyproject.toml b/pyproject.toml index 8f6978de..5f0a184f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dynamic = ["version"] metrics = [ "cellpose==2.1.0", "scikit-learn>=1.1.3", - "torchmetrics[detection]>=1.0.0", + "torchmetrics[detection]>=1.3.1", "ptflops>=0.7", ] visual = ["ipykernel", "graphviz", "torchview"] diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index 4ed441b4..f22efa4c 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -6,6 +6,7 @@ MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, MaskedMultiscaleEncoder, + PixelToVoxelShuffleHead, generate_mask, masked_patchify, masked_unpatchify, @@ -104,6 +105,13 @@ def test_masked_multiscale_encoder(): assert afeat.shape[2] == afeat.shape[3] == xy_size // stride +def test_pixel_to_voxel_shuffle_head(): + head = PixelToVoxelShuffleHead(240, 3, out_stack_depth=5, xy_scaling=4) + x = torch.rand(2, 240, 16, 16) + y = head(x) + assert y.shape == (2, 3, 5, 64, 64) + + def test_fcmae(): x = torch.rand(2, 3, 5, 128, 128) model = FullyConvolutionalMAE(3, 3) @@ -113,3 +121,16 @@ def test_fcmae(): y, m = model(x, mask_ratio=0.6) assert y.shape == x.shape assert m.shape == (2, 1, 128, 128) + + +def test_fcmae_head_conv(): + x = torch.rand(2, 3, 5, 128, 128) + model = FullyConvolutionalMAE( + 3, 3, head_conv=True, head_conv_expansion_ratio=4, head_conv_pool=True + ) + y, m = model(x) + assert y.shape == x.shape + assert m is None + y, m = model(x, mask_ratio=0.6) + assert y.shape == x.shape + assert m.shape == (2, 1, 128, 128) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index db3803a5..31ea9f6c 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -79,7 +79,6 @@ class ConcatDataModule(LightningDataModule): The concatenated data module will have the same batch size and number of workers as the first data module. Each element will be sampled uniformly regardless of their original data module. - :param Sequence[LightningDataModule] data_modules: data modules to concatenate """ @@ -93,9 +92,11 @@ def __init__(self, data_modules: Sequence[LightningDataModule]): raise ValueError("Inconsistent number of workers") if dm.batch_size != self.batch_size: raise ValueError("Inconsistent batch size") + self.prepare_data_per_node = True def prepare_data(self): for dm in self.data_modules: + dm.trainer = self.trainer dm.prepare_data() def setup(self, stage: Literal["fit", "validate", "test", "predict"]): diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index d666fdcb..727f241f 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -10,9 +10,8 @@ class CTMCv1ValidationDataset(SlidingWindowDataset): - subsample_rate: int = 30 - def __len__(self) -> int: + def __len__(self, subsample_rate: int = 30) -> int: # sample every 30th frame in the videos return super().__len__() // self.subsample_rate diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index f33b6121..77bcc1ed 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -191,8 +191,6 @@ def __getitem__(self, index: int) -> Sample: sample_images["norm_meta"] = norm_meta if self.transform: sample_images = self.transform(sample_images) - # if isinstance(sample_images, list): - # sample_images = sample_images[0] if "weight" in sample_images: del sample_images["weight"] sample = { diff --git a/viscy/data/livecell.py b/viscy/data/livecell.py index 5d83f099..bb8bb56c 100644 --- a/viscy/data/livecell.py +++ b/viscy/data/livecell.py @@ -3,9 +3,11 @@ import torch from lightning.pytorch import LightningDataModule -from monai.transforms import Compose, Transform +from monai.transforms import Compose, MapTransform +from pycocotools.coco import COCO from tifffile import imread from torch.utils.data import DataLoader, Dataset +from torchvision.ops import box_convert from viscy.data.typing import Sample @@ -15,10 +17,10 @@ class LiveCellDataset(Dataset): LiveCell dataset. :param list[Path] images: List of paths to single-page, single-channel TIFF files. - :param Transform | Compose transform: Transform to apply to the dataset + :param MapTransform | Compose transform: Transform to apply to the dataset """ - def __init__(self, images: list[Path], transform: Transform | Compose) -> None: + def __init__(self, images: list[Path], transform: MapTransform | Compose) -> None: self.images = images self.transform = transform @@ -32,14 +34,70 @@ def __getitem__(self, idx: int) -> Sample: return {"source": image, "target": image} +class LiveCellTestDataset(Dataset): + """ + LiveCell dataset. + + :param list[Path] images: List of paths to single-page, single-channel TIFF files. + :param MapTransform | Compose transform: Transform to apply to the dataset + """ + + def __init__( + self, + image_dir: Path, + transform: MapTransform | Compose, + annotations: Path, + load_target: bool = False, + load_labels: bool = False, + ) -> None: + self.image_dir = image_dir + self.transform = transform + self.coco = COCO(str(annotations)) + self.image_ids = list(self.coco.imgs.keys()) + self.load_target = load_target + self.load_labels = load_labels + + def __len__(self) -> int: + return len(self.image_ids) + + def __getitem__(self, idx: int) -> Sample: + image_id = self.image_ids[idx] + file_name = self.coco.imgs[image_id]["file_name"] + image_path = self.image_dir / file_name + image = imread(image_path)[None, None] + image = torch.from_numpy(image).to(torch.float32) + sample = Sample(source=image) + if self.load_target: + sample["target"] = image + if self.load_labels: + anns = self.coco.loadAnns(self.coco.getAnnIds(image_id)) or [] + boxes = [torch.tensor(ann["bbox"]).to(torch.float32) for ann in anns] + masks = [ + torch.from_numpy(self.coco.annToMask(ann)).to(torch.bool) + for ann in anns + ] + dets = { + "boxes": box_convert(torch.stack(boxes), in_fmt="xywh", out_fmt="xyxy"), + "labels": torch.zeros(len(anns)).to(torch.uint8), + "masks": torch.stack(masks), + } + sample["detections"] = dets + sample["file_name"] = file_name + self.transform(sample) + return sample + + class LiveCellDataModule(LightningDataModule): def __init__( self, - train_val_images: Path, - train_annotations: Path, - val_annotations: Path, - train_transforms: list[Transform], - val_transforms: list[Transform], + train_val_images: Path | None = None, + test_images: Path | None = None, + train_annotations: Path | None = None, + val_annotations: Path | None = None, + test_annotations: Path | None = None, + train_transforms: list[MapTransform] = [], + val_transforms: list[MapTransform] = [], + test_transforms: list[MapTransform] = [], batch_size: int = 16, num_workers: int = 8, ) -> None: @@ -47,21 +105,29 @@ def __init__( self.train_val_images = Path(train_val_images) if not self.train_val_images.is_dir(): raise NotADirectoryError(str(train_val_images)) + self.test_images = Path(test_images) + if not self.test_images.is_dir(): + raise NotADirectoryError(str(test_images)) self.train_annotations = Path(train_annotations) if not self.train_annotations.is_file(): raise FileNotFoundError(str(train_annotations)) self.val_annotations = Path(val_annotations) if not self.val_annotations.is_file(): raise FileNotFoundError(str(val_annotations)) + self.test_annotations = Path(test_annotations) + if not self.test_annotations.is_file(): + raise FileNotFoundError(str(test_annotations)) self.train_transforms = Compose(train_transforms) self.val_transforms = Compose(val_transforms) + self.test_transforms = Compose(test_transforms) self.batch_size = batch_size self.num_workers = num_workers def setup(self, stage: str) -> None: - if stage != "fit": - raise NotImplementedError("Only fit stage is supported") - self._setup_fit() + if stage == "fit": + self._setup_fit() + elif stage == "test": + self._setup_test() def _parse_image_names(self, annotations: Path) -> list[Path]: with open(annotations) as f: @@ -80,6 +146,14 @@ def _setup_fit(self) -> None: transform=self.val_transforms, ) + def _setup_test(self) -> None: + self.test_dataset = LiveCellTestDataset( + self.test_images, + transform=self.test_transforms, + annotations=self.test_annotations, + load_labels=True, + ) + def train_dataloader(self) -> DataLoader: return DataLoader( self.train_dataset, @@ -96,3 +170,8 @@ def val_dataloader(self) -> DataLoader: num_workers=self.num_workers, persistent_workers=bool(self.num_workers), ) + + def test_dataloader(self) -> DataLoader: + return DataLoader( + self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) diff --git a/viscy/evaluation/evaluation_metrics.py b/viscy/evaluation/evaluation_metrics.py index 921b0e4e..bb89858f 100644 --- a/viscy/evaluation/evaluation_metrics.py +++ b/viscy/evaluation/evaluation_metrics.py @@ -9,7 +9,7 @@ from monai.metrics.regression import compute_ssim_and_cs from scipy.optimize import linear_sum_assignment from skimage.measure import label, regionprops -from torchmetrics.detection import MeanAveragePrecision +from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchvision.ops import masks_to_boxes @@ -172,7 +172,12 @@ def mean_average_precision( :py:class:`torchmetrics.detection.MeanAveragePrecision` :return dict[str, torch.Tensor]: COCO-style metrics """ - map_metric = MeanAveragePrecision(box_format="xyxy", iou_type="segm", **kwargs) + defaults = dict( + iou_type="segm", box_format="xyxy", max_detection_thresholds=[1, 100, 10000] + ) + if not kwargs: + kwargs = {} + map_metric = MeanAveragePrecision(**(defaults | kwargs)) map_metric.update( [labels_to_detection(pred_labels)], [labels_to_detection(target_labels)] ) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 08b85319..ac15c208 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -146,6 +146,7 @@ def __init__( self.log_batches_per_epoch = log_batches_per_epoch self.log_samples_per_batch = log_samples_per_batch self.training_step_outputs = [] + self.validation_losses = [] self.validation_step_outputs = [] # required to log the graph if architecture == "2D": @@ -170,32 +171,49 @@ def __init__( def forward(self, x: Tensor) -> Tensor: return self.model(x) - def training_step(self, batch: Sample, batch_idx: int): - source = batch["source"] - target = batch["target"] - pred = self.forward(source) - loss = self.loss_function(pred, target) + def training_step(self, batch: Sample | Sequence[Sample], batch_idx: int): + losses = [] + batch_size = 0 + if not isinstance(batch, Sequence): + batch = [batch] + for b in batch: + source = b["source"] + target = b["target"] + pred = self.forward(source) + loss = self.loss_function(pred, target) + losses.append(loss) + batch_size += source.shape[0] + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target, pred)) + ) + loss_step = torch.stack(losses).mean() self.log( "loss/train", - loss, + loss_step.to(self.device), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, + batch_size=batch_size, ) - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - self._detach_sample((source, target, pred)) - ) - return loss + return loss_step def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source = batch["source"] - target = batch["target"] + source: Tensor = batch["source"] + target: Tensor = batch["target"] pred = self.forward(source) loss = self.loss_function(pred, target) - self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False) + if dataloader_idx + 1 > len(self.validation_losses): + self.validation_losses.append([]) + self.validation_losses[dataloader_idx].append(loss.detach()) + self.log( + f"loss/val/{dataloader_idx}", + loss.to(self.device), + sync_dist=True, + batch_size=source.shape[0], + ) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target, pred)) @@ -305,8 +323,16 @@ def on_train_epoch_end(self): self.training_step_outputs = [] def on_validation_epoch_end(self): + super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) self.validation_step_outputs = [] + # average within each dataloader + loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses] + self.log( + "loss/validate", + torch.tensor(loss_means).mean().to(self.device), + sync_dist=True, + ) def on_test_start(self): """Load CellPose model for segmentation.""" @@ -382,7 +408,6 @@ class FcmaeUNet(VSUNet): def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio - self.validation_losses = [] def forward(self, x: Tensor, mask_ratio: float = 0.0): return self.model(x, mask_ratio) @@ -434,13 +459,3 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 self.validation_step_outputs.extend( self._detach_sample((source, target * mask.unsqueeze(2), pred)) ) - - def on_validation_epoch_end(self): - super().on_validation_epoch_end() - # average within each dataloader - loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses] - self.log( - "loss/validate", - torch.tensor(loss_means).mean().to(self.device), - sync_dist=True, - ) diff --git a/viscy/scripts/network_diagram.py b/viscy/scripts/network_diagram.py index dc436cdf..bcc1714f 100644 --- a/viscy/scripts/network_diagram.py +++ b/viscy/scripts/network_diagram.py @@ -1,7 +1,7 @@ # %% from torchview import draw_graph -from viscy.light.engine import VSUNet +from viscy.light.engine import FcmaeUNet, VSUNet # %% 2D UNet model = VSUNet( @@ -94,3 +94,40 @@ graph22d # %% If you want to save the graphs as SVG files: # model_graph.visual_graph.render(format="svg") + +# %% +model = FcmaeUNet( + model_config=dict( + in_channels=1, + out_channels=1, + encoder_blocks=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + decoder_conv_blocks=1, + stem_kernel_size=(1, 2, 2), + in_stack_depth=1, + ), + fit_mask_ratio=0.5, + schedule="WarmupCosine", + lr=2e-4, + log_batches_per_epoch=2, + log_samples_per_batch=2, +) + +model_graph = draw_graph( + model, + (model.example_input_array), + graph_name="VSCyto2D", + roll=True, + depth=3, +) + +fcmae = model_graph.visual_graph +fcmae + +# %% + +model_graph.visual_graph.render( + format="svg", +) + +# %% diff --git a/viscy/transforms.py b/viscy/transforms.py index 88e7f738..3775d154 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -166,8 +166,13 @@ class RandInvertIntensityd(MapTransform, RandomizableTransform): Randomly invert the intensity of the image. """ - def __init__(self, keys: Union[str, Iterable[str]], prob: float = 0.1) -> None: - MapTransform.__init__(self, keys) + def __init__( + self, + keys: Union[str, Iterable[str]], + prob: float = 0.1, + allow_missing_keys: bool = False, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys=allow_missing_keys) RandomizableTransform.__init__(self, prob) def __call__(self, sample: Sample) -> Sample: diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 51345906..6c2f6f45 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -5,9 +5,11 @@ and timm's dense implementation of the encoder in ``timm.models.convnext`` """ +import math from typing import Sequence import torch +from monai.networks.blocks import UpSample from timm.models.convnext import ( Downsample, DropPath, @@ -18,7 +20,7 @@ ) from torch import BoolTensor, Size, Tensor, nn -from viscy.unet.networks.Unet22D import PixelToVoxelHead, Unet2dDecoder, UnsqueezeHead +from viscy.unet.networks.Unet22D import PixelToVoxelHead, Unet2dDecoder def _init_weights(module: nn.Module) -> None: @@ -337,7 +339,9 @@ def __init__( self.total_stride = stem_kernel_size[1] * 2 ** (len(self.stages) - 1) self.apply(_init_weights) - def forward(self, x: Tensor, mask_ratio: float = 0.0) -> list[Tensor]: + def forward( + self, x: Tensor, mask_ratio: float = 0.0 + ) -> tuple[list[Tensor], BoolTensor | None]: """ :param Tensor x: input tensor (BCDHW) :param float mask_ratio: ratio of the feature maps to mask, @@ -362,6 +366,35 @@ def forward(self, x: Tensor, mask_ratio: float = 0.0) -> list[Tensor]: return features, mask +class PixelToVoxelShuffleHead(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + out_stack_depth: int = 5, + xy_scaling: int = 4, + pool: bool = False, + ) -> None: + super().__init__() + self.out_channels = out_channels + self.out_stack_depth = out_stack_depth + self.upsample = UpSample( + spatial_dims=2, + in_channels=in_channels, + out_channels=out_stack_depth * out_channels, + scale_factor=xy_scaling, + mode="pixelshuffle", + pre_conv=None, + apply_pad_pool=pool, + ) + + def forward(self, x: Tensor) -> Tensor: + x = self.upsample(x) + b, _, h, w = x.shape + x = x.reshape(b, self.out_channels, self.out_stack_depth, h, w) + return x + + class FullyConvolutionalMAE(nn.Module): def __init__( self, @@ -370,11 +403,13 @@ def __init__( encoder_blocks: Sequence[int] = [3, 3, 9, 3], dims: Sequence[int] = [96, 192, 384, 768], encoder_drop_path_rate: float = 0.0, - head_expansion_ratio: int = 4, stem_kernel_size: Sequence[int] = (5, 4, 4), in_stack_depth: int = 5, decoder_conv_blocks: int = 1, pretraining: bool = True, + head_conv: bool = False, + head_conv_expansion_ratio: int = 4, + head_conv_pool: bool = True, ) -> None: super().__init__() self.encoder = MaskedMultiscaleEncoder( @@ -387,9 +422,14 @@ def __init__( ) decoder_channels = list(dims) decoder_channels.reverse() - decoder_channels[-1] = ( - (in_stack_depth + 2) * in_channels * 2**2 * head_expansion_ratio - ) + if head_conv: + decoder_channels[-1] = ( + (in_stack_depth + 2) * in_channels * 2**2 * head_conv_expansion_ratio + ) + else: + decoder_channels[-1] = ( + out_channels * in_stack_depth * stem_kernel_size[-1] ** 2 + ) self.decoder = Unet2dDecoder( decoder_channels, norm_name="instance", @@ -398,18 +438,25 @@ def __init__( strides=[2] * (len(dims) - 1) + [stem_kernel_size[-1]], upsample_pre_conv=None, ) - if in_stack_depth == 1: - self.head = UnsqueezeHead() - else: + if head_conv: self.head = PixelToVoxelHead( in_channels=decoder_channels[-1], out_channels=out_channels, out_stack_depth=in_stack_depth, - expansion_ratio=head_expansion_ratio, + expansion_ratio=head_conv_expansion_ratio, + pool=head_conv_pool, + ) + else: + self.head = PixelToVoxelShuffleHead( + in_channels=decoder_channels[-1], + out_channels=out_channels, + out_stack_depth=in_stack_depth, + xy_scaling=stem_kernel_size[-1], pool=True, ) self.out_stack_depth = in_stack_depth - self.num_blocks = 6 + # TODO: replace num_blocks with explicit strides for all models + self.num_blocks = len(dims) * int(math.log2(stem_kernel_size[-1])) self.pretraining = pretraining def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: