From e88a17413894650606f419eebef760c03e8ebb8c Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 2 Dec 2024 15:29:58 -0800 Subject: [PATCH] Gradio example (#158) * initial demo * using the predict_step * modifying paths to chkpt and example pngs * updating gradio as the one on Huggingface --- examples/gradio/demo_gradio.py | 144 +++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 examples/gradio/demo_gradio.py diff --git a/examples/gradio/demo_gradio.py b/examples/gradio/demo_gradio.py new file mode 100644 index 00000000..bcee613c --- /dev/null +++ b/examples/gradio/demo_gradio.py @@ -0,0 +1,144 @@ +import gradio as gr +import torch +from viscy.light.engine import VSUNet +from huggingface_hub import hf_hub_download +from numpy.typing import ArrayLike +import numpy as np +from skimage import exposure + + +class VSGradio: + def __init__(self, model_config, model_ckpt_path): + self.model_config = model_config + self.model_ckpt_path = model_ckpt_path + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = None + self.load_model() + + def load_model(self): + # Load the model checkpoint and move it to the correct device (GPU or CPU) + self.model = VSUNet.load_from_checkpoint( + self.model_ckpt_path, + architecture="UNeXt2_2D", + model_config=self.model_config, + ) + self.model.to(self.device) # Move the model to the correct device (GPU/CPU) + self.model.eval() + + def normalize_fov(self, input: ArrayLike): + "Normalizing the fov with zero mean and unit variance" + mean = np.mean(input) + std = np.std(input) + return (input - mean) / std + + def predict(self, inp): + # Normalize the input and convert to tensor + inp = self.normalize_fov(inp) + inp = torch.from_numpy(np.array(inp).astype(np.float32)) + + # Prepare the input dictionary and move input to the correct device (GPU or CPU) + test_dict = dict( + index=None, + source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.device), + ) + + # Run model inference + with torch.inference_mode(): + self.model.on_predict_start() # Necessary preprocessing for the model + pred = ( + self.model.predict_step(test_dict, 0, 0).cpu().numpy() + ) # Move output back to CPU for post-processing + + # Post-process the model output and rescale intensity + nuc_pred = pred[0, 0, 0] + mem_pred = pred[0, 1, 0] + nuc_pred = exposure.rescale_intensity(nuc_pred, out_range=(0, 1)) + mem_pred = exposure.rescale_intensity(mem_pred, out_range=(0, 1)) + + return nuc_pred, mem_pred + + +# Load the custom CSS from the file +def load_css(file_path): + with open(file_path, "r") as file: + return file.read() + + +# %% +if __name__ == "__main__": + # Download the model checkpoint from Hugging Face + model_ckpt_path = hf_hub_download( + repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt" + ) + + # Model configuration + model_config = { + "in_channels": 1, + "out_channels": 2, + "encoder_blocks": [3, 3, 9, 3], + "dims": [96, 192, 384, 768], + "decoder_conv_blocks": 2, + "stem_kernel_size": [1, 2, 2], + "in_stack_depth": 1, + "pretraining": False, + } + + # Initialize the Gradio app using Blocks + with gr.Blocks(css=load_css("style.css")) as demo: + # Title and description + gr.HTML( + "
Image Translation (Virtual Staining) of cellular landmark organelles
" + ) + # Improved description block with better formatting + gr.HTML( + """ +
+

Model: VSCyto2D

+

+ Input: label-free image (e.g., QPI or phase contrast)
+ Output: two virtually stained channels: one for the nucleus and one for the cell membrane. +

+

+ Check out our preprint: + Liu et al.,Robust virtual staining of landmark organelles +

+
+ """ + ) + + vsgradio = VSGradio(model_config, model_ckpt_path) + + # Layout for input and output images + with gr.Row(): + input_image = gr.Image(type="numpy", image_mode="L", label="Upload Image") + with gr.Column(): + output_nucleus = gr.Image(type="numpy", label="VS Nucleus") + output_membrane = gr.Image(type="numpy", label="VS Membrane") + + # Button to trigger prediction + submit_button = gr.Button("Submit") + + # Define what happens when the button is clicked + submit_button.click( + vsgradio.predict, + inputs=input_image, + outputs=[output_nucleus, output_membrane], + ) + + # Example images and article + gr.Examples( + examples=["examples/a549.png", "examples/hek.png"], inputs=input_image + ) + + # Article or footer information + gr.HTML( + """ +
+

Model trained primarily on HEK293T, BJ5, and A549 cells. For best results, use quantitative phase images (QPI) or Zernike phase contrast.

+

For training, inference and evaluation of the model refer to the GitHub repository.

+
+ """ + ) + + # Launch the Gradio app + demo.launch()