Skip to content

Commit

Permalink
Commit from GitHub Actions (Build Notebooks)
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Aug 25, 2024
1 parent 9a9e917 commit fc6e937
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 70 deletions.
107 changes: 74 additions & 33 deletions exercise.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"cell_type": "markdown",
"id": "fe88fdbe",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
},
"source": [
Expand All @@ -13,16 +14,21 @@
"\n",
"## Overview\n",
"\n",
"In this exercise, we will predict fluorescence images of\n",
"nuclei and plasma membrane markers from quantitative phase images of cells,\n",
"i.e., we will _virtually stain_ the nuclei and plasma membrane\n",
"visible in the phase image.\n",
"This is an example of an image translation task.\n",
"We will apply spatial and intensity augmentations to train robust models\n",
"and evaluate their performance using a regression approach.\n",
"In this exercise, we will _virtually stain_ the nuclei and plasma membrane from the quantitative phase image (QPI), i.e., translate QPI images into fluoresence images of nuclei and plasma membranes.\n",
"QPI encodes multiple cellular structures and virtual staining decomposes these structures. After the model is trained, one only needs to acquire label-free QPI data.\n",
"This strategy solves the problem as \"multi-spectral imaging\", but is more compatible with live cell imaging and high-throughput screening.\n",
"Virtual staining is often a step towards multiple downstream analyses: segmentation, tracking, and cell state phenotyping.\n",
"\n",
"In this exercise, you will:\n",
"- Train a model to predict the fluorescence images of nuclei and plasma membranes from QPI images\n",
"- Make it robust to variations in imaging conditions using data augmentions\n",
"- Segment the cells\n",
"- Use regression and segmentation metrics to evalute the models\n",
"- Visualize the image transform learned by the model\n",
"- Understand the failure modes of the trained model\n",
"\n",
"[![HEK293T](https://raw.githubusercontent.com/mehta-lab/VisCy/main/docs/figures/svideo_1.png)](https://github.com/mehta-lab/VisCy/assets/67518483/d53a81eb-eb37-44f3-b522-8bd7bddc7755)\n",
"(Click on image to play video)"
"(Click on image to play video)\n"
]
},
{
Expand All @@ -34,30 +40,34 @@
"source": [
"### Goals\n",
"\n",
"#### Part 1: Learn to use iohub (I/O library), VisCy dataloaders, and TensorBoard.\n",
"#### Part 1: Train a virtual staining model\n",
"\n",
" - Use a OME-Zarr dataset of 34 FOVs of adenocarcinomic human alveolar basal epithelial cells (A549),\n",
" each FOV has 3 channels (phase, nuclei, and cell membrane).\n",
" The nuclei were stained with DAPI and the cell membrane with Cellmask.\n",
" - Explore OME-Zarr using [iohub](https://czbiohub-sf.github.io/iohub/main/index.html)\n",
" and the high-content-screen (HCS) format.\n",
" - Use [MONAI](https://monai.io/) to implement data augmentations.\n",
"\n",
"#### Part 2: Train and evaluate the model to translate phase into fluorescence.\n",
" - Train a 2D UNeXt2 model to predict nuclei and membrane from phase images.\n",
" - Compare the performance of the trained model and a pre-trained model.\n",
" - Use our `viscy.data.HCSDataloader()` dataloader and explore the 3 channel (phase, fluoresecence nuclei and cell membrane) \n",
" A549 cell dataset. \n",
" - Implement data augmentations [MONAI](https://monai.io/) to train a robust model to imaging parameters and conditions. \n",
" - Use tensorboard to log the augmentations, training and validation losses and batches\n",
" - Start the training of the UNeXt2 model to predict nuclei and membrane from phase images.\n",
"\n",
"#### Part 2:Evaluate the model to translate phase into fluorescence.\n",
" - Compare the performance of your trained model with the _VSCyto2D_ pre-trained model.\n",
" - Evaluate the model using pixel-level and instance-level metrics.\n",
"\n",
"#### Part 3: Visualize the image transforms learned by the model and explore the model's regime of validity\n",
" - Visualize the first 3 principal componets mapped to a color space in each encoder and decoder block.\n",
" - Explore the model's regime of validity by applying blurring and scaling transforms to the input phase image.\n",
"\n",
"Checkout [VisCy](https://github.com/mehta-lab/VisCy/tree/main/examples/demos),\n",
"#### For more information:\n",
"Checkout [VisCy](https://github.com/mehta-lab/VisCy),\n",
"our deep learning pipeline for training and deploying computer vision models\n",
"for image-based phenotyping including the robust virtual staining of landmark organelles.\n",
"\n",
"VisCy exploits recent advances in data and metadata formats\n",
"([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks,\n",
"[PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/).\n",
"\n",
"### References\n",
"\n",
"- [Liu, Z. and Hirata-Miyasaki, E. et al. (2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v2.full.pdf)\n",
"- [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning. eLife](https://elifesciences.org/articles/55502)"
]
Expand All @@ -69,12 +79,13 @@
"tags": []
},
"source": [
"<div class=\"alert alert-info\">\n",
"The exercise is organized in 2 parts\n",
"<div class=\"alert alert-success\">\n",
"The exercise is organized in 3 parts:\n",
"\n",
"<ul>\n",
"<li><b>Part 1</b> - Learn to use iohub (I/O library), VisCy dataloaders, and tensorboard.</li>\n",
"<li><b>Part 2</b> - Train and evaluate the model to translate phase into fluorescence.</li>\n",
"<li><b>Part 1</b> - Train a virtual staining model using iohub (I/O library), VisCy dataloaders, and tensorboard</li>\n",
"<li><b>Part 2</b> - Evaluate the model to translate phase into fluorescence.</li>\n",
"<li><b>Part 3</b> - Visualize the image transforms learned by the model and explore the model's regime of validity.</li>\n",
"</ul>\n",
"\n",
"</div>"
Expand All @@ -98,7 +109,7 @@
"id": "b67b4355",
"metadata": {},
"source": [
"## Part 1: Log training data to tensorboard, start training a model.\n",
"# Part 1: Log training data to tensorboard, start training a model.\n",
"---------\n",
"Learning goals:\n",
"\n",
Expand Down Expand Up @@ -280,6 +291,8 @@
"specified by the Open Microscopy Environment Next Generation File Format\n",
"(OME-NGFF).\n",
"\n",
"The 3 channels correspond to the QPI, nuclei, and cell membrane. The nuclei were stained with DAPI and the cell membrane with Cellmask.\n",
"\n",
"- The layout on the disk is: `row/col/field/pyramid_level/timepoint/channel/z/y/x.`\n",
"- These datasets only have 1 level in the pyramid (highest resolution) which is '0'."
]
Expand Down Expand Up @@ -486,6 +499,7 @@
" p1, p99 = np.percentile(batch_phase, (0.1, 99.9))\n",
" batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)\n",
"\n",
" n_channels = batch[\"target\"].shape[1] + batch[\"source\"].shape[1]\n",
" plt.figure()\n",
" fig, axes = plt.subplots(\n",
" batch_size, n_channels, figsize=(n_channels * 2, batch_size * 2)\n",
Expand Down Expand Up @@ -689,10 +703,16 @@
"\n",
"normalizations = [\n",
" NormalizeSampled(\n",
" keys=source_channel + target_channel,\n",
" keys=source_channel,\n",
" level=\"fov_statistics\",\n",
" subtrahend=\"mean\",\n",
" divisor=\"std\",\n",
" ),\n",
" NormalizeSampled(\n",
" keys=target_channel,\n",
" level=\"fov_statistics\",\n",
" subtrahend=\"median\",\n",
" divisor=\"iqr\",\n",
" )\n",
"]\n",
"\n",
Expand Down Expand Up @@ -788,7 +808,7 @@
"# Create a 2D UNet.\n",
"GPU_ID = 0\n",
"\n",
"BATCH_SIZE = 12\n",
"BATCH_SIZE = 16\n",
"YX_PATCH_SIZE = (256, 256)\n",
"\n",
"# #######################\n",
Expand All @@ -814,7 +834,7 @@
" model_config=phase2fluor_config.copy(),\n",
" loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5),\n",
" schedule=\"WarmupCosine\",\n",
" lr=2e-4,\n",
" lr=6e-4,\n",
" log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard.\n",
" freeze_encoder=False,\n",
")\n",
Expand Down Expand Up @@ -842,7 +862,7 @@
")\n",
"phase2fluor_2D_data.setup(\"fit\")\n",
"# fast_dev_run runs a single batch of data through the model to check for errors.\n",
"trainer = VSTrainer(accelerator=\"gpu\", devices=[GPU_ID], fast_dev_run=True)\n",
"trainer = VSTrainer(accelerator=\"gpu\", devices=[GPU_ID], precision='16-mixed' ,fast_dev_run=True)\n",
"\n",
"# trainer class takes the model and the data module as inputs.\n",
"trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)"
Expand Down Expand Up @@ -887,7 +907,7 @@
")\n",
"phase2fluor_2D_data.setup(\"fit\")\n",
"# fast_dev_run runs a single batch of data through the model to check for errors.\n",
"trainer = VSTrainer(accelerator=\"gpu\", devices=[GPU_ID], fast_dev_run=True)\n",
"trainer = VSTrainer(accelerator=\"gpu\", devices=[GPU_ID],precision='16-mixed', fast_dev_run=True)\n",
"\n",
"# trainer class takes the model and the data module as inputs.\n",
"trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)"
Expand Down Expand Up @@ -984,12 +1004,13 @@
"\n",
"n_samples = len(phase2fluor_2D_data.train_dataset)\n",
"steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch.\n",
"n_epochs = 25 # Set this to 25-30 or the number of epochs you want to train for.\n",
"n_epochs = 80 # Set this to 80-100 or the number of epochs you want to train for.\n",
"\n",
"trainer = VSTrainer(\n",
" accelerator=\"gpu\",\n",
" devices=[GPU_ID],\n",
" max_epochs=n_epochs,\n",
" precision='16-mixed',\n",
" log_every_n_steps=steps_per_epoch // 2,\n",
" # log losses and image samples 2 times per epoch.\n",
" logger=TensorBoardLogger(\n",
Expand Down Expand Up @@ -1035,7 +1056,7 @@
"tags": []
},
"source": [
"## Part 2: Assess your trained model\n",
"# Part 2: Assess your trained model\n",
"\n",
"Now we will look at some metrics of performance of previous model.\n",
"We typically evaluate the model performance on a held out test data.\n",
Expand Down Expand Up @@ -1115,7 +1136,7 @@
"\n",
"```python\n",
"phase2fluor_model_ckpt = natsorted(glob(\n",
" str(top_dir/\"06_image_translation/backup/phase2fluor/version_3/checkpoints/*.ckpt\")\n",
" str(top_dir/\"06_image_translation/backup/phase2fluor/version_0/checkpoints/*.ckpt\")\n",
"))[-1]\n",
"````\n",
"</div>"
Expand Down Expand Up @@ -1789,6 +1810,20 @@
"</div>"
]
},
{
"cell_type": "markdown",
"id": "1aed3296",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
},
"source": [
"# Part 3: Visualizing the encoder and decoder features & exploring the model's range of validity\n",
"\n",
"- In this section, we will visualize the encoder and decoder features of the model you trained.\n",
"- We will also explore the model's range of validity by looking at the feature maps of the encoder and decoder.\n"
]
},
{
"cell_type": "markdown",
"id": "a82a8dfd",
Expand Down Expand Up @@ -2079,10 +2114,16 @@
"\n",
"normalizations = [\n",
" NormalizeSampled(\n",
" keys=source_channel + target_channel,\n",
" keys=source_channel,\n",
" level=\"fov_statistics\",\n",
" subtrahend=\"mean\",\n",
" divisor=\"std\",\n",
" ),\n",
" NormalizeSampled(\n",
" keys=target_channel,\n",
" level=\"fov_statistics\",\n",
" subtrahend=\"median\",\n",
" divisor=\"iqr\",\n",
" )\n",
"]\n",
"\n",
Expand Down
Loading

0 comments on commit fc6e937

Please sign in to comment.