Skip to content

Commit

Permalink
Merge pull request #573 from AllenCell/feature/thresholding-live_images
Browse files Browse the repository at this point in the history
Feature/thresholding live images
  • Loading branch information
yrkim98 authored Dec 6, 2024
2 parents e8837e1 + 213086d commit a0cb9f6
Show file tree
Hide file tree
Showing 29 changed files with 1,445 additions and 329 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ dependencies = [
"npe2>=0.6.2",
"numpy",
"hydra-core==1.3.2",
"bioio",
"bioio==1.0.1",
"tifffile>=2023.4.12",
"watchdog",
"cyto-dl>=0.1.8",
"cyto-dl>=0.4.4",
"scikit-image!=0.23.0",
]

Expand Down
87 changes: 86 additions & 1 deletion src/allencell_ml_segmenter/_tests/core/test_file_input_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from pathlib import Path

import allencell_ml_segmenter
from allencell_ml_segmenter._tests.fakes.fake_subscriber import FakeSubscriber
from allencell_ml_segmenter.core.event import Event
from allencell_ml_segmenter.core.file_input_model import FileInputModel
from allencell_ml_segmenter.core.file_input_model import (
FileInputModel,
InputMode,
)


def test_set_selected_paths_no_extract_channels() -> None:
Expand Down Expand Up @@ -98,3 +102,84 @@ def test_set_max_channels_dispatch() -> None:

# Assert nothing happened
dummy_subscriber.was_handled(Event.ACTION_FILEINPUT_MAX_CHANNELS_SET)


def test_get_input_files_as_list_from_path() -> None:
"""
Test to see if all paths from a directory are returned as a list
"""
# ARRANGE
file_input_model: FileInputModel = FileInputModel()
file_input_model.set_input_mode(InputMode.FROM_PATH)
file_input_model.set_input_image_path(
Path(allencell_ml_segmenter.__file__).parent
/ "_tests"
/ "test_files"
/ "img_folder"
)

# Act
files: list[Path] = file_input_model.get_input_files_as_list()

# Assert
assert len(files) == 5


def test_get_input_files_as_list_from_viewer() -> None:
"""
Test to see if all paths from viewer displayed images are returned as a list
"""
# ARRANGE
file_input_model: FileInputModel = FileInputModel()
file_input_model.set_input_mode(InputMode.FROM_NAPARI_LAYERS)
fake_selected_paths: list[Path] = [Path("fake_path1"), Path("fake_path2")]
file_input_model.set_selected_paths(fake_selected_paths)

# Act
files: list[Path] = file_input_model.get_input_files_as_list()

# Assert
assert len(files) == 2
assert files == fake_selected_paths


def test_get_input_files_as_list_from_no_directory_selected() -> None:
"""
Test to see if an empty list is returned when no directory is selected
"""
# ARRANGE
file_input_model: FileInputModel = FileInputModel()

# Act
files: list[Path] = file_input_model.get_input_files_as_list()

# Assert
assert len(files) == 0


def test_get_input_files_as_list_from_no_selected_paths() -> None:
"""
Test to see if an empty list is returned when no layers were selected
"""
# ARRANGE
file_input_model: FileInputModel = FileInputModel()

# Act
files: list[Path] = file_input_model.get_input_files_as_list()

# Assert
assert len(files) == 0


def test_get_input_files_as_list_from_no_selected_paths() -> None:
"""
Test to see if an empty list is returned when no input mode is selected
"""
# ARRANGE
file_input_model: FileInputModel = FileInputModel()

# Act
files: list[Path] = file_input_model.get_input_files_as_list()

# Assert
assert len(files) == 0
24 changes: 22 additions & 2 deletions src/allencell_ml_segmenter/_tests/fakes/fake_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ def __init__(self, viewer: Optional[napari.Viewer] = None):
self._shapes_layers: Dict[str, ShapesLayer] = {}
self._labels_layers: Dict[str, LabelsLayer] = {}
self._on_layers_change_fns: List[Callable] = []
self.threshold_inserted: Dict[str, np.ndarray] = {}

