Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zarr: Optimize region="auto" detection #8997

Merged
merged 4 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/user-guide/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ and then calling ``to_zarr`` with ``compute=False`` to write only metadata
# The values of this dask array are entirely irrelevant; only the dtype,
# shape and chunks are used
dummies = dask.array.zeros(30, chunks=10)
ds = xr.Dataset({"foo": ("x", dummies)})
ds = xr.Dataset({"foo": ("x", dummies)}, coords={"x": np.arange(30)})
path = "path/to/directory.zarr"
# Now we write the metadata without computing any array values
ds.to_zarr(path, compute=False)
Expand All @@ -890,7 +890,7 @@ where the data should be written (in index space, not label space), e.g.,

# For convenience, we'll slice a single dataset, but in the real use-case
# we would create them separately possibly even from separate processes.
ds = xr.Dataset({"foo": ("x", np.arange(30))})
ds = xr.Dataset({"foo": ("x", np.arange(30))}, coords={"x": np.arange(30)})
# Any of the following region specifications are valid
ds.isel(x=slice(0, 10)).to_zarr(path, region="auto")
ds.isel(x=slice(10, 20)).to_zarr(path, region={"x": "auto"})
Copy link
Contributor Author

@dcherian dcherian May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This last line does not do what it looks like it's doing if there are no indexes!

