The exercise is organized in 3 parts.
* **Part 1** - Explore the data using tensorboard. Launch the training before lunch.
@@ -33,22 +33,24 @@
* **Part 2** - Evaluate the training with tensorboard. Train another model.
* **Part 3** - Tune the models to improve performance.
-
-📖 As you work through parts 2 and 3, please share the layouts of the models you train and their performance with everyone via [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) 📖.
-
-
-Our guesstimate is that each of the three parts will take ~1.5 hours, but don't rush parts 1 and 2 if you need more time with them.
-We will discuss your observations on google doc after checkpoints 2 and 3. The exercise is focused on understanding information contained in data, process of training and evaluating image translation models, and parameter exploration.
-There are a few coding tasks sprinkled in.
+"""
+# %% [markdown]
+"""
+📖 As you work through parts 2 and 3, please share the layouts of your models (output of torchview) and their performance with everyone via [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) 📖.
-Before you start,
+Our guesstimate is that each of the three parts will take ~1.5 hours. A reasonable 2D UNet can be trained in ~20 min on a typical AWS node.
+We will discuss your observations on google doc after checkpoints 2 and 3.
+The focus of the exercise is on understanding information content of the data, how to train and evaluate 2D image translation model, and explore some hyperparameters of the model. If you complete this exercise and have time to spare, try the bonus exercise on 3D image translation.
+"""
+# %% [markdown]
+"""
"""
# Part 1: Log training data to tensorboard, start training a model.
---------
@@ -60,11 +62,9 @@
- Log some patches to tensorboard.
- Initialize a 2D U-Net model for virtual staining
- Start training the model to predict nuclei and membrane from phase.
-
"""
# %% Imports and paths
-
from pathlib import Path
import matplotlib.pyplot as plt
@@ -76,9 +76,10 @@
from iohub import open_ome_zarr
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
+from skimage import metrics # for metrics.
+# %% Imports and paths
# pytorch lightning wrapper for Tensorboard.
-from tensorboard import notebook # for viewing tensorboard in notebook
from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard
# HCSDataModule makes it easy to load data during training.
@@ -99,22 +100,25 @@
# Create log directory if needed, and launch tensorboard
log_dir.mkdir(parents=True, exist_ok=True)
-# fmt: off
-%reload_ext tensorboard
-%tensorboard --logdir {log_dir}
-# fmt: on
+# %% [markdown] tags=[]
+'''
+The next cell starts tensorboard within the notebook.
+
+'''
+
+# %% Imports and paths tags=[]
+%reload_ext tensorboard
+%tensorboard --logdir {log_dir}
# %% [markdown]
"""
## Load Dataset.
-
-
There should be 301 FOVs in the dataset (12 GB compressed).
Each FOV consists of 3 channels of 2048x2048 images,
@@ -125,26 +129,24 @@
The layout on the disk is: row/col/field/pyramid_level/timepoint/channel/z/y/x.
Notice that labelling of nuclei channel is not complete - some cells are not expressing the fluorescent protein.
-
"""
# %%
-
dataset = open_ome_zarr(data_path)
print(f"Number of positions: {len(list(dataset.positions()))}")
# Use the field and pyramid_level below to visualize data.
-row = "0"
-col = "0"
-field = "23"
+row = 0
+col = 0
+field = 23 # TODO: Change this to explore data.
# This dataset contains images at 3 resolutions.
# '0' is the highest resolution
# '1' is down-scaled 2x2,
# '2' is down-scaled 4x4.
# Such datasets are called image pyramids.
-pyaramid_level = "0"
+pyaramid_level = 0
# `channel_names` is the metadata that is stored with data according to the OME-NGFF spec.
n_channels = len(dataset.channel_names)
@@ -166,13 +168,16 @@
plt.tight_layout()
# %% [markdown]
-"""
-## Initialize data loaders and see the samples in tensorboard.
+#
+#
+# ### Task 1.1
+#
+# Look at a couple different fields of view by changing the value in the cell above. See if you notice any missing or inconsistent staining.
+#
`
+# %% [markdown]
+"""
+## Explore the effects of augmentation on batch.
VisCy builds on top of PyTorch Lightning. PyTorch Lightning is a thin wrapper around PyTorch that allows rapid experimentation. It provides a [DataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) to handle loading and processing of data during training. VisCy provides a child class, `HCSDataModule` to make it intuitve to access data stored in the HCS layout.
@@ -180,13 +185,23 @@
- `source`: the input image, a tensor of size 1*1*Y*X
- `target`: the target image, a tensor of size 2*1*Y*X
- `index` : the tuple of (location of field in HCS layout, time, and z-slice) of the sample.
-
"""
+# %% [markdown]
+#
+#
+# ### Task 1.2
+#
+# Setup the data loader and log several batches to tensorboard.
+#
+# Based on the tensorboard images, what are the two channels in the target image?
+#
+# Note: If tensorboard is not showing images, try refreshing and using the "Images" tab.
+#
+
# %%
# Define a function to write a batch to tensorboard log.
-
def log_batch_tensorboard(batch, batchno, writer, card_name):
"""
Logs a batch of images to TensorBoard.
@@ -228,11 +243,57 @@ def log_batch_tensorboard(batch, batchno, writer, card_name):
writer.add_image(card_name, grid, batchno)
+# %%
+# Define a function to visualize a batch on jupyter, in case tensorboard is finicky
+
+def log_batch_jupyter(batch):
+ """
+ Logs a batch of images on jupyter using ipywidget.
+
+ Args:
+ batch (dict): A dictionary containing the batch of images to be logged.
+
+ Returns:
+ None
+ """
+ batch_phase = batch["source"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.
+ batch_size = batch_phase.shape[0]
+ batch_membrane = batch["target"][:, 1, 0, :, :].unsqueeze(
+ 1
+ ) # batch_size x 1 x Y x X tensor.
+ batch_nuclei = batch["target"][:, 0, 0, :, :].unsqueeze(
+ 1
+ ) # batch_size x 1 x Y x X tensor.
+
+ p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))
+ batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)
+
+ p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))
+ batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)
+
+ p1, p99 = np.percentile(batch_phase, (0.1, 99.9))
+ batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)
+
+ plt.figure()
+ fig, axes = plt.subplots(batch_size, n_channels, figsize=(10, 10))
+ [N, C, H, W] = batch_phase.shape
+ for sample_id in range(batch_size):
+ axes[sample_id, 0].imshow(batch_phase[sample_id,0])
+ axes[sample_id, 1].imshow(batch_nuclei[sample_id,0])
+ axes[sample_id, 2].imshow(batch_membrane[sample_id,0])
+
+ for i in range(n_channels):
+ axes[sample_id, i].axis("off")
+ axes[sample_id, i].set_title(dataset.channel_names[i])
+ plt.tight_layout()
+ plt.show()
+
+
# %%
# Initialize the data module.
-BATCH_SIZE = 42
+BATCH_SIZE = 4
# 42 is a perfectly reasonable batch size. After all, it is the answer to the ultimate question of life, the universe and everything.
# More seriously, batch size does not have to be a power of 2.
# See: https://sebastianraschka.com/blog/2022/batch-size-2.html
@@ -240,7 +301,7 @@ def log_batch_tensorboard(batch, batchno, writer, card_name):
data_module = HCSDataModule(
data_path,
source_channel="Phase",
- target_channel=["Nuclei", "Membrane"],
+ target_channel=["Membrane", "Nuclei"],
z_window_size=1,
split_ratio=0.8,
batch_size=BATCH_SIZE,
@@ -262,70 +323,58 @@ def log_batch_tensorboard(batch, batchno, writer, card_name):
# Draw a batch and write to tensorboard.
batch = next(iter(train_dataloader))
log_batch_tensorboard(batch, 0, writer, "augmentation/none")
-
-# Iterate through all the batches and log them to tensorboard.
-for i, batch in enumerate(train_dataloader):
- log_batch_tensorboard(batch, i, writer, "augmentation/none")
writer.close()
+
# %% [markdown]
-"""
-There are multiple ways of seeing the tensorboard.
-1. Jupyter lab forwards the tensorboard port to the browser. Go to http://localhost:6006/ to see the tensorboard.
-2. You likely have an open viewer in the first cell where you loaded tensorboard jupyter extension.
-3. If you want to see tensorboard in a specific cell, use the following code.
-```
-notebook.list() # View open TensorBoard instances
-notebook.display(port=6006, height=800) # Display the TensorBoard instance specified by the port.
-```
-"""
+# Visualize directly on Jupyter ☄️, if your tensorboard is causing issues.
+
+# %%
+%matplotlib inline
+log_batch_jupyter(batch)
# %% [markdown]
"""
## View augmentations using tensorboard.
-
-
-Task 1.3
-Turn on augmentation and view the batch in tensorboard.
"""
# %%
-##########################
-######## TODO ########
-##########################
-
-# Write code to turn on augmentations, change batch sizes and log them to tensorboard.
-# See how the training data changes as a function of these parameters.
-# Remember to call `data_module.setup("fit")` after changing the parameters.
-
-
-# %% tags=["solution"]
-##########################
-######## Solution ########
-##########################
-
+# Here we turn on data augmentation and rerun setup
data_module.augment = True
-data_module.batch_size = 21
-data_module.split_ratio = 0.8
data_module.setup("fit")
-train_dataloader = data_module.train_dataloader()
+# get the new data loader with augmentation turned on
+augmented_train_dataloader = data_module.train_dataloader()
+
# Draw batches and write to tensorboard
writer = SummaryWriter(log_dir=f"{log_dir}/view_batch")
-for i, batch in enumerate(train_dataloader):
- log_batch_tensorboard(batch, i, writer, "augmentation/some")
+augmented_batch = next(iter(augmented_train_dataloader))
+log_batch_tensorboard(augmented_batch, 0, writer, "augmentation/some")
writer.close()
# %% [markdown]
-"""
-## Construct a 2D U-Net for image translation.
-See ``viscy.unet.networks.Unet2D.Unet2d`` for configuration details.
-We setup a fresh data module and instantiate the trainer class.
-"""
+# Visualize directly on Jupyter ☄️
# %%
+log_batch_jupyter(augmented_batch)
-# The entire training loop is contained in this cell.
+# %% [markdown]
+#
+#
+# ### Task 1.3
+# Can you tell what augmentation were applied from looking at the augmented images in Tensorboard?
+#
+# Check your answer using the source code [here](https://github.com/mehta-lab/VisCy/blob/b89f778b34735553cf155904eef134c756708ff2/viscy/light/data.py#L529).
+#
+# %% [markdown]
+"""
+## Train a 2D U-Net model to predict nuclei and membrane from phase.
+
+### Construct a 2D U-Net
+See ``viscy.unet.networks.Unet2D.Unet2d`` ([source code](https://github.com/mehta-lab/VisCy/blob/7c5e4c1d68e70163cf514d22c475da8ea7dc3a88/viscy/unet/networks/Unet2D.py#L7)) for configuration details.
+"""
+# %%
+# Create a 2D UNet.
GPU_ID = 0
BATCH_SIZE = 10
YX_PATCH_SIZE = (512, 512)
@@ -347,15 +396,21 @@ def log_batch_tensorboard(batch, batchno, writer, card_name):
batch_size=BATCH_SIZE,
loss_function=torch.nn.functional.l1_loss,
schedule="WarmupCosine",
- log_num_samples=10, # Number of samples from each batch to log to tensorboard.
+ log_num_samples=5, # Number of samples from each batch to log to tensorboard.
example_input_yx_shape=YX_PATCH_SIZE,
)
-# Reinitialize the data module.
+
+# %% [markdown]
+"""
+### Instantiate data module and trainer, test that we are setup to launch training.
+"""
+# %%
+# Setup the data module.
phase2fluor_data = HCSDataModule(
data_path,
source_channel="Phase",
- target_channel=["Nuclei", "Membrane"],
+ target_channel=["Membrane", "Nuclei"],
z_window_size=1,
split_ratio=0.8,
batch_size=BATCH_SIZE,
@@ -365,27 +420,44 @@ def log_batch_tensorboard(batch, batchno, writer, card_name):
augment=True,
)
phase2fluor_data.setup("fit")
+# fast_dev_run runs a single batch of data through the model to check for errors.
+trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], fast_dev_run=True)
+
+# trainer class takes the model and the data module as inputs.
+trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)
# %% [markdown]
-"""
-
-Task 1.4
-Setup the training for ~30 epochs
-
+# ## View model graph.
+#
+# PyTorch uses dynamic graphs under the hood. The graphs are constructed on the fly. This is in contrast to TensorFlow, where the graph is constructed before the training loop and remains static. In other words, the graph of the network can change with every forward pass. Therefore, we need to supply an input tensor to construct the graph. The input tensor can be a random tensor of the correct shape and type. We can also supply a real image from the dataset. The latter is more useful for debugging.
-Tips:
-- Set ``default_root_dir`` to store the logs and checkpoints
-in a specific directory.
-"""
+# %% [markdown]
+#
+#
+# ### Task 1.4
+# Run the next cell to generate a graph representation of the model architecture. Can you recognize the UNet structure and skip connections in this graph visualization?
+#
-# %% Setup trainer and check for errors.
+# %%
+# visualize graph of phase2fluor model as image.
+model_graph_phase2fluor = torchview.draw_graph(
+ phase2fluor_model,
+ phase2fluor_data.train_dataset[0]["source"],
+ depth=2, # adjust depth to zoom in.
+ device="cpu",
+)
+# Print the image of the model.
+model_graph_phase2fluor.visual_graph
-# fast_dev_run runs a single batch of data through the model to check for errors.
-trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], fast_dev_run=True)
+# %% [markdown]
+"""
+
-# trainer class takes the model and the data module as inputs.
-trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)
+### Task 1.5
+Start training by running the following cell. Check the new logs on the tensorboard.
+
+"""
# %%
@@ -393,66 +465,187 @@ def log_batch_tensorboard(batch, batchno, writer, card_name):
GPU_ID = 0
n_samples = len(phase2fluor_data.train_dataset)
steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch.
-n_epochs = 30
+n_epochs = 50 # Set this to 50 or the number of epochs you want to train for.
trainer = VSTrainer(
accelerator="gpu",
devices=[GPU_ID],
max_epochs=n_epochs,
- # log losses and image samples 2 times per epoch.
log_every_n_steps=steps_per_epoch // 2,
- # lightning trainer transparently saves logs and model checkpoints in this directory
+ # log losses and image samples 2 times per epoch.
logger=TensorBoardLogger(
save_dir=log_dir,
+ # lightning trainer transparently saves logs and model checkpoints in this directory.
name="phase2fluor",
log_graph=True,
- ),
-)
-
-# Launch training.
+ ),
+ )
+# Launch training and check that loss and images are being logged on tensorboard.
trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)
-
# %% [markdown]
"""
-Checkpoint 1
+
+## Checkpoint 1
Now the training has started,
we can come back after a while and evaluate the performance!
"""
-# %% [markdown]
+# %%
"""
# Part 2: Assess previous model, train fluorescence to phase contrast translation model.
--------------------------------------------------
+"""
-Learning goals:
-- Visualize the previous model and training with tensorboard
-- Train fluorescence to phase contrast translation model
-- Compare the performance of the two models.
+# %% [markdown]
+"""
+We now look at some metrics of performance of previous model. We typically evaluate the model performance on a held out test data. We will use the following metrics to evaluate the accuracy of regression of the model:
+- [Person Correlation](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient).
+- [Structural similarity](https://en.wikipedia.org/wiki/Structural_similarity) (SSIM).
+You should also look at the validation samples on tensorboard (hint: the experimental data in nuclei channel is imperfect.)
"""
-# %%
+# %% [markdown]
+"""
+
-# PyTorch uses dynamic graphs under the hood. The graphs are constructed on the fly. This is in contrast to TensorFlow, where the graph is constructed before the training loop and remains static. In other words, the graph of the network can change with every forward pass. Therefore, we need to supply an input tensor to construct the graph. The input tensor can be a random tensor of the correct shape and type. We can also supply a real image from the dataset. The latter is more useful for debugging.
+### Task 2.1 Define metrics
-# visualize graph.
-model_graph_phase2fluor = torchview.draw_graph(
- phase2fluor_model,
- phase2fluor_data.train_dataset[0]["source"],
+For each of the above metrics, write a brief definition of what they are and what they mean for this image translation task.
+
+
+"""
+
+# %% [markdown]
+# ```
+# #######################
+# ##### Todo ############
+# #######################
+# ```
+#
+# - Pearson Correlation:
+#
+# - Structural similarity:
+
+# %% Compute metrics directly and plot here.
+test_data_path = Path(
+ "~/data/04_image_translation/HEK_nuclei_membrane_test.zarr"
+).expanduser()
+
+test_data = HCSDataModule(
+ test_data_path,
+ source_channel="Phase",
+ target_channel=["Membrane", "Nuclei"],
+ z_window_size=1,
+ batch_size=1,
+ num_workers=8,
+ architecture="2D",
+)
+test_data.setup("test")
+
+test_metrics = pd.DataFrame(
+ columns=["pearson_nuc", "SSIM_nuc", "pearson_mem", "SSIM_mem"]
+)
+
+
+def min_max_scale(input):
+ return (input - np.min(input)) / (np.max(input) - np.min(input))
+
+
+# %% Compute metrics directly and plot here.
+for i, sample in enumerate(test_data.test_dataloader()):
+ phase_image = sample["source"]
+ with torch.inference_mode(): # turn off gradient computation.
+ predicted_image = phase2fluor_model(phase_image)
+
+ target_image = (
+ sample["target"].cpu().numpy().squeeze(0)
+ ) # Squeezing batch dimension.
+ predicted_image = predicted_image.cpu().numpy().squeeze(0)
+ phase_image = phase_image.cpu().numpy().squeeze(0)
+ target_mem = min_max_scale(target_image[1, 0, :, :])
+ target_nuc = min_max_scale(target_image[0, 0, :, :])
+ # slicing channel dimension, squeezing z-dimension.
+ predicted_mem = min_max_scale(predicted_image[1, :, :, :].squeeze(0))
+ predicted_nuc = min_max_scale(predicted_image[0, :, :, :].squeeze(0))
+
+ # Compute SSIM and pearson correlation.
+ ssim_nuc = metrics.structural_similarity(target_nuc, predicted_nuc, data_range=1)
+ ssim_mem = metrics.structural_similarity(target_mem, predicted_mem, data_range=1)
+ pearson_nuc = np.corrcoef(target_nuc.flatten(), predicted_nuc.flatten())[0, 1]
+ pearson_mem = np.corrcoef(target_mem.flatten(), predicted_mem.flatten())[0, 1]
+
+ test_metrics.loc[i] = {
+ "pearson_nuc": pearson_nuc,
+ "SSIM_nuc": ssim_nuc,
+ "pearson_mem": pearson_mem,
+ "SSIM_mem": ssim_mem,
+ }
+
+test_metrics.boxplot(
+ column=["pearson_nuc", "SSIM_nuc", "pearson_mem", "SSIM_mem"],
+ rot=30,
+)
+
+
+# %% [markdown] tags=[]
+"""
+
+
+### Task 2.2 Train fluorescence to phase contrast translation model
+
+Instantiate a data module, model, and trainer for fluorescence to phase contrast translation. Copy over the code from previous cells and update the parameters. Give the variables and paths a different name/suffix (fluor2phase) to avoid overwriting objects used to train phase2fluor models.
+
+"""
+# %% tags=[]
+##########################
+######## TODO ########
+##########################
+
+fluor2phase_data = HCSDataModule(
+ # Your code here (copy from above and modify as needed)
+)
+fluor2phase_data.setup("fit")
+
+# Dictionary that specifies key parameters of the model.
+fluor2phase_config = {
+ # Your config here
+}
+
+fluor2phase_model = VSUNet(
+ # Your code here (copy from above and modify as needed)
+)
+
+trainer = VSTrainer(
+ # Your code here (copy from above and modify as needed)
+)
+trainer.fit(fluor2phase_model, datamodule=fluor2phase_data)
+
+
+# Visualize the graph of fluor2phase model as image.
+model_graph_fluor2phase = torchview.draw_graph(
+ fluor2phase_model,
+ fluor2phase_data.train_dataset[0]["source"],
depth=2, # adjust depth to zoom in.
device="cpu",
)
-# Increase the depth to zoom in.
-model_graph_phase2fluor.visual_graph
+model_graph_fluor2phase.visual_graph
+
+# %% tags=["solution"]
+
+##########################
+######## Solution ########
+##########################
+
+# The entire training loop is contained in this cell.
-# %% tags = ["solution"]
fluor2phase_data = HCSDataModule(
data_path,
- source_channel="Nuclei",
+ source_channel="Membrane",
target_channel="Phase",
z_window_size=1,
split_ratio=0.8,
@@ -480,126 +673,172 @@ def log_batch_tensorboard(batch, batchno, writer, card_name):
batch_size=BATCH_SIZE,
loss_function=torch.nn.functional.mse_loss,
schedule="WarmupCosine",
- log_num_samples=10,
+ log_num_samples=5,
example_input_yx_shape=YX_PATCH_SIZE,
)
-n_samples = len(fluor2phase_data.train_dataset)
-steps_per_epoch = n_samples // BATCH_SIZE
-n_epochs = 30
trainer = VSTrainer(
accelerator="gpu",
devices=[GPU_ID],
max_epochs=n_epochs,
- log_every_n_steps=steps_per_epoch,
+ log_every_n_steps=steps_per_epoch // 2,
logger=TensorBoardLogger(
save_dir=log_dir,
+ # lightning trainer transparently saves logs and model checkpoints in this directory.
name="fluor2phase",
log_graph=True,
),
)
trainer.fit(fluor2phase_model, datamodule=fluor2phase_data)
-# %%
-# Visualize the graph of fluor2phase model.
+
+# Visualize the graph of fluor2phase model as image.
model_graph_fluor2phase = torchview.draw_graph(
- phase2fluor_model,
- phase2fluor_data.train_dataset[0]["source"],
+ fluor2phase_model,
+ fluor2phase_data.train_dataset[0]["source"],
depth=2, # adjust depth to zoom in.
device="cpu",
)
model_graph_fluor2phase.visual_graph
-# %% [markdown]
+# %% [markdown] tags=[]
"""
-We now look at some metrics of performance. Loss is a differentiable metric. But, several non-differentiable metrics are useful to assess the performance of the model. We typically evaluate the model performance on a held out test data. We will use the following metrics to evaluate the accuracy of regression of the model:
-- [Coefficient of determination](https://en.wikipedia.org/wiki/Coefficient_of_determination): $R^2$
-- [Person Correlation](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient).
-- [Structural similarity](https://en.wikipedia.org/wiki/Structural_similarity) (SSIM):
+
+
+### Task 2.3
+While your model is training, let's think about the following questions:
+- What is the information content of each channel in the dataset?
+- How would you use image translation models?
+- What can you try to improve the performance of each model?
+
"""
# %%
-
-# TODO: set following parameters, specifically path to checkpoint, and log the metrics.
test_data_path = Path(
"~/data/04_image_translation/HEK_nuclei_membrane_test.zarr"
).expanduser()
-model_version = "phase2fluor"
-save_dir = Path(log_dir, "test")
-ckpt_path = Path(
- r"/home/mehtas/data/04_image_translation/logs/phase2fluor/lightning_logs/version_0/checkpoints/epoch=29-step=720.ckpt"
-) # prefix the string with 'r' to avoid the need for escape characters.
-### END TODO
test_data = HCSDataModule(
test_data_path,
- source_channel="Phase",
- target_channel=["Nuclei", "Membrane"],
+ source_channel="Nuclei", # or Membrane, depending on your choice of source
+ target_channel="Phase",
z_window_size=1,
batch_size=1,
num_workers=8,
architecture="2D",
)
test_data.setup("test")
-trainer = VSTrainer(
- accelerator="gpu",
- devices=[GPU_ID],
- logger=CSVLogger(save_dir=save_dir, version=model_version),
-)
-trainer.test(
- phase2fluor_model,
- datamodule=test_data,
- ckpt_path=ckpt_path,
+
+test_metrics = pd.DataFrame(
+ columns=["pearson_phase", "SSIM_phase"]
)
-# read metrics and plot
-metrics = pd.read_csv(Path(save_dir, "lightning_logs", model_version, "metrics.csv"))
-metrics.boxplot(
- column=[
- "test_metrics/r2_step",
- "test_metrics/pearson_step",
- "test_metrics/SSIM_step",
- ],
+
+
+def min_max_scale(input):
+ return (input - np.min(input)) / (np.max(input) - np.min(input))
+
+
+# %%
+for i, sample in enumerate(test_data.test_dataloader()):
+ source_image = sample["source"]
+ with torch.inference_mode(): # turn off gradient computation.
+ predicted_image = fluor2phase_model(source_image)
+
+ target_image = (
+ sample["target"].cpu().numpy().squeeze(0)
+ ) # Squeezing batch dimension.
+ predicted_image = predicted_image.cpu().numpy().squeeze(0)
+ source_image = source_image.cpu().numpy().squeeze(0)
+ target_phase = min_max_scale(target_image[0, 0, :, :])
+ # slicing channel dimension, squeezing z-dimension.
+ predicted_phase = min_max_scale(predicted_image[0, :, :, :].squeeze(0))
+
+ # Compute SSIM and pearson correlation.
+ ssim_phase = metrics.structural_similarity(target_phase, predicted_phase, data_range=1)
+ pearson_phase = np.corrcoef(target_phase.flatten(), predicted_phase.flatten())[0, 1]
+
+ test_metrics.loc[i] = {
+ "pearson_phase": pearson_phase,
+ "SSIM_phase": ssim_phase,
+ }
+
+test_metrics.boxplot(
+ column=["pearson_phase", "SSIM_phase"],
rot=30,
)
-# %% [markdown]
+
+# %% [markdown] tags=[]
"""
-Checkpoint 2
-Please summarize hyperparameters and performance of your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z)
-
-Now that you have trained two models, let's think about the following questions:
-- What is the information content of each channel in the dataset?
-- How would you use image translation models?
-- What can you try to improve the performance of each model?
+## Checkpoint 2
+When your model finishes training, please summarize hyperparameters and performance of your models in the [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z)
"""
-# %% [markdown]
+# %%
tags=[]
"""
# Part 3: Tune the models.
--------------------------------------------------
-Learning goals:
+Learning goals: Understand how data, model capacity, and training parameters control the performance of the model. Your goal is to try to underfit or overfit the model.
+"""
+
-- Tweak model hyperparameters, such as number of filters at each depth.
-- Adjust learning rate to improve performance.
+# %% [markdown] tags=[]
"""
+
-# %%
-# %%
+### Task 3.1
+
+- Choose a model you want to train (phase2fluor or fluor2phase).
+- Set up a configuration that you think will improve the performance of the model
+- Consider modifying the learning rate and see how it changes performance
+- Use training loop illustrated in previous cells to train phase2fluor and fluor2phase models to prototype your own training loop.
+- Add code to evaluate the model using Pearson Correlation and SSIM
+
+As your model is training, please document hyperparameters, snapshots of predictions on validation set, and loss curves for your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z)
+
+
+"""
+# %% tags=[]
##########################
######## TODO ########
##########################
-# Choose a model you want to train (phase2fluor or fluor2phase).
-# Create a config to double the number of filters at each stage.
-# Use training loop illustrated in previous cells to train phase2fluor and fluor2phase models to prototype your own training loop.
+tune_data = HCSDataModule(
+ # Your code here (copy from above and modify as needed)
+)
+tune_data.setup("fit")
+# Dictionary that specifies key parameters of the model.
+tune_config = {
+ # Your config here
+}
-# %% tags = ["solution"]
+tune_model = VSUNet(
+ # Your code here (copy from above and modify as needed)
+)
+
+trainer = VSTrainer(
+ # Your code here (copy from above and modify as needed)
+)
+trainer.fit(tune_model, datamodule=tune_data)
+
+
+# Visualize the graph of fluor2phase model as image.
+model_graph_tune = torchview.draw_graph(
+ tune_model,
+ tune_data.train_dataset[0]["source"],
+ depth=2, # adjust depth to zoom in.
+ device="cpu",
+)
+model_graph_tune.visual_graph
+
+
+# %% tags=["solution"]
##########################
######## Solution ########
@@ -621,7 +860,7 @@ def log_batch_tensorboard(batch, batchno, writer, card_name):
batch_size=BATCH_SIZE,
loss_function=torch.nn.functional.l1_loss,
schedule="WarmupCosine",
- log_num_samples=10,
+ log_num_samples=5,
example_input_yx_shape=YX_PATCH_SIZE,
)
@@ -641,16 +880,7 @@ def log_batch_tensorboard(batch, batchno, writer, card_name):
) # Set fast_dev_run to False to train the model.
trainer.fit(phase2fluor_wider_model, datamodule=phase2fluor_data)
-# %%
-##########################
-######## TODO ########
-##########################
-
-# Choose a model you want to train (phase2fluor or fluor2phase).
-# Train it with lower learning rate to see how the performance changes.
-
-
-# %% tags = ["solution"]
+# %% tags=["solution"]
##########################
######## Solution ########
@@ -663,7 +893,7 @@ def log_batch_tensorboard(batch, batchno, writer, card_name):
# lower learning rate by 5 times
lr=2e-4,
schedule="WarmupCosine",
- log_num_samples=10,
+ log_num_samples=5,
example_input_yx_shape=YX_PATCH_SIZE,
)
@@ -683,12 +913,13 @@ def log_batch_tensorboard(batch, batchno, writer, card_name):
trainer.fit(phase2fluor_slow_model, datamodule=phase2fluor_data)
-# %% [markdown]
+# %% [markdown] tags=[]
"""
-Checkpoint 3
+
+## Checkpoint 3
Congratulations! You have trained several image translation models now!
-Please document hyperparameters, snapshots of predictioons on validation set, and loss curves for your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z)
+Please document hyperparameters, snapshots of predictions on validation set, and loss curves for your models and add the final perforance in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z). We'll discuss our combined results as a group.
"""