Skip to content

Commit

Permalink
Merge pull request #31 from angelolab/interactive_viewer
Browse files Browse the repository at this point in the history
Interactive viewer
  • Loading branch information
JLrumberger authored Nov 29, 2024
2 parents 089f9a2 + 829e943 commit 3f380f2
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 6 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"zarr",
"lmdb",
"kornia",
"mpl_interactions",
]

[[project.source]]
Expand Down
49 changes: 48 additions & 1 deletion src/nimbus_inference/example_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from typing import Union
import datasets
from alpineer.misc_utils import verify_in_list
import zipfile
import os
import requests

EXAMPLE_DATASET_REVISION: str = "main"

Expand Down Expand Up @@ -214,4 +217,48 @@ def get_example_dataset(dataset: str, save_dir: Union[str, pathlib.Path],
example_dataset.download_example_dataset()

# Move the dataset over to the save_dir from the user.
example_dataset.move_example_dataset(move_dir=save_dir)
example_dataset.move_example_dataset(move_dir=save_dir)


def download_and_unpack_gold_standard(save_dir: Union[str, pathlib.Path], overwrite_existing: bool = True):
"""
Downloads 'gold_standard_labelled.zip' from the Hugging Face dataset and unpacks it in the given folder
if the dataset is not already present there.
Args:
save_dir (Union[str, Path]): The path to save the dataset files in.
overwrite_existing (bool): The option to overwrite existing files. Defaults to True.
"""
url = "https://huggingface.co/datasets/JLrumberger/Pan-Multiplex-Gold-Standard/resolve/main/gold_standard_labelled.zip"
save_dir = pathlib.Path(save_dir)
zip_path = save_dir / "gold_standard_labelled.zip"

# Create the save directory if it doesn't exist
save_dir.mkdir(parents=True, exist_ok=True)

# Check if the dataset is already present
if zip_path.exists() and not overwrite_existing:
print(f"{zip_path} already exists. Skipping download.")
return

# Download the zip file
print(f"Downloading {url} to {zip_path}...")
response = requests.get(url, stream=True)
response.raise_for_status()

with open(zip_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)

print(f"Downloaded {zip_path}")

# Unpack the zip file
print(f"Unpacking {zip_path}...")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(save_dir)

print(f"Unpacked to {save_dir}")

# Optionally, remove the zip file after unpacking
os.remove(zip_path)
print(f"Removed {zip_path}")
15 changes: 12 additions & 3 deletions src/nimbus_inference/nimbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,18 @@ def segmentation_naming_convention(fov_path):
Returns:
str: paths to segmentation fovs
"""
fov_name = os.path.basename(fov_path).replace(".ome.tiff", "")
return os.path.join(deepcell_output_dir, fov_name + "_whole_cell.tiff")

fov_name = os.path.basename(fov_path)
# remove suffix
fov_name = Path(fov_name).stem
# find all fnames which contain a superset of the fov_name
fnames = os.listdir(deepcell_output_dir)
# use re instead of glob
fnames = [os.path.join(deepcell_output_dir, f) for f in fnames if fov_name in f]
if len(fnames) == 0:
raise ValueError(f"No segmentation data found for fov {fov_name}")
if len(fnames) > 1:
raise ValueError(f"Multiple segmentation data found for fov {fov_name}")
return fnames[0]
return segmentation_naming_convention


Expand Down
57 changes: 56 additions & 1 deletion src/nimbus_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,4 +717,59 @@ def __getitem__(self, idx):
input_data = sample[:2]
groundtruth = sample[2:3]
inst_mask = sample[3:]
return input_data, groundtruth, inst_mask, self.keys[idx]
return input_data, groundtruth, inst_mask, self.keys[idx]


class InteractiveDataset(object):
"""Dataset for the InteractiveViewer class. This dataset class stores multiple objects of type
MultiplexedDataset, and allows to select a dataset and use its method for reading fovs and
channels from it.
Args:
datasets (dict): dictionary with dataset names as keys and dataset objects as values
"""
def __init__(self, datasets: dict):
self.datasets = datasets
self.dataset_names = list(datasets.keys())
self.dataset = None

def set_dataset(self, dataset_name: str):
"""Set the active dataset
Args:
dataset_name (str): name of the dataset
"""
self.dataset = self.datasets[dataset_name]
return self.dataset

def get_channel(self, fov: str, channel: str):
"""Get a channel from a fov
Args:
fov (str): name of a fov
channel (str): channel name
Returns:
np.array: channel image
"""
return self.dataset.get_channel(fov, channel)

def get_segmentation(self, fov: str):
"""Get the instance mask for a fov
Args:
fov (str): name of a fov
Returns:
np.array: instance mask
"""
return self.dataset.get_segmentation(fov)

def get_groundtruth(self, fov: str, channel: str):
"""Get the groundtruth for a fov / channel combination
Args:
fov (str): name of a fov
channel (str): channel name
Returns:
np.array: groundtruth activity mask (0: negative, 1: positive, 2: ambiguous)
"""
return self.dataset.get_groundtruth(fov, channel)
224 changes: 223 additions & 1 deletion src/nimbus_inference/viewer_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from natsort import natsorted
from skimage.segmentation import find_boundaries
from skimage.transform import rescale
from nimbus_inference.utils import MultiplexDataset
from nimbus_inference.utils import MultiplexDataset, InteractiveDataset
from mpl_interactions import panhandler
import matplotlib.pyplot as plt

class NimbusViewer(object):
"""Viewer for Nimbus application.
Expand Down Expand Up @@ -277,3 +279,223 @@ def display(self):
self.select_fov(None)
self.layout()
self.update_composite()


class InteractiveImageDuo(widgets.Image):
"""Interactive image viewer for Nimbus application.
Args:
figsize (tuple): Size of figure.
title_left (str): Title of left image.
title_right (str): Title of right image.
"""
def __init__(self, figsize=(10, 5), title_left='Multiplexed image', title_right='Groundtruth'):
super().__init__()
self.title_left = title_left
self.title_right = title_right

# Initialize matplotlib figure
with plt.ioff():
self.fig, self.ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=figsize)

