Skip to content

Commit

Permalink
S3 Support (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar authored Apr 16, 2024
2 parents a7adab7 + e3afa41 commit 9117256
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
20 changes: 18 additions & 2 deletions dacapo/experiments/datasplits/datasets/arrays/zarr_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lazy_property
import numpy as np
import zarr
from zarr.n5 import N5FSStore

from collections import OrderedDict
import logging
Expand Down Expand Up @@ -103,10 +104,11 @@ def __init__(self, array_config):
self.name = array_config.name
self.file_name = array_config.file_name
self.dataset = array_config.dataset

self._mode = array_config.mode
self._attributes = self.data.attrs
self._axes = array_config._axes
self.snap_to_grid = array_config.snap_to_grid


def __str__(self):
"""
Expand Down Expand Up @@ -142,6 +144,14 @@ def __repr__(self):
"""
return f"ZarrArray({self.file_name}, {self.dataset})"

@property
def mode(self):
if not hasattr(self, "_mode"):
self._mode = "a"
if self._mode not in ["r", "w", "a"]:
raise ValueError(f"Mode {self._mode} not in ['r', 'w', 'a']")
return self._mode

@property
def attrs(self):
Expand Down Expand Up @@ -358,7 +368,12 @@ def data(self) -> Any:
Notes:
This method is used to return the data of the array.
"""
zarr_container = zarr.open(str(self.file_name))
file_name = str(self.file_name)
# Zarr library does not detect the store for N5 datasets
if file_name.endswith(".n5"):
zarr_container = zarr.open(N5FSStore(str(file_name)), mode=self.mode)
else:
zarr_container = zarr.open(str(file_name), mode=self.mode)
return zarr_container[self.dataset]

def __getitem__(self, roi: Roi) -> np.ndarray:
Expand Down Expand Up @@ -406,6 +421,7 @@ def create_from_array_identifier(
num_channels,
voxel_size,
dtype,
mode = "a",
write_size=None,
name=None,
overwrite=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class ZarrArrayConfig(ArrayConfig):
_axes: Optional[List[str]] = attr.ib(
default=None, metadata={"help_text": "The axes of your data!"}
)
mode: Optional[str] = attr.ib(
default="a", metadata={"help_text": "The access mode!"}
)

def verify(self) -> Tuple[bool, str]:
"""
Expand Down
18 changes: 13 additions & 5 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dacapo.experiments.tasks import TaskConfig
from pathlib import Path
from upath import UPath as Path
from typing import List
from enum import Enum, EnumMeta
from funlib.geometry import Coordinate
from typing import Union, Optional

import zarr
from zarr.n5 import N5FSStore
from dacapo.experiments.datasplits.datasets.arrays import (
ZarrArrayConfig,
ZarrArray,
Expand All @@ -21,7 +22,7 @@
logger = logging.getLogger(__name__)


def is_zarr_group(file_name: str, dataset: str):
def is_zarr_group(file_name: Path, dataset: str):
"""
Check if the dataset is a Zarr group. If the dataset is a Zarr group, it will return True, otherwise False.
Expand All @@ -40,7 +41,10 @@ def is_zarr_group(file_name: str, dataset: str):
Notes:
This function is used to check if the dataset is a Zarr group.
"""
zarr_file = zarr.open(str(file_name))
if file_name.suffix == ".n5":
zarr_file = zarr.open(N5FSStore(str(file_name)), mode="r")
else:
zarr_file = zarr.open(str(file_name), mode="r")
return isinstance(zarr_file[dataset], zarr.hierarchy.Group)


Expand Down Expand Up @@ -121,6 +125,7 @@ def get_right_resolution_array_config(
file_name=container,
dataset=str(current_dataset_path),
snap_to_grid=target_resolution,
mode = "r"
)
zarr_array = ZarrArray(zarr_config)
while (
Expand All @@ -133,6 +138,7 @@ def get_right_resolution_array_config(
file_name=container,
dataset=str(Path(dataset, f"s{level}")),
snap_to_grid=target_resolution,
mode = "r"
)

zarr_array = ZarrArray(zarr_config)
Expand Down Expand Up @@ -762,7 +768,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec):
# f"Processing raw_container:{raw_container} raw_dataset:{raw_dataset} gt_path:{gt_path} gt_dataset:{gt_dataset}"
# )

if is_zarr_group(str(raw_container), raw_dataset):
if is_zarr_group(raw_container, raw_dataset):
raw_config = get_right_resolution_array_config(
raw_container, raw_dataset, self.input_resolution, "raw"
)
Expand All @@ -772,6 +778,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec):
name=f"raw_{raw_container.stem}_uint8",
file_name=raw_container,
dataset=raw_dataset,
mode="r",
),
self.input_resolution,
"raw",
Expand All @@ -789,7 +796,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec):
raise FileNotFoundError(
f"GT path {gt_path/current_class_dataset} does not exist."
)
if is_zarr_group(str(gt_path), current_class_dataset):
if is_zarr_group(gt_path, current_class_dataset):
gt_config = get_right_resolution_array_config(
gt_path, current_class_dataset, self.output_resolution, "gt"
)
Expand All @@ -799,6 +806,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec):
name=f"gt_{gt_path.stem}_{current_class_dataset}_uint8",
file_name=gt_path,
dataset=current_class_dataset,
mode="r",
),
self.output_resolution,
"gt",
Expand Down

0 comments on commit 9117256

Please sign in to comment.