diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 09ef053bb39..c7ea19a53cb 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -5,11 +5,3 @@ - [ ] Passes `pre-commit run --all-files` - [ ] User visible changes (including notable bug fixes) are documented in `whats-new.rst` - [ ] New functions/methods are listed in `api.rst` - - - -

- Overriding CI behaviors -

- By default, the upstream dev CI is disabled on pull request and push events. You can override this behavior per commit by adding a [test-upstream] tag to the first line of the commit message. For documentation-only commits, you can skip the CI per commit by adding a [skip-ci] tag to the first line of the commit message -
diff --git a/asv_bench/benchmarks/repr.py b/asv_bench/benchmarks/repr.py new file mode 100644 index 00000000000..b218c0be870 --- /dev/null +++ b/asv_bench/benchmarks/repr.py @@ -0,0 +1,18 @@ +import pandas as pd + +import xarray as xr + + +class ReprMultiIndex: + def setup(self, key): + index = pd.MultiIndex.from_product( + [range(10000), range(10000)], names=("level_0", "level_1") + ) + series = pd.Series(range(100000000), index=index) + self.da = xr.DataArray(series) + + def time_repr(self): + repr(self.da) + + def time_repr_html(self): + self.da._repr_html_() diff --git a/asv_bench/benchmarks/unstacking.py b/asv_bench/benchmarks/unstacking.py index 342475b96df..8d0c3932870 100644 --- a/asv_bench/benchmarks/unstacking.py +++ b/asv_bench/benchmarks/unstacking.py @@ -7,18 +7,23 @@ class Unstacking: def setup(self): - data = np.random.RandomState(0).randn(1, 1000, 500) - self.ds = xr.DataArray(data).stack(flat_dim=["dim_1", "dim_2"]) + data = np.random.RandomState(0).randn(500, 1000) + self.da_full = xr.DataArray(data, dims=list("ab")).stack(flat_dim=[...]) + self.da_missing = self.da_full[:-1] + self.df_missing = self.da_missing.to_pandas() def time_unstack_fast(self): - self.ds.unstack("flat_dim") + self.da_full.unstack("flat_dim") def time_unstack_slow(self): - self.ds[:, ::-1].unstack("flat_dim") + self.da_missing.unstack("flat_dim") + + def time_unstack_pandas_slow(self): + self.df_missing.unstack() class UnstackingDask(Unstacking): def setup(self, *args, **kwargs): requires_dask() super().setup(**kwargs) - self.ds = self.ds.chunk({"flat_dim": 50}) + self.da_full = self.da_full.chunk({"flat_dim": 50}) diff --git a/doc/api.rst b/doc/api.rst index ce866093db8..9add7a96109 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -126,6 +126,7 @@ Indexing Dataset.isel Dataset.sel Dataset.drop_sel + Dataset.drop_isel Dataset.head Dataset.tail Dataset.thin @@ -308,6 +309,7 @@ Indexing DataArray.isel DataArray.sel DataArray.drop_sel + DataArray.drop_isel DataArray.head DataArray.tail DataArray.thin diff --git a/doc/conf.py b/doc/conf.py index d83e966f3fa..14b28b4e471 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -411,7 +411,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), - "iris": ("https://scitools.org.uk/iris/docs/latest", None), + "iris": ("https://scitools-iris.readthedocs.io/en/latest", None), "numpy": ("https://numpy.org/doc/stable", None), "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), "numba": ("https://numba.pydata.org/numba-doc/latest", None), diff --git a/doc/contributing.rst b/doc/contributing.rst index 9c4ce5a0af2..439791cbbd6 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -836,6 +836,7 @@ PR checklist - Write new tests if needed. See `"Test-driven development/code writing" `_. - Test the code using `Pytest `_. Running all tests (type ``pytest`` in the root directory) takes a while, so feel free to only run the tests you think are needed based on your PR (example: ``pytest xarray/tests/test_dataarray.py``). CI will catch any failing tests. + - By default, the upstream dev CI is disabled on pull request and push events. You can override this behavior per commit by adding a [test-upstream] tag to the first line of the commit message. For documentation-only commits, you can skip the CI per commit by adding a "[skip-ci]" tag to the first line of the commit message. - **Properly format your code** and verify that it passes the formatting guidelines set by `Black `_ and `Flake8 `_. See `"Code formatting" `_. You can use `pre-commit `_ to run these automatically on each commit. diff --git a/doc/faq.rst b/doc/faq.rst index a2b8be47e06..a2151cc4b37 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -166,7 +166,7 @@ different approaches to handling metadata: Iris strictly interprets `CF conventions`_. Iris particularly shines at mapping, thanks to its integration with Cartopy_. -.. _Iris: http://scitools.org.uk/iris/ +.. _Iris: https://scitools-iris.readthedocs.io/en/stable/ .. _Cartopy: http://scitools.org.uk/cartopy/docs/latest/ `UV-CDAT`__ is another Python library that implements in-memory netCDF-like diff --git a/doc/related-projects.rst b/doc/related-projects.rst index 456cb64197f..0a010195d6d 100644 --- a/doc/related-projects.rst +++ b/doc/related-projects.rst @@ -15,6 +15,7 @@ Geosciences - `aospy `_: Automated analysis and management of gridded climate data. - `climpred `_: Analysis of ensemble forecast models for climate prediction. - `geocube `_: Tool to convert geopandas vector data into rasterized xarray data. +- `GeoWombat `_: Utilities for analysis of remotely sensed and gridded raster data at scale (easily tame Landsat, Sentinel, Quickbird, and PlanetScope). - `infinite-diff `_: xarray-based finite-differencing, focused on gridded climate/meterology data - `marc_analysis `_: Analysis package for CESM/MARC experiments and output. - `MetPy `_: A collection of tools in Python for reading, visualizing, and performing calculations with weather data. diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 502181487b9..bb81426e4dd 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -17,7 +17,7 @@ What's New .. _whats-new.0.16.3: -v0.16.3 (unreleased) +v0.17.0 (unreleased) -------------------- Breaking changes @@ -39,16 +39,32 @@ Breaking changes always be set such that ``int64`` values can be used. In the past, no units finer than "seconds" were chosen, which would sometimes mean that ``float64`` values were required, which would lead to inaccurate I/O round-trips. -- remove deprecated ``autoclose`` kwargs from :py:func:`open_dataset` (:pull: `4725`). - By `Aureliana Barghini `_ +- remove deprecated ``autoclose`` kwargs from :py:func:`open_dataset` (:pull:`4725`). + By `Aureliana Barghini `_. + +Deprecations +~~~~~~~~~~~~ + +- ``dim`` argument to :py:meth:`DataArray.integrate` is being deprecated in + favour of a ``coord`` argument, for consistency with :py:meth:`Dataset.integrate`. + For now using ``dim`` issues a ``FutureWarning``. By `Tom Nicholas `_. New Features ~~~~~~~~~~~~ +- Significantly higher ``unstack`` performance on numpy-backed arrays which + contain missing values; 8x faster in our benchmark, and 2x faster than pandas. + (:pull:`4746`); + By `Maximilian Roos `_. + - Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables. By `Deepak Cherian `_ - Add :py:meth:`Dataset.plot.quiver` for quiver plots with :py:class:`Dataset` variables. By `Deepak Cherian `_ + By `Deepak Cherian `_. +- :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims + in the form of kwargs as well as a dict, like most similar methods. + By `Maximilian Roos `_. Bug fixes ~~~~~~~~~ @@ -82,6 +98,7 @@ Bug fixes - Expand user directory paths (e.g. ``~/``) in :py:func:`open_mfdataset` and :py:meth:`Dataset.to_zarr` (:issue:`4783`, :pull:`4795`). By `Julien Seguinot `_. +- Add :py:meth:`Dataset.drop_isel` and :py:meth:`DataArray.drop_isel` (:issue:`4658`, :pull:`4819`). By `Daniel Mesejo `_. Documentation ~~~~~~~~~~~~~ @@ -110,6 +127,8 @@ Internal Changes By `Maximilian Roos `_. - Speed up attribute style access (e.g. ``ds.somevar`` instead of ``ds["somevar"]``) and tab completion in ipython (:issue:`4741`, :pull:`4742`). By `Richard Kleijn `_. +- Added the ``set_close`` method to ``Dataset`` and ``DataArray`` for beckends to specify how to voluntary release + all resources. (:pull:`#4809`), By `Alessandro Amici `_. .. _whats-new.0.16.2: diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 4958062a262..81314588784 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -522,7 +522,7 @@ def maybe_decode_store(store, chunks): else: ds2 = ds - ds2._file_obj = ds._file_obj + ds2.set_close(ds._close) return ds2 filename_or_obj = _normalize_path(filename_or_obj) @@ -701,7 +701,7 @@ def open_dataarray( else: (data_array,) = dataset.data_vars.values() - data_array._file_obj = dataset._file_obj + data_array.set_close(dataset._close) # Reset names if they were changed during saving # to ensure that we can 'roundtrip' perfectly @@ -715,17 +715,6 @@ def open_dataarray( return data_array -class _MultiFileCloser: - __slots__ = ("file_objs",) - - def __init__(self, file_objs): - self.file_objs = file_objs - - def close(self): - for f in self.file_objs: - f.close() - - def open_mfdataset( paths, chunks=None, @@ -918,14 +907,14 @@ def open_mfdataset( getattr_ = getattr datasets = [open_(p, **open_kwargs) for p in paths] - file_objs = [getattr_(ds, "_file_obj") for ds in datasets] + closers = [getattr_(ds, "_close") for ds in datasets] if preprocess is not None: datasets = [preprocess(ds) for ds in datasets] if parallel: # calling compute here will return the datasets/file_objs lists, # the underlying datasets will still be stored as dask arrays - datasets, file_objs = dask.compute(datasets, file_objs) + datasets, closers = dask.compute(datasets, closers) # Combine all datasets, closing them in case of a ValueError try: @@ -963,7 +952,11 @@ def open_mfdataset( ds.close() raise - combined._file_obj = _MultiFileCloser(file_objs) + def multi_file_closer(): + for closer in closers: + closer() + + combined.set_close(multi_file_closer) # read global attributes from the attrs_file or from the first dataset if attrs_file is not None: diff --git a/xarray/backends/apiv2.py b/xarray/backends/apiv2.py index 0f98291983d..d31fc9ea773 100644 --- a/xarray/backends/apiv2.py +++ b/xarray/backends/apiv2.py @@ -90,7 +90,7 @@ def _dataset_from_backend_dataset( **extra_tokens, ) - ds._file_obj = backend_ds._file_obj + ds.set_close(backend_ds._close) # Ensure source filename always stored in dataset object (GH issue #2550) if "source" not in ds.encoding: diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index d4933e370c7..65c5bc2a02b 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -5,9 +5,22 @@ from ..core import indexing from ..core.utils import Frozen, FrozenDict, close_on_error from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray, BackendEntrypoint +from .common import ( + BACKEND_ENTRYPOINTS, + AbstractDataStore, + BackendArray, + BackendEntrypoint, +) from .locks import SerializableLock, ensure_lock -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint + +try: + import cfgrib + + has_cfgrib = True +except ModuleNotFoundError: + has_cfgrib = False + # FIXME: Add a dedicated lock, even if ecCodes is supposed to be thread-safe # in most circumstances. See: @@ -38,7 +51,6 @@ class CfGribDataStore(AbstractDataStore): """ def __init__(self, filename, lock=None, **backend_kwargs): - import cfgrib if lock is None: lock = ECCODES_LOCK @@ -74,58 +86,58 @@ def get_encoding(self): return encoding -def guess_can_open_cfgrib(store_spec): - try: - _, ext = os.path.splitext(store_spec) - except TypeError: - return False - return ext in {".grib", ".grib2", ".grb", ".grb2"} - - -def open_backend_dataset_cfgrib( - filename_or_obj, - *, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - lock=None, - indexpath="{path}.{short_hash}.idx", - filter_by_keys={}, - read_keys=[], - encode_cf=("parameter", "time", "geography", "vertical"), - squeeze=True, - time_dims=("time", "step"), -): - - store = CfGribDataStore( +class CfgribfBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".grib", ".grib2", ".grb", ".grb2"} + + def open_dataset( + self, filename_or_obj, - indexpath=indexpath, - filter_by_keys=filter_by_keys, - read_keys=read_keys, - encode_cf=encode_cf, - squeeze=squeeze, - time_dims=time_dims, - lock=lock, - ) - - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + *, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + lock=None, + indexpath="{path}.{short_hash}.idx", + filter_by_keys={}, + read_keys=[], + encode_cf=("parameter", "time", "geography", "vertical"), + squeeze=True, + time_dims=("time", "step"), + ): + + store = CfGribDataStore( + filename_or_obj, + indexpath=indexpath, + filter_by_keys=filter_by_keys, + read_keys=read_keys, + encode_cf=encode_cf, + squeeze=squeeze, + time_dims=time_dims, + lock=lock, ) - return ds - - -cfgrib_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_cfgrib, guess_can_open=guess_can_open_cfgrib -) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + +if has_cfgrib: + BACKEND_ENTRYPOINTS["cfgrib"] = CfgribfBackendEntrypoint diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 72a63957662..e2905d0866b 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,6 +1,7 @@ import logging import time import traceback +from typing import Dict, Tuple, Type, Union import numpy as np @@ -343,9 +344,13 @@ def encode(self, variables, attributes): class BackendEntrypoint: - __slots__ = ("guess_can_open", "open_dataset", "open_dataset_parameters") + open_dataset_parameters: Union[Tuple, None] = None - def __init__(self, open_dataset, open_dataset_parameters=None, guess_can_open=None): - self.open_dataset = open_dataset - self.open_dataset_parameters = open_dataset_parameters - self.guess_can_open = guess_can_open + def open_dataset(self): + raise NotImplementedError + + def guess_can_open(self, store_spec): + return False + + +BACKEND_ENTRYPOINTS: Dict[str, Type[BackendEntrypoint]] = {} diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index b2996369ee7..aa892c4f89c 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -8,7 +8,12 @@ from ..core import indexing from ..core.utils import FrozenDict, is_remote_uri, read_magic_number from ..core.variable import Variable -from .common import BackendEntrypoint, WritableCFDataStore, find_root_and_group +from .common import ( + BACKEND_ENTRYPOINTS, + BackendEntrypoint, + WritableCFDataStore, + find_root_and_group, +) from .file_manager import CachingFileManager, DummyFileManager from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from .netCDF4_ import ( @@ -18,7 +23,14 @@ _get_datatype, _nc4_require_group, ) -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint + +try: + import h5netcdf + + has_h5netcdf = True +except ModuleNotFoundError: + has_h5netcdf = False class H5NetCDFArrayWrapper(BaseNetCDF4Array): @@ -85,8 +97,6 @@ class H5NetCDFStore(WritableCFDataStore): def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=False): - import h5netcdf - if isinstance(manager, (h5netcdf.File, h5netcdf.Group)): if group is None: root, group = find_root_and_group(manager) @@ -122,7 +132,6 @@ def open( invalid_netcdf=None, phony_dims=None, ): - import h5netcdf if isinstance(filename, bytes): raise ValueError( @@ -319,59 +328,61 @@ def close(self, **kwargs): self._manager.close(**kwargs) -def guess_can_open_h5netcdf(store_spec): - try: - return read_magic_number(store_spec).startswith(b"\211HDF\r\n\032\n") - except TypeError: - pass - - try: - _, ext = os.path.splitext(store_spec) - except TypeError: - return False - - return ext in {".nc", ".nc4", ".cdf"} - - -def open_backend_dataset_h5netcdf( - filename_or_obj, - *, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - format=None, - group=None, - lock=None, - invalid_netcdf=None, - phony_dims=None, -): - - store = H5NetCDFStore.open( +class H5netcdfBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + try: + return read_magic_number(store_spec).startswith(b"\211HDF\r\n\032\n") + except TypeError: + pass + + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + + return ext in {".nc", ".nc4", ".cdf"} + + def open_dataset( + self, filename_or_obj, - format=format, - group=group, - lock=lock, - invalid_netcdf=invalid_netcdf, - phony_dims=phony_dims, - ) + *, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + format=None, + group=None, + lock=None, + invalid_netcdf=None, + phony_dims=None, + ): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, - ) - return ds + store = H5NetCDFStore.open( + filename_or_obj, + format=format, + group=group, + lock=lock, + invalid_netcdf=invalid_netcdf, + phony_dims=phony_dims, + ) + store_entrypoint = StoreBackendEntrypoint() + + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds -h5netcdf_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_h5netcdf, guess_can_open=guess_can_open_h5netcdf -) + +if has_h5netcdf: + BACKEND_ENTRYPOINTS["h5netcdf"] = H5netcdfBackendEntrypoint diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 0e35270ea9a..e3d87aaf83f 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -12,6 +12,7 @@ from ..core.utils import FrozenDict, close_on_error, is_remote_uri from ..core.variable import Variable from .common import ( + BACKEND_ENTRYPOINTS, BackendArray, BackendEntrypoint, WritableCFDataStore, @@ -21,7 +22,15 @@ from .file_manager import CachingFileManager, DummyFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, get_write_lock from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint + +try: + import netCDF4 + + has_netcdf4 = True +except ModuleNotFoundError: + has_netcdf4 = False + # This lookup table maps from dtype.byteorder to a readable endian # string used by netCDF4. @@ -298,7 +307,6 @@ class NetCDF4DataStore(WritableCFDataStore): def __init__( self, manager, group=None, mode=None, lock=NETCDF4_PYTHON_LOCK, autoclose=False ): - import netCDF4 if isinstance(manager, netCDF4.Dataset): if group is None: @@ -335,7 +343,6 @@ def open( lock_maker=None, autoclose=False, ): - import netCDF4 if isinstance(filename, pathlib.Path): filename = os.fspath(filename) @@ -505,61 +512,62 @@ def close(self, **kwargs): self._manager.close(**kwargs) -def guess_can_open_netcdf4(store_spec): - if isinstance(store_spec, str) and is_remote_uri(store_spec): - return True - try: - _, ext = os.path.splitext(store_spec) - except TypeError: - return False - return ext in {".nc", ".nc4", ".cdf"} - - -def open_backend_dataset_netcdf4( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - group=None, - mode="r", - format="NETCDF4", - clobber=True, - diskless=False, - persist=False, - lock=None, - autoclose=False, -): +class NetCDF4BackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + if isinstance(store_spec, str) and is_remote_uri(store_spec): + return True + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".nc", ".nc4", ".cdf"} - store = NetCDF4DataStore.open( + def open_dataset( + self, filename_or_obj, - mode=mode, - format=format, - group=group, - clobber=clobber, - diskless=diskless, - persist=persist, - lock=lock, - autoclose=autoclose, - ) + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + group=None, + mode="r", + format="NETCDF4", + clobber=True, + diskless=False, + persist=False, + lock=None, + autoclose=False, + ): - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + store = NetCDF4DataStore.open( + filename_or_obj, + mode=mode, + format=format, + group=group, + clobber=clobber, + diskless=diskless, + persist=persist, + lock=lock, + autoclose=autoclose, ) - return ds + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds -netcdf4_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_netcdf4, guess_can_open=guess_can_open_netcdf4 -) + +if has_netcdf4: + BACKEND_ENTRYPOINTS["netcdf4"] = NetCDF4BackendEntrypoint diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index d5799a78f91..b8cd2bf6378 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -2,33 +2,11 @@ import inspect import itertools import logging -import typing as T import warnings import pkg_resources -from .cfgrib_ import cfgrib_backend -from .common import BackendEntrypoint -from .h5netcdf_ import h5netcdf_backend -from .netCDF4_ import netcdf4_backend -from .pseudonetcdf_ import pseudonetcdf_backend -from .pydap_ import pydap_backend -from .pynio_ import pynio_backend -from .scipy_ import scipy_backend -from .store import store_backend -from .zarr import zarr_backend - -BACKEND_ENTRYPOINTS: T.Dict[str, BackendEntrypoint] = { - "store": store_backend, - "netcdf4": netcdf4_backend, - "h5netcdf": h5netcdf_backend, - "scipy": scipy_backend, - "pseudonetcdf": pseudonetcdf_backend, - "zarr": zarr_backend, - "cfgrib": cfgrib_backend, - "pydap": pydap_backend, - "pynio": pynio_backend, -} +from .common import BACKEND_ENTRYPOINTS def remove_duplicates(backend_entrypoints): @@ -58,6 +36,7 @@ def remove_duplicates(backend_entrypoints): def detect_parameters(open_dataset): signature = inspect.signature(open_dataset) parameters = signature.parameters + parameters_list = [] for name, param in parameters.items(): if param.kind in ( inspect.Parameter.VAR_KEYWORD, @@ -67,7 +46,9 @@ def detect_parameters(open_dataset): f"All the parameters in {open_dataset!r} signature should be explicit. " "*args and **kwargs is not supported" ) - return tuple(parameters) + if name != "self": + parameters_list.append(name) + return tuple(parameters_list) def create_engines_dict(backend_entrypoints): @@ -79,8 +60,8 @@ def create_engines_dict(backend_entrypoints): return engines -def set_missing_parameters(engines): - for name, backend in engines.items(): +def set_missing_parameters(backend_entrypoints): + for name, backend in backend_entrypoints.items(): if backend.open_dataset_parameters is None: open_dataset = backend.open_dataset backend.open_dataset_parameters = detect_parameters(open_dataset) @@ -92,7 +73,10 @@ def build_engines(entrypoints): external_backend_entrypoints = create_engines_dict(pkg_entrypoints) backend_entrypoints.update(external_backend_entrypoints) set_missing_parameters(backend_entrypoints) - return backend_entrypoints + engines = {} + for name, backend in backend_entrypoints.items(): + engines[name] = backend() + return engines @functools.lru_cache(maxsize=1) diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index d9128d1d503..80485fce459 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -3,10 +3,23 @@ from ..core import indexing from ..core.utils import Frozen, FrozenDict, close_on_error from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray, BackendEntrypoint +from .common import ( + BACKEND_ENTRYPOINTS, + AbstractDataStore, + BackendArray, + BackendEntrypoint, +) from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint + +try: + from PseudoNetCDF import pncopen + + has_pseudonetcdf = True +except ModuleNotFoundError: + has_pseudonetcdf = False + # psuedonetcdf can invoke netCDF libraries internally PNETCDF_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK]) @@ -40,7 +53,6 @@ class PseudoNetCDFDataStore(AbstractDataStore): @classmethod def open(cls, filename, lock=None, mode=None, **format_kwargs): - from PseudoNetCDF import pncopen keywords = {"kwargs": format_kwargs} # only include mode if explicitly passed @@ -88,53 +100,55 @@ def close(self): self._manager.close() -def open_backend_dataset_pseudonetcdf( - filename_or_obj, - mask_and_scale=False, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - mode=None, - lock=None, - **format_kwargs, -): - - store = PseudoNetCDFDataStore.open( - filename_or_obj, lock=lock, mode=mode, **format_kwargs +class PseudoNetCDFBackendEntrypoint(BackendEntrypoint): + + # *args and **kwargs are not allowed in open_backend_dataset_ kwargs, + # unless the open_dataset_parameters are explicity defined like this: + open_dataset_parameters = ( + "filename_or_obj", + "mask_and_scale", + "decode_times", + "concat_characters", + "decode_coords", + "drop_variables", + "use_cftime", + "decode_timedelta", + "mode", + "lock", ) - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + def open_dataset( + self, + filename_or_obj, + mask_and_scale=False, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + mode=None, + lock=None, + **format_kwargs, + ): + store = PseudoNetCDFDataStore.open( + filename_or_obj, lock=lock, mode=mode, **format_kwargs ) - return ds - - -# *args and **kwargs are not allowed in open_backend_dataset_ kwargs, -# unless the open_dataset_parameters are explicity defined like this: -open_dataset_parameters = ( - "filename_or_obj", - "mask_and_scale", - "decode_times", - "concat_characters", - "decode_coords", - "drop_variables", - "use_cftime", - "decode_timedelta", - "mode", - "lock", -) -pseudonetcdf_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_pseudonetcdf, - open_dataset_parameters=open_dataset_parameters, -) + + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + +if has_pseudonetcdf: + BACKEND_ENTRYPOINTS["pseudonetcdf"] = PseudoNetCDFBackendEntrypoint diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 4995045a739..7f8622ca66e 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -4,8 +4,21 @@ from ..core.pycompat import integer_types from ..core.utils import Frozen, FrozenDict, close_on_error, is_dict_like, is_remote_uri from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray, BackendEntrypoint, robust_getitem -from .store import open_backend_dataset_store +from .common import ( + BACKEND_ENTRYPOINTS, + AbstractDataStore, + BackendArray, + BackendEntrypoint, + robust_getitem, +) +from .store import StoreBackendEntrypoint + +try: + import pydap.client + + has_pydap = True +except ModuleNotFoundError: + has_pydap = False class PydapArrayWrapper(BackendArray): @@ -74,7 +87,6 @@ def __init__(self, ds): @classmethod def open(cls, url, session=None): - import pydap.client ds = pydap.client.open_url(url, session=session) return cls(ds) @@ -95,41 +107,41 @@ def get_dimensions(self): return Frozen(self.ds.dimensions) -def guess_can_open_pydap(store_spec): - return isinstance(store_spec, str) and is_remote_uri(store_spec) - +class PydapBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + return isinstance(store_spec, str) and is_remote_uri(store_spec) -def open_backend_dataset_pydap( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - session=None, -): - - store = PydapDataStore.open( + def open_dataset( + self, filename_or_obj, - session=session, - ) - - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + session=None, + ): + store = PydapDataStore.open( + filename_or_obj, + session=session, ) - return ds + + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds -pydap_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_pydap, guess_can_open=guess_can_open_pydap -) +if has_pydap: + BACKEND_ENTRYPOINTS["pydap"] = PydapBackendEntrypoint diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index dc6c47935e8..41c99efd076 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -3,10 +3,23 @@ from ..core import indexing from ..core.utils import Frozen, FrozenDict, close_on_error from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray, BackendEntrypoint +from .common import ( + BACKEND_ENTRYPOINTS, + AbstractDataStore, + BackendArray, + BackendEntrypoint, +) from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, SerializableLock, combine_locks, ensure_lock -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint + +try: + import Nio + + has_pynio = True +except ModuleNotFoundError: + has_pynio = False + # PyNIO can invoke netCDF libraries internally # Add a dedicated lock just in case NCL as well isn't thread-safe. @@ -45,7 +58,6 @@ class NioDataStore(AbstractDataStore): """Store for accessing datasets via PyNIO""" def __init__(self, filename, mode="r", lock=None, **kwargs): - import Nio if lock is None: lock = PYNIO_LOCK @@ -85,37 +97,39 @@ def close(self): self._manager.close() -def open_backend_dataset_pynio( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - mode="r", - lock=None, -): - - store = NioDataStore( +class PynioBackendEntrypoint(BackendEntrypoint): + def open_dataset( filename_or_obj, - mode=mode, - lock=lock, - ) - - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + mode="r", + lock=None, + ): + store = NioDataStore( + filename_or_obj, + mode=mode, + lock=lock, ) - return ds - -pynio_backend = BackendEntrypoint(open_dataset=open_backend_dataset_pynio) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + +if has_pynio: + BACKEND_ENTRYPOINTS["pynio"] = PynioBackendEntrypoint diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index a0500c7e1c2..c689c1e99d7 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -361,6 +361,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc result = result.chunk(chunks, name_prefix=name_prefix, token=token) # Make the file closeable - result._file_obj = manager + result.set_close(manager.close) return result diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 873a91f9c07..ddc157ed8e4 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -6,11 +6,23 @@ from ..core.indexing import NumpyIndexingAdapter from ..core.utils import Frozen, FrozenDict, close_on_error, read_magic_number from ..core.variable import Variable -from .common import BackendArray, BackendEntrypoint, WritableCFDataStore +from .common import ( + BACKEND_ENTRYPOINTS, + BackendArray, + BackendEntrypoint, + WritableCFDataStore, +) from .file_manager import CachingFileManager, DummyFileManager from .locks import ensure_lock, get_write_lock from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint + +try: + import scipy.io + + has_scipy = True +except ModuleNotFoundError: + has_scipy = False def _decode_string(s): @@ -61,8 +73,6 @@ def __setitem__(self, key, value): def _open_scipy_netcdf(filename, mode, mmap, version): import gzip - import scipy.io - # if the string ends with .gz, then gunzip and open as netcdf file if isinstance(filename, str) and filename.endswith(".gz"): try: @@ -222,52 +232,54 @@ def close(self): self._manager.close() -def guess_can_open_scipy(store_spec): - try: - return read_magic_number(store_spec).startswith(b"CDF") - except TypeError: - pass +class ScipyBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + try: + return read_magic_number(store_spec).startswith(b"CDF") + except TypeError: + pass - try: - _, ext = os.path.splitext(store_spec) - except TypeError: - return False - return ext in {".nc", ".nc4", ".cdf", ".gz"} - - -def open_backend_dataset_scipy( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - mode="r", - format=None, - group=None, - mmap=None, - lock=None, -): - - store = ScipyDataStore( - filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock - ) - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".nc", ".nc4", ".cdf", ".gz"} + + def open_dataset( + self, + filename_or_obj, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + mode="r", + format=None, + group=None, + mmap=None, + lock=None, + ): + + store = ScipyDataStore( + filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock ) - return ds + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds -scipy_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_scipy, guess_can_open=guess_can_open_scipy -) + +if has_scipy: + BACKEND_ENTRYPOINTS["scipy"] = ScipyBackendEntrypoint diff --git a/xarray/backends/store.py b/xarray/backends/store.py index d314a9c3ca9..d57b3ab9df8 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -1,47 +1,45 @@ from .. import conventions from ..core.dataset import Dataset -from .common import AbstractDataStore, BackendEntrypoint - - -def guess_can_open_store(store_spec): - return isinstance(store_spec, AbstractDataStore) - - -def open_backend_dataset_store( - store, - *, - mask_and_scale=True, - decode_times=True, - concat_characters=True, - decode_coords=True, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, -): - vars, attrs = store.load() - file_obj = store - encoding = store.get_encoding() - - vars, attrs, coord_names = conventions.decode_cf_variables( - vars, - attrs, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, - ) - - ds = Dataset(vars, attrs=attrs) - ds = ds.set_coords(coord_names.intersection(vars)) - ds._file_obj = file_obj - ds.encoding = encoding - - return ds - - -store_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_store, guess_can_open=guess_can_open_store -) +from .common import BACKEND_ENTRYPOINTS, AbstractDataStore, BackendEntrypoint + + +class StoreBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + return isinstance(store_spec, AbstractDataStore) + + def open_dataset( + self, + store, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + ): + vars, attrs = store.load() + encoding = store.get_encoding() + + vars, attrs, coord_names = conventions.decode_cf_variables( + vars, + attrs, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + + ds = Dataset(vars, attrs=attrs) + ds = ds.set_coords(coord_names.intersection(vars)) + ds.set_close(store.close) + ds.encoding = encoding + + return ds + + +BACKEND_ENTRYPOINTS["store"] = StoreBackendEntrypoint diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 3b4b3a3d9d5..1d667a38b53 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -9,12 +9,21 @@ from ..core.utils import FrozenDict, HiddenKeyDict, close_on_error from ..core.variable import Variable from .common import ( + BACKEND_ENTRYPOINTS, AbstractWritableDataStore, BackendArray, BackendEntrypoint, _encode_variable_name, ) -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint + +try: + import zarr + + has_zarr = True +except ModuleNotFoundError: + has_zarr = False + # need some special secret attributes to tell us the dimensions DIMENSION_KEY = "_ARRAY_DIMENSIONS" @@ -289,7 +298,6 @@ def open_group( append_dim=None, write_region=None, ): - import zarr # zarr doesn't support pathlib.Path objects yet. zarr-python#601 if isinstance(store, pathlib.Path): @@ -409,7 +417,6 @@ def store( dimension on which the zarray will be appended only needed in append mode """ - import zarr existing_variables = { vn for vn in variables if _encode_variable_name(vn) in self.ds @@ -663,45 +670,48 @@ def open_zarr( return ds -def open_backend_dataset_zarr( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - group=None, - mode="r", - synchronizer=None, - consolidated=False, - consolidate_on_close=False, - chunk_store=None, -): - - store = ZarrStore.open_group( +class ZarrBackendEntrypoint(BackendEntrypoint): + def open_dataset( + self, filename_or_obj, - group=group, - mode=mode, - synchronizer=synchronizer, - consolidated=consolidated, - consolidate_on_close=consolidate_on_close, - chunk_store=chunk_store, - ) - - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + group=None, + mode="r", + synchronizer=None, + consolidated=False, + consolidate_on_close=False, + chunk_store=None, + ): + store = ZarrStore.open_group( + filename_or_obj, + group=group, + mode=mode, + synchronizer=synchronizer, + consolidated=consolidated, + consolidate_on_close=consolidate_on_close, + chunk_store=chunk_store, ) - return ds + + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds -zarr_backend = BackendEntrypoint(open_dataset=open_backend_dataset_zarr) +if has_zarr: + BACKEND_ENTRYPOINTS["zarr"] = ZarrBackendEntrypoint diff --git a/xarray/conventions.py b/xarray/conventions.py index bb0b92c77a1..e33ae53b31d 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -576,12 +576,12 @@ def decode_cf( vars = obj._variables attrs = obj.attrs extra_coords = set(obj.coords) - file_obj = obj._file_obj + close = obj._close encoding = obj.encoding elif isinstance(obj, AbstractDataStore): vars, attrs = obj.load() extra_coords = set() - file_obj = obj + close = obj.close encoding = obj.get_encoding() else: raise TypeError("can only decode Dataset or DataStore objects") @@ -599,7 +599,7 @@ def decode_cf( ) ds = Dataset(vars, attrs=attrs) ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars)) - ds._file_obj = file_obj + ds.set_close(close) ds.encoding = encoding return ds diff --git a/xarray/core/common.py b/xarray/core/common.py index 283114770cf..c5836c68759 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -3,6 +3,7 @@ from html import escape from textwrap import dedent from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -11,6 +12,7 @@ Iterator, List, Mapping, + Optional, Tuple, TypeVar, Union, @@ -31,6 +33,12 @@ ALL_DIMS = ... +if TYPE_CHECKING: + from .dataarray import DataArray + from .weighted import Weighted + +T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") + C = TypeVar("C") T = TypeVar("T") @@ -330,7 +338,9 @@ def get_squeeze_dims( class DataWithCoords(SupportsArithmetic, AttrAccessMixin): """Shared base class for Dataset and DataArray.""" - __slots__ = () + _close: Optional[Callable[[], None]] + + __slots__ = ("_close",) _rolling_exp_cls = RollingExp @@ -769,7 +779,9 @@ def groupby_bins( }, ) - def weighted(self, weights): + def weighted( + self: T_DataWithCoords, weights: "DataArray" + ) -> "Weighted[T_DataWithCoords]": """ Weighted operations. @@ -1263,11 +1275,27 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): return ops.where_method(self, cond, other) + def set_close(self, close: Optional[Callable[[], None]]) -> None: + """Register the function that releases any resources linked to this object. + + This method controls how xarray cleans up resources associated + with this object when the ``.close()`` method is called. It is mostly + intended for backend developers and it is rarely needed by regular + end-users. + + Parameters + ---------- + close : callable + The function that when called like ``close()`` releases + any resources linked to this object. + """ + self._close = close + def close(self: Any) -> None: - """Close any files linked to this object""" - if self._file_obj is not None: - self._file_obj.close() - self._file_obj = None + """Release any resources linked to this object.""" + if self._close is not None: + self._close() + self._close = None def isnull(self, keep_attrs: bool = None): """Test each value in the array for whether it is a missing value. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 6fdda8fc418..0155cdc4e19 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -344,6 +344,7 @@ class DataArray(AbstractArray, DataWithCoords): _cache: Dict[str, Any] _coords: Dict[Any, Variable] + _close: Optional[Callable[[], None]] _indexes: Optional[Dict[Hashable, pd.Index]] _name: Optional[Hashable] _variable: Variable @@ -351,7 +352,7 @@ class DataArray(AbstractArray, DataWithCoords): __slots__ = ( "_cache", "_coords", - "_file_obj", + "_close", "_indexes", "_name", "_variable", @@ -421,7 +422,7 @@ def __init__( # public interface. self._indexes = indexes - self._file_obj = None + self._close = None def _replace( self, @@ -1698,7 +1699,9 @@ def rename( new_name_or_name_dict = cast(Hashable, new_name_or_name_dict) return self._replace(name=new_name_or_name_dict) - def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray": + def swap_dims( + self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs + ) -> "DataArray": """Returns a new DataArray with swapped dimensions. Parameters @@ -1707,6 +1710,10 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray": Dictionary whose keys are current dimension names and whose values are new names. + **dim_kwargs : {dim: , ...}, optional + The keyword arguments form of ``dims_dict``. + One of dims_dict or dims_kwargs must be provided. + Returns ------- swapped : DataArray @@ -1748,6 +1755,7 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray": DataArray.rename Dataset.swap_dims """ + dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "swap_dims") ds = self._to_temp_dataset().swap_dims(dims_dict) return self._from_temp_dataset(ds) @@ -2247,6 +2255,28 @@ def drop_sel( ds = self._to_temp_dataset().drop_sel(labels, errors=errors) return self._from_temp_dataset(ds) + def drop_isel(self, indexers=None, **indexers_kwargs): + """Drop index positions from this DataArray. + + Parameters + ---------- + indexers : mapping of hashable to Any + Index locations to drop + **indexers_kwargs : {dim: position, ...}, optional + The keyword arguments form of ``dim`` and ``positions`` + + Returns + ------- + dropped : DataArray + + Raises + ------ + IndexError + """ + dataset = self._to_temp_dataset() + dataset = dataset.drop_isel(indexers=indexers, **indexers_kwargs) + return self._from_temp_dataset(dataset) + def dropna( self, dim: Hashable, how: str = "any", thresh: int = None ) -> "DataArray": @@ -3451,21 +3481,26 @@ def differentiate( return self._from_temp_dataset(ds) def integrate( - self, dim: Union[Hashable, Sequence[Hashable]], datetime_unit: str = None + self, + coord: Union[Hashable, Sequence[Hashable]] = None, + datetime_unit: str = None, + *, + dim: Union[Hashable, Sequence[Hashable]] = None, ) -> "DataArray": - """ integrate the array with the trapezoidal rule. + """Integrate along the given coordinate using the trapezoidal rule. .. note:: - This feature is limited to simple cartesian geometry, i.e. dim + This feature is limited to simple cartesian geometry, i.e. coord must be one dimensional. Parameters ---------- + coord: hashable, or a sequence of hashable + Coordinate(s) used for the integration. dim : hashable, or sequence of hashable Coordinate(s) used for the integration. - datetime_unit : {"Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", \ - "ps", "fs", "as"}, optional - Can be used to specify the unit if datetime coordinate is used. + datetime_unit: {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ + 'ps', 'fs', 'as'}, optional Returns ------- @@ -3473,6 +3508,7 @@ def integrate( See also -------- + Dataset.integrate numpy.trapz: corresponding numpy function Examples @@ -3498,7 +3534,22 @@ def integrate( array([5.4, 6.6, 7.8]) Dimensions without coordinates: y """ - ds = self._to_temp_dataset().integrate(dim, datetime_unit) + if dim is not None and coord is not None: + raise ValueError( + "Cannot pass both 'dim' and 'coord'. Please pass only 'coord' instead." + ) + + if dim is not None and coord is None: + coord = dim + msg = ( + "The `dim` keyword argument to `DataArray.integrate` is " + "being replaced with `coord`, for consistency with " + "`Dataset.integrate`. Please pass `coord` instead." + " `dim` will be removed in version 0.19.0." + ) + warnings.warn(msg, FutureWarning, stacklevel=2) + + ds = self._to_temp_dataset().integrate(coord, datetime_unit) return self._from_temp_dataset(ds) def unify_chunks(self) -> "DataArray": diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7edc2fab067..8376b4875f9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4,6 +4,7 @@ import sys import warnings from collections import defaultdict +from distutils.version import LooseVersion from html import escape from numbers import Number from operator import methodcaller @@ -79,7 +80,7 @@ ) from .missing import get_clean_interp_index from .options import OPTIONS, _get_keep_attrs -from .pycompat import is_duck_dask_array +from .pycompat import is_duck_dask_array, sparse_array_type from .utils import ( Default, Frozen, @@ -636,6 +637,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): _coord_names: Set[Hashable] _dims: Dict[Hashable, int] _encoding: Optional[Dict[Hashable, Any]] + _close: Optional[Callable[[], None]] _indexes: Optional[Dict[Hashable, pd.Index]] _variables: Dict[Hashable, Variable] @@ -645,7 +647,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): "_coord_names", "_dims", "_encoding", - "_file_obj", + "_close", "_indexes", "_variables", "__weakref__", @@ -687,7 +689,7 @@ def __init__( ) self._attrs = dict(attrs) if attrs is not None else None - self._file_obj = None + self._close = None self._encoding = None self._variables = variables self._coord_names = coord_names @@ -703,7 +705,7 @@ def load_store(cls, store, decoder=None) -> "Dataset": if decoder: variables, attributes = decoder(variables, attributes) obj = cls(variables, attrs=attributes) - obj._file_obj = store + obj.set_close(store.close) return obj @property @@ -876,7 +878,7 @@ def __dask_postcompute__(self): self._attrs, self._indexes, self._encoding, - self._file_obj, + self._close, ) return self._dask_postcompute, args @@ -896,7 +898,7 @@ def __dask_postpersist__(self): self._attrs, self._indexes, self._encoding, - self._file_obj, + self._close, ) return self._dask_postpersist, args @@ -1007,7 +1009,7 @@ def _construct_direct( attrs=None, indexes=None, encoding=None, - file_obj=None, + close=None, ): """Shortcut around __init__ for internal use when we want to skip costly validation @@ -1020,7 +1022,7 @@ def _construct_direct( obj._dims = dims obj._indexes = indexes obj._attrs = attrs - obj._file_obj = file_obj + obj._close = close obj._encoding = encoding return obj @@ -2122,7 +2124,7 @@ def isel( attrs=self._attrs, indexes=indexes, encoding=self._encoding, - file_obj=self._file_obj, + close=self._close, ) def _isel_fancy( @@ -3154,7 +3156,9 @@ def rename_vars( ) return self._replace(variables, coord_names, dims=dims, indexes=indexes) - def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset": + def swap_dims( + self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs + ) -> "Dataset": """Returns a new object with swapped dimensions. Parameters @@ -3163,6 +3167,10 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset": Dictionary whose keys are current dimension names and whose values are new names. + **dim_kwargs : {existing_dim: new_dim, ...}, optional + The keyword arguments form of ``dims_dict``. + One of dims_dict or dims_kwargs must be provided. + Returns ------- swapped : Dataset @@ -3213,6 +3221,8 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset": """ # TODO: deprecate this method in favor of a (less confusing) # rename_dims() method that only renames dimensions. + + dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "swap_dims") for k, v in dims_dict.items(): if k not in self.dims: raise ValueError( @@ -3706,7 +3716,40 @@ def ensure_stackable(val): return data_array - def _unstack_once(self, dim: Hashable, fill_value, sparse) -> "Dataset": + def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset": + index = self.get_index(dim) + index = remove_unused_levels_categories(index) + + variables: Dict[Hashable, Variable] = {} + indexes = {k: v for k, v in self.indexes.items() if k != dim} + + for name, var in self.variables.items(): + if name != dim: + if dim in var.dims: + if isinstance(fill_value, Mapping): + fill_value_ = fill_value[name] + else: + fill_value_ = fill_value + + variables[name] = var._unstack_once( + index=index, dim=dim, fill_value=fill_value_ + ) + else: + variables[name] = var + + for name, lev in zip(index.names, index.levels): + variables[name] = IndexVariable(name, lev) + indexes[name] = lev + + coord_names = set(self._coord_names) - {dim} | set(index.names) + + return self._replace_with_new_dims( + variables, coord_names=coord_names, indexes=indexes + ) + + def _unstack_full_reindex( + self, dim: Hashable, fill_value, sparse: bool + ) -> "Dataset": index = self.get_index(dim) index = remove_unused_levels_categories(index) full_idx = pd.MultiIndex.from_product(index.levels, names=index.names) @@ -3803,7 +3846,38 @@ def unstack( result = self.copy(deep=False) for dim in dims: - result = result._unstack_once(dim, fill_value, sparse) + + if ( + # Dask arrays don't support assignment by index, which the fast unstack + # function requires. + # https://github.com/pydata/xarray/pull/4746#issuecomment-753282125 + any(is_duck_dask_array(v.data) for v in self.variables.values()) + # Sparse doesn't currently support (though we could special-case + # it) + # https://github.com/pydata/sparse/issues/422 + or any( + isinstance(v.data, sparse_array_type) + for v in self.variables.values() + ) + or sparse + # numpy full_like only added `shape` in 1.17 + or LooseVersion(np.__version__) < LooseVersion("1.17") + # Until https://github.com/pydata/xarray/pull/4751 is resolved, + # we check explicitly whether it's a numpy array. Once that is + # resolved, explicitly exclude pint arrays. + # # pint doesn't implement `np.full_like` in a way that's + # # currently compatible. + # # https://github.com/pydata/xarray/pull/4746#issuecomment-753425173 + # # or any( + # # isinstance(v.data, pint_array_type) for v in self.variables.values() + # # ) + or any( + not isinstance(v.data, np.ndarray) for v in self.variables.values() + ) + ): + result = result._unstack_full_reindex(dim, fill_value, sparse) + else: + result = result._unstack_once(dim, fill_value) return result def update(self, other: "CoercibleMapping") -> "Dataset": @@ -4020,9 +4094,17 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): Examples -------- - >>> data = np.random.randn(2, 3) + >>> data = np.arange(6).reshape(2, 3) >>> labels = ["a", "b", "c"] >>> ds = xr.Dataset({"A": (["x", "y"], data), "y": labels}) + >>> ds + + Dimensions: (x: 2, y: 3) + Coordinates: + * y (y) >> ds.drop_sel(y=["a", "c"]) Dimensions: (x: 2, y: 1) @@ -4030,7 +4112,7 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): * y (y) >> ds.drop_sel(y="b") Dimensions: (x: 2, y: 2) @@ -4038,12 +4120,12 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): * y (y) >> data = np.arange(6).reshape(2, 3) + >>> labels = ["a", "b", "c"] + >>> ds = xr.Dataset({"A": (["x", "y"], data), "y": labels}) + >>> ds + + Dimensions: (x: 2, y: 3) + Coordinates: + * y (y) >> ds.drop_isel(y=[0, 2]) + + Dimensions: (x: 2, y: 1) + Coordinates: + * y (y) >> ds.drop_isel(y=1) + + Dimensions: (x: 2, y: 2) + Coordinates: + * y (y) "Dataset": @@ -4900,6 +5047,10 @@ def _set_numpy_data_from_dataframe( self[name] = (dims, values) return + # NB: similar, more general logic, now exists in + # variable.unstack_once; we could consider combining them at some + # point. + shape = tuple(lev.size for lev in idx.levels) indexer = tuple(idx.codes) @@ -5812,8 +5963,10 @@ def differentiate(self, coord, edge_order=1, datetime_unit=None): variables[k] = v return self._replace(variables) - def integrate(self, coord, datetime_unit=None): - """ integrate the array with the trapezoidal rule. + def integrate( + self, coord: Union[Hashable, Sequence[Hashable]], datetime_unit: str = None + ) -> "Dataset": + """Integrate along the given coordinate using the trapezoidal rule. .. note:: This feature is limited to simple cartesian geometry, i.e. coord @@ -5821,11 +5974,11 @@ def integrate(self, coord, datetime_unit=None): Parameters ---------- - coord: str, or sequence of str + coord: hashable, or a sequence of hashable Coordinate(s) used for the integration. - datetime_unit : {"Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", \ - "ps", "fs", "as"}, optional - Can be specify the unit if datetime coordinate is used. + datetime_unit: {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ + 'ps', 'fs', 'as'}, optional + Specify the unit if datetime coordinate is used. Returns ------- diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 282620e3569..0c1be1cc175 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -300,11 +300,18 @@ def _summarize_coord_multiindex(coord, col_width, marker): def _summarize_coord_levels(coord, col_width, marker="-"): + if len(coord) > 100 and col_width < len(coord): + n_values = col_width + indices = list(range(0, n_values)) + list(range(-n_values, 0)) + subset = coord[indices] + else: + subset = coord + return "\n".join( summarize_variable( - lname, coord.get_level_variable(lname), col_width, marker=marker + lname, subset.get_level_variable(lname), col_width, marker=marker ) - for lname in coord.level_names + for lname in subset.level_names ) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 797de65bbcf..64c1895da59 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -10,6 +10,7 @@ Any, Dict, Hashable, + List, Mapping, Optional, Sequence, @@ -1488,7 +1489,7 @@ def set_dims(self, dims, shape=None): ) return expanded_var.transpose(*dims) - def _stack_once(self, dims, new_dim): + def _stack_once(self, dims: List[Hashable], new_dim: Hashable): if not set(dims) <= set(self.dims): raise ValueError("invalid existing dimensions: %s" % dims) @@ -1544,7 +1545,15 @@ def stack(self, dimensions=None, **dimensions_kwargs): result = result._stack_once(dims, new_dim) return result - def _unstack_once(self, dims, old_dim): + def _unstack_once_full( + self, dims: Mapping[Hashable, int], old_dim: Hashable + ) -> "Variable": + """ + Unstacks the variable without needing an index. + + Unlike `_unstack_once`, this function requires the existing dimension to + contain the full product of the new dimensions. + """ new_dim_names = tuple(dims.keys()) new_dim_sizes = tuple(dims.values()) @@ -1573,6 +1582,53 @@ def _unstack_once(self, dims, old_dim): return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) + def _unstack_once( + self, + index: pd.MultiIndex, + dim: Hashable, + fill_value=dtypes.NA, + ) -> "Variable": + """ + Unstacks this variable given an index to unstack and the name of the + dimension to which the index refers. + """ + + reordered = self.transpose(..., dim) + + new_dim_sizes = [lev.size for lev in index.levels] + new_dim_names = index.names + indexer = index.codes + + # Potentially we could replace `len(other_dims)` with just `-1` + other_dims = [d for d in self.dims if d != dim] + new_shape = list(reordered.shape[: len(other_dims)]) + new_dim_sizes + new_dims = reordered.dims[: len(other_dims)] + new_dim_names + + if fill_value is dtypes.NA: + is_missing_values = np.prod(new_shape) > np.prod(self.shape) + if is_missing_values: + dtype, fill_value = dtypes.maybe_promote(self.dtype) + else: + dtype = self.dtype + fill_value = dtypes.get_fill_value(dtype) + else: + dtype = self.dtype + + # Currently fails on sparse due to https://github.com/pydata/sparse/issues/422 + data = np.full_like( + self.data, + fill_value=fill_value, + shape=new_shape, + dtype=dtype, + ) + + # Indexer is a list of lists of locations. Each list is the locations + # on the new dimension. This is robust to the data being sparse; in that + # case the destinations will be NaN / zero. + data[(..., *indexer)] = reordered + + return self._replace(dims=new_dims, data=data) + def unstack(self, dimensions=None, **dimensions_kwargs): """ Unstack an existing dimension into multiple new dimensions. @@ -1580,6 +1636,10 @@ def unstack(self, dimensions=None, **dimensions_kwargs): New dimensions will be added at the end, and the order of the data along each new dimension will be in contiguous (C) order. + Note that unlike ``DataArray.unstack`` and ``Dataset.unstack``, this + method requires the existing dimension to contain the full product of + the new dimensions. + Parameters ---------- dimensions : mapping of hashable to mapping of hashable to int @@ -1598,11 +1658,13 @@ def unstack(self, dimensions=None, **dimensions_kwargs): See also -------- Variable.stack + DataArray.unstack + Dataset.unstack """ dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, "unstack") result = self for old_dim, dims in dimensions.items(): - result = result._unstack_once(dims, old_dim) + result = result._unstack_once_full(dims, old_dim) return result def fillna(self, value): diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index dbd4e1ad103..449a7200ee7 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,13 +1,16 @@ -from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload +from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, TypeVar, Union from . import duck_array_ops from .computation import dot -from .options import _get_keep_attrs from .pycompat import is_duck_dask_array if TYPE_CHECKING: + from .common import DataWithCoords # noqa: F401 from .dataarray import DataArray, Dataset +T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") + + _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """ Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s). @@ -56,7 +59,7 @@ """ -class Weighted: +class Weighted(Generic[T_DataWithCoords]): """An object that implements weighted operations. You should create a Weighted object by using the ``DataArray.weighted`` or @@ -70,15 +73,7 @@ class Weighted: __slots__ = ("obj", "weights") - @overload - def __init__(self, obj: "DataArray", weights: "DataArray") -> None: - ... - - @overload - def __init__(self, obj: "Dataset", weights: "DataArray") -> None: - ... - - def __init__(self, obj, weights): + def __init__(self, obj: T_DataWithCoords, weights: "DataArray"): """ Create a Weighted object @@ -121,8 +116,8 @@ def _weight_check(w): else: _weight_check(weights.data) - self.obj = obj - self.weights = weights + self.obj: T_DataWithCoords = obj + self.weights: "DataArray" = weights @staticmethod def _reduce( @@ -146,7 +141,6 @@ def _reduce( # `dot` does not broadcast arrays, so this avoids creating a large # DataArray (if `weights` has additional dimensions) - # maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`) return dot(da, weights, dims=dim) def _sum_of_weights( @@ -203,7 +197,7 @@ def sum_of_weights( self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, keep_attrs: Optional[bool] = None, - ) -> Union["DataArray", "Dataset"]: + ) -> T_DataWithCoords: return self._implementation( self._sum_of_weights, dim=dim, keep_attrs=keep_attrs @@ -214,7 +208,7 @@ def sum( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, keep_attrs: Optional[bool] = None, - ) -> Union["DataArray", "Dataset"]: + ) -> T_DataWithCoords: return self._implementation( self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs @@ -225,7 +219,7 @@ def mean( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, keep_attrs: Optional[bool] = None, - ) -> Union["DataArray", "Dataset"]: + ) -> T_DataWithCoords: return self._implementation( self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs @@ -239,22 +233,15 @@ def __repr__(self): return f"{klass} with weights along dimensions: {weight_dims}" -class DataArrayWeighted(Weighted): - def _implementation(self, func, dim, **kwargs): - - keep_attrs = kwargs.pop("keep_attrs") - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) - - weighted = func(self.obj, dim=dim, **kwargs) - - if keep_attrs: - weighted.attrs = self.obj.attrs +class DataArrayWeighted(Weighted["DataArray"]): + def _implementation(self, func, dim, **kwargs) -> "DataArray": - return weighted + dataset = self.obj._to_temp_dataset() + dataset = dataset.map(func, dim=dim, **kwargs) + return self.obj._from_temp_dataset(dataset) -class DatasetWeighted(Weighted): +class DatasetWeighted(Weighted["Dataset"]): def _implementation(self, func, dim, **kwargs) -> "Dataset": return self.obj.map(func, dim=dim, **kwargs) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 3ead427e22e..fc84687511e 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1639,6 +1639,16 @@ def test_swap_dims(self): expected.indexes[dim_name], actual.indexes[dim_name] ) + # as kwargs + array = DataArray(np.random.randn(3), {"x": list("abc")}, "x") + expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y") + actual = array.swap_dims(x="y") + assert_identical(expected, actual) + for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()): + pd.testing.assert_index_equal( + expected.indexes[dim_name], actual.indexes[dim_name] + ) + # multiindex case idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"]) array = DataArray(np.random.randn(3), {"y": ("x", idx)}, "x") @@ -2327,6 +2337,12 @@ def test_drop_index_labels(self): with pytest.warns(DeprecationWarning): arr.drop([0, 1, 3], dim="y", errors="ignore") + def test_drop_index_positions(self): + arr = DataArray(np.random.randn(2, 3), dims=["x", "y"]) + actual = arr.drop_isel(y=[0, 1]) + expected = arr[:, 2:] + assert_identical(actual, expected) + def test_dropna(self): x = np.random.randn(4, 4) x[::2, 0] = np.nan diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index bd1938455b1..db47faa8d2b 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2371,8 +2371,12 @@ def test_drop_index_labels(self): data.drop(DataArray(["a", "b", "c"]), dim="x", errors="ignore") assert_identical(expected, actual) - with raises_regex(ValueError, "does not have coordinate labels"): - data.drop_sel(y=1) + actual = data.drop_sel(y=[1]) + expected = data.isel(y=[0, 2]) + assert_identical(expected, actual) + + with raises_regex(KeyError, "not found in axis"): + data.drop_sel(x=0) def test_drop_labels_by_keyword(self): data = Dataset( @@ -2410,6 +2414,34 @@ def test_drop_labels_by_keyword(self): with pytest.raises(ValueError): data.drop(dim="x", x="a") + def test_drop_labels_by_position(self): + data = Dataset( + {"A": (["x", "y"], np.random.randn(2, 6)), "x": ["a", "b"], "y": range(6)} + ) + # Basic functionality. + assert len(data.coords["x"]) == 2 + + actual = data.drop_isel(x=0) + expected = data.drop_sel(x="a") + assert_identical(expected, actual) + + actual = data.drop_isel(x=[0]) + expected = data.drop_sel(x=["a"]) + assert_identical(expected, actual) + + actual = data.drop_isel(x=[0, 1]) + expected = data.drop_sel(x=["a", "b"]) + assert_identical(expected, actual) + assert actual.coords["x"].size == 0 + + actual = data.drop_isel(x=[0, 1], y=range(0, 6, 2)) + expected = data.drop_sel(x=["a", "b"], y=range(0, 6, 2)) + assert_identical(expected, actual) + assert actual.coords["x"].size == 0 + + with pytest.raises(KeyError): + data.drop_isel(z=1) + def test_drop_dims(self): data = xr.Dataset( { @@ -2716,6 +2748,13 @@ def test_swap_dims(self): actual = original.swap_dims({"x": "u"}) assert_identical(expected, actual) + # as kwargs + expected = Dataset( + {"y": ("u", list("abc")), "z": 42}, coords={"x": ("u", [1, 2, 3])} + ) + actual = original.swap_dims(x="u") + assert_identical(expected, actual) + # handle multiindex case idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"]) original = Dataset({"x": [1, 2, 3], "y": ("x", idx), "z": 42}) @@ -6564,6 +6603,9 @@ def test_integrate(dask): with pytest.raises(ValueError): da.integrate("x2d") + with pytest.warns(FutureWarning): + da.integrate(dim="x") + @pytest.mark.parametrize("dask", [True, False]) @pytest.mark.parametrize("which_datetime", ["np", "cftime"]) diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py index 110ef47209f..64a1c563dba 100644 --- a/xarray/tests/test_plugins.py +++ b/xarray/tests/test_plugins.py @@ -6,19 +6,24 @@ from xarray.backends import common, plugins -def dummy_open_dataset_args(filename_or_obj, *args): - pass +class DummyBackendEntrypointArgs(common.BackendEntrypoint): + def open_dataset(filename_or_obj, *args): + pass -def dummy_open_dataset_kwargs(filename_or_obj, **kwargs): - pass +class DummyBackendEntrypointKwargs(common.BackendEntrypoint): + def open_dataset(filename_or_obj, **kwargs): + pass -def dummy_open_dataset(filename_or_obj, *, decoder): - pass +class DummyBackendEntrypoint1(common.BackendEntrypoint): + def open_dataset(self, filename_or_obj, *, decoder): + pass -dummy_cfgrib = common.BackendEntrypoint(dummy_open_dataset) +class DummyBackendEntrypoint2(common.BackendEntrypoint): + def open_dataset(self, filename_or_obj, *, decoder): + pass @pytest.fixture @@ -65,46 +70,48 @@ def test_create_engines_dict(): def test_set_missing_parameters(): - backend_1 = common.BackendEntrypoint(dummy_open_dataset) - backend_2 = common.BackendEntrypoint(dummy_open_dataset, ("filename_or_obj",)) + backend_1 = DummyBackendEntrypoint1 + backend_2 = DummyBackendEntrypoint2 + backend_2.open_dataset_parameters = ("filename_or_obj",) engines = {"engine_1": backend_1, "engine_2": backend_2} plugins.set_missing_parameters(engines) assert len(engines) == 2 - engine_1 = engines["engine_1"] - assert engine_1.open_dataset_parameters == ("filename_or_obj", "decoder") - engine_2 = engines["engine_2"] - assert engine_2.open_dataset_parameters == ("filename_or_obj",) + assert backend_1.open_dataset_parameters == ("filename_or_obj", "decoder") + assert backend_2.open_dataset_parameters == ("filename_or_obj",) + + backend = DummyBackendEntrypointKwargs() + backend.open_dataset_parameters = ("filename_or_obj", "decoder") + plugins.set_missing_parameters({"engine": backend}) + assert backend.open_dataset_parameters == ("filename_or_obj", "decoder") + + backend = DummyBackendEntrypointArgs() + backend.open_dataset_parameters = ("filename_or_obj", "decoder") + plugins.set_missing_parameters({"engine": backend}) + assert backend.open_dataset_parameters == ("filename_or_obj", "decoder") def test_set_missing_parameters_raise_error(): - backend = common.BackendEntrypoint(dummy_open_dataset_args) + backend = DummyBackendEntrypointKwargs() with pytest.raises(TypeError): plugins.set_missing_parameters({"engine": backend}) - backend = common.BackendEntrypoint( - dummy_open_dataset_args, ("filename_or_obj", "decoder") - ) - plugins.set_missing_parameters({"engine": backend}) - - backend = common.BackendEntrypoint(dummy_open_dataset_kwargs) + backend = DummyBackendEntrypointArgs() with pytest.raises(TypeError): plugins.set_missing_parameters({"engine": backend}) - backend = plugins.BackendEntrypoint( - dummy_open_dataset_kwargs, ("filename_or_obj", "decoder") - ) - plugins.set_missing_parameters({"engine": backend}) - -@mock.patch("pkg_resources.EntryPoint.load", mock.MagicMock(return_value=dummy_cfgrib)) +@mock.patch( + "pkg_resources.EntryPoint.load", + mock.MagicMock(return_value=DummyBackendEntrypoint1), +) def test_build_engines(): - dummy_cfgrib_pkg_entrypoint = pkg_resources.EntryPoint.parse( + dummy_pkg_entrypoint = pkg_resources.EntryPoint.parse( "cfgrib = xarray.tests.test_plugins:backend_1" ) - backend_entrypoints = plugins.build_engines([dummy_cfgrib_pkg_entrypoint]) - assert backend_entrypoints["cfgrib"] is dummy_cfgrib + backend_entrypoints = plugins.build_engines([dummy_pkg_entrypoint]) + assert isinstance(backend_entrypoints["cfgrib"], DummyBackendEntrypoint1) assert backend_entrypoints["cfgrib"].open_dataset_parameters == ( "filename_or_obj", "decoder", diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index bb3127e90b5..76dd830de23 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -3681,7 +3681,7 @@ def test_stacking_reordering(self, func, dtype): ( method("diff", dim="x"), method("differentiate", coord="x"), - method("integrate", dim="x"), + method("integrate", coord="x"), method("quantile", q=[0.25, 0.75]), method("reduce", func=np.sum, dim="x"), pytest.param(lambda x: x.dot(x), id="method_dot"),