Skip to content

Commit

Permalink
Merge pull request #547 from AllenCell/pr-fix/separate_file_input_model
Browse files Browse the repository at this point in the history
Separate file input model from prediction model
  • Loading branch information
yrkim98 authored Nov 7, 2024
2 parents 5c97ae4 + 40a989c commit e8837e1
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 119 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from unittest.mock import patch, Mock
from pathlib import Path

import pytest
from qtpy.QtWidgets import QFileDialog
Expand All @@ -11,33 +12,37 @@
from allencell_ml_segmenter.prediction.model import (
PredictionModel,
)
from allencell_ml_segmenter.core.file_input_model import InputMode

from allencell_ml_segmenter.core.file_input_model import (
InputMode,
FileInputModel,
)

MOCK_PATH: str = "/path/to/file"


@pytest.fixture
def prediction_model(qtbot: QtBot) -> PredictionModel:
return PredictionModel()
def file_input_model(qtbot: QtBot) -> FileInputModel:
return FileInputModel()


@pytest.fixture
def file_input_widget(
qtbot: QtBot, prediction_model: PredictionModel
qtbot: QtBot, file_input_model: FileInputModel
) -> FileInputWidget:
"""
Fixture that creates an instance of ModelInputWidget for testing.
"""
return FileInputWidget(prediction_model, viewer=FakeViewer(), service=None)
return FileInputWidget(file_input_model, viewer=FakeViewer(), service=None)


