Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultiZarrToZarr.append() #404

Merged
merged 9 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Combining
kerchunk.combine.drop

.. autoclass:: kerchunk.combine.MultiZarrToZarr
:members: __init__, translate
:members: __init__, append, translate

.. autofunction:: kerchunk.combine.merge_vars

Expand Down
12 changes: 12 additions & 0 deletions kerchunk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,15 @@
__version__ = "9999"

__all__ = ["__version__"]


def set_reference_filesystem_cachable(cachable=True):
"""While experimenting with kerchunk and referenceFS, it can be convenient to not cache FS instances

You may wish to call this function with ``False`` before any kerchunking session; leaving
the instances cachable (the default) is what end-users will want, since it will be
more efficient.
"""
import fsspec

fsspec.get_filesystem_class("reference").cachable = cachable
138 changes: 119 additions & 19 deletions kerchunk/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class MultiZarrToZarr:
This allows you to supply an fsspec.implementations.reference.LazyReferenceMapper
to write out parquet as the references get filled, or some other dictionary-like class
to customise how references get stored
:param append: bool
If True, will load the references specified by out and add to them rather than starting
from scratch. Assumes the same coordinates are being concatenated.
"""

def __init__(
Expand Down Expand Up @@ -141,8 +144,100 @@ def __init__(
self.preprocess = preprocess
self.postprocess = postprocess
self.out = out or {}
self.coos = None
self.done = set()

@classmethod
def append(
cls,
path,
original_refs,
remote_protocol=None,
remote_options=None,
target_options=None,
**kwargs,
):
"""
Update an existing combined reference set with new references

There are two main usage patterns:

- if the input ``original_refs`` is JSON, the combine happens in memory and the
output should be written to JSON. This could then be optionally converted to parquet in a
separate step
- if ``original_refs`` is a lazy parquet reference set, then it will be amended in-place

If you want to extend JSON references and output to parquet, you must first convert to
parquet in the location you would like the final product to live.

The other arguments should be the same as they were at the creation of the original combined
reference set.

NOTE: if the original combine used a postprocess function, it may be that this process
functions, as the combine is done "before" postprocessing. Functions that only add information
(as as setting attrs) would be OK.

