diff --git a/dacapo/train.py b/dacapo/train.py index c8d700510..f218b3251 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -6,7 +6,7 @@ create_weights_store, ) from dacapo.experiments import Run -from dacapo.validate import validate_run +from dacapo.validate import validate import torch from tqdm import tqdm @@ -187,13 +187,13 @@ def train_run(run: Run): try: # launch validation in a separate thread to avoid blocking training validate_thread = threading.Thread( - target=validate_run, + target=validate, args=(run, iteration_stats.iteration + 1), name=f"validate_{run.name}_{iteration_stats.iteration + 1}", daemon=True, ) validate_thread.start() - # validate_run( + # validate( # run, # iteration_stats.iteration + 1, # ) diff --git a/dacapo/validate.py b/dacapo/validate.py index 79da393e2..83a2e5a8e 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -10,12 +10,36 @@ from pathlib import Path import logging +from warnings import warn logger = logging.getLogger(__name__) +def validate_run( + run: Run, + iteration: int, + num_workers: int = 1, + output_dtype: str = "uint8", + overwrite: bool = True, +): + """ + validate_run is deprecated and will be removed in a future version. Please use validate instead. + """ + warn( + "validate_run is deprecated and will be removed in a future version. Please use validate instead.", + DeprecationWarning, + ) + return validate( + run_name=run, + iteration=iteration, + num_workers=num_workers, + output_dtype=output_dtype, + overwrite=overwrite, + ) + + def validate( - run_name: str, + run_name: str | Run, iteration: int, num_workers: int = 1, output_dtype: str = "uint8", @@ -42,11 +66,13 @@ def validate( print(f"Validating run {run_name} at iteration {iteration}...") - # create run - - config_store = create_config_store() - run_config = config_store.retrieve_run_config(run_name) - run = Run(run_config) + if isinstance(run_name, Run): + run = run_name + run_name = run.name + else: + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) # read in previous training/validation stats stats_store = create_stats_store() @@ -55,43 +81,6 @@ def validate( run_name ) - return validate_run( - run, - iteration, - num_workers=num_workers, - output_dtype=output_dtype, - overwrite=overwrite, - ) - - -# @reloading # allows us to fix validation bugs without interrupting training -def validate_run( - run: Run, - iteration: int, - num_workers: int = 1, - output_dtype: str = "uint8", - overwrite: bool = True, -): - """ - Validate an already loaded run at the given iteration. This does not - load the weights of that iteration, it is assumed that the model is already - loaded correctly. Returns the best parameters and scores for this - iteration. - - Args: - run: The run to validate. - iteration: The iteration to validate. - num_workers: The number of workers to use for validation. - output_dtype: The dtype to use for the output arrays. - overwrite: Whether to overwrite existing output arrays - Returns: - The best parameters and scores for this iteration - Raises: - ValueError: If the run does not have a validation dataset or the dataset does not have ground truth. - Example: - validate_run(run, 1000) - """ - if ( run.datasplit.validate is None or len(run.datasplit.validate) == 0