diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 24b1968c3..bade6d083 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -97,6 +97,7 @@ nav: - webknossos-py/examples/download_tiff_stack.md - webknossos-py/examples/remote_datasets.md - webknossos-py/examples/zarr_and_dask.md + - webknossos-py/examples/convert_4d_tiff.md - Annotation Examples: - webknossos-py/examples/apply_merger_mode.md - webknossos-py/examples/learned_segmenter.md @@ -112,8 +113,10 @@ nav: - Overview: api/webknossos.md - Geometry: - BoundingBox: api/webknossos/geometry/bounding_box.md + - NDBoundingBox: api/webknossos/geometry/nd_bounding_box.md - Mag: api/webknossos/geometry/mag.md - Vec3Int: api/webknossos/geometry/vec3_int.md + - VecInt: api/webknossos/geometry/vec_int.md - Dataset: - Dataset: api/webknossos/dataset/dataset.md - Layer: api/webknossos/dataset/layer.md diff --git a/docs/src/webknossos-py/examples/convert_4d_tiff.md b/docs/src/webknossos-py/examples/convert_4d_tiff.md new file mode 100644 index 000000000..10edc2298 --- /dev/null +++ b/docs/src/webknossos-py/examples/convert_4d_tiff.md @@ -0,0 +1,13 @@ +# Convert 4D Tiff + +This example demonstrates the basic interactions with Datasets that have more than three dimensions. + +In order to manipulate 4D data in WEBKNOSSOS, we first convert the 4D Tiff dataset into a Zarr3 dataset. This conversion is achieved using the [from_images method](../../api/webknossos/dataset/dataset.md#Dataset.from_images). + +Once the dataset is converted, we can access specific layers and views, [read data](../../api/webknossos/dataset/mag_view.md#MagView.read) from a defined bounding box, and [write data](../../api/webknossos/dataset/mag_view.md#MagView.write) to a different position within the dataset. The [NDBoundingBox](../../api/webknossos/geometry/nd_bounding_box.md#NDBoundingBox) is utilized to select a 4D region of the dataset. + +```python +--8<-- +webknossos/examples/convert_4d_tiff.py +--8<-- +``` diff --git a/webknossos/Changelog.md b/webknossos/Changelog.md index bb493f1e9..758a139ad 100644 --- a/webknossos/Changelog.md +++ b/webknossos/Changelog.md @@ -21,6 +21,7 @@ For upgrade instructions, please check the respective _Breaking Changes_ section - The rules for naming the layers have been tightened to match the allowed layer names on webknossos. [#1016](https://github.com/scalableminds/webknossos-libs/pull/1016) - Replaced PyLint linter + black formatter with Ruff for development. [#1013](https://github.com/scalableminds/webknossos-libs/pull/1013) - The remote operations now use the WEBKNOSSOS API version 6. [#1018](https://github.com/scalableminds/webknossos-libs/pull/1018) +- The conversion of 4D Tiff files to a Zarr3 Dataset is possible. NDBoundingBoxes and VecInt classes are introduced to support working with more than 3 dimensions. [#966](https://github.com/scalableminds/webknossos-libs/pull/966) ### Fixed diff --git a/webknossos/examples/convert_4d_tiff.py b/webknossos/examples/convert_4d_tiff.py new file mode 100644 index 000000000..184bee12c --- /dev/null +++ b/webknossos/examples/convert_4d_tiff.py @@ -0,0 +1,48 @@ +from pathlib import Path + +import webknossos as wk + + +def main() -> None: + # Create a WEBKNOSSOS dataset from a 4D tiff image + dataset = wk.Dataset.from_images( + Path(__file__).parent.parent / "testdata" / "4D" / "4D_series", + "testoutput/4D_series", + voxel_size=(10, 10, 10), + data_format="zarr3", + use_bioformats=True, + ) + + # Access the first color layer and the Mag 1 view of this layer + layer = dataset.get_color_layers()[0] + mag_view = layer.get_finest_mag() + + # To get the bounding box of the dataset use layer.bounding_box + # -> NDBoundingBox(topleft=(0, 0, 0, 0), size=(7, 5, 167, 439), axes=('t', 'z', 'y', 'x')) + + # Read all data of the dataset + data = mag_view.read() + # data.shape -> (1, 7, 5, 167, 439) # first value is the channel dimension + + # Read data for a specific time point (t=3) of the dataset + data = mag_view.read( + absolute_bounding_box=layer.bounding_box.with_bounds("t", 3, 1) + ) + # data.shape -> (1, 1, 5, 167, 439) + + # Create a NDBoundingBox to read data from a specific region of the dataset + read_bbox = wk.NDBoundingBox( + topleft=(2, 0, 67, 39), + size=(2, 5, 100, 400), + axes=("t", "z", "y", "x"), + index=(1, 2, 3, 4), + ) + data = mag_view.read(absolute_bounding_box=read_bbox) + # data.shape -> (1, 2, 5, 100, 400) # first value is the channel dimension + + # Write some data to a given position + mag_view.write(data, absolute_bounding_box=read_bbox.offset((2, 0, 0, 0))) + + +if __name__ == "__main__": + main() diff --git a/webknossos/testdata/4D/4D_series/4D-series.ome.tif b/webknossos/testdata/4D/4D_series/4D-series.ome.tif new file mode 100644 index 000000000..f957ee380 Binary files /dev/null and b/webknossos/testdata/4D/4D_series/4D-series.ome.tif differ diff --git a/webknossos/tests/dataset/test_add_layer_from_images.py b/webknossos/tests/dataset/test_add_layer_from_images.py index 8750923b0..82babf807 100644 --- a/webknossos/tests/dataset/test_add_layer_from_images.py +++ b/webknossos/tests/dataset/test_add_layer_from_images.py @@ -43,6 +43,31 @@ def test_compare_tifffile(tmp_path: Path) -> None: assert np.array_equal(data[:, :, z_index], comparison_slice) +def test_compare_nd_tifffile(tmp_path: Path) -> None: + ds = wk.Dataset(tmp_path, (1, 1, 1)) + layer = ds.add_layer_from_images( + "testdata/4D/4D_series/4D-series.ome.tif", + layer_name="color", + category="color", + topleft=(100, 100, 55), + use_bioformats=True, + data_format="zarr3", + chunk_shape=(8, 8, 8), + chunks_per_shard=(8, 8, 8), + ) + assert layer.bounding_box.topleft == wk.VecInt( + 0, 55, 100, 100, axes=("t", "z", "y", "x") + ) + assert layer.bounding_box.size == wk.VecInt( + 7, 5, 167, 439, axes=("t", "z", "y", "x") + ) + read_with_tifffile_reader = TiffFile( + "testdata/4D/4D_series/4D-series.ome.tif" + ).asarray() + read_first_channel_from_dataset = layer.get_finest_mag().read()[0] + assert np.array_equal(read_with_tifffile_reader, read_first_channel_from_dataset) + + REPO_IMAGES_ARGS: List[ Tuple[Union[str, List[Path]], Dict[str, Any], str, int, Tuple[int, int, int]] ] = [ @@ -205,15 +230,6 @@ def download_and_unpack( (192, 128, 9), 1, ), - ( - "https://samples.scif.io/sdub.zip", - "sdub*.pic", - {"allow_multiple_layers": True}, - "uint8", - 1, - (192, 128, 9), - 12, - ), ( "https://samples.scif.io/test-avi.zip", "t1-rendering.avi", diff --git a/webknossos/webknossos/_nml/parameters.py b/webknossos/webknossos/_nml/parameters.py index 1e5c4c3d8..b9424f128 100644 --- a/webknossos/webknossos/_nml/parameters.py +++ b/webknossos/webknossos/_nml/parameters.py @@ -3,7 +3,7 @@ from loxun import XmlWriter -from ..geometry import BoundingBox +from ..geometry import BoundingBox, NDBoundingBox from ..geometry.bounding_box import _DEFAULT_BBOX_NAME from .utils import Vector3, enforce_not_null, filter_none_values @@ -22,22 +22,24 @@ class Parameters(NamedTuple): editPosition: Optional[Vector3] = None editRotation: Optional[Vector3] = None zoomLevel: Optional[float] = None - taskBoundingBox: Optional[BoundingBox] = None - userBoundingBoxes: Optional[List[BoundingBox]] = None + taskBoundingBox: Optional[NDBoundingBox] = None + userBoundingBoxes: Optional[List[NDBoundingBox]] = None def _dump_bounding_box( self, xf: XmlWriter, - bounding_box: BoundingBox, + bounding_box: NDBoundingBox, tag_name: str, bbox_id: Optional[int], # user bounding boxes need an id ) -> None: color = bounding_box.color or DEFAULT_BOUNDING_BOX_COLOR attributes = { - "name": _DEFAULT_BBOX_NAME - if bounding_box.name is None - else str(bounding_box.name), + "name": ( + _DEFAULT_BBOX_NAME + if bounding_box.name is None + else str(bounding_box.name) + ), "isVisible": "true" if bounding_box.is_visible else "false", "color.r": str(color[0]), "color.g": str(color[1]), @@ -136,7 +138,7 @@ def _dump(self, xf: XmlWriter) -> None: xf.endTag() # parameters @classmethod - def _parse_bounding_box(cls, bounding_box_element: Element) -> BoundingBox: + def _parse_bounding_box(cls, bounding_box_element: Element) -> NDBoundingBox: topleft = ( int(bounding_box_element.get("topLeftX", 0)), int(bounding_box_element.get("topLeftY", 0)), @@ -165,14 +167,16 @@ def _parse_bounding_box(cls, bounding_box_element: Element) -> BoundingBox: ) @classmethod - def _parse_user_bounding_boxes(cls, nml_parameters: Element) -> List[BoundingBox]: + def _parse_user_bounding_boxes(cls, nml_parameters: Element) -> List[NDBoundingBox]: if nml_parameters.find("userBoundingBox") is None: return [] bb_elements = nml_parameters.findall("userBoundingBox") return [cls._parse_bounding_box(bb_element) for bb_element in bb_elements] @classmethod - def _parse_task_bounding_box(cls, nml_parameters: Element) -> Optional[BoundingBox]: + def _parse_task_bounding_box( + cls, nml_parameters: Element + ) -> Optional[NDBoundingBox]: bb_element = nml_parameters.find("taskBoundingBox") if bb_element is not None: return cls._parse_bounding_box(bb_element) diff --git a/webknossos/webknossos/annotation/annotation.py b/webknossos/webknossos/annotation/annotation.py index bb2cb4f15..9d93aa94a 100644 --- a/webknossos/webknossos/annotation/annotation.py +++ b/webknossos/webknossos/annotation/annotation.py @@ -60,7 +60,7 @@ ) from ..dataset.defaults import PROPERTIES_FILE_NAME from ..dataset.properties import DatasetProperties, dataset_converter -from ..geometry import BoundingBox, Vec3Int +from ..geometry import NDBoundingBox, Vec3Int from ..skeleton import Skeleton from ..utils import time_since_epoch_in_ms, warn_deprecated from ._nml_conversion import annotation_to_nml, nml_to_skeleton @@ -124,8 +124,8 @@ class Annotation: edit_rotation: Optional[Vector3] = None zoom_level: Optional[float] = None metadata: Dict[str, str] = attr.Factory(dict) - task_bounding_box: Optional[BoundingBox] = None - user_bounding_boxes: List[BoundingBox] = attr.Factory(list) + task_bounding_box: Optional[NDBoundingBox] = None + user_bounding_boxes: List[NDBoundingBox] = attr.Factory(list) _volume_layers: List[_VolumeLayer] = attr.field(factory=list, init=False) @classmethod @@ -474,7 +474,7 @@ def _load_from_zip(cls, content: Union[str, PathLike, BinaryIO]) -> "Annotation" assert len(nml_paths) > 0, "Couldn't find an nml file in the supplied zip-file." assert ( len(nml_paths) == 1 - ), f"There must be exactly one nml file in the zip-file, buf found {len(nml_paths)}." + ), f"There must be exactly one nml file in the zip-file, but found {len(nml_paths)}." with nml_paths[0].open(mode="rb") as f: return cls._load_from_nml(nml_paths[0].stem, f, possible_volume_paths=paths) diff --git a/webknossos/webknossos/cli/convert_knossos.py b/webknossos/webknossos/cli/convert_knossos.py index 51184c342..44ad18faf 100644 --- a/webknossos/webknossos/cli/convert_knossos.py +++ b/webknossos/webknossos/cli/convert_knossos.py @@ -151,15 +151,15 @@ def convert_cube_job( time_start(f"Converting of {target_view.bounding_box}") cube_size = cast(Tuple[int, int, int], (KNOSSOS_CUBE_EDGE_LEN,) * 3) - offset = target_view.bounding_box.in_mag(target_view.mag).topleft - size = target_view.bounding_box.in_mag(target_view.mag).size + offset = target_view.bounding_box.in_mag(target_view.mag).topleft_xyz + size = target_view.bounding_box.in_mag(target_view.mag).size_xyz buffer = np.zeros(size.to_tuple(), dtype=target_view.get_dtype()) with open_knossos(source_knossos_info) as source_knossos: for x in range(0, size.x, KNOSSOS_CUBE_EDGE_LEN): for y in range(0, size.y, KNOSSOS_CUBE_EDGE_LEN): for z in range(0, size.z, KNOSSOS_CUBE_EDGE_LEN): cube_data = source_knossos.read( - (offset + Vec3Int(x, y, z)).to_tuple(), cube_size + Vec3Int(offset + (x, y, z)).to_tuple(), cube_size ) buffer[ x : (x + KNOSSOS_CUBE_EDGE_LEN), diff --git a/webknossos/webknossos/cli/export_wkw_as_tiff.py b/webknossos/webknossos/cli/export_wkw_as_tiff.py index 56afcaf2e..b4bb4aa14 100644 --- a/webknossos/webknossos/cli/export_wkw_as_tiff.py +++ b/webknossos/webknossos/cli/export_wkw_as_tiff.py @@ -262,7 +262,7 @@ def main( mag_view = Dataset.open(source).get_layer(layer_name).get_mag(mag) - bbox = mag_view.bounding_box if bbox is None else bbox + bbox = BoundingBox.from_ndbbox(mag_view.bounding_box) if bbox is None else bbox logging.info("Starting tiff export for bounding box: %s", bbox) executor_args = Namespace( diff --git a/webknossos/webknossos/dataset/_array.py b/webknossos/webknossos/dataset/_array.py index 547ca08a0..eab5a4286 100644 --- a/webknossos/webknossos/dataset/_array.py +++ b/webknossos/webknossos/dataset/_array.py @@ -5,7 +5,17 @@ from dataclasses import dataclass from os.path import relpath from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterator, + List, + Optional, + Tuple, + Type, + Union, +) import numcodecs import numpy as np @@ -14,7 +24,7 @@ from upath import UPath from zarr.storage import FSStore -from ..geometry import BoundingBox, Vec3Int, Vec3IntLike +from ..geometry import BoundingBox, NDBoundingBox, Vec3Int, VecInt from ..utils import is_fs_path, warn_deprecated from .data_format import DataFormat @@ -60,6 +70,9 @@ class ArrayInfo: voxel_type: np.dtype chunk_shape: Vec3Int chunks_per_shard: Vec3Int + shape: VecInt = VecInt(c=1, x=1, y=1, z=1) + dimension_names: Tuple[str, ...] = ("c", "x", "y", "z") + axis_order: VecInt = VecInt(c=3, x=2, y=1, z=0) compression_mode: bool = False @property @@ -103,21 +116,24 @@ def create(cls, path: Path, array_info: ArrayInfo) -> "BaseArray": pass @abstractmethod - def read(self, offset: Vec3IntLike, shape: Vec3IntLike) -> np.ndarray: + def read(self, bbox: NDBoundingBox) -> np.ndarray: pass @abstractmethod - def write(self, offset: Vec3IntLike, data: np.ndarray) -> None: + def write(self, bbox: NDBoundingBox, data: np.ndarray) -> None: pass @abstractmethod def ensure_size( - self, new_shape: Vec3IntLike, align_with_shards: bool = True, warn: bool = False + self, + new_bbox: NDBoundingBox, + align_with_shards: bool = True, + warn: bool = False, ) -> None: pass @abstractmethod - def list_bounding_boxes(self) -> Iterator[BoundingBox]: + def list_bounding_boxes(self) -> Iterator[NDBoundingBox]: "The bounding boxes are measured in voxels of the current mag." @abstractmethod @@ -208,15 +224,15 @@ def create(cls, path: Path, array_info: ArrayInfo) -> "WKWArray": raise ArrayException(f"Exception while creating array {path}") from e return WKWArray(path) - def read(self, offset: Vec3IntLike, shape: Vec3IntLike) -> np.ndarray: - return self._wkw_dataset.read(Vec3Int(offset), Vec3Int(shape)) + def read(self, bbox: NDBoundingBox) -> np.ndarray: + return self._wkw_dataset.read(Vec3Int(bbox.topleft), Vec3Int(bbox.size)) - def write(self, offset: Vec3IntLike, data: np.ndarray) -> None: - self._wkw_dataset.write(Vec3Int(offset), data) + def write(self, bbox: NDBoundingBox, data: np.ndarray) -> None: + self._wkw_dataset.write(Vec3Int(bbox.topleft), data) def ensure_size( self, - new_shape: Vec3IntLike, + new_bbox: NDBoundingBox, align_with_shards: bool = True, warn: bool = False, ) -> None: @@ -228,7 +244,7 @@ def _list_files(self) -> Iterator[Path]: for filename in self._wkw_dataset.list_files() ) - def list_bounding_boxes(self) -> Iterator[BoundingBox]: + def list_bounding_boxes(self) -> Iterator[NDBoundingBox]: def _extract_num(s: str) -> int: match = re.search("[0-9]+", s) assert match is not None @@ -331,19 +347,14 @@ def create(cls, path: Path, array_info: ArrayInfo) -> "ZarrArray": ) return ZarrArray(path) - def read(self, offset: Vec3IntLike, shape: Vec3IntLike) -> np.ndarray: - offset = Vec3Int(offset) - shape = Vec3Int(shape) + def read(self, bbox: NDBoundingBox) -> np.ndarray: + shape = bbox.size zarray = self._zarray with _blosc_disable_threading(): - data = zarray[ - :, - offset.x : (offset.x + shape.x), - offset.y : (offset.y + shape.y), - offset.z : (offset.z + shape.z), - ] + data = zarray[(slice(None),) + bbox.to_slices()] + shape_with_channels = (self.info.num_channels,) + shape.to_tuple() - if data.shape != shape and data.shape != shape_with_channels: + if data.shape not in (shape, shape_with_channels): padded_data = np.zeros(shape_with_channels, dtype=data.dtype) padded_data[ :, @@ -355,16 +366,21 @@ def read(self, offset: Vec3IntLike, shape: Vec3IntLike) -> np.ndarray: return data def ensure_size( - self, new_shape: Vec3IntLike, align_with_shards: bool = True, warn: bool = False + self, + new_bbox: NDBoundingBox, + align_with_shards: bool = True, + warn: bool = False, ) -> None: - new_shape = Vec3Int(new_shape) + new_shape = VecInt(new_bbox.size, axes=new_bbox.axes) zarray = self._zarray - new_shape_tuple = ( - zarray.shape[0], - max(zarray.shape[1], new_shape.x), - max(zarray.shape[2], new_shape.y), - max(zarray.shape[3], new_shape.z), + new_shape_tuple = (zarray.shape[0],) + tuple( + ( + max(zarray.shape[i + 1], new_shape[i]) + if len(zarray.shape) > i + else new_shape[i] + ) + for i in range(len(new_shape)) ) if new_shape_tuple != zarray.shape: if align_with_shards: @@ -388,24 +404,22 @@ def ensure_size( ) zarray.resize(new_shape_tuple) - def write(self, offset: Vec3IntLike, data: np.ndarray) -> None: - offset = Vec3Int(offset) + def write(self, bbox: NDBoundingBox, data: np.ndarray) -> None: + """Writes a ZarrArray. If offset and bbox are given, the bbox is preferred to enable writing of n-dimensional data.""" + # If data is 3-dimensional, it is assumed that num_channels=1. if data.ndim == 3: data = data.reshape((1,) + data.shape) assert data.ndim == 4 with _blosc_disable_threading(): - self.ensure_size(offset + Vec3Int(data.shape[1:4]), warn=True) + self.ensure_size(bbox, warn=True) zarray = self._zarray - zarray[ - :, - offset.x : (offset.x + data.shape[1]), - offset.y : (offset.y + data.shape[2]), - offset.z : (offset.z + data.shape[3]), - ] = data + index_tuple = (slice(None),) + bbox.to_slices() + + zarray[index_tuple] = data - def list_bounding_boxes(self) -> Iterator[BoundingBox]: + def list_bounding_boxes(self) -> Iterator[NDBoundingBox]: zarray = self._zarray chunk_shape = Vec3Int(*zarray.chunks[1:4]) for key in zarray.store.keys(): @@ -484,6 +498,10 @@ def info(self) -> ArrayInfo: from zarrita.sharding import ShardingCodec zarray = self._zarray + if (names := getattr(zarray.metadata, "dimension_names", None)) is None: + dimension_names = ("c", "x", "y", "z") + else: + dimension_names = names if isinstance(zarray, Array): if len(zarray.codec_pipeline.codecs) == 1 and isinstance( zarray.codec_pipeline.codecs[0], ShardingCodec @@ -499,8 +517,10 @@ def info(self) -> ArrayInfo: sharding_codec.codec_pipeline.codecs ), chunk_shape=Vec3Int(chunk_shape[1:4]), - chunks_per_shard=Vec3Int(shard_shape[1:4]) - // Vec3Int(chunk_shape[1:4]), + chunks_per_shard=Vec3Int( + Vec3Int(shard_shape[1:4]) // Vec3Int(chunk_shape[1:4]) + ), + dimension_names=dimension_names, ) return ArrayInfo( data_format=DataFormat.Zarr3, @@ -514,6 +534,7 @@ def info(self) -> ArrayInfo: ) or Vec3Int.full(1), chunks_per_shard=Vec3Int.full(1), + dimension_names=dimension_names, ) else: return ArrayInfo( @@ -523,6 +544,7 @@ def info(self) -> ArrayInfo: compression_mode=zarray.metadata.compressor is not None, chunk_shape=Vec3Int(*zarray.metadata.chunks[1:4]) or Vec3Int.full(1), chunks_per_shard=Vec3Int.full(1), + dimension_names=dimension_names, ) @classmethod @@ -532,37 +554,45 @@ def create(cls, path: Path, array_info: ArrayInfo) -> "ZarritaArray": assert array_info.data_format in (DataFormat.Zarr, DataFormat.Zarr3) if array_info.data_format == DataFormat.Zarr3: + chunk_shape = (array_info.num_channels,) + tuple( + getattr(array_info.chunk_shape, axis, 1) + for axis in array_info.dimension_names[1:] + ) + shard_shape = (array_info.num_channels,) + tuple( + getattr(array_info.shard_shape, axis, 1) + for axis in array_info.dimension_names[1:] + ) Array.create( store=path, - shape=(array_info.num_channels, 1, 1, 1), - chunk_shape=(array_info.num_channels,) - + array_info.shard_shape.to_tuple(), + shape=array_info.shape, + chunk_shape=shard_shape, chunk_key_encoding=("default", "/"), dtype=array_info.voxel_type, - dimension_names=["c", "x", "y", "z"], + dimension_names=array_info.dimension_names, codecs=[ zarrita.codecs.sharding_codec( - chunk_shape=(array_info.num_channels,) - + array_info.chunk_shape.to_tuple(), - codecs=[ - zarrita.codecs.transpose_codec([3, 2, 1, 0]), - zarrita.codecs.bytes_codec(), - zarrita.codecs.blosc_codec( - typesize=array_info.voxel_type.itemsize - ), - ] - if array_info.compression_mode - else [ - zarrita.codecs.transpose_codec([3, 2, 1, 0]), - zarrita.codecs.bytes_codec(), - ], + chunk_shape=chunk_shape, + codecs=( + [ + zarrita.codecs.transpose_codec(array_info.axis_order), + zarrita.codecs.bytes_codec(), + zarrita.codecs.blosc_codec( + typesize=array_info.voxel_type.itemsize + ), + ] + if array_info.compression_mode + else [ + zarrita.codecs.transpose_codec(array_info.axis_order), + zarrita.codecs.bytes_codec(), + ] + ), ) ], ) else: ArrayV2.create( store=path, - shape=(array_info.num_channels, 1, 1, 1), + shape=(array_info.shape), chunks=(array_info.num_channels,) + array_info.chunk_shape.to_tuple(), dtype=array_info.voxel_type, compressor=( @@ -575,46 +605,45 @@ def create(cls, path: Path, array_info: ArrayInfo) -> "ZarritaArray": ) return ZarritaArray(path) - def read(self, offset: Vec3IntLike, shape: Vec3IntLike) -> np.ndarray: - offset = Vec3Int(offset) - shape = Vec3Int(shape) + def read(self, bbox: NDBoundingBox) -> np.ndarray: + shape = bbox.size.to_tuple() zarray = self._zarray + slice_tuple = (slice(None),) + bbox.to_slices() with _blosc_disable_threading(): - data = zarray[ - :, - offset.x : (offset.x + shape.x), - offset.y : (offset.y + shape.y), - offset.z : (offset.z + shape.z), - ] - shape_with_channels = (self.info.num_channels,) + shape.to_tuple() - if data.shape != shape and data.shape != shape_with_channels: - padded_data = np.zeros(shape_with_channels, dtype=data.dtype) - padded_data[ - :, - 0 : data.shape[1], - 0 : data.shape[2], - 0 : data.shape[3], - ] = data + data = zarray[slice_tuple] + + shape_with_channels = (self.info.num_channels,) + shape + if data.shape != shape_with_channels: + data_slice_tuple = tuple(slice(0, size) for size in data.shape) + padded_data = np.zeros(shape_with_channels, dtype=zarray.metadata.dtype) + padded_data[data_slice_tuple] = data data = padded_data return data def ensure_size( - self, new_shape: Vec3IntLike, align_with_shards: bool = True, warn: bool = False + self, + new_bbox: NDBoundingBox, + align_with_shards: bool = True, + warn: bool = False, ) -> None: - new_shape = Vec3Int(new_shape) zarray = self._zarray - new_shape_tuple = ( - zarray.metadata.shape[0], - max(zarray.metadata.shape[1], new_shape.x), - max(zarray.metadata.shape[2], new_shape.y), - max(zarray.metadata.shape[3], new_shape.z), + new_bbox = new_bbox.with_bottomright( + ( + max(zarray.metadata.shape[i + 1], new_bbox.bottomright[i]) + for i in range(len(new_bbox)) + ) ) + new_shape_tuple = (zarray.metadata.shape[0],) + tuple(new_bbox.bottomright) if new_shape_tuple != zarray.metadata.shape: if align_with_shards: shard_shape = self.info.shard_shape - new_shape = new_shape.ceildiv(shard_shape) * shard_shape - new_shape_tuple = (zarray.metadata.shape[0],) + new_shape.to_tuple() + new_aligned_bbox = new_bbox.with_bottomright_xyz( + new_bbox.bottomright_xyz.ceildiv(shard_shape) * shard_shape + ) + new_shape_tuple = ( + zarray.metadata.shape[0], + ) + new_aligned_bbox.bottomright.to_tuple() # Check on-disk for changes to shape current_zarray = zarray.open(self._path) @@ -630,24 +659,21 @@ def ensure_size( ) self._cached_zarray = zarray.resize(new_shape_tuple) - def write(self, offset: Vec3IntLike, data: np.ndarray) -> None: - offset = Vec3Int(offset) - - if data.ndim == 3: + def write(self, bbox: NDBoundingBox, data: np.ndarray) -> None: + if data.ndim == len(bbox): + # the bbox does not include the channels, if data and bbox have the same size there is only 1 channel data = data.reshape((1,) + data.shape) - assert data.ndim == 4 + + assert data.ndim == len(bbox) + 1 with _blosc_disable_threading(): - self.ensure_size(offset + Vec3Int(data.shape[1:4]), warn=True) + self.ensure_size(bbox, warn=True) zarray = self._zarray - zarray[ - :, - offset.x : (offset.x + data.shape[1]), - offset.y : (offset.y + data.shape[2]), - offset.z : (offset.z + data.shape[3]), - ] = data + index_tuple = (slice(None),) + bbox.to_slices() + + zarray[index_tuple] = data - def list_bounding_boxes(self) -> Iterator[BoundingBox]: + def list_bounding_boxes(self) -> Iterator[NDBoundingBox]: raise NotImplementedError def close(self) -> None: diff --git a/webknossos/webknossos/dataset/_utils/buffered_slice_reader.py b/webknossos/webknossos/dataset/_utils/buffered_slice_reader.py index de6969899..69b73031e 100644 --- a/webknossos/webknossos/dataset/_utils/buffered_slice_reader.py +++ b/webknossos/webknossos/dataset/_utils/buffered_slice_reader.py @@ -9,8 +9,7 @@ if TYPE_CHECKING: from ..view import View -from ...geometry import BoundingBox, Vec3IntLike -from ...utils import get_chunks +from ...geometry import BoundingBox, NDBoundingBox, Vec3IntLike class BufferedSliceReader: @@ -23,8 +22,8 @@ def __init__( buffer_size: int = 32, dimension: int = 2, # z *, - relative_bounding_box: Optional[BoundingBox] = None, # in mag1 - absolute_bounding_box: Optional[BoundingBox] = None, # in mag1 + relative_bounding_box: Optional[NDBoundingBox] = None, # in mag1 + absolute_bounding_box: Optional[NDBoundingBox] = None, # in mag1 use_logging: bool = False, ) -> None: """see `View.get_buffered_slice_reader()`""" @@ -61,49 +60,17 @@ def __init__( self.bbox_current_mag = absolute_bounding_box.in_mag(view.mag) def _get_slice_generator(self) -> Generator[np.ndarray, None, None]: - for batch in get_chunks( - list( - range( - self.bbox_current_mag.topleft[self.dimension], - self.bbox_current_mag.bottomright[self.dimension], - ) - ), - self.buffer_size, - ): - n_slices = len(batch) - batch_start_idx = batch[0] - - assert ( - n_slices <= self.buffer_size - ), f"n_slices should at most be batch_size, but {n_slices} > {self.buffer_size}" - - bbox_offset = self.bbox_current_mag.topleft - bbox_size = self.bbox_current_mag.size - - buffer_bounding_box = BoundingBox.from_tuple2( - ( - bbox_offset[: self.dimension] - + (batch_start_idx,) - + bbox_offset[self.dimension + 1 :], - bbox_size[: self.dimension] - + (n_slices,) - + bbox_size[self.dimension + 1 :], - ) - ) + chunk_size = self.bbox_current_mag.size_xyz.to_list() + chunk_size[self.dimension] = self.buffer_size + for chunk in self.bbox_current_mag.chunk(chunk_size): if self.use_logging: - info( - f"({getpid()}) Reading {n_slices} slices at position {batch_start_idx}." - ) + info(f"({getpid()}) Reading data from bbox {chunk}.") data = self.view.read( - absolute_bounding_box=buffer_bounding_box.from_mag_to_mag1( - self.view.mag - ) + absolute_bounding_box=chunk.from_mag_to_mag1(self.view.mag) ) - for current_slice in np.rollaxis( - data, self.dimension + 1 - ): # The '+1' is important because the first dimension is the channel + for current_slice in np.rollaxis(data, chunk.index_xyz[self.dimension]): yield current_slice def __enter__(self) -> Generator[np.ndarray, None, None]: diff --git a/webknossos/webknossos/dataset/_utils/buffered_slice_writer.py b/webknossos/webknossos/dataset/_utils/buffered_slice_writer.py index 95aec77e1..296a02453 100644 --- a/webknossos/webknossos/dataset/_utils/buffered_slice_writer.py +++ b/webknossos/webknossos/dataset/_utils/buffered_slice_writer.py @@ -9,6 +9,8 @@ import numpy as np import psutil +from webknossos.geometry.nd_bounding_box import NDBoundingBox + from ...geometry import BoundingBox, Vec3Int, Vec3IntLike if TYPE_CHECKING: @@ -43,6 +45,8 @@ def __init__( *, relative_offset: Optional[Vec3IntLike] = None, # in mag1 absolute_offset: Optional[Vec3IntLike] = None, # in mag1 + relative_bounding_box: Optional[NDBoundingBox] = None, # in mag1 + absolute_bounding_box: Optional[NDBoundingBox] = None, # in mag1 use_logging: bool = False, ) -> None: """see `View.get_buffered_slice_writer()`""" @@ -52,7 +56,15 @@ def __init__( self.dtype = self.view.get_dtype() self.use_logging = use_logging self.json_update_allowed = json_update_allowed - if offset is None and relative_offset is None and absolute_offset is None: + self.bbox: NDBoundingBox + + if ( + offset is None + and relative_offset is None + and absolute_offset is None + and relative_bounding_box is None + and absolute_bounding_box is None + ): relative_offset = Vec3Int.zeros() if offset is not None: warnings.warn( @@ -61,25 +73,28 @@ def __init__( DeprecationWarning, ) self.offset = None if offset is None else Vec3Int(offset) - self.relative_offset = ( - None if relative_offset is None else Vec3Int(relative_offset) - ) - self.absolute_offset = ( - None if absolute_offset is None else Vec3Int(absolute_offset) - ) - self.dimension = dimension - effective_offset = Vec3Int.full(0) - if self.relative_offset is not None: - effective_offset = self.view.bounding_box.topleft + self.relative_offset + if relative_offset is not None: + self.bbox = BoundingBox( + self.view.bounding_box.topleft + relative_offset, Vec3Int.zeros() + ) + + if absolute_offset is not None: + self.bbox = BoundingBox(absolute_offset, Vec3Int.zeros()) - if self.absolute_offset is not None: - effective_offset = self.absolute_offset + if relative_bounding_box is not None: + self.bbox = relative_bounding_box.offset(self.view.bounding_box.topleft) - view_chunk_depth = self.view.info.chunk_shape[self.dimension] + if absolute_bounding_box is not None: + self.bbox = absolute_bounding_box + + assert 0 <= dimension <= 2 # either x (0), y (1) or z (2) + self.dimension = dimension + + view_chunk_depth = self.view.info.chunk_shape[dimension] if ( - effective_offset is not None - and effective_offset[self.dimension] % view_chunk_depth != 0 + self.bbox is not None + and self.bbox.topleft_xyz[self.dimension] % view_chunk_depth != 0 ): warnings.warn( "[WARNING] Using an offset that doesn't align with the datataset's chunk size, " @@ -91,8 +106,6 @@ def __init__( + "will slow down the buffered slice writer.", ) - assert 0 <= dimension <= 2 - self.slices_to_write: List[np.ndarray] = [] self.current_slice: Optional[int] = None self.buffer_start_slice: Optional[int] = None @@ -129,52 +142,83 @@ def _flush_buffer(self) -> None: max_width = max(section.shape[-2] for section in self.slices_to_write) max_height = max(section.shape[-1] for section in self.slices_to_write) channel_count = self.slices_to_write[0].shape[0] - buffer_depth = min(self.buffer_size, len(self.slices_to_write)) - buffer_bbox = BoundingBox((0, 0, 0), (max_width, max_height, buffer_depth)) - - shard_dimensions = self.view._get_file_dimensions().moveaxis( - -1, self.dimension + buffer_start = Vec3Int.zeros().with_replaced( + self.dimension, self.buffer_start_slice ) + + bbox = self.bbox.with_size_xyz( + Vec3Int(max_width, max_height, buffer_depth).moveaxis( + -1, self.dimension + ) + ).offset(buffer_start) + + shard_dimensions = self.view._get_file_dimensions() chunk_size = Vec3Int( min(shard_dimensions[0], max_width), min(shard_dimensions[1], max_height), buffer_depth, - ) - for chunk_bbox in buffer_bbox.chunk(chunk_size): - info(f"Writing chunk {chunk_bbox}") - width, height, depth = chunk_bbox.size + ).moveaxis(-1, self.dimension) + for chunk_bbox in bbox.chunk(chunk_size): + info(f"Writing chunk {chunk_bbox}.") + data = np.zeros( - (channel_count, width, height, depth), + (channel_count, *chunk_bbox.size), dtype=self.slices_to_write[0].dtype, ) + section_topleft = Vec3Int( + (chunk_bbox.topleft_xyz - bbox.topleft_xyz).moveaxis( + self.dimension, -1 + ) + ) + section_bottomright = Vec3Int( + (chunk_bbox.bottomright_xyz - bbox.topleft_xyz).moveaxis( + self.dimension, -1 + ) + ) + + z_index = chunk_bbox.index_xyz[self.dimension] z = 0 for section in self.slices_to_write: section_chunk = section[ :, - chunk_bbox.topleft.x : chunk_bbox.bottomright.x, - chunk_bbox.topleft.y : chunk_bbox.bottomright.y, + section_topleft.x : section_bottomright.x, + section_topleft.y : section_bottomright.y, ] + # Section chunk includes the axes c, x, y. The remaining axes are added by considering + # the length of the bbox. Since the bbox does not contain the channel, we subtract 2 + # instead of 3. + section_chunk = section_chunk[ + (slice(None), slice(None), slice(None)) + + tuple(np.newaxis for _ in range(len(bbox) - 2)) + ] + section_chunk = np.moveaxis( + section_chunk, + [1, 2], + bbox.index_xyz[: self.dimension] + + bbox.index_xyz[self.dimension + 1 :], + ) + + slice_tuple = (slice(None),) + tuple( + slice(0, min(size1, size2)) + for size1, size2 in zip( + chunk_bbox.size, section_chunk.shape[1:] + ) + ) + data[ - :, 0 : section_chunk.shape[-2], 0 : section_chunk.shape[-1], z + slice_tuple[:z_index] + + (slice(z, z + 1),) + + slice_tuple[z_index + 1 :] ] = section_chunk z += 1 - buffer_start = Vec3Int( - chunk_bbox.topleft.x, chunk_bbox.topleft.y, self.buffer_start_slice - ).moveaxis(-1, self.dimension) - buffer_start_mag1 = buffer_start * self.view.mag.to_vec3_int() - - data = np.moveaxis(data, -1, self.dimension + 1) - self.view.write( data, - offset=buffer_start.add_or_none(self.offset), - relative_offset=buffer_start_mag1.add_or_none(self.relative_offset), - absolute_offset=buffer_start_mag1.add_or_none(self.absolute_offset), json_update_allowed=self.json_update_allowed, + absolute_bounding_box=chunk_bbox.from_mag_to_mag1(self.view._mag), ) del data diff --git a/webknossos/webknossos/dataset/_utils/infer_bounding_box_existing_files.py b/webknossos/webknossos/dataset/_utils/infer_bounding_box_existing_files.py index 177042e01..47c4fcb6c 100644 --- a/webknossos/webknossos/dataset/_utils/infer_bounding_box_existing_files.py +++ b/webknossos/webknossos/dataset/_utils/infer_bounding_box_existing_files.py @@ -10,7 +10,7 @@ def infer_bounding_box_existing_files(mag_view: MagView) -> BoundingBox: The returned bounding box is measured in Mag(1) voxels.""" return reduce( - lambda acc, bbox: acc.extended_by(bbox), + lambda acc, bbox: acc.extended_by(BoundingBox.from_ndbbox(bbox)), mag_view.get_bounding_boxes_on_disk(), BoundingBox.empty(), ) diff --git a/webknossos/webknossos/dataset/_utils/pims_images.py b/webknossos/webknossos/dataset/_utils/pims_images.py index 2587e8d03..bb4942742 100644 --- a/webknossos/webknossos/dataset/_utils/pims_images.py +++ b/webknossos/webknossos/dataset/_utils/pims_images.py @@ -24,6 +24,10 @@ from natsort import natsorted from numpy.typing import DTypeLike +from webknossos.geometry.bounding_box import BoundingBox +from webknossos.geometry.nd_bounding_box import NDBoundingBox + +# pylint: disable=unused-import try: from .pims_czi_reader import PimsCziReader except ImportError: @@ -45,7 +49,7 @@ pass -from ...geometry.vec3_int import Vec3Int +from ...geometry.vec_int import VecInt from ..mag_view import MagView try: @@ -67,7 +71,6 @@ def _assume_color_channel(dim_size: int, dtype: np.dtype) -> bool: class PimsImages: dtype: DTypeLike - expected_shape: Vec3Int num_channels: int def __init__( @@ -86,11 +89,11 @@ def __init__( """ During initialization the pims objects are examined and configured to produce ndarrays that follow the following form: - (self._iter_dim, *self._img_dims) - self._iter_dim can be either "z", "t" or "" if the image is 2D. + (self._iter_axes, *self._bundle_axis) + self._iter_axes can be a list of different axes or an empty list if the image is 2D. In the latter case, the inner 2D image is still wrapped in a single-element list by _open_images() to be consistent with 3D images. - self._img_dims can consist of "x", "y" and "c", where "c" is optional and must be + self._bundle_axis can consist of "x", "y" and "c", where "c" is optional and must be at the start or the end, so one of "xy", "yx", "xyc", "yxc", "cxy", "cyx". The part "IDENTIFY AXIS ORDER" figures out (self._iter_dim, *self._img_dims) @@ -116,9 +119,10 @@ def __init__( self._use_bioformats = use_bioformats ## attributes that will be set in __init__() - self._iter_dim = None + # _bundle_axes + self._iter_axes = None + self._iter_loop_size = None self._possible_layers = {} - # _img_dims ## attributes only for pims.FramesSequenceND instances: # _default_coords @@ -126,7 +130,6 @@ def __init__( ## attributes that will also be set in __init__() # dtype - # expected_shape # num_channels # _first_n_channels @@ -141,10 +144,6 @@ def __init__( self.dtype = images.dtype if isinstance(images, pims.FramesSequenceND): - assert all( - axis in "xyzct" for axis in images.axes - ), f"Found unknown axes {set(images.axes) - set('xyzct')}" - self._default_coords = {} self._init_c_axis = False if isinstance(images, pims.imageio_reader.ImageIOReader): @@ -163,34 +162,49 @@ def __init__( if len(available_czi_channels) > 1: self._possible_layers["czi_channel"] = available_czi_channels + # An image slice should always consist of a 2D image. If there are multiple channels + # the data of each channel is part of the image slices. Possible shapes of an image + # slice are (#y_shape, #x_shape), (1, #y_shape, #x_shape) or (3, #y_shape, #x_shape). if images.sizes.get("c", 1) > 1: - self._img_dims = "cyx" + self._bundle_axes = ["c", "y", "x"] else: if "c" in images.axes: + # When c-axis is not in _bundle_axes and _iter_axes its value at coordinate 0 + # should be returned self._default_coords["c"] = 0 - self._img_dims = "yx" - - self._iter_dim = "" - - if images.sizes.get("z", 1) > 1: - self._iter_dim = "z" - elif "z" in images.axes: - self._default_coords["z"] = 0 - - if timepoint is None: - if images.sizes.get("t", 1) > 1: - if self._iter_dim == "": - self._iter_dim = "t" - else: - self._default_coords["t"] = 0 - self._possible_layers["timepoint"] = list( - range(0, images.sizes["t"]) - ) - elif "t" in images.axes: - self._default_coords["t"] = 0 - else: - assert "t" in images.axes - self._default_coords["t"] = timepoint + self._bundle_axes = ["y", "x"] + + # All other axes are used to iterate over them. The last one is iterated the fastest. + self._iter_axes = list( + set(images.axes).difference({*self._bundle_axes, "c", "z"}) + ) + if "z" in images.axes: + self._iter_axes.append("z") + + if self._timepoint is not None: + # if a timepoint is given, PimsImages should only generate image slices for that timepoint + if "t" in self._iter_axes: + self._iter_axes.remove("t") + self._default_coords["t"] = self._timepoint + + if len(self._iter_axes) > 1: + iter_size = 1 + self._iter_loop_size = dict() + for axis, other_axis in zip( + self._iter_axes[-1:0:-1], self._iter_axes[-2::-1] + ): + # Creates a dict that contains the size of the loop for each axis + # the axes are identified by their index in the _iter_axes list + # the last axis is the fastest iterating axis, therfore the size of the loop + # for the last axis is 1. For all other axes it is the product of all previous axes sizes. + # self._iter_axes[-1:0:-1] is a reversed copy of self._iter_axes without the last element + # e.g. [1,2,3,4] -> [4,3,2] + # self._iter_axes[-2::-1] is a reversed copy of self._iter_axes without the first element + # e.g. [1,2,3,4] -> [3,2,1] + self._iter_loop_size[other_axis] = ( + iter_size := iter_size * images.sizes[axis] + ) + else: # Fallback for generic pims classes that do not name their # dimensions as pims.FramesSequenceND does: @@ -201,31 +215,31 @@ def __init__( if len(images.shape) == 2: # Assume yx - self._img_dims = "yx" - self._iter_dim = "" + self._bundle_axes = ["y", "x"] + self._iter_axes = [] elif len(images.shape) == 3: # Assume yxc, cyx or zyx if _assume_color_channel(images.shape[2], images.dtype): - self._img_dims = "yxc" - self._iter_dim = "" + self._bundle_axes = ["y", "x", "c"] + self._iter_axes = [] elif images.shape[0] == 1 or ( _allow_channels_first and _assume_color_channel(images.shape[0], images.dtype) ): - self._img_dims = "cyx" - self._iter_dim = "" + self._bundle_axes = ["c", "y", "x"] + self._iter_axes = [] else: - self._img_dims = "yx" - self._iter_dim = "z" + self._bundle_axes = ["y", "x"] + self._iter_axes = ["z"] elif len(images.shape) == 4: # Assume zcyx or zyxc if images.shape[1] == 1 or _assume_color_channel( images.shape[1], images.dtype ): - self._img_dims = "cyx" + self._bundle_axes = ["c", "y", "x"] else: - self._img_dims = "yxc" - self._iter_dim = "z" + self._bundle_axes = ["y", "x", "c"] + self._iter_axes = ["z"] elif len(images.shape) == 5: # Assume tzcyx or tzyxc # t has to be constant for this reader to obtain 4D image @@ -236,13 +250,15 @@ def __init__( raise RuntimeError( f"Got {len(images.shape)} axes for the images after " + "removing time dimension, can only map to 3D+channels." + + "To import image with more dimensions use dataformat" + + "Zarr3 and set use_bioformats=True." ) if _assume_color_channel(images.shape[2], images.dtype): - self._img_dims = "cyx" + self._bundle_axes = ["c", "y", "x"] else: - self._img_dims = "yxc" - self._iter_dim = "z" + self._bundle_axes = ["y", "x", "c"] + self._iter_axes = ["z"] self._timepoint = 0 if images.shape[0] > 1: self._possible_layers["timepoint"] = list( @@ -251,36 +267,29 @@ def __init__( else: raise RuntimeError( f"Got {len(images.shape)} axes for the images, " - + "cannot map to 3D+channels+timepoints." + + "but don't have axes information. Try to open " + + "an N-dimensional image file with use_bioformats=" + + "True." ) - ############################# - # IDENTIFY SHAPE & CHANNELS # - ############################# + ######################### + # IDENTIFY NUM_CHANNELS # + ######################### with self._open_images() as images: - if isinstance(images, list): - images_shape = (len(images),) + cast( - pims.FramesSequence, images[0] - ).shape - else: - images_shape = images.shape - c_index = self._img_dims.find("c") - if c_index == -1: - self.num_channels = 1 - else: - # Since images_shape contains the first dimension iter_dim, - # we need to offset the index by one before accessing the images_shape. - # images_shape corresponds to (z, *_img_dims) + try: + c_index = self._bundle_axes.index("c") + if isinstance(images, list): + images_shape = (len(images),) + cast( + pims.FramesSequence, images[0] + ).shape + else: + images_shape = images.shape # pylint: disable=no-member + self.num_channels = images_shape[c_index + 1] - x_index = self._img_dims.find("x") + 1 - y_index = self._img_dims.find("y") + 1 - if swap_xy: - x_index, y_index = y_index, x_index - self.expected_shape = Vec3Int( - images_shape[x_index], images_shape[y_index], images_shape[0] - ) + except ValueError: + self.num_channels = 1 self._first_n_channels = None if self._channel is not None: @@ -418,13 +427,13 @@ def _open_images( self, ) -> Iterator[Union[pims.FramesSequence, List[pims.FramesSequence]]]: """ - This yields well-defined images of the form (self._iter_dim, *self._img_dims), + This yields well-defined images of the form (self._iter_axes, *self._bundle_axes), after IDENTIFY AXIS ORDER of __init__() has run. For a 2D image this is achieved by wrapping it in a list. """ images_context_manager: Optional[ContextManager] with warnings.catch_warnings(): - if isinstance(self._original_images, pims.FramesSequence): + if isinstance(self._original_images, pims.FramesSequenceND): images_context_manager = nullcontext(enter_result=self._original_images) else: exceptions: List[Exception] = [] @@ -454,7 +463,7 @@ def _open_images( with images_context_manager as images: if isinstance(images, pims.FramesSequenceND): - if hasattr(self, "_img_dims"): + if hasattr(self, "_bundle_axes"): # first part of __init__() has happened images.default_coords.update(self._default_coords) if self._init_c_axis and "c" not in images.sizes: @@ -466,29 +475,41 @@ def _open_images( images._get_frame_dict[key + ("c",)] = ( images._get_frame_dict.pop(key) ) - images.bundle_axes = self._img_dims - images.iter_axes = self._iter_dim or "" + self._bundle_axes.remove("c") + self._bundle_axes.append("c") + images.bundle_axes = self._bundle_axes + images.iter_axes = self._iter_axes else: if self._timepoint is not None: images = images[self._timepoint] - if self._iter_dim == "": + if self._iter_axes and "t" in self._iter_axes: + self._iter_axes.remove("t") + if self._iter_axes == []: # add outer list to wrap 2D images as 3D-like structure images = [images] yield images def copy_to_view( self, - args: Tuple[int, int], + args: Union[BoundingBox, NDBoundingBox], mag_view: MagView, is_segmentation: bool, dtype: Optional[DTypeLike] = None, ) -> Tuple[Tuple[int, int], Optional[int]]: """Copies the images according to the passed arguments to the given mag_view. - args is expected to be the start and end of the z-range, meant for usage with an executor. + args is expected to be a (ND)BoundingBox the start and end of the z-range, meant for usage with an executor. copy_to_view returns an iterable of image shapes and largest segment ids. When using this method a manual update of the bounding box and the largest segment id might be necessary. """ - z_start, z_end = args + relative_bbox = args + + assert all( + size == 1 + for size, axis in zip(relative_bbox.size, relative_bbox.axes) + if axis not in ("x", "y", "z") + ), "The delivered BoundingBox has to be flat except for x,y and z dimension." + + z_start, z_end = relative_bbox.get_bounds("z") shapes = [] max_id: Optional[int] if is_segmentation: @@ -497,10 +518,22 @@ def copy_to_view( max_id = None with self._open_images() as images: + if self._iter_axes is not None and self._iter_loop_size is not None: + # select the range of images that represents one xyz combination + lower_bounds = sum( + self._iter_loop_size[axis_name] + * relative_bbox.get_bounds(axis_name)[0] + for axis_name in self._iter_axes[:-1] + ) + upper_bounds = lower_bounds + relative_bbox.get_shape("z") + images = images[lower_bounds:upper_bounds] if self._flip_z: - images = images[::-1] + images = images[::-1] # pylint: disable=unsubscriptable-object + with mag_view.get_buffered_slice_writer( - relative_offset=(0, 0, z_start * mag_view.mag.z), + # Previously only z_start and its end were important, now the slice writer needs to know + # which axis is currently written. + relative_bounding_box=relative_bbox, buffer_size=mag_view.info.chunk_shape.z, # copy_to_view is typically used in a multiprocessing-context. Therefore the # buffered slice writer should not update the json file to avoid race conditions. @@ -509,15 +542,17 @@ def copy_to_view( for image_slice in images[z_start:z_end]: image_slice = np.array(image_slice) # place channels first - if self._img_dims.endswith("c"): - image_slice = np.moveaxis(image_slice, source=-1, destination=0) - # ensure the last two axes are xy: - if ("yx" in self._img_dims and not self._swap_xy) or ( - "xy" in self._img_dims and self._swap_xy - ): - image_slice = image_slice.swapaxes(-1, -2) - - if "c" in self._img_dims: + if "c" in self._bundle_axes: + if hasattr(self, "_init_c_axis") and self._init_c_axis: + # Bugfix for ImageIOReader which misses channel axis sometimes, + # assuming channels come last. _init_c_axis is set in __init__(). + # This might get fixed via + image_slice = image_slice[0] + image_slice = np.moveaxis( + image_slice, + source=self._bundle_axes.index("c"), + destination=0, + ) if self._channel is not None: image_slice = image_slice[self._channel : self._channel + 1] elif self._first_n_channels is not None: @@ -537,6 +572,10 @@ def copy_to_view( if max_id is not None: max_id = max(max_id, image_slice.max()) + + if self._swap_xy is False: + image_slice = np.moveaxis(image_slice, -1, -2) + shapes.append(image_slice.shape[-2:]) writer.send(image_slice) @@ -548,6 +587,78 @@ def get_possible_layers(self) -> Optional[Dict["str", List[int]]]: else: return self._possible_layers + @property + def expected_bbox(self) -> NDBoundingBox: + # replaces the previous expected_shape to enable n-dimensional input files + with self._open_images() as images: + if isinstance(images, pims.FramesSequenceND): + axes = images.axes + images_shape = tuple(images.sizes[axis] for axis in axes) + else: + if isinstance(images, list): + images_shape = (len(images),) + cast( + pims.FramesSequence, images[0] + ).shape + + else: + images_shape = images.shape # pylint: disable=no-member + if len(images_shape) == 3: + axes = ("z", "y", "x") + else: + axes = ("z", "c", "y", "x") + + if self._iter_loop_size is None: + # There is no or only one element in self._iter_axes, so a 3D bounding box is sufficient. + x_index, y_index = ( + axes.index("x"), + axes.index("y"), + ) + if self._iter_axes: + try: + # In case the naming of the third axis is not "z", + # it is still considered as the z-axis. + z_index = axes.index(self._iter_axes[0]) + except ValueError: + z_index = axes.index("z") + z_shape = images_shape[z_index] + else: + z_shape = 1 + if self._swap_xy: + x_index, y_index = y_index, x_index + return BoundingBox( + (0, 0, 0), + (images_shape[x_index], images_shape[y_index], z_shape), + ) + else: + if isinstance(images, pims.FramesSequenceND): + axes_names = (self._iter_axes or []) + [ + axis for axis in self._bundle_axes if axis != "c" + ] + axes_sizes = [ + images.sizes[axis] # pylint: disable=no-member + for axis in axes_names + ] + axes_index = list(range(1, len(axes_names) + 1)) + topleft = VecInt.zeros(tuple(axes_names)) + + if self._swap_xy: + x_index, y_index = axes_names.index("x"), axes_names.index("y") + axes_sizes[x_index], axes_sizes[y_index] = ( + axes_sizes[y_index], + axes_sizes[x_index], + ) + + return NDBoundingBox( + topleft, + VecInt(axes_sizes, axes=axes_names), + axes_names, + VecInt(axes_index, axes=axes_names), + ) + + raise ValueError( + "It seems as if you try to load an N-dimensional image from 2D images. This is currently not supported." + ) + T = TypeVar("T", bound=Tuple[int, ...]) @@ -605,4 +716,4 @@ def has_image_z_dimension( flip_z=False, ) - return pims_images.expected_shape.z > 1 + return pims_images.expected_bbox.get_shape("z") > 1 diff --git a/webknossos/webknossos/dataset/dataset.py b/webknossos/webknossos/dataset/dataset.py index ddf74051b..46b7e96c6 100644 --- a/webknossos/webknossos/dataset/dataset.py +++ b/webknossos/webknossos/dataset/dataset.py @@ -35,6 +35,8 @@ from numpy.typing import DTypeLike from upath import UPath +from webknossos.geometry.vec_int import VecInt, VecIntLike + from ..client.api_client.models import ApiDataset from ..geometry.vec3_int import Vec3Int, Vec3IntLike from ._array import ArrayException, ArrayInfo, BaseArray @@ -60,7 +62,7 @@ from ..administration.user import Team from ..client._upload_dataset import LayerToLink -from ..geometry import BoundingBox, Mag +from ..geometry import BoundingBox, Mag, NDBoundingBox from ..utils import ( copy_directory_with_symlinks, copytree, @@ -778,7 +780,7 @@ def add_layer( dtype_per_channel: Optional[DTypeLike] = None, num_channels: Optional[int] = None, data_format: Union[str, DataFormat] = DEFAULT_DATA_FORMAT, - bounding_box: Optional[BoundingBox] = None, + bounding_box: Optional[NDBoundingBox] = None, **kwargs: Any, ) -> Layer: """ @@ -1043,7 +1045,7 @@ def add_layer_from_images( compress: bool = False, *, ## other arguments - topleft: Vec3IntLike = Vec3Int.zeros(), # in Mag(1) + topleft: VecIntLike = Vec3Int.zeros(), # in Mag(1) swap_xy: bool = False, flip_x: bool = False, flip_y: bool = False, @@ -1210,8 +1212,12 @@ def add_layer_from_images( num_channels=pims_images.num_channels, **add_layer_kwargs, # type: ignore[arg-type] ) + + expected_bbox = pims_images.expected_bbox + + # When the expected bbox is 2D the chunk_shape is set to 2D too. if ( - pims_images.expected_shape.z == 1 + expected_bbox.get_shape("z") == 1 and layer.data_format == DataFormat.Zarr ): if chunk_shape is None: @@ -1222,18 +1228,14 @@ def add_layer_from_images( if chunks_per_shard is None and layer.data_format == DataFormat.Zarr3: chunks_per_shard = DEFAULT_CHUNKS_PER_SHARD_FROM_IMAGES + mag = Mag(mag) + layer.bounding_box = expected_bbox.from_mag_to_mag1(mag).offset(topleft) mag_view = layer.add_mag( mag=mag, chunk_shape=chunk_shape, chunks_per_shard=chunks_per_shard, compress=compress, ) - mag = mag_view.mag - layer.bounding_box = ( - BoundingBox((0, 0, 0), pims_images.expected_shape) - .from_mag_to_mag1(mag) - .offset(topleft) - ) if batch_size is None: if compress: @@ -1257,10 +1259,44 @@ def add_layer_from_images( ) args = [] - for z_start in range(0, pims_images.expected_shape.z, batch_size): - z_end = min(z_start + batch_size, pims_images.expected_shape.z) - # return shapes and set to union when using --pad - args.append((z_start, z_end)) + bbox = layer.bounding_box + additional_axes = [ + axis_name for axis_name in bbox.axes if axis_name not in ("x", "y", "z") + ] + additional_axes_shapes = tuple( + product( + *[range(bbox.get_shape(axis_name)) for axis_name in additional_axes] + ) + ) + if additional_axes and layer.data_format != DataFormat.Zarr3: + assert ( + len(additional_axes_shapes) == 1 + ), "The data stores additional axes with shape bigger than 1. These are only supported by data format Zarr3." + + # Convert NDBoundingBox to 3D BoundingBox + bbox = BoundingBox( + bbox.topleft_xyz, + bbox.size_xyz, + ) + expected_bbox = bbox + additional_axes = [] + + z_shape = bbox.get_shape("z") + bbox = bbox.with_topleft(VecInt.zeros(bbox.axes)) + for z_start in range(0, z_shape, batch_size): + z_size = min(batch_size, z_shape - z_start) + z_bbox = bbox.with_bounds("z", z_start, z_size) + if not additional_axes: + args.append(z_bbox) + else: + for shape in additional_axes_shapes: + reduced_bbox = z_bbox + for index, axis in enumerate(additional_axes): + reduced_bbox = reduced_bbox.with_bounds( + axis, shape[index], 1 + ) + args.append(reduced_bbox) + with warnings.catch_warnings(): # Block alignmnent within the dataset should not be a problem, since shard-wise chunking is enforced. # However, dataset borders might change between different parallelized writes, when sizes differ. @@ -1289,18 +1325,14 @@ def add_layer_from_images( if category == "segmentation": max_id = max(max_ids) cast(SegmentationLayer, layer).largest_segment_id = max_id - actual_size = Vec3Int( - dimwise_max(shapes) + (pims_images.expected_shape.z,) - ) - layer.bounding_box = ( - BoundingBox((0, 0, 0), actual_size) - .from_mag_to_mag1(mag) - .offset(topleft) + layer.bounding_box = layer.bounding_box.with_size_xyz( + Vec3Int(dimwise_max(shapes) + (layer.bounding_box.get_shape("z"),)) + * mag.to_vec3_int().with_z(1) ) - if pims_images.expected_shape != actual_size: + if expected_bbox != layer.bounding_box: warnings.warn( "[WARNING] Some images are larger than expected, smaller slices are padded with zeros now. " - + f"New size is {actual_size}, expected {pims_images.expected_shape}." + + f"New bbox is {layer.bounding_box}, expected {expected_bbox}." ) if first_layer is None: first_layer = layer @@ -1510,7 +1542,7 @@ def add_fs_copy_layer( self._export_as_json() return self.layers[new_layer_name] - def calculate_bounding_box(self) -> BoundingBox: + def calculate_bounding_box(self) -> NDBoundingBox: """ Calculates and returns the enclosing bounding box of all data layers of the dataset. """ @@ -1731,12 +1763,19 @@ def _export_as_json(self) -> None: self._ensure_writable() properties_on_disk = self._load_properties() - - if properties_on_disk != self._last_read_properties: + try: + if properties_on_disk != self._last_read_properties: + warnings.warn( + "[WARNING] While exporting the dataset's properties, properties were found on disk which are " + + "newer than the ones that were seen last time. The properties will be overwritten. This is " + + "likely happening because multiple processes changed the metadata of this dataset." + ) + except ValueError: + # the __eq__ operator raises a ValueError when two bboxes are not comparable. This is the case when the + # axes are not the same. During initialization axes are added or moved sometimes. warnings.warn( - "[WARNING] While exporting the dataset's properties, properties were found on disk which are " - + "newer than the ones that were seen last time. The properties will be overwritten. This is " - + "likely happening because multiple processes changed the metadata of this dataset." + "[WARNING] Properties changed in a way that they are not comparable anymore. Most likely " + + "the bounding box naming or axis order changed." ) with (self.path / PROPERTIES_FILE_NAME).open("w", encoding="utf-8") as outfile: diff --git a/webknossos/webknossos/dataset/layer.py b/webknossos/webknossos/dataset/layer.py index 7d37308f3..80de7443c 100644 --- a/webknossos/webknossos/dataset/layer.py +++ b/webknossos/webknossos/dataset/layer.py @@ -13,7 +13,7 @@ from numpy.typing import DTypeLike from upath import UPath -from ..geometry import BoundingBox, Mag, Vec3Int, Vec3IntLike +from ..geometry import Mag, NDBoundingBox, Vec3Int, Vec3IntLike from ._array import ArrayException, BaseArray, DataFormat from ._downsampling_utils import ( calculate_default_coarsest_mag, @@ -192,7 +192,7 @@ def __init__(self, dataset: "Dataset", properties: LayerProperties) -> None: self.path.mkdir(parents=True, exist_ok=True) for mag in properties.mags: - self._setup_mag(Mag(mag.mag)) + self._setup_mag(Mag(mag.mag), mag.path) # Only keep the properties of mags that were initialized. # Sometimes the directory of a mag is removed from disk manually, but the properties are not updated. self._properties.mags = [ @@ -251,11 +251,11 @@ def dataset(self) -> "Dataset": return self._dataset @property - def bounding_box(self) -> BoundingBox: + def bounding_box(self) -> NDBoundingBox: return self._properties.bounding_box @bounding_box.setter - def bounding_box(self, bbox: BoundingBox) -> None: + def bounding_box(self, bbox: NDBoundingBox) -> None: """ Updates the offset and size of the bounding box of this layer in the properties. """ @@ -264,9 +264,7 @@ def bounding_box(self, bbox: BoundingBox) -> None: self._properties.bounding_box = bbox self.dataset._export_as_json() for mag in self.mags.values(): - mag._array.ensure_size( - bbox.align_with_mag(mag.mag).in_mag(mag.mag).bottomright - ) + mag._array.ensure_size(bbox.align_with_mag(mag.mag).in_mag(mag.mag)) @property def category(self) -> LayerCategoryType: @@ -399,9 +397,7 @@ def add_mag( create=True, ) - mag_view._array.ensure_size( - self.bounding_box.align_with_mag(mag).in_mag(mag).bottomright - ) + mag_view._array.ensure_size(self.bounding_box.align_with_mag(mag).in_mag(mag)) self._mags[mag] = mag_view mag_array_info = mag_view.info @@ -414,7 +410,12 @@ def add_mag( else None ), axis_order=( - {"x": 1, "y": 2, "z": 3, "c": 0} + dict( + zip( + ("c", "x", "y", "z"), + (0, *self.bounding_box.index_xyz), + ) + ) if mag_array_info.data_format in (DataFormat.Zarr, DataFormat.Zarr3) else None ), @@ -451,7 +452,13 @@ def add_mag_for_existing_files( else None ), axis_order=( - {"x": 1, "y": 2, "z": 3, "c": 0} + { + key: value + for key, value in zip( + ("c", *self.bounding_box.axes), + (0, *self.bounding_box.index), + ) + } if mag_array_info.data_format in (DataFormat.Zarr, DataFormat.Zarr3) else None ), @@ -473,7 +480,7 @@ def get_or_add_mag( file_len: Optional[int] = None, # deprecated ) -> MagView: """ - Creates a new mag called and adds it to the dataset, in case it did not exist before. + Creates a new mag and adds it to the dataset, in case it did not exist before. Then, returns the mag. See `add_mag` for more information. @@ -559,9 +566,11 @@ def add_copy_mag( chunk_shape=chunk_shape or foreign_mag_view._array_info.chunk_shape, chunks_per_shard=chunks_per_shard or foreign_mag_view._array_info.chunks_per_shard, - compress=compress - if compress is not None - else foreign_mag_view._array_info.compression_mode, + compress=( + compress + if compress is not None + else foreign_mag_view._array_info.compression_mode + ), ) if extend_layer_bounding_box: @@ -1075,7 +1084,7 @@ def upsample( # Restoring the original layer bbox self.bounding_box = old_layer_bbox - def _setup_mag(self, mag: Mag) -> None: + def _setup_mag(self, mag: Mag, path: Optional[str] = None) -> None: # This method is used to initialize the mag when opening the Dataset. This does not create e.g. the wk_header. mag_name = mag.to_layer_name() @@ -1085,7 +1094,7 @@ def _setup_mag(self, mag: Mag) -> None: try: cls_array = BaseArray.get_class(self._properties.data_format) info = cls_array.open( - _find_mag_path_on_disk(self.dataset.path, self.name, mag_name) + _find_mag_path_on_disk(self.dataset.path, self.name, mag_name, path) ).info self._mags[mag] = MagView( self, diff --git a/webknossos/webknossos/dataset/mag_view.py b/webknossos/webknossos/dataset/mag_view.py index c3e9f7e98..e7bf52c41 100644 --- a/webknossos/webknossos/dataset/mag_view.py +++ b/webknossos/webknossos/dataset/mag_view.py @@ -10,7 +10,7 @@ from cluster_tools import Executor from upath import UPath -from ..geometry import BoundingBox, Mag, Vec3Int, Vec3IntLike +from ..geometry import Mag, NDBoundingBox, Vec3Int, Vec3IntLike, VecInt from ..utils import ( NDArrayLike, get_executor_for_args, @@ -30,7 +30,12 @@ from .view import View -def _find_mag_path_on_disk(dataset_path: Path, layer_name: str, mag_name: str) -> Path: +def _find_mag_path_on_disk( + dataset_path: Path, layer_name: str, mag_name: str, path: Optional[str] = None +) -> Path: + if path is not None: + return dataset_path / path + mag = Mag(mag_name) short_mag_file_path = dataset_path / layer_name / mag.to_layer_name() long_mag_file_path = dataset_path / layer_name / mag.to_long_layer_name() @@ -73,6 +78,15 @@ def __init__( chunk_shape=chunk_shape, chunks_per_shard=chunks_per_shard, compression_mode=compression_mode, + axis_order=VecInt( + 0, *layer.bounding_box.index, axes=("c",) + layer.bounding_box.axes + ), + shape=VecInt( + layer.num_channels, + *VecInt.ones(layer.bounding_box.axes), + axes=("c",) + layer.bounding_box.axes, + ), + dimension_names=("c",) + layer.bounding_box.axes, ) if create: self_path = layer.dataset.path / layer.name / mag.to_layer_name() @@ -81,14 +95,14 @@ def __init__( super().__init__( _find_mag_path_on_disk(layer.dataset.path, layer.name, mag.to_layer_name()), array_info, - bounding_box=None, + bounding_box=layer.bounding_box, mag=mag, ) self._layer = layer # Overwrites of View methods: @property - def bounding_box(self) -> BoundingBox: + def bounding_box(self) -> NDBoundingBox: # Overwrites View's method since no extra bbox is stored for a MagView, # but the Layer's bbox is used: return self.layer.bounding_box.align_with_mag(self._mag, ceil=True) @@ -105,7 +119,7 @@ def global_offset(self) -> Vec3Int: return Vec3Int.zeros() @property - def size(self) -> Vec3Int: + def size(self) -> VecInt: """⚠️ Deprecated, use `mag_view.bounding_box.in_mag(mag_view.mag).bottomright` instead.""" warnings.warn( "[DEPRECATION] mag_view.size is deprecated. " @@ -149,6 +163,8 @@ def write( *, relative_offset: Optional[Vec3IntLike] = None, # in mag1 absolute_offset: Optional[Vec3IntLike] = None, # in mag1 + relative_bounding_box: Optional[NDBoundingBox] = None, # in mag1 + absolute_bounding_box: Optional[NDBoundingBox] = None, # in mag1 ) -> None: if offset is not None: if self._mag == Mag(1): @@ -166,14 +182,30 @@ def write( DeprecationWarning, ) - if all(i is None for i in [offset, absolute_offset, relative_offset]): + if all( + i is None + for i in [ + offset, + absolute_offset, + relative_offset, + absolute_bounding_box, + relative_bounding_box, + ] + ): relative_offset = Vec3Int.zeros() + if (absolute_bounding_box or relative_bounding_box) is not None: + data_shape = None + else: + data_shape = Vec3Int(data.shape[-3:]) + mag1_bbox = self._get_mag1_bbox( abs_current_mag_offset=offset, rel_mag1_offset=relative_offset, abs_mag1_offset=absolute_offset, - current_mag_size=Vec3Int(data.shape[-3:]), + abs_mag1_bbox=absolute_bounding_box, + rel_mag1_bbox=relative_bounding_box, + current_mag_size=data_shape, ) # Only update the layer's bbox if we are actually larger @@ -183,8 +215,8 @@ def write( super().write( data, - absolute_offset=mag1_bbox.topleft, json_update_allowed=json_update_allowed, + absolute_bounding_box=mag1_bbox, ) def read( @@ -196,8 +228,8 @@ def read( *, relative_offset: Optional[Vec3IntLike] = None, # in mag1 absolute_offset: Optional[Vec3IntLike] = None, # in mag1 - relative_bounding_box: Optional[BoundingBox] = None, # in mag1 - absolute_bounding_box: Optional[BoundingBox] = None, # in mag1 + relative_bounding_box: Optional[NDBoundingBox] = None, # in mag1 + absolute_bounding_box: Optional[NDBoundingBox] = None, # in mag1 ) -> np.ndarray: # THIS METHOD CAN BE REMOVED WHEN THE DEPRECATED OFFSET IS REMOVED @@ -237,12 +269,14 @@ def get_view( *, relative_offset: Optional[Vec3IntLike] = None, # in mag1 absolute_offset: Optional[Vec3IntLike] = None, # in mag1 + relative_bbox: Optional[NDBoundingBox] = None, # in mag1 + absolute_bbox: Optional[NDBoundingBox] = None, # in mag1 read_only: Optional[bool] = None, ) -> View: # THIS METHOD CAN BE REMOVED WHEN THE DEPRECATED OFFSET IS REMOVED # This has other defaults than the View implementation - # (all deprecations are handled in the subsclass) + # (all deprecations are handled in the superclass) bb = self.bounding_box.in_mag(self._mag) if offset is not None and size is None: offset = Vec3Int(offset) @@ -253,12 +287,14 @@ def get_view( size, relative_offset=relative_offset, absolute_offset=absolute_offset, + relative_bbox=relative_bbox, + absolute_bbox=absolute_bbox, read_only=read_only, ) def get_bounding_boxes_on_disk( self, - ) -> Iterator[BoundingBox]: + ) -> Iterator[NDBoundingBox]: """ Returns a Mag(1) bounding box for each file on disk. diff --git a/webknossos/webknossos/dataset/properties.py b/webknossos/webknossos/dataset/properties.py index 1cf285b48..72686bfdf 100644 --- a/webknossos/webknossos/dataset/properties.py +++ b/webknossos/webknossos/dataset/properties.py @@ -1,3 +1,4 @@ +import copy from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union @@ -6,7 +7,7 @@ import numpy as np from cattr.gen import make_dict_structure_fn, make_dict_unstructure_fn, override -from ..geometry import BoundingBox, Mag, Vec3Int +from ..geometry import Mag, NDBoundingBox, Vec3Int from ..utils import snake_to_camel_case, warn_deprecated from ._array import ArrayException, BaseArray, DataFormat from .layer_categories import LayerCategoryType @@ -119,6 +120,7 @@ class LayerViewConfiguration: @attr.define class MagViewProperties: mag: Mag + path: Optional[str] = None cube_length: Optional[int] = None axis_order: Optional[Dict[str, int]] = None @@ -128,11 +130,18 @@ def resolution(self) -> Mag: return self.mag +@attr.define +class AxisProperties: + name: str + bounds: Tuple[int, int] + index: int + + @attr.define class LayerProperties: name: str category: LayerCategoryType - bounding_box: BoundingBox + bounding_box: NDBoundingBox element_class: str data_format: DataFormat mags: List[MagViewProperties] @@ -169,10 +178,10 @@ class DatasetProperties: dataset_converter = cattr.Converter() # register (un-)structure hooks for non-attr-classes -bbox_to_wkw: Callable[[BoundingBox], dict] = lambda o: o.to_wkw_dict() # noqa: E731 -dataset_converter.register_unstructure_hook(BoundingBox, bbox_to_wkw) +bbox_to_wkw: Callable[[NDBoundingBox], dict] = lambda o: o.to_wkw_dict() # noqa: E731 +dataset_converter.register_unstructure_hook(NDBoundingBox, bbox_to_wkw) dataset_converter.register_structure_hook( - BoundingBox, lambda d, _: BoundingBox.from_wkw_dict(d) + NDBoundingBox, lambda d, _: NDBoundingBox.from_wkw_dict(d) ) @@ -234,13 +243,13 @@ def mag_unstructure(mag: Mag) -> List[int]: # The serialization of `LayerProperties` differs slightly based on whether it is a `wkw` or `zarr` layer. # These post-unstructure and pre-structure functions perform the conditional field renames. -def mag_view_properties_post_structure(d: Dict[str, Any]) -> Dict[str, Any]: +def mag_view_properties_post_unstructure(d: Dict[str, Any]) -> Dict[str, Any]: d["resolution"] = d["mag"] del d["mag"] return d -def mag_view_properties_pre_unstructure(d: Dict[str, Any]) -> Dict[str, Any]: +def mag_view_properties_pre_structure(d: Dict[str, Any]) -> Dict[str, Any]: d["mag"] = d["resolution"] del d["resolution"] return d @@ -257,9 +266,14 @@ def __layer_properties_post_unstructure( d = converter_fn(obj) if d["dataFormat"] == "wkw": d["wkwResolutions"] = [ - mag_view_properties_post_structure(m) for m in d["mags"] + mag_view_properties_post_unstructure(m) for m in d["mags"] ] del d["mags"] + + # json expects nd_bounding_box to be represented as bounding_box and additional_axes + if "additionalAxes" in d["boundingBox"]: + d["additionalAxes"] = d["boundingBox"]["additionalAxes"] + del d["boundingBox"]["additionalAxes"] return d return __layer_properties_post_unstructure @@ -280,9 +294,25 @@ def __layer_properties_pre_structure( ) -> Union[LayerProperties, SegmentationLayerProperties]: if d["dataFormat"] == "wkw": d["mags"] = [ - mag_view_properties_pre_unstructure(m) for m in d["wkwResolutions"] + mag_view_properties_pre_structure(m) for m in d["wkwResolutions"] ] del d["wkwResolutions"] + # bounding_box and additional_axes are internally handled as nd_bounding_box + if "additionalAxes" in d: + d["boundingBox"]["additionalAxes"] = copy.deepcopy(d["additionalAxes"]) + del d["additionalAxes"] + if len(d["mags"]) > 0: + first_mag = d["mags"][0] + if "axisOrder" in first_mag: + assert ( + first_mag["axisOrder"]["c"] == 0 + ), "The channels c must have index 0 in axis order." + assert all( + first_mag["axisOrder"] == mag["axisOrder"] for mag in d["mags"] + ) + d["boundingBox"]["axisOrder"] = copy.deepcopy(first_mag["axisOrder"]) + del d["boundingBox"]["axisOrder"]["c"] + obj = converter_fn(d, type_value) return obj diff --git a/webknossos/webknossos/dataset/view.py b/webknossos/webknossos/dataset/view.py index d6c77d6e3..2b9171b36 100644 --- a/webknossos/webknossos/dataset/view.py +++ b/webknossos/webknossos/dataset/view.py @@ -19,7 +19,9 @@ import wkw from cluster_tools import Executor -from ..geometry import BoundingBox, Mag, Vec3Int, Vec3IntLike +from webknossos.geometry.vec_int import VecInt + +from ..geometry import BoundingBox, Mag, NDBoundingBox, Vec3Int, Vec3IntLike from ..utils import ( get_executor_for_args, get_rich_progress, @@ -54,7 +56,7 @@ class View: _path: Path _array_info: ArrayInfo - _bounding_box: Optional[BoundingBox] + _bounding_box: Optional[NDBoundingBox] _read_only: bool _cached_array: Optional[BaseArray] _mag: Mag @@ -64,7 +66,7 @@ def __init__( path_to_mag_view: Path, array_info: ArrayInfo, bounding_box: Optional[ - BoundingBox + NDBoundingBox ], # in mag 1, absolute coordinates, optional only for mag_view since it overwrites the bounding_box property mag: Mag, read_only: bool = False, @@ -93,7 +95,7 @@ def header(self) -> wkw.Header: return self._array._wkw_dataset.header @property - def bounding_box(self) -> BoundingBox: + def bounding_box(self) -> NDBoundingBox: assert self._bounding_box is not None return self._bounding_box @@ -106,7 +108,7 @@ def read_only(self) -> bool: return self._read_only @property - def global_offset(self) -> Vec3Int: + def global_offset(self) -> VecInt: """⚠️ Deprecated, use `view.bounding_box.in_mag(view.mag).topleft` instead.""" warnings.warn( "[DEPRECATION] view.global_offset is deprecated. " @@ -117,7 +119,7 @@ def global_offset(self) -> Vec3Int: return self.bounding_box.in_mag(self._mag).topleft @property - def size(self) -> Vec3Int: + def size(self) -> VecInt: """⚠️ Deprecated, use `view.bounding_box.in_mag(view.mag).size` instead.""" warnings.warn( "[DEPRECATION] view.size is deprecated. " @@ -129,15 +131,15 @@ def size(self) -> Vec3Int: def _get_mag1_bbox( self, - abs_mag1_bbox: Optional[BoundingBox] = None, - rel_mag1_bbox: Optional[BoundingBox] = None, + abs_mag1_bbox: Optional[NDBoundingBox] = None, + rel_mag1_bbox: Optional[NDBoundingBox] = None, abs_mag1_offset: Optional[Vec3IntLike] = None, rel_mag1_offset: Optional[Vec3IntLike] = None, mag1_size: Optional[Vec3IntLike] = None, abs_current_mag_offset: Optional[Vec3IntLike] = None, rel_current_mag_offset: Optional[Vec3IntLike] = None, current_mag_size: Optional[Vec3IntLike] = None, - ) -> BoundingBox: + ) -> NDBoundingBox: num_bboxes = _count_defined_values([abs_mag1_bbox, rel_mag1_bbox]) num_offsets = _count_defined_values( [ @@ -167,7 +169,7 @@ def _get_mag1_bbox( if abs_mag1_bbox is not None: return abs_mag1_bbox - elif rel_mag1_bbox is not None: + if rel_mag1_bbox is not None: return rel_mag1_bbox.offset(self.bounding_box.topleft) else: @@ -187,7 +189,12 @@ def _get_mag1_bbox( assert abs_mag1_offset is not None, "No offset was supplied." assert mag1_size is not None, "No size was supplied." - return BoundingBox(Vec3Int(abs_mag1_offset), Vec3Int(mag1_size)) + + assert ( + len(self.bounding_box) == 3 + ), "The delivered offset and size are only usable for 3D views." + + return self.bounding_box.with_topleft(abs_mag1_offset).with_size(mag1_size) def write( self, @@ -197,9 +204,17 @@ def write( *, relative_offset: Optional[Vec3IntLike] = None, # in mag1 absolute_offset: Optional[Vec3IntLike] = None, # in mag1 + relative_bounding_box: Optional[NDBoundingBox] = None, # in mag1 + absolute_bounding_box: Optional[NDBoundingBox] = None, # in mag1 ) -> None: """ - Writes the `data` at the specified `relative_offset` or `absolute_offset`, both specified in Mag(1). + The user can specify where the data should be written. + The default is to write the data to the view's bounding box. + Alternatively, one can supply one of the following keywords: + * `relative_offset` in Mag(1) -> only usable for 3D datasets + * `absolute_offset` in Mag(1) -> only usable for 3D datasets + * `relative_bounding_box` in Mag(1) + * `absolute_bounding_box` in Mag(1) ⚠️ The `offset` parameter is deprecated. This parameter used to be relative for `View` and absolute for `MagView`, @@ -230,9 +245,23 @@ def write( """ assert not self.read_only, "Cannot write data to an read_only View" - if all(i is None for i in [offset, absolute_offset, relative_offset]): + if all( + i is None + for i in [ + offset, + absolute_offset, + relative_offset, + absolute_bounding_box, + relative_bounding_box, + ] + ): relative_offset = Vec3Int.zeros() + if (absolute_bounding_box or relative_bounding_box) is not None: + data_shape = None + else: + data_shape = Vec3Int(data.shape[-3:]) + if offset is not None: if self._mag == Mag(1): alternative = "Since this is a View in Mag(1), please use view.write(relative_offset=my_vec)" @@ -263,32 +292,31 @@ def write( rel_current_mag_offset=offset, rel_mag1_offset=relative_offset, abs_mag1_offset=absolute_offset, - current_mag_size=Vec3Int(data.shape[-3:]), + rel_mag1_bbox=relative_bounding_box, + abs_mag1_bbox=absolute_bounding_box, + current_mag_size=data_shape, ) if json_update_allowed: assert self.bounding_box.contains_bbox( mag1_bbox ), f"The bounding box to write {mag1_bbox} is larger than the view's bounding box {self.bounding_box}" - if len(data.shape) == 4 and data.shape[0] == 1: - data = data[0] # remove channel dimension for single-channel data - current_mag_bbox = mag1_bbox.in_mag(self._mag) if self._is_compressed(): for current_mag_bbox, chunked_data in self._prepare_compressed_write( current_mag_bbox, data, json_update_allowed ): - self._array.write(current_mag_bbox.topleft, chunked_data) + self._array.write(current_mag_bbox, chunked_data) else: - self._array.write(current_mag_bbox.topleft, data) + self._array.write(current_mag_bbox, data) def _prepare_compressed_write( self, - current_mag_bbox: BoundingBox, + current_mag_bbox: NDBoundingBox, data: np.ndarray, json_update_allowed: bool = True, - ) -> Iterator[Tuple[BoundingBox, np.ndarray]]: + ) -> Iterator[Tuple[NDBoundingBox, np.ndarray]]: """This method takes an arbitrary sized chunk of data with an accompanying bbox, divides these into chunks of shard_shape size and delegates the preparation to _prepare_compressed_write_chunk.""" @@ -314,10 +342,10 @@ def _prepare_compressed_write( def _prepare_compressed_write_chunk( self, - current_mag_bbox: BoundingBox, + current_mag_bbox: NDBoundingBox, data: np.ndarray, json_update_allowed: bool = True, - ) -> Tuple[BoundingBox, np.ndarray]: + ) -> Tuple[NDBoundingBox, np.ndarray]: """This method takes an arbitrary sized chunk of data with an accompanying bbox (ideally not larger than a shard) and enlarges that chunk to fit the shard it resides in (by reading the entire shard data and writing the passed data ndarray @@ -358,8 +386,8 @@ def read( *, relative_offset: Optional[Vec3IntLike] = None, # in mag1 absolute_offset: Optional[Vec3IntLike] = None, # in mag1 - relative_bounding_box: Optional[BoundingBox] = None, # in mag1 - absolute_bounding_box: Optional[BoundingBox] = None, # in mag1 + relative_bounding_box: Optional[NDBoundingBox] = None, # in mag1 + absolute_bounding_box: Optional[NDBoundingBox] = None, # in mag1 ) -> np.ndarray: """ The user can specify which data should be read. @@ -400,8 +428,9 @@ def read( assert ( relative_offset is None and absolute_offset is None ), "You must supply size, when reading with an offset." + absolute_bounding_box = self.bounding_box current_mag_size = None - mag1_size = self.bounding_box.size + mag1_size = None else: if relative_offset is None and absolute_offset is None: if type(self) == View: @@ -440,14 +469,26 @@ def read( ) if size is None: + absolute_bounding_box = self.bounding_box.offset( + self._mag.to_vec3_int() * offset + ) + offset = None current_mag_size = None - mag1_size = self.bounding_box.size + mag1_size = None else: # (deprecated) offset and size are given current_mag_size = size mag1_size = None - if all(i is None for i in [offset, absolute_offset, relative_offset]): + if all( + i is None + for i in [ + offset, + absolute_offset, + relative_offset, + absolute_bounding_box, + ] + ): relative_offset = Vec3Int.zeros() else: assert ( @@ -507,11 +548,9 @@ def read_bbox(self, bounding_box: Optional[BoundingBox] = None) -> np.ndarray: def _read_without_checks( self, - current_mag_bbox: BoundingBox, + current_mag_bbox: NDBoundingBox, ) -> np.ndarray: - data = self._array.read( - current_mag_bbox.topleft.to_np(), current_mag_bbox.size.to_np() - ) + data = self._array.read(current_mag_bbox) return data def get_view( @@ -521,6 +560,8 @@ def get_view( *, relative_offset: Optional[Vec3IntLike] = None, # in mag1 absolute_offset: Optional[Vec3IntLike] = None, # in mag1 + relative_bbox: Optional[NDBoundingBox] = None, # in mag1 + absolute_bbox: Optional[NDBoundingBox] = None, # in mag1 read_only: Optional[bool] = None, ) -> "View": """ @@ -617,6 +658,8 @@ def get_view( relative_offset = Vec3Int.zeros() mag1_bbox = self._get_mag1_bbox( + abs_mag1_bbox=absolute_bbox, + rel_mag1_bbox=relative_bbox, rel_current_mag_offset=offset, rel_mag1_offset=relative_offset, abs_mag1_offset=absolute_offset, @@ -670,6 +713,8 @@ def get_buffered_slice_writer( *, relative_offset: Optional[Vec3IntLike] = None, # in mag1 absolute_offset: Optional[Vec3IntLike] = None, # in mag1 + relative_bounding_box: Optional[NDBoundingBox] = None, # in mag1 + absolute_bounding_box: Optional[NDBoundingBox] = None, # in mag1 use_logging: bool = False, ) -> "BufferedSliceWriter": """ @@ -680,6 +725,8 @@ def get_buffered_slice_writer( * The user can specify where the writer should start: * `relative_offset` in Mag(1) * `absolute_offset` in Mag(1) + * `relative_bounding_box` in Mag(1) + * `absolute_bounding_box` in Mag(1) * ⚠️ deprecated: `offset` in the current Mag, used to be relative for `View` and absolute for `MagView` * `buffer_size`: amount of slices that get buffered @@ -713,6 +760,8 @@ def get_buffered_slice_writer( dimension=dimension, relative_offset=relative_offset, absolute_offset=absolute_offset, + relative_bounding_box=relative_bounding_box, + absolute_bounding_box=absolute_bounding_box, use_logging=use_logging, ) @@ -723,8 +772,8 @@ def get_buffered_slice_reader( buffer_size: int = 32, dimension: int = 2, # z *, - relative_bounding_box: Optional[BoundingBox] = None, # in mag1 - absolute_bounding_box: Optional[BoundingBox] = None, # in mag1 + relative_bounding_box: Optional[NDBoundingBox] = None, # in mag1 + absolute_bounding_box: Optional[NDBoundingBox] = None, # in mag1 use_logging: bool = False, ) -> "BufferedSliceReader": """ @@ -1088,7 +1137,7 @@ def _get_file_dimensions(self) -> Vec3Int: return self.info.shard_shape def _get_file_dimensions_mag1(self) -> Vec3Int: - return self._get_file_dimensions() * self.mag.to_vec3_int() + return Vec3Int(self._get_file_dimensions() * self.mag.to_vec3_int()) @property def _array(self) -> BaseArray: diff --git a/webknossos/webknossos/geometry/__init__.py b/webknossos/webknossos/geometry/__init__.py index 328823d9e..301d074b1 100644 --- a/webknossos/webknossos/geometry/__init__.py +++ b/webknossos/webknossos/geometry/__init__.py @@ -2,4 +2,6 @@ from .bounding_box import BoundingBox from .mag import Mag +from .nd_bounding_box import NDBoundingBox from .vec3_int import Vec3Int, Vec3IntLike +from .vec_int import VecInt, VecIntLike diff --git a/webknossos/webknossos/geometry/bounding_box.py b/webknossos/webknossos/geometry/bounding_box.py index a9e0010f1..75038a40e 100644 --- a/webknossos/webknossos/geometry/bounding_box.py +++ b/webknossos/webknossos/geometry/bounding_box.py @@ -1,6 +1,5 @@ import json import re -from collections import defaultdict from typing import ( Callable, Dict, @@ -17,13 +16,14 @@ import numpy as np from .mag import Mag +from .nd_bounding_box import NDBoundingBox from .vec3_int import Vec3Int, Vec3IntLike _DEFAULT_BBOX_NAME = "Unnamed Bounding Box" @attr.frozen -class BoundingBox: +class BoundingBox(NDBoundingBox): """ This class is used to represent an axis-aligned cuboid in 3D. The top-left coordinate is inclusive and the bottom-right coordinate is exclusive. @@ -42,6 +42,8 @@ class BoundingBox: topleft: Vec3Int = attr.field(converter=Vec3Int) size: Vec3Int = attr.field(converter=Vec3Int) + axes: Tuple[str, str, str] = attr.field(default=("x", "y", "z")) + index: Vec3Int = attr.field(default=Vec3Int(1, 2, 3)) bottomright: Vec3Int = attr.field(init=False) name: Optional[str] = _DEFAULT_BBOX_NAME is_visible: bool = True @@ -61,61 +63,26 @@ def __attrs_post_init__(self) -> None: # it is needed. object.__setattr__(self, "bottomright", self.topleft + self.size) - def with_topleft(self, new_topleft: Vec3IntLike) -> "BoundingBox": - return attr.evolve(self, topleft=new_topleft) - - def with_size(self, new_size: Vec3IntLike) -> "BoundingBox": - return attr.evolve(self, size=new_size) - - def with_name(self, name: Optional[str]) -> "BoundingBox": - return attr.evolve(self, name=name) - - def with_is_visible(self, is_visible: bool) -> "BoundingBox": - return attr.evolve(self, is_visible=is_visible) - - def with_color( - self, color: Optional[Tuple[float, float, float, float]] - ) -> "BoundingBox": - return attr.evolve(self, color=color) - def with_bounds_x( self, new_topleft_x: Optional[int] = None, new_size_x: Optional[int] = None ) -> "BoundingBox": """Returns a copy of the bounding box with topleft.x optionally replaced and size.x optionally replaced.""" - new_topleft = ( - self.topleft.with_x(new_topleft_x) - if new_topleft_x is not None - else self.topleft - ) - new_size = self.size.with_x(new_size_x) if new_size_x is not None else self.size - return attr.evolve(self, topleft=new_topleft, size=new_size) + return cast(BoundingBox, self.with_bounds("x", new_topleft_x, new_size_x)) def with_bounds_y( self, new_topleft_y: Optional[int] = None, new_size_y: Optional[int] = None ) -> "BoundingBox": """Returns a copy of the bounding box with topleft.y optionally replaced and size.y optionally replaced.""" - new_topleft = ( - self.topleft.with_y(new_topleft_y) - if new_topleft_y is not None - else self.topleft - ) - new_size = self.size.with_y(new_size_y) if new_size_y is not None else self.size - return attr.evolve(self, topleft=new_topleft, size=new_size) + return cast(BoundingBox, self.with_bounds("y", new_topleft_y, new_size_y)) def with_bounds_z( self, new_topleft_z: Optional[int] = None, new_size_z: Optional[int] = None ) -> "BoundingBox": """Returns a copy of the bounding box with topleft.z optionally replaced and size.z optionally replaced.""" - new_topleft = ( - self.topleft.with_z(new_topleft_z) - if new_topleft_z is not None - else self.topleft - ) - new_size = self.size.with_z(new_size_z) if new_size_z is not None else self.size - return attr.evolve(self, topleft=new_topleft, size=new_size) + return cast(BoundingBox, self.with_bounds("z", new_topleft_z, new_size_z)) @classmethod def from_wkw_dict(cls, bbox: Dict) -> "BoundingBox": @@ -162,6 +129,10 @@ def from_csv(cls, csv_bbox: str) -> "BoundingBox": bbox_tuple = tuple(int(x) for x in csv_bbox.split(",")) return cls.from_tuple6(cast(Tuple[int, int, int, int, int, int], bbox_tuple)) + @classmethod + def from_ndbbox(cls, bbox: NDBoundingBox) -> "BoundingBox": + return cls(bbox.topleft_xyz, bbox.size_xyz) + @classmethod def from_auto( cls, obj: Union["BoundingBox", str, Dict, List, Tuple] @@ -185,25 +156,6 @@ def from_auto( raise Exception("Unknown bounding box format.") - @classmethod - def group_boxes_with_aligned_mag( - cls, bounding_boxes: Iterable["BoundingBox"], aligning_mag: Mag - ) -> Dict["BoundingBox", List["BoundingBox"]]: - """ - Groups the given BoundingBox instances by aligning each - bbox to the given mag and using that as the key. - For example, bounding boxes of size 256**3 could be grouped - into the corresponding 1024**3 chunks to which they belong - by using aligning_mag = Mag(1024). - """ - - chunks_with_bboxes = defaultdict(list) - for bbox in bounding_boxes: - chunk_key = bbox.align_with_mag(aligning_mag, ceil=True) - chunks_with_bboxes[chunk_key].append(bbox) - - return chunks_with_bboxes - @classmethod def empty( cls, @@ -240,19 +192,15 @@ def to_tuple6(self) -> Tuple[int, int, int, int, int, int]: def to_csv(self) -> str: return ",".join(map(str, self.to_tuple6())) - def __repr__(self) -> str: - return "BoundingBox(topleft={}, size={})".format( - str(tuple(self.topleft)), str(tuple(self.size)) - ) - - def __str__(self) -> str: - return self.__repr__() - def __eq__(self, other: object) -> bool: - if isinstance(other, BoundingBox): + if isinstance(other, NDBoundingBox): + self._check_compatibility(other) return self.topleft == other.topleft and self.size == other.size - else: - raise NotImplementedError() + + raise NotImplementedError() + + def __repr__(self) -> str: + return f"BoundingBox(topleft={self.topleft.to_tuple()}, size={self.size.to_tuple()})" def padded_with_margins( self, margins_left: Vec3IntLike, margins_right: Optional[Vec3IntLike] = None @@ -269,36 +217,6 @@ def padded_with_margins( size=self.size + (margins_left + margins_right), ) - def intersected_with( - self, other: "BoundingBox", dont_assert: bool = False - ) -> "BoundingBox": - """If dont_assert is set to False, this method may return empty bounding boxes (size == (0, 0, 0))""" - - topleft = self.topleft.pairmax(other.topleft) - bottomright = self.bottomright.pairmin(other.bottomright) - size = (bottomright - topleft).pairmax(Vec3Int.zeros()) - - intersection = attr.evolve(self, topleft=topleft, size=size) - - if not dont_assert: - assert ( - not intersection.is_empty() - ), f"No intersection between bounding boxes {self} and {other}." - - return intersection - - def extended_by(self, other: "BoundingBox") -> "BoundingBox": - if self.is_empty(): - return other - if other.is_empty(): - return self - - topleft = self.topleft.pairmin(other.topleft) - bottomright = self.bottomright.pairmax(other.bottomright) - size = bottomright - topleft - - return attr.evolve(self, topleft=topleft, size=size) - def is_empty(self) -> bool: return not self.size.is_positive(strictly_positive=True) @@ -318,14 +236,6 @@ def in_mag(self, mag: Mag) -> "BoundingBox": size=(self.size // mag_vec), ) - def from_mag_to_mag1(self, from_mag: Mag) -> "BoundingBox": - mag_vec = from_mag.to_vec3_int() - return attr.evolve( - self, - topleft=(self.topleft * mag_vec), - size=(self.size * mag_vec), - ) - def _align_with_mag_slow(self, mag: Mag, ceil: bool = False) -> "BoundingBox": """Rounds the bounding box, so that both topleft and bottomright are divisible by mag. @@ -394,9 +304,6 @@ def contains(self, coord: Union[Vec3IntLike, np.ndarray]) -> bool: and self.topleft[2] <= coord[2] < self.bottomright[2] ) - def contains_bbox(self, inner_bbox: "BoundingBox") -> bool: - return inner_bbox.intersected_with(self, dont_assert=True) == inner_bbox - def chunk( self, chunk_shape: Vec3IntLike, @@ -433,24 +340,10 @@ def chunk( for z in range( start[2] - start_adjust[2], start[2] + self.size[2], chunk_shape[2] ): - yield BoundingBox([x, y, z], chunk_shape).intersected_with(self) - - def volume(self) -> int: - return self.size.prod() - - def slice_array(self, array: np.ndarray) -> np.ndarray: - return array[ - self.topleft.x : self.bottomright.x, - self.topleft.y : self.bottomright.y, - self.topleft.z : self.bottomright.z, - ] - - def to_slices(self) -> Tuple[slice, slice, slice]: - return np.index_exp[ - self.topleft.x : self.bottomright.x, - self.topleft.y : self.bottomright.y, - self.topleft.z : self.bottomright.z, - ] + yield cast( + BoundingBox, + BoundingBox([x, y, z], chunk_shape).intersected_with(self), + ) def offset(self, vector: Vec3IntLike) -> "BoundingBox": return attr.evolve(self, topleft=self.topleft + Vec3Int(vector)) diff --git a/webknossos/webknossos/geometry/nd_bounding_box.py b/webknossos/webknossos/geometry/nd_bounding_box.py new file mode 100644 index 000000000..def386881 --- /dev/null +++ b/webknossos/webknossos/geometry/nd_bounding_box.py @@ -0,0 +1,849 @@ +from collections import defaultdict +from itertools import product +from typing import ( + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) + +import attr +import numpy as np + +from .mag import Mag +from .vec3_int import Vec3Int, Vec3IntLike +from .vec_int import VecInt, VecIntLike + +_DEFAULT_BBOX_NAME = "Unnamed Bounding Box" + +_T = TypeVar("_T", bound="NDBoundingBox") + + +def str_tpl(str_list: Iterable[str]) -> Tuple[str, ...]: + # Fix for mypy bug https://github.com/python/mypy/issues/5313. + # Solution based on other issue for the same bug: https://github.com/python/mypy/issues/8389. + return tuple(str_list) + + +def int_tpl(vec_int_like: VecIntLike) -> VecInt: + return VecInt( + vec_int_like, axes=(f"unset_{i}" for i in range(len(list(vec_int_like)))) + ) + + +@attr.frozen +class NDBoundingBox: + """ + The NDBoundingBox class is a generalized version of the 3-dimensional BoundingBox class. It is designed to represent bounding boxes in any number of dimensions. + + The bounding box is characterized by its top-left corner, the size of the box, the names of the axes for each dimension, and the index (or order) of the axes. Each axis must have a unique index, starting from 1 (index 0 is reserved for channel information). + + The top-left coordinate is inclusive, while the bottom-right coordinate is exclusive. + + Here's a brief example of how to use it: + + ```python + + # Create a 2D bounding box + bbox_1 = NDBoundingBox( + top_left=(0, 0), + size=(100, 100), + axes=("x", "y"), + index=(1,2) + ) + + # Create a 4D bounding box + bbox_2 = NDBoundingBox( + top_left=(75, 75, 75, 0), + size=(100, 100, 100, 20), + axes=("x", "y", "z", "t"), + index=(2,3,4,1) + ) + ``` + """ + + topleft: VecInt = attr.field(converter=int_tpl) + size: VecInt = attr.field(converter=int_tpl) + axes: Tuple[str, ...] = attr.field(converter=str_tpl) + index: VecInt = attr.field(converter=int_tpl) + bottomright: VecInt = attr.field(init=False) + name: Optional[str] = _DEFAULT_BBOX_NAME + is_visible: bool = True + color: Optional[Tuple[float, float, float, float]] = None + + def __attrs_post_init__(self) -> None: + assert ( + len(self.topleft) == len(self.size) == len(self.axes) == len(self.index) + ), ( + f"The dimensions of topleft, size, axes and index ({len(self.topleft)}, " + + f"{len(self.size)}, {len(self.axes)} and {len(self.index)}) do not match." + ) + assert 0 not in self.index, "Index 0 is reserved for channels." + + # Convert the delivered tuples to VecInts + object.__setattr__(self, "topleft", VecInt(self.topleft, axes=self.axes)) + object.__setattr__(self, "size", VecInt(self.size, axes=self.axes)) + object.__setattr__(self, "index", VecInt(self.index, axes=self.axes)) + + if not self._is_sorted(): + self._sort_positions_of_axes() + + if not self.size.is_positive(): + # Flip the size in negative dimensions, so that the topleft is smaller than bottomright. + # E.g. BoundingBox((10, 10, 10), (-5, 5, 5)) -> BoundingBox((5, 10, 10), (5, 5, 5)). + negative_size = tuple(min(0, value) for value in self.size) + new_topleft = tuple( + val1 + val2 for val1, val2 in zip(self.topleft, negative_size) + ) + new_size = (abs(value) for value in self.size) + object.__setattr__(self, "topleft", VecInt(new_topleft, axes=self.axes)) + object.__setattr__(self, "size", VecInt(new_size, axes=self.axes)) + + # Compute bottomright to avoid that it's recomputed every time + # it is needed. + object.__setattr__( + self, + "bottomright", + self.topleft + self.size, + ) + + def _sort_positions_of_axes(self) -> None: + # Bring topleft and size in required order + # defined in axisOrder and index of additionalAxes + + size, topleft, axes, index = zip( + *sorted( + zip(self.size, self.topleft, self.axes, self.index), key=lambda x: x[3] + ) + ) + object.__setattr__(self, "size", VecInt(size, axes=axes)) + object.__setattr__(self, "topleft", VecInt(topleft, axes=axes)) + object.__setattr__(self, "axes", axes) + object.__setattr__(self, "index", VecInt(index, axes=axes)) + + def _is_sorted(self) -> bool: + return all(self.index[i - 1] < self.index[i] for i in range(1, len(self.index))) + + def with_name(self: _T, name: Optional[str]) -> _T: + """ + Returns a new instance of `NDBoundingBox` with the specified name. + + Args: + - name (Optional[str]): The name to assign to the new `NDBoundingBox` instance. + + Returns: + - NDBoundingBox: A new instance of `NDBoundingBox` with the specified name. + """ + return attr.evolve(self, name=name) + + def with_topleft(self: _T, new_topleft: VecIntLike) -> _T: + """ + Returns a new NDBoundingBox object with the specified top left coordinates. + + Args: + - new_topleft (VecIntLike): The new top left coordinates for the bounding box. + + Returns: + - NDBoundingBox: A new NDBoundingBox object with the updated top left coordinates. + """ + return attr.evolve(self, topleft=VecInt(new_topleft, axes=self.axes)) + + def with_size(self: _T, new_size: VecIntLike) -> _T: + """ + Returns a new NDBoundingBox object with the specified size. + + Args: + - new_size (VecIntLike): The new size of the bounding box. Can be a VecInt or any object that can be converted to a VecInt. + + Returns: + - A new NDBoundingBox object with the specified size. + """ + return attr.evolve(self, size=VecInt(new_size, axes=self.axes)) + + def with_index(self: _T, new_index: VecIntLike) -> _T: + """ + Returns a new NDBoundingBox object with the specified index. + + Args: + - new_index (VecIntLike): The new axis order for the bounding box. + + Returns: + - NDBoundingBox: A new NDBoundingBox object with the updated index. + """ + axes, _ = zip(*sorted(zip(self.axes, new_index), key=lambda x: x[1])) + return attr.evolve(self, index=VecInt(new_index, axes=axes)) + + def with_bottomright(self: _T, new_bottomright: VecIntLike) -> _T: + """ + Returns a new NDBoundingBox with an updated bottomright value. + + Args: + - new_bottomright (VecIntLike): The new bottom right corner coordinates. + + Returns: + - NDBoundingBox: A new NDBoundingBox object with the updated bottom right corner. + """ + new_size = VecInt(new_bottomright, axes=self.axes) - self.topleft + + return self.with_size(new_size) + + def with_is_visible(self: _T, is_visible: bool) -> _T: + """ + Returns a new NDBoundingBox object with the specified visibility. + + Args: + - is_visible (bool): The visibility value to set. + + Returns: + - NDBoundingBox: A new NDBoundingBox object with the updated visibility value. + """ + return attr.evolve(self, is_visible=is_visible) + + def with_color(self: _T, color: Optional[Tuple[float, float, float, float]]) -> _T: + """ + Returns a new instance of NDBoundingBox with the specified color. + + Args: + - color (Optional[Tuple[float, float, float, float]]): The color to set for the bounding box. + The color should be specified as a tuple of four floats representing RGBA values. + + Returns: + - NDBoundingBox: A new instance of NDBoundingBox with the specified color. + """ + return attr.evolve(self, color=color) + + def with_bounds( + self: _T, axis: str, new_topleft: Optional[int], new_size: Optional[int] + ) -> _T: + """ + Returns a new NDBoundingBox object with updated bounds along the specified axis. + + Args: + - axis (str): The name of the axis to update. + - new_topleft (Optional[int]): The new value for the top-left coordinate along the specified axis. + - new_size (Optional[int]): The new size along the specified axis. + + Returns: + - NDBoundingBox: A new NDBoundingBox object with updated bounds. + + Raises: + - ValueError: If the given axis name does not exist. + + """ + try: + index = self.axes.index(axis) + except ValueError as err: + raise ValueError("The given axis name does not exist.") from err + + _new_topleft = ( + self.topleft.with_replaced(index, new_topleft) + if new_topleft is not None + else self.topleft + ) + _new_size = ( + self.size.with_replaced(index, new_size) + if new_size is not None + else self.size + ) + + return attr.evolve(self, topleft=_new_topleft, size=_new_size) + + def get_bounds(self, axis: str) -> Tuple[int, int]: + """ + Returns the bounds of the given axis. + + Args: + - axis (str): The name of the axis to get the bounds for. + + Returns: + - Tuple[int, int]: A tuple containing the top-left and bottom-right coordinates along the specified axis. + """ + try: + index = self.axes.index(axis) + except ValueError as err: + raise ValueError("The given axis name does not exist.") from err + + return (self.topleft[index], self.topleft[index] + self.size[index]) + + @classmethod + def group_boxes_with_aligned_mag( + cls, bounding_boxes: Iterable["NDBoundingBox"], aligning_mag: Mag + ) -> Dict["NDBoundingBox", List["NDBoundingBox"]]: + """ + Groups the given BoundingBox instances by aligning each + bbox to the given mag and using that as the key. + For example, bounding boxes of size 256**3 could be grouped + into the corresponding 1024**3 chunks to which they belong + by using aligning_mag = Mag(1024). + """ + + chunks_with_bboxes = defaultdict(list) + for bbox in bounding_boxes: + chunk_key = bbox.align_with_mag(aligning_mag, ceil=True) + chunks_with_bboxes[chunk_key].append(bbox) + + return chunks_with_bboxes + + @classmethod + def from_wkw_dict(cls, bbox: Dict) -> "NDBoundingBox": + """ + Create an instance of NDBoundingBox from a dictionary representation. + + Args: + - bbox (Dict): The dictionary representation of the bounding box. + + Returns: + - NDBoundingBox: An instance of NDBoundingBox. + + Raises: + - AssertionError: If additionalAxes are present but axisOrder is not provided. + """ + + topleft: Tuple[int, ...] = bbox["topLeft"] + size: Tuple[int, ...] = (bbox["width"], bbox["height"], bbox["depth"]) + axes: Tuple[str, ...] = ("x", "y", "z") + index: Tuple[int, ...] = (1, 2, 3) + + if "axisOrder" in bbox: + axes = tuple(bbox["axisOrder"].keys()) + index = tuple(bbox["axisOrder"][axis] for axis in axes) + + if "additionalAxes" in bbox: + assert ( + "axisOrder" in bbox + ), "If there are additionalAxes an axisOrder needs to be provided." + for axis in bbox["additionalAxes"]: + topleft += (axis["bounds"][0],) + size += (axis["bounds"][1] - axis["bounds"][0],) + axes += (axis["name"],) + index += (axis["index"],) + + return cls( + topleft=VecInt(topleft, axes=axes), + size=VecInt(size, axes=axes), + axes=axes, + index=VecInt(index, axes=axes), + ) + + def to_wkw_dict(self) -> dict: + """ + Converts the bounding box object to a json dictionary. + + Returns: + - dict: A json dictionary representing the bounding box. + """ + topleft = [None, None, None] + width, height, depth = None, None, None + additional_axes = [] + for i, axis in enumerate(self.axes): + if axis == "x": + topleft[0] = self.topleft[i] + width = self.size[i] + elif axis == "y": + topleft[1] = self.topleft[i] + height = self.size[i] + elif axis == "z": + topleft[2] = self.topleft[i] + depth = self.size[i] + else: + additional_axes.append( + { + "name": axis, + "bounds": [self.topleft[i], self.bottomright[i]], + "index": self.index[i], + } + ) + if additional_axes: + return { + "topLeft": topleft, + "width": width, + "height": height, + "depth": depth, + "additionalAxes": additional_axes, + } + return { + "topLeft": topleft, + "width": width, + "height": height, + "depth": depth, + } + + def to_config_dict(self) -> dict: + """ + Returns a dictionary representation of the bounding box. + + Returns: + - dict: A dictionary representation of the bounding box. + """ + return { + "topleft": self.topleft.to_list(), + "size": self.size.to_list(), + "axes": self.axes, + } + + def to_checkpoint_name(self) -> str: + """ + Returns a string representation of the bounding box that can be used as a checkpoint name. + + Returns: + - str: A string representation of the bounding box. + """ + return f"{'_'.join(str(element) for element in self.topleft)}_{'_'.join(str(element) for element in self.size)}" + + def __repr__(self) -> str: + return f"NDBoundingBox(topleft={self.topleft.to_tuple()}, size={self.size.to_tuple()}, axes={self.axes})" + + def __str__(self) -> str: + return self.__repr__() + + def __eq__(self, other: object) -> bool: + if isinstance(other, NDBoundingBox): + self._check_compatibility(other) + return self.topleft == other.topleft and self.size == other.size + + raise NotImplementedError() + + def __len__(self) -> int: + return len(self.axes) + + def get_shape(self, axis_name: str) -> int: + """ + Returns the size of the bounding box along the specified axis. + + Args: + - axis_name (str): The name of the axis to get the size for. + + Returns: + - int: The size of the bounding box along the specified axis. + """ + try: + index = self.axes.index(axis_name) + return self.size[index] + except ValueError as err: + raise ValueError( + f"Axis {axis_name} doesn't exist in NDBoundingBox." + ) from err + + def _get_attr_xyz(self, attr_name: str) -> Vec3Int: + axes = ("x", "y", "z") + attr_3d = [] + + for axis in axes: + index = self.axes.index(axis) + attr_3d.append(getattr(self, attr_name)[index]) + + return Vec3Int(attr_3d) + + def _get_attr_with_replaced_xyz(self, attr_name: str, xyz: Vec3IntLike) -> VecInt: + value = Vec3Int(xyz) + axes = ("x", "y", "z") + modified_attr = getattr(self, attr_name).to_list() + + for i, axis in enumerate(axes): + index = self.axes.index(axis) + modified_attr[index] = value[i] + + return VecInt(modified_attr, axes=self.axes) + + @property + def topleft_xyz(self) -> Vec3Int: + """The topleft corner of the bounding box regarding only x, y and z axis.""" + + return self._get_attr_xyz("topleft") + + @property + def size_xyz(self) -> Vec3Int: + """The size of the bounding box regarding only x, y and z axis.""" + + return self._get_attr_xyz("size") + + @property + def bottomright_xyz(self) -> Vec3Int: + """The bottomright corner of the bounding box regarding only x, y and z axis.""" + + return self._get_attr_xyz("bottomright") + + @property + def index_xyz(self) -> Vec3Int: + """The index of x, y and z axis within the bounding box.""" + + return self._get_attr_xyz("index") + + def with_topleft_xyz(self: _T, new_xyz: Vec3IntLike) -> _T: + """ + Returns a new NDBoundingBox object with changed x, y and z coordinates of the topleft corner. + + Args: + - new_xyz (Vec3IntLike): The new x, y and z coordinates for the topleft corner. + + Returns: + - NDBoundingBox: A new NDBoundingBox object with the updated x, y and z coordinates of the topleft corner. + """ + new_topleft = self._get_attr_with_replaced_xyz("topleft", new_xyz) + + return self.with_topleft(new_topleft) + + def with_size_xyz(self: _T, new_xyz: Vec3IntLike) -> _T: + """ + Returns a new NDBoundingBox object with changed x, y and z size. + + Args: + - new_xyz (Vec3IntLike): The new x, y and z size for the bounding box. + + Returns: + - NDBoundingBox: A new NDBoundingBox object with the updated x, y and z size. + """ + new_size = self._get_attr_with_replaced_xyz("size", new_xyz) + + return self.with_size(new_size) + + def with_bottomright_xyz(self: _T, new_xyz: Vec3IntLike) -> _T: + """ + Returns a new NDBoundingBox object with changed x, y and z coordinates of the bottomright corner. + + Args: + - new_xyz (Vec3IntLike): The new x, y and z coordinates for the bottomright corner. + + Returns: + - NDBoundingBox: A new NDBoundingBox object with the updated x, y and z coordinates of the bottomright corner. + """ + new_bottomright = self._get_attr_with_replaced_xyz("bottomright", new_xyz) + + return self.with_bottomright(new_bottomright) + + def with_index_xyz(self: _T, new_xyz: Vec3IntLike) -> _T: + """ + Returns a new NDBoundingBox object with changed x, y and z index. + + Args: + - new_xyz (Vec3IntLike): The new x, y and z index for the bounding box. + + Returns: + - NDBoundingBox: A new NDBoundingBox object with the updated x, y and z index. + """ + new_index = self._get_attr_with_replaced_xyz("index", new_xyz) + + return self.with_index(new_index) + + def _check_compatibility(self, other: "NDBoundingBox") -> None: + """Checks if two bounding boxes are comparable. To be comparable they need the same number of axes, with same names and same order.""" + + if self.axes != other.axes: + raise ValueError( + f"Operation with two bboxes is only possible if they have the same axes and axes order. {self.axes} != {other.axes}" + ) + + def padded_with_margins( + self, margins_left: VecIntLike, margins_right: Optional[VecIntLike] = None + ) -> "NDBoundingBox": + raise NotImplementedError() + + def intersected_with(self: _T, other: _T, dont_assert: bool = False) -> _T: + """ + Returns the intersection of two bounding boxes. + + If dont_assert is set to False, this method may return empty bounding boxes (size == (0, 0, 0)) + + Args: + - other (NDBoundingBox): The other bounding box to intersect with. + - dont_assert (bool): If True, the method may return empty bounding boxes. + + Returns: + - NDBoundingBox: The intersection of the two bounding boxes. + """ + + self._check_compatibility(other) + topleft = self.topleft.pairmax(other.topleft) + bottomright = self.bottomright.pairmin(other.bottomright) + size = (bottomright - topleft).pairmax(VecInt.zeros(self.axes)) + + intersection = attr.evolve(self, topleft=topleft, size=size) + + if not dont_assert: + assert ( + not intersection.is_empty() + ), f"No intersection between bounding boxes {self} and {other}." + + return intersection + + def extended_by(self: _T, other: _T) -> _T: + """ + Returns the smallest bounding box that contains both bounding boxes. + + Args: + - other (NDBoundingBox): The other bounding box to extend with. + + Returns: + - NDBoundingBox: The smallest bounding box that contains both bounding boxes. + """ + self._check_compatibility(other) + if self.is_empty(): + return other + if other.is_empty(): + return self + + topleft = self.topleft.pairmin(other.topleft) + bottomright = self.bottomright.pairmax(other.bottomright) + size = bottomright - topleft + + return attr.evolve(self, topleft=topleft, size=size) + + def is_empty(self) -> bool: + """ + Boolean check whether the boundung box is empty. + + Returns: + - bool: True if the bounding box is empty, False otherwise. + """ + return not self.size.is_positive(strictly_positive=True) + + def in_mag(self: _T, mag: Mag) -> _T: + """ + Returns the bounding box in the given mag. + + Args: + - mag (Mag): The magnification to convert the bounding box to. + + Returns: + - NDBoundingBox: The bounding box in the given magnification. + """ + mag_vec = mag.to_vec3_int() + + assert ( + self.topleft_xyz % mag_vec == Vec3Int.zeros() + ), f"topleft {self.topleft} is not aligned with the mag {mag}. Use BoundingBox.align_with_mag()." + assert ( + self.bottomright_xyz % mag_vec == Vec3Int.zeros() + ), f"bottomright {self.bottomright} is not aligned with the mag {mag}. Use BoundingBox.align_with_mag()." + + return self.with_topleft_xyz(self.topleft_xyz // mag_vec).with_size_xyz( + self.size_xyz // mag_vec + ) + + def from_mag_to_mag1(self: _T, from_mag: Mag) -> _T: + """ + Returns the bounging box in the finest magnification (Mag(1)). + + Args: + - from_mag (Mag): The current magnification of the bounding box. + + Returns: + - NDBoundingBox: The bounding box in the given magnification. + """ + mag_vec = from_mag.to_vec3_int() + + return self.with_topleft_xyz(self.topleft_xyz * mag_vec).with_size_xyz( + self.size_xyz * mag_vec + ) + + def _align_with_mag_slow(self: _T, mag: Mag, ceil: bool = False) -> _T: + """Rounds the bounding box, so that both topleft and bottomright are divisible by mag. + + :argument ceil: If true, the bounding box is enlarged when necessary. If false, it's shrinked when necessary. + """ + np_mag = mag.to_np() + + align = ( # noqa E731 + lambda point, round_fn: round_fn(point.to_np() / np_mag).astype(int) + * np_mag + ) + + if ceil: + topleft = align(self.topleft, np.floor) + bottomright = align(self.bottomright, np.ceil) + else: + topleft = align(self.topleft, np.ceil) + bottomright = align(self.bottomright, np.floor) + return attr.evolve(self, topleft=topleft, size=bottomright - topleft) + + def align_with_mag(self: _T, mag: Union[Mag, Vec3Int], ceil: bool = False) -> _T: + """ + Rounds the bounding box, so that both topleft and bottomright are divisible by mag. + + Args: + - mag (Union[Mag, Vec3Int]): The magnification to align the bounding box to. + - ceil (bool): If True, the bounding box is enlarged when necessary. If False, it's shrinked when necessary. + + Returns: + - NDBoundingBox: The aligned bounding box. + """ + # This does the same as _align_with_mag_slow, which is more readable. + # Same behavior is asserted in test_align_with_mag_against_numpy_implementation + mag_vec = mag.to_vec3_int() if isinstance(mag, Mag) else mag + topleft = self.topleft_xyz + bottomright = self.bottomright_xyz + roundup = topleft if ceil else bottomright + rounddown = bottomright if ceil else topleft + margin_to_roundup = roundup % mag_vec + aligned_roundup = roundup - margin_to_roundup + margin_to_rounddown = (mag_vec - (rounddown % mag_vec)) % mag_vec + aligned_rounddown = rounddown + margin_to_rounddown + if ceil: + return self.with_topleft_xyz(aligned_roundup).with_size_xyz( + aligned_rounddown - aligned_roundup + ) + else: + return self.with_topleft_xyz(aligned_rounddown).with_size_xyz( + aligned_roundup - aligned_rounddown + ) + + def contains(self, coord: VecIntLike) -> bool: + """ + Check whether a point is inside of the bounding box. + Note that the point may have float coordinates in the ndarray case + + Args: + - coord (VecIntLike): The coordinates to check. + + Returns: + - bool: True if the point is inside of the bounding box, False otherwise. + """ + + if isinstance(coord, np.ndarray): + assert ( + coord.shape == (len(self.size),) + ), f"Numpy array BoundingBox.contains must have shape ({len(self.size)},), got {coord.shape}." + return cast( + bool, + np.all(coord >= self.topleft) and np.all(coord < self.bottomright), + ) + else: + # In earlier versions, we simply converted to ndarray to have + # a unified calculation here, but this turned out to be a performance bottleneck. + # Therefore, the contains-check is performed on the tuple here. + coord = VecInt(coord, axes=self.axes) + return all( + self.topleft[i] <= coord[i] < self.bottomright[i] + for i in range(len(self.axes)) + ) + + def contains_bbox(self: _T, inner_bbox: _T) -> bool: + """ + Check whether a bounding box is completely inside of the bounding box. + + Args: + - inner_bbox (NDBoundingBox): The bounding box to check. + + Returns: + - bool: True if the bounding box is completely inside of the bounding box, False otherwise. + """ + self._check_compatibility(inner_bbox) + return inner_bbox.intersected_with(self, dont_assert=True) == inner_bbox + + def chunk( + self: _T, + chunk_shape: VecIntLike, + chunk_border_alignments: Optional[VecIntLike] = None, + ) -> Generator[_T, None, None]: + """ + Decompose the bounding box into smaller chunks of size `chunk_shape`. + + Chunks at the border of the bounding box might be smaller than chunk_shape. + If `chunk_border_alignment` is set, all border coordinates + *between two chunks* will be divisible by that value. + + Args: + - chunk_shape (VecIntLike): The size of the chunks to generate. + - chunk_border_alignments (Optional[VecIntLike]): The alignment of the chunk borders. + + + Yields: + - Generator[NDBoundingBox]: A generator of the chunks. + """ + + start = self.topleft.to_np() + try: + # If a 3D chunk_shape is given it is assumed that iteration over xyz is + # intended. Therefore NDBoundingBoxes are generated that have a shape of + # x: chunk_shape.x, y: chunk_shape.y, z: chunk_shape.z and 1 for all other + # axes. + chunk_shape = Vec3Int(chunk_shape) + + chunk_shape = ( + self.with_size(VecInt.ones(self.axes)) + .with_size_xyz(chunk_shape) + .size.to_np() + ) + except AssertionError: + chunk_shape = VecInt(chunk_shape, axes=self.axes).to_np() + + start_adjust = VecInt.zeros(self.axes).to_np() + if chunk_border_alignments is not None: + try: + chunk_border_alignments = Vec3Int(chunk_border_alignments) + + chunk_border_alignments = ( + self.with_size(VecInt.ones(self.axes)) + .with_size_xyz(chunk_border_alignments) + .size.to_np() + ) + except AssertionError: + chunk_border_alignments = VecInt( + chunk_border_alignments, axes=self.axes + ).to_np() + + assert np.all( + chunk_shape % chunk_border_alignments == 0 + ), f"{chunk_shape} not divisible by {chunk_border_alignments}" + + # Move the start to be aligned correctly. This doesn't actually change + # the start of the first chunk, because we'll intersect with `self`, + # but it'll lead to all chunk borders being aligned correctly. + start_adjust = start % chunk_border_alignments + for coordinates in product( + *[ + range( + start[i] - start_adjust[i], start[i] + self.size[i], chunk_shape[i] + ) + for i in range(len(self.axes)) + ] + ): + yield self.intersected_with( + self.__class__( + topleft=VecInt(coordinates, axes=self.axes), + size=VecInt(chunk_shape, axes=self.axes), + axes=self.axes, + index=self.index, + ) + ) + + def volume(self) -> int: + """ + Returns the volume of the bounding box. + """ + return self.size.prod() + + def slice_array(self, array: np.ndarray) -> np.ndarray: + """ + Returns a slice of the given array that corresponds to the bounding box. + """ + return array[self.to_slices()] + + def to_slices(self) -> Tuple[slice, ...]: + """ + Returns a tuple of slices that corresponds to the bounding box. + """ + return tuple( + slice(topleft, topleft + size) + for topleft, size in zip(self.topleft, self.size) + ) + + def offset(self: _T, vector: VecIntLike) -> _T: + """ + Returns a new NDBoundingBox object with the specified offset. + + Args: + - vector (VecIntLike): The offset to apply to the bounding box. + + Returns: + - NDBoundingBox: A new NDBoundingBox object with the specified offset. + """ + try: + return self.with_topleft_xyz(self.topleft_xyz + Vec3Int(vector)) + except AssertionError: + return self.with_topleft(self.topleft + VecInt(vector, axes=self.axes)) diff --git a/webknossos/webknossos/geometry/vec3_int.py b/webknossos/webknossos/geometry/vec3_int.py index a38d149e6..ed1f7b6be 100644 --- a/webknossos/webknossos/geometry/vec3_int.py +++ b/webknossos/webknossos/geometry/vec3_int.py @@ -1,18 +1,19 @@ import re -from operator import add, floordiv, mod, mul, sub -from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, cast +from typing import Iterable, Optional, Tuple, Union, cast import numpy as np -value_error = "Vector components must be three integers or a Vec3IntLike object." +from .vec_int import VecInt +_VALUE_ERROR = "Vector components must be three integers or a Vec3IntLike object." -class Vec3Int(tuple): + +class Vec3Int(VecInt): def __new__( cls, - vec: Union[int, "Vec3IntLike"], - y: Optional[int] = None, - z: Optional[int] = None, + *args: Union["Vec3IntLike", Iterable[str], int], + axes: Optional[Iterable[str]] = ("x", "y", "z"), + **kwargs: int, ) -> "Vec3Int": """ Class to represent a 3D vector. Inherits from tuple and provides useful @@ -31,54 +32,36 @@ def __new__( ``` """ - if isinstance(vec, Vec3Int): - return vec + if args: + if isinstance(args[0], Vec3Int): + return args[0] - as_tuple: Optional[Tuple[int, int, int]] = None + assert axes is not None, _VALUE_ERROR - if isinstance(vec, int): - assert y is not None and z is not None, value_error - assert isinstance(y, int) and isinstance(z, int), value_error - as_tuple = vec, y, z - else: - assert y is None and z is None, value_error - if isinstance(vec, np.ndarray): - assert np.count_nonzero(vec % 1) == 0, value_error - assert vec.shape == ( - 3, - ), "Numpy array for Vec3Int must have shape (3,)." - if isinstance(vec, Iterable): - as_tuple = cast(Tuple[int, int, int], tuple(int(item) for item in vec)) - assert len(as_tuple) == 3, value_error - assert as_tuple is not None and len(as_tuple) == 3, value_error - - return super().__new__(cls, cast(Iterable, as_tuple)) + if isinstance(args[0], Iterable): + self = super().__new__(cls, *args[0], axes=("x", "y", "z")) + assert self is not None and len(self) == 3, _VALUE_ERROR - @staticmethod - def from_xyz(x: int, y: int, z: int) -> "Vec3Int": - """Use Vec3Int.from_xyz for fast construction.""" + return cast(Vec3Int, self) - # By calling __new__ of tuple directly, we circumvent - # the tolerant (and potentially) slow Vec3Int.__new__ method. - return tuple.__new__(Vec3Int, (x, y, z)) + assert len(args) == 3 and len(tuple(axes)) == 3, _VALUE_ERROR + assert kwargs is None or len(kwargs) == 0, _VALUE_ERROR + assert "x" in axes and "y" in axes and "z" in axes, _VALUE_ERROR + values, _ = zip(*sorted(zip(args, axes), key=lambda x: x[1])) + else: + assert "x" in kwargs and "y" in kwargs and "z" in kwargs, _VALUE_ERROR + assert len(kwargs) == 3, _VALUE_ERROR + values = kwargs["x"], kwargs["y"], kwargs["z"] - @staticmethod - def from_vec3_float(vec: Tuple[float, float, float]) -> "Vec3Int": - return Vec3Int(int(vec[0]), int(vec[1]), int(vec[2])) + self = super().__new__(cls, *values, axes=("x", "y", "z")) + self.axes = ("x", "y", "z") - @staticmethod - def from_vec_or_int(vec_or_int: Union["Vec3IntLike", int]) -> "Vec3Int": - if isinstance(vec_or_int, int): - return Vec3Int.full(vec_or_int) - else: - return Vec3Int(vec_or_int) + assert self is not None and len(self) == 3, _VALUE_ERROR - @staticmethod - def from_str(string: str) -> "Vec3Int": - if re.match(r"\(\d+,\d+,\d+\)", string): - return Vec3Int(tuple(map(int, re.findall(r"\d+", string)))) - else: - return Vec3Int.full(int(string)) + return cast(Vec3Int, self) + + def __getnewargs__(self) -> Tuple[Tuple[int, ...], Tuple[str, ...]]: + return (self.to_tuple(), self.axes) @property def x(self) -> int: @@ -101,104 +84,47 @@ def with_y(self, new_y: int) -> "Vec3Int": def with_z(self, new_z: int) -> "Vec3Int": return Vec3Int.from_xyz(self.x, self.y, new_z) - def to_np(self) -> np.ndarray: - return np.array((self.x, self.y, self.z)) - - def to_list(self) -> List[int]: - return [self.x, self.y, self.z] - def to_tuple(self) -> Tuple[int, int, int]: - return self.x, self.y, self.z + return (self.x, self.y, self.z) - def contains(self, needle: int) -> bool: - return self.x == needle or self.y == needle or self.z == needle - - def is_positive(self, strictly_positive: bool = False) -> bool: - if strictly_positive: - return all(i > 0 for i in self) - else: - return all(i >= 0 for i in self) - - def is_uniform(self) -> bool: - return self.x == self.y == self.z - - def _element_wise( - self, other: Union[int, "Vec3IntLike"], fn: Callable[[int, Any], int] - ) -> "Vec3Int": - if isinstance(other, int): - other_imported = Vec3Int.from_xyz(other, other, other) - else: - other_imported = Vec3Int(other) - return Vec3Int.from_xyz( - fn(self.x, other_imported.x), - fn(self.y, other_imported.y), - fn(self.z, other_imported.z), - ) - - # note: (arguments incompatible with superclass, do not add Vec3Int to plain tuple! Hence the type:ignore) - def __add__(self, other: Union[int, "Vec3IntLike"]) -> "Vec3Int": # type: ignore[override] - return self._element_wise(other, add) - - def __sub__(self, other: Union[int, "Vec3IntLike"]) -> "Vec3Int": - return self._element_wise(other, sub) - - # Note: When multiplying regular tuples with an int those are repeated, - # which is a different behavior in the superclass! Hence the type:ignore. - def __mul__(self, other: Union[int, "Vec3IntLike"]) -> "Vec3Int": # type: ignore[override] - return self._element_wise(other, mul) - - def __floordiv__(self, other: Union[int, "Vec3IntLike"]) -> "Vec3Int": - return self._element_wise(other, floordiv) - - def __mod__(self, other: Union[int, "Vec3IntLike"]) -> "Vec3Int": - return self._element_wise(other, mod) - - def __neg__(self) -> "Vec3Int": - return Vec3Int.from_xyz(-self.x, -self.y, -self.z) - - def ceildiv(self, other: Union[int, "Vec3IntLike"]) -> "Vec3Int": - return (self + other - 1) // other - - def pairmax(self, other: Union[int, "Vec3IntLike"]) -> "Vec3Int": - return self._element_wise(other, max) + @staticmethod + def from_xyz(x: int, y: int, z: int) -> "Vec3Int": + """Use Vec3Int.from_xyz for fast construction.""" - def pairmin(self, other: Union[int, "Vec3IntLike"]) -> "Vec3Int": - return self._element_wise(other, min) + # By calling __new__ of tuple directly, we circumvent + # the tolerant (and potentially) slow Vec3Int.__new__ method. + vec3int = tuple.__new__(Vec3Int, (x, y, z)) + vec3int.axes = ("x", "y", "z") + return vec3int - def prod(self) -> int: - return self.x * self.y * self.z + @staticmethod + def from_vec3_float(vec: Tuple[float, float, float]) -> "Vec3Int": + return Vec3Int(int(vec[0]), int(vec[1]), int(vec[2])) - def __repr__(self) -> str: - return f"Vec3Int({self.x},{self.y},{self.z})" + @staticmethod + def from_vec_or_int(vec_or_int: Union["Vec3IntLike", int]) -> "Vec3Int": + if isinstance(vec_or_int, int): + return Vec3Int.full(vec_or_int) - def add_or_none(self, other: Optional["Vec3Int"]) -> Optional["Vec3Int"]: - return None if other is None else self + other + return Vec3Int(vec_or_int) - def moveaxis( - self, source: Union[int, List[int]], target: Union[int, List[int]] - ) -> "Vec3Int": - """ - Allows to move one element at index `source` to another index `target`. Similar to - np.moveaxis, this is *not* a swap operation but instead it moves the specified - source so that the other elements move when necessary. - """ + @staticmethod + def from_str(string: str) -> "Vec3Int": + if re.match(r"\(\d+,\d+,\d+\)", string): + return Vec3Int(tuple(map(int, re.findall(r"\d+", string)))) - # Piggy-back on np.moveaxis by creating an auxiliary array where the indices 0, 1 and - # 2 appear in the shape. - indices = np.moveaxis(np.zeros((0, 1, 2)), source, target).shape - arr = self.to_np()[np.array(indices)] - return Vec3Int(arr) + return Vec3Int.full(int(string)) @classmethod - def zeros(cls) -> "Vec3Int": + def zeros(cls, _axes: Tuple[str, ...] = ("x", "y", "z")) -> "Vec3Int": return cls(0, 0, 0) @classmethod - def ones(cls) -> "Vec3Int": + def ones(cls, _axes: Tuple[str, ...] = ("x", "y", "z")) -> "Vec3Int": return cls(1, 1, 1) @classmethod - def full(cls, an_int: int) -> "Vec3Int": + def full(cls, an_int: int, _axes: Tuple[str, ...] = ("x", "y", "z")) -> "Vec3Int": return cls(an_int, an_int, an_int) diff --git a/webknossos/webknossos/geometry/vec_int.py b/webknossos/webknossos/geometry/vec_int.py new file mode 100644 index 000000000..0c64bc5c0 --- /dev/null +++ b/webknossos/webknossos/geometry/vec_int.py @@ -0,0 +1,336 @@ +import re +from operator import add, floordiv, mod, mul, sub +from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union, cast + +import numpy as np + +_VALUE_ERROR = "VecInt can be instantiated with int values `VecInt(1,2,3,4) or with `VecIntLike` object `VecInt([1,2,3,4])." + +_T = TypeVar("_T", bound="VecInt") + + +class VecInt(tuple): + """ + The VecInt class is designed to represent a vector of integers. This class is a subclass of the built-in tuple class, and it extends the functionality of tuples by providing additional methods and operations. + + One of the key features of the VecInt class is that it allows for the storage of axis names along with their corresponding values. + + Here is a brief example demonstrating how to use the VecInt class: + + ```python + from webknossos import VecInt + + # Creating a VecInt instance with 4 elements and axes x, y, z, t: + vector_1 = VecInt(1, 2, 3, 4, axes=("x", "y", "z", "t")) + # Alternative ways to create the same VecInt instance: + vector_1 = VecInt([1, 2, 3, 4], axes=("x", "y", "z", "t")) + vector_1 = VecInt(x=1, y=2, z=3, t=4) + + # Creating a VecInt instance with all elements set to 1 and axes x, y, z, t: + vector_2 = VecInt.full(1, axes=("x", "y", "z", "t")) + # Asserting that all elements in vector_2 are equal to 1: + assert vector_2[0] == vector_2[1] == vector_2[2] == vector_2[3] + + # Demonstrating the addition operation between two VecInt instances: + assert vector_1 + vector_2 == VecInt(2, 3, 4, 5) + ``` + """ + + axes: Tuple[str, ...] + _x_pos: Optional[int] + _y_pos: Optional[int] + _z_pos: Optional[int] + + def __new__( + cls, + *args: Union["VecIntLike", Iterable[str], int], + axes: Optional[Iterable[str]] = None, + **kwargs: int, + ) -> "VecInt": + as_tuple: Optional[Tuple[int, ...]] = None + + if args: + if isinstance(args[0], VecInt): + return args[0] + if isinstance(args[0], np.ndarray): + assert np.count_nonzero(args[0] % 1) == 0, _VALUE_ERROR + if isinstance(args[0], str): + return cls.from_str(args[0]) + if isinstance(args[0], Iterable): + as_tuple = tuple(int(item) for item in args[0]) + if args[1:] and isinstance(args[1], Iterable): + assert all(isinstance(arg, str) for arg in args[1]), _VALUE_ERROR + axes = tuple(args[1]) # type: ignore + elif isinstance(args, Iterable): + as_tuple = tuple(int(arg) for arg in args) # type: ignore + else: + raise ValueError(_VALUE_ERROR) + assert axes is not None, _VALUE_ERROR + else: + assert kwargs, _VALUE_ERROR + assert axes is None, _VALUE_ERROR + as_tuple = tuple(kwargs.values()) + + assert as_tuple is not None, _VALUE_ERROR + + self = super().__new__(cls, cast(Iterable, as_tuple)) + # self.axes is set in __new__ instead of __init__ so that pickling/unpickling + # works without problems. As long as the deserialization of a tree instance + # is not finished, the object is only half-initialized. Since self.axes + # is needed after deepcopy, an error would be raised otherwise. + # Also see: + # https://stackoverflow.com/questions/46283738/attributeerror-when-using-python-deepcopy + self.axes = tuple(axes or kwargs.keys()) + self._x_pos = self.axes.index("x") if "x" in self.axes else None + self._y_pos = self.axes.index("y") if "y" in self.axes else None + self._z_pos = self.axes.index("z") if "z" in self.axes else None + + return self + + def __getnewargs__(self) -> Tuple[Tuple[int, ...], Tuple[str, ...]]: + return (self.to_tuple(), self.axes) + + @property + def x(self) -> int: + """ + Returns the x component of the vector. + """ + if self._x_pos is not None: + return self[self._x_pos] + + raise ValueError("The vector does not have an x component.") + + @property + def y(self) -> int: + """ + Returns the y component of the vector. + """ + if self._y_pos is not None: + return self[self._y_pos] + + raise ValueError("The vector does not have an y component.") + + @property + def z(self) -> int: + """ + Returns the z component of the vector. + """ + if self._z_pos is not None: + return self[self._z_pos] + + raise ValueError("The vector does not have an z component.") + + @staticmethod + def from_str(string: str) -> "VecInt": + """ + Returns a new ND Vector from a string representation. + + Args: + - string (str): The string representation of the vector. + + Returns: + - VecInt: The new vector. + """ + return VecInt(tuple(map(int, re.findall(r"\d+", string)))) + + def with_replaced(self: _T, index: int, new_element: int) -> _T: + """Returns a new ND Vector with a replaced element at a given index.""" + + return self.__class__( + *self[:index], new_element, *self[index + 1 :], axes=self.axes + ) + + def to_np(self) -> np.ndarray: + """ + Returns the vector as a numpy array. + """ + return np.array(self) + + def to_list(self) -> List[int]: + """ + Returns the vector as a list. + """ + return list(self) + + def to_tuple(self) -> Tuple[int, ...]: + """ + Returns the vector as a tuple. + """ + return tuple(self) + + def contains(self, needle: int) -> bool: + """ + Checks if the vector contains a given element. + """ + return any(element == needle for element in self) + + def is_positive(self, strictly_positive: bool = False) -> bool: + """ + Checks if all elements in the vector are positive. + + Args: + - strictly_positive (bool): If True, checks if all elements are strictly positive. + + Returns: + - bool: True if all elements are positive, False otherwise. + """ + if strictly_positive: + return all(i > 0 for i in self) + + return all(i >= 0 for i in self) + + def is_uniform(self) -> bool: + """ + Checks if all elements in the vector are the same. + """ + first = self[0] + return all(element == first for element in self) + + def _element_wise( + self: _T, other: Union[int, "VecIntLike"], fn: Callable[[int, Any], int] + ) -> _T: + if isinstance(other, int): + other_imported = VecInt.full(other, axes=self.axes) + else: + other_imported = VecInt(other, axes=self.axes) + assert len(other_imported) == len( + self + ), f"{other} and {self} are not equally shaped." + return self.__class__( + **{ + axis: fn(self[i], other_imported[i]) for i, axis in enumerate(self.axes) + }, + axes=None, + ) + + # Note: When adding regular tuples the first tuple is extended with the second tuple. + # For VecInt we want to add the elements at the same index. + # Do not add VecInt to plain tuple! Hence the type:ignore) + def __add__(self: _T, other: Union[int, "VecIntLike"]) -> _T: # type: ignore[override] + return self._element_wise(other, add) + + def __sub__(self: _T, other: Union[int, "VecIntLike"]) -> _T: + return self._element_wise(other, sub) + + # Note: When multiplying regular tuples with an int those are repeated, + # which is a different behavior in the superclass! Hence the type:ignore. + def __mul__(self: _T, other: Union[int, "VecIntLike"]) -> _T: # type: ignore[override] + return self._element_wise(other, mul) + + def __floordiv__(self: _T, other: Union[int, "VecIntLike"]) -> _T: + return self._element_wise(other, floordiv) + + def __mod__(self: _T, other: Union[int, "VecIntLike"]) -> _T: + return self._element_wise(other, mod) + + def __neg__(self: _T) -> _T: + return self.__class__((-elem for elem in self), axes=self.axes) + + def ceildiv(self: _T, other: Union[int, "VecIntLike"]) -> _T: + """ + Returns a new VecInt with the ceil division of each element by the other. + """ + return (self + other - 1) // other + + def pairmax(self: _T, other: Union[int, "VecIntLike"]) -> _T: + """ + Returns a new VecInt with the maximum of each pair of elements from the two vectors. + """ + return self._element_wise(other, max) + + def pairmin(self: _T, other: Union[int, "VecIntLike"]) -> _T: + """ + Returns a new VecInt with the minimum of each pair of elements from the two vectors. + """ + return self._element_wise(other, min) + + def prod(self) -> int: + """ + Returns the product of all elements in the vector. + """ + return int(np.prod(self.to_np())) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({','.join((str(element) for element in self))})" + ) + + def add_or_none(self: _T, other: Optional["VecInt"]) -> Optional[_T]: + """ + Adds two VecInts or returns None if the other is None. + + Args: + - other (Optional[VecInt]): The other vector to add. + + Returns: + - Optional[VecInt]: The sum of the two vectors or None if the other is None. + """ + return None if other is None else self + other + + def moveaxis( + self: _T, source: Union[int, List[int]], target: Union[int, List[int]] + ) -> _T: + """ + Allows to move one element at index `source` to another index `target`. Similar to + np.moveaxis, this is *not* a swap operation but instead it moves the specified + source so that the other elements move when necessary. + + Args: + - source (Union[int, List[int]]): The index of the element to move. + - target (Union[int, List[int]]): The index where the element should be moved to. + + Returns: + - VecInt: A new vector with the moved element. + """ + + # Piggy-back on np.moveaxis by creating an auxiliary array where the indices 0, 1 and + # 2 appear in the shape. + indices = np.moveaxis( + np.zeros(tuple(i for i in range(len(self)))), source, target + ).shape + arr = self.to_np()[np.array(indices)] + axes = np.array(self.axes)[np.array(indices)] + return self.__class__(arr, axes=axes) + + @classmethod + def zeros(cls, axes: Tuple[str, ...]) -> "VecInt": + """ + Returns a new ND Vector with all elements set to 0. + + Args: + - axes (Tuple[str, ...]): The axes of the vector. + + Returns: + - VecInt: The new vector. + """ + return cls((0 for _ in range(len(axes))), axes=axes) + + @classmethod + def ones(cls, axes: Tuple[str, ...]) -> "VecInt": + """ + Returns a new ND Vector with all elements set to 1. + + Args: + - axes (Tuple[str, ...]): The axes of the vector. + + Returns: + - VecInt: The new vector. + """ + return cls((1 for _ in range(len(axes))), axes=axes) + + @classmethod + def full(cls, an_int: int, axes: Tuple[str, ...]) -> "VecInt": + """ + Returns a new ND Vector with all elements set to the same value. + + Args: + - an_int (int): The value to set all elements to. + - axes (Tuple[str, ...]): The axes of the vector. + + Returns: + - VecInt: The new vector. + """ + return cls((an_int for _ in range(len(axes))), axes=axes) + + +VecIntLike = Union[VecInt, Tuple[int, ...], np.ndarray, Iterable[int]]