Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write scale less than 5D #114

Merged
merged 10 commits into from
Oct 12, 2021
63 changes: 47 additions & 16 deletions ome_zarr/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Functions for generating synthetic data."""
from random import randrange
from typing import Callable, List, Tuple
from typing import Callable, List, Tuple, Union

import numpy as np
import zarr
Expand Down Expand Up @@ -32,23 +32,26 @@ def coins() -> Tuple[List, List]:
pyramid = list(reversed([zoom(image, 2 ** i, order=3) for i in range(4)]))
labels = list(reversed([zoom(label_image, 2 ** i, order=0) for i in range(4)]))

pyramid = [rgb_to_5d(layer) for layer in pyramid]
labels = [rgb_to_5d(layer) for layer in labels]
return pyramid, labels


def astronaut() -> Tuple[List, List]:
"""Sample data from skimage."""
scaler = Scaler()

pixels = rgb_to_5d(np.tile(data.astronaut(), (2, 2, 1)))
astro = data.astronaut()
red = astro[:, :, 0]
green = astro[:, :, 1]
blue = astro[:, :, 2]
astro = np.array([red, green, blue])
pixels = np.tile(astro, (1, 2, 2))
pyramid = scaler.nearest(pixels)

shape = list(pyramid[0].shape)
shape[1] = 1
label = np.zeros(shape)
make_circle(100, 100, 1, label[0, 0, 0, 200:300, 200:300])
make_circle(150, 150, 2, label[0, 0, 0, 250:400, 250:400])
c, y, x = shape
label = np.zeros((y, x), dtype=np.int8)
make_circle(100, 100, 1, label[200:300, 200:300])
make_circle(150, 150, 2, label[250:400, 250:400])
labels = scaler.nearest(label)

return pyramid, labels
Expand Down Expand Up @@ -100,38 +103,63 @@ def create_zarr(
method: Callable[..., Tuple[List, List]] = coins,
label_name: str = "coins",
fmt: Format = CurrentFormat(),
chunks: Union[Tuple, List] = None,
) -> None:
"""Generate a synthetic image pyramid with labels."""
pyramid, labels = method()

loc = parse_url(zarr_directory, mode="w")
assert loc
grp = zarr.group(loc.store)
write_multiscale(pyramid, grp)

if pyramid[0].shape[CHANNEL_DIMENSION] == 1:
axes = None
size_c = 1
if fmt.version not in ("0.1", "0.2"):
if pyramid[0].ndim == 3:
axes = "cyx"
size_c = 3
else:
axes = "tczyx"[-pyramid[0].ndim :]
size_c = 1
else:
# v0.1 and v0.2 must be 5D
pyramid = [rgb_to_5d(layer) for layer in pyramid]
if labels:
labels = [rgb_to_5d(layer) for layer in labels]
size_c = pyramid[0].shape[CHANNEL_DIMENSION]

if chunks is None:
# Use smallest pyramid as chunk size...
chunks = list(pyramid[-1].shape)
# setting any z, c, t sizes to 1
for zct in range(3):
if zct + 2 < len(chunks):
chunks[zct] = 1

write_multiscale(pyramid, grp, chunks=tuple(chunks), axes=axes)

if size_c == 1:
image_data = {
"channels": [{"window": {"start": 0, "end": 1}}],
"channels": [{"window": {"start": 0, "end": 255}, "color": "FF0000"}],
"rdefs": {"model": "greyscale"},
}
else:
image_data = {
"channels": [
{
"color": "FF0000",
"window": {"start": 0, "end": 1},
"window": {"start": 0, "end": 255},
"label": "Red",
"active": True,
},
{
"color": "00FF00",
"window": {"start": 0, "end": 1},
"window": {"start": 0, "end": 255},
"label": "Green",
"active": True,
},
{
"color": "0000FF",
"window": {"start": 0, "end": 1},
"window": {"start": 0, "end": 255},
"label": "Blue",
"active": True,
},
Expand All @@ -146,7 +174,10 @@ def create_zarr(
labels_grp.attrs["labels"] = [label_name]

label_grp = labels_grp.create_group(label_name)
write_multiscale(labels, label_grp)
if axes is not None:
# remove channel axis for masks
axes = axes.replace("c", "")
write_multiscale(labels, label_grp, axes=axes)

colors = []
properties = []
Expand Down
46 changes: 28 additions & 18 deletions ome_zarr/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,10 @@ def laplacian(self, base: np.ndarray) -> List[np.ndarray]:
def local_mean(self, base: np.ndarray) -> List[np.ndarray]:
"""Downsample using :func:`skimage.transform.downscale_local_mean`."""
rv = [base]
# FIXME: fix hard-coding
rv = [base]
stack_dims = base.ndim - 2
factors = (*(1,) * stack_dims, *(self.downscale, self.downscale))
for i in range(self.max_layer):
rv.append(
downscale_local_mean(
rv[-1], factors=(1, 1, 1, self.downscale, self.downscale)
)
)
rv.append(downscale_local_mean(rv[-1], factors=factors))
return rv

def zoom(self, base: np.ndarray) -> List[np.ndarray]:
Expand All @@ -198,23 +194,37 @@ def _by_plane(
func: Callable[[np.ndarray, int, int], np.ndarray],
) -> np.ndarray:
"""Loop over 3 of the 5 dimensions and apply the func transform."""
assert 5 == len(base.shape)

rv = [base]
for i in range(self.max_layer):
fiveD = rv[-1]
# FIXME: fix hard-coding of dimensions
T, C, Z, Y, X = fiveD.shape
stack_to_scale = rv[-1]
shape_5d = (*(1,) * (5 - stack_to_scale.ndim), *stack_to_scale.shape)
T, C, Z, Y, X = shape_5d

# If our data is already 2D, simply resize and add to pyramid
if stack_to_scale.ndim == 2:
rv.append(func(stack_to_scale, Y, X))
continue

smaller = None
# stack_dims is any dims over 2D
stack_dims = stack_to_scale.ndim - 2
new_stack = None
for t in range(T):
for c in range(C):
for z in range(Z):
out = func(fiveD[t][c][z][:], Y, X)
if smaller is None:
smaller = np.zeros(
(T, C, Z, out.shape[0], out.shape[1]), dtype=base.dtype
dims_to_slice = (t, c, z)[-stack_dims:]
# slice nd down to 2D
plane = stack_to_scale[(dims_to_slice)][:]
out = func(plane, Y, X)
# first iteration of loop creates the new nd stack
if new_stack is None:
zct_dims = shape_5d[:-2]
shape_dims = zct_dims[-stack_dims:]
new_stack = np.zeros(
(*shape_dims, out.shape[0], out.shape[1]),
dtype=base.dtype,
)
smaller[t][c][z] = out
rv.append(smaller)
# insert resized plane into the stack at correct indices
new_stack[(dims_to_slice)] = out
rv.append(new_stack)
return rv
55 changes: 49 additions & 6 deletions ome_zarr/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,57 @@ def write_multiscale(
group: zarr.Group,
chunks: Union[Tuple[Any, ...], int] = None,
fmt: Format = CurrentFormat(),
axes: Union[str, List[str]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly not an action for this PR but a general thought: as the specification gets refined and new concepts gets introduced (thinking concretely of the ongoing transformation proposal), there might be a trade-off between adding every new key as an extra parameter vs e.g. passing some form of dictionary of extra metadata which will be validated depending on the specification.

) -> None:
"""
Write a pyramid with multiscale metadata to disk.

