diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 2399b5b17..c976f7d96 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -9,6 +9,7 @@ import lazy_property import numpy as np import zarr +from zarr.n5 import N5FSStore from collections import OrderedDict import logging @@ -103,10 +104,11 @@ def __init__(self, array_config): self.name = array_config.name self.file_name = array_config.file_name self.dataset = array_config.dataset - + self._mode = array_config.mode self._attributes = self.data.attrs self._axes = array_config._axes self.snap_to_grid = array_config.snap_to_grid + def __str__(self): """ @@ -142,6 +144,14 @@ def __repr__(self): """ return f"ZarrArray({self.file_name}, {self.dataset})" + + @property + def mode(self): + if not hasattr(self, "_mode"): + self._mode = "a" + if self._mode not in ["r", "w", "a"]: + raise ValueError(f"Mode {self._mode} not in ['r', 'w', 'a']") + return self._mode @property def attrs(self): @@ -358,7 +368,12 @@ def data(self) -> Any: Notes: This method is used to return the data of the array. """ - zarr_container = zarr.open(str(self.file_name)) + file_name = str(self.file_name) + # Zarr library does not detect the store for N5 datasets + if file_name.endswith(".n5"): + zarr_container = zarr.open(N5FSStore(str(file_name)), mode=self.mode) + else: + zarr_container = zarr.open(str(file_name), mode=self.mode) return zarr_container[self.dataset] def __getitem__(self, roi: Roi) -> np.ndarray: @@ -406,6 +421,7 @@ def create_from_array_identifier( num_channels, voxel_size, dtype, + mode = "a", write_size=None, name=None, overwrite=False, diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array_config.py index af6f9dd20..b67717647 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array_config.py @@ -55,6 +55,9 @@ class ZarrArrayConfig(ArrayConfig): _axes: Optional[List[str]] = attr.ib( default=None, metadata={"help_text": "The axes of your data!"} ) + mode: Optional[str] = attr.ib( + default="a", metadata={"help_text": "The access mode!"} + ) def verify(self) -> Tuple[bool, str]: """ diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index fac37ed7e..54df330b6 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -1,11 +1,12 @@ from dacapo.experiments.tasks import TaskConfig -from pathlib import Path +from upath import UPath as Path from typing import List from enum import Enum, EnumMeta from funlib.geometry import Coordinate from typing import Union, Optional import zarr +from zarr.n5 import N5FSStore from dacapo.experiments.datasplits.datasets.arrays import ( ZarrArrayConfig, ZarrArray, @@ -21,7 +22,7 @@ logger = logging.getLogger(__name__) -def is_zarr_group(file_name: str, dataset: str): +def is_zarr_group(file_name: Path, dataset: str): """ Check if the dataset is a Zarr group. If the dataset is a Zarr group, it will return True, otherwise False. @@ -40,7 +41,10 @@ def is_zarr_group(file_name: str, dataset: str): Notes: This function is used to check if the dataset is a Zarr group. """ - zarr_file = zarr.open(str(file_name)) + if file_name.suffix == ".n5": + zarr_file = zarr.open(N5FSStore(str(file_name)), mode="r") + else: + zarr_file = zarr.open(str(file_name), mode="r") return isinstance(zarr_file[dataset], zarr.hierarchy.Group) @@ -121,6 +125,7 @@ def get_right_resolution_array_config( file_name=container, dataset=str(current_dataset_path), snap_to_grid=target_resolution, + mode = "r" ) zarr_array = ZarrArray(zarr_config) while ( @@ -133,6 +138,7 @@ def get_right_resolution_array_config( file_name=container, dataset=str(Path(dataset, f"s{level}")), snap_to_grid=target_resolution, + mode = "r" ) zarr_array = ZarrArray(zarr_config) @@ -762,7 +768,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): # f"Processing raw_container:{raw_container} raw_dataset:{raw_dataset} gt_path:{gt_path} gt_dataset:{gt_dataset}" # ) - if is_zarr_group(str(raw_container), raw_dataset): + if is_zarr_group(raw_container, raw_dataset): raw_config = get_right_resolution_array_config( raw_container, raw_dataset, self.input_resolution, "raw" ) @@ -772,6 +778,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): name=f"raw_{raw_container.stem}_uint8", file_name=raw_container, dataset=raw_dataset, + mode="r", ), self.input_resolution, "raw", @@ -789,7 +796,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): raise FileNotFoundError( f"GT path {gt_path/current_class_dataset} does not exist." ) - if is_zarr_group(str(gt_path), current_class_dataset): + if is_zarr_group(gt_path, current_class_dataset): gt_config = get_right_resolution_array_config( gt_path, current_class_dataset, self.output_resolution, "gt" ) @@ -799,6 +806,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): name=f"gt_{gt_path.stem}_{current_class_dataset}_uint8", file_name=gt_path, dataset=current_class_dataset, + mode="r", ), self.output_resolution, "gt",