diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f73df554aa8..e2d9812e803 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -94,6 +94,8 @@ Internal Changes - Use Python 3.6 idioms throughout the codebase. (:pull:3419) By `Maximilian Roos `_ +- Implement :py:func:`__dask_tokenize__` for xarray objects. + By `Deepak Cherian `_ and `Guido Imperiale `_. .. _whats-new.0.14.0: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 62890f9cefa..eb9cc128dad 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -754,6 +754,9 @@ def reset_coords( dataset[self.name] = self.variable return dataset + def __dask_tokenize__(self): + return (DataArray, self._variable, self._coords, self._name) + def __dask_graph__(self): return self._to_temp_dataset().__dask_graph__() diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 05d9772cb7a..14e336fdb99 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -649,6 +649,9 @@ def load(self, **kwargs) -> "Dataset": return self + def __dask_tokenize__(self): + return (Dataset, self._variables, self._coord_names, self._attrs) + def __dask_graph__(self): graphs = {k: v.__dask_graph__() for k, v in self.variables.items()} graphs = {k: v for k, v in graphs.items() if v is not None} diff --git a/xarray/core/variable.py b/xarray/core/variable.py index a9f599069eb..866a27ce478 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -390,6 +390,9 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) + def __dask_tokenize__(self): + return Variable, self._dims, self.data, self._attrs + def __dask_graph__(self): if isinstance(self._data, dask_array_type): return self._data.__dask_graph__() @@ -1967,6 +1970,9 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): if not isinstance(self._data, PandasIndexAdapter): self._data = PandasIndexAdapter(self._data) + def __dask_tokenize__(self): + return (IndexVariable, self._dims, self._data.array, self._attrs) + def load(self): # data is already loaded into memory for IndexVariable return self diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index cde9faa44b7..0bbb6734a8e 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1,5 +1,6 @@ import operator import pickle +import sys from contextlib import suppress from distutils.version import LooseVersion from textwrap import dedent @@ -23,11 +24,14 @@ raises_regex, ) from ..core.duck_array_ops import lazy_array_equiv +from .test_backends import create_tmp_file dask = pytest.importorskip("dask") da = pytest.importorskip("dask.array") dd = pytest.importorskip("dask.dataframe") +ON_WINDOWS = sys.platform == "win32" + class CountingScheduler: """ Simple dask scheduler counting the number of computes. @@ -1221,3 +1225,57 @@ def test_lazy_array_equiv(): "no_conflicts", ]: xr.merge([lons1, lons2], compat=compat) + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +@pytest.mark.parametrize( + "transform", + [ + lambda x: x.reset_coords(), + lambda x: x.reset_coords(drop=True), + lambda x: x.isel(x=1), + lambda x: x.attrs.update(new_attrs=1), + lambda x: x.assign_coords(cxy=1), + lambda x: x.rename({"x": "xnew"}), + lambda x: x.rename({"cxy": "cxynew"}), + ], +) +def test_normalize_token_not_identical(obj, transform): + with raise_if_dask_computes(): + assert not dask.base.tokenize(obj) == dask.base.tokenize(transform(obj)) + assert not dask.base.tokenize(obj.compute()) == dask.base.tokenize( + transform(obj.compute()) + ) + + +@pytest.mark.parametrize("transform", [lambda x: x, lambda x: x.compute()]) +def test_normalize_differently_when_data_changes(transform): + obj = transform(make_ds()) + new = obj.copy(deep=True) + new["a"] *= 2 + with raise_if_dask_computes(): + assert not dask.base.tokenize(obj) == dask.base.tokenize(new) + + obj = transform(make_da()) + new = obj.copy(deep=True) + new *= 2 + with raise_if_dask_computes(): + assert not dask.base.tokenize(obj) == dask.base.tokenize(new) + + +@pytest.mark.parametrize( + "transform", [lambda x: x, lambda x: x.copy(), lambda x: x.copy(deep=True)] +) +@pytest.mark.parametrize( + "obj", [make_da(), make_ds(), make_da().indexes["x"], make_ds().variables["a"]] +) +def test_normalize_token_identical(obj, transform): + with raise_if_dask_computes(): + assert dask.base.tokenize(obj) == dask.base.tokenize(transform(obj)) + + +def test_normalize_token_netcdf_backend(map_ds): + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp_file: + map_ds.to_netcdf(tmp_file) + read = xr.open_dataset(tmp_file) + assert not dask.base.tokenize(map_ds) == dask.base.tokenize(read) diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index 73c4b9b8c74..ce4ee374f21 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -22,6 +22,7 @@ ) sparse = pytest.importorskip("sparse") +dask = pytest.importorskip("dask") def assert_sparse_equal(a, b): @@ -849,3 +850,14 @@ def test_chunk(): dsc = ds.chunk(2) assert dsc.chunks == {"dim_0": (2, 2)} assert_identical(dsc, ds) + + +def test_normalize_token(): + s = sparse.COO.from_numpy(np.array([0, 0, 1, 2])) + a = DataArray(s) + dask.base.tokenize(a) + assert isinstance(a.data, sparse.COO) + + ac = a.chunk(2) + dask.base.tokenize(ac) + assert isinstance(ac.data._meta, sparse.COO)