Skip to content

Commit

Permalink
Merge pull request #27 from mehta-lab/3d-augmentation
Browse files Browse the repository at this point in the history
3D augmentation and 2.1D U-Net
  • Loading branch information
mattersoflight authored Aug 12, 2023
2 parents d6f51fe + 4e2c8b6 commit f9b4e16
Show file tree
Hide file tree
Showing 11 changed files with 599 additions and 200 deletions.
73 changes: 35 additions & 38 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,62 +2,59 @@

viscy is a deep learning pipeline for training and deploying computer vision models for image-based phenotyping at single cell resolution.

The current focus of the pipeline is on the image translation models for virtual staining of multiple cellular compartments from label-free images. We are building these models for simultaneous segmentation of nuclei and membrane, which are the first steps in a single-cell phenotyping pipeline. Our pipeline also provides utilities to export the models to onnx format for use at runtime. We will grow the collection of the models suitable for high-throughput imaging and phenotyping.
The current focus of the pipeline is on the image translation models for virtual staining of multiple cellular compartments from label-free images.
We are building these models for simultaneous segmentation of nuclei and membrane, which are the first steps in a single-cell phenotyping pipeline.
Our pipeline also provides utilities to export the models to ONNX format for use at runtime.
We will grow the collection of the models suitable for high-throughput imaging and phenotyping.
Expect rough edges until we release a PyPI package.


![virtual_staining](docs/figures/phase_to_nuclei_membrane.svg)

