Skip to content

Commit

Permalink
Gradio example (#158)
Browse files Browse the repository at this point in the history
* initial demo

* using the predict_step

* modifying paths to chkpt and example pngs

* updating gradio as the one on Huggingface
  • Loading branch information
edyoshikun authored Dec 2, 2024
1 parent 987874f commit e88a174
Showing 1 changed file with 144 additions and 0 deletions.
144 changes: 144 additions & 0 deletions examples/gradio/demo_gradio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import gradio as gr
import torch
from viscy.light.engine import VSUNet
from huggingface_hub import hf_hub_download
from numpy.typing import ArrayLike
import numpy as np
from skimage import exposure


class VSGradio:
def __init__(self, model_config, model_ckpt_path):
self.model_config = model_config
self.model_ckpt_path = model_ckpt_path
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.load_model()

def load_model(self):
# Load the model checkpoint and move it to the correct device (GPU or CPU)
self.model = VSUNet.load_from_checkpoint(
self.model_ckpt_path,
architecture="UNeXt2_2D",
model_config=self.model_config,
)
self.model.to(self.device) # Move the model to the correct device (GPU/CPU)
self.model.eval()

def normalize_fov(self, input: ArrayLike):
"Normalizing the fov with zero mean and unit variance"
mean = np.mean(input)
std = np.std(input)
return (input - mean) / std

def predict(self, inp):
# Normalize the input and convert to tensor
inp = self.normalize_fov(inp)
inp = torch.from_numpy(np.array(inp).astype(np.float32))

# Prepare the input dictionary and move input to the correct device (GPU or CPU)
test_dict = dict(
index=None,
source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.device),
)

# Run model inference
with torch.inference_mode():
self.model.on_predict_start() # Necessary preprocessing for the model
pred = (
self.model.predict_step(test_dict, 0, 0).cpu().numpy()
) # Move output back to CPU for post-processing

# Post-process the model output and rescale intensity
nuc_pred = pred[0, 0, 0]
mem_pred = pred[0, 1, 0]
nuc_pred = exposure.rescale_intensity(nuc_pred, out_range=(0, 1))
mem_pred = exposure.rescale_intensity(mem_pred, out_range=(0, 1))

return nuc_pred, mem_pred


# Load the custom CSS from the file
def load_css(file_path):
with open(file_path, "r") as file:
return file.read()


# %%
if __name__ == "__main__":
# Download the model checkpoint from Hugging Face
model_ckpt_path = hf_hub_download(
repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt"
)

# Model configuration
model_config = {
"in_channels": 1,
"out_channels": 2,
"encoder_blocks": [3, 3, 9, 3],
"dims": [96, 192, 384, 768],
"decoder_conv_blocks": 2,
"stem_kernel_size": [1, 2, 2],
"in_stack_depth": 1,
"pretraining": False,
}

# Initialize the Gradio app using Blocks
with gr.Blocks(css=load_css("style.css")) as demo:
# Title and description
gr.HTML(
"<div class='title-block'>Image Translation (Virtual Staining) of cellular landmark organelles</div>"
)
# Improved description block with better formatting
gr.HTML(
"""
<div class='description-block'>
<p><b>Model:</b> VSCyto2D</p>
<p>
<b>Input:</b> label-free image (e.g., QPI or phase contrast) <br>
<b>Output:</b> two virtually stained channels: one for the <b>nucleus</b> and one for the <b>cell membrane</b>.
</p>
<p>
Check out our preprint:
<a href='https://www.biorxiv.org/content/10.1101/2024.05.31.596901' target='_blank'><i>Liu et al.,Robust virtual staining of landmark organelles</i></a>
</p>
</div>
"""
)

vsgradio = VSGradio(model_config, model_ckpt_path)

# Layout for input and output images
with gr.Row():
input_image = gr.Image(type="numpy", image_mode="L", label="Upload Image")
with gr.Column():
output_nucleus = gr.Image(type="numpy", label="VS Nucleus")
output_membrane = gr.Image(type="numpy", label="VS Membrane")

# Button to trigger prediction
submit_button = gr.Button("Submit")

# Define what happens when the button is clicked
submit_button.click(
vsgradio.predict,
inputs=input_image,
outputs=[output_nucleus, output_membrane],
)

# Example images and article
gr.Examples(
examples=["examples/a549.png", "examples/hek.png"], inputs=input_image
)

# Article or footer information
gr.HTML(
"""
<div class='article-block'>
<p> Model trained primarily on HEK293T, BJ5, and A549 cells. For best results, use quantitative phase images (QPI) or Zernike phase contrast.</p>
<p> For training, inference and evaluation of the model refer to the <a href='https://github.com/mehta-lab/VisCy/tree/main/examples/virtual_staining/dlmbl_exercise' target='_blank'>GitHub repository</a>.</p>
</div>
"""
)

# Launch the Gradio app
demo.launch()

0 comments on commit e88a174

Please sign in to comment.