-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into attention-unet
- Loading branch information
Showing
21 changed files
with
96 additions
and
323 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
name: black-action | ||
|
||
on: [push, pull_request] | ||
|
||
jobs: | ||
linter_name: | ||
name: runner / black | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
BSD 3-Clause License | ||
|
||
Copyright (c) 2024, Howard Hughes Medical Institute | ||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
|
||
1. Redistributions of source code must retain the above copyright notice, this | ||
list of conditions and the following disclaimer. | ||
|
||
2. Redistributions in binary form must reproduce the above copyright notice, | ||
this list of conditions and the following disclaimer in the documentation | ||
and/or other materials provided with the distribution. | ||
|
||
3. Neither the name of the copyright holder nor the names of its | ||
contributors may be used to endorse or promote products derived from | ||
this software without specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,200 +1,13 @@ | ||
import logging | ||
from typing import Optional | ||
from funlib.geometry import Roi, Coordinate | ||
import numpy as np | ||
from dacapo.experiments.datasplits.datasets.arrays.array import Array | ||
from dacapo.experiments.datasplits.datasets.dataset import Dataset | ||
from dacapo.experiments.run import Run | ||
|
||
from dacapo.experiments.tasks.post_processors.post_processor_parameters import ( | ||
PostProcessorParameters, | ||
) | ||
import dacapo.experiments.tasks.post_processors as post_processors | ||
from dacapo.store.array_store import LocalArrayIdentifier | ||
from dacapo.predict import predict | ||
from dacapo.compute_context import LocalTorch, ComputeContext | ||
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray | ||
from dacapo.store import ( | ||
create_config_store, | ||
create_weights_store, | ||
) | ||
|
||
from pathlib import Path | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def apply( | ||
run_name: str, | ||
input_container: Path or str, | ||
input_dataset: str, | ||
output_path: Path or str, | ||
validation_dataset: Optional[Dataset or str] = None, | ||
criterion: Optional[str] = "voi", | ||
iteration: Optional[int] = None, | ||
parameters: Optional[PostProcessorParameters or str] = None, | ||
roi: Optional[Roi or str] = None, | ||
num_cpu_workers: int = 30, | ||
output_dtype: Optional[np.dtype or str] = np.uint8, | ||
compute_context: ComputeContext = LocalTorch(), | ||
overwrite: bool = True, | ||
file_format: str = "zarr", | ||
): | ||
"""Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used.""" | ||
if isinstance(output_dtype, str): | ||
output_dtype = np.dtype(output_dtype) | ||
|
||
if isinstance(roi, str): | ||
start, end = zip( | ||
*[ | ||
tuple(int(coord) for coord in axis.split(":")) | ||
for axis in roi.strip("[]").split(",") | ||
] | ||
) | ||
roi = Roi( | ||
Coordinate(start), | ||
Coordinate(end) - Coordinate(start), | ||
) | ||
|
||
assert (validation_dataset is not None and isinstance(criterion, str)) or ( | ||
isinstance(iteration, int) | ||
), "Either validation_dataset and criterion, or iteration must be provided." | ||
|
||
# retrieving run | ||
logger.info("Loading run %s", run_name) | ||
config_store = create_config_store() | ||
run_config = config_store.retrieve_run_config(run_name) | ||
run = Run(run_config) | ||
|
||
# create weights store | ||
weights_store = create_weights_store() | ||
|
||
# load weights | ||
if iteration is None: | ||
# weights_store._load_best(run, criterion) | ||
iteration = weights_store.retrieve_best(run_name, validation_dataset, criterion) | ||
logger.info("Loading weights for iteration %i", iteration) | ||
weights_store.retrieve_weights(run, iteration) # shouldn't this be load_weights? | ||
|
||
# find the best parameters | ||
if isinstance(validation_dataset, str): | ||
val_ds_name = validation_dataset | ||
validation_dataset = [ | ||
dataset for dataset in run.datasplit.validate if dataset.name == val_ds_name | ||
][0] | ||
logger.info("Finding best parameters for validation dataset %s", validation_dataset) | ||
if parameters is None: | ||
parameters = run.task.evaluator.get_overall_best_parameters( | ||
validation_dataset, criterion | ||
) | ||
assert ( | ||
parameters is not None | ||
), "Unable to retieve parameters. Parameters must be provided explicitly." | ||
|
||
elif isinstance(parameters, str): | ||
try: | ||
post_processor_name = parameters.split("(")[0] | ||
post_processor_kwargs = parameters.split("(")[1].strip(")").split(",") | ||
post_processor_kwargs = { | ||
key.strip(): value.strip() | ||
for key, value in [arg.split("=") for arg in post_processor_kwargs] | ||
} | ||
for key, value in post_processor_kwargs.items(): | ||
if value.isdigit(): | ||
post_processor_kwargs[key] = int(value) | ||
elif value.replace(".", "", 1).isdigit(): | ||
post_processor_kwargs[key] = float(value) | ||
except: | ||
raise ValueError( | ||
f"Could not parse parameters string {parameters}. Must be of the form 'post_processor_name(arg1=val1, arg2=val2, ...)'" | ||
) | ||
try: | ||
parameters = getattr(post_processors, post_processor_name)( | ||
**post_processor_kwargs | ||
) | ||
except Exception as e: | ||
logger.error( | ||
f"Could not instantiate post-processor {post_processor_name} with arguments {post_processor_kwargs}.", | ||
exc_info=True, | ||
) | ||
raise e | ||
|
||
assert isinstance( | ||
parameters, PostProcessorParameters | ||
), "Parameters must be parsable to a PostProcessorParameters object." | ||
|
||
# make array identifiers for input, predictions and outputs | ||
input_array_identifier = LocalArrayIdentifier(input_container, input_dataset) | ||
input_array = ZarrArray.open_from_array_identifier(input_array_identifier) | ||
roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect( | ||
input_array.roi | ||
) | ||
output_container = Path( | ||
output_path, | ||
"".join(Path(input_container).name.split(".")[:-1]) + f".{file_format}", | ||
) | ||
prediction_array_identifier = LocalArrayIdentifier( | ||
output_container, f"prediction_{run_name}_{iteration}" | ||
) | ||
output_array_identifier = LocalArrayIdentifier( | ||
output_container, f"output_{run_name}_{iteration}_{parameters}" | ||
) | ||
|
||
def apply(run_name: str, iteration: int, dataset_name: str): | ||
logger.info( | ||
"Applying best results from run %s at iteration %i to dataset %s", | ||
run.name, | ||
"Applying results from run %s at iteration %d to dataset %s", | ||
run_name, | ||
iteration, | ||
Path(input_container, input_dataset), | ||
) | ||
return apply_run( | ||
run, | ||
parameters, | ||
input_array, | ||
prediction_array_identifier, | ||
output_array_identifier, | ||
roi, | ||
num_cpu_workers, | ||
output_dtype, | ||
compute_context, | ||
overwrite, | ||
) | ||
|
||
|
||
def apply_run( | ||
run: Run, | ||
parameters: PostProcessorParameters, | ||
input_array: Array, | ||
prediction_array_identifier: LocalArrayIdentifier, | ||
output_array_identifier: LocalArrayIdentifier, | ||
roi: Optional[Roi] = None, | ||
num_cpu_workers: int = 30, | ||
output_dtype: Optional[np.dtype] = np.uint8, | ||
compute_context: ComputeContext = LocalTorch(), | ||
overwrite: bool = True, | ||
): | ||
"""Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded.""" | ||
run.model.eval() | ||
|
||
# render prediction dataset | ||
logger.info("Predicting on dataset %s", prediction_array_identifier) | ||
predict( | ||
run.model, | ||
input_array, | ||
prediction_array_identifier, | ||
output_roi=roi, | ||
num_cpu_workers=num_cpu_workers, | ||
output_dtype=output_dtype, | ||
compute_context=compute_context, | ||
overwrite=overwrite, | ||
dataset_name, | ||
) | ||
|
||
# post-process the output | ||
logger.info("Post-processing output to dataset %s", output_array_identifier) | ||
post_processor = run.task.post_processor | ||
post_processor.set_prediction(prediction_array_identifier) | ||
post_processor.process( | ||
parameters, output_array_identifier, overwrite=overwrite, blockwise=True | ||
) | ||
|
||
logger.info("Done") | ||
return | ||
raise NotImplementedError("This function is not yet implemented.") |
Oops, something went wrong.