diff --git a/.gitignore b/.gitignore index 648cffd..30a2a24 100644 --- a/.gitignore +++ b/.gitignore @@ -6,8 +6,10 @@ __pycache__ build/ dist/ rechunker/_version.py +.idea/ # ignore temp data created during tests and nb execution *.zarr .ipynb_checkpoints -dask-worker-space \ No newline at end of file +dask-worker-space + diff --git a/rechunker/api.py b/rechunker/api.py index f36b59b..ed81fa3 100644 --- a/rechunker/api.py +++ b/rechunker/api.py @@ -1,13 +1,19 @@ """User-facing functions.""" +from __future__ import annotations + +import contextlib import html import textwrap from collections import defaultdict -from typing import Union +from typing import Iterator, Optional, Union import dask import dask.array +import fsspec import xarray import zarr +from fsspec import AbstractFileSystem +from fsspec.implementations.local import LocalFileSystem from xarray.backends.zarr import ( DIMENSION_KEY, encode_zarr_attr_value, @@ -33,15 +39,17 @@ class Rechunked: >>> source = zarr.ones((4, 4), chunks=(2, 2), store="source.zarr") >>> intermediate = "intermediate.zarr" >>> target = "target.zarr" - >>> rechunked = rechunk(source, target_chunks=(4, 1), target_store=target, - ... max_mem=256000, - ... temp_store=intermediate) - >>> rechunked + >>> with api.rechunk(source, + ... target_chunks=(4, 1), + ... target_store=target, + ... max_mem=256000, + ... temp_store=intermediate) as rechunked: + >>> rechunked * Source : * Intermediate: dask.array * Target : - >>> rechunked.execute() + >>> rechunked.execute() """ @@ -218,7 +226,7 @@ class PythonCopySpecExecutor(PythonPipelineExecutor, CopySpecToPipelinesMixin): raise ValueError(f"unrecognized executor {name}") -def rechunk( +def _unsafe_rechunk( source, target_chunks, max_mem, @@ -579,3 +587,52 @@ def _setup_array_rechunk( int_proxy = ArrayProxy(int_array, int_chunks) write_proxy = ArrayProxy(target_array, write_chunks) return CopySpec(read_proxy, int_proxy, write_proxy) + + +@contextlib.contextmanager +def rechunk( + source, + target_chunks, + max_mem, + target_store: str, + target_options: Optional[dict] = None, + temp_store: Optional[str] = None, + temp_options: Optional[dict] = None, + executor: Union[str, CopySpecExecutor] = "dask", + target_filesystem: Union[str, AbstractFileSystem] = LocalFileSystem(), + temp_filesystem: Union[str, AbstractFileSystem] = LocalFileSystem(), + keep_target_store: bool = True, +) -> Iterator[Rechunked]: + try: + target_options = target_options or {} + temp_options = temp_options or {} + if isinstance(target_filesystem, str): + target_filesystem = fsspec.filesystem(target_filesystem, **target_options) + if isinstance(temp_filesystem, str): + temp_filesystem = fsspec.filesystem(temp_filesystem, **temp_options) + if target_filesystem.exists(target_store): + raise FileExistsError(target_store) + if temp_store is not None: + _rm_store(temp_store, temp_filesystem) + yield _unsafe_rechunk( + source=source, + target_chunks=target_chunks, + max_mem=max_mem, + target_store=target_store, + target_options=target_options, + temp_store=temp_store, + temp_options=temp_options, + executor=executor, + ) + finally: + if temp_store is not None: + _rm_store(temp_store, temp_filesystem) + if not keep_target_store: + _rm_store(target_store, target_filesystem) + + +def _rm_store(store: str, filesystem: AbstractFileSystem): + try: + filesystem.rm(store, recursive=True, maxdepth=100) + except FileNotFoundError: + pass diff --git a/tests/test_rechunk.py b/tests/test_rechunk.py index 09a32f0..924804c 100644 --- a/tests/test_rechunk.py +++ b/tests/test_rechunk.py @@ -1,6 +1,7 @@ import importlib from functools import partial from pathlib import Path +from unittest.mock import MagicMock, patch import dask import dask.array as dsa @@ -10,10 +11,22 @@ import pytest import xarray import zarr +from fsspec.implementations.local import LocalFileSystem +from fsspec.implementations.memory import MemoryFileSystem from rechunker import api _DIMENSION_KEY = "_ARRAY_DIMENSIONS" +TEST_DATASET = xarray.DataArray( + data=np.empty((10, 10)), + coords={"x": range(0, 10), "y": range(0, 10)}, + dims=["x", "y"], + name="test_data", +).to_dataset() +LOCAL_FS = LocalFileSystem() +MEM_FS = MemoryFileSystem() +TARGET_STORE_NAME = "target_store.zarr" +TMP_STORE_NAME = "tmp.zarr" def requires_import(module, *args): @@ -200,7 +213,7 @@ def test_rechunk_dataset( _FillValue=-9999, ) ) - rechunked = api.rechunk( + rechunked = api._unsafe_rechunk( ds, target_chunks=target_chunks, max_mem=max_mem, @@ -261,7 +274,7 @@ def test_rechunk_dataset_dimchunks( _FillValue=-9999, ) ) - rechunked = api.rechunk( + rechunked = api._unsafe_rechunk( ds, target_chunks=target_chunks, max_mem=max_mem, @@ -331,7 +344,7 @@ def test_rechunk_array( target_store = str(tmp_path / "target.zarr") temp_store = str(tmp_path / "temp.zarr") - rechunked = api.rechunk( + rechunked = api._unsafe_rechunk( source_array, target_chunks, max_mem, @@ -374,7 +387,7 @@ def test_rechunk_dask_array( target_store = str(tmp_path / "target.zarr") temp_store = str(tmp_path / "temp.zarr") - rechunked = api.rechunk( + rechunked = api._unsafe_rechunk( source_array, target_chunks, max_mem, target_store, temp_store=temp_store ) assert isinstance(rechunked, api.Rechunked) @@ -417,7 +430,7 @@ def test_rechunk_group(tmp_path, executor, source_store, target_store, temp_stor max_mem = 1600 # should force a two-step plan for a target_chunks = {"a": (5, 10, 4), "b": (20,)} - rechunked = api.rechunk( + rechunked = api._unsafe_rechunk( group, target_chunks, max_mem, @@ -516,7 +529,7 @@ def rechunk_args(tmp_path, request): @pytest.fixture() def rechunked(rechunk_args): - return api.rechunk(**rechunk_args) + return api._unsafe_rechunk(**rechunk_args) def test_repr(rechunked): @@ -546,13 +559,13 @@ def _wrap_options(source, options): def test_rechunk_option_overwrite(rechunk_args): - api.rechunk(**rechunk_args).execute() + api._unsafe_rechunk(**rechunk_args).execute() # TODO: make this match more reliable based on outcome of # https://github.com/zarr-developers/zarr-python/issues/605 with pytest.raises(ValueError, match=r"path .* contains an array"): - api.rechunk(**rechunk_args).execute() + api._unsafe_rechunk(**rechunk_args).execute() options = _wrap_options(rechunk_args["source"], dict(overwrite=True)) - api.rechunk(**rechunk_args, target_options=options).execute() + api._unsafe_rechunk(**rechunk_args, target_options=options).execute() def test_rechunk_passthrough(rechunk_args): @@ -561,7 +574,7 @@ def test_rechunk_passthrough(rechunk_args): rechunk_args["target_chunks"] = {v: None for v in rechunk_args["source"]} else: rechunk_args["target_chunks"] = None - api.rechunk(**rechunk_args).execute() + api._unsafe_rechunk(**rechunk_args).execute() def test_rechunk_no_temp_dir_provided_error(rechunk_args): @@ -569,7 +582,7 @@ def test_rechunk_no_temp_dir_provided_error(rechunk_args): # and the chunks to write differ from the chunks to read args = {k: v for k, v in rechunk_args.items() if k != "temp_store"} with pytest.raises(ValueError, match="A temporary store location must be provided"): - api.rechunk(**args).execute() + api._unsafe_rechunk(**args).execute() def test_rechunk_option_compression(rechunk_args): @@ -577,7 +590,7 @@ def rechunk(compressor): options = _wrap_options( rechunk_args["source"], dict(overwrite=True, compressor=compressor) ) - rechunked = api.rechunk(**rechunk_args, target_options=options) + rechunked = api._unsafe_rechunk(**rechunk_args, target_options=options) rechunked.execute() return sum( file.stat().st_size @@ -600,14 +613,14 @@ def test_rechunk_invalid_option(rechunk_args): ValueError, match="Chunks must be provided in ``target_chunks`` rather than options", ): - api.rechunk(**rechunk_args, target_options=options) + api._unsafe_rechunk(**rechunk_args, target_options=options) else: for o in ["shape", "chunks", "dtype", "store", "name", "unknown"]: options = _wrap_options(rechunk_args["source"], {o: True}) with pytest.raises(ValueError, match=f"Zarr options must not include {o}"): - api.rechunk(**rechunk_args, temp_options=options) + api._unsafe_rechunk(**rechunk_args, temp_options=options) with pytest.raises(ValueError, match=f"Zarr options must not include {o}"): - api.rechunk(**rechunk_args, target_options=options) + api._unsafe_rechunk(**rechunk_args, target_options=options) def test_rechunk_bad_target_chunks(rechunk_args): @@ -618,7 +631,7 @@ def test_rechunk_bad_target_chunks(rechunk_args): with pytest.raises( ValueError, match="You must specify ``target-chunks`` as a dict" ): - api.rechunk(**rechunk_args) + api._unsafe_rechunk(**rechunk_args) def test_rechunk_invalid_source(tmp_path): @@ -626,7 +639,7 @@ def test_rechunk_invalid_source(tmp_path): ValueError, match="Source must be a Zarr Array, Zarr Group, Dask Array or Xarray Dataset", ): - api.rechunk( + api._unsafe_rechunk( [[1, 2], [3, 4]], target_chunks=(10, 10), max_mem=100, target_store=tmp_path ) @@ -637,7 +650,7 @@ def test_rechunk_no_target_chunks(rechunk_args): rechunk_args["target_chunks"] = {v: None for v in rechunk_args["source"]} else: rechunk_args["target_chunks"] = None - api.rechunk(**rechunk_args) + api._unsafe_rechunk(**rechunk_args) def test_no_intermediate(): @@ -662,8 +675,129 @@ def test_no_intermediate_fused(tmp_path): target_store = str(tmp_path / "target.zarr") - rechunked = api.rechunk(source_array, target_chunks, max_mem, target_store) + rechunked = api._unsafe_rechunk(source_array, target_chunks, max_mem, target_store) # rechunked.plan is a list of dask delayed objects num_tasks = len([v for v in rechunked.plan[0].dask.values() if dask.core.istask(v)]) assert num_tasks < 20 # less than if no fuse + + +class Test_rechunk_context_manager: + def _clean(self, stores): + for s in stores: + try: + LOCAL_FS.rm(s, recursive=True, maxdepth=100) + except: + pass + + @pytest.fixture(autouse=True) + def _wrap(self): + self._clean([TMP_STORE_NAME, TARGET_STORE_NAME]) + with dask.config.set(scheduler="single-threaded"): + yield + self._clean([TMP_STORE_NAME, TARGET_STORE_NAME]) + + @patch("rechunker.api._unsafe_rechunk") + def test_rechunk__args_sent_as_is(self, rechunk_func: MagicMock): + with api.rechunk( + source="source", + target_chunks={"truc": "bidule"}, + max_mem="42KB", + target_store="target_store.zarr", + temp_store="tmp_store.zarr", + target_options=None, + temp_options=None, + executor="dask", + target_filesystem=LOCAL_FS, + temp_filesystem=LOCAL_FS, + keep_target_store=False, + ): + rechunk_func.assert_called_with( + source="source", + target_chunks={"truc": "bidule"}, + max_mem="42KB", + target_store="target_store.zarr", + target_options={}, + temp_store="tmp_store.zarr", + temp_options={}, + executor="dask", + ) + + def test_rechunk__remove_every_stores(self): + with api.rechunk( + source=TEST_DATASET, + target_chunks={"x": 2, "y": 2}, + max_mem="42KB", + target_store="target_store.zarr", + temp_store="tmp_store.zarr", + target_options=None, + temp_options=None, + executor="dask", + target_filesystem=LOCAL_FS, + temp_filesystem=LOCAL_FS, + keep_target_store=False, + ) as plan: + plan.execute() + assert LOCAL_FS.exists("target_store.zarr") + assert LOCAL_FS.exists("tmp_store.zarr") + assert not LOCAL_FS.exists("tmp_store.zarr") + assert not LOCAL_FS.exists("target_store.zarr") + + def test_rechunk__keep_target(self): + with api.rechunk( + source=TEST_DATASET, + target_chunks={"x": 2, "y": 2}, + max_mem="42KB", + target_store="target_store.zarr", + temp_store="tmp_store.zarr", + target_options=None, + temp_options=None, + executor="dask", + target_filesystem=LOCAL_FS, + temp_filesystem=LOCAL_FS, + keep_target_store=True, + ) as plan: + plan.execute() + assert LOCAL_FS.exists("target_store.zarr") + assert LOCAL_FS.exists("tmp_store.zarr") + assert LOCAL_FS.exists("target_store.zarr") + assert not LOCAL_FS.exists("tmp_store.zarr") + + def test_rechunk__error_target_exist(self): + f = LOCAL_FS.open("target_store.zarr", "x") + f.close() + with pytest.raises(FileExistsError): + with api.rechunk( + source=TEST_DATASET, + target_chunks={"x": 2, "y": 2}, + max_mem="42KB", + target_store="target_store.zarr", + temp_store="tmp_store.zarr", + target_options=None, + temp_options=None, + executor="dask", + target_filesystem=LOCAL_FS, + temp_filesystem=LOCAL_FS, + keep_target_store=False, + ): + pass + + def test_rechunk__memory_filesystem(self): + with api.rechunk( + source=TEST_DATASET, + target_chunks={"x": 2, "y": 2}, + max_mem="42KB", + target_store="memory://target_store.zarr", + temp_store="memory://tmp_store.zarr", + target_options={"mode": "rw"}, + temp_options={"mode": "rw"}, + executor="dask", + target_filesystem=MEM_FS, + temp_filesystem=MEM_FS, + keep_target_store=True, + ) as plan: + plan.execute() + assert MEM_FS.exists("target_store.zarr") + assert MEM_FS.exists("tmp_store.zarr") + assert MEM_FS.exists("target_store.zarr") + assert not MEM_FS.exists("tmp_store.zarr")