Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Format Python code with psf/black push #45

Merged
merged 6 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/black.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
name: black-action

on: [push, pull_request]

jobs:
linter_name:
name: runner / black
Expand Down
9 changes: 4 additions & 5 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
name: Pages
on:
push:
branches:
- master
name: Generate Pages

on: [push, pull_request]

jobs:
docs:
runs-on: ubuntu-latest
Expand Down
38 changes: 0 additions & 38 deletions .github/workflows/publish.yaml

This file was deleted.

3 changes: 1 addition & 2 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
name: Test

on:
push:
on: [push, pull_request]

jobs:
test:
Expand Down
197 changes: 5 additions & 192 deletions dacapo/apply.py
Original file line number Diff line number Diff line change
@@ -1,200 +1,13 @@
import logging
from typing import Optional
from funlib.geometry import Roi, Coordinate
import numpy as np
from dacapo.experiments.datasplits.datasets.arrays.array import Array
from dacapo.experiments.datasplits.datasets.dataset import Dataset
from dacapo.experiments.run import Run

from dacapo.experiments.tasks.post_processors.post_processor_parameters import (
PostProcessorParameters,
)
import dacapo.experiments.tasks.post_processors as post_processors
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.predict import predict
from dacapo.compute_context import LocalTorch, ComputeContext
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.store import (
create_config_store,
create_weights_store,
)

from pathlib import Path

logger = logging.getLogger(__name__)


def apply(
run_name: str,
input_container: Path or str,
input_dataset: str,
output_path: Path or str,
validation_dataset: Optional[Dataset or str] = None,
criterion: Optional[str] = "voi",
iteration: Optional[int] = None,
parameters: Optional[PostProcessorParameters or str] = None,
roi: Optional[Roi or str] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[np.dtype or str] = np.uint8,
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
file_format: str = "zarr",
):
"""Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used."""
if isinstance(output_dtype, str):
output_dtype = np.dtype(output_dtype)

if isinstance(roi, str):
start, end = zip(
*[
tuple(int(coord) for coord in axis.split(":"))
for axis in roi.strip("[]").split(",")
]
)
roi = Roi(
Coordinate(start),
Coordinate(end) - Coordinate(start),
)

assert (validation_dataset is not None and isinstance(criterion, str)) or (
isinstance(iteration, int)
), "Either validation_dataset and criterion, or iteration must be provided."

# retrieving run
logger.info("Loading run %s", run_name)
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

# create weights store
weights_store = create_weights_store()

# load weights
if iteration is None:
# weights_store._load_best(run, criterion)
iteration = weights_store.retrieve_best(run_name, validation_dataset, criterion)
logger.info("Loading weights for iteration %i", iteration)
weights_store.retrieve_weights(run, iteration) # shouldn't this be load_weights?

# find the best parameters
if isinstance(validation_dataset, str):
val_ds_name = validation_dataset
validation_dataset = [
dataset for dataset in run.datasplit.validate if dataset.name == val_ds_name
][0]
logger.info("Finding best parameters for validation dataset %s", validation_dataset)
if parameters is None:
parameters = run.task.evaluator.get_overall_best_parameters(
validation_dataset, criterion
)
assert (
parameters is not None
), "Unable to retieve parameters. Parameters must be provided explicitly."

elif isinstance(parameters, str):
try:
post_processor_name = parameters.split("(")[0]
post_processor_kwargs = parameters.split("(")[1].strip(")").split(",")
post_processor_kwargs = {
key.strip(): value.strip()
for key, value in [arg.split("=") for arg in post_processor_kwargs]
}
for key, value in post_processor_kwargs.items():
if value.isdigit():
post_processor_kwargs[key] = int(value)
elif value.replace(".", "", 1).isdigit():
post_processor_kwargs[key] = float(value)
except:
raise ValueError(
f"Could not parse parameters string {parameters}. Must be of the form 'post_processor_name(arg1=val1, arg2=val2, ...)'"
)
try:
parameters = getattr(post_processors, post_processor_name)(
**post_processor_kwargs
)
except Exception as e:
logger.error(
f"Could not instantiate post-processor {post_processor_name} with arguments {post_processor_kwargs}.",
exc_info=True,
)
raise e

assert isinstance(
parameters, PostProcessorParameters
), "Parameters must be parsable to a PostProcessorParameters object."

