Skip to content

Commit

Permalink
experiment name for all views refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
hughes036 committed Oct 3, 2023
1 parent d93e518 commit 6dd8b30
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 220 deletions.
12 changes: 6 additions & 6 deletions src/allencell_ml_segmenter/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
CYTO_DL_HOME_PATH = "/Users/brian.kim/Desktop/work/cyto-dl"
USER_EXPERIMENTS_PATH = (
"/Users/brian.kim/Desktop/work/cyto-dl/cyto_dl/logs/train/runs"
)
CYTO_DL_HOME_PATH = "/Users/chrishu/dev/code/test2/cyto-dl"
USER_EXPERIMENTS_PATH = "/Users/chrishu/dev/code/test2/allencell-ml-segmenter/logs/train/runs"

# CYTO_DL_HOME_PATH = "/Users/chrishu/dev/code/test2/cyto-dl"
# USER_EXPERIMENTS_PATH = "/Users/chrishu/dev/code/test2/logs/train/runs"
# CYTO_DL_HOME_PATH = "/Users/brian.kim/Desktop/work/cyto-dl"
# USER_EXPERIMENTS_PATH = (
# "/Users/brian.kim/Desktop/work/cyto-dl/cyto_dl/logs/train/runs"
# )
35 changes: 35 additions & 0 deletions src/allencell_ml_segmenter/main/experiments_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,45 @@

class ExperimentsModel(IExperimentsModel):
def __init__(self, config: CytoDlConfig) -> None:
super().__init__()
self.config = config

# options
self.experiments = {}
self.refresh_experiments()

# state
self._experiment_name: str = None
self._checkpoint: str = None

def get_experiment_name(self) -> str:
"""
Gets experiment name
"""
return self._experiment_name

def set_experiment_name(self, name: str) -> None:
"""
Sets experiment name
name (str): name of cyto-dl experiment
"""
self._experiment_name = name

def get_checkpoint(self) -> str:
"""
Gets checkpoint
"""
return self._checkpoint

def set_checkpoint(self, checkpoint: str) -> None:
"""
Sets checkpoint
checkpoint (str): name of checkpoint to use
"""
self._checkpoint = checkpoint

def refresh_experiments(self) -> None:
for experiment in Path(
self.config.get_user_experiments_path()
Expand Down
23 changes: 22 additions & 1 deletion src/allencell_ml_segmenter/main/i_experiments_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
from abc import ABC, abstractmethod
from pathlib import Path

from allencell_ml_segmenter.core.publisher import Publisher

class IExperimentsModel(ABC):

class IExperimentsModel(Publisher):

"""
Interface for implementing and testing ExperimentsModel
"""

def __init__(self):
super().__init__()

@abstractmethod
def get_experiment_name(self) -> str:
pass

@abstractmethod
def set_experiment_name(self, name: str) -> None:
pass

@abstractmethod
def get_checkpoint(self) -> str:
pass

@abstractmethod
def set_checkpoint(self, checkpoint: str):
pass

@abstractmethod
def get_experiments(self):
pass
Expand Down
18 changes: 13 additions & 5 deletions src/allencell_ml_segmenter/main/main_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from allencell_ml_segmenter.main.experiments_model import ExperimentsModel
from allencell_ml_segmenter.main.main_model import MainModel
from allencell_ml_segmenter.prediction.view import PredictionView
from allencell_ml_segmenter.training.model_selection_widget import ModelSelectionWidget
from allencell_ml_segmenter.training.view import TrainingView


Expand All @@ -43,9 +44,20 @@ def __init__(self, viewer: napari.Viewer, config: CytoDlConfig = None):
Event.ACTION_CHANGE_VIEW, self, self.handle_change_view
)

if config is None:
config = CytoDlConfig(CYTO_DL_HOME_PATH, USER_EXPERIMENTS_PATH)
experiment_model = ExperimentsModel(config)