# uncomment the following lines to enable zooming via scroll wheel
# self.zoom_handler = self.custom_zoom_factory(self.ax[0])
# self.pan_handler = panhandler(self.fig)

# Display the figure canvas
display(self.fig.canvas)

def custom_zoom_factory(self, ax, base_scale=1.1):
"""Enable zooming via scroll wheel on matplotlib axes.
Args:
ax (matplotlib ax): ax to enable zooming on.
base_scale (float): Scale factor for zooming.
"""
def zoom(event):
cur_xlim = ax.get_xlim()
cur_ylim = ax.get_ylim()
xdata = event.xdata # get event x location
ydata = event.ydata # get event y location

if event.button == 'up':
scale_factor = 1 / base_scale
elif event.button == 'down':
scale_factor = base_scale
else:
scale_factor = 1
print(event.button)

new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

relx = (cur_xlim[1] - xdata) / (cur_xlim[1] - cur_xlim[0])
rely = (cur_ylim[1] - ydata) / (cur_ylim[1] - cur_ylim[0])

ax.set_xlim([xdata - new_width * (1 - relx), xdata + new_width * (relx)])
ax.set_ylim([ydata - new_height * (1 - rely), ydata + new_height * (rely)])
ax.figure.canvas.draw_idle()

fig = ax.get_figure() # get the figure of interest
fig.canvas.mpl_connect('scroll_event', zoom)

return zoom

def update_left_image(self, image):
"""Update the left image displayed in the viewer.
Args:
image (np.array): Image to display.
"""
self.ax[0].imshow(image)
self.ax[0].title.set_text(self.title_left)
self.ax[0].set_xticks([])
self.ax[0].set_yticks([])
self.fig.canvas.draw_idle()

def update_right_image(self, image):
"""Update the right image displayed in the viewer.
Args:
image (np.array): Image to display.
"""
self.ax[1].imshow(image, vmin=0, vmax=255)
self.ax[1].title.set_text(self.title_right)
self.ax[1].set_xticks([])
self.ax[1].set_yticks([])
self.fig.canvas.draw_idle()


