Skip to content

Commit

Permalink
Format Python code with psf/black push (#328)
Browse files Browse the repository at this point in the history
There appear to be some python formatting errors in
1813190. This pull request
uses the [psf/black](https://github.com/psf/black) formatter to fix
these issues.
  • Loading branch information
rhoadesScholar authored Nov 12, 2024
2 parents a56dca3 + 2d63df5 commit 75e2980
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

logger = logging.getLogger(__name__)


class ThresholdPostProcessor(PostProcessor):
"""
A post-processor that applies a threshold to the prediction.
Expand Down
6 changes: 2 additions & 4 deletions dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,8 @@ def process(
sampling = tuple(float(v) / 2 for v in voxel_size)
# fixing the sampling for 2D images
if len(boundaries.shape) < len(sampling):
sampling = sampling[-len(boundaries.shape):]
distances = distance_transform_edt(
boundaries, sampling=sampling
)
sampling = sampling[-len(boundaries.shape) :]
distances = distance_transform_edt(boundaries, sampling=sampling)
distances = distances.astype(np.float32)

# restore original shape
Expand Down
4 changes: 1 addition & 3 deletions dacapo/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def validate_run(run: Run, iteration: int, datasets_config=None):
# criterion,
# )
dataset_iteration_scores.append(
[getattr(scores, criterion) for criterion in scores.criteria]
[getattr(scores, criterion) for criterion in scores.criteria]
)
except:
logger.error(
Expand All @@ -260,8 +260,6 @@ def validate_run(run: Run, iteration: int, datasets_config=None):
# the evaluator
# array_store.remove(output_array_identifier)



iteration_scores.append(dataset_iteration_scores)
# array_store.remove(prediction_array_identifier)

Expand Down
53 changes: 24 additions & 29 deletions tests/operations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,35 @@

logging.basicConfig(level=logging.INFO)

from dacapo.experiments.architectures import DummyArchitectureConfig, CNNectomeUNetConfig
from dacapo.experiments.architectures import (
DummyArchitectureConfig,
CNNectomeUNetConfig,
)

import pytest


def unet_architecture(batch_norm, upsample,use_attention, three_d):
def unet_architecture(batch_norm, upsample, use_attention, three_d):
name = "3d_unet" if three_d else "2d_unet"
name = f"{name}_bn" if batch_norm else name
name = f"{name}_up" if upsample else name
name = f"{name}_att" if use_attention else name

if three_d:
return CNNectomeUNetConfig(
name=name,
input_shape=(188, 188, 188),
eval_shape_increase=(72, 72, 72),
fmaps_in=1,
num_fmaps=6,
fmaps_out=6,
fmap_inc_factor=2,
downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)],
constant_upsample=True,
upsample_factors=[(2, 2, 2)] if upsample else [],
batch_norm=batch_norm,
use_attention=use_attention,
)
return CNNectomeUNetConfig(
name=name,
input_shape=(188, 188, 188),
eval_shape_increase=(72, 72, 72),
fmaps_in=1,
num_fmaps=6,
fmaps_out=6,
fmap_inc_factor=2,
downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)],
constant_upsample=True,
upsample_factors=[(2, 2, 2)] if upsample else [],
batch_norm=batch_norm,
use_attention=use_attention,
)
else:
return CNNectomeUNetConfig(
name=name,
Expand All @@ -61,7 +64,6 @@ def unet_architecture(batch_norm, upsample,use_attention, three_d):
)



# skip the test for the Apple Paravirtual device
# that does not support Metal 2.0
@pytest.mark.filterwarnings("ignore:.*Metal 2.0.*:UserWarning")
Expand Down Expand Up @@ -117,19 +119,15 @@ def test_train(
@pytest.mark.parametrize("use_attention", [True, False])
@pytest.mark.parametrize("three_d", [True, False])
def test_train_unet(
datasplit,
task,
trainer,
batch_norm,
upsample,
use_attention,
three_d):

datasplit, task, trainer, batch_norm, upsample, use_attention, three_d
):
store = create_config_store()
stats_store = create_stats_store()
weights_store = create_weights_store()

architecture_config = unet_architecture(batch_norm, upsample,use_attention, three_d)
architecture_config = unet_architecture(
batch_norm, upsample, use_attention, three_d
)

run_config = RunConfig(
name=f"{architecture_config.name}_run",
Expand Down Expand Up @@ -167,6 +165,3 @@ def test_train_unet(
training_stats = stats_store.retrieve_training_stats(run_config.name)

assert training_stats.trained_until() == run_config.num_iterations



0 comments on commit 75e2980

Please sign in to comment.