# Model selection which applies to all views
model_selection_widget: ModelSelectionWidget = ModelSelectionWidget(
experiment_model
)
model_selection_widget.setObjectName("modelSelection")
self.layout().addWidget(model_selection_widget, Qt.AlignTop)

# keep track of views
self._view_container: QTabWidget = QTabWidget()
self.layout().addWidget(self._view_container, Qt.AlignTop)
self.layout().addWidget(self._view_container, Qt.AlignCenter )
self.layout().addStretch(100)

self._view_to_index: Dict[View, int] = dict()
Expand All @@ -54,10 +66,6 @@ def __init__(self, viewer: napari.Viewer, config: CytoDlConfig = None):
self._prediction_view: PredictionView = PredictionView(self._model)
self._initialize_view(self._prediction_view, "Prediction")

if config is None:
config = CytoDlConfig(CYTO_DL_HOME_PATH, USER_EXPERIMENTS_PATH)
experiment_model = ExperimentsModel(config)

training_view: TrainingView = TrainingView(
main_model=self._model,
viewer=self.viewer,
Expand Down
31 changes: 3 additions & 28 deletions src/allencell_ml_segmenter/services/training_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,6 @@ def _list_to_string(list_to_convert: List[Any]) -> str:
return f"[{ints_to_strings}]"


# class MyPrintingCallback(Callback):
# def __init__(self):
# super().__init__()

# def on_train_start(self, trainer, pl_module):
# print(
# "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@Training is starting@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
# )

# def on_train_epoch_start(self, trainer, pl_module):
# print(
# f"@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@Epoch {trainer.current_epoch} is starting@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
# )

# def on_train_epoch_end(self, trainer, pl_module):
# print(
# f"@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@Training {trainer.current_epoch} is ending@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
# )

# def on_train_end(self, trainer, pl_module):
# print(
# "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@Training is ending@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
# )


class TrainingService(Subscriber):
"""
Interface for training a model. Uses cyto-dl to train model according to spec
Expand Down Expand Up @@ -101,9 +76,9 @@ def train_model_handler(self, _: Event) -> None:
# This is meant to be a string as is - not a string template. In cyto-dl, it will be treated as a string template
"hydra.run.dir=${paths.log_dir}/${task_name}/runs/${experiment_name}"
)
if self._training_model.get_checkpoint() is not None:
if self._experiments_model.get_checkpoint() is not None:
sys.argv.append(
f"ckpt_path={self._experiments_model.get_model_checkpoints_path(self._training_model.get_experiment_name(), self._training_model.get_checkpoint())}"
f"ckpt_path={self._experiments_model.get_model_checkpoints_path(self._experiments_model.get_experiment_name(), self._experiments_model.get_checkpoint())}"
)
# sys.argv.append(
# "+callbacks.print_progress._target_=allencell_ml_segmenter.services.training_service.MyPrintingCallback"
Expand Down Expand Up @@ -149,7 +124,7 @@ def _set_experiment_name(self) -> None:
"""
Sets the experiment_name argument variable for hydra override using sys.argv
"""
experiment_name: str = self._training_model.get_experiment_name()
experiment_name: str = self._experiments_model.get_experiment_name()
sys.argv.append(f"++experiment_name={experiment_name}")

def _set_max_epoch(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions src/allencell_ml_segmenter/training/experiment_info_widget.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from allencell_ml_segmenter.training.training_model import TrainingModel
from allencell_ml_segmenter.main.i_experiments_model import IExperimentsModel
from qtpy.QtWidgets import (
QWidget,
QFrame,
Expand All @@ -18,10 +18,10 @@ class ExperimentInfoWidget(QWidget):

TITLE_TEXT: str = "Experiment information"

def __init__(self, model: TrainingModel):
def __init__(self, model: IExperimentsModel):
super().__init__()

self._model: TrainingModel = model
self._model: IExperimentsModel = model

# widget skeleton
self.setLayout(QVBoxLayout())
Expand Down
Loading

0 comments on commit 6dd8b30

Please sign in to comment.