diff --git a/src/allencell_ml_segmenter/main/main_model.py b/src/allencell_ml_segmenter/main/main_model.py index 1e0ee805..590aaf62 100644 --- a/src/allencell_ml_segmenter/main/main_model.py +++ b/src/allencell_ml_segmenter/main/main_model.py @@ -82,3 +82,6 @@ def set_selected_channels( ): self._selected_channels = new_channels self.signals.selected_channels_changed.emit() + + def training_complete(self) -> None: + self.dispatch(Event.PROCESS_TRAINING_COMPLETE) diff --git a/src/allencell_ml_segmenter/main/main_widget.py b/src/allencell_ml_segmenter/main/main_widget.py index b02a9f93..81522664 100644 --- a/src/allencell_ml_segmenter/main/main_widget.py +++ b/src/allencell_ml_segmenter/main/main_widget.py @@ -146,6 +146,13 @@ def __init__( layout.addStretch(100) self.setStyleSheet(Style.get_stylesheet("core.qss")) + # events for auto window switching + self._model.subscribe( + Event.PROCESS_TRAINING_COMPLETE, + self, + self._disable_non_prediction_tabs, + ) + def _handle_experiment_applied(self, _: Event) -> None: """ Handle the experiment applied event. @@ -169,6 +176,16 @@ def _handle_new_model(self, _: Event) -> None: else self._prediction_view ) + def _disable_non_prediction_tabs(self, _: Event) -> None: + """ + Handle existing model selection (disable tabs). + + inputs: + is_new_model - bool + """ + self._window_container.setTabEnabled(0, False) + self._window_container.setTabEnabled(1, False) + def _handle_change_view(self, event: Event) -> None: """ Handle event function for the main widget, which handles MainEvents. diff --git a/src/allencell_ml_segmenter/training/view.py b/src/allencell_ml_segmenter/training/view.py index 1fd65942..2ec443a6 100644 --- a/src/allencell_ml_segmenter/training/view.py +++ b/src/allencell_ml_segmenter/training/view.py @@ -40,6 +40,7 @@ ) from allencell_ml_segmenter.core.info_dialog_box import InfoDialogBox from allencell_ml_segmenter.utils.file_utils import FileUtils +from allencell_ml_segmenter.utils.experiment_utils import ExperimentUtils class TrainingView(View, MainWindow): @@ -296,18 +297,46 @@ def getTypeOfWork(self) -> str: return "Training" def showResults(self) -> None: - csv_path: Optional[Path] = ( - self._experiments_model.get_latest_metrics_csv_path() + # double check to see if a ckpt was generated + exp_path: Optional[Path] = ( + self._experiments_model.get_user_experiments_path() ) - if csv_path is None: - raise RuntimeError("Cannot get min loss from undefined csv") - min_loss: Optional[float] = FileUtils.get_min_loss_from_csv(csv_path) - if min_loss is None: - raise RuntimeError("Cannot compute min loss") - dialog_box = InfoDialogBox( - "Training finished -- Final loss: {:.3f}".format(min_loss) + if exp_path is None: + raise ValueError( + "Experiments path should not be None after training complete." + ) + exp_name: Optional[str] = self._experiments_model.get_experiment_name() + if exp_name is None: + raise ValueError( + "Experiment name should not be None after training complete." + ) + ckpt_generated: Optional[Path] = ExperimentUtils.get_best_ckpt( + exp_path, + exp_name, ) - dialog_box.exec() + if ckpt_generated is not None: + # if model was successfully trained, get metrics to display + csv_path: Optional[Path] = ( + self._experiments_model.get_latest_metrics_csv_path() + ) + if csv_path is None: + raise RuntimeError("Cannot get min loss from undefined csv") + min_loss: Optional[float] = FileUtils.get_min_loss_from_csv( + csv_path + ) + if min_loss is None: + raise RuntimeError("Cannot compute min loss") + + dialog_box = InfoDialogBox( + "Training finished -- Final loss: {:.3f}".format(min_loss) + ) + dialog_box.exec() # this shows the dialog box + self._main_model.training_complete() # this dispatches the event that changes to prediction tab + else: + dialog_box = InfoDialogBox( + "Training failed- no model was saved from this run." + ) + dialog_box.exec() def _num_epochs_field_handler(self, num_epochs: str) -> None: self._training_model.set_num_epochs(int(num_epochs))