Skip to content

Commit

Permalink
Merge 9117256 into 313f69c
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar authored Apr 16, 2024
2 parents 313f69c + 9117256 commit acaa300
Show file tree
Hide file tree
Showing 172 changed files with 17,942 additions and 908 deletions.
67 changes: 64 additions & 3 deletions dacapo/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
create_weights_store,
)

from pathlib import Path
from upath import UPath as Path

logger = logging.getLogger(__name__)

Expand All @@ -38,7 +38,40 @@ def apply(
overwrite: bool = True,
file_format: str = "zarr",
):
"""Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used."""
"""
Load weights and apply a trained model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used.
Args:
run_name (str): Name of the run to apply.
input_container (Path | str): Path to the input container.
input_dataset (str): Name of the input dataset.
output_path (Path | str): Path to the output container.
validation_dataset (Optional[Dataset | str], optional): Validation dataset to use for finding the best parameters. Defaults to None.
criterion (str, optional): Criterion to use for finding the best parameters. Defaults to "voi".
iteration (Optional[int], optional): Iteration to use. If None, the best iteration is used. Defaults to None.
parameters (Optional[PostProcessorParameters | str], optional): Post-processor parameters to use. If None, the best parameters are found. Defaults to None.
roi (Optional[Roi | str], optional): Region of interest to use. If None, the whole input dataset is used. Defaults to None.
num_workers (int, optional): Number of workers to use. Defaults to 12.
output_dtype (np.dtype | str, optional): Output dtype. Defaults to np.uint8.
overwrite (bool, optional): Overwrite existing output. Defaults to True.
file_format (str, optional): File format to use. Defaults to "zarr".
Raises:
ValueError: If validation_dataset is None and criterion is not None.
ValueError: If parameters is a string that cannot be parsed to PostProcessorParameters.
ValueError: If parameters is not a PostProcessorParameters object.
Examples:
>>> apply(
... run_name="run_1",
... input_container="data.zarr",
... input_dataset="raw",
... output_path="output.zarr",
... validation_dataset="validate",
... criterion="voi",
... num_workers=12,
... output_dtype=np.uint8,
... overwrite=True,
... )
"""
if isinstance(output_dtype, str):
output_dtype = np.dtype(output_dtype)

Expand Down Expand Up @@ -178,8 +211,36 @@ def apply_run(
output_dtype: np.dtype | str = np.uint8, # type: ignore
overwrite: bool = True,
):
"""Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded."""
"""
Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded.
Args:
run (Run): The run object containing the task and post-processor.
iteration (int): The iteration number.
parameters (PostProcessorParameters): The post-processor parameters.
input_array_identifier (LocalArrayIdentifier): The identifier for the input array.
prediction_array_identifier (LocalArrayIdentifier): The identifier for the prediction array.
output_array_identifier (LocalArrayIdentifier): The identifier for the output array.
roi (Optional[Roi], optional): The region of interest. Defaults to None.
num_workers (int, optional): The number of workers for parallel processing. Defaults to 12.
output_dtype (np.dtype | str, optional): The output data type. Defaults to np.uint8.
overwrite (bool, optional): Whether to overwrite existing output. Defaults to True.
Raises:
ValueError: If the input array is not a ZarrArray.
Examples:
>>> apply_run(
... run=run,
... iteration=1,
... parameters=parameters,
... input_array_identifier=LocalArrayIdentifier(Path("data.zarr"), "raw"),
... prediction_array_identifier=LocalArrayIdentifier(Path("output.zarr"), "prediction_run_1_1"),
... output_array_identifier=LocalArrayIdentifier(Path("output.zarr"), "output_run_1_1"),
... roi=None,
... num_workers=12,
... output_dtype=np.uint8,
... overwrite=True,
... )
"""
# render prediction dataset
print(f"Predicting on dataset {prediction_array_identifier}")
predict(
Expand Down
89 changes: 75 additions & 14 deletions dacapo/blockwise/argmax_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pathlib import Path
from upath import UPath as Path
import sys
from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray
from dacapo.store.array_store import LocalArrayIdentifier
Expand Down Expand Up @@ -27,6 +27,18 @@
default="INFO",
)
def cli(log_level):
"""
CLI for running the threshold worker.
Args:
log_level (str): The log level to use.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> cli(log_level="INFO")
Note:
The method is implemented in the class.
"""
logging.basicConfig(level=getattr(logging, log_level.upper()))


