Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests v0 3 5 #338

Merged
merged 7 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import torch.nn as nn

from funlib.geometry import Coordinate

import math


Expand Down Expand Up @@ -176,7 +178,7 @@ def __init__(self, architecture_config):
self.unet = self.module()

@property
def eval_shape_increase(self):
def eval_shape_increase(self) -> Coordinate:
"""
The increase in shape due to the U-Net.

Expand All @@ -192,7 +194,7 @@ def eval_shape_increase(self):
"""
if self._eval_shape_increase is None:
return super().eval_shape_increase
return self._eval_shape_increase
return Coordinate(self._eval_shape_increase)

def module(self):
"""
Expand Down Expand Up @@ -235,16 +237,15 @@ def module(self):
"""
fmaps_in = self.fmaps_in
levels = len(self.downsample_factors) + 1
dims = len(self.downsample_factors[0])

if hasattr(self, "kernel_size_down"):
if self.kernel_size_down is not None:
kernel_size_down = self.kernel_size_down
else:
kernel_size_down = [[(3,) * dims, (3,) * dims]] * levels
if hasattr(self, "kernel_size_up"):
kernel_size_down = [[(3,) * self.dims, (3,) * self.dims]] * levels
if self.kernel_size_up is not None:
kernel_size_up = self.kernel_size_up
else:
kernel_size_up = [[(3,) * dims, (3,) * dims]] * (levels - 1)
kernel_size_up = [[(3,) * self.dims, (3,) * self.dims]] * (levels - 1)

# downsample factors has to be a list of tuples
downsample_factors = [tuple(x) for x in self.downsample_factors]
Expand Down Expand Up @@ -280,7 +281,7 @@ def module(self):
conv = ConvPass(
self.fmaps_out,
self.fmaps_out,
[(3,) * len(upsample_factor)] * 2,
kernel_size_up[-1],
activation="ReLU",
batch_norm=self.batch_norm,
)
Expand All @@ -306,11 +307,11 @@ def scale(self, voxel_size):
The voxel size should be given as a tuple ``(z, y, x)``.
"""
for upsample_factor in self.upsample_factors:
voxel_size = voxel_size / upsample_factor
voxel_size = voxel_size / Coordinate(upsample_factor)
return voxel_size

@property
def input_shape(self):
def input_shape(self) -> Coordinate:
"""
Return the input shape of the U-Net.

Expand All @@ -324,7 +325,7 @@ def input_shape(self):
Note:
The input shape should be given as a tuple ``(batch, channels, [length,] depth, height, width)``.
"""
return self._input_shape
return Coordinate(self._input_shape)

@property
def num_in_channels(self) -> int:
Expand Down
2 changes: 2 additions & 0 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
assert isinstance(dataset.weight, int), dataset

raw_source = gp.ArraySource(raw_key, dataset.raw)
if dataset.raw.channel_dims == 0:
raw_source += gp.Unsqueeze([raw_key], axis=0)
if self.clip_raw:
raw_source += gp.Crop(
raw_key, dataset.gt.roi.snap_to_grid(dataset.raw.voxel_size)
Expand Down
8 changes: 4 additions & 4 deletions tests/fixtures/datasplits.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def twelve_class_datasplit(tmp_path):
gt_dataset[:] += random_data > i
raw_dataset[:] = random_data
raw_dataset.attrs["offset"] = (0, 0, 0)
raw_dataset.attrs["resolution"] = (2, 2, 2)
raw_dataset.attrs["voxel_size"] = (2, 2, 2)
raw_dataset.attrs["axis_names"] = ("z", "y", "x")
gt_dataset.attrs["offset"] = (0, 0, 0)
gt_dataset.attrs["resolution"] = (2, 2, 2)
gt_dataset.attrs["voxel_size"] = (2, 2, 2)
gt_dataset.attrs["axis_names"] = ("z", "y", "x")

crop1 = RawGTDatasetConfig(name="crop1", raw_config=crop1_raw, gt_config=crop1_gt)
Expand Down Expand Up @@ -184,10 +184,10 @@ def six_class_datasplit(tmp_path):
gt_dataset[:] += random_data > i
raw_dataset[:] = random_data
raw_dataset.attrs["offset"] = (0, 0, 0)
raw_dataset.attrs["resolution"] = (2, 2, 2)
raw_dataset.attrs["voxel_size"] = (2, 2, 2)
raw_dataset.attrs["axis_names"] = ("z", "y", "x")
gt_dataset.attrs["offset"] = (0, 0, 0)
gt_dataset.attrs["resolution"] = (2, 2, 2)
gt_dataset.attrs["voxel_size"] = (2, 2, 2)
gt_dataset.attrs["axis_names"] = ("z", "y", "x")

crop1 = RawGTDatasetConfig(
Expand Down
2 changes: 1 addition & 1 deletion tests/operations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def unet_architecture(batch_norm, upsample, use_attention, three_d):
name=name,
input_shape=(2, 132, 132),
eval_shape_increase=(8, 32, 32),
fmaps_in=2,
fmaps_in=1,
num_fmaps=8,
fmaps_out=8,
fmap_inc_factor=2,
Expand Down