From 9a9e917005661deb8c58bec4c7f9617f88932154 Mon Sep 17 00:00:00 2001 From: edyoshikun Date: Sun, 25 Aug 2024 03:03:50 +0000 Subject: [PATCH] - fixing n_channels - adding path to backup trained model - increased learning rate --- solution.py | 127 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 82 insertions(+), 45 deletions(-) diff --git a/solution.py b/solution.py index 40945e1..2c79463 100644 --- a/solution.py +++ b/solution.py @@ -4,55 +4,65 @@ # Written by Eduardo Hirata-Miyasaki, Ziwen Liu, and Shalin Mehta, CZ Biohub San Francisco # ## Overview - -# In this exercise, we will predict fluorescence images of -# nuclei and plasma membrane markers from quantitative phase images of cells, -# i.e., we will _virtually stain_ the nuclei and plasma membrane -# visible in the phase image. -# This is an example of an image translation task. -# We will apply spatial and intensity augmentations to train robust models -# and evaluate their performance using a regression approach. - +# +# In this exercise, we will _virtually stain_ the nuclei and plasma membrane from the quantitative phase image (QPI), i.e., translate QPI images into fluoresence images of nuclei and plasma membranes. +# QPI encodes multiple cellular structures and virtual staining decomposes these structures. After the model is trained, one only needs to acquire label-free QPI data. +# This strategy solves the problem as "multi-spectral imaging", but is more compatible with live cell imaging and high-throughput screening. +# Virtual staining is often a step towards multiple downstream analyses: segmentation, tracking, and cell state phenotyping. +# +# In this exercise, you will: +# - Train a model to predict the fluorescence images of nuclei and plasma membranes from QPI images +# - Make it robust to variations in imaging conditions using data augmentions +# - Segment the cells +# - Use regression and segmentation metrics to evalute the models +# - Visualize the image transform learned by the model +# - Understand the failure modes of the trained model +# # [![HEK293T](https://raw.githubusercontent.com/mehta-lab/VisCy/main/docs/figures/svideo_1.png)](https://github.com/mehta-lab/VisCy/assets/67518483/d53a81eb-eb37-44f3-b522-8bd7bddc7755) # (Click on image to play video) - +# # %% [markdown] tags=[] # ### Goals -# #### Part 1: Learn to use iohub (I/O library), VisCy dataloaders, and TensorBoard. - -# - Use a OME-Zarr dataset of 34 FOVs of adenocarcinomic human alveolar basal epithelial cells (A549), -# each FOV has 3 channels (phase, nuclei, and cell membrane). -# The nuclei were stained with DAPI and the cell membrane with Cellmask. +# #### Part 1: Train a virtual staining model +# # - Explore OME-Zarr using [iohub](https://czbiohub-sf.github.io/iohub/main/index.html) # and the high-content-screen (HCS) format. -# - Use [MONAI](https://monai.io/) to implement data augmentations. - -# #### Part 2: Train and evaluate the model to translate phase into fluorescence. -# - Train a 2D UNeXt2 model to predict nuclei and membrane from phase images. -# - Compare the performance of the trained model and a pre-trained model. +# - Use our `viscy.data.HCSDataloader()` dataloader and explore the 3 channel (phase, fluoresecence nuclei and cell membrane) +# A549 cell dataset. +# - Implement data augmentations [MONAI](https://monai.io/) to train a robust model to imaging parameters and conditions. +# - Use tensorboard to log the augmentations, training and validation losses and batches +# - Start the training of the UNeXt2 model to predict nuclei and membrane from phase images. +# +# #### Part 2:Evaluate the model to translate phase into fluorescence. +# - Compare the performance of your trained model with the _VSCyto2D_ pre-trained model. # - Evaluate the model using pixel-level and instance-level metrics. - - -# Checkout [VisCy](https://github.com/mehta-lab/VisCy/tree/main/examples/demos), +# +# #### Part 3: Visualize the image transforms learned by the model and explore the model's regime of validity +# - Visualize the first 3 principal componets mapped to a color space in each encoder and decoder block. +# - Explore the model's regime of validity by applying blurring and scaling transforms to the input phase image. +# +# #### For more information: +# Checkout [VisCy](https://github.com/mehta-lab/VisCy), # our deep learning pipeline for training and deploying computer vision models # for image-based phenotyping including the robust virtual staining of landmark organelles. +# # VisCy exploits recent advances in data and metadata formats # ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, # [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). # ### References - # - [Liu, Z. and Hirata-Miyasaki, E. et al. (2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v2.full.pdf) # - [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning. eLife](https://elifesciences.org/articles/55502) # %% [markdown] tags=[] -#
-# The exercise is organized in 2 parts +#
+# The exercise is organized in 3 parts: #
    -#
  • Part 1 - Learn to use iohub (I/O library), VisCy dataloaders, and tensorboard.
  • -#
  • Part 2 - Train and evaluate the model to translate phase into fluorescence.
  • +#
  • Part 1 - Train a virtual staining model using iohub (I/O library), VisCy dataloaders, and tensorboard
  • +#
  • Part 2 - Evaluate the model to translate phase into fluorescence.
  • +#
  • Part 3 - Visualize the image transforms learned by the model and explore the model's regime of validity.
  • #
