From 6dd8b30a021e65017293c23bd7e211e5827c4f01 Mon Sep 17 00:00:00 2001 From: chrishu Date: Tue, 3 Oct 2023 15:37:58 -0400 Subject: [PATCH] experiment name for all views refactor --- src/allencell_ml_segmenter/constants.py | 12 +- .../main/experiments_model.py | 35 ++++ .../main/i_experiments_model.py | 23 ++- .../main/main_widget.py | 18 ++- .../services/training_service.py | 31 +--- .../training/experiment_info_widget.py | 6 +- .../training/model_selection_widget.py | 150 ++---------------- .../training/training_model.py | 30 ---- src/allencell_ml_segmenter/training/view.py | 139 ++++++++++++++-- 9 files changed, 224 insertions(+), 220 deletions(-) diff --git a/src/allencell_ml_segmenter/constants.py b/src/allencell_ml_segmenter/constants.py index 6bca325b..495e1956 100644 --- a/src/allencell_ml_segmenter/constants.py +++ b/src/allencell_ml_segmenter/constants.py @@ -1,7 +1,7 @@ -CYTO_DL_HOME_PATH = "/Users/brian.kim/Desktop/work/cyto-dl" -USER_EXPERIMENTS_PATH = ( - "/Users/brian.kim/Desktop/work/cyto-dl/cyto_dl/logs/train/runs" -) +CYTO_DL_HOME_PATH = "/Users/chrishu/dev/code/test2/cyto-dl" +USER_EXPERIMENTS_PATH = "/Users/chrishu/dev/code/test2/allencell-ml-segmenter/logs/train/runs" -# CYTO_DL_HOME_PATH = "/Users/chrishu/dev/code/test2/cyto-dl" -# USER_EXPERIMENTS_PATH = "/Users/chrishu/dev/code/test2/logs/train/runs" +# CYTO_DL_HOME_PATH = "/Users/brian.kim/Desktop/work/cyto-dl" +# USER_EXPERIMENTS_PATH = ( +# "/Users/brian.kim/Desktop/work/cyto-dl/cyto_dl/logs/train/runs" +# ) diff --git a/src/allencell_ml_segmenter/main/experiments_model.py b/src/allencell_ml_segmenter/main/experiments_model.py index 4256a734..9d1d241c 100644 --- a/src/allencell_ml_segmenter/main/experiments_model.py +++ b/src/allencell_ml_segmenter/main/experiments_model.py @@ -8,10 +8,45 @@ class ExperimentsModel(IExperimentsModel): def __init__(self, config: CytoDlConfig) -> None: + super().__init__() self.config = config + + # options self.experiments = {} self.refresh_experiments() + # state + self._experiment_name: str = None + self._checkpoint: str = None + + def get_experiment_name(self) -> str: + """ + Gets experiment name + """ + return self._experiment_name + + def set_experiment_name(self, name: str) -> None: + """ + Sets experiment name + + name (str): name of cyto-dl experiment + """ + self._experiment_name = name + + def get_checkpoint(self) -> str: + """ + Gets checkpoint + """ + return self._checkpoint + + def set_checkpoint(self, checkpoint: str) -> None: + """ + Sets checkpoint + + checkpoint (str): name of checkpoint to use + """ + self._checkpoint = checkpoint + def refresh_experiments(self) -> None: for experiment in Path( self.config.get_user_experiments_path() diff --git a/src/allencell_ml_segmenter/main/i_experiments_model.py b/src/allencell_ml_segmenter/main/i_experiments_model.py index 532368d7..834c141b 100644 --- a/src/allencell_ml_segmenter/main/i_experiments_model.py +++ b/src/allencell_ml_segmenter/main/i_experiments_model.py @@ -1,13 +1,34 @@ from abc import ABC, abstractmethod from pathlib import Path +from allencell_ml_segmenter.core.publisher import Publisher -class IExperimentsModel(ABC): + +class IExperimentsModel(Publisher): """ Interface for implementing and testing ExperimentsModel """ + def __init__(self): + super().__init__() + + @abstractmethod + def get_experiment_name(self) -> str: + pass + + @abstractmethod + def set_experiment_name(self, name: str) -> None: + pass + + @abstractmethod + def get_checkpoint(self) -> str: + pass + + @abstractmethod + def set_checkpoint(self, checkpoint: str): + pass + @abstractmethod def get_experiments(self): pass diff --git a/src/allencell_ml_segmenter/main/main_widget.py b/src/allencell_ml_segmenter/main/main_widget.py index 8c6bfde2..c5c9ce5b 100644 --- a/src/allencell_ml_segmenter/main/main_widget.py +++ b/src/allencell_ml_segmenter/main/main_widget.py @@ -21,6 +21,7 @@ from allencell_ml_segmenter.main.experiments_model import ExperimentsModel from allencell_ml_segmenter.main.main_model import MainModel from allencell_ml_segmenter.prediction.view import PredictionView +from allencell_ml_segmenter.training.model_selection_widget import ModelSelectionWidget from allencell_ml_segmenter.training.view import TrainingView @@ -43,9 +44,20 @@ def __init__(self, viewer: napari.Viewer, config: CytoDlConfig = None): Event.ACTION_CHANGE_VIEW, self, self.handle_change_view ) + if config is None: + config = CytoDlConfig(CYTO_DL_HOME_PATH, USER_EXPERIMENTS_PATH) + experiment_model = ExperimentsModel(config) + + # Model selection which applies to all views + model_selection_widget: ModelSelectionWidget = ModelSelectionWidget( + experiment_model + ) + model_selection_widget.setObjectName("modelSelection") + self.layout().addWidget(model_selection_widget, Qt.AlignTop) + # keep track of views self._view_container: QTabWidget = QTabWidget() - self.layout().addWidget(self._view_container, Qt.AlignTop) + self.layout().addWidget(self._view_container, Qt.AlignCenter ) self.layout().addStretch(100) self._view_to_index: Dict[View, int] = dict() @@ -54,10 +66,6 @@ def __init__(self, viewer: napari.Viewer, config: CytoDlConfig = None): self._prediction_view: PredictionView = PredictionView(self._model) self._initialize_view(self._prediction_view, "Prediction") - if config is None: - config = CytoDlConfig(CYTO_DL_HOME_PATH, USER_EXPERIMENTS_PATH) - experiment_model = ExperimentsModel(config) - training_view: TrainingView = TrainingView( main_model=self._model, viewer=self.viewer, diff --git a/src/allencell_ml_segmenter/services/training_service.py b/src/allencell_ml_segmenter/services/training_service.py index 001b8fbd..e876a620 100644 --- a/src/allencell_ml_segmenter/services/training_service.py +++ b/src/allencell_ml_segmenter/services/training_service.py @@ -30,31 +30,6 @@ def _list_to_string(list_to_convert: List[Any]) -> str: return f"[{ints_to_strings}]" -# class MyPrintingCallback(Callback): -# def __init__(self): -# super().__init__() - -# def on_train_start(self, trainer, pl_module): -# print( -# "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@Training is starting@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" -# ) - -# def on_train_epoch_start(self, trainer, pl_module): -# print( -# f"@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@Epoch {trainer.current_epoch} is starting@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" -# ) - -# def on_train_epoch_end(self, trainer, pl_module): -# print( -# f"@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@Training {trainer.current_epoch} is ending@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" -# ) - -# def on_train_end(self, trainer, pl_module): -# print( -# "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@Training is ending@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" -# ) - - class TrainingService(Subscriber): """ Interface for training a model. Uses cyto-dl to train model according to spec @@ -101,9 +76,9 @@ def train_model_handler(self, _: Event) -> None: # This is meant to be a string as is - not a string template. In cyto-dl, it will be treated as a string template "hydra.run.dir=${paths.log_dir}/${task_name}/runs/${experiment_name}" ) - if self._training_model.get_checkpoint() is not None: + if self._experiments_model.get_checkpoint() is not None: sys.argv.append( - f"ckpt_path={self._experiments_model.get_model_checkpoints_path(self._training_model.get_experiment_name(), self._training_model.get_checkpoint())}" + f"ckpt_path={self._experiments_model.get_model_checkpoints_path(self._experiments_model.get_experiment_name(), self._experiments_model.get_checkpoint())}" ) # sys.argv.append( # "+callbacks.print_progress._target_=allencell_ml_segmenter.services.training_service.MyPrintingCallback" @@ -149,7 +124,7 @@ def _set_experiment_name(self) -> None: """ Sets the experiment_name argument variable for hydra override using sys.argv """ - experiment_name: str = self._training_model.get_experiment_name() + experiment_name: str = self._experiments_model.get_experiment_name() sys.argv.append(f"++experiment_name={experiment_name}") def _set_max_epoch(self) -> None: diff --git a/src/allencell_ml_segmenter/training/experiment_info_widget.py b/src/allencell_ml_segmenter/training/experiment_info_widget.py index e0b7cc8b..63e08f8a 100644 --- a/src/allencell_ml_segmenter/training/experiment_info_widget.py +++ b/src/allencell_ml_segmenter/training/experiment_info_widget.py @@ -1,4 +1,4 @@ -from allencell_ml_segmenter.training.training_model import TrainingModel +from allencell_ml_segmenter.main.i_experiments_model import IExperimentsModel from qtpy.QtWidgets import ( QWidget, QFrame, @@ -18,10 +18,10 @@ class ExperimentInfoWidget(QWidget): TITLE_TEXT: str = "Experiment information" - def __init__(self, model: TrainingModel): + def __init__(self, model: IExperimentsModel): super().__init__() - self._model: TrainingModel = model + self._model: IExperimentsModel = model # widget skeleton self.setLayout(QVBoxLayout()) diff --git a/src/allencell_ml_segmenter/training/model_selection_widget.py b/src/allencell_ml_segmenter/training/model_selection_widget.py index 18409e9c..47565414 100644 --- a/src/allencell_ml_segmenter/training/model_selection_widget.py +++ b/src/allencell_ml_segmenter/training/model_selection_widget.py @@ -4,13 +4,9 @@ QVBoxLayout, QSizePolicy, QFrame, - QLabel, QGridLayout, QComboBox, - QHBoxLayout, QRadioButton, - QLineEdit, - QCheckBox, ) from allencell_ml_segmenter.core.event import Event from allencell_ml_segmenter.main.i_experiments_model import IExperimentsModel @@ -19,9 +15,7 @@ ) from allencell_ml_segmenter.training.training_model import TrainingModel -from allencell_ml_segmenter.training.training_model import PatchSize from allencell_ml_segmenter.widgets.label_with_hint_widget import LabelWithHint -from PyQt5.QtGui import QIntValidator class ModelSelectionWidget(QWidget): @@ -33,12 +27,10 @@ class ModelSelectionWidget(QWidget): def __init__( self, - training_model: TrainingModel, experiments_model: IExperimentsModel, ): super().__init__() - self._training_model: TrainingModel = training_model self._experiments_model: IExperimentsModel = experiments_model # widget skeleton @@ -66,7 +58,7 @@ def __init__( top_grid_layout.addWidget(self._radio_new_model, 0, 0) self.experiment_info_widget = ExperimentInfoWidget( - self._training_model + self._experiments_model ) label_new_model: LabelWithHint = LabelWithHint("Start a new model") top_grid_layout.addWidget(label_new_model, 0, 1) @@ -93,7 +85,7 @@ def __init__( self._combo_box_existing_models.currentTextChanged.connect( self._model_combo_handler ) - self._training_model.subscribe( + self._experiments_model.subscribe( Event.PROCESS_TRAINING, self, self._process_event_handler ) @@ -107,7 +99,7 @@ def __init__( self._combo_box_existing_models_checkpoint.setEnabled(False) self._combo_box_existing_models_checkpoint.setMinimumWidth(306) self._combo_box_existing_models_checkpoint.currentTextChanged.connect( - lambda path_text: self._training_model.set_checkpoint(path_text) + lambda path_text: self._experiments_model.set_checkpoint(path_text) ) self._combo_box_existing_models.setEnabled(False) self._combo_box_existing_models_checkpoint.setEnabled(False) @@ -119,123 +111,15 @@ def __init__( frame.layout().addLayout(top_grid_layout) - # bottom half - bottom_grid_layout = QGridLayout() - - patch_size_label: LabelWithHint = LabelWithHint("Structure size") - bottom_grid_layout.addWidget(patch_size_label, 0, 0) - - self._patch_size_combo_box: QComboBox = QComboBox() - self._patch_size_combo_box.setObjectName("structureSizeComboBox") - self._patch_size_combo_box.setCurrentIndex(-1) - self._patch_size_combo_box.setPlaceholderText("Select an option") - self._patch_size_combo_box.addItems( - [patch.name.lower() for patch in PatchSize] - ) - self._patch_size_combo_box.currentTextChanged.connect( - lambda size: self._training_model.set_patch_size(size) - ) - bottom_grid_layout.addWidget(self._patch_size_combo_box, 0, 1) - - image_dimensions_label: LabelWithHint = LabelWithHint( - "Image dimension" - ) - bottom_grid_layout.addWidget(image_dimensions_label, 1, 0) - - dimension_choice_layout: QHBoxLayout = QHBoxLayout() - dimension_choice_layout.setSpacing(0) - - self._radio_3d: QRadioButton = QRadioButton() - self._radio_3d.setObjectName("3DRadio") - self._radio_3d.toggled.connect( - lambda: self._training_model.set_image_dims(3) - ) - label_3d: LabelWithHint = LabelWithHint("3D") - - self._radio_2d: QRadioButton = QRadioButton() - self._radio_2d.toggled.connect( - lambda: self._training_model.set_image_dims(2) - ) - label_2d: LabelWithHint = LabelWithHint("2D") - - dimension_choice_layout.addWidget(self._radio_3d) - dimension_choice_layout.addWidget(label_3d) - dimension_choice_layout.addWidget( - self._radio_2d, alignment=Qt.AlignLeft - ) - dimension_choice_layout.addWidget(label_2d, alignment=Qt.AlignLeft) - dimension_choice_layout.addStretch(10) - dimension_choice_layout.setContentsMargins(0, 0, 0, 0) - - dimension_choice_dummy: QWidget = ( - QWidget() - ) # stops interference with other radio buttons - dimension_choice_dummy.setLayout(dimension_choice_layout) - - bottom_grid_layout.addWidget(dimension_choice_dummy, 1, 1) - - max_epoch_label: LabelWithHint = LabelWithHint("Training steps") - bottom_grid_layout.addWidget(max_epoch_label, 2, 0) - - self._max_epoch_input: QLineEdit = QLineEdit() - # allow only integers TODO [needs test coverage] - self._max_epoch_input.setValidator(QIntValidator()) - self._max_epoch_input.setPlaceholderText("1000") - self._max_epoch_input.setObjectName("trainingStepInput") - self._max_epoch_input.textChanged.connect( - self._max_epochtext_field_handler - ) - bottom_grid_layout.addWidget(self._max_epoch_input, 2, 1) - - max_time_layout: QHBoxLayout = QHBoxLayout() - max_time_layout.setSpacing(0) - - self._max_time_checkbox: QCheckBox = QCheckBox() - self._max_time_checkbox.setObjectName("timeoutCheckbox") - self._max_time_checkbox.stateChanged.connect( - self._max_time_checkbox_slot - ) - max_time_layout.addWidget(self._max_time_checkbox) - - max_time_left_text: QLabel = QLabel("Time out after") - max_time_layout.addWidget(max_time_left_text) - - self._max_time_in_hours_input: QLineEdit = QLineEdit() - self._max_time_in_hours_input.setObjectName("timeoutHourInput") - self._max_time_in_hours_input.setEnabled(False) - self._max_time_in_hours_input.setMaximumWidth(30) - self._max_time_in_hours_input.setPlaceholderText("0") - # TODO: decide between converting as int(text) or float(text) -> will users want to use decimals? is there a better way to convert from hours to seconds? - # TODO: how to handle invalid (not convertible to a number) input? - self._max_time_in_hours_input.textChanged.connect( - lambda text: self._training_model.set_max_time( - round(float(text) * 3600) - ) - ) - max_time_layout.addWidget(self._max_time_in_hours_input) - - max_time_right_text: LabelWithHint = LabelWithHint("hours") - max_time_layout.addWidget(max_time_right_text, alignment=Qt.AlignLeft) - max_time_layout.addStretch() - - bottom_grid_layout.addLayout(max_time_layout, 3, 1) - bottom_grid_layout.setColumnStretch(1, 8) - bottom_grid_layout.setColumnStretch(0, 3) - - frame.layout().addLayout(bottom_grid_layout) - - def _max_epochtext_field_handler(self, max_epochs: str) -> None: - self._training_model.set_max_epoch(int(max_epochs)) - def _model_combo_handler(self, experiment_name: str) -> None: """ Triggered when the user selects a model from the _combo_box_existing_models. Sets the model path in the model. """ if experiment_name == "": - self._training_model.set_experiment_name(None) + self._experiments_model.set_experiment_name(None) else: - self._training_model.set_experiment_name(experiment_name) + self._experiments_model.set_experiment_name(experiment_name) self._refresh_checkpoint_options() def _model_radio_handler(self) -> None: @@ -251,29 +135,19 @@ def _model_radio_handler(self) -> None: self._combo_box_existing_models_checkpoint.setEnabled(False) self._combo_box_existing_models_checkpoint.clear() - self._training_model.set_experiment_name(None) - self._training_model.set_checkpoint(None) + self._experiments_model.set_experiment_name(None) + self._experiments_model.set_checkpoint(None) if self._radio_existing_model.isChecked(): """ Triggered when the user selects the "existing model" radio button. Enables and disables relevent controls. """ - self._training_model.set_experiment_name(None) + self._experiments_model.set_experiment_name(None) self._combo_box_existing_models.setEnabled(True) self.experiment_info_widget.set_enabled(False) self.experiment_info_widget.clear() - def _max_time_checkbox_slot(self, checked: Qt.CheckState) -> None: - """ - Triggered when the user selects the "time out after" _timeout_checkbox. - Enables/disables interaction with the neighboring hour input based on checkstate. - """ - if checked == Qt.Checked: - self._max_time_in_hours_input.setEnabled(True) - else: - self._max_time_in_hours_input.setEnabled(False) - def _process_event_handler(self, _: Event = None) -> None: """ Refreshes the experiments in the _combo_box_existing_models. @@ -282,8 +156,8 @@ def _process_event_handler(self, _: Event = None) -> None: self._refresh_experiment_options() if ( self._radio_existing_model.isChecked() - and self._training_model.get_experiment_name() is not None - and not self._training_model.is_training_running() + and self._experiments_model.get_experiment_name() is not None + and not self._experiments_model.is_training_running() ): self._refresh_checkpoint_options() @@ -297,12 +171,12 @@ def _refresh_experiment_options(self): def _refresh_checkpoint_options(self): # update and enable checkpoint combo box self._experiments_model.refresh_checkpoints( - self._training_model.get_experiment_name() + self._experiments_model.get_experiment_name() ) self._combo_box_existing_models_checkpoint.clear() self._combo_box_existing_models_checkpoint.addItems( self._experiments_model.get_experiments()[ - self._training_model.get_experiment_name() + self._experiments_model.get_experiment_name() ] ) self._combo_box_existing_models_checkpoint.setCurrentIndex(-1) diff --git a/src/allencell_ml_segmenter/training/training_model.py b/src/allencell_ml_segmenter/training/training_model.py index 0c3d19e2..37f730a8 100644 --- a/src/allencell_ml_segmenter/training/training_model.py +++ b/src/allencell_ml_segmenter/training/training_model.py @@ -46,8 +46,6 @@ class TrainingModel(Publisher): def __init__(self, main_model: MainModel): super().__init__() self._main_model = main_model - self._experiment_name: str = None - self._checkpoint: str = None self._experiment_type: TrainingType = None self._hardware_type: Hardware = None self._images_directory: Path = None @@ -63,34 +61,6 @@ def __init__(self, main_model: MainModel): self._config_dir: Path = None self._is_training_running: bool = False - def get_experiment_name(self) -> str: - """ - Gets experiment name - """ - return self._experiment_name - - def set_experiment_name(self, name: str) -> None: - """ - Sets experiment name - - name (str): name of cyto-dl experiment - """ - self._experiment_name = name - - def get_checkpoint(self) -> str: - """ - Gets checkpoint - """ - return self._checkpoint - - def set_checkpoint(self, checkpoint: str) -> None: - """ - Sets checkpoint - - checkpoint (str): name of checkpoint to use - """ - self._checkpoint = checkpoint - def get_experiment_type(self) -> TrainingType: """ Gets experiment type diff --git a/src/allencell_ml_segmenter/training/view.py b/src/allencell_ml_segmenter/training/view.py index 16085f93..6db3f1e5 100644 --- a/src/allencell_ml_segmenter/training/view.py +++ b/src/allencell_ml_segmenter/training/view.py @@ -8,6 +8,13 @@ QFrame, QVBoxLayout, QSizePolicy, + QWidget, + QGridLayout, + QComboBox, + QHBoxLayout, + QRadioButton, + QLineEdit, + QCheckBox, ) from allencell_ml_segmenter._style import Style from allencell_ml_segmenter.core.event import Event @@ -26,6 +33,10 @@ from aicsimageio import AICSImage from aicsimageio.readers import TiffReader +from allencell_ml_segmenter.widgets.label_with_hint_widget import LabelWithHint +from PyQt5.QtGui import QIntValidator +from allencell_ml_segmenter.training.training_model import PatchSize + class TrainingView(View): """ @@ -68,23 +79,120 @@ def __init__( ) image_selection_widget.setObjectName("imageSelection") - model_selection_widget: ModelSelectionWidget = ModelSelectionWidget( - self._training_model, self._experiments_model - ) - model_selection_widget.setObjectName("modelSelection") - # Dummy divs allow for easy alignment top_container: QVBoxLayout = QVBoxLayout() top_dummy: QFrame = QFrame() - bottom_container: QVBoxLayout = QVBoxLayout() bottom_dummy: QFrame = QFrame() top_container.addWidget(image_selection_widget) top_dummy.setLayout(top_container) self.layout().addWidget(top_dummy) - bottom_container.addWidget(model_selection_widget) - bottom_dummy.setLayout(bottom_container) + + # bottom half + bottom_grid_layout = QGridLayout() + + patch_size_label: LabelWithHint = LabelWithHint("Structure size") + bottom_grid_layout.addWidget(patch_size_label, 0, 0) + + self._patch_size_combo_box: QComboBox = QComboBox() + self._patch_size_combo_box.setObjectName("structureSizeComboBox") + self._patch_size_combo_box.setCurrentIndex(-1) + self._patch_size_combo_box.setPlaceholderText("Select an option") + self._patch_size_combo_box.addItems( + [patch.name.lower() for patch in PatchSize] + ) + self._patch_size_combo_box.currentTextChanged.connect( + lambda size: self._training_model.set_patch_size(size) + ) + bottom_grid_layout.addWidget(self._patch_size_combo_box, 0, 1) + + image_dimensions_label: LabelWithHint = LabelWithHint( + "Image dimension" + ) + bottom_grid_layout.addWidget(image_dimensions_label, 1, 0) + + dimension_choice_layout: QHBoxLayout = QHBoxLayout() + dimension_choice_layout.setSpacing(0) + + self._radio_3d: QRadioButton = QRadioButton() + self._radio_3d.setObjectName("3DRadio") + self._radio_3d.toggled.connect( + lambda: self._training_model.set_image_dims(3) + ) + label_3d: LabelWithHint = LabelWithHint("3D") + + self._radio_2d: QRadioButton = QRadioButton() + self._radio_2d.toggled.connect( + lambda: self._training_model.set_image_dims(2) + ) + label_2d: LabelWithHint = LabelWithHint("2D") + + dimension_choice_layout.addWidget(self._radio_3d) + dimension_choice_layout.addWidget(label_3d) + dimension_choice_layout.addWidget( + self._radio_2d, alignment=Qt.AlignLeft + ) + dimension_choice_layout.addWidget(label_2d, alignment=Qt.AlignLeft) + dimension_choice_layout.addStretch(10) + dimension_choice_layout.setContentsMargins(0, 0, 0, 0) + + dimension_choice_dummy: QWidget = ( + QWidget() + ) # stops interference with other radio buttons + dimension_choice_dummy.setLayout(dimension_choice_layout) + + bottom_grid_layout.addWidget(dimension_choice_dummy, 1, 1) + + max_epoch_label: LabelWithHint = LabelWithHint("Training steps") + bottom_grid_layout.addWidget(max_epoch_label, 2, 0) + + self._max_epoch_input: QLineEdit = QLineEdit() + # allow only integers TODO [needs test coverage] + self._max_epoch_input.setValidator(QIntValidator()) + self._max_epoch_input.setPlaceholderText("1000") + self._max_epoch_input.setObjectName("trainingStepInput") + self._max_epoch_input.textChanged.connect( + self._max_epochtext_field_handler + ) + bottom_grid_layout.addWidget(self._max_epoch_input, 2, 1) + + max_time_layout: QHBoxLayout = QHBoxLayout() + max_time_layout.setSpacing(0) + + self._max_time_checkbox: QCheckBox = QCheckBox() + self._max_time_checkbox.setObjectName("timeoutCheckbox") + self._max_time_checkbox.stateChanged.connect( + self._max_time_checkbox_slot + ) + max_time_layout.addWidget(self._max_time_checkbox) + + max_time_left_text: QLabel = QLabel("Time out after") + max_time_layout.addWidget(max_time_left_text) + + self._max_time_in_hours_input: QLineEdit = QLineEdit() + self._max_time_in_hours_input.setObjectName("timeoutHourInput") + self._max_time_in_hours_input.setEnabled(False) + self._max_time_in_hours_input.setMaximumWidth(30) + self._max_time_in_hours_input.setPlaceholderText("0") + # TODO: decide between converting as int(text) or float(text) -> will users want to use decimals? is there a better way to convert from hours to seconds? + # TODO: how to handle invalid (not convertible to a number) input? + self._max_time_in_hours_input.textChanged.connect( + lambda text: self._training_model.set_max_time( + round(float(text) * 3600) + ) + ) + max_time_layout.addWidget(self._max_time_in_hours_input) + + max_time_right_text: LabelWithHint = LabelWithHint("hours") + max_time_layout.addWidget(max_time_right_text, alignment=Qt.AlignLeft) + max_time_layout.addStretch() + + bottom_grid_layout.addLayout(max_time_layout, 3, 1) + bottom_grid_layout.setColumnStretch(1, 8) + bottom_grid_layout.setColumnStretch(0, 3) + + bottom_dummy.setLayout(bottom_grid_layout) self.layout().addWidget(bottom_dummy) self._train_btn: QPushButton = QPushButton("Start training") @@ -147,7 +255,7 @@ def doWork(self): self._training_model.set_training_running(True) result_images = self.read_result_images( self._experiments_model.get_model_test_images_path( - self._training_model.get_experiment_name() + self._experiments_model.get_experiment_name() ) ) print("doWork - setting result images") @@ -164,3 +272,16 @@ def getTypeOfWork(self) -> str: def showResults(self): for idx, image in enumerate(self._training_model.get_result_images()): self.add_image_to_viewer(image.data, f"Segmentation {str(idx)}") + + def _max_epochtext_field_handler(self, max_epochs: str) -> None: + self._training_model.set_max_epoch(int(max_epochs)) + + def _max_time_checkbox_slot(self, checked: Qt.CheckState) -> None: + """ + Triggered when the user selects the "time out after" _timeout_checkbox. + Enables/disables interaction with the neighboring hour input based on checkstate. + """ + if checked == Qt.Checked: + self._max_time_in_hours_input.setEnabled(True) + else: + self._max_time_in_hours_input.setEnabled(False) \ No newline at end of file