Expand Down
115 changes: 10 additions & 105 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
_normalize_path,
)
from xarray.backends.locks import _get_scheduler
from xarray.backends.zarr import open_zarr
from xarray.core import indexing
from xarray.core.combine import (
_infer_concat_order_from_positions,
Expand Down Expand Up @@ -1522,92 +1521,6 @@ def save_mfdataset(
)


def _auto_detect_region(ds_new, ds_orig, dim):
# Create a mapping array of coordinates to indices on the original array
coord = ds_orig[dim]
da_map = DataArray(np.arange(coord.size), coords={dim: coord})

try:
da_idxs = da_map.sel({dim: ds_new[dim]})
except KeyError as e:
if "not all values found" in str(e):
raise KeyError(
f"Not all values of coordinate '{dim}' in the new array were"
" found in the original store. Writing to a zarr region slice"
" requires that no dimensions or metadata are changed by the write."
)
else:
raise e

if (da_idxs.diff(dim) != 1).any():
raise ValueError(
f"The auto-detected region of coordinate '{dim}' for writing new data"
" to the original store had non-contiguous indices. Writing to a zarr"
" region slice requires that the new data constitute a contiguous subset"
" of the original store."
)

dim_slice = slice(da_idxs.values[0], da_idxs.values[-1] + 1)

return dim_slice


def _auto_detect_regions(ds, region, open_kwargs):
ds_original = open_zarr(**open_kwargs)
for key, val in region.items():
if val == "auto":
region[key] = _auto_detect_region(ds, ds_original, key)
return region


def _validate_and_autodetect_region(ds, region, mode, open_kwargs) -> dict[str, slice]:
if region == "auto":
region = {dim: "auto" for dim in ds.dims}

if not isinstance(region, dict):
raise TypeError(f"``region`` must be a dict, got {type(region)}")

if any(v == "auto" for v in region.values()):
if mode != "r+":
raise ValueError(
f"``mode`` must be 'r+' when using ``region='auto'``, got {mode}"
)
region = _auto_detect_regions(ds, region, open_kwargs)

for k, v in region.items():
if k not in ds.dims:
raise ValueError(
f"all keys in ``region`` are not in Dataset dimensions, got "
f"{list(region)} and {list(ds.dims)}"
)
if not isinstance(v, slice):
raise TypeError(
"all values in ``region`` must be slice objects, got "
f"region={region}"
)
if v.step not in {1, None}:
raise ValueError(
"step on all slices in ``region`` must be 1 or None, got "
f"region={region}"
)

non_matching_vars = [
k for k, v in ds.variables.items() if not set(region).intersection(v.dims)
]
if non_matching_vars:
raise ValueError(
f"when setting `region` explicitly in to_zarr(), all "
f"variables in the dataset to write must have at least "
f"one dimension in common with the region's dimensions "
f"{list(region.keys())}, but that is not "
f"the case for some variables here. To drop these variables "
f"from this dataset before exporting to zarr, write: "
f".drop_vars({non_matching_vars!r})"
)

return region


def _validate_datatypes_for_zarr_append(zstore, dataset):
"""If variable exists in the store, confirm dtype of the data to append is compatible with
existing dtype.
Expand Down Expand Up @@ -1768,24 +1681,6 @@ def to_zarr(
# validate Dataset keys, DataArray names
_validate_dataset_names(dataset)

if region is not None:
open_kwargs = dict(
store=store,
synchronizer=synchronizer,
group=group,
consolidated=consolidated,
storage_options=storage_options,
zarr_version=zarr_version,
)
region = _validate_and_autodetect_region(dataset, region, mode, open_kwargs)
# can't modify indexed with region writes
dataset = dataset.drop_vars(dataset.indexes)
if append_dim is not None and append_dim in region:
raise ValueError(
f"cannot list the same dimension in both ``append_dim`` and "
f"``region`` with to_zarr(), got {append_dim} in both"
)

if zarr_version is None:
# default to 2 if store doesn't specify it's version (e.g. a path)
zarr_version = int(getattr(store, "_store_version", 2))
Expand Down Expand Up @@ -1815,6 +1710,16 @@ def to_zarr(
write_empty=write_empty_chunks,
)

if region is not None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved down so we only open the Zarr store once.

zstore._validate_and_autodetect_region(dataset)
# can't modify indexed with region writes
dataset = dataset.drop_vars(dataset.indexes)
if append_dim is not None and append_dim in region:
raise ValueError(
f"cannot list the same dimension in both ``append_dim`` and "
f"``region`` with to_zarr(), got {append_dim} in both"
)

if mode in ["a", "a-", "r+"]:
_validate_datatypes_for_zarr_append(zstore, dataset)
if append_dim is not None:
Expand Down
96 changes: 89 additions & 7 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd

from xarray import coding, conventions
from xarray.backends.common import (
Expand Down Expand Up @@ -509,7 +510,9 @@ def ds(self):
# TODO: consider deprecating this in favor of zarr_group
return self.zarr_group

def open_store_variable(self, name, zarr_array):
def open_store_variable(self, name, zarr_array=None):
if zarr_array is None:
zarr_array = self.zarr_group[name]
data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array))
try_nczarr = self._mode == "r"
dimensions, attributes = _get_zarr_dims_and_attrs(
Expand Down Expand Up @@ -623,11 +626,7 @@ def store(
# avoid needing to load index variables into memory.
# TODO: consider making loading indexes lazy again?
existing_vars, _, _ = conventions.decode_cf_variables(
{
k: v
for k, v in self.get_variables().items()
if k in existing_variable_names
},
{k: self.open_store_variable(name=k) for k in existing_variable_names},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just open the needed variables instead of opening all of them.

self.get_attrs(),
)
# Modified variables must use the same encoding as the store.
Expand Down Expand Up @@ -796,10 +795,93 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
region = tuple(write_region[dim] for dim in dims)
writer.add(v.data, zarr_array, region)

def close(self):
def close(self) -> None:
if self._close_store_on_close:
self.zarr_group.store.close()

def _auto_detect_regions(self, ds, region):
for dim, val in region.items():
if val != "auto":
continue

if dim not in ds._variables:
# unindexed dimension
region[dim] = slice(0, ds.sizes[dim])
continue

variable = conventions.decode_cf_variable(
dim, self.open_store_variable(dim).compute()
)
assert variable.dims == (dim,)
index = pd.Index(variable.data)
idxs = index.get_indexer(ds[dim].data)
Comment on lines +812 to +817
Copy link
Contributor Author

@dcherian dcherian May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 812-817: This is the main logic change.

if any(idxs == -1):
raise KeyError(
f"Not all values of coordinate '{dim}' in the new array were"
" found in the original store. Writing to a zarr region slice"
" requires that no dimensions or metadata are changed by the write."
)

if (np.diff(idxs) != 1).any():
raise ValueError(
f"The auto-detected region of coordinate '{dim}' for writing new data"
" to the original store had non-contiguous indices. Writing to a zarr"
" region slice requires that the new data constitute a contiguous subset"
" of the original store."
)
region[dim] = slice(idxs[0], idxs[-1] + 1)
return region

def _validate_and_autodetect_region(self, ds) -> None:
region = self._write_region

if region == "auto":
region = {dim: "auto" for dim in ds.dims}

if not isinstance(region, dict):
raise TypeError(f"``region`` must be a dict, got {type(region)}")
if any(v == "auto" for v in region.values()):
if self._mode != "r+":
raise ValueError(
f"``mode`` must be 'r+' when using ``region='auto'``, got {self._mode!r}"
)
region = self._auto_detect_regions(ds, region)

# validate before attempting to auto-detect since the auto-detection
# should always return a valid slice.
for k, v in region.items():
if k not in ds.dims:
raise ValueError(
f"all keys in ``region`` are not in Dataset dimensions, got "
f"{list(region)} and {list(ds.dims)}"
)
if not isinstance(v, slice):
raise TypeError(
"all values in ``region`` must be slice objects, got "
f"region={region}"
)
if v.step not in {1, None}:
raise ValueError(
"step on all slices in ``region`` must be 1 or None, got "
f"region={region}"
)

non_matching_vars = [
k for k, v in ds.variables.items() if not set(region).intersection(v.dims)
]
if non_matching_vars:
raise ValueError(
f"when setting `region` explicitly in to_zarr(), all "
f"variables in the dataset to write must have at least "
f"one dimension in common with the region's dimensions "
f"{list(region.keys())}, but that is not "
f"the case for some variables here. To drop these variables "
f"from this dataset before exporting to zarr, write: "
f".drop_vars({non_matching_vars!r})"
)

self._write_region = region


def open_zarr(
store,
Expand Down
Loading