diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index c8b666734..87584c0e5 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -70,7 +70,7 @@ def cli(log_level): ) @click.option("-od", "--output_dataset", required=True, type=str) def start_worker( - run_name: str, + run_name: str| Run, iteration: int | None, input_container: Path | str, input_dataset: str, @@ -78,6 +78,28 @@ def start_worker( output_dataset: str, return_io_loop: Optional[bool] = False, ): + return start_worker_fn( + run_name=run_name, + iteration=iteration, + input_container=input_container, + input_dataset=input_dataset, + output_container=output_container, + output_dataset=output_dataset, + return_io_loop=return_io_loop, + ) + + + +def start_worker_fn( + run_name: str| Run, + iteration: int | None, + input_container: Path | str, + input_dataset: str, + output_container: Path | str, + output_dataset: str, + return_io_loop: bool, +): + """ Start a worker to apply a trained model to a dataset. @@ -93,9 +115,14 @@ def start_worker( device = compute_context.device # retrieving run - config_store = create_config_store() - run_config = config_store.retrieve_run_config(run_name) - run = Run(run_config) + logger.error(f"run_name: {run_name} {type(run_name)}" ) + if isinstance(run_name, Run): + run = run_name + run_name = run.name + else: + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) if iteration is not None: # create weights store @@ -207,7 +234,7 @@ def io_loop(): def spawn_worker( - run_name: str, + run_name: str| Run, iteration: int | None, input_array_identifier: "LocalArrayIdentifier", output_array_identifier: "LocalArrayIdentifier", @@ -225,13 +252,13 @@ def spawn_worker( """ compute_context = create_compute_context() if not compute_context.distribute_workers: - return start_worker( - run_name, - iteration, - input_array_identifier.container, - input_array_identifier.dataset, - output_array_identifier.container, - output_array_identifier.dataset, + return start_worker_fn( + run_name= run_name, + iteration=iteration, + input_container= input_array_identifier.container, + input_dataset= input_array_identifier.dataset, + output_container=output_array_identifier.container, + output_dataset=output_array_identifier.dataset, return_io_loop=True, ) diff --git a/dacapo/predict.py b/dacapo/predict.py index 0c09a9f7a..8197b2e63 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -94,7 +94,7 @@ def predict( input_size = input_voxel_size * input_shape output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] num_out_channels = model.num_out_channels - del model + # del model # calculate input and output rois @@ -149,7 +149,7 @@ def predict( max_retries=2, # TODO: make this an option timeout=None, # TODO: make this an option ###### - run_name=run_name, + run_name=run, iteration=iteration, input_array_identifier=input_array_identifier, output_array_identifier=output_array_identifier,