Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mo more validate_run #224

Merged
merged 3 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dacapo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
create_weights_store,
)
from dacapo.experiments import Run
from dacapo.validate import validate_run
from dacapo.validate import validate

import torch
from tqdm import tqdm
Expand Down Expand Up @@ -187,13 +187,13 @@ def train_run(run: Run):
try:
# launch validation in a separate thread to avoid blocking training
validate_thread = threading.Thread(
target=validate_run,
target=validate,
args=(run, iteration_stats.iteration + 1),
name=f"validate_{run.name}_{iteration_stats.iteration + 1}",
daemon=True,
)
validate_thread.start()
# validate_run(
# validate(
# run,
# iteration_stats.iteration + 1,
# )
Expand Down
75 changes: 32 additions & 43 deletions dacapo/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,36 @@

from pathlib import Path
import logging
from warnings import warn

logger = logging.getLogger(__name__)


def validate_run(
run: Run,
iteration: int,
num_workers: int = 1,
output_dtype: str = "uint8",
overwrite: bool = True,
):
"""
validate_run is deprecated and will be removed in a future version. Please use validate instead.
"""
warn(
"validate_run is deprecated and will be removed in a future version. Please use validate instead.",
DeprecationWarning,
)
return validate(
run_name=run,
iteration=iteration,
num_workers=num_workers,
output_dtype=output_dtype,
overwrite=overwrite,
)


def validate(
run_name: str,
run_name: str | Run,
iteration: int,
num_workers: int = 1,
output_dtype: str = "uint8",
Expand All @@ -42,11 +66,13 @@ def validate(

print(f"Validating run {run_name} at iteration {iteration}...")

# create run

config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)
if isinstance(run_name, Run):
run = run_name
run_name = run.name
else:
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

# read in previous training/validation stats
stats_store = create_stats_store()
Expand All @@ -55,43 +81,6 @@ def validate(
run_name
)

return validate_run(
run,
iteration,
num_workers=num_workers,
output_dtype=output_dtype,
overwrite=overwrite,
)


# @reloading # allows us to fix validation bugs without interrupting training
def validate_run(
run: Run,
iteration: int,
num_workers: int = 1,
output_dtype: str = "uint8",
overwrite: bool = True,
):
"""
Validate an already loaded run at the given iteration. This does not
load the weights of that iteration, it is assumed that the model is already
loaded correctly. Returns the best parameters and scores for this
iteration.

Args:
run: The run to validate.
iteration: The iteration to validate.
num_workers: The number of workers to use for validation.
output_dtype: The dtype to use for the output arrays.
overwrite: Whether to overwrite existing output arrays
Returns:
The best parameters and scores for this iteration
Raises:
ValueError: If the run does not have a validation dataset or the dataset does not have ground truth.
Example:
validate_run(run, 1000)
"""

if (
run.datasplit.validate is None
or len(run.datasplit.validate) == 0
Expand Down
Loading