From 3f46b3b8e210c478ecceb7b1ec05618c2abb045f Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Wed, 30 Aug 2023 22:26:14 -0400 Subject: [PATCH] dlmbl 2023 archive (#44) * final version of dlmbl 2023 demo * fix typos --- examples/demo_dlmbl/README.md | 48 ++ examples/demo_dlmbl/convert-solution.py | 41 -- examples/demo_dlmbl/exercise.ipynb | 754 ------------------- examples/demo_dlmbl/prepare-exercise.sh | 10 - examples/demo_dlmbl/setup.sh | 35 +- examples/demo_dlmbl/solution.ipynb | 929 ------------------------ examples/demo_dlmbl/solution.py | 635 ++++++++++------ 7 files changed, 497 insertions(+), 1955 deletions(-) create mode 100644 examples/demo_dlmbl/README.md delete mode 100644 examples/demo_dlmbl/convert-solution.py delete mode 100644 examples/demo_dlmbl/exercise.ipynb delete mode 100644 examples/demo_dlmbl/prepare-exercise.sh delete mode 100644 examples/demo_dlmbl/solution.ipynb diff --git a/examples/demo_dlmbl/README.md b/examples/demo_dlmbl/README.md new file mode 100644 index 00000000..eb7917c8 --- /dev/null +++ b/examples/demo_dlmbl/README.md @@ -0,0 +1,48 @@ +# Exercise 4: Image translation + +This demo script was developed for the DL@MBL 2023 course by Ziwen Liu and Shalin Mehta, with many inputs and bugfixes by [Morgan Schwartz](https://github.com/msschwartz21), [Caroline Malin-Mayor](https://github.com/cmalinmayor), and [Peter Park](https://github.com/peterhpark). + + + + +## Setup + +Make sure that you are inside of the `image_translation` folder by using the `cd` command to change directories if needed. + +Make sure that you can use mamba to switch environments. + +```bash +mamba init +``` + +**Close your shell, and login again.** + +Run the setup script to create the environment for this exercise and download the dataset. +```bash +sh setup.sh +``` +Activate your environment +```bash +mamba activate 04_image_translation +``` + +## Use vscode + +Install vscode, install jupyter extension inside vscode, and setup [cell mode](https://code.visualstudio.com/docs/python/jupyter-support-py). Open [solution.py](solution.py) and run the script interactively. + +## Use Jupyter Notebook + +The matching exercise and solution notebooks can be found [here](https://github.com/dlmbl/image_translation/tree/28e0e515b4a8ad3f392a69c8341e105f730d204f) on the course repository. + +Launch a jupyter environment + +``` +jupyter notebook +``` + +...and continue with the instructions in the notebook. + +If 04_image_translation is not available as a kernel in jupyter, run +``` +python -m ipykernel install --user --name=04_image_translation +``` diff --git a/examples/demo_dlmbl/convert-solution.py b/examples/demo_dlmbl/convert-solution.py deleted file mode 100644 index 279f7874..00000000 --- a/examples/demo_dlmbl/convert-solution.py +++ /dev/null @@ -1,41 +0,0 @@ -import argparse -from traitlets.config import Config -import nbformat as nbf -from nbconvert.preprocessors import TagRemovePreprocessor, ClearOutputPreprocessor -from nbconvert.exporters import NotebookExporter - - -def get_arg_parser(): - parser = argparse.ArgumentParser() - - parser.add_argument('input_file') - parser.add_argument('output_file') - - return parser - - -def convert(input_file, output_file): - c = Config() - c.TagRemovePreprocessor.remove_cell_tags = ("solution",) - c.TagRemovePreprocessor.enabled = True - c.ClearOutputPreprocesser.enabled = True - c.NotebookExporter.preprocessors = [ - "nbconvert.preprocessors.TagRemovePreprocessor", - "nbconvert.preprocessors.ClearOutputPreprocessor" - ] - - exporter = NotebookExporter(config=c) - exporter.register_preprocessor(TagRemovePreprocessor(config=c), True) - exporter.register_preprocessor(ClearOutputPreprocessor(), True) - - output = NotebookExporter(config=c).from_filename(input_file) - with open(output_file, 'w') as f: - f.write(output[0]) - - -if __name__ == "__main__": - parser = get_arg_parser() - args = parser.parse_args() - - convert(args.input_file, args.output_file) - print(f'Converted {args.input_file} to {args.output_file}') diff --git a/examples/demo_dlmbl/exercise.ipynb b/examples/demo_dlmbl/exercise.ipynb deleted file mode 100644 index b6a3a256..00000000 --- a/examples/demo_dlmbl/exercise.ipynb +++ /dev/null @@ -1,754 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "c9438eb5", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "# Image translation\n", - "---\n", - "\n", - "Written by Ziwen Liu and Shalin Mehta, CZ Biohub San Francisco.\n", - "---\n", - "\n", - "In this exercise, we will solve an image translation task to predict fluorescence images of nuclei and membrane markers from quantitative phase images of cells. In other words, we will _virtually stain_ the nuclei and membrane visible in the phase image. \n", - "\n", - "Here, the source domain is label-free microscopy (material density) and the target domain is fluorescence microscopy (fluorophore density). The goal is to learn a mapping from the source domain to the target domain. We will use a deep convolutional neural network (CNN), specifically, a U-Net model with residual connections to learn the mapping. The preprocessing, training, prediction, evaluation, and deployment steps are unified in a computer vision pipeline for single-cell analysis that we call [VisCy](https://github.com/mehta-lab/VisCy).\n", - "\n", - "VisCy evolved from our previous work on virtual staining of cellular components from their density and anisotropy.\n", - "![](https://iiif.elifesciences.org/lax/55502%2Felife-55502-fig1-v2.tif/full/1500,/0/default.jpg)\n", - "\n", - "[Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning\n", - ". eLife](https://elifesciences.org/articles/55502).\n", - "\n", - "VisCy exploits recent advances in the data and metadata formats ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). Our previous pipeline, [microDL](https://github.com/mehta-lab/microDL), is deprecated and is now a public archive." - ] - }, - { - "cell_type": "markdown", - "id": "b36463af", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Today, we will train a 2D image translation model using a 2D U-Net with residual connections. We will use a dataset of 301 fields of view (FOVs) of Human Embryonic Kidney (HEK) cells, each FOV has 3 channels (phase, membrane, and nuclei). The cells were labeled with CRISPR editing. Intrestingly, not all cells during this experiment were labeled due to the stochastic nature of CRISPR editing. In such situations, virtual staining rescues missing labels.\n", - "![HEK](https://github.com/mehta-lab/VisCy/blob/dlmbl2023/docs/figures/phase_to_nuclei_membrane.svg?raw=true)\n", - "\n", - "
\n", - "The exercise is organized in 3 parts.\n", - "\n", - "* **Part 1** - Explore the data using tensorboard. Launch the training before lunch.\n", - "* Lunch break - The model will continue training during lunch.\n", - "* **Part 2** - Evaluate the training with tensorboard. Train another model.\n", - "* **Part 3** - Tune the models to improve performance.\n", - "
\n", - "\n", - "📖 As you work through parts 2 and 3, please share the layouts of the models you train and their performance with everyone via [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) 📖.\n", - "\n", - "\n", - "Our guesstimate is that each of the three parts will take ~1.5 hours, but don't rush parts 1 and 2 if you need more time with them.\n", - "We will discuss your observations on google doc after checkpoints 2 and 3. The exercise is focused on understanding information contained in data, process of training and evaluating image translation models, and parameter exploration.\n", - "There are a few coding tasks sprinkled in.\n", - "\n", - "\n", - "Before you start,\n", - "\n", - "
\n", - "Set your python kernel to 04-image-translation\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "ff42876f", - "metadata": { - "cell_marker": "\"\"\"", - "incorrectly_encoded_metadata": "id='1_phase2fluor'>", - "title": "\n", - "Task 1.1\n", - "Use \n", - "iohub.open_ome_zarr to read the dataset and explore several FOVs with matplotlib.\n", - "\n", - "\n", - "There should be 301 FOVs in the dataset (12 GB compressed).\n", - "\n", - "Each FOV consists of 3 channels of 2048x2048 images,\n", - "saved in the \n", - "High-Content Screening (HCS) layout\n", - "specified by the Open Microscopy Environment Next Generation File Format\n", - "(OME-NGFF).\n", - "\n", - "The layout on the disk is: row/col/field/pyramid_level/timepoint/channel/z/y/x.\n", - "Notice that labelling of nuclei channel is not complete - some cells are not expressing the fluorescent protein." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e0241037", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "dataset = open_ome_zarr(data_path)\n", - "\n", - "print(f\"Number of positions: {len(list(dataset.positions()))}\")\n", - "\n", - "# Use the field and pyramid_level below to visualize data.\n", - "row = \"0\"\n", - "col = \"0\"\n", - "field = \"23\"\n", - "\n", - "# This dataset contains images at 3 resolutions.\n", - "# '0' is the highest resolution\n", - "# '1' is down-scaled 2x2,\n", - "# '2' is down-scaled 4x4.\n", - "# Such datasets are called image pyramids.\n", - "pyaramid_level = \"0\"\n", - "\n", - "# `channel_names` is the metadata that is stored with data according to the OME-NGFF spec.\n", - "n_channels = len(dataset.channel_names)\n", - "\n", - "image = dataset[f\"{row}/{col}/{field}/{pyaramid_level}\"].numpy()\n", - "print(f\"data shape: {image.shape}, FOV: {field}, pyramid level: {pyaramid_level}\")\n", - "\n", - "figure, axes = plt.subplots(1, n_channels, figsize=(9, 3))\n", - "\n", - "for i in range(n_channels):\n", - " for i in range(n_channels):\n", - " channel_image = image[0, i, 0]\n", - " # Adjust contrast to 0.5th and 99.5th percentile of pixel values.\n", - " p_low, p_high = np.percentile(channel_image, (0.5, 99.5))\n", - " channel_image = np.clip(channel_image, p_low, p_high)\n", - " axes[i].imshow(channel_image, cmap=\"gray\")\n", - " axes[i].axis(\"off\")\n", - " axes[i].set_title(dataset.channel_names[i])\n", - "plt.tight_layout()" - ] - }, - { - "cell_type": "markdown", - "id": "89e676bf", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "## Initialize data loaders and see the samples in tensorboard.\n", - "\n", - "
\n", - "Task 1.2\n", - "Setup the data loader and log several batches to tensorboard.\n", - "
`\n", - "\n", - "VisCy builds on top of PyTorch Lightning. PyTorch Lightning is a thin wrapper around PyTorch that allows rapid experimentation. It provides a [DataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) to handle loading and processing of data during training. VisCy provides a child class, `HCSDataModule` to make it intuitve to access data stored in the HCS layout.\n", - " \n", - "The dataloader in `HCSDataModule` returns a batch of samples. A `batch` is a list of dictionaries. The length of the list is equal to the batch size. Each dictionary consists of following key-value pairs.\n", - "- `source`: the input image, a tensor of size 1*1*Y*X\n", - "- `target`: the target image, a tensor of size 2*1*Y*X\n", - "- `index` : the tuple of (location of field in HCS layout, time, and z-slice) of the sample." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "05ce4dbb", - "metadata": {}, - "outputs": [], - "source": [ - "# Define a function to write a batch to tensorboard log.\n", - "\n", - "\n", - "def log_batch_tensorboard(batch, batchno, writer, card_name):\n", - " \"\"\"\n", - " Logs a batch of images to TensorBoard.\n", - "\n", - " Args:\n", - " batch (dict): A dictionary containing the batch of images to be logged.\n", - " writer (SummaryWriter): A TensorBoard SummaryWriter object.\n", - " card_name (str): The name of the card to be displayed in TensorBoard.\n", - "\n", - " Returns:\n", - " None\n", - " \"\"\"\n", - " batch_phase = batch[\"source\"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.\n", - " batch_membrane = batch[\"target\"][:, 1, 0, :, :].unsqueeze(\n", - " 1\n", - " ) # batch_size x 1 x Y x X tensor.\n", - " batch_nuclei = batch[\"target\"][:, 0, 0, :, :].unsqueeze(\n", - " 1\n", - " ) # batch_size x 1 x Y x X tensor.\n", - "\n", - " p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))\n", - " batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)\n", - "\n", - " p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))\n", - " batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)\n", - "\n", - " p1, p99 = np.percentile(batch_phase, (0.1, 99.9))\n", - " batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)\n", - "\n", - " [N, C, H, W] = batch_phase.shape\n", - " interleaved_images = torch.zeros((3 * N, C, H, W), dtype=batch_phase.dtype)\n", - " interleaved_images[0::3, :] = batch_phase\n", - " interleaved_images[1::3, :] = batch_nuclei\n", - " interleaved_images[2::3, :] = batch_membrane\n", - "\n", - " grid = torchvision.utils.make_grid(interleaved_images, nrow=3)\n", - "\n", - " # add the grid to tensorboard\n", - " writer.add_image(card_name, grid, batchno)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f627e8e8", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# Initialize the data module.\n", - "\n", - "BATCH_SIZE = 42\n", - "# 42 is a perfectly reasonable batch size. After all, it is the answer to the ultimate question of life, the universe and everything.\n", - "# More seriously, batch size does not have to be a power of 2.\n", - "# See: https://sebastianraschka.com/blog/2022/batch-size-2.html\n", - "\n", - "data_module = HCSDataModule(\n", - " data_path,\n", - " source_channel=\"Phase\",\n", - " target_channel=[\"Nuclei\", \"Membrane\"],\n", - " z_window_size=1,\n", - " split_ratio=0.8,\n", - " batch_size=BATCH_SIZE,\n", - " num_workers=8,\n", - " architecture=\"2D\",\n", - " yx_patch_size=(512, 512), # larger patch size makes it easy to see augmentations.\n", - " augment=False, # Turn off augmentation for now.\n", - ")\n", - "data_module.setup(\"fit\")\n", - "\n", - "print(\n", - " f\"FOVs in training set: {len(data_module.train_dataset)}, FOVs in validation set:{len(data_module.val_dataset)}\"\n", - ")\n", - "train_dataloader = data_module.train_dataloader()\n", - "\n", - "# Instantiate the tensorboard SummaryWriter, logs the first batch and then iterates through all the batches and logs them to tensorboard.\n", - "\n", - "writer = SummaryWriter(log_dir=f\"{log_dir}/view_batch\")\n", - "# Draw a batch and write to tensorboard.\n", - "batch = next(iter(train_dataloader))\n", - "log_batch_tensorboard(batch, 0, writer, \"augmentation/none\")\n", - "\n", - "# Iterate through all the batches and log them to tensorboard.\n", - "for i, batch in enumerate(train_dataloader):\n", - " log_batch_tensorboard(batch, i, writer, \"augmentation/none\")\n", - "writer.close()" - ] - }, - { - "cell_type": "markdown", - "id": "f6ffc7d1", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "There are multiple ways of seeing the tensorboard.\n", - "1. Jupyter lab forwards the tensorboard port to the browser. Go to http://localhost:6006/ to see the tensorboard.\n", - "2. You likely have an open viewer in the first cell where you loaded tensorboard jupyter extension.\n", - "3. If you want to see tensorboard in a specific cell, use the following code.\n", - "```\n", - "notebook.list() # View open TensorBoard instances\n", - "notebook.display(port=6006, height=800) # Display the TensorBoard instance specified by the port.\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "ed1448ff", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "## View augmentations using tensorboard.\n", - "\n", - "
\n", - "Task 1.3\n", - "Turn on augmentation and view the batch in tensorboard." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "245ae4dd", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "##########################\n", - "######## TODO ########\n", - "##########################\n", - "\n", - "# Write code to turn on augmentations, change batch sizes and log them to tensorboard.\n", - "# See how the training data changes as a function of these parameters.\n", - "# Remember to call `data_module.setup(\"fit\")` after changing the parameters." - ] - }, - { - "cell_type": "markdown", - "id": "85743be9", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "## Construct a 2D U-Net for image translation.\n", - "See ``viscy.unet.networks.Unet2D.Unet2d`` for configuration details.\n", - "We setup a fresh data module and instantiate the trainer class." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "44fddf02", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "\n", - "# The entire training loop is contained in this cell.\n", - "\n", - "GPU_ID = 0\n", - "BATCH_SIZE = 10\n", - "YX_PATCH_SIZE = (512, 512)\n", - "\n", - "\n", - "# Dictionary that specifies key parameters of the model.\n", - "phase2fluor_config = {\n", - " \"architecture\": \"2D\",\n", - " \"num_filters\": [24, 48, 96, 192, 384],\n", - " \"in_channels\": 1,\n", - " \"out_channels\": 2,\n", - " \"residual\": True,\n", - " \"dropout\": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data.\n", - " \"task\": \"reg\", # reg = regression task.\n", - "}\n", - "\n", - "phase2fluor_model = VSUNet(\n", - " model_config=phase2fluor_config.copy(),\n", - " batch_size=BATCH_SIZE,\n", - " loss_function=torch.nn.functional.l1_loss,\n", - " schedule=\"WarmupCosine\",\n", - " log_num_samples=10, # Number of samples from each batch to log to tensorboard.\n", - " example_input_yx_shape=YX_PATCH_SIZE,\n", - ")\n", - "\n", - "# Reinitialize the data module.\n", - "phase2fluor_data = HCSDataModule(\n", - " data_path,\n", - " source_channel=\"Phase\",\n", - " target_channel=[\"Nuclei\", \"Membrane\"],\n", - " z_window_size=1,\n", - " split_ratio=0.8,\n", - " batch_size=BATCH_SIZE,\n", - " num_workers=8,\n", - " architecture=\"2D\",\n", - " yx_patch_size=YX_PATCH_SIZE,\n", - " augment=True,\n", - ")\n", - "phase2fluor_data.setup(\"fit\")" - ] - }, - { - "cell_type": "markdown", - "id": "292cb22a", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "
\n", - "Task 1.4\n", - "Setup the training for ~30 epochs\n", - "
\n", - "\n", - "Tips:\n", - "- Set ``default_root_dir`` to store the logs and checkpoints\n", - "in a specific directory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b9cf215b", - "metadata": { - "lines_to_next_cell": 2, - "title": "Setup trainer and check for errors." - }, - "outputs": [], - "source": [ - "\n", - "# fast_dev_run runs a single batch of data through the model to check for errors.\n", - "trainer = VSTrainer(accelerator=\"gpu\", devices=[GPU_ID], fast_dev_run=True)\n", - "\n", - "# trainer class takes the model and the data module as inputs.\n", - "trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5f11d604", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "\n", - "GPU_ID = 0\n", - "n_samples = len(phase2fluor_data.train_dataset)\n", - "steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch.\n", - "n_epochs = 30\n", - "\n", - "trainer = VSTrainer(\n", - " accelerator=\"gpu\",\n", - " devices=[GPU_ID],\n", - " max_epochs=n_epochs,\n", - " log_every_n_steps=steps_per_epoch // 2,\n", - " # log losses and image samples 2 times per epoch.\n", - " default_root_dir=Path(\n", - " log_dir, \"phase2fluor\"\n", - " ), # lightning trainer transparently saves logs and model checkpoints in this directory.\n", - ")\n", - "\n", - "# Log graph\n", - "trainer.logger.log_graph(phase2fluor_model, phase2fluor_data.train_dataset[0][\"source\"])\n", - "# Launch training.\n", - "trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)" - ] - }, - { - "cell_type": "markdown", - "id": "157a989c", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "
\n", - "Checkpoint 1\n", - "\n", - "Now the training has started,\n", - "we can come back after a while and evaluate the performance!\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "d68dab93", - "metadata": { - "cell_marker": "\"\"\"", - "incorrectly_encoded_metadata": "id='1_fluor2phase'>", - "title": "\n", - "Checkpoint 2\n", - "Please summarize hyperparameters and performance of your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z)\n", - "\n", - "Now that you have trained two models, let's think about the following questions:\n", - "- What is the information content of each channel in the dataset?\n", - "- How would you use image translation models?\n", - "- What can you try to improve the performance of each model?\n", - "\n", - "\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "7d0c4204", - "metadata": { - "cell_marker": "\"\"\"", - "incorrectly_encoded_metadata": "id='3_tuning'>", - "title": "\n", - "Checkpoint 3\n", - "\n", - "Congratulations! You have trained several image translation models now!\n", - "Please document hyperparameters, snapshots of predictioons on validation set, and loss curves for your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z)\n", - "" - ] - } - ], - "metadata": { - "jupytext": { - "cell_metadata_filter": "all", - "main_language": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/demo_dlmbl/prepare-exercise.sh b/examples/demo_dlmbl/prepare-exercise.sh deleted file mode 100644 index 88f2ba1b..00000000 --- a/examples/demo_dlmbl/prepare-exercise.sh +++ /dev/null @@ -1,10 +0,0 @@ -# Run black on .py files -# black solution.py - -# Convert .py to ipynb -# "cell_metadata_filter": "all" preserve cell tags including our solution tags -jupytext --to ipynb --update-metadata '{"jupytext": {"cell_metadata_filter":"all"}}' solution.py - -# Create the exercise notebook by removing cell outputs and deleting cells tagged with "solution" -# There is a bug in the nbconvert cli so we need to use the python API instead -python convert-solution.py solution.ipynb exercise.ipynb diff --git a/examples/demo_dlmbl/setup.sh b/examples/demo_dlmbl/setup.sh index 6e1f1b57..9472b16b 100644 --- a/examples/demo_dlmbl/setup.sh +++ b/examples/demo_dlmbl/setup.sh @@ -1,35 +1,32 @@ #!/usr/bin/env -S bash -i +START_DIR=$(pwd) + # Create mamba environment -mamba create --name 04_image_translation python=3.10 +mamba create -y --name 04_image_translation python=3.10 # Install ipykernel in the environment. -mamba install -y ipykernel nbformat nbconvert black jupytext --name 04_image_translation +mamba install -y ipykernel nbformat nbconvert black jupytext ipywidgets --name 04_image_translation # Specifying the environment explicitly. # mamba activate sometimes doesn't work from within shell scripts. -mamba install -y nbformat --name 04_image_translation - +# install viscy and its dependencies`s in the environment using pip. +mkdir -p ~/code/ +cd ~/code/ +git clone https://github.com/mehta-lab/viscy.git +cd viscy +git checkout 7c5e4c1d68e70163cf514d22c475da8ea7dc3a88 # Exercise is tested with this commit of viscy # Find path to the environment - mamba activate doesn't work from within shell scripts. -ENV_PATH=$(conda info --envs | grep 04_image_translation | awk '{print $2}') -# install viscy and its dependencies in the environment using pip. -$ENV_PATH/bin/pip install "viscy[metrics] @ git+https://github.com/mehta-lab/viscy.git@dlmbl2023" - -# Store the code directory path. -CODE_DIR=$(pwd) - +ENV_PATH=$(conda info --envs | grep 04_image_translation | awk '{print $NF}') +$ENV_PATH/bin/pip install ."[metrics]" # Create data directory mkdir -p ~/data/04_image_translation cd ~/data/04_image_translation wget https://dl-at-mbl-2023-data.s3.us-east-2.amazonaws.com/DLMBL2023_image_translation_data_pyramid.tar.gz -wget https://dl-at-mbl-2023-data.s3.us-east-2.amazonaws.com/DLMBL2023_image_translation_test.tar.gz -tar -xzf DLMBL2023_image_translation_data_pyramid.tar.gz +wget https://dl-at-mbl-2023-data.s3.us-east-2.amazonaws.com/DLMBL2023_image_translation_test.tar.gz +tar -xzf DLMBL2023_image_translation_data_pyramid.tar.gz tar -xzf DLMBL2023_image_translation_test.tar.gz -# Go back to the code directory -cd $CODE_DIR - - -# this didn't not work from within shell scripts on TA1 node even after mamba init. -# mamba activate 04_image_translation \ No newline at end of file +# Change back to the starting directory +cd $START_DIR diff --git a/examples/demo_dlmbl/solution.ipynb b/examples/demo_dlmbl/solution.ipynb deleted file mode 100644 index c97b8577..00000000 --- a/examples/demo_dlmbl/solution.ipynb +++ /dev/null @@ -1,929 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "c9438eb5", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "# Image translation\n", - "---\n", - "\n", - "Written by Ziwen Liu and Shalin Mehta, CZ Biohub San Francisco.\n", - "---\n", - "\n", - "In this exercise, we will solve an image translation task to predict fluorescence images of nuclei and membrane markers from quantitative phase images of cells. In other words, we will _virtually stain_ the nuclei and membrane visible in the phase image. \n", - "\n", - "Here, the source domain is label-free microscopy (material density) and the target domain is fluorescence microscopy (fluorophore density). The goal is to learn a mapping from the source domain to the target domain. We will use a deep convolutional neural network (CNN), specifically, a U-Net model with residual connections to learn the mapping. The preprocessing, training, prediction, evaluation, and deployment steps are unified in a computer vision pipeline for single-cell analysis that we call [VisCy](https://github.com/mehta-lab/VisCy).\n", - "\n", - "VisCy evolved from our previous work on virtual staining of cellular components from their density and anisotropy.\n", - "![](https://iiif.elifesciences.org/lax/55502%2Felife-55502-fig1-v2.tif/full/1500,/0/default.jpg)\n", - "\n", - "[Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning\n", - ". eLife](https://elifesciences.org/articles/55502).\n", - "\n", - "VisCy exploits recent advances in the data and metadata formats ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). Our previous pipeline, [microDL](https://github.com/mehta-lab/microDL), is deprecated and is now a public archive." - ] - }, - { - "cell_type": "markdown", - "id": "b36463af", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Today, we will train a 2D image translation model using a 2D U-Net with residual connections. We will use a dataset of 301 fields of view (FOVs) of Human Embryonic Kidney (HEK) cells, each FOV has 3 channels (phase, membrane, and nuclei). The cells were labeled with CRISPR editing. Intrestingly, not all cells during this experiment were labeled due to the stochastic nature of CRISPR editing. In such situations, virtual staining rescues missing labels.\n", - "![HEK](https://github.com/mehta-lab/VisCy/blob/dlmbl2023/docs/figures/phase_to_nuclei_membrane.svg?raw=true)\n", - "\n", - "
\n", - "The exercise is organized in 3 parts.\n", - "\n", - "* **Part 1** - Explore the data using tensorboard. Launch the training before lunch.\n", - "* Lunch break - The model will continue training during lunch.\n", - "* **Part 2** - Evaluate the training with tensorboard. Train another model.\n", - "* **Part 3** - Tune the models to improve performance.\n", - "
\n", - "\n", - "📖 As you work through parts 2 and 3, please share the layouts of the models you train and their performance with everyone via [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) 📖.\n", - "\n", - "\n", - "Our guesstimate is that each of the three parts will take ~1.5 hours, but don't rush parts 1 and 2 if you need more time with them.\n", - "We will discuss your observations on google doc after checkpoints 2 and 3. The exercise is focused on understanding information contained in data, process of training and evaluating image translation models, and parameter exploration.\n", - "There are a few coding tasks sprinkled in.\n", - "\n", - "\n", - "Before you start,\n", - "\n", - "
\n", - "Set your python kernel to 04-image-translation\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "ff42876f", - "metadata": { - "cell_marker": "\"\"\"", - "incorrectly_encoded_metadata": "id='1_phase2fluor'>", - "title": "\n", - "Task 1.1\n", - "Use \n", - "iohub.open_ome_zarr to read the dataset and explore several FOVs with matplotlib.\n", - "\n", - "\n", - "There should be 301 FOVs in the dataset (12 GB compressed).\n", - "\n", - "Each FOV consists of 3 channels of 2048x2048 images,\n", - "saved in the \n", - "High-Content Screening (HCS) layout\n", - "specified by the Open Microscopy Environment Next Generation File Format\n", - "(OME-NGFF).\n", - "\n", - "The layout on the disk is: row/col/field/pyramid_level/timepoint/channel/z/y/x.\n", - "Notice that labelling of nuclei channel is not complete - some cells are not expressing the fluorescent protein." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e0241037", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "dataset = open_ome_zarr(data_path)\n", - "\n", - "print(f\"Number of positions: {len(list(dataset.positions()))}\")\n", - "\n", - "# Use the field and pyramid_level below to visualize data.\n", - "row = \"0\"\n", - "col = \"0\"\n", - "field = \"23\"\n", - "\n", - "# This dataset contains images at 3 resolutions.\n", - "# '0' is the highest resolution\n", - "# '1' is down-scaled 2x2,\n", - "# '2' is down-scaled 4x4.\n", - "# Such datasets are called image pyramids.\n", - "pyaramid_level = \"0\"\n", - "\n", - "# `channel_names` is the metadata that is stored with data according to the OME-NGFF spec.\n", - "n_channels = len(dataset.channel_names)\n", - "\n", - "image = dataset[f\"{row}/{col}/{field}/{pyaramid_level}\"].numpy()\n", - "print(f\"data shape: {image.shape}, FOV: {field}, pyramid level: {pyaramid_level}\")\n", - "\n", - "figure, axes = plt.subplots(1, n_channels, figsize=(9, 3))\n", - "\n", - "for i in range(n_channels):\n", - " for i in range(n_channels):\n", - " channel_image = image[0, i, 0]\n", - " # Adjust contrast to 0.5th and 99.5th percentile of pixel values.\n", - " p_low, p_high = np.percentile(channel_image, (0.5, 99.5))\n", - " channel_image = np.clip(channel_image, p_low, p_high)\n", - " axes[i].imshow(channel_image, cmap=\"gray\")\n", - " axes[i].axis(\"off\")\n", - " axes[i].set_title(dataset.channel_names[i])\n", - "plt.tight_layout()" - ] - }, - { - "cell_type": "markdown", - "id": "89e676bf", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "## Initialize data loaders and see the samples in tensorboard.\n", - "\n", - "
\n", - "Task 1.2\n", - "Setup the data loader and log several batches to tensorboard.\n", - "
`\n", - "\n", - "VisCy builds on top of PyTorch Lightning. PyTorch Lightning is a thin wrapper around PyTorch that allows rapid experimentation. It provides a [DataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) to handle loading and processing of data during training. VisCy provides a child class, `HCSDataModule` to make it intuitve to access data stored in the HCS layout.\n", - " \n", - "The dataloader in `HCSDataModule` returns a batch of samples. A `batch` is a list of dictionaries. The length of the list is equal to the batch size. Each dictionary consists of following key-value pairs.\n", - "- `source`: the input image, a tensor of size 1*1*Y*X\n", - "- `target`: the target image, a tensor of size 2*1*Y*X\n", - "- `index` : the tuple of (location of field in HCS layout, time, and z-slice) of the sample." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "05ce4dbb", - "metadata": {}, - "outputs": [], - "source": [ - "# Define a function to write a batch to tensorboard log.\n", - "\n", - "\n", - "def log_batch_tensorboard(batch, batchno, writer, card_name):\n", - " \"\"\"\n", - " Logs a batch of images to TensorBoard.\n", - "\n", - " Args:\n", - " batch (dict): A dictionary containing the batch of images to be logged.\n", - " writer (SummaryWriter): A TensorBoard SummaryWriter object.\n", - " card_name (str): The name of the card to be displayed in TensorBoard.\n", - "\n", - " Returns:\n", - " None\n", - " \"\"\"\n", - " batch_phase = batch[\"source\"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.\n", - " batch_membrane = batch[\"target\"][:, 1, 0, :, :].unsqueeze(\n", - " 1\n", - " ) # batch_size x 1 x Y x X tensor.\n", - " batch_nuclei = batch[\"target\"][:, 0, 0, :, :].unsqueeze(\n", - " 1\n", - " ) # batch_size x 1 x Y x X tensor.\n", - "\n", - " p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))\n", - " batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)\n", - "\n", - " p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))\n", - " batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)\n", - "\n", - " p1, p99 = np.percentile(batch_phase, (0.1, 99.9))\n", - " batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)\n", - "\n", - " [N, C, H, W] = batch_phase.shape\n", - " interleaved_images = torch.zeros((3 * N, C, H, W), dtype=batch_phase.dtype)\n", - " interleaved_images[0::3, :] = batch_phase\n", - " interleaved_images[1::3, :] = batch_nuclei\n", - " interleaved_images[2::3, :] = batch_membrane\n", - "\n", - " grid = torchvision.utils.make_grid(interleaved_images, nrow=3)\n", - "\n", - " # add the grid to tensorboard\n", - " writer.add_image(card_name, grid, batchno)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f627e8e8", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# Initialize the data module.\n", - "\n", - "BATCH_SIZE = 42\n", - "# 42 is a perfectly reasonable batch size. After all, it is the answer to the ultimate question of life, the universe and everything.\n", - "# More seriously, batch size does not have to be a power of 2.\n", - "# See: https://sebastianraschka.com/blog/2022/batch-size-2.html\n", - "\n", - "data_module = HCSDataModule(\n", - " data_path,\n", - " source_channel=\"Phase\",\n", - " target_channel=[\"Nuclei\", \"Membrane\"],\n", - " z_window_size=1,\n", - " split_ratio=0.8,\n", - " batch_size=BATCH_SIZE,\n", - " num_workers=8,\n", - " architecture=\"2D\",\n", - " yx_patch_size=(512, 512), # larger patch size makes it easy to see augmentations.\n", - " augment=False, # Turn off augmentation for now.\n", - ")\n", - "data_module.setup(\"fit\")\n", - "\n", - "print(\n", - " f\"FOVs in training set: {len(data_module.train_dataset)}, FOVs in validation set:{len(data_module.val_dataset)}\"\n", - ")\n", - "train_dataloader = data_module.train_dataloader()\n", - "\n", - "# Instantiate the tensorboard SummaryWriter, logs the first batch and then iterates through all the batches and logs them to tensorboard.\n", - "\n", - "writer = SummaryWriter(log_dir=f\"{log_dir}/view_batch\")\n", - "# Draw a batch and write to tensorboard.\n", - "batch = next(iter(train_dataloader))\n", - "log_batch_tensorboard(batch, 0, writer, \"augmentation/none\")\n", - "\n", - "# Iterate through all the batches and log them to tensorboard.\n", - "for i, batch in enumerate(train_dataloader):\n", - " log_batch_tensorboard(batch, i, writer, \"augmentation/none\")\n", - "writer.close()" - ] - }, - { - "cell_type": "markdown", - "id": "f6ffc7d1", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "There are multiple ways of seeing the tensorboard.\n", - "1. Jupyter lab forwards the tensorboard port to the browser. Go to http://localhost:6006/ to see the tensorboard.\n", - "2. You likely have an open viewer in the first cell where you loaded tensorboard jupyter extension.\n", - "3. If you want to see tensorboard in a specific cell, use the following code.\n", - "```\n", - "notebook.list() # View open TensorBoard instances\n", - "notebook.display(port=6006, height=800) # Display the TensorBoard instance specified by the port.\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "ed1448ff", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "## View augmentations using tensorboard.\n", - "\n", - "
\n", - "Task 1.3\n", - "Turn on augmentation and view the batch in tensorboard." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "245ae4dd", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "##########################\n", - "######## TODO ########\n", - "##########################\n", - "\n", - "# Write code to turn on augmentations, change batch sizes and log them to tensorboard.\n", - "# See how the training data changes as a function of these parameters.\n", - "# Remember to call `data_module.setup(\"fit\")` after changing the parameters." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e14f4faf", - "metadata": { - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "##########################\n", - "######## Solution ########\n", - "##########################\n", - "\n", - "data_module.augment = True\n", - "data_module.batch_size = 21\n", - "data_module.split_ratio = 0.8\n", - "data_module.setup(\"fit\")\n", - "\n", - "train_dataloader = data_module.train_dataloader()\n", - "# Draw batches and write to tensorboard\n", - "writer = SummaryWriter(log_dir=f\"{log_dir}/view_batch\")\n", - "for i, batch in enumerate(train_dataloader):\n", - " log_batch_tensorboard(batch, i, writer, \"augmentation/some\")\n", - "writer.close()" - ] - }, - { - "cell_type": "markdown", - "id": "85743be9", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "## Construct a 2D U-Net for image translation.\n", - "See ``viscy.unet.networks.Unet2D.Unet2d`` for configuration details.\n", - "We setup a fresh data module and instantiate the trainer class." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "44fddf02", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "\n", - "# The entire training loop is contained in this cell.\n", - "\n", - "GPU_ID = 0\n", - "BATCH_SIZE = 10\n", - "YX_PATCH_SIZE = (512, 512)\n", - "\n", - "\n", - "# Dictionary that specifies key parameters of the model.\n", - "phase2fluor_config = {\n", - " \"architecture\": \"2D\",\n", - " \"num_filters\": [24, 48, 96, 192, 384],\n", - " \"in_channels\": 1,\n", - " \"out_channels\": 2,\n", - " \"residual\": True,\n", - " \"dropout\": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data.\n", - " \"task\": \"reg\", # reg = regression task.\n", - "}\n", - "\n", - "phase2fluor_model = VSUNet(\n", - " model_config=phase2fluor_config.copy(),\n", - " batch_size=BATCH_SIZE,\n", - " loss_function=torch.nn.functional.l1_loss,\n", - " schedule=\"WarmupCosine\",\n", - " log_num_samples=10, # Number of samples from each batch to log to tensorboard.\n", - " example_input_yx_shape=YX_PATCH_SIZE,\n", - ")\n", - "\n", - "# Reinitialize the data module.\n", - "phase2fluor_data = HCSDataModule(\n", - " data_path,\n", - " source_channel=\"Phase\",\n", - " target_channel=[\"Nuclei\", \"Membrane\"],\n", - " z_window_size=1,\n", - " split_ratio=0.8,\n", - " batch_size=BATCH_SIZE,\n", - " num_workers=8,\n", - " architecture=\"2D\",\n", - " yx_patch_size=YX_PATCH_SIZE,\n", - " augment=True,\n", - ")\n", - "phase2fluor_data.setup(\"fit\")" - ] - }, - { - "cell_type": "markdown", - "id": "292cb22a", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "
\n", - "Task 1.4\n", - "Setup the training for ~30 epochs\n", - "
\n", - "\n", - "Tips:\n", - "- Set ``default_root_dir`` to store the logs and checkpoints\n", - "in a specific directory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b9cf215b", - "metadata": { - "lines_to_next_cell": 2, - "title": "Setup trainer and check for errors." - }, - "outputs": [], - "source": [ - "\n", - "# fast_dev_run runs a single batch of data through the model to check for errors.\n", - "trainer = VSTrainer(accelerator=\"gpu\", devices=[GPU_ID], fast_dev_run=True)\n", - "\n", - "# trainer class takes the model and the data module as inputs.\n", - "trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5f11d604", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "\n", - "GPU_ID = 0\n", - "n_samples = len(phase2fluor_data.train_dataset)\n", - "steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch.\n", - "n_epochs = 30\n", - "\n", - "trainer = VSTrainer(\n", - " accelerator=\"gpu\",\n", - " devices=[GPU_ID],\n", - " max_epochs=n_epochs,\n", - " log_every_n_steps=steps_per_epoch // 2,\n", - " # log losses and image samples 2 times per epoch.\n", - " default_root_dir=Path(\n", - " log_dir, \"phase2fluor\"\n", - " ), # lightning trainer transparently saves logs and model checkpoints in this directory.\n", - ")\n", - "\n", - "# Log graph\n", - "trainer.logger.log_graph(phase2fluor_model, phase2fluor_data.train_dataset[0][\"source\"])\n", - "# Launch training.\n", - "trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)" - ] - }, - { - "cell_type": "markdown", - "id": "157a989c", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "
\n", - "Checkpoint 1\n", - "\n", - "Now the training has started,\n", - "we can come back after a while and evaluate the performance!\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "d68dab93", - "metadata": { - "cell_marker": "\"\"\"", - "incorrectly_encoded_metadata": "id='1_fluor2phase'>", - "title": "\n", - "Checkpoint 2\n", - "Please summarize hyperparameters and performance of your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z)\n", - "\n", - "Now that you have trained two models, let's think about the following questions:\n", - "- What is the information content of each channel in the dataset?\n", - "- How would you use image translation models?\n", - "- What can you try to improve the performance of each model?\n", - "\n", - "\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "7d0c4204", - "metadata": { - "cell_marker": "\"\"\"", - "incorrectly_encoded_metadata": "id='3_tuning'>", - "title": "\n", - "Checkpoint 3\n", - "\n", - "Congratulations! You have trained several image translation models now!\n", - "Please document hyperparameters, snapshots of predictioons on validation set, and loss curves for your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z)\n", - "" - ] - } - ], - "metadata": { - "jupytext": { - "cell_metadata_filter": "all", - "main_language": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 661a1fb4..c96aab7b 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -4,7 +4,6 @@ --- Written by Ziwen Liu and Shalin Mehta, CZ Biohub San Francisco. ---- In this exercise, we will solve an image translation task to predict fluorescence images of nuclei and membrane markers from quantitative phase images of cells. In other words, we will _virtually stain_ the nuclei and membrane visible in the phase image. @@ -16,16 +15,17 @@ [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning . eLife](https://elifesciences.org/articles/55502). -VisCy exploits recent advances in the data and metadata formats ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). Our previous pipeline, [microDL](https://github.com/mehta-lab/microDL), is deprecated and is now a public archive. - +VisCy exploits recent advances in the data and metadata formats ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). """ # %% [markdown] """ Today, we will train a 2D image translation model using a 2D U-Net with residual connections. We will use a dataset of 301 fields of view (FOVs) of Human Embryonic Kidney (HEK) cells, each FOV has 3 channels (phase, membrane, and nuclei). The cells were labeled with CRISPR editing. Intrestingly, not all cells during this experiment were labeled due to the stochastic nature of CRISPR editing. In such situations, virtual staining rescues missing labels. ![HEK](https://github.com/mehta-lab/VisCy/blob/dlmbl2023/docs/figures/phase_to_nuclei_membrane.svg?raw=true) - -
+""" +# %% [markdown] +""" +
The exercise is organized in 3 parts. * **Part 1** - Explore the data using tensorboard. Launch the training before lunch. @@ -33,22 +33,24 @@ * **Part 2** - Evaluate the training with tensorboard. Train another model. * **Part 3** - Tune the models to improve performance.
- -📖 As you work through parts 2 and 3, please share the layouts of the models you train and their performance with everyone via [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) 📖. - - -Our guesstimate is that each of the three parts will take ~1.5 hours, but don't rush parts 1 and 2 if you need more time with them. -We will discuss your observations on google doc after checkpoints 2 and 3. The exercise is focused on understanding information contained in data, process of training and evaluating image translation models, and parameter exploration. -There are a few coding tasks sprinkled in. +""" +# %% [markdown] +""" +📖 As you work through parts 2 and 3, please share the layouts of your models (output of torchview) and their performance with everyone via [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) 📖. -Before you start, +Our guesstimate is that each of the three parts will take ~1.5 hours. A reasonable 2D UNet can be trained in ~20 min on a typical AWS node. +We will discuss your observations on google doc after checkpoints 2 and 3. +The focus of the exercise is on understanding information content of the data, how to train and evaluate 2D image translation model, and explore some hyperparameters of the model. If you complete this exercise and have time to spare, try the bonus exercise on 3D image translation. +""" +# %% [markdown] +"""
-Set your python kernel to 04-image-translation +Set your python kernel to 04_image_translation
""" -# %% [markdown] +# %% """ # Part 1: Log training data to tensorboard, start training a model. --------- @@ -60,11 +62,9 @@ - Log some patches to tensorboard. - Initialize a 2D U-Net model for virtual staining - Start training the model to predict nuclei and membrane from phase. - """ # %% Imports and paths - from pathlib import Path import matplotlib.pyplot as plt @@ -76,9 +76,10 @@ from iohub import open_ome_zarr from lightning.pytorch import seed_everything from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger +from skimage import metrics # for metrics. +# %% Imports and paths # pytorch lightning wrapper for Tensorboard. -from tensorboard import notebook # for viewing tensorboard in notebook from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard # HCSDataModule makes it easy to load data during training. @@ -99,22 +100,25 @@ # Create log directory if needed, and launch tensorboard log_dir.mkdir(parents=True, exist_ok=True) -# fmt: off -%reload_ext tensorboard -%tensorboard --logdir {log_dir} -# fmt: on +# %% [markdown] tags=[] +''' +The next cell starts tensorboard within the notebook. +
+If you launched jupyter lab from ssh terminal, add --host <your-server-name> to the tensorboard command below. <your-server-name> is the address of your compute node that ends in amazonaws.com. + +You can also launch tensorboard in an independent tab (instead of in the notebook) by changing the `%` to `!` +
+''' + +# %% Imports and paths tags=[] +%reload_ext tensorboard +%tensorboard --logdir {log_dir} # %% [markdown] """ ## Load Dataset. -
-Task 1.1 -Use -iohub.open_ome_zarr to read the dataset and explore several FOVs with matplotlib. -
- There should be 301 FOVs in the dataset (12 GB compressed). Each FOV consists of 3 channels of 2048x2048 images, @@ -125,26 +129,24 @@ The layout on the disk is: row/col/field/pyramid_level/timepoint/channel/z/y/x. Notice that labelling of nuclei channel is not complete - some cells are not expressing the fluorescent protein. - """ # %% - dataset = open_ome_zarr(data_path) print(f"Number of positions: {len(list(dataset.positions()))}") # Use the field and pyramid_level below to visualize data. -row = "0" -col = "0" -field = "23" +row = 0 +col = 0 +field = 23 # TODO: Change this to explore data. # This dataset contains images at 3 resolutions. # '0' is the highest resolution # '1' is down-scaled 2x2, # '2' is down-scaled 4x4. # Such datasets are called image pyramids. -pyaramid_level = "0" +pyaramid_level = 0 # `channel_names` is the metadata that is stored with data according to the OME-NGFF spec. n_channels = len(dataset.channel_names) @@ -166,13 +168,16 @@ plt.tight_layout() # %% [markdown] -""" -## Initialize data loaders and see the samples in tensorboard. +#
+# +# ### Task 1.1 +# +# Look at a couple different fields of view by changing the value in the cell above. See if you notice any missing or inconsistent staining. +#
-
-Task 1.2 -Setup the data loader and log several batches to tensorboard. -
` +# %% [markdown] +""" +## Explore the effects of augmentation on batch. VisCy builds on top of PyTorch Lightning. PyTorch Lightning is a thin wrapper around PyTorch that allows rapid experimentation. It provides a [DataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) to handle loading and processing of data during training. VisCy provides a child class, `HCSDataModule` to make it intuitve to access data stored in the HCS layout. @@ -180,13 +185,23 @@ - `source`: the input image, a tensor of size 1*1*Y*X - `target`: the target image, a tensor of size 2*1*Y*X - `index` : the tuple of (location of field in HCS layout, time, and z-slice) of the sample. - """ +# %% [markdown] +#
+# +# ### Task 1.2 +# +# Setup the data loader and log several batches to tensorboard. +# +# Based on the tensorboard images, what are the two channels in the target image? +# +# Note: If tensorboard is not showing images, try refreshing and using the "Images" tab. +#
+ # %% # Define a function to write a batch to tensorboard log. - def log_batch_tensorboard(batch, batchno, writer, card_name): """ Logs a batch of images to TensorBoard. @@ -228,11 +243,57 @@ def log_batch_tensorboard(batch, batchno, writer, card_name): writer.add_image(card_name, grid, batchno) +# %% +# Define a function to visualize a batch on jupyter, in case tensorboard is finicky + +def log_batch_jupyter(batch): + """ + Logs a batch of images on jupyter using ipywidget. + + Args: + batch (dict): A dictionary containing the batch of images to be logged. + + Returns: + None + """ + batch_phase = batch["source"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor. + batch_size = batch_phase.shape[0] + batch_membrane = batch["target"][:, 1, 0, :, :].unsqueeze( + 1 + ) # batch_size x 1 x Y x X tensor. + batch_nuclei = batch["target"][:, 0, 0, :, :].unsqueeze( + 1 + ) # batch_size x 1 x Y x X tensor. + + p1, p99 = np.percentile(batch_membrane, (0.1, 99.9)) + batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1) + + p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9)) + batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1) + + p1, p99 = np.percentile(batch_phase, (0.1, 99.9)) + batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1) + + plt.figure() + fig, axes = plt.subplots(batch_size, n_channels, figsize=(10, 10)) + [N, C, H, W] = batch_phase.shape + for sample_id in range(batch_size): + axes[sample_id, 0].imshow(batch_phase[sample_id,0]) + axes[sample_id, 1].imshow(batch_nuclei[sample_id,0]) + axes[sample_id, 2].imshow(batch_membrane[sample_id,0]) + + for i in range(n_channels): + axes[sample_id, i].axis("off") + axes[sample_id, i].set_title(dataset.channel_names[i]) + plt.tight_layout() + plt.show() + + # %% # Initialize the data module. -BATCH_SIZE = 42 +BATCH_SIZE = 4 # 42 is a perfectly reasonable batch size. After all, it is the answer to the ultimate question of life, the universe and everything. # More seriously, batch size does not have to be a power of 2. # See: https://sebastianraschka.com/blog/2022/batch-size-2.html @@ -240,7 +301,7 @@ def log_batch_tensorboard(batch, batchno, writer, card_name): data_module = HCSDataModule( data_path, source_channel="Phase", - target_channel=["Nuclei", "Membrane"], + target_channel=["Membrane", "Nuclei"], z_window_size=1, split_ratio=0.8, batch_size=BATCH_SIZE, @@ -262,70 +323,58 @@ def log_batch_tensorboard(batch, batchno, writer, card_name): # Draw a batch and write to tensorboard. batch = next(iter(train_dataloader)) log_batch_tensorboard(batch, 0, writer, "augmentation/none") - -# Iterate through all the batches and log them to tensorboard. -for i, batch in enumerate(train_dataloader): - log_batch_tensorboard(batch, i, writer, "augmentation/none") writer.close() + # %% [markdown] -""" -There are multiple ways of seeing the tensorboard. -1. Jupyter lab forwards the tensorboard port to the browser. Go to http://localhost:6006/ to see the tensorboard. -2. You likely have an open viewer in the first cell where you loaded tensorboard jupyter extension. -3. If you want to see tensorboard in a specific cell, use the following code. -``` -notebook.list() # View open TensorBoard instances -notebook.display(port=6006, height=800) # Display the TensorBoard instance specified by the port. -``` -""" +# Visualize directly on Jupyter ☄️, if your tensorboard is causing issues. + +# %% +%matplotlib inline +log_batch_jupyter(batch) # %% [markdown] """ ## View augmentations using tensorboard. - -
-Task 1.3 -Turn on augmentation and view the batch in tensorboard. """ # %% -########################## -######## TODO ######## -########################## - -# Write code to turn on augmentations, change batch sizes and log them to tensorboard. -# See how the training data changes as a function of these parameters. -# Remember to call `data_module.setup("fit")` after changing the parameters. - - -# %% tags=["solution"] -########################## -######## Solution ######## -########################## - +# Here we turn on data augmentation and rerun setup data_module.augment = True -data_module.batch_size = 21 -data_module.split_ratio = 0.8 data_module.setup("fit") -train_dataloader = data_module.train_dataloader() +# get the new data loader with augmentation turned on +augmented_train_dataloader = data_module.train_dataloader() + # Draw batches and write to tensorboard writer = SummaryWriter(log_dir=f"{log_dir}/view_batch") -for i, batch in enumerate(train_dataloader): - log_batch_tensorboard(batch, i, writer, "augmentation/some") +augmented_batch = next(iter(augmented_train_dataloader)) +log_batch_tensorboard(augmented_batch, 0, writer, "augmentation/some") writer.close() # %% [markdown] -""" -## Construct a 2D U-Net for image translation. -See ``viscy.unet.networks.Unet2D.Unet2d`` for configuration details. -We setup a fresh data module and instantiate the trainer class. -""" +# Visualize directly on Jupyter ☄️ # %% +log_batch_jupyter(augmented_batch) -# The entire training loop is contained in this cell. +# %% [markdown] +#
+# +# ### Task 1.3 +# Can you tell what augmentation were applied from looking at the augmented images in Tensorboard? +# +# Check your answer using the source code [here](https://github.com/mehta-lab/VisCy/blob/b89f778b34735553cf155904eef134c756708ff2/viscy/light/data.py#L529). +#
+# %% [markdown] +""" +## Train a 2D U-Net model to predict nuclei and membrane from phase. + +### Construct a 2D U-Net +See ``viscy.unet.networks.Unet2D.Unet2d`` ([source code](https://github.com/mehta-lab/VisCy/blob/7c5e4c1d68e70163cf514d22c475da8ea7dc3a88/viscy/unet/networks/Unet2D.py#L7)) for configuration details. +""" +# %% +# Create a 2D UNet. GPU_ID = 0 BATCH_SIZE = 10 YX_PATCH_SIZE = (512, 512) @@ -347,15 +396,21 @@ def log_batch_tensorboard(batch, batchno, writer, card_name): batch_size=BATCH_SIZE, loss_function=torch.nn.functional.l1_loss, schedule="WarmupCosine", - log_num_samples=10, # Number of samples from each batch to log to tensorboard. + log_num_samples=5, # Number of samples from each batch to log to tensorboard. example_input_yx_shape=YX_PATCH_SIZE, ) -# Reinitialize the data module. + +# %% [markdown] +""" +### Instantiate data module and trainer, test that we are setup to launch training. +""" +# %% +# Setup the data module. phase2fluor_data = HCSDataModule( data_path, source_channel="Phase", - target_channel=["Nuclei", "Membrane"], + target_channel=["Membrane", "Nuclei"], z_window_size=1, split_ratio=0.8, batch_size=BATCH_SIZE, @@ -365,27 +420,44 @@ def log_batch_tensorboard(batch, batchno, writer, card_name): augment=True, ) phase2fluor_data.setup("fit") +# fast_dev_run runs a single batch of data through the model to check for errors. +trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], fast_dev_run=True) + +# trainer class takes the model and the data module as inputs. +trainer.fit(phase2fluor_model, datamodule=phase2fluor_data) # %% [markdown] -""" -
-Task 1.4 -Setup the training for ~30 epochs -
+# ## View model graph. +# +# PyTorch uses dynamic graphs under the hood. The graphs are constructed on the fly. This is in contrast to TensorFlow, where the graph is constructed before the training loop and remains static. In other words, the graph of the network can change with every forward pass. Therefore, we need to supply an input tensor to construct the graph. The input tensor can be a random tensor of the correct shape and type. We can also supply a real image from the dataset. The latter is more useful for debugging. -Tips: -- Set ``default_root_dir`` to store the logs and checkpoints -in a specific directory. -""" +# %% [markdown] +#
+# +# ### Task 1.4 +# Run the next cell to generate a graph representation of the model architecture. Can you recognize the UNet structure and skip connections in this graph visualization? +#
-# %% Setup trainer and check for errors. +# %% +# visualize graph of phase2fluor model as image. +model_graph_phase2fluor = torchview.draw_graph( + phase2fluor_model, + phase2fluor_data.train_dataset[0]["source"], + depth=2, # adjust depth to zoom in. + device="cpu", +) +# Print the image of the model. +model_graph_phase2fluor.visual_graph -# fast_dev_run runs a single batch of data through the model to check for errors. -trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], fast_dev_run=True) +# %% [markdown] +""" +
-# trainer class takes the model and the data module as inputs. -trainer.fit(phase2fluor_model, datamodule=phase2fluor_data) +### Task 1.5 +Start training by running the following cell. Check the new logs on the tensorboard. +
+""" # %% @@ -393,66 +465,187 @@ def log_batch_tensorboard(batch, batchno, writer, card_name): GPU_ID = 0 n_samples = len(phase2fluor_data.train_dataset) steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch. -n_epochs = 30 +n_epochs = 50 # Set this to 50 or the number of epochs you want to train for. trainer = VSTrainer( accelerator="gpu", devices=[GPU_ID], max_epochs=n_epochs, - # log losses and image samples 2 times per epoch. log_every_n_steps=steps_per_epoch // 2, - # lightning trainer transparently saves logs and model checkpoints in this directory + # log losses and image samples 2 times per epoch. logger=TensorBoardLogger( save_dir=log_dir, + # lightning trainer transparently saves logs and model checkpoints in this directory. name="phase2fluor", log_graph=True, - ), -) - -# Launch training. + ), + ) +# Launch training and check that loss and images are being logged on tensorboard. trainer.fit(phase2fluor_model, datamodule=phase2fluor_data) - # %% [markdown] """
-Checkpoint 1 + +## Checkpoint 1 Now the training has started, we can come back after a while and evaluate the performance!
""" -# %% [markdown] +# %% """ # Part 2: Assess previous model, train fluorescence to phase contrast translation model. -------------------------------------------------- +""" -Learning goals: -- Visualize the previous model and training with tensorboard -- Train fluorescence to phase contrast translation model -- Compare the performance of the two models. +# %% [markdown] +""" +We now look at some metrics of performance of previous model. We typically evaluate the model performance on a held out test data. We will use the following metrics to evaluate the accuracy of regression of the model: +- [Person Correlation](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient). +- [Structural similarity](https://en.wikipedia.org/wiki/Structural_similarity) (SSIM). +You should also look at the validation samples on tensorboard (hint: the experimental data in nuclei channel is imperfect.) """ -# %% +# %% [markdown] +""" +
-# PyTorch uses dynamic graphs under the hood. The graphs are constructed on the fly. This is in contrast to TensorFlow, where the graph is constructed before the training loop and remains static. In other words, the graph of the network can change with every forward pass. Therefore, we need to supply an input tensor to construct the graph. The input tensor can be a random tensor of the correct shape and type. We can also supply a real image from the dataset. The latter is more useful for debugging. +### Task 2.1 Define metrics -# visualize graph. -model_graph_phase2fluor = torchview.draw_graph( - phase2fluor_model, - phase2fluor_data.train_dataset[0]["source"], +For each of the above metrics, write a brief definition of what they are and what they mean for this image translation task. + +
+""" + +# %% [markdown] +# ``` +# ####################### +# ##### Todo ############ +# ####################### +# ``` +# +# - Pearson Correlation: +# +# - Structural similarity: + +# %% Compute metrics directly and plot here. +test_data_path = Path( + "~/data/04_image_translation/HEK_nuclei_membrane_test.zarr" +).expanduser() + +test_data = HCSDataModule( + test_data_path, + source_channel="Phase", + target_channel=["Membrane", "Nuclei"], + z_window_size=1, + batch_size=1, + num_workers=8, + architecture="2D", +) +test_data.setup("test") + +test_metrics = pd.DataFrame( + columns=["pearson_nuc", "SSIM_nuc", "pearson_mem", "SSIM_mem"] +) + + +def min_max_scale(input): + return (input - np.min(input)) / (np.max(input) - np.min(input)) + + +# %% Compute metrics directly and plot here. +for i, sample in enumerate(test_data.test_dataloader()): + phase_image = sample["source"] + with torch.inference_mode(): # turn off gradient computation. + predicted_image = phase2fluor_model(phase_image) + + target_image = ( + sample["target"].cpu().numpy().squeeze(0) + ) # Squeezing batch dimension. + predicted_image = predicted_image.cpu().numpy().squeeze(0) + phase_image = phase_image.cpu().numpy().squeeze(0) + target_mem = min_max_scale(target_image[1, 0, :, :]) + target_nuc = min_max_scale(target_image[0, 0, :, :]) + # slicing channel dimension, squeezing z-dimension. + predicted_mem = min_max_scale(predicted_image[1, :, :, :].squeeze(0)) + predicted_nuc = min_max_scale(predicted_image[0, :, :, :].squeeze(0)) + + # Compute SSIM and pearson correlation. + ssim_nuc = metrics.structural_similarity(target_nuc, predicted_nuc, data_range=1) + ssim_mem = metrics.structural_similarity(target_mem, predicted_mem, data_range=1) + pearson_nuc = np.corrcoef(target_nuc.flatten(), predicted_nuc.flatten())[0, 1] + pearson_mem = np.corrcoef(target_mem.flatten(), predicted_mem.flatten())[0, 1] + + test_metrics.loc[i] = { + "pearson_nuc": pearson_nuc, + "SSIM_nuc": ssim_nuc, + "pearson_mem": pearson_mem, + "SSIM_mem": ssim_mem, + } + +test_metrics.boxplot( + column=["pearson_nuc", "SSIM_nuc", "pearson_mem", "SSIM_mem"], + rot=30, +) + + +# %% [markdown] tags=[] +""" +
+ +### Task 2.2 Train fluorescence to phase contrast translation model + +Instantiate a data module, model, and trainer for fluorescence to phase contrast translation. Copy over the code from previous cells and update the parameters. Give the variables and paths a different name/suffix (fluor2phase) to avoid overwriting objects used to train phase2fluor models. +
+""" +# %% tags=[] +########################## +######## TODO ######## +########################## + +fluor2phase_data = HCSDataModule( + # Your code here (copy from above and modify as needed) +) +fluor2phase_data.setup("fit") + +# Dictionary that specifies key parameters of the model. +fluor2phase_config = { + # Your config here +} + +fluor2phase_model = VSUNet( + # Your code here (copy from above and modify as needed) +) + +trainer = VSTrainer( + # Your code here (copy from above and modify as needed) +) +trainer.fit(fluor2phase_model, datamodule=fluor2phase_data) + + +# Visualize the graph of fluor2phase model as image. +model_graph_fluor2phase = torchview.draw_graph( + fluor2phase_model, + fluor2phase_data.train_dataset[0]["source"], depth=2, # adjust depth to zoom in. device="cpu", ) -# Increase the depth to zoom in. -model_graph_phase2fluor.visual_graph +model_graph_fluor2phase.visual_graph + +# %% tags=["solution"] + +########################## +######## Solution ######## +########################## + +# The entire training loop is contained in this cell. -# %% tags = ["solution"] fluor2phase_data = HCSDataModule( data_path, - source_channel="Nuclei", + source_channel="Membrane", target_channel="Phase", z_window_size=1, split_ratio=0.8, @@ -480,126 +673,172 @@ def log_batch_tensorboard(batch, batchno, writer, card_name): batch_size=BATCH_SIZE, loss_function=torch.nn.functional.mse_loss, schedule="WarmupCosine", - log_num_samples=10, + log_num_samples=5, example_input_yx_shape=YX_PATCH_SIZE, ) -n_samples = len(fluor2phase_data.train_dataset) -steps_per_epoch = n_samples // BATCH_SIZE -n_epochs = 30 trainer = VSTrainer( accelerator="gpu", devices=[GPU_ID], max_epochs=n_epochs, - log_every_n_steps=steps_per_epoch, + log_every_n_steps=steps_per_epoch // 2, logger=TensorBoardLogger( save_dir=log_dir, + # lightning trainer transparently saves logs and model checkpoints in this directory. name="fluor2phase", log_graph=True, ), ) trainer.fit(fluor2phase_model, datamodule=fluor2phase_data) -# %% -# Visualize the graph of fluor2phase model. + +# Visualize the graph of fluor2phase model as image. model_graph_fluor2phase = torchview.draw_graph( - phase2fluor_model, - phase2fluor_data.train_dataset[0]["source"], + fluor2phase_model, + fluor2phase_data.train_dataset[0]["source"], depth=2, # adjust depth to zoom in. device="cpu", ) model_graph_fluor2phase.visual_graph -# %% [markdown] +# %% [markdown] tags=[] """ -We now look at some metrics of performance. Loss is a differentiable metric. But, several non-differentiable metrics are useful to assess the performance of the model. We typically evaluate the model performance on a held out test data. We will use the following metrics to evaluate the accuracy of regression of the model: -- [Coefficient of determination](https://en.wikipedia.org/wiki/Coefficient_of_determination): $R^2$ -- [Person Correlation](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient). -- [Structural similarity](https://en.wikipedia.org/wiki/Structural_similarity) (SSIM): +
+ +### Task 2.3 +While your model is training, let's think about the following questions: +- What is the information content of each channel in the dataset? +- How would you use image translation models? +- What can you try to improve the performance of each model? +
""" # %% - -# TODO: set following parameters, specifically path to checkpoint, and log the metrics. test_data_path = Path( "~/data/04_image_translation/HEK_nuclei_membrane_test.zarr" ).expanduser() -model_version = "phase2fluor" -save_dir = Path(log_dir, "test") -ckpt_path = Path( - r"/home/mehtas/data/04_image_translation/logs/phase2fluor/lightning_logs/version_0/checkpoints/epoch=29-step=720.ckpt" -) # prefix the string with 'r' to avoid the need for escape characters. -### END TODO test_data = HCSDataModule( test_data_path, - source_channel="Phase", - target_channel=["Nuclei", "Membrane"], + source_channel="Nuclei", # or Membrane, depending on your choice of source + target_channel="Phase", z_window_size=1, batch_size=1, num_workers=8, architecture="2D", ) test_data.setup("test") -trainer = VSTrainer( - accelerator="gpu", - devices=[GPU_ID], - logger=CSVLogger(save_dir=save_dir, version=model_version), -) -trainer.test( - phase2fluor_model, - datamodule=test_data, - ckpt_path=ckpt_path, + +test_metrics = pd.DataFrame( + columns=["pearson_phase", "SSIM_phase"] ) -# read metrics and plot -metrics = pd.read_csv(Path(save_dir, "lightning_logs", model_version, "metrics.csv")) -metrics.boxplot( - column=[ - "test_metrics/r2_step", - "test_metrics/pearson_step", - "test_metrics/SSIM_step", - ], + + +def min_max_scale(input): + return (input - np.min(input)) / (np.max(input) - np.min(input)) + + +# %% +for i, sample in enumerate(test_data.test_dataloader()): + source_image = sample["source"] + with torch.inference_mode(): # turn off gradient computation. + predicted_image = fluor2phase_model(source_image) + + target_image = ( + sample["target"].cpu().numpy().squeeze(0) + ) # Squeezing batch dimension. + predicted_image = predicted_image.cpu().numpy().squeeze(0) + source_image = source_image.cpu().numpy().squeeze(0) + target_phase = min_max_scale(target_image[0, 0, :, :]) + # slicing channel dimension, squeezing z-dimension. + predicted_phase = min_max_scale(predicted_image[0, :, :, :].squeeze(0)) + + # Compute SSIM and pearson correlation. + ssim_phase = metrics.structural_similarity(target_phase, predicted_phase, data_range=1) + pearson_phase = np.corrcoef(target_phase.flatten(), predicted_phase.flatten())[0, 1] + + test_metrics.loc[i] = { + "pearson_phase": pearson_phase, + "SSIM_phase": ssim_phase, + } + +test_metrics.boxplot( + column=["pearson_phase", "SSIM_phase"], rot=30, ) -# %% [markdown] + +# %% [markdown] tags=[] """
-Checkpoint 2 -Please summarize hyperparameters and performance of your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) - -Now that you have trained two models, let's think about the following questions: -- What is the information content of each channel in the dataset? -- How would you use image translation models? -- What can you try to improve the performance of each model? +## Checkpoint 2 +When your model finishes training, please summarize hyperparameters and performance of your models in the [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z)
""" -# %% [markdown] +# %% tags=[] """ # Part 3: Tune the models. -------------------------------------------------- -Learning goals: +Learning goals: Understand how data, model capacity, and training parameters control the performance of the model. Your goal is to try to underfit or overfit the model. +""" + -- Tweak model hyperparameters, such as number of filters at each depth. -- Adjust learning rate to improve performance. +# %% [markdown] tags=[] """ +
-# %% -# %% +### Task 3.1 + +- Choose a model you want to train (phase2fluor or fluor2phase). +- Set up a configuration that you think will improve the performance of the model +- Consider modifying the learning rate and see how it changes performance +- Use training loop illustrated in previous cells to train phase2fluor and fluor2phase models to prototype your own training loop. +- Add code to evaluate the model using Pearson Correlation and SSIM + +As your model is training, please document hyperparameters, snapshots of predictions on validation set, and loss curves for your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) + +
+""" +# %% tags=[] ########################## ######## TODO ######## ########################## -# Choose a model you want to train (phase2fluor or fluor2phase). -# Create a config to double the number of filters at each stage. -# Use training loop illustrated in previous cells to train phase2fluor and fluor2phase models to prototype your own training loop. +tune_data = HCSDataModule( + # Your code here (copy from above and modify as needed) +) +tune_data.setup("fit") +# Dictionary that specifies key parameters of the model. +tune_config = { + # Your config here +} -# %% tags = ["solution"] +tune_model = VSUNet( + # Your code here (copy from above and modify as needed) +) + +trainer = VSTrainer( + # Your code here (copy from above and modify as needed) +) +trainer.fit(tune_model, datamodule=tune_data) + + +# Visualize the graph of fluor2phase model as image. +model_graph_tune = torchview.draw_graph( + tune_model, + tune_data.train_dataset[0]["source"], + depth=2, # adjust depth to zoom in. + device="cpu", +) +model_graph_tune.visual_graph + + +# %% tags=["solution"] ########################## ######## Solution ######## @@ -621,7 +860,7 @@ def log_batch_tensorboard(batch, batchno, writer, card_name): batch_size=BATCH_SIZE, loss_function=torch.nn.functional.l1_loss, schedule="WarmupCosine", - log_num_samples=10, + log_num_samples=5, example_input_yx_shape=YX_PATCH_SIZE, ) @@ -641,16 +880,7 @@ def log_batch_tensorboard(batch, batchno, writer, card_name): ) # Set fast_dev_run to False to train the model. trainer.fit(phase2fluor_wider_model, datamodule=phase2fluor_data) -# %% -########################## -######## TODO ######## -########################## - -# Choose a model you want to train (phase2fluor or fluor2phase). -# Train it with lower learning rate to see how the performance changes. - - -# %% tags = ["solution"] +# %% tags=["solution"] ########################## ######## Solution ######## @@ -663,7 +893,7 @@ def log_batch_tensorboard(batch, batchno, writer, card_name): # lower learning rate by 5 times lr=2e-4, schedule="WarmupCosine", - log_num_samples=10, + log_num_samples=5, example_input_yx_shape=YX_PATCH_SIZE, ) @@ -683,12 +913,13 @@ def log_batch_tensorboard(batch, batchno, writer, card_name): trainer.fit(phase2fluor_slow_model, datamodule=phase2fluor_data) -# %% [markdown] +# %% [markdown] tags=[] """
-Checkpoint 3 + +## Checkpoint 3 Congratulations! You have trained several image translation models now! -Please document hyperparameters, snapshots of predictioons on validation set, and loss curves for your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) +Please document hyperparameters, snapshots of predictions on validation set, and loss curves for your models and add the final perforance in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z). We'll discuss our combined results as a group.
"""