Parameters
----------
path: list of reference sets to add. If remote/target options would be different
to ``original_refs``, these can be as dicts or LazyReferenceMapper instances
original_refs: combined reference set to be extended
remote_protocol, remote_options, target_options: referring to ``original_refs```
kwargs: to MultiZarrToZarr

Returns
-------
MultiZarrToZarr
"""
import xarray as xr

fs = fsspec.filesystem(
"reference",
fo=original_refs,
remote_protocol=remote_protocol,
remote_options=remote_options,
target_options=target_options,
)
ds = xr.open_dataset(
fs.get_mapper(), engine="zarr", backend_kwargs={"consolidated": False}
)
mzz = MultiZarrToZarr(
path,
out=fs.references, # dict or parquet/lazy
remote_protocol=remote_protocol,
remote_options=remote_options,
target_options=target_options,
**kwargs,
)
mzz.coos = {}
for var, selector in mzz.coo_map.items():
if selector.startswith("cf:") and "M" not in mzz.coo_dtypes.get(var, ""):
import cftime
import datetime

# undoing CF recoding in original input
mzz.coos[var] = set()
for c in ds[var].values:
value = cftime.date2num(
datetime.datetime.fromisoformat(str(c).split(".")[0]),
calendar=ds[var].attrs.get(
"calendar", ds[var].encoding.get("calendar", "standard")
),
units=ds[var].attrs.get("units", ds[var].encoding["units"]),
)
value2 = cftime.num2date(
value,
calendar=ds[var].attrs.get(
"calendar", ds[var].encoding.get("calendar", "standard")
),
units=ds[var].attrs.get("units", ds[var].encoding["units"]),
)
mzz.coos[var].add(value2)

else:
mzz.coos[var] = set(ds[var].values)
return mzz

@property
def fss(self):
"""filesystem instances being analysed, one per input dataset"""
Expand Down Expand Up @@ -232,10 +327,13 @@ def _get_value(self, index, z, var, fn=None):
units = datavar.attrs.get("units")
calendar = datavar.attrs.get("calendar", "standard")
o = cftime.num2date(o, units=units, calendar=calendar)
if self.cf_units is None:
self.cf_units = {}
if var not in self.cf_units:
self.cf_units[var] = dict(units=units, calendar=calendar)
if "M" in self.coo_dtypes.get(var, ""):
o = np.array([_.isoformat() for _ in o], dtype=self.coo_dtypes[var])
else:
if self.cf_units is None:
self.cf_units = {}
if var not in self.cf_units:
self.cf_units[var] = dict(units=units, calendar=calendar)
else:
o = selector # must be a non-number constant - error?
logger.debug("Decode: %s -> %s", (selector, index, var, fn), o)
Expand All @@ -244,7 +342,7 @@ def _get_value(self, index, z, var, fn=None):
def first_pass(self):
"""Accumulate the set of concat coords values across all inputs"""

coos = {c: set() for c in self.coo_map}
coos = self.coos or {c: set() for c in self.coo_map}
for i, fs in enumerate(self.fss):
if self.preprocess:
self.preprocess(fs.references)
Expand Down Expand Up @@ -278,7 +376,6 @@ def store_coords(self):
"""
Write coordinate arrays into the output
"""
self.out.clear()
group = zarr.open(self.out)
m = self.fss[0].get_mapper("")
z = zarr.open(m)
Expand All @@ -290,12 +387,7 @@ def store_coords(self):
compression = numcodecs.Zstd() if len(v) > 100 else None
kw = {}
if self.cf_units and k in self.cf_units:
if "M" in self.coo_dtypes.get(k, ""):
# explicit time format
data = np.array(
[_.isoformat() for _ in v], dtype=self.coo_dtypes[k]
)
else:
if "M" not in self.coo_dtypes.get(k, ""):
import cftime

data = cftime.date2num(v, **self.cf_units[k]).ravel()
Expand Down Expand Up @@ -348,10 +440,11 @@ def store_coords(self):
def second_pass(self):
"""map every input chunk to the output"""
# TODO: this stage cannot be rerun without clearing and rerunning store_coords too,
# because some code runs dependant on the current state f self.out
# because some code runs dependent on the current state of self.out
chunk_sizes = {} #
skip = set()
dont_skip = set()
did_them = set()
no_deps = None

for i, fs in enumerate(self.fss):
Expand Down Expand Up @@ -422,7 +515,10 @@ def second_pass(self):
# a coordinate is any array appearing in its own or other array's _ARRAY_DIMENSIONS
skip.add(v)
for k in fs.ls(v, detail=False):
self.out[k] = fs.references[k]
if k.rsplit("/", 1)[-1].startswith(".z"):
self.out[k] = fs.cat(k)
else:
self.out[k] = fs.references[k]
continue

dont_skip.add(v) # don't check for coord or identical again
Expand All @@ -432,7 +528,8 @@ def second_pass(self):
] + coords

# create output array, accounting for shape, chunks and dim dependencies
if f"{var or v}/.zarray" not in self.out:
if (var or v) not in did_them:
did_them.add(var or v)
shape = []
ch = []
for c in coord_order:
Expand Down Expand Up @@ -495,7 +592,7 @@ def translate(self, filename=None, storage_options=None):
"""Perform all stages and return the resultant references dict

If filename and storage options are given, the output is written to this
file using ujson and fsspec instead of being returned.
file using ujson and fsspec.
"""
if 1 not in self.done:
self.first_pass()
Expand All @@ -507,12 +604,15 @@ def translate(self, filename=None, storage_options=None):
if self.postprocess is not None:
self.out = self.postprocess(self.out)
self.done.add(4)
out = consolidate(self.out)
if filename is None:
return out
if isinstance(self.out, dict):
out = consolidate(self.out)
else:
self.out.flush()
out = self.out
if filename is not None:
with fsspec.open(filename, mode="wt", **(storage_options or {})) as f:
ujson.dump(out, f)
return out


def _reorganise(coos):
Expand Down
81 changes: 81 additions & 0 deletions kerchunk/tests/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@
"memory://single2.zarr"
)

data = xr.DataArray(
data=arr,
coords={"time": np.array([3])},
dims=["time", "x", "y"],
name="data",
attrs={"attr0": 4},
)
xr.Dataset({"data": data, "static": static}, attrs={"attr1": 6}).to_zarr(
"memory://single3.zarr"
)

data = xr.DataArray(
data=np.vstack([arr] * 4),
coords={"time": np.array([1, 2, 3, 4])},
Expand Down Expand Up @@ -303,6 +314,76 @@ def test_single(refs):
assert (z.data[1].values == arr).all()


def test_single_append(refs):
mzz = MultiZarrToZarr(
[refs["single1"], refs["single2"]],
remote_protocol="memory",
concat_dims=["time"],
coo_dtypes={"time": "int16"},
)
out = mzz.translate()
mzz = MultiZarrToZarr.append(
[refs["single3"]],
out,
remote_protocol="memory",
concat_dims=["time"],
coo_dtypes={"time": "int16"},
)
out = mzz.translate()
z = xr.open_dataset(
"reference://",
backend_kwargs={
"storage_options": {"fo": out, "remote_protocol": "memory"},
"consolidated": False,
},
engine="zarr",
decode_cf=False,
)
assert z.data.shape == (3, 10, 10)
assert out["refs"]["data/1.0.0"] == ["memory:///single2.zarr/data/0.0.0"]
assert out["refs"]["data/2.0.0"] == ["memory:///single3.zarr/data/0.0.0"]
assert z.time.values.tolist() == [1, 2, 3]


def test_single_append_parquet(refs):
from fsspec.implementations.reference import LazyReferenceMapper

m = fsspec.filesystem("memory")
out = LazyReferenceMapper.create("memory://refs/out.parquet", fs=m)
mzz = MultiZarrToZarr(
[refs["single1"], refs["single2"]],
remote_protocol="memory",
concat_dims=["time"],
coo_dtypes={"time": "int16"},
out=out,
)
mzz.translate()

# reload here due to unknown bug after flush
out = LazyReferenceMapper("memory://refs/out.parquet", fs=m)
mzz = MultiZarrToZarr.append(
[refs["single3"]],
out,
remote_protocol="memory",
concat_dims=["time"],
coo_dtypes={"time": "int16"},
)
out = mzz.translate()

z = xr.open_dataset(
out,
backend_kwargs={
"storage_options": {"remote_protocol": "memory"},
},
engine="kerchunk",
decode_cf=False,
)
assert z.data.shape == (3, 10, 10)
assert out["data/1.0.0"] == ["memory:///single2.zarr/data/0.0.0"]
assert out["data/2.0.0"] == ["memory:///single3.zarr/data/0.0.0"]
assert z.time.values.tolist() == [1, 2, 3]


def test_lazy_filler(tmpdir, refs):
pd = pytest.importorskip("pandas")
pytest.importorskip("fastparquet")
Expand Down
16 changes: 16 additions & 0 deletions kerchunk/tests/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import pandas as pd
import pytest
import xarray as xr
import fsspec.implementations.reference as reffs

import kerchunk.combine
import kerchunk.zarr
import kerchunk.utils

Expand Down Expand Up @@ -68,3 +70,17 @@ def test_zarr_in_zip(zarr_in_zip, ds):
"reference", fo=out, remote_protocol="zip", remote_options={"fo": zarr_in_zip}
)
assert isinstance(fs.references["temp/.zarray"], (str, bytes))


def test_zarr_combine(tmpdir, ds):
fn1 = f"{tmpdir}/test1.zarr"
ds.to_zarr(fn1)

one = kerchunk.zarr.ZarrToZarr(fn1, inline_threshold=0).translate()
fn = f"{tmpdir}/out.parq"
out = reffs.LazyReferenceMapper.create(fn)
mzz = kerchunk.combine.MultiZarrToZarr([one], concat_dims=["time"], out=out)
mzz.translate()

ds2 = xr.open_dataset(fn, engine="kerchunk")
assert ds.equals(ds2)