diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 09ef053bb39..c7ea19a53cb 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -5,11 +5,3 @@
- [ ] Passes `pre-commit run --all-files`
- [ ] User visible changes (including notable bug fixes) are documented in `whats-new.rst`
- [ ] New functions/methods are listed in `api.rst`
-
-
-
-
- Overriding CI behaviors
-
- By default, the upstream dev CI is disabled on pull request and push events. You can override this behavior per commit by adding a [test-upstream] tag to the first line of the commit message. For documentation-only commits, you can skip the CI per commit by adding a [skip-ci] tag to the first line of the commit message
-
diff --git a/asv_bench/benchmarks/repr.py b/asv_bench/benchmarks/repr.py
new file mode 100644
index 00000000000..b218c0be870
--- /dev/null
+++ b/asv_bench/benchmarks/repr.py
@@ -0,0 +1,18 @@
+import pandas as pd
+
+import xarray as xr
+
+
+class ReprMultiIndex:
+ def setup(self, key):
+ index = pd.MultiIndex.from_product(
+ [range(10000), range(10000)], names=("level_0", "level_1")
+ )
+ series = pd.Series(range(100000000), index=index)
+ self.da = xr.DataArray(series)
+
+ def time_repr(self):
+ repr(self.da)
+
+ def time_repr_html(self):
+ self.da._repr_html_()
diff --git a/asv_bench/benchmarks/unstacking.py b/asv_bench/benchmarks/unstacking.py
index 342475b96df..8d0c3932870 100644
--- a/asv_bench/benchmarks/unstacking.py
+++ b/asv_bench/benchmarks/unstacking.py
@@ -7,18 +7,23 @@
class Unstacking:
def setup(self):
- data = np.random.RandomState(0).randn(1, 1000, 500)
- self.ds = xr.DataArray(data).stack(flat_dim=["dim_1", "dim_2"])
+ data = np.random.RandomState(0).randn(500, 1000)
+ self.da_full = xr.DataArray(data, dims=list("ab")).stack(flat_dim=[...])
+ self.da_missing = self.da_full[:-1]
+ self.df_missing = self.da_missing.to_pandas()
def time_unstack_fast(self):
- self.ds.unstack("flat_dim")
+ self.da_full.unstack("flat_dim")
def time_unstack_slow(self):
- self.ds[:, ::-1].unstack("flat_dim")
+ self.da_missing.unstack("flat_dim")
+
+ def time_unstack_pandas_slow(self):
+ self.df_missing.unstack()
class UnstackingDask(Unstacking):
def setup(self, *args, **kwargs):
requires_dask()
super().setup(**kwargs)
- self.ds = self.ds.chunk({"flat_dim": 50})
+ self.da_full = self.da_full.chunk({"flat_dim": 50})
diff --git a/doc/api.rst b/doc/api.rst
index ce866093db8..9add7a96109 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -126,6 +126,7 @@ Indexing
Dataset.isel
Dataset.sel
Dataset.drop_sel
+ Dataset.drop_isel
Dataset.head
Dataset.tail
Dataset.thin
@@ -308,6 +309,7 @@ Indexing
DataArray.isel
DataArray.sel
DataArray.drop_sel
+ DataArray.drop_isel
DataArray.head
DataArray.tail
DataArray.thin
diff --git a/doc/conf.py b/doc/conf.py
index d83e966f3fa..14b28b4e471 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -411,7 +411,7 @@
intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"pandas": ("https://pandas.pydata.org/pandas-docs/stable", None),
- "iris": ("https://scitools.org.uk/iris/docs/latest", None),
+ "iris": ("https://scitools-iris.readthedocs.io/en/latest", None),
"numpy": ("https://numpy.org/doc/stable", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference", None),
"numba": ("https://numba.pydata.org/numba-doc/latest", None),
diff --git a/doc/contributing.rst b/doc/contributing.rst
index 9c4ce5a0af2..439791cbbd6 100644
--- a/doc/contributing.rst
+++ b/doc/contributing.rst
@@ -836,6 +836,7 @@ PR checklist
- Write new tests if needed. See `"Test-driven development/code writing" `_.
- Test the code using `Pytest `_. Running all tests (type ``pytest`` in the root directory) takes a while, so feel free to only run the tests you think are needed based on your PR (example: ``pytest xarray/tests/test_dataarray.py``). CI will catch any failing tests.
+ - By default, the upstream dev CI is disabled on pull request and push events. You can override this behavior per commit by adding a [test-upstream] tag to the first line of the commit message. For documentation-only commits, you can skip the CI per commit by adding a "[skip-ci]" tag to the first line of the commit message.
- **Properly format your code** and verify that it passes the formatting guidelines set by `Black `_ and `Flake8 `_. See `"Code formatting" `_. You can use `pre-commit `_ to run these automatically on each commit.
diff --git a/doc/faq.rst b/doc/faq.rst
index a2b8be47e06..a2151cc4b37 100644
--- a/doc/faq.rst
+++ b/doc/faq.rst
@@ -166,7 +166,7 @@ different approaches to handling metadata: Iris strictly interprets
`CF conventions`_. Iris particularly shines at mapping, thanks to its
integration with Cartopy_.
-.. _Iris: http://scitools.org.uk/iris/
+.. _Iris: https://scitools-iris.readthedocs.io/en/stable/
.. _Cartopy: http://scitools.org.uk/cartopy/docs/latest/
`UV-CDAT`__ is another Python library that implements in-memory netCDF-like
diff --git a/doc/related-projects.rst b/doc/related-projects.rst
index 456cb64197f..0a010195d6d 100644
--- a/doc/related-projects.rst
+++ b/doc/related-projects.rst
@@ -15,6 +15,7 @@ Geosciences
- `aospy `_: Automated analysis and management of gridded climate data.
- `climpred `_: Analysis of ensemble forecast models for climate prediction.
- `geocube `_: Tool to convert geopandas vector data into rasterized xarray data.
+- `GeoWombat `_: Utilities for analysis of remotely sensed and gridded raster data at scale (easily tame Landsat, Sentinel, Quickbird, and PlanetScope).
- `infinite-diff `_: xarray-based finite-differencing, focused on gridded climate/meterology data
- `marc_analysis `_: Analysis package for CESM/MARC experiments and output.
- `MetPy `_: A collection of tools in Python for reading, visualizing, and performing calculations with weather data.
diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index 502181487b9..bb81426e4dd 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -17,7 +17,7 @@ What's New
.. _whats-new.0.16.3:
-v0.16.3 (unreleased)
+v0.17.0 (unreleased)
--------------------
Breaking changes
@@ -39,16 +39,32 @@ Breaking changes
always be set such that ``int64`` values can be used. In the past, no units
finer than "seconds" were chosen, which would sometimes mean that ``float64``
values were required, which would lead to inaccurate I/O round-trips.
-- remove deprecated ``autoclose`` kwargs from :py:func:`open_dataset` (:pull: `4725`).
- By `Aureliana Barghini `_
+- remove deprecated ``autoclose`` kwargs from :py:func:`open_dataset` (:pull:`4725`).
+ By `Aureliana Barghini `_.
+
+Deprecations
+~~~~~~~~~~~~
+
+- ``dim`` argument to :py:meth:`DataArray.integrate` is being deprecated in
+ favour of a ``coord`` argument, for consistency with :py:meth:`Dataset.integrate`.
+ For now using ``dim`` issues a ``FutureWarning``. By `Tom Nicholas `_.
New Features
~~~~~~~~~~~~
+- Significantly higher ``unstack`` performance on numpy-backed arrays which
+ contain missing values; 8x faster in our benchmark, and 2x faster than pandas.
+ (:pull:`4746`);
+ By `Maximilian Roos `_.
+
- Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables.
By `Deepak Cherian `_
- Add :py:meth:`Dataset.plot.quiver` for quiver plots with :py:class:`Dataset` variables.
By `Deepak Cherian `_
+ By `Deepak Cherian `_.
+- :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims
+ in the form of kwargs as well as a dict, like most similar methods.
+ By `Maximilian Roos `_.
Bug fixes
~~~~~~~~~
@@ -82,6 +98,7 @@ Bug fixes
- Expand user directory paths (e.g. ``~/``) in :py:func:`open_mfdataset` and
:py:meth:`Dataset.to_zarr` (:issue:`4783`, :pull:`4795`).
By `Julien Seguinot `_.
+- Add :py:meth:`Dataset.drop_isel` and :py:meth:`DataArray.drop_isel` (:issue:`4658`, :pull:`4819`). By `Daniel Mesejo `_.
Documentation
~~~~~~~~~~~~~
@@ -110,6 +127,8 @@ Internal Changes
By `Maximilian Roos `_.
- Speed up attribute style access (e.g. ``ds.somevar`` instead of ``ds["somevar"]``) and tab completion
in ipython (:issue:`4741`, :pull:`4742`). By `Richard Kleijn `_.
+- Added the ``set_close`` method to ``Dataset`` and ``DataArray`` for beckends to specify how to voluntary release
+ all resources. (:pull:`#4809`), By `Alessandro Amici `_.
.. _whats-new.0.16.2:
diff --git a/xarray/backends/api.py b/xarray/backends/api.py
index 4958062a262..81314588784 100644
--- a/xarray/backends/api.py
+++ b/xarray/backends/api.py
@@ -522,7 +522,7 @@ def maybe_decode_store(store, chunks):
else:
ds2 = ds
- ds2._file_obj = ds._file_obj
+ ds2.set_close(ds._close)
return ds2
filename_or_obj = _normalize_path(filename_or_obj)
@@ -701,7 +701,7 @@ def open_dataarray(
else:
(data_array,) = dataset.data_vars.values()
- data_array._file_obj = dataset._file_obj
+ data_array.set_close(dataset._close)
# Reset names if they were changed during saving
# to ensure that we can 'roundtrip' perfectly
@@ -715,17 +715,6 @@ def open_dataarray(
return data_array
-class _MultiFileCloser:
- __slots__ = ("file_objs",)
-
- def __init__(self, file_objs):
- self.file_objs = file_objs
-
- def close(self):
- for f in self.file_objs:
- f.close()
-
-
def open_mfdataset(
paths,
chunks=None,
@@ -918,14 +907,14 @@ def open_mfdataset(
getattr_ = getattr
datasets = [open_(p, **open_kwargs) for p in paths]
- file_objs = [getattr_(ds, "_file_obj") for ds in datasets]
+ closers = [getattr_(ds, "_close") for ds in datasets]
if preprocess is not None:
datasets = [preprocess(ds) for ds in datasets]
if parallel:
# calling compute here will return the datasets/file_objs lists,
# the underlying datasets will still be stored as dask arrays
- datasets, file_objs = dask.compute(datasets, file_objs)
+ datasets, closers = dask.compute(datasets, closers)
# Combine all datasets, closing them in case of a ValueError
try:
@@ -963,7 +952,11 @@ def open_mfdataset(
ds.close()
raise
- combined._file_obj = _MultiFileCloser(file_objs)
+ def multi_file_closer():
+ for closer in closers:
+ closer()
+
+ combined.set_close(multi_file_closer)
# read global attributes from the attrs_file or from the first dataset
if attrs_file is not None:
diff --git a/xarray/backends/apiv2.py b/xarray/backends/apiv2.py
index 0f98291983d..d31fc9ea773 100644
--- a/xarray/backends/apiv2.py
+++ b/xarray/backends/apiv2.py
@@ -90,7 +90,7 @@ def _dataset_from_backend_dataset(
**extra_tokens,
)
- ds._file_obj = backend_ds._file_obj
+ ds.set_close(backend_ds._close)
# Ensure source filename always stored in dataset object (GH issue #2550)
if "source" not in ds.encoding:
diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py
index d4933e370c7..65c5bc2a02b 100644
--- a/xarray/backends/cfgrib_.py
+++ b/xarray/backends/cfgrib_.py
@@ -5,9 +5,22 @@
from ..core import indexing
from ..core.utils import Frozen, FrozenDict, close_on_error
from ..core.variable import Variable
-from .common import AbstractDataStore, BackendArray, BackendEntrypoint
+from .common import (
+ BACKEND_ENTRYPOINTS,
+ AbstractDataStore,
+ BackendArray,
+ BackendEntrypoint,
+)
from .locks import SerializableLock, ensure_lock
-from .store import open_backend_dataset_store
+from .store import StoreBackendEntrypoint
+
+try:
+ import cfgrib
+
+ has_cfgrib = True
+except ModuleNotFoundError:
+ has_cfgrib = False
+
# FIXME: Add a dedicated lock, even if ecCodes is supposed to be thread-safe
# in most circumstances. See:
@@ -38,7 +51,6 @@ class CfGribDataStore(AbstractDataStore):
"""
def __init__(self, filename, lock=None, **backend_kwargs):
- import cfgrib
if lock is None:
lock = ECCODES_LOCK
@@ -74,58 +86,58 @@ def get_encoding(self):
return encoding
-def guess_can_open_cfgrib(store_spec):
- try:
- _, ext = os.path.splitext(store_spec)
- except TypeError:
- return False
- return ext in {".grib", ".grib2", ".grb", ".grb2"}
-
-
-def open_backend_dataset_cfgrib(
- filename_or_obj,
- *,
- mask_and_scale=True,
- decode_times=None,
- concat_characters=None,
- decode_coords=None,
- drop_variables=None,
- use_cftime=None,
- decode_timedelta=None,
- lock=None,
- indexpath="{path}.{short_hash}.idx",
- filter_by_keys={},
- read_keys=[],
- encode_cf=("parameter", "time", "geography", "vertical"),
- squeeze=True,
- time_dims=("time", "step"),
-):
-
- store = CfGribDataStore(
+class CfgribfBackendEntrypoint(BackendEntrypoint):
+ def guess_can_open(self, store_spec):
+ try:
+ _, ext = os.path.splitext(store_spec)
+ except TypeError:
+ return False
+ return ext in {".grib", ".grib2", ".grb", ".grb2"}
+
+ def open_dataset(
+ self,
filename_or_obj,
- indexpath=indexpath,
- filter_by_keys=filter_by_keys,
- read_keys=read_keys,
- encode_cf=encode_cf,
- squeeze=squeeze,
- time_dims=time_dims,
- lock=lock,
- )
-
- with close_on_error(store):
- ds = open_backend_dataset_store(
- store,
- mask_and_scale=mask_and_scale,
- decode_times=decode_times,
- concat_characters=concat_characters,
- decode_coords=decode_coords,
- drop_variables=drop_variables,
- use_cftime=use_cftime,
- decode_timedelta=decode_timedelta,
+ *,
+ mask_and_scale=True,
+ decode_times=None,
+ concat_characters=None,
+ decode_coords=None,
+ drop_variables=None,
+ use_cftime=None,
+ decode_timedelta=None,
+ lock=None,
+ indexpath="{path}.{short_hash}.idx",
+ filter_by_keys={},
+ read_keys=[],
+ encode_cf=("parameter", "time", "geography", "vertical"),
+ squeeze=True,
+ time_dims=("time", "step"),
+ ):
+
+ store = CfGribDataStore(
+ filename_or_obj,
+ indexpath=indexpath,
+ filter_by_keys=filter_by_keys,
+ read_keys=read_keys,
+ encode_cf=encode_cf,
+ squeeze=squeeze,
+ time_dims=time_dims,
+ lock=lock,
)
- return ds
-
-
-cfgrib_backend = BackendEntrypoint(
- open_dataset=open_backend_dataset_cfgrib, guess_can_open=guess_can_open_cfgrib
-)
+ store_entrypoint = StoreBackendEntrypoint()
+ with close_on_error(store):
+ ds = store_entrypoint.open_dataset(
+ store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ return ds
+
+
+if has_cfgrib:
+ BACKEND_ENTRYPOINTS["cfgrib"] = CfgribfBackendEntrypoint
diff --git a/xarray/backends/common.py b/xarray/backends/common.py
index 72a63957662..e2905d0866b 100644
--- a/xarray/backends/common.py
+++ b/xarray/backends/common.py
@@ -1,6 +1,7 @@
import logging
import time
import traceback
+from typing import Dict, Tuple, Type, Union
import numpy as np
@@ -343,9 +344,13 @@ def encode(self, variables, attributes):
class BackendEntrypoint:
- __slots__ = ("guess_can_open", "open_dataset", "open_dataset_parameters")
+ open_dataset_parameters: Union[Tuple, None] = None
- def __init__(self, open_dataset, open_dataset_parameters=None, guess_can_open=None):
- self.open_dataset = open_dataset
- self.open_dataset_parameters = open_dataset_parameters
- self.guess_can_open = guess_can_open
+ def open_dataset(self):
+ raise NotImplementedError
+
+ def guess_can_open(self, store_spec):
+ return False
+
+
+BACKEND_ENTRYPOINTS: Dict[str, Type[BackendEntrypoint]] = {}
diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py
index b2996369ee7..aa892c4f89c 100644
--- a/xarray/backends/h5netcdf_.py
+++ b/xarray/backends/h5netcdf_.py
@@ -8,7 +8,12 @@
from ..core import indexing
from ..core.utils import FrozenDict, is_remote_uri, read_magic_number
from ..core.variable import Variable
-from .common import BackendEntrypoint, WritableCFDataStore, find_root_and_group
+from .common import (
+ BACKEND_ENTRYPOINTS,
+ BackendEntrypoint,
+ WritableCFDataStore,
+ find_root_and_group,
+)
from .file_manager import CachingFileManager, DummyFileManager
from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock
from .netCDF4_ import (
@@ -18,7 +23,14 @@
_get_datatype,
_nc4_require_group,
)
-from .store import open_backend_dataset_store
+from .store import StoreBackendEntrypoint
+
+try:
+ import h5netcdf
+
+ has_h5netcdf = True
+except ModuleNotFoundError:
+ has_h5netcdf = False
class H5NetCDFArrayWrapper(BaseNetCDF4Array):
@@ -85,8 +97,6 @@ class H5NetCDFStore(WritableCFDataStore):
def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=False):
- import h5netcdf
-
if isinstance(manager, (h5netcdf.File, h5netcdf.Group)):
if group is None:
root, group = find_root_and_group(manager)
@@ -122,7 +132,6 @@ def open(
invalid_netcdf=None,
phony_dims=None,
):
- import h5netcdf
if isinstance(filename, bytes):
raise ValueError(
@@ -319,59 +328,61 @@ def close(self, **kwargs):
self._manager.close(**kwargs)
-def guess_can_open_h5netcdf(store_spec):
- try:
- return read_magic_number(store_spec).startswith(b"\211HDF\r\n\032\n")
- except TypeError:
- pass
-
- try:
- _, ext = os.path.splitext(store_spec)
- except TypeError:
- return False
-
- return ext in {".nc", ".nc4", ".cdf"}
-
-
-def open_backend_dataset_h5netcdf(
- filename_or_obj,
- *,
- mask_and_scale=True,
- decode_times=None,
- concat_characters=None,
- decode_coords=None,
- drop_variables=None,
- use_cftime=None,
- decode_timedelta=None,
- format=None,
- group=None,
- lock=None,
- invalid_netcdf=None,
- phony_dims=None,
-):
-
- store = H5NetCDFStore.open(
+class H5netcdfBackendEntrypoint(BackendEntrypoint):
+ def guess_can_open(self, store_spec):
+ try:
+ return read_magic_number(store_spec).startswith(b"\211HDF\r\n\032\n")
+ except TypeError:
+ pass
+
+ try:
+ _, ext = os.path.splitext(store_spec)
+ except TypeError:
+ return False
+
+ return ext in {".nc", ".nc4", ".cdf"}
+
+ def open_dataset(
+ self,
filename_or_obj,
- format=format,
- group=group,
- lock=lock,
- invalid_netcdf=invalid_netcdf,
- phony_dims=phony_dims,
- )
+ *,
+ mask_and_scale=True,
+ decode_times=None,
+ concat_characters=None,
+ decode_coords=None,
+ drop_variables=None,
+ use_cftime=None,
+ decode_timedelta=None,
+ format=None,
+ group=None,
+ lock=None,
+ invalid_netcdf=None,
+ phony_dims=None,
+ ):
- ds = open_backend_dataset_store(
- store,
- mask_and_scale=mask_and_scale,
- decode_times=decode_times,
- concat_characters=concat_characters,
- decode_coords=decode_coords,
- drop_variables=drop_variables,
- use_cftime=use_cftime,
- decode_timedelta=decode_timedelta,
- )
- return ds
+ store = H5NetCDFStore.open(
+ filename_or_obj,
+ format=format,
+ group=group,
+ lock=lock,
+ invalid_netcdf=invalid_netcdf,
+ phony_dims=phony_dims,
+ )
+ store_entrypoint = StoreBackendEntrypoint()
+
+ ds = store_entrypoint.open_dataset(
+ store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ return ds
-h5netcdf_backend = BackendEntrypoint(
- open_dataset=open_backend_dataset_h5netcdf, guess_can_open=guess_can_open_h5netcdf
-)
+
+if has_h5netcdf:
+ BACKEND_ENTRYPOINTS["h5netcdf"] = H5netcdfBackendEntrypoint
diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py
index 0e35270ea9a..e3d87aaf83f 100644
--- a/xarray/backends/netCDF4_.py
+++ b/xarray/backends/netCDF4_.py
@@ -12,6 +12,7 @@
from ..core.utils import FrozenDict, close_on_error, is_remote_uri
from ..core.variable import Variable
from .common import (
+ BACKEND_ENTRYPOINTS,
BackendArray,
BackendEntrypoint,
WritableCFDataStore,
@@ -21,7 +22,15 @@
from .file_manager import CachingFileManager, DummyFileManager
from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, get_write_lock
from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable
-from .store import open_backend_dataset_store
+from .store import StoreBackendEntrypoint
+
+try:
+ import netCDF4
+
+ has_netcdf4 = True
+except ModuleNotFoundError:
+ has_netcdf4 = False
+
# This lookup table maps from dtype.byteorder to a readable endian
# string used by netCDF4.
@@ -298,7 +307,6 @@ class NetCDF4DataStore(WritableCFDataStore):
def __init__(
self, manager, group=None, mode=None, lock=NETCDF4_PYTHON_LOCK, autoclose=False
):
- import netCDF4
if isinstance(manager, netCDF4.Dataset):
if group is None:
@@ -335,7 +343,6 @@ def open(
lock_maker=None,
autoclose=False,
):
- import netCDF4
if isinstance(filename, pathlib.Path):
filename = os.fspath(filename)
@@ -505,61 +512,62 @@ def close(self, **kwargs):
self._manager.close(**kwargs)
-def guess_can_open_netcdf4(store_spec):
- if isinstance(store_spec, str) and is_remote_uri(store_spec):
- return True
- try:
- _, ext = os.path.splitext(store_spec)
- except TypeError:
- return False
- return ext in {".nc", ".nc4", ".cdf"}
-
-
-def open_backend_dataset_netcdf4(
- filename_or_obj,
- mask_and_scale=True,
- decode_times=None,
- concat_characters=None,
- decode_coords=None,
- drop_variables=None,
- use_cftime=None,
- decode_timedelta=None,
- group=None,
- mode="r",
- format="NETCDF4",
- clobber=True,
- diskless=False,
- persist=False,
- lock=None,
- autoclose=False,
-):
+class NetCDF4BackendEntrypoint(BackendEntrypoint):
+ def guess_can_open(self, store_spec):
+ if isinstance(store_spec, str) and is_remote_uri(store_spec):
+ return True
+ try:
+ _, ext = os.path.splitext(store_spec)
+ except TypeError:
+ return False
+ return ext in {".nc", ".nc4", ".cdf"}
- store = NetCDF4DataStore.open(
+ def open_dataset(
+ self,
filename_or_obj,
- mode=mode,
- format=format,
- group=group,
- clobber=clobber,
- diskless=diskless,
- persist=persist,
- lock=lock,
- autoclose=autoclose,
- )
+ mask_and_scale=True,
+ decode_times=None,
+ concat_characters=None,
+ decode_coords=None,
+ drop_variables=None,
+ use_cftime=None,
+ decode_timedelta=None,
+ group=None,
+ mode="r",
+ format="NETCDF4",
+ clobber=True,
+ diskless=False,
+ persist=False,
+ lock=None,
+ autoclose=False,
+ ):
- with close_on_error(store):
- ds = open_backend_dataset_store(
- store,
- mask_and_scale=mask_and_scale,
- decode_times=decode_times,
- concat_characters=concat_characters,
- decode_coords=decode_coords,
- drop_variables=drop_variables,
- use_cftime=use_cftime,
- decode_timedelta=decode_timedelta,
+ store = NetCDF4DataStore.open(
+ filename_or_obj,
+ mode=mode,
+ format=format,
+ group=group,
+ clobber=clobber,
+ diskless=diskless,
+ persist=persist,
+ lock=lock,
+ autoclose=autoclose,
)
- return ds
+ store_entrypoint = StoreBackendEntrypoint()
+ with close_on_error(store):
+ ds = store_entrypoint.open_dataset(
+ store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ return ds
-netcdf4_backend = BackendEntrypoint(
- open_dataset=open_backend_dataset_netcdf4, guess_can_open=guess_can_open_netcdf4
-)
+
+if has_netcdf4:
+ BACKEND_ENTRYPOINTS["netcdf4"] = NetCDF4BackendEntrypoint
diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py
index d5799a78f91..b8cd2bf6378 100644
--- a/xarray/backends/plugins.py
+++ b/xarray/backends/plugins.py
@@ -2,33 +2,11 @@
import inspect
import itertools
import logging
-import typing as T
import warnings
import pkg_resources
-from .cfgrib_ import cfgrib_backend
-from .common import BackendEntrypoint
-from .h5netcdf_ import h5netcdf_backend
-from .netCDF4_ import netcdf4_backend
-from .pseudonetcdf_ import pseudonetcdf_backend
-from .pydap_ import pydap_backend
-from .pynio_ import pynio_backend
-from .scipy_ import scipy_backend
-from .store import store_backend
-from .zarr import zarr_backend
-
-BACKEND_ENTRYPOINTS: T.Dict[str, BackendEntrypoint] = {
- "store": store_backend,
- "netcdf4": netcdf4_backend,
- "h5netcdf": h5netcdf_backend,
- "scipy": scipy_backend,
- "pseudonetcdf": pseudonetcdf_backend,
- "zarr": zarr_backend,
- "cfgrib": cfgrib_backend,
- "pydap": pydap_backend,
- "pynio": pynio_backend,
-}
+from .common import BACKEND_ENTRYPOINTS
def remove_duplicates(backend_entrypoints):
@@ -58,6 +36,7 @@ def remove_duplicates(backend_entrypoints):
def detect_parameters(open_dataset):
signature = inspect.signature(open_dataset)
parameters = signature.parameters
+ parameters_list = []
for name, param in parameters.items():
if param.kind in (
inspect.Parameter.VAR_KEYWORD,
@@ -67,7 +46,9 @@ def detect_parameters(open_dataset):
f"All the parameters in {open_dataset!r} signature should be explicit. "
"*args and **kwargs is not supported"
)
- return tuple(parameters)
+ if name != "self":
+ parameters_list.append(name)
+ return tuple(parameters_list)
def create_engines_dict(backend_entrypoints):
@@ -79,8 +60,8 @@ def create_engines_dict(backend_entrypoints):
return engines
-def set_missing_parameters(engines):
- for name, backend in engines.items():
+def set_missing_parameters(backend_entrypoints):
+ for name, backend in backend_entrypoints.items():
if backend.open_dataset_parameters is None:
open_dataset = backend.open_dataset
backend.open_dataset_parameters = detect_parameters(open_dataset)
@@ -92,7 +73,10 @@ def build_engines(entrypoints):
external_backend_entrypoints = create_engines_dict(pkg_entrypoints)
backend_entrypoints.update(external_backend_entrypoints)
set_missing_parameters(backend_entrypoints)
- return backend_entrypoints
+ engines = {}
+ for name, backend in backend_entrypoints.items():
+ engines[name] = backend()
+ return engines
@functools.lru_cache(maxsize=1)
diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py
index d9128d1d503..80485fce459 100644
--- a/xarray/backends/pseudonetcdf_.py
+++ b/xarray/backends/pseudonetcdf_.py
@@ -3,10 +3,23 @@
from ..core import indexing
from ..core.utils import Frozen, FrozenDict, close_on_error
from ..core.variable import Variable
-from .common import AbstractDataStore, BackendArray, BackendEntrypoint
+from .common import (
+ BACKEND_ENTRYPOINTS,
+ AbstractDataStore,
+ BackendArray,
+ BackendEntrypoint,
+)
from .file_manager import CachingFileManager
from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock
-from .store import open_backend_dataset_store
+from .store import StoreBackendEntrypoint
+
+try:
+ from PseudoNetCDF import pncopen
+
+ has_pseudonetcdf = True
+except ModuleNotFoundError:
+ has_pseudonetcdf = False
+
# psuedonetcdf can invoke netCDF libraries internally
PNETCDF_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK])
@@ -40,7 +53,6 @@ class PseudoNetCDFDataStore(AbstractDataStore):
@classmethod
def open(cls, filename, lock=None, mode=None, **format_kwargs):
- from PseudoNetCDF import pncopen
keywords = {"kwargs": format_kwargs}
# only include mode if explicitly passed
@@ -88,53 +100,55 @@ def close(self):
self._manager.close()
-def open_backend_dataset_pseudonetcdf(
- filename_or_obj,
- mask_and_scale=False,
- decode_times=None,
- concat_characters=None,
- decode_coords=None,
- drop_variables=None,
- use_cftime=None,
- decode_timedelta=None,
- mode=None,
- lock=None,
- **format_kwargs,
-):
-
- store = PseudoNetCDFDataStore.open(
- filename_or_obj, lock=lock, mode=mode, **format_kwargs
+class PseudoNetCDFBackendEntrypoint(BackendEntrypoint):
+
+ # *args and **kwargs are not allowed in open_backend_dataset_ kwargs,
+ # unless the open_dataset_parameters are explicity defined like this:
+ open_dataset_parameters = (
+ "filename_or_obj",
+ "mask_and_scale",
+ "decode_times",
+ "concat_characters",
+ "decode_coords",
+ "drop_variables",
+ "use_cftime",
+ "decode_timedelta",
+ "mode",
+ "lock",
)
- with close_on_error(store):
- ds = open_backend_dataset_store(
- store,
- mask_and_scale=mask_and_scale,
- decode_times=decode_times,
- concat_characters=concat_characters,
- decode_coords=decode_coords,
- drop_variables=drop_variables,
- use_cftime=use_cftime,
- decode_timedelta=decode_timedelta,
+ def open_dataset(
+ self,
+ filename_or_obj,
+ mask_and_scale=False,
+ decode_times=None,
+ concat_characters=None,
+ decode_coords=None,
+ drop_variables=None,
+ use_cftime=None,
+ decode_timedelta=None,
+ mode=None,
+ lock=None,
+ **format_kwargs,
+ ):
+ store = PseudoNetCDFDataStore.open(
+ filename_or_obj, lock=lock, mode=mode, **format_kwargs
)
- return ds
-
-
-# *args and **kwargs are not allowed in open_backend_dataset_ kwargs,
-# unless the open_dataset_parameters are explicity defined like this:
-open_dataset_parameters = (
- "filename_or_obj",
- "mask_and_scale",
- "decode_times",
- "concat_characters",
- "decode_coords",
- "drop_variables",
- "use_cftime",
- "decode_timedelta",
- "mode",
- "lock",
-)
-pseudonetcdf_backend = BackendEntrypoint(
- open_dataset=open_backend_dataset_pseudonetcdf,
- open_dataset_parameters=open_dataset_parameters,
-)
+
+ store_entrypoint = StoreBackendEntrypoint()
+ with close_on_error(store):
+ ds = store_entrypoint.open_dataset(
+ store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ return ds
+
+
+if has_pseudonetcdf:
+ BACKEND_ENTRYPOINTS["pseudonetcdf"] = PseudoNetCDFBackendEntrypoint
diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py
index 4995045a739..7f8622ca66e 100644
--- a/xarray/backends/pydap_.py
+++ b/xarray/backends/pydap_.py
@@ -4,8 +4,21 @@
from ..core.pycompat import integer_types
from ..core.utils import Frozen, FrozenDict, close_on_error, is_dict_like, is_remote_uri
from ..core.variable import Variable
-from .common import AbstractDataStore, BackendArray, BackendEntrypoint, robust_getitem
-from .store import open_backend_dataset_store
+from .common import (
+ BACKEND_ENTRYPOINTS,
+ AbstractDataStore,
+ BackendArray,
+ BackendEntrypoint,
+ robust_getitem,
+)
+from .store import StoreBackendEntrypoint
+
+try:
+ import pydap.client
+
+ has_pydap = True
+except ModuleNotFoundError:
+ has_pydap = False
class PydapArrayWrapper(BackendArray):
@@ -74,7 +87,6 @@ def __init__(self, ds):
@classmethod
def open(cls, url, session=None):
- import pydap.client
ds = pydap.client.open_url(url, session=session)
return cls(ds)
@@ -95,41 +107,41 @@ def get_dimensions(self):
return Frozen(self.ds.dimensions)
-def guess_can_open_pydap(store_spec):
- return isinstance(store_spec, str) and is_remote_uri(store_spec)
-
+class PydapBackendEntrypoint(BackendEntrypoint):
+ def guess_can_open(self, store_spec):
+ return isinstance(store_spec, str) and is_remote_uri(store_spec)
-def open_backend_dataset_pydap(
- filename_or_obj,
- mask_and_scale=True,
- decode_times=None,
- concat_characters=None,
- decode_coords=None,
- drop_variables=None,
- use_cftime=None,
- decode_timedelta=None,
- session=None,
-):
-
- store = PydapDataStore.open(
+ def open_dataset(
+ self,
filename_or_obj,
- session=session,
- )
-
- with close_on_error(store):
- ds = open_backend_dataset_store(
- store,
- mask_and_scale=mask_and_scale,
- decode_times=decode_times,
- concat_characters=concat_characters,
- decode_coords=decode_coords,
- drop_variables=drop_variables,
- use_cftime=use_cftime,
- decode_timedelta=decode_timedelta,
+ mask_and_scale=True,
+ decode_times=None,
+ concat_characters=None,
+ decode_coords=None,
+ drop_variables=None,
+ use_cftime=None,
+ decode_timedelta=None,
+ session=None,
+ ):
+ store = PydapDataStore.open(
+ filename_or_obj,
+ session=session,
)
- return ds
+
+ store_entrypoint = StoreBackendEntrypoint()
+ with close_on_error(store):
+ ds = store_entrypoint.open_dataset(
+ store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ return ds
-pydap_backend = BackendEntrypoint(
- open_dataset=open_backend_dataset_pydap, guess_can_open=guess_can_open_pydap
-)
+if has_pydap:
+ BACKEND_ENTRYPOINTS["pydap"] = PydapBackendEntrypoint
diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py
index dc6c47935e8..41c99efd076 100644
--- a/xarray/backends/pynio_.py
+++ b/xarray/backends/pynio_.py
@@ -3,10 +3,23 @@
from ..core import indexing
from ..core.utils import Frozen, FrozenDict, close_on_error
from ..core.variable import Variable
-from .common import AbstractDataStore, BackendArray, BackendEntrypoint
+from .common import (
+ BACKEND_ENTRYPOINTS,
+ AbstractDataStore,
+ BackendArray,
+ BackendEntrypoint,
+)
from .file_manager import CachingFileManager
from .locks import HDF5_LOCK, NETCDFC_LOCK, SerializableLock, combine_locks, ensure_lock
-from .store import open_backend_dataset_store
+from .store import StoreBackendEntrypoint
+
+try:
+ import Nio
+
+ has_pynio = True
+except ModuleNotFoundError:
+ has_pynio = False
+
# PyNIO can invoke netCDF libraries internally
# Add a dedicated lock just in case NCL as well isn't thread-safe.
@@ -45,7 +58,6 @@ class NioDataStore(AbstractDataStore):
"""Store for accessing datasets via PyNIO"""
def __init__(self, filename, mode="r", lock=None, **kwargs):
- import Nio
if lock is None:
lock = PYNIO_LOCK
@@ -85,37 +97,39 @@ def close(self):
self._manager.close()
-def open_backend_dataset_pynio(
- filename_or_obj,
- mask_and_scale=True,
- decode_times=None,
- concat_characters=None,
- decode_coords=None,
- drop_variables=None,
- use_cftime=None,
- decode_timedelta=None,
- mode="r",
- lock=None,
-):
-
- store = NioDataStore(
+class PynioBackendEntrypoint(BackendEntrypoint):
+ def open_dataset(
filename_or_obj,
- mode=mode,
- lock=lock,
- )
-
- with close_on_error(store):
- ds = open_backend_dataset_store(
- store,
- mask_and_scale=mask_and_scale,
- decode_times=decode_times,
- concat_characters=concat_characters,
- decode_coords=decode_coords,
- drop_variables=drop_variables,
- use_cftime=use_cftime,
- decode_timedelta=decode_timedelta,
+ mask_and_scale=True,
+ decode_times=None,
+ concat_characters=None,
+ decode_coords=None,
+ drop_variables=None,
+ use_cftime=None,
+ decode_timedelta=None,
+ mode="r",
+ lock=None,
+ ):
+ store = NioDataStore(
+ filename_or_obj,
+ mode=mode,
+ lock=lock,
)
- return ds
-
-pynio_backend = BackendEntrypoint(open_dataset=open_backend_dataset_pynio)
+ store_entrypoint = StoreBackendEntrypoint()
+ with close_on_error(store):
+ ds = store_entrypoint.open_dataset(
+ store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ return ds
+
+
+if has_pynio:
+ BACKEND_ENTRYPOINTS["pynio"] = PynioBackendEntrypoint
diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py
index a0500c7e1c2..c689c1e99d7 100644
--- a/xarray/backends/rasterio_.py
+++ b/xarray/backends/rasterio_.py
@@ -361,6 +361,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc
result = result.chunk(chunks, name_prefix=name_prefix, token=token)
# Make the file closeable
- result._file_obj = manager
+ result.set_close(manager.close)
return result
diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py
index 873a91f9c07..ddc157ed8e4 100644
--- a/xarray/backends/scipy_.py
+++ b/xarray/backends/scipy_.py
@@ -6,11 +6,23 @@
from ..core.indexing import NumpyIndexingAdapter
from ..core.utils import Frozen, FrozenDict, close_on_error, read_magic_number
from ..core.variable import Variable
-from .common import BackendArray, BackendEntrypoint, WritableCFDataStore
+from .common import (
+ BACKEND_ENTRYPOINTS,
+ BackendArray,
+ BackendEntrypoint,
+ WritableCFDataStore,
+)
from .file_manager import CachingFileManager, DummyFileManager
from .locks import ensure_lock, get_write_lock
from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name
-from .store import open_backend_dataset_store
+from .store import StoreBackendEntrypoint
+
+try:
+ import scipy.io
+
+ has_scipy = True
+except ModuleNotFoundError:
+ has_scipy = False
def _decode_string(s):
@@ -61,8 +73,6 @@ def __setitem__(self, key, value):
def _open_scipy_netcdf(filename, mode, mmap, version):
import gzip
- import scipy.io
-
# if the string ends with .gz, then gunzip and open as netcdf file
if isinstance(filename, str) and filename.endswith(".gz"):
try:
@@ -222,52 +232,54 @@ def close(self):
self._manager.close()
-def guess_can_open_scipy(store_spec):
- try:
- return read_magic_number(store_spec).startswith(b"CDF")
- except TypeError:
- pass
+class ScipyBackendEntrypoint(BackendEntrypoint):
+ def guess_can_open(self, store_spec):
+ try:
+ return read_magic_number(store_spec).startswith(b"CDF")
+ except TypeError:
+ pass
- try:
- _, ext = os.path.splitext(store_spec)
- except TypeError:
- return False
- return ext in {".nc", ".nc4", ".cdf", ".gz"}
-
-
-def open_backend_dataset_scipy(
- filename_or_obj,
- mask_and_scale=True,
- decode_times=None,
- concat_characters=None,
- decode_coords=None,
- drop_variables=None,
- use_cftime=None,
- decode_timedelta=None,
- mode="r",
- format=None,
- group=None,
- mmap=None,
- lock=None,
-):
-
- store = ScipyDataStore(
- filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock
- )
- with close_on_error(store):
- ds = open_backend_dataset_store(
- store,
- mask_and_scale=mask_and_scale,
- decode_times=decode_times,
- concat_characters=concat_characters,
- decode_coords=decode_coords,
- drop_variables=drop_variables,
- use_cftime=use_cftime,
- decode_timedelta=decode_timedelta,
+ try:
+ _, ext = os.path.splitext(store_spec)
+ except TypeError:
+ return False
+ return ext in {".nc", ".nc4", ".cdf", ".gz"}
+
+ def open_dataset(
+ self,
+ filename_or_obj,
+ mask_and_scale=True,
+ decode_times=None,
+ concat_characters=None,
+ decode_coords=None,
+ drop_variables=None,
+ use_cftime=None,
+ decode_timedelta=None,
+ mode="r",
+ format=None,
+ group=None,
+ mmap=None,
+ lock=None,
+ ):
+
+ store = ScipyDataStore(
+ filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock
)
- return ds
+ store_entrypoint = StoreBackendEntrypoint()
+ with close_on_error(store):
+ ds = store_entrypoint.open_dataset(
+ store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ return ds
-scipy_backend = BackendEntrypoint(
- open_dataset=open_backend_dataset_scipy, guess_can_open=guess_can_open_scipy
-)
+
+if has_scipy:
+ BACKEND_ENTRYPOINTS["scipy"] = ScipyBackendEntrypoint
diff --git a/xarray/backends/store.py b/xarray/backends/store.py
index d314a9c3ca9..d57b3ab9df8 100644
--- a/xarray/backends/store.py
+++ b/xarray/backends/store.py
@@ -1,47 +1,45 @@
from .. import conventions
from ..core.dataset import Dataset
-from .common import AbstractDataStore, BackendEntrypoint
-
-
-def guess_can_open_store(store_spec):
- return isinstance(store_spec, AbstractDataStore)
-
-
-def open_backend_dataset_store(
- store,
- *,
- mask_and_scale=True,
- decode_times=True,
- concat_characters=True,
- decode_coords=True,
- drop_variables=None,
- use_cftime=None,
- decode_timedelta=None,
-):
- vars, attrs = store.load()
- file_obj = store
- encoding = store.get_encoding()
-
- vars, attrs, coord_names = conventions.decode_cf_variables(
- vars,
- attrs,
- mask_and_scale=mask_and_scale,
- decode_times=decode_times,
- concat_characters=concat_characters,
- decode_coords=decode_coords,
- drop_variables=drop_variables,
- use_cftime=use_cftime,
- decode_timedelta=decode_timedelta,
- )
-
- ds = Dataset(vars, attrs=attrs)
- ds = ds.set_coords(coord_names.intersection(vars))
- ds._file_obj = file_obj
- ds.encoding = encoding
-
- return ds
-
-
-store_backend = BackendEntrypoint(
- open_dataset=open_backend_dataset_store, guess_can_open=guess_can_open_store
-)
+from .common import BACKEND_ENTRYPOINTS, AbstractDataStore, BackendEntrypoint
+
+
+class StoreBackendEntrypoint(BackendEntrypoint):
+ def guess_can_open(self, store_spec):
+ return isinstance(store_spec, AbstractDataStore)
+
+ def open_dataset(
+ self,
+ store,
+ *,
+ mask_and_scale=True,
+ decode_times=True,
+ concat_characters=True,
+ decode_coords=True,
+ drop_variables=None,
+ use_cftime=None,
+ decode_timedelta=None,
+ ):
+ vars, attrs = store.load()
+ encoding = store.get_encoding()
+
+ vars, attrs, coord_names = conventions.decode_cf_variables(
+ vars,
+ attrs,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+
+ ds = Dataset(vars, attrs=attrs)
+ ds = ds.set_coords(coord_names.intersection(vars))
+ ds.set_close(store.close)
+ ds.encoding = encoding
+
+ return ds
+
+
+BACKEND_ENTRYPOINTS["store"] = StoreBackendEntrypoint
diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py
index 3b4b3a3d9d5..1d667a38b53 100644
--- a/xarray/backends/zarr.py
+++ b/xarray/backends/zarr.py
@@ -9,12 +9,21 @@
from ..core.utils import FrozenDict, HiddenKeyDict, close_on_error
from ..core.variable import Variable
from .common import (
+ BACKEND_ENTRYPOINTS,
AbstractWritableDataStore,
BackendArray,
BackendEntrypoint,
_encode_variable_name,
)
-from .store import open_backend_dataset_store
+from .store import StoreBackendEntrypoint
+
+try:
+ import zarr
+
+ has_zarr = True
+except ModuleNotFoundError:
+ has_zarr = False
+
# need some special secret attributes to tell us the dimensions
DIMENSION_KEY = "_ARRAY_DIMENSIONS"
@@ -289,7 +298,6 @@ def open_group(
append_dim=None,
write_region=None,
):
- import zarr
# zarr doesn't support pathlib.Path objects yet. zarr-python#601
if isinstance(store, pathlib.Path):
@@ -409,7 +417,6 @@ def store(
dimension on which the zarray will be appended
only needed in append mode
"""
- import zarr
existing_variables = {
vn for vn in variables if _encode_variable_name(vn) in self.ds
@@ -663,45 +670,48 @@ def open_zarr(
return ds
-def open_backend_dataset_zarr(
- filename_or_obj,
- mask_and_scale=True,
- decode_times=None,
- concat_characters=None,
- decode_coords=None,
- drop_variables=None,
- use_cftime=None,
- decode_timedelta=None,
- group=None,
- mode="r",
- synchronizer=None,
- consolidated=False,
- consolidate_on_close=False,
- chunk_store=None,
-):
-
- store = ZarrStore.open_group(
+class ZarrBackendEntrypoint(BackendEntrypoint):
+ def open_dataset(
+ self,
filename_or_obj,
- group=group,
- mode=mode,
- synchronizer=synchronizer,
- consolidated=consolidated,
- consolidate_on_close=consolidate_on_close,
- chunk_store=chunk_store,
- )
-
- with close_on_error(store):
- ds = open_backend_dataset_store(
- store,
- mask_and_scale=mask_and_scale,
- decode_times=decode_times,
- concat_characters=concat_characters,
- decode_coords=decode_coords,
- drop_variables=drop_variables,
- use_cftime=use_cftime,
- decode_timedelta=decode_timedelta,
+ mask_and_scale=True,
+ decode_times=None,
+ concat_characters=None,
+ decode_coords=None,
+ drop_variables=None,
+ use_cftime=None,
+ decode_timedelta=None,
+ group=None,
+ mode="r",
+ synchronizer=None,
+ consolidated=False,
+ consolidate_on_close=False,
+ chunk_store=None,
+ ):
+ store = ZarrStore.open_group(
+ filename_or_obj,
+ group=group,
+ mode=mode,
+ synchronizer=synchronizer,
+ consolidated=consolidated,
+ consolidate_on_close=consolidate_on_close,
+ chunk_store=chunk_store,
)
- return ds
+
+ store_entrypoint = StoreBackendEntrypoint()
+ with close_on_error(store):
+ ds = store_entrypoint.open_dataset(
+ store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ return ds
-zarr_backend = BackendEntrypoint(open_dataset=open_backend_dataset_zarr)
+if has_zarr:
+ BACKEND_ENTRYPOINTS["zarr"] = ZarrBackendEntrypoint
diff --git a/xarray/conventions.py b/xarray/conventions.py
index bb0b92c77a1..e33ae53b31d 100644
--- a/xarray/conventions.py
+++ b/xarray/conventions.py
@@ -576,12 +576,12 @@ def decode_cf(
vars = obj._variables
attrs = obj.attrs
extra_coords = set(obj.coords)
- file_obj = obj._file_obj
+ close = obj._close
encoding = obj.encoding
elif isinstance(obj, AbstractDataStore):
vars, attrs = obj.load()
extra_coords = set()
- file_obj = obj
+ close = obj.close
encoding = obj.get_encoding()
else:
raise TypeError("can only decode Dataset or DataStore objects")
@@ -599,7 +599,7 @@ def decode_cf(
)
ds = Dataset(vars, attrs=attrs)
ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars))
- ds._file_obj = file_obj
+ ds.set_close(close)
ds.encoding = encoding
return ds
diff --git a/xarray/core/common.py b/xarray/core/common.py
index 283114770cf..c5836c68759 100644
--- a/xarray/core/common.py
+++ b/xarray/core/common.py
@@ -3,6 +3,7 @@
from html import escape
from textwrap import dedent
from typing import (
+ TYPE_CHECKING,
Any,
Callable,
Dict,
@@ -11,6 +12,7 @@
Iterator,
List,
Mapping,
+ Optional,
Tuple,
TypeVar,
Union,
@@ -31,6 +33,12 @@
ALL_DIMS = ...
+if TYPE_CHECKING:
+ from .dataarray import DataArray
+ from .weighted import Weighted
+
+T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")
+
C = TypeVar("C")
T = TypeVar("T")
@@ -330,7 +338,9 @@ def get_squeeze_dims(
class DataWithCoords(SupportsArithmetic, AttrAccessMixin):
"""Shared base class for Dataset and DataArray."""
- __slots__ = ()
+ _close: Optional[Callable[[], None]]
+
+ __slots__ = ("_close",)
_rolling_exp_cls = RollingExp
@@ -769,7 +779,9 @@ def groupby_bins(
},
)
- def weighted(self, weights):
+ def weighted(
+ self: T_DataWithCoords, weights: "DataArray"
+ ) -> "Weighted[T_DataWithCoords]":
"""
Weighted operations.
@@ -1263,11 +1275,27 @@ def where(self, cond, other=dtypes.NA, drop: bool = False):
return ops.where_method(self, cond, other)
+ def set_close(self, close: Optional[Callable[[], None]]) -> None:
+ """Register the function that releases any resources linked to this object.
+
+ This method controls how xarray cleans up resources associated
+ with this object when the ``.close()`` method is called. It is mostly
+ intended for backend developers and it is rarely needed by regular
+ end-users.
+
+ Parameters
+ ----------
+ close : callable
+ The function that when called like ``close()`` releases
+ any resources linked to this object.
+ """
+ self._close = close
+
def close(self: Any) -> None:
- """Close any files linked to this object"""
- if self._file_obj is not None:
- self._file_obj.close()
- self._file_obj = None
+ """Release any resources linked to this object."""
+ if self._close is not None:
+ self._close()
+ self._close = None
def isnull(self, keep_attrs: bool = None):
"""Test each value in the array for whether it is a missing value.
diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py
index 6fdda8fc418..0155cdc4e19 100644
--- a/xarray/core/dataarray.py
+++ b/xarray/core/dataarray.py
@@ -344,6 +344,7 @@ class DataArray(AbstractArray, DataWithCoords):
_cache: Dict[str, Any]
_coords: Dict[Any, Variable]
+ _close: Optional[Callable[[], None]]
_indexes: Optional[Dict[Hashable, pd.Index]]
_name: Optional[Hashable]
_variable: Variable
@@ -351,7 +352,7 @@ class DataArray(AbstractArray, DataWithCoords):
__slots__ = (
"_cache",
"_coords",
- "_file_obj",
+ "_close",
"_indexes",
"_name",
"_variable",
@@ -421,7 +422,7 @@ def __init__(
# public interface.
self._indexes = indexes
- self._file_obj = None
+ self._close = None
def _replace(
self,
@@ -1698,7 +1699,9 @@ def rename(
new_name_or_name_dict = cast(Hashable, new_name_or_name_dict)
return self._replace(name=new_name_or_name_dict)
- def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray":
+ def swap_dims(
+ self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs
+ ) -> "DataArray":
"""Returns a new DataArray with swapped dimensions.
Parameters
@@ -1707,6 +1710,10 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray":
Dictionary whose keys are current dimension names and whose values
are new names.
+ **dim_kwargs : {dim: , ...}, optional
+ The keyword arguments form of ``dims_dict``.
+ One of dims_dict or dims_kwargs must be provided.
+
Returns
-------
swapped : DataArray
@@ -1748,6 +1755,7 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray":
DataArray.rename
Dataset.swap_dims
"""
+ dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "swap_dims")
ds = self._to_temp_dataset().swap_dims(dims_dict)
return self._from_temp_dataset(ds)
@@ -2247,6 +2255,28 @@ def drop_sel(
ds = self._to_temp_dataset().drop_sel(labels, errors=errors)
return self._from_temp_dataset(ds)
+ def drop_isel(self, indexers=None, **indexers_kwargs):
+ """Drop index positions from this DataArray.
+
+ Parameters
+ ----------
+ indexers : mapping of hashable to Any
+ Index locations to drop
+ **indexers_kwargs : {dim: position, ...}, optional
+ The keyword arguments form of ``dim`` and ``positions``
+
+ Returns
+ -------
+ dropped : DataArray
+
+ Raises
+ ------
+ IndexError
+ """
+ dataset = self._to_temp_dataset()
+ dataset = dataset.drop_isel(indexers=indexers, **indexers_kwargs)
+ return self._from_temp_dataset(dataset)
+
def dropna(
self, dim: Hashable, how: str = "any", thresh: int = None
) -> "DataArray":
@@ -3451,21 +3481,26 @@ def differentiate(
return self._from_temp_dataset(ds)
def integrate(
- self, dim: Union[Hashable, Sequence[Hashable]], datetime_unit: str = None
+ self,
+ coord: Union[Hashable, Sequence[Hashable]] = None,
+ datetime_unit: str = None,
+ *,
+ dim: Union[Hashable, Sequence[Hashable]] = None,
) -> "DataArray":
- """ integrate the array with the trapezoidal rule.
+ """Integrate along the given coordinate using the trapezoidal rule.
.. note::
- This feature is limited to simple cartesian geometry, i.e. dim
+ This feature is limited to simple cartesian geometry, i.e. coord
must be one dimensional.
Parameters
----------
+ coord: hashable, or a sequence of hashable
+ Coordinate(s) used for the integration.
dim : hashable, or sequence of hashable
Coordinate(s) used for the integration.
- datetime_unit : {"Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", \
- "ps", "fs", "as"}, optional
- Can be used to specify the unit if datetime coordinate is used.
+ datetime_unit: {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \
+ 'ps', 'fs', 'as'}, optional
Returns
-------
@@ -3473,6 +3508,7 @@ def integrate(
See also
--------
+ Dataset.integrate
numpy.trapz: corresponding numpy function
Examples
@@ -3498,7 +3534,22 @@ def integrate(
array([5.4, 6.6, 7.8])
Dimensions without coordinates: y
"""
- ds = self._to_temp_dataset().integrate(dim, datetime_unit)
+ if dim is not None and coord is not None:
+ raise ValueError(
+ "Cannot pass both 'dim' and 'coord'. Please pass only 'coord' instead."
+ )
+
+ if dim is not None and coord is None:
+ coord = dim
+ msg = (
+ "The `dim` keyword argument to `DataArray.integrate` is "
+ "being replaced with `coord`, for consistency with "
+ "`Dataset.integrate`. Please pass `coord` instead."
+ " `dim` will be removed in version 0.19.0."
+ )
+ warnings.warn(msg, FutureWarning, stacklevel=2)
+
+ ds = self._to_temp_dataset().integrate(coord, datetime_unit)
return self._from_temp_dataset(ds)
def unify_chunks(self) -> "DataArray":
diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index 7edc2fab067..8376b4875f9 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -4,6 +4,7 @@
import sys
import warnings
from collections import defaultdict
+from distutils.version import LooseVersion
from html import escape
from numbers import Number
from operator import methodcaller
@@ -79,7 +80,7 @@
)
from .missing import get_clean_interp_index
from .options import OPTIONS, _get_keep_attrs
-from .pycompat import is_duck_dask_array
+from .pycompat import is_duck_dask_array, sparse_array_type
from .utils import (
Default,
Frozen,
@@ -636,6 +637,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords):
_coord_names: Set[Hashable]
_dims: Dict[Hashable, int]
_encoding: Optional[Dict[Hashable, Any]]
+ _close: Optional[Callable[[], None]]
_indexes: Optional[Dict[Hashable, pd.Index]]
_variables: Dict[Hashable, Variable]
@@ -645,7 +647,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords):
"_coord_names",
"_dims",
"_encoding",
- "_file_obj",
+ "_close",
"_indexes",
"_variables",
"__weakref__",
@@ -687,7 +689,7 @@ def __init__(
)
self._attrs = dict(attrs) if attrs is not None else None
- self._file_obj = None
+ self._close = None
self._encoding = None
self._variables = variables
self._coord_names = coord_names
@@ -703,7 +705,7 @@ def load_store(cls, store, decoder=None) -> "Dataset":
if decoder:
variables, attributes = decoder(variables, attributes)
obj = cls(variables, attrs=attributes)
- obj._file_obj = store
+ obj.set_close(store.close)
return obj
@property
@@ -876,7 +878,7 @@ def __dask_postcompute__(self):
self._attrs,
self._indexes,
self._encoding,
- self._file_obj,
+ self._close,
)
return self._dask_postcompute, args
@@ -896,7 +898,7 @@ def __dask_postpersist__(self):
self._attrs,
self._indexes,
self._encoding,
- self._file_obj,
+ self._close,
)
return self._dask_postpersist, args
@@ -1007,7 +1009,7 @@ def _construct_direct(
attrs=None,
indexes=None,
encoding=None,
- file_obj=None,
+ close=None,
):
"""Shortcut around __init__ for internal use when we want to skip
costly validation
@@ -1020,7 +1022,7 @@ def _construct_direct(
obj._dims = dims
obj._indexes = indexes
obj._attrs = attrs
- obj._file_obj = file_obj
+ obj._close = close
obj._encoding = encoding
return obj
@@ -2122,7 +2124,7 @@ def isel(
attrs=self._attrs,
indexes=indexes,
encoding=self._encoding,
- file_obj=self._file_obj,
+ close=self._close,
)
def _isel_fancy(
@@ -3154,7 +3156,9 @@ def rename_vars(
)
return self._replace(variables, coord_names, dims=dims, indexes=indexes)
- def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset":
+ def swap_dims(
+ self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs
+ ) -> "Dataset":
"""Returns a new object with swapped dimensions.
Parameters
@@ -3163,6 +3167,10 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset":
Dictionary whose keys are current dimension names and whose values
are new names.
+ **dim_kwargs : {existing_dim: new_dim, ...}, optional
+ The keyword arguments form of ``dims_dict``.
+ One of dims_dict or dims_kwargs must be provided.
+
Returns
-------
swapped : Dataset
@@ -3213,6 +3221,8 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset":
"""
# TODO: deprecate this method in favor of a (less confusing)
# rename_dims() method that only renames dimensions.
+
+ dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "swap_dims")
for k, v in dims_dict.items():
if k not in self.dims:
raise ValueError(
@@ -3706,7 +3716,40 @@ def ensure_stackable(val):
return data_array
- def _unstack_once(self, dim: Hashable, fill_value, sparse) -> "Dataset":
+ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
+ index = self.get_index(dim)
+ index = remove_unused_levels_categories(index)
+
+ variables: Dict[Hashable, Variable] = {}
+ indexes = {k: v for k, v in self.indexes.items() if k != dim}
+
+ for name, var in self.variables.items():
+ if name != dim:
+ if dim in var.dims:
+ if isinstance(fill_value, Mapping):
+ fill_value_ = fill_value[name]
+ else:
+ fill_value_ = fill_value
+
+ variables[name] = var._unstack_once(
+ index=index, dim=dim, fill_value=fill_value_
+ )
+ else:
+ variables[name] = var
+
+ for name, lev in zip(index.names, index.levels):
+ variables[name] = IndexVariable(name, lev)
+ indexes[name] = lev
+
+ coord_names = set(self._coord_names) - {dim} | set(index.names)
+
+ return self._replace_with_new_dims(
+ variables, coord_names=coord_names, indexes=indexes
+ )
+
+ def _unstack_full_reindex(
+ self, dim: Hashable, fill_value, sparse: bool
+ ) -> "Dataset":
index = self.get_index(dim)
index = remove_unused_levels_categories(index)
full_idx = pd.MultiIndex.from_product(index.levels, names=index.names)
@@ -3803,7 +3846,38 @@ def unstack(
result = self.copy(deep=False)
for dim in dims:
- result = result._unstack_once(dim, fill_value, sparse)
+
+ if (
+ # Dask arrays don't support assignment by index, which the fast unstack
+ # function requires.
+ # https://github.com/pydata/xarray/pull/4746#issuecomment-753282125
+ any(is_duck_dask_array(v.data) for v in self.variables.values())
+ # Sparse doesn't currently support (though we could special-case
+ # it)
+ # https://github.com/pydata/sparse/issues/422
+ or any(
+ isinstance(v.data, sparse_array_type)
+ for v in self.variables.values()
+ )
+ or sparse
+ # numpy full_like only added `shape` in 1.17
+ or LooseVersion(np.__version__) < LooseVersion("1.17")
+ # Until https://github.com/pydata/xarray/pull/4751 is resolved,
+ # we check explicitly whether it's a numpy array. Once that is
+ # resolved, explicitly exclude pint arrays.
+ # # pint doesn't implement `np.full_like` in a way that's
+ # # currently compatible.
+ # # https://github.com/pydata/xarray/pull/4746#issuecomment-753425173
+ # # or any(
+ # # isinstance(v.data, pint_array_type) for v in self.variables.values()
+ # # )
+ or any(
+ not isinstance(v.data, np.ndarray) for v in self.variables.values()
+ )
+ ):
+ result = result._unstack_full_reindex(dim, fill_value, sparse)
+ else:
+ result = result._unstack_once(dim, fill_value)
return result
def update(self, other: "CoercibleMapping") -> "Dataset":
@@ -4020,9 +4094,17 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs):
Examples
--------
- >>> data = np.random.randn(2, 3)
+ >>> data = np.arange(6).reshape(2, 3)
>>> labels = ["a", "b", "c"]
>>> ds = xr.Dataset({"A": (["x", "y"], data), "y": labels})
+ >>> ds
+
+ Dimensions: (x: 2, y: 3)
+ Coordinates:
+ * y (y) >> ds.drop_sel(y=["a", "c"])
Dimensions: (x: 2, y: 1)
@@ -4030,7 +4112,7 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs):
* y (y) >> ds.drop_sel(y="b")
Dimensions: (x: 2, y: 2)
@@ -4038,12 +4120,12 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs):
* y (y) >> data = np.arange(6).reshape(2, 3)
+ >>> labels = ["a", "b", "c"]
+ >>> ds = xr.Dataset({"A": (["x", "y"], data), "y": labels})
+ >>> ds
+
+ Dimensions: (x: 2, y: 3)
+ Coordinates:
+ * y (y) >> ds.drop_isel(y=[0, 2])
+
+ Dimensions: (x: 2, y: 1)
+ Coordinates:
+ * y (y) >> ds.drop_isel(y=1)
+
+ Dimensions: (x: 2, y: 2)
+ Coordinates:
+ * y (y) "Dataset":
@@ -4900,6 +5047,10 @@ def _set_numpy_data_from_dataframe(
self[name] = (dims, values)
return
+ # NB: similar, more general logic, now exists in
+ # variable.unstack_once; we could consider combining them at some
+ # point.
+
shape = tuple(lev.size for lev in idx.levels)
indexer = tuple(idx.codes)
@@ -5812,8 +5963,10 @@ def differentiate(self, coord, edge_order=1, datetime_unit=None):
variables[k] = v
return self._replace(variables)
- def integrate(self, coord, datetime_unit=None):
- """ integrate the array with the trapezoidal rule.
+ def integrate(
+ self, coord: Union[Hashable, Sequence[Hashable]], datetime_unit: str = None
+ ) -> "Dataset":
+ """Integrate along the given coordinate using the trapezoidal rule.
.. note::
This feature is limited to simple cartesian geometry, i.e. coord
@@ -5821,11 +5974,11 @@ def integrate(self, coord, datetime_unit=None):
Parameters
----------
- coord: str, or sequence of str
+ coord: hashable, or a sequence of hashable
Coordinate(s) used for the integration.
- datetime_unit : {"Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", \
- "ps", "fs", "as"}, optional
- Can be specify the unit if datetime coordinate is used.
+ datetime_unit: {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \
+ 'ps', 'fs', 'as'}, optional
+ Specify the unit if datetime coordinate is used.
Returns
-------
diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py
index 282620e3569..0c1be1cc175 100644
--- a/xarray/core/formatting.py
+++ b/xarray/core/formatting.py
@@ -300,11 +300,18 @@ def _summarize_coord_multiindex(coord, col_width, marker):
def _summarize_coord_levels(coord, col_width, marker="-"):
+ if len(coord) > 100 and col_width < len(coord):
+ n_values = col_width
+ indices = list(range(0, n_values)) + list(range(-n_values, 0))
+ subset = coord[indices]
+ else:
+ subset = coord
+
return "\n".join(
summarize_variable(
- lname, coord.get_level_variable(lname), col_width, marker=marker
+ lname, subset.get_level_variable(lname), col_width, marker=marker
)
- for lname in coord.level_names
+ for lname in subset.level_names
)
diff --git a/xarray/core/variable.py b/xarray/core/variable.py
index 797de65bbcf..64c1895da59 100644
--- a/xarray/core/variable.py
+++ b/xarray/core/variable.py
@@ -10,6 +10,7 @@
Any,
Dict,
Hashable,
+ List,
Mapping,
Optional,
Sequence,
@@ -1488,7 +1489,7 @@ def set_dims(self, dims, shape=None):
)
return expanded_var.transpose(*dims)
- def _stack_once(self, dims, new_dim):
+ def _stack_once(self, dims: List[Hashable], new_dim: Hashable):
if not set(dims) <= set(self.dims):
raise ValueError("invalid existing dimensions: %s" % dims)
@@ -1544,7 +1545,15 @@ def stack(self, dimensions=None, **dimensions_kwargs):
result = result._stack_once(dims, new_dim)
return result
- def _unstack_once(self, dims, old_dim):
+ def _unstack_once_full(
+ self, dims: Mapping[Hashable, int], old_dim: Hashable
+ ) -> "Variable":
+ """
+ Unstacks the variable without needing an index.
+
+ Unlike `_unstack_once`, this function requires the existing dimension to
+ contain the full product of the new dimensions.
+ """
new_dim_names = tuple(dims.keys())
new_dim_sizes = tuple(dims.values())
@@ -1573,6 +1582,53 @@ def _unstack_once(self, dims, old_dim):
return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True)
+ def _unstack_once(
+ self,
+ index: pd.MultiIndex,
+ dim: Hashable,
+ fill_value=dtypes.NA,
+ ) -> "Variable":
+ """
+ Unstacks this variable given an index to unstack and the name of the
+ dimension to which the index refers.
+ """
+
+ reordered = self.transpose(..., dim)
+
+ new_dim_sizes = [lev.size for lev in index.levels]
+ new_dim_names = index.names
+ indexer = index.codes
+
+ # Potentially we could replace `len(other_dims)` with just `-1`
+ other_dims = [d for d in self.dims if d != dim]
+ new_shape = list(reordered.shape[: len(other_dims)]) + new_dim_sizes
+ new_dims = reordered.dims[: len(other_dims)] + new_dim_names
+
+ if fill_value is dtypes.NA:
+ is_missing_values = np.prod(new_shape) > np.prod(self.shape)
+ if is_missing_values:
+ dtype, fill_value = dtypes.maybe_promote(self.dtype)
+ else:
+ dtype = self.dtype
+ fill_value = dtypes.get_fill_value(dtype)
+ else:
+ dtype = self.dtype
+
+ # Currently fails on sparse due to https://github.com/pydata/sparse/issues/422
+ data = np.full_like(
+ self.data,
+ fill_value=fill_value,
+ shape=new_shape,
+ dtype=dtype,
+ )
+
+ # Indexer is a list of lists of locations. Each list is the locations
+ # on the new dimension. This is robust to the data being sparse; in that
+ # case the destinations will be NaN / zero.
+ data[(..., *indexer)] = reordered
+
+ return self._replace(dims=new_dims, data=data)
+
def unstack(self, dimensions=None, **dimensions_kwargs):
"""
Unstack an existing dimension into multiple new dimensions.
@@ -1580,6 +1636,10 @@ def unstack(self, dimensions=None, **dimensions_kwargs):
New dimensions will be added at the end, and the order of the data
along each new dimension will be in contiguous (C) order.
+ Note that unlike ``DataArray.unstack`` and ``Dataset.unstack``, this
+ method requires the existing dimension to contain the full product of
+ the new dimensions.
+
Parameters
----------
dimensions : mapping of hashable to mapping of hashable to int
@@ -1598,11 +1658,13 @@ def unstack(self, dimensions=None, **dimensions_kwargs):
See also
--------
Variable.stack
+ DataArray.unstack
+ Dataset.unstack
"""
dimensions = either_dict_or_kwargs(dimensions, dimensions_kwargs, "unstack")
result = self
for old_dim, dims in dimensions.items():
- result = result._unstack_once(dims, old_dim)
+ result = result._unstack_once_full(dims, old_dim)
return result
def fillna(self, value):
diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py
index dbd4e1ad103..449a7200ee7 100644
--- a/xarray/core/weighted.py
+++ b/xarray/core/weighted.py
@@ -1,13 +1,16 @@
-from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload
+from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, TypeVar, Union
from . import duck_array_ops
from .computation import dot
-from .options import _get_keep_attrs
from .pycompat import is_duck_dask_array
if TYPE_CHECKING:
+ from .common import DataWithCoords # noqa: F401
from .dataarray import DataArray, Dataset
+T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")
+
+
_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """
Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s).
@@ -56,7 +59,7 @@
"""
-class Weighted:
+class Weighted(Generic[T_DataWithCoords]):
"""An object that implements weighted operations.
You should create a Weighted object by using the ``DataArray.weighted`` or
@@ -70,15 +73,7 @@ class Weighted:
__slots__ = ("obj", "weights")
- @overload
- def __init__(self, obj: "DataArray", weights: "DataArray") -> None:
- ...
-
- @overload
- def __init__(self, obj: "Dataset", weights: "DataArray") -> None:
- ...
-
- def __init__(self, obj, weights):
+ def __init__(self, obj: T_DataWithCoords, weights: "DataArray"):
"""
Create a Weighted object
@@ -121,8 +116,8 @@ def _weight_check(w):
else:
_weight_check(weights.data)
- self.obj = obj
- self.weights = weights
+ self.obj: T_DataWithCoords = obj
+ self.weights: "DataArray" = weights
@staticmethod
def _reduce(
@@ -146,7 +141,6 @@ def _reduce(
# `dot` does not broadcast arrays, so this avoids creating a large
# DataArray (if `weights` has additional dimensions)
- # maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`)
return dot(da, weights, dims=dim)
def _sum_of_weights(
@@ -203,7 +197,7 @@ def sum_of_weights(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
keep_attrs: Optional[bool] = None,
- ) -> Union["DataArray", "Dataset"]:
+ ) -> T_DataWithCoords:
return self._implementation(
self._sum_of_weights, dim=dim, keep_attrs=keep_attrs
@@ -214,7 +208,7 @@ def sum(
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
keep_attrs: Optional[bool] = None,
- ) -> Union["DataArray", "Dataset"]:
+ ) -> T_DataWithCoords:
return self._implementation(
self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs
@@ -225,7 +219,7 @@ def mean(
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
keep_attrs: Optional[bool] = None,
- ) -> Union["DataArray", "Dataset"]:
+ ) -> T_DataWithCoords:
return self._implementation(
self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs
@@ -239,22 +233,15 @@ def __repr__(self):
return f"{klass} with weights along dimensions: {weight_dims}"
-class DataArrayWeighted(Weighted):
- def _implementation(self, func, dim, **kwargs):
-
- keep_attrs = kwargs.pop("keep_attrs")
- if keep_attrs is None:
- keep_attrs = _get_keep_attrs(default=False)
-
- weighted = func(self.obj, dim=dim, **kwargs)
-
- if keep_attrs:
- weighted.attrs = self.obj.attrs
+class DataArrayWeighted(Weighted["DataArray"]):
+ def _implementation(self, func, dim, **kwargs) -> "DataArray":
- return weighted
+ dataset = self.obj._to_temp_dataset()
+ dataset = dataset.map(func, dim=dim, **kwargs)
+ return self.obj._from_temp_dataset(dataset)
-class DatasetWeighted(Weighted):
+class DatasetWeighted(Weighted["Dataset"]):
def _implementation(self, func, dim, **kwargs) -> "Dataset":
return self.obj.map(func, dim=dim, **kwargs)
diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py
index 3ead427e22e..fc84687511e 100644
--- a/xarray/tests/test_dataarray.py
+++ b/xarray/tests/test_dataarray.py
@@ -1639,6 +1639,16 @@ def test_swap_dims(self):
expected.indexes[dim_name], actual.indexes[dim_name]
)
+ # as kwargs
+ array = DataArray(np.random.randn(3), {"x": list("abc")}, "x")
+ expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y")
+ actual = array.swap_dims(x="y")
+ assert_identical(expected, actual)
+ for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()):
+ pd.testing.assert_index_equal(
+ expected.indexes[dim_name], actual.indexes[dim_name]
+ )
+
# multiindex case
idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"])
array = DataArray(np.random.randn(3), {"y": ("x", idx)}, "x")
@@ -2327,6 +2337,12 @@ def test_drop_index_labels(self):
with pytest.warns(DeprecationWarning):
arr.drop([0, 1, 3], dim="y", errors="ignore")
+ def test_drop_index_positions(self):
+ arr = DataArray(np.random.randn(2, 3), dims=["x", "y"])
+ actual = arr.drop_isel(y=[0, 1])
+ expected = arr[:, 2:]
+ assert_identical(actual, expected)
+
def test_dropna(self):
x = np.random.randn(4, 4)
x[::2, 0] = np.nan
diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py
index bd1938455b1..db47faa8d2b 100644
--- a/xarray/tests/test_dataset.py
+++ b/xarray/tests/test_dataset.py
@@ -2371,8 +2371,12 @@ def test_drop_index_labels(self):
data.drop(DataArray(["a", "b", "c"]), dim="x", errors="ignore")
assert_identical(expected, actual)
- with raises_regex(ValueError, "does not have coordinate labels"):
- data.drop_sel(y=1)
+ actual = data.drop_sel(y=[1])
+ expected = data.isel(y=[0, 2])
+ assert_identical(expected, actual)
+
+ with raises_regex(KeyError, "not found in axis"):
+ data.drop_sel(x=0)
def test_drop_labels_by_keyword(self):
data = Dataset(
@@ -2410,6 +2414,34 @@ def test_drop_labels_by_keyword(self):
with pytest.raises(ValueError):
data.drop(dim="x", x="a")
+ def test_drop_labels_by_position(self):
+ data = Dataset(
+ {"A": (["x", "y"], np.random.randn(2, 6)), "x": ["a", "b"], "y": range(6)}
+ )
+ # Basic functionality.
+ assert len(data.coords["x"]) == 2
+
+ actual = data.drop_isel(x=0)
+ expected = data.drop_sel(x="a")
+ assert_identical(expected, actual)
+
+ actual = data.drop_isel(x=[0])
+ expected = data.drop_sel(x=["a"])
+ assert_identical(expected, actual)
+
+ actual = data.drop_isel(x=[0, 1])
+ expected = data.drop_sel(x=["a", "b"])
+ assert_identical(expected, actual)
+ assert actual.coords["x"].size == 0
+
+ actual = data.drop_isel(x=[0, 1], y=range(0, 6, 2))
+ expected = data.drop_sel(x=["a", "b"], y=range(0, 6, 2))
+ assert_identical(expected, actual)
+ assert actual.coords["x"].size == 0
+
+ with pytest.raises(KeyError):
+ data.drop_isel(z=1)
+
def test_drop_dims(self):
data = xr.Dataset(
{
@@ -2716,6 +2748,13 @@ def test_swap_dims(self):
actual = original.swap_dims({"x": "u"})
assert_identical(expected, actual)
+ # as kwargs
+ expected = Dataset(
+ {"y": ("u", list("abc")), "z": 42}, coords={"x": ("u", [1, 2, 3])}
+ )
+ actual = original.swap_dims(x="u")
+ assert_identical(expected, actual)
+
# handle multiindex case
idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"])
original = Dataset({"x": [1, 2, 3], "y": ("x", idx), "z": 42})
@@ -6564,6 +6603,9 @@ def test_integrate(dask):
with pytest.raises(ValueError):
da.integrate("x2d")
+ with pytest.warns(FutureWarning):
+ da.integrate(dim="x")
+
@pytest.mark.parametrize("dask", [True, False])
@pytest.mark.parametrize("which_datetime", ["np", "cftime"])
diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py
index 110ef47209f..64a1c563dba 100644
--- a/xarray/tests/test_plugins.py
+++ b/xarray/tests/test_plugins.py
@@ -6,19 +6,24 @@
from xarray.backends import common, plugins
-def dummy_open_dataset_args(filename_or_obj, *args):
- pass
+class DummyBackendEntrypointArgs(common.BackendEntrypoint):
+ def open_dataset(filename_or_obj, *args):
+ pass
-def dummy_open_dataset_kwargs(filename_or_obj, **kwargs):
- pass
+class DummyBackendEntrypointKwargs(common.BackendEntrypoint):
+ def open_dataset(filename_or_obj, **kwargs):
+ pass
-def dummy_open_dataset(filename_or_obj, *, decoder):
- pass
+class DummyBackendEntrypoint1(common.BackendEntrypoint):
+ def open_dataset(self, filename_or_obj, *, decoder):
+ pass
-dummy_cfgrib = common.BackendEntrypoint(dummy_open_dataset)
+class DummyBackendEntrypoint2(common.BackendEntrypoint):
+ def open_dataset(self, filename_or_obj, *, decoder):
+ pass
@pytest.fixture
@@ -65,46 +70,48 @@ def test_create_engines_dict():
def test_set_missing_parameters():
- backend_1 = common.BackendEntrypoint(dummy_open_dataset)
- backend_2 = common.BackendEntrypoint(dummy_open_dataset, ("filename_or_obj",))
+ backend_1 = DummyBackendEntrypoint1
+ backend_2 = DummyBackendEntrypoint2
+ backend_2.open_dataset_parameters = ("filename_or_obj",)
engines = {"engine_1": backend_1, "engine_2": backend_2}
plugins.set_missing_parameters(engines)
assert len(engines) == 2
- engine_1 = engines["engine_1"]
- assert engine_1.open_dataset_parameters == ("filename_or_obj", "decoder")
- engine_2 = engines["engine_2"]
- assert engine_2.open_dataset_parameters == ("filename_or_obj",)
+ assert backend_1.open_dataset_parameters == ("filename_or_obj", "decoder")
+ assert backend_2.open_dataset_parameters == ("filename_or_obj",)
+
+ backend = DummyBackendEntrypointKwargs()
+ backend.open_dataset_parameters = ("filename_or_obj", "decoder")
+ plugins.set_missing_parameters({"engine": backend})
+ assert backend.open_dataset_parameters == ("filename_or_obj", "decoder")
+
+ backend = DummyBackendEntrypointArgs()
+ backend.open_dataset_parameters = ("filename_or_obj", "decoder")
+ plugins.set_missing_parameters({"engine": backend})
+ assert backend.open_dataset_parameters == ("filename_or_obj", "decoder")
def test_set_missing_parameters_raise_error():
- backend = common.BackendEntrypoint(dummy_open_dataset_args)
+ backend = DummyBackendEntrypointKwargs()
with pytest.raises(TypeError):
plugins.set_missing_parameters({"engine": backend})
- backend = common.BackendEntrypoint(
- dummy_open_dataset_args, ("filename_or_obj", "decoder")
- )
- plugins.set_missing_parameters({"engine": backend})
-
- backend = common.BackendEntrypoint(dummy_open_dataset_kwargs)
+ backend = DummyBackendEntrypointArgs()
with pytest.raises(TypeError):
plugins.set_missing_parameters({"engine": backend})
- backend = plugins.BackendEntrypoint(
- dummy_open_dataset_kwargs, ("filename_or_obj", "decoder")
- )
- plugins.set_missing_parameters({"engine": backend})
-
-@mock.patch("pkg_resources.EntryPoint.load", mock.MagicMock(return_value=dummy_cfgrib))
+@mock.patch(
+ "pkg_resources.EntryPoint.load",
+ mock.MagicMock(return_value=DummyBackendEntrypoint1),
+)
def test_build_engines():
- dummy_cfgrib_pkg_entrypoint = pkg_resources.EntryPoint.parse(
+ dummy_pkg_entrypoint = pkg_resources.EntryPoint.parse(
"cfgrib = xarray.tests.test_plugins:backend_1"
)
- backend_entrypoints = plugins.build_engines([dummy_cfgrib_pkg_entrypoint])
- assert backend_entrypoints["cfgrib"] is dummy_cfgrib
+ backend_entrypoints = plugins.build_engines([dummy_pkg_entrypoint])
+ assert isinstance(backend_entrypoints["cfgrib"], DummyBackendEntrypoint1)
assert backend_entrypoints["cfgrib"].open_dataset_parameters == (
"filename_or_obj",
"decoder",
diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py
index bb3127e90b5..76dd830de23 100644
--- a/xarray/tests/test_units.py
+++ b/xarray/tests/test_units.py
@@ -3681,7 +3681,7 @@ def test_stacking_reordering(self, func, dtype):
(
method("diff", dim="x"),
method("differentiate", coord="x"),
- method("integrate", dim="x"),
+ method("integrate", coord="x"),
method("quantile", q=[0.25, 0.75]),
method("reduce", func=np.sum, dim="x"),
pytest.param(lambda x: x.dot(x), id="method_dot"),