Skip to content

Commit

Permalink
more architecture tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Nov 13, 2024
1 parent bc70727 commit 0d3b3d2
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .db import options
from .architectures import dummy_architecture
from .architectures import dummy_architecture, unet_architecture
from .arrays import dummy_array, zarr_array, cellmap_array
from .datasplits import dummy_datasplit, twelve_class_datasplit, six_class_datasplit
from .evaluators import binary_3_channel_evaluator
Expand Down
23 changes: 22 additions & 1 deletion tests/fixtures/architectures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from dacapo.experiments.architectures import DummyArchitectureConfig
from dacapo.experiments.architectures import (
DummyArchitectureConfig,
CNNectomeUNetConfig,
)

import pytest

Expand All @@ -8,3 +11,21 @@ def dummy_architecture():
yield DummyArchitectureConfig(
name="dummy_architecture", num_in_channels=1, num_out_channels=12
)


@pytest.fixture()
def unet_architecture():
yield CNNectomeUNetConfig(
name="tmp_unet_architecture",
input_shape=(2, 132, 132),
eval_shape_increase=(8, 32, 32),
fmaps_in=2,
num_fmaps=8,
fmaps_out=8,
fmap_inc_factor=2,
downsample_factors=[(1, 4, 4), (1, 4, 4)],
kernel_size_down=[[(1, 3, 3)] * 2] * 3,
kernel_size_up=[[(1, 3, 3)] * 2] * 2,
constant_upsample=True,
padding="valid",
)
26 changes: 26 additions & 0 deletions tests/operations/test_architecture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from ..fixtures import *

import pytest
from pytest_lazy_fixtures import lf

import logging

logging.basicConfig(level=logging.INFO)


@pytest.mark.parametrize(
"architecture_config",
[
lf("dummy_architecture"),
lf("unet_architecture"),
],
)
def test_architecture(
architecture_config,
):

architecture_type = architecture_config.architecture_type

architecture = architecture_type(architecture_config)

assert architecture.dims is not None, f"Architecture dims are None {architecture}"
14 changes: 11 additions & 3 deletions tests/operations/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@
import pytest


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", [""])
def test_create_compute_context(device):
compute_context = create_compute_context()
assert compute_context is not None
assert compute_context.device is not None
if torch.cuda.is_available():
assert compute_context.device == torch.device('cuda'), "Model is not on CUDA when CUDA is available {}".format(compute_context.device)
assert compute_context.device == torch.device(
"cuda"
), "Model is not on CUDA when CUDA is available {}".format(
compute_context.device
)
else:
assert compute_context.device == torch.device('cpu'), "Model is not on CPU when CUDA is not available {}".format(compute_context.device)
assert compute_context.device == torch.device(
"cpu"
), "Model is not on CPU when CUDA is not available {}".format(
compute_context.device
)

0 comments on commit 0d3b3d2

Please sign in to comment.