class NimbusInteractiveGTViewer(NimbusViewer):
"""Interactive viewer for Nimbus application that shows input data and ground truth
side by side.
Args:
dataset (MultiplexDataset): dataset object
output_dir (str): Path to directory containing output of Nimbus application.
figsize (tuple): Size of figure.
"""
def __init__(
self, datasets: InteractiveDataset, output_dir, figsize=(20, 10)
):
super().__init__(
datasets.datasets[datasets.dataset_names[0]], output_dir
)
self.image = InteractiveImageDuo(figsize=figsize)
self.dataset = datasets.datasets[datasets.dataset_names[0]]
self.datasets = datasets
self.dataset_select = widgets.Select(
options=datasets.dataset_names,
description='Dataset:',
disabled=False
)
self.dataset_select.observe(self.select_dataset, names='value')

def layout(self):
"""Creates layout for viewer."""
channel_selectors = widgets.HBox([
self.red_select,
self.green_select,
self.blue_select
])
layout = widgets.HBox([
# widgets.HBox([
self.dataset_select,
self.fov_select,
channel_selectors,
self.overlay_checkbox,
self.update_button
# ]),
])
display(layout)

def select_dataset(self, change):
"""Selects dataset to display.
Args:
change (dict): Change dictionary from ipywidgets.
"""
self.dataset = self.datasets.set_dataset(change['new'])
self.fov_names = natsorted(copy(self.dataset.fovs))
self.fov_select.options = self.fov_names
self.select_fov(None)


def update_img(self, image_fn, composite_image):
"""Updates image in viewer by saving it as png and loading it with the viewer widget.
Args:
ax (matplotlib ax): ax to update.
composite_image (np.array): Composite image to display.
"""
if composite_image.shape[0] > self.max_resolution[0] or composite_image.shape[1] > self.max_resolution[1]:
scale = float(np.max(self.max_resolution)/np.max(composite_image.shape))
composite_image = rescale(composite_image, (scale, scale, 1), preserve_range=True)
composite_image = composite_image.astype(np.uint8)
image_fn(composite_image)

def update_composite(self):
"""Updates composite image in viewer."""
path_dict = {
"red": None,
"green": None,
"blue": None
}
in_path_dict = copy(path_dict)
if self.red_select.value:
path_dict["red"] = os.path.join(
self.output_dir, self.fov_select.value, self.red_select.value + self.suffix
)
in_path_dict["red"] = {"fov": self.fov_select.value, "channel": self.red_select.value}
if self.green_select.value:
path_dict["green"] = os.path.join(
self.output_dir, self.fov_select.value, self.green_select.value + self.suffix
)
in_path_dict["green"] = {
"fov": self.fov_select.value, "channel": self.green_select.value
}
if self.blue_select.value:
path_dict["blue"] = os.path.join(
self.output_dir, self.fov_select.value, self.blue_select.value + self.suffix
)
in_path_dict["blue"] = {
"fov": self.fov_select.value, "channel": self.blue_select.value
}
non_none = [p for p in path_dict.values() if p]
if not non_none:
return

in_composite_image = self.create_composite_from_dataset(in_path_dict)
in_composite_image, seg_boundaries = self.overlay(
in_composite_image, add_boundaries=self.overlay_checkbox.value
)
in_composite_image = in_composite_image / np.quantile(
in_composite_image, 0.999, axis=(0,1)
)
in_composite_image = np.clip(in_composite_image*255, 0, 255).astype(np.uint8)
if seg_boundaries is not None:
in_composite_image[seg_boundaries] = [127, 127, 127]

img = in_composite_image[...,0].astype(np.float32) * 0
right_images = []
for c, s in {'red': self.red_select.value,
'green': self.green_select.value,
'blue': self.blue_select.value}.items():
if s:
composite_image = self.dataset.get_groundtruth(
self.fov_select.value, s
)
else:
composite_image = img
composite_image = np.squeeze(composite_image).astype(np.float32)
right_images.append(composite_image)
right_images = np.stack(right_images, axis=-1)
right_images = np.clip(right_images, 0, 2)
right_images[right_images == 2] = 0.3
right_images[seg_boundaries] = 0.0
right_images *= 255.0
right_images = right_images.astype(np.uint8)

# update image viewers
self.update_img(self.image.update_left_image, in_composite_image)
self.update_img(self.image.update_right_image, right_images)
Loading

0 comments on commit 3f380f2

Please sign in to comment.