diff --git a/dacapo/blockwise/argmax_worker.py b/dacapo/blockwise/argmax_worker.py index 59a17d752..d685161af 100644 --- a/dacapo/blockwise/argmax_worker.py +++ b/dacapo/blockwise/argmax_worker.py @@ -27,6 +27,18 @@ default="INFO", ) def cli(log_level): + """ + CLI for running the threshold worker. + + Args: + log_level (str): The log level to use. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> cli(log_level="INFO") + Note: + The method is implemented in the class. + """ logging.basicConfig(level=getattr(logging, log_level.upper())) @@ -48,6 +60,21 @@ def start_worker( output_container: Path | str, output_dataset: str, ): + """ + Start the threshold worker. + + Args: + input_container (Path | str): The input container. + input_dataset (str): The input dataset. + output_container (Path | str): The output container. + output_dataset (str): The output dataset. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> start_worker(input_container="input_container", input_dataset="input_dataset", output_container="output_container", output_dataset="output_dataset") + Note: + The method is implemented in the class. + """ # get arrays input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) input_array = ZarrArray.open_from_array_identifier(input_array_identifier) @@ -77,12 +104,21 @@ def spawn_worker( input_array_identifier: "LocalArrayIdentifier", output_array_identifier: "LocalArrayIdentifier", ): - """Spawn a worker to predict on a given dataset. + """ + Spawn a worker to predict on a given dataset. Args: model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. + Returns: + The worker to run. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> spawn_worker(model, raw_array, prediction_array_identifier) + Note: + The method is implemented in the class. """ compute_context = create_compute_context() @@ -103,6 +139,16 @@ def spawn_worker( ] def run_worker(): + """ + Run the worker in the given compute context. + + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> run_worker() + Note: + The method is implemented in the class. + """ # Run the worker in the given compute context compute_context.execute(command) diff --git a/dacapo/blockwise/blockwise_task.py b/dacapo/blockwise/blockwise_task.py index cbae73b8b..cfdc4f9ce 100644 --- a/dacapo/blockwise/blockwise_task.py +++ b/dacapo/blockwise/blockwise_task.py @@ -5,6 +5,29 @@ class DaCapoBlockwiseTask(Task): + """ + A task to run a blockwise worker function. This task is used to run a + blockwise worker function on a given ROI. + + Attributes: + worker_file (str | Path): The path to the worker file. + total_roi (Roi): The ROI to process. + read_roi (Roi): The ROI to read from for a block. + write_roi (Roi): The ROI to write to for a block. + num_workers (int): The number of workers to use. + max_retries (int): The maximum number of times a task will be retried if failed + (either due to failed post check or application crashes or network + failure) + timeout: The timeout for the task. + upstream_tasks: The upstream tasks. + *args: Additional positional arguments to pass to ``worker_function``. + **kwargs: Additional keyword arguments to pass to ``worker_function``. + Methods: + __init__: + Initialize the task. + Note: + The method is implemented in the class. + """ def __init__( self, worker_file: str | Path, @@ -18,6 +41,29 @@ def __init__( *args, **kwargs, ): + """ + Initialize the task. + + Args: + worker_file (str | Path): The path to the worker file. + total_roi (Roi): The ROI to process. + read_roi (Roi): The ROI to read from for a block. + write_roi (Roi): The ROI to write to for a block. + num_workers (int): The number of workers to use. + max_retries (int): The maximum number of times a task will be retried if failed + (either due to failed post check or application crashes or network + failure) + timeout: The timeout for the task. + upstream_tasks: The upstream tasks. + *args: Additional positional arguments to pass to ``worker_function``. + **kwargs: Additional keyword arguments to pass to ``worker_function``. + Raises: + ValueError: If the worker file is not a valid path. + Examples: + >>> DaCapoBlockwiseTask(worker_file="worker_file", total_roi=Roi, read_roi=Roi, write_roi=Roi, num_workers=16, max_retries=2, timeout=None, upstream_tasks=None) + Note: + The method is implemented in the class. + """ # Load worker functions worker_name = Path(worker_file).stem worker = SourceFileLoader(worker_name, str(worker_file)).load_module() diff --git a/dacapo/blockwise/empanada_function.py b/dacapo/blockwise/empanada_function.py index 4175f8577..4e94e748d 100644 --- a/dacapo/blockwise/empanada_function.py +++ b/dacapo/blockwise/empanada_function.py @@ -45,6 +45,35 @@ def segment_function(input_array, block, **parameters): + """ + Segment a 3D block using the empanada-napari library. + + Args: + input_array (np.ndarray): The 3D array to segment. + block (dask.array.core.Block): The block object. + **parameters: Parameters for the empanada-napari segmenter. + Returns: + np.ndarray: The segmented 3D array. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> import numpy as np + >>> from dask import array as da + >>> from dacapo.blockwise.empanada_function import segment_function + >>> input_array = np.random.rand(64, 64, 64) + >>> block = da.from_array(input_array, chunks=(32, 32, 32)) + >>> segmented_array = segment_function(block, model_config="MitoNet_v1") + Note: + The `model_config` parameter should be one of the following: + - MitoNet_v1 + - MitoNet_v2 + - MitoNet_v3 + - MitoNet_v4 + - MitoNet_v5 + - MitoNet_v6 + Reference: + - doi: 10.1016/j.cels.2022.12.006 + """ vols, class_names = [], [] for vol, class_name, _ in empanada_segmenter( input_array[block.read_roi], **parameters @@ -60,12 +89,66 @@ def segment_function(input_array, block, **parameters): def stack_inference(engine, volume, axis_name): + """ + Perform inference on a single axis of a 3D volume. + + Args: + engine (Engine3d): The engine object. + volume (np.ndarray): The 3D volume to segment. + axis_name (str): The axis name to segment. + Returns: + tuple: The stack, axis name, and trackers dictionary. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> import numpy as np + >>> from empanada_napari.inference import Engine3d + >>> from dacapo.blockwise.empanada_function import stack_inference + >>> model_config = "MitoNet_v1" + >>> use_gpu = True + >>> use_quantized = False + >>> engine = Engine3d(model_config, use_gpu=use_gpu, use_quantized=use_quantized) + >>> volume = np.random.rand(64, 64, 64) + >>> axis_name = "xy" + >>> stack, axis_name, trackers_dict = stack_inference(engine, volume, axis_name) + Note: + The `axis_name` parameter should be one of the following: + """ stack, trackers = engine.infer_on_axis(volume, axis_name) trackers_dict = {axis_name: trackers} return stack, axis_name, trackers_dict def orthoplane_inference(engine, volume): + """ + Perform inference on the orthogonal planes of a 3D volume. + + Args: + engine (Engine3d): The engine object. + volume (np.ndarray): The 3D volume to segment. + Returns: + dict: The trackers dictionary. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> import numpy as np + >>> from empanada_napari.inference import Engine3d + >>> from dacapo.blockwise.empanada_function import orthoplane_inference + >>> model_config = "MitoNet_v1" + >>> use_gpu = True + >>> use_quantized = False + >>> engine = Engine3d(model_config, use_gpu=use_gpu, use_quantized=use_quantized) + >>> volume = np.random.rand(64, 64, 64) + >>> trackers_dict = orthoplane_inference(engine, volume) + Note: + The `model_config` parameter should be one of the following: + - MitoNet_v1 + - MitoNet_v2 + - MitoNet_v3 + - MitoNet_v4 + - MitoNet_v5 + - MitoNet_v6 + """ trackers_dict = {} for axis_name in ["xy", "xz", "yz"]: stack, trackers = engine.infer_on_axis(volume, axis_name) @@ -103,6 +186,93 @@ def empanada_segmenter( pixel_vote_thr=1, allow_one_view=False, ): + """ + Segment a 3D volume using the empanada-napari library. + + Args: + image (np.ndarray): The 3D volume to segment. + model_config (str): The model configuration to use. + use_gpu (bool): Whether to use the GPU. + use_quantized (bool): Whether to use quantized inference. + multigpu (bool): Whether to use multiple GPUs. + downsampling (int): The downsampling factor. + confidence_thr (float): The confidence threshold. + center_confidence_thr (float): The center confidence threshold. + min_distance_object_centers (int): The minimum distance between object centers. + fine_boundaries (bool): Whether to use fine boundaries. + semantic_only (bool): Whether to use semantic segmentation only. + median_slices (int): The number of median slices. + min_size (int): The minimum size of objects. + min_extent (int): The minimum extent. + maximum_objects_per_class (int): The maximum number of objects per class. + inference_plane (str): The inference plane. + orthoplane (bool): Whether to use orthoplane inference. + return_panoptic (bool): Whether to return the panoptic segmentation. + pixel_vote_thr (int): The pixel vote threshold. + allow_one_view (bool): Whether to allow one view. + Returns: + tuple: The volume, class name, and tracker. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> import numpy as np + >>> from empanada_napari.inference import Engine3d + >>> from dacapo.blockwise.empanada_function import empanada_segmenter + >>> image = np.random.rand(64, 64, 64) + >>> model_config = "MitoNet_v1" + >>> use_gpu = True + >>> use_quantized = False + >>> multigpu = False + >>> downsampling = 1 + >>> confidence_thr = 0.5 + >>> center_confidence_thr = 0.1 + >>> min_distance_object_centers = 21 + >>> fine_boundaries = True + >>> semantic_only = False + >>> median_slices = 11 + >>> min_size = 10000 + >>> min_extent = 50 + >>> maximum_objects_per_class = 1000000 + >>> inference_plane = "xy" + >>> orthoplane = True + >>> return_panoptic = False + >>> pixel_vote_thr = 1 + >>> allow_one_view = False + >>> for vol, class_name, tracker in empanada_segmenter( + ... image, + ... model_config=model_config, + ... use_gpu=use_gpu, + ... use_quantized=use_quantized, + ... multigpu=multigpu, + ... downsampling=downsampling, + ... confidence_thr=confidence_thr, + ... center_confidence_thr=center_confidence_thr, + ... min_distance_object_centers=min_distance_object_centers, + ... fine_boundaries=fine_boundaries, + ... semantic_only=semantic_only, + ... median_slices=median_slices, + ... min_size=min_size, + ... min_extent=min_extent, + ... maximum_objects_per_class=maximum_objects_per_class, + ... inference_plane=inference_plane, + ... orthoplane=orthoplane, + ... return_panoptic=return_panoptic, + ... pixel_vote_thr=pixel_vote_thr, + ... allow_one_view=allow_one_view + ... ): + ... print(vol.shape, class_name, tracker) + Note: + The `model_config` parameter should be one of the following: + - MitoNet_v1 + - MitoNet_v2 + - MitoNet_v3 + - MitoNet_v4 + - MitoNet_v5 + - MitoNet_v6 + Reference: + - doi: 10.1016/j.cels.2022.12.006 + + """ # load the model config model_config = read_yaml(model_configs[model_config]) min_size = int(min_size) @@ -144,6 +314,22 @@ def empanada_segmenter( ) def start_postprocess_worker(*args): + """ + Start the postprocessing worker. + + Args: + *args: The arguments to pass to the worker. + Returns: + generator: The generator object. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> for vol, class_name, tracker in start_postprocess_worker(*args): + ... print(vol.shape, class_name, tracker) + Note: + The `args` parameter should be a tuple of arguments. + + """ trackers_dict = args[0][2] for vol, class_name, tracker in stack_postprocessing( trackers_dict, @@ -157,6 +343,21 @@ def start_postprocess_worker(*args): yield vol, class_name, tracker def start_consensus_worker(trackers_dict): + """ + Start the consensus worker. + + Args: + trackers_dict (dict): The trackers dictionary. + Returns: + generator: The generator object. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> for vol, class_name, tracker in start_consensus_worker(trackers_dict): + ... print(vol.shape, class_name, tracker) + Note: + The `trackers_dict` parameter should be a dictionary of trackers. + """ for vol, class_name, tracker in tracker_consensus( trackers_dict, model_config, @@ -202,8 +403,34 @@ def stack_postprocessing( min_extent=4, dtype=np.uint32, ): - r"""Relabels and filters each class defined in trackers. Yields a numpy + """ + Relabels and filters each class defined in trackers. Yields a numpy or zarr volume along with the name of the class that is segmented. + + Args: + trackers (dict): The trackers dictionary. + model_config (str): The model configuration to use. + label_divisor (int): The label divisor. + min_size (int): The minimum size of objects. + min_extent (int): The minimum extent of objects. + dtype (type): The data type. + Returns: + generator: The generator object. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> for vol, class_name, tracker in stack_postprocessing(trackers, model_config): + ... print(vol.shape, class_name, tracker) + Note: + The `model_config` parameter should be one of the following: + - MitoNet_v1 + - MitoNet_v2 + - MitoNet_v3 + - MitoNet_v4 + - MitoNet_v5 + - MitoNet_v6 + Reference: + - doi: 10.1016/j.cels.2022.12.006 """ thing_list = model_config["thing_list"] class_names = model_config["class_names"] @@ -244,8 +471,36 @@ def tracker_consensus( min_extent=4, dtype=np.uint32, ): - r"""Calculate the orthoplane consensus from trackers. Yields a numpy + """ + Calculate the orthoplane consensus from trackers. Yields a numpy or zarr volume along with the name of the class that is segmented. + + Args: + trackers (dict): The trackers dictionary. + model_config (str): The model configuration to use. + pixel_vote_thr (int): The pixel vote threshold. + cluster_iou_thr (float): The cluster IoU threshold. + allow_one_view (bool): Whether to allow one view. + min_size (int): The minimum size of objects. + min_extent (int): The minimum extent of objects. + dtype (type): The data type. + Returns: + generator: The generator object. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> for vol, class_name, tracker in tracker_consensus(trackers, model_config): + ... print(vol.shape, class_name, tracker) + Note: + The `model_config` parameter should be one of the following: + - MitoNet_v1 + - MitoNet_v2 + - MitoNet_v3 + - MitoNet_v4 + - MitoNet_v5 + - MitoNet_v6 + Reference: + - doi: 10.1016/j.cels.2022.12.006 """ labels = model_config["labels"] thing_list = model_config["thing_list"] diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 739e8699a..070ddc4b3 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -36,6 +36,20 @@ default="INFO", ) def cli(log_level): + """ + CLI for running the predict worker. + + The predict worker is used to apply a trained model to a dataset. + + Args: + log_level (str): The log level to use for logging. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> cli(log_level="INFO") + Note: + The method is implemented in the class. + """ logging.basicConfig(level=getattr(logging, log_level.upper())) @@ -69,6 +83,24 @@ def start_worker( output_container: Path | str, output_dataset: str, ): + """ + Start a worker to apply a trained model to a dataset. + + Args: + run_name (str): The name of the run to apply. + iteration (int or None): The training iteration of the model to use for prediction. + input_container (Path | str): The input container. + input_dataset (str): The input dataset. + output_container (Path | str): The output container. + output_dataset (str): The output dataset. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> start_worker(run_name="run", iteration=0, input_container="input", input_dataset="input", output_container="output", output_dataset="output") + Note: + The method is implemented in the class. + + """ compute_context = create_compute_context() device = compute_context.device @@ -186,13 +218,20 @@ def spawn_worker( input_array_identifier: "LocalArrayIdentifier", output_array_identifier: "LocalArrayIdentifier", ): - """Spawn a worker to predict on a given dataset. + """ + Spawn a worker to predict on a given dataset. Args: run_name (str): The name of the run to apply. iteration (int or None): The training iteration of the model to use for prediction. input_array_identifier (LocalArrayIdentifier): The raw data to predict on. output_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> spawn_worker(run_name="run", iteration=0, input_array_identifier=LocalArrayIdentifier(Path("input"), "input"), output_array_identifier=LocalArrayIdentifier(Path("output"), "output")) + Note: + The method is implemented in the class. """ compute_context = create_compute_context() @@ -219,6 +258,16 @@ def spawn_worker( print("Defining worker with command: ", compute_context.wrap_command(command)) def run_worker(): + """ + Run the worker in the given compute context. + + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> run_worker() + Note: + The method is implemented in the class. + """ # Run the worker in the given compute context print("Running worker with command: ", command) compute_context.execute(command) diff --git a/dacapo/blockwise/relabel_worker.py b/dacapo/blockwise/relabel_worker.py index b374f7120..7f98892aa 100644 --- a/dacapo/blockwise/relabel_worker.py +++ b/dacapo/blockwise/relabel_worker.py @@ -24,6 +24,18 @@ default="INFO", ) def cli(log_level): + """ + CLI for running the relabel worker. + + Args: + log_level (str): The log level to use. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> cli(log_level="INFO") + Note: + The method is implemented in the class. + """ logging.basicConfig(level=getattr(logging, log_level.upper())) @@ -43,6 +55,20 @@ def start_worker( *args, **kwargs, ): + """ + Start the relabel worker. + + Args: + output_container (str): The output container + output_dataset (str): The output dataset + tmpdir (str): The temporary directory + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> start_worker(output_container="output_container", output_dataset="output_dataset", tmpdir="tmpdir") + Note: + The method is implemented in the class. + """ client = daisy.Client() array_out = open_ds(output_container, output_dataset, mode="a") @@ -66,6 +92,21 @@ def start_worker( def relabel_in_block(array_out, old_values, new_values, block): + """ + Relabel the array in the given block. + + Args: + array_out (np.ndarray): The output array + old_values (np.ndarray): The old values + new_values (np.ndarray): The new values + block (daisy.Block): The block + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> relabel_in_block(array_out, old_values, new_values, block) + Note: + The method is implemented in the class. + """ a = array_out.to_ndarray(block.write_roi) # DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input if old_values.size > 0: @@ -74,6 +115,21 @@ def relabel_in_block(array_out, old_values, new_values, block): def find_components(nodes, edges): + """ + Find the components. + + Args: + nodes (np.ndarray): The nodes + edges (np.ndarray): The edges + Returns: + List[int]: The components + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> find_components(nodes, edges) + Note: + The method is implemented in the class. + """ # scipy disjoint_set = DisjointSet(nodes) for edge in edges: @@ -82,6 +138,20 @@ def find_components(nodes, edges): def read_cross_block_merges(tmpdir): + """ + Read the cross block merges. + + Args: + tmpdir (str): The temporary directory + Returns: + Tuple[np.ndarray, np.ndarray]: The nodes and edges + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> read_cross_block_merges(tmpdir) + Note: + The method is implemented in the class. + """ block_files = glob(os.path.join(tmpdir, "block_*.npz")) nodes = [] @@ -100,11 +170,20 @@ def spawn_worker( *args, **kwargs, ): - """Spawn a worker to predict on a given dataset. + """ + Spawn a worker to predict on a given dataset. Args: output_array_identifier (LocalArrayIdentifier): The output array identifier tmpdir (str): The temporary directory + Returns: + Callable: The function to run the worker + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> spawn_worker(output_array_identifier, tmpdir) + Note: + The method is implemented in the class. """ compute_context = create_compute_context() @@ -123,6 +202,16 @@ def spawn_worker( ] def run_worker(): + """ + Run the worker in the given compute context. + + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> run_worker() + Note: + The method is implemented in the class. + """ # Run the worker in the given compute context compute_context.execute(command) diff --git a/dacapo/blockwise/scheduler.py b/dacapo/blockwise/scheduler.py index ddea38280..1d321a8a7 100644 --- a/dacapo/blockwise/scheduler.py +++ b/dacapo/blockwise/scheduler.py @@ -28,46 +28,38 @@ def run_blockwise( *args, **kwargs, ): - """Run a function in parallel over a large volume. + """ + Run a function in parallel over a large volume. Args: - worker_file (``str`` or ``Path``): - The path to the file containing the necessary worker functions: ``spawn_worker`` and ``start_worker``. Optionally, the file can also contain a ``check_function`` and an ``init_callback_fn``. - total_roi (``Roi``): The ROI to process. - read_roi (``Roi``): The ROI to read from for a block. - write_roi (``Roi``): The ROI to write to for a block. - num_workers (``int``): - - The number of workers to use. - + The number of workers to use. max_retries (``int``): - - The maximum number of times a task will be retried if failed - (either due to failed post check or application crashes or network - failure) - + The maximum number of times a task will be retried if failed + (either due to failed post check or application crashes or network + failure) *args: - Additional positional arguments to pass to ``worker_function``. - **kwargs: - Additional keyword arguments to pass to ``worker_function``. - Returns: - - ``Bool``. + ``Bool``. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> run_blockwise(worker_file, total_roi, read_roi, write_roi, num_workers, max_retries, timeout, upstream_tasks) + Note: + The method is implemented in the class. """ @@ -103,58 +95,44 @@ def segment_blockwise( *args, **kwargs, ): - """Run a segmentation function in parallel over a large volume. + """ + Run a segmentation function in parallel over a large volume. Args: - - segment_function_file (``str`` or ``Path``): - - The path to the file containing the necessary worker functions: - ``spawn_worker`` and ``start_worker``. - Optionally, the file can also contain a ``check_function`` and an ``init_callback_fn``. - - context (``Coordinate``): - - The context to add to the read and write ROI. - - total_roi (``Roi``): - The ROI to process. - - read_roi (``Roi``): - The ROI to read from for a block. - - write_roi (``Roi``): - The ROI to write to for a block. - - num_workers (``int``): - - The number of workers to use. - - max_retries (``int``): - - The maximum number of times a task will be retried if failed - (either due to failed post check or application crashes or network - failure) - - timeout (``int``): - - The maximum time in seconds to wait for a worker to complete a task. - - upstream_tasks (``List``): - - List of upstream tasks. - - *args: - - Additional positional arguments to pass to ``worker_function``. - - **kwargs: - - Additional keyword arguments to pass to ``worker_function``. - + segment_function_file (``str`` or ``Path``): + The path to the file containing the necessary worker functions: + ``spawn_worker`` and ``start_worker``. + Optionally, the file can also contain a ``check_function`` and an ``init_callback_fn``. + context (``Coordinate``): + The context to add to the read and write ROI. + total_roi (``Roi``): + The ROI to process. + read_roi (``Roi``): + The ROI to read from for a block. + write_roi (``Roi``): + The ROI to write to for a block. + num_workers (``int``): + The number of workers to use. + max_retries (``int``): + The maximum number of times a task will be retried if failed + (either due to failed post check or application crashes or network + failure) + timeout (``int``): + The maximum time in seconds to wait for a worker to complete a task. + upstream_tasks (``List``): + List of upstream tasks. + *args: + Additional positional arguments to pass to ``worker_function``. + **kwargs: + Additional keyword arguments to pass to ``worker_function``. Returns: - ``Bool``. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> segment_blockwise(segment_function_file, context, total_roi, read_roi, write_roi, num_workers, max_retries, timeout, upstream_tasks) + Note: + The method is implemented in the class. """ options = Options.instance() if not options.runs_base_dir.exists(): diff --git a/dacapo/blockwise/segment_worker.py b/dacapo/blockwise/segment_worker.py index da1e0c098..9bc913197 100644 --- a/dacapo/blockwise/segment_worker.py +++ b/dacapo/blockwise/segment_worker.py @@ -26,6 +26,18 @@ default="INFO", ) def cli(log_level): + """ + CLI for running the segment worker. + + Args: + log_level (str): The log level to use. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> cli(log_level="INFO") + Note: + The method is implemented in the class. + """ logging.basicConfig(level=getattr(logging, log_level.upper())) @@ -49,7 +61,8 @@ def start_worker( tmpdir: str, function_path: str, ): - """Start a worker to run a segment function on a given dataset. + """ + Start a worker to run a segment function on a given dataset. Args: input_container (str): The input container. @@ -58,6 +71,12 @@ def start_worker( output_dataset (str): The output dataset. tmpdir (str): The temporary directory. function_path (str): The path to the segment function. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> start_worker(input_container="input_container", input_dataset="input_dataset", output_container="output_container", output_dataset="output_dataset", tmpdir="tmpdir", function_path="function_path") + Note: + The method is implemented in the class. """ print("Starting worker") @@ -184,12 +203,19 @@ def spawn_worker( tmpdir: str, function_path: str, ): - """Spawn a worker to predict on a given dataset. + """ + Spawn a worker to predict on a given dataset. Args: model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> spawn_worker(input_array_identifier="input_array_identifier", output_array_identifier="output_array_identifier", tmpdir="tmpdir", function_path="function_path") + Note: + The method is implemented in the class. """ compute_context = create_compute_context() @@ -214,6 +240,16 @@ def spawn_worker( ] def run_worker(): + """ + Run the worker in the given compute context. + + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> run_worker() + Note: + The method is implemented in the class. + """ # Run the worker in the given compute context compute_context.execute(command) diff --git a/dacapo/blockwise/threshold_worker.py b/dacapo/blockwise/threshold_worker.py index 3ff08c1e6..09a284f8e 100644 --- a/dacapo/blockwise/threshold_worker.py +++ b/dacapo/blockwise/threshold_worker.py @@ -27,6 +27,18 @@ default="INFO", ) def cli(log_level): + """ + CLI for running the threshold worker. + + Args: + log_level (str): The log level to use. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> cli(log_level="INFO") + Note: + The method is implemented in the class. + """ logging.basicConfig(level=getattr(logging, log_level.upper())) @@ -50,6 +62,23 @@ def start_worker( output_dataset: str, threshold: float = 0.0, ): + """ + Start the threshold worker. + + Args: + input_container (Path | str): The input container. + input_dataset (str): The input dataset. + output_container (Path | str): The output container. + output_dataset (str): The output dataset. + threshold (float): The threshold. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> start_worker(input_container="input", input_dataset="input", output_container="output", output_dataset="output", threshold=0.0) + Note: + The method is implemented in the class. + + """ # get arrays input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) input_array = ZarrArray.open_from_array_identifier(input_array_identifier) @@ -79,12 +108,23 @@ def spawn_worker( output_array_identifier: "LocalArrayIdentifier", threshold: float = 0.0, ): - """Spawn a worker to predict on a given dataset. + """ + Spawn a worker to predict on a given dataset. Args: model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. + block_shape (Tuple[int]): The shape of the blocks. + halo (Tuple[int]): The halo to use. + Returns: + Callable: The function to run the worker. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> spawn_worker(model, raw_array, prediction_array_identifier, block_shape, halo) + Note: + The method is implemented in the class. """ compute_context = create_compute_context() @@ -107,6 +147,16 @@ def spawn_worker( ] def run_worker(): + """ + Run the worker in the given compute context. + + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> run_worker() + Note: + The method is implemented in the class. + """ # Run the worker in the given compute context compute_context.execute(command) diff --git a/dacapo/blockwise/watershed_function.py b/dacapo/blockwise/watershed_function.py index 0c5deae6f..18cefc025 100644 --- a/dacapo/blockwise/watershed_function.py +++ b/dacapo/blockwise/watershed_function.py @@ -5,6 +5,29 @@ def segment_function(input_array, block, offsets, bias): + """ + Segment the input array using the multicut watershed algorithm. + + Args: + input_array (np.ndarray): The input array. + block (daisy.Block): The block to be processed. + offsets (List[Tuple[int]]): The offsets. + bias (float): The bias. + Returns: + np.ndarray: The segmented array. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> input_array = np.random.rand(128, 128, 128) + >>> block = daisy.Block((0, 0, 0), (128, 128, 128)) + >>> offsets = [(0, 1, 0), (1, 0, 0), (0, 0, 1)] + >>> bias = 0.1 + >>> segmentation = segment_function(input_array, block, offsets, bias) + Note: + The method is implemented in the class. + If a previous segmentation is provided, it must have a "grid graph" in its metadata. + DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input + """ # if a previous segmentation is provided, it must have a "grid graph" # in its metadata. pred_data = input_array[block.read_roi] diff --git a/dacapo/compute_context/bsub.py b/dacapo/compute_context/bsub.py index a3fb6aac5..f464841d2 100644 --- a/dacapo/compute_context/bsub.py +++ b/dacapo/compute_context/bsub.py @@ -10,6 +10,23 @@ @attr.s class Bsub(ComputeContext): + """ + The Bsub class is a subclass of the ComputeContext class. It is used to specify the + context in which computations are to be done. Bsub is used to specify that + computations are to be done on a cluster using LSF. + + Attributes: + queue (str): The queue to run on. + num_gpus (int): The number of gpus to train on. Currently only 1 gpu can be used. + num_cpus (int): The number of cpus to use to generate training data. + billing (Optional[str]): Project name that will be paying for this Job. + Methods: + device(): Returns the device on which computations are to be done. + _wrap_command(command): Wraps a command in the context specific command. + Note: + The class is a subclass of the ComputeContext class. + + """ queue: str = attr.ib(default="local", metadata={"help_text": "The queue to run on"}) num_gpus: int = attr.ib( default=1, @@ -33,12 +50,43 @@ class Bsub(ComputeContext): @property def device(self): + """ + A property method that returns the device on which computations are to be done. + + A device can be a CPU, GPU, TPU, etc. It is used to specify the context in which computations are to be done. + + Returns: + str: The device on which computations are to be done. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> context = Bsub() + >>> device = context.device + Note: + The method is implemented in the class. + """ if self.num_gpus > 0: return "cuda" else: return "cpu" def _wrap_command(self, command): + """ + A helper method to wrap a command in the context specific command. + + Args: + command (List[str]): The command to be wrapped. + Returns: + List[str]: The wrapped command. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> context = Bsub() + >>> command = ["python", "script.py"] + >>> wrapped_command = context._wrap_command(command) + Note: + The method is implemented in the class. + """ try: client = daisy.Client() basename = str( @@ -72,3 +120,4 @@ def _wrap_command(self, command): ) + command ) + diff --git a/dacapo/compute_context/compute_context.py b/dacapo/compute_context/compute_context.py index 57b4c4064..dd48065da 100644 --- a/dacapo/compute_context/compute_context.py +++ b/dacapo/compute_context/compute_context.py @@ -7,20 +7,96 @@ class ComputeContext(ABC): + """ + The ComputeContext class is an abstract base class for defining the context in which computations are to be done. + It is inherited from the built-in class `ABC` (Abstract Base Classes). Other classes can inherit this class to define + their own specific variations of the context. It requires to implement several property methods, and also includes + additional methods related to the context design. + + Attributes: + device: The device on which computations are to be done. + Methods: + _wrap_command(command): Wraps a command in the context specific command. + wrap_command(command): Wraps a command in the context specific command and returns it. + execute(command): Runs a command in the context specific way. + Note: + The class is abstract and requires to implement the abstract methods. + """ @property @abstractmethod def device(self): + """ + Abstract property method to define the device on which computations are to be done. + + A device can be a CPU, GPU, TPU, etc. It is used to specify the context in which computations are to be done. + + Returns: + str: The device on which computations are to be done. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> context = ComputeContext() + >>> device = context.device + Note: + The method should be implemented in the derived class. + """ pass def _wrap_command(self, command): + """ + A helper method to wrap a command in the context specific command. + + Args: + command (List[str]): The command to be wrapped. + Returns: + List[str]: The wrapped command. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> context = ComputeContext() + >>> command = ["python", "script.py"] + >>> wrapped_command = context._wrap_command(command) + Note: + The method should be implemented in the derived class. + """ # A helper method to wrap a command in the context specific command. return command def wrap_command(self, command): + """ + A method to wrap a command in the context specific command and return it. + + Args: + command (List[str]): The command to be wrapped. + Returns: + List[str]: The wrapped command. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> context = ComputeContext() + >>> command = ["python", "script.py"] + >>> wrapped_command = context.wrap_command(command) + Note: + The method should be implemented in the derived class. + """ command = [str(com) for com in self._wrap_command(command)] return command def execute(self, command): + """ + A method to run a command in the context specific way. + + Args: + command (List[str]): The command to be executed. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> context = ComputeContext() + >>> command = ["python", "script.py"] + >>> context.execute(command) + Note: + The method should be implemented in the derived class. + """ # A helper method to run a command in the context specific way. # add pythonpath to the environment @@ -31,7 +107,18 @@ def execute(self, command): def create_compute_context() -> ComputeContext: - """Create a compute context based on the global DaCapo options.""" + """ + Create a compute context based on the global DaCapo options. + + Returns: + ComputeContext: The compute context object. + Raises: + ValueError: If the store type is unknown. + Examples: + >>> context = create_compute_context() + Note: + The method is implemented in the module. + """ options = Options.instance() diff --git a/dacapo/compute_context/local_torch.py b/dacapo/compute_context/local_torch.py index 330e1899a..e70e8ae7d 100644 --- a/dacapo/compute_context/local_torch.py +++ b/dacapo/compute_context/local_torch.py @@ -15,6 +15,12 @@ class LocalTorch(ComputeContext): Attributes: _device (Optional[str]): This stores the type of device on which torch computations are to be done. It can take "cuda" for GPU or "cpu" for CPU. None value results in automatic detection of device type. + oom_limit (Optional[float | int]): The out of GPU memory to leave free in GB. If the free memory is below + this limit, we will fall back on CPU. + Methods: + device(): Returns the torch device object. + Note: + The class is a subclass of the ComputeContext class. """ _device: Optional[str] = attr.ib( @@ -37,6 +43,14 @@ def device(self): """ A property method that returns the torch device object. It automatically detects and uses "cuda" (GPU) if available, else it falls back on using "cpu". + + Returns: + torch.device: The torch device object. + Examples: + >>> context = LocalTorch() + >>> device = context.device + Note: + The method is implemented in the class. """ if self._device is None: if torch.cuda.is_available(): diff --git a/dacapo/experiments/architectures/architecture.py b/dacapo/experiments/architectures/architecture.py index 888030adb..0f188560e 100644 --- a/dacapo/experiments/architectures/architecture.py +++ b/dacapo/experiments/architectures/architecture.py @@ -11,6 +11,17 @@ class Architecture(torch.nn.Module, ABC): It is inherited from PyTorch's Module and built-in class `ABC` (Abstract Base Classes). Other classes can inherit this class to define their own specific variations of architecture. It requires to implement several property methods, and also includes additional methods related to the architecture design. + + Attributes: + input_shape (Coordinate): The spatial input shape for the neural network architecture. + eval_shape_increase (Coordinate): The amount to increase the input shape during prediction. + num_in_channels (int): The number of input channels required by the architecture. + num_out_channels (int): The number of output channels provided by the architecture. + Methods: + dims: Returns the number of dimensions of the input shape. + scale: Scales the input voxel size as required by the architecture. + Note: + The class is abstract and requires to implement the abstract methods. """ @property @@ -22,6 +33,14 @@ def input_shape(self) -> Coordinate: Returns: Coordinate: The spatial input shape. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> input_shape = Coordinate((128, 128, 128)) + >>> model = MyModel(input_shape) + Note: + The method should be implemented in the derived class. + """ pass @@ -32,6 +51,13 @@ def eval_shape_increase(self) -> Coordinate: Returns: Coordinate: An instance representing the amount to increase in each dimension of the input shape. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> eval_shape_increase = Coordinate((0, 0, 0)) + >>> model = MyModel(input_shape, eval_shape_increase) + Note: + The method is optional and can be overridden in the derived class. """ return Coordinate((0,) * self.input_shape.dims) @@ -43,6 +69,13 @@ def num_in_channels(self) -> int: Returns: int: Required number of input channels. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> num_in_channels = 1 + >>> model = MyModel(input_shape, num_in_channels) + Note: + The method should be implemented in the derived class. """ pass @@ -54,6 +87,14 @@ def num_out_channels(self) -> int: Returns: int: Number of output channels. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> num_out_channels = 1 + >>> model = MyModel(input_shape, num_out_channels) + Note: + The method should be implemented in the derived class. + """ pass @@ -64,6 +105,15 @@ def dims(self) -> int: Returns: int: The number of dimensions. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> input_shape = Coordinate((128, 128, 128)) + >>> model = MyModel(input_shape) + >>> model.dims + 3 + Note: + The method is optional and can be overridden in the derived class. """ return self.input_shape.dims @@ -73,8 +123,16 @@ def scale(self, input_voxel_size: Coordinate) -> Coordinate: Args: input_voxel_size (Coordinate): The original size of the input voxel. - Returns: Coordinate: The scaled voxel size. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> input_voxel_size = Coordinate((1, 1, 1)) + >>> model = MyModel(input_shape) + >>> model.scale(input_voxel_size) + Coordinate((1, 1, 1)) + Note: + The method is optional and can be overridden in the derived class. """ return input_voxel_size diff --git a/dacapo/experiments/architectures/architecture_config.py b/dacapo/experiments/architectures/architecture_config.py index 09455ce55..67ea080a2 100644 --- a/dacapo/experiments/architectures/architecture_config.py +++ b/dacapo/experiments/architectures/architecture_config.py @@ -5,18 +5,16 @@ @attr.s class ArchitectureConfig: """ - A class to represent the base configurations of any architecture. - - Attributes - ---------- - name : str - a unique name for the architecture. - - Methods - ------- - verify() - validates the given architecture. - + A class to represent the base configurations of any architecture. It is used to define the architecture of a neural network model. + + Attributes: + name : str + a unique name for the architecture. + Methods: + verify() + validates the given architecture. + Note: + The class is abstract and requires to implement the abstract methods. """ name: str = attr.ib( @@ -31,11 +29,15 @@ def verify(self) -> Tuple[bool, str]: """ A method to validate an architecture configuration. - Returns - ------- - bool - A flag indicating whether the config is valid or not. - str - A description of the architecture. + Returns: + Tuple[bool, str]: A tuple of a boolean indicating if the architecture is valid and a message. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> config = ArchitectureConfig("MyModel") + >>> is_valid, message = config.verify() + >>> print(is_valid, message) + Note: + The method should be implemented in the derived class. """ return True, "No validation for this Architecture" diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index bb2be3586..e28947479 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -7,7 +7,150 @@ class CNNectomeUNet(Architecture): + """ + A U-Net architecture for 3D or 4D data. The U-Net expects 3D or 4D tensors + shaped like:: + + ``(batch=1, channels, [length,] depth, height, width)``. + + This U-Net performs only "valid" convolutions, i.e., sizes of the feature + maps decrease after each convolution. It will perfrom 4D convolutions as + long as ``length`` is greater than 1. As soon as ``length`` is 1 due to a + valid convolution, the time dimension will be dropped and tensors with + ``(b, c, z, y, x)`` will be use (and returned) from there on. + + Attributes: + fmaps_in: + The number of input channels. + fmaps_out: + The number of feature maps in the output layer. This is also the + number of output feature maps. Stored in the ``channels`` dimension. + num_fmaps: + The number of feature maps in the first layer. This is also the + number of output feature maps. Stored in the ``channels`` dimension. + fmap_inc_factor: + By how much to multiply the number of feature maps between layers. + If layer 0 has ``k`` feature maps, layer ``l`` will have + ``k*fmap_inc_factor**l``. + downsample_factors: + List of tuples ``(z, y, x)`` to use to down- and up-sample the + feature maps between layers. + kernel_size_down (optional): + List of lists of kernel sizes. The number of sizes in a list + determines the number of convolutional layers in the corresponding + level of the build on the left side. Kernel sizes can be given as + tuples or integer. If not given, each convolutional pass will + consist of two 3x3x3 convolutions. + kernel_size_up (optional): + List of lists of kernel sizes. The number of sizes in a list + determines the number of convolutional layers in the corresponding + level of the build on the right side. Within one of the lists going + from left to right. Kernel sizes can be given as tuples or integer. + If not given, each convolutional pass will consist of two 3x3x3 + convolutions. + activation + Which activation to use after a convolution. Accepts the name of + any tensorflow activation function (e.g., ``ReLU`` for + ``torch.nn.ReLU``). + fov (optional): + Initial field of view in physical units + voxel_size (optional): + Size of a voxel in the input data, in physical units + num_heads (optional): + Number of decoders. The resulting U-Net has one single encoder + path and num_heads decoder paths. This is useful in a multi-task + learning context. + constant_upsample (optional): + If set to true, perform a constant upsampling instead of a + transposed convolution in the upsampling layers. + padding (optional): + How to pad convolutions. Either 'same' or 'valid' (default). + upsample_channel_contraction: + When performing the ConvTranspose, whether to reduce the number + of channels by the fmap_increment_factor. can be either bool or + list of bools to apply independently per layer. + activation_on_upsample: + Whether or not to add an activation after the upsample operation. + use_attention: + Whether or not to use an attention block in the U-Net. + Methods: + forward(x): + Forward pass of the U-Net. + scale(voxel_size): + Scale the voxel size according to the upsampling factors. + input_shape: + Return the input shape of the U-Net. + num_in_channels: + Return the number of input channels. + num_out_channels: + Return the number of output channels. + eval_shape_increase: + Return the increase in shape due to the U-Net. + Note: + This class is a wrapper around the ``CNNectomeUNetModule`` class. + The ``CNNectomeUNetModule`` class is the actual implementation of the + U-Net architecture. + """ def __init__(self, architecture_config): + """ + Initialize the U-Net architecture. + + Args: + architecture_config (dict): A dictionary containing the configuration + of the U-Net architecture. The dictionary should contain the following + keys: + - input_shape: The shape of the input data. + - fmaps_out: The number of output feature maps. + - fmaps_in: The number of input feature maps. + - num_fmaps: The number of feature maps in the first layer. + - fmap_inc_factor: The factor by which the number of feature maps + increases between layers. + - downsample_factors: List of tuples ``(z, y, x)`` to use to down- + and up-sample the feature maps between layers. + - kernel_size_down (optional): List of lists of kernel sizes. The + number of sizes in a list determines the number of convolutional + layers in the corresponding level of the build on the left side. + Kernel sizes can be given as tuples or integer. If not given, each + convolutional pass will consist of two 3x3x3 convolutions. + - kernel_size_up (optional): List of lists of kernel sizes. The + number of sizes in a list determines the number of convolutional + layers in the corresponding level of the build on the right side. + Within one of the lists going from left to right. Kernel sizes can + be given as tuples or integer. If not given, each convolutional + pass will consist of two 3x3x3 convolutions. + - constant_upsample (optional): If set to true, perform a constant + upsampling instead of a transposed convolution in the upsampling + layers. + - padding (optional): How to pad convolutions. Either 'same' or + 'valid' (default). + - upsample_factors (optional): List of tuples ``(z, y, x)`` to use + to upsample the feature maps between layers. + - activation_on_upsample (optional): Whether or not to add an + activation after the upsample operation. + - use_attention (optional): Whether or not to use an attention block + in the U-Net. + Raises: + ValueError: If the input shape is not given. + Examples: + >>> architecture_config = { + ... "input_shape": (1, 1, 128, 128, 128), + ... "fmaps_out": 1, + ... "fmaps_in": 1, + ... "num_fmaps": 24, + ... "fmap_inc_factor": 2, + ... "downsample_factors": [(2, 2, 2), (2, 2, 2), (2, 2, 2)], + ... "kernel_size_down": [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... "kernel_size_up": [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... "constant_upsample": False, + ... "padding": "valid", + ... "upsample_factors": [(2, 2, 2), (2, 2, 2), (2, 2, 2)], + ... "activation_on_upsample": True, + ... "use_attention": False + ... } + >>> unet = CNNectomeUNet(architecture_config) + Note: + The input shape should be given as a tuple ``(batch, channels, [length,] depth, height, width)``. + """ super().__init__() self._input_shape = architecture_config.input_shape @@ -31,11 +174,62 @@ def __init__(self, architecture_config): @property def eval_shape_increase(self): + """ + The increase in shape due to the U-Net. + + Returns: + The increase in shape due to the U-Net. + Raises: + AttributeError: If the increase in shape is not given. + Examples: + >>> unet.eval_shape_increase + (1, 1, 128, 128, 128) + Note: + The increase in shape should be given as a tuple ``(batch, channels, [length,] depth, height, width)``. + """ if self._eval_shape_increase is None: return super().eval_shape_increase return self._eval_shape_increase def module(self): + """ + Create the U-Net module. + + Returns: + The U-Net module. + Raises: + AttributeError: If the number of input channels is not given. + AttributeError: If the number of output channels is not given. + AttributeError: If the number of feature maps in the first layer is not given. + AttributeError: If the factor by which the number of feature maps increases between layers is not given. + AttributeError: If the downsample factors are not given. + AttributeError: If the kernel sizes for the down pass are not given. + AttributeError: If the kernel sizes for the up pass are not given. + AttributeError: If the constant upsample flag is not given. + AttributeError: If the padding is not given. + AttributeError: If the upsample factors are not given. + AttributeError: If the activation on upsample flag is not given. + AttributeError: If the use attention flag is not given. + Examples: + >>> unet.module() + CNNectomeUNetModule( + in_channels=1, + num_fmaps=24, + num_fmaps_out=1, + fmap_inc_factor=2, + kernel_size_down=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + kernel_size_up=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], + constant_upsample=False, + padding='valid', + activation_on_upsample=True, + upsample_channel_contraction=[False, True, True], + use_attention=False + ) + Note: + The U-Net module is an instance of the ``CNNectomeUNetModule`` class. + + """ fmaps_in = self.fmaps_in levels = len(self.downsample_factors) + 1 dims = len(self.downsample_factors[0]) @@ -91,27 +285,158 @@ def module(self): return unet def scale(self, voxel_size): + """ + Scale the voxel size according to the upsampling factors. + + Args: + voxel_size (tuple): The size of a voxel in the input data. + Returns: + The scaled voxel size. + Raises: + ValueError: If the voxel size is not given. + Examples: + >>> unet.scale((1, 1, 1)) + (1, 1, 1) + Note: + The voxel size should be given as a tuple ``(z, y, x)``. + """ for upsample_factor in self.upsample_factors: voxel_size = voxel_size / upsample_factor return voxel_size @property def input_shape(self): + """ + Return the input shape of the U-Net. + + Returns: + The input shape of the U-Net. + Raises: + AttributeError: If the input shape is not given. + Examples: + >>> unet.input_shape + (1, 1, 128, 128, 128) + Note: + The input shape should be given as a tuple ``(batch, channels, [length,] depth, height, width)``. + """ return self._input_shape @property def num_in_channels(self) -> int: + """ + Return the number of input channels. + + Returns: + The number of input channels. + Raises: + AttributeError: If the number of input channels is not given. + Examples: + >>> unet.num_in_channels + 1 + Note: + The number of input channels should be given as an integer. + """ return self.fmaps_in @property def num_out_channels(self) -> int: + """ + Return the number of output channels. + + Returns: + The number of output channels. + Raises: + AttributeError: If the number of output channels is not given. + Examples: + >>> unet.num_out_channels + 1 + Note: + The number of output channels should be given as an integer. + """ return self.fmaps_out def forward(self, x): + """ + Forward pass of the U-Net. + + Args: + x (Tensor): The input tensor. + Returns: + The output tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> unet = CNNectomeUNet(architecture_config) + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> unet(x) + Note: + The input tensor should be given as a 5D tensor. + """ return self.unet(x) class CNNectomeUNetModule(torch.nn.Module): + """ + A U-Net module for 3D or 4D data. The U-Net expects 3D or 4D tensors shaped + like:: + + ``(batch=1, channels, [length,] depth, height, width)``. + + This U-Net performs only "valid" convolutions, i.e., sizes of the feature maps + decrease after each convolution. It will perfrom 4D convolutions as long as + ``length`` is greater than 1. As soon as ``length`` is 1 due to a valid + convolution, the time dimension will be dropped and tensors with ``(b, c, z, y, x)`` + will be use (and returned) from there on. + + Attributes: + num_levels: + The number of levels in the U-Net. + num_heads: + The number of decoders. + in_channels: + The number of input channels. + out_channels: + The number of output channels. + dims: + The number of dimensions. + use_attention: + Whether or not to use an attention block in the U-Net. + l_conv: + The left convolutional passes. + l_down: + The left downsample layers. + r_up: + The right up/crop/concatenate layers. + r_conv: + The right convolutional passes. + kernel_size_down: + The kernel sizes for the down pass. + kernel_size_up: + The kernel sizes for the up pass. + fmap_inc_factor: + The factor by which the number of feature maps increases between layers. + downsample_factors: + The downsample factors. + constant_upsample: + Whether to perform a constant upsampling instead of a transposed convolution. + padding: + How to pad convolutions. + upsample_channel_contraction: + Whether to reduce the number of channels by the fmap_increment_factor. + activation_on_upsample: + Whether or not to add an activation after the upsample operation. + use_attention: + Whether or not to use an attention block in the U-Net. + attention: + The attention blocks. + Methods: + rec_forward(level, f_in): + Recursive forward pass of the U-Net. + forward(x): + Forward pass of the U-Net. + Note: + The input tensor should be given as a 5D tensor. + """ def __init__( self, in_channels, @@ -129,7 +454,8 @@ def __init__( activation_on_upsample=False, use_attention=False, ): - """Create a U-Net:: + """ + Create a U-Net:: f_in --> f_left --------------------------->> f_right--> f_out | ^ @@ -155,83 +481,80 @@ def __init__( from there on. Args: - in_channels: - The number of input channels. - num_fmaps: - The number of feature maps in the first layer. This is also the number of output feature maps. Stored in the ``channels`` dimension. - fmap_inc_factor: - By how much to multiply the number of feature maps between layers. If layer 0 has ``k`` feature maps, layer ``l`` will have ``k*fmap_inc_factor**l``. - downsample_factors: - List of tuples ``(z, y, x)`` to use to down- and up-sample the feature maps between layers. - kernel_size_down (optional): - List of lists of kernel sizes. The number of sizes in a list determines the number of convolutional layers in the corresponding level of the build on the left side. Kernel sizes can be given as tuples or integer. If not given, each convolutional pass will consist of two 3x3x3 convolutions. - kernel_size_up (optional): - List of lists of kernel sizes. The number of sizes in a list determines the number of convolutional layers in the corresponding level of the build on the right side. Within one of the lists going from left to right. Kernel sizes can be given as tuples or integer. If not given, each convolutional pass will consist of two 3x3x3 convolutions. - activation: - Which activation to use after a convolution. Accepts the name of any tensorflow activation function (e.g., ``ReLU`` for ``torch.nn.ReLU``). - fov (optional): - Initial field of view in physical units - voxel_size (optional): - Size of a voxel in the input data, in physical units - num_heads (optional): - Number of decoders. The resulting U-Net has one single encoder path and num_heads decoder paths. This is useful in a multi-task learning context. - constant_upsample (optional): - If set to true, perform a constant upsampling instead of a transposed convolution in the upsampling layers. - padding (optional): - How to pad convolutions. Either 'same' or 'valid' (default). - upsample_channel_contraction: - When performing the ConvTranspose, whether to reduce the number of channels by the fmap_increment_factor. can be either bool or list of bools to apply independently per layer. - activation_on_upsample: - Whether or not to add an activation after the upsample operation. + use_attention: + Whether or not to use an attention block in the U-Net. + attention: + The attention blocks. + Returns: + The U-Net module. + Raises: + ValueError: If the number of input channels is not given. + Examples: + >>> unet = CNNectomeUNetModule( + ... in_channels=1, + ... num_fmaps=24, + ... num_fmaps_out=1, + ... fmap_inc_factor=2, + ... kernel_size_down=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... kernel_size_up=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], + ... constant_upsample=False, + ... padding='valid', + ... activation_on_upsample=True, + ... upsample_channel_contraction=[False, True, True], + ... use_attention=False + ... ) + Note: + The input tensor should be given as a 5D tensor. """ super().__init__() @@ -378,6 +701,36 @@ def __init__( ) def rec_forward(self, level, f_in): + """ + Recursive forward pass of the U-Net. + + Args: + level (int): The level of the U-Net. + f_in (Tensor): The input tensor. + Returns: + The output tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> unet = CNNectomeUNetModule( + ... in_channels=1, + ... num_fmaps=24, + ... num_fmaps_out=1, + ... fmap_inc_factor=2, + ... kernel_size_down=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... kernel_size_up=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], + ... constant_upsample=False, + ... padding='valid', + ... activation_on_upsample=True, + ... upsample_channel_contraction=[False, True, True], + ... use_attention=False + ... ) + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> unet.rec_forward(2, x) + Note: + The input tensor should be given as a 5D tensor. + """ # index of level in layer arrays i = self.num_levels - level - 1 @@ -415,6 +768,35 @@ def rec_forward(self, level, f_in): return fs_out def forward(self, x): + """ + Forward pass of the U-Net. + + Args: + x (Tensor): The input tensor. + Returns: + The output tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> unet = CNNectomeUNetModule( + ... in_channels=1, + ... num_fmaps=24, + ... num_fmaps_out=1, + ... fmap_inc_factor=2, + ... kernel_size_down=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... kernel_size_up=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], + ... constant_upsample=False, + ... padding='valid', + ... activation_on_upsample=True, + ... upsample_channel_contraction=[False, True, True], + ... use_attention=False + ... ) + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> unet(x) + Note: + The input tensor should be given as a 5D tensor. + """ y = self.rec_forward(self.num_levels - 1, x) if self.num_heads == 1: @@ -424,9 +806,44 @@ def forward(self, x): class ConvPass(torch.nn.Module): + """ + Convolutional pass module. This module performs a series of convolutional + layers followed by an activation function. The module can also pad the + feature maps to ensure translation equivariance. The module can perform + 2D or 3D convolutions. + + Attributes: + dims: + The number of dimensions. + conv_pass: + The convolutional pass module. + Methods: + forward(x): + Forward pass of the Conv + Note: + The input tensor should be given as a 5D tensor. + """ def __init__( self, in_channels, out_channels, kernel_sizes, activation, padding="valid" ): + """ + Convolutional pass module. This module performs a series of + convolutional layers followed by an activation function. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + kernel_sizes (list): The kernel sizes for the convolutional layers. + activation (str): The activation function to use. + padding (optional): How to pad convolutions. Either 'same' or 'valid'. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> conv_pass = ConvPass(1, 1, [(3, 3, 3), (3, 3, 3)], "ReLU") + Note: + The input tensor should be given as a 5D tensor. + + """ super(ConvPass, self).__init__() if activation is not None: @@ -460,11 +877,60 @@ def __init__( self.conv_pass = torch.nn.Sequential(*layers) def forward(self, x): + """ + Forward pass of the ConvPass module. + + Args: + x (Tensor): The input tensor. + Returns: + The output tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> conv_pass = ConvPass(1, 1, [(3, 3, 3), (3, 3, 3)], "ReLU") + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> conv_pass(x) + Note: + The input tensor should be given as a 5D tensor. + """ return self.conv_pass(x) class Downsample(torch.nn.Module): + """ + Downsample module. This module performs downsampling of the input tensor + using either max-pooling or average pooling. The module can also crop the + feature maps to ensure translation equivariance with a stride of the + downsampling factor. + + Attributes: + dims: + The number of dimensions. + downsample_factor: + The downsampling factor. + down: + The downsampling layer. + Methods: + forward(x): + Downsample the input tensor. + Note: + The input tensor should be given as a 5D tensor. + + """ def __init__(self, downsample_factor): + """ + Downsample module. This module performs downsampling of the input tensor + using either max-pooling or average pooling. + + Args: + downsample_factor (tuple): The downsampling factor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> downsample = Downsample((2, 2, 2)) + Note: + The input tensor should be given as a 5D tensor. + """ super(Downsample, self).__init__() self.dims = len(downsample_factor) @@ -479,6 +945,22 @@ def __init__(self, downsample_factor): self.down = pool(downsample_factor, stride=downsample_factor) def forward(self, x): + """ + Downsample the input tensor. + + Args: + x (Tensor): The input tensor. + Returns: + The downsampled tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> downsample = Downsample((2, 2, 2)) + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> downsample(x) + Note: + The input tensor should be given as a 5D tensor. + """ for d in range(1, self.dims + 1): if x.size()[-d] % self.downsample_factor[-d] != 0: raise RuntimeError( @@ -491,6 +973,33 @@ def forward(self, x): class Upsample(torch.nn.Module): + """ + Upsample module. This module performs upsampling of the input tensor using + either transposed convolutions or nearest neighbor interpolation. The + module can also crop the feature maps to ensure translation equivariance + with a stride of the upsampling factor. + + Attributes: + crop_factor: + The crop factor. + next_conv_kernel_sizes: + The kernel sizes for the convolutional layers. + dims: + The number of dimensions. + up: + The upsampling layer. + Methods: + crop_to_factor(x, factor, kernel_sizes): + Crop feature maps to ensure translation equivariance with stride of + upsampling factor. + crop(x, shape): + Center-crop x to match spatial dimensions given by shape. + forward(g_out, f_left=None): + Forward pass of the Upsample module. + Note: + The input tensor should be given as a 5D tensor. + + """ def __init__( self, scale_factor, @@ -501,6 +1010,27 @@ def __init__( next_conv_kernel_sizes=None, activation=None, ): + """ + Upsample module. This module performs upsampling of the input tensor + + Args: + scale_factor (tuple): The upsampling factor. + mode (optional): The upsampling mode. Either 'transposed_conv' or + 'nearest + in_channels (optional): The number of input channels. + out_channels (optional): The number of output channels. + crop_factor (optional): The crop factor. + next_conv_kernel_sizes (optional): The kernel sizes for the convolutional layers. + activation (optional): The activation function to use. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> upsample = Upsample(scale_factor=(2, 2, 2), in_channels=1, out_channels=1) + >>> upsample = Upsample(scale_factor=(2, 2, 2), in_channels=1, out_channels=1, activation="ReLU") + >>> upsample = Upsample(scale_factor=(2, 2, 2), in_channels=1, out_channels=1, crop_factor=(2, 2, 2), next_conv_kernel_sizes=[(3, 3, 3), (3, 3, 3)]) + Note: + The input tensor should be given as a 5D tensor. + """ super(Upsample, self).__init__() if activation is not None: @@ -548,12 +1078,50 @@ def __init__( self.up = layers[0] def crop_to_factor(self, x, factor, kernel_sizes): - """Crop feature maps to ensure translation equivariance with stride of + """ + Crop feature maps to ensure translation equivariance with stride of upsampling factor. This should be done right after upsampling, before application of the convolutions with the given kernel sizes. The crop could be done after the convolutions, but it is more efficient to do that before (feature maps will be smaller). + + We need to ensure that the feature map is large enough to ensure that + the translation equivariance is maintained. This is done by cropping + the feature map to the largest size that is a multiple of the factor + and that is large enough to ensure that the translation equivariance + is maintained. + + We need (spatial_shape - convolution_crop) to be a multiple of factor, + i.e.: + (s - c) = n*k + + where s is the spatial size of the feature map, c is the crop due to + the convolutions, n is the number of strides of the upsampling factor, + and k is the upsampling factor. + + We want to find the largest n for which s' = n*k + c <= s + + n = floor((s - c)/k) + + This gives us the target shape s' + + s' = n*k + c + + Args: + x (Tensor): The input tensor. + factor (tuple): The upsampling factor. + kernel_sizes (list): The kernel sizes for the convolutional layers. + Returns: + The cropped tensor. + Raises: + RuntimeError: If the feature map is too small to ensure translation equivariance. + Examples: + >>> upsample = Upsample(scale_factor=(2, 2, 2), in_channels=1, out_channels=1) + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> upsample.crop_to_factor(x, (2, 2, 2), [(3, 3, 3), (3, 3, 3)]) + Note: + The input tensor should be given as a 5D tensor. """ shape = x.size() @@ -599,7 +1167,23 @@ def crop_to_factor(self, x, factor, kernel_sizes): return x def crop(self, x, shape): - """Center-crop x to match spatial dimensions given by shape.""" + """ + Center-crop x to match spatial dimensions given by shape. + + Args: + x (Tensor): The input tensor. + shape (tuple): The target shape. + Returns: + The center-cropped tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> upsample = Upsample(scale_factor=(2, 2, 2), in_channels=1, out_channels=1) + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> upsample.crop(x, (32, 32, 32)) + Note: + The input tensor should be given as a 5D tensor. + """ x_target_size = x.size()[: -self.dims] + shape @@ -610,6 +1194,24 @@ def crop(self, x, shape): return x[slices] def forward(self, g_out, f_left=None): + """ + Forward pass of the Upsample module. + + Args: + g_out (Tensor): The gating signal tensor. + f_left (Tensor): The input feature tensor. + Returns: + The output feature tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> upsample = Upsample(scale_factor=(2, 2, 2), in_channels=1, out_channels=1) + >>> g_out = torch.randn(1, 1, 64, 64, 64) + >>> f_left = torch.randn(1, 1, 32, 32, 32) + >>> upsample(g_out, f_left) + Note: + The gating signal and input feature tensors should be given as 5D tensors. + """ g_up = self.up(g_out) if self.next_conv_kernel_sizes is not None: @@ -628,41 +1230,71 @@ def forward(self, g_out, f_left=None): class AttentionBlockModule(nn.Module): - def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): - """Attention Block Module:: + """ + Attention Block Module: + + The AttentionBlock uses two separate pathways to process 'g' and 'x', + combines them, and applies a sigmoid activation to generate an attention map. + This map is then used to scale the input features 'x', resulting in an output + that focuses on important features as dictated by the gating signal 'g'. - The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). + The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). [g] --> W_g --\ /--> psi --> * --> [output] \ / [x] --> W_x --> [+] --> relu -- Where: - - W_g and W_x are 1x1 Convolution followed by Batch Normalization - - [+] indicates element-wise addition - - relu is the Rectified Linear Unit activation function - - psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation - - * indicates element-wise multiplication between the output of psi and input feature 'x' - - [output] has the same dimensions as input 'x', selectively emphasized by attention weights + - W_g and W_x are 1x1 Convolution followed by Batch Normalization + - [+] indicates element-wise addition + - relu is the Rectified Linear Unit activation function + - psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation + - * indicates element-wise multiplication between the output of psi and input feature 'x' + - [output] has the same dimensions as input 'x', selectively emphasized by attention weights + + Attributes: + dims: + The number of dimensions of the input tensors. + kernel_sizes: + The kernel sizes for the convolutional layers. + upsample_factor: + The factor by which to upsample the attention map. + W_g: + The 1x1 Convolutional layer for the gating signal. + W_x: + The 1x1 Convolutional layer for the input features. + psi: + The 1x1 Convolutional layer followed by Sigmoid activation. + up: + The upsampling layer to match the dimensions of the input features. + relu: + The Rectified Linear Unit activation function. + Methods: + calculate_and_apply_padding(smaller_tensor, larger_tensor): + Calculate and apply symmetric padding to the smaller tensor to match the dimensions of the larger tensor. + forward(g, x): + Forward pass of the Attention Block. + Note: + The AttentionBlockModule is an instance of the ``torch.nn.Module`` class. + """ + def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): + """ + Initialize the Attention Block Module. Args: - F_g (int): The number of feature channels in the gating signal (g). - This is the input channel dimension for the W_g convolutional layer. - - F_l (int): The number of feature channels in the input features (x). - This is the input channel dimension for the W_x convolutional layer. - - F_int (int): The number of intermediate feature channels. - This represents the output channel dimension of the W_g and W_x convolutional layers - and the input channel dimension for the psi layer. Typically, F_int is smaller - than F_g and F_l, as it serves to compress the feature representations before - applying the attention mechanism. - - The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them, - and applies a sigmoid activation to generate an attention map. This map is then used - to scale the input features 'x', resulting in an output that focuses on important - features as dictated by the gating signal 'g'. - + F_g (int): The number of feature maps in the gating signal tensor. + F_l (int): The number of feature maps in the input feature tensor. + F_int (int): The number of feature maps in the intermediate tensor. + dims (int): The number of dimensions of the input tensors. + upsample_factor (optional): The factor by which to upsample the attention map. + Returns: + The Attention Block Module. + Raises: + RuntimeError: If the gating signal and input feature tensors have different dimensions. + Examples: + >>> attention_block = AttentionBlockModule(F_g=1, F_l=1, F_int=1, dims=3) + Note: + The number of feature maps should be given as an integer. """ super(AttentionBlockModule, self).__init__() @@ -709,11 +1341,19 @@ def calculate_and_apply_padding(self, smaller_tensor, larger_tensor): Calculate and apply symmetric padding to the smaller tensor to match the dimensions of the larger tensor. Args: - smaller_tensor (Tensor): The tensor to be padded. - larger_tensor (Tensor): The tensor whose dimensions the smaller tensor needs to match. - + smaller_tensor (Tensor): The tensor to be padded. + larger_tensor (Tensor): The tensor whose dimensions the smaller tensor needs to match. Returns: - Tensor: The padded smaller tensor with the same dimensions as the larger tensor. + Tensor: The padded smaller tensor with the same dimensions as the larger tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> larger_tensor = torch.randn(1, 1, 128, 128, 128) + >>> smaller_tensor = torch.randn(1, 1, 64, 64, 64) + >>> attention_block = AttentionBlockModule(F_g=1, F_l=1, F_int=1, dims=3) + >>> padded_tensor = attention_block.calculate_and_apply_padding(smaller_tensor, larger_tensor) + Note: + The tensors should have the same dimensions. """ padding = [] for i in range(2, 2 + self.dims): @@ -727,6 +1367,24 @@ def calculate_and_apply_padding(self, smaller_tensor, larger_tensor): return nn.functional.pad(smaller_tensor, padding, mode="constant", value=0) def forward(self, g, x): + """ + Forward pass of the Attention Block. + + Args: + g (Tensor): The gating signal tensor. + x (Tensor): The input feature tensor. + Returns: + Tensor: The output tensor with the same dimensions as the input feature tensor. + Raises: + RuntimeError: If the gating signal and input feature tensors have different dimensions. + Examples: + >>> g = torch.randn(1, 1, 128, 128, 128) + >>> x = torch.randn(1, 1, 128, 128, 128) + >>> attention_block = AttentionBlockModule(F_g=1, F_l=1, F_int=1, dims=3) + >>> output = attention_block(g, x) + Note: + The gating signal and input feature tensors should have the same dimensions. + """ g1 = self.W_g(g) x1 = self.W_x(x) g1 = self.calculate_and_apply_padding(g1, x1) @@ -734,3 +1392,4 @@ def forward(self, g, x): psi = self.psi(psi) psi = self.up(psi) return x * psi + diff --git a/dacapo/experiments/architectures/cnnectome_unet_config.py b/dacapo/experiments/architectures/cnnectome_unet_config.py index c0e9e5b9d..be59e7069 100644 --- a/dacapo/experiments/architectures/cnnectome_unet_config.py +++ b/dacapo/experiments/architectures/cnnectome_unet_config.py @@ -10,10 +10,49 @@ @attr.s class CNNectomeUNetConfig(ArchitectureConfig): - """This class configures the CNNectomeUNet based on + """ + This class configures the CNNectomeUNet based on https://github.com/saalfeldlab/CNNectome/blob/master/CNNectome/networks/unet_class.py Includes support for super resolution via the upsampling factors. + + Attributes: + input_shape: Coordinate + The shape of the data passed into the network during training. + fmaps_out: int + The number of channels produced by your architecture. + fmaps_in: int + The number of channels expected from the raw data. + num_fmaps: int + The number of feature maps in the top level of the UNet. + fmap_inc_factor: int + The multiplication factor for the number of feature maps for each level of the UNet. + downsample_factors: List[Coordinate] + The factors to downsample the feature maps along each axis per layer. + kernel_size_down: Optional[List[Coordinate]] + The size of the convolutional kernels used before downsampling in each layer. + kernel_size_up: Optional[List[Coordinate]] + The size of the convolutional kernels used before upsampling in each layer. + _eval_shape_increase: Optional[Coordinate] + The amount by which to increase the input size when just prediction rather than training. + It is generally possible to significantly increase the input size since we don't have the memory + constraints of the gradients, the optimizer and the batch size. + upsample_factors: Optional[List[Coordinate]] + The amount by which to upsample the output of the UNet. + constant_upsample: bool + Whether to use a transpose convolution or simply copy voxels to upsample. + padding: str + The padding to use in convolution operations. + use_attention: bool + Whether to use attention blocks in the UNet. This is supported for 2D and 3D. + Methods: + architecture_type() + Returns the architecture type. + Note: + The architecture_type attribute is set to CNNectomeUNet. + References: + Saalfeld, S., Fetter, R., Cardona, A., & Tomancak, P. (2012). + """ architecture_type = CNNectomeUNet diff --git a/dacapo/experiments/architectures/dummy_architecture.py b/dacapo/experiments/architectures/dummy_architecture.py index 70a0d5d3e..fa5a889e7 100644 --- a/dacapo/experiments/architectures/dummy_architecture.py +++ b/dacapo/experiments/architectures/dummy_architecture.py @@ -12,15 +12,27 @@ class DummyArchitecture(Architecture): channels_out: An integer representing the number of output channels. conv: A 3D convolution object. input_shape: A coordinate object representing the shape of the input. - Methods: forward(x): Performs the forward pass of the network. + num_in_channels(): Returns the number of input channels for this architecture. + num_out_channels(): Returns the number of output channels for this architecture. + Note: + This class is used to represent a dummy architecture layer for a 3D CNN. """ def __init__(self, architecture_config): """ + Constructor for the DummyArchitecture class. Initializes the 3D convolution object. + Args: - architecture_config: An object containing the configuration settings for the architecture. + architecture_config: An architecture configuration object. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> architecture_config = ArchitectureConfig(num_in_channels=1, num_out_channels=1) + >>> dummy_architecture = DummyArchitecture(architecture_config) + Note: + This method is used to initialize the DummyArchitecture class. """ super().__init__() @@ -36,6 +48,13 @@ def input_shape(self): Returns: Coordinate: Input shape of the architecture. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> dummy_architecture.input_shape + Coordinate(x=40, y=20, z=20) + Note: + This method is used to return the input shape for this architecture. """ return Coordinate(40, 20, 20) @@ -46,6 +65,13 @@ def num_in_channels(self): Returns: int: Number of input channels. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> dummy_architecture.num_in_channels + 1 + Note: + This method is used to return the number of input channels for this architecture. """ return self.channels_in @@ -56,6 +82,13 @@ def num_out_channels(self): Returns: int: Number of output channels. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> dummy_architecture.num_out_channels + 1 + Note: + This method is used to return the number of output channels for this architecture. """ return self.channels_out @@ -65,8 +98,15 @@ def forward(self, x): Args: x: Input tensor. - Returns: Tensor: Output tensor after the forward pass. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> dummy_architecture = DummyArchitecture(architecture_config) + >>> x = torch.randn(1, 1, 40, 20, 20) + >>> dummy_architecture.forward(x) + Note: + This method is used to perform the forward pass of the network. """ return self.conv(x) diff --git a/dacapo/experiments/architectures/dummy_architecture_config.py b/dacapo/experiments/architectures/dummy_architecture_config.py index eaf9b7027..695d8bc41 100644 --- a/dacapo/experiments/architectures/dummy_architecture_config.py +++ b/dacapo/experiments/architectures/dummy_architecture_config.py @@ -8,7 +8,8 @@ @attr.s class DummyArchitectureConfig(ArchitectureConfig): - """A dummy architecture configuration class used for testing purposes. + """ + A dummy architecture configuration class used for testing purposes. It extends the base class "ArchitectureConfig". This class contains dummy attributes and always returns that the configuration is invalid when verified. @@ -20,6 +21,10 @@ class DummyArchitectureConfig(ArchitectureConfig): functionality or meaning. num_out_channels (int): The number of output channels. This is also a dummy attribute and has no real functionality or meaning. + Methods: + verify(self) -> Tuple[bool, str]: This method is used to check whether this is a valid architecture configuration. + Note: + This class is used to represent a DummyArchitectureConfig object in the system. """ architecture_type = DummyArchitecture @@ -29,13 +34,22 @@ class DummyArchitectureConfig(ArchitectureConfig): num_out_channels: int = attr.ib(metadata={"help_text": "Dummy attribute."}) def verify(self) -> Tuple[bool, str]: - """Verifies the configuration validity. + """ + Verifies the configuration validity. Since this is a dummy configuration for testing purposes, this method always returns False indicating that the configuration is invalid. Returns: tuple: A tuple containing a boolean validity flag and a reason message string. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> dummy_architecture_config = DummyArchitectureConfig(num_in_channels=1, num_out_channels=1) + >>> dummy_architecture_config.verify() + (False, "This is a DummyArchitectureConfig and is never valid") + Note: + This method is used to check whether this is a valid architecture configuration. """ return False, "This is a DummyArchitectureConfig and is never valid" diff --git a/dacapo/experiments/arraytypes/annotations.py b/dacapo/experiments/arraytypes/annotations.py index f7fc2f9b1..f90d9dd09 100644 --- a/dacapo/experiments/arraytypes/annotations.py +++ b/dacapo/experiments/arraytypes/annotations.py @@ -8,7 +8,15 @@ class AnnotationArray(ArrayType): """ An AnnotationArray is a uint8, uint16, uint32 or uint64 Array where each - voxel has a value associated with its class. + voxel has a value associated with its class. The class of each voxel can be + determined by simply taking the value. + + Attributes: + classes (Dict[int, str]): A mapping from class label to class name. + Methods: + interpolatable(self) -> bool: It is a method that returns False. + Note: + This class is used to create an AnnotationArray object which is used to represent an array of class labels. """ classes: Dict[int, str] = attr.ib( @@ -20,4 +28,20 @@ class AnnotationArray(ArrayType): @property def interpolatable(self): + """ + Method to return False. + + Returns: + bool + Returns a boolean value of False representing that the values are not interpolatable. + Raises: + NotImplementedError + This method is not implemented in this class. + Examples: + >>> annotation_array = AnnotationArray(classes={1: "mitochondria", 2: "membrane"}) + >>> annotation_array.interpolatable + False + Note: + This method is used to check if the array is interpolatable. + """ return False diff --git a/dacapo/experiments/arraytypes/arraytype.py b/dacapo/experiments/arraytypes/arraytype.py index 0dce23ec0..c4ec2f050 100644 --- a/dacapo/experiments/arraytypes/arraytype.py +++ b/dacapo/experiments/arraytypes/arraytype.py @@ -8,7 +8,16 @@ class ArrayType(ABC): track of the semantic meaning of an Array. Additionally the ArrayType keeps track of metadata that is specific to this datatype such as num_classes for an annotated volume or channel names for intensity - arrays. + arrays. The ArrayType class is an abstract class and should be subclassed + to represent different types of arrays. + + Attributes: + num_classes (int): The number of classes in the array. + channel_names (List[str]): The names of the channels in the array. + Methods: + interpolatable: This is an abstract method which should be overridden in each of the subclasses to determine if an array is interpolatable or not. + Note: + This class is used to create an ArrayType object which is used to represent the type of data provided by an array. """ @property @@ -20,5 +29,13 @@ def interpolatable(self) -> bool: Returns: bool: True if the array is interpolatable, False otherwise. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> array_type = ArrayType() + >>> array_type.interpolatable + NotImplementedError + Note: + This method is used to check if the array is interpolatable. """ pass diff --git a/dacapo/experiments/arraytypes/binary.py b/dacapo/experiments/arraytypes/binary.py index e6c57faeb..dcc95b109 100644 --- a/dacapo/experiments/arraytypes/binary.py +++ b/dacapo/experiments/arraytypes/binary.py @@ -9,16 +9,14 @@ class BinaryArray(ArrayType): """ A subclass of ArrayType representing BinaryArray. The BinaryArray object is created with two attributes; channels. - Each voxel in this array is either 1 or 0. + Each voxel in this array is either 1 or 0. The class of each voxel can be determined by simply taking the argmax. Attributes: channels (Dict[int, str]): A dictionary attribute representing channel mapping with its binary classification. - - Args: - channels (Dict[int, str]): A dictionary input where keys are channel numbers and values are their corresponding class for binary classification. - Methods: interpolatable: Returns False as binary array type is not interpolatable. + Note: + This class is used to represent a BinaryArray object in the system. """ channels: Dict[int, str] = attr.ib( @@ -34,5 +32,13 @@ def interpolatable(self) -> bool: Returns: bool: Always returns False because interpolation is not possible. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> binary_array = BinaryArray(channels={1: "class1"}) + >>> binary_array.interpolatable + False + Note: + This method is used to check if the array is interpolatable. """ return False diff --git a/dacapo/experiments/arraytypes/distances.py b/dacapo/experiments/arraytypes/distances.py index 057f8f1b2..589e60d38 100644 --- a/dacapo/experiments/arraytypes/distances.py +++ b/dacapo/experiments/arraytypes/distances.py @@ -9,7 +9,17 @@ class DistanceArray(ArrayType): """ An array containing signed distances to the nearest boundary voxel for a particular label class. - Distances should be positive outside an object and negative inside an object. + Distances should be positive outside an object and negative inside an object. The distance should be 0 on the boundary. + The class of each voxel can be determined by simply taking the argmin. The distance should be in the range [-max, max]. + + Attributes: + classes (Dict[int, str]): A mapping from channel to class on which distances were calculated. + max (float): The maximum possible distance value of your distances. + Methods: + interpolatable(self) -> bool: It is a method that returns True. + Note: + This class is used to create a DistanceArray object which is used to represent an array containing signed distances to the nearest boundary voxel for a particular label class. + The class of each voxel can be determined by simply taking the argmin. """ classes: Dict[int, str] = attr.ib( @@ -20,4 +30,18 @@ class DistanceArray(ArrayType): @property def interpolatable(self) -> bool: + """ + Checks if the array is interpolatable. Returns True for this class. + + Returns: + bool: True indicating that the data can be interpolated. + Raises: + NotImplementedError: This method is not implemented in this class + Examples: + >>> distance_array = DistanceArray(classes={1: "class1"}) + >>> distance_array.interpolatable + True + Note: + This method is used to check if the array is interpolatable. + """ return True diff --git a/dacapo/experiments/arraytypes/embedding.py b/dacapo/experiments/arraytypes/embedding.py index 81fcadce3..ed751ca59 100644 --- a/dacapo/experiments/arraytypes/embedding.py +++ b/dacapo/experiments/arraytypes/embedding.py @@ -7,7 +7,16 @@ class EmbeddingArray(ArrayType): """ A generic output of a model that could represent almost anything. Assumed to be - float, interpolatable, and have sum number of channels. + float, interpolatable, and have sum number of channels. The channels are not + specified, and the array can be of any shape. + + Attributes: + embedding_dims (int): The dimension of your embedding. + Methods: + interpolatable(): + It is a method that returns True. + Note: + This class is used to represent an EmbeddingArray object in the system. """ embedding_dims: int = attr.ib( @@ -16,4 +25,20 @@ class EmbeddingArray(ArrayType): @property def interpolatable(self) -> bool: + """ + Method to return True. + + Returns: + bool + Returns a boolean value of True representing that the values are interpolatable. + Raises: + NotImplementedError + This method is not implemented in this class. + Examples: + >>> embedding_array = EmbeddingArray(embedding_dims=10) + >>> embedding_array.interpolatable + True + Note: + This method is used to check if the array is interpolatable. + """ return True diff --git a/dacapo/experiments/arraytypes/intensities.py b/dacapo/experiments/arraytypes/intensities.py index 84cf9227d..6cc74e96c 100644 --- a/dacapo/experiments/arraytypes/intensities.py +++ b/dacapo/experiments/arraytypes/intensities.py @@ -9,7 +9,17 @@ @attr.s class IntensitiesArray(ArrayType): """ - An IntensitiesArray is an Array of measured intensities. + An IntensitiesArray is an Array of measured intensities. Each voxel has a value in the range [min, max]. + + Attributes: + channels (Dict[int, str]): A mapping from channel to a name describing that channel. + min (float): The minimum possible value of your intensities. + max (float): The maximum possible value of your intensities. + Methods: + __attrs_post_init__(self): This method is called after the instance has been initialized by the constructor. + interpolatable(self) -> bool: It is a method that returns True. + Note: + This class is used to create an IntensitiesArray object which is used to represent an array of measured intensities. """ channels: Dict[int, str] = attr.ib( @@ -26,4 +36,20 @@ class IntensitiesArray(ArrayType): @property def interpolatable(self) -> bool: + """ + Method to return True. + + Returns: + bool + Returns a boolean value of True representing that the values are interpolatable. + Raises: + NotImplementedError + This method is not implemented in this class. + Examples: + >>> intensities_array = IntensitiesArray(channels={1: "channel1"}, min=0, max=1) + >>> intensities_array.interpolatable + True + Note: + This method is used to check if the array is interpolatable. + """ return True diff --git a/dacapo/experiments/arraytypes/mask.py b/dacapo/experiments/arraytypes/mask.py index 7f188ca73..cf2a04eaf 100644 --- a/dacapo/experiments/arraytypes/mask.py +++ b/dacapo/experiments/arraytypes/mask.py @@ -8,10 +8,11 @@ class Mask(ArrayType): """ A class that inherits the ArrayType class. This is a representation of a Mask in the system. - Methods - ------- - interpolatable(): - It is a method that returns False. + Methods: + interpolatable(): + It is a method that returns False. + Note: + This class is used to represent a Mask object in the system. """ @property @@ -19,9 +20,17 @@ def interpolatable(self) -> bool: """ Method to return False. - Returns - ------ - bool - Returns a boolean value of False representing that the values are not interpolatable. + Returns: + bool + Returns a boolean value of False representing that the values are not interpolatable. + Raises: + NotImplementedError + This method is not implemented in this class. + Examples: + >>> mask = Mask() + >>> mask.interpolatable + False + Note: + This method is used to check if the array is interpolatable. """ return False diff --git a/dacapo/experiments/arraytypes/probabilities.py b/dacapo/experiments/arraytypes/probabilities.py index e6510190f..d237aa601 100644 --- a/dacapo/experiments/arraytypes/probabilities.py +++ b/dacapo/experiments/arraytypes/probabilities.py @@ -14,6 +14,9 @@ class ProbabilityArray(ArrayType): Attributes: classes (List[str]): A mapping from channel to class on which distances were calculated. + Note: + This class is used to create a ProbabilityArray object which is used to represent an array containing probability distributions for each voxel pointed by its coordinate. + The class of each voxel can be determined by simply taking the argmax. """ classes: List[str] = attr.ib( @@ -29,5 +32,13 @@ def interpolatable(self) -> bool: Returns: bool: True indicating that the data can be interpolated. + Raises: + NotImplementedError: This method is not implemented in this class + Examples: + >>> probability_array = ProbabilityArray(classes=["class1", "class2"]) + >>> probability_array.interpolatable + True + Note: + This method is used to check if the array is interpolatable. """ return True diff --git a/dacapo/experiments/datasplits/datasets/arrays/array.py b/dacapo/experiments/datasplits/datasets/arrays/array.py index 37479e6af..f9d8322c8 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/array.py @@ -7,48 +7,168 @@ class Array(ABC): + """ + An Array is a multi-dimensional array of data that can be read from and written to. It is + defined by a region of interest (ROI) in world units, a voxel size, and a number of spatial + dimensions. The data is stored in a numpy array, and can be accessed using numpy-like slicing + syntax. + + The Array class is an abstract base class that defines the interface for all Array + implementations. It provides a number of properties that must be implemented by subclasses, + such as the ROI, voxel size, and data type of the array. It also provides a method for fetching + data from the array, which is implemented by slicing the numpy array. + + The Array class also provides a method for checking if the array can be visualized in + Neuroglancer, and a method for generating a Neuroglancer layer for the array. These methods are + implemented by subclasses that support visualization in Neuroglancer. + + Attributes: + attrs (Dict[str, Any]): A dictionary of metadata attributes stored on this array. + axes (List[str]): The axes of this dataset as a string of characters, as they are indexed. + Permitted characters are: + * ``zyx`` for spatial dimensions + * ``c`` for channels + * ``s`` for samples + dims (int): The number of spatial dimensions. + voxel_size (Coordinate): The size of a voxel in physical units. + roi (Roi): The total ROI of this array, in world units. + dtype (Any): The dtype of this array, in numpy dtypes + num_channels (Optional[int]): The number of channels provided by this dataset. Should return + None if the channel dimension doesn't exist. + data (np.ndarray): A numpy-like readable and writable view into this array. + writable (bool): Can we write to this Array? + Methods: + __getitem__(self, roi: Roi) -> np.ndarray: Get a numpy like readable and writable view into + this array. + _can_neuroglance(self) -> bool: Check if this array can be visualized in Neuroglancer. + _neuroglancer_layer(self): Generate a Neuroglancer layer for this array. + _slices(self, roi: Roi) -> Iterable[slice]: Generate a list of slices for the given ROI. + Note: + This class is used to define the interface for all Array implementations. It provides a + number of properties that must be implemented by subclasses, such as the ROI, voxel size, and + data type of the array. It also provides a method for fetching data from the array, which is + implemented by slicing the numpy array. The Array class also provides a method for checking + if the array can be visualized in Neuroglancer, and a method for generating a Neuroglancer + layer for the array. These methods are implemented by subclasses that support visualization + in Neuroglancer. + """ @property @abstractmethod def attrs(self) -> Dict[str, Any]: """ Return a dictionary of metadata attributes stored on this array. + + Returns: + Dict[str, Any]: A dictionary of metadata attributes stored on this array. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.attrs + {} + Note: + This method must be implemented by the subclass. """ pass @property @abstractmethod def axes(self) -> List[str]: - """Returns the axes of this dataset as a string of charactes, as they + """ + Returns the axes of this dataset as a string of charactes, as they are indexed. Permitted characters are: * ``zyx`` for spatial dimensions * ``c`` for channels * ``s`` for samples + + Returns: + List[str]: The axes of this dataset as a string of characters, as they are indexed. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.axes + ['z', 'y', 'x'] + Note: + This method must be implemented by the subclass. """ pass @property @abstractmethod def dims(self) -> int: - """Returns the number of spatial dimensions.""" + """ + Returns the number of spatial dimensions. + + Returns: + int: The number of spatial dimensions. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.dims + 3 + Note: + This method must be implemented by the subclass. + """ pass @property @abstractmethod def voxel_size(self) -> Coordinate: - """The size of a voxel in physical units.""" + """ + The size of a voxel in physical units. + + Returns: + Coordinate: The size of a voxel in physical units. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.voxel_size + Coordinate((1, 1, 1)) + Note: + This method must be implemented by the subclass. + """ pass @property @abstractmethod def roi(self) -> Roi: - """The total ROI of this array, in world units.""" + """ + The total ROI of this array, in world units. + + Returns: + Roi: The total ROI of this array, in world units. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.roi + Roi(offset=Coordinate((0, 0, 0)), shape=Coordinate((100, 100, 100))) + Note: + This method must be implemented by the subclass. + """ pass @property @abstractmethod def dtype(self) -> Any: - """The dtype of this array, in numpy dtypes""" + """ + The dtype of this array, in numpy dtypes + + Returns: + Any: The dtype of this array, in numpy dtypes. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.dtype + np.dtype('uint8') + Note: + This method must be implemented by the subclass. + """ pass @property @@ -57,6 +177,17 @@ def num_channels(self) -> Optional[int]: """ The number of channels provided by this dataset. Should return None if the channel dimension doesn't exist. + + Returns: + Optional[int]: The number of channels provided by this dataset. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.num_channels + 1 + Note: + This method must be implemented by the subclass. """ pass @@ -65,6 +196,17 @@ def num_channels(self) -> Optional[int]: def data(self) -> np.ndarray: """ Get a numpy like readable and writable view into this array. + + Returns: + np.ndarray: A numpy like readable and writable view into this array. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.data + np.ndarray + Note: + This method must be implemented by the subclass. """ pass @@ -73,10 +215,38 @@ def data(self) -> np.ndarray: def writable(self) -> bool: """ Can we write to this Array? + + Returns: + bool: Can we write to this Array? + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.writable + False + Note: + This method must be implemented by the subclass. """ pass def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Get a numpy like readable and writable view into this array. + + Args: + roi (Roi): The region of interest to fetch data from. + Returns: + np.ndarray: A numpy like readable and writable view into this array. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> roi = Roi(offset=Coordinate((0, 0, 0)), shape=Coordinate((100, 100, 100))) + >>> array[roi] + np.ndarray + Note: + This method must be implemented by the subclass. + """ if not self.roi.contains(roi): raise ValueError(f"Cannot fetch data from outside my roi: {self.roi}!") @@ -92,12 +262,53 @@ def __getitem__(self, roi: Roi) -> np.ndarray: return self.data[slices] def _can_neuroglance(self) -> bool: + """ + Check if this array can be visualized in Neuroglancer. + + Returns: + bool: Whether this array can be visualized in Neuroglancer. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array._can_neuroglance() + False + Note: + This method must be implemented by the subclass. + """ return False def _neuroglancer_layer(self): + """ + Generate a Neuroglancer layer for this array. + + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array._neuroglancer_layer() + NotImplementedError + Note: + This method must be implemented by the subclass. + """ pass def _slices(self, roi: Roi) -> Iterable[slice]: + """ + Generate a list of slices for the given ROI. + + Args: + roi (Roi): The region of interest to generate slices for. + Returns: + Iterable[slice]: A list of slices for the given ROI. + Examples: + >>> array = Array() + >>> roi = Roi(offset=Coordinate((0, 0, 0)), shape=Coordinate((100, 100, 100))) + >>> array._slices(roi) + [slice(None, None, None), slice(None, None, None), slice(None, None, None)] + Note: + This method must be implemented by the subclass. + """ offset = (roi.offset - self.roi.offset) / self.voxel_size shape = roi.shape / self.voxel_size spatial_slices: Dict[str, slice] = { diff --git a/dacapo/experiments/datasplits/datasets/arrays/array_config.py b/dacapo/experiments/datasplits/datasets/arrays/array_config.py index 0642cbb52..a8e51dfd2 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/array_config.py @@ -5,9 +5,22 @@ @attr.s class ArrayConfig: - """Base class for array configurations. Each subclass of an + """ + Base class for array configurations. Each subclass of an `Array` should have a corresponding config class derived from - `ArrayConfig`. + `ArrayConfig`. This class should be used to store the configuration + of the array. + + Attributes: + name (str): A unique name for this array. This will be saved so you + and others can find and reuse this array. Keep it short + and avoid special characters. + Methods: + verify(self) -> Tuple[bool, str]: This method is used to check whether this is a valid Array. + Note: + This class is used to create a base class for array configurations. Each subclass of an + `Array` should have a corresponding config class derived from `ArrayConfig`. + This class should be used to store the configuration of the array. """ name: str = attr.ib( @@ -21,5 +34,18 @@ class ArrayConfig: def verify(self) -> Tuple[bool, str]: """ Check whether this is a valid Array + + Returns: + Tuple[bool, str]: A tuple with the first element being a boolean + indicating whether the array is valid and the second element being + a string with a message explaining why the array is invalid + Raises: + NotImplementedError: This method is not implemented in this class + Examples: + >>> array_config = ArrayConfig(name="array_config") + >>> array_config.verify() + (True, "No validation for this Array") + Note: + This method is used to check whether this is a valid Array. """ return True, "No validation for this Array" diff --git a/dacapo/experiments/datasplits/datasets/arrays/binarize_array.py b/dacapo/experiments/datasplits/datasets/arrays/binarize_array.py index 791c1051c..4307bdd50 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/binarize_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/binarize_array.py @@ -13,19 +13,50 @@ class BinarizeArray(Array): Because we often want to predict classes that are a combination of a set of labels we wrap a ZarrArray with the BinarizeArray and provide something like `groupings=[("mito", [3,4,5])]` - where 4 corresponds to mito_membrane, 5 is mito_ribos, and - 3 is everything else that is part of a mitochondria. The BinarizeArray - will simply combine labels 3,4,5 into a single binary channel for th - class of "mito". + where 4 corresponds to mito_mem (mitochondria membrane), 5 is mito_ribo + (mitochondria ribosomes), and 3 is everything else that is part of a + mitochondria. The BinarizeArray will simply combine labels 3,4,5 into + a single binary channel for the class of "mito". + We use a single channel per class because some classes may overlap. For example if you had `groupings=[("mito", [3,4,5]), ("membrane", [4, 8, 1])]` - where 4 is mito_membrane, 8 is er_membrane, and 1 is plasma_membrane. + where 4 is mito_mem, 8 is er_mem (ER membrane), and 1 is pm (plasma membrane). Now you can have a binary classification for membrane or not which in some cases overlaps with the channel for mitochondria which includes the mito membrane. + + Attributes: + name (str): The name of the array. + source_array (Array): The source array to binarize. + background (int): The label to treat as background. + groupings (List[Tuple[str, List[int]]]): A list of tuples where the first + element is the name of the class and the second element is a list of + labels that should be combined into a single binary channel. + Methods: + __init__(self, array_config): This method initializes the BinarizeArray object. + __attrs_post_init__(self): This method is called after the instance has been initialized by the constructor. It is used to set the default_config to an instance of ArrayConfig if it is None. + __getitem__(self, roi: Roi) -> np.ndarray: This method returns the binary channels for the given region of interest. + _can_neuroglance(self): This method returns True if the source array can be visualized in neuroglance. + _neuroglancer_source(self): This method returns the source array for neuroglancer. + _neuroglancer_layer(self): This method returns the neuroglancer layer for the source array. + _source_name(self): This method returns the name of the source array. + Note: + This class is used to create a BinarizeArray object which is a wrapper around a ZarrArray containing uint annotations. """ def __init__(self, array_config): + """ + This method initializes the BinarizeArray object. + + Args: + array_config (ArrayConfig): The array configuration. + Raises: + AssertionError: If the source array has channels. + Examples: + >>> binarize_array = BinarizeArray(array_config) + Note: + This method is used to initialize the BinarizeArray object. + """ self.name = array_config.name self._source_array = array_config.source_array_config.array_type( array_config.source_array_config @@ -40,38 +71,147 @@ def __init__(self, array_config): @property def attrs(self): + """ + This method returns the attributes of the source array. + + Returns: + Dict: The attributes of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.attrs + Note: + This method is used to return the attributes of the source array. + """ return self._source_array.attrs @property def axes(self): + """ + This method returns the axes of the source array. + + Returns: + List[str]: The axes of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.axes + Note: + This method is used to return the axes of the source array. + """ return ["c"] + self._source_array.axes @property def dims(self) -> int: + """ + This method returns the dimensions of the source array. + + Returns: + int: The dimensions of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.dims + Note: + This method is used to return the dimensions of the source array. + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + This method returns the voxel size of the source array. + + Returns: + Coordinate: The voxel size of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.voxel_size + Note: + This method is used to return the voxel size of the source array. + """ return self._source_array.voxel_size @property def roi(self) -> Roi: + """ + This method returns the region of interest of the source array. + + Returns: + Roi: The region of interest of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.roi + Note: + This method is used to return the region of interest of the source array. + """ return self._source_array.roi @property def writable(self) -> bool: + """ + This method returns True if the source array is writable. + + Returns: + bool: True if the source array is writable. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.writable + Note: + This method is used to return True if the source array is writable. + """ return False @property def dtype(self): + """ + This method returns the data type of the source array. + + Returns: + np.dtype: The data type of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.dtype + Note: + This method is used to return the data type of the source array. + """ return np.uint8 @property def num_channels(self) -> int: + """ + This method returns the number of channels in the source array. + + Returns: + int: The number of channels in the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.num_channels + Note: + This method is used to return the number of channels in the source array. + + """ return len(self._groupings) @property def data(self): + """ + This method returns the data of the source array. + + Returns: + np.ndarray: The data of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.data + Note: + This method is used to return the data of the source array. + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." @@ -79,9 +219,35 @@ def data(self): @property def channels(self): + """ + This method returns the channel names of the source array. + + Returns: + Iterator[str]: The channel names of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.channels + Note: + This method is used to return the channel names of the source array. + """ return (name for name, _ in self._groupings) def __getitem__(self, roi: Roi) -> np.ndarray: + """ + This method returns the binary channels for the given region of interest. + + Args: + roi (Roi): The region of interest. + Returns: + np.ndarray: The binary channels for the given region of interest. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array[roi] + Note: + This method is used to return the binary channels for the given region of interest. + """ labels = self._source_array[roi] grouped = np.zeros((len(self._groupings), *labels.shape), dtype=np.uint8) for i, (_, ids) in enumerate(self._groupings): @@ -92,14 +258,62 @@ def __getitem__(self, roi: Roi) -> np.ndarray: return grouped def _can_neuroglance(self): + """ + This method returns True if the source array can be visualized in neuroglance. + + Returns: + bool: True if the source array can be visualized in neuroglance. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array._can_neuroglance() + Note: + This method is used to return True if the source array can be visualized in neuroglance. + """ return self._source_array._can_neuroglance() def _neuroglancer_source(self): + """ + This method returns the source array for neuroglancer. + + Returns: + neuroglancer.LocalVolume: The source array for neuroglancer. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array._neuroglancer_source() + Note: + This method is used to return the source array for neuroglancer. + """ return self._source_array._neuroglancer_source() def _neuroglancer_layer(self): + """ + This method returns the neuroglancer layer for the source array. + + Returns: + neuroglancer.SegmentationLayer: The neuroglancer layer for the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array._neuroglancer_layer() + Note: + This method is used to return the neuroglancer layer for the source array. + """ layer = neuroglancer.SegmentationLayer(source=self._neuroglancer_source()) return layer def _source_name(self): + """ + This method returns the name of the source array. + + Returns: + str: The name of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array._source_name() + Note: + This method is used to return the name of the source array. + """ return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/binarize_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/binarize_array_config.py index 62f4c4da6..bea576667 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/binarize_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/binarize_array_config.py @@ -8,8 +8,21 @@ @attr.s class BinarizeArrayConfig(ArrayConfig): - """This config class provides the necessary configuration for turning an Annotated dataset into a - multi class binary classification problem""" + """ + This config class provides the necessary configuration for turning an Annotated dataset into a + multi class binary classification problem. Each class will be binarized into a separate channel. + + Attributes: + source_array_config (ArrayConfig): The Array from which to pull annotated data. Is expected to contain a volume with uint64 voxels and no channel dimension + groupings (List[Tuple[str, List[int]]]): List of id groups with a symantic name. Each id group is a List of ids. + Group i found in groupings[i] will be binarized and placed in channel i. + An empty group will binarize all non background labels. + background (int): The id considered background. Will never be binarized to 1, defaults to 0. + Note: + This class is used to create a BinarizeArray object which is used to turn an Annotated dataset into a multi class binary classification problem. + Each class will be binarized into a separate channel. + + """ array_type = BinarizeArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py index 37cf650f6..0e44d38d2 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py @@ -11,10 +11,53 @@ class ConcatArray(Array): - """This is a wrapper around other `source_arrays` that concatenates - them along the channel dimension.""" + """ + This is a wrapper around other `source_arrays` that concatenates + them along the channel dimension. The `source_arrays` are expected + to have the same shape and ROI, but can have different data types. + + Attributes: + name: The name of the array. + channels: The list of channel names. + source_arrays: A dictionary mapping channel names to source arrays. + default_array: An optional default array to use for channels that are + not present in `source_arrays`. + Methods: + from_toml(cls, toml_path: str) -> ConcatArrayConfig: + Load the ConcatArrayConfig from a TOML file + to_toml(self, toml_path: str) -> None: + Save the ConcatArrayConfig to a TOML file + create_array(self) -> ConcatArray: + Create the ConcatArray from the config + Note: + This class is a subclass of Array and inherits all its attributes + and methods. The only difference is that the array_type is ConcatArray. + + """ def __init__(self, array_config): + """ + Initialize the ConcatArray from a ConcatArrayConfig. + + Args: + array_config (ConcatArrayConfig): The config to create the ConcatArray from. + Raises: + AssertionError: If the source arrays have different shapes or ROIs. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + Note: + The `source_arrays` are expected to have the same shape and ROI, + but can have different data types. + """ self.name = array_config.name self.channels = array_config.channels self.source_arrays = { @@ -29,14 +72,82 @@ def __init__(self, array_config): @property def attrs(self): + """ + Return the attributes of the ConcatArray as a dictionary. + + Returns: + Dict[str, Any]: The attributes of the ConcatArray. + Raises: + AssertionError: If the source arrays have different attributes. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.attrs + {'axes': 'cxyz', 'roi': Roi(...), 'voxel_size': (1, 1, 1)} + Note: + The `source_arrays` are expected to have the same attributes. + """ return dict() @property def source_arrays(self) -> Dict[str, Array]: + """ + Return the source arrays of the ConcatArray. + + Returns: + Dict[str, Array]: The source arrays of the ConcatArray. + Raises: + AssertionError: If the source arrays are empty. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.source_arrays + {'A': Array(...), 'B': Array(...)} + Note: + The `source_arrays` are expected to have the same shape and ROI. + """ return self._source_arrays @source_arrays.setter def source_arrays(self, value: Dict[str, Array]): + """ + Set the source arrays of the ConcatArray. + + Args: + value (Dict[str, Array]): The source arrays to set. + Raises: + AssertionError: If the source arrays are empty. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.source_arrays = {'A': Array(...), 'B': Array(...)} + Note: + The `source_arrays` are expected to have the same shape and ROI. + """ assert len(value) > 0, "Source arrays is empty!" self._source_arrays = value attrs: Dict[str, Any] = {} @@ -58,10 +169,56 @@ def source_arrays(self, value: Dict[str, Array]): @property def source_array(self) -> Array: + """ + Return the source array of the ConcatArray. + + Returns: + Array: The source array of the ConcatArray. + Raises: + AssertionError: If the source array is None. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.source_array + Array(...) + Note: + The `source_array` is expected to have the same shape and ROI. + """ return self._source_array @property def axes(self): + """ + Return the axes of the ConcatArray. + + Returns: + str: The axes of the ConcatArray. + Raises: + AssertionError: If the source arrays have different axes. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.axes + 'cxyz' + Note: + The `source_arrays` are expected to have the same axes. + """ source_axes = self.source_array.axes if "c" not in source_axes: source_axes = ["c"] + source_axes @@ -69,33 +226,210 @@ def axes(self): @property def dims(self): + """ + Return the dimensions of the ConcatArray. + + Returns: + Tuple[int]: The dimensions of the ConcatArray. + Raises: + AssertionError: If the source arrays have different dimensions. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.dims + (2, 100, 100, 100) + Note: + The `source_arrays` are expected to have the same dimensions. + """ return self.source_array.dims @property def voxel_size(self): + """ + Return the voxel size of the ConcatArray. + + Returns: + Tuple[float]: The voxel size of the ConcatArray. + Raises: + AssertionError: If the source arrays have different voxel sizes. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.voxel_size + (1, 1, 1) + Note: + The `source_arrays` are expected to have the same voxel size. + """ return self.source_array.voxel_size @property def roi(self): + """ + Return the ROI of the ConcatArray. + + Returns: + Roi: The ROI of the ConcatArray. + Raises: + AssertionError: If the source arrays have different ROIs. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.roi + Roi(...) + Note: + The `source_arrays` are expected to have the same ROI. + """ return self.source_array.roi @property def writable(self) -> bool: + """ + Return whether the ConcatArray is writable. + + Returns: + bool: Whether the ConcatArray is writable. + Raises: + AssertionError: If the ConcatArray is writable. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.writable + False + Note: + The ConcatArray is not writable. + """ return False @property def data(self): + """ + Return the data of the ConcatArray. + + Returns: + np.ndarray: The data of the ConcatArray. + Raises: + RuntimeError: If the ConcatArray is not writable. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.data + np.ndarray(...) + Note: + The ConcatArray is not writable. + """ raise RuntimeError("Cannot get writable version of this data!") @property def dtype(self): + """ + Return the data type of the ConcatArray. + + Returns: + np.dtype: The data type of the ConcatArray. + Raises: + AssertionError: If the source arrays have different data types. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.dtype + np.float32 + Note: + The `source_arrays` are expected to have the same data type. + """ return self.source_array.dtype @property def num_channels(self): + """ + Return the number of channels of the ConcatArray. + + Returns: + int: The number of channels of the ConcatArray. + Raises: + AssertionError: If the source arrays have different numbers of channels. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.num_channels + 2 + Note: + The `source_arrays` are expected to have the same number of channels. + """ return len(self.channels) def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Return the data of the ConcatArray for a given ROI. + + Args: + roi (Roi): The ROI to get the data for. + Returns: + np.ndarray: The data of the ConcatArray for the given ROI. + Raises: + AssertionError: If the source arrays have different shapes or ROIs. + Examples: + >>> roi = Roi(...) + >>> array[roi] + np.ndarray(...) + Note: + The `source_arrays` are expected to have the same shape and ROI. + """ default = ( np.zeros_like(self.source_array[roi]) if self.default_array is None diff --git a/dacapo/experiments/datasplits/datasets/arrays/concat_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/concat_array_config.py index ca76c167b..21b5cb76c 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/concat_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/concat_array_config.py @@ -8,7 +8,20 @@ @attr.s class ConcatArrayConfig(ArrayConfig): - """This array read data from the source array and then return a np.ones_like() version.""" + """ + This array read data from the source array and then return a np.ones_like() version of the data. + + Attributes: + channels (List[str]): An ordering for the source_arrays. + source_array_configs (Dict[str, ArrayConfig]): A mapping from channels to array_configs. If a channel has no ArrayConfig it will be filled with zeros + default_config (Optional[ArrayConfig]): An optional array providing the default array per channel. If not provided, missing channels will simply be filled with 0s + Methods: + __attrs_post_init__(self): This method is called after the instance has been initialized by the constructor. It is used to set the default_config to an instance of ArrayConfig if it is None. + get_array(self, source_arrays: Dict[str, np.ndarray]) -> np.ndarray: This method reads data from the source array and then return a np.ones_like() version of the data. + Note: + This class is used to create a ConcatArray object which is used to read data from the source array and then return a np.ones_like() version of the data. + The source array is a dictionary with the key being the channel and the value being the array. + """ array_type = ConcatArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/crop_array.py b/dacapo/experiments/datasplits/datasets/arrays/crop_array.py index 04b163513..5553058be 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/crop_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/crop_array.py @@ -7,10 +7,64 @@ class CropArray(Array): """ - Used to crop a larger array to a smaller array. + Used to crop a larger array to a smaller array. This is useful when you + want to work with a subset of a larger array, but don't want to copy the + data. The crop is done on demand, so the data is not copied until you + actually access it. + + Attributes: + name: The name of the array. + source_array: The array to crop. + crop_roi: The region of interest to crop to. + Methods: + attrs: Returns the attributes of the source array. + axes: Returns the axes of the source array. + dims: Returns the number of dimensions of the source array. + voxel_size: Returns the voxel size of the source array. + roi: Returns the region of interest of the source array. + writable: Returns whether the array is writable. + dtype: Returns the data type of the source array. + num_channels: Returns the number of channels of the source array. + data: Returns the data of the source array. + channels: Returns the channels of the source array. + __getitem__(roi): Returns the data of the source array within the + region of interest. + _can_neuroglance(): Returns whether the source array can be viewed in + Neuroglancer. + _neuroglancer_source(): Returns the source of the source array for + Neuroglancer. + _neuroglancer_layer(): Returns the layer of the source array for + Neuroglancer. + _source_name(): Returns the name of the source array. + Note: + This class is a subclass of Array. + + """ def __init__(self, array_config): + """ + Initializes the CropArray. + + Args: + array_config: The configuration of the array to crop. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + Note: + The source array configuration must be an instance of ArrayConfig. + """ self.name = array_config.name self._source_array = array_config.source_array_config.array_type( array_config.source_array_config @@ -19,38 +73,265 @@ def __init__(self, array_config): @property def attrs(self): + """ + Returns the attributes of the source array. + + Returns: + The attributes of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.attrs + {} + Note: + The attributes are empty because the source array is not modified. + """ return self._source_array.attrs @property def axes(self): + """ + Returns the axes of the source array. + + Returns: + The axes of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.axes + 'zyx' + Note: + The axes are 'zyx' because the source array is not modified. + """ return self._source_array.axes @property def dims(self) -> int: + """ + Returns the number of dimensions of the source array. + + Returns: + The number of dimensions of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.dims + 3 + Note: + The number of dimensions is 3 because the source array is not + modified. + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + Returns the voxel size of the source array. + + Returns: + The voxel size of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.voxel_size + Coordinate(x=1.0, y=1.0, z=1.0) + Note: + The voxel size is (1.0, 1.0, 1.0) because the source array is not + modified. + """ return self._source_array.voxel_size @property def roi(self) -> Roi: + """ + Returns the region of interest of the source array. + + Returns: + The region of interest of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.roi + Roi(offset=(0, 0, 0), shape=(10, 10, 10)) + Note: + The region of interest is (0, 0, 0) with shape (10, 10, 10) + because the source array is not modified. + """ return self.crop_roi.intersect(self._source_array.roi) @property def writable(self) -> bool: + """ + Returns whether the array is writable. + + Returns: + False + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.writable + False + Note: + The array is not writable because it is a virtual array created by + modifying another array on demand. + """ return False @property def dtype(self): + """ + Returns the data type of the source array. + + Returns: + The data type of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.dtype + numpy.dtype('uint8') + Note: + The data type is uint8 because the source array is not modified. + """ return self._source_array.dtype @property def num_channels(self) -> int: + """ + Returns the number of channels of the source array. + + Returns: + The number of channels of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.num_channels + 1 + Note: + The number of channels is 1 because the source array is not + modified. + """ return self._source_array.num_channels @property def data(self): + """ + Returns the data of the source array. + + Returns: + The data of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.data + array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0, 0, 0, 0, 0 + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." @@ -58,20 +339,170 @@ def data(self): @property def channels(self): + """ + Returns the channels of the source array. + + Returns: + The channels of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.channels + 1 + Note: + The channels is 1 because the source array is not modified. + """ return self._source_array.channels def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Returns the data of the source array within the region of interest. + + Args: + roi: The region of interest. + Returns: + The data of the source array within the region of interest. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array[Roi((0, 0, 0), (5, 5, 5))] + array([[[ + Note: + The data is the same as the source array because the source array + is not modified. + """ assert self.roi.contains(roi) return self._source_array[roi] def _can_neuroglance(self): + """ + Returns whether the source array can be viewed in Neuroglancer. + + Returns: + Whether the source array can be viewed in Neuroglancer. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> source_array_config = ArrayConfig( + ... name='source_array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array._can_neuroglance() + False + Note: + The source array cannot be viewed in Neuroglancer because the + source array is not modified. + """ return self._source_array._can_neuroglance() def _neuroglancer_source(self): + """ + Returns the source of the source array for Neuroglancer. + + Returns: + The source of the source array for Neuroglancer. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> source_array_config = ArrayConfig( + ... name='source_array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array._neuroglancer_source() + {'source': 'source_array'} + Note: + The source is the source array because the source array is not + modified. + """ return self._source_array._neuroglancer_source() def _neuroglancer_layer(self): + """ + Returns the layer of the source array for Neuroglancer. + + Returns: + The layer of the source array for Neuroglancer. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> source_array_config = ArrayConfig( + ... name='source_array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array._neuroglancer_layer() + {'source': 'source_array', 'type': 'image'} + Note: + The layer is an image because the source array is not modified. + """ return self._source_array._neuroglancer_layer() def _source_name(self): + """ + Returns the name of the source array. + + Returns: + The name of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> source_array_config = ArrayConfig( + ... name='source_array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array._source_name() + 'source_array' + Note: + The name is the source array because the source array is not + modified. + """ return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/crop_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/crop_array_config.py index 0a8d885fd..899120e90 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/crop_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/crop_array_config.py @@ -8,9 +8,26 @@ @attr.s class CropArrayConfig(ArrayConfig): - """This config class provides the necessary configuration for cropping an + """ + This config class provides the necessary configuration for cropping an Array to a smaller ROI. Especially useful for validation volumes that may - be too large for quick evaluation""" + be too large for quick evaluation. The ROI is specified in the config. The + cropped Array will have the same dtype as the source Array. + + Attributes: + source_array_config (ArrayConfig): The Array to crop + roi (Roi): The ROI for cropping + Methods: + from_toml(cls, toml_path: str) -> CropArrayConfig: + Load the CropArrayConfig from a TOML file + to_toml(self, toml_path: str) -> None: + Save the CropArrayConfig to a TOML file + create_array(self) -> CropArray: + Create the CropArray from the config + Note: + This class is a subclass of ArrayConfig and inherits all its attributes + and methods. The only difference is that the array_type is CropArray. + """ array_type = CropArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/dummy_array.py b/dacapo/experiments/datasplits/datasets/arrays/dummy_array.py index 8e3ce3daa..31cab35f1 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/dummy_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/dummy_array.py @@ -6,44 +6,188 @@ class DummyArray(Array): - """This is just a dummy array for testing.""" + """ + This is just a dummy array for testing. It has a shape of (100, 50, 50) and is filled with zeros. + + Attributes: + array_config (ArrayConfig): The config object for the array + Methods: + __getitem__: Returns the intensities normalized to the range (0, 1) + Notes: + The array_config must be an ArrayConfig object. + The min and max values are used to normalize the intensities. + All intensities are converted to float32. + + """ def __init__(self, array_config): + """ + Initializes the IntensitiesArray object + + Args: + array_config (ArrayConfig): The config object for the array + Raises: + ValueError: If the array_config is not an ArrayConfig object + Examples: + >>> array_config = ArrayConfig(...) + >>> intensities_array = IntensitiesArray(array_config) + Notes: + The array_config must be an ArrayConfig object. + """ super().__init__() self._data = np.zeros((100, 50, 50)) @property def attrs(self): + """ + Returns the attributes of the source array + + Returns: + dict: The attributes of the source array + Raises: + ValueError: If the attributes is not a dictionary + Examples: + >>> intensities_array.attrs + {'resolution': (1.0, 1.0, 1.0), 'unit': 'micrometer'} + """ return dict() @property def axes(self): + """ + Returns the axes of the source array + + Returns: + str: The axes of the source array + Raises: + ValueError: If the axes is not a string + Examples: + >>> intensities_array.axes + 'zyx' + Notes: + The axes are the same as the source array + """ return ["z", "y", "x"] @property def dims(self): + """ + Returns the number of dimensions of the source array + + Returns: + int: The number of dimensions of the source array + Raises: + ValueError: If the dims is not an integer + Examples: + >>> intensities_array.dims + 3 + Notes: + The dims are the same as the source array + """ return 3 @property def voxel_size(self): + """ + Returns the voxel size of the source array + + Returns: + Coordinate: The voxel size of the source array + Raises: + ValueError: If the voxel size is not a Coordinate object + Examples: + >>> intensities_array.voxel_size + Coordinate(x=1.0, y=1.0, z=1.0) + Notes: + The voxel size is the same as the source array + """ return Coordinate(1, 2, 2) @property def roi(self): + """ + Returns the region of interest of the source array + + Returns: + Roi: The region of interest of the source array + Raises: + ValueError: If the roi is not a Roi object + Examples: + >>> intensities_array.roi + Roi(offset=(0, 0, 0), shape=(100, 100, 100)) + Notes: + The roi is the same as the source array + """ return Roi((0, 0, 0), (100, 100, 100)) @property def writable(self) -> bool: + """ + Returns whether the array is writable + + Returns: + bool: Whether the array is writable + Examples: + >>> intensities_array.writable + True + Notes: + The array is always writable + """ return True @property def data(self): + """ + Returns the data of the source array + + Returns: + np.ndarray: The data of the source array + Raises: + ValueError: If the data is not a numpy array + Examples: + >>> intensities_array.data + array([[[0., 0., 0., ..., 0., 0., 0.], + [0., 0., 0., ..., 0., 0., 0.], + [0., 0., 0., ..., 0., 0., 0.], + ..., + [0., 0., 0., ..., 0., 0., 0.], + [0., 0., 0., ..., 0., 0., 0.], + [0., 0., 0., ..., 0., 0., 0.]], + Notes: + The data is the same as the source array + """ return self._data @property def dtype(self): + """ + Returns the data type of the array + + Returns: + type: The data type of the array + Raises: + ValueError: If the data type is not a type + Examples: + >>> intensities_array.dtype + numpy.float32 + Notes: + The data type is the same as the source array + """ return self._data.dtype @property def num_channels(self): + """ + Returns the number of channels in the source array + + Returns: + int: The number of channels in the source array + Raises: + ValueError: If the number of channels is not an integer + Examples: + >>> intensities_array.num_channels + 1 + Notes: + The number of channels is the same as the source array + """ return None diff --git a/dacapo/experiments/datasplits/datasets/arrays/dummy_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/dummy_array_config.py index fba67ec51..58fcab517 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/dummy_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/dummy_array_config.py @@ -8,10 +8,34 @@ @attr.s class DummyArrayConfig(ArrayConfig): - """This is just a dummy array config used for testing. None of the - attributes have any particular meaning.""" + """ + This is just a dummy array config used for testing. None of the + attributes have any particular meaning. It is used to test the + ArrayConfig class. + + Methods: + to_array: Returns the DummyArray object + verify: Returns whether the DummyArrayConfig is valid + Notes: + The source_array_config must be an ArrayConfig object. + + """ array_type = DummyArray def verify(self) -> Tuple[bool, str]: + """ + Check whether this is a valid Array + + Returns: + Tuple[bool, str]: Whether the Array is valid and a message + Raises: + ValueError: If the source is not a tuple of strings + Examples: + >>> dummy_array_config = DummyArrayConfig(...) + >>> dummy_array_config.verify() + (False, "This is a DummyArrayConfig and is never valid") + Notes: + The source must be a tuple of strings. + """ return False, "This is a DummyArrayConfig and is never valid" diff --git a/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py b/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py index e08ffe562..548821b41 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py @@ -22,41 +22,180 @@ class DVIDArray(Array): - """This is a DVID array""" + """ + This is a DVID array. It is a wrapper around a DVID array that provides + the necessary methods to interact with the array. It is used to fetch data + from a DVID server. The source is a tuple of three strings: the server, the UUID, + and the data name. + + DVID: data management system for terabyte-sized 3D images + + Attributes: + name (str): The name of the array + source (tuple[str, str, str]): The source of the array + Methods: + __getitem__: Returns the data from the array for a given region of interest + Notes: + The source is a tuple of three strings: the server, the UUID, and the data name. + """ def __init__(self, array_config): + """ + Initializes the DVIDArray object + + Args: + array_config (ArrayConfig): The config object for the array + Returns: + DVIDArray: The DVIDArray object + Raises: + ValueError: If the array_config is not an ArrayConfig object + Examples: + >>> array_config = ArrayConfig(...) + >>> dvid_array = DVIDArray(array_config) + Notes: + The array_config must be an ArrayConfig object. + + """ super().__init__() self.name: str = array_config.name self.source: tuple[str, str, str] = array_config.source def __str__(self): + """ + Returns the string representation of the DVIDArray object + + Returns: + str: The string representation of the DVIDArray object + Raises: + ValueError: If the source is not a tuple of three strings + Examples: + >>> str(dvid_array) + DVIDArray(('server', 'UUID', 'data_name')) + Notes: + The string representation is the source of the array + """ return f"DVIDArray({self.source})" def __repr__(self): + """ + Returns the string representation of the DVIDArray object + + Returns: + str: The string representation of the DVIDArray object + Raises: + ValueError: If the source is not a tuple of three strings + Examples: + >>> repr(dvid_array) + DVIDArray(('server', 'UUID', 'data_name')) + Notes: + The string representation is the source of the array + """ return f"DVIDArray({self.source})" @lazy_property.LazyProperty def attrs(self): + """ + Returns the attributes of the DVID array + + Returns: + dict: The attributes of the DVID array + Raises: + ValueError: If the attributes is not a dictionary + Examples: + >>> dvid_array.attrs + {'Extended': {'VoxelSize': (1.0, 1.0, 1.0), 'Values': [{'DataType': 'uint64'}]}, 'Extents': {'MinPoint': (0, 0, 0), 'MaxPoint': (100, 100, 100)}} + Notes: + The attributes are the same as the source array + """ return fetch_info(*self.source) @property def axes(self): + """ + Returns the axes of the DVID array + + Returns: + str: The axes of the DVID array + Raises: + ValueError: If the axes is not a string + Examples: + >>> dvid_array.axes + 'zyx' + Notes: + The axes are the same as the source array + """ return ["c", "z", "y", "x"][-self.dims :] @property def dims(self) -> int: + """ + Returns the dimensions of the DVID array + + Returns: + int: The dimensions of the DVID array + Raises: + ValueError: If the dimensions is not an integer + Examples: + >>> dvid_array.dims + 3 + Notes: + The dimensions are the same as the source array + """ return self.voxel_size.dims @lazy_property.LazyProperty def _daisy_array(self) -> funlib.persistence.Array: + """ + Returns the DVID array as a Daisy array + + Returns: + funlib.persistence.Array: The DVID array as a Daisy array + Raises: + ValueError: If the DVID array is not a Daisy array + Examples: + >>> dvid_array._daisy_array + Array(...) + Notes: + The DVID array is a Daisy array + """ raise NotImplementedError() @lazy_property.LazyProperty def voxel_size(self) -> Coordinate: + """ + Returns the voxel size of the DVID array + + Returns: + Coordinate: The voxel size of the DVID array + Raises: + ValueError: If the voxel size is not a Coordinate object + Examples: + >>> dvid_array.voxel_size + Coordinate(x=1.0, y=1.0, z=1.0) + Notes: + The voxel size is the same as the source array + """ return Coordinate(self.attrs["Extended"]["VoxelSize"]) @lazy_property.LazyProperty def roi(self) -> Roi: + """ + Returns the region of interest of the DVID array + + Returns: + Roi: The region of interest of the DVID array + Raises: + ValueError: If the region of interest is not a Roi object + Examples: + >>> dvid_array.roi + Roi(...) + Notes: + The region of interest is the same as the source array + """ + return Roi( + Coordinate(self.attrs["Extents"]["MinPoint"]) * self.voxel_size, + Coordinate(self.attrs["Extents"]["MaxPoint"]) * self.voxel_size, + ) return Roi( Coordinate(self.attrs["Extents"]["MinPoint"]) * self.voxel_size, Coordinate(self.attrs["Extents"]["MaxPoint"]) * self.voxel_size, @@ -64,25 +203,105 @@ def roi(self) -> Roi: @property def writable(self) -> bool: + """ + Returns whether the DVID array is writable + + Returns: + bool: Whether the DVID array is writable + Raises: + ValueError: If the writable is not a boolean + Examples: + >>> dvid_array.writable + False + Notes: + The writable is the same as the source array + """ return False @property def dtype(self) -> Any: + """ + Returns the data type of the DVID array + + Returns: + type: The data type of the DVID array + Raises: + ValueError: If the data type is not a type + Examples: + >>> dvid_array.dtype + numpy.uint64 + Notes: + The data type is the same as the source array + """ return np.dtype(self.attrs["Extended"]["Values"][0]["DataType"]) @property def num_channels(self) -> Optional[int]: + """ + Returns the number of channels of the DVID array + + Returns: + int: The number of channels of the DVID array + Raises: + ValueError: If the number of channels is not an integer + Examples: + >>> dvid_array.num_channels + 1 + Notes: + The number of channels is the same as the source array + """ return None @property def spatial_axes(self) -> List[str]: + """ + Returns the spatial axes of the DVID array + + Returns: + List[str]: The spatial axes of the DVID array + Raises: + ValueError: If the spatial axes is not a list + Examples: + >>> dvid_array.spatial_axes + ['z', 'y', 'x'] + Notes: + The spatial axes are the same as the source array + """ return [ax for ax in self.axes if ax not in set(["c", "b"])] @property def data(self) -> Any: + """ + Returns the number of channels of the DVID array + + Returns: + int: The number of channels of the DVID array + Raises: + ValueError: If the number of channels is not an integer + Examples: + >>> dvid_array.num_channels + 1 + Notes: + The number of channels is the same as the source array + """ raise NotImplementedError() def __getitem__(self, roi: Roi) -> np.ndarray[Any, Any]: + """ + Returns the data of the DVID array for a given region of interest + + Args: + roi (Roi): The region of interest for which to get the data + Returns: + np.ndarray: The data of the DVID array for the region of interest + Raises: + ValueError: If the data is not a numpy array + Examples: + >>> dvid_array[roi] + array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], [[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]]]) + Notes: + The data is the same as the source array + """ box = np.array( (roi.offset / self.voxel_size, (roi.offset + roi.shape) / self.voxel_size) ) @@ -95,22 +314,114 @@ def __getitem__(self, roi: Roi) -> np.ndarray[Any, Any]: return data def _can_neuroglance(self) -> bool: + """ + Returns whether the DVID array can be used with neuroglance + + Returns: + bool: Whether the DVID array can be used with neuroglance + Raises: + ValueError: If the DVID array cannot be used with neuroglance + Examples: + >>> dvid_array._can_neuroglance() + True + Notes: + The DVID array can be used with neuroglance + """ return True def _neuroglancer_source(self): + """ + Returns the neuroglancer source of the DVID array + + Returns: + Tuple[str, str, str]: The neuroglancer source of the DVID array + Raises: + ValueError: If the neuroglancer source is not a tuple of three strings + Examples: + >>> dvid_array._neuroglancer_source() + ('server', 'UUID', 'data_name') + Notes: + The neuroglancer source is the same as the source array + """ raise NotImplementedError() def _neuroglancer_layer(self) -> Tuple[neuroglancer.ImageLayer, Dict[str, Any]]: + """ + Returns the neuroglancer layer of the DVID array + + Returns: + Tuple[neuroglancer.ImageLayer, dict]: The neuroglancer layer of the DVID array + Raises: + ValueError: If the neuroglancer layer is not a tuple of an ImageLayer and a dictionary + Examples: + >>> dvid_array._neuroglancer_layer() + (ImageLayer(...), {}) + Notes: + The neuroglancer layer is the same as the source array + """ raise NotImplementedError() def _transform_matrix(self): + """ + Returns the transformation matrix of the DVID array + + Returns: + np.ndarray: The transformation matrix of the DVID array + Raises: + ValueError: If the transformation matrix is not a numpy array + Examples: + >>> dvid_array._transform_matrix() + array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + Notes: + The transformation matrix is the same as the source array + """ raise NotImplementedError() def _output_dimensions(self) -> Dict[str, Tuple[float, str]]: + """ + Returns the output dimensions of the DVID array + + Returns: + dict: The output dimensions of the DVID array + Raises: + ValueError: If the output dimensions is not a dictionary + Examples: + >>> dvid_array._output_dimensions() + {'z': (100, 'nm'), 'y': (100, 'nm'), 'x': (100, 'nm')} + Notes: + The output dimensions are the same as the source array + """ raise NotImplementedError() def _source_name(self) -> str: + """ + Returns the source name of the DVID array + + Returns: + str: The source name of the DVID array + Raises: + ValueError: If the source name is not a string + Examples: + >>> dvid_array._source_name() + 'data_name' + Notes: + The source name is the same as the source array + """ raise NotImplementedError() def add_metadata(self, metadata: Dict[str, Any]) -> None: + """ + Adds metadata to the DVID array + + Args: + metadata (dict): The metadata to add to the DVID array + Returns: + None + Raises: + ValueError: If the metadata is not a dictionary + Examples: + >>> dvid_array.add_metadata({'description': 'This is a DVID array'}) + Notes: + The metadata is added to the source array + """ raise NotImplementedError() diff --git a/dacapo/experiments/datasplits/datasets/arrays/dvid_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/dvid_array_config.py index d9c5071c0..bc25300dc 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/dvid_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/dvid_array_config.py @@ -9,7 +9,17 @@ @attr.s class DVIDArrayConfig(ArrayConfig): - """This config class provides the necessary configuration for a DVID array""" + """ + This config class provides the necessary configuration for a DVID array. It takes a source string and returns the DVIDArray object. + + Attributes: + source (Tuple[str, str, str]): The source strings + Methods: + to_array: Returns the DVIDArray object + Notes: + The source must be a tuple of strings. + + """ array_type = DVIDArray @@ -20,5 +30,16 @@ class DVIDArrayConfig(ArrayConfig): def verify(self) -> Tuple[bool, str]: """ Check whether this is a valid Array + + Returns: + Tuple[bool, str]: Whether the Array is valid and a message + Raises: + ValueError: If the source is not a tuple of strings + Examples: + >>> dvid_array_config = DVIDArrayConfig(...) + >>> dvid_array_config.verify() + (True, "No validation for this Array") + Notes: + The source must be a tuple of strings. """ return True, "No validation for this Array" diff --git a/dacapo/experiments/datasplits/datasets/arrays/intensity_array.py b/dacapo/experiments/datasplits/datasets/arrays/intensity_array.py index 9840cddd9..f53ff1e1b 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/intensity_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/intensity_array.py @@ -11,9 +11,33 @@ class IntensitiesArray(Array): the range (0, 1) and convert to float32. Use this if you have your intensities stored as uint8 or similar and want your model to have floats as input. + + Attributes: + array_config (ArrayConfig): The config object for the array + min (float): The minimum intensity value in the array + max (float): The maximum intensity value in the array + Methods: + __getitem__: Returns the intensities normalized to the range (0, 1) + Notes: + The array_config must be an ArrayConfig object. + The min and max values are used to normalize the intensities. + All intensities are converted to float32. """ def __init__(self, array_config): + """ + Initializes the IntensitiesArray object + + Args: + array_config (ArrayConfig): The config object for the array + Raises: + ValueError: If the array_config is not an ArrayConfig object + Examples: + >>> array_config = ArrayConfig(...) + >>> intensities_array = IntensitiesArray(array_config) + Notes: + The array_config must be an ArrayConfig object. + """ self.name = array_config.name self._source_array = array_config.source_array_config.array_type( array_config.source_array_config @@ -24,44 +48,176 @@ def __init__(self, array_config): @property def attrs(self): + """ + Returns the attributes of the source array + + Returns: + dict: The attributes of the source array + Raises: + ValueError: If the attributes is not a dictionary + Examples: + >>> intensities_array.attrs + {'resolution': (1.0, 1.0, 1.0), 'unit': 'micrometer'} + Notes: + The attributes are the same as the source array + """ return self._source_array.attrs @property def axes(self): + """ + Returns the axes of the source array + + Returns: + str: The axes of the source array + Raises: + ValueError: If the axes is not a string + Examples: + >>> intensities_array.axes + 'zyx' + Notes: + The axes are the same as the source array + """ return self._source_array.axes @property def dims(self) -> int: + """ + Returns the dimensions of the source array + + Returns: + int: The dimensions of the source array + Raises: + ValueError: If the dimensions is not an integer + Examples: + >>> intensities_array.dims + 3 + Notes: + The dimensions are the same as the source array + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + Returns the voxel size of the source array + + Returns: + Coordinate: The voxel size of the source array + Raises: + ValueError: If the voxel size is not a Coordinate object + Examples: + >>> intensities_array.voxel_size + Coordinate(x=1.0, y=1.0, z=1.0) + Notes: + The voxel size is the same as the source array + """ return self._source_array.voxel_size @property def roi(self) -> Roi: + """ + Returns the region of interest of the source array + + Returns: + Roi: The region of interest of the source array + Raises: + ValueError: If the region of interest is not a Roi object + Examples: + >>> intensities_array.roi + Roi(offset=(0, 0, 0), shape=(10, 20, 30)) + Notes: + The region of interest is the same as the source array + """ return self._source_array.roi @property def writable(self) -> bool: + """ + Returns whether the array is writable + + Returns: + bool: Whether the array is writable + Raises: + ValueError: If the array is not writable + Examples: + >>> intensities_array.writable + False + Notes: + The array is not writable because it is a virtual array created by modifying another array on demand. + """ return False @property def dtype(self): + """ + Returns the data type of the array + + Returns: + type: The data type of the array + Raises: + ValueError: If the data type is not a type + Examples: + >>> intensities_array.dtype + numpy.float32 + Notes: + The data type is always float32 + """ return np.float32 @property def num_channels(self) -> int: + """ + Returns the number of channels in the source array + + Returns: + int: The number of channels in the source array + Raises: + ValueError: If the number of channels is not an integer + Examples: + >>> intensities_array.num_channels + 3 + Notes: + The number of channels is the same as the source array + """ return self._source_array.num_channels @property def data(self): + """ + Returns the data of the source array + + Returns: + np.ndarray: The data of the source array + Raises: + ValueError: If the data is not a numpy array + Examples: + >>> intensities_array.data + array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], [[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]]]) + Notes: + The data is the same as the source array + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." ) def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Returns the intensities normalized to the range (0, 1) + + Args: + roi (Roi): The region of interest to get the intensities from + Returns: + np.ndarray: The intensities normalized to the range (0, 1) + Raises: + ValueError: If the intensities are not in the range (0, 1) + Examples: + >>> intensities_array[roi] + array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], [[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]]]) + Notes: + The intensities are normalized to the range (0, 1) + """ intensities = self._source_array[roi] normalized = (intensities.astype(np.float32) - self._min) / ( self._max - self._min @@ -69,13 +225,66 @@ def __getitem__(self, roi: Roi) -> np.ndarray: return normalized def _can_neuroglance(self): + """ + Returns whether the array can be visualized with Neuroglancer + + Returns: + bool: Whether the array can be visualized with Neuroglancer + Raises: + ValueError: If the array cannot be visualized with Neuroglancer + Examples: + >>> intensities_array._can_neuroglance() + True + Notes: + The array can be visualized with Neuroglancer if the source array can be visualized with Neuroglancer + + """ return self._source_array._can_neuroglance() def _neuroglancer_layer(self): + """ + Returns the Neuroglancer layer of the source array + + Returns: + dict: The Neuroglancer layer of the source array + Raises: + ValueError: If the Neuroglancer layer is not a dictionary + Examples: + >>> intensities_array._neuroglancer_layer() + {'type': 'image', 'source': 'precomputed://https://mybucket.s3.amazonaws.com/mydata'} + Notes: + The Neuroglancer layer is the same as the source array + """ return self._source_array._neuroglancer_layer() def _source_name(self): + """ + Returns the name of the source array + + Returns: + str: The name of the source array + Raises: + ValueError: If the name is not a string + Examples: + >>> intensities_array._source_name() + 'mydata' + Notes: + The name is the same as the source array + """ return self._source_array._source_name() def _neuroglancer_source(self): + """ + Returns the Neuroglancer source of the source array + + Returns: + str: The Neuroglancer source of the source array + Raises: + ValueError: If the Neuroglancer source is not a string + Examples: + >>> intensities_array._neuroglancer_source() + 'precomputed://https://mybucket.s3.amazonaws.com/mydata' + Notes: + The Neuroglancer source is the same as the source array + """ return self._source_array._neuroglancer_source() diff --git a/dacapo/experiments/datasplits/datasets/arrays/intensity_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/intensity_array_config.py index 87281f69f..5273639b0 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/intensity_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/intensity_array_config.py @@ -6,8 +6,20 @@ @attr.s class IntensitiesArrayConfig(ArrayConfig): - """This config class provides the necessary configuration for turning an Annotated dataset into a - multi class binary classification problem""" + """ + This config class provides the necessary configuration for turning an Annotated dataset into a + multi class binary classification problem. It takes a source array and normalizes the intensities + between 0 and 1. The source array is expected to contain a volume with uint64 voxels and no channel dimension. + + Attributes: + source_array_config (ArrayConfig): The Array from which to pull annotated data + min (float): The minimum intensity in your data + max (float): The maximum intensity in your data + Methods: + to_array: Returns the IntensitiesArray object + Notes: + The source_array_config must be an ArrayConfig object. + """ array_type = IntensitiesArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/logical_or_array.py b/dacapo/experiments/datasplits/datasets/arrays/logical_or_array.py index 995f27d05..a06fda32a 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/logical_or_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/logical_or_array.py @@ -9,9 +9,93 @@ class LogicalOrArray(Array): - """ """ + """ + Array that computes the logical OR of the instances in a list of source arrays. + + Attributes: + name: str + The name of the array + source_array: Array + The source array from which to take the logical OR + Methods: + axes: () -> List[str] + Get the axes of the array + dims: () -> int + Get the number of dimensions of the array + voxel_size: () -> Coordinate + Get the voxel size of the array + roi: () -> Roi + Get the region of interest of the array + writable: () -> bool + Get whether the array is writable + dtype: () -> type + Get the data type of the array + num_channels: () -> int + Get the number of channels in the array + data: () -> np.ndarray + Get the data of the array + attrs: () -> dict + Get the attributes of the array + __getitem__: (roi: Roi) -> np.ndarray + Get the data of the array in the region of interest + _can_neuroglance: () -> bool + Get whether the array can be visualized in neuroglance + _neuroglancer_source: () -> dict + Get the neuroglancer source of the array + _neuroglancer_layer: () -> Tuple[neuroglancer.Layer, dict] + Get the neuroglancer layer of the array + _source_name: () -> str + Get the name of the source array + Notes: + The LogicalOrArray class is used to create a LogicalOrArray. The LogicalOrArray + class is a subclass of the Array class. + """ def __init__(self, array_config): + """ + Create a LogicalOrArray instance from a configuration + Args: + array_config: MergeInstancesArrayConfig + The configuration for the array + Returns: + LogicalOrArray + The LogicalOrArray instance created from the configuration + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.name + 'logical_or' + >>> array.source_array.name + 'mask1' + >>> array.source_array.mask_id + 1 + Notes: + The create_array method is used to create a LogicalOrArray instance from a + configuration. The LogicalOrArray instance is created by taking the logical OR + of the instances in the source arrays. + """ self.name = array_config.name self._source_array = array_config.source_array_config.array_type( array_config.source_array_config @@ -19,34 +103,330 @@ def __init__(self, array_config): @property def axes(self): + """ + Get the axes of the array + + Returns: + List[str]: The axes of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.axes + ['x', 'y', 'z'] + Notes: + The axes method is used to get the axes of the array. The axes are the dimensions + of the array. + """ return [x for x in self._source_array.axes if x != "c"] @property def dims(self) -> int: + """ + Get the number of dimensions of the array + + Returns: + int: The number of dimensions of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.dims + 3 + Notes: + The dims method is used to get the number of dimensions of the array. The number + of dimensions is the number of axes of the array. + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + Get the voxel size of the array + + Returns: + Coordinate: The voxel size of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.voxel_size + Coordinate(x=1.0, y=1.0, z=1.0) + Notes: + The voxel_size method is used to get the voxel size of the array. The voxel size + is the size of a voxel in the array. + + """ return self._source_array.voxel_size @property def roi(self) -> Roi: + """ + Get the region of interest of the array + + Returns: + Roi: The region of interest of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.roi + Roi(offset=(0, 0, 0), shape=(10, 10, 10)) + Notes: + The roi method is used to get the region of interest of the array. The region of + interest is the shape and offset of the array. + """ return self._source_array.roi @property def writable(self) -> bool: + """ + Get whether the array is writable + + Returns: + bool: Whether the array is writable + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.writable + False + Notes: + The writable method is used to get whether the array is writable. An array is + writable if it can be modified. + """ return False @property def dtype(self): + """ + Get the data type of the array + + Returns: + type: The data type of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.dtype + + Notes: + The dtype method is used to get the data type of the array. The data type is the + type of the data in the array. + """ return np.uint8 @property def num_channels(self): + """ + Get the number of channels in the array + + Returns: + int: The number of channels in the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.num_channels + 1 + Notes: + The num_channels method is used to get the number of channels in the array. The + number of channels is the number of channels in the array. + """ return None @property def data(self): + """ + Get the data of the array + + Returns: + np.ndarray: The data of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.data + array([[[1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1], + ..., + [1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1]]], dtype=uint8) + Notes: + The data method is used to get the data of the array. The data is the content of + the array. + + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." @@ -54,21 +434,211 @@ def data(self): @property def attrs(self): + """ + Get the attributes of the array + + Returns: + dict: The attributes of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.attrs + {'name': 'logical_or'} + Notes: + The attrs method is used to get the attributes of the array. The attributes are + the metadata of the array. + """ return self._source_array.attrs def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Get the data of the array in the region of interest + + Args: + roi: Roi + The region of interest of the array + Returns: + np.ndarray: The data of the array in the region of interest + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> roi = Roi((0, 0, 0), (10, 10, 10)) + >>> array[roi] + array([[[1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1], + ..., + [1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1]]], dtype=uint8) + Notes: + The __getitem__ method is used to get the data of the array in the region of interest. + The data is the content of the array in the region of interest. + """ mask = self._source_array[roi] if "c" in self._source_array.axes: mask = np.max(mask, axis=self._source_array.axes.index("c")) return mask def _can_neuroglance(self): + """ + Get whether the array can be visualized in neuroglance + + Returns: + bool: Whether the array can be visualized in neuroglance + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array._can_neuroglance() + True + Notes: + The _can_neuroglance method is used to get whether the array can be visualized + in neuroglance. + """ return self._source_array._can_neuroglance() def _neuroglancer_source(self): + """ + Get the neuroglancer source of the array + + Returns: + dict: The neuroglancer source of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array._neuroglancer_source() + {'source': 'precomputed://https://mybucket.storage.googleapis.com/path/to/logical_or'} + Notes: + The _neuroglancer_source method is used to get the neuroglancer source of the array. + The neuroglancer source is the source that is displayed in the neuroglancer viewer. + """ return self._source_array._neuroglancer_source() def _neuroglancer_layer(self): + """ + Get the neuroglancer layer of the array + + Returns: + Tuple[neuroglancer.Layer, dict]: The neuroglancer layer of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array._neuroglancer_layer() + (SegmentationLayer(source='precomputed://https://mybucket.storage.googleapis.com/path/to/logical_or'), {'visible': False}) + Notes: + The _neuroglancer_layer method is used to get the neuroglancer layer of the array. + The neuroglancer layer is the layer that is displayed in the neuroglancer viewer. + """ # Generates an Segmentation layer layer = neuroglancer.SegmentationLayer(source=self._neuroglancer_source()) @@ -78,4 +648,40 @@ def _neuroglancer_layer(self): return layer, kwargs def _source_name(self): + """ + Get the name of the source array + + Returns: + str: The name of the source array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array._source_name() + 'mask1' + Notes: + The _source_name method is used to get the name of the source array. The name + of the source array is the name of the array that is being modified. + """ return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/logical_or_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/logical_or_array_config.py index d0a211a8a..a22591405 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/logical_or_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/logical_or_array_config.py @@ -6,8 +6,17 @@ @attr.s class LogicalOrArrayConfig(ArrayConfig): - """This config class takes a source array and performs a logical or over the channels. - Good for union multiple masks.""" + """ + This config class takes a source array and performs a logical or over the channels. + Good for union multiple masks. + + Attributes: + source_array_config (ArrayConfig): The Array of masks from which to take the union + Methods: + to_array: Returns the LogicalOrArray object + Notes: + The source_array_config must be an ArrayConfig object. + """ array_type = LogicalOrArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array.py b/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array.py index 944c69b69..59cd344cb 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array.py @@ -9,9 +9,68 @@ class MergeInstancesArray(Array): - """ """ + """ + This array merges multiple source arrays into a single array by summing them. This is useful for merging + instance segmentation arrays into a single array. NeuoGlancer will display each instance as a different color. + + Attributes: + name : str + The name of the array + source_array_configs : List[ArrayConfig] + A list of source arrays to merge + Methods: + __getitem__(roi: Roi) -> np.ndarray + Returns a numpy array with the requested region of interest + _can_neuroglance() -> bool + Returns True if the array can be visualized in neuroglancer + _neuroglancer_source() -> str + Returns the source name for the array in neuroglancer + _neuroglancer_layer() -> Tuple[neuroglancer.SegmentationLayer, Dict[str, Any]] + Returns a neuroglancer layer and its configuration + _source_name() -> str + Returns the source name for the array + Note: + This array is not writable + Source arrays must have the same shape. + + """ def __init__(self, array_config): + """ + Constructor for MergeInstancesArray + + Args: + array_config : MergeInstancesArrayConfig + The configuration for the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + ``` + Note: + This example shows how to create a MergeInstancesArray object + """ self.name = array_config.name self._source_arrays = [ source_config.array_type(source_config) @@ -21,34 +80,317 @@ def __init__(self, array_config): @property def axes(self): + """ + Returns the axes of the array + + Returns: + List[str]: The axes of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + axes = array.axes + ``` + Note: + This example shows how to get the axes of the array + + """ return [x for x in self._source_array.axes if x != "c"] @property def dims(self) -> int: + """ + Returns the number of dimensions of the array + + Returns: + int: The number of dimensions of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + dims = array.dims + ``` + Note: + This example shows how to get the number of dimensions of the array + + + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + Returns the voxel size of the array + + Returns: + Coordinate: The voxel size of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + voxel_size = array.voxel_size + ``` + Note: + This example shows how to get the voxel size of the array + """ return self._source_array.voxel_size @property def roi(self) -> Roi: + """ + Returns the region of interest of the array + + Returns: + Roi: The region of interest of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + roi = array.roi + ``` + Note: + This example shows how to get the region of interest of the array + """ return self._source_array.roi @property def writable(self) -> bool: + """ + Returns True if the array is writable, False otherwise + + Returns: + bool: True if the array is writable, False otherwise + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + writable = array.writable + ``` + Note: + This example shows how to check if the array is writable + """ return False @property def dtype(self): + """ + Returns the data type of the array + + Returns: + np.dtype: The data type of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + dtype = array.dtype + ``` + Note: + This example shows how to get the data type of the array + """ return np.uint8 @property def num_channels(self): + """ + Returns the number of channels of the array + + Returns: + int: The number of channels of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + num_channels = array.num_channels + ``` + Note: + This example shows how to get the number of channels of the array + """ return None @property def data(self): + """ + Returns the data of the array + + Returns: + np.ndarray: The data of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + data = array.data + ``` + Note: + This example shows how to get the data of the array + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." @@ -56,9 +398,83 @@ def data(self): @property def attrs(self): + """ + Returns the attributes of the array + + Returns: + Dict[str, Any]: The attributes of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + attributes = array.attrs + ``` + Note: + This example shows how to get the attributes of the array + """ return self._source_array.attrs def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Returns a numpy array with the requested region of interest + + Args: + roi : Roi + The region of interest to get + Returns: + np.ndarray: A numpy array with the requested region of interest + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + roi = Roi((0, 0, 0), (100, 100, 100)) + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + array_data = array[roi] + ``` + Note: + This example shows how to get a numpy array with the requested region of interest + """ arrays = [source_array[roi] for source_array in self._source_arrays] offset = 0 for array in arrays: @@ -67,12 +483,117 @@ def __getitem__(self, roi: Roi) -> np.ndarray: return np.sum(arrays, axis=0) def _can_neuroglance(self): + """ + Returns True if the array can be visualized in neuroglancer, False otherwise + + Returns: + bool: True if the array can be visualized in neuroglancer, False otherwise + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + can_neuroglance = array._can_neuroglance() + ``` + Note: + This example shows how to check if the array can be visualized in neuroglancer + """ return self._source_array._can_neuroglance() def _neuroglancer_source(self): + """ + Returns the source name for the array in neuroglancer + + Returns: + str: The source name for the array in neuroglancer + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + source = array._neuroglancer_source() + ``` + Note: + This example shows how to get the source name for the array in neuroglancer + """ return self._source_array._neuroglancer_source() def _neuroglancer_layer(self): + """ + Returns a neuroglancer layer and its configuration + + Returns: + Tuple[neuroglancer.SegmentationLayer, Dict[str, Any]]: A neuroglancer layer and its configuration + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + layer, kwargs = array._neuroglancer_layer() + ``` + Note: + This example shows how to get a neuroglancer layer and its configuration + """ # Generates an Segmentation layer layer = neuroglancer.SegmentationLayer(source=self._neuroglancer_source()) @@ -82,4 +603,39 @@ def _neuroglancer_layer(self): return layer, kwargs def _source_name(self): + """ + Returns the source name for the array + + Returns: + str: The source name for the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + source_name = array._source_name() + ``` + Note: + This example shows how to get the source name for the array + """ return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array_config.py index 31c6e5acd..b2befad9c 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array_config.py @@ -8,6 +8,20 @@ @attr.s class MergeInstancesArrayConfig(ArrayConfig): + """ + Configuration for an array that merges instances from multiple arrays + into a single array. The instances are merged by taking the union of the + instances in the source arrays. + + Attributes: + source_array_configs: List[ArrayConfig] + The Array of masks from which to take the union + Methods: + create_array: () -> MergeInstancesArray + Create a MergeInstancesArray instance from the configuration + Notes: + The MergeInstancesArrayConfig class is used to create a MergeInstancesArray + """ array_type = MergeInstancesArray source_array_configs: List[ArrayConfig] = attr.ib( diff --git a/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask.py b/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask.py index 3d1a86b93..6d1750089 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask.py +++ b/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask.py @@ -20,9 +20,32 @@ class MissingAnnotationsMask(Array): See package fibsem_tools for appropriate metadata format for indicating presence of labels in your ground truth. "https://github.com/janelia-cosem/fibsem-tools" + + Attributes: + array_config: A BinarizeArrayConfig object + Methods: + __getitem__(roi: Roi) -> np.ndarray: Returns a binary mask of the + annotations that are present but not annotated. + Note: + This class is not meant to be used directly. It is used by the + BinarizeArray class to mask out annotations that are present but + not annotated. """ def __init__(self, array_config): + """ + Initializes the MissingAnnotationsMask class + + Args: + array_config (BinarizeArrayConfig): A BinarizeArrayConfig object + Raises: + AssertionError: If the source array has channels + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> missing_annotations_mask = MissingAnnotationsMask(MissingAnnotationsMaskConfig(source_array, groupings)) + Notes: + This is a helper function for the BinarizeArray class + """ self.name = array_config.name self._source_array = array_config.source_array_config.array_type( array_config.source_array_config @@ -36,34 +59,152 @@ def __init__(self, array_config): @property def axes(self): + """ + Returns the axes of the source array + + Returns: + list: Axes of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.axes + ['x', 'y', 'z'] + Notes: + This is a helper function for the BinarizeArray class + """ return ["c"] + self._source_array.axes @property def dims(self) -> int: + """ + Returns the number of dimensions of the source array + + Returns: + int: Number of dimensions of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.dims + 3 + Notes: + This is a helper function for the BinarizeArray class + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + Returns the voxel size of the source array + + Returns: + Coordinate: Voxel size of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.voxel_size + Coordinate(x=4, y=4, z=40) + Notes: + This is a helper function for the BinarizeArray class + + """ return self._source_array.voxel_size @property def roi(self) -> Roi: + """ + Returns the region of interest of the source array + + Returns: + Roi: Region of interest of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.roi + Roi(offset=(0, 0, 0), shape=(100, 100, 100)) + Notes: + This is a helper function for the BinarizeArray class + """ return self._source_array.roi @property def writable(self) -> bool: + """ + Returns whether the source array is writable + + Returns: + bool: Whether the source array is writable + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.writable + False + Notes: + This is a helper function for the BinarizeArray class + + """ return False @property def dtype(self): + """ + Returns the data type of the source array + + Returns: + np.dtype: Data type of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.dtype + np.uint8 + Notes: + This is a helper function for the BinarizeArray class + + """ return np.uint8 @property def num_channels(self) -> int: + """ + Returns the number of channels + + Returns: + int: Number of channels + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.num_channels + 2 + Notes: + This is a helper function for the BinarizeArray class + + + """ return len(self._groupings) @property def data(self): + """ + Returns the data of the source array + + Returns: + np.ndarray: Data of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.data + np.ndarray(...) + Notes: + This is a helper function for the BinarizeArray class + + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." @@ -71,13 +212,62 @@ def data(self): @property def attrs(self): + """ + Returns the attributes of the source array + + Returns: + dict: Attributes of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.attrs + {'name': 'source_array', 'resolution': [4, 4, 40]} + Notes: + This is a helper function for the BinarizeArray class + """ return self._source_array.attrs @property def channels(self): + """ + Returns the names of the channels + + Returns: + Generator[str]: Names of the channels + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.channels + Generator['channel1', 'channel2', ...] + Notes: + This is a helper function for the BinarizeArray class + """ return (name for name, _ in self._groupings) def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Returns a binary mask of the annotations that are present but not annotated. + + Args: + roi (Roi): Region of interest to get the mask for + Returns: + np.ndarray: Binary mask of the annotations that are present but not annotated + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> missing_annotations_mask = MissingAnnotationsMask(MissingAnnotationsMaskConfig(source_array, groupings)) + >>> roi = Roi(...) + >>> missing_annotations_mask[roi] + np.ndarray(...) + Notes: + - This is a helper function for the BinarizeArray class + - Number of channels in the mask is equal to the number of groupings + - Nuclues is a special case where we mask out the whole channel if any of the + sub-organelles are present but not annotated + """ labels = self._source_array[roi] grouped = np.ones((len(self._groupings), *labels.shape), dtype=bool) grouped[:] = labels > 0 @@ -93,32 +283,63 @@ def __getitem__(self, roi: Roi) -> np.ndarray: ) for i, (_, ids) in enumerate(self._groupings): if any([id in present_not_annotated for id in ids]): - # specially handle id 37 - # TODO: find more general solution - if 37 in ids and 37 not in present_not_annotated: - # 37 marks any kind of nucleus voxel. There many be nucleus sub - # organelles marked as "present not annotated", but we can safely - # train any channel that includes those organelles as long as - # 37 is annotated. - pass - else: - # mask out this whole channel - grouped[i] = 0 - - # for id in ids: - # grouped[i][labels == id] = 0 + grouped[i] = 0 except KeyError: pass return grouped def _can_neuroglance(self): + """ + Returns whether the array can be visualized in neuroglancer + + Returns: + bool: Whether the array can be visualized in neuroglancer + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array._can_neuroglance() + True + Notes: + This is a helper function for the neuroglancer layer + + """ return self._source_array._can_neuroglance() def _neuroglancer_source(self): + """ + Returns a neuroglancer source for the array + + Returns: + neuroglancer.LocalVolume: Neuroglancer source for the array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array._neuroglancer_source() + neuroglancer.LocalVolume(...) + Notes: + This is a helper function for the neuroglancer layer + """ return self._source_array._neuroglancer_source() def _neuroglancer_layer(self): + """ + Returns a neuroglancer Segmentation layer for the array + + Returns: + neuroglancer.SegmentationLayer: Segmentation layer for the array + dict: Keyword arguments for the layer + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array._neuroglancer_layer() + (neuroglancer.SegmentationLayer, dict) + Notes: + This is a helper function for the neuroglancer layer + """ # Generates an Segmentation layer layer = neuroglancer.SegmentationLayer(source=self._neuroglancer_source()) @@ -128,4 +349,18 @@ def _neuroglancer_layer(self): return layer, kwargs def _source_name(self): + """ + Returns the name of the source array + + Returns: + str: Name of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array._source_name() + 'source_array' + Notes: + This is a helper function for the neuroglancer layer name + """ return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask_config.py b/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask_config.py index 6fae4d51d..2785df4b3 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask_config.py @@ -8,8 +8,20 @@ @attr.s class MissingAnnotationsMaskConfig(ArrayConfig): - """This config class provides the necessary configuration for turning an Annotated dataset into a - multi class binary classification problem""" + """ + This config class provides the necessary configuration for turning an Annotated dataset into a + multi class binary classification problem + + Attributes: + source_array_config : ArrayConfig + The Array from which to pull annotated data. Is expected to contain a volume with uint64 voxels and no channel dimension + groupings : List[Tuple[str, List[int]]] + List of id groups with a symantic name. Each id group is a List of ids. + Group i found in groupings[i] will be binarized and placed in channel i. + Note: + The output array will have a channel dimension equal to the number of groups. + Each channel will be a binary mask of the ids in the groupings list. + """ array_type = MissingAnnotationsMask diff --git a/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py b/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py index 5f2bc0483..f4c3c2ef7 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py @@ -9,7 +9,21 @@ class NumpyArray(Array): - """This is just a wrapper for a numpy array to make it fit the DaCapo Array interface.""" + """ + This is just a wrapper for a numpy array to make it fit the DaCapo Array interface. + + Attributes: + data: The numpy array. + dtype: The data type of the numpy array. + roi: The region of interest of the numpy array. + voxel_size: The voxel size of the numpy array. + axes: The axes of the numpy array. + Methods: + from_gp_array: Create a NumpyArray from a Gunpowder Array. + from_np_array: Create a NumpyArray from a numpy array. + Note: + This class is a subclass of Array. + """ _data: np.ndarray _dtype: np.dtype @@ -18,14 +32,73 @@ class NumpyArray(Array): _axes: List[str] def __init__(self, array_config): + """ + Create a NumpyArray from an array config. + + Args: + array_config: The array config. + Returns: + NumpyArray: The NumpyArray. + Raises: + ValueError: If the array does not have a data type. + Examples: + >>> array = NumpyArray(OnesArrayConfig(source_array_config=ArrayConfig())) + >>> array.data + array([[[1., 1., 1., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 1.]], + + [[1., 1., 1., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 1.]]]) + Note: + This method creates a NumpyArray from an array config. + """ raise RuntimeError("Numpy Array cannot be built from a config file") @property def attrs(self): + """ + Returns the attributes of the array. + + Returns: + dict: The attributes of the array. + Raises: + ValueError: If the array does not have attributes. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.attrs + {} + Note: + This method is a property. It returns the attributes of the array. + """ return dict() @classmethod def from_gp_array(cls, array: gp.Array): + """ + Create a NumpyArray from a Gunpowder Array. + + Args: + array (gp.Array): The Gunpowder Array. + Returns: + NumpyArray: The NumpyArray. + Raises: + ValueError: If the array does not have a data type. + Examples: + >>> array = gp.Array(data=np.zeros((2, 3, 4)), spec=gp.ArraySpec(roi=Roi((0, 0, 0), (2, 3, 4)), voxel_size=Coordinate((1, 1, 1)))) + >>> array = NumpyArray.from_gp_array(array) + >>> array.data + array([[[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]], + + [[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]]]) + Note: + This method creates a NumpyArray from a Gunpowder Array. + """ instance = cls.__new__(cls) instance._data = array.data instance._dtype = array.data.dtype @@ -45,6 +118,32 @@ def from_gp_array(cls, array: gp.Array): @classmethod def from_np_array(cls, array: np.ndarray, roi, voxel_size, axes): + """ + Create a NumpyArray from a numpy array. + + Args: + array (np.ndarray): The numpy array. + roi (Roi): The region of interest of the array. + voxel_size (Coordinate): The voxel size of the array. + axes (List[str]): The axes of the array. + Returns: + NumpyArray: The NumpyArray. + Raises: + ValueError: If the array does not have a data type. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.data + array([[[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]], + + [[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]]]) + Note: + This method creates a NumpyArray from a numpy array. + + """ instance = cls.__new__(cls) instance._data = array instance._dtype = array.dtype @@ -55,34 +154,151 @@ def from_np_array(cls, array: np.ndarray, roi, voxel_size, axes): @property def axes(self): + """ + Returns the axes of the array. + + Returns: + List[str]: The axes of the array. + Raises: + ValueError: If the array does not have axes. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.axes + ['z', 'y', 'x'] + Note: + This method is a property. It returns the axes of the array. + """ return self._axes @property def dims(self): + """ + Returns the number of dimensions of the array. + + Returns: + int: The number of dimensions of the array. + Raises: + ValueError: If the array does not have a dimension. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.dims + 3 + Note: + This method is a property. It returns the number of dimensions of the array. + """ return self._roi.dims @property def voxel_size(self): + """ + Returns the voxel size of the array. + + Returns: + Coordinate: The voxel size of the array. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.voxel_size + Coordinate((1, 1, 1)) + Note: + This method is a property. It returns the voxel size of the array. + """ return self._voxel_size @property def roi(self): + """ + Returns the region of interest of the array. + + Returns: + Roi: The region of interest of the array. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.roi + Roi((0, 0, 0), (2, 3, 4)) + Note: + This method is a property. It returns the region of interest of the array. + """ return self._roi @property def writable(self) -> bool: + """ + Returns whether the array is writable. + + Returns: + bool: Whether the array is writable. + Raises: + ValueError: If the array is not writable. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.writable + True + Note: + This method is a property. It returns whether the array is writable. + """ return True @property def data(self): + """ + Returns the numpy array. + + Returns: + np.ndarray: The numpy array. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.data + array([[[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]], + + [[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]]]) + Note: + This method is a property. It returns the numpy array. + """ return self._data @property def dtype(self): + """ + Returns the data type of the array. + + Returns: + np.dtype: The data type of the array. + Raises: + ValueError: If the array does not have a data type. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.dtype + dtype('float64') + Note: + This method is a property. It returns the data type of the array. + """ return self.data.dtype @property def num_channels(self): + """ + Returns the number of channels in the array. + + Returns: + int: The number of channels in the array. + Raises: + ValueError: If the array does not have a channel dimension. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((1, 2, 3, 4)), Roi((0, 0, 0), (1, 2, 3)), Coordinate((1, 1, 1)), ["b", "c", "z", "y", "x"]) + >>> array.num_channels + 1 + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.num_channels + Traceback (most recent call last): + ... + ValueError: Array does not have a channel dimension. + Note: + This method is a property. It returns the number of channels in the array. + """ try: channel_dim = self.axes.index("c") return self.data.shape[channel_dim] diff --git a/dacapo/experiments/datasplits/datasets/arrays/ones_array.py b/dacapo/experiments/datasplits/datasets/arrays/ones_array.py index 4fe0aaca1..2cc7f7518 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/ones_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/ones_array.py @@ -6,59 +6,400 @@ class OnesArray(Array): - """This is a wrapper around another `source_array` that simply provides ones - with the same metadata as the `source_array`.""" + """ + This is a wrapper around another `source_array` that simply provides ones + with the same metadata as the `source_array`. + + This is useful for creating a mask array that is the same size as the + original array, but with all values set to 1. + + Attributes: + source_array: The source array that this array is based on. + Methods: + like: Create a new OnesArray with the same metadata as another array. + attrs: Get the attributes of the array. + axes: Get the axes of the array. + dims: Get the dimensions of the array. + voxel_size: Get the voxel size of the array. + roi: Get the region of interest of the array. + writable: Check if the array is writable. + data: Get the data of the array. + dtype: Get the data type of the array. + num_channels: Get the number of channels of the array. + __getitem__: Get a subarray of the array. + Note: + This class is not meant to be instantiated directly. Instead, use the + `like` method to create a new OnesArray with the same metadata as + another array. + """ def __init__(self, array_config): + """ + Initialize the OnesArray with the given array configuration. + + Args: + array_config: The configuration of the source array. + Raises: + RuntimeError: If the source array is not specified in the + configuration. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> source_array_config = ArrayConfig(source_array) + >>> ones_array = OnesArray(source_array_config) + >>> ones_array.source_array + NumpyArray(data=array([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]), voxel_size=(1.0, 1.0, 1.0), roi=Roi((0, 0, 0), (10, 10, 10)), num_channels=1) + Notes: + This class is not meant to be instantiated directly. Instead, use the + `like` method to create a new OnesArray with the same metadata as + another array. + """ self._source_array = array_config.source_array_config.array_type( array_config.source_array_config ) @classmethod def like(cls, array: Array): + """ + Create a new OnesArray with the same metadata as another array. + + Args: + array: The source array. + Returns: + The new OnesArray with the same metadata as the source array. + Raises: + RuntimeError: If the source array is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray.like(source_array) + >>> ones_array.source_array + NumpyArray(data=array([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]), voxel_size=(1.0, 1.0, 1.0), roi=Roi((0, 0, 0), (10, 10, 10)), num_channels=1) + Notes: + This class is not meant to be instantiated directly. Instead, use the + `like` method to create a new OnesArray with the same metadata as + another array. + + """ instance = cls.__new__(cls) instance._source_array = array return instance @property def attrs(self): + """ + Get the attributes of the array. + + Returns: + An empty dictionary. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.attrs + {} + Notes: + This method is used to get the attributes of the array. The attributes + are stored as key-value pairs in a dictionary. This method returns an + empty dictionary because the OnesArray does not have any attributes. + """ return dict() @property def source_array(self) -> Array: + """ + Get the source array that this array is based on. + + Returns: + The source array. + Raises: + RuntimeError: If the source array is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.source_array + NumpyArray(data=array([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]), voxel_size=(1.0, 1.0, 1.0), roi=Roi((0, 0, 0), (10, 10, 10)), num_channels=1) + Notes: + This method is used to get the source array that this array is based on. + The source array is the array that the OnesArray is created from. This + method returns the source array that was specified when the OnesArray + was created. + """ return self._source_array @property def axes(self): + """ + Get the axes of the array. + + Returns: + The axes of the array. + Raises: + RuntimeError: If the axes are not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.axes + 'zyx' + Notes: + This method is used to get the axes of the array. The axes are the + order of the dimensions of the array. This method returns the axes of + the array that was specified when the OnesArray was created. + """ return self.source_array.axes @property def dims(self): + """ + Get the dimensions of the array. + + Returns: + The dimensions of the array. + Raises: + RuntimeError: If the dimensions are not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.dims + (10, 10, 10) + Notes: + This method is used to get the dimensions of the array. The dimensions + are the size of the array along each axis. This method returns the + dimensions of the array that was specified when the OnesArray was created. + """ return self.source_array.dims @property def voxel_size(self): + """ + Get the voxel size of the array. + + Returns: + The voxel size of the array. + Raises: + RuntimeError: If the voxel size is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.voxel_size + (1.0, 1.0, 1.0) + Notes: + This method is used to get the voxel size of the array. The voxel size + is the size of each voxel in the array. This method returns the voxel + size of the array that was specified when the OnesArray was created. + """ return self.source_array.voxel_size @property def roi(self): + """ + Get the region of interest of the array. + + Returns: + The region of interest of the array. + Raises: + RuntimeError: If the region of interest is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.roi + Roi((0, 0, 0), (10, 10, 10)) + Notes: + This method is used to get the region of interest of the array. The + region of interest is the region of the array that contains the data. + This method returns the region of interest of the array that was specified + when the OnesArray was created. + """ return self.source_array.roi @property def writable(self) -> bool: + """ + Check if the array is writable. + + Returns: + False. + Raises: + RuntimeError: If the writability of the array is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.writable + False + Notes: + This method is used to check if the array is writable. An array is + writable if it can be modified in place. This method returns False + because the OnesArray is read-only and cannot be modified. + """ return False @property def data(self): + """ + Get the data of the array. + + Returns: + The data of the array. + Raises: + RuntimeError: If the data is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.data + array([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]) + Notes: + This method is used to get the data of the array. The data is the + values that are stored in the array. This method returns a subarray + of the array with all values set to 1. + """ raise RuntimeError("Cannot get writable version of this data!") @property def dtype(self): + """ + Get the data type of the array. + + Returns: + The data type of the array. + Raises: + RuntimeError: If the data type is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.dtype + + Notes: + This method is used to get the data type of the array. The data type + is the type of the values that are stored in the array. This method + returns the data type of the array that was specified when the OnesArray + was created. + """ return bool @property def num_channels(self): + """ + Get the number of channels of the array. + + Returns: + The number of channels of the array. + Raises: + RuntimeError: If the number of channels is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.num_channels + 1 + Notes: + This method is used to get the number of channels of the array. The + number of channels is the number of values that are stored at each + voxel in the array. This method returns the number of channels of the + array that was specified when the OnesArray was created. + """ return self.source_array.num_channels def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Get a subarray of the array. + + Args: + roi: The region of interest. + Returns: + A subarray of the array with all values set to 1. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> roi = Roi((0, 0, 0), (10, 10, 10)) + >>> ones_array[roi] + array([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]) + Notes: + This method is used to get a subarray of the array. The subarray is + specified by the region of interest. This method returns a subarray + of the array with all values set to 1. + """ return np.ones_like(self.source_array.__getitem__(roi), dtype=bool) diff --git a/dacapo/experiments/datasplits/datasets/arrays/ones_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/ones_array_config.py index 649aaa390..152b357c2 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/ones_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/ones_array_config.py @@ -6,7 +6,20 @@ @attr.s class OnesArrayConfig(ArrayConfig): - """This array read data from the source array and then return a np.ones_like() version.""" + """ + This array read data from the source array and then return a np.ones_like() version. + + This is useful for creating a mask array from a source array. For example, if you have a + 2D array of data and you want to create a mask array that is the same shape as the data + array, you can use this class to create the mask array. + + Attributes: + source_array_config: The source array that you want to copy and fill with ones. + Methods: + create_array: Create the array. + Note: + This class is a subclass of ArrayConfig. + """ array_type = OnesArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py b/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py index f74d5bf1d..d406da5ac 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py @@ -8,9 +8,76 @@ class ResampledArray(Array): - """This is a zarr array""" + """ + This is a zarr array that is a resampled version of another array. + + Resampling is done by rescaling the source array with the given + upsample and downsample factors. The voxel size of the resampled array + is the voxel size of the source array divided by the downsample factor + and multiplied by the upsample factor. + + Attributes: + name: str + The name of the array + source_array: Array + The source array + upsample: Coordinate + The upsample factor for each dimension + downsample: Coordinate + The downsample factor for each dimension + interp_order: int + The order of the interpolation used for resampling + Methods: + attrs: Dict + Returns the attributes of the source array + axes: str + Returns the axes of the source array + dims: int + Returns the number of dimensions of the source array + voxel_size: Coordinate + Returns the voxel size of the resampled array + roi: Roi + Returns the region of interest of the resampled array + writable: bool + Returns whether the resampled array is writable + dtype: np.dtype + Returns the data type of the resampled array + num_channels: int + Returns the number of channels of the resampled array + data: np.ndarray + Returns the data of the resampled array + scale: Tuple[float] + Returns the scale of the resampled array + __getitem__(roi: Roi) -> np.ndarray + Returns the data of the resampled array within the given region of interest + _can_neuroglance() -> bool + Returns whether the source array can be visualized with neuroglance + _neuroglancer_layer() -> Dict + Returns the neuroglancer layer of the source array + _neuroglancer_source() -> Dict + Returns the neuroglancer source of the source array + _source_name() -> str + Returns the name of the source array + Note: + This class is a subclass of Array. + + + """ def __init__(self, array_config): + """ + Constructor of the ResampledArray class. + + Args: + array_config: ArrayConfig + The configuration of the array + Raises: + AssertionError: If the voxel size of the resampled array is not equal to the voxel size of the source array divided by the downsample factor and multiplied by the upsample factor + Examples: + >>> resampled_array = ResampledArray(array_config) + Note: + This constructor resamples the source array with the given upsample and downsample factors. + """ self.name = array_config.name self._source_array = array_config.source_array_config.array_type( array_config.source_array_config @@ -26,38 +93,149 @@ def __init__(self, array_config): @property def attrs(self): + """ + Returns the attributes of the source array. + + Returns: + Dict: The attributes of the source array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.attrs + Note: + This method returns the attributes of the source array. + + """ return self._source_array.attrs @property def axes(self): + """ + Returns the axes of the source array. + + Returns: + str: The axes of the source array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.axes + Note: + This method returns the axes of the source array. + """ return self._source_array.axes @property def dims(self) -> int: + """ + Returns the number of dimensions of the source array. + + Returns: + int: The number of dimensions of the source array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.dims + Note: + This method returns the number of dimensions of the source array. + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + Returns the voxel size of the resampled array. + + Returns: + Coordinate: The voxel size of the resampled array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.voxel_size + Note: + This method returns the voxel size of the resampled array. + """ return (self._source_array.voxel_size * self.downsample) / self.upsample @property def roi(self) -> Roi: + """ + Returns the region of interest of the resampled array. + + Returns: + Roi: The region of interest of the resampled array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.roi + Note: + This method returns the region of interest of the resampled array. + + """ return self._source_array.roi.snap_to_grid(self.voxel_size, mode="shrink") @property def writable(self) -> bool: + """ + Returns whether the resampled array is writable. + + Returns: + bool: True if the resampled array is writable, False otherwise + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.writable + Note: + This method returns whether the resampled array is writable. + + """ return False @property def dtype(self): + """ + Returns the data type of the resampled array. + + Returns: + np.dtype: The data type of the resampled array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.dtype + Note: + This method returns the data type of the resampled array. + """ return self._source_array.dtype @property def num_channels(self) -> int: + """ + Returns the number of channels of the resampled array. + + Returns: + int: The number of channels of the resampled array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.num_channels + Note: + This method returns the number of channels of the resampled array. + """ return self._source_array.num_channels @property def data(self): + """ + Returns the data of the resampled array. + + Returns: + np.ndarray: The data of the resampled array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.data + Note: + This method returns the data of the resampled array. + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." @@ -65,6 +243,19 @@ def data(self): @property def scale(self): + """ + Returns the scale of the resampled array. + + Returns: + Tuple[float]: The scale of the resampled array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.scale + Note: + This method returns the scale of the resampled array. + + """ spatial_scales = tuple(u / d for d, u in zip(self.downsample, self.upsample)) if "c" in self.axes: scales = list(spatial_scales) @@ -74,6 +265,21 @@ def scale(self): return spatial_scales def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Returns the data of the resampled array within the given region of interest. + + Args: + roi: Roi + The region of interest + Returns: + np.ndarray: The data of the resampled array within the given region of interest + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array[roi] + Note: + This method returns the data of the resampled array within the given region of interest. + """ snapped_roi = roi.snap_to_grid(self._source_array.voxel_size, mode="grow") resampled_array = funlib.persistence.Array( rescale( @@ -88,13 +294,61 @@ def __getitem__(self, roi: Roi) -> np.ndarray: return resampled_array.to_ndarray(roi) def _can_neuroglance(self): + """ + Returns whether the source array can be visualized with neuroglance. + + Returns: + bool: True if the source array can be visualized with neuroglance, False otherwise + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array._can_neuroglance() + Note: + This method returns whether the source array can be visualized with neuroglance. + """ return self._source_array._can_neuroglance() def _neuroglancer_layer(self): + """ + Returns the neuroglancer layer of the source array. + + Returns: + Dict: The neuroglancer layer of the source array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array._neuroglancer_layer() + Note: + This method returns the neuroglancer layer of the source array. + """ return self._source_array._neuroglancer_layer() def _neuroglancer_source(self): + """ + Returns the neuroglancer source of the source array. + + Returns: + Dict: The neuroglancer source of the source array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array._neuroglancer_source() + Note: + This method returns the neuroglancer source of the source array. + """ return self._source_array._neuroglancer_source() def _source_name(self): + """ + Returns the name of the source array. + + Returns: + str: The name of the source array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array._source_name() + Note: + This method returns the name of the source array. + """ return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py index e080b8304..63ca41aef 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py @@ -8,7 +8,20 @@ @attr.s class ResampledArrayConfig(ArrayConfig): - """This array will up or down sample an array into the desired voxel size.""" + """ + A configuration for a ResampledArray. This array will up or down sample an array into the desired voxel size. + + Attributes: + source_array_config (ArrayConfig): The Array that you want to upsample or downsample. + upsample (Coordinate): The amount by which to upsample! + downsample (Coordinate): The amount by which to downsample! + interp_order (bool): The order of the interpolation! + Methods: + create_array: Creates a ResampledArray from the configuration. + Note: + This class is meant to be used with the ArrayDataset class. + + """ array_type = ResampledArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/sum_array.py b/dacapo/experiments/datasplits/datasets/arrays/sum_array.py index 845b69810..fe5d6e470 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/sum_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/sum_array.py @@ -9,9 +9,56 @@ class SumArray(Array): - """ """ + """ + This class provides a sum array. This array is a virtual array that is created by summing + multiple source arrays. The source arrays must have the same shape and ROI. + + Attributes: + name: str + The name of the array. + _source_arrays: List[Array] + The source arrays to sum. + _source_array: Array + The first source array. + Methods: + __getitem__(roi: Roi) -> np.ndarray + Get the data for the given region of interest. + _can_neuroglance() -> bool + Check if neuroglance can be used. + _neuroglancer_source() -> Dict + Return the source for neuroglance. + _neuroglancer_layer() -> Tuple[neuroglancer.SegmentationLayer, Dict] + Return the neuroglancer layer. + _source_name() -> str + Return the source name. + Note: + This class is a subclass of Array. + """ def __init__(self, array_config): + """ + Initialize the SumArray. + + Args: + array_config: SumArrayConfig + The configuration for the sum array. + Returns: + SumArray: The sum array. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays.sum_array import SumArray + >>> from dacapo.experiments.datasplits.datasets.arrays.sum_array_config import SumArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays.tiff_array import TiffArray + >>> from dacapo.experiments.datasplits.datasets.arrays.tiff_array_config import TiffArrayConfig + >>> from funlib.geometry import Coordinate + >>> from pathlib import Path + >>> sum_array = SumArray(SumArrayConfig(name="sum", source_array_configs=[TiffArrayConfig(file_name=Path("data.tiff"), offset=Coordinate([0, 0, 0]), voxel_size=Coordinate([1, 1, 1]), axes=["x", "y", "z"])])) + Note: + This class is a subclass of Array. + + """ self.name = array_config.name self._source_arrays = [ source_config.array_type(source_config) @@ -21,34 +68,163 @@ def __init__(self, array_config): @property def axes(self): + """ + The axes of the array. + + Returns: + List[str]: The axes of the array. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.axes + ['x', 'y', 'z'] + Note: + This class is a subclass of Array. + """ return [x for x in self._source_array.axes if x != "c"] @property def dims(self) -> int: + """ + The number of dimensions of the array. + + Returns: + int: The number of dimensions of the array. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.dims + 3 + Note: + This class is a subclass of Array. + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + The size of each voxel in each dimension. + + Returns: + Coordinate: The size of each voxel in each dimension. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.voxel_size + Coordinate([1, 1, 1]) + Note: + This class is a subclass of Array. + """ return self._source_array.voxel_size @property def roi(self) -> Roi: + """ + The region of interest of the array. + + Args: + roi: Roi + The region of interest. + Returns: + Roi: The region of interest. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.roi + Roi(Coordinate([0, 0, 0]), Coordinate([100, 100, 100])) + Note: + This class is a subclass of Array. + """ return self._source_array.roi @property def writable(self) -> bool: + """ + Check if the array is writable. + + Args: + writable: bool + Check if the array is writable. + Returns: + bool: True if the array is writable, otherwise False. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.writable + False + Note: + This class is a subclass of Array. + """ return False @property def dtype(self): + """ + The data type of the array. + + Args: + dtype: np.uint8 + The data type of the array. + Returns: + np.uint8: The data type of the array. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.dtype + np.uint8 + Note: + This class is a subclass of Array. + + """ return np.uint8 @property def num_channels(self): + """ + The number of channels in the array. + + Args: + num_channels: Optional[int] + The number of channels in the array. + Returns: + Optional[int]: The number of channels in the array. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.num_channels + None + Note: + This class is a subclass of Array. + + """ return None @property def data(self): + """ + Get the data of the array. + + Args: + data: np.ndarray + The data of the array. + Returns: + np.ndarray: The data of the array. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.data + np.array([[[0, 0], [0, 0]], [[0, 0], [0, 0]]]) + Note: + This class is a subclass of Array. + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." @@ -56,20 +232,107 @@ def data(self): @property def attrs(self): + """ + Return the attributes of the array. + + Args: + attrs: Dict + The attributes of the array. + Returns: + Dict: The attributes of the array. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.attrs + {} + Note: + This class is a subclass of Array. + """ return self._source_array.attrs def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Get the data for the given region of interest. + + Args: + roi: Roi + The region of interest. + Returns: + np.ndarray: The data for the given region of interest. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array[roi] + np.array([[[0, 0], [0, 0]], [[0, 0], [0, 0]]]) + Note: + This class is a subclass of Array. + """ return np.sum( [source_array[roi] for source_array in self._source_arrays], axis=0 ) def _can_neuroglance(self): + """ + Check if neuroglance can be used. + + Args: + can_neuroglance: bool + Check if neuroglance can be used. + Returns: + bool: True if neuroglance can be used, otherwise False. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array._can_neuroglance() + False + Note: + This class is a subclass of Array. + """ return self._source_array._can_neuroglance() def _neuroglancer_source(self): + """ + Return the source for neuroglance. + + Args: + source: Dict + The source for neuroglance. + Returns: + Dict: The source for neuroglance. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array._neuroglancer_source() + {'source': 'precomputed://https://mybucket/segmentation', 'type': 'segmentation', 'voxel_size': [1, 1, 1]} + Note: + This class is a subclass of Array. + + """ return self._source_array._neuroglancer_source() def _neuroglancer_layer(self): + """ + Return the neuroglancer layer. + + Args: + layer: Tuple[neuroglancer.SegmentationLayer, Dict] + The neuroglancer layer. + Returns: + Tuple[neuroglancer.SegmentationLayer, Dict]: The neuroglancer layer. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array._neuroglancer_layer() + (SegmentationLayer(source={'source': 'precomputed://https://mybucket/segmentation', 'type': 'segmentation', 'voxel_size': [1, 1, 1]}, visible=False), {}) + Note: + This class is a subclass of Array. + + """ # Generates an Segmentation layer layer = neuroglancer.SegmentationLayer(source=self._neuroglancer_source()) @@ -79,4 +342,22 @@ def _neuroglancer_layer(self): return layer, kwargs def _source_name(self): + """ + Return the source name. + + Args: + source_name: str + The source name. + Returns: + str: The source name. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array._source_name() + 'data.tiff' + Note: + This class is a subclass of Array. + + """ return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/sum_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/sum_array_config.py index 4cc12ddd7..d235af715 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/sum_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/sum_array_config.py @@ -8,6 +8,16 @@ @attr.s class SumArrayConfig(ArrayConfig): + """ + This config class provides the necessary configuration for a sum + array. + + Attributes: + source_array_configs: List[ArrayConfig] + The Array of masks from which to take the union + Note: + This class is a subclass of ArrayConfig. + """ array_type = SumArray source_array_configs: List[ArrayConfig] = attr.ib( diff --git a/dacapo/experiments/datasplits/datasets/arrays/tiff_array.py b/dacapo/experiments/datasplits/datasets/arrays/tiff_array.py index ccdf50376..ed3ca776b 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/tiff_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/tiff_array.py @@ -13,7 +13,25 @@ class TiffArray(Array): - """This is a tiff array""" + """ + This class provides the necessary configuration for a tiff array. + + Attributes: + _offset: Coordinate + The offset of the array. + _file_name: Path + The file name of the tiff. + _voxel_size: Coordinate + The voxel size of the array. + _axes: List[str] + The axes of the array. + Methods: + attrs() -> Dict + Return the attributes of the tiff. + Note: + This class is a subclass of Array. + + """ _offset: Coordinate _file_name: Path @@ -21,6 +39,24 @@ class TiffArray(Array): _axes: List[str] def __init__(self, array_config): + """ + Initialize the TiffArray. + + Args: + array_config: TiffArrayConfig + The configuration for the tiff array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays.tiff_array import TiffArray + >>> from dacapo.experiments.datasplits.datasets.arrays.tiff_array_config import TiffArrayConfig + >>> from funlib.geometry import Coordinate + >>> from pathlib import Path + >>> tiff_array = TiffArray(TiffArrayConfig(file_name=Path("data.tiff"), offset=Coordinate([0, 0, 0]), voxel_size=Coordinate([1, 1, 1]), axes=["x", "y", "z"])) + Note: + This class is a subclass of Array. + """ super().__init__() self._file_name = array_config.file_name @@ -30,20 +66,76 @@ def __init__(self, array_config): @property def attrs(self): + """ + Return the attributes of the tiff. + + Returns: + Dict: The attributes of the tiff. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.attrs + {'axes': ['x', 'y', 'z'], 'offset': [0, 0, 0], 'voxel_size': [1, 1, 1]} + Note: + Tiffs have tons of different locations for metadata. + """ raise NotImplementedError( "Tiffs have tons of different locations for metadata." ) @property def axes(self) -> List[str]: + """ + Return the axes of the array. + + Returns: + List[str]: The axes of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.axes + ['x', 'y', 'z'] + Note: + Tiffs have tons of different locations for metadata. + """ return self._axes @property def dims(self) -> int: + """ + Return the number of dimensions of the array. + + Returns: + int: The number of dimensions of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.dims + 3 + Note: + Tiffs have tons of different locations for metadata. + """ return self.voxel_size.dims @lazy_property.LazyProperty def shape(self) -> Coordinate: + """ + Return the shape of the array. + + Returns: + Coordinate: The shape of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.shape + Coordinate([100, 100, 100]) + Note: + Tiffs have tons of different locations for metadata. + """ data_shape = self.data.shape spatial_shape = Coordinate( [data_shape[self.axes.index(axis)] for axis in self.spatial_axes] @@ -52,22 +144,94 @@ def shape(self) -> Coordinate: @lazy_property.LazyProperty def voxel_size(self) -> Coordinate: + """ + Return the voxel size of the array. + + Returns: + Coordinate: The voxel size of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.voxel_size + Coordinate([1, 1, 1]) + Note: + Tiffs have tons of different locations for metadata. + """ return self._voxel_size @lazy_property.LazyProperty def roi(self) -> Roi: + """ + Return the region of interest of the array. + + Returns: + Roi: The region of interest of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.roi + Roi([0, 0, 0], [100, 100, 100]) + Note: + Tiffs have tons of different locations for metadata. + """ return Roi(self._offset, self.shape) @property def writable(self) -> bool: + """ + Return whether the array is writable. + + Returns: + bool: Whether the array is writable. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.writable + False + Note: + Tiffs have tons of different locations for metadata. + """ return False @property def dtype(self): + """ + Return the data type of the array. + + Returns: + np.dtype: The data type of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.dtype + np.float32 + Note: + Tiffs have tons of different locations for metadata. + + """ return self.data.dtype @property def num_channels(self) -> Optional[int]: + """ + Return the number of channels of the array. + + Returns: + Optional[int]: The number of channels of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.num_channels + 1 + Note: + Tiffs have tons of different locations for metadata. + + """ if "c" in self.axes: return self.data.shape[self.axes.index("c")] else: @@ -75,8 +239,36 @@ def num_channels(self) -> Optional[int]: @property def spatial_axes(self) -> List[str]: + """ + Return the spatial axes of the array. + + Returns: + List[str]: The spatial axes of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.spatial_axes + ['x', 'y', 'z'] + Note: + Tiffs have tons of different locations for metadata. + """ return [c for c in self.axes if c != "c"] @lazy_property.LazyProperty def data(self): + """ + Return the data of the tiff. + + Returns: + np.ndarray: The data of the tiff. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.data + np.ndarray + Note: + Tiffs have tons of different locations for metadata. + """ return tifffile.TiffFile(self._file_name).values diff --git a/dacapo/experiments/datasplits/datasets/arrays/tiff_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/tiff_array_config.py index d1930e55a..b6c6c0a3f 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/tiff_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/tiff_array_config.py @@ -11,7 +11,21 @@ @attr.s class ZarrArrayConfig(ArrayConfig): - """This config class provides the necessary configuration for a tiff array""" + """ + This config class provides the necessary configuration for a tiff array + + Attributes: + file_name: Path + The file name of the tiff. + offset: Coordinate + The offset of the array. + voxel_size: Coordinate + The voxel size of the array. + axes: List[str] + The axes of the array. + Note: + This class is a subclass of ArrayConfig. + """ array_type = TiffArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 51046fd2e..2288a142c 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -20,9 +20,85 @@ class ZarrArray(Array): - """This is a zarr array""" + """ + This is a zarr array. + + Attributes: + name (str): The name of the array. + file_name (Path): The file name of the array. + dataset (str): The dataset name. + _axes (Optional[List[str]]): The axes of the array. + snap_to_grid (Optional[Coordinate]): The snap to grid. + Methods: + __init__(array_config): + Initializes the array type 'raw' and name for the DummyDataset instance. + __str__(): + Returns the string representation of the ZarrArray. + __repr__(): + Returns the string representation of the ZarrArray. + attrs(): + Returns the attributes of the array. + axes(): + Returns the axes of the array. + dims(): + Returns the dimensions of the array. + _daisy_array(): + Returns the daisy array. + voxel_size(): + Returns the voxel size of the array. + roi(): + Returns the region of interest of the array. + writable(): + Returns the boolean value of the array. + dtype(): + Returns the data type of the array. + num_channels(): + Returns the number of channels of the array. + spatial_axes(): + Returns the spatial axes of the array. + data(): + Returns the data of the array. + __getitem__(roi): + Returns the data of the array for the given region of interest. + __setitem__(roi, value): + Sets the data of the array for the given region of interest. + create_from_array_identifier(array_identifier, axes, roi, num_channels, voxel_size, dtype, write_size=None, name=None, overwrite=False): + Creates a new ZarrArray given an array identifier. + open_from_array_identifier(array_identifier, name=""): + Opens a new ZarrArray given an array identifier. + _can_neuroglance(): + Returns the boolean value of the array. + _neuroglancer_source(): + Returns the neuroglancer source of the array. + _neuroglancer_layer(): + Returns the neuroglancer layer of the array. + _transform_matrix(): + Returns the transform matrix of the array. + _output_dimensions(): + Returns the output dimensions of the array. + _source_name(): + Returns the source name of the array. + add_metadata(metadata): + Adds metadata to the array. + Notes: + This class is used to create a zarr array. + """ def __init__(self, array_config): + """ + Initializes the array type 'raw' and name for the DummyDataset instance. + + Args: + array_config (object): an instance of a configuration class that includes the name and + raw configuration of the data. + Raises: + NotImplementedError + If the method is not implemented in the derived class. + Examples: + >>> dataset = DummyDataset(dataset_config) + Notes: + This method is used to initialize the dataset. + """ super().__init__() self.name = array_config.name self.file_name = array_config.file_name @@ -33,17 +109,75 @@ def __init__(self, array_config): self.snap_to_grid = array_config.snap_to_grid def __str__(self): + """ + Returns the string representation of the ZarrArray. + + Args: + ZarrArray (str): The string representation of the ZarrArray. + Returns: + str: The string representation of the ZarrArray. + Raises: + NotImplementedError + Examples: + >>> print(ZarrArray) + Notes: + This method is used to return the string representation of the ZarrArray. + """ return f"ZarrArray({self.file_name}, {self.dataset})" def __repr__(self): + """ + Returns the string representation of the ZarrArray. + + Args: + ZarrArray (str): The string representation of the ZarrArray. + Returns: + str: The string representation of the ZarrArray. + Raises: + NotImplementedError + Examples: + >>> print(ZarrArray) + Notes: + This method is used to return the string representation of the ZarrArray. + + """ return f"ZarrArray({self.file_name}, {self.dataset})" @property def attrs(self): + """ + Returns the attributes of the array. + + Args: + attrs (Any): The attributes of the array. + Returns: + Any: The attributes of the array. + Raises: + NotImplementedError + Examples: + >>> attrs() + Notes: + This method is used to return the attributes of the array. + + """ return self.data.attrs @property def axes(self): + """ + Returns the axes of the array. + + Args: + axes (List[str]): The axes of the array. + Returns: + List[str]: The axes of the array. + Raises: + NotImplementedError + Examples: + >>> axes() + Notes: + This method is used to return the axes of the array. + """ if self._axes is not None: return self._axes try: @@ -58,18 +192,77 @@ def axes(self): @property def dims(self) -> int: + """ + Returns the dimensions of the array. + + Args: + dims (int): The dimensions of the array. + Returns: + int: The dimensions of the array. + Raises: + NotImplementedError + Examples: + >>> dims() + Notes: + This method is used to return the dimensions of the array. + + """ return self.voxel_size.dims @lazy_property.LazyProperty def _daisy_array(self) -> funlib.persistence.Array: + """ + Returns the daisy array. + + Args: + voxel_size (Coordinate): The voxel size. + Returns: + funlib.persistence.Array: The daisy array. + Raises: + NotImplementedError + Examples: + >>> _daisy_array() + Notes: + This method is used to return the daisy array. + + """ return funlib.persistence.open_ds(f"{self.file_name}", self.dataset) @lazy_property.LazyProperty def voxel_size(self) -> Coordinate: + """ + Returns the voxel size of the array. + + Args: + voxel_size (Coordinate): The voxel size. + Returns: + Coordinate: The voxel size of the array. + Raises: + NotImplementedError + Examples: + >>> voxel_size() + Notes: + This method is used to return the voxel size of the array. + + """ return self._daisy_array.voxel_size @lazy_property.LazyProperty def roi(self) -> Roi: + """ + Returns the region of interest of the array. + + Args: + roi (Roi): The region of interest. + Returns: + Roi: The region of interest of the array. + Raises: + NotImplementedError + Examples: + >>> roi() + Notes: + This method is used to return the region of interest of the array. + """ if self.snap_to_grid is not None: return self._daisy_array.roi.snap_to_grid(self.snap_to_grid, mode="shrink") else: @@ -77,32 +270,131 @@ def roi(self) -> Roi: @property def writable(self) -> bool: + """ + Returns the boolean value of the array. + + Args: + writable (bool): The boolean value of the array. + Returns: + bool: The boolean value of the array. + Raises: + NotImplementedError + Examples: + >>> writable() + Notes: + This method is used to return the boolean value of the array. + """ return True @property def dtype(self) -> Any: + """ + Returns the data type of the array. + + Args: + dtype (Any): The data type of the array. + Returns: + Any: The data type of the array. + Raises: + NotImplementedError + Examples: + >>> dtype() + Notes: + This method is used to return the data type of the array. + """ return self.data.dtype @property def num_channels(self) -> Optional[int]: + """ + Returns the number of channels of the array. + + Args: + num_channels (Optional[int]): The number of channels of the array. + Returns: + Optional[int]: The number of channels of the array. + Raises: + NotImplementedError + Examples: + >>> num_channels() + Notes: + This method is used to return the number of channels of the array. + + """ return None if "c" not in self.axes else self.data.shape[self.axes.index("c")] @property def spatial_axes(self) -> List[str]: + """ + Returns the spatial axes of the array. + + Args: + spatial_axes (List[str]): The spatial axes of the array. + Returns: + List[str]: The spatial axes of the array. + Raises: + NotImplementedError + Examples: + >>> spatial_axes() + Notes: + This method is used to return the spatial axes of the array. + + """ return [ax for ax in self.axes if ax not in set(["c", "b"])] @property def data(self) -> Any: + """ + Returns the data of the array. + + Args: + data (Any): The data of the array. + Returns: + Any: The data of the array. + Raises: + NotImplementedError + Examples: + >>> data() + Notes: + This method is used to return the data of the array. + """ zarr_container = zarr.open(str(self.file_name)) return zarr_container[self.dataset] def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Returns the data of the array for the given region of interest. + + Args: + roi (Roi): The region of interest. + Returns: + np.ndarray: The data of the array for the given region of interest. + Raises: + NotImplementedError + Examples: + >>> __getitem__(roi) + Notes: + This method is used to return the data of the array for the given region of interest. + """ data: np.ndarray = funlib.persistence.Array( self.data, self.roi, self.voxel_size ).to_ndarray(roi=roi) return data def __setitem__(self, roi: Roi, value: np.ndarray): + """ + Sets the data of the array for the given region of interest. + + Args: + roi (Roi): The region of interest. + value (np.ndarray): The value to set. + Raises: + NotImplementedError + Examples: + >>> __setitem__(roi, value) + Notes: + This method is used to set the data of the array for the given region of interest. + """ funlib.persistence.Array(self.data, self.roi, self.voxel_size)[roi] = value @classmethod @@ -120,7 +412,26 @@ def create_from_array_identifier( ): """ Create a new ZarrArray given an array identifier. It is assumed that - this array_identifier points to a dataset that does not yet exist + this array_identifier points to a dataset that does not yet exist. + + Args: + array_identifier (ArrayIdentifier): The array identifier. + axes (List[str]): The axes of the array. + roi (Roi): The region of interest. + num_channels (int): The number of channels. + voxel_size (Coordinate): The voxel size. + dtype (Any): The data type. + write_size (Optional[Coordinate]): The write size. + name (Optional[str]): The name of the array. + overwrite (bool): The boolean value to overwrite the array. + Returns: + ZarrArray: The ZarrArray. + Raises: + NotImplementedError + Examples: + >>> create_from_array_identifier(array_identifier, axes, roi, num_channels, voxel_size, dtype, write_size=None, name=None, overwrite=False) + Notes: + This method is used to create a new ZarrArray given an array identifier. """ if write_size is None: # total storage per block is approx c*x*y*z*dtype_size @@ -214,6 +525,21 @@ def create_from_array_identifier( @classmethod def open_from_array_identifier(cls, array_identifier, name=""): + """ + Opens a new ZarrArray given an array identifier. + + Args: + array_identifier (ArrayIdentifier): The array identifier. + name (str): The name of the array. + Returns: + ZarrArray: The ZarrArray. + Raises: + NotImplementedError + Examples: + >>> open_from_array_identifier(array_identifier, name="") + Notes: + This method is used to open a new ZarrArray given an array identifier. + """ zarr_array = cls.__new__(cls) zarr_array.name = name zarr_array.file_name = array_identifier.container @@ -224,9 +550,38 @@ def open_from_array_identifier(cls, array_identifier, name=""): return zarr_array def _can_neuroglance(self) -> bool: + """ + Returns the boolean value of the array. + + Args: + can_neuroglance (bool): The boolean value of the array. + Returns: + bool: The boolean value of the array. + Raises: + NotImplementedError + Examples: + >>> can_neuroglance() + Notes: + This method is used to return the boolean value of the array. + """ return True def _neuroglancer_source(self): + """ + Returns the neuroglancer source of the array. + + Args: + neuroglancer.LocalVolume: The neuroglancer source of the array. + Returns: + neuroglancer.LocalVolume: The neuroglancer source of the array. + Raises: + NotImplementedError + Examples: + >>> neuroglancer_source() + Notes: + This method is used to return the neuroglancer source of the array. + + """ d = open_ds(str(self.file_name), self.dataset) return neuroglancer.LocalVolume( data=d.data, @@ -239,10 +594,38 @@ def _neuroglancer_source(self): ) def _neuroglancer_layer(self) -> Tuple[neuroglancer.ImageLayer, Dict[str, Any]]: + """ + Returns the neuroglancer layer of the array. + + Args: + layer (neuroglancer.ImageLayer): The neuroglancer layer of the array. + Returns: + Tuple[neuroglancer.ImageLayer, Dict[str, Any]]: The neuroglancer layer of the array. + Raises: + NotImplementedError + Examples: + >>> neuroglancer_layer() + Notes: + This method is used to return the neuroglancer layer of the array. + """ layer = neuroglancer.ImageLayer(source=self._neuroglancer_source()) return layer def _transform_matrix(self): + """ + Returns the transform matrix of the array. + + Args: + transform_matrix (List[List[float]]): The transform matrix of the array. + Returns: + List[List[float]]: The transform matrix of the array. + Raises: + NotImplementedError + Examples: + >>> transform_matrix() + Notes: + This method is used to return the transform matrix of the array. + """ is_zarr = self.file_name.name.endswith(".zarr") if is_zarr: offset = self.roi.offset @@ -267,6 +650,20 @@ def _transform_matrix(self): return [[0] * i + [1] + [0] * (self.dims - i) for i in range(self.dims)] def _output_dimensions(self) -> Dict[str, Tuple[float, str]]: + """ + Returns the output dimensions of the array. + + Args: + output_dimensions (Dict[str, Tuple[float, str]]): The output dimensions of the array. + Returns: + Dict[str, Tuple[float, str]]: The output dimensions of the array. + Raises: + NotImplementedError + Examples: + >>> output_dimensions() + Notes: + This method is used to return the output dimensions of the array. + """ is_zarr = self.file_name.name.endswith(".zarr") if is_zarr: spatial_dimensions = OrderedDict() @@ -282,9 +679,37 @@ def _output_dimensions(self) -> Dict[str, Tuple[float, str]]: } def _source_name(self) -> str: + """ + Returns the source name of the array. + + Args: + source_name (str): The source name of the array. + Returns: + str: The source name of the array. + Raises: + NotImplementedError + Examples: + >>> source_name() + Notes: + This method is used to return the source name of the array. + + """ return self.name def add_metadata(self, metadata: Dict[str, Any]) -> None: + """ + Adds metadata to the array. + + Args: + metadata (Dict[str, Any]): The metadata to add to the array. + Raises: + NotImplementedError + Examples: + >>> add_metadata(metadata) + Notes: + This method is used to add metadata to the array. + + """ dataset = zarr.open(self.file_name, mode="a")[self.dataset] for k, v in metadata.items(): dataset.attrs[k] = v diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array_config.py index 69bce2378..b667e3768 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array_config.py @@ -12,7 +12,29 @@ @attr.s class ZarrArrayConfig(ArrayConfig): - """This config class provides the necessary configuration for a zarr array""" + """ + This config class provides the necessary configuration for a zarr array. + + A zarr array is a container for large, multi-dimensional arrays. It is similar to HDF5, but is designed to work + with large arrays that do not fit into memory. Zarr arrays can be stored on disk or in the cloud + and can be accessed concurrently by multiple processes. Zarr arrays can be compressed and + support chunked, N-dimensional arrays. + + Attributes: + file_name: Path + The file name of the zarr container. + dataset: str + The name of your dataset. May include '/' characters for nested heirarchies + snap_to_grid: Optional[Coordinate] + If you need to make sure your ROI's align with a specific voxel_size + _axes: Optional[List[str]] + The axes of your data! + Methods: + verify() -> Tuple[bool, str] + Check whether this is a valid Array + Note: + This class is a subclass of ArrayConfig. + """ array_type = ZarrArray @@ -37,6 +59,23 @@ class ZarrArrayConfig(ArrayConfig): def verify(self) -> Tuple[bool, str]: """ Check whether this is a valid Array + + Returns: + Tuple[bool, str]: A tuple of a boolean and a string. The boolean indicates whether the Array is valid or not. + The string provides a reason why the Array is not valid. + Raises: + NotImplementedError: This method is not implemented for this Array + Examples: + >>> zarr_array_config = ZarrArrayConfig( + ... file_name=Path("data.zarr"), + ... dataset="data", + ... snap_to_grid=Coordinate(1, 1, 1), + ... _axes=["x", "y", "z"] + ... ) + >>> zarr_array_config.verify() + (True, 'No validation for this Array') + Note: + This method is not implemented for this Array """ if not self.file_name.exists(): return False, f"{self.file_name} does not exist!" diff --git a/dacapo/experiments/datasplits/datasets/dataset.py b/dacapo/experiments/datasplits/datasets/dataset.py index 663805227..ced4f58d6 100644 --- a/dacapo/experiments/datasplits/datasets/dataset.py +++ b/dacapo/experiments/datasplits/datasets/dataset.py @@ -15,6 +15,20 @@ class Dataset(ABC): mask (Array, optional): The mask for the data. weight (int, optional): The weight of the dataset. sample_points (list[Coordinate], optional): The list of sample points in the dataset. + Methods: + __eq__(other): + Overloaded equality operator for dataset objects. + __hash__(): + Calculates a hash for the dataset. + __repr__(): + Returns the official string representation of the dataset object. + __str__(): + Returns the string representation of the dataset object. + _neuroglancer_layers(prefix="", exclude_layers=None): + Generates neuroglancer layers for raw, gt and mask if they can be viewed by neuroglance, excluding those in + the exclude_layers. + Notes: + This class is a base class and should not be instantiated. """ name: str @@ -30,9 +44,17 @@ def __eq__(self, other: Any) -> bool: Args: other (Any): The object to compare with the dataset. - Returns: bool: True if the object is also a dataset and they have the same name, False otherwise. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> dataset1 = Dataset("dataset1") + >>> dataset2 = Dataset("dataset2") + >>> dataset1 == dataset2 + False + Notes: + This method is used to compare two dataset objects. """ return isinstance(other, type(self)) and self.name == other.name @@ -42,6 +64,14 @@ def __hash__(self) -> int: Returns: int: The hash of the dataset name. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> dataset = Dataset("dataset") + >>> hash(dataset) + 123456 + Notes: + This method is used to calculate a hash for the dataset. """ return hash(self.name) @@ -51,6 +81,14 @@ def __repr__(self) -> str: Returns: str: String representation of the dataset. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> dataset = Dataset("dataset") + >>> dataset + Dataset(dataset) + Notes: + This method is used to return the official string representation of the dataset object. """ return f"Dataset({self.name})" @@ -58,8 +96,18 @@ def __str__(self) -> str: """ Returns the string representation of the dataset object. + Args: + self (Dataset): The dataset object. Returns: str: String representation of the dataset. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> dataset = Dataset("dataset") + >>> print(dataset) + Dataset(dataset) + Notes: + This method is used to return the string representation of the dataset object. """ return f"Dataset({self.name})" @@ -71,9 +119,16 @@ def _neuroglancer_layers(self, prefix="", exclude_layers=None): Args: prefix (str, optional): A prefix to be added to the layer names. exclude_layers (set, optional): A set of layer names to exclude. - Returns: dict: A dictionary containing layer names as keys and corresponding neuroglancer layer as values. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> dataset = Dataset("dataset") + >>> dataset._neuroglancer_layers() + {"raw": neuroglancer_layer} + Notes: + This method is used to generate neuroglancer layers for raw, gt and mask if they can be viewed by neuroglance. """ layers = {} exclude_layers = exclude_layers if exclude_layers is not None else set() diff --git a/dacapo/experiments/datasplits/datasets/dataset_config.py b/dacapo/experiments/datasplits/datasets/dataset_config.py index c860d600e..4217eb00e 100644 --- a/dacapo/experiments/datasplits/datasets/dataset_config.py +++ b/dacapo/experiments/datasplits/datasets/dataset_config.py @@ -5,7 +5,8 @@ @attr.s class DatasetConfig: - """A class used to define configuration for datasets. This provides the + """ + A class used to define configuration for datasets. This provides the framework to create a Dataset instance. Attributes: @@ -18,11 +19,12 @@ class DatasetConfig: A numeric value that indicates how frequently this dataset should be sampled in comparison to others. Higher the weight, more frequently it gets sampled. - Methods: verify: Checks and validates the dataset configuration. The specific rules for validation need to be defined by the user. + Notes: + This class is used to create a configuration object for datasets. """ name: str = attr.ib( @@ -51,5 +53,13 @@ def verify(self) -> Tuple[bool, str]: Returns: tuple: A tuple of boolean value indicating the check (True or False) and message specifying result of validation. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> dataset_config = DatasetConfig(name="sample_dataset") + >>> dataset_config.verify() + (True, "No validation for this DataSet") + Notes: + This method is used to validate the configuration of the dataset. """ return True, "No validation for this DataSet" diff --git a/dacapo/experiments/datasplits/datasets/dummy_dataset.py b/dacapo/experiments/datasplits/datasets/dummy_dataset.py index cec9e05b4..4fc34e84b 100644 --- a/dacapo/experiments/datasplits/datasets/dummy_dataset.py +++ b/dacapo/experiments/datasplits/datasets/dummy_dataset.py @@ -3,20 +3,35 @@ class DummyDataset(Dataset): - """DummyDataset is a child class of the Dataset. This class has property 'raw' of Array type and a name. + """ + DummyDataset is a child class of the Dataset. This class has property 'raw' of Array type and a name. - Args: - dataset_config (object): an instance of a configuration class. + Attributes: + raw: Array + The raw data. + Methods: + __init__(dataset_config): + Initializes the array type 'raw' and name for the DummyDataset instance. + Notes: + This class is used to create a dataset with raw data. """ raw: Array def __init__(self, dataset_config): - """Initializes the array type 'raw' and name for the DummyDataset instance. + """ + Initializes the array type 'raw' and name for the DummyDataset instance. Args: dataset_config (object): an instance of a configuration class that includes the name and raw configuration of the data. + Raises: + NotImplementedError + If the method is not implemented in the derived class. + Examples: + >>> dataset = DummyDataset(dataset_config) + Notes: + This method is used to initialize the dataset. """ super().__init__() self.name = dataset_config.name diff --git a/dacapo/experiments/datasplits/datasets/dummy_dataset_config.py b/dacapo/experiments/datasplits/datasets/dummy_dataset_config.py index ecdf3e36e..6aaefc98a 100644 --- a/dacapo/experiments/datasplits/datasets/dummy_dataset_config.py +++ b/dacapo/experiments/datasplits/datasets/dummy_dataset_config.py @@ -15,9 +15,10 @@ class DummyDatasetConfig(DatasetConfig): Attributes: dataset_type : Clearly mentions the type of dataset raw_config : This attribute holds the configurations related to dataset arrays. - Methods: verify: A dummy verification method for testing purposes, always returns False and a message. + Notes: + This class is used to create a configuration object for the dummy dataset. """ dataset_type = DummyDataset @@ -25,10 +26,20 @@ class DummyDatasetConfig(DatasetConfig): raw_config: ArrayConfig = attr.ib(DummyArrayConfig(name="dummy_array")) def verify(self) -> Tuple[bool, str]: - """A dummy method that always indicates the dataset config is not valid. + """ + A dummy method that always indicates the dataset config is not valid. Returns: A tuple of False and a message indicating the invalidity. + Raises: + NotImplementedError + If the method is not implemented in the derived class. + Examples: + >>> dataset_config = DummyDatasetConfig(raw_config=DummyArrayConfig(name="dummy_array")) + >>> dataset_config.verify() + (False, "This is a DummyDatasetConfig and is never valid") + Notes: + This method is used to validate the configuration of the dataset. """ return False, "This is a DummyDatasetConfig and is never valid" diff --git a/dacapo/experiments/datasplits/datasets/graphstores/graph_source_config.py b/dacapo/experiments/datasplits/datasets/graphstores/graph_source_config.py index d7d587d78..1a2a7745f 100644 --- a/dacapo/experiments/datasplits/datasets/graphstores/graph_source_config.py +++ b/dacapo/experiments/datasplits/datasets/graphstores/graph_source_config.py @@ -3,9 +3,17 @@ @attr.s class GraphStoreConfig: - """Base class for graph store configurations. Each subclass of a + """ + Base class for graph store configurations. Each subclass of a `GraphStore` should have a corresponding config class derived from `GraphStoreConfig`. + + Attributes: + store_type (class): The type of graph store that is being configured. + Methods: + verify: A method to verify the validity of the configuration. + Notes: + This class is used to create a configuration object for the graph store. """ pass diff --git a/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py b/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py index 040c5baa3..1b81e1c1f 100644 --- a/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py +++ b/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py @@ -7,12 +7,46 @@ class RawGTDataset(Dataset): + """ + A dataset that contains raw and ground truth data. Optionally, it can also contain a mask. + + Attributes: + raw: Array + The raw data. + gt: Array + The ground truth data. + mask: Optional[Array] + The mask data. + sample_points: Optional[List[Coordinate]] + The sample points in the graph. + weight: Optional[float] + The weight of the dataset. + Methods: + __init__(dataset_config): + Initialize the dataset. + Notes: + This class is a base class and should not be instantiated. + """ raw: Array gt: Array mask: Optional[Array] sample_points: Optional[List[Coordinate]] def __init__(self, dataset_config): + """ + Initialize the dataset. + + Args: + dataset_config: DataSplitConfig + The configuration of the dataset. + Raises: + NotImplementedError + If the method is not implemented in the derived class. + Examples: + >>> dataset = RawGTDataset(dataset_config) + Notes: + This method is used to initialize the dataset. + """ self.name = dataset_config.name self.raw = dataset_config.raw_config.array_type(dataset_config.raw_config) self.gt = dataset_config.gt_config.array_type(dataset_config.gt_config) diff --git a/dacapo/experiments/datasplits/datasets/raw_gt_dataset_config.py b/dacapo/experiments/datasplits/datasets/raw_gt_dataset_config.py index 705bcb467..e967b83d6 100644 --- a/dacapo/experiments/datasplits/datasets/raw_gt_dataset_config.py +++ b/dacapo/experiments/datasplits/datasets/raw_gt_dataset_config.py @@ -27,6 +27,10 @@ class RawGTDatasetConfig(DatasetConfig): equal to zero on voxels where the mask is 1. sample_points (Optional[List[Coordinate]]): An optional list of points around which training samples will be extracted. + Methods: + verify: A method to verify the validity of the configuration. + Notes: + This class is used to create a configuration object for the standard dataset with both raw and GT Array. """ dataset_type = RawGTDataset diff --git a/dacapo/experiments/datasplits/datasplit.py b/dacapo/experiments/datasplits/datasplit.py index 62eaa4b27..ddf9d6ee2 100644 --- a/dacapo/experiments/datasplits/datasplit.py +++ b/dacapo/experiments/datasplits/datasplit.py @@ -7,10 +7,45 @@ class DataSplit(ABC): + """ + A class for creating a simple train dataset and no validation dataset. It is derived from `DataSplit` class. + It is used to split the data into training and validation datasets. The training and validation datasets are + used to train and validate the model respectively. + + Attributes: + train : list + The list containing training datasets. In this class, it contains only one dataset for training. + validate : list + The list containing validation datasets. In this class, it is an empty list as no validation dataset is set. + Methods: + __init__(self, datasplit_config): + The constructor for DummyDataSplit class. It initialises a list with training datasets according to the input configuration. + Notes: + This class is used to split the data into training and validation datasets. + """ train: List[Dataset] validate: Optional[List[Dataset]] def _neuroglancer(self, embedded=False): + """ + A method to visualize the data in Neuroglancer. It creates a Neuroglancer viewer and adds the layers of the training and validation datasets to it. + + Args: + embedded : bool + A boolean flag to indicate if the Neuroglancer viewer is to be embedded in the notebook. + Returns: + viewer : obj + The Neuroglancer viewer object. + Raises: + Exception + If the model setup cannot be loaded, an Exception is thrown which is logged and handled by training the model without head matching. + Examples: + >>> viewer = datasplit._neuroglancer(embedded=True) + Notes: + This function is called by the DataSplit class to visualize the data in Neuroglancer. + It creates a Neuroglancer viewer and adds the layers of the training and validation datasets to it. + Neuroglancer is a powerful tool for visualizing large-scale volumetric data. + """ neuroglancer.set_server_bind_address("0.0.0.0") viewer = neuroglancer.Viewer() with viewer.txn() as s: diff --git a/dacapo/experiments/datasplits/datasplit_config.py b/dacapo/experiments/datasplits/datasplit_config.py index f00069960..992113d47 100644 --- a/dacapo/experiments/datasplits/datasplit_config.py +++ b/dacapo/experiments/datasplits/datasplit_config.py @@ -8,17 +8,16 @@ class DataSplitConfig: """ A class used to create a DataSplit configuration object. - Attributes - ---------- - name : str - A name for the datasplit. This name will be saved so it can be found - and reused easily. It is recommended to keep it short and avoid special - characters. - - Methods - ------- - verify() -> Tuple[bool, str]: - Validates if it is a valid data split configuration. + Attributes: + name : str + A name for the datasplit. This name will be saved so it can be found + and reused easily. It is recommended to keep it short and avoid special + characters. + Methods: + verify() -> Tuple[bool, str]: + Validates if it is a valid data split configuration. + Notes: + This class is used to create a DataSplit configuration object. """ name: str = attr.ib( @@ -33,10 +32,18 @@ def verify(self) -> Tuple[bool, str]: """ Validates if the current configuration is a valid data split configuration. - Returns - ------- - Tuple[bool, str] - True if the configuration is valid, - False otherwise along with respective validation error message. + Returns: + Tuple[bool, str] + True if the configuration is valid, + False otherwise along with respective validation error message. + Raises: + NotImplementedError + If the method is not implemented in the derived class. + Examples: + >>> datasplit_config = DataSplitConfig(name="datasplit") + >>> datasplit_config.verify() + (True, "No validation for this DataSplit") + Notes: + This method is used to validate the configuration of DataSplit. """ return True, "No validation for this DataSplit" diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index 8f177e187..d9a91d5db 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -22,6 +22,24 @@ def is_zarr_group(file_name: str, dataset: str): + """ + Check if the dataset is a Zarr group. If the dataset is a Zarr group, it will return True, otherwise False. + + Args: + file_name : str + The name of the file. + dataset : str + The name of the dataset. + Returns: + bool : True if the dataset is a Zarr group, otherwise False. + Raises: + FileNotFoundError + If the file does not exist, a FileNotFoundError is raised. + Examples: + >>> is_zarr_group(file_name, dataset) + Notes: + This function is used to check if the dataset is a Zarr group. + """ zarr_file = zarr.open(str(file_name)) return isinstance(zarr_file[dataset], zarr.hierarchy.Group) @@ -29,6 +47,26 @@ def is_zarr_group(file_name: str, dataset: str): def resize_if_needed( array_config: ZarrArrayConfig, target_resolution: Coordinate, extra_str="" ): + """ + Resize the array if needed. If the array needs to be resized, it will return the resized array, otherwise it will return the original array. + + Args: + array_config : obj + The configuration of the array. + target_resolution : obj + The target resolution. + extra_str : str + An extra string. + Returns: + obj : The resized array if needed, otherwise the original array. + Raises: + FileNotFoundError + If the file does not exist, a FileNotFoundError is raised. + Examples: + >>> resize_if_needed(array_config, target_resolution, extra_str) + Notes: + This function is used to resize the array if needed. + """ zarr_array = ZarrArray(array_config) raw_voxel_size = zarr_array.voxel_size @@ -49,6 +87,28 @@ def resize_if_needed( def get_right_resolution_array_config( container: Path, dataset, target_resolution, extra_str="" ): + """ + Get the right resolution array configuration. It will return the right resolution array configuration. + + Args: + container : obj + The container. + dataset : str + The dataset. + target_resolution : obj + The target resolution. + extra_str : str + An extra string. + Returns: + obj : The right resolution array configuration. + Raises: + FileNotFoundError + If the file does not exist, a FileNotFoundError is raised. + Examples: + >>> get_right_resolution_array_config(container, dataset, target_resolution, extra_str) + Notes: + This function is used to get the right resolution array configuration. + """ level = 0 current_dataset_path = Path(dataset, f"s{level}") if not (container / current_dataset_path).exists(): @@ -80,7 +140,35 @@ def get_right_resolution_array_config( class CustomEnumMeta(EnumMeta): + """ + Custom Enum Meta class to raise KeyError when an invalid option is passed. + + Attributes: + _member_names_ : list + The list of member names. + Methods: + __getitem__(self, item) + A method to get the item. + Notes: + This class is used to raise KeyError when an invalid option is passed. + """ def __getitem__(self, item): + """ + Get the item. + + Args: + item : obj + The item. + Returns: + obj : The item. + Raises: + KeyError + If the item is not a valid option, a KeyError is raised. + Examples: + >>> __getitem__(item) + Notes: + This function is used to get the item. + """ if item not in self._member_names_: raise KeyError( f"{item} is not a valid option of {self.__name__}, the valid options are {self._member_names_}" @@ -89,21 +177,99 @@ def __getitem__(self, item): class CustomEnum(Enum, metaclass=CustomEnumMeta): + """ + A custom Enum class to raise KeyError when an invalid option is passed. + + Attributes: + __str__ : str + The string representation of the class. + Methods: + __str__(self) + A method to get the string representation of the class. + Notes: + This class is used to raise KeyError when an invalid option is passed. + """ def __str__(self) -> str: + """ + Get the string representation of the class. + + Args: + self : obj + The object. + Returns: + str : The string representation of the class. + Raises: + KeyError + If the item is not a valid option, a KeyError is raised. + Examples: + >>> __str__() + Notes: + This function is used to get the string representation of the class. + """ return self.name class DatasetType(CustomEnum): + """ + An Enum class to represent the dataset type. It is derived from `CustomEnum` class. + + Attributes: + val : int + The value of the dataset type. + train : int + The training dataset type. + Methods: + __str__(self) + A method to get the string representation of the class. + Notes: + This class is used to represent the dataset type. + """ val = 1 train = 2 class SegmentationType(CustomEnum): + """ + An Enum class to represent the segmentation type. It is derived from `CustomEnum` class. + + Attributes: + semantic : int + The semantic segmentation type. + instance : int + The instance segmentation type. + Methods: + __str__(self) + A method to get the string representation of the class. + Notes: + This class is used to represent the segmentation type. + """ semantic = 1 instance = 2 class DatasetSpec: + """ + A class for dataset specification. It is used to specify the dataset. + + Attributes: + dataset_type : obj + The dataset type. + raw_container : obj + The raw container. + raw_dataset : str + The raw dataset. + gt_container : obj + The ground truth container. + gt_dataset : str + The ground truth dataset. + Methods: + __init__(dataset_type, raw_container, raw_dataset, gt_container, gt_dataset) + Initializes the DatasetSpec class with the specified dataset type, raw container, raw dataset, ground truth container, and ground truth dataset. + __str__(self) + A method to get the string representation of the class. + Notes: + This class is used to specify the dataset. + """ def __init__( self, dataset_type: Union[str, DatasetType], @@ -112,6 +278,28 @@ def __init__( gt_container: Union[str, Path], gt_dataset: str, ): + """ + Initializes the DatasetSpec class with the specified dataset type, raw container, raw dataset, ground truth container, and ground truth dataset. + + Args: + dataset_type : obj + The dataset type. + raw_container : obj + The raw container. + raw_dataset : str + The raw dataset. + gt_container : obj + The ground truth container. + gt_dataset : str + The ground truth dataset. + Raises: + KeyError + If the item is not a valid option, a KeyError is raised. + Methods: + __init__(dataset_type, raw_container, raw_dataset, gt_container, gt_dataset) + Notes: + This function is used to initialize the DatasetSpec class with the specified dataset type, raw container, raw dataset, ground truth container, and ground truth dataset. + """ if isinstance(dataset_type, str): dataset_type = DatasetType[dataset_type.lower()] @@ -128,10 +316,42 @@ def __init__( self.gt_dataset = gt_dataset def __str__(self) -> str: + """ + Get the string representation of the class. + + Args: + self : obj + The object. + Returns: + str : The string representation of the class. + Raises: + KeyError + If the item is not a valid option, a KeyError is raised. + Examples: + >>> __str__() + Notes: + This function is used to get the string representation of the class. + """ return f"{self.raw_container.stem}_{self.gt_dataset}" def generate_dataspec_from_csv(csv_path: Path): + """ + Generate the dataset specification from the CSV file. It will return the dataset specification. + + Args: + csv_path : obj + The CSV file path. + Returns: + list : The dataset specification. + Raises: + FileNotFoundError + If the file does not exist, a FileNotFoundError is raised. + Examples: + >>> generate_dataspec_from_csv(csv_path) + Notes: + This function is used to generate the dataset specification from the CSV file. + """ datasets = [] if not csv_path.exists(): raise FileNotFoundError(f"CSV file {csv_path} does not exist.") @@ -158,14 +378,73 @@ def generate_dataspec_from_csv(csv_path: Path): class DataSplitGenerator: - """Generates DataSplitConfig for a given task config and datasets. - class names in gt_dataset shoulb be within [] e.g. [mito&peroxisome&er] for mutiple classes or [mito] for one class - Currently only supports: - - semantic segmentation. - Supports: + """ + Generates DataSplitConfig for a given task config and datasets. A csv file can be generated + from the DataSplitConfig and used to generate the DataSplitConfig again. + + Currently only supports semantic segmentation. + Supports: - 2D and 3D datasets. - Zarr, N5 and OME-Zarr datasets. - Multi class targets. + - Different resolutions for raw and ground truth datasets. + - Different resolutions for training and validation datasets. + + Attributes: + name : str + The name of the data split generator. + datasets : list + The list of dataset specifications. + input_resolution : obj + The input resolution. + output_resolution : obj + The output resolution. + targets : list + The list of targets. + segmentation_type : obj + The segmentation type. + max_gt_downsample : int + The maximum ground truth downsample. + max_gt_upsample : int + The maximum ground truth upsample. + max_raw_training_downsample : int + The maximum raw training downsample. + max_raw_training_upsample : int + The maximum raw training upsample. + max_raw_validation_downsample : int + The maximum raw validation downsample. + max_raw_validation_upsample : int + The maximum raw validation upsample. + min_training_volume_size : int + The minimum training volume size. + raw_min : int + The minimum raw value. + raw_max : int + The maximum raw value. + classes_separator_caracter : str + The classes separator character. + Methods: + __init__(name, datasets, input_resolution, output_resolution, targets, segmentation_type, max_gt_downsample, max_gt_upsample, max_raw_training_downsample, max_raw_training_upsample, max_raw_validation_downsample, max_raw_validation_upsample, min_training_volume_size, raw_min, raw_max, classes_separator_caracter) + Initializes the DataSplitGenerator class with the specified name, datasets, input resolution, output resolution, targets, segmentation type, maximum ground truth downsample, maximum ground truth upsample, maximum raw training downsample, maximum raw training upsample, maximum raw validation downsample, maximum raw validation upsample, minimum training volume size, minimum raw value, maximum raw value, and classes separator character. + __str__(self) + A method to get the string representation of the class. + class_name(self) + A method to get the class name. + check_class_name(self, class_name) + A method to check the class name. + compute(self) + A method to compute the data split. + __generate_semantic_seg_datasplit(self) + A method to generate the semantic segmentation data split. + __generate_semantic_seg_dataset_crop(self, dataset) + A method to generate the semantic segmentation dataset crop. + generate_csv(datasets, csv_path) + A method to generate the CSV file. + generate_from_csv(csv_path, input_resolution, output_resolution, name, **kwargs) + A method to generate the data split from the CSV file. + Notes: + - This class is used to generate the DataSplitConfig for a given task config and datasets. + - Class names in gt_dataset shoulb be within [] e.g. [mito&peroxisome&er] for mutiple classes or [mito] for one class """ def __init__( @@ -187,6 +466,69 @@ def __init__( raw_max=255, classes_separator_caracter="&", ): + """ + Initializes the DataSplitGenerator class with the specified: + - name + - datasets + - input resolution + - output resolution + - targets + - segmentation type + - maximum ground truth downsample + - maximum ground truth upsample + - maximum raw training downsample + - maximum raw training upsample + - maximum raw validation downsample + - maximum raw validation upsample + - minimum training volume size + - minimum raw value + - maximum raw value + - classes separator character + + Args: + name : str + The name of the data split generator. + datasets : list + The list of dataset specifications. + input_resolution : obj + The input resolution. + output_resolution : obj + The output resolution. + targets : list + The list of targets. + segmentation_type : obj + The segmentation type. + max_gt_downsample : int + The maximum ground truth downsample. + max_gt_upsample : int + The maximum ground truth upsample. + max_raw_training_downsample : int + The maximum raw training downsample. + max_raw_training_upsample : int + The maximum raw training upsample. + max_raw_validation_downsample : int + The maximum raw validation downsample. + max_raw_validation_upsample : int + The maximum raw validation upsample. + min_training_volume_size : int + The minimum training volume size. + raw_min : int + The minimum raw value. + raw_max : int + The maximum raw value. + classes_separator_caracter : str + The classes separator character. + Returns: + obj : The DataSplitGenerator class. + Raises: + ValueError + If the class name is already set, a ValueError is raised. + Examples: + >>> DataSplitGenerator(name, datasets, input_resolution, output_resolution, targets, segmentation_type, max_gt_downsample, max_gt_upsample, max_raw_training_downsample, max_raw_training_upsample, max_raw_validation_downsample, max_raw_validation_upsample, min_training_volume_size, raw_min, raw_max, classes_separator_caracter) + Notes: + This function is used to initialize the DataSplitGenerator class with the specified name, datasets, input resolution, output resolution, targets, segmentation type, maximum ground truth downsample, maximum ground truth upsample, maximum raw training downsample, maximum raw training upsample, maximum raw validation downsample, maximum raw validation upsample, minimum training volume size, minimum raw value, maximum raw value, and classes separator character. + + """ self.name = name self.datasets = datasets self.input_resolution = input_resolution @@ -210,15 +552,65 @@ def __init__( self.classes_separator_caracter = classes_separator_caracter def __str__(self) -> str: + """ + Get the string representation of the class. + + Args: + self : obj + The object. + Returns: + str : The string representation of the class. + Raises: + ValueError + If the class name is already set, a ValueError is raised. + Examples: + >>> __str__() + Notes: + This function is used to get the string representation of the class. + """ return f"DataSplitGenerator:{self.name}_{self.segmentation_type}_{self.class_name}_{self.output_resolution[0]}nm" @property def class_name(self): + """ + Get the class name. + + Args: + self : obj + The object. + Returns: + obj : The class name. + Raises: + ValueError + If the class name is already set, a ValueError is raised. + Examples: + >>> class_name + Notes: + This function is used to get the class name. + """ return self._class_name # Goal is to force class_name to be set only once, so we have the same classes for all datasets @class_name.setter def class_name(self, class_name): + """ + Set the class name. + + Args: + self : obj + The object. + class_name : obj + The class name. + Returns: + obj : The class name. + Raises: + ValueError + If the class name is already set, a ValueError is raised. + Examples: + >>> class_name + Notes: + This function is used to set the class name. + """ if self._class_name is not None: raise ValueError( f"Class name already set. Current class name is {self.class_name} and new class name is {class_name}" @@ -226,6 +618,25 @@ def class_name(self, class_name): self._class_name = class_name def check_class_name(self, class_name): + """ + Check the class name. + + Args: + self : obj + The object. + class_name : obj + The class name. + Returns: + obj : The class name. + Raises: + ValueError + If the class name is already set, a ValueError is raised. + Examples: + >>> check_class_name(class_name) + Notes: + This function is used to check the class name. + + """ datasets, classes = format_class_name( class_name, self.classes_separator_caracter ) @@ -242,6 +653,22 @@ def check_class_name(self, class_name): return datasets, classes def compute(self): + """ + Compute the data split. + + Args: + self : obj + The object. + Returns: + obj : The data split. + Raises: + NotImplementedError + If the segmentation type is not implemented, a NotImplementedError is raised. + Examples: + >>> compute() + Notes: + This function is used to compute the data split. + """ if self.segmentation_type == SegmentationType.semantic: return self.__generate_semantic_seg_datasplit() else: @@ -250,6 +677,23 @@ def compute(self): ) def __generate_semantic_seg_datasplit(self): + """ + Generate the semantic segmentation data split. + + Args: + self : obj + The object. + Returns: + obj : The data split. + Raises: + FileNotFoundError + If the file does not exist, a FileNotFoundError is raised. + Examples: + >>> __generate_semantic_seg_datasplit() + Notes: + This function is used to generate the semantic segmentation data split. + + """ train_dataset_configs = [] validation_dataset_configs = [] for dataset in self.datasets: @@ -281,6 +725,24 @@ def __generate_semantic_seg_datasplit(self): ) def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): + """ + Generate the semantic segmentation dataset crop. + + Args: + self : obj + The object. + dataset : obj + The dataset. + Returns: + obj : The dataset crop. + Raises: + FileNotFoundError + If the file does not exist, a FileNotFoundError is raised. + Examples: + >>> __generate_semantic_seg_dataset_crop(dataset) + Notes: + This function is used to generate the semantic segmentation dataset crop. + """ raw_container = dataset.raw_container raw_dataset = dataset.raw_dataset gt_path = dataset.gt_container @@ -374,6 +836,31 @@ def generate_from_csv( name: Optional[str] = None, **kwargs, ): + """ + Generate the data split from the CSV file. + + Args: + csv_path : obj + The CSV file path. + input_resolution : obj + The input resolution. + output_resolution : obj + The output resolution. + name : str + The name. + **kwargs : dict + The keyword arguments. + Returns: + obj : The data split. + Raises: + FileNotFoundError + If the file does not exist, a FileNotFoundError is raised. + Examples: + >>> generate_from_csv(csv_path, input_resolution, output_resolution, name, **kwargs) + Notes: + This function is used to generate the data split from the CSV file. + + """ if isinstance(csv_path, str): csv_path = Path(csv_path) @@ -390,6 +877,24 @@ def generate_from_csv( def format_class_name(class_name, separator_character="&"): + """ + Format the class name. + + Args: + class_name : obj + The class name. + separator_character : str + The separator character. + Returns: + obj : The class name. + Raises: + ValueError + If the class name is invalid, a ValueError is raised. + Examples: + >>> format_class_name(class_name, separator_character) + Notes: + This function is used to format the class name. + """ if "[" in class_name: if "]" not in class_name: raise ValueError(f"Invalid class name {class_name} missing ']'") diff --git a/dacapo/experiments/datasplits/dummy_datasplit.py b/dacapo/experiments/datasplits/dummy_datasplit.py index 6a5476ef0..05f996dd1 100644 --- a/dacapo/experiments/datasplits/dummy_datasplit.py +++ b/dacapo/experiments/datasplits/dummy_datasplit.py @@ -5,34 +5,40 @@ class DummyDataSplit(DataSplit): - """A class for creating a simple train dataset and no validation dataset. - - It is derived from `DataSplit` class. - - ... - Attributes - ---------- - train : list - The list containing training datasets. In this class, it contains only one dataset for training. - validate : list - The list containing validation datasets. In this class, it is an empty list as no validation dataset is set. - - Methods - ------- - __init__(self, datasplit_config): - The constructor for DummyDataSplit class. It initialises a list with training datasets according to the input configuration. + """ + A class for creating a simple train dataset and no validation dataset. It is derived from `DataSplit` class. + It is used to split the data into training and validation datasets. The training and validation datasets are + used to train and validate the model respectively. + + Attributes: + train : list + The list containing training datasets. In this class, it contains only one dataset for training. + validate : list + The list containing validation datasets. In this class, it is an empty list as no validation dataset is set. + Methods: + __init__(self, datasplit_config): + The constructor for DummyDataSplit class. It initialises a list with training datasets according to the input configuration. + Notes: + This class is used to split the data into training and validation datasets. """ train: List[Dataset] validate: List[Dataset] def __init__(self, datasplit_config): - """Constructor method for initializing the instance of `DummyDataSplit` class. It sets up the list of training datasets based on the passed configuration. - - Parameters - ---------- - datasplit_config : DatasplitConfig - The configuration setup for processing the datasets into the training sets. + """ + Constructor method for initializing the instance of `DummyDataSplit` class. It sets up the list of training datasets based on the passed configuration. + + Args: + datasplit_config : obj + The configuration to initialize the DummyDataSplit class. + Raises: + Exception + If the model setup cannot be loaded, an Exception is thrown which is logged and handled by training the model without head matching. + Examples: + >>> dummy_datasplit = DummyDataSplit(datasplit_config) + Notes: + This function is called by the DummyDataSplit class to initialize the DummyDataSplit class with specified config to split the data into training and validation datasets. """ super().__init__() diff --git a/dacapo/experiments/datasplits/dummy_datasplit_config.py b/dacapo/experiments/datasplits/dummy_datasplit_config.py index d320df949..fc343909a 100644 --- a/dacapo/experiments/datasplits/dummy_datasplit_config.py +++ b/dacapo/experiments/datasplits/dummy_datasplit_config.py @@ -9,7 +9,8 @@ @attr.s class DummyDataSplitConfig(DataSplitConfig): - """A simple class representing config for Dummy DataSplit. + """ + A simple class representing config for Dummy DataSplit. This class is derived from 'DataSplitConfig' and is initialized with 'DatasetConfig' for training dataset. @@ -17,6 +18,12 @@ class DummyDataSplitConfig(DataSplitConfig): Attributes: datasplit_type: Class of dummy data split functionality. train_config: Config for the training dataset. Defaults to DummyDatasetConfig. + Methods: + verify() + A method for verification. This method always return 'False' plus + a string indicating the condition. + Notes: + This class is used to represent the configuration for Dummy DataSplit. """ @@ -25,10 +32,17 @@ class DummyDataSplitConfig(DataSplitConfig): train_config: DatasetConfig = attr.ib(DummyDatasetConfig(name="dummy_dataset")) def verify(self) -> Tuple[bool, str]: - """A method for verification. This method always return 'False' plus + """ + A method for verification. This method always return 'False' plus a string indicating the condition. Returns: Tuple[bool, str]: A tuple contains a boolean 'False' and a string. + Examples: + >>> dummy_datasplit_config = DummyDataSplitConfig(train_config) + >>> dummy_datasplit_config.verify() + (False, "This is a DummyDataSplit and is never valid") + Notes: + This method is used to verify the configuration of DummyDataSplit. """ return False, "This is a DummyDataSplit and is never valid" diff --git a/dacapo/experiments/datasplits/keys/keys.py b/dacapo/experiments/datasplits/keys/keys.py index 7da64dd78..531e43d49 100644 --- a/dacapo/experiments/datasplits/keys/keys.py +++ b/dacapo/experiments/datasplits/keys/keys.py @@ -2,7 +2,26 @@ class DataKey(Enum): - """Represent a base class for various types of keys in Dacapo library.""" + """ + Represent a base class for various types of keys in Dacapo library. + + Attributes: + RAW: str + The raw data key. + GT: str + The ground truth data key. + MASK: str + The data mask key. + NON_EMPTY: str + The data key for non-empty mask. + SPECIFIED_LOCATIONS: str + The key for specified locations in the graph. + Methods: + __str__(): + Return the string representation of the key. + Notes: + This class is a base class and should not be instantiated. + """ pass @@ -12,16 +31,20 @@ class ArrayKey(DataKey): """ A unique enumeration representing different types of array keys - Attributes - ---------- - RAW: str - The raw data key. - GT: str - The ground truth data key. - MASK: str - The data mask key. - NON_EMPTY: str - The data key for non-empty mask. + Attributes: + RAW: str + The raw data key. + GT: str + The ground truth data key. + MASK: str + The data mask key. + NON_EMPTY: str + The data key for non-empty mask. + Methods: + __str__(): + Return the string representation of the key. + Notes: + This class is a base class and should not be instantiated. """ RAW = "raw" @@ -35,10 +58,14 @@ class GraphKey(DataKey): """ A unique enumeration representing different types of graph keys - Attributes - ---------- - SPECIFIED_LOCATIONS: str - The key for specified locations in the graph. + Attributes: + SPECIFIED_LOCATIONS: str + The key for specified locations in the graph. + Methods: + __str__(): + Return the string representation of the key. + Notes: + This class is a base class and should not be instantiated. """ SPECIFIED_LOCATIONS = "specified_locations" diff --git a/dacapo/experiments/datasplits/train_validate_datasplit.py b/dacapo/experiments/datasplits/train_validate_datasplit.py index 3fdfe6c41..8d456386c 100644 --- a/dacapo/experiments/datasplits/train_validate_datasplit.py +++ b/dacapo/experiments/datasplits/train_validate_datasplit.py @@ -5,10 +5,46 @@ class TrainValidateDataSplit(DataSplit): + """ + A DataSplit that contains a list of training and validation datasets. This + class is used to split the data into training and validation datasets. The + training and validation datasets are used to train and validate the model + respectively. + + Attributes: + train : list + The list of training datasets. + validate : list + The list of validation datasets. + Methods: + __init__(datasplit_config) + Initializes the TrainValidateDataSplit class with specified config to + split the data into training and validation datasets. + Notes: + This class is used to split the data into training and validation datasets. + """ train: List[Dataset] validate: List[Dataset] def __init__(self, datasplit_config): + """ + Initializes the TrainValidateDataSplit class with specified config to + split the data into training and validation datasets. + + Args: + datasplit_config : obj + The configuration to initialize the TrainValidateDataSplit class. + Raises: + Exception + If the model setup cannot be loaded, an Exception is thrown which + is logged and handled by training the model without head matching. + Examples: + >>> train_validate_datasplit = TrainValidateDataSplit(datasplit_config) + Notes: + This function is called by the TrainValidateDataSplit class to initialize + the TrainValidateDataSplit class with specified config to split the data + into training and validation datasets. + """ super().__init__() self.train = [ diff --git a/dacapo/experiments/datasplits/train_validate_datasplit_config.py b/dacapo/experiments/datasplits/train_validate_datasplit_config.py index 9970250a6..3cb7f9364 100644 --- a/dacapo/experiments/datasplits/train_validate_datasplit_config.py +++ b/dacapo/experiments/datasplits/train_validate_datasplit_config.py @@ -10,7 +10,23 @@ @attr.s class TrainValidateDataSplitConfig(DataSplitConfig): """ - This is the standard Train/Validate DataSplit config. + This is the standard Train/Validate DataSplit config. It contains a list of + training and validation datasets. This class is used to split the data into + training and validation datasets. The training and validation datasets are + used to train and validate the model respectively. + + Attributes: + train_configs : list + The list of training datasets. + validate_configs : list + The list of validation datasets. + Methods: + __init__(datasplit_config) + Initializes the TrainValidateDataSplitConfig class with specified config to + split the data into training and validation datasets. + Notes: + This class is used to split the data into training and validation datasets. + """ datasplit_type = TrainValidateDataSplit diff --git a/dacapo/experiments/starts/cosem_start.py b/dacapo/experiments/starts/cosem_start.py index fb943b45a..0db6d75ef 100644 --- a/dacapo/experiments/starts/cosem_start.py +++ b/dacapo/experiments/starts/cosem_start.py @@ -8,6 +8,31 @@ def get_model_setup(run): + """ + Loads the model setup from the dacapo store for the specified run. The + model setup includes the classes_channels, voxel_size_input and + voxel_size_output. + + Args: + run : str + The run for which the model setup is to be loaded. + Returns: + classes_channels : list + The classes_channels of the model. + voxel_size_input : list + The voxel_size_input of the model. + voxel_size_output : list + The voxel_size_output of the model. + Raises: + Exception + If the model setup cannot be loaded, an Exception is thrown which + is logged and handled by training the model without head matching. + Examples: + >>> classes_channels, voxel_size_input, voxel_size_output = get_model_setup(run) + Notes: + This function is called by the CosemStart class to load the model setup + from the dacapo store for the specified run. + """ try: model = cosem.load_model(run) if hasattr(model, "classes_channels"): @@ -31,7 +56,55 @@ def get_model_setup(run): class CosemStart(Start): + """ + A class to represent the starting point for tasks. This class inherits + from the Start class and is used to load the weights of the starter model + used for finetuning. The weights are loaded from the dacapo store for the + specified run and criterion. + + Attributes: + run : str + The run to be used as a starting point for tasks. + criterion : str + The criterion to be used for choosing weights from run. + name : str + The name of the run and criterion. + channels : list + The classes_channels of the model. + Methods: + __init__(start_config) + Initializes the CosemStart class with specified config to run the + initialization of weights for a model associated with a specific + criterion. + check() + Checks if the checkpoint for the specified run and criterion exists. + initialize_weights(model, new_head=None) + Retrieves the weights from the dacapo store and load them into + the model. + Notes: + This class is used to represent the starting point for tasks. The weights + of the starter model used for finetuning are loaded from the dacapo store. + """ def __init__(self, start_config): + """ + Initializes the CosemStart class with specified config to run the + initialization of weights for a model associated with a specific + criterion. + + Args: + start_config : obj + The configuration to initialize the CosemStart class. + Raises: + Exception + If the model setup cannot be loaded, an Exception is thrown which + is logged and handled by training the model without head matching. + Examples: + >>> start = CosemStart(start_config) + Notes: + This function is called by the CosemStart class to initialize the + CosemStart class with specified config to run the initialization of + weights for a model associated with a specific criterion. + """ self.run = start_config.run self.criterion = start_config.criterion self.name = f"{self.run}/{self.criterion}" @@ -43,6 +116,19 @@ def __init__(self, start_config): self.channels = channels def check(self): + """ + Checks if the checkpoint for the specified run and criterion exists. + + Raises: + Exception + If the checkpoint does not exist, an Exception is thrown which + is logged and handled by training the model without head matching. + Examples: + >>> check() + Notes: + This function is called by the CosemStart class to check if the + checkpoint for the specified run and criterion exists. + """ from dacapo.store.create_store import create_weights_store weights_store = create_weights_store() @@ -56,6 +142,29 @@ def check(self): logger.info(f"Checkpoint for {self.name} exists.") def initialize_weights(self, model, new_head=None): + """ + Retrieves the weights from the dacapo store and load them into + the model. + + Args: + model : obj + The model to which the weights are to be loaded. + new_head : list + The labels of the new head. + Returns: + model : obj + The model with the weights loaded from the dacapo store. + Raises: + RuntimeError + If weights of a non-existing or mismatched layer are being + loaded, a RuntimeError exception is thrown which is logged + and handled by loading only the common layers from weights. + Examples: + >>> model = initialize_weights(model, new_head) + Notes: + This function is called by the CosemStart class to retrieve the weights + from the dacapo store and load them into the model. + """ self.check() from dacapo.store.create_store import create_weights_store diff --git a/dacapo/experiments/starts/cosem_start_config.py b/dacapo/experiments/starts/cosem_start_config.py index de16477b1..dcd3f150b 100644 --- a/dacapo/experiments/starts/cosem_start_config.py +++ b/dacapo/experiments/starts/cosem_start_config.py @@ -5,9 +5,30 @@ @attr.s class CosemStartConfig(StartConfig): - """Starter for COSEM pretained models. This is a subclass of `StartConfig` and + """ + Starter for COSEM pretained models. This is a subclass of `StartConfig` and should be used to initialize the model with pretrained weights from a previous run. + + The weights are loaded from the dacapo store for the specified run. The + configuration is used to initialize the weights for the model associated with + a specific criterion. + + Attributes: + run : str + The run to be used as a starting point for tasks. + criterion : str + The criterion to be used for choosing weights from run. + Methods: + __init__(start_config) + Initializes the CosemStartConfig class with specified config to run the + initialization of weights for a model associated with a specific + criterion. + Examples: + >>> start_config = CosemStartConfig(run="run_1", criterion="best") + Notes: + This class is used to represent the configuration for running tasks. + """ start_type = CosemStart diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index a673b7e56..0e9e8f7e4 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -12,6 +12,37 @@ def match_heads(model, head_weights, old_head, new_head): + """ + Matches the head of the model to the new head by copying the weights + of the old head to the new head. The weights of the old head are + copied to the new head by matching the labels of the old head to the + labels of the new head. + + Args: + model : obj + The model to which the weights are to be loaded. + head_weights : dict + The weights of the old head. + old_head : list + The labels of the old head. + new_head : list + The labels of the new head. + Returns: + model : obj + The model with the weights of the old head copied to the new + head. + Raises: + RuntimeError + If the old head is not found in the new head, a RuntimeError + exception is thrown which is logged and handled by loading + only the common layers from weights. + Examples: + >>> model = match_heads(model, head_weights, old_head, new_head) + Notes: + This function is called by the Start class to match the head of + the model to the new head by copying the weights of the old head + to the new head. + """ for label in new_head: if label in old_head: logger.warning(f"matching head for {label}.") @@ -25,6 +56,45 @@ def match_heads(model, head_weights, old_head, new_head): def _set_weights(model, weights, run, criterion, old_head=None, new_head=None): + """ + Loads the weights of the model from the dacapo store into the model. If + the old head and new head are provided, the weights of the old head are + copied to the new head by matching the labels of the old head to the labels + of the new head. If the old head is not found in the new head, a RuntimeError + exception is thrown which is logged and handled by loading only the common + layers from weights. + + Args: + model : obj + The model to which the weights are to be loaded. + weights : obj + The weights of the model retrieved from the dacapo store. + run : str + The specified run to retrieve weights for the model. + criterion : str + The policy that was used to decide when to store the weights. + old_head : list + The labels of the old head. + new_head : list + The labels of the new head. + Returns: + model : obj + The model with the weights loaded from the dacapo store. + Raises: + RuntimeError + If weights of a non-existing or mismatched layer are being + loaded, a RuntimeError exception is thrown which is logged + and handled by loading only the common layers from weights. + Examples: + >>> model = _set_weights(model, weights, run, criterion, old_head, new_head) + Notes: + This function is called by the Start class to load the weights of the + model from the dacapo store into the model. If the old head and new head + are provided, the weights of the old head are copied to the new head by + matching the labels of the old head to the labels of the new head. If the + old head is not found in the new head, a RuntimeError exception is thrown + which is logged and handled by loading only the common layers from weights. + """ logger.warning( f"loading weights from run {run}, criterion: {criterion}, old_head {old_head}, new_head: {new_head}" ) @@ -79,12 +149,24 @@ class Start(ABC): This class interfaces with the dacapo store to retrieve and load the weights of the starter model used for finetuning. - Attributes - ---------- - run : str - The specified run to retrieve weights for the model. - criterion : str - The policy that was used to decide when to store the weights. + Attributes: + run : str + The specified run to retrieve weights for the model. + criterion : str + The policy that was used to decide when to store the weights. + channels : int + The number of channels in the input data. + Methods: + __init__(start_config) + Initializes the Start class with specified config to run the + initialization of weights for a model associated with a specific + criterion. + initialize_weights(model, new_head=None) + Retrieves the weights from the dacapo store and load them into + the model. + Notes: + This class is used to retrieve and load the weights of the starter + model used for finetuning from the dacapo store. """ def __init__(self, start_config): @@ -93,11 +175,15 @@ def __init__(self, start_config): initialization of weights for a model associated with a specific criterion. - Parameters - ---------- - start_config : obj - An object containing configuration details for the model - initialization. + Args: + start_config : obj + The configuration to initialize the Start class. + Examples: + >>> start = Start(start_config) + Notes: + This function is called by the Start class to initialize the + Start class with specified config to run the initialization of + weights for a model associated with a specific criterion. """ # Old version return a dict, new version return an object, this line is to support both if isinstance(start_config, dict): @@ -117,16 +203,25 @@ def initialize_weights(self, model, new_head=None): """ Retrieves the weights from the dacapo store and load them into the model. - Parameters - ---------- - model : obj - The model to which the weights are to be loaded. - Raises - ------ - RuntimeError - If weights of a non-existing or mismatched layer are being - loaded, a RuntimeError exception is thrown which is logged - and handled by loading only the common layers from weights. + + Args: + model : obj + The model to which the weights are to be loaded. + new_head : list + The labels of the new head. + Returns: + model : obj + The model with the weights loaded from the dacapo store. + Raises: + RuntimeError + If weights of a non-existing or mismatched layer are being + loaded, a RuntimeError exception is thrown which is logged + and handled by loading only the common layers from weights. + Examples: + >>> model = start.initialize_weights(model, new_head) + Notes: + This function is called by the Start class to retrieve the weights + from the dacapo store and load them into the model. """ from dacapo.store.create_store import create_weights_store diff --git a/dacapo/experiments/starts/start_config.py b/dacapo/experiments/starts/start_config.py index 60ae35ff9..0c961f250 100644 --- a/dacapo/experiments/starts/start_config.py +++ b/dacapo/experiments/starts/start_config.py @@ -5,16 +5,22 @@ @attr.s class StartConfig: """ - A class to represent the configuration for running tasks. - - Attributes - ---------- - run : str - The run to be used as a starting point for tasks. - - criterion : str - The criterion to be used for choosing weights from run. + A class to represent the configuration for running tasks. This class + interfaces with the dacapo store to retrieve and load the weights of the + starter model used for finetuning. + Attributes: + run : str + The run to be used as a starting point for tasks. + criterion : str + The criterion to be used for choosing weights from run. + Methods: + __init__(start_config) + Initializes the StartConfig class with specified config to run the + initialization of weights for a model associated with a specific + criterion. + Notes: + This class is used to represent the configuration for running tasks. """ start_type = Start diff --git a/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluation_scores.py b/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluation_scores.py index a8eb68dce..8252c3db4 100644 --- a/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluation_scores.py +++ b/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluation_scores.py @@ -6,6 +6,60 @@ @attr.s class BinarySegmentationEvaluationScores(EvaluationScores): + """ + Class representing evaluation scores for binary segmentation tasks. + + The metrics include: + - Dice coefficient: 2 * |A ∩ B| / |A| + |B| ; where A and B are the binary segmentations + - Jaccard coefficient: |A ∩ B| / |A ∪ B| ; where A and B are the binary segmentations + - Hausdorff distance: max(h(A, B), h(B, A)) ; where h(A, B) is the Hausdorff distance between A and B + - False negative rate: |A - B| / |A| ; where A and B are the binary segmentations + - False positive rate: |B - A| / |B| ; where A and B are the binary segmentations + - False discovery rate: |B - A| / |A| ; where A and B are the binary segmentations + - VOI: Variation of Information; split and merge errors combined into a single measure of segmentation quality + - Mean false distance: 0.5 * (mean false positive distance + mean false negative distance) + - Mean false negative distance: mean distance of false negatives + - Mean false positive distance: mean distance of false positives + - Mean false distance clipped: 0.5 * (mean false positive distance clipped + mean false negative distance clipped) ; clipped to a maximum distance + - Mean false negative distance clipped: mean distance of false negatives clipped ; clipped to a maximum distance + - Mean false positive distance clipped: mean distance of false positives clipped ; clipped to a maximum distance + - Precision with tolerance: TP / (TP + FP) ; where TP and FP are the true and false positives within a tolerance distance + - Recall with tolerance: TP / (TP + FN) ; where TP and FN are the true and false positives within a tolerance distance + - F1 score with tolerance: 2 * (Recall * Precision) / (Recall + Precision) ; where Recall and Precision are the true and false positives within a tolerance distance + - Precision: TP / (TP + FP) ; where TP and FP are the true and false positives + - Recall: TP / (TP + FN) ; where TP and FN are the true and false positives + - F1 score: 2 * (Recall * Precision) / (Recall + Precision) ; where Recall and Precision are the true and false positives + + Attributes: + dice (float): The Dice coefficient. + jaccard (float): The Jaccard index. + hausdorff (float): The Hausdorff distance. + false_negative_rate (float): The false negative rate. + false_negative_rate_with_tolerance (float): The false negative rate with tolerance. + false_positive_rate (float): The false positive rate. + false_discovery_rate (float): The false discovery rate. + false_positive_rate_with_tolerance (float): The false positive rate with tolerance. + voi (float): The variation of information. + mean_false_distance (float): The mean false distance. + mean_false_negative_distance (float): The mean false negative distance. + mean_false_positive_distance (float): The mean false positive distance. + mean_false_distance_clipped (float): The mean false distance clipped. + mean_false_negative_distance_clipped (float): The mean false negative distance clipped. + mean_false_positive_distance_clipped (float): The mean false positive distance clipped. + precision_with_tolerance (float): The precision with tolerance. + recall_with_tolerance (float): The recall with tolerance. + f1_score_with_tolerance (float): The F1 score with tolerance. + precision (float): The precision. + recall (float): The recall. + f1_score (float): The F1 score. + Methods: + store_best(criterion: str) -> bool: Whether or not to store the best weights/validation blocks for this criterion. + higher_is_better(criterion: str) -> bool: Determines whether a higher value is better for a given criterion. + bounds(criterion: str) -> Tuple[Union[int, float, None], Union[int, float, None]]: Determines the bounds for a given criterion. + Notes: + The evaluation scores are stored as attributes of the class. The class also contains methods to determine whether a higher value is better for a given criterion, whether or not to store the best weights/validation blocks for a given criterion, and the bounds for a given criterion. + """ + dice: float = attr.ib(default=float("nan")) jaccard: float = attr.ib(default=float("nan")) hausdorff: float = attr.ib(default=float("nan")) @@ -54,6 +108,24 @@ class BinarySegmentationEvaluationScores(EvaluationScores): @staticmethod def store_best(criterion: str) -> bool: + """ + Determines whether or not to store the best weights/validation blocks for a given criterion. + + Args: + criterion (str): The evaluation criterion. + Returns: + bool: True if the best weights/validation blocks should be stored, False otherwise. + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> BinarySegmentationEvaluationScores.store_best("dice") + False + >>> BinarySegmentationEvaluationScores.store_best("f1_score") + True + Notes: + The method returns True if the criterion is recognized and False otherwise. Whether or not to store the best weights/validation blocks for a given criterion is determined by the mapping dictionary. + + """ # Whether or not to store the best weights/validation blocks for this # criterion. mapping = { @@ -83,6 +155,23 @@ def store_best(criterion: str) -> bool: @staticmethod def higher_is_better(criterion: str) -> bool: + """ + Determines whether a higher value is better for a given criterion. + + Args: + criterion (str): The evaluation criterion. + Returns: + bool: True if a higher value is better, False otherwise. + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> BinarySegmentationEvaluationScores.higher_is_better("dice") + True + >>> BinarySegmentationEvaluationScores.higher_is_better("f1_score") + True + Notes: + The method returns True if the criterion is recognized and False otherwise. Whether a higher value is better for a given criterion is determined by the mapping dictionary. + """ mapping = { "dice": True, "jaccard": True, @@ -112,6 +201,23 @@ def higher_is_better(criterion: str) -> bool: def bounds( criterion: str, ) -> Tuple[Union[int, float, None], Union[int, float, None]]: + """ + Determines the bounds for a given criterion. The bounds are used to determine the best value for a given criterion. + + Args: + criterion (str): The evaluation criterion. + Returns: + Tuple[Union[int, float, None], Union[int, float, None]]: The lower and upper bounds for the criterion. + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> BinarySegmentationEvaluationScores.bounds("dice") + (0, 1) + >>> BinarySegmentationEvaluationScores.bounds("hausdorff") + (0, nan) + Notes: + The method returns the lower and upper bounds for the criterion. The bounds are determined by the mapping dictionary. + """ mapping = { "dice": (0, 1), "jaccard": (0, 1), @@ -140,15 +246,53 @@ def bounds( @attr.s class MultiChannelBinarySegmentationEvaluationScores(EvaluationScores): + """ + Class representing evaluation scores for multi-channel binary segmentation tasks. + + Attributes: + channel_scores (List[Tuple[str, BinarySegmentationEvaluationScores]]): The list of channel scores. + Methods: + higher_is_better(criterion: str) -> bool: Determines whether a higher value is better for a given criterion. + store_best(criterion: str) -> bool: Whether or not to store the best weights/validation blocks for this criterion. + bounds(criterion: str) -> Tuple[Union[int, float, None], Union[int, float, None]]: Determines the bounds for a given criterion. + Notes: + The evaluation scores are stored as attributes of the class. The class also contains methods to determine whether a higher value is better for a given criterion, whether or not to store the best weights/validation blocks for a given criterion, and the bounds for a given criterion. + """ + channel_scores: List[Tuple[str, BinarySegmentationEvaluationScores]] = attr.ib() def __attrs_post_init__(self): + """ + Post-initialization method to set attributes for each channel and criterion. + + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> channel_scores = [("channel1", BinarySegmentationEvaluationScores()), ("channel2", BinarySegmentationEvaluationScores())] + >>> MultiChannelBinarySegmentationEvaluationScores(channel_scores) + Notes: + The method sets attributes for each channel and criterion. The attributes are stored as attributes of the class. + """ for channel, scores in self.channel_scores: for criteria in BinarySegmentationEvaluationScores.criteria: setattr(self, f"{channel}__{criteria}", getattr(scores, criteria)) @property def criteria(self): + """ + Returns a list of all criteria for all channels. + + Returns: + List[str]: The list of criteria. + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> channel_scores = [("channel1", BinarySegmentationEvaluationScores()), ("channel2", BinarySegmentationEvaluationScores())] + >>> MultiChannelBinarySegmentationEvaluationScores(channel_scores).criteria + Notes: + The method returns a list of all criteria for all channels. The criteria are stored as attributes of the class. + """ + return [ f"{channel}__{criteria}" for channel, _ in self.channel_scores @@ -157,11 +301,45 @@ def criteria(self): @staticmethod def higher_is_better(criterion: str) -> bool: + """ + Determines whether a higher value is better for a given criterion. + + Args: + criterion (str): The evaluation criterion. + Returns: + bool: True if a higher value is better, False otherwise. + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> MultiChannelBinarySegmentationEvaluationScores.higher_is_better("channel1__dice") + True + >>> MultiChannelBinarySegmentationEvaluationScores.higher_is_better("channel1__f1_score") + True + Notes: + The method returns True if the criterion is recognized and False otherwise. Whether a higher value is better for a given criterion is determined by the mapping dictionary. + """ _, criterion = criterion.split("__") return BinarySegmentationEvaluationScores.higher_is_better(criterion) @staticmethod def store_best(criterion: str) -> bool: + """ + Determines whether or not to store the best weights/validation blocks for a given criterion. + + Args: + criterion (str): The evaluation criterion. + Returns: + bool: True if the best weights/validation blocks should be stored, False otherwise. + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> MultiChannelBinarySegmentationEvaluationScores.store_best("channel1__dice") + False + >>> MultiChannelBinarySegmentationEvaluationScores.store_best("channel1__f1_score") + True + Notes: + The method returns True if the criterion is recognized and False otherwise. Whether or not to store the best weights/validation blocks for a given criterion is determined by the mapping dictionary. + """ _, criterion = criterion.split("__") return BinarySegmentationEvaluationScores.store_best(criterion) @@ -169,5 +347,22 @@ def store_best(criterion: str) -> bool: def bounds( criterion: str, ) -> Tuple[Union[int, float, None], Union[int, float, None]]: + """ + Determines the bounds for a given criterion. The bounds are used to determine the best value for a given criterion. + + Args: + criterion (str): The evaluation criterion. + Returns: + Tuple[Union[int, float, None], Union[int, float, None]]: The lower and upper bounds for the criterion. + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> MultiChannelBinarySegmentationEvaluationScores.bounds("channel1__dice") + (0, 1) + >>> MultiChannelBinarySegmentationEvaluationScores.bounds("channel1__hausdorff") + (0, nan) + Notes: + The method returns the lower and upper bounds for the criterion. The bounds are determined by the mapping dictionary. + """ _, criterion = criterion.split("__") return BinarySegmentationEvaluationScores.bounds(criterion) diff --git a/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py b/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py index 542083c4d..c74a7b320 100644 --- a/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py +++ b/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py @@ -23,12 +23,78 @@ class BinarySegmentationEvaluator(Evaluator): """ - Given a binary segmentation, compute various metrics to determine their similarity. + Given a binary segmentation, compute various metrics to determine their similarity. The metrics include: + - Dice coefficient: 2 * |A ∩ B| / |A| + |B| ; where A and B are the binary segmentations + - Jaccard coefficient: |A ∩ B| / |A ∪ B| ; where A and B are the binary segmentations + - Hausdorff distance: max(h(A, B), h(B, A)) ; where h(A, B) is the Hausdorff distance between A and B + - False negative rate: |A - B| / |A| ; where A and B are the binary segmentations + - False positive rate: |B - A| / |B| ; where A and B are the binary segmentations + - False discovery rate: |B - A| / |A| ; where A and B are the binary segmentations + - VOI: Variation of Information; split and merge errors combined into a single measure of segmentation quality + - Mean false distance: 0.5 * (mean false positive distance + mean false negative distance) + - Mean false negative distance: mean distance of false negatives + - Mean false positive distance: mean distance of false positives + - Mean false distance clipped: 0.5 * (mean false positive distance clipped + mean false negative distance clipped) ; clipped to a maximum distance + - Mean false negative distance clipped: mean distance of false negatives clipped ; clipped to a maximum distance + - Mean false positive distance clipped: mean distance of false positives clipped ; clipped to a maximum distance + - Precision with tolerance: TP / (TP + FP) ; where TP and FP are the true and false positives within a tolerance distance + - Recall with tolerance: TP / (TP + FN) ; where TP and FN are the true and false positives within a tolerance distance + - F1 score with tolerance: 2 * (Recall * Precision) / (Recall + Precision) ; where Recall and Precision are the true and false positives within a tolerance distance + - Precision: TP / (TP + FP) ; where TP and FP are the true and false positives + - Recall: TP / (TP + FN) ; where TP and FN are the true and false positives + - F1 score: 2 * (Recall * Precision) / (Recall + Precision) ; where Recall and Precision are the true and false positives + + Attributes: + clip_distance : float + the clip distance + tol_distance : float + the tolerance distance + channels : List[str] + the channels + criteria : List[str] + the evaluation criteria + Methods: + evaluate(output_array_identifier, evaluation_array) + Evaluate the output array against the evaluation array. + score + Return the evaluation scores. + Note: + The BinarySegmentationEvaluator class is used to evaluate the performance of a binary segmentation task. + The class provides methods to evaluate the output array against the evaluation array and return the evaluation scores. + All evaluation scores should inherit from this class. + + Clip distance is the maximum distance between the ground truth and the predicted segmentation for a pixel to be considered a false positive. + Tolerance distance is the maximum distance between the ground truth and the predicted segmentation for a pixel to be considered a true positive. + Channels are the channels of the binary segmentation. + Criteria are the evaluation criteria. + """ criteria = ["jaccard", "voi"] def __init__(self, clip_distance: float, tol_distance: float, channels: List[str]): + """ + Initialize the binary segmentation evaluator. + + Args: + clip_distance : float + the clip distance + tol_distance : float + the tolerance distance + channels : List[str] + the channels + Raises: + ValueError: if the clip distance is not valid + Examples: + >>> binary_segmentation_evaluator = BinarySegmentationEvaluator(clip_distance=200, tol_distance=40, channels=["channel1", "channel2"]) + Note: + This function is used to initialize the binary segmentation evaluator. + + Clip distance is the maximum distance between the ground truth and the predicted segmentation for a pixel to be considered a false positive. + Tolerance distance is the maximum distance between the ground truth and the predicted segmentation for a pixel to be considered a true positive. + Channels are the channels of the binary segmentation. + Criteria are the evaluation criteria. + """ self.clip_distance = clip_distance self.tol_distance = tol_distance self.channels = channels @@ -38,6 +104,28 @@ def __init__(self, clip_distance: float, tol_distance: float, channels: List[str ] def evaluate(self, output_array_identifier, evaluation_array): + """ + Evaluate the output array against the evaluation array. + + Args: + output_array_identifier : str + the identifier of the output array + evaluation_array : ZarrArray + the evaluation array + Returns: + BinarySegmentationEvaluationScores or MultiChannelBinarySegmentationEvaluationScores + the evaluation scores + Raises: + ValueError: if the output array identifier is not valid + Examples: + >>> binary_segmentation_evaluator = BinarySegmentationEvaluator(clip_distance=200, tol_distance=40, channels=["channel1", "channel2"]) + >>> output_array_identifier = "output_array" + >>> evaluation_array = ZarrArray.open_from_array_identifier("evaluation_array") + >>> binary_segmentation_evaluator.evaluate(output_array_identifier, evaluation_array) + BinarySegmentationEvaluationScores(dice=0.0, jaccard=0.0, hausdorff=0.0, false_negative_rate=0.0, false_positive_rate=0.0, false_discovery_rate=0.0, voi=0.0, mean_false_distance=0.0, mean_false_negative_distance=0.0, mean_false_positive_distance=0.0, mean_false_distance_clipped=0.0, mean_false_negative_distance_clipped=0.0, mean_false_positive_distance_clipped=0.0, precision_with_tolerance=0.0, recall_with_tolerance=0.0, f1_score_with_tolerance=0.0, precision=0.0, recall=0.0, f1_score=0.0) + Note: + This function is used to evaluate the output array against the evaluation array. + """ output_array = ZarrArray.open_from_array_identifier(output_array_identifier) evaluation_data = evaluation_array[evaluation_array.roi].squeeze() output_data = output_array[output_array.roi].squeeze() @@ -135,12 +223,50 @@ def evaluate(self, output_array_identifier, evaluation_array): @property def score(self): + """ + Return the evaluation scores. + + Returns: + BinarySegmentationEvaluationScores or MultiChannelBinarySegmentationEvaluationScores + the evaluation scores + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> binary_segmentation_evaluator = BinarySegmentationEvaluator(clip_distance=200, tol_distance=40, channels=["channel1", "channel2"]) + >>> binary_segmentation_evaluator.score + BinarySegmentationEvaluationScores(dice=0.0, jaccard=0.0, hausdorff=0.0, false_negative_rate=0.0, false_positive_rate=0.0, false_discovery_rate=0.0, voi=0.0, mean_false_distance=0.0, mean_false_negative_distance=0.0, mean_false_positive_distance=0.0, mean_false_distance_clipped=0.0, mean_false_negative_distance_clipped=0.0, mean_false_positive_distance_clipped=0.0, precision_with_tolerance=0.0, recall_with_tolerance=0.0, f1_score_with_tolerance=0.0, precision=0.0, recall=0.0, f1_score=0.0) + Note: + This function is used to return the evaluation scores. + """ channel_scores = [] for channel in self.channels: channel_scores.append((channel, BinarySegmentationEvaluationScores())) return MultiChannelBinarySegmentationEvaluationScores(channel_scores) def _evaluate(self, output_data, evaluation_data, voxel_size): + """ + Evaluate the output array against the evaluation array. + + Args: + output_data : np.ndarray + the output data + evaluation_data : np.ndarray + the evaluation data + voxel_size : Tuple[float, float, float] + the voxel size + Returns: + BinarySegmentationEvaluationScores + the evaluation scores + Examples: + >>> binary_segmentation_evaluator = BinarySegmentationEvaluator(clip_distance=200, tol_distance=40, channels=["channel1", "channel2"]) + >>> output_data = np.array([[[0, 1], [1, 0]], [[0, 1], [1, 0]]]) + >>> evaluation_data = np.array([[[0, 1], [1, 0]], [[0, 1], [1, 0]]]) + >>> voxel_size = (1, 1, 1) + >>> binary_segmentation_evaluator._evaluate(output_data, evaluation_data, voxel_size) + BinarySegmentationEvaluationScores(dice=0.0, jaccard=0.0, hausdorff=0.0, false_negative_rate=0.0, false_positive_rate=0.0, false_discovery_rate=0.0, voi=0.0, mean_false_distance=0.0, mean_false_negative_distance=0.0, mean_false_positive_distance=0.0, mean_false_distance_clipped=0.0, mean_false_negative_distance_clipped=0.0, mean_false_positive_distance_clipped=0.0, precision_with_tolerance=0.0, recall_with_tolerance=0.0, f1_score_with_tolerance=0.0, precision=0.0, recall=0.0, f1_score=0.0) + Note: + This function is used to evaluate the output array against the evaluation array. + """ evaluator = ArrayEvaluator( evaluation_data, output_data, @@ -178,6 +304,89 @@ def _evaluate(self, output_data, evaluation_data, voxel_size): class ArrayEvaluator: + """ + Given a binary segmentation, compute various metrics to determine their similarity. The metrics include: + - Dice coefficient: 2 * |A ∩ B| / |A| + |B| ; where A and B are the binary segmentations + - Jaccard coefficient: |A ∩ B| / |A ∪ B| ; where A and B are the binary segmentations + - Hausdorff distance: max(h(A, B), h(B, A)) ; where h(A, B) is the Hausdorff distance between A and B + - False negative rate: |A - B| / |A| ; where A and B are the binary segmentations + - False positive rate: |B - A| / |B| ; where A and B are the binary segmentations + - False discovery rate: |B - A| / |A| ; where A and B are the binary segmentations + - VOI: Variation of Information; split and merge errors combined into a single measure of segmentation quality + - Mean false distance: 0.5 * (mean false positive distance + mean false negative distance) + - Mean false negative distance: mean distance of false negatives + - Mean false positive distance: mean distance of false positives + - Mean false distance clipped: 0.5 * (mean false positive distance clipped + mean false negative distance clipped) ; clipped to a maximum distance + - Mean false negative distance clipped: mean distance of false negatives clipped ; clipped to a maximum distance + - Mean false positive distance clipped: mean distance of false positives clipped ; clipped to a maximum distance + - Precision with tolerance: TP / (TP + FP) ; where TP and FP are the true and false positives within a tolerance distance + - Recall with tolerance: TP / (TP + FN) ; where TP and FN are the true and false positives within a tolerance distance + - F1 score with tolerance: 2 * (Recall * Precision) / (Recall + Precision) ; where Recall and Precision are the true and false positives within a tolerance distance + - Precision: TP / (TP + FP) ; where TP and FP are the true and false positives + - Recall: TP / (TP + FN) ; where TP and FN are the true and false positives + - F1 score: 2 * (Recall * Precision) / (Recall + Precision) ; where Recall and Precision are the true and false positives + + Attributes: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + truth_empty : bool + whether the truth binary segmentation is empty + test_empty : bool + whether the test binary segmentation is empty + cremieval : CremiEvaluator + the cremi evaluator + resolution : Tuple[float, float, float] + the resolution + Methods: + dice + Return the Dice coefficient. + jaccard + Return the Jaccard coefficient. + hausdorff + Return the Hausdorff distance. + false_negative_rate + Return the false negative rate. + false_positive_rate + Return the false positive rate. + false_discovery_rate + Return the false discovery rate. + precision + Return the precision. + recall + Return the recall. + f1_score + Return the F1 score. + voi + Return the VOI. + mean_false_distance + Return the mean false distance. + mean_false_negative_distance + Return the mean false negative distance. + mean_false_positive_distance + Return the mean false positive distance. + mean_false_distance_clipped + Return the mean false distance clipped. + mean_false_negative_distance_clipped + Return the mean false negative distance clipped. + mean_false_positive_distance_clipped + Return the mean false positive distance clipped. + false_positive_rate_with_tolerance + Return the false positive rate with tolerance. + false_negative_rate_with_tolerance + Return the false negative rate with tolerance. + precision_with_tolerance + Return the precision with tolerance. + recall_with_tolerance + Return the recall with tolerance. + f1_score_with_tolerance + Return the F1 score with tolerance. + Note: + The ArrayEvaluator class is used to evaluate the performance of a binary segmentation task. + The class provides methods to evaluate the truth binary segmentation against the test binary segmentation. + All evaluation scores should inherit from this class. + """ def __init__( self, truth_binary, @@ -187,6 +396,38 @@ def __init__( metric_params, resolution, ): + """ + Initialize the array evaluator. + + Args: + truth_binary : np.ndarray + the truth binary segmentation + test_binary : np.ndarray + the test binary segmentation + truth_empty : bool + whether the truth binary segmentation is empty + test_empty : bool + whether the test binary segmentation is empty + metric_params : Dict[str, float] + the metric parameters + resolution : Tuple[float, float, float] + the resolution + Returns: + ArrayEvaluator + the array evaluator + Raises: + ValueError: if the truth binary segmentation is not valid + Examples: + >>> truth_binary = np.array([[[0, 1], [1, 0]], [[0, 1], [1, 0]]]) + >>> test_binary = np.array([[[0, 1], [1, 0]], [[0, 1], [1, 0]]]) + >>> truth_empty = False + >>> test_empty = False + >>> metric_params = {"clip_distance": 200, "tol_distance": 40} + >>> resolution = (1, 1, 1) + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + Note: + This function is used to initialize the array evaluator. + """ self.truth = truth_binary.astype(np.uint8) self.test = test_binary.astype(np.uint8) self.truth_empty = truth_empty @@ -202,35 +443,148 @@ def __init__( @lazy_property.LazyProperty def truth_itk(self): + """ + A SimpleITK image of the truth binary segmentation. + + Returns: + sitk.Image + the truth binary segmentation as a SimpleITK image + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.truth_itk + ::value_type *' at 0x7f8b1c0b3f30> > + Note: + This function is used to return the truth binary segmentation as a SimpleITK image. + """ res = sitk.GetImageFromArray(self.truth) res.SetSpacing(self.resolution) return res @lazy_property.LazyProperty def test_itk(self): + """ + A SimpleITK image of the test binary segmentation. + + Args: + test : np.ndarray + the test binary segmentation + resolution : Tuple[float, float, float] + the resolution + Returns: + sitk.Image + the test binary segmentation as a SimpleITK image + Raises: + ValueError: if the test binary segmentation is not valid + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.test_itk + ::value_type *' at 0x7f8b1c0b3f30> > + Note: + This function is used to return the test binary segmentation as a SimpleITK image. + """ res = sitk.GetImageFromArray(self.test) res.SetSpacing(self.resolution) return res @lazy_property.LazyProperty def overlap_measures_filter(self): + """ + A SimpleITK filter to compute overlap measures. + + Args: + truth_itk : sitk.Image + the truth binary segmentation as a SimpleITK image + test_itk : sitk.Image + the test binary segmentation as a SimpleITK image + Returns: + sitk.LabelOverlapMeasuresImageFilter + the overlap measures filter + Raises: + ValueError: if the truth binary segmentation or the test binary segmentation is not valid + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.overlap_measures_filter + > + Note: + This function is used to return the overlap measures filter. + """ overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter() overlap_measures_filter.Execute(self.test_itk, self.truth_itk) return overlap_measures_filter def dice(self): + """ + The Dice coefficient. + + Args: + truth_itk : sitk.Image + the truth binary segmentation as a SimpleITK image + test_itk : sitk.Image + the test binary segmentation as a SimpleITK image + Returns: + float + the Dice coefficient + Raises: + ValueError: if the truth binary segmentation or the test binary segmentation is not valid + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.dice() + 0.0 + Note: + This function is used to return the Dice coefficient. + """ if (not self.truth_empty) or (not self.test_empty): return self.overlap_measures_filter.GetDiceCoefficient() else: return np.nan def jaccard(self): + """ + The Jaccard coefficient. + + Args: + truth_itk : sitk.Image + the truth binary segmentation as a SimpleITK image + test_itk : sitk.Image + the test binary segmentation as a SimpleITK image + Returns: + float + the Jaccard coefficient + Raises: + ValueError: if the truth binary segmentation or the test binary segmentation is not valid + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.jaccard() + 0.0 + Note: + This function is used to return the Jaccard coefficient. + + """ if (not self.truth_empty) or (not self.test_empty): return self.overlap_measures_filter.GetJaccardCoefficient() else: return np.nan def hausdorff(self): + """ + The Hausdorff distance. + + Args: + None + Returns: + float: the Hausdorff distance + Raises: + None + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.hausdorff() + 0.0 + Note: + This function is used to return the Hausdorff distance between the truth binary segmentation and the test binary segmentation. + + If either the truth or test binary segmentation is empty, the function returns 0. + Otherwise, it calculates the Hausdorff distance using the HausdorffDistanceImageFilter from the SimpleITK library. + """ if self.truth_empty and self.test_empty: return 0 elif not self.truth_empty and not self.test_empty: @@ -241,12 +595,47 @@ def hausdorff(self): return np.nan def false_negative_rate(self): + """ + The false negative rate. + + Returns: + float + the false negative rate + Returns: + ValueError: if the truth binary segmentation or the test binary segmentation is not valid + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.false_negative_rate() + 0.0 + Note: + This function is used to return the false negative rate. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.overlap_measures_filter.GetFalseNegativeError() def false_positive_rate(self): + """ + The false positive rate. + + Args: + truth_itk : sitk.Image + the truth binary segmentation as a SimpleITK image + test_itk : sitk.Image + the test binary segmentation as a SimpleITK image + Returns: + float + the false positive rate + Raises: + ValueError: if the truth binary segmentation or the test binary segmentation is not valid + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.false_positive_rate() + 0.0 + Note: + This function is used to return the false positive rate. + """ if self.truth_empty or self.test_empty: return np.nan else: @@ -255,12 +644,45 @@ def false_positive_rate(self): ) def false_discovery_rate(self): + """ + Calculate the false discovery rate (FDR) for the binary segmentation evaluation. + + Returns: + float: The false discovery rate. + Raises: + None + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_discovery_rate() + 0.25 + Note: + The false discovery rate is a measure of the proportion of false positives among the predicted positive samples. + It is calculated as the ratio of false positives to the sum of true positives and false positives. + If either the ground truth or the predicted segmentation is empty, the FDR is set to NaN. + """ if (not self.truth_empty) or (not self.test_empty): return self.overlap_measures_filter.GetFalsePositiveError() else: return np.nan def precision(self): + """ + Calculate the precision of the binary segmentation evaluation. + + Returns: + float: The precision value. + Raises: + None. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.precision() + 0.75 + Note: + Precision is a measure of the accuracy of the positive predictions made by the model. + It is calculated as the ratio of true positives to the total number of positive predictions. + If either the ground truth or the predicted values are empty, the precision value will be NaN. + """ + if self.truth_empty or self.test_empty: return np.nan else: @@ -269,6 +691,21 @@ def precision(self): return float(np.float32(tp) / np.float32(pred_pos)) def recall(self): + """ + Calculate the recall metric for binary segmentation evaluation. + + Returns: + float: The recall value. + Raises: + None + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.recall() + 0.75 + Note: + Recall is a measure of the ability of a binary classifier to identify all positive samples. + It is calculated as the ratio of true positives to the total number of actual positives. + """ if self.truth_empty or self.test_empty: return np.nan else: @@ -277,6 +714,23 @@ def recall(self): return float(np.float32(tp) / np.float32(cond_pos)) def f1_score(self): + """ + Calculate the F1 score for binary segmentation evaluation. + + Returns: + float: The F1 score value. + Raises: + None. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.f1_score() + 0.75 + Note: + The F1 score is the harmonic mean of precision and recall. + It is a measure of the balance between precision and recall, providing a single metric to evaluate the model's performance. + + If either the ground truth or the predicted values are empty, the F1 score will be NaN. + """ if self.truth_empty or self.test_empty: return np.nan else: @@ -288,6 +742,25 @@ def f1_score(self): return 2 * (rec * prec) / (rec + prec) def voi(self): + """ + Calculate the Variation of Information (VOI) for binary segmentation evaluation. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The VOI value. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.voi() + 0.75 + Note: + The VOI is a measure of the similarity between two segmentations. + It combines the split and merge errors into a single measure of segmentation quality. + If either the ground truth or the predicted values are empty, the VOI will be NaN. + """ if self.truth_empty or self.test_empty: return np.nan else: @@ -297,66 +770,272 @@ def voi(self): return voi_split + voi_merge def mean_false_distance(self): + """ + Calculate the mean false distance between the ground truth and the test results. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The mean false distance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_distance() + 0.25 + Note: + - This method returns np.nan if either the ground truth or the test results are empty. + - The mean false distance is a measure of the average distance between the false positive pixels in the test results and the nearest true positive pixels in the ground truth. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.mean_false_distance def mean_false_negative_distance(self): + """ + Calculate the mean false negative distance between the ground truth and the test results. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The mean false negative distance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_negative_distance() + 0.25 + Note: + This method returns np.nan if either the ground truth or the test results are empty. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.mean_false_negative_distance def mean_false_positive_distance(self): + """ + Calculate the mean false positive distance. + + This method calculates the mean false positive distance between the ground truth and the test results. + If either the ground truth or the test results are empty, the method returns NaN. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The mean false positive distance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_positive_distance() + 0.5 + Note: + The mean false positive distance is a measure of the average distance between false positive pixels in the + test results and the corresponding ground truth pixels. It is commonly used to evaluate the performance of + binary segmentation algorithms. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.mean_false_positive_distance def mean_false_distance_clipped(self): + """ + Calculate the mean false distance (clipped) between the ground truth and the test results. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The mean false distance (clipped) value. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_distance_clipped() + 0.123 + Note: + This method returns np.nan if either the ground truth or the test results are empty. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.mean_false_distance_clipped def mean_false_negative_distance_clipped(self): + """ + Calculate the mean false negative distance, with clipping. + + This method calculates the mean false negative distance between the ground truth and the test results. + The distance is clipped to avoid extreme values. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The mean false negative distance with clipping. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_negative_distance_clipped() + 0.123 + Note: + - The mean false negative distance is a measure of the average distance between the false negative pixels in the ground truth and the test results. + - Clipping the distance helps to avoid extreme values that may skew the overall evaluation. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.mean_false_negative_distances_clipped def mean_false_positive_distance_clipped(self): + """ + Calculate the mean false positive distance, with clipping. + + This method calculates the mean false positive distance between the ground truth and the test results, + taking into account any clipping that may have been applied. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The mean false positive distance with clipping. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_positive_distance_clipped() + 0.25 + Note: + - The mean false positive distance is a measure of the average distance between false positive pixels + in the test results and the corresponding ground truth pixels. + - If either the ground truth or the test results are empty, the method returns NaN. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.mean_false_positive_distances_clipped def false_positive_rate_with_tolerance(self): + """ + Calculate the false positive rate with tolerance. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The false positive rate with tolerance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_positive_rate_with_tolerance() + 0.25 + Note: + This method calculates the false positive rate with tolerance by comparing the truth and test data. + If either the truth or test data is empty, it returns NaN. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.false_positive_rate_with_tolerance def false_negative_rate_with_tolerance(self): + """ + Calculate the false negative rate with tolerance. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + The false negative rate with tolerance as a floating-point number. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_negative_rate_with_tolerance() + 0.25 + Note: + This method calculates the false negative rate with tolerance, which is a measure of the proportion of false negatives in a binary segmentation evaluation. If either the ground truth or the test data is empty, the method returns NaN. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.false_negative_rate_with_tolerance def precision_with_tolerance(self): + """ + Calculate the precision with tolerance. + + This method calculates the precision with tolerance by comparing the truth and test data. + Precision is the ratio of true positives to the sum of true positives and false positives. + Tolerance is a distance threshold within which two pixels are considered to be a match. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The precision with tolerance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.precision_with_tolerance() + 0.75 + Note: + - Precision is a measure of the accuracy of the positive predictions. + - If either the ground truth or the test data is empty, the method returns NaN. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.precision_with_tolerance - def recall_with_tolerance(self): + """ + Calculate the recall with tolerance for the binary segmentation evaluator. + + Returns: + float: The recall with tolerance value. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.recall_with_tolerance() + 0.75 + Note: + This method calculates the recall with tolerance, which is a measure of how well the binary segmentation evaluator performs. It returns the recall with tolerance value as a float. If either the truth or test data is empty, it returns NaN. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.recall_with_tolerance def f1_score_with_tolerance(self): + """ + Calculate the F1 score with tolerance. + + This method calculates the F1 score with tolerance between the ground truth and the test results. + If either the ground truth or the test results are empty, the function returns NaN. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The F1 score with tolerance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.f1_score_with_tolerance() + 0.85 + Note: + The F1 score is a measure of a test's accuracy. It considers both the precision and recall of the test to compute the score. + The tolerance parameter allows for a certain degree of variation between the ground truth and the test results. + """ if self.truth_empty or self.test_empty: return np.nan else: @@ -364,9 +1043,94 @@ def f1_score_with_tolerance(self): class CremiEvaluator: + """ + Evaluate the performance of a binary segmentation task using the CREMI score. + The CREMI score is a measure of the similarity between two binary segmentations. + + Attributes: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + sampling : Tuple[float, float, float] + the sampling resolution + clip_distance : float + the maximum distance to clip + tol_distance : float + the tolerance distance + Methods: + false_positive_distances + Return the false positive distances. + false_positives_with_tolerance + Return the false positives with tolerance. + false_positive_rate_with_tolerance + Return the false positive rate with tolerance. + false_negatives_with_tolerance + Return the false negatives with tolerance. + false_negative_rate_with_tolerance + Return the false negative rate with tolerance. + true_positives_with_tolerance + Return the true positives with tolerance. + precision_with_tolerance + Return the precision with tolerance. + recall_with_tolerance + Return the recall with tolerance. + f1_score_with_tolerance + Return the F1 score with tolerance. + mean_false_positive_distances_clipped + Return the mean false positive distances clipped. + mean_false_negative_distances_clipped + Return the mean false negative distances clipped. + mean_false_positive_distance + Return the mean false positive distance. + false_negative_distances + Return the false negative distances. + mean_false_negative_distance + Return the mean false negative distance. + mean_false_distance + Return the mean false distance. + mean_false_distance_clipped + Return the mean false distance clipped. + Note: + - The CremiEvaluator class is used to evaluate the performance of a binary segmentation task using the CREMI score. + - True and test binary segmentations are compared to calculate various evaluation metrics. + - The class provides methods to evaluate the performance of the binary segmentation task. + - Toleration distance is used to determine the tolerance level for the evaluation. + - Clip distance is used to clip the distance values to avoid extreme values. + - All evaluation scores should inherit from this class. + """ def __init__( self, truth, test, sampling=(1, 1, 1), clip_distance=200, tol_distance=40 ): + """ + Initialize the Cremi evaluator. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + sampling : Tuple[float, float, float] + the sampling resolution + clip_distance : float + the maximum distance to clip + tol_distance : float + the tolerance distance + Returns: + CremiEvaluator + the Cremi evaluator + Raises: + ValueError: if the truth binary segmentation is not valid + Examples: + >>> truth = np.array([[[0, 1], [1, 0]], [[0, 1], [1, 0]]]) + >>> test = np.array([[[0, 1], [1, 0]], [[0, 1], [1, 0]]]) + >>> sampling = (1, 1, 1) + >>> clip_distance = 200 + >>> tol_distance = 40 + >>> cremi_evaluator = CremiEvaluator(truth, test, sampling, clip_distance, tol_distance) + Note: + This function is used to initialize the Cremi evaluator. + """ self.test = test self.truth = truth self.sampling = sampling @@ -375,37 +1139,176 @@ def __init__( @lazy_property.LazyProperty def test_mask(self): + """ + Generate a binary mask for the test data. + + Args: + test : np.ndarray + the test binary segmentation + Returns: + test_mask (ndarray): A binary mask indicating the regions of interest in the test data. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.test = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) + >>> evaluator.test_mask() + array([[False, True, False], + [ True, True, True], + [False, True, False]]) + Note: + This method assumes that the background class is represented by the constant `BG`. + """ # todo: more involved masking test_mask = self.test == BG return test_mask @lazy_property.LazyProperty def truth_mask(self): + """ + Returns a binary mask indicating the truth values. + + Args: + truth : np.ndarray + the truth binary segmentation + Returns: + truth_mask (ndarray): A binary mask where True indicates the truth values and False indicates other values. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> mask = evaluator.truth_mask() + >>> print(mask) + [[ True True False] + [False True False] + [ True False False]] + Note: + The truth mask is computed by comparing the truth values with a predefined background value (BG). + """ truth_mask = self.truth == BG return truth_mask @lazy_property.LazyProperty def test_edt(self): + """ + Calculate the Euclidean Distance Transform (EDT) of the test mask. + + Args: + self.test_mask (ndarray): The binary test mask. + self.sampling (float or sequence of floats): The pixel spacing or sampling along each dimension. + Returns: + ndarray: The Euclidean Distance Transform of the test mask. + Examples: + # Example 1: + test_mask = np.array([[0, 0, 1], + [1, 1, 1], + [0, 0, 0]]) + sampling = 1.0 + result = test_edt(test_mask, sampling) + # Output: array([[1. , 1. , 0. ], + # [0. , 0. , 0. ], + # [1. , 1. , 1.41421356]]) + + # Example 2: + test_mask = np.array([[0, 1, 0], + [1, 0, 1], + [0, 1, 0]]) + sampling = 0.5 + result = test_edt(test_mask, sampling) + # Output: array([[0.5 , 0. , 0.5 ], + # [0. , 0.70710678, 0. ], + # [0.5 , 0. , 0.5 ]]) + + Note: + The Euclidean Distance Transform (EDT) calculates the distance from each pixel in the binary mask to the nearest boundary pixel. It is commonly used in image processing and computer vision tasks, such as edge detection and shape analysis. + """ test_edt = scipy.ndimage.distance_transform_edt(self.test_mask, self.sampling) return test_edt @lazy_property.LazyProperty def truth_edt(self): + """ + Calculate the Euclidean Distance Transform (EDT) of the ground truth mask. + + Args: + self.truth_mask (ndarray): The binary ground truth mask. + self.sampling (float or sequence of floats): The pixel spacing or sampling along each dimension. + Returns: + ndarray: The Euclidean Distance Transform of the ground truth mask. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> edt = evaluator.truth_edt() + Note: + The Euclidean Distance Transform (EDT) calculates the distance from each pixel in the binary mask to the nearest boundary pixel. It is commonly used in image processing and computer vision tasks. + """ truth_edt = scipy.ndimage.distance_transform_edt(self.truth_mask, self.sampling) return truth_edt @lazy_property.LazyProperty def false_positive_distances(self): + """ + Calculate the distances of false positive pixels from the ground truth segmentation. + + Args: + self.test_mask (ndarray): The binary test mask. + self.truth_edt (ndarray): The Euclidean Distance Transform of the ground truth segmentation. + Returns: + numpy.ndarray: An array containing the distances of false positive pixels from the ground truth segmentation. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> distances = evaluator.false_positive_distances() + >>> print(distances) + [1.2, 0.8, 2.5, 1.0] + Note: + This method assumes that the ground truth segmentation and the test mask have been initialized. + The ground truth segmentation is stored in the `truth_edt` attribute, and the test mask is obtained by inverting the `test_mask` attribute. + """ test_bin = np.invert(self.test_mask) false_positive_distances = self.truth_edt[test_bin] return false_positive_distances @lazy_property.LazyProperty def false_positives_with_tolerance(self): + """ + Calculate the number of false positives with a given tolerance distance. + + Args: + self.false_positive_distances (ndarray): The distances of false positive pixels from the ground truth segmentation. + self.tol_distance (float): The tolerance distance. + Returns: + int: The number of false positives with a distance greater than the tolerance distance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_positive_distances = [1, 2, 3] + >>> evaluator.tol_distance = 2 + >>> false_positives = evaluator.false_positives_with_tolerance() + >>> print(false_positives) + 1 + Note: + The `false_positive_distances` attribute should be initialized before calling this method. + + """ return np.sum(self.false_positive_distances > self.tol_distance) @lazy_property.LazyProperty def false_positive_rate_with_tolerance(self): + """ + Calculate the false positive rate with tolerance. + + This method calculates the false positive rate by dividing the number of false positives with tolerance + by the number of condition negatives. + + Args: + self.false_positives_with_tolerance (int): The number of false positives with tolerance. + self.truth_mask (ndarray): The binary ground truth mask. + Returns: + float: The false positive rate with tolerance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_positives_with_tolerance = 10 + >>> evaluator.truth_mask = np.array([0, 1, 0, 1, 0]) + >>> evaluator.false_positive_rate_with_tolerance() + 0.5 + Note: + The false positive rate with tolerance is a measure of the proportion of false positive predictions + with respect to the total number of condition negatives. It is commonly used in binary segmentation tasks. + """ condition_negative = np.sum(self.truth_mask) return float( np.float32(self.false_positives_with_tolerance) @@ -414,10 +1317,51 @@ def false_positive_rate_with_tolerance(self): @lazy_property.LazyProperty def false_negatives_with_tolerance(self): + """ + Calculate the number of false negatives with tolerance. + + Args: + self.false_negative_distances (ndarray): The distances of false negative pixels from the ground truth segmentation. + self.tol_distance (float): The tolerance distance. + Returns: + int: The number of false negatives with tolerance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_negative_distances = [1, 2, 3] + >>> evaluator.tol_distance = 2 + >>> false_negatives = evaluator.false_negatives_with_tolerance() + >>> print(false_negatives) + 1 + Note: + False negatives are cases where the model incorrectly predicts the absence of a positive class. + The tolerance distance is used to determine whether a false negative is within an acceptable range. + + """ return np.sum(self.false_negative_distances > self.tol_distance) @lazy_property.LazyProperty def false_negative_rate_with_tolerance(self): + """ + Calculate the false negative rate with tolerance. + + This method calculates the false negative rate by dividing the number of false negatives + with tolerance by the number of condition positives. + + Args: + self.false_negatives_with_tolerance (int): The number of false negatives with tolerance. + self.false_negative_distances (ndarray): The distances of false negative pixels from the ground truth segmentation. + Returns: + float: The false negative rate with tolerance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_negative_distances = [1, 2, 3] + >>> evaluator.false_negatives_with_tolerance = 2 + >>> evaluator.false_negative_rate_with_tolerance() + 0.6666666666666666 + Note: + The false negative rate with tolerance is a measure of the proportion of condition positives + that are incorrectly classified as negatives, considering a certain tolerance level. + """ condition_positive = len(self.false_negative_distances) return float( np.float32(self.false_negatives_with_tolerance) @@ -426,6 +1370,29 @@ def false_negative_rate_with_tolerance(self): @lazy_property.LazyProperty def true_positives_with_tolerance(self): + """ + Calculate the number of true positives with tolerance. + + Args: + self.test_mask (ndarray): The test binary segmentation mask. + self.truth_mask (ndarray): The ground truth binary segmentation mask. + self.false_negatives_with_tolerance (int): The number of false negatives with tolerance. + self.false_positives_with_tolerance (int): The number of false positives with tolerance. + Returns: + int: The number of true positives with tolerance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.test_mask = np.array([[0, 1], [1, 0]]) + >>> evaluator.truth_mask = np.array([[0, 1], [1, 0]]) + >>> evaluator.false_negatives_with_tolerance = 1 + >>> evaluator.false_positives_with_tolerance = 1 + >>> true_positives = evaluator.true_positives_with_tolerance() + >>> print(true_positives) + 2 + Note: + True positives are cases where the model correctly predicts the presence of a positive class. + The tolerance distance is used to determine whether a true positive is within an acceptable range. + """ all_pos = np.sum(np.invert(self.test_mask & self.truth_mask)) return ( all_pos @@ -435,6 +1402,31 @@ def true_positives_with_tolerance(self): @lazy_property.LazyProperty def precision_with_tolerance(self): + """ + Calculate the precision with tolerance. + + This method calculates the precision with tolerance by dividing the number of true positives + with tolerance by the sum of true positives with tolerance and false positives with tolerance. + + Args: + self.true_positives_with_tolerance (int): The number of true positives with tolerance. + self.false_positives_with_tolerance (int): The number of false positives with tolerance. + Returns: + float: The precision with tolerance. + Raises: + ZeroDivisionError: If the sum of true positives with tolerance and false positives with tolerance is zero. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.true_positives_with_tolerance = 10 + >>> evaluator.false_positives_with_tolerance = 5 + >>> evaluator.precision_with_tolerance() + 0.6666666666666666 + Note: + The precision with tolerance is a measure of the proportion of true positives with tolerance + out of the total number of predicted positives with tolerance. + It indicates how well the binary segmentation evaluator performs in terms of correctly identifying positive samples. + If the sum of true positives with tolerance and false positives with tolerance is zero, the precision with tolerance is undefined and a ZeroDivisionError is raised. + """ return float( np.float32(self.true_positives_with_tolerance) / np.float32( @@ -444,6 +1436,23 @@ def precision_with_tolerance(self): @lazy_property.LazyProperty def recall_with_tolerance(self): + """ + A measure of the ability of a binary classifier to identify all positive samples. + + Args: + self.true_positives_with_tolerance (int): The number of true positives with tolerance. + self.false_negatives_with_tolerance (int): The number of false negatives with tolerance. + Returns: + float: The recall with tolerance value. + Raises: + ZeroDivisionError: If the sum of true positives with tolerance and false negatives with tolerance is zero. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.recall_with_tolerance() + 0.75 + Note: + This method calculates the recall with tolerance, which is a measure of how well the binary segmentation evaluator performs. It returns the recall with tolerance value as a float. If either the truth or test data is empty, it returns NaN. + """ return float( np.float32(self.true_positives_with_tolerance) / np.float32( @@ -453,6 +1462,28 @@ def recall_with_tolerance(self): @lazy_property.LazyProperty def f1_score_with_tolerance(self): + """ + Calculate the F1 score with tolerance. + + Args: + self.recall_with_tolerance (float): The recall with tolerance value. + self.precision_with_tolerance (float): The precision with tolerance value. + Returns: + float: The F1 score with tolerance. + Raises: + ZeroDivisionError: If both the recall with tolerance and precision with tolerance are zero. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.recall_with_tolerance = 0.8 + >>> evaluator.precision_with_tolerance = 0.9 + >>> evaluator.f1_score_with_tolerance() + 0.8571428571428571 + Note: + The F1 score is a measure of a test's accuracy. It considers both the precision and recall of the test to compute the score. + The F1 score with tolerance is calculated using the formula: + F1 = 2 * (recall_with_tolerance * precision_with_tolerance) / (recall_with_tolerance + precision_with_tolerance) + If both recall_with_tolerance and precision_with_tolerance are 0, the F1 score with tolerance will be NaN. + """ if self.recall_with_tolerance == 0 and self.precision_with_tolerance == 0: return np.nan else: @@ -464,6 +1495,26 @@ def f1_score_with_tolerance(self): @lazy_property.LazyProperty def mean_false_positive_distances_clipped(self): + """ + Calculate the mean of the false positive distances, clipped to a maximum distance. + + Args: + self.false_positive_distances (ndarray): The distances of false positive pixels from the ground truth segmentation. + self.clip_distance (float): The maximum distance to clip. + Returns: + float: The mean of the false positive distances, clipped to a maximum distance. + Raises: + ValueError: If the clip distance is not set. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_positive_distances = [1, 2, 3, 4, 5] + >>> evaluator.clip_distance = 3 + >>> evaluator.mean_false_positive_distances_clipped() + 2.5 + Note: + + This method calculates the mean of the false positive distances, where the distances are clipped to a maximum distance. The `false_positive_distances` attribute should be set before calling this method. The `clip_distance` attribute determines the maximum distance to which the distances are clipped. + """ mean_false_positive_distance_clipped = np.mean( np.clip(self.false_positive_distances, None, self.clip_distance) ) @@ -471,6 +1522,25 @@ def mean_false_positive_distances_clipped(self): @lazy_property.LazyProperty def mean_false_negative_distances_clipped(self): + """ + Calculate the mean of the false negative distances, clipped to a maximum distance. + + Args: + self.false_negative_distances (ndarray): The distances of false negative pixels from the ground truth segmentation. + self.clip_distance (float): The maximum distance to clip. + Returns: + float: The mean of the false negative distances, clipped to a maximum distance. + Raises: + ValueError: If the clip distance is not set. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_negative_distances = [1, 2, 3, 4, 5] + >>> evaluator.clip_distance = 3 + >>> evaluator.mean_false_negative_distances_clipped() + 2.5 + Note: + This method calculates the mean of the false negative distances, where the distances are clipped to a maximum distance. The `false_negative_distances` attribute should be set before calling this method. The `clip_distance` attribute determines the maximum distance to which the distances are clipped. + """ mean_false_negative_distance_clipped = np.mean( np.clip(self.false_negative_distances, None, self.clip_distance) ) @@ -478,22 +1548,98 @@ def mean_false_negative_distances_clipped(self): @lazy_property.LazyProperty def mean_false_positive_distance(self): + """ + Calculate the mean false positive distance. + + This method calculates the mean distance between the false positive points and the ground truth points. + + Args: + self.false_positive_distances (ndarray): The distances of false positive pixels from the ground truth mask. + Returns: + float: The mean false positive distance. + Raises: + ValueError: If the false positive distances are not set. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_positive_distances = [1.2, 3.4, 2.1] + >>> evaluator.mean_false_positive_distance() + 2.2333333333333334 + Note: + The false positive distances should be set before calling this method using the `false_positive_distances` attribute. + """ mean_false_positive_distance = np.mean(self.false_positive_distances) return mean_false_positive_distance @lazy_property.LazyProperty def false_negative_distances(self): + """ + Calculate the distances of false negative pixels from the ground truth mask. + + Args: + self.truth_mask (ndarray): The binary ground truth mask. + Returns: + numpy.ndarray: An array containing the distances of false negative pixels. + Raises: + ValueError: If the truth mask is not set. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> distances = evaluator.false_negative_distances() + >>> print(distances) + [0.5, 1.0, 1.5, 2.0] + Note: + This method assumes that the ground truth mask and the test mask have already been set. + """ truth_bin = np.invert(self.truth_mask) false_negative_distances = self.test_edt[truth_bin] return false_negative_distances @lazy_property.LazyProperty def mean_false_negative_distance(self): + """ + Calculate the mean false negative distance. + + Args: + self.false_negative_distances (ndarray): The distances of false negative pixels from the ground truth mask. + Returns: + float: The mean false negative distance. + Raises: + ValueError: If the false negative distances are not set. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_negative_distances = [1.2, 3.4, 2.1] + >>> evaluator.mean_false_negative_distance() + 2.2333333333333334 + Note: + The mean false negative distance is calculated as the average of all false negative distances. + + """ mean_false_negative_distance = np.mean(self.false_negative_distances) return mean_false_negative_distance @lazy_property.LazyProperty def mean_false_distance(self): + """ + Calculate the mean false distance. + + This method calculates the mean false distance by taking the average of the mean false positive distance + and the mean false negative distance. + + Args: + self.mean_false_positive_distance (float): The mean false positive distance. + self.mean_false_negative_distance (float): The mean false negative distance. + Returns: + float: The calculated mean false distance. + Raises: + ValueError: If the mean false positive distance or the mean false negative distance is not set. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_distance() + 5.0 + Note: + The mean false distance is a metric used to evaluate the performance of a binary segmentation model. + It provides a measure of the average distance between false positive and false negative predictions. + + """ mean_false_distance = 0.5 * ( self.mean_false_positive_distance + self.mean_false_negative_distance ) @@ -501,6 +1647,28 @@ def mean_false_distance(self): @lazy_property.LazyProperty def mean_false_distance_clipped(self): + """ + Calculates the mean false distance clipped. + + This method calculates the mean false distance clipped by taking the average of the mean false positive distances + clipped and the mean false negative distances clipped. + + Args: + self.mean_false_positive_distances_clipped (float): The mean false positive distances clipped. + self.mean_false_negative_distances_clipped (float): The mean false negative distances clipped. + Returns: + float: The calculated mean false distance clipped. + Raises: + ValueError: If the mean false positive distances clipped or the mean false negative distances clipped are not set. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_distance_clipped() + 2.5 + Note: + The mean false distance clipped is calculated as 0.5 * (mean_false_positive_distances_clipped + + mean_false_negative_distances_clipped). + + """ mean_false_distance_clipped = 0.5 * ( self.mean_false_positive_distances_clipped + self.mean_false_negative_distances_clipped diff --git a/dacapo/experiments/tasks/evaluators/dummy_evaluation_scores.py b/dacapo/experiments/tasks/evaluators/dummy_evaluation_scores.py index eb7879cbc..9891d5e8f 100644 --- a/dacapo/experiments/tasks/evaluators/dummy_evaluation_scores.py +++ b/dacapo/experiments/tasks/evaluators/dummy_evaluation_scores.py @@ -6,6 +6,24 @@ @attr.s class DummyEvaluationScores(EvaluationScores): + """ + The evaluation scores for the dummy task. The scores include the frizz level and blipp score. A higher frizz level indicates more frizz, while a higher blipp score indicates better performance. + + Attributes: + frizz_level : float + the frizz level + blipp_score : float + the blipp score + Methods: + higher_is_better(criterion) + Return whether higher is better for the given criterion. + bounds(criterion) + Return the bounds for the given criterion. + store_best(criterion) + Return whether to store the best score for the given criterion. + Note: + The DummyEvaluationScores class is used to store the evaluation scores for the dummy task. The class also provides methods to determine whether higher is better for a given criterion, the bounds for a given criterion, and whether to store the best score for a given criterion. + """ criteria = ["frizz_level", "blipp_score"] frizz_level: float = attr.ib(default=float("nan")) @@ -13,6 +31,23 @@ class DummyEvaluationScores(EvaluationScores): @staticmethod def higher_is_better(criterion: str) -> bool: + """ + Return whether higher is better for the given criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + bool + whether higher is better for this criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> DummyEvaluationScores.higher_is_better("frizz_level") + True + Note: + This function is used to determine whether higher is better for the given criterion. + """ mapping = { "frizz_level": True, "blipp_score": False, @@ -23,6 +58,23 @@ def higher_is_better(criterion: str) -> bool: def bounds( criterion: str, ) -> Tuple[Union[int, float, None], Union[int, float, None]]: + """ + Return the bounds for the given criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + Tuple[Union[int, float, None], Union[int, float, None]] + the bounds for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> DummyEvaluationScores.bounds("frizz_level") + (0.0, 1.0) + Note: + This function is used to return the bounds for the given criterion. + """ mapping = { "frizz_level": (0.0, 1.0), "blipp_score": (0.0, 1.0), @@ -31,4 +83,21 @@ def bounds( @staticmethod def store_best(criterion: str) -> bool: + """ + Return whether to store the best score for the given criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + bool + whether to store the best score for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> DummyEvaluationScores.store_best("frizz_level") + True + Note: + This function is used to determine whether to store the best score for the given criterion. + """ return True diff --git a/dacapo/experiments/tasks/evaluators/dummy_evaluator.py b/dacapo/experiments/tasks/evaluators/dummy_evaluator.py index f9a4dc1ea..b6aad561c 100644 --- a/dacapo/experiments/tasks/evaluators/dummy_evaluator.py +++ b/dacapo/experiments/tasks/evaluators/dummy_evaluator.py @@ -5,6 +5,20 @@ class DummyEvaluator(Evaluator): + """ + A class representing a dummy evaluator. This evaluator is used for testing purposes. + + Attributes: + criteria : List[str] + the evaluation criteria + Methods: + evaluate(output_array_identifier, evaluation_dataset) + Evaluate the output array against the evaluation dataset. + score + Return the evaluation scores. + Note: + The DummyEvaluator class is used to evaluate the performance of a dummy task. + """ criteria = ["frizz_level", "blipp_score"] def evaluate(self, output_array_identifier, evaluation_dataset): @@ -14,9 +28,18 @@ def evaluate(self, output_array_identifier, evaluation_dataset): Args: output_array_identifier : The output array to be evaluated. evaluation_dataset : The dataset to be used for evaluation. - Returns: DummyEvaluationScore: An object of DummyEvaluationScores class, with the evaluation scores. + Raises: + ValueError: if the output array identifier is not valid + Examples: + >>> dummy_evaluator = DummyEvaluator() + >>> output_array_identifier = "output_array" + >>> evaluation_dataset = "evaluation_dataset" + >>> dummy_evaluator.evaluate(output_array_identifier, evaluation_dataset) + DummyEvaluationScores(frizz_level=0.0, blipp_score=0.0) + Note: + This function is used to evaluate the output array against the evaluation dataset. """ return DummyEvaluationScores( frizz_level=random.random(), blipp_score=random.random() @@ -24,4 +47,16 @@ def evaluate(self, output_array_identifier, evaluation_dataset): @property def score(self) -> DummyEvaluationScores: + """ + Return the evaluation scores. + + Returns: + DummyEvaluationScores: An object of DummyEvaluationScores class, with the evaluation scores. + Examples: + >>> dummy_evaluator = DummyEvaluator() + >>> dummy_evaluator.score + DummyEvaluationScores(frizz_level=0.0, blipp_score=0.0) + Note: + This function is used to return the evaluation scores. + """ return DummyEvaluationScores() diff --git a/dacapo/experiments/tasks/evaluators/evaluation_scores.py b/dacapo/experiments/tasks/evaluators/evaluation_scores.py index 3733b9133..ebf8924c5 100644 --- a/dacapo/experiments/tasks/evaluators/evaluation_scores.py +++ b/dacapo/experiments/tasks/evaluators/evaluation_scores.py @@ -6,11 +6,44 @@ @attr.s class EvaluationScores: - """Base class for evaluation scores.""" + """ + Base class for evaluation scores. This class is used to store the evaluation scores for a task. + The scores include the evaluation criteria. The class also provides methods to determine whether higher is better for a given criterion, + the bounds for a given criterion, and whether to store the best score for a given criterion. + + Attributes: + criteria : List[str] + the evaluation criteria + Methods: + higher_is_better(criterion) + Return whether higher is better for the given criterion. + bounds(criterion) + Return the bounds for the given criterion. + store_best(criterion) + Return whether to store the best score for the given criterion. + Note: + The EvaluationScores class is used to store the evaluation scores for a task. All evaluation scores should inherit from this class. + + """ @property @abstractmethod def criteria(self) -> List[str]: + """ + The evaluation criteria. + + Returns: + List[str] + the evaluation criteria + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluation_scores = EvaluationScores() + >>> evaluation_scores.criteria + ["criterion1", "criterion2"] + Note: + This function is used to return the evaluation criteria. + """ pass @staticmethod @@ -18,6 +51,23 @@ def criteria(self) -> List[str]: def higher_is_better(criterion: str) -> bool: """ Wether or not higher is better for this criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + bool + whether higher is better for this criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluation_scores = EvaluationScores() + >>> criterion = "criterion1" + >>> evaluation_scores.higher_is_better(criterion) + True + Note: + This function is used to determine whether higher is better for a given criterion. + """ pass @@ -27,7 +77,24 @@ def bounds( criterion: str, ) -> Tuple[Union[int, float, None], Union[int, float, None]]: """ - The bounds for this criterion + The bounds for this criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + Tuple[Union[int, float, None], Union[int, float, None]] + the bounds for this criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluation_scores = EvaluationScores() + >>> criterion = "criterion1" + >>> evaluation_scores.bounds(criterion) + (0, 1) + Note: + This function is used to return the bounds for the given criterion. + """ pass @@ -37,5 +104,21 @@ def store_best(criterion: str) -> bool: """ Whether or not to save the best validation block and model weights for this criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + bool + whether to store the best score for this criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluation_scores = EvaluationScores() + >>> criterion = "criterion1" + >>> evaluation_scores.store_best(criterion) + True + Note: + This function is used to return whether to store the best score for the given criterion. """ pass diff --git a/dacapo/experiments/tasks/evaluators/evaluator.py b/dacapo/experiments/tasks/evaluators/evaluator.py index 83e4763b3..764fa93c4 100644 --- a/dacapo/experiments/tasks/evaluators/evaluator.py +++ b/dacapo/experiments/tasks/evaluators/evaluator.py @@ -21,10 +21,38 @@ class Evaluator(ABC): - """Base class of all evaluators. + """ + Base class of all evaluators: An abstract class representing an evaluator that compares and evaluates the output array against the evaluation array. An evaluator takes a post-processor's output and compares it against - ground-truth. + ground-truth. It then returns a set of scores that can be used to + determine the quality of the post-processor's output. + + Attributes: + best_scores : Dict[OutputIdentifier, BestScore] + the best scores for each dataset/post-processing parameter/criterion combination + Methods: + evaluate(output_array_identifier, evaluation_array) + Compare and evaluate the output array against the evaluation array. + is_best(dataset, parameter, criterion, score) + Check if the provided score is the best for this dataset/parameter/criterion combo. + get_overall_best(dataset, criterion) + Return the best score for the given dataset and criterion. + get_overall_best_parameters(dataset, criterion) + Return the best parameters for the given dataset and criterion. + compare(score_1, score_2, criterion) + Compare two scores for the given criterion. + set_best(validation_scores) + Find the best iteration for each dataset/post_processing_parameter/criterion. + higher_is_better(criterion) + Return whether higher is better for the given criterion. + bounds(criterion) + Return the bounds for the given criterion. + store_best(criterion) + Return whether to store the best score for the given criterion. + Note: + The Evaluator class is used to compare and evaluate the output array against the evaluation array. + """ @abstractmethod @@ -34,17 +62,24 @@ def evaluate( """ Compares and evaluates the output array against the evaluation array. - Parameters - ---------- - output_array_identifier : Array - The output data array to evaluate - evaluation_array : Array - The evaluation data array to compare with the output - - Returns - ------- - EvaluationScores - The detailed evaluation scores after the comparison. + Args: + output_array_identifier : LocalArrayIdentifier + The identifier of the output array. + evaluation_array : Array + The evaluation array. + Returns: + EvaluationScores + The evaluation scores. + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> output_array_identifier = LocalArrayIdentifier("output_array") + >>> evaluation_array = Array() + >>> evaluator.evaluate(output_array_identifier, evaluation_array) + EvaluationScores() + Note: + This function is used to compare and evaluate the output array against the evaluation array. """ pass @@ -52,6 +87,21 @@ def evaluate( def best_scores( self, ) -> Dict[OutputIdentifier, BestScore]: + """ + The best scores for each dataset/post-processing parameter/criterion combination. + + Returns: + Dict[OutputIdentifier, BestScore] + the best scores for each dataset/post-processing parameter/criterion combination + Raises: + AttributeError: if the best scores are not set + Examples: + >>> evaluator = Evaluator() + >>> evaluator.best_scores + {} + Note: + This function is used to return the best scores for each dataset/post-processing parameter/criterion combination. + """ if not hasattr(self, "_best_scores"): self._best_scores: Dict[OutputIdentifier, BestScore] = {} return self._best_scores @@ -64,7 +114,32 @@ def is_best( score: "EvaluationScores", ) -> bool: """ - Check if the provided score is the best for this dataset/parameter/criterion combo + Check if the provided score is the best for this dataset/parameter/criterion combo. + + Args: + dataset : Dataset + the dataset + parameter : PostProcessorParameters + the post-processor parameters + criterion : str + the criterion + score : EvaluationScores + the evaluation scores + Returns: + bool + whether the provided score is the best for this dataset/parameter/criterion combo + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> dataset = Dataset() + >>> parameter = PostProcessorParameters() + >>> criterion = "criterion" + >>> score = EvaluationScores() + >>> evaluator.is_best(dataset, parameter, criterion, score) + False + Note: + This function is used to check if the provided score is the best for this dataset/parameter/criterion combo. """ if not self.store_best(criterion) or math.isnan(getattr(score, criterion)): return False @@ -78,6 +153,28 @@ def is_best( ) def get_overall_best(self, dataset: "Dataset", criterion: str): + """ + Return the best score for the given dataset and criterion. + + Args: + dataset : Dataset + the dataset + criterion : str + the criterion + Returns: + Optional[float] + the best score for the given dataset and criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> dataset = Dataset() + >>> criterion = "criterion" + >>> evaluator.get_overall_best(dataset, criterion) + None + Note: + This function is used to return the best score for the given dataset and criterion. + """ overall_best = None if self.best_scores: for _, parameter, _ in self.best_scores.keys(): @@ -99,6 +196,28 @@ def get_overall_best(self, dataset: "Dataset", criterion: str): return overall_best def get_overall_best_parameters(self, dataset: "Dataset", criterion: str): + """ + Return the best parameters for the given dataset and criterion. + + Args: + dataset : Dataset + the dataset + criterion : str + the criterion + Returns: + Optional[PostProcessorParameters] + the best parameters for the given dataset and criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> dataset = Dataset() + >>> criterion = "criterion" + >>> evaluator.get_overall_best_parameters(dataset, criterion) + None + Note: + This function is used to return the best parameters for the given dataset and criterion. + """ overall_best = None overall_best_parameters = None if self.best_scores: @@ -121,6 +240,31 @@ def get_overall_best_parameters(self, dataset: "Dataset", criterion: str): return overall_best_parameters def compare(self, score_1, score_2, criterion): + """ + Compare two scores for the given criterion. + + Args: + score_1 : float + the first score + score_2 : float + the second score + criterion : str + the criterion + Returns: + bool + whether the first score is better than the second score for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> score_1 = 0.0 + >>> score_2 = 0.0 + >>> criterion = "criterion" + >>> evaluator.compare(score_1, score_2, criterion) + False + Note: + This function is used to compare two scores for the given criterion. + """ if self.higher_is_better(criterion): return score_1 > score_2 else: @@ -128,7 +272,21 @@ def compare(self, score_1, score_2, criterion): def set_best(self, validation_scores: "ValidationScores") -> None: """ - Find the best iteration for each dataset/post_processing_parameter/criterion + Find the best iteration for each dataset/post_processing_parameter/criterion. + + Args: + validation_scores : ValidationScores + the validation scores + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> validation_scores = ValidationScores() + >>> evaluator.set_best(validation_scores) + None + Note: + This function is used to find the best iteration for each dataset/post_processing_parameter/criterion. + Typically, this function is called after the validation scores have been computed. """ scores = validation_scores.to_xarray() @@ -185,12 +343,40 @@ def criteria(self) -> List[str]: criteria might be "precision", "recall", and "jaccard". It is unlikely that the best iteration/post processing parameters will be the same for all 3 of these criteria + + Returns: + List[str] + the evaluation criteria + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> evaluator.criteria + [] + Note: + This function is used to return the evaluation criteria. """ pass def higher_is_better(self, criterion: str) -> bool: """ Wether or not higher is better for this criterion. + + Args: + criterion : str + the criterion + Returns: + bool + whether higher is better for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> criterion = "criterion" + >>> evaluator.higher_is_better(criterion) + False + Note: + This function is used to determine whether higher is better for the given criterion. """ return self.score.higher_is_better(criterion) @@ -199,16 +385,63 @@ def bounds( ) -> Tuple[Union[int, float, None], Union[int, float, None]]: """ The bounds for this criterion + + Args: + criterion : str + the criterion + Returns: + Tuple[Union[int, float, None], Union[int, float, None]] + the bounds for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> criterion = "criterion" + >>> evaluator.bounds(criterion) + (0, 1) + Note: + This function is used to return the bounds for the given criterion. """ return self.score.bounds(criterion) def store_best(self, criterion: str) -> bool: """ The bounds for this criterion + + Args: + criterion : str + the criterion + Returns: + bool + whether to store the best score for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> criterion = "criterion" + >>> evaluator.store_best(criterion) + False + Note: + This function is used to return whether to store the best score for the given criterion. """ return self.score.store_best(criterion) @property @abstractmethod def score(self) -> "EvaluationScores": + """ + The evaluation scores. + + Returns: + EvaluationScores + the evaluation scores + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> evaluator.score + EvaluationScores() + Note: + This function is used to return the evaluation scores. + """ pass diff --git a/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py b/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py index 4e4df9cca..a90945256 100644 --- a/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py +++ b/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py @@ -6,6 +6,26 @@ @attr.s class InstanceEvaluationScores(EvaluationScores): + """ + The evaluation scores for the instance segmentation task. The scores include the variation of information (VOI) split, VOI merge, and VOI. + + Attributes: + voi_split : float + the variation of information (VOI) split + voi_merge : float + the variation of information (VOI) merge + voi : float + the variation of information (VOI) + Methods: + higher_is_better(criterion) + Return whether higher is better for the given criterion. + bounds(criterion) + Return the bounds for the given criterion. + store_best(criterion) + Return whether to store the best score for the given criterion. + Note: + The InstanceEvaluationScores class is used to store the evaluation scores for the instance segmentation task. + """ criteria = ["voi_split", "voi_merge", "voi"] voi_split: float = attr.ib(default=float("nan")) @@ -13,10 +33,42 @@ class InstanceEvaluationScores(EvaluationScores): @property def voi(self): + """ + Return the average of the VOI split and VOI merge. + + Returns: + float + the average of the VOI split and VOI merge + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> instance_evaluation_scores = InstanceEvaluationScores(voi_split=0.1, voi_merge=0.2) + >>> instance_evaluation_scores.voi + 0.15 + Note: + This function is used to calculate the average of the VOI split and VOI merge. + """ return (self.voi_split + self.voi_merge) / 2 @staticmethod def higher_is_better(criterion: str) -> bool: + """ + Return whether higher is better for the given criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + bool + whether higher is better for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> InstanceEvaluationScores.higher_is_better("voi_split") + False + Note: + This function is used to determine whether higher is better for the given criterion. + """ mapping = { "voi_split": False, "voi_merge": False, @@ -28,6 +80,24 @@ def higher_is_better(criterion: str) -> bool: def bounds( criterion: str, ) -> Tuple[Union[int, float, None], Union[int, float, None]]: + """ + Return the bounds for the given criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + Tuple[Union[int, float, None], Union[int, float, None]] + the bounds for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> InstanceEvaluationScores.bounds("voi_split") + (0, 1) + Note: + This function is used to return the bounds for the given criterion. + + """ mapping = { "voi_split": (0, 1), "voi_merge": (0, 1), @@ -37,4 +107,21 @@ def bounds( @staticmethod def store_best(criterion: str) -> bool: + """ + Return whether to store the best score for the given criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + bool + whether to store the best score for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> InstanceEvaluationScores.store_best("voi_split") + True + Note: + This function is used to determine whether to store the best score for the given criterion. + """ return True diff --git a/dacapo/experiments/tasks/evaluators/instance_evaluator.py b/dacapo/experiments/tasks/evaluators/instance_evaluator.py index 30707b369..f16b25971 100644 --- a/dacapo/experiments/tasks/evaluators/instance_evaluator.py +++ b/dacapo/experiments/tasks/evaluators/instance_evaluator.py @@ -14,30 +14,34 @@ def relabel(array, return_backwards_map=False, inplace=False): - """Relabel array, such that IDs are consecutive. Excludes 0. + """ + Relabel array, such that IDs are consecutive. Excludes 0. Args: - array (ndarray): - The array to relabel. - return_backwards_map (``bool``, optional): - If ``True``, return an ndarray that maps new labels (indices in the array) to old labels. - inplace (``bool``, optional): - Perform the replacement in-place on ``array``. - Returns: - A tuple ``(relabelled, n)``, where ``relabelled`` is the relabelled array and ``n`` the number of unique labels found. - If ``return_backwards_map`` is ``True``, returns ``(relabelled, n, backwards_map)``. + Raises: + ValueError: + If ``array`` is not of type ``np.ndarray``. + Examples: + >>> array = np.array([[1, 2, 0], [0, 2, 1]]) + >>> relabel(array) + (array([[1, 2, 0], [0, 2, 1]]), 2) + >>> relabel(array, return_backwards_map=True) + (array([[1, 2, 0], [0, 2, 1]]), 2, [0, 1, 2]) + Note: + This function is used to relabel an array, such that IDs are consecutive. Excludes 0. + """ if array.size == 0: @@ -71,9 +75,47 @@ def relabel(array, return_backwards_map=False, inplace=False): class InstanceEvaluator(Evaluator): + """ + A class representing an evaluator for instance segmentation tasks. + + Attributes: + criteria : List[str] + the evaluation criteria + Methods: + evaluate(output_array_identifier, evaluation_array) + Evaluate the output array against the evaluation array. + score + Return the evaluation scores. + Note: + The InstanceEvaluator class is used to evaluate the performance of an instance segmentation task. + + """ criteria: List[str] = ["voi_merge", "voi_split", "voi"] def evaluate(self, output_array_identifier, evaluation_array): + """ + Evaluate the output array against the evaluation array. + + Args: + output_array_identifier : str + the identifier of the output array + evaluation_array : ZarrArray + the evaluation array + Returns: + InstanceEvaluationScores + the evaluation scores + Raises: + ValueError: if the output array identifier is not valid + Examples: + >>> instance_evaluator = InstanceEvaluator() + >>> output_array_identifier = "output_array" + >>> evaluation_array = ZarrArray.open_from_array_identifier("evaluation_array") + >>> instance_evaluator.evaluate(output_array_identifier, evaluation_array) + InstanceEvaluationScores(voi_merge=0.0, voi_split=0.0) + Note: + This function is used to evaluate the output array against the evaluation array. + + """ output_array = ZarrArray.open_from_array_identifier(output_array_identifier) evaluation_data = evaluation_array[evaluation_array.roi].astype(np.uint64) output_data = output_array[output_array.roi].astype(np.uint64) @@ -86,9 +128,47 @@ def evaluate(self, output_array_identifier, evaluation_array): @property def score(self) -> InstanceEvaluationScores: + """ + Return the evaluation scores. + + Returns: + InstanceEvaluationScores + the evaluation scores + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> instance_evaluator = InstanceEvaluator() + >>> instance_evaluator.score + InstanceEvaluationScores(voi_merge=0.0, voi_split=0.0) + Note: + This function is used to return the evaluation scores. + + """ return InstanceEvaluationScores() def voi(truth, test): + """ + Calculate the variation of information (VOI) between two segmentations. + + Args: + truth : ndarray + the ground truth segmentation + test : ndarray + the test segmentation + Returns: + dict + the variation of information (VOI) scores + Raises: + ValueError: if the truth and test arrays are not of type np.ndarray + Examples: + >>> truth = np.array([[1, 1, 0], [0, 2, 2]]) + >>> test = np.array([[1, 1, 0], [0, 2, 2]]) + >>> voi(truth, test) + {'voi_split': 0.0, 'voi_merge': 0.0} + Note: + This function is used to calculate the variation of information (VOI) between two segmentations. + + """ voi_split, voi_merge = _voi(test + 1, truth + 1, ignore_groundtruth=[]) return {"voi_split": voi_split, "voi_merge": voi_merge} diff --git a/dacapo/experiments/tasks/losses/affinities_loss.py b/dacapo/experiments/tasks/losses/affinities_loss.py index 74fc7fe67..2c8f42943 100644 --- a/dacapo/experiments/tasks/losses/affinities_loss.py +++ b/dacapo/experiments/tasks/losses/affinities_loss.py @@ -3,11 +3,70 @@ class AffinitiesLoss(Loss): + """ + A class representing a loss function that calculates the loss between affinities and local shape descriptors (LSDs). + + Attributes: + num_affinities : int + the number of affinities + lsds_to_affs_weight_ratio : float + the ratio of the weight of the loss between affinities and LSDs + Methods: + compute(prediction, target, weight=None) + Calculate the total loss between prediction and target. + Note: + The AffinitiesLoss class is used to calculate the loss between affinities and local shape descriptors (LSDs). + + """ def __init__(self, num_affinities: int, lsds_to_affs_weight_ratio: float): + """ + Initialize the AffinitiesLoss class with the number of affinities and the ratio of the weight of the loss between affinities and LSDs. + + Args: + num_affinities : int + the number of affinities + lsds_to_affs_weight_ratio : float + the ratio of the weight of the loss between affinities and LSDs + Examples: + >>> affinities_loss = AffinitiesLoss(3, 0.5) + >>> affinities_loss.num_affinities + 3 + >>> affinities_loss.lsds_to_affs_weight_ratio + 0.5 + Note: + The AffinitiesLoss class is used to calculate the loss between affinities and local shape descriptors (LSDs). + + """ self.num_affinities = num_affinities self.lsds_to_affs_weight_ratio = lsds_to_affs_weight_ratio def compute(self, prediction, target, weight): + """ + Method to calculate the total loss between affinities and LSDs. + + Args: + prediction : torch.Tensor + the model's prediction + target : torch.Tensor + the target values + weight : torch.Tensor + the weight to apply to the loss + Returns: + torch.Tensor + the total loss between affinities and LSDs + Raises: + ValueError: if the number of affinities in the prediction and target does not match + Examples: + >>> affinities_loss = AffinitiesLoss(3, 0.5) + >>> prediction = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) + >>> target = torch.tensor([[9, 10, 11, 12], [13, 14, 15, 16]]) + >>> weight = torch.tensor([[1, 1, 1, 1], [1, 1, 1, 1]]) + >>> affinities_loss.compute(prediction, target, weight) + tensor(0.5) + Note: + The AffinitiesLoss class is used to calculate the loss between affinities and local shape descriptors (LSDs). + + """ affs, affs_target, affs_weight = ( prediction[:, 0 : self.num_affinities, ...], target[:, 0 : self.num_affinities, ...], diff --git a/dacapo/experiments/tasks/losses/dummy_loss.py b/dacapo/experiments/tasks/losses/dummy_loss.py index f68206d01..f2e077279 100644 --- a/dacapo/experiments/tasks/losses/dummy_loss.py +++ b/dacapo/experiments/tasks/losses/dummy_loss.py @@ -7,29 +7,41 @@ class DummyLoss(Loss): Inherits the Loss class. - Methods - ------- - compute(prediction, target, weight=None) - Calculate the total loss between prediction and target. + Attributes: + name : str + name of the loss function + Methods: + compute(prediction, target, weight=None) + Calculate the total loss between prediction and target. + Note: + The dummy loss is used to test the training loop and the loss calculation. It is not a real loss function. + It is used to test the training loop and the loss calculation. + """ def compute(self, prediction, target, weight=None): """ Method to calculate the total dummy loss. - Parameters - ---------- - prediction : float or int - predicted output - target : float or int - true output - weight : float or int, optional - weight parameter for the loss, by default None - - Returns - ------- - float or int - Total loss calculated as the sum of absolute differences between prediction and target. + Args: + prediction : torch.Tensor + the model's prediction + target : torch.Tensor + the target values + weight : torch.Tensor + the weight to apply to the loss + Returns: + torch.Tensor + the total loss between prediction and target + Examples: + >>> dummy_loss = DummyLoss() + >>> prediction = torch.tensor([1, 2, 3]) + >>> target = torch.tensor([4, 5, 6]) + >>> dummy_loss.compute(prediction, target) + tensor(9) + Note: + The dummy loss is used to test the training loop and the loss calculation. It is not a real loss function. + It is used to test the training loop and the loss calculation. """ return abs(prediction - target).sum() diff --git a/dacapo/experiments/tasks/losses/hot_distance_loss.py b/dacapo/experiments/tasks/losses/hot_distance_loss.py index 784176bd0..0dc46f99f 100644 --- a/dacapo/experiments/tasks/losses/hot_distance_loss.py +++ b/dacapo/experiments/tasks/losses/hot_distance_loss.py @@ -7,7 +7,51 @@ # The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps. # Model should predict twice the number of channels as the target. class HotDistanceLoss(Loss): + """ + A class used to represent the Hot Distance Loss function. This class inherits from the Loss class. The Hot Distance Loss + function is used for predicting hot and distance maps at the same time. The first half of the channels are the hot maps, + the second half are the distance maps. The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance + maps. The model should predict twice the number of channels as the target. + + Attributes: + hot_loss: The Binary Cross Entropy Loss function. + distance_loss: The Mean Square Error Loss function. + Methods: + compute(prediction, target, weight) -> torch.Tensor + Function to compute the Hot Distance Loss for the provided prediction and target, with respect to the weight. + split(x) -> Tuple[torch.Tensor, torch.Tensor] + Function to split the input tensor into two tensors. + Note: + This class is abstract. Subclasses must implement the abstract methods. Once created, the values of its attributes + cannot be changed. + + """ def compute(self, prediction, target, weight): + """ + Function to compute the Hot Distance Loss for the provided prediction and target, with respect to the weight. + + Args: + prediction : torch.Tensor + The predicted tensor. + target : torch.Tensor + The target tensor. + weight : torch.Tensor + The weight tensor. + Returns: + torch.Tensor + The computed Hot Distance Loss tensor. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> loss = HotDistanceLoss() + >>> prediction = torch.tensor([1.0, 2.0, 3.0]) + >>> target = torch.tensor([1.0, 2.0, 3.0]) + >>> weight = torch.tensor([1.0, 1.0, 1.0]) + >>> loss.compute(prediction, target, weight) + tensor(0.) + Note: + This method must be implemented in the subclass. It should return the computed Hot Distance Loss tensor. + """ target_hot, target_distance = self.split(target) prediction_hot, prediction_distance = self.split(prediction) weight_hot, weight_distance = self.split(weight) @@ -16,14 +60,83 @@ def compute(self, prediction, target, weight): ) + self.distance_loss(prediction_distance, target_distance, weight_distance) def hot_loss(self, prediction, target, weight): + """ + The Binary Cross Entropy Loss function. This function computes the BCELoss for the hot maps. + + Args: + prediction : torch.Tensor + The predicted tensor. + target : torch.Tensor + The target tensor. + weight : torch.Tensor + The weight tensor. + Returns: + torch.Tensor + The computed BCELoss tensor. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> loss = HotDistanceLoss() + >>> prediction = torch.tensor([1.0, 2.0, 3.0]) + >>> target = torch.tensor([1.0, 2.0, 3.0]) + >>> weight = torch.tensor([1.0, 1.0, 1.0]) + >>> loss.hot_loss(prediction, target, weight) + tensor(0.) + Note: + This method must be implemented in the subclass. It should return the computed BCELoss tensor. + """ loss = torch.nn.BCEWithLogitsLoss(reduction="none") return torch.mean(loss(prediction, target) * weight) def distance_loss(self, prediction, target, weight): + """ + The Mean Square Error Loss function. This function computes the MSELoss for the distance maps. + + Args: + prediction : torch.Tensor + The predicted tensor. + target : torch.Tensor + The target tensor. + weight : torch.Tensor + The weight tensor. + Returns: + torch.Tensor + The computed MSELoss tensor. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> loss = HotDistanceLoss() + >>> prediction = torch.tensor([1.0, 2.0, 3.0]) + >>> target = torch.tensor([1.0, 2.0, 3.0]) + >>> weight = torch.tensor([1.0, 1.0, 1.0]) + >>> loss.distance_loss(prediction, target, weight) + tensor(0.) + Note: + This method must be implemented in the subclass. It should return the computed MSELoss tensor. + """ loss = torch.nn.MSELoss() return loss(prediction * weight, target * weight) def split(self, x): + """ + Function to split the input tensor into two tensors. + + Args: + x : torch.Tensor + The input tensor. + Returns: + Tuple[torch.Tensor, torch.Tensor] + The two split tensors. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> loss = HotDistanceLoss() + >>> x = torch.tensor([1.0, 2.0, 3.0]) + >>> loss.split(x) + (tensor([1.0]), tensor([2.0])) + Note: + This method must be implemented in the subclass. It should return the two split tensors. + """ # Shape[0] is the batch size and Shape[1] is the number of channels. assert ( x.shape[1] % 2 == 0 diff --git a/dacapo/experiments/tasks/losses/loss.py b/dacapo/experiments/tasks/losses/loss.py index 20824d6ab..a51125dc5 100644 --- a/dacapo/experiments/tasks/losses/loss.py +++ b/dacapo/experiments/tasks/losses/loss.py @@ -5,6 +5,17 @@ class Loss(ABC): + """ + A class used to represent a loss function. This class is an abstract class + that should be inherited by any loss function class. + + Methods: + compute(prediction, target, weight) -> torch.Tensor + Function to compute the loss for the provided prediction and target, with respect to the weight. + Note: + This class is abstract. Subclasses must implement the abstract methods. Once created, the values of its attributes + cannot be changed. + """ @abstractmethod def compute( self, @@ -12,10 +23,31 @@ def compute( target: torch.Tensor, weight: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Compute the loss for the given prediction and target. Optionally, if + """ + Compute the loss for the given prediction and target. Optionally, if given, a loss weight should be considered. All arguments are ``torch`` tensors. The return type should be a ``torch`` scalar that can be used with an optimizer, just as usual when - training with ``torch``.""" + training with ``torch``. + + Args: + prediction: The predicted tensor. + target: The target tensor. + weight: The weight tensor. + Returns: + The computed loss tensor. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> loss = MSELoss() + >>> prediction = torch.tensor([1.0, 2.0, 3.0]) + >>> target = torch.tensor([1.0, 2.0, 3.0]) + >>> weight = torch.tensor([1.0, 1.0, 1.0]) + >>> loss.compute(prediction, target, weight) + tensor(0.) + Note: + This method must be implemented in the subclass. It should return the + computed loss tensor. + """ pass diff --git a/dacapo/experiments/tasks/losses/mse_loss.py b/dacapo/experiments/tasks/losses/mse_loss.py index 348042c11..e3b0dac0a 100644 --- a/dacapo/experiments/tasks/losses/mse_loss.py +++ b/dacapo/experiments/tasks/losses/mse_loss.py @@ -4,34 +4,40 @@ class MSELoss(Loss): """ - A class used to represent the Mean Square Error Loss function (MSELoss). + A class used to represent the Mean Square Error Loss function (MSELoss). This class inherits from the Loss class. - Attributes - ---------- - None - - Methods - ------- - compute(prediction, target, weight): - Computes the MSELoss with the given weight for the predictiom and target. + Methods: + compute(prediction, target, weight) -> torch.Tensor + Function to compute the MSELoss for the provided prediction and target, with respect to the weight. + Note: + This class is abstract. Subclasses must implement the abstract methods. Once created, the values of its attributes + cannot be changed. """ def compute(self, prediction, target, weight): """ Function to compute the MSELoss for the provided prediction and target, with respect to the weight. - Parameters: - ---------- - prediction : torch.Tensor - The prediction tensor for which loss needs to be calculated. - target : torch.Tensor - The target tensor with respect to which loss is calculated. - weight : torch.Tensor - The weight tensor used to weigh the prediction in the loss calculation. - + Args: + prediction : torch.Tensor + The predicted tensor. + target : torch.Tensor + The target tensor. + weight : torch.Tensor + The weight tensor. Returns: - ------- - torch.Tensor - The computed MSELoss tensor. + torch.Tensor + The computed MSELoss tensor. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> loss = MSELoss() + >>> prediction = torch.tensor([1.0, 2.0, 3.0]) + >>> target = torch.tensor([1.0, 2.0, 3.0]) + >>> weight = torch.tensor([1.0, 1.0, 1.0]) + >>> loss.compute(prediction, target, weight) + tensor(0.) + Note: + This method must be implemented in the subclass. It should return the computed MSELoss tensor. """ return torch.nn.MSELoss().forward(prediction * weight, target * weight) diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index bfd4584d9..732349b1c 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -10,16 +10,74 @@ class ArgmaxPostProcessor(PostProcessor): + """ + Post-processor that takes the argmax of the input array along the channel + axis. The output is a binary array where the value is 1 if the argmax is + greater than the threshold, and 0 otherwise. + + Attributes: + prediction_array: The array containing the model's prediction. + Methods: + enumerate_parameters: Enumerate all possible parameters of this post-processor. + set_prediction: Set the prediction array identifier. + process: Convert predictions into the final output. + Note: + This class is abstract. Subclasses must implement the abstract methods. Once + created, the values of its attributes cannot be changed. + """ def __init__(self): + """ + Initialize the post-processor. + + Args: + detection_threshold: The detection threshold. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = ArgmaxPostProcessor() + Note: + This method must be implemented in the subclass. It should set the + `detection_threshold` attribute. + """ pass def enumerate_parameters(self): - """Enumerate all possible parameters of this post-processor. Should - return instances of ``PostProcessorParameters``.""" + """ + Enumerate all possible parameters of this post-processor. Should + return instances of ``PostProcessorParameters``. + + Returns: + An iterable of `PostProcessorParameters` instances. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = ArgmaxPostProcessor() + >>> for parameters in post_processor.enumerate_parameters(): + ... print(parameters) + ArgmaxPostProcessorParameters(id=0) + Note: + This method must be implemented in the subclass. It should return an + iterable of `PostProcessorParameters` instances. + """ yield ArgmaxPostProcessorParameters(id=1) def set_prediction(self, prediction_array_identifier): + """ + Set the prediction array identifier. + + Args: + prediction_array_identifier: The identifier of the array containing + the model's prediction. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = ArgmaxPostProcessor() + >>> post_processor.set_prediction("prediction") + Note: + This method must be implemented in the subclass. It should set the + `prediction_array_identifier` attribute. + """ self.prediction_array_identifier = prediction_array_identifier self.prediction_array = ZarrArray.open_from_array_identifier( prediction_array_identifier @@ -32,6 +90,26 @@ def process( num_workers: int = 16, block_size: Coordinate = Coordinate((256, 256, 256)), ): + """ + Convert predictions into the final output. + + Args: + parameters: The parameters of the post-processor. + output_array_identifier: The identifier of the output array. + num_workers: The number of workers to use. + block_size: The size of the blocks to process. + Returns: + The output array. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = ArgmaxPostProcessor() + >>> post_processor.set_prediction("prediction") + >>> post_processor.process(parameters, "output") + Note: + This method must be implemented in the subclass. It should process the + predictions and return the output array. + """ if self.prediction_array._daisy_array.chunk_shape is not None: block_size = Coordinate( self.prediction_array._daisy_array.chunk_shape[ diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor_parameters.py index 331faf5e6..d030ce3ff 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor_parameters.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor_parameters.py @@ -4,4 +4,14 @@ @attr.s(frozen=True) class ArgmaxPostProcessorParameters(PostProcessorParameters): + """ + Parameters for the argmax post-processor. The argmax post-processor will set + the output to the index of the maximum value in the input array. + + Methods: + parameter_names: Get the names of the parameters. + Note: + This class is immutable. Once created, the values of its attributes + cannot be changed. + """ pass diff --git a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py index 4a992ced2..e70bd9ba1 100644 --- a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py @@ -8,20 +8,107 @@ class DummyPostProcessor(PostProcessor): + """ + Dummy post-processor that stores some dummy data. The dummy data is a 10x10x10 + array filled with the value of the min_size parameter. The min_size parameter + is specified in the parameters of the post-processor. The post-processor has + a detection threshold that is used to determine if an object is detected. + + Attributes: + detection_threshold: The detection threshold. + Methods: + enumerate_parameters: Enumerate all possible parameters of this post-processor. + set_prediction: Set the prediction array identifier. + process: Convert predictions into the final output. + Note: + This class is abstract. Subclasses must implement the abstract methods. Once + created, the values of its attributes cannot be changed. + """ def __init__(self, detection_threshold: float): + """ + Initialize the post-processor. + + Args: + detection_threshold: The detection threshold. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = DummyPostProcessor(0.5) + Note: + This method must be implemented in the subclass. It should set the + `detection_threshold` attribute. + """ self.detection_threshold = detection_threshold def enumerate_parameters(self) -> Iterable[DummyPostProcessorParameters]: - """Enumerate all possible parameters of this post-processor. Should - return instances of ``PostProcessorParameters``.""" + """ + Enumerate all possible parameters of this post-processor. Should + return instances of ``PostProcessorParameters``. + + Returns: + An iterable of `PostProcessorParameters` instances. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = DummyPostProcessor() + >>> for parameters in post_processor.enumerate_parameters(): + ... print(parameters) + DummyPostProcessorParameters(id=0, min_size=1) + DummyPostProcessorParameters(id=1, min_size=2) + DummyPostProcessorParameters(id=2, min_size=3) + DummyPostProcessorParameters(id=3, min_size=4) + DummyPostProcessorParameters(id=4, min_size=5) + DummyPostProcessorParameters(id=5, min_size=6) + DummyPostProcessorParameters(id=6, min_size=7) + DummyPostProcessorParameters(id=7, min_size=8) + DummyPostProcessorParameters(id=8, min_size=9) + DummyPostProcessorParameters(id=9, min_size=10) + Note: + This method must be implemented in the subclass. It should return an + iterable of `PostProcessorParameters` instances. + """ for i, min_size in enumerate(range(1, 11)): yield DummyPostProcessorParameters(id=i, min_size=min_size) def set_prediction(self, prediction_array_identifier): + """ + Set the prediction array identifier. + + Args: + prediction_array_identifier: The identifier of the array containing + the model's prediction. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = DummyPostProcessor() + >>> post_processor.set_prediction("prediction") + Note: + This method must be implemented in the subclass. It should set the + `prediction_array_identifier` attribute. + """ pass def process(self, parameters, output_array_identifier, *args, **kwargs): + """ + Convert predictions into the final output. + + Args: + parameters: The parameters of the post-processor. + output_array_identifier: The identifier of the output array. + num_workers: The number of workers to use. + chunk_size: The size of the chunks to process. + Returns: + The output array. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = DummyPostProcessor() + >>> post_processor.process(parameters, "output") + Note: + This method must be implemented in the subclass. It should process the + predictions and store the output in the output array. + """ # store some dummy data f = zarr.open(str(output_array_identifier.container), "a") f[output_array_identifier.dataset] = np.ones((10, 10, 10)) * parameters.min_size diff --git a/dacapo/experiments/tasks/post_processors/dummy_post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/dummy_post_processor_parameters.py index bfa09e583..f0a17182b 100644 --- a/dacapo/experiments/tasks/post_processors/dummy_post_processor_parameters.py +++ b/dacapo/experiments/tasks/post_processors/dummy_post_processor_parameters.py @@ -4,4 +4,18 @@ @attr.s(frozen=True) class DummyPostProcessorParameters(PostProcessorParameters): + """ + Parameters for the dummy post-processor. The dummy post-processor will set + the output to 1 if the input is greater than the minimum size, and 0 + otherwise. + + Attributes: + min_size: The minimum size. If the input is greater than this value, the + output will be set to 1. Otherwise, the output will be set to 0. + Methods: + parameter_names: Get the names of the parameters. + Note: + This class is immutable. Once created, the values of its attributes + cannot be changed. + """ min_size: int = attr.ib() diff --git a/dacapo/experiments/tasks/post_processors/post_processor.py b/dacapo/experiments/tasks/post_processors/post_processor.py index f0a991c51..e15e52150 100644 --- a/dacapo/experiments/tasks/post_processors/post_processor.py +++ b/dacapo/experiments/tasks/post_processors/post_processor.py @@ -12,21 +12,71 @@ class PostProcessor(ABC): - """Base class of all post-processors. + """ + Base class of all post-processors. A post-processor takes a model's prediction and converts it into the final - output (e.g., per-voxel class probabilities into a semantic segmentation). + output (e.g., per-voxel class probabilities into a semantic segmentation). A + post-processor can have multiple parameters, which can be enumerated using + the `enumerate_parameters` method. The `process` method takes a set of + parameters and applies the post-processing to the prediction. + + Attributes: + prediction_array_identifier: The identifier of the array containing the + model's prediction. + Methods: + enumerate_parameters: Enumerate all possible parameters of this + post-processor. + set_prediction: Set the prediction array identifier. + process: Convert predictions into the final output. + Note: + This class is abstract. Subclasses must implement the abstract methods. Once + created, the values of its attributes cannot be changed. """ @abstractmethod def enumerate_parameters(self) -> Iterable["PostProcessorParameters"]: - """Enumerate all possible parameters of this post-processor.""" + """ + Enumerate all possible parameters of this post-processor. + + Returns: + An iterable of `PostProcessorParameters` instances. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = MyPostProcessor() + >>> for parameters in post_processor.enumerate_parameters(): + ... print(parameters) + MyPostProcessorParameters(param1=0.0, param2=0.0) + MyPostProcessorParameters(param1=0.0, param2=1.0) + MyPostProcessorParameters(param1=1.0, param2=0.0) + MyPostProcessorParameters(param1=1.0, param2=1.0) + Note: + This method must be implemented in the subclass. It should return an + iterable of `PostProcessorParameters` instances. + + """ pass @abstractmethod def set_prediction( self, prediction_array_identifier: "LocalArrayIdentifier" ) -> None: + """ + Set the prediction array identifier. + + Args: + prediction_array_identifier: The identifier of the array containing + the model's prediction. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = MyPostProcessor() + >>> post_processor.set_prediction("prediction") + Note: + This method must be implemented in the subclass. It should set the + `prediction_array_identifier` attribute. + """ pass @abstractmethod @@ -37,5 +87,26 @@ def process( num_workers: int = 16, chunk_size: Coordinate = Coordinate((64, 64, 64)), ) -> "Array": - """Convert predictions into the final output.""" + """ + Convert predictions into the final output. + + Args: + parameters: The parameters of the post-processor. + output_array_identifier: The identifier of the array to store the + output. + num_workers: The number of workers to use. + chunk_size: The size of the chunks to process. + Returns: + The output array. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = MyPostProcessor() + >>> post_processor.set_prediction("prediction") + >>> parameters = MyPostProcessorParameters(param1=0.0, param2=0.0) + >>> output = post_processor.process(parameters, "output") + Note: + This method must be implemented in the subclass. It should convert the + model's prediction into the final output. + """ pass diff --git a/dacapo/experiments/tasks/post_processors/post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/post_processor_parameters.py index dd08ab41c..b43e664b4 100644 --- a/dacapo/experiments/tasks/post_processors/post_processor_parameters.py +++ b/dacapo/experiments/tasks/post_processors/post_processor_parameters.py @@ -5,12 +5,40 @@ @attr.s(frozen=True) class PostProcessorParameters: - """Base class for post-processor parameters.""" + """ + Base class for post-processor parameters. Post-processor parameters are + immutable objects that define the parameters of a post-processor. The + parameters are used to configure the post-processor. + + Attributes: + id: The identifier of the post-processor parameter. + Methods: + parameter_names: Get the names of the parameters. + Note: + This class is immutable. Once created, the values of its attributes + cannot be changed. + + """ id: int = attr.ib() @property def parameter_names(self) -> List[str]: + """ + Get the names of the parameters. + + Returns: + A list of parameter names. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> parameters = PostProcessorParameters(0) + >>> parameters.parameter_names + ["id"] + Note: + This method must be implemented in the subclass. It should return a + list of parameter names. + """ return ["id"] diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index c5fdd52d5..5d6f32450 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -12,15 +12,52 @@ class ThresholdPostProcessor(PostProcessor): + """ + A post-processor that applies a threshold to the prediction. + + Attributes: + prediction_array_identifier: The identifier of the prediction array. + prediction_array: The prediction array. + Methods: + enumerate_parameters: Enumerate all possible parameters of this post-processor. + set_prediction: Set the prediction array. + process: Process the prediction with the given parameters. + Note: + This post-processor applies a threshold to the prediction. The threshold is used to define the segmentation. The prediction array is set using the `set_prediction` method. + """ def __init__(self): pass def enumerate_parameters(self) -> Iterable["ThresholdPostProcessorParameters"]: - """Enumerate all possible parameters of this post-processor.""" + """ + Enumerate all possible parameters of this post-processor. + + Returns: + Generator[ThresholdPostProcessorParameters]: A generator of parameters. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> for parameters in post_processor.enumerate_parameters(): + ... print(parameters) + Note: + This method should return a generator of instances of ``ThresholdPostProcessorParameters``. + """ for i, threshold in enumerate([100, 127, 150]): yield ThresholdPostProcessorParameters(id=i, threshold=threshold) def set_prediction(self, prediction_array_identifier): + """ + Set the prediction array. + + Args: + prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> post_processor.set_prediction(prediction_array_identifier) + Note: + This method should set the prediction array using the given identifier. + """ self.prediction_array_identifier = prediction_array_identifier self.prediction_array = ZarrArray.open_from_array_identifier( prediction_array_identifier @@ -33,6 +70,24 @@ def process( num_workers: int = 16, block_size: Coordinate = Coordinate((256, 256, 256)), ) -> ZarrArray: + """ + Process the prediction with the given parameters. + + Args: + parameters (ThresholdPostProcessorParameters): The parameters to use for processing. + output_array_identifier (LocalArrayIdentifier): The identifier of the output array. + num_workers (int): The number of workers to use for processing. + block_size (Coordinate): The block size to use for processing. + Returns: + ZarrArray: The output array. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> post_processor.process(parameters, output_array_identifier) + Note: + This method should process the prediction with the given parameters and return the output array. The method uses the `run_blockwise` function from the `dacapo.blockwise.scheduler` module to run the blockwise post-processing. + The output array is created using the `ZarrArray.create_from_array_identifier` function from the `dacapo.experiments.datasplits.datasets.arrays` module. + """ # TODO: Investigate Liskov substitution princple and whether it is a problem here # OOP theory states the super class should always be replaceable with its subclasses # meaning the input arguments to methods on the subclass can only be more loosely diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py index 9a28ba970..3fc121d96 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py @@ -4,4 +4,17 @@ @attr.s(frozen=True) class ThresholdPostProcessorParameters(PostProcessorParameters): + """ + Parameters for the threshold post-processor. The threshold post-processor + will set the output to 1 if the input is greater than the threshold, and 0 + otherwise. + + Attributes: + threshold: The threshold value. If the input is greater than this + value, the output will be set to 1. Otherwise, the output will be + set to 0. + Note: + This class is immutable. Once created, the values of its attributes + cannot be changed. + """ threshold: float = attr.ib(default=0.0) diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 7a3467daa..d508e9624 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -16,12 +16,51 @@ class WatershedPostProcessor(PostProcessor): + """ + A post-processor that applies a watershed transformation to the + prediction. + + Attributes: + offsets: List of offsets for the watershed transformation. + Methods: + enumerate_parameters: Enumerate all possible parameters of this post-processor. + set_prediction: Set the prediction array. + process: Process the prediction with the given parameters. + Note: + This post-processor uses the `watershed_function.py` script to apply the watershed transformation. The offsets are used to define the neighborhood for the watershed transformation. + + """ + def __init__(self, offsets: List[Coordinate]): + """ + A post-processor that applies a watershed transformation to the + prediction. + + Args: + offsets (List[Coordinate]): List of offsets for the watershed transformation. + Examples: + >>> WatershedPostProcessor(offsets=[(0, 0, 1), (0, 1, 0), (1, 0, 0)]) + Note: + This post-processor uses the `watershed_function.py` script to apply the watershed transformation. The offsets are used to define the neighborhood for the watershed transformation. + """ self.offsets = offsets def enumerate_parameters(self): - """Enumerate all possible parameters of this post-processor. Should - return instances of ``PostProcessorParameters``.""" + """ + Enumerate all possible parameters of this post-processor. Should + return instances of ``PostProcessorParameters``. + + Returns: + Generator[WatershedPostProcessorParameters]: A generator of parameters. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> for parameters in post_processor.enumerate_parameters(): + ... print(parameters) + Note: + This method should be implemented by the subclass. It should return a generator of instances of ``WatershedPostProcessorParameters``. + + """ for i, bias in enumerate([0.1, 0.25, 0.5, 0.75, 0.9]): yield WatershedPostProcessorParameters(id=i, bias=bias) @@ -31,6 +70,18 @@ def set_prediction(self, prediction_array_identifier): self.prediction_array = ZarrArray.open_from_array_identifier( prediction_array_identifier ) + """ + Set the prediction array. + + Args: + prediction_array_identifier (LocalArrayIdentifier): The prediction array identifier. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> post_processor.set_prediction(prediction_array_identifier) + Note: + This method should be implemented by the subclass. To set the prediction array, the method uses the `ZarrArray.open_from_array_identifier` function from the `dacapo.experiments.datasplits.datasets.arrays` module. + """ def process( self, @@ -39,6 +90,23 @@ def process( num_workers: int = 16, block_size: Coordinate = Coordinate((256, 256, 256)), ): + """ + Process the prediction with the given parameters. + + Args: + parameters (WatershedPostProcessorParameters): The parameters to use for processing. + output_array_identifier (LocalArrayIdentifier): The output array identifier. + num_workers (int): The number of workers to use for processing. + block_size (Coordinate): The block size to use for processing. + Returns: + LocalArrayIdentifier: The output array identifier. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> post_processor.process(parameters, output_array_identifier) + Note: + This method should be implemented by the subclass. To run the watershed transformation, the method uses the `segment_blockwise` function from the `dacapo.blockwise.scheduler` module. + """ if self.prediction_array._daisy_array.chunk_shape is not None: block_size = Coordinate( self.prediction_array._daisy_array.chunk_shape[ diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor_parameters.py index 6a3a1e271..b4a62d517 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor_parameters.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor_parameters.py @@ -5,5 +5,20 @@ @attr.s(frozen=True) class WatershedPostProcessorParameters(PostProcessorParameters): + """ + Parameters for the watershed post-processor. + + Attributes: + offsets: List of offsets for the watershed transformation. + threshold: Threshold for the watershed transformation. + sigma: Sigma for the watershed transformation. + min_size: Minimum size of the segments. + bias: Bias for the watershed transformation. + context: Context for the watershed transformation. + Examples: + >>> WatershedPostProcessorParameters(offsets=[(0, 0, 1), (0, 1, 0), (1, 0, 0)], threshold=0.5, sigma=1.0, min_size=100, bias=0.5, context=(32, 32, 32)) + Note: + This class is used by the ``WatershedPostProcessor`` to define the parameters for the watershed transformation. The offsets are used to define the neighborhood for the watershed transformation. + """ bias: float = attr.ib(default=0.5) context: Coordinate = attr.ib(default=Coordinate((32, 32, 32))) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 44dbf4088..4916e557f 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -108,7 +108,11 @@ def create_optimizer(self, model): >>> optimizer = trainer.create_optimizer(model) """ - optimizer = torch.optim.RAdam(lr=self.learning_rate, params=model.parameters()) + optimizer = torch.optim.RAdam( + lr=self.learning_rate, + params=model.parameters(), + decoupled_weight_decay=True, + ) self.scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=0.01,