Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cli rechunking #186

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions iohub/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import click

from iohub._version import __version__
from iohub.cli.rechunk import rechunking
from iohub.convert import TIFFConverter
from iohub.reader import print_info

Expand Down Expand Up @@ -101,3 +102,49 @@ def convert(input, output, format, scale_voxels, grid_layout, chunks):
chunks=chunks,
)
converter.run()


@cli.command()
@click.argument(
"input_zarr",
nargs=-1,
required=True,
type=_DATASET_PATH,
)
@click.option(
"--output",
"-o",
required=True,
type=click.Path(exists=False, resolve_path=True),
help="Output zarr store (/**/converted.zarr)",
)
@click.help_option("-h", "--help")
@click.option(
"--chunks",
"-c",
required=False,
type=(int, int, int),
default=None,
help="New chunksize given as (Z,Y,X) tuple argument. The ZYX chunk size will be limited to 500 MB.",
)
@click.option(
"--num-processes",
"-j",
default=1,
help="Number of simultaneous processes",
required=False,
type=int,
)
def rechunk(input_zarr, output, chunks, num_processes):
"""Rechunks OME-Zarr dataset to input chunk_size"""
rechunking(input_zarr, output, chunks, num_processes)


# if __name__ == "__main__":
# from iohub.cli.rechunk import rechunking

# input = "/hpc/projects/comp.micro/mantis/2023_08_09_HEK_PCNA_H2B/xx-mbl_course_H2B/cropped_dataset_v3_small.zarr"
# output = "./output_test.zarr"
# chunks = (1, 1, 1)
# num_processes = 4
# rechunking(input, output, chunks, num_processes)
169 changes: 169 additions & 0 deletions iohub/cli/rechunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import copy
import itertools
import json
import logging
import multiprocessing as mp
import os
from functools import partial
from glob import glob
from pathlib import Path
from typing import Literal, Tuple, Union

import numpy as np
from dask.array import to_zarr
from natsort import natsorted
from numpy.typing import NDArray
from tqdm import tqdm
from tqdm.contrib.itertools import product

from iohub.ngff import Plate, Position, TransformationMeta, open_ome_zarr

MAX_CHUNK_SIZE = 500e6 # in bytes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @edyoshikun , I bumped into this because I'm having issues with chunking when copying data from Daxi.

MAX_CHUNK_SIZE should be a constraint, except for the 2147483647 bytes limit of zarr.
Otherwise, it decreases iohub flexibility.


# Borrowed from mantis
def get_output_paths(
input_paths: list[Path], output_zarr_path: Path
) -> list[Path]:
"""Generates a mirrored output path list given an input list of positions"""
list_output_path = []
for path in input_paths:
# Select the Row/Column/FOV parts of input path
path_strings = Path(path).parts[-3:]
# Append the same Row/Column/FOV to the output zarr path
list_output_path.append(Path(output_zarr_path, *path_strings))
return list_output_path


# Borrowed from mantis
def create_empty_zarr(
position_paths: list[Path],
output_path: Path,
output_zyx_shape: Tuple[int],
chunk_zyx_shape: Tuple[int] = None,
voxel_size: Tuple[int, float] = (1, 1, 1),
) -> None:
"""Create an empty zarr store mirroring another store"""
DTYPE = np.float32
bytes_per_pixel = np.dtype(DTYPE).itemsize

# Load the first position to infer dataset information
input_dataset = open_ome_zarr(str(position_paths[0]), mode="r")
T, C, Z, Y, X = input_dataset.data.shape

logging.info("Creating empty array...")

# Handle transforms and metadata
transform = TransformationMeta(
type="scale",
scale=2 * (1,) + voxel_size,
)

# Prepare output dataset
channel_names = input_dataset.channel_names

# Output shape based on the type of reconstruction
output_shape = (T, len(channel_names)) + output_zyx_shape
logging.info(f"Number of positions: {len(position_paths)}")
logging.info(f"Output shape: {output_shape}")

# Create output dataset
output_dataset = open_ome_zarr(
output_path, layout="hcs", mode="w", channel_names=channel_names
)
if chunk_zyx_shape is None:
chunk_zyx_shape = list(output_zyx_shape)
# chunk_zyx_shape[-3] > 1 ensures while loop will not stall if single
# XY image is larger than MAX_CHUNK_SIZE
while (
chunk_zyx_shape[-3] > 1
and np.prod(chunk_zyx_shape) * bytes_per_pixel > MAX_CHUNK_SIZE
):
chunk_zyx_shape[-3] = np.ceil(chunk_zyx_shape[-3] / 2).astype(int)
chunk_zyx_shape = tuple(chunk_zyx_shape)

chunk_size = 2 * (1,) + chunk_zyx_shape
logging.info(f"Chunk size: {chunk_size}")

# This takes care of the logic for single position or multiple position by wildcards
for path in position_paths:
path_strings = Path(path).parts[-3:]
pos = output_dataset.create_position(
str(path_strings[0]), str(path_strings[1]), str(path_strings[2])
)

_ = pos.create_zeros(
name="0",
shape=output_shape,
chunks=chunk_size,
dtype=DTYPE,
transform=[transform],
)
input_dataset.close()


def copy_n_paste(
position: Position,
output_path: Path,
t_idx: int,
c_idx: int,
) -> None:
"""Load a zyx array from a Position object, apply a transformation and save the result to file"""
print(f"Processing c={c_idx}, t={t_idx}")

data_array = open_ome_zarr(position)
zyx_data = data_array[0][t_idx, c_idx]

with open_ome_zarr(output_path, mode="r+") as output_dataset:
output_dataset[0][t_idx, c_idx] = zyx_data

data_array.close()
logging.info(f"Finished Writing.. c={c_idx}, t={t_idx}")


def rechunking(
input_zarr_path: Path,
output_zarr_path: Path,
chunk_size_zyx: Tuple,
num_processes: int = 1,
):
"""
Rechunk a ome-zarr dataset given the 3D rechunking size (Z,Y,X)
"""
logging.info("Starting Rechunking")
print(input_zarr_path, output_zarr_path, chunk_size_zyx)
assert len(input_zarr_path) == 1

input_zarr_path = input_zarr_path[0]
output_zarr_path = Path(output_zarr_path)

# Check we are given a plate
with open_ome_zarr(input_zarr_path) as plate:
assert isinstance(plate, Plate)
# Check chunksize is 3D
chunk_size_zyx = tuple(chunk_size_zyx)
assert len(chunk_size_zyx) == 3

# Convert to wildcard to process and mirror the input zarr
input_zarr_path = input_zarr_path / "*" / "*" / "*"
print(input_zarr_path)
input_zarr_paths = natsorted(glob(str(input_zarr_path)))
input_zarr_paths = [Path(path) for path in input_zarr_paths]
output_zarr_paths = get_output_paths(input_zarr_paths, output_zarr_path)
# Use FOV 0 for output_shape and
with open_ome_zarr(input_zarr_paths[0]) as position:
T, C, Z, Y, X = position[0].shape

# Create empty zarr
create_empty_zarr(
position_paths=input_zarr_paths,
output_path=output_zarr_path,
output_zyx_shape=(Z, Y, X),
chunk_zyx_shape=chunk_size_zyx,
)

for input_path, output_path in zip(input_zarr_paths, output_zarr_paths):
with mp.Pool(num_processes) as p:
p.starmap(
partial(copy_n_paste, input_path, output_path),
itertools.product(range(T), range(C)),
)