Skip to content

Commit

Permalink
Added interactive viewer widget
Browse files Browse the repository at this point in the history
  • Loading branch information
JLrumberger committed Nov 29, 2024
1 parent 77acadc commit 99c75e4
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 25 deletions.
49 changes: 48 additions & 1 deletion src/nimbus_inference/example_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from typing import Union
import datasets
from alpineer.misc_utils import verify_in_list
import zipfile
import os
import requests

EXAMPLE_DATASET_REVISION: str = "main"

Expand Down Expand Up @@ -214,4 +217,48 @@ def get_example_dataset(dataset: str, save_dir: Union[str, pathlib.Path],
example_dataset.download_example_dataset()

# Move the dataset over to the save_dir from the user.
example_dataset.move_example_dataset(move_dir=save_dir)
example_dataset.move_example_dataset(move_dir=save_dir)


def download_and_unpack_gold_standard(save_dir: Union[str, pathlib.Path], overwrite_existing: bool = True):
"""
Downloads 'gold_standard_labelled.zip' from the Hugging Face dataset and unpacks it in the given folder
if the dataset is not already present there.
Args:
save_dir (Union[str, Path]): The path to save the dataset files in.
overwrite_existing (bool): The option to overwrite existing files. Defaults to True.
"""
url = "https://huggingface.co/datasets/JLrumberger/Pan-Multiplex-Gold-Standard/resolve/main/gold_standard_labelled.zip"
save_dir = pathlib.Path(save_dir)
zip_path = save_dir / "gold_standard_labelled.zip"

# Create the save directory if it doesn't exist
save_dir.mkdir(parents=True, exist_ok=True)

# Check if the dataset is already present
if zip_path.exists() and not overwrite_existing:
print(f"{zip_path} already exists. Skipping download.")
return

# Download the zip file
print(f"Downloading {url} to {zip_path}...")
response = requests.get(url, stream=True)
response.raise_for_status()

with open(zip_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)

print(f"Downloaded {zip_path}")

# Unpack the zip file
print(f"Unpacking {zip_path}...")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(save_dir)

print(f"Unpacked to {save_dir}")

# Optionally, remove the zip file after unpacking
os.remove(zip_path)
print(f"Removed {zip_path}")
15 changes: 12 additions & 3 deletions src/nimbus_inference/nimbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,18 @@ def segmentation_naming_convention(fov_path):
Returns:
str: paths to segmentation fovs
"""
fov_name = os.path.basename(fov_path).replace(".ome.tiff", "")
return os.path.join(deepcell_output_dir, fov_name + "_whole_cell.tiff")

fov_name = os.path.basename(fov_path)
# remove suffix
fov_name = Path(fov_name).stem
# find all fnames which contain a superset of the fov_name
fnames = os.listdir(deepcell_output_dir)
# use re instead of glob
fnames = [os.path.join(deepcell_output_dir, f) for f in fnames if fov_name in f]
if len(fnames) == 0:
raise ValueError(f"No segmentation data found for fov {fov_name}")
if len(fnames) > 1:
raise ValueError(f"Multiple segmentation data found for fov {fov_name}")
return fnames[0]
return segmentation_naming_convention


Expand Down
57 changes: 56 additions & 1 deletion src/nimbus_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,4 +717,59 @@ def __getitem__(self, idx):
input_data = sample[:2]
groundtruth = sample[2:3]
inst_mask = sample[3:]
return input_data, groundtruth, inst_mask, self.keys[idx]
return input_data, groundtruth, inst_mask, self.keys[idx]


class InteractiveDataset(object):
"""Dataset for the InteractiveViewer class. This dataset class stores multiple objects of type
MultiplexedDataset, and allows to select a dataset and use its method for reading fovs and
channels from it.
Args:
datasets (dict): dictionary with dataset names as keys and dataset objects as values
"""
def __init__(self, datasets: dict):
self.datasets = datasets
self.dataset_names = list(datasets.keys())
self.dataset = None

def set_dataset(self, dataset_name: str):
"""Set the active dataset
Args:
dataset_name (str): name of the dataset
"""
self.dataset = self.datasets[dataset_name]
return self.dataset

def get_channel(self, fov: str, channel: str):
"""Get a channel from a fov
Args:
fov (str): name of a fov
channel (str): channel name
Returns:
np.array: channel image
"""
return self.dataset.get_channel(fov, channel)

def get_segmentation(self, fov: str):
"""Get the instance mask for a fov
Args:
fov (str): name of a fov
Returns:
np.array: instance mask
"""
return self.dataset.get_segmentation(fov)

def get_groundtruth(self, fov: str, channel: str):
"""Get the groundtruth for a fov / channel combination
Args:
fov (str): name of a fov
channel (str): channel name
Returns:
np.array: groundtruth activity mask (0: negative, 1: positive, 2: ambiguous)
"""
return self.dataset.get_groundtruth(fov, channel)
77 changes: 57 additions & 20 deletions src/nimbus_inference/viewer_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from natsort import natsorted
from skimage.segmentation import find_boundaries
from skimage.transform import rescale
from nimbus_inference.utils import MultiplexDataset
from nimbus_inference.utils import MultiplexDataset, InteractiveDataset
from mpl_interactions import panhandler
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -289,7 +289,7 @@ class InteractiveImageDuo(widgets.Image):
title_left (str): Title of left image.
title_right (str): Title of right image.
"""
def __init__(self, figsize=(10, 5), title_left='Multiplexed image', title_right='Prediction'):
def __init__(self, figsize=(10, 5), title_left='Multiplexed image', title_right='Groundtruth'):
super().__init__()
self.title_left = title_left
self.title_right = title_right
Expand Down Expand Up @@ -359,30 +359,37 @@ def update_right_image(self, image):
Args:
image (np.array): Image to display.
"""
self.ax[1].imshow(image)
self.ax[1].imshow(image, vmin=0, vmax=255)
self.ax[1].title.set_text(self.title_right)
self.ax[1].set_xticks([])
self.ax[1].set_yticks([])
self.fig.canvas.draw_idle()


class NimbusInteractiveViewer(NimbusViewer):
"""Interactive viewer for Nimbus application.
class NimbusInteractiveGTViewer(NimbusViewer):
"""Interactive viewer for Nimbus application that shows input data and ground truth
side by side.
Args:
dataset (MultiplexDataset): dataset object
output_dir (str): Path to directory containing output of Nimbus application.
segmentation_naming_convention (fn): Function that maps input path to segmentation path
img_width (str): Width of images in viewer.
suffix (str): Suffix of images in dataset.
max_resolution (tuple): Maximum resolution of images in viewer.
figsize (tuple): Size of figure.
"""
def __init__(
self, dataset: MultiplexDataset, output_dir: str, img_width='600px', suffix=".tiff",
max_resolution=(2048, 2048)
self, datasets: InteractiveDataset, output_dir, figsize=(20, 10)
):
super().__init__(dataset, output_dir, img_width, suffix, max_resolution)
self.image = InteractiveImageDuo()
super().__init__(
datasets.datasets[datasets.dataset_names[0]], output_dir
)
self.image = InteractiveImageDuo(figsize=figsize)
self.dataset = datasets.datasets[datasets.dataset_names[0]]
self.datasets = datasets
self.dataset_select = widgets.Select(
options=datasets.dataset_names,
description='Dataset:',
disabled=False
)
self.dataset_select.observe(self.select_dataset, names='value')

def layout(self):
"""Creates layout for viewer."""
Expand All @@ -392,15 +399,28 @@ def layout(self):
self.blue_select
])
layout = widgets.HBox([
widgets.HBox([
# widgets.HBox([
self.dataset_select,
self.fov_select,
channel_selectors,
self.overlay_checkbox,
self.update_button
]),
# ]),
])
display(layout)