#
@@ -62,7 +72,7 @@ # Set your python kernel to 06_image_translation #
# %% [markdown] -# ## Part 1: Log training data to tensorboard, start training a model. +# # Part 1: Log training data to tensorboard, start training a model. # --------- # Learning goals: @@ -188,12 +198,14 @@ def launch_tensorboard(log_dir): # ## Load OME-Zarr Dataset # There should be 34 FOVs in the dataset. - +# # Each FOV consists of 3 channels of 2048x2048 images, # saved in the [High-Content Screening (HCS) layout](https://ngff.openmicroscopy.org/latest/#hcs-layout) # specified by the Open Microscopy Environment Next Generation File Format # (OME-NGFF). - +# +# The 3 channels correspond to the QPI, nuclei, and cell membrane. The nuclei were stained with DAPI and the cell membrane with Cellmask. +# # - The layout on the disk is: `row/col/field/pyramid_level/timepoint/channel/z/y/x.` # - These datasets only have 1 level in the pyramid (highest resolution) which is '0'. @@ -344,6 +356,7 @@ def log_batch_jupyter(batch): p1, p99 = np.percentile(batch_phase, (0.1, 99.9)) batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1) + n_channels = batch["target"].shape[1] + batch["source"].shape[1] plt.figure() fig, axes = plt.subplots( batch_size, n_channels, figsize=(n_channels * 2, batch_size * 2) @@ -525,10 +538,16 @@ def log_batch_jupyter(batch): normalizations = [ NormalizeSampled( - keys=source_channel + target_channel, + keys=source_channel, level="fov_statistics", subtrahend="mean", divisor="std", + ), + NormalizeSampled( + keys=target_channel, + level="fov_statistics", + subtrahend="median", + divisor="iqr", ) ] @@ -580,10 +599,16 @@ def log_batch_jupyter(batch): normalizations = [ NormalizeSampled( - keys=source_channel + target_channel, + keys=source_channel, level="fov_statistics", subtrahend="mean", divisor="std", + ), + NormalizeSampled( + keys=target_channel, + level="fov_statistics", + subtrahend="median", + divisor="iqr", ) ] @@ -636,7 +661,7 @@ def log_batch_jupyter(batch): # Create a 2D UNet. GPU_ID = 0 -BATCH_SIZE = 12 +BATCH_SIZE = 16 YX_PATCH_SIZE = (256, 256) # ####################### @@ -662,7 +687,7 @@ def log_batch_jupyter(batch): model_config=phase2fluor_config.copy(), loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), schedule="WarmupCosine", - lr=2e-4, + lr=6e-4, log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard. freeze_encoder=False, ) @@ -690,7 +715,7 @@ def log_batch_jupyter(batch): ) phase2fluor_2D_data.setup("fit") # fast_dev_run runs a single batch of data through the model to check for errors. -trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], fast_dev_run=True) +trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], precision='16-mixed' ,fast_dev_run=True) # trainer class takes the model and the data module as inputs. trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data) @@ -701,7 +726,7 @@ def log_batch_jupyter(batch): # Here we are creating a 2D UNet. GPU_ID = 0 -BATCH_SIZE = 12 +BATCH_SIZE = 16 YX_PATCH_SIZE = (256, 256) # Dictionary that specifies key parameters of the model. @@ -724,7 +749,7 @@ def log_batch_jupyter(batch): model_config=phase2fluor_config.copy(), loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), schedule="WarmupCosine", - lr=2e-4, + lr=6e-4, log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard. freeze_encoder=False, ) @@ -751,7 +776,7 @@ def log_batch_jupyter(batch): ) phase2fluor_2D_data.setup("fit") # fast_dev_run runs a single batch of data through the model to check for errors. -trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], fast_dev_run=True) +trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID],precision='16-mixed', fast_dev_run=True) # trainer class takes the model and the data module as inputs. trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data) @@ -811,12 +836,13 @@ def log_batch_jupyter(batch): n_samples = len(phase2fluor_2D_data.train_dataset) steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch. -n_epochs = 25 # Set this to 25-30 or the number of epochs you want to train for. +n_epochs = 80 # Set this to 80-100 or the number of epochs you want to train for. trainer = VSTrainer( accelerator="gpu", devices=[GPU_ID], max_epochs=n_epochs, + precision='16-mixed', log_every_n_steps=steps_per_epoch // 2, # log losses and image samples 2 times per epoch. logger=TensorBoardLogger( @@ -846,7 +872,7 @@ def log_batch_jupyter(batch): # # %% [markdown] tags=[] -# ## Part 2: Assess your trained model +# # Part 2: Assess your trained model # Now we will look at some metrics of performance of previous model. # We typically evaluate the model performance on a held out test data. @@ -897,7 +923,7 @@ def log_batch_jupyter(batch): # #```python #phase2fluor_model_ckpt = natsorted(glob( -# str(top_dir/"06_image_translation/backup/phase2fluor/version_3/checkpoints/*.ckpt") +# str(top_dir/"06_image_translation/backup/phase2fluor/version_0/checkpoints/*.ckpt") #))[-1] #```` # @@ -1084,10 +1110,9 @@ def process_image(image): # NOTE: if their model didn't go past epoch 5, lost their checkpoint, or didnt train anything. # Uncomment the next lines #phase2fluor_model_ckpt = natsorted(glob( -# str(top_dir/"06_image_translation/backup/phase2fluor/version_3/checkpoints/*.ckpt") +# str(top_dir/"06_image_translation/backup/phase2fluor/version_0/checkpoints/*.ckpt") #))[-1] - phase2fluor_config = dict( in_channels=1, out_channels=2, @@ -1488,6 +1513,12 @@ def min_max_scale(image:ArrayLike)->ArrayLike: # # +#%% [markdown] tags=[] +# # Part 3: Visualizing the encoder and decoder features & exploring the model's range of validity +# +# - In this section, we will visualize the encoder and decoder features of the model you trained. +# - We will also explore the model's range of validity by looking at the feature maps of the encoder and decoder. +# # %% [markdown] tags=[] #
#

Task 3.1: Let's look at what the model is learning

@@ -1696,10 +1727,16 @@ def clip_highlight(image: np.ndarray) -> np.ndarray: normalizations = [ NormalizeSampled( - keys=source_channel + target_channel, + keys=source_channel, level="fov_statistics", subtrahend="mean", divisor="std", + ), + NormalizeSampled( + keys=target_channel, + level="fov_statistics", + subtrahend="median", + divisor="iqr", ) ]