Skip to content

Commit

Permalink
Merge branch 'dev/main' into fix_starter
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored Apr 5, 2024
2 parents 6b7076a + 679b61c commit e34aa5d
Show file tree
Hide file tree
Showing 100 changed files with 10,942 additions and 427 deletions.
48 changes: 47 additions & 1 deletion dacapo/blockwise/argmax_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@
default="INFO",
)
def cli(log_level):
"""
CLI for running the threshold worker.
Args:
log_level (str): The log level to use.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> cli(log_level="INFO")
Note:
The method is implemented in the class.
"""
logging.basicConfig(level=getattr(logging, log_level.upper()))


Expand All @@ -48,6 +60,21 @@ def start_worker(
output_container: Path | str,
output_dataset: str,
):
"""
Start the threshold worker.
Args:
input_container (Path | str): The input container.
input_dataset (str): The input dataset.
output_container (Path | str): The output container.
output_dataset (str): The output dataset.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> start_worker(input_container="input_container", input_dataset="input_dataset", output_container="output_container", output_dataset="output_dataset")
Note:
The method is implemented in the class.
"""
# get arrays
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
Expand Down Expand Up @@ -77,12 +104,21 @@ def spawn_worker(
input_array_identifier: "LocalArrayIdentifier",
output_array_identifier: "LocalArrayIdentifier",
):
"""Spawn a worker to predict on a given dataset.
"""
Spawn a worker to predict on a given dataset.
Args:
model (Model): The model to use for prediction.
raw_array (Array): The raw data to predict on.
prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
Returns:
The worker to run.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> spawn_worker(model, raw_array, prediction_array_identifier)
Note:
The method is implemented in the class.
"""
compute_context = create_compute_context()

Expand All @@ -103,6 +139,16 @@ def spawn_worker(
]

def run_worker():
"""
Run the worker in the given compute context.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> run_worker()
Note:
The method is implemented in the class.
"""
# Run the worker in the given compute context
compute_context.execute(command)

Expand Down
46 changes: 46 additions & 0 deletions dacapo/blockwise/blockwise_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,29 @@


class DaCapoBlockwiseTask(Task):
"""
A task to run a blockwise worker function. This task is used to run a
blockwise worker function on a given ROI.
Attributes:
worker_file (str | Path): The path to the worker file.
total_roi (Roi): The ROI to process.
read_roi (Roi): The ROI to read from for a block.
write_roi (Roi): The ROI to write to for a block.
num_workers (int): The number of workers to use.
max_retries (int): The maximum number of times a task will be retried if failed
(either due to failed post check or application crashes or network
failure)
timeout: The timeout for the task.
upstream_tasks: The upstream tasks.
*args: Additional positional arguments to pass to ``worker_function``.
**kwargs: Additional keyword arguments to pass to ``worker_function``.
Methods:
__init__:
Initialize the task.
Note:
The method is implemented in the class.
"""
def __init__(
self,
worker_file: str | Path,
Expand All @@ -18,6 +41,29 @@ def __init__(
*args,
**kwargs,
):
"""
Initialize the task.
Args:
worker_file (str | Path): The path to the worker file.
total_roi (Roi): The ROI to process.
read_roi (Roi): The ROI to read from for a block.
write_roi (Roi): The ROI to write to for a block.
num_workers (int): The number of workers to use.
max_retries (int): The maximum number of times a task will be retried if failed
(either due to failed post check or application crashes or network
failure)
timeout: The timeout for the task.
upstream_tasks: The upstream tasks.
*args: Additional positional arguments to pass to ``worker_function``.
**kwargs: Additional keyword arguments to pass to ``worker_function``.
Raises:
ValueError: If the worker file is not a valid path.
Examples:
>>> DaCapoBlockwiseTask(worker_file="worker_file", total_roi=Roi, read_roi=Roi, write_roi=Roi, num_workers=16, max_retries=2, timeout=None, upstream_tasks=None)
Note:
The method is implemented in the class.
"""
# Load worker functions
worker_name = Path(worker_file).stem
worker = SourceFileLoader(worker_name, str(worker_file)).load_module()
Expand Down
Loading

0 comments on commit e34aa5d

Please sign in to comment.