Parameters
----------
TODO:
pyramid: List of np.ndarray
the image data to save. Largest level first
group: zarr.Group
the group within the zarr store to store the data in
chunks: int or tuple of ints,
size of the saved chunks to store the image
fmt: Format
The format of the ome_zarr data which should be used.
Defaults to the most current.
axes: str or list of str
the names of the axes. e.g. "tczyx". Not needed for v0.1 or v0.2
or for v0.3 if 2D or 5D. Otherwise this must be provided
"""

dims = len(pyramid[0].shape)
if fmt.version not in ("0.1", "0.2"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the axes guessing below definitely applies to 0.3, if the assumptions gets further relax, there will be an outstanding TODO of restricting this to the relevant versions of the specification.

if axes is None:
if dims == 2:
axes = ["y", "x"]
elif dims == 5:
axes = ["t", "c", "z", "y", "x"]
else:
raise ValueError(
"axes must be provided. Can't be guessed for 3D or 4D data"
)
if len(axes) != dims:
raise ValueError("axes length must match number of dimensions")

if isinstance(axes, str):
axes = list(axes)

for dim in axes:
if dim not in ("t", "c", "z", "y", "x"):
raise ValueError("axes must each be one of 'x', 'y', 'z', 'c' or 't'")

paths = []
for path, dataset in enumerate(pyramid):
# TODO: chunks here could be different per layer
group.create_dataset(str(path), data=dataset, chunks=chunks)
paths.append({"path": str(path)})

multiscales = [{"version": fmt.version, "datasets": paths}]
if axes is not None:
multiscales[0]["axes"] = axes
group.attrs["multiscales"] = multiscales


Expand All @@ -45,6 +80,7 @@ def write_image(
byte_order: Union[str, List[str]] = "tczyx",
scaler: Scaler = Scaler(),
fmt: Format = CurrentFormat(),
axes: Union[str, List[str]] = None,
**metadata: JSONDict,
) -> None:
"""Writes an image to the zarr store according to ome-zarr specification
Expand All @@ -67,24 +103,29 @@ def write_image(
fmt: Format
The format of the ome_zarr data which should be used.
Defaults to the most current.
axes: str or list of str
the names of the axes. e.g. "tczyx". Not needed for v0.1 or v0.2
or for v0.3 if 2D or 5D. Otherwise this must be provided
"""

if image.ndim > 5:
raise ValueError("Only images of 5D or less are supported")

shape_5d: Tuple[Any, ...] = (*(1,) * (5 - image.ndim), *image.shape)
image = image.reshape(shape_5d)
if fmt.version in ("0.1", "0.2"):
# v0.1 and v0.2 are strictly 5D
shape_5d: Tuple[Any, ...] = (*(1,) * (5 - image.ndim), *image.shape)
image = image.reshape(shape_5d)

if chunks is not None:
chunks = _retuple(chunks, shape_5d)
chunks = _retuple(chunks, image.shape)

if scaler is not None:
image = scaler.nearest(image)
else:
LOGGER.debug("disabling pyramid")
image = [image]

write_multiscale(image, group, chunks=chunks, fmt=fmt)
write_multiscale(image, group, chunks=chunks, fmt=fmt, axes=axes)
group.attrs.update(metadata)


Expand All @@ -98,4 +139,6 @@ def _retuple(
else:
_chunks = chunks

return (*shape[: (5 - len(_chunks))], *_chunks)
dims_to_add = len(shape) - len(_chunks)

return (*shape[:dims_to_add], *_chunks)
10 changes: 5 additions & 5 deletions tests/test_ome_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from ome_zarr.utils import download, info


def log_strings(idx, t, c, z, y, x, ct, cc, cz, cy, cx, dtype):
def log_strings(idx, c, y, x, cc, cy, cx, dtype):
yield f"resolution: {idx}"
yield f" - shape ('t', 'c', 'z', 'y', 'x') = ({t}, {c}, {z}, {y}, {x})"
yield f" - chunks = ['{ct}', '{cc}', '{cz}', '{cx}', '{cy}']"
yield f" - shape ('c', 'y', 'x') = ({c}, {y}, {x})"
yield f" - chunks = ['{cc}', '{cx}', '{cy}']"
yield f" - dtype = {dtype}"


Expand All @@ -21,9 +21,9 @@ def initdir(self, tmpdir):
create_zarr(str(self.path), method=astronaut)

def check_info_stdout(self, out):
for log in log_strings(0, 1, 3, 1, 1024, 1024, 1, 1, 1, 256, 256, "float64"):
for log in log_strings(0, 3, 1024, 1024, 1, 64, 64, "uint8"):
assert log in out
for log in log_strings(1, 1, 3, 1, 512, 512, 1, 1, 1, 256, 256, "float64"):
for log in log_strings(1, 3, 512, 512, 1, 64, 64, "int8"):
assert log in out

# from info's print of omero metadata
Expand Down
9 changes: 8 additions & 1 deletion tests/test_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@


class TestScaler:
@pytest.fixture(params=((1, 2, 1, 256, 256),))
@pytest.fixture(
params=(
(1, 2, 1, 256, 256),
(3, 512, 512),
(256, 256),
),
ids=["5D", "3D", "2D"],
)
def shape(self, request):
return request.param

Expand Down
5 changes: 4 additions & 1 deletion tests/test_upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def test_pre_created(self, request, path, version):
def test_newly_created(self, version):
shape = (1, 1, 1, 8, 8)
data = self.create_data(shape, version)
write_image(image=data, group=self.group, scaler=None, fmt=version)
axes = None
if version not in ("0.1", "0.2"):
axes = "tczyx"
write_image(image=data, group=self.group, scaler=None, fmt=version, axes=axes)
self.assert_data(f"{self.path}/test", shape=shape, fmt=version)

def test_requested_no_upgrade(self):
Expand Down
21 changes: 17 additions & 4 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@ def create_data(self, shape, dtype=np.uint8, mean_val=10):
rng = np.random.default_rng(0)
return rng.poisson(mean_val, size=shape).astype(dtype)

@pytest.fixture(params=((1, 2, 1, 256, 256),))
@pytest.fixture(
params=(
(1, 2, 1, 256, 256),
(3, 512, 512),
(256, 256),
),
ids=["5D", "3D", "2D"],
)
def shape(self, request):
return request.param

Expand All @@ -49,18 +56,24 @@ def scaler(self, request):
def test_writer(self, shape, scaler, format_version):

data = self.create_data(shape)
version = format_version()
axes = "tczyx"[-len(shape) :]
write_image(
image=data,
group=self.group,
chunks=(128, 128),
scaler=scaler,
fmt=format_version(),
fmt=version,
axes=axes,
)

# Verify
reader = Reader(parse_url(f"{self.path}/test"))
node = list(reader())[0]
assert Multiscales.matches(node.zarr)
assert node.data[0].shape == shape
assert node.data[0].chunks == ((1,), (2,), (1,), (128, 128), (128, 128))
if version.version not in ("0.1", "0.2"):
# v0.1 and v0.2 MUST be 5D
assert node.data[0].shape == shape
else:
assert node.data[0].ndim == 5
assert np.allclose(data, node.data[0][...].compute())