Skip to content

Commit

Permalink
Data conventions (#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored Nov 13, 2024
2 parents 975b8b8 + d22eee5 commit e69bba2
Show file tree
Hide file tree
Showing 13 changed files with 284 additions and 53 deletions.
1 change: 1 addition & 0 deletions dacapo/experiments/datasplits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .train_validate_datasplit import TrainValidateDataSplit
from .train_validate_datasplit_config import TrainValidateDataSplitConfig
from .datasplit_generator import DataSplitGenerator, DatasetSpec
from .simple_config import SimpleDataSplitConfig
1 change: 1 addition & 0 deletions dacapo/experiments/datasplits/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .dummy_dataset_config import DummyDatasetConfig
from .raw_gt_dataset import RawGTDataset
from .raw_gt_dataset_config import RawGTDatasetConfig
from .simple import SimpleDataset
9 changes: 9 additions & 0 deletions dacapo/experiments/datasplits/datasets/dummy_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .dataset import Dataset
from funlib.persistence import Array

import warnings


class DummyDataset(Dataset):
"""
Expand All @@ -15,6 +17,7 @@ class DummyDataset(Dataset):
Notes:
This class is used to create a dataset with raw data.
"""


raw: Array

Expand All @@ -34,5 +37,11 @@ def __init__(self, dataset_config):
This method is used to initialize the dataset.
"""
super().__init__()

warnings.warn(
"DummyDataset is deprecated. Use SimpleDataset instead.",
DeprecationWarning,
)

self.name = dataset_config.name
self.raw = dataset_config.raw_config.array()
7 changes: 7 additions & 0 deletions dacapo/experiments/datasplits/datasets/raw_gt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from funlib.geometry import Coordinate

from typing import Optional, List
import warnings


class RawGTDataset(Dataset):
Expand Down Expand Up @@ -48,6 +49,12 @@ def __init__(self, dataset_config):
Notes:
This method is used to initialize the dataset.
"""

warnings.warn(
"RawGTDataset is deprecated. Use SimpleDataset instead.",
DeprecationWarning,
)

self.name = dataset_config.name
self.raw = dataset_config.raw_config.array()
self.gt = dataset_config.gt_config.array()
Expand Down
69 changes: 69 additions & 0 deletions dacapo/experiments/datasplits/datasets/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from .dataset_config import DatasetConfig

from funlib.persistence import Array, open_ds


import attr

from pathlib import Path
import numpy as np

@attr.s
class SimpleDataset(DatasetConfig):

path: Path = attr.ib()
weight: int = attr.ib(default=1)
raw_name: str = attr.ib(default="raw")
gt_name: str = attr.ib(default="labels")
mask_name: str = attr.ib(default="mask")

@staticmethod
def dataset_type(dataset_config):
return dataset_config

@property
def raw(self) -> Array:
raw_array = open_ds(self.path / self.raw_name)
dtype = raw_array.dtype
if dtype == np.uint8:
raw_array.lazy_op(lambda data: data.astype(np.float32) / 255)
elif dtype == np.uint16:
raw_array.lazy_op(lambda data: data.astype(np.float32) / 65535)
elif np.issubdtype(dtype, np.floating):
pass
elif np.issubdtype(dtype, np.integer):
raise Exception(
f"Not sure how to normalize intensity data with dtype {dtype}"
)
return raw_array

@property
def gt(self) -> Array:
return open_ds(self.path / self.gt_name)

@property
def mask(self) -> Array | None:
mask_path = self.path / self.mask_name
if mask_path.exists():
mask = open_ds(mask_path)
assert np.issubdtype(mask.dtype, np.integer), "Mask must be integer type"
mask.lazy_op(lambda data: data > 0)
return mask
return None

@property
def sample_points(self) -> None:
return None


def __eq__(self, other) -> bool:
return isinstance(other, type(self)) and self.name == other.name

def __hash__(self) -> int:
return hash(self.name)

def __repr__(self) -> str:
return self.name

def __str__(self) -> str:
return self.name
5 changes: 5 additions & 0 deletions dacapo/experiments/datasplits/dummy_datasplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .datasets import Dataset

from typing import List
import warnings


class DummyDataSplit(DataSplit):
Expand Down Expand Up @@ -41,6 +42,10 @@ def __init__(self, datasplit_config):
This function is called by the DummyDataSplit class to initialize the DummyDataSplit class with specified config to split the data into training and validation datasets.
"""
super().__init__()
warnings.warn(
"TrainValidateDataSplit is deprecated. Use SimpleDataSplitConfig instead.",
DeprecationWarning,
)

self.train = [
datasplit_config.train_config.dataset_type(datasplit_config.train_config)
Expand Down
69 changes: 69 additions & 0 deletions dacapo/experiments/datasplits/simple_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from .datasets.simple import SimpleDataset
from .datasplit_config import DataSplitConfig

import attr

from pathlib import Path

import glob

@attr.s
class SimpleDataSplitConfig(DataSplitConfig):
"""
A convention over configuration datasplit that can handle many of the most
basic cases.
"""

path: Path = attr.ib()
name: str = attr.ib()
train_group_name: str = attr.ib(default="train")
validate_group_name: str = attr.ib(default="test")
raw_name: str = attr.ib(default="raw")
gt_name: str = attr.ib(default="labels")
mask_name: str = attr.ib(default="mask")

@staticmethod
def datasplit_type(datasplit_config):
return datasplit_config

def get_paths(self, group_name: str) -> list[Path]:
level_0 = f"{self.path}/{self.raw_name}"
level_1 = f"{self.path}/{group_name}/{self.raw_name}"
level_2 = f"{self.path}/{group_name}/**/{self.raw_name}"
level_0_matches = glob.glob(level_0)
level_1_matches = glob.glob(level_1)
level_2_matches = glob.glob(level_2)
if len(level_0_matches) > 0:
assert (
len(level_1_matches) == len(level_2_matches) == 0
), f"Found raw data at {level_0} and {level_1} and {level_2}"
return [Path(x).parent for x in level_0_matches]
elif len(level_1_matches) > 0:
assert (
len(level_2_matches) == 0
), f"Found raw data at {level_1} and {level_2}"
return [Path(x).parent for x in level_1_matches]
elif len(level_2_matches).parent > 0:
return [Path(x) for x in level_2_matches]

raise Exception(f"No raw data found at {level_0} or {level_1} or {level_2}")

@property
def train(self) -> list[SimpleDataset]:
return [
SimpleDataset(
name=x.stem,
path=x,
)
for x in self.get_paths(self.train_group_name)
]

@property
def validate(self) -> list[SimpleDataset]:
return [
SimpleDataset(
name=x.stem,
path=x,
)
for x in self.get_paths(self.validate_group_name)
]
5 changes: 5 additions & 0 deletions dacapo/experiments/datasplits/train_validate_datasplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .datasets import Dataset

from typing import List
import warnings


class TrainValidateDataSplit(DataSplit):
Expand Down Expand Up @@ -47,6 +48,10 @@ def __init__(self, datasplit_config):
into training and validation datasets.
"""
super().__init__()
warnings.warn(
"TrainValidateDataSplit is deprecated. Use SimpleDataSplitConfig instead.",
DeprecationWarning,
)

self.train = [
train_config.dataset_type(train_config)
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/validation_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def to_xarray(self) -> xr.DataArray:
"iterations": [
iteration_score.iteration for iteration_score in self.scores
],
"datasets": self.datasets,
"datasets": [d.name for d in self.datasets],
"parameters": self.parameters,
"criteria": self.criteria,
},
Expand Down
2 changes: 1 addition & 1 deletion dacapo/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def plot_runs(
)
colors_val = itertools.cycle(plt.cm.tab20.colors)
for dataset in run.validation_scores.datasets:
dataset_data = validation_score_data.sel(datasets=dataset)
dataset_data = validation_score_data.sel(datasets=dataset.name)
include_validation_figure = True
x = [score.iteration for score in run.validation_scores.scores]
for i, cc in zip(range(dataset_data.data.shape[1]), colors_val):
Expand Down
111 changes: 111 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
.. _sec_data:

Data Formatting
===============

Overview
--------

We support any data format that can be opened with the `zarr.open` convenience function from
`zarr <https://zarr.readthedocs.io/en/stable/api/convenience.html#zarr.convenience.open>`_. We also expect some specific metadata to come
with the data.

Metadata
--------

- `voxel_size`: The size of each voxel in the dataset. This is expected to be a tuple of ints
with the same length as the number of spatial dimensions in the dataset.
- `offset`: The offset of the dataset. This is expected to be a tuple of ints with the same length
as the number of spatial dimensions in the dataset.
- `axis_names`: The name of each axis. This is expected to be a tuple of strings with the same length
as the total number of dimensions in the dataset. For example a 3D dataset with channels would have
`axis_names=('c^', 'z', 'y', 'x')`. Note we expect non-spatial dimensions to include a "^" character.
See [1]_ for expected future changes
- `units`: The units of each axis. This is expected to be a tuple of strings with the same length
as the number of spatial dimensions in the dataset. For example a 3D dataset with channels would have
`units=('nanometers', 'nanometers', 'nanometers')`.

Orgnaization
------------

Ideally all of your data will be contained in a single zarr container.
The simplest possible dataset would look like this:
::

data.zarr
├── raw
└── labels

If this is what your data looks like, then your data configuration will look like this:

.. code-block::
:caption: A simple data configuration
data_config = DataConfig(
path="/path/to/data.zarr"
)
Note that a lot of assumptions will be made.

1. We assume your raw data is normalized based on the `dtype`. I.e. if your data is
stored as an unsigned int (we recommend uint8) we will assume a range and normalize
it to [0,1] by dividing by the appropriate value (255 for `uint8` or 65535 for `uint16`).
If your data is stored as any `float` we will assume it is already in the range [0, 1].
2. We assume your labels are stored as unsigned integers. If you want to generate instance segmentations, you will need
to assign a unique id to every object of the class you are interested in. If you want semantic segmentations you
can simply assign a unique id to each class. 0 is reserved for the background class.
3. We assume that the labels are provided densely. The entire volume will be used for training.
4. We will be training and validating on the same data. This is not ideal, but it is an ok starting point for testing
and debugging.

Next we can add a little bit of complexity by seperating train and test data. This can also be handled
by the same data configuration as above since it will detect the presence of the `train` and `test` groups.

::

data.zarr
├── train
│ ├── raw
│ └── labels
└── test
├── raw
└── labels

We can go further with our basic data configuration since this will often not be enough to describe your data. You may have multiple crops and often your data may be
sparsely annotated. The same data configuration from above will also work for the slightly more complicated
dataset below:

::

data.zarr
├── train
│ ├── crop_01
│ │ ├── raw
│ │ ├── labels
│ │ └── mask
│ └── crop_02
│ ├── raw
│ └── labels
└── test
└─ crop_03
│ ├── raw
│ ├── labels
│ └── mask
└─ crop_04
├── raw
└── labels

Note that `crop_01` and `crop_03` have masks associated with them. We assume a value of `0` in the mask indicates
unknown data. We will never use this data for supervised training, regardless of the corresponding label value.
If multiple test datasets are provided, this will increase the amount of information to review after training.
You will have e.g. `crop_03_voi` and `crop_04_voi` stored in the validation scores. Since we also take care to
save the "best" model checkpoint, you may now double the number of checkpoints saved since the checkpoint that
achieves optimal `voi` on `crop_03` may not be the same as the checkpoint that achieves optimal `voi` on `crop_04`.

Footnotes
---------

.. [1] The specification of axis names is expected to change in the future since we expect to support a `type` field in the future which
can be one of ["time", "space", "{anything-else}"]. Which would allow you to specify dimensions as "channel"
or "batch" or whatever else you want. This will bring us more in line with OME-Zarr and allow us to more easily
handle a larger variety of common data specification formats.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
overview
install
notebooks/minimal_tutorial
data
unet_architectures
tutorial
docker
Expand Down
Loading

0 comments on commit e69bba2

Please sign in to comment.