This pipeline evolved from the [TensorFlow version of virtual staining pipeline](https://github.com/mehta-lab/microDL), which we reported in [this paper in 2020](https://elifesciences.org/articles/55502). The previous pipeline is now a public archive, and we will be focusing our efforts on viscy.

## Installation
## Installing viscy

(Optional) create a new virtual/Conda environment.
1. We highly encourage using new Conda/virtual environment.
([Mamba](https://github.com/mamba-org/mamba) is a faster re-implementation Conda.)

Clone this repository and install viscy:
```sh
mamba create --name viscy python=3.10
# OR
mamba create --prefix /path/to/conda/envs/viscy python=3.10
```

```sh
git clone https://github.com/mehta-lab/viscy.git
cd viscy
pip install .
```

Verify installation by accessing the CLI help message:
2. Clone this repository and install with pip:

```sh
viscy --help
```
```sh
git clone https://github.com/mehta-lab/viscy.git
# change to project root directory (parent folder of pyproject.toml)
cd viscy
pip install .
```

If evaluating virtually stained images for segmentation tasks,
additional dependencies need to be installed:

```sh
pip install ".[metrics]"
```

3. Verify installation by accessing the CLI help message:

```sh
viscy --help
```

For development installation, see [the contributing guide](CONTRIBUTING.md).

The pipeline is built using the [pytorch lightning](https://www.pytorchlightning.ai/index.html) framework and [iohub](https://github.com/czbiohub-sf/iohub) library for reading and writing data in [ome-zarr](https://www.nature.com/articles/s41592-021-01326-w) format.
The pipeline is built using the [PyTorch Lightning](https://www.pytorchlightning.ai/index.html) framework and [iohub](https://github.com/czbiohub-sf/iohub) library for reading and writing data in [OME-Zarr](https://www.nature.com/articles/s41592-021-01326-w) format.

The full functionality is tested only on Linux `x86_64` with NVIDIA Ampere GPUs (CUDA 12.0).
Some features (e.g. mixed precision and distributed training) may not work with other setups,
see [PyTorch documentation](https://pytorch.org) for details.

Following dependencies will allow use and development of the pipeline, while the pypi package is pending:

```<yaml>
iohub==0.1.0.dev3
torch>=2.0.0
torchvision>=0.15.1
tensorboard>=2.13.0
lightning>=2.0.1
monai>=1.2.0
jsonargparse[signatures]>=4.20.1
scikit-image>=0.19.2
matplotlib
cellpose==2.1.0
lapsolver==1.1.0
scikit-learn>=1.1.3
scipy>=1.8.0
torchmetrics[detection]>=1.0.0
pytest
pytest-cov
hypothesis
profilehooks
onnxruntime
```

## Virtual staining of cellular compartments from label-free images

Predicting sub-cellular landmarks such as nuclei and membrane from label-free (e.g. phase) images
Expand Down
285 changes: 285 additions & 0 deletions examples/demo_dlmbl/python/excercise_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
# %% [markdown]
"""
# Image translation excercise part 1
In this exercise, we will solve an image translation task of
reconstructing nuclei and membrane markers from phase images of cells.
Here, the source domain is label-free microscopy (average material density),
and the target domain is fluorescence microscopy (fluorophore density).
Learning goals of part 1:
- Load the and visualize the images from OME-Zarr
- Configure the data loaders
- Initialize a 2D U-Net model for virtual staining
<div class="alert alert-danger">
Set your python kernel to <code>004-image-translation</code>
</div>
"""

# %%
import matplotlib.pyplot as plt
import torch
from iohub import open_ome_zarr
from tensorboard import notebook
from torchview import draw_graph
import os


from viscy.light.data import HCSDataModule
from viscy.light.engine import VSTrainer, VSUNet

BATCH_SIZE = 32
GPU_ID = 0

# %% [markdown]
"""
Load Dataset.
<div class="alert alert-info">
Task 1.1
Use <a href=https://czbiohub-sf.github.io/iohub/main/api/ngff.html#open-ome-zarr>
<code>iohub.open_ome_zarr</code></a> to read the dataset.
There should be 301 FOVs in the dataset (9.3 GB compressed).
Each FOV consists of 3 channels of 2048x2048 images,
saved in the <a href="https://ngff.openmicroscopy.org/latest/#hcs-layout">
High-Content Screening (HCS) layout</a>
specified by the Open Microscopy Environment Next Generation File Format
(OME-NGFF).
Run <code>open_ome_zarr?</code> in a cell to see the docstring.
"""

# %%
# set dataset path here
data_path = "/hpc/projects/comp.micro/virtual_staining/datasets/dlmbl/HEK_nuclei_membrane_pyramid.zarr"

dataset = open_ome_zarr(data_path)

print(len(list(dataset.positions())))


# %% [markdown]
"""
View images with matplotlib.
The layout on the disk is: row/col/field/resolution/timepoint/channel/z/y/x.
Note that labelling is not perfect,
as some cells are not expressing the fluorophore.
"""

# %%

row = "0"
col = "0"
field = "0"
# '0' is the highest resolution
# '1' is 2x2 down-scaled, '2' is 4x4 down-scaled, etc.
resolution = "0"
image = dataset[f"{row}/{col}/{field}/{resolution}"].numpy()
print(image.shape)

figure, axes = plt.subplots(1, 3, figsize=(9, 3))

for ax, channel in zip(axes, image[0, :, 0]):
ax.imshow(channel, cmap="gray")
ax.axis("off")

plt.tight_layout()

# %% [markdown]
"""
Configure the data loaders for training and validation.
"""

# %%
data_module = HCSDataModule(
data_path,
source_channel="Phase",
target_channel=["Nuclei", "Membrane"],
z_window_size=1,
split_ratio=0.8,
batch_size=BATCH_SIZE,
num_workers=8,
architecture="2D",
yx_patch_size=(256, 256),
)

data_module.setup("fit")

print(len(data_module.train_dataset), len(data_module.val_dataset))

# %% [markdown]
"""
<div class="alert alert-info">
Task 1.2
Validate that the data can be loaded in batches correctly.
</div>
"""

# %%
train_dataloader = data_module.train_dataloader()

for i, batch in enumerate(train_dataloader):
...
# plot one image from each of the batch and break
break

# %% tags=["solution"]
train_dataloader = data_module.train_dataloader()


fig, axs = plt.subplots(3, 8, figsize=(20, 6))

# Draw 8 batches, each with 32 images. Show the first image in each batch.

for i, batch in enumerate(train_dataloader):
# The batch is a dictionary consisting of three keys: 'index', 'source', 'target'.
# index is the tuple consisting of (image name, time, and z-slice)
# source is the tensor of size 1x1x256x256
# target is the tensor of size 2x1x256x256

if i >= 8:
break
FOV = batch["index"][0][0]
input_tensor = batch["source"][0, 0, :, :].squeeze()
target_nuclei_tensor = batch["target"][0, 0, :, :].squeeze()
target_membrane_tensor = batch["target"][0, 1, :, :].squeeze()


axs[0, i].imshow(input_tensor, cmap="gray")
axs[1, i].imshow(target_nuclei_tensor, cmap="gray")
axs[2, i].imshow(target_membrane_tensor, cmap="gray")
axs[0, i].set_title(f"input@{FOV}")
axs[1, i].set_title("target-nuclei")
axs[2, i].set_title("target-membrane")
axs[0, i].axis("off")
axs[1, i].axis("off")
axs[2, i].axis("off")

plt.tight_layout()
plt.show()


# %% [markdown]
"""
Construct a 2D U-Net for image translation.
See ``viscy.unet.networks.Unet2D.Unet2d`` for configuration details.
Increase the ``depth`` in ``draw_graph`` to zoom in.
"""

# %%
model_config = {
"architecture": "2D",
"in_channels": 1,
"out_channels": 2,
"residual": True,
"dropout": 0.1,
"task": "reg",
}

model = VSUNet(
model_config=model_config.copy(),
batch_size=BATCH_SIZE,
loss_function=torch.nn.functional.mse_loss,
schedule="WarmupCosine",
log_num_samples=10,
)

# visualize graph
model_graph = draw_graph(model, model.example_input_array, depth=2, device="cpu")
graph = model_graph.visual_graph
graph

# %% [markdown]
"""
Configure trainer class.
Here we use the ``fast_dev_run`` flag to run a sanity check first.
"""

# %%
trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], fast_dev_run=True)

trainer.fit(model, datamodule=data_module)

# %% [markdown]
"""
<div class="alert alert-info">
Task 1.3
Modify the trainer to train the model for 20 epochs.
</div>
"""

# %% [markdown]
"""
Tips:
- See ``VSTrainer?`` for all the available parameters.
- Set ``default_root_dir`` to store the logs and checkpoints
in a specific directory.
"""

# %% [markdown]
"""
Bonus:
- Tweak model hyperparameters
- Adjust batch size to fully utilize the VRAM
"""

# %% tags=["solution"]
wider_config = model_config | {"num_filters": [24, 48, 96, 192, 384]}

model = model = VSUNet(
model_config=wider_config.copy(),
batch_size=BATCH_SIZE,
loss_function=torch.nn.functional.mse_loss,
schedule="WarmupCosine",
log_num_samples=10,
)


trainer = VSTrainer(
accelerator="gpu", max_epochs=20, log_every_n_steps=8, default_root_dir=os.path.expanduser("~")
)

trainer.fit(model, datamodule=data_module)

# %% [markdown]
"""
Launch TensorBoard with:
```
%load_ext tensorboard
%tensorboard --logdir /path/to/lightning_logs
```
"""

# %%
notebook.list()

# %%
notebook.display(port=6006, height=800)

# %% [markdown]
"""
<div class="alert alert-success">
Checkpoint 1
Now the training has started,
we can come back after a while and evaluate the performance!
</div>
"""

# %%
Loading

0 comments on commit f9b4e16

Please sign in to comment.