From 97ff42a180d934529c0fc2b028f4cb62cde06a4c Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Sun, 6 Aug 2023 22:15:05 -0700 Subject: [PATCH 01/30] training time z scaling augmentation --- viscy/light/data.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/viscy/light/data.py b/viscy/light/data.py index b95d41c1..12cdc4ed 100644 --- a/viscy/light/data.py +++ b/viscy/light/data.py @@ -1,4 +1,5 @@ import logging +import math import os import re import tempfile @@ -281,6 +282,9 @@ class HCSDataModule(LightningDataModule): defaults to False :param str ground_truth_masks: path to the ground truth segmentation masks, defaults to None + :param tuple[float, float] train_z_scale_range: Z scaling augmentation range, + passed to MONAI's ``RandAffined`` transform, + defaults to [-0.4, 0.2] """ def __init__( @@ -298,6 +302,7 @@ def __init__( caching: bool = False, normalize_source: bool = False, ground_truth_masks: str = None, + train_z_scale_range: tuple[float, float] = [-0.2, 1], ): super().__init__() self.data_path = data_path @@ -314,6 +319,9 @@ def __init__( self.normalize_source = normalize_source self.ground_truth_masks = ground_truth_masks self.tmp_zarr = None + if train_z_scale_range[0] > 1 or train_z_scale_range[1] < 1: + raise ValueError(f"Invalid scaling range: {train_z_scale_range}") + self.train_z_scale_range = train_z_scale_range def prepare_data(self): if not self.caching: @@ -391,11 +399,15 @@ def _setup_fit(self, dataset_settings: dict): shuffled_indices = torch.randperm(len(positions)) positions = list(positions[i] for i in shuffled_indices) num_train_fovs = int(len(positions) * self.split_ratio) + # training set needs to sample more Z range for augmentation + train_dataset_settings = dataset_settings.copy() + expanded_z = math.ceil(self.z_window_size * (1 + self.train_z_scale_range[1])) + train_dataset_settings["z_window_size"] = expanded_z - expanded_z // 2 # train/val split self.train_dataset = SlidingWindowDataset( positions[:num_train_fovs], transform=train_transform, - **dataset_settings, + **train_dataset_settings, ) self.val_dataset = SlidingWindowDataset( positions[num_train_fovs:], transform=val_transform, **dataset_settings @@ -481,7 +493,7 @@ def _fit_transform(self): CenterSpatialCropd( keys=self.source_channel + self.target_channel, roi_size=( - -1, + self.z_window_size, self.yx_patch_size[0], self.yx_patch_size[1], ), @@ -505,7 +517,7 @@ def _train_transform(self) -> list[Callable]: prob=0.5, rotate_range=(np.pi, 0, 0), shear_range=(0, (0.05), (0.05)), - scale_range=(0, 0.3, 0.3), + scale_range=(self.train_z_scale_range, 0.3, 0.3), ), RandAdjustContrastd( keys=self.source_channel, prob=0.3, gamma=(0.75, 1.5) From 721369268de847267f6c04467a514d4f9f51a301 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Sun, 6 Aug 2023 22:16:33 -0700 Subject: [PATCH 02/30] plotting script for network diagram --- viscy/scripts/network_diagram.py | 35 ++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 viscy/scripts/network_diagram.py diff --git a/viscy/scripts/network_diagram.py b/viscy/scripts/network_diagram.py new file mode 100644 index 00000000..48587a1d --- /dev/null +++ b/viscy/scripts/network_diagram.py @@ -0,0 +1,35 @@ +# %% +from torchview import draw_graph + +from viscy.light.engine import VSUNet + +# %% +model = VSUNet( + model_config={ + "architecture": "2.5D", + "in_channels": 1, + "out_channels": 2, + "in_stack_depth": 5, + "residual": True, + "task": "reg", + "dropout": 0.1, + }, + batch_size=32, +) +# %% + +model_graph = draw_graph( + model, + model.example_input_array, + graph_name="2.5D UNet", + roll=True, + depth=4, + # graph_dir="LR", + directory="/hpc/projects/comp.micro/virtual_staining/models/HEK_phase_to_nuc_mem/", + save_graph=True, +) + +graph = model_graph.visual_graph +graph +# %% +model_graph.visual_graph.render(format="svg") From 87c2e044ebff2774426011cc9918f56151d08eee Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 8 Aug 2023 17:10:21 -0700 Subject: [PATCH 03/30] support different log level --- viscy/cli/cli.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/viscy/cli/cli.py b/viscy/cli/cli.py index acb72774..ef1cd48b 100644 --- a/viscy/cli/cli.py +++ b/viscy/cli/cli.py @@ -1,3 +1,5 @@ +import logging +import os from datetime import datetime import torch @@ -35,6 +37,9 @@ def add_arguments_to_parser(self, parser): def main(): + """Main Lightning CLI entry point.""" + log_level = os.getenv("VISCY_LOG_LEVEL", logging.INFO) + logging.getLogger("lightning.pytorch").setLevel((log_level)) torch.set_float32_matmul_precision("high") _ = VSLightningCLI( model_class=VSUNet, From 9957e17e86ba3dbc63dba3d71b0d0beffb430ab1 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 8 Aug 2023 23:20:08 -0700 Subject: [PATCH 04/30] fix multi-timepoint prediction writing --- viscy/light/predict_writer.py | 36 +++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py index 5e830aa0..5abebd42 100644 --- a/viscy/light/predict_writer.py +++ b/viscy/light/predict_writer.py @@ -1,6 +1,6 @@ import logging import os -from typing import Literal, Sequence +from typing import Literal, Optional, Sequence import torch from iohub.ngff import ImageArray, _pad_shape, open_ome_zarr @@ -10,17 +10,19 @@ from viscy.light.data import Sample +_logger = logging.getLogger("lightning.pytorch") + def _resize_image(image: ImageArray, t_index: int, z_index: int): """Resize image array if incoming T and Z index is not within bounds.""" if image.shape[0] <= t_index or image.shape[2] <= z_index: - logging.debug( + _logger.debug( f"Resizing image '{image.name}' {image.shape} for T={t_index}, Z={z_index}." ) image.resize( max(t_index + 1, image.shape[0]), image.channels, - max(z_index + 1, image.shape[1]), + max(z_index + 1, image.shape[2]), *image.shape[-2:], ) @@ -63,11 +65,11 @@ def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None else: channel_names = prediction_channel if self.write_input: - channel_names = source_channel + target_channel + channel_names + channel_names = source_channel + channel_names self.plate = open_ome_zarr( self.output_store, layout="hcs", mode="a", channel_names=channel_names ) - logging.info(f"Writing prediction to: '{self.plate.zgroup.store.path}'.") + _logger.info(f"Writing prediction to: '{self.plate.zgroup.store.path}'.") if self.write_input: self.source_index = self._get_channel_indices(source_channel) self.target_index = self._get_channel_indices(target_channel) @@ -81,12 +83,12 @@ def write_on_batch_end( trainer: Trainer, pl_module: LightningModule, prediction: torch.Tensor, - batch_indices: Sequence[int] | None, + batch_indices: Optional[Sequence[int]], batch: Sample, batch_idx: int, dataloader_idx: int, ) -> None: - logging.debug(f"Writing batch {batch_idx}.") + _logger.debug(f"Writing batch {batch_idx}.") for sample_index, _ in enumerate(batch["index"][0]): self.write_sample(batch, prediction[sample_index], sample_index) @@ -96,7 +98,7 @@ def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None: def write_sample( self, batch: Sample, sample_prediction: torch.Tensor, sample_index: int ) -> None: - logging.debug(f"Writing sample {sample_index}.") + _logger.debug(f"Writing sample {sample_index}.") sample_prediction = sample_prediction.cpu().numpy() img_name, t_index, z_index = [batch["index"][i][sample_index] for i in range(3)] t_index = int(t_index) @@ -106,20 +108,22 @@ def write_sample( ) _resize_image(image, t_index, z_index) if self.write_input: - # FIXME: should write center sclice of source - image[t_index, self.source_index, z_index] = batch["source"][ - sample_index - ].cpu()[:, 0] - image[t_index, self.target_index, z_index] = batch["target"][ - sample_index - ].cpu()[:, 0] + source_stack = batch["source"][sample_index].cpu() + center_slice_index = source_stack.shape[-3] // 2 + image[t_index, self.source_index, z_index] = source_stack[ + :, center_slice_index + ] + if "target" in batch: + image[t_index, self.target_index, z_index] = batch["target"][ + sample_index + ][:, center_slice_index].cpu() # write C1YX image.oindex[t_index, self.prediction_index, z_index] = sample_prediction[:, 0] def _create_image(self, img_name: str, shape: tuple[int], dtype: DTypeLike): if img_name in self.plate.zgroup: return self.plate[img_name] - logging.debug(f"Creating image '{img_name}'") + _logger.debug(f"Creating image '{img_name}'") _, row_name, col_name, pos_name, arr_name = img_name.split("/") position = self.plate.create_position(row_name, col_name, pos_name) shape = [1] + list(shape) From e76988c64e804c06389f038adc3da35c6bb40a23 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 00:39:03 -0700 Subject: [PATCH 05/30] add 2.1D U-Net --- viscy/light/data.py | 17 +++- viscy/light/engine.py | 13 ++- viscy/scripts/network_diagram.py | 20 +++-- viscy/unet/networks/Unet21D.py | 144 +++++++++++++++++++++++++++++++ viscy/unet/utils/model.py | 123 -------------------------- 5 files changed, 179 insertions(+), 138 deletions(-) create mode 100644 viscy/unet/networks/Unet21D.py delete mode 100644 viscy/unet/utils/model.py diff --git a/viscy/light/data.py b/viscy/light/data.py index 12cdc4ed..ada93eb4 100644 --- a/viscy/light/data.py +++ b/viscy/light/data.py @@ -22,6 +22,7 @@ RandAffined, RandGaussianSmoothd, RandWeightedCropd, + ScaleIntensityRangePercentilesd, ) from torch.utils.data import DataLoader, Dataset @@ -169,7 +170,7 @@ def _read_img_window( slice(t, t + 1), [int(i) for i in ch_idx], slice(z, z + self.z_window_size), - ] + ].astype(np.float32) return torch.from_numpy(data).unbind(dim=1), (img.name, t, z) def __len__(self) -> int: @@ -269,7 +270,8 @@ class HCSDataModule(LightningDataModule): e.g. ``['Nuclei', 'Membrane']`` :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D :param float split_ratio: split ratio of the training subset in the fit stage, - e.g. 0.8 means a 80/20 split between training/validation + e.g. 0.8 means a 80/20 split between training/validation, + by default 0.8 :param int batch_size: batch size, defaults to 16 :param int num_workers: number of data-loading workers, defaults to 8 :param Literal["2.5D", "2D", "3D"] architecture: U-Net architecture, @@ -293,7 +295,7 @@ def __init__( source_channel: Union[str, Sequence[str]], target_channel: Union[str, Sequence[str]], z_window_size: int, - split_ratio: float, + split_ratio: float = 0.8, batch_size: int = 16, num_workers: int = 8, architecture: Literal["2.5D", "2D", "3D"] = "2.5D", @@ -497,7 +499,14 @@ def _fit_transform(self): self.yx_patch_size[0], self.yx_patch_size[1], ), - ) + ), + ScaleIntensityRangePercentilesd( + keys=self.target_channel, + lower=5, + upper=95, + b_min=None, + b_max=None, + ), ] def _train_transform(self) -> list[Callable]: diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 658d1831..a73993fe 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -29,8 +29,13 @@ from viscy.evaluation.evaluation_metrics import mean_average_precision from viscy.light.data import Sample +from viscy.unet.networks.Unet21D import Unet21d from viscy.unet.networks.Unet25D import Unet25d -from viscy.unet.utils.model import ModelDefaults25D, define_model + +_UNET_ARCHITECTURE = { + "2.5D": Unet25d, + "2.1D": Unet21d, +} class VSTrainer(Trainer): @@ -126,7 +131,11 @@ def __init__( test_evaluate_cellpose: bool = False, ) -> None: super().__init__() - self.model = define_model(Unet25d, ModelDefaults25D(), model_config) + arch = model_config.pop("architecture") + net_class = _UNET_ARCHITECTURE.get(arch) + if not arch: + raise ValueError(f"Architecture {arch} not in {_UNET_ARCHITECTURE.keys()}") + self.model = net_class(**model_config) # TODO: handle num_outputs in metrics # self.out_channels = self.model.terminal_block.out_filters self.batch_size = batch_size diff --git a/viscy/scripts/network_diagram.py b/viscy/scripts/network_diagram.py index 48587a1d..7fe5dbeb 100644 --- a/viscy/scripts/network_diagram.py +++ b/viscy/scripts/network_diagram.py @@ -6,30 +6,32 @@ # %% model = VSUNet( model_config={ - "architecture": "2.5D", + "architecture": "2.1D", "in_channels": 1, "out_channels": 2, - "in_stack_depth": 5, - "residual": True, - "task": "reg", - "dropout": 0.1, + "in_stack_depth": 9, + "backbone": "convnextv2_femto", + "stem_kernel_size": (5, 4, 4), }, batch_size=32, ) # %% - model_graph = draw_graph( model, model.example_input_array, - graph_name="2.5D UNet", + # model.example_input_array, + graph_name="2.1D UNet", roll=True, - depth=4, + depth=3, # graph_dir="LR", directory="/hpc/projects/comp.micro/virtual_staining/models/HEK_phase_to_nuc_mem/", - save_graph=True, + # save_graph=True, ) graph = model_graph.visual_graph graph # %% model_graph.visual_graph.render(format="svg") + +# %% +import matplotlib.pyplot as plt diff --git a/viscy/unet/networks/Unet21D.py b/viscy/unet/networks/Unet21D.py new file mode 100644 index 00000000..d8afb1bd --- /dev/null +++ b/viscy/unet/networks/Unet21D.py @@ -0,0 +1,144 @@ +from typing import Sequence, Union + +import timm +import torch +from monai.networks.blocks import Convolution, ResidualUnit, UnetrUpBlock +from monai.networks.blocks.dynunet_block import get_conv_layer +from torch import nn + + +class Conv21dStem(nn.Module): + """Stem for 2.1D networks.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: tuple[int, int, int], + in_stack_depth: int, + ) -> None: + super().__init__() + ratio = in_stack_depth // kernel_size[0] + self.conv = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels // ratio, + kernel_size=kernel_size, + stride=kernel_size, + ) + + def forward(self, x: torch.Tensor): + x = self.conv(x) + b, c, d, h, w = x.shape + # project Z/depth into channels + # return a view when possible (contiguous) + return x.view(b, c * d, h, w) + + +class Unet2dDecoder(nn.Module): + def __init__( + self, + num_channels: list[int], + out_channels: int, + res_block: bool, + norm_name: str, + kernel_size: Union[int, tuple[int, int]], + last_kernel_size: Union[int, tuple[int, int]], + dropout: float = 0, + ) -> None: + super().__init__() + decoder_stages = [] + stages = len(num_channels) + num_channels.append(out_channels) + stride = 2 + for i in range(stages): + stage = UnetrUpBlock( + spatial_dims=2, + in_channels=num_channels[i], + out_channels=num_channels[i + 1], + kernel_size=kernel_size, + upsample_kernel_size=stride, + norm_name=norm_name, + res_block=res_block, + ) + decoder_stages.append(stage) + self.decoder_stages = nn.ModuleList(decoder_stages) + self.head = nn.Sequential( + get_conv_layer( + spatial_dims=2, + in_channels=num_channels[-2], + out_channels=num_channels[-2], + stride=last_kernel_size, + kernel_size=last_kernel_size, + norm=norm_name, + is_transposed=True, + ), + ResidualUnit( + spatial_dims=2, + in_channels=num_channels[-2], + out_channels=num_channels[-2], + kernel_size=kernel_size, + norm=norm_name, + dropout=dropout, + ), + nn.Conv2d( + num_channels[-2], + out_channels, + kernel_size=(1, 1), + ), + ) + + def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor: + feat = features[0] + # padding + features.append(None) + for skip, stage in zip(features[1:], self.decoder_stages[:-1]): + feat = stage(feat, skip) + return self.head(feat) + + +class Unet21d(nn.Module): + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + in_stack_depth: int = 9, + backbone: str = "convnextv2_tiny", + pretrained: bool = False, + stem_kernel_size: tuple[int, int, int] = (3, 4, 4), + decoder_res_block: bool = True, + decoder_norm_layer: str = "instance", + ) -> None: + super().__init__() + if in_stack_depth % stem_kernel_size[0] != 0: + raise ValueError( + f"Input stack depth {in_stack_depth} is not divisible " + f"by stem kernel depth {stem_kernel_size[0]}." + ) + multi_scale_encoder = timm.create_model( + backbone, pretrained=pretrained, features_only=True + ) + num_channels = multi_scale_encoder.feature_info.channels() + # replace first convolution layer with a projection tokenizer + multi_scale_encoder.stem_0 = nn.Identity() + self.encoder_stages = multi_scale_encoder + self.stem = Conv21dStem( + in_channels, num_channels[0], stem_kernel_size, in_stack_depth + ) + decoder_channels = num_channels + decoder_channels.reverse() + self.decoder = Unet2dDecoder( + decoder_channels, + out_channels, + res_block=decoder_res_block, + norm_name=decoder_norm_layer, + kernel_size=3, + last_kernel_size=stem_kernel_size[-2:], + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.stem(x) + x: list = self.encoder_stages(x) + x.reverse() + x = self.decoder(x) + # add Z/depth back + return x.unsqueeze(2) diff --git a/viscy/unet/utils/model.py b/viscy/unet/utils/model.py deleted file mode 100644 index cf888ea8..00000000 --- a/viscy/unet/utils/model.py +++ /dev/null @@ -1,123 +0,0 @@ -import torch - -import viscy.unet.networks.Unet2D as Unet2D -import viscy.unet.networks.Unet25D as Unet25D - - -def model_init(network_config, device=torch.device("cuda"), debug_mode=False): - """ - Initializes network model from a configuration dictionary. - - :param dict network_config: dict containing the configuration parameters for - the model - :param torch.device device: device to store model parameters on (must be same - as data) - """ - if device == "gpu": - device = "cuda" - - assert ( - "architecture" in network_config - ), "Must specify network architecture: 2D, 2.5D" - - if network_config["architecture"] == "2.5D": - default_model = ModelDefaults25D() - model_class = Unet25D.Unet25d - model = define_model( - model_class, - default_model, - network_config, - ) - elif network_config["architecture"] == "2D": - default_model = ModelDefaults2D() - model_class = Unet2D.Unet2d - model = define_model( - model_class, - default_model, - network_config, - ) - else: - raise NotImplementedError("Only 2.5D and 2D architectures available.") - - model.debug_mode = debug_mode - - model.to(device) - - return model - - -def define_model(model_class, model_defaults, config): - """ - Returns an instance of the model given the parameter config and specified - defaults. The model weights are not on cpu at this point. - - :param nn.Module model_class: actual model class to pass defaults into - :param ModelDefaults model_defaults: default model corresponding to config - :param dict config: _description_ - """ - kwargs = {} - for param_name in vars(model_defaults): - if param_name in config: - kwargs[param_name] = config[param_name] - else: - kwargs[param_name] = model_defaults.get(param_name) - - return model_class(**kwargs) - - -class ModelDefaults: - def __init__(self): - """ - Parent class of the model defaults objects. - """ - - def get(self, varname): - """ - Logic for getting an attribute of the default parameters class - - :param str varname: name of attribute - """ - return getattr(self, varname) - - -class ModelDefaults2D(ModelDefaults): - def __init__(self): - """ - Instance of model defaults class, containing all of the default - hyper-parameters for the 2D unet - - All parameters in this default model CAN be accessed by name through - the model config - """ - super(ModelDefaults, self).__init__() - self.in_channels = 1 - self.out_channels = 1 - self.kernel_size = (3, 3) - self.residual = False - self.dropout = 0.2 - self.num_blocks = 4 - self.num_block_layers = 2 - self.num_filters = [] - self.task = "reg" - - -class ModelDefaults25D(ModelDefaults): - def __init__(self): - """ - Instance of default model class, containing all of the default - hyper-parameters for the 2D unet. - - All parameters in this default model CAN be accessed by name through - the model config - """ - self.in_channels = 1 - self.out_channels = 1 - self.in_stack_depth = 5 - self.out_stack_depth = 1 - self.xy_kernel_size = (3, 3) - self.residual = False - self.dropout = 0.2 - self.num_blocks = 4 - self.num_block_layers = 2 - self.num_filters = [] - self.task = "reg" From 9ca8727ff8aa3def50735a78f7aad0047a77e07e Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 00:39:37 -0700 Subject: [PATCH 06/30] cleanup --- viscy/scripts/network_diagram.py | 3 --- viscy/unet/networks/Unet21D.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/viscy/scripts/network_diagram.py b/viscy/scripts/network_diagram.py index 7fe5dbeb..7de226a9 100644 --- a/viscy/scripts/network_diagram.py +++ b/viscy/scripts/network_diagram.py @@ -32,6 +32,3 @@ graph # %% model_graph.visual_graph.render(format="svg") - -# %% -import matplotlib.pyplot as plt diff --git a/viscy/unet/networks/Unet21D.py b/viscy/unet/networks/Unet21D.py index d8afb1bd..6d9dd560 100644 --- a/viscy/unet/networks/Unet21D.py +++ b/viscy/unet/networks/Unet21D.py @@ -2,7 +2,7 @@ import timm import torch -from monai.networks.blocks import Convolution, ResidualUnit, UnetrUpBlock +from monai.networks.blocks import ResidualUnit, UnetrUpBlock from monai.networks.blocks.dynunet_block import get_conv_layer from torch import nn From 209948933bd083cc81e9c373b76bc16801926852 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 00:50:00 -0700 Subject: [PATCH 07/30] fix datamodule arg parsing --- viscy/light/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/viscy/light/data.py b/viscy/light/data.py index ada93eb4..30898ad6 100644 --- a/viscy/light/data.py +++ b/viscy/light/data.py @@ -312,7 +312,7 @@ def __init__( self.target_channel = _ensure_channel_list(target_channel) self.batch_size = batch_size self.num_workers = num_workers - self.target_2d = True if architecture == "2.5D" else False + self.target_2d = False if architecture == "3D" else True self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size @@ -321,7 +321,7 @@ def __init__( self.normalize_source = normalize_source self.ground_truth_masks = ground_truth_masks self.tmp_zarr = None - if train_z_scale_range[0] > 1 or train_z_scale_range[1] < 1: + if train_z_scale_range[0] > 0 or train_z_scale_range[1] < 0: raise ValueError(f"Invalid scaling range: {train_z_scale_range}") self.train_z_scale_range = train_z_scale_range From 76886246488cd988749d56d7d4522f3514308d96 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 01:12:08 -0700 Subject: [PATCH 08/30] cache data in unique dir --- viscy/light/data.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/viscy/light/data.py b/viscy/light/data.py index 30898ad6..9af48489 100644 --- a/viscy/light/data.py +++ b/viscy/light/data.py @@ -281,6 +281,7 @@ class HCSDataModule(LightningDataModule): :param bool augment: whether to apply augmentation in training, defaults to True :param bool caching: whether to decompress all the images and cache the result, + will store in ``/tmp/$SLURM_JOB_ID/`` if available, defaults to False :param str ground_truth_masks: path to the ground truth segmentation masks, defaults to None @@ -343,7 +344,9 @@ def prepare_data(self): logger.addHandler(file_handler) # cache in temporary directory self.tmp_zarr = os.path.join( - tempfile.gettempdir(), os.path.basename(self.data_path) + tempfile.gettempdir(), + os.getenv("SLURM_JOB_ID"), + os.path.basename(self.data_path), ) logger.info(f"Caching dataset at {self.tmp_zarr}.") tmp_store = zarr.NestedDirectoryStore(self.tmp_zarr) @@ -500,13 +503,13 @@ def _fit_transform(self): self.yx_patch_size[1], ), ), - ScaleIntensityRangePercentilesd( - keys=self.target_channel, - lower=5, - upper=95, - b_min=None, - b_max=None, - ), + # ScaleIntensityRangePercentilesd( + # keys=self.target_channel, + # lower=5, + # upper=95, + # b_min=None, + # b_max=None, + # ), ] def _train_transform(self) -> list[Callable]: From e9b3144cae172ea5418b0d2c5be2d27014ab4d4a Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 09:36:35 -0700 Subject: [PATCH 09/30] ensure contiguous --- viscy/unet/networks/Unet21D.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/unet/networks/Unet21D.py b/viscy/unet/networks/Unet21D.py index 6d9dd560..96bac7a8 100644 --- a/viscy/unet/networks/Unet21D.py +++ b/viscy/unet/networks/Unet21D.py @@ -31,7 +31,7 @@ def forward(self, x: torch.Tensor): b, c, d, h, w = x.shape # project Z/depth into channels # return a view when possible (contiguous) - return x.view(b, c * d, h, w) + return x.reshape(b, c * d, h, w) class Unet2dDecoder(nn.Module): From be7ad557958f624d6a60be19568b111e3e523777 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 11:01:03 -0700 Subject: [PATCH 10/30] fix diagram --- viscy/scripts/network_diagram.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/viscy/scripts/network_diagram.py b/viscy/scripts/network_diagram.py index 7de226a9..06c4d16a 100644 --- a/viscy/scripts/network_diagram.py +++ b/viscy/scripts/network_diagram.py @@ -11,7 +11,7 @@ "out_channels": 2, "in_stack_depth": 9, "backbone": "convnextv2_femto", - "stem_kernel_size": (5, 4, 4), + "stem_kernel_size": (3, 4, 4), }, batch_size=32, ) @@ -22,7 +22,7 @@ # model.example_input_array, graph_name="2.1D UNet", roll=True, - depth=3, + depth=2, # graph_dir="LR", directory="/hpc/projects/comp.micro/virtual_staining/models/HEK_phase_to_nuc_mem/", # save_graph=True, From 4595a745d9c2227a0f2ff212263d8d6e48b55ab5 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 11:01:25 -0700 Subject: [PATCH 11/30] fix shape normalization --- viscy/light/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/light/data.py b/viscy/light/data.py index 9af48489..df1ecf7c 100644 --- a/viscy/light/data.py +++ b/viscy/light/data.py @@ -407,7 +407,7 @@ def _setup_fit(self, dataset_settings: dict): # training set needs to sample more Z range for augmentation train_dataset_settings = dataset_settings.copy() expanded_z = math.ceil(self.z_window_size * (1 + self.train_z_scale_range[1])) - train_dataset_settings["z_window_size"] = expanded_z - expanded_z // 2 + train_dataset_settings["z_window_size"] = expanded_z - expanded_z % 2 # train/val split self.train_dataset = SlidingWindowDataset( positions[:num_train_fovs], From 2a47a2e0c5ec1306897efb55c6fd1bbfb1a7f7cb Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 11:08:32 -0700 Subject: [PATCH 12/30] AdamW --- viscy/light/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index a73993fe..7241c5c0 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -310,7 +310,7 @@ def on_predict_start(self): self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) def configure_optimizers(self): - optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) + optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr) if self.schedule == "WarmupCosine": scheduler = WarmupCosineSchedule( optimizer, From f0b71f23a5c2ca5f32bf1740f5dbd9c62ea9e63d Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 11:30:26 -0700 Subject: [PATCH 13/30] depend on timm --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fd9ec29b..e073f000 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,12 +13,12 @@ authors = [{ name = "CZ Biohub SF", email = "compmicro@czbiohub.org" }] dependencies = [ "iohub==0.1.0.dev3", "torch>=2.0.0", - "torchvision>=0.15.1", + "timm>=0.9.5", "tensorboard>=2.13.0", "lightning>=2.0.1", "monai>=1.2.0", "jsonargparse[signatures]>=4.20.1", - "scikit-image>=0.19.2", + "scikit-image", "matplotlib", ] dynamic = ["version"] From 534a30581ecb3b9a6f782d12a15eb17cdf17b971 Mon Sep 17 00:00:00 2001 From: Shalin Mehta <2934183+mattersoflight@users.noreply.github.com> Date: Fri, 28 Jul 2023 16:29:17 -0700 Subject: [PATCH 14/30] link to microDL and paper --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 21e1ddfe..6a301208 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,13 @@ # viscy -viscy is a deep learning pipeline for training and deploying computer vision models for high-throughput imaging and image-based phenotyping with single cell resolution. +viscy is a deep learning pipeline for training and deploying computer vision models for image-based phenotyping at single cell resolution. -The current focus of the pipeline is on the image translation models for virutal staining of multiple cellular compartments from label-free images. We are building these models for screening fields of view during imaging and for simultaneous segmentation of nuclei and membrane for single-cell phenotyping. The pipeline provides utilities to export the models to onnx format for use during runtime. We will grow the collection of the models suitable for high-throughput imaging and phenotyping. +The current focus of the pipeline is on the image translation models for virtual staining of multiple cellular compartments from label-free images. We are building these models for simultaneous segmentation of nuclei and membrane, which are the first steps in a single-cell phenotyping pipeline. Our pipeline also provides utilities to export the models to onnx format for use at runtime. We will grow the collection of the models suitable for high-throughput imaging and phenotyping. ![virtual_staining](docs/figures/phase_to_nuclei_membrane.svg) +This pipeline evolved from the [TensorFlow version of virtual staining pipeline](https://github.com/mehta-lab/microDL), which we reported in [this paper in 2020](https://elifesciences.org/articles/55502). The previous pipeline is now a public archive, and we will be focusing our efforts on viscy. + ## Installation (Optional) create a new virtual/Conda environment. @@ -14,7 +16,8 @@ Clone this repository and install viscy: ```sh git clone https://github.com/mehta-lab/viscy.git -pip install viscy +cd viscy +pip install . ``` Verify installation by accessing the CLI help message: From abab7ba2a682ddc879c9e20c612568bd74cc3e93 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 9 Aug 2023 11:26:32 -0700 Subject: [PATCH 15/30] bump iohub version (#25) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e073f000..388019c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ requires-python = ">=3.9,!=3.11" license = { file = "LICENSE" } authors = [{ name = "CZ Biohub SF", email = "compmicro@czbiohub.org" }] dependencies = [ - "iohub==0.1.0.dev3", + "iohub==0.1.0.dev4", "torch>=2.0.0", "timm>=0.9.5", "tensorboard>=2.13.0", From 17b485048a63d85c8c895c58b60d0df1381178f2 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 21:41:45 -0700 Subject: [PATCH 16/30] fix 0 depth when input is 1 --- viscy/light/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/viscy/light/data.py b/viscy/light/data.py index df1ecf7c..492a8e0d 100644 --- a/viscy/light/data.py +++ b/viscy/light/data.py @@ -305,7 +305,7 @@ def __init__( caching: bool = False, normalize_source: bool = False, ground_truth_masks: str = None, - train_z_scale_range: tuple[float, float] = [-0.2, 1], + train_z_scale_range: tuple[float, float] = [0, 0], ): super().__init__() self.data_path = data_path @@ -407,7 +407,7 @@ def _setup_fit(self, dataset_settings: dict): # training set needs to sample more Z range for augmentation train_dataset_settings = dataset_settings.copy() expanded_z = math.ceil(self.z_window_size * (1 + self.train_z_scale_range[1])) - train_dataset_settings["z_window_size"] = expanded_z - expanded_z % 2 + train_dataset_settings["z_window_size"] = max(1, expanded_z - expanded_z % 2) # train/val split self.train_dataset = SlidingWindowDataset( positions[:num_train_fovs], From 9f55843d6ff61114eb30f6d23f4d55cb309d6fce Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 21:41:59 -0700 Subject: [PATCH 17/30] add 2D to engine --- viscy/light/engine.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 7241c5c0..f9791864 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -29,12 +29,14 @@ from viscy.evaluation.evaluation_metrics import mean_average_precision from viscy.light.data import Sample +from viscy.unet.networks.Unet2D import Unet2d from viscy.unet.networks.Unet21D import Unet21d from viscy.unet.networks.Unet25D import Unet25d _UNET_ARCHITECTURE = { - "2.5D": Unet25d, + "2D": Unet2d, "2.1D": Unet21d, + "2.5D": Unet25d, } @@ -146,10 +148,14 @@ def __init__( self.training_step_outputs = [] self.validation_step_outputs = [] # required to log the graph + if arch == "2D": + example_depth = 1 + else: + example_depth = model_config.get("in_stack_depth") or 5 self.example_input_array = torch.rand( 1, 1, - (model_config.get("in_stack_depth") or 5), + example_depth, *example_input_yx_shape, ) self.test_cellpose_model_path = test_cellpose_model_path From 0e2b575cf5c01a912a86634d2cd702ba6fd0a6ee Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 21:42:22 -0700 Subject: [PATCH 18/30] normalize shape for 2D nets --- viscy/unet/networks/Unet2D.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/viscy/unet/networks/Unet2D.py b/viscy/unet/networks/Unet2D.py index 36abb26b..8698e892 100644 --- a/viscy/unet/networks/Unet2D.py +++ b/viscy/unet/networks/Unet2D.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from viscy.unet.networks.layers import ConvBlock2D +from viscy.unet.networks.layers.ConvBlock2D import ConvBlock2D class Unet2d(nn.Module): @@ -189,7 +189,7 @@ def forward(self, x, validate_input=False): f"Input channels must equal network" f" input channels: {self.in_channels}" ) - + x = x.squeeze(2) # encoder skip_tensors = [] for i in range(self.num_blocks): @@ -209,7 +209,7 @@ def forward(self, x, validate_input=False): # output channel collapsing layer x = self.terminal_block(x) - return x + return x.unsqueeze(2) def register_modules(self, module_list, name): """ From 79380960644470b7ccdfbe64d65622ae2589f626 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 22:20:02 -0700 Subject: [PATCH 19/30] fix args check --- viscy/unet/networks/Unet2D.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/viscy/unet/networks/Unet2D.py b/viscy/unet/networks/Unet2D.py index 8698e892..05d9a17a 100644 --- a/viscy/unet/networks/Unet2D.py +++ b/viscy/unet/networks/Unet2D.py @@ -59,9 +59,9 @@ def __init__( # ----- Standardize Filter Sequence -----# if len(num_filters) != 0: - assert len(num_filters) == num_blocks, ( - "Length of num_filters must be equal to num_blo" - "cks + 1 (number of convolutional blocks per path)." + assert len(num_filters) == num_blocks + 1, ( + "Length of num_filters must be equal to num_blocks + 1 " + "(number of convolutional blocks per path)." ) self.num_filters = num_filters else: From aee1ca0ba63a130015d64809111d1336329f1fd7 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Aug 2023 22:33:17 -0700 Subject: [PATCH 20/30] excercise 1 --- examples/demo_dlmbl/python/excercise_1.py | 241 ++++++++++++++++++++++ 1 file changed, 241 insertions(+) create mode 100644 examples/demo_dlmbl/python/excercise_1.py diff --git a/examples/demo_dlmbl/python/excercise_1.py b/examples/demo_dlmbl/python/excercise_1.py new file mode 100644 index 00000000..23d8f9e8 --- /dev/null +++ b/examples/demo_dlmbl/python/excercise_1.py @@ -0,0 +1,241 @@ +# %% [markdown] +""" +# Image translation excercise part 1 + +In this exercise, we will solve an image translation task of +reconstructing nuclei and membrane markers from phase images of cells. +Here, the source domain is label-free microscopy (average material density), +and the target domain is fluorescence microscopy (fluorophore density). + +Learning goals of part 1: + +- Load the and visualize the images from OME-Zarr +- Configure the data loaders +- Initialize a 2D U-Net model for virtual staining + + +
+Set your python kernel to 004-image-translation +
+""" + +# %% +import matplotlib.pyplot as plt +import torch +from iohub import open_ome_zarr +from tensorboard import notebook +from torchview import draw_graph + +from viscy.light.data import HCSDataModule +from viscy.light.engine import VSTrainer, VSUNet + +BATCH_SIZE = 32 + +# %% [markdown] +""" +Load Dataset. + +
+Task 1.1 + +Use +iohub.open_ome_zarr to read the dataset. +Run open_ome_zarr? in a cell to see the docstring. + +There should be 301 FOVs in the dataset (9.3 GB compressed). + +""" + +# %% +# set dataset path here +data_path = ( + "/hpc/projects/comp.micro/virtual_staining/datasets/dlmbl/HEK_nuclei_membrane.zarr" +) + +dataset = open_ome_zarr(data_path) + +print(len(list(dataset.positions()))) + +# %% [markdown] +""" +View images with matplotlib. + +Note that labelling is not perfect, +as some cells are not expressing the fluorophore. +""" + +# %% +image = dataset["0/0/0/0"].numpy() +print(image.shape) + +figure, axes = plt.subplots(1, 3, figsize=(9, 3)) + +for ax, channel in zip(axes, image[0, :, 0]): + ax.imshow(channel, cmap="gray") + ax.axis("off") + +plt.tight_layout() + +# %% [markdown] +""" +Configure the data loaders for training and validation. +""" + +# %% +data_module = HCSDataModule( + data_path, + source_channel="Phase", + target_channel=["Nuclei", "Membrane"], + z_window_size=1, + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=8, + architecture="2D", + yx_patch_size=(256, 256), +) + +data_module.setup("fit") + +print(len(data_module.train_dataset), len(data_module.val_dataset)) + +# %% [markdown] +""" +
+Task 1.2 + +Validate that the data can be loaded in batches correctly. +
+""" + +# %% +train_dataloader = data_module.train_dataloader() + +for i, batch in enumerate(train_dataloader): + ... + # check some batches and break + break + +# %% tags=["solution"] +train_dataloader = data_module.train_dataloader() + +for i, batch in enumerate(train_dataloader): + print(f"Batch {i}:") + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + summary = (v.shape, v.dtype) + else: + summary = v + print(k, summary) + if i > 2: + break + +# %% [markdown] +""" +Construct a 2D U-Net for image translation. + +See ``viscy.unet.networks.Unet2D.Unet2d`` for configuration details. +Increase the ``depth`` in ``draw_graph`` to zoom in. +""" + +# %% +model_config = { + "architecture": "2D", + "residual": True, + "dropout": 0.1, + "task": "reg", +} + +model = VSUNet( + model_config=model_config.copy(), + batch_size=BATCH_SIZE, + loss_function=torch.nn.functional.mse_loss, + schedule="WarmupCosine", + log_num_samples=10, +) + +# visualize graph +model_graph = draw_graph(model, model.example_input_array, depth=2) +graph = model_graph.visual_graph +graph + +# %% [markdown] +""" +Configure trainer class. +Here we use the ``fast_dev_run`` flag to run a sanity check first. +""" + +# %% +trainer = VSTrainer(accelerator="gpu", fast_dev_run=True) + +trainer.fit(model, datamodule=data_module) + +# %% [markdown] +""" +
+Task 1.3 + +Modify the trainer to train the model for 20 epochs. +
+""" + +# %% [markdown] +""" +Tips: + +- See ``VSTrainer?`` for all the available parameters. +- Set ``default_root_dir`` to store the logs and checkpoints +in a specific directory. +""" + +# %% [markdown] +""" +Bonus: + +- Tweak model hyperparameters +- Adjust batch size to fully utilize the VRAM +""" + +# %% tags=["solution"] +wider_config = model_config | {"num_filters": [24, 48, 96, 192, 384]} + +model = model = VSUNet( + model_config=wider_config.copy(), + batch_size=BATCH_SIZE, + loss_function=torch.nn.functional.mse_loss, + schedule="WarmupCosine", + log_num_samples=10, +) + +trainer = VSTrainer( + accelerator="gpu", max_epochs=20, log_every_n_steps=10, default_root_dir="" +) + +trainer.fit(model, datamodule=data_module) + +# %% [markdown] +""" +Launch TensorBoard with: + +``` +%load_ext tensorboard +%tensorboard --logdir /path/to/lightning_logs +``` +""" + +# %% +notebook.list() + +# %% +notebook.display(port=6006, height=800) + +# %% [markdown] +""" +
+Checkpoint 1 + +Now the training has started, +we can come back after a while and evaluate the performance! +
+""" + +# %% From 7069b82ac2d357fe73717470673b72a0e9c57f42 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Thu, 10 Aug 2023 13:50:21 -0700 Subject: [PATCH 21/30] readme + dependencies tested with python 3.10 (#30) * readme + dependencies tested with python 3.10 * match package description with README * cellpose (metrics -> required) * misc formatting and language edits * fix optional dependency on cellpose --------- Co-authored-by: Ziwen Liu --- README.md | 71 ++++++++++++++++++++----------------------- pyproject.toml | 2 +- viscy/light/engine.py | 13 +++++++- 3 files changed, 46 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 6a301208..8f1bcf2d 100644 --- a/README.md +++ b/README.md @@ -2,62 +2,57 @@ viscy is a deep learning pipeline for training and deploying computer vision models for image-based phenotyping at single cell resolution. -The current focus of the pipeline is on the image translation models for virtual staining of multiple cellular compartments from label-free images. We are building these models for simultaneous segmentation of nuclei and membrane, which are the first steps in a single-cell phenotyping pipeline. Our pipeline also provides utilities to export the models to onnx format for use at runtime. We will grow the collection of the models suitable for high-throughput imaging and phenotyping. +The current focus of the pipeline is on the image translation models for virtual staining of multiple cellular compartments from label-free images. +We are building these models for simultaneous segmentation of nuclei and membrane, which are the first steps in a single-cell phenotyping pipeline. +Our pipeline also provides utilities to export the models to ONNX format for use at runtime. +We will grow the collection of the models suitable for high-throughput imaging and phenotyping. +Expect rough edges until we release a PyPI package. ![virtual_staining](docs/figures/phase_to_nuclei_membrane.svg) This pipeline evolved from the [TensorFlow version of virtual staining pipeline](https://github.com/mehta-lab/microDL), which we reported in [this paper in 2020](https://elifesciences.org/articles/55502). The previous pipeline is now a public archive, and we will be focusing our efforts on viscy. -## Installation +## Installing viscy -(Optional) create a new virtual/Conda environment. +1. We highly encourage using new Conda/virtual environment. + ([Mamba](https://github.com/mamba-org/mamba) is a faster re-implementation Conda.) -Clone this repository and install viscy: + ```sh + mamba create --name viscy python=3.10 + # OR + mamba create --prefix /path/to/conda/envs/viscy python=3.10 + ``` -```sh -git clone https://github.com/mehta-lab/viscy.git -cd viscy -pip install . -``` +2. Clone this repository and install with pip: -Verify installation by accessing the CLI help message: + ```sh + git clone https://github.com/mehta-lab/viscy.git + # change to project root directory (parent folder of pyproject.toml) + cd viscy + pip install . + ``` -```sh -viscy --help -``` + If evaluating virtually stained images for segmentation tasks, + additional dependencies need to be installed: + + ```sh + pip install ".[metrics]" + ``` + +3. Verify installation by accessing the CLI help message: + + ```sh + viscy --help + ``` For development installation, see [the contributing guide](CONTRIBUTING.md). -The pipeline is built using the [pytorch lightning](https://www.pytorchlightning.ai/index.html) framework and [iohub](https://github.com/czbiohub-sf/iohub) library for reading and writing data in [ome-zarr](https://www.nature.com/articles/s41592-021-01326-w) format. +The pipeline is built using the [PyTorch Lightning](https://www.pytorchlightning.ai/index.html) framework and [iohub](https://github.com/czbiohub-sf/iohub) library for reading and writing data in [OME-Zarr](https://www.nature.com/articles/s41592-021-01326-w) format. The full functionality is tested only on Linux `x86_64` with NVIDIA Ampere GPUs (CUDA 12.0). Some features (e.g. mixed precision and distributed training) may not work with other setups, see [PyTorch documentation](https://pytorch.org) for details. -Following dependencies will allow use and development of the pipeline, while the pypi package is pending: - -``` -iohub==0.1.0.dev3 -torch>=2.0.0 -torchvision>=0.15.1 -tensorboard>=2.13.0 -lightning>=2.0.1 -monai>=1.2.0 -jsonargparse[signatures]>=4.20.1 -scikit-image>=0.19.2 -matplotlib -cellpose==2.1.0 -lapsolver==1.1.0 -scikit-learn>=1.1.3 -scipy>=1.8.0 -torchmetrics[detection]>=1.0.0 -pytest -pytest-cov -hypothesis -profilehooks -onnxruntime -``` - ## Virtual staining of cellular compartments from label-free images Predicting sub-cellular landmarks such as nuclei and membrane from label-free (e.g. phase) images diff --git a/pyproject.toml b/pyproject.toml index 388019c9..6a85a05a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "viscy" -description = "Learning vision for cells" +description = "computer vision for image-based phenotyping of single cells" readme = "README.md" # cannot build on 3.11 due to https://github.com/cheind/py-lapsolver/pull/18 requires-python = ">=3.9,!=3.11" diff --git a/viscy/light/engine.py b/viscy/light/engine.py index f9791864..d8a11909 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -5,7 +5,6 @@ import numpy as np import torch import torch.nn.functional as F -from cellpose.models import CellposeModel from imageio import imwrite from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized @@ -33,6 +32,12 @@ from viscy.unet.networks.Unet21D import Unet21d from viscy.unet.networks.Unet25D import Unet25d +try: + from cellpose.models import CellposeModel +except ImportError: + CellposeModel = None + + _UNET_ARCHITECTURE = { "2D": Unet2d, "2.1D": Unet21d, @@ -303,6 +308,12 @@ def on_validation_epoch_end(self): def on_test_start(self): """Load CellPose model for segmentation.""" + if CellposeModel is None: + raise ImportError( + "CellPose not installed. " + "Please install the metrics dependency with " + "`pip install viscy\".[metrics]\"`" + ) if self.test_cellpose_model_path is not None: self.cellpose_model = CellposeModel( model_type=self.test_cellpose_model_path, device=self.device From e19341c738258c8934c2815b64edaf2ccd261ac0 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Thu, 10 Aug 2023 14:06:56 -0700 Subject: [PATCH 22/30] tested data loading --- examples/demo_dlmbl/python/excercise_1.py | 16 ++++++++++++++-- pyproject.toml | 3 +++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/examples/demo_dlmbl/python/excercise_1.py b/examples/demo_dlmbl/python/excercise_1.py index 23d8f9e8..8364cace 100644 --- a/examples/demo_dlmbl/python/excercise_1.py +++ b/examples/demo_dlmbl/python/excercise_1.py @@ -40,10 +40,13 @@ Use iohub.open_ome_zarr to read the dataset. -Run open_ome_zarr? in a cell to see the docstring. There should be 301 FOVs in the dataset (9.3 GB compressed). +Each FOV consists of 3 channels of 2048x2048 images, saved in a [HCS layout](https://ngff.openmicroscopy.org/latest/#hcs-layout) format specified by the to the Open Microscopy Environment (OME) next generation file format (NGFF). + +Run open_ome_zarr? in a cell to see the docstring. + """ # %% @@ -56,16 +59,25 @@ print(len(list(dataset.positions()))) + # %% [markdown] """ View images with matplotlib. +The layout on the disk is: row/col/field/resolution/timepoint/channel/z/y/x. + + Note that labelling is not perfect, as some cells are not expressing the fluorophore. """ # %% -image = dataset["0/0/0/0"].numpy() + +row = "0" +col = "0" +field = "0" +resolution = "0" # 0 is the highest resolution, 1 is 2x2 binned, 2 is 4x4 binned, etc. +image = dataset[f'{row}/{col}/{field}/{resolution}'].numpy() print(image.shape) figure, axes = plt.subplots(1, 3, figsize=(9, 3)) diff --git a/pyproject.toml b/pyproject.toml index 388019c9..962201d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,9 @@ dependencies = [ "jsonargparse[signatures]>=4.20.1", "scikit-image", "matplotlib", + "ipykernel", # used by demo_dlmbl + "graphviz", # used by demo_dlmbl + "torchview", # used by demo_dlmbl ] dynamic = ["version"] From b4dea7205543ba69e05e18745dc3f4802c4fcdb6 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 10 Aug 2023 14:16:23 -0700 Subject: [PATCH 23/30] fix HTML link --- examples/demo_dlmbl/python/excercise_1.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/demo_dlmbl/python/excercise_1.py b/examples/demo_dlmbl/python/excercise_1.py index 8364cace..1a69177c 100644 --- a/examples/demo_dlmbl/python/excercise_1.py +++ b/examples/demo_dlmbl/python/excercise_1.py @@ -43,7 +43,11 @@ There should be 301 FOVs in the dataset (9.3 GB compressed). -Each FOV consists of 3 channels of 2048x2048 images, saved in a [HCS layout](https://ngff.openmicroscopy.org/latest/#hcs-layout) format specified by the to the Open Microscopy Environment (OME) next generation file format (NGFF). +Each FOV consists of 3 channels of 2048x2048 images, +saved in the +High-Content Screening (HCS) layout +specified by the Open Microscopy Environment Next Generation File Format +(OME-NGFF). Run open_ome_zarr? in a cell to see the docstring. From d94b338b658fa5a73014df3af22a5d641a492ebb Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Thu, 10 Aug 2023 15:13:45 -0700 Subject: [PATCH 24/30] show examples from the batches as images --- examples/demo_dlmbl/python/excercise_1.py | 33 +++++++++++++++++------ 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/examples/demo_dlmbl/python/excercise_1.py b/examples/demo_dlmbl/python/excercise_1.py index 8364cace..09e7dc15 100644 --- a/examples/demo_dlmbl/python/excercise_1.py +++ b/examples/demo_dlmbl/python/excercise_1.py @@ -130,16 +130,33 @@ # %% tags=["solution"] train_dataloader = data_module.train_dataloader() + +fig, axs = plt.subplots(3, 8, figsize=(20, 6)) + +# Draw 8 batches, each with 32 images. Show the first image in each batch. + for i, batch in enumerate(train_dataloader): - print(f"Batch {i}:") - for k, v in batch.items(): - if isinstance(v, torch.Tensor): - summary = (v.shape, v.dtype) - else: - summary = v - print(k, summary) - if i > 2: + # The batch is a dictionary consisting of three keys: 'index', 'source', 'target'. + if i >= 8: break + FOV = batch['index'][0][0] + input_tensor = batch['source'][0, 0, :, :].squeeze() + target_membrane_tensor = batch['target'][0, 0, :, :].squeeze() + target_nuclei_tensor = batch['target'][0, 1, :, :].squeeze() + + axs[0, i].imshow(input_tensor, cmap='gray') + axs[1, i].imshow(target_nuclei_tensor, cmap='gray') + axs[2, i].imshow(target_membrane_tensor, cmap='gray') + axs[0, i].set_title(f'input@{FOV}') + axs[1, i].set_title('target-nuclei') + axs[2, i].set_title('target-membrane') + axs[0, i].axis('off') + axs[1, i].axis('off') + axs[2, i].axis('off') + +plt.tight_layout() +plt.show() + # %% [markdown] """ From 413b429309e4182c68c816899d09e85ca11d22da Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Thu, 10 Aug 2023 16:09:57 -0700 Subject: [PATCH 25/30] clarified data module --- examples/demo_dlmbl/python/excercise_1.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/demo_dlmbl/python/excercise_1.py b/examples/demo_dlmbl/python/excercise_1.py index 7adff170..e87ef43f 100644 --- a/examples/demo_dlmbl/python/excercise_1.py +++ b/examples/demo_dlmbl/python/excercise_1.py @@ -128,7 +128,7 @@ for i, batch in enumerate(train_dataloader): ... - # check some batches and break + # plot one image from each of the batch and break break # %% tags=["solution"] @@ -141,6 +141,10 @@ for i, batch in enumerate(train_dataloader): # The batch is a dictionary consisting of three keys: 'index', 'source', 'target'. + # index is the tuple consisting of (image name, time, and z-slice) + # source is the tensor of size 1x1x256x256 + # target is the tensor of size 2x1x256x256 + if i >= 8: break FOV = batch['index'][0][0] From f8e2991a5f6cabd70e942872e09059a3b2c0be0a Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 10 Aug 2023 16:35:12 -0700 Subject: [PATCH 26/30] update pyramid dataset --- examples/demo_dlmbl/python/excercise_1.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/demo_dlmbl/python/excercise_1.py b/examples/demo_dlmbl/python/excercise_1.py index e87ef43f..be0c9f11 100644 --- a/examples/demo_dlmbl/python/excercise_1.py +++ b/examples/demo_dlmbl/python/excercise_1.py @@ -56,7 +56,7 @@ # %% # set dataset path here data_path = ( - "/hpc/projects/comp.micro/virtual_staining/datasets/dlmbl/HEK_nuclei_membrane.zarr" + "/hpc/projects/comp.micro/virtual_staining/datasets/dlmbl/HEK_nuclei_membrane_pyramid.zarr" ) dataset = open_ome_zarr(data_path) @@ -80,7 +80,9 @@ row = "0" col = "0" field = "0" -resolution = "0" # 0 is the highest resolution, 1 is 2x2 binned, 2 is 4x4 binned, etc. +# '0' is the highest resolution +# '1' is 2x2 down-scaled, '2' is 4x4 down-scaled, etc. +resolution = "0" image = dataset[f'{row}/{col}/{field}/{resolution}'].numpy() print(image.shape) From 998bfde9861958d541a52e3f8094f2e3a7284a34 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 10 Aug 2023 16:36:29 -0700 Subject: [PATCH 27/30] format --- examples/demo_dlmbl/python/excercise_1.py | 36 +++++++++++------------ 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/examples/demo_dlmbl/python/excercise_1.py b/examples/demo_dlmbl/python/excercise_1.py index be0c9f11..b7bf6138 100644 --- a/examples/demo_dlmbl/python/excercise_1.py +++ b/examples/demo_dlmbl/python/excercise_1.py @@ -55,9 +55,7 @@ # %% # set dataset path here -data_path = ( - "/hpc/projects/comp.micro/virtual_staining/datasets/dlmbl/HEK_nuclei_membrane_pyramid.zarr" -) +data_path = "/hpc/projects/comp.micro/virtual_staining/datasets/dlmbl/HEK_nuclei_membrane_pyramid.zarr" dataset = open_ome_zarr(data_path) @@ -83,7 +81,7 @@ # '0' is the highest resolution # '1' is 2x2 down-scaled, '2' is 4x4 down-scaled, etc. resolution = "0" -image = dataset[f'{row}/{col}/{field}/{resolution}'].numpy() +image = dataset[f"{row}/{col}/{field}/{resolution}"].numpy() print(image.shape) figure, axes = plt.subplots(1, 3, figsize=(9, 3)) @@ -146,23 +144,23 @@ # index is the tuple consisting of (image name, time, and z-slice) # source is the tensor of size 1x1x256x256 # target is the tensor of size 2x1x256x256 - + if i >= 8: break - FOV = batch['index'][0][0] - input_tensor = batch['source'][0, 0, :, :].squeeze() - target_membrane_tensor = batch['target'][0, 0, :, :].squeeze() - target_nuclei_tensor = batch['target'][0, 1, :, :].squeeze() - - axs[0, i].imshow(input_tensor, cmap='gray') - axs[1, i].imshow(target_nuclei_tensor, cmap='gray') - axs[2, i].imshow(target_membrane_tensor, cmap='gray') - axs[0, i].set_title(f'input@{FOV}') - axs[1, i].set_title('target-nuclei') - axs[2, i].set_title('target-membrane') - axs[0, i].axis('off') - axs[1, i].axis('off') - axs[2, i].axis('off') + FOV = batch["index"][0][0] + input_tensor = batch["source"][0, 0, :, :].squeeze() + target_membrane_tensor = batch["target"][0, 0, :, :].squeeze() + target_nuclei_tensor = batch["target"][0, 1, :, :].squeeze() + + axs[0, i].imshow(input_tensor, cmap="gray") + axs[1, i].imshow(target_nuclei_tensor, cmap="gray") + axs[2, i].imshow(target_membrane_tensor, cmap="gray") + axs[0, i].set_title(f"input@{FOV}") + axs[1, i].set_title("target-nuclei") + axs[2, i].set_title("target-membrane") + axs[0, i].axis("off") + axs[1, i].axis("off") + axs[2, i].axis("off") plt.tight_layout() plt.show() From b81fd0e9ddde5c6db3724ab4050bd3778dd75d44 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 10 Aug 2023 16:38:45 -0700 Subject: [PATCH 28/30] fix model arch --- examples/demo_dlmbl/python/excercise_1.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/demo_dlmbl/python/excercise_1.py b/examples/demo_dlmbl/python/excercise_1.py index b7bf6138..f2af135c 100644 --- a/examples/demo_dlmbl/python/excercise_1.py +++ b/examples/demo_dlmbl/python/excercise_1.py @@ -177,6 +177,8 @@ # %% model_config = { "architecture": "2D", + "in_channels": 1, + "out_channels": 2, "residual": True, "dropout": 0.1, "task": "reg", From 84205708b11f4a2df926f125031313edef59a17d Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 10 Aug 2023 16:53:51 -0700 Subject: [PATCH 29/30] fix device issue --- examples/demo_dlmbl/python/excercise_1.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/demo_dlmbl/python/excercise_1.py b/examples/demo_dlmbl/python/excercise_1.py index f2af135c..04bc48ee 100644 --- a/examples/demo_dlmbl/python/excercise_1.py +++ b/examples/demo_dlmbl/python/excercise_1.py @@ -30,6 +30,7 @@ from viscy.light.engine import VSTrainer, VSUNet BATCH_SIZE = 32 +GPU_ID = 0 # %% [markdown] """ @@ -193,7 +194,7 @@ ) # visualize graph -model_graph = draw_graph(model, model.example_input_array, depth=2) +model_graph = draw_graph(model, model.example_input_array, depth=2, device="cpu") graph = model_graph.visual_graph graph @@ -204,7 +205,7 @@ """ # %% -trainer = VSTrainer(accelerator="gpu", fast_dev_run=True) +trainer = VSTrainer(accelerator="gpu", device=GPU_ID, fast_dev_run=True) trainer.fit(model, datamodule=data_module) From a3fbd0bd8856e5743f9026c1399641148380425c Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Thu, 10 Aug 2023 18:00:10 -0700 Subject: [PATCH 30/30] small fixes and final comments on exercise 1 --- examples/demo_dlmbl/python/excercise_1.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/demo_dlmbl/python/excercise_1.py b/examples/demo_dlmbl/python/excercise_1.py index 04bc48ee..0904b75a 100644 --- a/examples/demo_dlmbl/python/excercise_1.py +++ b/examples/demo_dlmbl/python/excercise_1.py @@ -25,6 +25,8 @@ from iohub import open_ome_zarr from tensorboard import notebook from torchview import draw_graph +import os + from viscy.light.data import HCSDataModule from viscy.light.engine import VSTrainer, VSUNet @@ -150,8 +152,9 @@ break FOV = batch["index"][0][0] input_tensor = batch["source"][0, 0, :, :].squeeze() - target_membrane_tensor = batch["target"][0, 0, :, :].squeeze() - target_nuclei_tensor = batch["target"][0, 1, :, :].squeeze() + target_nuclei_tensor = batch["target"][0, 0, :, :].squeeze() + target_membrane_tensor = batch["target"][0, 1, :, :].squeeze() + axs[0, i].imshow(input_tensor, cmap="gray") axs[1, i].imshow(target_nuclei_tensor, cmap="gray") @@ -205,7 +208,7 @@ """ # %% -trainer = VSTrainer(accelerator="gpu", device=GPU_ID, fast_dev_run=True) +trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], fast_dev_run=True) trainer.fit(model, datamodule=data_module) @@ -246,8 +249,9 @@ log_num_samples=10, ) + trainer = VSTrainer( - accelerator="gpu", max_epochs=20, log_every_n_steps=10, default_root_dir="" + accelerator="gpu", max_epochs=20, log_every_n_steps=8, default_root_dir=os.path.expanduser("~") ) trainer.fit(model, datamodule=data_module)