Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/enforce model trained once #519

Merged
merged 9 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/allencell_ml_segmenter/main/main_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,6 @@ def set_selected_channels(
):
self._selected_channels = new_channels
self.signals.selected_channels_changed.emit()

def training_complete(self) -> None:
self.dispatch(Event.PROCESS_TRAINING_COMPLETE)
15 changes: 15 additions & 0 deletions src/allencell_ml_segmenter/main/main_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ def __init__(
layout.addStretch(100)
self.setStyleSheet(Style.get_stylesheet("core.qss"))

# events for auto window switching
self._model.subscribe(
Event.PROCESS_TRAINING_COMPLETE, self, self._disable_non_prediction_tabs
)

def _handle_experiment_applied(self, _: Event) -> None:
"""
Handle the experiment applied event.
Expand All @@ -169,6 +174,16 @@ def _handle_new_model(self, _: Event) -> None:
else self._prediction_view
)

def _disable_non_prediction_tabs(self, _: Event) -> None:
"""
Handle existing model selection (disable tabs).

inputs:
is_new_model - bool
"""
self._window_container.setTabEnabled(0, False)
hughes036 marked this conversation as resolved.
Show resolved Hide resolved
self._window_container.setTabEnabled(1, False)

def _handle_change_view(self, event: Event) -> None:
"""
Handle event function for the main widget, which handles MainEvents.
Expand Down
49 changes: 39 additions & 10 deletions src/allencell_ml_segmenter/training/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from allencell_ml_segmenter.core.info_dialog_box import InfoDialogBox
from allencell_ml_segmenter.utils.file_utils import FileUtils
from allencell_ml_segmenter.utils.experiment_utils import ExperimentUtils


class TrainingView(View, MainWindow):
Expand Down Expand Up @@ -296,18 +297,46 @@ def getTypeOfWork(self) -> str:
return "Training"

def showResults(self) -> None:
hughes036 marked this conversation as resolved.
Show resolved Hide resolved
csv_path: Optional[Path] = (
self._experiments_model.get_latest_metrics_csv_path()
# double check to see if a ckpt was generated
exp_path: Optional[Path] = (
self._experiments_model.get_user_experiments_path()
)
if csv_path is None:
raise RuntimeError("Cannot get min loss from undefined csv")
min_loss: Optional[float] = FileUtils.get_min_loss_from_csv(csv_path)
if min_loss is None:
raise RuntimeError("Cannot compute min loss")
dialog_box = InfoDialogBox(
"Training finished -- Final loss: {:.3f}".format(min_loss)
if exp_path is None:
raise ValueError(
"Experiments path should not be None after training complete."
)
exp_name: Optional[str] = self._experiments_model.get_experiment_name()
hughes036 marked this conversation as resolved.
Show resolved Hide resolved
if exp_name is None:
raise ValueError(
"Experiment name should not be None after training complete."
)
ckpt_generated: Optional[Path] = ExperimentUtils.get_best_ckpt(
exp_path,
exp_name,
)
dialog_box.exec()
if ckpt_generated is not None:
# if model was successfully trained, get metrics to display
csv_path: Optional[Path] = (
self._experiments_model.get_latest_metrics_csv_path()
)
if csv_path is None:
raise RuntimeError("Cannot get min loss from undefined csv")
min_loss: Optional[float] = FileUtils.get_min_loss_from_csv(
csv_path
)
if min_loss is None:
raise RuntimeError("Cannot compute min loss")

dialog_box = InfoDialogBox(
"Training finished -- Final loss: {:.3f}".format(min_loss)
)
dialog_box.exec()
self._main_model.training_complete()
hughes036 marked this conversation as resolved.
Show resolved Hide resolved
else:
dialog_box = InfoDialogBox(
"Training failed- no model was saved from this run."
)
dialog_box.exec()

def _num_epochs_field_handler(self, num_epochs: str) -> None:
self._training_model.set_num_epochs(int(num_epochs))
Expand Down
Loading