diff --git a/docs/figures/phase_to_nuclei_membrane.png b/docs/figures/phase_to_nuclei_membrane.png
new file mode 100644
index 00000000..52c627ce
Binary files /dev/null and b/docs/figures/phase_to_nuclei_membrane.png differ
diff --git a/examples/demo_dlmbl/convert-solution.py b/examples/demo_dlmbl/convert-solution.py
new file mode 100644
index 00000000..279f7874
--- /dev/null
+++ b/examples/demo_dlmbl/convert-solution.py
@@ -0,0 +1,41 @@
+import argparse
+from traitlets.config import Config
+import nbformat as nbf
+from nbconvert.preprocessors import TagRemovePreprocessor, ClearOutputPreprocessor
+from nbconvert.exporters import NotebookExporter
+
+
+def get_arg_parser():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('input_file')
+ parser.add_argument('output_file')
+
+ return parser
+
+
+def convert(input_file, output_file):
+ c = Config()
+ c.TagRemovePreprocessor.remove_cell_tags = ("solution",)
+ c.TagRemovePreprocessor.enabled = True
+ c.ClearOutputPreprocesser.enabled = True
+ c.NotebookExporter.preprocessors = [
+ "nbconvert.preprocessors.TagRemovePreprocessor",
+ "nbconvert.preprocessors.ClearOutputPreprocessor"
+ ]
+
+ exporter = NotebookExporter(config=c)
+ exporter.register_preprocessor(TagRemovePreprocessor(config=c), True)
+ exporter.register_preprocessor(ClearOutputPreprocessor(), True)
+
+ output = NotebookExporter(config=c).from_filename(input_file)
+ with open(output_file, 'w') as f:
+ f.write(output[0])
+
+
+if __name__ == "__main__":
+ parser = get_arg_parser()
+ args = parser.parse_args()
+
+ convert(args.input_file, args.output_file)
+ print(f'Converted {args.input_file} to {args.output_file}')
diff --git a/examples/demo_dlmbl/debug_log_graph.py b/examples/demo_dlmbl/debug_log_graph.py
new file mode 100644
index 00000000..34b8c84a
--- /dev/null
+++ b/examples/demo_dlmbl/debug_log_graph.py
@@ -0,0 +1,96 @@
+
+# %%
+# %% Imports and paths
+
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import torch
+import torchview
+import torchvision
+from iohub import open_ome_zarr
+from lightning.pytorch import seed_everything
+from lightning.pytorch.loggers import CSVLogger
+
+# pytorch lightning wrapper for Tensorboard.
+from tensorboard import notebook # for viewing tensorboard in notebook
+from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard
+
+# HCSDataModule makes it easy to load data during training.
+from viscy.light.data import HCSDataModule
+
+# Trainer class and UNet.
+from viscy.light.engine import VSTrainer, VSUNet
+
+seed_everything(42, workers=True)
+
+# Paths to data and log directory
+data_path = Path(
+ Path("~/data/04_image_translation/HEK_nuclei_membrane_pyramid.zarr/")
+).expanduser()
+
+log_dir = Path("~/data/04_image_translation/logs/").expanduser()
+
+# Create log directory if needed, and launch tensorboard
+log_dir.mkdir(parents=True, exist_ok=True)
+
+# fmt: off
+%reload_ext tensorboard
+%tensorboard --logdir {log_dir} --port 6007 --bind_all
+# fmt: on
+
+# %% The entire training loop is contained in this cell.
+
+GPU_ID = 0
+BATCH_SIZE = 10
+YX_PATCH_SIZE = (512, 512)
+
+
+# Dictionary that specifies key parameters of the model.
+phase2fluor_config = {
+ "architecture": "2D",
+ "num_filters": [24, 48, 96, 192, 384],
+ "in_channels": 1,
+ "out_channels": 2,
+ "residual": True,
+ "dropout": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data.
+ "task": "reg", # reg = regression task.
+}
+
+phase2fluor_model = VSUNet(
+ model_config=phase2fluor_config.copy(),
+ batch_size=BATCH_SIZE,
+ loss_function=torch.nn.functional.l1_loss,
+ schedule="WarmupCosine",
+ log_num_samples=10, # Number of samples from each batch to log to tensorboard.
+ example_input_yx_shape=YX_PATCH_SIZE,
+)
+
+# Reinitialize the data module.
+phase2fluor_data = HCSDataModule(
+ data_path,
+ source_channel="Phase",
+ target_channel=["Nuclei", "Membrane"],
+ z_window_size=1,
+ split_ratio=0.8,
+ batch_size=BATCH_SIZE,
+ num_workers=8,
+ architecture="2D",
+ yx_patch_size=YX_PATCH_SIZE,
+ augment=True,
+)
+phase2fluor_data.setup("fit")
+
+
+# Train for 3 epochs to see if you can log graph.
+trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], max_epochs=3, default_root_dir=log_dir)
+
+# trainer class takes the model and the data module as inputs.
+trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)
+
+# %% Is exmple_input_array present?
+print(f'{phase2fluor_model.example_input_array.shape},{phase2fluor_model.example_input_array.dtype}')
+trainer.logger.log_graph(phase2fluor_model, phase2fluor_model.example_input_array)
+# %%
diff --git a/examples/demo_dlmbl/exercise.ipynb b/examples/demo_dlmbl/exercise.ipynb
new file mode 100644
index 00000000..b6a3a256
--- /dev/null
+++ b/examples/demo_dlmbl/exercise.ipynb
@@ -0,0 +1,754 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "c9438eb5",
+ "metadata": {
+ "cell_marker": "\"\"\""
+ },
+ "source": [
+ "# Image translation\n",
+ "---\n",
+ "\n",
+ "Written by Ziwen Liu and Shalin Mehta, CZ Biohub San Francisco.\n",
+ "---\n",
+ "\n",
+ "In this exercise, we will solve an image translation task to predict fluorescence images of nuclei and membrane markers from quantitative phase images of cells. In other words, we will _virtually stain_ the nuclei and membrane visible in the phase image. \n",
+ "\n",
+ "Here, the source domain is label-free microscopy (material density) and the target domain is fluorescence microscopy (fluorophore density). The goal is to learn a mapping from the source domain to the target domain. We will use a deep convolutional neural network (CNN), specifically, a U-Net model with residual connections to learn the mapping. The preprocessing, training, prediction, evaluation, and deployment steps are unified in a computer vision pipeline for single-cell analysis that we call [VisCy](https://github.com/mehta-lab/VisCy).\n",
+ "\n",
+ "VisCy evolved from our previous work on virtual staining of cellular components from their density and anisotropy.\n",
+ "![](https://iiif.elifesciences.org/lax/55502%2Felife-55502-fig1-v2.tif/full/1500,/0/default.jpg)\n",
+ "\n",
+ "[Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning\n",
+ ". eLife](https://elifesciences.org/articles/55502).\n",
+ "\n",
+ "VisCy exploits recent advances in the data and metadata formats ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). Our previous pipeline, [microDL](https://github.com/mehta-lab/microDL), is deprecated and is now a public archive."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b36463af",
+ "metadata": {
+ "cell_marker": "\"\"\"",
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "Today, we will train a 2D image translation model using a 2D U-Net with residual connections. We will use a dataset of 301 fields of view (FOVs) of Human Embryonic Kidney (HEK) cells, each FOV has 3 channels (phase, membrane, and nuclei). The cells were labeled with CRISPR editing. Intrestingly, not all cells during this experiment were labeled due to the stochastic nature of CRISPR editing. In such situations, virtual staining rescues missing labels.\n",
+ "![HEK](https://github.com/mehta-lab/VisCy/blob/dlmbl2023/docs/figures/phase_to_nuclei_membrane.svg?raw=true)\n",
+ "\n",
+ "
\n",
+ "The exercise is organized in 3 parts.\n",
+ "\n",
+ "* **Part 1** - Explore the data using tensorboard. Launch the training before lunch.\n",
+ "* Lunch break - The model will continue training during lunch.\n",
+ "* **Part 2** - Evaluate the training with tensorboard. Train another model.\n",
+ "* **Part 3** - Tune the models to improve performance.\n",
+ "
\n",
+ "\n",
+ "📖 As you work through parts 2 and 3, please share the layouts of the models you train and their performance with everyone via [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) 📖.\n",
+ "\n",
+ "\n",
+ "Our guesstimate is that each of the three parts will take ~1.5 hours, but don't rush parts 1 and 2 if you need more time with them.\n",
+ "We will discuss your observations on google doc after checkpoints 2 and 3. The exercise is focused on understanding information contained in data, process of training and evaluating image translation models, and parameter exploration.\n",
+ "There are a few coding tasks sprinkled in.\n",
+ "\n",
+ "\n",
+ "Before you start,\n",
+ "\n",
+ "\n",
+ "Task 1.2\n",
+ "Setup the data loader and log several batches to tensorboard.\n",
+ "
`\n",
+ "\n",
+ "VisCy builds on top of PyTorch Lightning. PyTorch Lightning is a thin wrapper around PyTorch that allows rapid experimentation. It provides a [DataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) to handle loading and processing of data during training. VisCy provides a child class, `HCSDataModule` to make it intuitve to access data stored in the HCS layout.\n",
+ " \n",
+ "The dataloader in `HCSDataModule` returns a batch of samples. A `batch` is a list of dictionaries. The length of the list is equal to the batch size. Each dictionary consists of following key-value pairs.\n",
+ "- `source`: the input image, a tensor of size 1*1*Y*X\n",
+ "- `target`: the target image, a tensor of size 2*1*Y*X\n",
+ "- `index` : the tuple of (location of field in HCS layout, time, and z-slice) of the sample."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "05ce4dbb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Define a function to write a batch to tensorboard log.\n",
+ "\n",
+ "\n",
+ "def log_batch_tensorboard(batch, batchno, writer, card_name):\n",
+ " \"\"\"\n",
+ " Logs a batch of images to TensorBoard.\n",
+ "\n",
+ " Args:\n",
+ " batch (dict): A dictionary containing the batch of images to be logged.\n",
+ " writer (SummaryWriter): A TensorBoard SummaryWriter object.\n",
+ " card_name (str): The name of the card to be displayed in TensorBoard.\n",
+ "\n",
+ " Returns:\n",
+ " None\n",
+ " \"\"\"\n",
+ " batch_phase = batch[\"source\"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.\n",
+ " batch_membrane = batch[\"target\"][:, 1, 0, :, :].unsqueeze(\n",
+ " 1\n",
+ " ) # batch_size x 1 x Y x X tensor.\n",
+ " batch_nuclei = batch[\"target\"][:, 0, 0, :, :].unsqueeze(\n",
+ " 1\n",
+ " ) # batch_size x 1 x Y x X tensor.\n",
+ "\n",
+ " p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))\n",
+ " batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)\n",
+ "\n",
+ " p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))\n",
+ " batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)\n",
+ "\n",
+ " p1, p99 = np.percentile(batch_phase, (0.1, 99.9))\n",
+ " batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)\n",
+ "\n",
+ " [N, C, H, W] = batch_phase.shape\n",
+ " interleaved_images = torch.zeros((3 * N, C, H, W), dtype=batch_phase.dtype)\n",
+ " interleaved_images[0::3, :] = batch_phase\n",
+ " interleaved_images[1::3, :] = batch_nuclei\n",
+ " interleaved_images[2::3, :] = batch_membrane\n",
+ "\n",
+ " grid = torchvision.utils.make_grid(interleaved_images, nrow=3)\n",
+ "\n",
+ " # add the grid to tensorboard\n",
+ " writer.add_image(card_name, grid, batchno)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f627e8e8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "# Initialize the data module.\n",
+ "\n",
+ "BATCH_SIZE = 42\n",
+ "# 42 is a perfectly reasonable batch size. After all, it is the answer to the ultimate question of life, the universe and everything.\n",
+ "# More seriously, batch size does not have to be a power of 2.\n",
+ "# See: https://sebastianraschka.com/blog/2022/batch-size-2.html\n",
+ "\n",
+ "data_module = HCSDataModule(\n",
+ " data_path,\n",
+ " source_channel=\"Phase\",\n",
+ " target_channel=[\"Nuclei\", \"Membrane\"],\n",
+ " z_window_size=1,\n",
+ " split_ratio=0.8,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " num_workers=8,\n",
+ " architecture=\"2D\",\n",
+ " yx_patch_size=(512, 512), # larger patch size makes it easy to see augmentations.\n",
+ " augment=False, # Turn off augmentation for now.\n",
+ ")\n",
+ "data_module.setup(\"fit\")\n",
+ "\n",
+ "print(\n",
+ " f\"FOVs in training set: {len(data_module.train_dataset)}, FOVs in validation set:{len(data_module.val_dataset)}\"\n",
+ ")\n",
+ "train_dataloader = data_module.train_dataloader()\n",
+ "\n",
+ "# Instantiate the tensorboard SummaryWriter, logs the first batch and then iterates through all the batches and logs them to tensorboard.\n",
+ "\n",
+ "writer = SummaryWriter(log_dir=f\"{log_dir}/view_batch\")\n",
+ "# Draw a batch and write to tensorboard.\n",
+ "batch = next(iter(train_dataloader))\n",
+ "log_batch_tensorboard(batch, 0, writer, \"augmentation/none\")\n",
+ "\n",
+ "# Iterate through all the batches and log them to tensorboard.\n",
+ "for i, batch in enumerate(train_dataloader):\n",
+ " log_batch_tensorboard(batch, i, writer, \"augmentation/none\")\n",
+ "writer.close()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f6ffc7d1",
+ "metadata": {
+ "cell_marker": "\"\"\""
+ },
+ "source": [
+ "There are multiple ways of seeing the tensorboard.\n",
+ "1. Jupyter lab forwards the tensorboard port to the browser. Go to http://localhost:6006/ to see the tensorboard.\n",
+ "2. You likely have an open viewer in the first cell where you loaded tensorboard jupyter extension.\n",
+ "3. If you want to see tensorboard in a specific cell, use the following code.\n",
+ "```\n",
+ "notebook.list() # View open TensorBoard instances\n",
+ "notebook.display(port=6006, height=800) # Display the TensorBoard instance specified by the port.\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ed1448ff",
+ "metadata": {
+ "cell_marker": "\"\"\"",
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "## View augmentations using tensorboard.\n",
+ "\n",
+ "\n",
+ "Task 1.2\n",
+ "Setup the data loader and log several batches to tensorboard.\n",
+ "
`\n",
+ "\n",
+ "VisCy builds on top of PyTorch Lightning. PyTorch Lightning is a thin wrapper around PyTorch that allows rapid experimentation. It provides a [DataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) to handle loading and processing of data during training. VisCy provides a child class, `HCSDataModule` to make it intuitve to access data stored in the HCS layout.\n",
+ " \n",
+ "The dataloader in `HCSDataModule` returns a batch of samples. A `batch` is a list of dictionaries. The length of the list is equal to the batch size. Each dictionary consists of following key-value pairs.\n",
+ "- `source`: the input image, a tensor of size 1*1*Y*X\n",
+ "- `target`: the target image, a tensor of size 2*1*Y*X\n",
+ "- `index` : the tuple of (location of field in HCS layout, time, and z-slice) of the sample."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "05ce4dbb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Define a function to write a batch to tensorboard log.\n",
+ "\n",
+ "\n",
+ "def log_batch_tensorboard(batch, batchno, writer, card_name):\n",
+ " \"\"\"\n",
+ " Logs a batch of images to TensorBoard.\n",
+ "\n",
+ " Args:\n",
+ " batch (dict): A dictionary containing the batch of images to be logged.\n",
+ " writer (SummaryWriter): A TensorBoard SummaryWriter object.\n",
+ " card_name (str): The name of the card to be displayed in TensorBoard.\n",
+ "\n",
+ " Returns:\n",
+ " None\n",
+ " \"\"\"\n",
+ " batch_phase = batch[\"source\"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.\n",
+ " batch_membrane = batch[\"target\"][:, 1, 0, :, :].unsqueeze(\n",
+ " 1\n",
+ " ) # batch_size x 1 x Y x X tensor.\n",
+ " batch_nuclei = batch[\"target\"][:, 0, 0, :, :].unsqueeze(\n",
+ " 1\n",
+ " ) # batch_size x 1 x Y x X tensor.\n",
+ "\n",
+ " p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))\n",
+ " batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)\n",
+ "\n",
+ " p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))\n",
+ " batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)\n",
+ "\n",
+ " p1, p99 = np.percentile(batch_phase, (0.1, 99.9))\n",
+ " batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)\n",
+ "\n",
+ " [N, C, H, W] = batch_phase.shape\n",
+ " interleaved_images = torch.zeros((3 * N, C, H, W), dtype=batch_phase.dtype)\n",
+ " interleaved_images[0::3, :] = batch_phase\n",
+ " interleaved_images[1::3, :] = batch_nuclei\n",
+ " interleaved_images[2::3, :] = batch_membrane\n",
+ "\n",
+ " grid = torchvision.utils.make_grid(interleaved_images, nrow=3)\n",
+ "\n",
+ " # add the grid to tensorboard\n",
+ " writer.add_image(card_name, grid, batchno)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f627e8e8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "# Initialize the data module.\n",
+ "\n",
+ "BATCH_SIZE = 42\n",
+ "# 42 is a perfectly reasonable batch size. After all, it is the answer to the ultimate question of life, the universe and everything.\n",
+ "# More seriously, batch size does not have to be a power of 2.\n",
+ "# See: https://sebastianraschka.com/blog/2022/batch-size-2.html\n",
+ "\n",
+ "data_module = HCSDataModule(\n",
+ " data_path,\n",
+ " source_channel=\"Phase\",\n",
+ " target_channel=[\"Nuclei\", \"Membrane\"],\n",
+ " z_window_size=1,\n",
+ " split_ratio=0.8,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " num_workers=8,\n",
+ " architecture=\"2D\",\n",
+ " yx_patch_size=(512, 512), # larger patch size makes it easy to see augmentations.\n",
+ " augment=False, # Turn off augmentation for now.\n",
+ ")\n",
+ "data_module.setup(\"fit\")\n",
+ "\n",
+ "print(\n",
+ " f\"FOVs in training set: {len(data_module.train_dataset)}, FOVs in validation set:{len(data_module.val_dataset)}\"\n",
+ ")\n",
+ "train_dataloader = data_module.train_dataloader()\n",
+ "\n",
+ "# Instantiate the tensorboard SummaryWriter, logs the first batch and then iterates through all the batches and logs them to tensorboard.\n",
+ "\n",
+ "writer = SummaryWriter(log_dir=f\"{log_dir}/view_batch\")\n",
+ "# Draw a batch and write to tensorboard.\n",
+ "batch = next(iter(train_dataloader))\n",
+ "log_batch_tensorboard(batch, 0, writer, \"augmentation/none\")\n",
+ "\n",
+ "# Iterate through all the batches and log them to tensorboard.\n",
+ "for i, batch in enumerate(train_dataloader):\n",
+ " log_batch_tensorboard(batch, i, writer, \"augmentation/none\")\n",
+ "writer.close()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f6ffc7d1",
+ "metadata": {
+ "cell_marker": "\"\"\""
+ },
+ "source": [
+ "There are multiple ways of seeing the tensorboard.\n",
+ "1. Jupyter lab forwards the tensorboard port to the browser. Go to http://localhost:6006/ to see the tensorboard.\n",
+ "2. You likely have an open viewer in the first cell where you loaded tensorboard jupyter extension.\n",
+ "3. If you want to see tensorboard in a specific cell, use the following code.\n",
+ "```\n",
+ "notebook.list() # View open TensorBoard instances\n",
+ "notebook.display(port=6006, height=800) # Display the TensorBoard instance specified by the port.\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ed1448ff",
+ "metadata": {
+ "cell_marker": "\"\"\"",
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "## View augmentations using tensorboard.\n",
+ "\n",
+ "\n",
+ "Task 1.3\n",
+ "Turn on augmentation and view the batch in tensorboard."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "245ae4dd",
+ "metadata": {
+ "lines_to_next_cell": 2
+ },
+ "outputs": [],
+ "source": [
+ "##########################\n",
+ "######## TODO ########\n",
+ "##########################\n",
+ "\n",
+ "# Write code to turn on augmentations, change batch sizes and log them to tensorboard.\n",
+ "# See how the training data changes as a function of these parameters.\n",
+ "# Remember to call `data_module.setup(\"fit\")` after changing the parameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e14f4faf",
+ "metadata": {
+ "tags": [
+ "solution"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "##########################\n",
+ "######## Solution ########\n",
+ "##########################\n",
+ "\n",
+ "data_module.augment = True\n",
+ "data_module.batch_size = 21\n",
+ "data_module.split_ratio = 0.8\n",
+ "data_module.setup(\"fit\")\n",
+ "\n",
+ "train_dataloader = data_module.train_dataloader()\n",
+ "# Draw batches and write to tensorboard\n",
+ "writer = SummaryWriter(log_dir=f\"{log_dir}/view_batch\")\n",
+ "for i, batch in enumerate(train_dataloader):\n",
+ " log_batch_tensorboard(batch, i, writer, \"augmentation/some\")\n",
+ "writer.close()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "85743be9",
+ "metadata": {
+ "cell_marker": "\"\"\""
+ },
+ "source": [
+ "## Construct a 2D U-Net for image translation.\n",
+ "See ``viscy.unet.networks.Unet2D.Unet2d`` for configuration details.\n",
+ "We setup a fresh data module and instantiate the trainer class."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "44fddf02",
+ "metadata": {
+ "lines_to_next_cell": 2
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "# The entire training loop is contained in this cell.\n",
+ "\n",
+ "GPU_ID = 0\n",
+ "BATCH_SIZE = 10\n",
+ "YX_PATCH_SIZE = (512, 512)\n",
+ "\n",
+ "\n",
+ "# Dictionary that specifies key parameters of the model.\n",
+ "phase2fluor_config = {\n",
+ " \"architecture\": \"2D\",\n",
+ " \"num_filters\": [24, 48, 96, 192, 384],\n",
+ " \"in_channels\": 1,\n",
+ " \"out_channels\": 2,\n",
+ " \"residual\": True,\n",
+ " \"dropout\": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data.\n",
+ " \"task\": \"reg\", # reg = regression task.\n",
+ "}\n",
+ "\n",
+ "phase2fluor_model = VSUNet(\n",
+ " model_config=phase2fluor_config.copy(),\n",
+ " batch_size=BATCH_SIZE,\n",
+ " loss_function=torch.nn.functional.l1_loss,\n",
+ " schedule=\"WarmupCosine\",\n",
+ " log_num_samples=10, # Number of samples from each batch to log to tensorboard.\n",
+ " example_input_yx_shape=YX_PATCH_SIZE,\n",
+ ")\n",
+ "\n",
+ "# Reinitialize the data module.\n",
+ "phase2fluor_data = HCSDataModule(\n",
+ " data_path,\n",
+ " source_channel=\"Phase\",\n",
+ " target_channel=[\"Nuclei\", \"Membrane\"],\n",
+ " z_window_size=1,\n",
+ " split_ratio=0.8,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " num_workers=8,\n",
+ " architecture=\"2D\",\n",
+ " yx_patch_size=YX_PATCH_SIZE,\n",
+ " augment=True,\n",
+ ")\n",
+ "phase2fluor_data.setup(\"fit\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "292cb22a",
+ "metadata": {
+ "cell_marker": "\"\"\""
+ },
+ "source": [
+ "
\n",
+ "Task 1.4\n",
+ "Setup the training for ~30 epochs\n",
+ "
\n",
+ "\n",
+ "Tips:\n",
+ "- Set ``default_root_dir`` to store the logs and checkpoints\n",
+ "in a specific directory."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b9cf215b",
+ "metadata": {
+ "lines_to_next_cell": 2,
+ "title": "Setup trainer and check for errors."
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "# fast_dev_run runs a single batch of data through the model to check for errors.\n",
+ "trainer = VSTrainer(accelerator=\"gpu\", devices=[GPU_ID], fast_dev_run=True)\n",
+ "\n",
+ "# trainer class takes the model and the data module as inputs.\n",
+ "trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5f11d604",
+ "metadata": {
+ "lines_to_next_cell": 2
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "GPU_ID = 0\n",
+ "n_samples = len(phase2fluor_data.train_dataset)\n",
+ "steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch.\n",
+ "n_epochs = 30\n",
+ "\n",
+ "trainer = VSTrainer(\n",
+ " accelerator=\"gpu\",\n",
+ " devices=[GPU_ID],\n",
+ " max_epochs=n_epochs,\n",
+ " log_every_n_steps=steps_per_epoch // 2,\n",
+ " # log losses and image samples 2 times per epoch.\n",
+ " default_root_dir=Path(\n",
+ " log_dir, \"phase2fluor\"\n",
+ " ), # lightning trainer transparently saves logs and model checkpoints in this directory.\n",
+ ")\n",
+ "\n",
+ "# Log graph\n",
+ "trainer.logger.log_graph(phase2fluor_model, phase2fluor_data.train_dataset[0][\"source\"])\n",
+ "# Launch training.\n",
+ "trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "157a989c",
+ "metadata": {
+ "cell_marker": "\"\"\""
+ },
+ "source": [
+ "
\n",
+ "Checkpoint 1\n",
+ "\n",
+ "Now the training has started,\n",
+ "we can come back after a while and evaluate the performance!\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d68dab93",
+ "metadata": {
+ "cell_marker": "\"\"\"",
+ "incorrectly_encoded_metadata": "id='1_fluor2phase'>",
+ "title": "
\n",
+ "Checkpoint 2\n",
+ "Please summarize hyperparameters and performance of your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z)\n",
+ "\n",
+ "Now that you have trained two models, let's think about the following questions:\n",
+ "- What is the information content of each channel in the dataset?\n",
+ "- How would you use image translation models?\n",
+ "- What can you try to improve the performance of each model?\n",
+ "\n",
+ "\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7d0c4204",
+ "metadata": {
+ "cell_marker": "\"\"\"",
+ "incorrectly_encoded_metadata": "id='3_tuning'>",
+ "title": "
+Task 1.3
+Turn on augmentation and view the batch in tensorboard.
+"""
+# %%
+##########################
+######## TODO ########
+##########################
+
+# Write code to turn on augmentations, change batch sizes and log them to tensorboard.
+# See how the training data changes as a function of these parameters.
+# Remember to call `data_module.setup("fit")` after changing the parameters.
+
+
+# %% tags=["solution"]
+##########################
+######## Solution ########
+##########################
+
+data_module.augment = True
+data_module.batch_size = 21
+data_module.split_ratio = 0.8
+data_module.setup("fit")
+
+train_dataloader = data_module.train_dataloader()
+# Draw batches and write to tensorboard
+writer = SummaryWriter(log_dir=f"{log_dir}/view_batch")
+for i, batch in enumerate(train_dataloader):
+ log_batch_tensorboard(batch, i, writer, "augmentation/some")
+writer.close()
+
+# %% [markdown]
+"""
+## Construct a 2D U-Net for image translation.
+See ``viscy.unet.networks.Unet2D.Unet2d`` for configuration details.
+We setup a fresh data module and instantiate the trainer class.
+"""
+
+# %%
+
+# The entire training loop is contained in this cell.
+
+GPU_ID = 0
+BATCH_SIZE = 10
+YX_PATCH_SIZE = (512, 512)
+
+
+# Dictionary that specifies key parameters of the model.
+phase2fluor_config = {
+ "architecture": "2D",
+ "num_filters": [24, 48, 96, 192, 384],
+ "in_channels": 1,
+ "out_channels": 2,
+ "residual": True,
+ "dropout": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data.
+ "task": "reg", # reg = regression task.
+}
+
+phase2fluor_model = VSUNet(
+ model_config=phase2fluor_config.copy(),
+ batch_size=BATCH_SIZE,
+ loss_function=torch.nn.functional.l1_loss,
+ schedule="WarmupCosine",
+ log_num_samples=10, # Number of samples from each batch to log to tensorboard.
+ example_input_yx_shape=YX_PATCH_SIZE,
+)
+
+# Reinitialize the data module.
+phase2fluor_data = HCSDataModule(
+ data_path,
+ source_channel="Phase",
+ target_channel=["Nuclei", "Membrane"],
+ z_window_size=1,
+ split_ratio=0.8,
+ batch_size=BATCH_SIZE,
+ num_workers=8,
+ architecture="2D",
+ yx_patch_size=YX_PATCH_SIZE,
+ augment=True,
+)
+phase2fluor_data.setup("fit")
+
+
+# %% [markdown]
+"""
+
+Task 1.4
+Setup the training for ~30 epochs
+
+
+Tips:
+- Set ``default_root_dir`` to store the logs and checkpoints
+in a specific directory.
+"""
+
+# %% Setup trainer and check for errors.
+
+# fast_dev_run runs a single batch of data through the model to check for errors.
+trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], fast_dev_run=True)
+
+# trainer class takes the model and the data module as inputs.
+trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)
+
+
+# %%
+
+GPU_ID = 0
+n_samples = len(phase2fluor_data.train_dataset)
+steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch.
+n_epochs = 30
+
+trainer = VSTrainer(
+ accelerator="gpu",
+ devices=[GPU_ID],
+ max_epochs=n_epochs,
+ # log losses and image samples 2 times per epoch.
+ log_every_n_steps=steps_per_epoch // 2,
+ # lightning trainer transparently saves logs and model checkpoints in this directory
+ logger=TensorBoardLogger(
+ save_dir=log_dir,
+ name="phase2fluor",
+ log_graph=True,
+ ),
+)
+
+# Launch training.
+trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)
+
+
+# %% [markdown]
+"""
+
+Checkpoint 1
+
+Now the training has started,
+we can come back after a while and evaluate the performance!
+
+"""
+
+# %% [markdown]
+"""
+# Part 2: Assess previous model, train fluorescence to phase contrast translation model.
+--------------------------------------------------
+
+Learning goals:
+- Visualize the previous model and training with tensorboard
+- Train fluorescence to phase contrast translation model
+- Compare the performance of the two models.
+
+"""
+
+# %%
+
+# PyTorch uses dynamic graphs under the hood. The graphs are constructed on the fly. This is in contrast to TensorFlow, where the graph is constructed before the training loop and remains static. In other words, the graph of the network can change with every forward pass. Therefore, we need to supply an input tensor to construct the graph. The input tensor can be a random tensor of the correct shape and type. We can also supply a real image from the dataset. The latter is more useful for debugging.
+
+# visualize graph.
+model_graph_phase2fluor = torchview.draw_graph(
+ phase2fluor_model,
+ phase2fluor_data.train_dataset[0]["source"],
+ depth=2, # adjust depth to zoom in.
+ device="cpu",
+)
+# Increase the depth to zoom in.
+model_graph_phase2fluor.visual_graph
+
+# %% tags = ["solution"]
+fluor2phase_data = HCSDataModule(
+ data_path,
+ source_channel="Nuclei",
+ target_channel="Phase",
+ z_window_size=1,
+ split_ratio=0.8,
+ batch_size=BATCH_SIZE,
+ num_workers=8,
+ architecture="2D",
+ yx_patch_size=YX_PATCH_SIZE,
+ augment=True,
+)
+fluor2phase_data.setup("fit")
+
+# Dictionary that specifies key parameters of the model.
+fluor2phase_config = {
+ "architecture": "2D",
+ "in_channels": 1,
+ "out_channels": 1,
+ "residual": True,
+ "dropout": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data.
+ "task": "reg", # reg = regression task.
+ "num_filters": [24, 48, 96, 192, 384],
+}
+
+fluor2phase_model = VSUNet(
+ model_config=fluor2phase_config.copy(),
+ batch_size=BATCH_SIZE,
+ loss_function=torch.nn.functional.mse_loss,
+ schedule="WarmupCosine",
+ log_num_samples=10,
+ example_input_yx_shape=YX_PATCH_SIZE,
+)
+
+n_samples = len(fluor2phase_data.train_dataset)
+steps_per_epoch = n_samples // BATCH_SIZE
+n_epochs = 30
+
+trainer = VSTrainer(
+ accelerator="gpu",
+ devices=[GPU_ID],
+ max_epochs=n_epochs,
+ log_every_n_steps=steps_per_epoch,
+ logger=TensorBoardLogger(
+ save_dir=log_dir,
+ name="fluor2phase",
+ log_graph=True,
+ ),
+)
+trainer.fit(fluor2phase_model, datamodule=fluor2phase_data)
+
+# %%
+# Visualize the graph of fluor2phase model.
+model_graph_fluor2phase = torchview.draw_graph(
+ phase2fluor_model,
+ phase2fluor_data.train_dataset[0]["source"],
+ depth=2, # adjust depth to zoom in.
+ device="cpu",
+)
+model_graph_fluor2phase.visual_graph
+
+# %% [markdown]
+"""
+We now look at some metrics of performance. Loss is a differentiable metric. But, several non-differentiable metrics are useful to assess the performance of the model. We typically evaluate the model performance on a held out test data. We will use the following metrics to evaluate the accuracy of regression of the model:
+- [Coefficient of determination](https://en.wikipedia.org/wiki/Coefficient_of_determination): $R^2$
+- [Person Correlation](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient).
+- [Structural similarity](https://en.wikipedia.org/wiki/Structural_similarity) (SSIM):
+
+"""
+# %%
+
+# TODO: set following parameters, specifically path to checkpoint, and log the metrics.
+test_data_path = Path(
+ "~/data/04_image_translation/HEK_nuclei_membrane_test.zarr"
+).expanduser()
+model_version = "phase2fluor"
+save_dir = Path(log_dir, "test")
+ckpt_path = Path(
+ r"/home/mehtas/data/04_image_translation/logs/phase2fluor/lightning_logs/version_0/checkpoints/epoch=29-step=720.ckpt"
+) # prefix the string with 'r' to avoid the need for escape characters.
+### END TODO
+
+test_data = HCSDataModule(
+ test_data_path,
+ source_channel="Phase",
+ target_channel=["Nuclei", "Membrane"],
+ z_window_size=1,
+ batch_size=1,
+ num_workers=8,
+ architecture="2D",
+)
+test_data.setup("test")
+trainer = VSTrainer(
+ accelerator="gpu",
+ devices=[GPU_ID],
+ logger=CSVLogger(save_dir=save_dir, version=model_version),
+)
+trainer.test(
+ phase2fluor_model,
+ datamodule=test_data,
+ ckpt_path=ckpt_path,
+)
+# read metrics and plot
+metrics = pd.read_csv(Path(save_dir, "lightning_logs", model_version, "metrics.csv"))
+metrics.boxplot(
+ column=[
+ "test_metrics/r2_step",
+ "test_metrics/pearson_step",
+ "test_metrics/SSIM_step",
+ ],
+ rot=30,
+)
+# %% [markdown]
+"""
+
+Checkpoint 2
+Please summarize hyperparameters and performance of your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z)
+
+Now that you have trained two models, let's think about the following questions:
+- What is the information content of each channel in the dataset?
+- How would you use image translation models?
+- What can you try to improve the performance of each model?
+
+
+
+"""
+
+# %% [markdown]
+"""
+# Part 3: Tune the models.
+--------------------------------------------------
+
+Learning goals:
+
+- Tweak model hyperparameters, such as number of filters at each depth.
+- Adjust learning rate to improve performance.
+"""
+
+# %%
+# %%
+##########################
+######## TODO ########
+##########################
+
+# Choose a model you want to train (phase2fluor or fluor2phase).
+# Create a config to double the number of filters at each stage.
+# Use training loop illustrated in previous cells to train phase2fluor and fluor2phase models to prototype your own training loop.
+
+
+# %% tags = ["solution"]
+
+##########################
+######## Solution ########
+##########################
+
+phase2fluor_wider_config = {
+ "architecture": "2D",
+ # double the number of filters at each stage
+ "num_filters": [48, 96, 192, 384, 768],
+ "in_channels": 1,
+ "out_channels": 2,
+ "residual": True,
+ "dropout": 0.1,
+ "task": "reg",
+}
+
+phase2fluor_wider_model = VSUNet(
+ model_config=phase2fluor_wider_config.copy(),
+ batch_size=BATCH_SIZE,
+ loss_function=torch.nn.functional.l1_loss,
+ schedule="WarmupCosine",
+ log_num_samples=10,
+ example_input_yx_shape=YX_PATCH_SIZE,
+)
+
+
+trainer = VSTrainer(
+ accelerator="gpu",
+ devices=[GPU_ID],
+ max_epochs=n_epochs,
+ log_every_n_steps=steps_per_epoch,
+ logger=TensorBoardLogger(
+ save_dir=log_dir,
+ name="phase2fluor",
+ version="wider",
+ log_graph=True,
+ ),
+ fast_dev_run=True,
+) # Set fast_dev_run to False to train the model.
+trainer.fit(phase2fluor_wider_model, datamodule=phase2fluor_data)
+
+# %%
+##########################
+######## TODO ########
+##########################
+
+# Choose a model you want to train (phase2fluor or fluor2phase).
+# Train it with lower learning rate to see how the performance changes.
+
+
+# %% tags = ["solution"]
+
+##########################
+######## Solution ########
+##########################
+
+phase2fluor_slow_model = VSUNet(
+ model_config=phase2fluor_config.copy(),
+ batch_size=BATCH_SIZE,
+ loss_function=torch.nn.functional.l1_loss,
+ # lower learning rate by 5 times
+ lr=2e-4,
+ schedule="WarmupCosine",
+ log_num_samples=10,
+ example_input_yx_shape=YX_PATCH_SIZE,
+)
+
+trainer = VSTrainer(
+ accelerator="gpu",
+ devices=[GPU_ID],
+ max_epochs=n_epochs,
+ log_every_n_steps=steps_per_epoch,
+ logger=TensorBoardLogger(
+ save_dir=log_dir,
+ name="phase2fluor",
+ version="low_lr",
+ log_graph=True,
+ ),
+ fast_dev_run=True,
+)
+trainer.fit(phase2fluor_slow_model, datamodule=phase2fluor_data)
+
+
+# %% [markdown]
+"""
+
+Checkpoint 3
+
+Congratulations! You have trained several image translation models now!
+Please document hyperparameters, snapshots of predictioons on validation set, and loss curves for your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z)
+
+"""
diff --git a/pyproject.toml b/pyproject.toml
index c0f73bd2..485f823f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,7 +23,7 @@ dependencies = [
"ipykernel", # used by demo_dlmbl
"graphviz", # used by demo_dlmbl
"torchview", # used by demo_dlmbl
-]
+ ]
dynamic = ["version"]
[project.optional-dependencies]
@@ -33,6 +33,7 @@ metrics = [
"scikit-learn>=1.1.3",
"scipy>=1.8.0",
"torchmetrics[detection]>=1.0.0",
+ "ptflops>=0.7",
]
dev = ["pytest", "pytest-cov", "hypothesis", "profilehooks", "onnxruntime"]
diff --git a/tests/evaluation/test_evaluation_metrics.py b/tests/evaluation/test_evaluation_metrics.py
index 30ef15b3..78a7233f 100644
--- a/tests/evaluation/test_evaluation_metrics.py
+++ b/tests/evaluation/test_evaluation_metrics.py
@@ -2,6 +2,7 @@
import pytest
import torch
from skimage import data, measure
+from skimage.util import img_as_float
from viscy.evaluation.evaluation_metrics import (
POD_metric,
@@ -9,6 +10,8 @@
labels_to_detection,
labels_to_masks,
mean_average_precision,
+ ms_ssim_25d,
+ ssim_25d,
)
@@ -101,3 +104,37 @@ def test_mean_average_precision(labels_tensor: torch.ShortTensor):
assert coco_metrics["map"] == 1
assert _is_within_unit(coco_metrics["mar_1"])
assert coco_metrics["mar_10"] == 1
+
+
+def test_ssim_25d():
+ img = torch.from_numpy(img_as_float(data.camera()[np.newaxis, np.newaxis]))
+ img = torch.stack([img] * 5, dim=2)
+ # comparing to self should be almost 1
+ ssim_self = ssim_25d(img, img)
+ assert torch.allclose(ssim_self, torch.tensor(1.0))
+ # add $\mathcal{U}(0, 1)$ additive noise to mimic prediction
+ # should still be positive correlation
+ img_pred = img + torch.rand(img.shape) - 0.5
+ ssim_pred = ssim_25d(img_pred, img)
+ assert _is_within_unit(ssim_pred)
+ # inverted should be negative
+ img_inv = 1 - img
+ ssim_inv = ssim_25d(img_inv, img)
+ assert _is_within_unit(-ssim_inv)
+
+
+def test_ms_ssim_25d():
+ img = torch.from_numpy(img_as_float(data.camera()[np.newaxis, np.newaxis]))
+ img = torch.stack([img] * 5, dim=2)
+ # comparing to self should be almost 1
+ ssim_self = ms_ssim_25d(img, img)
+ assert torch.allclose(ssim_self, torch.tensor(1.0))
+ # add $\mathcal{U}(0, 1)$ additive noise to mimic prediction
+ # should still be positive correlation
+ noise = torch.rand(img.shape)
+ img_pred = img + noise - 0.5
+ ssim_pred = ms_ssim_25d(img_pred, img)
+ assert _is_within_unit(ssim_pred)
+ # Negative correlation should be zero when clamped
+ ssim_inv = ms_ssim_25d(1 - img, img, clamp=True)
+ assert torch.allclose(ssim_inv, torch.tensor(0.0))
\ No newline at end of file
diff --git a/viscy/cli/cli.py b/viscy/cli/cli.py
index ef1cd48b..37179d13 100644
--- a/viscy/cli/cli.py
+++ b/viscy/cli/cli.py
@@ -21,9 +21,8 @@ def subcommands() -> dict[str, set[str]]:
return subcommands
def add_arguments_to_parser(self, parser):
- parser.link_arguments("data.batch_size", "model.batch_size")
parser.link_arguments("data.yx_patch_size", "model.example_input_yx_shape")
- parser.link_arguments("model.model_config.architecture", "data.architecture")
+ parser.link_arguments("model.architecture", "data.architecture")
parser.set_defaults(
{
"trainer.logger": lazy_instance(
diff --git a/viscy/evaluation/evaluation_metrics.py b/viscy/evaluation/evaluation_metrics.py
index deed3394..b1f8b4ae 100644
--- a/viscy/evaluation/evaluation_metrics.py
+++ b/viscy/evaluation/evaluation_metrics.py
@@ -1,7 +1,12 @@
"""Metrics for model evaluation"""
+from typing import Sequence, Union
+from warnings import warn
+
import numpy as np
import torch
+import torch.nn.functional as F
from lapsolver import solve_dense
+from monai.metrics.regression import compute_ssim_and_cs
from skimage.measure import label, regionprops
from torchmetrics.detection import MeanAveragePrecision
from torchvision.ops import masks_to_boxes
@@ -169,3 +174,89 @@ def mean_average_precision(
[labels_to_detection(pred_labels)], [labels_to_detection(target_labels)]
)
return map_metric.compute()
+
+
+def ssim_25d(
+ preds: torch.Tensor,
+ target: torch.Tensor,
+ in_plane_window_size: tuple[int, int] = (11, 11),
+ return_contrast_sensitivity: bool = False,
+) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
+ """Multi-scale SSIM loss function for 2.5D volumes (3D with small depth).
+ Uses uniform kernel (windows), depth-dimension window size equals to depth size.
+
+ :param torch.Tensor preds: predicted batch (B, C, D, W, H)
+ :param torch.Tensor target: target batch
+ :param tuple[int, int] in_plane_window_size: kernel width and height,
+ by default (11, 11)
+ :param bool return_contrast_sensitivity: whether to return contrast sensitivity
+ :return torch.Tensor: SSIM for the batch
+ :return Optional[torch.Tensor]: contrast sensitivity
+ """
+ if preds.ndim != 5:
+ raise ValueError(
+ f"Input shape must be (B, C, D, W, H), got input shape {preds.shape}"
+ )
+ depth = preds.shape[2]
+ if depth > 15:
+ warn(f"Input depth {depth} is potentially too large for 2.5D SSIM.")
+ ssim_img, cs_img = compute_ssim_and_cs(
+ preds,
+ target,
+ 3,
+ kernel_sigma=None,
+ kernel_size=(depth, *in_plane_window_size),
+ data_range=target.max(),
+ kernel_type="uniform",
+ )
+ # aggregate to one scalar per batch
+ ssim = ssim_img.view(ssim_img.shape[0], -1).mean(1)
+ if return_contrast_sensitivity:
+ return ssim, cs_img.view(cs_img.shape[0], -1).mean(1)
+ else:
+ return ssim
+
+
+def ms_ssim_25d(
+ preds: torch.Tensor,
+ target: torch.Tensor,
+ in_plane_window_size: tuple[int, int] = (11, 11),
+ clamp: bool = False,
+ betas: Sequence[float] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
+) -> torch.Tensor:
+ """Multi-scale SSIM for 2.5D volumes (3D with small depth).
+ Uses uniform kernel (windows), depth-dimension window size equals to depth size.
+ Depth dimension is not downsampled.
+
+ Adapted from torchmetrics@99d6d9d6ac4eb1b3398241df558604e70521e6b0
+ Original license:
+ Copyright The Lightning team, http://www.apache.org/licenses/LICENSE-2.0
+
+ :param torch.Tensor preds: predicted images
+ :param torch.Tensor target: target images
+ :param tuple[int, int] in_plane_window_size: kernel width and height,
+ defaults to (11, 11)
+ :param bool clamp: clamp with ReLU for training stability when used in loss,
+ defaults to False
+ :param Sequence[float] betas: exponents of each resolution,
+ defaults to (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
+ :return torch.Tensor: multi-scale SSIM
+ """
+ mcs_list = []
+ for _ in range(len(betas)):
+ ssim, contrast_sensitivity = ssim_25d(
+ preds, target, in_plane_window_size, return_contrast_sensitivity=True
+ )
+ if clamp:
+ contrast_sensitivity = torch.relu(contrast_sensitivity)
+ mcs_list.append(contrast_sensitivity)
+ # do not downsample along depth
+ preds = F.avg_pool3d(preds, (1, 2, 2))
+ target = F.avg_pool3d(target, (1, 2, 2))
+ if clamp:
+ ssim = torch.relu(ssim)
+ mcs_list[-1] = ssim
+ mcs_stack = torch.stack(mcs_list)
+ betas = torch.tensor(betas, device=mcs_stack.device).view(-1, 1)
+ mcs_weighted = mcs_stack**betas
+ return torch.prod(mcs_weighted, axis=0).mean()
diff --git a/viscy/light/data.py b/viscy/light/data.py
index 870cdede..41eb62f6 100644
--- a/viscy/light/data.py
+++ b/viscy/light/data.py
@@ -13,6 +13,7 @@
from iohub.ngff import ImageArray, Plate, Position, open_ome_zarr
from lightning.pytorch import LightningDataModule
from monai.data import set_track_meta
+from monai.data.utils import collate_meta_tensor
from monai.transforms import (
CenterSpatialCropd,
Compose,
@@ -59,9 +60,18 @@ class ChannelMap(TypedDict, total=False):
class Sample(TypedDict, total=False):
index: tuple[str, int, int]
# optional
- source: torch.Tensor
- target: torch.Tensor
- labels: torch.Tensor
+ source: Union[torch.Tensor, Sequence[torch.Tensor]]
+ target: Union[torch.Tensor, Sequence[torch.Tensor]]
+ labels: Union[torch.Tensor, Sequence[torch.Tensor]]
+
+
+def _collate_samples(batch: Sequence[Sample]) -> Sample:
+ elemment = batch[0]
+ collated = {}
+ for key in elemment.keys():
+ data: list[list[torch.Tensor]] = [sample[key] for sample in batch]
+ collated[key] = collate_meta_tensor([im for imgs in data for im in imgs])
+ return collated
class NormalizeSampled(MapTransform, InvertibleTransform):
@@ -178,9 +188,15 @@ def __len__(self) -> int:
return self._max_window
def _stack_channels(
- self, sample_images: dict[str, torch.Tensor], key: str
+ self, sample_images: list[dict[str, torch.Tensor]], key: str
) -> torch.Tensor:
- return torch.stack([sample_images[ch][0] for ch in self.channels[key]])
+ if not isinstance(sample_images, list):
+ return torch.stack([sample_images[ch][0] for ch in self.channels[key]])
+ # training time
+ return [
+ torch.stack([im[ch][0] for ch in self.channels[key]])
+ for im in sample_images
+ ]
def __getitem__(self, index: int) -> Sample:
img, tz = self._find_window(index)
@@ -198,8 +214,8 @@ def __getitem__(self, index: int) -> Sample:
sample_images["weight"] = sample_images[self.channels["target"][0]]
if self.transform:
sample_images = self.transform(sample_images)
- if isinstance(sample_images, list):
- sample_images = sample_images[0]
+ # if isinstance(sample_images, list):
+ # sample_images = sample_images[0]
if "weight" in sample_images:
del sample_images["weight"]
sample = {
@@ -275,7 +291,7 @@ class HCSDataModule(LightningDataModule):
by default 0.8
:param int batch_size: batch size, defaults to 16
:param int num_workers: number of data-loading workers, defaults to 8
- :param Literal["2.5D", "2D", "3D"] architecture: U-Net architecture,
+ :param Literal["2D", "2.1D", "2.2D", "2.5D", "3D"] architecture: U-Net architecture,
defaults to "2.5D"
:param tuple[int, int] yx_patch_size: patch size in (Y, X),
defaults to (256, 256)
@@ -292,6 +308,8 @@ class HCSDataModule(LightningDataModule):
:param float train_noise_std: Upper bound of the standard deviation
of the Gaussian noise added to source images during training,
defaults to 0.0
+ :param int train_patches_per_stack: number of patches to sample
+ from each stack during training, defaults to 1
"""
def __init__(
@@ -303,7 +321,7 @@ def __init__(
split_ratio: float = 0.8,
batch_size: int = 16,
num_workers: int = 8,
- architecture: Literal["2.5D", "2D", "3D"] = "2.5D",
+ architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D"] = "2.5D",
yx_patch_size: tuple[int, int] = (256, 256),
augment: bool = True,
caching: bool = False,
@@ -311,6 +329,7 @@ def __init__(
ground_truth_masks: str = None,
train_z_scale_range: tuple[float, float] = [0, 0],
train_noise_std: float = 0.0,
+ train_patches_per_stack: int = 1,
):
super().__init__()
self.data_path = data_path
@@ -318,7 +337,7 @@ def __init__(
self.target_channel = _ensure_channel_list(target_channel)
self.batch_size = batch_size
self.num_workers = num_workers
- self.target_2d = False if architecture == "3D" else True
+ self.target_2d = False if architecture in ["2.2D", "3D"] else True
self.z_window_size = z_window_size
self.split_ratio = split_ratio
self.yx_patch_size = yx_patch_size
@@ -333,6 +352,12 @@ def __init__(
raise ValueError(f"Invalid scaling range: {train_z_scale_range}")
self.train_z_scale_range = train_z_scale_range
self.train_noise_std = train_noise_std
+ if batch_size % train_patches_per_stack != 0:
+ raise ValueError(
+ "Batch size must be divisible by number of patches per stack. "
+ f"Got {batch_size} and {train_patches_per_stack}."
+ )
+ self.train_patches_per_stack = train_patches_per_stack
def prepare_data(self):
if not self.caching:
@@ -414,8 +439,13 @@ def _setup_fit(self, dataset_settings: dict):
num_train_fovs = int(len(positions) * self.split_ratio)
# training set needs to sample more Z range for augmentation
train_dataset_settings = dataset_settings.copy()
- expanded_z = math.ceil(self.z_window_size * (1 + self.train_z_scale_range[1]))
- train_dataset_settings["z_window_size"] = max(1, expanded_z - expanded_z % 2)
+ z_scale_low, z_scale_high = self.train_z_scale_range
+ if z_scale_high <= 0.0:
+ expanded_z = self.z_window_size
+ else:
+ expanded_z = math.ceil(self.z_window_size * (1 + z_scale_high))
+ expanded_z -= expanded_z % 2
+ train_dataset_settings["z_window_size"] = expanded_z
# train/val split
self.train_dataset = SlidingWindowDataset(
positions[:num_train_fovs],
@@ -430,12 +460,19 @@ def _setup_test(self, dataset_settings):
if self.batch_size != 1:
logging.warning(f"Ignoring batch size {self.batch_size} in test stage.")
plate, normalize_transform = self._setup_eval(dataset_settings)
- self.test_dataset = MaskTestDataset(
- [p for _, p in plate.positions()],
- transform=normalize_transform,
- ground_truth_masks=self.ground_truth_masks,
- **dataset_settings,
- )
+ if self.ground_truth_masks:
+ self.test_dataset = MaskTestDataset(
+ [p for _, p in plate.positions()],
+ transform=normalize_transform,
+ ground_truth_masks=self.ground_truth_masks,
+ **dataset_settings,
+ )
+ else:
+ self.test_dataset = SlidingWindowDataset(
+ [p for _, p in plate.positions()],
+ transform=normalize_transform,
+ **dataset_settings,
+ )
def _setup_predict(self, dataset_settings: dict):
# track metadata for inverting transform
@@ -474,10 +511,11 @@ def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample
def train_dataloader(self):
return DataLoader(
self.train_dataset,
- batch_size=self.batch_size,
+ batch_size=self.batch_size // self.train_patches_per_stack,
num_workers=self.num_workers,
shuffle=True,
persistent_workers=bool(self.num_workers),
+ collate_fn=_collate_samples,
)
def val_dataloader(self):
@@ -523,7 +561,7 @@ def _train_transform(self) -> list[Callable]:
keys=self.source_channel + self.target_channel,
w_key="weight",
spatial_size=(-1, self.yx_patch_size[0] * 2, self.yx_patch_size[1] * 2),
- num_samples=1,
+ num_samples=self.train_patches_per_stack,
)
]
if self.augment:
diff --git a/viscy/light/engine.py b/viscy/light/engine.py
index 074fee98..3f6c8ce1 100644
--- a/viscy/light/engine.py
+++ b/viscy/light/engine.py
@@ -1,10 +1,9 @@
import logging
import os
-from typing import Callable, Literal, Sequence
+from typing import Literal, Sequence, Union
import numpy as np
import torch
-import torch.nn.functional as F
from imageio import imwrite
from lightning.pytorch import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized
@@ -12,6 +11,8 @@
from monai.optimizers import WarmupCosineSchedule
from monai.transforms import DivisiblePad
from skimage.exposure import rescale_intensity
+from torch import nn
+from torch.nn import functional as F
from torch.onnx import OperatorExportTypes
from torch.optim.lr_scheduler import ConstantLR
from torchmetrics.functional import (
@@ -26,7 +27,7 @@
structural_similarity_index_measure,
)
-from viscy.evaluation.evaluation_metrics import mean_average_precision
+from viscy.evaluation.evaluation_metrics import mean_average_precision, ms_ssim_25d
from viscy.light.data import Sample
from viscy.unet.networks.Unet2D import Unet2d
from viscy.unet.networks.Unet21D import Unet21d
@@ -41,10 +42,49 @@
_UNET_ARCHITECTURE = {
"2D": Unet2d,
"2.1D": Unet21d,
+ # same class with out_stack_depth > 1
+ "2.2D": Unet21d,
"2.5D": Unet25d,
}
+class MixedLoss(nn.Module):
+ """Mixed reconstruction loss.
+ Adapted from Zhao et al, https://arxiv.org/pdf/1511.08861.pdf
+ Reduces to simple distances if only one weight is non-zero.
+
+ :param float l1_alpha: L1 loss weight, defaults to 0.5
+ :param float l2_alpha: L2 loss weight, defaults to 0.0
+ :param float ms_dssim_alpha: MS-DSSIM weight, defaults to 0.5
+ """
+
+ def __init__(
+ self, l1_alpha: float = 0.5, l2_alpha: float = 0.0, ms_dssim_alpha: float = 0.5
+ ):
+ super().__init__()
+ if not any([l1_alpha, l2_alpha, ms_dssim_alpha]):
+ raise ValueError("Loss term weights cannot be all zero!")
+ self.l1_alpha = l1_alpha
+ self.l2_alpha = l2_alpha
+ self.ms_dssim_alpha = ms_dssim_alpha
+
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(self, preds, target):
+ loss = 0
+ if self.l1_alpha:
+ # the gaussian in the reference is not used
+ # because the SSIM here uses a uniform window
+ loss += F.l1_loss(preds, target) * self.l1_alpha
+ if self.l2_alpha:
+ loss += F.mse_loss(preds, target) * self.l2_alpha
+ if self.ms_dssim_alpha:
+ ms_ssim = ms_ssim_25d(preds, target, clamp=True)
+ # the 1/2 factor in the original DSSIM is not used
+ # since the MS-SSIM here is stabilized with ReLU
+ loss += (1 - ms_ssim) * self.ms_dssim_alpha
+ return loss
+
+
class VSTrainer(Trainer):
def export(
self,
@@ -104,14 +144,17 @@ class VSUNet(LightningModule):
:param dict model_config: model config,
defaults to :py:class:`viscy.unet.utils.model.ModelDefaults25D`
- :param int batch_size: batch size, defaults to 16
- :param Callable[[torch.Tensor, torch.Tensor], torch.Tensor] loss_function:
- loss function in training/validation, defaults to L2 (mean squared error)
+ :param Union[nn.Module, MixedLoss] loss_function:
+ loss function in training/validation,
+ if a dictionary, should specify weights of each term
+ ('l1_alpha', 'l2_alpha', 'ssim_alpha')
+ defaults to L2 (mean squared error)
:param float lr: learning rate in training, defaults to 1e-3
:param Literal['WarmupCosine', 'Constant'] schedule:
learning rate scheduler, defaults to "Constant"
:param int log_num_samples:
- number of image samples to log each training/validation epoch, defaults to 8
+ number of image samples to log each training/validation epoch,
+ has to be smaller than batch size, defaults to 8
:param Sequence[int] example_input_yx_shape:
XY shape of the example input for network graph tracing, defaults to (256, 256)
:param str test_cellpose_model_path:
@@ -126,9 +169,9 @@ class VSUNet(LightningModule):
def __init__(
self,
+ architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D"],
model_config: dict = {},
- batch_size: int = 16,
- loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
+ loss_function: Union[nn.Module, MixedLoss] = None,
lr: float = 1e-3,
schedule: Literal["WarmupCosine", "Constant"] = "Constant",
log_num_samples: int = 8,
@@ -138,28 +181,30 @@ def __init__(
test_evaluate_cellpose: bool = False,
) -> None:
super().__init__()
- arch = model_config.pop("architecture")
- net_class = _UNET_ARCHITECTURE.get(arch)
- if not arch:
- raise ValueError(f"Architecture {arch} not in {_UNET_ARCHITECTURE.keys()}")
+ net_class = _UNET_ARCHITECTURE.get(architecture)
+ if not net_class:
+ raise ValueError(
+ f"Architecture {architecture} not in {_UNET_ARCHITECTURE.keys()}"
+ )
+ if architecture == "2.2D":
+ model_config["out_stack_depth"] = model_config["in_stack_depth"]
self.model = net_class(**model_config)
# TODO: handle num_outputs in metrics
# self.out_channels = self.model.terminal_block.out_filters
- self.batch_size = batch_size
- self.loss_function = loss_function if loss_function else F.mse_loss
+ self.loss_function = loss_function if loss_function else nn.MSELoss()
self.lr = lr
self.schedule = schedule
self.log_num_samples = log_num_samples
self.training_step_outputs = []
self.validation_step_outputs = []
# required to log the graph
- if arch == "2D":
+ if architecture == "2D":
example_depth = 1
else:
example_depth = model_config.get("in_stack_depth") or 5
self.example_input_array = torch.rand(
1,
- 1,
+ model_config.get("in_channels") or 1,
example_depth,
*example_input_yx_shape,
)
@@ -182,11 +227,10 @@ def training_step(self, batch: Sample, batch_idx: int):
on_epoch=True,
prog_bar=True,
logger=True,
- batch_size=self.batch_size,
sync_dist=True,
)
- if batch_idx < self.log_num_samples:
- self.training_step_outputs.append(
+ if batch_idx == 0:
+ self.training_step_outputs.extend(
self._detach_sample((source, target, pred))
)
return loss
@@ -196,9 +240,9 @@ def validation_step(self, batch: Sample, batch_idx: int):
target = batch["target"]
pred = self.forward(source)
loss = self.loss_function(pred, target)
- self.log("loss/validate", loss, batch_size=self.batch_size, sync_dist=True)
- if batch_idx < self.log_num_samples:
- self.validation_step_outputs.append(
+ self.log("loss/validate", loss, sync_dist=True)
+ if batch_idx == 0:
+ self.validation_step_outputs.extend(
self._detach_sample((source, target, pred))
)
@@ -309,11 +353,17 @@ def on_validation_epoch_end(self):
def on_test_start(self):
"""Load CellPose model for segmentation."""
if CellposeModel is None:
- raise ImportError(
+ # raise ImportError(
+ # "CellPose not installed. "
+ # "Please install the metrics dependency with "
+ # '`pip install viscy".[metrics]"`'
+ # )
+ logging.warning(
"CellPose not installed. "
"Please install the metrics dependency with "
- '`pip install viscy".[metrics]"`'
+ '`pip install viscy"[metrics]"`'
)
+
if self.test_cellpose_model_path is not None:
self.cellpose_model = CellposeModel(
model_type=self.test_cellpose_model_path, device=self.device
@@ -341,9 +391,12 @@ def configure_optimizers(self):
)
return [optimizer], [scheduler]
- @staticmethod
- def _detach_sample(imgs: Sequence[torch.Tensor]):
- return [np.squeeze(img[0].detach().cpu().numpy().max(axis=(1))) for img in imgs]
+ def _detach_sample(self, imgs: Sequence[torch.Tensor]):
+ num_samples = min(imgs[0].shape[0], self.log_num_samples)
+ return [
+ [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs]
+ for i in range(num_samples)
+ ]
def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
images_grid = []
diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py
index 45245387..e94db68f 100644
--- a/viscy/light/predict_writer.py
+++ b/viscy/light/predict_writer.py
@@ -2,11 +2,12 @@
import os
from typing import Literal, Optional, Sequence
+import numpy as np
import torch
from iohub.ngff import ImageArray, _pad_shape, open_ome_zarr
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import BasePredictionWriter
-from numpy.typing import DTypeLike
+from numpy.typing import DTypeLike, NDArray
from viscy.light.data import HCSDataModule, Sample
@@ -14,20 +15,35 @@
_logger = logging.getLogger("lightning.pytorch")
-def _resize_image(image: ImageArray, t_index: int, z_index: int):
- """Resize image array if incoming T and Z index is not within bounds."""
- if image.shape[0] <= t_index or image.shape[2] <= z_index:
+def _resize_image(image: ImageArray, t_index: int, z_slice: slice) -> None:
+ """Resize image array if incoming (1, C, Z, Y, X) stack is not within bounds."""
+ if image.shape[0] <= t_index or image.shape[2] < z_slice.stop:
_logger.debug(
- f"Resizing image '{image.name}' {image.shape} for T={t_index}, Z={z_index}."
+ f"Resizing image '{image.name}' {image.shape} for "
+ f"T={t_index}, Z-sclice={z_slice}."
)
image.resize(
max(t_index + 1, image.shape[0]),
image.channels,
- max(z_index + 1, image.shape[2]),
+ max(z_slice.stop, image.shape[2]),
*image.shape[-2:],
)
+def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> None:
+ if z_slice.start == 0:
+ return new_stack
+ depth = z_slice.stop - z_slice.start
+ # relevant predictions to integrate
+ samples = min(z_slice.start + 1, depth)
+ factors = []
+ for i in reversed(list(range(depth))):
+ factors.append(min(i + 1, samples))
+ _logger.debug(f"Blending with factors {factors}.")
+ factors = np.array(factors)[np.newaxis :, np.newaxis, np.newaxis]
+ return old_stack * (factors - 1) / factors + new_stack / factors
+
+
class HCSPredictionWriter(BasePredictionWriter):
"""Callback to store virtual staining predictions as HCS OME-Zarr.
@@ -109,10 +125,11 @@ def write_sample(
z_index = int(z_index)
# account for lost slices in 2.5D
z_index += self.z_padding
+ z_slice = slice(z_index, z_index + sample_prediction.shape[-3])
image = self._create_image(
img_name, sample_prediction.shape, sample_prediction.dtype
)
- _resize_image(image, t_index, z_index)
+ _resize_image(image, t_index, z_slice)
if self.write_input:
source_stack = batch["source"][sample_index].cpu()
center_slice_index = source_stack.shape[-3] // 2
@@ -123,8 +140,11 @@ def write_sample(
image[t_index, self.target_index, z_index] = batch["target"][
sample_index
][:, center_slice_index].cpu()
- # write C1YX
- image.oindex[t_index, self.prediction_index, z_index] = sample_prediction[:, 0]
+ # write CZYX
+ if self.z_padding == 0 and sample_prediction.shape[-3] > 1:
+ old_stack = image.oindex[t_index, self.prediction_index, z_slice]
+ sample_prediction = _blend_in(old_stack, sample_prediction, z_slice)
+ image.oindex[t_index, self.prediction_index, z_slice] = sample_prediction
def _create_image(self, img_name: str, shape: tuple[int], dtype: DTypeLike):
if img_name in self.plate.zgroup:
diff --git a/viscy/scripts/count_flops.py b/viscy/scripts/count_flops.py
new file mode 100644
index 00000000..c96f5a6b
--- /dev/null
+++ b/viscy/scripts/count_flops.py
@@ -0,0 +1,26 @@
+# %%
+import torch
+from ptflops import get_model_complexity_info
+
+from viscy.light.engine import VSUNet
+
+# %%
+model = VSUNet(
+ architecture="2.2D",
+ model_config={
+ "in_channels": 1,
+ "out_channels": 2,
+ "in_stack_depth": 5,
+ "backbone": "convnextv2_tiny",
+ "stem_kernel_size": (5, 4, 4),
+ },
+)
+
+# %%
+with torch.cuda.device(0):
+ macs, params = get_model_complexity_info(
+ model,
+ (1, 5, 2048, 2048), # print_per_layer_stat=False
+ )
+print(macs, params)
+# %%
diff --git a/viscy/scripts/network_diagram.py b/viscy/scripts/network_diagram.py
index 06c4d16a..94b56f1b 100644
--- a/viscy/scripts/network_diagram.py
+++ b/viscy/scripts/network_diagram.py
@@ -3,32 +3,96 @@
from viscy.light.engine import VSUNet
-# %%
+# %% 2D UNet
model = VSUNet(
+ architecture="2D",
+ model_config={
+ "in_channels": 2,
+ "out_channels": 1,
+ },
+)
+
+model_graph = draw_graph(
+ model,
+ model.example_input_array,
+ graph_name="2D UNet",
+ roll=True,
+ depth=2,
+ # graph_dir="LR",
+ # save_graph=True,
+)
+
+graph2d = model_graph.visual_graph
+graph2d
+
+# %% 2.5D UNet
+model = VSUNet(
+ architecture="2.5D",
model_config={
- "architecture": "2.1D",
"in_channels": 1,
- "out_channels": 2,
+ "out_channels": 3,
"in_stack_depth": 9,
- "backbone": "convnextv2_femto",
- "stem_kernel_size": (3, 4, 4),
},
- batch_size=32,
)
+
+model_graph = draw_graph(
+ model,
+ model.example_input_array,
+ graph_name="2.5D UNet",
+ roll=True,
+ depth=2,
+)
+
+graph25d = model_graph.visual_graph
+graph25d
+
# %%
+# 2.1D UNet without upsampling in Z.
+model = VSUNet(
+ architecture="2.1D",
+ model_config={
+ "in_channels": 2,
+ "out_channels": 1,
+ "in_stack_depth": 9,
+ "backbone": "convnextv2_tiny",
+ "stem_kernel_size": (3, 1, 1),
+ "decoder_mode": "pixelshuffle",
+ },
+)
+
model_graph = draw_graph(
model,
model.example_input_array,
- # model.example_input_array,
graph_name="2.1D UNet",
roll=True,
- depth=2,
- # graph_dir="LR",
- directory="/hpc/projects/comp.micro/virtual_staining/models/HEK_phase_to_nuc_mem/",
- # save_graph=True,
+ depth=3,
)
-graph = model_graph.visual_graph
-graph
+graph21d = model_graph.visual_graph
+graph21d
# %%
-model_graph.visual_graph.render(format="svg")
+# 2.1D UNet with upsampling in Z.
+model = VSUNet(
+ architecture="2.2D",
+ model_config={
+ "in_channels": 2,
+ "out_channels": 1,
+ "in_stack_depth": 9,
+ "backbone": "convnextv2_tiny",
+ "stem_kernel_size": (3, 2, 2),
+ "decoder_mode": "deconv",
+ },
+)
+
+model_graph = draw_graph(
+ model,
+ model.example_input_array,
+ graph_name="2.2D UNet",
+ roll=True,
+ depth=3,
+)
+
+graph22d = model_graph.visual_graph
+graph22d
+# %% If you want to save the graphs as SVG files:
+# model_graph.visual_graph.render(format="svg")
diff --git a/viscy/unet/networks/Unet21D.py b/viscy/unet/networks/Unet21D.py
index a971eff0..4cdbf4e8 100644
--- a/viscy/unet/networks/Unet21D.py
+++ b/viscy/unet/networks/Unet21D.py
@@ -1,12 +1,25 @@
-from typing import Sequence, Union
+from typing import Literal, Sequence
import timm
import torch
-from monai.networks.blocks import ResidualUnit, UnetrUpBlock
+from monai.networks.blocks import ResidualUnit, UpSample
from monai.networks.blocks.dynunet_block import get_conv_layer
from torch import nn
+def _get_convnext_stage(in_channels: int, out_channels: int, depth: int):
+ return timm.models.convnext.ConvNeXtStage(
+ in_chs=in_channels,
+ out_chs=out_channels,
+ stride=1,
+ depth=depth,
+ ls_init_value=None,
+ use_grn=True,
+ norm_layer=timm.layers.LayerNorm2d,
+ norm_layer_cl=timm.layers.LayerNorm,
+ )
+
+
class Conv21dStem(nn.Module):
"""Stem for 2.1D networks."""
@@ -34,57 +47,137 @@ def forward(self, x: torch.Tensor):
return x.reshape(b, c * d, h, w)
+class Unet2dUpStage(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ scale_factor: int,
+ mode: Literal["deconv", "pixelshuffle"],
+ conv_blocks: int,
+ norm_name: str,
+ ) -> None:
+ super().__init__()
+ spatial_dims = 2
+ if mode == "deconv":
+ self.upsample = (
+ get_conv_layer(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=scale_factor,
+ kernel_size=scale_factor,
+ norm=norm_name,
+ is_transposed=True,
+ ),
+ )
+ self.conv = nn.Sequential(
+ ResidualUnit(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=in_channels,
+ norm=norm_name,
+ ),
+ nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)),
+ )
+ elif mode == "pixelshuffle":
+ self.upsample = UpSample(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ scale_factor=scale_factor,
+ mode=mode,
+ pre_conv="default",
+ apply_pad_pool=True,
+ )
+ self.conv = _get_convnext_stage(
+ out_channels + out_channels, out_channels, conv_blocks
+ )
+ self.conv.apply(timm.models.convnext._init_weights)
+
+ def forward(self, inp: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
+ """
+ :param torch.Tensor inp: Low resolution features
+ :param torch.Tensor skip: High resolution skip connection features
+ :return torch.Tensor: High resolution features
+ """
+ inp = self.upsample(inp)
+ inp = torch.cat([inp, skip], dim=1)
+ return self.conv(inp)
+
+
+class PixelToVoxelHead(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ out_stack_depth: int,
+ ) -> None:
+ super().__init__()
+ self.norm = timm.layers.LayerNorm2d(num_channels=in_channels)
+ self.gelu = nn.GELU()
+ self.conv = nn.Conv3d(
+ in_channels // out_stack_depth,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+ self.out_stack_depth = out_stack_depth
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.norm(x)
+ x = self.gelu(x)
+ b, c, h, w = x.shape
+ x = x.reshape((b, c // self.out_stack_depth, self.out_stack_depth, h, w))
+ x = self.conv(x)
+ return x
+
+
+class UnsqueezeHead(nn.Module):
+ """Unsqueeze 2D (B, C, H, W) feature map to 3D (B, C, 1, H, W) output"""
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x.unsqueeze(2)
+ return x
+
+
class Unet2dDecoder(nn.Module):
def __init__(
self,
num_channels: list[int],
out_channels: int,
- res_block: bool,
norm_name: str,
- kernel_size: Union[int, tuple[int, int]],
- last_kernel_size: Union[int, tuple[int, int]],
- dropout: float = 0,
+ mode: Literal["deconv", "pixelshuffle"],
+ conv_blocks: int,
+ strides: list[int],
) -> None:
super().__init__()
- decoder_stages = []
+ self.decoder_stages = nn.ModuleList([])
stages = len(num_channels)
- num_channels.append(out_channels)
- stride = 2
+ num_channels.append(num_channels[-1])
for i in range(stages):
- stage = UnetrUpBlock(
- spatial_dims=2,
+ stride = strides[i]
+ stage = Unet2dUpStage(
in_channels=num_channels[i],
out_channels=num_channels[i + 1],
- kernel_size=kernel_size,
- upsample_kernel_size=stride,
+ scale_factor=stride,
+ mode=mode,
+ conv_blocks=conv_blocks,
norm_name=norm_name,
- res_block=res_block,
)
- decoder_stages.append(stage)
- self.decoder_stages = nn.ModuleList(decoder_stages)
- self.head = nn.Sequential(
- get_conv_layer(
- spatial_dims=2,
- in_channels=num_channels[-2],
- out_channels=num_channels[-2],
- stride=last_kernel_size,
- kernel_size=last_kernel_size,
- norm=norm_name,
- is_transposed=True,
- ),
- ResidualUnit(
- spatial_dims=2,
- in_channels=num_channels[-2],
- out_channels=num_channels[-2],
- kernel_size=kernel_size,
- norm=norm_name,
- dropout=dropout,
- ),
- nn.Conv2d(
- num_channels[-2],
- out_channels,
- kernel_size=(1, 1),
- ),
+ self.decoder_stages.append(stage)
+ self.head = UpSample(
+ spatial_dims=2,
+ in_channels=num_channels[-1],
+ out_channels=out_channels,
+ scale_factor=strides[-1],
+ mode=mode,
+ pre_conv="default",
+ apply_pad_pool=False,
)
def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor:
@@ -101,12 +194,15 @@ def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
- in_stack_depth: int = 9,
+ in_stack_depth: int = 5,
+ out_stack_depth: int = 1,
backbone: str = "convnextv2_tiny",
pretrained: bool = False,
- stem_kernel_size: tuple[int, int, int] = (3, 4, 4),
- decoder_res_block: bool = True,
+ stem_kernel_size: tuple[int, int, int] = (5, 4, 4),
+ decoder_mode: Literal["deconv", "pixelshuffle"] = "pixelshuffle",
+ decoder_conv_blocks: int = 2,
decoder_norm_layer: str = "instance",
+ drop_path_rate: float = 0.0,
) -> None:
super().__init__()
if in_stack_depth % stem_kernel_size[0] != 0:
@@ -114,8 +210,17 @@ def __init__(
f"Input stack depth {in_stack_depth} is not divisible "
f"by stem kernel depth {stem_kernel_size[0]}."
)
+ if not (in_stack_depth == out_stack_depth or out_stack_depth == 1):
+ raise ValueError(
+ "`out_stack_depth` must be either 1 or "
+ f"the same as `input_stack_depth` ({in_stack_depth}), "
+ f"but got {out_stack_depth}."
+ )
multi_scale_encoder = timm.create_model(
- backbone, pretrained=pretrained, features_only=True
+ backbone,
+ pretrained=pretrained,
+ features_only=True,
+ drop_path_rate=drop_path_rate,
)
num_channels = multi_scale_encoder.feature_info.channels()
# replace first convolution layer with a projection tokenizer
@@ -126,21 +231,34 @@ def __init__(
)
decoder_channels = num_channels
decoder_channels.reverse()
+ if out_stack_depth == 1:
+ decoder_out_channels = out_channels
+ self.head = UnsqueezeHead()
+ else:
+ decoder_out_channels = (
+ out_stack_depth * decoder_channels[-1] // stem_kernel_size[-1] ** 2
+ )
+ self.head = PixelToVoxelHead(
+ decoder_out_channels, out_channels, out_stack_depth
+ )
self.decoder = Unet2dDecoder(
decoder_channels,
- out_channels,
- res_block=decoder_res_block,
+ decoder_out_channels,
norm_name=decoder_norm_layer,
- kernel_size=3,
- last_kernel_size=stem_kernel_size[-2:],
+ mode=decoder_mode,
+ conv_blocks=decoder_conv_blocks,
+ strides=[2] * len(num_channels) + [stem_kernel_size[-1]],
)
- # shape compatibility
- self.num_blocks = 6
+ self.out_stack_depth = out_stack_depth
+
+ @property
+ def num_blocks(self) -> int:
+ """2-times downscaling factor of the smallest feature map"""
+ return 6
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.stem(x)
x: list = self.encoder_stages(x)
x.reverse()
x = self.decoder(x)
- # add Z/depth back
- return x.unsqueeze(2)
+ return self.head(x)