def select_dataset(self, change):
"""Selects dataset to display.
Args:
change (dict): Change dictionary from ipywidgets.
"""
self.dataset = self.datasets.set_dataset(change['new'])
self.fov_names = natsorted(copy(self.dataset.fovs))
self.fov_select.options = self.fov_names
self.select_fov(None)


def update_img(self, image_fn, composite_image):
"""Updates image in viewer by saving it as png and loading it with the viewer widget.
Expand Down Expand Up @@ -444,10 +464,6 @@ def update_composite(self):
non_none = [p for p in path_dict.values() if p]
if not non_none:
return
composite_image = self.create_composite_image(path_dict)
composite_image, _ = self.overlay(
composite_image, add_overlay=True
)

in_composite_image = self.create_composite_from_dataset(in_path_dict)
in_composite_image, seg_boundaries = self.overlay(
Expand All @@ -459,6 +475,27 @@ def update_composite(self):
in_composite_image = np.clip(in_composite_image*255, 0, 255).astype(np.uint8)
if seg_boundaries is not None:
in_composite_image[seg_boundaries] = [127, 127, 127]

img = in_composite_image[...,0].astype(np.float32) * 0
right_images = []
for c, s in {'red': self.red_select.value,
'green': self.green_select.value,
'blue': self.blue_select.value}.items():
if s:
composite_image = self.dataset.get_groundtruth(
self.fov_select.value, s
)
else:
composite_image = img
composite_image = np.squeeze(composite_image).astype(np.float32)
right_images.append(composite_image)
right_images = np.stack(right_images, axis=-1)
right_images = np.clip(right_images, 0, 2)
right_images[right_images == 2] = 0.3
right_images[seg_boundaries] = 0.0
right_images *= 255.0
right_images = right_images.astype(np.uint8)

# update image viewers
self.update_img(self.image.update_left_image, in_composite_image)
self.update_img(self.image.update_right_image, composite_image)
self.update_img(self.image.update_right_image, right_images)
22 changes: 22 additions & 0 deletions tests/test_viewer_widget.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from nimbus_inference.viewer_widget import InteractiveImageDuo, NimbusInteractiveGTViewer
from nimbus_inference.viewer_widget import NimbusViewer
from nimbus_inference.nimbus import Nimbus, prep_naming_convention
from nimbus_inference.utils import MultiplexDataset
from tests.test_utils import prepare_ome_tif_data, prepare_tif_data
from natsort import natsorted
from copy import copy
import numpy as np
import tempfile
import os
Expand Down Expand Up @@ -73,3 +76,22 @@ def test_overlay():
assert composite_image.shape == (256, 256, 3)
assert seg_boundaries.shape == (256, 256)
assert np.unique(seg_boundaries).tolist() == [0, 1]


def test_InteractiveImageDuo():
image_duo = InteractiveImageDuo(
figsize=(10, 5), title_left='Left Image', title_right='Right Image'
)
assert isinstance(image_duo, InteractiveImageDuo)

# Create dummy images
left_image = np.random.randint(0, 255, (256, 256), dtype=np.uint8)
right_image = np.random.randint(0, 255, (256, 256), dtype=np.uint8)

# Update images
image_duo.update_left_image(left_image)
image_duo.update_right_image(right_image)

# Check if images are updated
assert image_duo.ax[0].images[0].get_array().shape == (256, 256)
assert image_duo.ax[1].images[0].get_array().shape == (256, 256)

0 comments on commit 99c75e4

Please sign in to comment.