diff --git a/pyproject.toml b/pyproject.toml index e8b2df5d..375e1c8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,10 +38,10 @@ dependencies = [ "npe2>=0.6.2", "numpy", "hydra-core==1.3.2", - "bioio", + "bioio==1.0.1", "tifffile>=2023.4.12", "watchdog", - "cyto-dl>=0.1.8", + "cyto-dl>=0.4.4", "scikit-image!=0.23.0", ] diff --git a/src/allencell_ml_segmenter/_tests/core/test_file_input_model.py b/src/allencell_ml_segmenter/_tests/core/test_file_input_model.py index 6d4441d1..898d9a63 100644 --- a/src/allencell_ml_segmenter/_tests/core/test_file_input_model.py +++ b/src/allencell_ml_segmenter/_tests/core/test_file_input_model.py @@ -1,8 +1,12 @@ from pathlib import Path +import allencell_ml_segmenter from allencell_ml_segmenter._tests.fakes.fake_subscriber import FakeSubscriber from allencell_ml_segmenter.core.event import Event -from allencell_ml_segmenter.core.file_input_model import FileInputModel +from allencell_ml_segmenter.core.file_input_model import ( + FileInputModel, + InputMode, +) def test_set_selected_paths_no_extract_channels() -> None: @@ -98,3 +102,84 @@ def test_set_max_channels_dispatch() -> None: # Assert nothing happened dummy_subscriber.was_handled(Event.ACTION_FILEINPUT_MAX_CHANNELS_SET) + + +def test_get_input_files_as_list_from_path() -> None: + """ + Test to see if all paths from a directory are returned as a list + """ + # ARRANGE + file_input_model: FileInputModel = FileInputModel() + file_input_model.set_input_mode(InputMode.FROM_PATH) + file_input_model.set_input_image_path( + Path(allencell_ml_segmenter.__file__).parent + / "_tests" + / "test_files" + / "img_folder" + ) + + # Act + files: list[Path] = file_input_model.get_input_files_as_list() + + # Assert + assert len(files) == 5 + + +def test_get_input_files_as_list_from_viewer() -> None: + """ + Test to see if all paths from viewer displayed images are returned as a list + """ + # ARRANGE + file_input_model: FileInputModel = FileInputModel() + file_input_model.set_input_mode(InputMode.FROM_NAPARI_LAYERS) + fake_selected_paths: list[Path] = [Path("fake_path1"), Path("fake_path2")] + file_input_model.set_selected_paths(fake_selected_paths) + + # Act + files: list[Path] = file_input_model.get_input_files_as_list() + + # Assert + assert len(files) == 2 + assert files == fake_selected_paths + + +def test_get_input_files_as_list_from_no_directory_selected() -> None: + """ + Test to see if an empty list is returned when no directory is selected + """ + # ARRANGE + file_input_model: FileInputModel = FileInputModel() + + # Act + files: list[Path] = file_input_model.get_input_files_as_list() + + # Assert + assert len(files) == 0 + + +def test_get_input_files_as_list_from_no_selected_paths() -> None: + """ + Test to see if an empty list is returned when no layers were selected + """ + # ARRANGE + file_input_model: FileInputModel = FileInputModel() + + # Act + files: list[Path] = file_input_model.get_input_files_as_list() + + # Assert + assert len(files) == 0 + + +def test_get_input_files_as_list_from_no_selected_paths() -> None: + """ + Test to see if an empty list is returned when no input mode is selected + """ + # ARRANGE + file_input_model: FileInputModel = FileInputModel() + + # Act + files: list[Path] = file_input_model.get_input_files_as_list() + + # Assert + assert len(files) == 0 diff --git a/src/allencell_ml_segmenter/_tests/fakes/fake_viewer.py b/src/allencell_ml_segmenter/_tests/fakes/fake_viewer.py index 0160c664..6b6591e1 100644 --- a/src/allencell_ml_segmenter/_tests/fakes/fake_viewer.py +++ b/src/allencell_ml_segmenter/_tests/fakes/fake_viewer.py @@ -24,9 +24,10 @@ def __init__(self, viewer: Optional[napari.Viewer] = None): self._shapes_layers: Dict[str, ShapesLayer] = {} self._labels_layers: Dict[str, LabelsLayer] = {} self._on_layers_change_fns: List[Callable] = [] + self.threshold_inserted: Dict[str, np.ndarray] = {} def add_image(self, image: np.ndarray, name: str): - self._image_layers[name] = ImageLayer(name, None) + self._image_layers[name] = ImageLayer(name, path=None, data=image) self._on_layers_change() def get_image(self, name: str) -> Optional[ImageLayer]: @@ -93,7 +94,7 @@ def contains_layer(self, name: str) -> bool: # not supporting in the fake because we will move away from this fn in the near future def get_layers(self) -> List[Layer]: - return [] + return list(self._image_layers.values()) def subscribe_layers_change_event( self, function: Callable[[NapariEvent], None] @@ -103,3 +104,22 @@ def subscribe_layers_change_event( def _on_layers_change(self): for fn in self._on_layers_change_fns: fn(FakeNapariEvent()) + + def get_seg_layers(self) -> list[Layer]: + return [ + layer + for layer in self._image_layers.values() + if layer.name.startswith("[seg]") + ] + + def insert_threshold( + self, layer_name: str, img: np.ndarray, seg_layers: bool = False + ) -> None: + self.threshold_inserted[f"[threshold] {layer_name}"] = img + + def get_layers_nonthreshold(self) -> list[Layer]: + return [ + layer + for layer in self._image_layers.values() + if not layer.name.startswith("[threshold]") + ] diff --git a/src/allencell_ml_segmenter/_tests/services/test_prediction_service.py b/src/allencell_ml_segmenter/_tests/services/test_prediction_service.py index 848ca2f5..d6f88818 100644 --- a/src/allencell_ml_segmenter/_tests/services/test_prediction_service.py +++ b/src/allencell_ml_segmenter/_tests/services/test_prediction_service.py @@ -136,7 +136,7 @@ def test_build_overrides() -> None: assert overrides["train"] == False assert overrides["mode"] == "predict" assert overrides["task_name"] == "predict_task_from_app" - assert overrides["ckpt_path"] == str( + assert overrides["checkpoint.ckpt_path"] == str( Path(__file__).parent.parent / "main" / "experiments_home" diff --git a/src/allencell_ml_segmenter/postprocess/__init__.py b/src/allencell_ml_segmenter/_tests/thresholding/__init__.py similarity index 100% rename from src/allencell_ml_segmenter/postprocess/__init__.py rename to src/allencell_ml_segmenter/_tests/thresholding/__init__.py diff --git a/src/allencell_ml_segmenter/_tests/thresholding/test_thresholding_model.py b/src/allencell_ml_segmenter/_tests/thresholding/test_thresholding_model.py new file mode 100644 index 00000000..8d4d63ae --- /dev/null +++ b/src/allencell_ml_segmenter/_tests/thresholding/test_thresholding_model.py @@ -0,0 +1,54 @@ +import pytest + +from allencell_ml_segmenter._tests.fakes.fake_subscriber import FakeSubscriber +from allencell_ml_segmenter.core.event import Event +from allencell_ml_segmenter.thresholding.thresholding_model import ( + ThresholdingModel, +) + + +@pytest.fixture +def thresholding_model() -> ThresholdingModel: + model = ThresholdingModel() + return model + + +def test_set_thresholding_value_dispatches_event(thresholding_model): + fake_subscriber: FakeSubscriber = FakeSubscriber() + thresholding_model.subscribe( + Event.ACTION_THRESHOLDING_VALUE_CHANGED, + fake_subscriber, + fake_subscriber.handle, + ) + + thresholding_model.set_thresholding_value(2) + + assert fake_subscriber.was_handled(Event.ACTION_THRESHOLDING_VALUE_CHANGED) + + +def test_set_autothresholding_enabled_dispatches_event(thresholding_model): + fake_subscriber: FakeSubscriber = FakeSubscriber() + thresholding_model.subscribe( + Event.ACTION_THRESHOLDING_AUTOTHRESHOLDING_SELECTED, + fake_subscriber, + fake_subscriber.handle, + ) + + thresholding_model.set_autothresholding_enabled(True) + + assert fake_subscriber.was_handled( + Event.ACTION_THRESHOLDING_AUTOTHRESHOLDING_SELECTED + ) + + +def test_dispatch_save_thresholded_images(thresholding_model): + fake_subscriber: FakeSubscriber = FakeSubscriber() + thresholding_model.subscribe( + Event.ACTION_SAVE_THRESHOLDING_IMAGES, + fake_subscriber, + fake_subscriber.handle, + ) + + thresholding_model.dispatch_save_thresholded_images() + + assert fake_subscriber.was_handled(Event.ACTION_SAVE_THRESHOLDING_IMAGES) diff --git a/src/allencell_ml_segmenter/_tests/thresholding/test_thresholding_service.py b/src/allencell_ml_segmenter/_tests/thresholding/test_thresholding_service.py new file mode 100644 index 00000000..4ff6df6e --- /dev/null +++ b/src/allencell_ml_segmenter/_tests/thresholding/test_thresholding_service.py @@ -0,0 +1,101 @@ +import pytest +import numpy as np + +from allencell_ml_segmenter.core.file_input_model import FileInputModel +from allencell_ml_segmenter._tests.fakes.fake_experiments_model import ( + FakeExperimentsModel, +) +from allencell_ml_segmenter.main.main_model import MainModel +from allencell_ml_segmenter.thresholding.thresholding_model import ( + ThresholdingModel, +) +from allencell_ml_segmenter.thresholding.thresholding_service import ( + ThresholdingService, +) +from allencell_ml_segmenter.core.task_executor import SynchroTaskExecutor +from allencell_ml_segmenter._tests.fakes.fake_viewer import FakeViewer + + +@pytest.fixture +def test_image(): + """Create a small test image for thresholding.""" + return np.array([[10, 20, 30], [40, 50, 60], [70, 80, 90]]) + + +def test_on_threshold_changed_non_prediction(test_image): + # ARRANGE + thresholding_model: ThresholdingModel = ThresholdingModel() + viewer: FakeViewer = FakeViewer() + thresholding_service: ThresholdingService = ThresholdingService( + thresholding_model, + FakeExperimentsModel(), + FileInputModel(), + MainModel(), + viewer, + task_executor=SynchroTaskExecutor.global_instance(), + ) + viewer.add_image(test_image, name="test_layer") + + # ACT set a threshold to trigger + thresholding_model.set_thresholding_value(50) + + # Verify a segmentation layer is added + assert "[threshold] test_layer" in viewer.threshold_inserted + seg_data = viewer.threshold_inserted["[threshold] test_layer"] + assert np.array_equal(seg_data, (test_image > 50).astype(int)) + + # check if existing thresholds get updated + thresholding_model.set_thresholding_value(100) + assert len(viewer.get_layers()) == 1 + seg_data = viewer.threshold_inserted["[threshold] test_layer"] + assert np.array_equal(seg_data, (test_image > 100).astype(int)) + + +def test_on_threshold_changed_non_prediction(test_image): + """ + Test that the thresholding service does not add a threshold layer for a layer that is not a probability map + """ + # ARRANGE + thresholding_model: ThresholdingModel = ThresholdingModel() + viewer: FakeViewer = FakeViewer() + main_model: MainModel = MainModel() + main_model.set_predictions_in_viewer(True) + thresholding_service: ThresholdingService = ThresholdingService( + thresholding_model, + FakeExperimentsModel(), + FileInputModel(), + main_model, + viewer, + task_executor=SynchroTaskExecutor.global_instance(), + ) + # Only the [seg] layers below should produce a threshold layer + viewer.add_image(test_image, name="[raw] test_layer 1") + viewer.add_image(test_image, name="[seg] test_layer 1") + viewer.add_image(test_image, name="[raw] test_layer 2") + viewer.add_image(test_image, name="[seg] test_layer 2") + viewer.add_image(test_image, name="donotthreshold") + + # ACT set a threshold to trigger + thresholding_model.set_thresholding_value(50) + + # Verify a threshold layer is added for each seg layer + assert "[threshold] [seg] test_layer 1" in viewer.threshold_inserted + seg_data = viewer.threshold_inserted["[threshold] [seg] test_layer 1"] + assert np.array_equal(seg_data, (test_image > 50).astype(int)) + assert "[threshold] [seg] test_layer 2" in viewer.threshold_inserted + seg_data = viewer.threshold_inserted["[threshold] [seg] test_layer 2"] + assert np.array_equal(seg_data, (test_image > 50).astype(int)) + # verify that raw layers do not get thresholded + assert len(viewer.threshold_inserted) == 2 + + # verify existing threshold layers get updated correctly + thresholding_model.set_thresholding_value(100) + # Verify a threshold layer is added for each seg layer + assert "[threshold] [seg] test_layer 1" in viewer.threshold_inserted + seg_data = viewer.threshold_inserted["[threshold] [seg] test_layer 1"] + assert np.array_equal(seg_data, (test_image > 100).astype(int)) + assert "[threshold] [seg] test_layer 2" in viewer.threshold_inserted + seg_data = viewer.threshold_inserted["[threshold] [seg] test_layer 2"] + assert np.array_equal(seg_data, (test_image > 100).astype(int)) + # verify that raw layers do not get thresholded + assert len(viewer.threshold_inserted) == 2 diff --git a/src/allencell_ml_segmenter/_tests/thresholding/test_thresholding_view.py b/src/allencell_ml_segmenter/_tests/thresholding/test_thresholding_view.py new file mode 100644 index 00000000..1c1d7caa --- /dev/null +++ b/src/allencell_ml_segmenter/_tests/thresholding/test_thresholding_view.py @@ -0,0 +1,270 @@ +import pytest +from pathlib import Path + +from qtpy.QtCore import Qt + +from allencell_ml_segmenter._tests.fakes.fake_experiments_model import ( + FakeExperimentsModel, +) +from allencell_ml_segmenter._tests.fakes.fake_subscriber import FakeSubscriber +from allencell_ml_segmenter.core.event import Event +from allencell_ml_segmenter.main.main_model import MainModel +from allencell_ml_segmenter.thresholding.thresholding_model import ( + ThresholdingModel, +) +from allencell_ml_segmenter.core.file_input_model import ( + FileInputModel, + InputMode, +) +from allencell_ml_segmenter._tests.fakes.fake_viewer import FakeViewer +from allencell_ml_segmenter.thresholding.thresholding_view import ( + ThresholdingView, +) + + +@pytest.fixture +def main_model() -> MainModel: + return MainModel() + + +@pytest.fixture +def thresholding_model() -> ThresholdingModel: + model = ThresholdingModel() + model.set_thresholding_value(128) + return model + + +@pytest.fixture +# tmp_path is a builtin pytest fixture for a faked path +def file_input_model(tmp_path: Path) -> FileInputModel: + model = FileInputModel() + model.set_output_directory(tmp_path / "output") + model.set_input_image_path(tmp_path / "input") + model.set_input_mode(InputMode.FROM_PATH) + return model + + +@pytest.fixture +def experiments_model() -> FakeExperimentsModel: + return FakeExperimentsModel() + + +@pytest.fixture +def viewer() -> FakeViewer: + return FakeViewer() + + +@pytest.fixture +def thresholding_view( + main_model, + thresholding_model, + file_input_model, + experiments_model, + viewer, + qtbot, +): + view = ThresholdingView( + main_model, + thresholding_model, + file_input_model, + experiments_model, + viewer, + ) + qtbot.addWidget(view) + return view + + +def test_model_updates_on_slider_release( + thresholding_view, thresholding_model +): + # this tests to see if the model updates when the slider releases + # Arrange + initial_value: int = thresholding_model.get_thresholding_value() + + # Act + thresholding_view._threshold_value_slider.setValue(111) + thresholding_view._threshold_value_slider.sliderReleased.emit() + + # Assert + assert thresholding_model.get_thresholding_value() == 111 + assert thresholding_model.get_thresholding_value() != initial_value + + +def test_model_updates_on_spinbox_editing_finished( + thresholding_view, thresholding_model +): + # this tests to see if the model updates when the spinbox is edited + # Arrange + initial_value = thresholding_model.get_thresholding_value() + + # Act + new_value = 122 + thresholding_view._threshold_value_spinbox.setValue(new_value) + thresholding_view._threshold_value_spinbox.editingFinished.emit() + + # Assert: Model value should update to match the spinbox's new value + assert thresholding_model.get_thresholding_value() == new_value + assert thresholding_model.get_thresholding_value() != initial_value + + +def test_update_spinbox_from_slider(thresholding_view, qtbot): + # This tests to see if the spinbox updates live along with the slider + # Act: dragging + for new_value in range( + 60, 101, 10 + ): # Simulate dragging from 60 to 100 in steps + thresholding_view._threshold_value_slider.setValue(new_value) + qtbot.wait(10) # small delay + + # Assert that spinbox always stays updated + assert thresholding_view._threshold_value_spinbox.value() == new_value + + # Act: clicking a value + qtbot.mouseClick(thresholding_view._threshold_value_slider, Qt.LeftButton) + clicked_value = 120 + thresholding_view._threshold_value_slider.setValue(clicked_value) + + # Assert that spinbox updated correctly + assert thresholding_view._threshold_value_spinbox.value() == clicked_value + + +def test_update_slider_from_spinbox(thresholding_view, qtbot): + # This tests to see if the slider updates live when the spinbox is changed + # Act + new_value = 100 + thresholding_view._threshold_value_spinbox.setValue(new_value) + qtbot.keyPress( + thresholding_view._threshold_value_spinbox, Qt.Key_Enter + ) # Simulate user pressing "Enter" + + # Assert + assert thresholding_view._threshold_value_slider.value() == new_value + + +def test_update_state_from_radios(thresholding_view, thresholding_model): + # This tests if ui/model updates correctly when the radio buttons are toggled + # Arrange + assert not thresholding_view._none_radio_button.isChecked() + assert not thresholding_view._specific_value_radio_button.isChecked() + assert not thresholding_view._autothreshold_radio_button.isChecked() + assert not thresholding_view._apply_save_button.isEnabled() + assert not thresholding_view._threshold_value_slider.isEnabled() + assert not thresholding_view._threshold_value_spinbox.isEnabled() + + # Act + thresholding_view._specific_value_radio_button.setChecked(True) + thresholding_view._update_state_from_radios() + + # Assert + assert thresholding_model.is_threshold_enabled() + assert not thresholding_model.is_autothresholding_enabled() + assert thresholding_view._apply_save_button.isEnabled() + assert thresholding_view._threshold_value_slider.isEnabled() + assert thresholding_view._threshold_value_spinbox.isEnabled() + + thresholding_view._specific_value_radio_button.setChecked(False) + thresholding_view._autothreshold_radio_button.setChecked(True) + thresholding_view._update_state_from_radios() + + # Assert + assert not thresholding_model.is_threshold_enabled() + assert thresholding_model.is_autothresholding_enabled() + assert thresholding_view._apply_save_button.isEnabled() + assert thresholding_view._autothreshold_method_combo.isEnabled() + assert not thresholding_view._threshold_value_slider.isEnabled() + assert not thresholding_view._threshold_value_spinbox.isEnabled() + + +def test_check_able_to_threshold_valid( + main_model, file_input_model, experiments_model, viewer +): + thresholding_model: ThresholdingModel = ThresholdingModel() + thresholding_model.set_threshold_enabled(True) + thresholding_model.set_thresholding_value(100) + thresholding_view: ThresholdingView = ThresholdingView( + main_model, + thresholding_model, + file_input_model, + experiments_model, + viewer, + ) + + assert thresholding_view._check_able_to_threshold() + + +def test_check_able_to_threshold_no_output_dir( + main_model, experiments_model, viewer +): + thresholding_model: ThresholdingModel = ThresholdingModel() + thresholding_model.set_threshold_enabled(True) + thresholding_model.set_thresholding_value(100) + file_input_model: FileInputModel = FileInputModel() + file_input_model.set_input_mode(InputMode.FROM_PATH) + file_input_model.set_input_image_path(Path("fake_path")) + thresholding_view: ThresholdingView = ThresholdingView( + main_model, + thresholding_model, + file_input_model, + experiments_model, + viewer, + ) + + assert not thresholding_view._check_able_to_threshold() + + +def test_check_able_to_threshold_no_input_dir( + main_model, experiments_model, viewer +): + thresholding_model: ThresholdingModel = ThresholdingModel() + thresholding_model.set_threshold_enabled(True) + thresholding_model.set_thresholding_value(100) + file_input_model: FileInputModel = FileInputModel() + file_input_model.set_input_mode(InputMode.FROM_PATH) + file_input_model.set_output_directory(Path("fake_path")) + thresholding_view: ThresholdingView = ThresholdingView( + main_model, + thresholding_model, + file_input_model, + experiments_model, + viewer, + ) + + assert not thresholding_view._check_able_to_threshold() + + +def test_check_able_to_threshold_no_input_method( + main_model, experiments_model, viewer +): + thresholding_model: ThresholdingModel = ThresholdingModel() + thresholding_model.set_threshold_enabled(True) + thresholding_model.set_thresholding_value(100) + file_input_model: FileInputModel = FileInputModel() + file_input_model.set_input_image_path(Path("fake_path")) + file_input_model.set_output_directory(Path("fake_path")) + thresholding_view: ThresholdingView = ThresholdingView( + main_model, + thresholding_model, + file_input_model, + experiments_model, + viewer, + ) + + assert not thresholding_view._check_able_to_threshold() + + +def check_button_press_dispatches_event( + thresholding_view, thresholding_model, qtbot +): + # arrange + fake_subscriber: FakeSubscriber = FakeSubscriber() + thresholding_model.subscribe( + Event.ACTION_SAVE_THRESHOLDING_IMAGES, + fake_subscriber, + fake_subscriber.handle, + ) + + # act + qtbot.mouseClick(thresholding_view._apply_save_button, Qt.LeftButton) + + # assert that event was dispatched + assert fake_subscriber.handled[Event.ACTION_SAVE_THRESHOLDING_IMAGES] diff --git a/src/allencell_ml_segmenter/core/event.py b/src/allencell_ml_segmenter/core/event.py index d72c0c8f..b5134de1 100644 --- a/src/allencell_ml_segmenter/core/event.py +++ b/src/allencell_ml_segmenter/core/event.py @@ -64,6 +64,11 @@ class Event(Enum): ACTION_CURATION_SEG1_THREAD_ERROR = "curation_seg1_thread_error" ACTION_CURATION_RAW_THREAD_ERROR = "curation_raw_thread_error" + # Thresholding events + ACTION_THRESHOLDING_VALUE_CHANGED = "thresholding_value_changed" + ACTION_THRESHOLDING_AUTOTHRESHOLDING_SELECTED = "autothresholding_selected" + ACTION_SAVE_THRESHOLDING_IMAGES = "save_thresholding_images" + # View selection events. These can stem from a user action, or from a process (i.e. prediction process ends, and a new view is shown automatically). VIEW_SELECTION_TRAINING = "training_selected" VIEW_SELECTION_PREDICTION = "prediction_selected" diff --git a/src/allencell_ml_segmenter/core/file_input_model.py b/src/allencell_ml_segmenter/core/file_input_model.py index ce3cd885..b0543ca0 100644 --- a/src/allencell_ml_segmenter/core/file_input_model.py +++ b/src/allencell_ml_segmenter/core/file_input_model.py @@ -79,3 +79,16 @@ def set_max_channels(self, max: Optional[int]) -> None: def get_selected_paths(self) -> Optional[list[Path]]: return self._selected_paths + + def get_input_files_as_list(self) -> List[Path]: + input_image_path: Optional[Path] = self.get_input_image_path() + if ( + self.get_input_mode() == InputMode.FROM_PATH + and input_image_path is not None + ): + return list(input_image_path.glob("*")) + elif self.get_input_mode() == InputMode.FROM_NAPARI_LAYERS: + selected_paths: Optional[list[Path]] = self.get_selected_paths() + if selected_paths is not None: + return selected_paths + return [] diff --git a/src/allencell_ml_segmenter/core/file_input_widget.py b/src/allencell_ml_segmenter/core/file_input_widget.py index ea004fcb..71481208 100644 --- a/src/allencell_ml_segmenter/core/file_input_widget.py +++ b/src/allencell_ml_segmenter/core/file_input_widget.py @@ -51,9 +51,10 @@ def __init__( model: FileInputModel, viewer: IViewer, service: ModelFileService, + include_channel_selection: bool = True, ): super().__init__() - + self._include_channel_selection: bool = include_channel_selection self._model: FileInputModel = model self._viewer: IViewer = viewer self._service: ModelFileService = service @@ -134,7 +135,7 @@ def __init__( self._browse_dir_edit: InputButton = InputButton( self._model, lambda dir: self._model.set_input_image_path( - Path(dir), extract_channels=True + Path(dir), extract_channels=include_channel_selection ), "Select directory...", FileInputMode.DIRECTORY, @@ -147,33 +148,35 @@ def __init__( frame_layout.addLayout(horiz_layout) grid_layout: QGridLayout = QGridLayout() + if include_channel_selection: + image_input_label: LabelWithHint = LabelWithHint( + "Input image's channel" + ) + image_input_label.set_hint( + "Select which channel of the input image(s) to apply the trained model on" + ) - image_input_label: LabelWithHint = LabelWithHint( - "Input image's channel" - ) - image_input_label.set_hint( - "Select which channel of the input image(s) to apply the trained model on" - ) - - self._channel_select_dropdown: QComboBox = QComboBox() + self._channel_select_dropdown: QComboBox = QComboBox() - self._channel_select_dropdown.setCurrentIndex(-1) - self._channel_select_dropdown.currentIndexChanged.connect( - self._model.set_image_input_channel_index - ) - self._channel_select_dropdown.setEnabled(False) - # Event to trigger combobox populate when we know the number of channels - self._model.subscribe( - Event.ACTION_FILEINPUT_MAX_CHANNELS_SET, - self, - self._populate_input_channel_combobox, - ) - # Event to set combobox text to 'loading' when we begin extracting channels - self._model.subscribe( - Event.ACTION_FILEINPUT_EXTRACT_CHANNELS, - self, - self._set_input_channel_combobox_to_loading, - ) + self._channel_select_dropdown.setCurrentIndex(-1) + self._channel_select_dropdown.currentIndexChanged.connect( + self._model.set_image_input_channel_index + ) + self._channel_select_dropdown.setEnabled(False) + # Event to trigger combobox populate when we know the number of channels + self._model.subscribe( + Event.ACTION_FILEINPUT_MAX_CHANNELS_SET, + self, + self._populate_input_channel_combobox, + ) + # Event to set combobox text to 'loading' when we begin extracting channels + self._model.subscribe( + Event.ACTION_FILEINPUT_EXTRACT_CHANNELS, + self, + self._set_input_channel_combobox_to_loading, + ) + grid_layout.addWidget(image_input_label, 0, 0) + grid_layout.addWidget(self._channel_select_dropdown, 0, 1) output_dir_label: LabelWithHint = LabelWithHint("Output directory") output_dir_label.set_hint( @@ -187,9 +190,6 @@ def __init__( FileInputMode.DIRECTORY, ) - grid_layout.addWidget(image_input_label, 0, 0) - grid_layout.addWidget(self._channel_select_dropdown, 0, 1) - grid_layout.addWidget(output_dir_label, 1, 0) grid_layout.addWidget(self._browse_output_edit, 1, 1) @@ -214,19 +214,24 @@ def _from_directory_slot(self) -> None: self._model.set_input_mode(InputMode.FROM_PATH) def _update_layer_list(self, event: Optional[NapariEvent] = None) -> None: + existing_selection = [ + self._viewer.get_layers_nonthreshold()[i].name + for i in self._image_list.get_checked_rows() + ] self._image_list.clear() - self._model.set_selected_paths([], extract_channels=False) self._reset_channel_combobox() for layer in self._viewer.get_layers(): path_of_layer_image: str = layer.source.path if path_of_layer_image: - self._image_list.add_item(layer.name) + self._image_list.add_item( + layer.name, set_checked=layer.name in existing_selection + ) def _process_checked_signal(self, row: int, state: Qt.CheckState) -> None: if self._model.get_input_mode() == InputMode.FROM_NAPARI_LAYERS: selected_indices: List[int] = self._image_list.get_checked_rows() selected_paths: List[Path] = [ - Path(self._viewer.get_layers()[i].source.path) + Path(self._viewer.get_layers_nonthreshold()[i].source.path) for i in selected_indices ] @@ -261,10 +266,11 @@ def _process_checked_signal(self, row: int, state: Qt.CheckState) -> None: ) def _reset_channel_combobox(self) -> None: - self._channel_select_dropdown.clear() - self._channel_select_dropdown.setPlaceholderText("") - self._channel_select_dropdown.setCurrentIndex(-1) - self._channel_select_dropdown.setEnabled(False) + if self._include_channel_selection: + self._channel_select_dropdown.clear() + self._channel_select_dropdown.setPlaceholderText("") + self._channel_select_dropdown.setCurrentIndex(-1) + self._channel_select_dropdown.setEnabled(False) def _set_input_channel_combobox_to_loading( self, event: Optional[Event] = None diff --git a/src/allencell_ml_segmenter/main/experiments_model.py b/src/allencell_ml_segmenter/main/experiments_model.py index 79f26c59..4d5f7bae 100644 --- a/src/allencell_ml_segmenter/main/experiments_model.py +++ b/src/allencell_ml_segmenter/main/experiments_model.py @@ -76,21 +76,30 @@ def get_model_checkpoints_path( return user_exp_path / experiment_name / "checkpoints" / checkpoint - def _get_exp_path(self) -> Path: + def _get_exp_path(self) -> Optional[Path]: user_exp_path: Optional[Path] = self.get_user_experiments_path() exp_name: Optional[str] = self.get_experiment_name() - if user_exp_path is None or exp_name is None: - raise ValueError("Experiment path or name undefined") - return user_exp_path / exp_name - - def get_csv_path(self) -> Path: - return self._get_exp_path() / "data" - - def get_metrics_csv_path(self) -> Path: - return self._get_exp_path() / "csv" - - def get_cache_dir(self) -> Path: - return self._get_exp_path() / "cache" + if exp_name is not None and user_exp_path is not None: + return user_exp_path / exp_name + return None + + def get_csv_path(self) -> Optional[Path]: + exp_path: Optional[Path] = self._get_exp_path() + if exp_path is not None: + return exp_path / "data" + return None + + def get_metrics_csv_path(self) -> Optional[Path]: + exp_path: Optional[Path] = self._get_exp_path() + if exp_path is not None: + return exp_path / "csv" + return None + + def get_cache_dir(self) -> Optional[Path]: + exp_path: Optional[Path] = self._get_exp_path() + if exp_path is not None: + return exp_path / "cache" + return None def get_latest_metrics_csv_version(self) -> int: """ @@ -99,8 +108,9 @@ def get_latest_metrics_csv_version(self) -> int: exist """ last_version: int = -1 - if self.get_metrics_csv_path().exists(): - for child in self.get_metrics_csv_path().glob("version_*"): + csv_path: Optional[Path] = self.get_metrics_csv_path() + if csv_path is not None and csv_path.exists(): + for child in csv_path.glob("version_*"): if child.is_dir(): version_str: str = child.name.split("_")[-1] try: @@ -115,9 +125,10 @@ def get_latest_metrics_csv_version(self) -> int: def get_latest_metrics_csv_path(self) -> Optional[Path]: version: int = self.get_latest_metrics_csv_version() + csv_path: Optional[Path] = self.get_metrics_csv_path() return ( - self.get_metrics_csv_path() / f"version_{version}" / "metrics.csv" - if version >= 0 + csv_path / f"version_{version}" / "metrics.csv" + if version >= 0 and csv_path is not None else None ) @@ -133,8 +144,13 @@ def get_train_config_path( ) return user_exp_path / experiment_name / "train_config.yaml" else: + user_exp_path = self._get_exp_path() # get config for currently selected experiment - return self._get_exp_path() / "train_config.yaml" + if user_exp_path is None: + raise ValueError( + "user_exp_path cannot be None if experiment_name is also None in get_train_config_path" + ) + return user_exp_path / "train_config.yaml" def get_current_epoch(self) -> Optional[int]: ckpt: Optional[Path] = self.get_best_ckpt() @@ -160,8 +176,7 @@ def get_best_ckpt(self) -> Optional[Path]: ) def get_channel_selection_path(self) -> Optional[Path]: - return ( - self.get_csv_path() / "selected_channels.json" - if self.get_csv_path() is not None - else None - ) + csv_path: Optional[Path] = self.get_csv_path() + if csv_path is None: + return None + return csv_path / "selected_channels.json" diff --git a/src/allencell_ml_segmenter/main/i_experiments_model.py b/src/allencell_ml_segmenter/main/i_experiments_model.py index 59ba5aae..1a106a73 100644 --- a/src/allencell_ml_segmenter/main/i_experiments_model.py +++ b/src/allencell_ml_segmenter/main/i_experiments_model.py @@ -55,7 +55,7 @@ def get_model_checkpoints_path( pass @abstractmethod - def get_metrics_csv_path(self) -> Path: + def get_metrics_csv_path(self) -> Optional[Path]: pass @abstractmethod @@ -67,11 +67,11 @@ def get_latest_metrics_csv_path(self) -> Optional[Path]: pass @abstractmethod - def get_csv_path(self) -> Path: + def get_csv_path(self) -> Optional[Path]: pass @abstractmethod - def get_cache_dir(self) -> Path: + def get_cache_dir(self) -> Optional[Path]: pass @abstractmethod diff --git a/src/allencell_ml_segmenter/main/i_viewer.py b/src/allencell_ml_segmenter/main/i_viewer.py index dc3886d3..f73f931a 100644 --- a/src/allencell_ml_segmenter/main/i_viewer.py +++ b/src/allencell_ml_segmenter/main/i_viewer.py @@ -72,3 +72,31 @@ def subscribe_layers_change_event( self, function: Callable[[NapariEvent], None] ) -> None: pass + + @abstractmethod + def get_seg_layers(self) -> list[Layer]: + """ + Get only segmentation layers (which should be probability mappings) from the viewer. + These are the layers that start with [seg]. + """ + pass + + @abstractmethod + def insert_threshold( + self, layer_name: str, img: np.ndarray, seg_layers: bool = False + ) -> None: + """ + Insert a thresholded image into the viewer. + If a layer for this thresholded image already exists, the new image will replace the old one and refresh the viewer. + If the layer does not exist, it will be added to the viewer in the correct place (on top of the original segmentation image: + index_of_segmentation + 1 in the LayerList) + """ + pass + + @abstractmethod + def get_layers_nonthreshold(self) -> list[Layer]: + """ + Get only layers which are not segmentation layers from the viewer. + These are the layers that do not start with [threshold]. + """ + pass diff --git a/src/allencell_ml_segmenter/main/main_model.py b/src/allencell_ml_segmenter/main/main_model.py index 590aaf62..830107f7 100644 --- a/src/allencell_ml_segmenter/main/main_model.py +++ b/src/allencell_ml_segmenter/main/main_model.py @@ -32,6 +32,9 @@ def __init__(self) -> None: self._current_view: Optional[MainWindow] = None self._is_new_model: bool = False self.signals: MainModelSignals = MainModelSignals() + self._predictions_in_viewer: bool = ( + False # Tracks whether predictions are displayed in the viewer + ) self._selected_channels: dict[ImageType, Optional[int]] = { ImageType.RAW: None, @@ -85,3 +88,15 @@ def set_selected_channels( def training_complete(self) -> None: self.dispatch(Event.PROCESS_TRAINING_COMPLETE) + + def set_predictions_in_viewer(self, predictions_in_viewer: bool) -> None: + """ + Set if predicted images (probability mappings) are displayed in the viewer. + """ + self._predictions_in_viewer = predictions_in_viewer + + def are_predictions_in_viewer(self) -> bool: + """ + Check if predicted images (probability mappings) are displayed in the viewer. + """ + return self._predictions_in_viewer diff --git a/src/allencell_ml_segmenter/main/main_widget.py b/src/allencell_ml_segmenter/main/main_widget.py index bd49fbce..e79cf359 100644 --- a/src/allencell_ml_segmenter/main/main_widget.py +++ b/src/allencell_ml_segmenter/main/main_widget.py @@ -28,6 +28,12 @@ PredictionService, ) from allencell_ml_segmenter.services.training_service import TrainingService +from allencell_ml_segmenter.thresholding.thresholding_model import ( + ThresholdingModel, +) +from allencell_ml_segmenter.thresholding.thresholding_service import ( + ThresholdingService, +) from allencell_ml_segmenter.training.model_selection_widget import ( ModelSelectionWidget, ) @@ -36,7 +42,7 @@ from allencell_ml_segmenter.curation.curation_model import CurationModel from allencell_ml_segmenter._style import Style from allencell_ml_segmenter.curation.curation_service import CurationService -from allencell_ml_segmenter.postprocess.postprocess_view import ( +from allencell_ml_segmenter.thresholding.thresholding_view import ( ThresholdingView, ) from allencell_ml_segmenter.core.file_input_model import FileInputModel @@ -84,13 +90,16 @@ def __init__( self._training_model: TrainingModel = TrainingModel( main_model=self._model, experiments_model=self._experiments_model ) - self._file_input_model: FileInputModel = FileInputModel() + self._prediction_file_input_model: FileInputModel = FileInputModel() self._prediction_model: PredictionModel = PredictionModel() self._curation_model: CurationModel = CurationModel( self._experiments_model, self._model, ) + self._thresholding_file_input_model: FileInputModel = FileInputModel() + self._thresholding_model: ThresholdingModel = ThresholdingModel() + # init services self._main_service: MainService = MainService( self._model, self._experiments_model @@ -105,8 +114,16 @@ def __init__( ) self._prediction_service: PredictionService = PredictionService( prediction_model=self._prediction_model, - file_input_model=self._file_input_model, + file_input_model=self._prediction_file_input_model, + experiments_model=self._experiments_model, + ) + + self._thresholding_service: ThresholdingService = ThresholdingService( + thresholding_model=self._thresholding_model, experiments_model=self._experiments_model, + file_input_model=self._thresholding_file_input_model, + main_model=self._model, + viewer=self.viewer, ) # keep track of windows @@ -135,7 +152,7 @@ def __init__( self._prediction_view: PredictionView = PredictionView( main_model=self._model, prediction_model=self._prediction_model, - file_input_model=self._file_input_model, + file_input_model=self._prediction_file_input_model, viewer=self.viewer, ) self._initialize_window(self._prediction_view, "Prediction") @@ -143,6 +160,8 @@ def __init__( self._thresholding_view = ThresholdingView( main_model=self._model, + thresholding_model=self._thresholding_model, + file_input_model=self._thresholding_file_input_model, experiments_model=self._experiments_model, viewer=self.viewer, ) diff --git a/src/allencell_ml_segmenter/main/segmenter_layer.py b/src/allencell_ml_segmenter/main/segmenter_layer.py index b3860015..88d3b26c 100644 --- a/src/allencell_ml_segmenter/main/segmenter_layer.py +++ b/src/allencell_ml_segmenter/main/segmenter_layer.py @@ -14,9 +14,16 @@ class ShapesLayer(SegmenterLayer): data: np.ndarray +@dataclass +class Source(SegmenterLayer): + path: Optional[Path] = None + + @dataclass class ImageLayer(SegmenterLayer): - path: Optional[Path] + path: Optional[Path] = None + data: Optional[np.ndarray] = None + source: Source = Source(name="sourcename") @dataclass diff --git a/src/allencell_ml_segmenter/main/viewer.py b/src/allencell_ml_segmenter/main/viewer.py index 22ec5615..cd7e6662 100644 --- a/src/allencell_ml_segmenter/main/viewer.py +++ b/src/allencell_ml_segmenter/main/viewer.py @@ -94,6 +94,17 @@ def contains_layer(self, name: str) -> bool: def get_layers(self) -> list[Layer]: return [l for l in self.viewer.layers] + def get_layers_nonthreshold(self) -> list[Layer]: + """ + Get only layers which are not segmentation layers from the viewer. + These are the layers that do not start with [threshold]. + """ + return [ + l + for l in self.viewer.layers + if not l.name.startswith("[threshold]") + ] + def subscribe_layers_change_event( self, function: Callable[[NapariEvent], None] ) -> None: @@ -105,3 +116,54 @@ def _get_layer_by_name(self, name: str) -> Optional[Layer]: if l.name == name: return l return None + + def get_seg_layers(self) -> list[Layer]: + """ + Get only segmentation layers (which should be probability mappings) from the viewer. + These are the layers that start with [seg]. + """ + return [ + layer + for layer in self.get_layers() + if layer.name.startswith("[seg]") + ] + + def insert_threshold( + self, + layer_name: str, + image: np.ndarray, + remove_seg_layers: bool = False, + ) -> None: + """ + Insert a thresholded image into the viewer. + If a layer for this thresholded image already exists, the new image will replace the old one and refresh the viewer. + If the layer does not exist, it will be added to the viewer in the correct place (on top of the original segmentation image: + index_of_segmentation + 1 in the LayerList) + + :param layer_name: name of layer to insert. Will replace if one exists, will create one in a new position if needed. + :param image: image to insert + :param remove_seg_layers: boolean indicating if the layer that is being thresholded is a segmentation layer, and should be removed from the layer once it is updated with the threshold. + """ + layer_to_insert = self._get_layer_by_name(f"[threshold] {layer_name}") + if layer_to_insert is None: + # No thresholding exists, so we add it to the correct place in the viewer + layerlist = self.viewer.layers + + # check if the original segementation layer is currently in the viewer, if so, remove later after + # thresholding is applied + seg_layer_og: Optional[Layer] = None + if remove_seg_layers: + seg_layer_og = self._get_layer_by_name(layer_name) + + # figure out where to insert the new thresholded layer (on top of the original segmentation image) + layerlist_pos = layerlist.index(layer_name) + labels_created = Labels(image, name=f"[threshold] {layer_name}") + layerlist.insert(layerlist_pos + 1, labels_created) + + # remove the original segmentation layer if it exists + if seg_layer_og: + layerlist.remove(seg_layer_og) + else: + # Thresholding already exists so just update the existing one in the viewer. + layer_to_insert.data = image + layer_to_insert.refresh() diff --git a/src/allencell_ml_segmenter/postprocess/postprocess_view.py b/src/allencell_ml_segmenter/postprocess/postprocess_view.py deleted file mode 100644 index 96a3f6aa..00000000 --- a/src/allencell_ml_segmenter/postprocess/postprocess_view.py +++ /dev/null @@ -1,221 +0,0 @@ -from allencell_ml_segmenter.main.i_experiments_model import IExperimentsModel -from allencell_ml_segmenter.main.i_viewer import IViewer -from allencell_ml_segmenter._style import Style -from allencell_ml_segmenter.core.view import View, MainWindow -from allencell_ml_segmenter.main.main_model import MainModel -from allencell_ml_segmenter.prediction.service import ModelFileService - -from allencell_ml_segmenter.widgets.label_with_hint_widget import LabelWithHint -from allencell_ml_segmenter.core.file_input_widget import ( - FileInputWidget, -) -from allencell_ml_segmenter.core.file_input_model import FileInputModel - -from qtpy.QtWidgets import ( - QLabel, - QVBoxLayout, - QHBoxLayout, - QSizePolicy, - QComboBox, - QGroupBox, - QRadioButton, - QPushButton, - QSlider, - QDoubleSpinBox, - QFileDialog, -) -from qtpy.QtCore import Qt - - -class ThresholdingView(View, MainWindow): - """ - View for thresholding - """ - - def __init__( - self, - main_model: MainModel, - experiments_model: IExperimentsModel, - viewer: IViewer, - ): - super().__init__() - - self._main_model: MainModel = main_model - self._experiments_model: IExperimentsModel = experiments_model - self._viewer: IViewer = viewer - self._thresholding_model: FileInputModel = FileInputModel() - self._service: ModelFileService = ModelFileService( - self._thresholding_model - ) - - layout: QVBoxLayout = QVBoxLayout() - layout.setContentsMargins(0, 0, 0, 0) - layout.setSpacing(20) - self.setLayout(layout) - layout.setAlignment(Qt.AlignmentFlag.AlignTop) - self.setSizePolicy( - QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Maximum - ) - - # title - self._title: QLabel = QLabel("THRESHOLD", self) - self._title.setObjectName("title") - layout.addWidget(self._title, alignment=Qt.AlignmentFlag.AlignHCenter) - - # selecting input image - self._file_input_widget: FileInputWidget = FileInputWidget( - self._thresholding_model, self._viewer, self._service - ) - self._file_input_widget.setObjectName("fileInput") - layout.addWidget(self._file_input_widget) - - # thresholding values - self._threshold_label: LabelWithHint = LabelWithHint("Threshold") - self._threshold_label.set_hint("Values to threshold with.") - self._threshold_label.setObjectName("title") - layout.addWidget(self._threshold_label) - - threshold_group_box = QGroupBox() - threshold_group_layout = QVBoxLayout() - - # none thresholding selection - none_radio_layout: QHBoxLayout = QHBoxLayout() - self._none_radio_button: QRadioButton = QRadioButton() - none_radio_layout.addWidget(self._none_radio_button) - - none_radio_label: LabelWithHint = LabelWithHint("None") - none_radio_label.set_hint("No thresholding applied.") - none_radio_layout.addWidget(none_radio_label) - threshold_group_layout.addLayout(none_radio_layout) - - # specific value thresholding selection - specific_value_layout = QHBoxLayout() - - self._specific_value_radio_button: QRadioButton = QRadioButton() - specific_value_layout.addWidget(self._specific_value_radio_button) - specific_radio_label: LabelWithHint = LabelWithHint("Specific Value") - specific_radio_label.set_hint( - "Set thresholding value you'd like to apply." - ) - specific_value_layout.addWidget(specific_radio_label) - - self._threshold_value_slider: QSlider = QSlider( - Qt.Orientation.Horizontal - ) - self._threshold_value_slider.setRange( - 0, 100 - ) # slider values from 0 to 100 (representing 0.0 to 1.0) - self._threshold_value_slider.setValue(50) # Default value at 0.5 - - self._threshold_value_spinbox: QDoubleSpinBox = QDoubleSpinBox() - self._threshold_value_spinbox.setRange(0.0, 1.0) - self._threshold_value_spinbox.setSingleStep(0.01) - self._threshold_value_spinbox.setValue(0.5) - - # sync slider and spinbox - self._threshold_value_slider.valueChanged.connect( - self._update_spinbox_from_slider - ) - self._threshold_value_spinbox.valueChanged.connect( - self._update_slider_from_spinbox - ) - - # add slider and spinbox - specific_value_layout.addWidget(self._specific_value_radio_button) - specific_value_layout.addWidget(self._threshold_value_slider) - specific_value_layout.addWidget(self._threshold_value_spinbox) - threshold_group_layout.addLayout(specific_value_layout) - - # autothresholding - autothreshold_layout = QHBoxLayout() - self._autothreshold_radio_button: QRadioButton = QRadioButton() - auto_thresh_label: LabelWithHint = LabelWithHint("Autothreshold") - auto_thresh_label.set_hint("Apply an autothresholding method.") - - self._autothreshold_method_combo: QComboBox = QComboBox() - self._autothreshold_method_combo.addItems(["Otsu"]) - self._autothreshold_method_combo.setEnabled(False) - - autothreshold_layout.addWidget(self._autothreshold_radio_button) - autothreshold_layout.addWidget(auto_thresh_label) - autothreshold_layout.addWidget(self._autothreshold_method_combo) - threshold_group_layout.addLayout(autothreshold_layout) - - threshold_group_box.setLayout(threshold_group_layout) - layout.addWidget(threshold_group_box) - - # apply and save - self._apply_save_button: QPushButton = QPushButton("Apply & Save") - self._apply_save_button.setEnabled(False) - layout.addWidget(self._apply_save_button) - - # need styling - self.setStyleSheet(Style.get_stylesheet("thresholding_view.qss")) - - # configure widget behavior - self._configure_slots() - - def _configure_slots(self) -> None: - """ - Connects behavior for widgets - """ - - # enable selections when corresponding radio button is selected - self._specific_value_radio_button.toggled.connect( - lambda checked: self._enable_specific_threshold_widgets(checked) - ) - self._autothreshold_radio_button.toggled.connect( - lambda checked: self._autothreshold_method_combo.setEnabled( - checked - ) - ) - - # enable apply button only when thresholding method is selected - self._none_radio_button.toggled.connect(self._enable_apply_button) - self._specific_value_radio_button.toggled.connect( - self._enable_apply_button - ) - self._autothreshold_radio_button.toggled.connect( - self._enable_apply_button - ) - - def _update_spinbox_from_slider(self, value: int) -> None: - """ - Update the spinbox value when slider is changed - """ - self._threshold_value_spinbox.setValue(value / 100.0) - - def _update_slider_from_spinbox(self, value: float) -> None: - """ - Update the slider value when spinbox is changed - """ - self._threshold_value_slider.setValue(int(value * 100)) - - def _enable_specific_threshold_widgets(self, enabled: bool) -> None: - """ - enable or disable specific value thresholding widgets - """ - self._threshold_value_slider.setEnabled(enabled) - self._threshold_value_spinbox.setEnabled(enabled) - - def _enable_apply_button(self) -> None: - """ - enable or disable apply button - """ - self._apply_save_button.setEnabled( - self._none_radio_button.isChecked() - or self._specific_value_radio_button.isChecked() - or self._autothreshold_radio_button.isChecked() - ) - - def doWork(self) -> None: - return - - def focus_changed(self) -> None: - return - - def getTypeOfWork(self) -> str: - return "" - - def showResults(self) -> None: - return diff --git a/src/allencell_ml_segmenter/prediction/view.py b/src/allencell_ml_segmenter/prediction/view.py index a99b5cab..f7307216 100644 --- a/src/allencell_ml_segmenter/prediction/view.py +++ b/src/allencell_ml_segmenter/prediction/view.py @@ -170,33 +170,35 @@ def showResults(self) -> None: stem_to_data: dict[str, dict[str, Path]] = { raw_img.stem: {"raw": raw_img} for raw_img in raw_imgs } - for seg in segmentations: - # ignore files in the folder that aren't from most recent predictions - if seg.stem in stem_to_data: - stem_to_data[seg.stem]["seg"] = seg - - self._viewer.clear_layers() - for data in stem_to_data.values(): - raw_np_data: Optional[np.ndarray] = ( - self._img_data_extractor.extract_image_data( - data["raw"], channel=channel - ).np_data - ) - seg_np_data: Optional[np.ndarray] = ( - self._img_data_extractor.extract_image_data( - data["seg"], seg=1 - ).np_data - ) - if raw_np_data is not None: - self._viewer.add_image( - raw_np_data, - f"[raw] {data['raw'].name}", + if segmentations: + for seg in segmentations: + # ignore files in the folder that aren't from most recent predictions + if seg.stem in stem_to_data: + stem_to_data[seg.stem]["seg"] = seg + + self._viewer.clear_layers() + for data in stem_to_data.values(): + raw_np_data: Optional[np.ndarray] = ( + self._img_data_extractor.extract_image_data( + data["raw"], channel=channel + ).np_data ) - if seg_np_data is not None: - self._viewer.add_labels( - seg_np_data, - name=f"[seg] {data['seg'].name}", + seg_np_data: Optional[np.ndarray] = ( + self._img_data_extractor.extract_image_data( + data["seg"], seg=1 + ).np_data ) + if raw_np_data is not None: + self._viewer.add_image( + raw_np_data, + f"[raw] {data['raw'].name}", + ) + if seg_np_data is not None: + self._viewer.add_labels( + seg_np_data, + name=f"[seg] {data['seg'].name}", + ) + self._main_model.set_predictions_in_viewer(True) # Display popup with saved images path if prediction inputs are from a directory else: dialog_box = DialogBox( diff --git a/src/allencell_ml_segmenter/services/prediction_service.py b/src/allencell_ml_segmenter/services/prediction_service.py index 83160578..bc6a5481 100644 --- a/src/allencell_ml_segmenter/services/prediction_service.py +++ b/src/allencell_ml_segmenter/services/prediction_service.py @@ -139,7 +139,9 @@ def build_overrides(self, checkpoint: Path) -> dict[str, Any]: # Need these overrides to load in csv's overrides["data.columns"] = ["raw", "split"] overrides["data.split_column"] = "split" - overrides["ckpt_path"] = str(checkpoint) + overrides["checkpoint.ckpt_path"] = str(checkpoint) + overrides["checkpoint.strict"] = False + overrides["checkpoint.weights_only"] = True input_path: Optional[Path] = ( self._file_input_model.get_input_image_path() diff --git a/src/allencell_ml_segmenter/styles/thresholding_view.qss b/src/allencell_ml_segmenter/styles/thresholding_view.qss index 2cadba41..845be4ab 100644 --- a/src/allencell_ml_segmenter/styles/thresholding_view.qss +++ b/src/allencell_ml_segmenter/styles/thresholding_view.qss @@ -37,4 +37,14 @@ Examples: https://doc.qt.io/qt-5/stylesheet-examples.html#customizing-specific-w #fileInput #onScreen, #imageList, #radioDirectory { margin-left: 40px; +} + +/* Styling for disabled sliders */ +QSlider:disabled { + color: #868E93; +} + +/* Styling for disabled spinboxes */ +QSpinBox:disabled { + color: #868E93; } \ No newline at end of file diff --git a/src/allencell_ml_segmenter/thresholding/__init__.py b/src/allencell_ml_segmenter/thresholding/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/allencell_ml_segmenter/thresholding/thresholding_model.py b/src/allencell_ml_segmenter/thresholding/thresholding_model.py new file mode 100644 index 00000000..93983767 --- /dev/null +++ b/src/allencell_ml_segmenter/thresholding/thresholding_model.py @@ -0,0 +1,112 @@ +from typing import Optional +from collections import OrderedDict + +import numpy as np +from napari.layers import Layer # type: ignore + +from allencell_ml_segmenter.core.event import Event +from allencell_ml_segmenter.core.publisher import Publisher + +# Some thresholding constants # +AVAILABLE_AUTOTHRESHOLD_METHODS: list[str] = ["threshold_otsu"] +THRESHOLD_DEFAULT = 120 +THRESHOLD_RANGE = (0, 255) + + +class ThresholdingModel(Publisher): + """ + Stores state relevant to thresholding processes. + """ + + def __init__(self) -> None: + super().__init__() + + # cyto-dl segmentations should have values between 0 and 255 + self._is_threshold_enabled: bool = False + self._thresholding_value_selected: int = THRESHOLD_DEFAULT + self._is_autothresholding_enabled: bool = False + self._autothresholding_method: str = AVAILABLE_AUTOTHRESHOLD_METHODS[0] + self._original_layers_in_viewer: Optional[ + OrderedDict[str, np.ndarray] + ] = None # Orderedict of layers when thresholding starts, in original order. + + def set_thresholding_value(self, value: int) -> None: + """ + Set the thresholding value. + """ + self._thresholding_value_selected = value + self.dispatch(Event.ACTION_THRESHOLDING_VALUE_CHANGED) + + def get_thresholding_value(self) -> int: + """ + Get the thresholding value. + """ + return self._thresholding_value_selected + + def set_autothresholding_enabled(self, enable: bool) -> None: + """ + Set autothresholding enabled. + """ + self._is_autothresholding_enabled = enable + if enable: + self.dispatch(Event.ACTION_THRESHOLDING_AUTOTHRESHOLDING_SELECTED) + + def is_autothresholding_enabled(self) -> bool: + """ + Get autothresholding enabled. + """ + return self._is_autothresholding_enabled + + def set_autothresholding_method(self, method: str) -> None: + """ + Set autothresholding method. + """ + self._autothresholding_method = method + self.dispatch(Event.ACTION_THRESHOLDING_AUTOTHRESHOLDING_SELECTED) + + def get_autothresholding_method(self) -> str: + """ + Get autothresholding method. + """ + return self._autothresholding_method + + def set_threshold_enabled(self, enabled: bool) -> None: + """ + Set threshold specific value. + """ + self._is_threshold_enabled = enabled + + def is_threshold_enabled(self) -> bool: + """ + Get threshold specific value. + """ + return self._is_threshold_enabled + + def dispatch_save_thresholded_images(self) -> None: + self.dispatch(Event.ACTION_SAVE_THRESHOLDING_IMAGES) + + def set_original_layers(self, layer_list: list[Layer]) -> None: + ordered_layers: OrderedDict[str, np.ndarray] = OrderedDict() + for layer in layer_list: + ordered_layers[layer.name] = layer.data + self._original_layers_in_viewer = ordered_layers + + def get_original_layers(self) -> Optional[OrderedDict[str, np.ndarray]]: + return self._original_layers_in_viewer + + def get_layers_to_threshold( + self, only_seg_layers: bool + ) -> OrderedDict[str, np.ndarray]: + if self._original_layers_in_viewer is None: + raise ValueError( + "Check original layers in model for None before calling get_layers_to_threshold" + ) + + if only_seg_layers: + return OrderedDict( + (key, value) + for key, value in self._original_layers_in_viewer.items() + if key.startswith("[seg]") + ) + + return self._original_layers_in_viewer diff --git a/src/allencell_ml_segmenter/thresholding/thresholding_service.py b/src/allencell_ml_segmenter/thresholding/thresholding_service.py new file mode 100644 index 00000000..39db09e5 --- /dev/null +++ b/src/allencell_ml_segmenter/thresholding/thresholding_service.py @@ -0,0 +1,153 @@ +from collections import OrderedDict +from pathlib import Path +from typing import Callable, Optional +from bioio import BioImage +from bioio.writers import OmeTiffWriter +from napari.layers import Layer # type: ignore +import numpy as np +from napari.utils.notifications import show_info # type: ignore + +from allencell_ml_segmenter.core.event import Event +from allencell_ml_segmenter.core.file_input_model import FileInputModel +from allencell_ml_segmenter.core.subscriber import Subscriber +from allencell_ml_segmenter.main.experiments_model import ExperimentsModel +from allencell_ml_segmenter.main.main_model import MainModel +from allencell_ml_segmenter.thresholding.thresholding_model import ( + ThresholdingModel, +) +from allencell_ml_segmenter.core.task_executor import ( + NapariThreadTaskExecutor, + ITaskExecutor, +) +from allencell_ml_segmenter.main.viewer import IViewer +from cyto_dl.models.im2im.utils.postprocessing.auto_thresh import AutoThreshold # type: ignore + + +class ThresholdingService(Subscriber): + def __init__( + self, + thresholding_model: ThresholdingModel, + experiments_model: ExperimentsModel, + file_input_model: FileInputModel, + main_model: MainModel, + viewer: IViewer, + task_executor: ITaskExecutor = NapariThreadTaskExecutor.global_instance(), + ): + super().__init__() + # Models + self._thresholding_model: ThresholdingModel = thresholding_model + self._experiments_model: ExperimentsModel = experiments_model + self._file_input_model: FileInputModel = file_input_model + self._main_model: MainModel = main_model + + # napari viewer + self._viewer: IViewer = viewer + + # Task Executor + self._task_executor: ITaskExecutor = task_executor + + self._thresholding_model.subscribe( + Event.ACTION_THRESHOLDING_VALUE_CHANGED, + self, + self._on_threshold_changed, + ) + + self._thresholding_model.subscribe( + Event.ACTION_THRESHOLDING_AUTOTHRESHOLDING_SELECTED, + self, + self._on_threshold_changed, + ) + + self._thresholding_model.subscribe( + Event.ACTION_SAVE_THRESHOLDING_IMAGES, + self, + self._save_thresholded_images, + ) + + def _handle_thresholding_error(self, error: Exception) -> None: + show_info("Thresholding failed: " + str(error)) + + def _on_threshold_changed(self, _: Event) -> None: + # if we havent thresholded yet, keep track of original layers. + # need to check this on first threshold change, since user can add images + # between finishing prediction and starting thresholding + # if they are using images from a directory. + original_layers: Optional[OrderedDict[str, np.ndarray]] = ( + self._thresholding_model.get_original_layers() + ) + if original_layers is None: + self._thresholding_model.set_original_layers( + self._viewer.get_layers() + ) + + # Get layers to threshold. + # if there are segmentations displayed in the viewer, only threshold those images. + layers_to_threshold: OrderedDict[str, np.ndarray] = ( + self._thresholding_model.get_layers_to_threshold( + self._main_model.are_predictions_in_viewer() + ) + ) + + # determine thresholding function to use + if self._thresholding_model.is_autothresholding_enabled(): + thresh_function: Callable = AutoThreshold( + self._thresholding_model.get_autothresholding_method() + ) + else: + thresh_function = self._threshold_image + for layer_name, image in layers_to_threshold.items(): + # Creating helper functions for mypy strict typing + def thresholding_task() -> np.ndarray: + return thresh_function(image) + + def on_return( + thresholded_image: np.ndarray, + layer_name_instance: str = layer_name, + ) -> None: + self._viewer.insert_threshold( + layer_name_instance, + thresholded_image, + self._main_model.are_predictions_in_viewer(), + ) + + self._task_executor.exec( + task=thresholding_task, + # lambda functions capture variables by reference so need to pass layer as a default argument + on_return=on_return, + on_error=self._handle_thresholding_error, + ) + + def _save_thresholded_images(self, _: Event) -> None: + images_to_threshold: list[Path] = ( + self._file_input_model.get_input_files_as_list() + ) + if self._thresholding_model.is_autothresholding_enabled(): + thresh_function: Callable = AutoThreshold( + self._thresholding_model.get_autothresholding_method() + ) + else: + thresh_function = self._threshold_image + for path in images_to_threshold: + image = BioImage(path) + try: + self._save_thresh_image(thresh_function(image.data), path.name) + except Exception as e: + self._handle_thresholding_error(e) + + def _save_thresh_image( + self, image: np.ndarray, original_image_name: str + ) -> None: + output_directory: Optional[Path] = ( + self._file_input_model.get_output_directory() + ) + if output_directory is not None: + new_image_path: Path = ( + output_directory / f"threshold_{original_image_name}" + ) + OmeTiffWriter.save(image, str(new_image_path)) + + def _threshold_image(self, image: np.ndarray) -> np.ndarray: + threshold_value: float = ( + self._thresholding_model.get_thresholding_value() + ) + return (image > threshold_value).astype(int) diff --git a/src/allencell_ml_segmenter/thresholding/thresholding_view.py b/src/allencell_ml_segmenter/thresholding/thresholding_view.py new file mode 100644 index 00000000..286f543d --- /dev/null +++ b/src/allencell_ml_segmenter/thresholding/thresholding_view.py @@ -0,0 +1,344 @@ +from pathlib import Path +from typing import Optional + +from napari.utils.notifications import show_info # type: ignore + +from allencell_ml_segmenter.core.dialog_box import DialogBox +from allencell_ml_segmenter.main.i_experiments_model import IExperimentsModel +from allencell_ml_segmenter.main.i_viewer import IViewer +from allencell_ml_segmenter._style import Style +from allencell_ml_segmenter.core.view import View, MainWindow +from allencell_ml_segmenter.main.main_model import MainModel +from allencell_ml_segmenter.prediction.prediction_folder_progress_tracker import ( + PredictionFolderProgressTracker, +) +from allencell_ml_segmenter.prediction.service import ModelFileService +from allencell_ml_segmenter.thresholding.thresholding_model import ( + ThresholdingModel, + AVAILABLE_AUTOTHRESHOLD_METHODS, + THRESHOLD_RANGE, +) +from allencell_ml_segmenter.utils.file_utils import FileUtils + +from allencell_ml_segmenter.widgets.label_with_hint_widget import LabelWithHint +from allencell_ml_segmenter.core.file_input_widget import ( + FileInputWidget, +) +from allencell_ml_segmenter.core.file_input_model import ( + FileInputModel, + InputMode, +) + +from qtpy.QtWidgets import ( + QLabel, + QVBoxLayout, + QHBoxLayout, + QSizePolicy, + QComboBox, + QGroupBox, + QRadioButton, + QPushButton, + QSlider, + QSpinBox, +) +from qtpy.QtCore import Qt + + +class ThresholdingView(View, MainWindow): + """ + View for thresholding + """ + + def __init__( + self, + main_model: MainModel, + thresholding_model: ThresholdingModel, + file_input_model: FileInputModel, + experiments_model: IExperimentsModel, + viewer: IViewer, + ): + super().__init__() + + self._main_model: MainModel = main_model + self._experiments_model: IExperimentsModel = experiments_model + self._viewer: IViewer = viewer + self._thresholding_model: ThresholdingModel = thresholding_model + + # To manage input files: + self._file_input_model: FileInputModel = file_input_model + self._input_files_service: ModelFileService = ModelFileService( + self._file_input_model + ) + + layout: QVBoxLayout = QVBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(20) + self.setLayout(layout) + layout.setAlignment(Qt.AlignmentFlag.AlignTop) + self.setSizePolicy( + QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Maximum + ) + + # title + self._title: QLabel = QLabel("THRESHOLD", self) + self._title.setObjectName("title") + layout.addWidget(self._title, alignment=Qt.AlignmentFlag.AlignHCenter) + + # selecting input image + self._file_input_widget: FileInputWidget = FileInputWidget( + self._file_input_model, + self._viewer, + self._input_files_service, + include_channel_selection=False, + ) + self._file_input_widget.setObjectName("fileInput") + layout.addWidget(self._file_input_widget) + + # thresholding values + self._threshold_label: LabelWithHint = LabelWithHint("Threshold") + self._threshold_label.set_hint("Values to threshold with.") + self._threshold_label.setObjectName("title") + layout.addWidget(self._threshold_label) + + threshold_group_box = QGroupBox() + threshold_group_layout = QVBoxLayout() + + # none thresholding selection + none_radio_layout: QHBoxLayout = QHBoxLayout() + self._none_radio_button: QRadioButton = QRadioButton() + none_radio_layout.addWidget(self._none_radio_button) + + none_radio_label: LabelWithHint = LabelWithHint("None") + none_radio_label.set_hint("No thresholding applied.") + none_radio_layout.addWidget(none_radio_label) + threshold_group_layout.addLayout(none_radio_layout) + + # specific value thresholding selection + specific_value_layout = QHBoxLayout() + + self._specific_value_radio_button: QRadioButton = QRadioButton() + specific_value_layout.addWidget(self._specific_value_radio_button) + specific_radio_label: LabelWithHint = LabelWithHint("Specific Value") + specific_radio_label.set_hint( + "Set thresholding value you'd like to apply." + ) + specific_value_layout.addWidget(specific_radio_label) + + self._threshold_value_slider: QSlider = QSlider( + Qt.Orientation.Horizontal + ) + + self._threshold_value_slider.setRange( + THRESHOLD_RANGE[0], THRESHOLD_RANGE[1] + ) # slider values from 0 to 100 (representing 0.0 to 1.0) + + self._threshold_value_spinbox: QSpinBox = QSpinBox() + self._threshold_value_spinbox.setRange( + THRESHOLD_RANGE[0], THRESHOLD_RANGE[1] + ) + self._threshold_value_spinbox.setSingleStep(1) + + # set default value + self._threshold_value_slider.setValue( + self._thresholding_model.get_thresholding_value() + ) + self._threshold_value_spinbox.setValue( + self._thresholding_model.get_thresholding_value() + ) + + self._threshold_value_slider.setEnabled(False) + self._threshold_value_spinbox.setEnabled(False) + self._specific_value_radio_button.setChecked(False) + + # add slider and spinbox + specific_value_layout.addWidget(self._threshold_value_slider) + specific_value_layout.addWidget(self._threshold_value_spinbox) + threshold_group_layout.addLayout(specific_value_layout) + + # autothresholding + autothreshold_layout = QHBoxLayout() + self._autothreshold_radio_button: QRadioButton = QRadioButton() + auto_thresh_label: LabelWithHint = LabelWithHint("Autothreshold") + auto_thresh_label.set_hint("Apply an autothresholding method.") + + self._autothreshold_method_combo: QComboBox = QComboBox() + self._autothreshold_method_combo.addItems( + AVAILABLE_AUTOTHRESHOLD_METHODS + ) + self._autothreshold_method_combo.setEnabled(False) + + autothreshold_layout.addWidget(self._autothreshold_radio_button) + autothreshold_layout.addWidget(auto_thresh_label) + autothreshold_layout.addWidget(self._autothreshold_method_combo) + threshold_group_layout.addLayout(autothreshold_layout) + + threshold_group_box.setLayout(threshold_group_layout) + layout.addWidget(threshold_group_box) + + # apply and save + self._apply_save_button: QPushButton = QPushButton("Apply & Save") + self._apply_save_button.setEnabled(False) + self._apply_save_button.clicked.connect(self._save_thresholded_images) + layout.addWidget(self._apply_save_button) + + # need styling + self.setStyleSheet(Style.get_stylesheet("thresholding_view.qss")) + + # configure widget behavior + self._configure_slots() + + def _configure_slots(self) -> None: + """ + Connects behavior for widgets + """ + + # sync slider and spinbox + self._threshold_value_slider.valueChanged.connect( + self._update_spinbox_from_slider + ) + self._threshold_value_spinbox.valueChanged.connect( + self._update_slider_from_spinbox + ) + + # update state and ui based on radio button selections + self._none_radio_button.toggled.connect(self._update_state_from_radios) + self._specific_value_radio_button.toggled.connect( + self._update_state_from_radios + ) + self._autothreshold_radio_button.toggled.connect( + self._update_state_from_radios + ) + + # update autothresholding method when one is selected, and update viewer if able + self._autothreshold_method_combo.currentIndexChanged.connect( + lambda: self._thresholding_model.set_autothresholding_method( + self._autothreshold_method_combo.currentText() + ) + ) + + # update thresholding value when the user is finished making a selection + self._threshold_value_slider.sliderReleased.connect( + lambda: self._thresholding_model.set_thresholding_value( + self._threshold_value_slider.value() + ) + ) + + self._threshold_value_spinbox.editingFinished.connect( + lambda: self._thresholding_model.set_thresholding_value( + self._threshold_value_spinbox.value() + ) + ) + + def _update_spinbox_from_slider(self, value: int) -> None: + """ + Update the spinbox value when slider is changed + """ + self._threshold_value_spinbox.setValue(value) + + def _update_slider_from_spinbox(self, value: int) -> None: + """ + Update the slider value when spinbox is changed + """ + self._threshold_value_slider.setValue(value) + + def _enable_specific_threshold_widgets(self, enabled: bool) -> None: + """ + enable or disable specific value thresholding widgets + """ + self._threshold_value_slider.setEnabled(enabled) + self._threshold_value_spinbox.setEnabled(enabled) + + def _update_state_from_radios(self) -> None: + """ + update state based on thresholding radio button selection + """ + self._thresholding_model.set_autothresholding_enabled( + self._autothreshold_radio_button.isChecked() + ) + self._autothreshold_method_combo.setEnabled( + self._autothreshold_radio_button.isChecked() + ) + + self._thresholding_model.set_threshold_enabled( + self._specific_value_radio_button.isChecked() + ) + self._enable_specific_threshold_widgets( + self._specific_value_radio_button.isChecked() + ) + + self._apply_save_button.setEnabled( + self._specific_value_radio_button.isChecked() + or self._autothreshold_radio_button.isChecked() + ) + + def _check_able_to_threshold(self) -> bool: + able_to_threshold: bool = True + # Check to see if output directory is selected + if self._file_input_model.get_output_directory() is None: + show_info("Please select an output directory first.") + able_to_threshold = False + + # Check to see if input images / directory of images are selected + if self._file_input_model.get_input_mode() is None: + show_info("Please select an input mode first.") + able_to_threshold = False + else: + if ( + self._file_input_model.get_input_mode() + == InputMode.FROM_NAPARI_LAYERS + and self._file_input_model.get_selected_paths() is None + ): + show_info("Please select on screen images to threshold.") + able_to_threshold = False + elif ( + self._file_input_model.get_input_mode() == InputMode.FROM_PATH + and self._file_input_model.get_input_image_path() is None + ): + show_info("Please select a directory to threshold.") + able_to_threshold = False + + # check to see if thresholding method is selected + if ( + not self._thresholding_model.is_threshold_enabled() + and not self._thresholding_model.is_autothresholding_enabled() + ): + show_info("Please select a thresholding method first.") + able_to_threshold = False + + return able_to_threshold + + def _save_thresholded_images(self) -> None: + output_dir: Optional[Path] = ( + self._file_input_model.get_output_directory() + ) + if output_dir is not None and self._check_able_to_threshold(): + # progress tracker is tracking number of images saved to the thresholding folder + progress_tracker: PredictionFolderProgressTracker = ( + PredictionFolderProgressTracker( + output_dir, + len(self._file_input_model.get_input_files_as_list()), + ) + ) + + self.startLongTaskWithProgressBar(progress_tracker) + + def doWork(self) -> None: + self._thresholding_model.dispatch_save_thresholded_images() + + def focus_changed(self) -> None: + return + + def getTypeOfWork(self) -> str: + return "" + + def showResults(self) -> None: + dialog_box = DialogBox( + f"Predicted images saved to {str(self._file_input_model.get_output_directory())}. \nWould you like to open this folder?" + ) + dialog_box.exec() + output_dir: Optional[Path] = ( + self._file_input_model.get_output_directory() + ) + + if output_dir and dialog_box.get_selection(): + FileUtils.open_directory_in_window(output_dir) diff --git a/src/allencell_ml_segmenter/training/image_selection_widget.py b/src/allencell_ml_segmenter/training/image_selection_widget.py index 62b560ab..21fb17f3 100644 --- a/src/allencell_ml_segmenter/training/image_selection_widget.py +++ b/src/allencell_ml_segmenter/training/image_selection_widget.py @@ -134,8 +134,9 @@ def __init__( self._model.signals.num_channels_set.connect(self._update_channels) def set_inputs_csv(self, event: Optional[Event] = None) -> None: - if self._experiments_model.get_csv_path() is not None: - csv_path = self._experiments_model.get_csv_path() / "train.csv" + csv_path: Optional[Path] = self._experiments_model.get_csv_path() + if csv_path is not None: + csv_path = csv_path / "train.csv" if csv_path.is_file(): # if the csv exists self._images_directory_input_button._text_display.setText( diff --git a/src/allencell_ml_segmenter/training/view.py b/src/allencell_ml_segmenter/training/view.py index 2ec443a6..ad9561c2 100644 --- a/src/allencell_ml_segmenter/training/view.py +++ b/src/allencell_ml_segmenter/training/view.py @@ -270,10 +270,18 @@ def train_btn_handler(self) -> None: if self._patch_size_ok(): self.set_patch_size() num_epochs: Optional[int] = self._training_model.get_num_epochs() + metrics_csv_path: Optional[Path] = ( + self._experiments_model.get_metrics_csv_path() + ) + cache_dir: Optional[Path] = self._experiments_model.get_cache_dir() + if metrics_csv_path is None or cache_dir is None: + raise ValueError( + "Metrics CSV path and cache_dir cannot be None after training has started!" + ) progress_tracker: TrainingProgressTracker = ( TrainingProgressTracker( - self._experiments_model.get_metrics_csv_path(), - self._experiments_model.get_cache_dir(), + metrics_csv_path, + cache_dir, num_epochs if num_epochs is not None else 0, self._training_model.get_total_num_images(), self._experiments_model.get_latest_metrics_csv_version() diff --git a/src/allencell_ml_segmenter/widgets/check_box_list_widget.py b/src/allencell_ml_segmenter/widgets/check_box_list_widget.py index c4daea8b..eea9fc7a 100644 --- a/src/allencell_ml_segmenter/widgets/check_box_list_widget.py +++ b/src/allencell_ml_segmenter/widgets/check_box_list_widget.py @@ -26,7 +26,9 @@ def _send_checked_signal(self, item: QListWidgetItem) -> None: idx: int = self.row(item) self.checkedSignal.emit(idx, item.checkState()) - def add_item(self, item: Union[str, QListWidgetItem]) -> None: + def add_item( + self, item: Union[str, QListWidgetItem], set_checked: bool = False + ) -> None: """ Adds an item to the list. """ @@ -39,10 +41,13 @@ def add_item(self, item: Union[str, QListWidgetItem]) -> None: raise TypeError( f"Item added to CheckBoxListWidget must be a string or QListWidgetItem, but got {type(item)} instead" ) - - # set checkable and unchecked by default + # set checkable item_add.setFlags(item_add.flags() | Qt.ItemFlag.ItemIsUserCheckable) - item_add.setCheckState(Qt.CheckState.Unchecked) + if set_checked: + item_add.setCheckState(Qt.CheckState.Checked) + else: + item_add.setCheckState(Qt.CheckState.Unchecked) + super().addItem(item_add) def set_all_state(self, state: Qt.CheckState) -> None: