diff --git a/ome_zarr/data.py b/ome_zarr/data.py index 0b92a21a..8ae48324 100644 --- a/ome_zarr/data.py +++ b/ome_zarr/data.py @@ -1,6 +1,6 @@ """Functions for generating synthetic data.""" from random import randrange -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Union import numpy as np import zarr @@ -32,8 +32,6 @@ def coins() -> Tuple[List, List]: pyramid = list(reversed([zoom(image, 2 ** i, order=3) for i in range(4)])) labels = list(reversed([zoom(label_image, 2 ** i, order=0) for i in range(4)])) - pyramid = [rgb_to_5d(layer) for layer in pyramid] - labels = [rgb_to_5d(layer) for layer in labels] return pyramid, labels @@ -41,14 +39,19 @@ def astronaut() -> Tuple[List, List]: """Sample data from skimage.""" scaler = Scaler() - pixels = rgb_to_5d(np.tile(data.astronaut(), (2, 2, 1))) + astro = data.astronaut() + red = astro[:, :, 0] + green = astro[:, :, 1] + blue = astro[:, :, 2] + astro = np.array([red, green, blue]) + pixels = np.tile(astro, (1, 2, 2)) pyramid = scaler.nearest(pixels) shape = list(pyramid[0].shape) - shape[1] = 1 - label = np.zeros(shape) - make_circle(100, 100, 1, label[0, 0, 0, 200:300, 200:300]) - make_circle(150, 150, 2, label[0, 0, 0, 250:400, 250:400]) + c, y, x = shape + label = np.zeros((y, x), dtype=np.int8) + make_circle(100, 100, 1, label[200:300, 200:300]) + make_circle(150, 150, 2, label[250:400, 250:400]) labels = scaler.nearest(label) return pyramid, labels @@ -100,6 +103,7 @@ def create_zarr( method: Callable[..., Tuple[List, List]] = coins, label_name: str = "coins", fmt: Format = CurrentFormat(), + chunks: Union[Tuple, List] = None, ) -> None: """Generate a synthetic image pyramid with labels.""" pyramid, labels = method() @@ -107,11 +111,35 @@ def create_zarr( loc = parse_url(zarr_directory, mode="w") assert loc grp = zarr.group(loc.store) - write_multiscale(pyramid, grp) - - if pyramid[0].shape[CHANNEL_DIMENSION] == 1: + axes = None + size_c = 1 + if fmt.version not in ("0.1", "0.2"): + if pyramid[0].ndim == 3: + axes = "cyx" + size_c = 3 + else: + axes = "tczyx"[-pyramid[0].ndim :] + size_c = 1 + else: + # v0.1 and v0.2 must be 5D + pyramid = [rgb_to_5d(layer) for layer in pyramid] + if labels: + labels = [rgb_to_5d(layer) for layer in labels] + size_c = pyramid[0].shape[CHANNEL_DIMENSION] + + if chunks is None: + # Use smallest pyramid as chunk size... + chunks = list(pyramid[-1].shape) + # setting any z, c, t sizes to 1 + for zct in range(3): + if zct + 2 < len(chunks): + chunks[zct] = 1 + + write_multiscale(pyramid, grp, chunks=tuple(chunks), axes=axes) + + if size_c == 1: image_data = { - "channels": [{"window": {"start": 0, "end": 1}}], + "channels": [{"window": {"start": 0, "end": 255}, "color": "FF0000"}], "rdefs": {"model": "greyscale"}, } else: @@ -119,19 +147,19 @@ def create_zarr( "channels": [ { "color": "FF0000", - "window": {"start": 0, "end": 1}, + "window": {"start": 0, "end": 255}, "label": "Red", "active": True, }, { "color": "00FF00", - "window": {"start": 0, "end": 1}, + "window": {"start": 0, "end": 255}, "label": "Green", "active": True, }, { "color": "0000FF", - "window": {"start": 0, "end": 1}, + "window": {"start": 0, "end": 255}, "label": "Blue", "active": True, }, @@ -146,7 +174,10 @@ def create_zarr( labels_grp.attrs["labels"] = [label_name] label_grp = labels_grp.create_group(label_name) - write_multiscale(labels, label_grp) + if axes is not None: + # remove channel axis for masks + axes = axes.replace("c", "") + write_multiscale(labels, label_grp, axes=axes) colors = [] properties = [] diff --git a/ome_zarr/scale.py b/ome_zarr/scale.py index 4509a2a3..8734a497 100644 --- a/ome_zarr/scale.py +++ b/ome_zarr/scale.py @@ -168,14 +168,10 @@ def laplacian(self, base: np.ndarray) -> List[np.ndarray]: def local_mean(self, base: np.ndarray) -> List[np.ndarray]: """Downsample using :func:`skimage.transform.downscale_local_mean`.""" rv = [base] - # FIXME: fix hard-coding - rv = [base] + stack_dims = base.ndim - 2 + factors = (*(1,) * stack_dims, *(self.downscale, self.downscale)) for i in range(self.max_layer): - rv.append( - downscale_local_mean( - rv[-1], factors=(1, 1, 1, self.downscale, self.downscale) - ) - ) + rv.append(downscale_local_mean(rv[-1], factors=factors)) return rv def zoom(self, base: np.ndarray) -> List[np.ndarray]: @@ -198,23 +194,37 @@ def _by_plane( func: Callable[[np.ndarray, int, int], np.ndarray], ) -> np.ndarray: """Loop over 3 of the 5 dimensions and apply the func transform.""" - assert 5 == len(base.shape) rv = [base] for i in range(self.max_layer): - fiveD = rv[-1] - # FIXME: fix hard-coding of dimensions - T, C, Z, Y, X = fiveD.shape + stack_to_scale = rv[-1] + shape_5d = (*(1,) * (5 - stack_to_scale.ndim), *stack_to_scale.shape) + T, C, Z, Y, X = shape_5d + + # If our data is already 2D, simply resize and add to pyramid + if stack_to_scale.ndim == 2: + rv.append(func(stack_to_scale, Y, X)) + continue - smaller = None + # stack_dims is any dims over 2D + stack_dims = stack_to_scale.ndim - 2 + new_stack = None for t in range(T): for c in range(C): for z in range(Z): - out = func(fiveD[t][c][z][:], Y, X) - if smaller is None: - smaller = np.zeros( - (T, C, Z, out.shape[0], out.shape[1]), dtype=base.dtype + dims_to_slice = (t, c, z)[-stack_dims:] + # slice nd down to 2D + plane = stack_to_scale[(dims_to_slice)][:] + out = func(plane, Y, X) + # first iteration of loop creates the new nd stack + if new_stack is None: + zct_dims = shape_5d[:-2] + shape_dims = zct_dims[-stack_dims:] + new_stack = np.zeros( + (*shape_dims, out.shape[0], out.shape[1]), + dtype=base.dtype, ) - smaller[t][c][z] = out - rv.append(smaller) + # insert resized plane into the stack at correct indices + new_stack[(dims_to_slice)] = out + rv.append(new_stack) return rv diff --git a/ome_zarr/writer.py b/ome_zarr/writer.py index 8aee269d..d4e80d91 100644 --- a/ome_zarr/writer.py +++ b/ome_zarr/writer.py @@ -19,15 +19,48 @@ def write_multiscale( group: zarr.Group, chunks: Union[Tuple[Any, ...], int] = None, fmt: Format = CurrentFormat(), + axes: Union[str, List[str]] = None, ) -> None: """ Write a pyramid with multiscale metadata to disk. Parameters ---------- - TODO: + pyramid: List of np.ndarray + the image data to save. Largest level first + group: zarr.Group + the group within the zarr store to store the data in + chunks: int or tuple of ints, + size of the saved chunks to store the image + fmt: Format + The format of the ome_zarr data which should be used. + Defaults to the most current. + axes: str or list of str + the names of the axes. e.g. "tczyx". Not needed for v0.1 or v0.2 + or for v0.3 if 2D or 5D. Otherwise this must be provided """ + dims = len(pyramid[0].shape) + if fmt.version not in ("0.1", "0.2"): + if axes is None: + if dims == 2: + axes = ["y", "x"] + elif dims == 5: + axes = ["t", "c", "z", "y", "x"] + else: + raise ValueError( + "axes must be provided. Can't be guessed for 3D or 4D data" + ) + if len(axes) != dims: + raise ValueError("axes length must match number of dimensions") + + if isinstance(axes, str): + axes = list(axes) + + for dim in axes: + if dim not in ("t", "c", "z", "y", "x"): + raise ValueError("axes must each be one of 'x', 'y', 'z', 'c' or 't'") + paths = [] for path, dataset in enumerate(pyramid): # TODO: chunks here could be different per layer @@ -35,6 +68,8 @@ def write_multiscale( paths.append({"path": str(path)}) multiscales = [{"version": fmt.version, "datasets": paths}] + if axes is not None: + multiscales[0]["axes"] = axes group.attrs["multiscales"] = multiscales @@ -45,6 +80,7 @@ def write_image( byte_order: Union[str, List[str]] = "tczyx", scaler: Scaler = Scaler(), fmt: Format = CurrentFormat(), + axes: Union[str, List[str]] = None, **metadata: JSONDict, ) -> None: """Writes an image to the zarr store according to ome-zarr specification @@ -67,16 +103,21 @@ def write_image( fmt: Format The format of the ome_zarr data which should be used. Defaults to the most current. + axes: str or list of str + the names of the axes. e.g. "tczyx". Not needed for v0.1 or v0.2 + or for v0.3 if 2D or 5D. Otherwise this must be provided """ if image.ndim > 5: raise ValueError("Only images of 5D or less are supported") - shape_5d: Tuple[Any, ...] = (*(1,) * (5 - image.ndim), *image.shape) - image = image.reshape(shape_5d) + if fmt.version in ("0.1", "0.2"): + # v0.1 and v0.2 are strictly 5D + shape_5d: Tuple[Any, ...] = (*(1,) * (5 - image.ndim), *image.shape) + image = image.reshape(shape_5d) if chunks is not None: - chunks = _retuple(chunks, shape_5d) + chunks = _retuple(chunks, image.shape) if scaler is not None: image = scaler.nearest(image) @@ -84,7 +125,7 @@ def write_image( LOGGER.debug("disabling pyramid") image = [image] - write_multiscale(image, group, chunks=chunks, fmt=fmt) + write_multiscale(image, group, chunks=chunks, fmt=fmt, axes=axes) group.attrs.update(metadata) @@ -98,4 +139,6 @@ def _retuple( else: _chunks = chunks - return (*shape[: (5 - len(_chunks))], *_chunks) + dims_to_add = len(shape) - len(_chunks) + + return (*shape[:dims_to_add], *_chunks) diff --git a/tests/test_ome_zarr.py b/tests/test_ome_zarr.py index ed34d34b..9691a466 100644 --- a/tests/test_ome_zarr.py +++ b/tests/test_ome_zarr.py @@ -7,10 +7,10 @@ from ome_zarr.utils import download, info -def log_strings(idx, t, c, z, y, x, ct, cc, cz, cy, cx, dtype): +def log_strings(idx, c, y, x, cc, cy, cx, dtype): yield f"resolution: {idx}" - yield f" - shape ('t', 'c', 'z', 'y', 'x') = ({t}, {c}, {z}, {y}, {x})" - yield f" - chunks = ['{ct}', '{cc}', '{cz}', '{cx}', '{cy}']" + yield f" - shape ('c', 'y', 'x') = ({c}, {y}, {x})" + yield f" - chunks = ['{cc}', '{cx}', '{cy}']" yield f" - dtype = {dtype}" @@ -21,9 +21,9 @@ def initdir(self, tmpdir): create_zarr(str(self.path), method=astronaut) def check_info_stdout(self, out): - for log in log_strings(0, 1, 3, 1, 1024, 1024, 1, 1, 1, 256, 256, "float64"): + for log in log_strings(0, 3, 1024, 1024, 1, 64, 64, "uint8"): assert log in out - for log in log_strings(1, 1, 3, 1, 512, 512, 1, 1, 1, 256, 256, "float64"): + for log in log_strings(1, 3, 512, 512, 1, 64, 64, "int8"): assert log in out # from info's print of omero metadata diff --git a/tests/test_scaler.py b/tests/test_scaler.py index aaf280ab..3fa8a3af 100644 --- a/tests/test_scaler.py +++ b/tests/test_scaler.py @@ -5,7 +5,14 @@ class TestScaler: - @pytest.fixture(params=((1, 2, 1, 256, 256),)) + @pytest.fixture( + params=( + (1, 2, 1, 256, 256), + (3, 512, 512), + (256, 256), + ), + ids=["5D", "3D", "2D"], + ) def shape(self, request): return request.param diff --git a/tests/test_upgrade.py b/tests/test_upgrade.py index 7243e12f..cda9bfb7 100644 --- a/tests/test_upgrade.py +++ b/tests/test_upgrade.py @@ -54,7 +54,10 @@ def test_pre_created(self, request, path, version): def test_newly_created(self, version): shape = (1, 1, 1, 8, 8) data = self.create_data(shape, version) - write_image(image=data, group=self.group, scaler=None, fmt=version) + axes = None + if version not in ("0.1", "0.2"): + axes = "tczyx" + write_image(image=data, group=self.group, scaler=None, fmt=version, axes=axes) self.assert_data(f"{self.path}/test", shape=shape, fmt=version) def test_requested_no_upgrade(self): diff --git a/tests/test_writer.py b/tests/test_writer.py index b65a5e25..8d132259 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -23,7 +23,14 @@ def create_data(self, shape, dtype=np.uint8, mean_val=10): rng = np.random.default_rng(0) return rng.poisson(mean_val, size=shape).astype(dtype) - @pytest.fixture(params=((1, 2, 1, 256, 256),)) + @pytest.fixture( + params=( + (1, 2, 1, 256, 256), + (3, 512, 512), + (256, 256), + ), + ids=["5D", "3D", "2D"], + ) def shape(self, request): return request.param @@ -49,18 +56,24 @@ def scaler(self, request): def test_writer(self, shape, scaler, format_version): data = self.create_data(shape) + version = format_version() + axes = "tczyx"[-len(shape) :] write_image( image=data, group=self.group, chunks=(128, 128), scaler=scaler, - fmt=format_version(), + fmt=version, + axes=axes, ) # Verify reader = Reader(parse_url(f"{self.path}/test")) node = list(reader())[0] assert Multiscales.matches(node.zarr) - assert node.data[0].shape == shape - assert node.data[0].chunks == ((1,), (2,), (1,), (128, 128), (128, 128)) + if version.version not in ("0.1", "0.2"): + # v0.1 and v0.2 MUST be 5D + assert node.data[0].shape == shape + else: + assert node.data[0].ndim == 5 assert np.allclose(data, node.data[0][...].compute())