def add_image(self, image: np.ndarray, name: str):
self._image_layers[name] = ImageLayer(name, None)
self._image_layers[name] = ImageLayer(name, path=None, data=image)
self._on_layers_change()

def get_image(self, name: str) -> Optional[ImageLayer]:
Expand Down Expand Up @@ -93,7 +94,7 @@ def contains_layer(self, name: str) -> bool:

# not supporting in the fake because we will move away from this fn in the near future
def get_layers(self) -> List[Layer]:
return []
return list(self._image_layers.values())

def subscribe_layers_change_event(
self, function: Callable[[NapariEvent], None]
Expand All @@ -103,3 +104,22 @@ def subscribe_layers_change_event(
def _on_layers_change(self):
for fn in self._on_layers_change_fns:
fn(FakeNapariEvent())

def get_seg_layers(self) -> list[Layer]:
return [
layer
for layer in self._image_layers.values()
if layer.name.startswith("[seg]")
]

def insert_threshold(
self, layer_name: str, img: np.ndarray, seg_layers: bool = False
) -> None:
self.threshold_inserted[f"[threshold] {layer_name}"] = img

def get_layers_nonthreshold(self) -> list[Layer]:
return [
layer
for layer in self._image_layers.values()
if not layer.name.startswith("[threshold]")
]
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_build_overrides() -> None:
assert overrides["train"] == False
assert overrides["mode"] == "predict"
assert overrides["task_name"] == "predict_task_from_app"
assert overrides["ckpt_path"] == str(
assert overrides["checkpoint.ckpt_path"] == str(
Path(__file__).parent.parent
/ "main"
/ "experiments_home"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest

from allencell_ml_segmenter._tests.fakes.fake_subscriber import FakeSubscriber
from allencell_ml_segmenter.core.event import Event
from allencell_ml_segmenter.thresholding.thresholding_model import (
ThresholdingModel,
)


@pytest.fixture
def thresholding_model() -> ThresholdingModel:
model = ThresholdingModel()
return model


def test_set_thresholding_value_dispatches_event(thresholding_model):
fake_subscriber: FakeSubscriber = FakeSubscriber()
thresholding_model.subscribe(
Event.ACTION_THRESHOLDING_VALUE_CHANGED,
fake_subscriber,
fake_subscriber.handle,
)

thresholding_model.set_thresholding_value(2)

assert fake_subscriber.was_handled(Event.ACTION_THRESHOLDING_VALUE_CHANGED)


def test_set_autothresholding_enabled_dispatches_event(thresholding_model):
fake_subscriber: FakeSubscriber = FakeSubscriber()
thresholding_model.subscribe(
Event.ACTION_THRESHOLDING_AUTOTHRESHOLDING_SELECTED,
fake_subscriber,
fake_subscriber.handle,
)

thresholding_model.set_autothresholding_enabled(True)

assert fake_subscriber.was_handled(
Event.ACTION_THRESHOLDING_AUTOTHRESHOLDING_SELECTED
)


def test_dispatch_save_thresholded_images(thresholding_model):
fake_subscriber: FakeSubscriber = FakeSubscriber()
thresholding_model.subscribe(
Event.ACTION_SAVE_THRESHOLDING_IMAGES,
fake_subscriber,
fake_subscriber.handle,
)

thresholding_model.dispatch_save_thresholded_images()

assert fake_subscriber.was_handled(Event.ACTION_SAVE_THRESHOLDING_IMAGES)
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import pytest
import numpy as np

from allencell_ml_segmenter.core.file_input_model import FileInputModel
from allencell_ml_segmenter._tests.fakes.fake_experiments_model import (
FakeExperimentsModel,
)
from allencell_ml_segmenter.main.main_model import MainModel
from allencell_ml_segmenter.thresholding.thresholding_model import (
ThresholdingModel,
)
from allencell_ml_segmenter.thresholding.thresholding_service import (
ThresholdingService,
)
from allencell_ml_segmenter.core.task_executor import SynchroTaskExecutor
from allencell_ml_segmenter._tests.fakes.fake_viewer import FakeViewer


@pytest.fixture
def test_image():
"""Create a small test image for thresholding."""
return np.array([[10, 20, 30], [40, 50, 60], [70, 80, 90]])


def test_on_threshold_changed_non_prediction(test_image):
# ARRANGE
thresholding_model: ThresholdingModel = ThresholdingModel()
viewer: FakeViewer = FakeViewer()
thresholding_service: ThresholdingService = ThresholdingService(
thresholding_model,
FakeExperimentsModel(),
FileInputModel(),
MainModel(),
viewer,
task_executor=SynchroTaskExecutor.global_instance(),
)
viewer.add_image(test_image, name="test_layer")

# ACT set a threshold to trigger
thresholding_model.set_thresholding_value(50)

# Verify a segmentation layer is added
assert "[threshold] test_layer" in viewer.threshold_inserted
seg_data = viewer.threshold_inserted["[threshold] test_layer"]
assert np.array_equal(seg_data, (test_image > 50).astype(int))

# check if existing thresholds get updated
thresholding_model.set_thresholding_value(100)
assert len(viewer.get_layers()) == 1
seg_data = viewer.threshold_inserted["[threshold] test_layer"]
assert np.array_equal(seg_data, (test_image > 100).astype(int))


def test_on_threshold_changed_non_prediction(test_image):
"""
Test that the thresholding service does not add a threshold layer for a layer that is not a probability map
"""
# ARRANGE
thresholding_model: ThresholdingModel = ThresholdingModel()
viewer: FakeViewer = FakeViewer()
main_model: MainModel = MainModel()
main_model.set_predictions_in_viewer(True)
thresholding_service: ThresholdingService = ThresholdingService(
thresholding_model,
FakeExperimentsModel(),
FileInputModel(),
main_model,
viewer,
task_executor=SynchroTaskExecutor.global_instance(),
)
# Only the [seg] layers below should produce a threshold layer
viewer.add_image(test_image, name="[raw] test_layer 1")
viewer.add_image(test_image, name="[seg] test_layer 1")
viewer.add_image(test_image, name="[raw] test_layer 2")
viewer.add_image(test_image, name="[seg] test_layer 2")
viewer.add_image(test_image, name="donotthreshold")

# ACT set a threshold to trigger
thresholding_model.set_thresholding_value(50)

# Verify a threshold layer is added for each seg layer
assert "[threshold] [seg] test_layer 1" in viewer.threshold_inserted
seg_data = viewer.threshold_inserted["[threshold] [seg] test_layer 1"]
assert np.array_equal(seg_data, (test_image > 50).astype(int))
assert "[threshold] [seg] test_layer 2" in viewer.threshold_inserted
seg_data = viewer.threshold_inserted["[threshold] [seg] test_layer 2"]
assert np.array_equal(seg_data, (test_image > 50).astype(int))
# verify that raw layers do not get thresholded
assert len(viewer.threshold_inserted) == 2

# verify existing threshold layers get updated correctly
thresholding_model.set_thresholding_value(100)
# Verify a threshold layer is added for each seg layer
assert "[threshold] [seg] test_layer 1" in viewer.threshold_inserted
seg_data = viewer.threshold_inserted["[threshold] [seg] test_layer 1"]
assert np.array_equal(seg_data, (test_image > 100).astype(int))
assert "[threshold] [seg] test_layer 2" in viewer.threshold_inserted
seg_data = viewer.threshold_inserted["[threshold] [seg] test_layer 2"]
assert np.array_equal(seg_data, (test_image > 100).astype(int))
# verify that raw layers do not get thresholded
assert len(viewer.threshold_inserted) == 2
Loading

0 comments on commit a0cb9f6

Please sign in to comment.