-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Demo for VSCyto2D and VSCyto3D (#94)
* first pass on VSCyto2D demo * bumping to cellpose 3 (#92) * adding fix for the scale metadata to handle positions as well. (#93) * adding the working example. pending missing comments. * adding vscyto3d example * adding readme * minor typo on readme.md * updating scripts to show the experimental fluorescence and adding the neuromast inference demo. * adding nueromast demo * touch README * update paths for vscyto2d * update paths for vscyto3d * update paths for vsneuromast * use new artifact name * adding plotting function. --------- Co-authored-by: Ziwen Liu <[email protected]>
- Loading branch information
1 parent
6a3457e
commit a5b51c3
Showing
5 changed files
with
526 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# VisCy usage examples | ||
|
||
Examples scripts showcasing the usage of VisCy. | ||
|
||
## Virtual staining | ||
|
||
### Training | ||
|
||
- WIP: DL@MBL notebooks | ||
|
||
### Inference | ||
|
||
- [Inference with VSCyto2D](./demo_vscyto2d.py): | ||
2D inference example on 20x A549 cell data. (Phase to nuclei and plasma membrane). | ||
- [Inference with VSCyto3D](./demos/demo_vscyto3d.py): | ||
3D inference example on 63x HEK293T cell data. (Phase to nuclei and plasma membrane). | ||
- [Inference VSNeuromast](./demo_vsneuromast.py): | ||
3D inference example of 63x zebrafish neuromast data (Phase to nuclei and plasma membrane) | ||
|
||
## Notes | ||
|
||
To run the examples, execute each individual script, for example: | ||
|
||
```sh | ||
python demo_vscyto2d.py | ||
``` | ||
|
||
These scripts can also be ran interactively in many IDEs as notebooks, | ||
for example in VS Code, PyCharm, and Spyder. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# %% [markdown] | ||
""" | ||
# 2D Virtual Staining of A549 Cells | ||
--- | ||
## Prediction using the VSCyto2D to predict nuclei and plasma membrane from phase. | ||
This example shows how to virtually stain A549 cells using the _VSCyto2D_ model. | ||
The model is trained to predict the membrane and nuclei channels from the phase channel. | ||
""" | ||
# %% Imports and paths | ||
from pathlib import Path | ||
|
||
from iohub import open_ome_zarr | ||
from plot import plot_vs_n_fluor | ||
|
||
# Viscy classes for the trainer and model | ||
from viscy.data.hcs import HCSDataModule | ||
from viscy.light.engine import FcmaeUNet | ||
from viscy.light.predict_writer import HCSPredictionWriter | ||
from viscy.light.trainer import VSTrainer | ||
from viscy.transforms import NormalizeSampled | ||
|
||
# %% [markdown] | ||
""" | ||
## Data and Model Paths | ||
The dataset and model checkpoint files need to be downloaded before running this example. | ||
""" | ||
|
||
# %% | ||
# Download from | ||
# https://public.czbiohub.org/comp.micro/viscy/datasets/testing/VSCyto2D/a549_hoechst_cellmask_test.zarr | ||
input_data_path = "datasets/testing/VSCyto2D/a549_hoechst_cellmask_test.zarr" | ||
# Download from GitHub release page of v0.1.0 | ||
model_ckpt_path = "VisCy-0.1.0-VS-models/VSCyto2D/epoch=399-step=23200.ckpt" | ||
# Zarr store to save the predictions | ||
output_path = "./a549_prediction.zarr" | ||
# FOV of interest | ||
fov = "0/0/0" | ||
|
||
input_data_path = Path(input_data_path) / fov | ||
|
||
# %% | ||
# Create the VSCyto2D network | ||
|
||
# NOTE: Change the following parameters as needed. | ||
BATCH_SIZE = 10 | ||
YX_PATCH_SIZE = (384, 384) | ||
NUM_WORKERS = 8 | ||
phase_channel_name = "Phase3D" | ||
|
||
# %%[markdown] | ||
""" | ||
For this example we will use the following parameters: | ||
For more information on the VSCyto2D model, | ||
see ``viscy.unet.networks.fcmae`` | ||
([source code](https://github.com/mehta-lab/VisCy/blob/6a3457ec8f43ecdc51b1760092f1a678ed73244d/viscy/unet/networks/fcmae.py#L398)) | ||
for configuration details. | ||
""" | ||
# %% | ||
# Setup the data module. | ||
data_module = HCSDataModule( | ||
data_path=input_data_path, | ||
source_channel=phase_channel_name, | ||
target_channel=["Membrane", "Nuclei"], | ||
z_window_size=1, | ||
split_ratio=0.8, | ||
batch_size=BATCH_SIZE, | ||
num_workers=NUM_WORKERS, | ||
architecture="2D", | ||
yx_patch_size=YX_PATCH_SIZE, | ||
normalizations=[ | ||
NormalizeSampled( | ||
[phase_channel_name], | ||
level="fov_statistics", | ||
subtrahend="median", | ||
divisor="iqr", | ||
) | ||
], | ||
) | ||
data_module.prepare_data() | ||
data_module.setup(stage="predict") | ||
# %% | ||
# Setup the model. | ||
# Dictionary that specifies key parameters of the model. | ||
config_VSCyto2D = { | ||
"in_channels": 1, | ||
"out_channels": 2, | ||
"encoder_blocks": [3, 3, 9, 3], | ||
"dims": [96, 192, 384, 768], | ||
"decoder_conv_blocks": 2, | ||
"stem_kernel_size": [1, 2, 2], | ||
"in_stack_depth": 1, | ||
"pretraining": False, | ||
} | ||
|
||
model_VSCyto2D = FcmaeUNet.load_from_checkpoint( | ||
model_ckpt_path, model_config=config_VSCyto2D | ||
) | ||
model_VSCyto2D.eval() | ||
|
||
# %% | ||
# Setup the Trainer | ||
trainer = VSTrainer( | ||
accelerator="gpu", | ||
callbacks=[HCSPredictionWriter(output_path)], | ||
) | ||
|
||
# Start the predictions | ||
trainer.predict( | ||
model=model_VSCyto2D, | ||
datamodule=data_module, | ||
return_predictions=False, | ||
) | ||
|
||
# %% | ||
# Open the output_zarr store and inspect the output | ||
# Show the individual channels and the fused in a 1x3 plot | ||
output_path = Path(output_path) / fov | ||
|
||
# %% | ||
# Open the predicted data | ||
vs_store = open_ome_zarr(output_path, mode="r") | ||
# Get the 2D images | ||
vs_nucleus = vs_store[0][0, 0, 0] # (t,c,z,y,x) | ||
vs_membrane = vs_store[0][0, 1, 0] # (t,c,z,y,x) | ||
# Open the experimental fluorescence | ||
fluor_store = open_ome_zarr(input_data_path, mode="r") | ||
# Get the 2D images | ||
# NOTE: Channel indeces hardcoded for this dataset | ||
fluor_nucleus = fluor_store[0][0, 1, 0] # (t,c,z,y,x) | ||
fluor_membrane = fluor_store[0][0, 2, 0] # (t,c,z,y,x) | ||
|
||
# Plot | ||
plot_vs_n_fluor(vs_nucleus, vs_membrane, fluor_nucleus, fluor_membrane) | ||
|
||
vs_store.close() | ||
fluor_store.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# %% [markdown] | ||
""" | ||
# 3D Virtual Staining of HEK293T Cells | ||
--- | ||
## Prediction using the VSCyto3D to predict nuclei and membrane from phase. | ||
This example shows how to virtually stain HEK293T cells using the _VSCyto3D_ model. | ||
The model is trained to predict the membrane and nuclei channels from the phase channel. | ||
""" | ||
# %% Imports and paths | ||
from pathlib import Path | ||
|
||
from iohub import open_ome_zarr | ||
from plot import plot_vs_n_fluor | ||
|
||
from viscy.data.hcs import HCSDataModule | ||
|
||
# Viscy classes for the trainer and model | ||
from viscy.light.engine import VSUNet | ||
from viscy.light.predict_writer import HCSPredictionWriter | ||
from viscy.light.trainer import VSTrainer | ||
from viscy.transforms import NormalizeSampled | ||
|
||
# %% [markdown] | ||
""" | ||
## Data and Model Paths | ||
The dataset and model checkpoint files need to be downloaded before running this example. | ||
""" | ||
|
||
# %% | ||
# Download from | ||
# https://public.czbiohub.org/comp.micro/viscy/datasets/testing/VSCyto3D/hek_h2b_caax_63x.zarr | ||
input_data_path = "datasets/testing/VSCyto3D/hek_h2b_caax_63x.zarr" | ||
# Download from GitHub release page of v0.1.0 | ||
model_ckpt_path = "VisCy-0.1.0-VS-models/VSCyto3D/epoch=48-step=18130.ckpt" | ||
# Zarr store to save the predictions | ||
output_path = "./hek_prediction_3d.zarr" | ||
# FOV of interest | ||
fov = "plate/0/0" | ||
|
||
input_data_path = Path(input_data_path) / fov | ||
|
||
# %% | ||
# Create the VSCyto3D model | ||
|
||
# NOTE: Change the following parameters as needed. | ||
BATCH_SIZE = 2 | ||
YX_PATCH_SIZE = (384, 384) | ||
NUM_WORKERS = 8 | ||
phase_channel_name = "Phase3D" | ||
|
||
# %%[markdown] | ||
""" | ||
For this example we will use the following parameters: | ||
### For more information on the VSCyto3D model: | ||
See ``viscy.unet.networks.fcmae`` | ||
([source code](https://github.com/mehta-lab/VisCy/blob/6a3457ec8f43ecdc51b1760092f1a678ed73244d/viscy/unet/networks/unext2.py#L252)) | ||
for configuration details. | ||
""" | ||
# %% | ||
# Setup the data module. | ||
data_module = HCSDataModule( | ||
data_path=input_data_path, | ||
source_channel=phase_channel_name, | ||
target_channel=["Membrane", "Nuclei"], | ||
z_window_size=5, | ||
split_ratio=0.8, | ||
batch_size=BATCH_SIZE, | ||
num_workers=NUM_WORKERS, | ||
architecture="UNeXt2", | ||
yx_patch_size=YX_PATCH_SIZE, | ||
normalizations=[ | ||
NormalizeSampled( | ||
[phase_channel_name], | ||
level="fov_statistics", | ||
subtrahend="median", | ||
divisor="iqr", | ||
) | ||
], | ||
) | ||
data_module.prepare_data() | ||
data_module.setup(stage="predict") | ||
# %% | ||
# Setup the model. | ||
# Dictionary that specifies key parameters of the model. | ||
config_VSCyto3D = { | ||
"in_channels": 1, | ||
"out_channels": 2, | ||
"in_stack_depth": 5, | ||
"backbone": "convnextv2_tiny", | ||
"stem_kernel_size": (5, 4, 4), | ||
"decoder_mode": "pixelshuffle", | ||
"head_expansion_ratio": 4, | ||
"head_pool": True, | ||
} | ||
|
||
model_VSCyto3D = VSUNet.load_from_checkpoint( | ||
model_ckpt_path, architecture="UNeXt2", model_config=config_VSCyto3D | ||
) | ||
model_VSCyto3D.eval() | ||
|
||
# %% | ||
# Setup the Trainer | ||
trainer = VSTrainer( | ||
accelerator="gpu", | ||
callbacks=[HCSPredictionWriter(output_path)], | ||
) | ||
|
||
# Start the predictions | ||
trainer.predict( | ||
model=model_VSCyto3D, | ||
datamodule=data_module, | ||
return_predictions=False, | ||
) | ||
|
||
# %% | ||
# Open the output_zarr store and inspect the output | ||
# Show the individual channels and the fused in a 1x3 plot | ||
output_path = Path(output_path) / fov | ||
|
||
# %% | ||
# Open the predicted data | ||
vs_store = open_ome_zarr(output_path, mode="r") | ||
T, C, Z, Y, X = vs_store.data.shape | ||
|
||
# Get a z-slice | ||
z_slice = Z // 2 # NOTE: using the middle slice of the stack. Change as needed. | ||
vs_nucleus = vs_store[0][0, 0, z_slice] # (t,c,z,y,x) | ||
vs_membrane = vs_store[0][0, 1, z_slice] # (t,c,z,y,x) | ||
# Open the experimental fluorescence | ||
fluor_store = open_ome_zarr(input_data_path, mode="r") | ||
# Get the 2D images | ||
# NOTE: Channel indeces hardcoded for this dataset | ||
fluor_nucleus = fluor_store[0][0, 2, z_slice] # (t,c,z,y,x) | ||
fluor_membrane = fluor_store[0][0, 1, z_slice] # (t,c,z,y,x) | ||
|
||
# Plot | ||
plot_vs_n_fluor(vs_nucleus, vs_membrane, fluor_nucleus, fluor_membrane) | ||
|
||
# Close stores | ||
vs_store.close() | ||
fluor_store.close() |
Oops, something went wrong.