Skip to content

Commit

Permalink
Demo for VSCyto2D and VSCyto3D (#94)
Browse files Browse the repository at this point in the history
* first pass on VSCyto2D demo

* bumping to cellpose 3 (#92)

* adding fix for the scale metadata to handle positions as well. (#93)

* adding the working example. pending missing comments.

* adding vscyto3d example

* adding readme

* minor typo on readme.md

* updating scripts to show the experimental fluorescence and adding the neuromast inference demo.

* adding nueromast demo

* touch README

* update paths for vscyto2d

* update paths for vscyto3d

* update paths for vsneuromast

* use new artifact name

* adding plotting function.

---------

Co-authored-by: Ziwen Liu <[email protected]>
  • Loading branch information
edyoshikun and ziw-liu authored Jun 24, 2024
1 parent 6a3457e commit a5b51c3
Show file tree
Hide file tree
Showing 5 changed files with 526 additions and 0 deletions.
29 changes: 29 additions & 0 deletions examples/demos/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# VisCy usage examples

Examples scripts showcasing the usage of VisCy.

## Virtual staining

### Training

- WIP: DL@MBL notebooks

### Inference

- [Inference with VSCyto2D](./demo_vscyto2d.py):
2D inference example on 20x A549 cell data. (Phase to nuclei and plasma membrane).
- [Inference with VSCyto3D](./demos/demo_vscyto3d.py):
3D inference example on 63x HEK293T cell data. (Phase to nuclei and plasma membrane).
- [Inference VSNeuromast](./demo_vsneuromast.py):
3D inference example of 63x zebrafish neuromast data (Phase to nuclei and plasma membrane)

## Notes

To run the examples, execute each individual script, for example:

```sh
python demo_vscyto2d.py
```

These scripts can also be ran interactively in many IDEs as notebooks,
for example in VS Code, PyCharm, and Spyder.
137 changes: 137 additions & 0 deletions examples/demos/demo_vscyto2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# %% [markdown]
"""
# 2D Virtual Staining of A549 Cells
---
## Prediction using the VSCyto2D to predict nuclei and plasma membrane from phase.
This example shows how to virtually stain A549 cells using the _VSCyto2D_ model.
The model is trained to predict the membrane and nuclei channels from the phase channel.
"""
# %% Imports and paths
from pathlib import Path

from iohub import open_ome_zarr
from plot import plot_vs_n_fluor

# Viscy classes for the trainer and model
from viscy.data.hcs import HCSDataModule
from viscy.light.engine import FcmaeUNet
from viscy.light.predict_writer import HCSPredictionWriter
from viscy.light.trainer import VSTrainer
from viscy.transforms import NormalizeSampled

# %% [markdown]
"""
## Data and Model Paths
The dataset and model checkpoint files need to be downloaded before running this example.
"""

# %%
# Download from
# https://public.czbiohub.org/comp.micro/viscy/datasets/testing/VSCyto2D/a549_hoechst_cellmask_test.zarr
input_data_path = "datasets/testing/VSCyto2D/a549_hoechst_cellmask_test.zarr"
# Download from GitHub release page of v0.1.0
model_ckpt_path = "VisCy-0.1.0-VS-models/VSCyto2D/epoch=399-step=23200.ckpt"
# Zarr store to save the predictions
output_path = "./a549_prediction.zarr"
# FOV of interest
fov = "0/0/0"

input_data_path = Path(input_data_path) / fov

# %%
# Create the VSCyto2D network

# NOTE: Change the following parameters as needed.
BATCH_SIZE = 10
YX_PATCH_SIZE = (384, 384)
NUM_WORKERS = 8
phase_channel_name = "Phase3D"

# %%[markdown]
"""
For this example we will use the following parameters:
For more information on the VSCyto2D model,
see ``viscy.unet.networks.fcmae``
([source code](https://github.com/mehta-lab/VisCy/blob/6a3457ec8f43ecdc51b1760092f1a678ed73244d/viscy/unet/networks/fcmae.py#L398))
for configuration details.
"""
# %%
# Setup the data module.
data_module = HCSDataModule(
data_path=input_data_path,
source_channel=phase_channel_name,
target_channel=["Membrane", "Nuclei"],
z_window_size=1,
split_ratio=0.8,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
architecture="2D",
yx_patch_size=YX_PATCH_SIZE,
normalizations=[
NormalizeSampled(
[phase_channel_name],
level="fov_statistics",
subtrahend="median",
divisor="iqr",
)
],
)
data_module.prepare_data()
data_module.setup(stage="predict")
# %%
# Setup the model.
# Dictionary that specifies key parameters of the model.
config_VSCyto2D = {
"in_channels": 1,
"out_channels": 2,
"encoder_blocks": [3, 3, 9, 3],
"dims": [96, 192, 384, 768],
"decoder_conv_blocks": 2,
"stem_kernel_size": [1, 2, 2],
"in_stack_depth": 1,
"pretraining": False,
}

model_VSCyto2D = FcmaeUNet.load_from_checkpoint(
model_ckpt_path, model_config=config_VSCyto2D
)
model_VSCyto2D.eval()

# %%
# Setup the Trainer
trainer = VSTrainer(
accelerator="gpu",
callbacks=[HCSPredictionWriter(output_path)],
)

# Start the predictions
trainer.predict(
model=model_VSCyto2D,
datamodule=data_module,
return_predictions=False,
)

# %%
# Open the output_zarr store and inspect the output
# Show the individual channels and the fused in a 1x3 plot
output_path = Path(output_path) / fov

# %%
# Open the predicted data
vs_store = open_ome_zarr(output_path, mode="r")
# Get the 2D images
vs_nucleus = vs_store[0][0, 0, 0] # (t,c,z,y,x)
vs_membrane = vs_store[0][0, 1, 0] # (t,c,z,y,x)
# Open the experimental fluorescence
fluor_store = open_ome_zarr(input_data_path, mode="r")
# Get the 2D images
# NOTE: Channel indeces hardcoded for this dataset
fluor_nucleus = fluor_store[0][0, 1, 0] # (t,c,z,y,x)
fluor_membrane = fluor_store[0][0, 2, 0] # (t,c,z,y,x)

# Plot
plot_vs_n_fluor(vs_nucleus, vs_membrane, fluor_nucleus, fluor_membrane)

vs_store.close()
fluor_store.close()
142 changes: 142 additions & 0 deletions examples/demos/demo_vscyto3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# %% [markdown]
"""
# 3D Virtual Staining of HEK293T Cells
---
## Prediction using the VSCyto3D to predict nuclei and membrane from phase.
This example shows how to virtually stain HEK293T cells using the _VSCyto3D_ model.
The model is trained to predict the membrane and nuclei channels from the phase channel.
"""
# %% Imports and paths
from pathlib import Path

from iohub import open_ome_zarr
from plot import plot_vs_n_fluor

from viscy.data.hcs import HCSDataModule

# Viscy classes for the trainer and model
from viscy.light.engine import VSUNet
from viscy.light.predict_writer import HCSPredictionWriter
from viscy.light.trainer import VSTrainer
from viscy.transforms import NormalizeSampled

# %% [markdown]
"""
## Data and Model Paths
The dataset and model checkpoint files need to be downloaded before running this example.
"""

# %%
# Download from
# https://public.czbiohub.org/comp.micro/viscy/datasets/testing/VSCyto3D/hek_h2b_caax_63x.zarr
input_data_path = "datasets/testing/VSCyto3D/hek_h2b_caax_63x.zarr"
# Download from GitHub release page of v0.1.0
model_ckpt_path = "VisCy-0.1.0-VS-models/VSCyto3D/epoch=48-step=18130.ckpt"
# Zarr store to save the predictions
output_path = "./hek_prediction_3d.zarr"
# FOV of interest
fov = "plate/0/0"

input_data_path = Path(input_data_path) / fov

# %%
# Create the VSCyto3D model

# NOTE: Change the following parameters as needed.
BATCH_SIZE = 2
YX_PATCH_SIZE = (384, 384)
NUM_WORKERS = 8
phase_channel_name = "Phase3D"

# %%[markdown]
"""
For this example we will use the following parameters:
### For more information on the VSCyto3D model:
See ``viscy.unet.networks.fcmae``
([source code](https://github.com/mehta-lab/VisCy/blob/6a3457ec8f43ecdc51b1760092f1a678ed73244d/viscy/unet/networks/unext2.py#L252))
for configuration details.
"""
# %%
# Setup the data module.
data_module = HCSDataModule(
data_path=input_data_path,
source_channel=phase_channel_name,
target_channel=["Membrane", "Nuclei"],
z_window_size=5,
split_ratio=0.8,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
architecture="UNeXt2",
yx_patch_size=YX_PATCH_SIZE,
normalizations=[
NormalizeSampled(
[phase_channel_name],
level="fov_statistics",
subtrahend="median",
divisor="iqr",
)
],
)
data_module.prepare_data()
data_module.setup(stage="predict")
# %%
# Setup the model.
# Dictionary that specifies key parameters of the model.
config_VSCyto3D = {
"in_channels": 1,
"out_channels": 2,
"in_stack_depth": 5,
"backbone": "convnextv2_tiny",
"stem_kernel_size": (5, 4, 4),
"decoder_mode": "pixelshuffle",
"head_expansion_ratio": 4,
"head_pool": True,
}

model_VSCyto3D = VSUNet.load_from_checkpoint(
model_ckpt_path, architecture="UNeXt2", model_config=config_VSCyto3D
)
model_VSCyto3D.eval()

# %%
# Setup the Trainer
trainer = VSTrainer(
accelerator="gpu",
callbacks=[HCSPredictionWriter(output_path)],
)

# Start the predictions
trainer.predict(
model=model_VSCyto3D,
datamodule=data_module,
return_predictions=False,
)

# %%
# Open the output_zarr store and inspect the output
# Show the individual channels and the fused in a 1x3 plot
output_path = Path(output_path) / fov

# %%
# Open the predicted data
vs_store = open_ome_zarr(output_path, mode="r")
T, C, Z, Y, X = vs_store.data.shape

# Get a z-slice
z_slice = Z // 2 # NOTE: using the middle slice of the stack. Change as needed.
vs_nucleus = vs_store[0][0, 0, z_slice] # (t,c,z,y,x)
vs_membrane = vs_store[0][0, 1, z_slice] # (t,c,z,y,x)
# Open the experimental fluorescence
fluor_store = open_ome_zarr(input_data_path, mode="r")
# Get the 2D images
# NOTE: Channel indeces hardcoded for this dataset
fluor_nucleus = fluor_store[0][0, 2, z_slice] # (t,c,z,y,x)
fluor_membrane = fluor_store[0][0, 1, z_slice] # (t,c,z,y,x)

# Plot
plot_vs_n_fluor(vs_nucleus, vs_membrane, fluor_nucleus, fluor_membrane)

# Close stores
vs_store.close()
fluor_store.close()
Loading

0 comments on commit a5b51c3

Please sign in to comment.