Expand All @@ -47,7 +59,23 @@ def start_worker(
input_dataset: str,
output_container: Path | str,
output_dataset: str,
return_io_loop: bool = False,
):
"""
Start the threshold worker.
Args:
input_container (Path | str): The input container.
input_dataset (str): The input dataset.
output_container (Path | str): The output container.
output_dataset (str): The output dataset.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> start_worker(input_container="input_container", input_dataset="input_dataset", output_container="output_container", output_dataset="output_dataset")
Note:
The method is implemented in the class.
"""
# get arrays
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
Expand All @@ -57,34 +85,57 @@ def start_worker(
)
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)

# wait for blocks to run pipeline
client = daisy.Client()
def io_loop():
# wait for blocks to run pipeline
client = daisy.Client()

while True:
print("getting block")
with client.acquire_block() as block:
if block is None:
break
while True:
print("getting block")
with client.acquire_block() as block:
if block is None:
break

# write to output array
output_array[block.write_roi] = np.argmax(
input_array[block.write_roi],
axis=input_array.axes.index("c"),
)
# write to output array
output_array[block.write_roi] = np.argmax(
input_array[block.write_roi],
axis=input_array.axes.index("c"),
)

if return_io_loop:
return io_loop
else:
io_loop()


def spawn_worker(
input_array_identifier: "LocalArrayIdentifier",
output_array_identifier: "LocalArrayIdentifier",
):
"""Spawn a worker to predict on a given dataset.
"""
Spawn a worker to predict on a given dataset.
Args:
model (Model): The model to use for prediction.
raw_array (Array): The raw data to predict on.
prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
Returns:
The worker to run.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> spawn_worker(model, raw_array, prediction_array_identifier)
Note:
The method is implemented in the class.
"""
compute_context = create_compute_context()
if not compute_context.distribute_workers:
return start_worker(
input_array_identifier.container,
input_array_identifier.dataset,
output_array_identifier.container,
output_array_identifier.dataset,
return_io_loop=True,
)

# Make the command for the worker to run
command = [
Expand All @@ -103,6 +154,16 @@ def spawn_worker(
]

def run_worker():
"""
Run the worker in the given compute context.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> run_worker()
Note:
The method is implemented in the class.
"""
# Run the worker in the given compute context
compute_context.execute(command)

Expand Down
49 changes: 48 additions & 1 deletion dacapo/blockwise/blockwise_task.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
from datetime import datetime
from importlib.machinery import SourceFileLoader
from pathlib import Path
from upath import UPath as Path
from daisy import Task, Roi


class DaCapoBlockwiseTask(Task):
"""
A task to run a blockwise worker function. This task is used to run a
blockwise worker function on a given ROI.
Attributes:
worker_file (str | Path): The path to the worker file.
total_roi (Roi): The ROI to process.
read_roi (Roi): The ROI to read from for a block.
write_roi (Roi): The ROI to write to for a block.
num_workers (int): The number of workers to use.
max_retries (int): The maximum number of times a task will be retried if failed
(either due to failed post check or application crashes or network
failure)
timeout: The timeout for the task.
upstream_tasks: The upstream tasks.
*args: Additional positional arguments to pass to ``worker_function``.
**kwargs: Additional keyword arguments to pass to ``worker_function``.
Methods:
__init__:
Initialize the task.
Note:
The method is implemented in the class.
"""

def __init__(
self,
worker_file: str | Path,
Expand All @@ -18,6 +42,29 @@ def __init__(
*args,
**kwargs,
):
"""
Initialize the task.
Args:
worker_file (str | Path): The path to the worker file.
total_roi (Roi): The ROI to process.
read_roi (Roi): The ROI to read from for a block.
write_roi (Roi): The ROI to write to for a block.
num_workers (int): The number of workers to use.
max_retries (int): The maximum number of times a task will be retried if failed
(either due to failed post check or application crashes or network
failure)
timeout: The timeout for the task.
upstream_tasks: The upstream tasks.
*args: Additional positional arguments to pass to ``worker_function``.
**kwargs: Additional keyword arguments to pass to ``worker_function``.
Raises:
ValueError: If the worker file is not a valid path.
Examples:
>>> DaCapoBlockwiseTask(worker_file="worker_file", total_roi=Roi, read_roi=Roi, write_roi=Roi, num_workers=16, max_retries=2, timeout=None, upstream_tasks=None)
Note:
The method is implemented in the class.
"""
# Load worker functions
worker_name = Path(worker_file).stem
worker = SourceFileLoader(worker_name, str(worker_file)).load_module()
Expand Down
Loading

0 comments on commit acaa300

Please sign in to comment.