Skip to content

Commit

Permalink
type, lint
Browse files Browse the repository at this point in the history
  • Loading branch information
yrkim98 committed Sep 27, 2024
1 parent b92e0c5 commit be9156e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
1 change: 0 additions & 1 deletion src/allencell_ml_segmenter/main/main_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,3 @@ def set_selected_channels(

def training_complete(self) -> None:
self.dispatch(Event.PROCESS_TRAINING_COMPLETE)

4 changes: 1 addition & 3 deletions src/allencell_ml_segmenter/main/main_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ def __init__(

# events for auto window switching
self._model.subscribe(
Event.PROCESS_TRAINING_COMPLETE,
self,
self._handle_existing_model
Event.PROCESS_TRAINING_COMPLETE, self, self._handle_existing_model
)

def _handle_experiment_applied(self, _: Event) -> None:
Expand Down
16 changes: 12 additions & 4 deletions src/allencell_ml_segmenter/training/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,17 +298,26 @@ def getTypeOfWork(self) -> str:

def showResults(self) -> None:
# double check to see if a ckpt was generated
exp_path: Optional[Path] = self._experiments_model.get_user_experiments_path()
if exp_path is None:
raise ValueError("Experiments path should not be None after training complete.")
exp_name: Optional[str] = self._experiments_model.get_experiment_name()
if exp_name is None:
raise ValueError("Experiment name should not be None after training complete.")
ckpt_generated: Optional[Path] = ExperimentUtils.get_best_ckpt(
self._experiments_model.get_user_experiments_path(),
self._experiments_model.get_experiment_name())
exp_path,
exp_name,
)
if ckpt_generated is not None:
# if model was successfully trained, get metrics to display
csv_path: Optional[Path] = (
self._experiments_model.get_latest_metrics_csv_path()
)
if csv_path is None:
raise RuntimeError("Cannot get min loss from undefined csv")
min_loss: Optional[float] = FileUtils.get_min_loss_from_csv(csv_path)
min_loss: Optional[float] = FileUtils.get_min_loss_from_csv(
csv_path
)
if min_loss is None:
raise RuntimeError("Cannot compute min loss")

Expand All @@ -323,7 +332,6 @@ def showResults(self) -> None:
)
dialog_box.exec()


def _num_epochs_field_handler(self, num_epochs: str) -> None:
self._training_model.set_num_epochs(int(num_epochs))

Expand Down

0 comments on commit be9156e

Please sign in to comment.