# make array identifiers for input, predictions and outputs
input_array_identifier = LocalArrayIdentifier(input_container, input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect(
input_array.roi
)
output_container = Path(
output_path,
"".join(Path(input_container).name.split(".")[:-1]) + f".{file_format}",
)
prediction_array_identifier = LocalArrayIdentifier(
output_container, f"prediction_{run_name}_{iteration}"
)
output_array_identifier = LocalArrayIdentifier(
output_container, f"output_{run_name}_{iteration}_{parameters}"
)

def apply(run_name: str, iteration: int, dataset_name: str):
logger.info(
"Applying best results from run %s at iteration %i to dataset %s",
run.name,
"Applying results from run %s at iteration %d to dataset %s",
run_name,
iteration,
Path(input_container, input_dataset),
)
return apply_run(
run,
parameters,
input_array,
prediction_array_identifier,
output_array_identifier,
roi,
num_cpu_workers,
output_dtype,
compute_context,
overwrite,
)


def apply_run(
run: Run,
parameters: PostProcessorParameters,
input_array: Array,
prediction_array_identifier: LocalArrayIdentifier,
output_array_identifier: LocalArrayIdentifier,
roi: Optional[Roi] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[np.dtype] = np.uint8,
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
):
"""Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded."""
run.model.eval()

# render prediction dataset
logger.info("Predicting on dataset %s", prediction_array_identifier)
predict(
run.model,
input_array,
prediction_array_identifier,
output_roi=roi,
num_cpu_workers=num_cpu_workers,
output_dtype=output_dtype,
compute_context=compute_context,
overwrite=overwrite,
dataset_name,
)

# post-process the output
logger.info("Post-processing output to dataset %s", output_array_identifier)
post_processor = run.task.post_processor
post_processor.set_prediction(prediction_array_identifier)
post_processor.process(
parameters, output_array_identifier, overwrite=overwrite, blockwise=True
)

logger.info("Done")
return
raise NotImplementedError("This function is not yet implemented.")
55 changes: 11 additions & 44 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

import dacapo
import click
import logging
Expand Down Expand Up @@ -42,52 +40,21 @@ def validate(run_name, iteration):

@cli.command()
@click.option(
"-r", "--run-name", required=True, type=str, help="The name of the run to apply."
"-r", "--run-name", required=True, type=str, help="The name of the run to use."
)
@click.option(
"-ic",
"--input_container",
"-i",
"--iteration",
required=True,
type=click.Path(exists=True, file_okay=False),
type=int,
help="The iteration weights and parameters to use.",
)
@click.option("-id", "--input_dataset", required=True, type=str)
@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False))
@click.option("-vd", "--validation_dataset", type=str, default=None)
@click.option("-c", "--criterion", default="voi")
@click.option("-i", "--iteration", type=int, default=None)
@click.option("-p", "--parameters", type=str, default=None)
@click.option(
"-roi",
"--roi",
"-r",
"--dataset",
required=True,
type=str,
required=False,
help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]",
help="The name of the dataset to apply the run to.",
)
@click.option("-w", "--num_cpu_workers", type=int, default=30)
@click.option("-dt", "--output_dtype", type=str, default="uint8")
def apply(
run_name: str,
input_container: str,
input_dataset: str,
output_path: str,
validation_dataset: Optional[str] = None,
criterion: Optional[str] = "voi",
iteration: Optional[int] = None,
parameters: Optional[str] = None,
roi: Optional[str] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[str] = "uint8",
):
dacapo.apply(
run_name,
input_container,
input_dataset,
output_path,
validation_dataset,
criterion,
iteration,
parameters,
roi,
num_cpu_workers,
output_dtype,
)
def apply(run_name, iteration, dataset_name):
dacapo.apply(run_name, iteration, dataset_name)
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def voxel_size(self) -> Coordinate:

@lazy_property.LazyProperty
def roi(self) -> Roi:
return Roi(self._offset * self.shape)
return Roi(self._offset, self.shape)

@property
def writable(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def forward(self, x):
result = self.eval_activation(result)
return result

def compute_output_shape(self, input_shape: Coordinate) -> Coordinate:
def compute_output_shape(self, input_shape: Coordinate) -> Tuple[int, Coordinate]:
"""Compute the spatial shape (i.e., not accounting for channels and
batch dimensions) of this model, when fed a tensor of the given spatial
shape as input."""
Expand Down
3 changes: 1 addition & 2 deletions dacapo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def predict(
num_cpu_workers: int = 4,
compute_context: ComputeContext = LocalTorch(),
output_roi: Optional[Roi] = None,
output_dtype: Optional[np.dtype] = np.float32, # add necessary type conversions
output_dtype: np.dtype = np.float32, # type: ignore
overwrite: bool = False,
):
# get the model's input and output size
Expand Down Expand Up @@ -59,7 +59,6 @@ def predict(
model.num_out_channels,
output_voxel_size,
output_dtype,
overwrite=overwrite,
)

# create gunpowder keys
Expand Down
5 changes: 4 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,7 @@ ignore_missing_imports = True
ignore_missing_imports = True

[mypy-mwatershed.*]
ignore_missing_imports = True
ignore_missing_imports = True

[mypy-numpy_indexed.*]
ignore_missing_imports = True
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
black
mypy
pytest
pytest==7.4.4
pytest-cov
pytest-lazy-fixture
Loading