def test_top_radio_button_slot(
qtbot: QtBot,
file_input_widget: FileInputWidget,
prediction_model: PredictionModel,
file_input_model: FileInputModel,
) -> None:
"""
Test the _top_radio_button_slot method of PredictionFileInput.
Test the _top_radio_button_slot method of FileInputWidget.
"""
# ARRANGE - explicitly disable file_input_widget._image_list and enable file_input_widget._browse_dir_edit
file_input_widget._image_list.setEnabled(False)
Expand All @@ -50,16 +55,16 @@ def test_top_radio_button_slot(
# ASSERT - states should have flipped
assert file_input_widget._image_list.isEnabled()
assert not file_input_widget._browse_dir_edit.isEnabled()
assert prediction_model.get_input_mode() == InputMode.FROM_NAPARI_LAYERS
assert file_input_model.get_input_mode() == InputMode.FROM_NAPARI_LAYERS


def test_bottom_radio_button_slot(
qtbot: QtBot,
file_input_widget: FileInputWidget,
prediction_model: PredictionModel,
file_input_model: FileInputModel,
) -> None:
"""
Test the _bottom_radio_button_slot method of PredictionFileInput.
Test the _bottom_radio_button_slot method of FileInputWidget.
"""
# ARRANGE - explicitly enable file_input_widget._image_list and disable file_input_widget._browse_dir_edit
file_input_widget._image_list.setEnabled(True)
Expand All @@ -72,7 +77,7 @@ def test_bottom_radio_button_slot(
# ASSERT - states should have flipped
assert not file_input_widget._image_list.isEnabled()
assert file_input_widget._browse_dir_edit.isEnabled()
assert prediction_model.get_input_mode() == InputMode.FROM_PATH
assert file_input_model.get_input_mode() == InputMode.FROM_PATH


# decorator used to stub QFileDialog and avoid nested context managers
Expand All @@ -84,14 +89,56 @@ def test_bottom_radio_button_slot(
)
def test_populate_input_channel_combobox(qtbot: QtBot) -> None:
# Arrange
prediction_model: PredictionModel = PredictionModel()
file_input_model: FileInputModel = FileInputModel()
prediction_file_input: FileInputWidget = FileInputWidget(
prediction_model, viewer=FakeViewer(), service=None
file_input_model, viewer=FakeViewer(), service=None
)
prediction_model.set_max_channels(6)
file_input_model.set_max_channels(6)

# Act
prediction_file_input._populate_input_channel_combobox()

# Assert
assert prediction_file_input._channel_select_dropdown.isEnabled()


def test_input_image_paths(file_input_model: FileInputModel) -> None:
"""
Tests that the input image paths are set and retrieved properly.
"""
# ARRANGE
dummy_paths: List[Path] = [
Path("example path " + str(i)) for i in range(10)
]

# ACT
file_input_model.set_input_image_path(dummy_paths)

# ASSERT
assert file_input_model.get_input_image_path() == dummy_paths


def test_image_input_channel_index(file_input_model: FileInputModel) -> None:
"""
Tests that the channel index is set and retrieved properly.
"""
for i in range(10):
# ACT
file_input_model.set_image_input_channel_index(i)

# ASSERT
assert file_input_model.get_image_input_channel_index() == i


def test_output_directory(file_input_model: FileInputModel) -> None:
"""
Tests that the output directory is set and retrieved properly.
"""
# ARRANGE
dummy_path: Path = Path("example path")

# ACT
file_input_model.set_output_directory(dummy_path)

# ASSERT
assert file_input_model.get_output_directory() == dummy_path
58 changes: 0 additions & 58 deletions src/allencell_ml_segmenter/_tests/prediction/test_model.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ModelInputWidget,
)
from allencell_ml_segmenter.prediction.model import PredictionModel
from allencell_ml_segmenter.core.file_input_model import FileInputModel


@pytest.fixture
Expand Down
29 changes: 19 additions & 10 deletions src/allencell_ml_segmenter/_tests/prediction/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from allencell_ml_segmenter.core.image_data_extractor import (
FakeImageDataExtractor,
)
from allencell_ml_segmenter.core.file_input_model import InputMode
from allencell_ml_segmenter.core.file_input_model import (
InputMode,
FileInputModel,
)
from allencell_ml_segmenter.prediction.view import PredictionView


Expand All @@ -31,7 +34,9 @@ def prediction_view(main_model: MainModel, qtbot: QtBot) -> PredictionView:
Returns a PredictionView instance for testing.
"""
prediction_model: PredictionModel = PredictionModel()
return PredictionView(main_model, prediction_model, FakeViewer())
return PredictionView(
main_model, prediction_model, FileInputModel(), FakeViewer()
)


def test_prediction_view(
Expand All @@ -52,23 +57,25 @@ def test_show_results(main_model: MainModel) -> None:
Testing the showresults that runs after a prediction run
"""
# ARRANGE
file_input_model: FileInputModel = FileInputModel()
prediction_model: PredictionModel = PredictionModel()
prediction_model.set_output_directory(
file_input_model.set_output_directory(
Path(allencell_ml_segmenter.__file__).parent
/ "_tests"
/ "test_files"
/ "output_test_folder"
)
prediction_model.set_input_mode(InputMode.FROM_NAPARI_LAYERS)
prediction_model.set_selected_paths(
file_input_model.set_input_mode(InputMode.FROM_NAPARI_LAYERS)
file_input_model.set_selected_paths(
[Path("output_1.tiff"), Path("output_2.tiff")]
)
prediction_model.set_image_input_channel_index(0)
file_input_model.set_image_input_channel_index(0)
fake_viewer: FakeViewer = FakeViewer()

prediction_view: PredictionView = PredictionView(
main_model,
prediction_model,
file_input_model,
fake_viewer,
img_data_extractor=FakeImageDataExtractor.global_instance(),
)
Expand All @@ -90,23 +97,25 @@ def test_show_results_non_empty_folder(main_model: MainModel) -> None:
Testing that only the new images in a folder will be shown after prediction.
"""
# ARRANGE
file_input_model: FileInputModel = FileInputModel()
prediction_model: PredictionModel = PredictionModel()
prediction_model.set_output_directory(
file_input_model.set_output_directory(
Path(allencell_ml_segmenter.__file__).parent
/ "_tests"
/ "test_files"
/ "output_test_folder_extra"
)
prediction_model.set_input_mode(InputMode.FROM_NAPARI_LAYERS)
prediction_model.set_selected_paths(
file_input_model.set_input_mode(InputMode.FROM_NAPARI_LAYERS)
file_input_model.set_selected_paths(
[Path("output_3.tiff"), Path("output_4.tiff")]
)
prediction_model.set_image_input_channel_index(0)
file_input_model.set_image_input_channel_index(0)
fake_viewer: FakeViewer = FakeViewer()

prediction_view: PredictionView = PredictionView(
main_model,
prediction_model,
file_input_model,
fake_viewer,
img_data_extractor=FakeImageDataExtractor.global_instance(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from allencell_ml_segmenter.services.prediction_service import (
PredictionService,
)
from allencell_ml_segmenter.core.file_input_model import FileInputModel


@pytest.fixture
Expand Down Expand Up @@ -101,6 +102,7 @@ def test_predict_model_no_checkpoint_selected() -> None:

def test_build_overrides() -> None:
# Arrange
file_input_model: FileInputModel = FileInputModel()
prediction_model: PredictionModel = PredictionModel()
experiments_model: ExperimentsModel = ExperimentsModel(
FakeUserSettings(
Expand All @@ -112,16 +114,16 @@ def test_build_overrides() -> None:
)
experiments_model.apply_experiment_name("one_ckpt_exp")
prediction_service: PredictionService = PredictionService(
prediction_model, experiments_model
prediction_model, file_input_model, experiments_model
)
prediction_model.set_output_directory(
file_input_model.set_output_directory(
Path(__file__).parent.parent
/ "main"
/ "0_exp"
/ "prediction_output_test"
)
prediction_model.set_input_image_path(Path("fake_img_path"))
prediction_model.set_image_input_channel_index(3)
file_input_model.set_input_image_path(Path("fake_img_path"))
file_input_model.set_image_input_channel_index(3)

# act
overrides: Dict[str, Union[str, int, float, bool]] = (
Expand Down Expand Up @@ -168,7 +170,7 @@ def test_write_csv_for_inputs() -> None:
experiments_model.apply_experiment_name("0_exp")
prediction_model: PredictionModel = PredictionModel()
prediction_service: PredictionService = PredictionService(
prediction_model, experiments_model
prediction_model, FileInputModel(), experiments_model
)
mock_csv_write = MagicMock(spec=csv.writer)

Expand Down
14 changes: 11 additions & 3 deletions src/allencell_ml_segmenter/core/file_input_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ class InputMode(Enum):
FROM_NAPARI_LAYERS = "from_napari_layers"


class FileInputModel(Publisher, ABC):
class FileInputModel(Publisher):
"""
Abstract base class for PredictionModel and PostprocessModel.
Defines the common methods and attributes both models should implement.
Model for FileInputWidget
"""

def __init__(self) -> None:
Expand All @@ -26,6 +25,15 @@ def __init__(self) -> None:
self._selected_paths: Optional[list[Path]] = None
self._max_channels: Optional[int] = None

def get_output_seg_directory(self) -> Optional[Path]:
"""
Gets path to where segmentation predictions are stored.
"""
output_dir: Optional[Path] = self.get_output_directory()
if output_dir is None:
return None
return output_dir / "target"

def set_input_image_path(
self, path: Optional[Path], extract_channels: bool = False
) -> None:
Expand Down
5 changes: 4 additions & 1 deletion src/allencell_ml_segmenter/main/main_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from allencell_ml_segmenter.postprocess.postprocess_view import (
ThresholdingView,
)
from allencell_ml_segmenter.core.file_input_model import FileInputModel


class MainWidget(AicsWidget):
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(
self._training_model: TrainingModel = TrainingModel(
main_model=self._model, experiments_model=self._experiments_model
)

self._file_input_model: FileInputModel = FileInputModel()
self._prediction_model: PredictionModel = PredictionModel()
self._curation_model: CurationModel = CurationModel(
self._experiments_model,
Expand All @@ -104,6 +105,7 @@ def __init__(
)
self._prediction_service: PredictionService = PredictionService(
prediction_model=self._prediction_model,
file_input_model=self._file_input_model,
experiments_model=self._experiments_model,
)

Expand Down Expand Up @@ -133,6 +135,7 @@ def __init__(
self._prediction_view: PredictionView = PredictionView(
main_model=self._model,
prediction_model=self._prediction_model,
file_input_model=self._file_input_model,
viewer=self.viewer,
)
self._initialize_window(self._prediction_view, "Prediction")
Expand Down
Loading

0 comments on commit e8837e1

Please sign in to comment.