diff --git a/.github/ISSUE_TEMPLATE/bugreport.yml b/.github/ISSUE_TEMPLATE/bugreport.yml
index 59e5889f5ec..cc1a2e12be3 100644
--- a/.github/ISSUE_TEMPLATE/bugreport.yml
+++ b/.github/ISSUE_TEMPLATE/bugreport.yml
@@ -44,6 +44,7 @@ body:
- label: Complete example — the example is self-contained, including all data and the text of any traceback.
- label: Verifiable example — the example copy & pastes into an IPython prompt or [Binder notebook](https://mybinder.org/v2/gh/pydata/xarray/main?urlpath=lab/tree/doc/examples/blank_template.ipynb), returning the result.
- label: New issue — a search of GitHub Issues suggests this is not a duplicate.
+ - label: Recent environment — the issue occurs with the latest version of xarray and its dependencies.
- type: textarea
id: log-output
diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml
index ec1c192fd35..dc9cc2cd2fe 100644
--- a/.github/workflows/ci-additional.yaml
+++ b/.github/workflows/ci-additional.yaml
@@ -82,8 +82,6 @@ jobs:
name: Mypy
runs-on: "ubuntu-latest"
needs: detect-ci-trigger
- # temporarily skipping due to https://github.com/pydata/xarray/issues/6551
- if: needs.detect-ci-trigger.outputs.triggered == 'false'
defaults:
run:
shell: bash -l {0}
@@ -190,6 +188,126 @@ jobs:
+ pyright:
+ name: Pyright
+ runs-on: "ubuntu-latest"
+ needs: detect-ci-trigger
+ if: |
+ always()
+ && (
+ contains( github.event.pull_request.labels.*.name, 'run-pyright')
+ )
+ defaults:
+ run:
+ shell: bash -l {0}
+ env:
+ CONDA_ENV_FILE: ci/requirements/environment.yml
+ PYTHON_VERSION: "3.10"
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # Fetch all history for all branches and tags.
+
+ - name: set environment variables
+ run: |
+ echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
+ - name: Setup micromamba
+ uses: mamba-org/setup-micromamba@v1
+ with:
+ environment-file: ${{env.CONDA_ENV_FILE}}
+ environment-name: xarray-tests
+ create-args: >-
+ python=${{env.PYTHON_VERSION}}
+ conda
+ cache-environment: true
+ cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}"
+ - name: Install xarray
+ run: |
+ python -m pip install --no-deps -e .
+ - name: Version info
+ run: |
+ conda info -a
+ conda list
+ python xarray/util/print_versions.py
+ - name: Install pyright
+ run: |
+ python -m pip install pyright --force-reinstall
+
+ - name: Run pyright
+ run: |
+ python -m pyright xarray/
+
+ - name: Upload pyright coverage to Codecov
+ uses: codecov/codecov-action@v3.1.4
+ with:
+ file: pyright_report/cobertura.xml
+ flags: pyright
+ env_vars: PYTHON_VERSION
+ name: codecov-umbrella
+ fail_ci_if_error: false
+
+ pyright39:
+ name: Pyright 3.9
+ runs-on: "ubuntu-latest"
+ needs: detect-ci-trigger
+ if: |
+ always()
+ && (
+ contains( github.event.pull_request.labels.*.name, 'run-pyright')
+ )
+ defaults:
+ run:
+ shell: bash -l {0}
+ env:
+ CONDA_ENV_FILE: ci/requirements/environment.yml
+ PYTHON_VERSION: "3.9"
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # Fetch all history for all branches and tags.
+
+ - name: set environment variables
+ run: |
+ echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
+ - name: Setup micromamba
+ uses: mamba-org/setup-micromamba@v1
+ with:
+ environment-file: ${{env.CONDA_ENV_FILE}}
+ environment-name: xarray-tests
+ create-args: >-
+ python=${{env.PYTHON_VERSION}}
+ conda
+ cache-environment: true
+ cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}"
+ - name: Install xarray
+ run: |
+ python -m pip install --no-deps -e .
+ - name: Version info
+ run: |
+ conda info -a
+ conda list
+ python xarray/util/print_versions.py
+ - name: Install pyright
+ run: |
+ python -m pip install pyright --force-reinstall
+
+ - name: Run pyright
+ run: |
+ python -m pyright xarray/
+
+ - name: Upload pyright coverage to Codecov
+ uses: codecov/codecov-action@v3.1.4
+ with:
+ file: pyright_report/cobertura.xml
+ flags: pyright39
+ env_vars: PYTHON_VERSION
+ name: codecov-umbrella
+ fail_ci_if_error: false
+
+
+
min-version-policy:
name: Minimum Version Policy
runs-on: "ubuntu-latest"
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index c2586a12aa2..5626f450ec0 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -18,13 +18,13 @@ repos:
files: ^xarray/
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
- rev: 'v0.0.287'
+ rev: 'v0.0.292'
hooks:
- id: ruff
args: ["--fix"]
# https://github.com/python/black#version-control-integration
- repo: https://github.com/psf/black
- rev: 23.7.0
+ rev: 23.9.1
hooks:
- id: black-jupyter
- repo: https://github.com/keewis/blackdoc
@@ -32,7 +32,7 @@ repos:
hooks:
- id: blackdoc
exclude: "generate_aggregations.py"
- additional_dependencies: ["black==23.7.0"]
+ additional_dependencies: ["black==23.9.1"]
- id: blackdoc-autoupdate-black
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
diff --git a/asv_bench/benchmarks/rolling.py b/asv_bench/benchmarks/rolling.py
index 1d3713f19bf..579f4f00fbc 100644
--- a/asv_bench/benchmarks/rolling.py
+++ b/asv_bench/benchmarks/rolling.py
@@ -5,10 +5,10 @@
from . import parameterized, randn, requires_dask
-nx = 300
+nx = 3000
long_nx = 30000
ny = 200
-nt = 100
+nt = 1000
window = 20
randn_xy = randn((nx, ny), frac_nan=0.1)
@@ -115,6 +115,11 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck):
roll = self.ds.var3.rolling(t=100)
getattr(roll, func)()
+ @parameterized(["stride"], ([None, 5, 50]))
+ def peakmem_1drolling_construct(self, stride):
+ self.ds.var2.rolling(t=100).construct("w", stride=stride)
+ self.ds.var3.rolling(t=100).construct("w", stride=stride)
+
class DatasetRollingMemory(RollingMemory):
@parameterized(["func", "use_bottleneck"], (["sum", "max", "mean"], [True, False]))
@@ -128,3 +133,7 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck):
with xr.set_options(use_bottleneck=use_bottleneck):
roll = self.ds.rolling(t=100)
getattr(roll, func)()
+
+ @parameterized(["stride"], ([None, 5, 50]))
+ def peakmem_1drolling_construct(self, stride):
+ self.ds.rolling(t=100).construct("w", stride=stride)
diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst
index d97c4010528..552d11a06dc 100644
--- a/doc/api-hidden.rst
+++ b/doc/api-hidden.rst
@@ -265,7 +265,7 @@
Variable.dims
Variable.dtype
Variable.encoding
- Variable.reset_encoding
+ Variable.drop_encoding
Variable.imag
Variable.nbytes
Variable.ndim
diff --git a/doc/api.rst b/doc/api.rst
index 0cf07f91df8..96b4864804f 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -110,9 +110,9 @@ Dataset contents
Dataset.drop_indexes
Dataset.drop_duplicates
Dataset.drop_dims
+ Dataset.drop_encoding
Dataset.set_coords
Dataset.reset_coords
- Dataset.reset_encoding
Dataset.convert_calendar
Dataset.interp_calendar
Dataset.get_index
@@ -303,8 +303,8 @@ DataArray contents
DataArray.drop_vars
DataArray.drop_indexes
DataArray.drop_duplicates
+ DataArray.drop_encoding
DataArray.reset_coords
- DataArray.reset_encoding
DataArray.copy
DataArray.convert_calendar
DataArray.interp_calendar
diff --git a/doc/ecosystem.rst b/doc/ecosystem.rst
index e6e970c6239..fc5ae963a1d 100644
--- a/doc/ecosystem.rst
+++ b/doc/ecosystem.rst
@@ -41,6 +41,7 @@ Geosciences
harmonic wind analysis in Python.
- `wradlib `_: An Open Source Library for Weather Radar Data Processing.
- `wrf-python `_: A collection of diagnostic and interpolation routines for use with output of the Weather Research and Forecasting (WRF-ARW) Model.
+- `xarray-regrid `_: xarray extension for regridding rectilinear data.
- `xarray-simlab `_: xarray extension for computer model simulations.
- `xarray-spatial `_: Numba-accelerated raster-based spatial processing tools (NDVI, curvature, zonal-statistics, proximity, hillshading, viewshed, etc.)
- `xarray-topo `_: xarray extension for topographic analysis and modelling.
diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst
index 2ffc25b2009..ffded682035 100644
--- a/doc/user-guide/io.rst
+++ b/doc/user-guide/io.rst
@@ -260,12 +260,12 @@ Note that all operations that manipulate variables other than indexing
will remove encoding information.
In some cases it is useful to intentionally reset a dataset's original encoding values.
-This can be done with either the :py:meth:`Dataset.reset_encoding` or
-:py:meth:`DataArray.reset_encoding` methods.
+This can be done with either the :py:meth:`Dataset.drop_encoding` or
+:py:meth:`DataArray.drop_encoding` methods.
.. ipython:: python
- ds_no_encoding = ds_disk.reset_encoding()
+ ds_no_encoding = ds_disk.drop_encoding()
ds_no_encoding.encoding
.. _combining multiple files:
diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index ee74411a004..40c50e158ad 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -14,25 +14,99 @@ What's New
np.random.seed(123456)
-.. _whats-new.2023.08.1:
+.. _whats-new.2023.09.1:
-v2023.08.1 (unreleased)
+v2023.09.1 (unreleased)
-----------------------
New Features
~~~~~~~~~~~~
+- :py:meth:`DataArray.where` & :py:meth:`Dataset.where` accept a callable for
+ the ``other`` parameter, passing the object as the only argument. Previously,
+ this was only valid for the ``cond`` parameter. (:issue:`8255`)
+ By `Maximilian Roos `_.
+- :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for
+ the ``variables`` parameter, passing the object as the only argument.
+ By `Maximilian Roos `_.
+- ``.rolling_exp`` functions can now operate on dask-backed arrays, assuming the
+ core dim has exactly one chunk. (:pull:`8284`).
+ By `Maximilian Roos `_.
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+- Made more arguments keyword-only (e.g. ``keep_attrs``, ``skipna``) for many :py:class:`xarray.DataArray` and
+ :py:class:`xarray.Dataset` methods (:pull:`6403`). By `Mathias Hauser `_.
+- :py:meth:`Dataset.to_zarr` & :py:meth:`DataArray.to_zarr` require keyword
+ arguments after the initial 7 positional arguments.
+ By `Maximilian Roos `_.
+
+
+Deprecations
+~~~~~~~~~~~~
+- Rename :py:meth:`Dataset.reset_encoding` & :py:meth:`DataArray.reset_encoding`
+ to :py:meth:`Dataset.drop_encoding` & :py:meth:`DataArray.drop_encoding` for
+ consistency with other ``drop`` & ``reset`` methods — ``drop`` generally
+ removes something, while ``reset`` generally resets to some default or
+ standard value. (:pull:`8287`, :issue:`8259`)
+ By `Maximilian Roos `_.
+
+Bug fixes
+~~~~~~~~~
+- :py:meth:`DataArray.rename` & :py:meth:`Dataset.rename` would emit a warning
+ when the operation was a no-op. (:issue:`8266`)
+ By `Simon Hansen `_.
+
+- Fix datetime encoding precision loss regression introduced in the previous
+ release for datetimes encoded with units requiring floating point values, and
+ a reference date not equal to the first value of the datetime array
+ (:issue:`8271`, :pull:`8272`). By `Spencer Clark
+ `_.
+
+
+Documentation
+~~~~~~~~~~~~~
+
+- Added xarray-regrid to the list of xarray related projects (:pull:`8272`).
+ By `Bart Schilperoort `_.
+
+
+Internal Changes
+~~~~~~~~~~~~~~~~
+
+- More improvements to support the Python `array API standard `_
+ by using duck array ops in more places in the codebase. (:pull:`8267`)
+ By `Tom White `_.
+
+
+.. _whats-new.2023.09.0:
+
+v2023.09.0 (Sep 26, 2023)
+-------------------------
+
+This release continues work on the new :py:class:`xarray.Coordinates` object, allows to provide `preferred_chunks` when
+reading from netcdf files, enables :py:func:`xarray.apply_ufunc` to handle missing core dimensions and fixes several bugs.
+
+Thanks to the 24 contributors to this release: Alexander Fischer, Amrest Chinkamol, Benoit Bovy, Darsh Ranjan, Deepak Cherian,
+Gianfranco Costamagna, Gregorio L. Trevisan, Illviljan, Joe Hamman, JR, Justus Magin, Kai Mühlbauer, Kian-Meng Ang, Kyle Sunden,
+Martin Raspaud, Mathias Hauser, Mattia Almansi, Maximilian Roos, András Gunyhó, Michael Niklas, Richard Kleijn, Riulinchen,
+Tom Nicholas and Wiktor Kraśnicki.
+
+We welcome the following new contributors to Xarray!: Alexander Fischer, Amrest Chinkamol, Darsh Ranjan, Gianfranco Costamagna, Gregorio L. Trevisan,
+Kian-Meng Ang, Riulinchen and Wiktor Kraśnicki.
+
+New Features
+~~~~~~~~~~~~
+
- Added the :py:meth:`Coordinates.assign` method that can be used to combine
different collections of coordinates prior to assign them to a Dataset or
DataArray (:pull:`8102`) at once.
By `Benoît Bovy `_.
- Provide `preferred_chunks` for data read from netcdf files (:issue:`1440`, :pull:`7948`).
By `Martin Raspaud `_.
-- Improved static typing of reduction methods (:pull:`6746`).
- By `Richard Kleijn `_.
- Added `on_missing_core_dims` to :py:meth:`apply_ufunc` to allow for copying or
- dropping a :py:class:`Dataset`'s variables with missing core dimensions.
- (:pull:`8138`)
+ dropping a :py:class:`Dataset`'s variables with missing core dimensions (:pull:`8138`).
By `Maximilian Roos `_.
Breaking changes
@@ -61,6 +135,8 @@ Deprecations
Bug fixes
~~~~~~~~~
+- Improved static typing of reduction methods (:pull:`6746`).
+ By `Richard Kleijn `_.
- Fix bug where empty attrs would generate inconsistent tokens (:issue:`6970`, :pull:`8101`).
By `Mattia Almansi `_.
- Improved handling of multi-coordinate indexes when updating coordinates, including bug fixes
@@ -71,26 +147,39 @@ Bug fixes
:pull:`8104`).
By `Benoît Bovy `_.
- Fix bug where :py:class:`DataArray` instances on the right-hand side
- of :py:meth:`DataArray.__setitem__` lose dimension names.
- (:issue:`7030`, :pull:`8067`) By `Darsh Ranjan `_.
+ of :py:meth:`DataArray.__setitem__` lose dimension names (:issue:`7030`, :pull:`8067`).
+ By `Darsh Ranjan `_.
- Return ``float64`` in presence of ``NaT`` in :py:class:`~core.accessor_dt.DatetimeAccessor` and
- special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar()`
+ special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar`
(:issue:`7928`, :pull:`8084`).
By `Kai Mühlbauer `_.
+- Fix :py:meth:`~core.rolling.DatasetRolling.construct` with stride on Datasets without indexes.
+ (:issue:`7021`, :pull:`7578`).
+ By `Amrest Chinkamol `_ and `Michael Niklas `_.
- Calling plot with kwargs ``col``, ``row`` or ``hue`` no longer squeezes dimensions passed via these arguments
(:issue:`7552`, :pull:`8174`).
By `Wiktor Kraśnicki `_.
-- Fixed a bug where casting from ``float`` to ``int64`` (undefined for ``NaN``) led to varying
- issues (:issue:`7817`, :issue:`7942`, :issue:`7790`, :issue:`6191`, :issue:`7096`,
+- Fixed a bug where casting from ``float`` to ``int64`` (undefined for ``NaN``) led to varying issues (:issue:`7817`, :issue:`7942`, :issue:`7790`, :issue:`6191`, :issue:`7096`,
:issue:`1064`, :pull:`7827`).
By `Kai Mühlbauer `_.
+- Fixed a bug where inaccurate ``coordinates`` silently failed to decode variable (:issue:`1809`, :pull:`8195`).
+ By `Kai Mühlbauer `_
- ``.rolling_exp`` functions no longer mistakenly lose non-dimensioned coords
- (:issue:`6528`, :pull:`8114`)
+ (:issue:`6528`, :pull:`8114`).
By `Maximilian Roos `_.
+- In the event that user-provided datetime64/timedelta64 units and integer dtype encoding parameters conflict with each other, override the units to preserve an integer dtype for most faithful serialization to disk (:issue:`1064`, :pull:`8201`).
+ By `Kai Mühlbauer `_.
+- Static typing of dunder ops methods (like :py:meth:`DataArray.__eq__`) has been fixed.
+ Remaining issues are upstream problems (:issue:`7780`, :pull:`8204`).
+ By `Michael Niklas `_.
+- Fix type annotation for ``center`` argument of plotting methods (like :py:meth:`xarray.plot.dataarray_plot.pcolormesh`) (:pull:`8261`).
+ By `Pieter Eendebak `_.
Documentation
~~~~~~~~~~~~~
+- Make documentation of :py:meth:`DataArray.where` clearer (:issue:`7767`, :pull:`7955`).
+ By `Riulinchen `_.
Internal Changes
~~~~~~~~~~~~~~~~
@@ -103,6 +192,8 @@ Internal Changes
than `.reduce`, as the start of a broader effort to move non-reducing
functions away from ```.reduce``, (:pull:`8114`).
By `Maximilian Roos `_.
+- Test range of fill_value's in test_interpolate_pd_compat (:issue:`8146`, :pull:`8189`).
+ By `Kai Mühlbauer `_.
.. _whats-new.2023.08.0:
diff --git a/pyproject.toml b/pyproject.toml
index dd380937bd2..bdae33e4d0d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -72,14 +72,16 @@ source = ["xarray"]
exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"]
[tool.mypy]
+enable_error_code = "redundant-self"
exclude = 'xarray/util/generate_.*\.py'
files = "xarray"
show_error_codes = true
show_error_context = true
warn_redundant_casts = true
+warn_unused_configs = true
warn_unused_ignores = true
-# Most of the numerical computing stack doesn't have type annotations yet.
+# Much of the numerical computing stack doesn't have type annotations yet.
[[tool.mypy.overrides]]
ignore_missing_imports = true
module = [
@@ -92,10 +94,10 @@ module = [
"cftime.*",
"cubed.*",
"cupy.*",
+ "dask.types.*",
"fsspec.*",
"h5netcdf.*",
"h5py.*",
- "importlib_metadata.*",
"iris.*",
"matplotlib.*",
"mpl_toolkits.*",
@@ -118,6 +120,108 @@ module = [
"numpy.exceptions.*", # remove once support for `numpy<2.0` has been dropped
]
+# Gradually we want to add more modules to this list, ratcheting up our total
+# coverage. Once a module is here, functions are checked by mypy regardless of
+# whether they have type annotations. It would be especially useful to have test
+# files listed here, because without them being checked, we don't have a great
+# way of testing our annotations.
+[[tool.mypy.overrides]]
+check_untyped_defs = true
+module = [
+ "xarray.core.accessor_dt",
+ "xarray.core.accessor_str",
+ "xarray.core.alignment",
+ "xarray.core.computation",
+ "xarray.core.rolling_exp",
+ "xarray.indexes.*",
+ "xarray.tests.*",
+]
+# This then excludes some modules from the above list. (So ideally we remove
+# from here in time...)
+[[tool.mypy.overrides]]
+check_untyped_defs = false
+module = [
+ "xarray.tests.test_coarsen",
+ "xarray.tests.test_coding_times",
+ "xarray.tests.test_combine",
+ "xarray.tests.test_computation",
+ "xarray.tests.test_concat",
+ "xarray.tests.test_coordinates",
+ "xarray.tests.test_dask",
+ "xarray.tests.test_dataarray",
+ "xarray.tests.test_duck_array_ops",
+ "xarray.tests.test_groupby",
+ "xarray.tests.test_indexing",
+ "xarray.tests.test_merge",
+ "xarray.tests.test_missing",
+ "xarray.tests.test_parallelcompat",
+ "xarray.tests.test_plot",
+ "xarray.tests.test_sparse",
+ "xarray.tests.test_ufuncs",
+ "xarray.tests.test_units",
+ "xarray.tests.test_utils",
+ "xarray.tests.test_variable",
+ "xarray.tests.test_weighted",
+]
+
+# Use strict = true whenever namedarray has become standalone. In the meantime
+# don't forget to add all new files related to namedarray here:
+# ref: https://mypy.readthedocs.io/en/stable/existing_code.html#introduce-stricter-options
+[[tool.mypy.overrides]]
+# Start off with these
+warn_unused_ignores = true
+
+# Getting these passing should be easy
+strict_concatenate = true
+strict_equality = true
+
+# Strongly recommend enabling this one as soon as you can
+check_untyped_defs = true
+
+# These shouldn't be too much additional work, but may be tricky to
+# get passing if you use a lot of untyped libraries
+disallow_any_generics = true
+disallow_subclassing_any = true
+disallow_untyped_decorators = true
+
+# These next few are various gradations of forcing use of type annotations
+disallow_incomplete_defs = true
+disallow_untyped_calls = true
+disallow_untyped_defs = true
+
+# This one isn't too hard to get passing, but return on investment is lower
+no_implicit_reexport = true
+
+# This one can be tricky to get passing if you use a lot of untyped libraries
+warn_return_any = true
+
+module = ["xarray.namedarray.*", "xarray.tests.test_namedarray"]
+
+[tool.pyright]
+# include = ["src"]
+# exclude = ["**/node_modules",
+# "**/__pycache__",
+# "src/experimental",
+# "src/typestubs"
+# ]
+# ignore = ["src/oldstuff"]
+defineConstant = {DEBUG = true}
+# stubPath = "src/stubs"
+# venv = "env367"
+
+reportMissingImports = true
+reportMissingTypeStubs = false
+
+# pythonVersion = "3.6"
+# pythonPlatform = "Linux"
+
+# executionEnvironments = [
+# { root = "src/web", pythonVersion = "3.5", pythonPlatform = "Windows", extraPaths = [ "src/service_libs" ] },
+# { root = "src/sdk", pythonVersion = "3.0", extraPaths = [ "src/backend" ] },
+# { root = "src/tests", extraPaths = ["src/tests/e2e", "src/sdk" ]},
+# { root = "src" }
+# ]
+
[tool.ruff]
builtins = ["ellipsis"]
exclude = [
@@ -146,15 +250,17 @@ select = [
known-first-party = ["xarray"]
[tool.pytest.ini_options]
-addopts = '--strict-markers'
+addopts = ["--strict-config", "--strict-markers"]
filterwarnings = [
"ignore:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning",
]
+log_cli_level = "INFO"
markers = [
"flaky: flaky tests",
"network: tests requiring a network connection",
"slow: slow tests",
]
+minversion = "7"
python_files = "test_*.py"
testpaths = ["xarray/tests", "properties"]
diff --git a/xarray/__init__.py b/xarray/__init__.py
index b63b0d81470..1fd3b0c4336 100644
--- a/xarray/__init__.py
+++ b/xarray/__init__.py
@@ -1,3 +1,5 @@
+from importlib.metadata import version as _version
+
from xarray import testing, tutorial
from xarray.backends.api import (
load_dataarray,
@@ -41,12 +43,6 @@
from xarray.core.variable import IndexVariable, Variable, as_variable
from xarray.util.print_versions import show_versions
-try:
- from importlib.metadata import version as _version
-except ImportError:
- # if the fallback library is missing, we are doomed.
- from importlib_metadata import version as _version
-
try:
__version__ = _version("xarray")
except Exception:
diff --git a/xarray/backends/api.py b/xarray/backends/api.py
index 58a05aeddce..27e155872de 100644
--- a/xarray/backends/api.py
+++ b/xarray/backends/api.py
@@ -488,6 +488,9 @@ def open_dataset(
as coordinate variables.
- "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and
other attributes as coordinate variables.
+
+ Only existing variables can be set as coordinates. Missing variables
+ will be silently ignored.
drop_variables: str or iterable of str, optional
A variable or list of variables to exclude from being parsed from the
dataset. This may be useful to drop variables with problems or
@@ -691,6 +694,9 @@ def open_dataarray(
as coordinate variables.
- "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and
other attributes as coordinate variables.
+
+ Only existing variables can be set as coordinates. Missing variables
+ will be silently ignored.
drop_variables: str or iterable of str, optional
A variable or list of variables to exclude from being parsed from the
dataset. This may be useful to drop variables with problems or
@@ -1522,6 +1528,7 @@ def to_zarr(
synchronizer=None,
group: str | None = None,
encoding: Mapping | None = None,
+ *,
compute: Literal[True] = True,
consolidated: bool | None = None,
append_dim: Hashable | None = None,
@@ -1567,6 +1574,7 @@ def to_zarr(
synchronizer=None,
group: str | None = None,
encoding: Mapping | None = None,
+ *,
compute: bool = True,
consolidated: bool | None = None,
append_dim: Hashable | None = None,
diff --git a/xarray/coding/times.py b/xarray/coding/times.py
index 79efbecfb7c..039fe371100 100644
--- a/xarray/coding/times.py
+++ b/xarray/coding/times.py
@@ -656,8 +656,22 @@ def cast_to_int_if_safe(num) -> np.ndarray:
return num
+def _division(deltas, delta, floor):
+ if floor:
+ # calculate int64 floor division
+ # to preserve integer dtype if possible (GH 4045, GH7817).
+ num = deltas // delta.astype(np.int64)
+ num = num.astype(np.int64, copy=False)
+ else:
+ num = deltas / delta
+ return num
+
+
def encode_cf_datetime(
- dates, units: str | None = None, calendar: str | None = None
+ dates,
+ units: str | None = None,
+ calendar: str | None = None,
+ dtype: np.dtype | None = None,
) -> tuple[np.ndarray, str, str]:
"""Given an array of datetime objects, returns the tuple `(num, units,
calendar)` suitable for a CF compliant time variable.
@@ -689,34 +703,47 @@ def encode_cf_datetime(
time_units, ref_date = _unpack_time_units_and_ref_date(units)
time_delta = _time_units_to_timedelta64(time_units)
+ # Wrap the dates in a DatetimeIndex to do the subtraction to ensure
+ # an OverflowError is raised if the ref_date is too far away from
+ # dates to be encoded (GH 2272).
+ dates_as_index = pd.DatetimeIndex(dates.ravel())
+ time_deltas = dates_as_index - ref_date
+
# retrieve needed units to faithfully encode to int64
needed_units, data_ref_date = _unpack_time_units_and_ref_date(data_units)
if data_units != units:
# this accounts for differences in the reference times
ref_delta = abs(data_ref_date - ref_date).to_timedelta64()
- if ref_delta > np.timedelta64(0, "ns"):
+ data_delta = _time_units_to_timedelta64(needed_units)
+ if (ref_delta % data_delta) > np.timedelta64(0, "ns"):
needed_units = _infer_time_units_from_diff(ref_delta)
- # Wrap the dates in a DatetimeIndex to do the subtraction to ensure
- # an OverflowError is raised if the ref_date is too far away from
- # dates to be encoded (GH 2272).
- dates_as_index = pd.DatetimeIndex(dates.ravel())
- time_deltas = dates_as_index - ref_date
-
# needed time delta to encode faithfully to int64
needed_time_delta = _time_units_to_timedelta64(needed_units)
- if time_delta <= needed_time_delta:
- # calculate int64 floor division
- # to preserve integer dtype if possible (GH 4045, GH7817).
- num = time_deltas // time_delta.astype(np.int64)
- num = num.astype(np.int64, copy=False)
- else:
- emit_user_level_warning(
- f"Times can't be serialized faithfully with requested units {units!r}. "
- f"Resolution of {needed_units!r} needed. "
- f"Serializing timeseries to floating point."
- )
- num = time_deltas / time_delta
+
+ floor_division = True
+ if time_delta > needed_time_delta:
+ floor_division = False
+ if dtype is None:
+ emit_user_level_warning(
+ f"Times can't be serialized faithfully to int64 with requested units {units!r}. "
+ f"Resolution of {needed_units!r} needed. Serializing times to floating point instead. "
+ f"Set encoding['dtype'] to integer dtype to serialize to int64. "
+ f"Set encoding['dtype'] to floating point dtype to silence this warning."
+ )
+ elif np.issubdtype(dtype, np.integer):
+ new_units = f"{needed_units} since {format_timestamp(ref_date)}"
+ emit_user_level_warning(
+ f"Times can't be serialized faithfully to int64 with requested units {units!r}. "
+ f"Serializing with units {new_units!r} instead. "
+ f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. "
+ f"Set encoding['units'] to {new_units!r} to silence this warning ."
+ )
+ units = new_units
+ time_delta = needed_time_delta
+ floor_division = True
+
+ num = _division(time_deltas, time_delta, floor_division)
num = num.values.reshape(dates.shape)
except (OutOfBoundsDatetime, OverflowError, ValueError):
@@ -728,7 +755,9 @@ def encode_cf_datetime(
return (num, units, calendar)
-def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarray, str]:
+def encode_cf_timedelta(
+ timedeltas, units: str | None = None, dtype: np.dtype | None = None
+) -> tuple[np.ndarray, str]:
data_units = infer_timedelta_units(timedeltas)
if units is None:
@@ -744,18 +773,29 @@ def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarra
# needed time delta to encode faithfully to int64
needed_time_delta = _time_units_to_timedelta64(needed_units)
- if time_delta <= needed_time_delta:
- # calculate int64 floor division
- # to preserve integer dtype if possible
- num = time_deltas // time_delta.astype(np.int64)
- num = num.astype(np.int64, copy=False)
- else:
- emit_user_level_warning(
- f"Timedeltas can't be serialized faithfully with requested units {units!r}. "
- f"Resolution of {needed_units!r} needed. "
- f"Serializing timedeltas to floating point."
- )
- num = time_deltas / time_delta
+
+ floor_division = True
+ if time_delta > needed_time_delta:
+ floor_division = False
+ if dtype is None:
+ emit_user_level_warning(
+ f"Timedeltas can't be serialized faithfully to int64 with requested units {units!r}. "
+ f"Resolution of {needed_units!r} needed. Serializing timeseries to floating point instead. "
+ f"Set encoding['dtype'] to integer dtype to serialize to int64. "
+ f"Set encoding['dtype'] to floating point dtype to silence this warning."
+ )
+ elif np.issubdtype(dtype, np.integer):
+ emit_user_level_warning(
+ f"Timedeltas can't be serialized faithfully with requested units {units!r}. "
+ f"Serializing with units {needed_units!r} instead. "
+ f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. "
+ f"Set encoding['units'] to {needed_units!r} to silence this warning ."
+ )
+ units = needed_units
+ time_delta = needed_time_delta
+ floor_division = True
+
+ num = _division(time_deltas, time_delta, floor_division)
num = num.values.reshape(timedeltas.shape)
return (num, units)
@@ -772,7 +812,8 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
units = encoding.pop("units", None)
calendar = encoding.pop("calendar", None)
- (data, units, calendar) = encode_cf_datetime(data, units, calendar)
+ dtype = encoding.get("dtype", None)
+ (data, units, calendar) = encode_cf_datetime(data, units, calendar, dtype)
safe_setitem(attrs, "units", units, name=name)
safe_setitem(attrs, "calendar", calendar, name=name)
@@ -807,7 +848,9 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if np.issubdtype(variable.data.dtype, np.timedelta64):
dims, data, attrs, encoding = unpack_for_encoding(variable)
- data, units = encode_cf_timedelta(data, encoding.pop("units", None))
+ data, units = encode_cf_timedelta(
+ data, encoding.pop("units", None), encoding.get("dtype", None)
+ )
safe_setitem(attrs, "units", units, name=name)
return Variable(dims, data, attrs, encoding, fastpath=True)
diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py
index d694c531b15..c583afc93c2 100644
--- a/xarray/coding/variables.py
+++ b/xarray/coding/variables.py
@@ -304,6 +304,8 @@ def decode(self, variable: Variable, name: T_Name = None):
)
# special case DateTime to properly handle NaT
+ dtype: np.typing.DTypeLike
+ decoded_fill_value: Any
if _is_time_like(attrs.get("units")) and data.dtype.kind in "iu":
dtype, decoded_fill_value = np.int64, np.iinfo(np.int64).min
else:
diff --git a/xarray/conventions.py b/xarray/conventions.py
index 596831e270a..cf207f0c37a 100644
--- a/xarray/conventions.py
+++ b/xarray/conventions.py
@@ -1,9 +1,8 @@
from __future__ import annotations
-import warnings
from collections import defaultdict
from collections.abc import Hashable, Iterable, Mapping, MutableMapping
-from typing import TYPE_CHECKING, Any, Union
+from typing import TYPE_CHECKING, Any, Literal, Union
import numpy as np
import pandas as pd
@@ -16,6 +15,7 @@
contains_cftime_datetimes,
)
from xarray.core.pycompat import is_duck_dask_array
+from xarray.core.utils import emit_user_level_warning
from xarray.core.variable import IndexVariable, Variable
CF_RELATED_DATA = (
@@ -111,13 +111,13 @@ def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
return var
if is_duck_dask_array(data):
- warnings.warn(
+ emit_user_level_warning(
f"variable {name} has data in the form of a dask array with "
"dtype=object, which means it is being loaded into memory "
"to determine a data type that can be safely stored on disk. "
"To avoid this, coerce this variable to a fixed-size dtype "
"with astype() before saving it.",
- SerializationWarning,
+ category=SerializationWarning,
)
data = data.compute()
@@ -354,15 +354,14 @@ def _update_bounds_encoding(variables: T_Variables) -> None:
and "bounds" in attrs
and attrs["bounds"] in variables
):
- warnings.warn(
- "Variable '{0}' has datetime type and a "
- "bounds variable but {0}.encoding does not have "
- "units specified. The units encodings for '{0}' "
- "and '{1}' will be determined independently "
+ emit_user_level_warning(
+ f"Variable {name:s} has datetime type and a "
+ f"bounds variable but {name:s}.encoding does not have "
+ f"units specified. The units encodings for {name:s} "
+ f"and {attrs['bounds']} will be determined independently "
"and may not be equal, counter to CF-conventions. "
"If this is a concern, specify a units encoding for "
- "'{0}' before writing to a file.".format(name, attrs["bounds"]),
- UserWarning,
+ f"{name:s} before writing to a file.",
)
if has_date_units and "bounds" in attrs:
@@ -379,7 +378,7 @@ def decode_cf_variables(
concat_characters: bool = True,
mask_and_scale: bool = True,
decode_times: bool = True,
- decode_coords: bool = True,
+ decode_coords: bool | Literal["coordinates", "all"] = True,
drop_variables: T_DropVariables = None,
use_cftime: bool | None = None,
decode_timedelta: bool | None = None,
@@ -441,11 +440,14 @@ def stackable(dim: Hashable) -> bool:
if decode_coords in [True, "coordinates", "all"]:
var_attrs = new_vars[k].attrs
if "coordinates" in var_attrs:
- coord_str = var_attrs["coordinates"]
- var_coord_names = coord_str.split()
- if all(k in variables for k in var_coord_names):
- new_vars[k].encoding["coordinates"] = coord_str
- del var_attrs["coordinates"]
+ var_coord_names = [
+ c for c in var_attrs["coordinates"].split() if c in variables
+ ]
+ # propagate as is
+ new_vars[k].encoding["coordinates"] = var_attrs["coordinates"]
+ del var_attrs["coordinates"]
+ # but only use as coordinate if existing
+ if var_coord_names:
coord_names.update(var_coord_names)
if decode_coords == "all":
@@ -461,8 +463,8 @@ def stackable(dim: Hashable) -> bool:
for role_or_name in part.split()
]
if len(roles_and_names) % 2 == 1:
- warnings.warn(
- f"Attribute {attr_name:s} malformed", stacklevel=5
+ emit_user_level_warning(
+ f"Attribute {attr_name:s} malformed"
)
var_names = roles_and_names[1::2]
if all(var_name in variables for var_name in var_names):
@@ -474,9 +476,8 @@ def stackable(dim: Hashable) -> bool:
for proj_name in var_names
if proj_name not in variables
]
- warnings.warn(
+ emit_user_level_warning(
f"Variable(s) referenced in {attr_name:s} not in variables: {referenced_vars_not_in_variables!s}",
- stacklevel=5,
)
del var_attrs[attr_name]
@@ -493,7 +494,7 @@ def decode_cf(
concat_characters: bool = True,
mask_and_scale: bool = True,
decode_times: bool = True,
- decode_coords: bool = True,
+ decode_coords: bool | Literal["coordinates", "all"] = True,
drop_variables: T_DropVariables = None,
use_cftime: bool | None = None,
decode_timedelta: bool | None = None,
@@ -632,12 +633,11 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names):
for name in list(non_dim_coord_names):
if isinstance(name, str) and " " in name:
- warnings.warn(
+ emit_user_level_warning(
f"coordinate {name!r} has a space in its name, which means it "
"cannot be marked as a coordinate on disk and will be "
"saved as a data variable instead",
- SerializationWarning,
- stacklevel=6,
+ category=SerializationWarning,
)
non_dim_coord_names.discard(name)
@@ -710,11 +710,11 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names):
if global_coordinates:
attributes = dict(attributes)
if "coordinates" in attributes:
- warnings.warn(
+ emit_user_level_warning(
f"cannot serialize global coordinates {global_coordinates!r} because the global "
f"attribute 'coordinates' already exists. This may prevent faithful roundtripping"
f"of xarray datasets",
- SerializationWarning,
+ category=SerializationWarning,
)
else:
attributes["coordinates"] = " ".join(sorted(map(str, global_coordinates)))
diff --git a/xarray/core/_typed_ops.py b/xarray/core/_typed_ops.py
index d3a783be45d..9b79ed46a9c 100644
--- a/xarray/core/_typed_ops.py
+++ b/xarray/core/_typed_ops.py
@@ -1,165 +1,182 @@
"""Mixin classes with arithmetic operators."""
# This file was generated using xarray.util.generate_ops. Do not edit manually.
+from __future__ import annotations
+
import operator
+from typing import TYPE_CHECKING, Any, Callable, overload
from xarray.core import nputils, ops
+from xarray.core.types import (
+ DaCompatible,
+ DsCompatible,
+ GroupByCompatible,
+ Self,
+ T_DataArray,
+ T_Xarray,
+ VarCompatible,
+)
+
+if TYPE_CHECKING:
+ from xarray.core.dataset import Dataset
class DatasetOpsMixin:
__slots__ = ()
- def _binary_op(self, other, f, reflexive=False):
+ def _binary_op(
+ self, other: DsCompatible, f: Callable, reflexive: bool = False
+ ) -> Self:
raise NotImplementedError
- def __add__(self, other):
+ def __add__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.add)
- def __sub__(self, other):
+ def __sub__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.sub)
- def __mul__(self, other):
+ def __mul__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.mul)
- def __pow__(self, other):
+ def __pow__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.pow)
- def __truediv__(self, other):
+ def __truediv__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.truediv)
- def __floordiv__(self, other):
+ def __floordiv__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.floordiv)
- def __mod__(self, other):
+ def __mod__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.mod)
- def __and__(self, other):
+ def __and__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.and_)
- def __xor__(self, other):
+ def __xor__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.xor)
- def __or__(self, other):
+ def __or__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.or_)
- def __lshift__(self, other):
+ def __lshift__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.lshift)
- def __rshift__(self, other):
+ def __rshift__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.rshift)
- def __lt__(self, other):
+ def __lt__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.lt)
- def __le__(self, other):
+ def __le__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.le)
- def __gt__(self, other):
+ def __gt__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.gt)
- def __ge__(self, other):
+ def __ge__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.ge)
- def __eq__(self, other):
+ def __eq__(self, other: DsCompatible) -> Self: # type:ignore[override]
return self._binary_op(other, nputils.array_eq)
- def __ne__(self, other):
+ def __ne__(self, other: DsCompatible) -> Self: # type:ignore[override]
return self._binary_op(other, nputils.array_ne)
- def __radd__(self, other):
+ def __radd__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.add, reflexive=True)
- def __rsub__(self, other):
+ def __rsub__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.sub, reflexive=True)
- def __rmul__(self, other):
+ def __rmul__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.mul, reflexive=True)
- def __rpow__(self, other):
+ def __rpow__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.pow, reflexive=True)
- def __rtruediv__(self, other):
+ def __rtruediv__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.truediv, reflexive=True)
- def __rfloordiv__(self, other):
+ def __rfloordiv__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.floordiv, reflexive=True)
- def __rmod__(self, other):
+ def __rmod__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.mod, reflexive=True)
- def __rand__(self, other):
+ def __rand__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.and_, reflexive=True)
- def __rxor__(self, other):
+ def __rxor__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.xor, reflexive=True)
- def __ror__(self, other):
+ def __ror__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.or_, reflexive=True)
- def _inplace_binary_op(self, other, f):
+ def _inplace_binary_op(self, other: DsCompatible, f: Callable) -> Self:
raise NotImplementedError
- def __iadd__(self, other):
+ def __iadd__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.iadd)
- def __isub__(self, other):
+ def __isub__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.isub)
- def __imul__(self, other):
+ def __imul__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.imul)
- def __ipow__(self, other):
+ def __ipow__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.ipow)
- def __itruediv__(self, other):
+ def __itruediv__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.itruediv)
- def __ifloordiv__(self, other):
+ def __ifloordiv__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.ifloordiv)
- def __imod__(self, other):
+ def __imod__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.imod)
- def __iand__(self, other):
+ def __iand__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.iand)
- def __ixor__(self, other):
+ def __ixor__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.ixor)
- def __ior__(self, other):
+ def __ior__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.ior)
- def __ilshift__(self, other):
+ def __ilshift__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.ilshift)
- def __irshift__(self, other):
+ def __irshift__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.irshift)
- def _unary_op(self, f, *args, **kwargs):
+ def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError
- def __neg__(self):
+ def __neg__(self) -> Self:
return self._unary_op(operator.neg)
- def __pos__(self):
+ def __pos__(self) -> Self:
return self._unary_op(operator.pos)
- def __abs__(self):
+ def __abs__(self) -> Self:
return self._unary_op(operator.abs)
- def __invert__(self):
+ def __invert__(self) -> Self:
return self._unary_op(operator.invert)
- def round(self, *args, **kwargs):
+ def round(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.round_, *args, **kwargs)
- def argsort(self, *args, **kwargs):
+ def argsort(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.argsort, *args, **kwargs)
- def conj(self, *args, **kwargs):
+ def conj(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.conj, *args, **kwargs)
- def conjugate(self, *args, **kwargs):
+ def conjugate(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.conjugate, *args, **kwargs)
__add__.__doc__ = operator.add.__doc__
@@ -215,157 +232,159 @@ def conjugate(self, *args, **kwargs):
class DataArrayOpsMixin:
__slots__ = ()
- def _binary_op(self, other, f, reflexive=False):
+ def _binary_op(
+ self, other: DaCompatible, f: Callable, reflexive: bool = False
+ ) -> Self:
raise NotImplementedError
- def __add__(self, other):
+ def __add__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.add)
- def __sub__(self, other):
+ def __sub__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.sub)
- def __mul__(self, other):
+ def __mul__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.mul)
- def __pow__(self, other):
+ def __pow__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.pow)
- def __truediv__(self, other):
+ def __truediv__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.truediv)
- def __floordiv__(self, other):
+ def __floordiv__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.floordiv)
- def __mod__(self, other):
+ def __mod__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.mod)
- def __and__(self, other):
+ def __and__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.and_)
- def __xor__(self, other):
+ def __xor__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.xor)
- def __or__(self, other):
+ def __or__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.or_)
- def __lshift__(self, other):
+ def __lshift__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.lshift)
- def __rshift__(self, other):
+ def __rshift__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.rshift)
- def __lt__(self, other):
+ def __lt__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.lt)
- def __le__(self, other):
+ def __le__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.le)
- def __gt__(self, other):
+ def __gt__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.gt)
- def __ge__(self, other):
+ def __ge__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.ge)
- def __eq__(self, other):
+ def __eq__(self, other: DaCompatible) -> Self: # type:ignore[override]
return self._binary_op(other, nputils.array_eq)
- def __ne__(self, other):
+ def __ne__(self, other: DaCompatible) -> Self: # type:ignore[override]
return self._binary_op(other, nputils.array_ne)
- def __radd__(self, other):
+ def __radd__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.add, reflexive=True)
- def __rsub__(self, other):
+ def __rsub__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.sub, reflexive=True)
- def __rmul__(self, other):
+ def __rmul__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.mul, reflexive=True)
- def __rpow__(self, other):
+ def __rpow__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.pow, reflexive=True)
- def __rtruediv__(self, other):
+ def __rtruediv__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.truediv, reflexive=True)
- def __rfloordiv__(self, other):
+ def __rfloordiv__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.floordiv, reflexive=True)
- def __rmod__(self, other):
+ def __rmod__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.mod, reflexive=True)
- def __rand__(self, other):
+ def __rand__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.and_, reflexive=True)
- def __rxor__(self, other):
+ def __rxor__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.xor, reflexive=True)
- def __ror__(self, other):
+ def __ror__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.or_, reflexive=True)
- def _inplace_binary_op(self, other, f):
+ def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self:
raise NotImplementedError
- def __iadd__(self, other):
+ def __iadd__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.iadd)
- def __isub__(self, other):
+ def __isub__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.isub)
- def __imul__(self, other):
+ def __imul__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.imul)
- def __ipow__(self, other):
+ def __ipow__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.ipow)
- def __itruediv__(self, other):
+ def __itruediv__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.itruediv)
- def __ifloordiv__(self, other):
+ def __ifloordiv__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.ifloordiv)
- def __imod__(self, other):
+ def __imod__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.imod)
- def __iand__(self, other):
+ def __iand__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.iand)
- def __ixor__(self, other):
+ def __ixor__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.ixor)
- def __ior__(self, other):
+ def __ior__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.ior)
- def __ilshift__(self, other):
+ def __ilshift__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.ilshift)
- def __irshift__(self, other):
+ def __irshift__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.irshift)
- def _unary_op(self, f, *args, **kwargs):
+ def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError
- def __neg__(self):
+ def __neg__(self) -> Self:
return self._unary_op(operator.neg)
- def __pos__(self):
+ def __pos__(self) -> Self:
return self._unary_op(operator.pos)
- def __abs__(self):
+ def __abs__(self) -> Self:
return self._unary_op(operator.abs)
- def __invert__(self):
+ def __invert__(self) -> Self:
return self._unary_op(operator.invert)
- def round(self, *args, **kwargs):
+ def round(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.round_, *args, **kwargs)
- def argsort(self, *args, **kwargs):
+ def argsort(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.argsort, *args, **kwargs)
- def conj(self, *args, **kwargs):
+ def conj(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.conj, *args, **kwargs)
- def conjugate(self, *args, **kwargs):
+ def conjugate(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.conjugate, *args, **kwargs)
__add__.__doc__ = operator.add.__doc__
@@ -421,157 +440,303 @@ def conjugate(self, *args, **kwargs):
class VariableOpsMixin:
__slots__ = ()
- def _binary_op(self, other, f, reflexive=False):
+ def _binary_op(
+ self, other: VarCompatible, f: Callable, reflexive: bool = False
+ ) -> Self:
raise NotImplementedError
- def __add__(self, other):
+ @overload
+ def __add__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __add__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __add__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.add)
- def __sub__(self, other):
+ @overload
+ def __sub__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __sub__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __sub__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.sub)
- def __mul__(self, other):
+ @overload
+ def __mul__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __mul__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __mul__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.mul)
- def __pow__(self, other):
+ @overload
+ def __pow__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __pow__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __pow__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.pow)
- def __truediv__(self, other):
+ @overload
+ def __truediv__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __truediv__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __truediv__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.truediv)
- def __floordiv__(self, other):
+ @overload
+ def __floordiv__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __floordiv__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __floordiv__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.floordiv)
- def __mod__(self, other):
+ @overload
+ def __mod__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __mod__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __mod__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.mod)
- def __and__(self, other):
+ @overload
+ def __and__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __and__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __and__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.and_)
- def __xor__(self, other):
+ @overload
+ def __xor__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __xor__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __xor__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.xor)
- def __or__(self, other):
+ @overload
+ def __or__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __or__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __or__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.or_)
- def __lshift__(self, other):
+ @overload
+ def __lshift__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __lshift__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __lshift__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.lshift)
- def __rshift__(self, other):
+ @overload
+ def __rshift__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __rshift__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __rshift__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.rshift)
- def __lt__(self, other):
+ @overload
+ def __lt__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __lt__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __lt__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.lt)
- def __le__(self, other):
+ @overload
+ def __le__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __le__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __le__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.le)
- def __gt__(self, other):
+ @overload
+ def __gt__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __gt__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __gt__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.gt)
- def __ge__(self, other):
+ @overload
+ def __ge__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __ge__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __ge__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, operator.ge)
- def __eq__(self, other):
+ @overload # type:ignore[override]
+ def __eq__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __eq__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __eq__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, nputils.array_eq)
- def __ne__(self, other):
+ @overload # type:ignore[override]
+ def __ne__(self, other: T_DataArray) -> T_DataArray:
+ ...
+
+ @overload
+ def __ne__(self, other: VarCompatible) -> Self:
+ ...
+
+ def __ne__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, nputils.array_ne)
- def __radd__(self, other):
+ def __radd__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.add, reflexive=True)
- def __rsub__(self, other):
+ def __rsub__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.sub, reflexive=True)
- def __rmul__(self, other):
+ def __rmul__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.mul, reflexive=True)
- def __rpow__(self, other):
+ def __rpow__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.pow, reflexive=True)
- def __rtruediv__(self, other):
+ def __rtruediv__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.truediv, reflexive=True)
- def __rfloordiv__(self, other):
+ def __rfloordiv__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.floordiv, reflexive=True)
- def __rmod__(self, other):
+ def __rmod__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.mod, reflexive=True)
- def __rand__(self, other):
+ def __rand__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.and_, reflexive=True)
- def __rxor__(self, other):
+ def __rxor__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.xor, reflexive=True)
- def __ror__(self, other):
+ def __ror__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.or_, reflexive=True)
- def _inplace_binary_op(self, other, f):
+ def _inplace_binary_op(self, other: VarCompatible, f: Callable) -> Self:
raise NotImplementedError
- def __iadd__(self, other):
+ def __iadd__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.iadd)
- def __isub__(self, other):
+ def __isub__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.isub)
- def __imul__(self, other):
+ def __imul__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.imul)
- def __ipow__(self, other):
+ def __ipow__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ipow)
- def __itruediv__(self, other):
+ def __itruediv__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.itruediv)
- def __ifloordiv__(self, other):
+ def __ifloordiv__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ifloordiv)
- def __imod__(self, other):
+ def __imod__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.imod)
- def __iand__(self, other):
+ def __iand__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.iand)
- def __ixor__(self, other):
+ def __ixor__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ixor)
- def __ior__(self, other):
+ def __ior__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ior)
- def __ilshift__(self, other):
+ def __ilshift__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ilshift)
- def __irshift__(self, other):
+ def __irshift__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.irshift)
- def _unary_op(self, f, *args, **kwargs):
+ def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError
- def __neg__(self):
+ def __neg__(self) -> Self:
return self._unary_op(operator.neg)
- def __pos__(self):
+ def __pos__(self) -> Self:
return self._unary_op(operator.pos)
- def __abs__(self):
+ def __abs__(self) -> Self:
return self._unary_op(operator.abs)
- def __invert__(self):
+ def __invert__(self) -> Self:
return self._unary_op(operator.invert)
- def round(self, *args, **kwargs):
+ def round(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.round_, *args, **kwargs)
- def argsort(self, *args, **kwargs):
+ def argsort(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.argsort, *args, **kwargs)
- def conj(self, *args, **kwargs):
+ def conj(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.conj, *args, **kwargs)
- def conjugate(self, *args, **kwargs):
+ def conjugate(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op(ops.conjugate, *args, **kwargs)
__add__.__doc__ = operator.add.__doc__
@@ -627,91 +792,93 @@ def conjugate(self, *args, **kwargs):
class DatasetGroupByOpsMixin:
__slots__ = ()
- def _binary_op(self, other, f, reflexive=False):
+ def _binary_op(
+ self, other: GroupByCompatible, f: Callable, reflexive: bool = False
+ ) -> Dataset:
raise NotImplementedError
- def __add__(self, other):
+ def __add__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.add)
- def __sub__(self, other):
+ def __sub__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.sub)
- def __mul__(self, other):
+ def __mul__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.mul)
- def __pow__(self, other):
+ def __pow__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.pow)
- def __truediv__(self, other):
+ def __truediv__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.truediv)
- def __floordiv__(self, other):
+ def __floordiv__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.floordiv)
- def __mod__(self, other):
+ def __mod__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.mod)
- def __and__(self, other):
+ def __and__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.and_)
- def __xor__(self, other):
+ def __xor__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.xor)
- def __or__(self, other):
+ def __or__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.or_)
- def __lshift__(self, other):
+ def __lshift__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.lshift)
- def __rshift__(self, other):
+ def __rshift__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.rshift)
- def __lt__(self, other):
+ def __lt__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.lt)
- def __le__(self, other):
+ def __le__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.le)
- def __gt__(self, other):
+ def __gt__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.gt)
- def __ge__(self, other):
+ def __ge__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.ge)
- def __eq__(self, other):
+ def __eq__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override]
return self._binary_op(other, nputils.array_eq)
- def __ne__(self, other):
+ def __ne__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override]
return self._binary_op(other, nputils.array_ne)
- def __radd__(self, other):
+ def __radd__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.add, reflexive=True)
- def __rsub__(self, other):
+ def __rsub__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.sub, reflexive=True)
- def __rmul__(self, other):
+ def __rmul__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.mul, reflexive=True)
- def __rpow__(self, other):
+ def __rpow__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.pow, reflexive=True)
- def __rtruediv__(self, other):
+ def __rtruediv__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.truediv, reflexive=True)
- def __rfloordiv__(self, other):
+ def __rfloordiv__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.floordiv, reflexive=True)
- def __rmod__(self, other):
+ def __rmod__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.mod, reflexive=True)
- def __rand__(self, other):
+ def __rand__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.and_, reflexive=True)
- def __rxor__(self, other):
+ def __rxor__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.xor, reflexive=True)
- def __ror__(self, other):
+ def __ror__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.or_, reflexive=True)
__add__.__doc__ = operator.add.__doc__
@@ -747,91 +914,93 @@ def __ror__(self, other):
class DataArrayGroupByOpsMixin:
__slots__ = ()
- def _binary_op(self, other, f, reflexive=False):
+ def _binary_op(
+ self, other: T_Xarray, f: Callable, reflexive: bool = False
+ ) -> T_Xarray:
raise NotImplementedError
- def __add__(self, other):
+ def __add__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.add)
- def __sub__(self, other):
+ def __sub__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.sub)
- def __mul__(self, other):
+ def __mul__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.mul)
- def __pow__(self, other):
+ def __pow__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.pow)
- def __truediv__(self, other):
+ def __truediv__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.truediv)
- def __floordiv__(self, other):
+ def __floordiv__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.floordiv)
- def __mod__(self, other):
+ def __mod__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.mod)
- def __and__(self, other):
+ def __and__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.and_)
- def __xor__(self, other):
+ def __xor__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.xor)
- def __or__(self, other):
+ def __or__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.or_)
- def __lshift__(self, other):
+ def __lshift__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.lshift)
- def __rshift__(self, other):
+ def __rshift__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.rshift)
- def __lt__(self, other):
+ def __lt__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.lt)
- def __le__(self, other):
+ def __le__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.le)
- def __gt__(self, other):
+ def __gt__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.gt)
- def __ge__(self, other):
+ def __ge__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.ge)
- def __eq__(self, other):
+ def __eq__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override]
return self._binary_op(other, nputils.array_eq)
- def __ne__(self, other):
+ def __ne__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override]
return self._binary_op(other, nputils.array_ne)
- def __radd__(self, other):
+ def __radd__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.add, reflexive=True)
- def __rsub__(self, other):
+ def __rsub__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.sub, reflexive=True)
- def __rmul__(self, other):
+ def __rmul__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.mul, reflexive=True)
- def __rpow__(self, other):
+ def __rpow__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.pow, reflexive=True)
- def __rtruediv__(self, other):
+ def __rtruediv__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.truediv, reflexive=True)
- def __rfloordiv__(self, other):
+ def __rfloordiv__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.floordiv, reflexive=True)
- def __rmod__(self, other):
+ def __rmod__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.mod, reflexive=True)
- def __rand__(self, other):
+ def __rand__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.and_, reflexive=True)
- def __rxor__(self, other):
+ def __rxor__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.xor, reflexive=True)
- def __ror__(self, other):
+ def __ror__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.or_, reflexive=True)
__add__.__doc__ = operator.add.__doc__
diff --git a/xarray/core/_typed_ops.pyi b/xarray/core/_typed_ops.pyi
deleted file mode 100644
index 9e2ba2d3a06..00000000000
--- a/xarray/core/_typed_ops.pyi
+++ /dev/null
@@ -1,782 +0,0 @@
-"""Stub file for mixin classes with arithmetic operators."""
-# This file was generated using xarray.util.generate_ops. Do not edit manually.
-
-from typing import NoReturn, TypeVar, overload
-
-import numpy as np
-from numpy.typing import ArrayLike
-
-from .dataarray import DataArray
-from .dataset import Dataset
-from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy
-from .types import (
- DaCompatible,
- DsCompatible,
- GroupByIncompatible,
- ScalarOrArray,
- VarCompatible,
-)
-from .variable import Variable
-
-try:
- from dask.array import Array as DaskArray
-except ImportError:
- DaskArray = np.ndarray # type: ignore
-
-# DatasetOpsMixin etc. are parent classes of Dataset etc.
-# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally
-# we use the ones in `types`. (We're open to refining this, and potentially integrating
-# the `py` & `pyi` files to simplify them.)
-T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin")
-T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin")
-T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin")
-
-class DatasetOpsMixin:
- __slots__ = ()
- def _binary_op(self, other, f, reflexive=...): ...
- def __add__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __sub__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __mul__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __pow__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __truediv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __floordiv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __mod__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __and__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __xor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __or__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __lshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __rshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __lt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __le__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __gt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __ge__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __eq__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... # type: ignore[override]
- def __ne__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... # type: ignore[override]
- def __radd__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __rsub__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __rmul__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __rpow__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __rtruediv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __rfloordiv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __rmod__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __rand__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __rxor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def __ror__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
- def _inplace_binary_op(self, other, f): ...
- def _unary_op(self, f, *args, **kwargs): ...
- def __neg__(self: T_Dataset) -> T_Dataset: ...
- def __pos__(self: T_Dataset) -> T_Dataset: ...
- def __abs__(self: T_Dataset) -> T_Dataset: ...
- def __invert__(self: T_Dataset) -> T_Dataset: ...
- def round(self: T_Dataset, *args, **kwargs) -> T_Dataset: ...
- def argsort(self: T_Dataset, *args, **kwargs) -> T_Dataset: ...
- def conj(self: T_Dataset, *args, **kwargs) -> T_Dataset: ...
- def conjugate(self: T_Dataset, *args, **kwargs) -> T_Dataset: ...
-
-class DataArrayOpsMixin:
- __slots__ = ()
- def _binary_op(self, other, f, reflexive=...): ...
- @overload
- def __add__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __add__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __add__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __sub__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __sub__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __sub__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __mul__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __mul__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __mul__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __pow__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __pow__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __pow__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __truediv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __truediv__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __truediv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __floordiv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __floordiv__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __floordiv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __mod__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __mod__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __mod__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __and__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __and__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __and__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __xor__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __xor__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __xor__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __or__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __or__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __or__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __lshift__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __lshift__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __lshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __rshift__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rshift__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __rshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __lt__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __lt__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __lt__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __le__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __le__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __le__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __gt__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __gt__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __gt__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __ge__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __ge__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __ge__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload # type: ignore[override]
- def __eq__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __eq__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __eq__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload # type: ignore[override]
- def __ne__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __ne__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __ne__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __radd__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __radd__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __radd__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __rsub__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rsub__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __rsub__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __rmul__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rmul__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __rmul__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __rpow__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rpow__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __rpow__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rtruediv__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __rtruediv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rfloordiv__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __rfloordiv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __rmod__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rmod__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __rmod__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __rand__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rand__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __rand__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __rxor__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rxor__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __rxor__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- @overload
- def __ror__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __ror__(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def __ror__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
- def _inplace_binary_op(self, other, f): ...
- def _unary_op(self, f, *args, **kwargs): ...
- def __neg__(self: T_DataArray) -> T_DataArray: ...
- def __pos__(self: T_DataArray) -> T_DataArray: ...
- def __abs__(self: T_DataArray) -> T_DataArray: ...
- def __invert__(self: T_DataArray) -> T_DataArray: ...
- def round(self: T_DataArray, *args, **kwargs) -> T_DataArray: ...
- def argsort(self: T_DataArray, *args, **kwargs) -> T_DataArray: ...
- def conj(self: T_DataArray, *args, **kwargs) -> T_DataArray: ...
- def conjugate(self: T_DataArray, *args, **kwargs) -> T_DataArray: ...
-
-class VariableOpsMixin:
- __slots__ = ()
- def _binary_op(self, other, f, reflexive=...): ...
- @overload
- def __add__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __add__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __add__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __sub__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __sub__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __sub__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __mul__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __mul__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __mul__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __pow__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __pow__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __pow__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __truediv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __truediv__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __truediv__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __floordiv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __floordiv__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __floordiv__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __mod__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __mod__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __mod__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __and__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __and__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __and__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __xor__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __xor__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __xor__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __or__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __or__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __or__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __lshift__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __lshift__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __lshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __rshift__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rshift__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __lt__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __lt__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __lt__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __le__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __le__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __le__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __gt__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __gt__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __gt__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __ge__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __ge__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __ge__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload # type: ignore[override]
- def __eq__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __eq__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __eq__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload # type: ignore[override]
- def __ne__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __ne__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __ne__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __radd__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __radd__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __radd__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __rsub__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rsub__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rsub__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __rmul__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rmul__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rmul__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __rpow__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rpow__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rpow__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rtruediv__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rtruediv__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rfloordiv__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rfloordiv__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __rmod__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rmod__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rmod__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __rand__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rand__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rand__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __rxor__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rxor__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rxor__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- @overload
- def __ror__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __ror__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __ror__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
- def _inplace_binary_op(self, other, f): ...
- def _unary_op(self, f, *args, **kwargs): ...
- def __neg__(self: T_Variable) -> T_Variable: ...
- def __pos__(self: T_Variable) -> T_Variable: ...
- def __abs__(self: T_Variable) -> T_Variable: ...
- def __invert__(self: T_Variable) -> T_Variable: ...
- def round(self: T_Variable, *args, **kwargs) -> T_Variable: ...
- def argsort(self: T_Variable, *args, **kwargs) -> T_Variable: ...
- def conj(self: T_Variable, *args, **kwargs) -> T_Variable: ...
- def conjugate(self: T_Variable, *args, **kwargs) -> T_Variable: ...
-
-class DatasetGroupByOpsMixin:
- __slots__ = ()
- def _binary_op(self, other, f, reflexive=...): ...
- @overload
- def __add__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __add__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __add__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __sub__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __sub__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __sub__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __mul__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __mul__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __mul__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __pow__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __pow__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __pow__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __truediv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __truediv__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __truediv__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __floordiv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __floordiv__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __floordiv__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __mod__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __mod__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __mod__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __and__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __and__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __and__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __xor__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __xor__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __xor__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __or__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __or__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __or__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __lshift__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __lshift__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rshift__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rshift__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __lt__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __lt__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __lt__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __le__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __le__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __le__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __gt__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __gt__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __gt__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __ge__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __ge__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __ge__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload # type: ignore[override]
- def __eq__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __eq__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __eq__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload # type: ignore[override]
- def __ne__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __ne__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __ne__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __radd__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __radd__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __radd__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rsub__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rsub__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __rsub__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rmul__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rmul__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __rmul__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rpow__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rpow__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __rpow__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rtruediv__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __rtruediv__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rfloordiv__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __rfloordiv__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rmod__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rmod__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __rmod__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rand__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rand__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __rand__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rxor__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rxor__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __rxor__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __ror__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __ror__(self, other: "DataArray") -> "Dataset": ...
- @overload
- def __ror__(self, other: GroupByIncompatible) -> NoReturn: ...
-
-class DataArrayGroupByOpsMixin:
- __slots__ = ()
- def _binary_op(self, other, f, reflexive=...): ...
- @overload
- def __add__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __add__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __add__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __sub__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __sub__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __sub__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __mul__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __mul__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __mul__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __pow__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __pow__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __pow__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __truediv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __truediv__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __truediv__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __floordiv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __floordiv__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __floordiv__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __mod__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __mod__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __mod__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __and__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __and__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __and__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __xor__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __xor__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __xor__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __or__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __or__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __or__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __lshift__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __lshift__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rshift__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rshift__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __lt__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __lt__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __lt__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __le__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __le__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __le__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __gt__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __gt__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __gt__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __ge__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __ge__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __ge__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload # type: ignore[override]
- def __eq__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __eq__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __eq__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload # type: ignore[override]
- def __ne__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __ne__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __ne__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __radd__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __radd__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __radd__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rsub__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rsub__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rsub__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rmul__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rmul__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rmul__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rpow__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rpow__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rpow__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rtruediv__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rtruediv__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rfloordiv__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rfloordiv__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rmod__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rmod__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rmod__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rand__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rand__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rand__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __rxor__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __rxor__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __rxor__(self, other: GroupByIncompatible) -> NoReturn: ...
- @overload
- def __ror__(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def __ror__(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def __ror__(self, other: GroupByIncompatible) -> NoReturn: ...
diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py
index 4c1ce4b5c48..8255e2a5232 100644
--- a/xarray/core/accessor_dt.py
+++ b/xarray/core/accessor_dt.py
@@ -7,6 +7,7 @@
import pandas as pd
from xarray.coding.times import infer_calendar_name
+from xarray.core import duck_array_ops
from xarray.core.common import (
_contains_datetime_like_objects,
is_np_datetime_like,
@@ -50,7 +51,7 @@ def _access_through_cftimeindex(values, name):
from xarray.coding.cftimeindex import CFTimeIndex
if not isinstance(values, CFTimeIndex):
- values_as_cftimeindex = CFTimeIndex(values.ravel())
+ values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
else:
values_as_cftimeindex = values
if name == "season":
@@ -69,7 +70,7 @@ def _access_through_series(values, name):
"""Coerce an array of datetime-like values to a pandas Series and
access requested datetime component
"""
- values_as_series = pd.Series(values.ravel(), copy=False)
+ values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
if name == "season":
months = values_as_series.dt.month.values
field_values = _season_from_months(months)
@@ -148,10 +149,10 @@ def _round_through_series_or_index(values, name, freq):
from xarray.coding.cftimeindex import CFTimeIndex
if is_np_datetime_like(values.dtype):
- values_as_series = pd.Series(values.ravel(), copy=False)
+ values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
method = getattr(values_as_series.dt, name)
else:
- values_as_cftimeindex = CFTimeIndex(values.ravel())
+ values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
method = getattr(values_as_cftimeindex, name)
field_values = method(freq=freq).values
@@ -195,7 +196,7 @@ def _strftime_through_cftimeindex(values, date_format: str):
"""
from xarray.coding.cftimeindex import CFTimeIndex
- values_as_cftimeindex = CFTimeIndex(values.ravel())
+ values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
field_values = values_as_cftimeindex.strftime(date_format)
return field_values.values.reshape(values.shape)
@@ -205,7 +206,7 @@ def _strftime_through_series(values, date_format: str):
"""Coerce an array of datetime-like values to a pandas Series and
apply string formatting
"""
- values_as_series = pd.Series(values.ravel(), copy=False)
+ values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
strs = values_as_series.dt.strftime(date_format)
return strs.values.reshape(values.shape)
diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py
index aa6dc2c7114..573200b5c88 100644
--- a/xarray/core/accessor_str.py
+++ b/xarray/core/accessor_str.py
@@ -2386,7 +2386,7 @@ def _partitioner(
# _apply breaks on an empty array in this case
if not self._obj.size:
- return self._obj.copy().expand_dims({dim: 0}, axis=-1) # type: ignore[return-value]
+ return self._obj.copy().expand_dims({dim: 0}, axis=-1)
arrfunc = lambda x, isep: np.array(func(x, isep), dtype=self._obj.dtype)
diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py
index ff2ecbc74a1..732ec5d3ea6 100644
--- a/xarray/core/alignment.py
+++ b/xarray/core/alignment.py
@@ -5,7 +5,7 @@
from collections import defaultdict
from collections.abc import Hashable, Iterable, Mapping
from contextlib import suppress
-from typing import TYPE_CHECKING, Any, Callable, Generic, cast
+from typing import TYPE_CHECKING, Any, Callable, Final, Generic, TypeVar, cast, overload
import numpy as np
import pandas as pd
@@ -26,7 +26,13 @@
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
- from xarray.core.types import JoinOptions, T_DataArray, T_Dataset, T_DuckArray
+ from xarray.core.types import (
+ Alignable,
+ JoinOptions,
+ T_DataArray,
+ T_Dataset,
+ T_DuckArray,
+ )
def reindex_variables(
@@ -128,7 +134,7 @@ def __init__(
objects: Iterable[T_Alignable],
join: str = "inner",
indexes: Mapping[Any, Any] | None = None,
- exclude_dims: Iterable = frozenset(),
+ exclude_dims: str | Iterable[Hashable] = frozenset(),
exclude_vars: Iterable[Hashable] = frozenset(),
method: str | None = None,
tolerance: int | float | Iterable[int | float] | None = None,
@@ -576,12 +582,111 @@ def align(self) -> None:
self.reindex_all()
+T_Obj1 = TypeVar("T_Obj1", bound="Alignable")
+T_Obj2 = TypeVar("T_Obj2", bound="Alignable")
+T_Obj3 = TypeVar("T_Obj3", bound="Alignable")
+T_Obj4 = TypeVar("T_Obj4", bound="Alignable")
+T_Obj5 = TypeVar("T_Obj5", bound="Alignable")
+
+
+@overload
+def align(
+ obj1: T_Obj1,
+ /,
+ *,
+ join: JoinOptions = "inner",
+ copy: bool = True,
+ indexes=None,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ fill_value=dtypes.NA,
+) -> tuple[T_Obj1]:
+ ...
+
+
+@overload
+def align(
+ obj1: T_Obj1,
+ obj2: T_Obj2,
+ /,
+ *,
+ join: JoinOptions = "inner",
+ copy: bool = True,
+ indexes=None,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ fill_value=dtypes.NA,
+) -> tuple[T_Obj1, T_Obj2]:
+ ...
+
+
+@overload
def align(
+ obj1: T_Obj1,
+ obj2: T_Obj2,
+ obj3: T_Obj3,
+ /,
+ *,
+ join: JoinOptions = "inner",
+ copy: bool = True,
+ indexes=None,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ fill_value=dtypes.NA,
+) -> tuple[T_Obj1, T_Obj2, T_Obj3]:
+ ...
+
+
+@overload
+def align(
+ obj1: T_Obj1,
+ obj2: T_Obj2,
+ obj3: T_Obj3,
+ obj4: T_Obj4,
+ /,
+ *,
+ join: JoinOptions = "inner",
+ copy: bool = True,
+ indexes=None,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ fill_value=dtypes.NA,
+) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4]:
+ ...
+
+
+@overload
+def align(
+ obj1: T_Obj1,
+ obj2: T_Obj2,
+ obj3: T_Obj3,
+ obj4: T_Obj4,
+ obj5: T_Obj5,
+ /,
+ *,
+ join: JoinOptions = "inner",
+ copy: bool = True,
+ indexes=None,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ fill_value=dtypes.NA,
+) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4, T_Obj5]:
+ ...
+
+
+@overload
+def align(
+ *objects: T_Alignable,
+ join: JoinOptions = "inner",
+ copy: bool = True,
+ indexes=None,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ fill_value=dtypes.NA,
+) -> tuple[T_Alignable, ...]:
+ ...
+
+
+def align( # type: ignore[misc]
*objects: T_Alignable,
join: JoinOptions = "inner",
copy: bool = True,
indexes=None,
- exclude=frozenset(),
+ exclude: str | Iterable[Hashable] = frozenset(),
fill_value=dtypes.NA,
) -> tuple[T_Alignable, ...]:
"""
@@ -620,7 +725,7 @@ def align(
indexes : dict-like, optional
Any indexes explicitly provided with the `indexes` argument should be
used in preference to the aligned indexes.
- exclude : sequence of str, optional
+ exclude : str, iterable of hashable or None, optional
Dimensions that must be excluded from alignment
fill_value : scalar or dict-like, optional
Value to use for newly missing values. If a dict-like, maps
@@ -787,12 +892,12 @@ def align(
def deep_align(
objects: Iterable[Any],
join: JoinOptions = "inner",
- copy=True,
+ copy: bool = True,
indexes=None,
- exclude=frozenset(),
- raise_on_invalid=True,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ raise_on_invalid: bool = True,
fill_value=dtypes.NA,
-):
+) -> list[Any]:
"""Align objects for merging, recursing into dictionary values.
This function is not public API.
@@ -807,12 +912,12 @@ def deep_align(
def is_alignable(obj):
return isinstance(obj, (Coordinates, DataArray, Dataset))
- positions = []
- keys = []
- out = []
- targets = []
- no_key = object()
- not_replaced = object()
+ positions: list[int] = []
+ keys: list[type[object] | Hashable] = []
+ out: list[Any] = []
+ targets: list[Alignable] = []
+ no_key: Final = object()
+ not_replaced: Final = object()
for position, variables in enumerate(objects):
if is_alignable(variables):
positions.append(position)
@@ -857,7 +962,7 @@ def is_alignable(obj):
if key is no_key:
out[position] = aligned_obj
else:
- out[position][key] = aligned_obj # type: ignore[index] # maybe someone can fix this?
+ out[position][key] = aligned_obj
return out
@@ -988,9 +1093,69 @@ def _broadcast_dataset(ds: T_Dataset) -> T_Dataset:
raise ValueError("all input must be Dataset or DataArray objects")
-# TODO: this typing is too restrictive since it cannot deal with mixed
-# DataArray and Dataset types...? Is this a problem?
-def broadcast(*args: T_Alignable, exclude=None) -> tuple[T_Alignable, ...]:
+@overload
+def broadcast(
+ obj1: T_Obj1, /, *, exclude: str | Iterable[Hashable] | None = None
+) -> tuple[T_Obj1]:
+ ...
+
+
+@overload
+def broadcast(
+ obj1: T_Obj1, obj2: T_Obj2, /, *, exclude: str | Iterable[Hashable] | None = None
+) -> tuple[T_Obj1, T_Obj2]:
+ ...
+
+
+@overload
+def broadcast(
+ obj1: T_Obj1,
+ obj2: T_Obj2,
+ obj3: T_Obj3,
+ /,
+ *,
+ exclude: str | Iterable[Hashable] | None = None,
+) -> tuple[T_Obj1, T_Obj2, T_Obj3]:
+ ...
+
+
+@overload
+def broadcast(
+ obj1: T_Obj1,
+ obj2: T_Obj2,
+ obj3: T_Obj3,
+ obj4: T_Obj4,
+ /,
+ *,
+ exclude: str | Iterable[Hashable] | None = None,
+) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4]:
+ ...
+
+
+@overload
+def broadcast(
+ obj1: T_Obj1,
+ obj2: T_Obj2,
+ obj3: T_Obj3,
+ obj4: T_Obj4,
+ obj5: T_Obj5,
+ /,
+ *,
+ exclude: str | Iterable[Hashable] | None = None,
+) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4, T_Obj5]:
+ ...
+
+
+@overload
+def broadcast(
+ *args: T_Alignable, exclude: str | Iterable[Hashable] | None = None
+) -> tuple[T_Alignable, ...]:
+ ...
+
+
+def broadcast( # type: ignore[misc]
+ *args: T_Alignable, exclude: str | Iterable[Hashable] | None = None
+) -> tuple[T_Alignable, ...]:
"""Explicitly broadcast any number of DataArray or Dataset objects against
one another.
@@ -1004,7 +1169,7 @@ def broadcast(*args: T_Alignable, exclude=None) -> tuple[T_Alignable, ...]:
----------
*args : DataArray or Dataset
Arrays to broadcast against each other.
- exclude : sequence of str, optional
+ exclude : str, iterable of hashable or None, optional
Dimensions that must not be broadcasted
Returns
diff --git a/xarray/core/common.py b/xarray/core/common.py
index ade701457c6..ab8a4d84261 100644
--- a/xarray/core/common.py
+++ b/xarray/core/common.py
@@ -45,6 +45,7 @@
DatetimeLike,
DTypeLikeSave,
ScalarOrArray,
+ Self,
SideOptions,
T_Chunks,
T_DataWithCoords,
@@ -222,7 +223,7 @@ def _get_axis_num(self: Any, dim: Hashable) -> int:
raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}")
@property
- def sizes(self: Any) -> Frozen[Hashable, int]:
+ def sizes(self: Any) -> Mapping[Hashable, int]:
"""Ordered mapping from dimension names to lengths.
Immutable.
@@ -381,11 +382,11 @@ class DataWithCoords(AttrAccessMixin):
__slots__ = ("_close",)
def squeeze(
- self: T_DataWithCoords,
+ self,
dim: Hashable | Iterable[Hashable] | None = None,
drop: bool = False,
axis: int | Iterable[int] | None = None,
- ) -> T_DataWithCoords:
+ ) -> Self:
"""Return a new object with squeezed data.
Parameters
@@ -414,12 +415,12 @@ def squeeze(
return self.isel(drop=drop, **{d: 0 for d in dims})
def clip(
- self: T_DataWithCoords,
+ self,
min: ScalarOrArray | None = None,
max: ScalarOrArray | None = None,
*,
keep_attrs: bool | None = None,
- ) -> T_DataWithCoords:
+ ) -> Self:
"""
Return an array whose values are limited to ``[min, max]``.
At least one of max or min must be given.
@@ -472,10 +473,10 @@ def _calc_assign_results(
return {k: v(self) if callable(v) else v for k, v in kwargs.items()}
def assign_coords(
- self: T_DataWithCoords,
+ self,
coords: Mapping[Any, Any] | None = None,
**coords_kwargs: Any,
- ) -> T_DataWithCoords:
+ ) -> Self:
"""Assign new coordinates to this object.
Returns a new object with all the original data in addition to the new
@@ -620,9 +621,7 @@ def assign_coords(
data.coords.update(results)
return data
- def assign_attrs(
- self: T_DataWithCoords, *args: Any, **kwargs: Any
- ) -> T_DataWithCoords:
+ def assign_attrs(self, *args: Any, **kwargs: Any) -> Self:
"""Assign new attrs to this object.
Returns a new object equivalent to ``self.attrs.update(*args, **kwargs)``.
@@ -1061,11 +1060,12 @@ def _resample(
restore_coord_dims=restore_coord_dims,
)
- def where(
- self: T_DataWithCoords, cond: Any, other: Any = dtypes.NA, drop: bool = False
- ) -> T_DataWithCoords:
+ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self:
"""Filter elements from this object according to a condition.
+ Returns elements from 'DataArray', where 'cond' is True,
+ otherwise fill in 'other'.
+
This operation follows the normal broadcasting and alignment rules that
xarray uses for binary arithmetic.
@@ -1073,10 +1073,12 @@ def where(
----------
cond : DataArray, Dataset, or callable
Locations at which to preserve this object's values. dtype must be `bool`.
- If a callable, it must expect this object as its only parameter.
- other : scalar, DataArray or Dataset, optional
+ If a callable, the callable is passed this object, and the result is used as
+ the value for cond.
+ other : scalar, DataArray, Dataset, or callable, optional
Value to use for locations in this object where ``cond`` is False.
- By default, these locations filled with NA.
+ By default, these locations are filled with NA. If a callable, it must
+ expect this object as its only parameter.
drop : bool, default: False
If True, coordinate labels that only correspond to False values of
the condition are dropped from the result.
@@ -1124,7 +1126,16 @@ def where(
[15., nan, nan, nan]])
Dimensions without coordinates: x, y
- >>> a.where(lambda x: x.x + x.y < 4, drop=True)
+ >>> a.where(lambda x: x.x + x.y < 4, lambda x: -x)
+
+ array([[ 0, 1, 2, 3, -4],
+ [ 5, 6, 7, -8, -9],
+ [ 10, 11, -12, -13, -14],
+ [ 15, -16, -17, -18, -19],
+ [-20, -21, -22, -23, -24]])
+ Dimensions without coordinates: x, y
+
+ >>> a.where(a.x + a.y < 4, drop=True)
array([[ 0., 1., 2., 3.],
[ 5., 6., 7., nan],
@@ -1132,14 +1143,6 @@ def where(
[15., nan, nan, nan]])
Dimensions without coordinates: x, y
- >>> a.where(a.x + a.y < 4, -1, drop=True)
-
- array([[ 0, 1, 2, 3],
- [ 5, 6, 7, -1],
- [10, 11, -1, -1],
- [15, -1, -1, -1]])
- Dimensions without coordinates: x, y
-
See Also
--------
numpy.where : corresponding numpy function
@@ -1151,14 +1154,16 @@ def where(
if callable(cond):
cond = cond(self)
+ if callable(other):
+ other = other(self)
if drop:
if not isinstance(cond, (Dataset, DataArray)):
raise TypeError(
- f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r}"
+ f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r} (or a callable than returns one)."
)
- self, cond = align(self, cond) # type: ignore[assignment]
+ self, cond = align(self, cond)
def _dataarray_indexer(dim: Hashable) -> DataArray:
return cond.any(dim=(d for d in cond.dims if d != dim))
@@ -1205,9 +1210,7 @@ def close(self) -> None:
self._close()
self._close = None
- def isnull(
- self: T_DataWithCoords, keep_attrs: bool | None = None
- ) -> T_DataWithCoords:
+ def isnull(self, keep_attrs: bool | None = None) -> Self:
"""Test each value in the array for whether it is a missing value.
Parameters
@@ -1250,9 +1253,7 @@ def isnull(
keep_attrs=keep_attrs,
)
- def notnull(
- self: T_DataWithCoords, keep_attrs: bool | None = None
- ) -> T_DataWithCoords:
+ def notnull(self, keep_attrs: bool | None = None) -> Self:
"""Test each value in the array for whether it is not a missing value.
Parameters
@@ -1295,7 +1296,7 @@ def notnull(
keep_attrs=keep_attrs,
)
- def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords:
+ def isin(self, test_elements: Any) -> Self:
"""Tests each value in the array for whether it is in test elements.
Parameters
@@ -1344,7 +1345,7 @@ def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords:
)
def astype(
- self: T_DataWithCoords,
+ self,
dtype,
*,
order=None,
@@ -1352,7 +1353,7 @@ def astype(
subok=None,
copy=None,
keep_attrs=True,
- ) -> T_DataWithCoords:
+ ) -> Self:
"""
Copy of the xarray object, with data cast to a specified type.
Leaves coordinate dtype unchanged.
@@ -1419,7 +1420,7 @@ def astype(
dask="allowed",
)
- def __enter__(self: T_DataWithCoords) -> T_DataWithCoords:
+ def __enter__(self) -> Self:
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
diff --git a/xarray/core/computation.py b/xarray/core/computation.py
index 971f036b394..9cb60e0c424 100644
--- a/xarray/core/computation.py
+++ b/xarray/core/computation.py
@@ -8,7 +8,7 @@
import operator
import warnings
from collections import Counter
-from collections.abc import Hashable, Iterable, Mapping, Sequence, Set
+from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence, Set
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, overload
import numpy as np
@@ -163,7 +163,7 @@ def to_gufunc_string(self, exclude_dims=frozenset()):
if exclude_dims:
exclude_dims = [self.dims_map[dim] for dim in exclude_dims]
- counter = Counter()
+ counter: Counter = Counter()
def _enumerate(dim):
if dim in exclude_dims:
@@ -289,8 +289,14 @@ def apply_dataarray_vfunc(
from xarray.core.dataarray import DataArray
if len(args) > 1:
- args = deep_align(
- args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
+ args = tuple(
+ deep_align(
+ args,
+ join=join,
+ copy=False,
+ exclude=exclude_dims,
+ raise_on_invalid=False,
+ )
)
objs = _all_of_type(args, DataArray)
@@ -506,8 +512,14 @@ def apply_dataset_vfunc(
objs = _all_of_type(args, Dataset)
if len(args) > 1:
- args = deep_align(
- args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
+ args = tuple(
+ deep_align(
+ args,
+ join=join,
+ copy=False,
+ exclude=exclude_dims,
+ raise_on_invalid=False,
+ )
)
list_of_coords, list_of_indexes = build_output_coords_and_indexes(
@@ -571,7 +583,7 @@ def apply_groupby_func(func, *args):
assert groupbys, "must have at least one groupby to iterate over"
first_groupby = groupbys[0]
(grouper,) = first_groupby.groupers
- if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]):
+ if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]): # type: ignore[union-attr]
raise ValueError(
"apply_ufunc can only perform operations over "
"multiple GroupBy objects at once if they are all "
@@ -583,6 +595,7 @@ def apply_groupby_func(func, *args):
iterators = []
for arg in args:
+ iterator: Iterator[Any]
if isinstance(arg, GroupBy):
iterator = (value for _, value in arg)
elif hasattr(arg, "dims") and grouped_dim in arg.dims:
@@ -597,9 +610,9 @@ def apply_groupby_func(func, *args):
iterator = itertools.repeat(arg)
iterators.append(iterator)
- applied = (func(*zipped_args) for zipped_args in zip(*iterators))
+ applied: Iterator = (func(*zipped_args) for zipped_args in zip(*iterators))
applied_example, applied = peek_at(applied)
- combine = first_groupby._combine
+ combine = first_groupby._combine # type: ignore[attr-defined]
if isinstance(applied_example, tuple):
combined = tuple(combine(output) for output in zip(*applied))
else:
@@ -893,7 +906,7 @@ def apply_ufunc(
dataset_fill_value: object = _NO_FILL_VALUE,
keep_attrs: bool | str | None = None,
kwargs: Mapping | None = None,
- dask: str = "forbidden",
+ dask: Literal["forbidden", "allowed", "parallelized"] = "forbidden",
output_dtypes: Sequence | None = None,
output_sizes: Mapping[Any, int] | None = None,
meta: Any = None,
@@ -2122,7 +2135,8 @@ def _calc_idxminmax(
chunkmanager = get_chunked_array_type(array.data)
chunks = dict(zip(array.dims, array.chunks))
dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim])
- res = indx.copy(data=dask_coord[indx.data.ravel()].reshape(indx.shape))
+ data = dask_coord[duck_array_ops.ravel(indx.data)]
+ res = indx.copy(data=duck_array_ops.reshape(data, indx.shape))
# we need to attach back the dim name
res.name = dim
else:
diff --git a/xarray/core/concat.py b/xarray/core/concat.py
index a76bb6b0033..a136480b2fb 100644
--- a/xarray/core/concat.py
+++ b/xarray/core/concat.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from collections.abc import Hashable, Iterable
-from typing import TYPE_CHECKING, Any, Union, cast, overload
+from typing import TYPE_CHECKING, Any, Union, overload
import numpy as np
import pandas as pd
@@ -504,8 +504,7 @@ def _dataset_concat(
# case where concat dimension is a coordinate or data_var but not a dimension
if (dim in coord_names or dim in data_names) and dim not in dim_names:
- # TODO: Overriding type because .expand_dims has incorrect typing:
- datasets = [cast(T_Dataset, ds.expand_dims(dim)) for ds in datasets]
+ datasets = [ds.expand_dims(dim) for ds in datasets]
# determine which variables to concatenate
concat_over, equals, concat_dim_lengths = _calc_concat_over(
@@ -708,8 +707,7 @@ def _dataarray_concat(
if compat == "identical":
raise ValueError("array names not identical")
else:
- # TODO: Overriding type because .rename has incorrect typing:
- arr = cast(T_DataArray, arr.rename(name))
+ arr = arr.rename(name)
datasets.append(arr._to_temp_dataset())
ds = _dataset_concat(
diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py
index e20c022e637..0c85b2a2d69 100644
--- a/xarray/core/coordinates.py
+++ b/xarray/core/coordinates.py
@@ -23,7 +23,7 @@
create_default_index_implicit,
)
from xarray.core.merge import merge_coordinates_without_align, merge_coords
-from xarray.core.types import Self, T_DataArray
+from xarray.core.types import DataVars, Self, T_DataArray, T_Xarray
from xarray.core.utils import (
Frozen,
ReprObject,
@@ -425,7 +425,7 @@ def __delitem__(self, key: Hashable) -> None:
# redirect to DatasetCoordinates.__delitem__
del self._data.coords[key]
- def equals(self, other: Coordinates) -> bool:
+ def equals(self, other: Self) -> bool:
"""Two Coordinates objects are equal if they have matching variables,
all of which are equal.
@@ -437,7 +437,7 @@ def equals(self, other: Coordinates) -> bool:
return False
return self.to_dataset().equals(other.to_dataset())
- def identical(self, other: Coordinates) -> bool:
+ def identical(self, other: Self) -> bool:
"""Like equals, but also checks all variable attributes.
See Also
@@ -565,9 +565,7 @@ def update(self, other: Mapping[Any, Any]) -> None:
self._update_coords(coords, indexes)
- def assign(
- self, coords: Mapping | None = None, **coords_kwargs: Any
- ) -> Coordinates:
+ def assign(self, coords: Mapping | None = None, **coords_kwargs: Any) -> Self:
"""Assign new coordinates (and indexes) to a Coordinates object, returning
a new object with all the original coordinates in addition to the new ones.
@@ -656,7 +654,7 @@ def copy(
self,
deep: bool = False,
memo: dict[int, Any] | None = None,
- ) -> Coordinates:
+ ) -> Self:
"""Return a copy of this Coordinates object."""
# do not copy indexes (may corrupt multi-coordinate indexes)
# TODO: disable variables deepcopy? it may also be problematic when they
@@ -664,8 +662,16 @@ def copy(
variables = {
k: v._copy(deep=deep, memo=memo) for k, v in self.variables.items()
}
- return Coordinates._construct_direct(
- coords=variables, indexes=dict(self.xindexes), dims=dict(self.sizes)
+
+ # TODO: getting an error with `self._construct_direct`, possibly because of how
+ # a subclass implements `_construct_direct`. (This was originally the same
+ # runtime code, but we switched the type definitions in #8216, which
+ # necessitates the cast.)
+ return cast(
+ Self,
+ Coordinates._construct_direct(
+ coords=variables, indexes=dict(self.xindexes), dims=dict(self.sizes)
+ ),
)
@@ -915,9 +921,7 @@ def drop_indexed_coords(
return Coordinates._construct_direct(coords=new_variables, indexes=new_indexes)
-def assert_coordinate_consistent(
- obj: T_DataArray | Dataset, coords: Mapping[Any, Variable]
-) -> None:
+def assert_coordinate_consistent(obj: T_Xarray, coords: Mapping[Any, Variable]) -> None:
"""Make sure the dimension coordinate of obj is consistent with coords.
obj: DataArray or Dataset
@@ -933,7 +937,7 @@ def assert_coordinate_consistent(
def create_coords_with_default_indexes(
- coords: Mapping[Any, Any], data_vars: Mapping[Any, Any] | None = None
+ coords: Mapping[Any, Any], data_vars: DataVars | None = None
) -> Coordinates:
"""Returns a Coordinates object from a mapping of coordinates (arbitrary objects).
diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py
index 791aad5cd17..391b4ed9412 100644
--- a/xarray/core/dataarray.py
+++ b/xarray/core/dataarray.py
@@ -4,7 +4,15 @@
import warnings
from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence
from os import PathLike
-from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast, overload
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Generic,
+ Literal,
+ NoReturn,
+ overload,
+)
import numpy as np
import pandas as pd
@@ -41,6 +49,7 @@
from xarray.core.indexing import is_fancy_indexer, map_index_queries
from xarray.core.merge import PANDAS_TYPES, MergeError
from xarray.core.options import OPTIONS, _get_keep_attrs
+from xarray.core.types import DaCompatible, T_DataArray, T_DataArrayOrSet
from xarray.core.utils import (
Default,
HybridMappingProxy,
@@ -57,6 +66,7 @@
)
from xarray.plot.accessor import DataArrayPlotAccessor
from xarray.plot.utils import _get_units_from_attrs
+from xarray.util.deprecation_helpers import _deprecate_positional_args
if TYPE_CHECKING:
from typing import TypeVar, Union
@@ -100,8 +110,9 @@
QueryEngineOptions,
QueryParserOptions,
ReindexMethodOptions,
+ Self,
SideOptions,
- T_DataArray,
+ T_Chunks,
T_Xarray,
)
from xarray.core.weighted import DataArrayWeighted
@@ -119,7 +130,7 @@ def _check_coords_dims(shape, coords, dims):
f"dimensions {dims}"
)
- for d, s in zip(v.dims, v.shape):
+ for d, s in v.sizes.items():
if s != sizes[d]:
raise ValueError(
f"conflicting sizes for dimension {d!r}: "
@@ -127,13 +138,6 @@ def _check_coords_dims(shape, coords, dims):
f"coordinate {k!r}"
)
- if k in sizes and v.shape != (sizes[k],):
- raise ValueError(
- f"coordinate {k!r} is a DataArray dimension, but "
- f"it has shape {v.shape!r} rather than expected shape {sizes[k]!r} "
- "matching the dimension size"
- )
-
def _infer_coords_and_dims(
shape, coords, dims
@@ -213,13 +217,13 @@ def _check_data_shape(data, coords, dims):
return data
-class _LocIndexer:
+class _LocIndexer(Generic[T_DataArray]):
__slots__ = ("data_array",)
- def __init__(self, data_array: DataArray):
+ def __init__(self, data_array: T_DataArray):
self.data_array = data_array
- def __getitem__(self, key) -> DataArray:
+ def __getitem__(self, key) -> T_DataArray:
if not utils.is_dict_like(key):
# expand the indexer so we can handle Ellipsis
labels = indexing.expanded_indexer(key, self.data_array.ndim)
@@ -462,12 +466,12 @@ def __init__(
@classmethod
def _construct_direct(
- cls: type[T_DataArray],
+ cls,
variable: Variable,
coords: dict[Any, Variable],
name: Hashable,
indexes: dict[Hashable, Index],
- ) -> T_DataArray:
+ ) -> Self:
"""Shortcut around __init__ for internal use when we want to skip
costly validation
"""
@@ -480,12 +484,12 @@ def _construct_direct(
return obj
def _replace(
- self: T_DataArray,
+ self,
variable: Variable | None = None,
coords=None,
name: Hashable | None | Default = _default,
indexes=None,
- ) -> T_DataArray:
+ ) -> Self:
if variable is None:
variable = self.variable
if coords is None:
@@ -497,10 +501,10 @@ def _replace(
return type(self)(variable, coords, name=name, indexes=indexes, fastpath=True)
def _replace_maybe_drop_dims(
- self: T_DataArray,
+ self,
variable: Variable,
name: Hashable | None | Default = _default,
- ) -> T_DataArray:
+ ) -> Self:
if variable.dims == self.dims and variable.shape == self.shape:
coords = self._coords.copy()
indexes = self._indexes
@@ -522,12 +526,12 @@ def _replace_maybe_drop_dims(
return self._replace(variable, coords, name, indexes=indexes)
def _overwrite_indexes(
- self: T_DataArray,
+ self,
indexes: Mapping[Any, Index],
variables: Mapping[Any, Variable] | None = None,
drop_coords: list[Hashable] | None = None,
rename_dims: Mapping[Any, Any] | None = None,
- ) -> T_DataArray:
+ ) -> Self:
"""Maybe replace indexes and their corresponding coordinates."""
if not indexes:
return self
@@ -560,8 +564,8 @@ def _to_temp_dataset(self) -> Dataset:
return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False)
def _from_temp_dataset(
- self: T_DataArray, dataset: Dataset, name: Hashable | None | Default = _default
- ) -> T_DataArray:
+ self, dataset: Dataset, name: Hashable | None | Default = _default
+ ) -> Self:
variable = dataset._variables.pop(_THIS_ARRAY)
coords = dataset._variables
indexes = dataset._indexes
@@ -773,7 +777,7 @@ def to_numpy(self) -> np.ndarray:
"""
return self.variable.to_numpy()
- def as_numpy(self: T_DataArray) -> T_DataArray:
+ def as_numpy(self) -> Self:
"""
Coerces wrapped data and coordinates into numpy arrays, returning a DataArray.
@@ -828,7 +832,7 @@ def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]:
key = indexing.expanded_indexer(key, self.ndim)
return dict(zip(self.dims, key))
- def _getitem_coord(self: T_DataArray, key: Any) -> T_DataArray:
+ def _getitem_coord(self, key: Any) -> Self:
from xarray.core.dataset import _get_virtual_variable
try:
@@ -839,7 +843,7 @@ def _getitem_coord(self: T_DataArray, key: Any) -> T_DataArray:
return self._replace_maybe_drop_dims(var, name=key)
- def __getitem__(self: T_DataArray, key: Any) -> T_DataArray:
+ def __getitem__(self, key: Any) -> Self:
if isinstance(key, str):
return self._getitem_coord(key)
else:
@@ -909,10 +913,16 @@ def encoding(self) -> dict[Any, Any]:
def encoding(self, value: Mapping[Any, Any]) -> None:
self.variable.encoding = dict(value)
- def reset_encoding(self: T_DataArray) -> T_DataArray:
+ def reset_encoding(self) -> Self:
+ warnings.warn(
+ "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead"
+ )
+ return self.drop_encoding()
+
+ def drop_encoding(self) -> Self:
"""Return a new DataArray without encoding on the array or any attached
coords."""
- ds = self._to_temp_dataset().reset_encoding()
+ ds = self._to_temp_dataset().drop_encoding()
return self._from_temp_dataset(ds)
@property
@@ -949,26 +959,29 @@ def coords(self) -> DataArrayCoordinates:
@overload
def reset_coords(
- self: T_DataArray,
+ self,
names: Dims = None,
+ *,
drop: Literal[False] = False,
) -> Dataset:
...
@overload
def reset_coords(
- self: T_DataArray,
+ self,
names: Dims = None,
*,
drop: Literal[True],
- ) -> T_DataArray:
+ ) -> Self:
...
+ @_deprecate_positional_args("v2023.10.0")
def reset_coords(
- self: T_DataArray,
+ self,
names: Dims = None,
+ *,
drop: bool = False,
- ) -> T_DataArray | Dataset:
+ ) -> Self | Dataset:
"""Given names of coordinates, reset them to become variables.
Parameters
@@ -1080,15 +1093,15 @@ def __dask_postpersist__(self):
func, args = self._to_temp_dataset().__dask_postpersist__()
return self._dask_finalize, (self.name, func) + args
- @staticmethod
- def _dask_finalize(results, name, func, *args, **kwargs) -> DataArray:
+ @classmethod
+ def _dask_finalize(cls, results, name, func, *args, **kwargs) -> Self:
ds = func(results, *args, **kwargs)
variable = ds._variables.pop(_THIS_ARRAY)
coords = ds._variables
indexes = ds._indexes
- return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True)
+ return cls(variable, coords, name=name, indexes=indexes, fastpath=True)
- def load(self: T_DataArray, **kwargs) -> T_DataArray:
+ def load(self, **kwargs) -> Self:
"""Manually trigger loading of this array's data from disk or a
remote source into memory and return this array.
@@ -1112,7 +1125,7 @@ def load(self: T_DataArray, **kwargs) -> T_DataArray:
self._coords = new._coords
return self
- def compute(self: T_DataArray, **kwargs) -> T_DataArray:
+ def compute(self, **kwargs) -> Self:
"""Manually trigger loading of this array's data from disk or a
remote source into memory and return a new array. The original is
left unaltered.
@@ -1134,7 +1147,7 @@ def compute(self: T_DataArray, **kwargs) -> T_DataArray:
new = self.copy(deep=False)
return new.load(**kwargs)
- def persist(self: T_DataArray, **kwargs) -> T_DataArray:
+ def persist(self, **kwargs) -> Self:
"""Trigger computation in constituent dask arrays
This keeps them as dask arrays but encourages them to keep data in
@@ -1153,7 +1166,7 @@ def persist(self: T_DataArray, **kwargs) -> T_DataArray:
ds = self._to_temp_dataset().persist(**kwargs)
return self._from_temp_dataset(ds)
- def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray:
+ def copy(self, deep: bool = True, data: Any = None) -> Self:
"""Returns a copy of this array.
If `deep=True`, a deep copy is made of the data array.
@@ -1224,11 +1237,11 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray:
return self._copy(deep=deep, data=data)
def _copy(
- self: T_DataArray,
+ self,
deep: bool = True,
data: Any = None,
memo: dict[int, Any] | None = None,
- ) -> T_DataArray:
+ ) -> Self:
variable = self.variable._copy(deep=deep, data=data, memo=memo)
indexes, index_vars = self.xindexes.copy_indexes(deep=deep)
@@ -1241,12 +1254,10 @@ def _copy(
return self._replace(variable, coords, indexes=indexes)
- def __copy__(self: T_DataArray) -> T_DataArray:
+ def __copy__(self) -> Self:
return self._copy(deep=False)
- def __deepcopy__(
- self: T_DataArray, memo: dict[int, Any] | None = None
- ) -> T_DataArray:
+ def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
return self._copy(deep=True, memo=memo)
# mutable objects should not be Hashable
@@ -1286,15 +1297,11 @@ def chunksizes(self) -> Mapping[Any, tuple[int, ...]]:
all_variables = [self.variable] + [c.variable for c in self.coords.values()]
return get_chunksizes(all_variables)
+ @_deprecate_positional_args("v2023.10.0")
def chunk(
- self: T_DataArray,
- chunks: (
- int
- | Literal["auto"]
- | tuple[int, ...]
- | tuple[tuple[int, ...], ...]
- | Mapping[Any, None | int | tuple[int, ...]]
- ) = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667)
+ self,
+ chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667)
+ *,
name_prefix: str = "xarray-",
token: str | None = None,
lock: bool = False,
@@ -1302,7 +1309,7 @@ def chunk(
chunked_array_type: str | ChunkManagerEntrypoint | None = None,
from_array_kwargs=None,
**chunks_kwargs: Any,
- ) -> T_DataArray:
+ ) -> Self:
"""Coerce this array's data into a dask arrays with the given chunks.
If this variable is a non-dask array, it will be converted to dask
@@ -1362,7 +1369,7 @@ def chunk(
if isinstance(chunks, (float, str, int)):
# ignoring type; unclear why it won't accept a Literal into the value.
- chunks = dict.fromkeys(self.dims, chunks) # type: ignore
+ chunks = dict.fromkeys(self.dims, chunks)
elif isinstance(chunks, (tuple, list)):
chunks = dict(zip(self.dims, chunks))
else:
@@ -1380,12 +1387,12 @@ def chunk(
return self._from_temp_dataset(ds)
def isel(
- self: T_DataArray,
+ self,
indexers: Mapping[Any, Any] | None = None,
drop: bool = False,
missing_dims: ErrorOptionsWithWarn = "raise",
**indexers_kwargs: Any,
- ) -> T_DataArray:
+ ) -> Self:
"""Return a new DataArray whose data is given by selecting indexes
along the specified dimension(s).
@@ -1471,13 +1478,13 @@ def isel(
return self._replace(variable=variable, coords=coords, indexes=indexes)
def sel(
- self: T_DataArray,
+ self,
indexers: Mapping[Any, Any] | None = None,
method: str | None = None,
tolerance=None,
drop: bool = False,
**indexers_kwargs: Any,
- ) -> T_DataArray:
+ ) -> Self:
"""Return a new DataArray whose data is given by selecting index
labels along the specified dimension(s).
@@ -1590,10 +1597,10 @@ def sel(
return self._from_temp_dataset(ds)
def head(
- self: T_DataArray,
+ self,
indexers: Mapping[Any, int] | int | None = None,
**indexers_kwargs: Any,
- ) -> T_DataArray:
+ ) -> Self:
"""Return a new DataArray whose data is given by the the first `n`
values along the specified dimension(s). Default `n` = 5
@@ -1633,10 +1640,10 @@ def head(
return self._from_temp_dataset(ds)
def tail(
- self: T_DataArray,
+ self,
indexers: Mapping[Any, int] | int | None = None,
**indexers_kwargs: Any,
- ) -> T_DataArray:
+ ) -> Self:
"""Return a new DataArray whose data is given by the the last `n`
values along the specified dimension(s). Default `n` = 5
@@ -1680,10 +1687,10 @@ def tail(
return self._from_temp_dataset(ds)
def thin(
- self: T_DataArray,
+ self,
indexers: Mapping[Any, int] | int | None = None,
**indexers_kwargs: Any,
- ) -> T_DataArray:
+ ) -> Self:
"""Return a new DataArray whose data is given by each `n` value
along the specified dimension(s).
@@ -1729,11 +1736,13 @@ def thin(
ds = self._to_temp_dataset().thin(indexers, **indexers_kwargs)
return self._from_temp_dataset(ds)
+ @_deprecate_positional_args("v2023.10.0")
def broadcast_like(
- self: T_DataArray,
- other: DataArray | Dataset,
+ self,
+ other: T_DataArrayOrSet,
+ *,
exclude: Iterable[Hashable] | None = None,
- ) -> T_DataArray:
+ ) -> Self:
"""Broadcast this DataArray against another Dataset or DataArray.
This is equivalent to xr.broadcast(other, self)[1]
@@ -1803,12 +1812,10 @@ def broadcast_like(
dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude)
- return _broadcast_helper(
- cast("T_DataArray", args[1]), exclude, dims_map, common_coords
- )
+ return _broadcast_helper(args[1], exclude, dims_map, common_coords)
def _reindex_callback(
- self: T_DataArray,
+ self,
aligner: alignment.Aligner,
dim_pos_indexers: dict[Hashable, Any],
variables: dict[Hashable, Variable],
@@ -1816,7 +1823,7 @@ def _reindex_callback(
fill_value: Any,
exclude_dims: frozenset[Hashable],
exclude_vars: frozenset[Hashable],
- ) -> T_DataArray:
+ ) -> Self:
"""Callback called from ``Aligner`` to create a new reindexed DataArray."""
if isinstance(fill_value, dict):
@@ -1842,14 +1849,16 @@ def _reindex_callback(
return da
+ @_deprecate_positional_args("v2023.10.0")
def reindex_like(
- self: T_DataArray,
- other: DataArray | Dataset,
+ self,
+ other: T_DataArrayOrSet,
+ *,
method: ReindexMethodOptions = None,
tolerance: int | float | Iterable[int | float] | None = None,
copy: bool = True,
fill_value=dtypes.NA,
- ) -> T_DataArray:
+ ) -> Self:
"""Conform this object onto the indexes of another object, filling in
missing values with ``fill_value``. The default fill value is NaN.
@@ -2012,15 +2021,17 @@ def reindex_like(
fill_value=fill_value,
)
+ @_deprecate_positional_args("v2023.10.0")
def reindex(
- self: T_DataArray,
+ self,
indexers: Mapping[Any, Any] | None = None,
+ *,
method: ReindexMethodOptions = None,
tolerance: float | Iterable[float] | None = None,
copy: bool = True,
fill_value=dtypes.NA,
**indexers_kwargs: Any,
- ) -> T_DataArray:
+ ) -> Self:
"""Conform this object onto the indexes of another object, filling in
missing values with ``fill_value``. The default fill value is NaN.
@@ -2104,13 +2115,13 @@ def reindex(
)
def interp(
- self: T_DataArray,
+ self,
coords: Mapping[Any, Any] | None = None,
method: InterpOptions = "linear",
assume_sorted: bool = False,
kwargs: Mapping[str, Any] | None = None,
**coords_kwargs: Any,
- ) -> T_DataArray:
+ ) -> Self:
"""Interpolate a DataArray onto new coordinates
Performs univariate or multivariate interpolation of a DataArray onto
@@ -2247,12 +2258,12 @@ def interp(
return self._from_temp_dataset(ds)
def interp_like(
- self: T_DataArray,
- other: DataArray | Dataset,
+ self,
+ other: T_Xarray,
method: InterpOptions = "linear",
assume_sorted: bool = False,
kwargs: Mapping[str, Any] | None = None,
- ) -> T_DataArray:
+ ) -> Self:
"""Interpolate this object onto the coordinates of another object,
filling out of range values with NaN.
@@ -2369,13 +2380,11 @@ def interp_like(
)
return self._from_temp_dataset(ds)
- # change type of self and return to T_DataArray once
- # https://github.com/python/mypy/issues/12846 is resolved
def rename(
self,
new_name_or_name_dict: Hashable | Mapping[Any, Hashable] | None = None,
**names: Hashable,
- ) -> DataArray:
+ ) -> Self:
"""Returns a new DataArray with renamed coordinates, dimensions or a new name.
Parameters
@@ -2416,10 +2425,10 @@ def rename(
return self._replace(name=new_name_or_name_dict)
def swap_dims(
- self: T_DataArray,
+ self,
dims_dict: Mapping[Any, Hashable] | None = None,
**dims_kwargs,
- ) -> T_DataArray:
+ ) -> Self:
"""Returns a new DataArray with swapped dimensions.
Parameters
@@ -2474,14 +2483,12 @@ def swap_dims(
ds = self._to_temp_dataset().swap_dims(dims_dict)
return self._from_temp_dataset(ds)
- # change type of self and return to T_DataArray once
- # https://github.com/python/mypy/issues/12846 is resolved
def expand_dims(
self,
dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None,
axis: None | int | Sequence[int] = None,
**dim_kwargs: Any,
- ) -> DataArray:
+ ) -> Self:
"""Return a new object with an additional axis (or axes) inserted at
the corresponding position in the array shape. The new object is a
view into the underlying array, not a copy.
@@ -2570,14 +2577,12 @@ def expand_dims(
ds = self._to_temp_dataset().expand_dims(dim, axis)
return self._from_temp_dataset(ds)
- # change type of self and return to T_DataArray once
- # https://github.com/python/mypy/issues/12846 is resolved
def set_index(
self,
indexes: Mapping[Any, Hashable | Sequence[Hashable]] | None = None,
append: bool = False,
**indexes_kwargs: Hashable | Sequence[Hashable],
- ) -> DataArray:
+ ) -> Self:
"""Set DataArray (multi-)indexes using one or more existing
coordinates.
@@ -2635,13 +2640,11 @@ def set_index(
ds = self._to_temp_dataset().set_index(indexes, append=append, **indexes_kwargs)
return self._from_temp_dataset(ds)
- # change type of self and return to T_DataArray once
- # https://github.com/python/mypy/issues/12846 is resolved
def reset_index(
self,
dims_or_levels: Hashable | Sequence[Hashable],
drop: bool = False,
- ) -> DataArray:
+ ) -> Self:
"""Reset the specified index(es) or multi-index level(s).
This legacy method is specific to pandas (multi-)indexes and
@@ -2675,11 +2678,11 @@ def reset_index(
return self._from_temp_dataset(ds)
def set_xindex(
- self: T_DataArray,
+ self,
coord_names: str | Sequence[Hashable],
index_cls: type[Index] | None = None,
**options,
- ) -> T_DataArray:
+ ) -> Self:
"""Set a new, Xarray-compatible index from one or more existing
coordinate(s).
@@ -2704,10 +2707,10 @@ def set_xindex(
return self._from_temp_dataset(ds)
def reorder_levels(
- self: T_DataArray,
+ self,
dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None,
**dim_order_kwargs: Sequence[int | Hashable],
- ) -> T_DataArray:
+ ) -> Self:
"""Rearrange index levels using input order.
Parameters
@@ -2730,12 +2733,12 @@ def reorder_levels(
return self._from_temp_dataset(ds)
def stack(
- self: T_DataArray,
+ self,
dimensions: Mapping[Any, Sequence[Hashable]] | None = None,
create_index: bool | None = True,
index_cls: type[Index] = PandasMultiIndex,
**dimensions_kwargs: Sequence[Hashable],
- ) -> T_DataArray:
+ ) -> Self:
"""
Stack any number of existing dimensions into a single new dimension.
@@ -2802,14 +2805,14 @@ def stack(
)
return self._from_temp_dataset(ds)
- # change type of self and return to T_DataArray once
- # https://github.com/python/mypy/issues/12846 is resolved
+ @_deprecate_positional_args("v2023.10.0")
def unstack(
self,
dim: Dims = None,
+ *,
fill_value: Any = dtypes.NA,
sparse: bool = False,
- ) -> DataArray:
+ ) -> Self:
"""
Unstack existing dimensions corresponding to MultiIndexes into
multiple new dimensions.
@@ -2864,7 +2867,7 @@ def unstack(
--------
DataArray.stack
"""
- ds = self._to_temp_dataset().unstack(dim, fill_value, sparse)
+ ds = self._to_temp_dataset().unstack(dim, fill_value=fill_value, sparse=sparse)
return self._from_temp_dataset(ds)
def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Dataset:
@@ -2933,11 +2936,11 @@ def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Data
return Dataset(data_dict)
def transpose(
- self: T_DataArray,
+ self,
*dims: Hashable,
transpose_coords: bool = True,
missing_dims: ErrorOptionsWithWarn = "raise",
- ) -> T_DataArray:
+ ) -> Self:
"""Return a new DataArray object with transposed dimensions.
Parameters
@@ -2983,17 +2986,15 @@ def transpose(
return self._replace(variable)
@property
- def T(self: T_DataArray) -> T_DataArray:
+ def T(self) -> Self:
return self.transpose()
- # change type of self and return to T_DataArray once
- # https://github.com/python/mypy/issues/12846 is resolved
def drop_vars(
self,
names: Hashable | Iterable[Hashable],
*,
errors: ErrorOptions = "raise",
- ) -> DataArray:
+ ) -> Self:
"""Returns an array with dropped variables.
Parameters
@@ -3054,11 +3055,11 @@ def drop_vars(
return self._from_temp_dataset(ds)
def drop_indexes(
- self: T_DataArray,
+ self,
coord_names: Hashable | Iterable[Hashable],
*,
errors: ErrorOptions = "raise",
- ) -> T_DataArray:
+ ) -> Self:
"""Drop the indexes assigned to the given coordinates.
Parameters
@@ -3079,13 +3080,13 @@ def drop_indexes(
return self._from_temp_dataset(ds)
def drop(
- self: T_DataArray,
+ self,
labels: Mapping[Any, Any] | None = None,
dim: Hashable | None = None,
*,
errors: ErrorOptions = "raise",
**labels_kwargs,
- ) -> T_DataArray:
+ ) -> Self:
"""Backward compatible method based on `drop_vars` and `drop_sel`
Using either `drop_vars` or `drop_sel` is encouraged
@@ -3099,12 +3100,12 @@ def drop(
return self._from_temp_dataset(ds)
def drop_sel(
- self: T_DataArray,
+ self,
labels: Mapping[Any, Any] | None = None,
*,
errors: ErrorOptions = "raise",
**labels_kwargs,
- ) -> T_DataArray:
+ ) -> Self:
"""Drop index labels from this DataArray.
Parameters
@@ -3167,8 +3168,8 @@ def drop_sel(
return self._from_temp_dataset(ds)
def drop_isel(
- self: T_DataArray, indexers: Mapping[Any, Any] | None = None, **indexers_kwargs
- ) -> T_DataArray:
+ self, indexers: Mapping[Any, Any] | None = None, **indexers_kwargs
+ ) -> Self:
"""Drop index positions from this DataArray.
Parameters
@@ -3217,12 +3218,14 @@ def drop_isel(
dataset = dataset.drop_isel(indexers=indexers, **indexers_kwargs)
return self._from_temp_dataset(dataset)
+ @_deprecate_positional_args("v2023.10.0")
def dropna(
- self: T_DataArray,
+ self,
dim: Hashable,
+ *,
how: Literal["any", "all"] = "any",
thresh: int | None = None,
- ) -> T_DataArray:
+ ) -> Self:
"""Returns a new array with dropped labels for missing values along
the provided dimension.
@@ -3293,7 +3296,7 @@ def dropna(
ds = self._to_temp_dataset().dropna(dim, how=how, thresh=thresh)
return self._from_temp_dataset(ds)
- def fillna(self: T_DataArray, value: Any) -> T_DataArray:
+ def fillna(self, value: Any) -> Self:
"""Fill missing values in this object.
This operation follows the normal broadcasting and alignment rules that
@@ -3356,7 +3359,7 @@ def fillna(self: T_DataArray, value: Any) -> T_DataArray:
return out
def interpolate_na(
- self: T_DataArray,
+ self,
dim: Hashable | None = None,
method: InterpOptions = "linear",
limit: int | None = None,
@@ -3372,7 +3375,7 @@ def interpolate_na(
) = None,
keep_attrs: bool | None = None,
**kwargs: Any,
- ) -> T_DataArray:
+ ) -> Self:
"""Fill in NaNs by interpolating according to different methods.
Parameters
@@ -3479,9 +3482,7 @@ def interpolate_na(
**kwargs,
)
- def ffill(
- self: T_DataArray, dim: Hashable, limit: int | None = None
- ) -> T_DataArray:
+ def ffill(self, dim: Hashable, limit: int | None = None) -> Self:
"""Fill NaN values by propagating values forward
*Requires bottleneck.*
@@ -3565,9 +3566,7 @@ def ffill(
return ffill(self, dim, limit=limit)
- def bfill(
- self: T_DataArray, dim: Hashable, limit: int | None = None
- ) -> T_DataArray:
+ def bfill(self, dim: Hashable, limit: int | None = None) -> Self:
"""Fill NaN values by propagating values backward
*Requires bottleneck.*
@@ -3651,7 +3650,7 @@ def bfill(
return bfill(self, dim, limit=limit)
- def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray:
+ def combine_first(self, other: Self) -> Self:
"""Combine two DataArray objects, with union of coordinates.
This operation follows the normal broadcasting and alignment rules of
@@ -3670,7 +3669,7 @@ def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray:
return ops.fillna(self, other, join="outer")
def reduce(
- self: T_DataArray,
+ self,
func: Callable[..., Any],
dim: Dims = None,
*,
@@ -3678,7 +3677,7 @@ def reduce(
keep_attrs: bool | None = None,
keepdims: bool = False,
**kwargs: Any,
- ) -> T_DataArray:
+ ) -> Self:
"""Reduce this array by applying `func` along some dimension(s).
Parameters
@@ -3716,7 +3715,7 @@ def reduce(
var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)
return self._replace_maybe_drop_dims(var)
- def to_pandas(self) -> DataArray | pd.Series | pd.DataFrame:
+ def to_pandas(self) -> Self | pd.Series | pd.DataFrame:
"""Convert this array into a pandas object with the same shape.
The type of the returned object depends on the number of DataArray
@@ -4033,6 +4032,7 @@ def to_zarr(
mode: Literal["w", "w-", "a", "r+", None] = None,
synchronizer=None,
group: str | None = None,
+ *,
encoding: Mapping | None = None,
compute: Literal[True] = True,
consolidated: bool | None = None,
@@ -4073,6 +4073,7 @@ def to_zarr(
synchronizer=None,
group: str | None = None,
encoding: Mapping | None = None,
+ *,
compute: bool = True,
consolidated: bool | None = None,
append_dim: Hashable | None = None,
@@ -4270,7 +4271,7 @@ def to_dict(
return d
@classmethod
- def from_dict(cls: type[T_DataArray], d: Mapping[str, Any]) -> T_DataArray:
+ def from_dict(cls, d: Mapping[str, Any]) -> Self:
"""Convert a dictionary into an xarray.DataArray
Parameters
@@ -4387,7 +4388,7 @@ def to_cdms2(self) -> cdms2_Variable:
return to_cdms2(self)
@classmethod
- def from_cdms2(cls, variable: cdms2_Variable) -> DataArray:
+ def from_cdms2(cls, variable: cdms2_Variable) -> Self:
"""Convert a cdms2.Variable into an xarray.DataArray
.. deprecated:: 2023.06.0
@@ -4414,13 +4415,13 @@ def to_iris(self) -> iris_Cube:
return to_iris(self)
@classmethod
- def from_iris(cls, cube: iris_Cube) -> DataArray:
+ def from_iris(cls, cube: iris_Cube) -> Self:
"""Convert a iris.cube.Cube into an xarray.DataArray"""
from xarray.convert import from_iris
return from_iris(cube)
- def _all_compat(self: T_DataArray, other: T_DataArray, compat_str: str) -> bool:
+ def _all_compat(self, other: Self, compat_str: str) -> bool:
"""Helper function for equals, broadcast_equals, and identical"""
def compat(x, y):
@@ -4430,7 +4431,7 @@ def compat(x, y):
self, other
)
- def broadcast_equals(self: T_DataArray, other: T_DataArray) -> bool:
+ def broadcast_equals(self, other: Self) -> bool:
"""Two DataArrays are broadcast equal if they are equal after
broadcasting them against each other such that they have the same
dimensions.
@@ -4479,7 +4480,7 @@ def broadcast_equals(self: T_DataArray, other: T_DataArray) -> bool:
except (TypeError, AttributeError):
return False
- def equals(self: T_DataArray, other: T_DataArray) -> bool:
+ def equals(self, other: Self) -> bool:
"""True if two DataArrays have the same dimensions, coordinates and
values; otherwise False.
@@ -4541,7 +4542,7 @@ def equals(self: T_DataArray, other: T_DataArray) -> bool:
except (TypeError, AttributeError):
return False
- def identical(self: T_DataArray, other: T_DataArray) -> bool:
+ def identical(self, other: Self) -> bool:
"""Like equals, but also checks the array name and attributes, and
attributes on all coordinates.
@@ -4608,19 +4609,19 @@ def _result_name(self, other: Any = None) -> Hashable | None:
else:
return None
- def __array_wrap__(self: T_DataArray, obj, context=None) -> T_DataArray:
+ def __array_wrap__(self, obj, context=None) -> Self:
new_var = self.variable.__array_wrap__(obj, context)
return self._replace(new_var)
- def __matmul__(self: T_DataArray, obj: T_DataArray) -> T_DataArray:
+ def __matmul__(self, obj: T_Xarray) -> T_Xarray:
return self.dot(obj)
- def __rmatmul__(self: T_DataArray, other: T_DataArray) -> T_DataArray:
+ def __rmatmul__(self, other: T_Xarray) -> T_Xarray:
# currently somewhat duplicative, as only other DataArrays are
# compatible with matmul
return computation.dot(other, self)
- def _unary_op(self: T_DataArray, f: Callable, *args, **kwargs) -> T_DataArray:
+ def _unary_op(self, f: Callable, *args, **kwargs) -> Self:
keep_attrs = kwargs.pop("keep_attrs", None)
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)
@@ -4636,32 +4637,29 @@ def _unary_op(self: T_DataArray, f: Callable, *args, **kwargs) -> T_DataArray:
return da
def _binary_op(
- self: T_DataArray,
- other: Any,
- f: Callable,
- reflexive: bool = False,
- ) -> T_DataArray:
+ self, other: DaCompatible, f: Callable, reflexive: bool = False
+ ) -> Self:
from xarray.core.groupby import GroupBy
if isinstance(other, (Dataset, GroupBy)):
return NotImplemented
if isinstance(other, DataArray):
align_type = OPTIONS["arithmetic_join"]
- self, other = align(self, other, join=align_type, copy=False) # type: ignore
- other_variable = getattr(other, "variable", other)
+ self, other = align(self, other, join=align_type, copy=False)
+ other_variable_or_arraylike: DaCompatible = getattr(other, "variable", other)
other_coords = getattr(other, "coords", None)
variable = (
- f(self.variable, other_variable)
+ f(self.variable, other_variable_or_arraylike)
if not reflexive
- else f(other_variable, self.variable)
+ else f(other_variable_or_arraylike, self.variable)
)
coords, indexes = self.coords._merge_raw(other_coords, reflexive)
name = self._result_name(other)
return self._replace(variable, coords, name, indexes=indexes)
- def _inplace_binary_op(self: T_DataArray, other: Any, f: Callable) -> T_DataArray:
+ def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self:
from xarray.core.groupby import GroupBy
if isinstance(other, GroupBy):
@@ -4720,12 +4718,14 @@ def _title_for_slice(self, truncate: int = 50) -> str:
return title
+ @_deprecate_positional_args("v2023.10.0")
def diff(
- self: T_DataArray,
+ self,
dim: Hashable,
n: int = 1,
+ *,
label: Literal["upper", "lower"] = "upper",
- ) -> T_DataArray:
+ ) -> Self:
"""Calculate the n-th order discrete difference along given axis.
Parameters
@@ -4771,11 +4771,11 @@ def diff(
return self._from_temp_dataset(ds)
def shift(
- self: T_DataArray,
+ self,
shifts: Mapping[Any, int] | None = None,
fill_value: Any = dtypes.NA,
**shifts_kwargs: int,
- ) -> T_DataArray:
+ ) -> Self:
"""Shift this DataArray by an offset along one or more dimensions.
Only the data is moved; coordinates stay in place. This is consistent
@@ -4821,11 +4821,11 @@ def shift(
return self._replace(variable=variable)
def roll(
- self: T_DataArray,
+ self,
shifts: Mapping[Hashable, int] | None = None,
roll_coords: bool = False,
**shifts_kwargs: int,
- ) -> T_DataArray:
+ ) -> Self:
"""Roll this array by an offset along one or more dimensions.
Unlike shift, roll treats the given dimensions as periodic, so will not
@@ -4870,7 +4870,7 @@ def roll(
return self._from_temp_dataset(ds)
@property
- def real(self: T_DataArray) -> T_DataArray:
+ def real(self) -> Self:
"""
The real part of the array.
@@ -4881,7 +4881,7 @@ def real(self: T_DataArray) -> T_DataArray:
return self._replace(self.variable.real)
@property
- def imag(self: T_DataArray) -> T_DataArray:
+ def imag(self) -> Self:
"""
The imaginary part of the array.
@@ -4892,10 +4892,10 @@ def imag(self: T_DataArray) -> T_DataArray:
return self._replace(self.variable.imag)
def dot(
- self: T_DataArray,
- other: T_DataArray,
+ self,
+ other: T_Xarray,
dims: Dims = None,
- ) -> T_DataArray:
+ ) -> T_Xarray:
"""Perform dot product of two DataArrays along their shared dims.
Equivalent to taking taking tensordot over all shared dims.
@@ -4945,13 +4945,14 @@ def dot(
return computation.dot(self, other, dims=dims)
- # change type of self and return to T_DataArray once
- # https://github.com/python/mypy/issues/12846 is resolved
def sortby(
self,
- variables: Hashable | DataArray | Sequence[Hashable | DataArray],
+ variables: Hashable
+ | DataArray
+ | Sequence[Hashable | DataArray]
+ | Callable[[Self], Hashable | DataArray | Sequence[Hashable | DataArray]],
ascending: bool = True,
- ) -> DataArray:
+ ) -> Self:
"""Sort object by labels or values (along an axis).
Sorts the dataarray, either along specified dimensions,
@@ -4970,9 +4971,10 @@ def sortby(
Parameters
----------
- variables : Hashable, DataArray, or sequence of Hashable or DataArray
- 1D DataArray objects or name(s) of 1D variable(s) in
- coords whose values are used to sort this array.
+ variables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable
+ 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are
+ used to sort this array. If a callable, the callable is passed this object,
+ and the result is used as the value for cond.
ascending : bool, default: True
Whether to sort by ascending or descending order.
@@ -4992,34 +4994,47 @@ def sortby(
Examples
--------
>>> da = xr.DataArray(
- ... np.random.rand(5),
+ ... np.arange(5, 0, -1),
... coords=[pd.date_range("1/1/2000", periods=5)],
... dims="time",
... )
>>> da
- array([0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.4236548 ])
+ array([5, 4, 3, 2, 1])
Coordinates:
* time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05
>>> da.sortby(da)
- array([0.4236548 , 0.54488318, 0.5488135 , 0.60276338, 0.71518937])
+ array([1, 2, 3, 4, 5])
Coordinates:
- * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-02
+ * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-01
+
+ >>> da.sortby(lambda x: x)
+
+ array([1, 2, 3, 4, 5])
+ Coordinates:
+ * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-01
"""
+ # We need to convert the callable here rather than pass it through to the
+ # dataset method, since otherwise the dataset method would try to call the
+ # callable with the dataset as the object
+ if callable(variables):
+ variables = variables(self)
ds = self._to_temp_dataset().sortby(variables, ascending=ascending)
return self._from_temp_dataset(ds)
+ @_deprecate_positional_args("v2023.10.0")
def quantile(
- self: T_DataArray,
+ self,
q: ArrayLike,
dim: Dims = None,
+ *,
method: QuantileMethods = "linear",
keep_attrs: bool | None = None,
skipna: bool | None = None,
interpolation: QuantileMethods | None = None,
- ) -> T_DataArray:
+ ) -> Self:
"""Compute the qth quantile of the data along the specified dimension.
Returns the qth quantiles(s) of the array elements.
@@ -5129,12 +5144,14 @@ def quantile(
)
return self._from_temp_dataset(ds)
+ @_deprecate_positional_args("v2023.10.0")
def rank(
- self: T_DataArray,
+ self,
dim: Hashable,
+ *,
pct: bool = False,
keep_attrs: bool | None = None,
- ) -> T_DataArray:
+ ) -> Self:
"""Ranks the data.
Equal values are assigned a rank that is the average of the ranks that
@@ -5174,11 +5191,11 @@ def rank(
return self._from_temp_dataset(ds)
def differentiate(
- self: T_DataArray,
+ self,
coord: Hashable,
edge_order: Literal[1, 2] = 1,
datetime_unit: DatetimeUnitOptions = None,
- ) -> T_DataArray:
+ ) -> Self:
""" Differentiate the array with the second order accurate central
differences.
@@ -5236,13 +5253,11 @@ def differentiate(
ds = self._to_temp_dataset().differentiate(coord, edge_order, datetime_unit)
return self._from_temp_dataset(ds)
- # change type of self and return to T_DataArray once
- # https://github.com/python/mypy/issues/12846 is resolved
def integrate(
self,
coord: Hashable | Sequence[Hashable] = None,
datetime_unit: DatetimeUnitOptions = None,
- ) -> DataArray:
+ ) -> Self:
"""Integrate along the given coordinate using the trapezoidal rule.
.. note::
@@ -5292,13 +5307,11 @@ def integrate(
ds = self._to_temp_dataset().integrate(coord, datetime_unit)
return self._from_temp_dataset(ds)
- # change type of self and return to T_DataArray once
- # https://github.com/python/mypy/issues/12846 is resolved
def cumulative_integrate(
self,
coord: Hashable | Sequence[Hashable] = None,
datetime_unit: DatetimeUnitOptions = None,
- ) -> DataArray:
+ ) -> Self:
"""Integrate cumulatively along the given coordinate using the trapezoidal rule.
.. note::
@@ -5356,7 +5369,7 @@ def cumulative_integrate(
ds = self._to_temp_dataset().cumulative_integrate(coord, datetime_unit)
return self._from_temp_dataset(ds)
- def unify_chunks(self) -> DataArray:
+ def unify_chunks(self) -> Self:
"""Unify chunk size along all chunked dimensions of this DataArray.
Returns
@@ -5541,7 +5554,7 @@ def polyfit(
)
def pad(
- self: T_DataArray,
+ self,
pad_width: Mapping[Any, int | tuple[int, int]] | None = None,
mode: PadModeOptions = "constant",
stat_length: int
@@ -5556,7 +5569,7 @@ def pad(
reflect_type: PadReflectOptions = None,
keep_attrs: bool | None = None,
**pad_width_kwargs: Any,
- ) -> T_DataArray:
+ ) -> Self:
"""Pad this array along one or more dimensions.
.. warning::
@@ -5708,13 +5721,15 @@ def pad(
)
return self._from_temp_dataset(ds)
+ @_deprecate_positional_args("v2023.10.0")
def idxmin(
self,
dim: Hashable | None = None,
+ *,
skipna: bool | None = None,
fill_value: Any = dtypes.NA,
keep_attrs: bool | None = None,
- ) -> DataArray:
+ ) -> Self:
"""Return the coordinate label of the minimum value along a dimension.
Returns a new `DataArray` named after the dimension with the values of
@@ -5804,13 +5819,15 @@ def idxmin(
keep_attrs=keep_attrs,
)
+ @_deprecate_positional_args("v2023.10.0")
def idxmax(
self,
dim: Hashable = None,
+ *,
skipna: bool | None = None,
fill_value: Any = dtypes.NA,
keep_attrs: bool | None = None,
- ) -> DataArray:
+ ) -> Self:
"""Return the coordinate label of the maximum value along a dimension.
Returns a new `DataArray` named after the dimension with the values of
@@ -5900,15 +5917,15 @@ def idxmax(
keep_attrs=keep_attrs,
)
- # change type of self and return to T_DataArray once
- # https://github.com/python/mypy/issues/12846 is resolved
+ @_deprecate_positional_args("v2023.10.0")
def argmin(
self,
dim: Dims = None,
+ *,
axis: int | None = None,
keep_attrs: bool | None = None,
skipna: bool | None = None,
- ) -> DataArray | dict[Hashable, DataArray]:
+ ) -> Self | dict[Hashable, Self]:
"""Index or indices of the minimum of the DataArray over one or more dimensions.
If a sequence is passed to 'dim', then result returned as dict of DataArrays,
@@ -6002,15 +6019,15 @@ def argmin(
else:
return self._replace_maybe_drop_dims(result)
- # change type of self and return to T_DataArray once
- # https://github.com/python/mypy/issues/12846 is resolved
+ @_deprecate_positional_args("v2023.10.0")
def argmax(
self,
dim: Dims = None,
+ *,
axis: int | None = None,
keep_attrs: bool | None = None,
skipna: bool | None = None,
- ) -> DataArray | dict[Hashable, DataArray]:
+ ) -> Self | dict[Hashable, Self]:
"""Index or indices of the maximum of the DataArray over one or more dimensions.
If a sequence is passed to 'dim', then result returned as dict of DataArrays,
@@ -6351,11 +6368,13 @@ def curvefit(
kwargs=kwargs,
)
+ @_deprecate_positional_args("v2023.10.0")
def drop_duplicates(
- self: T_DataArray,
+ self,
dim: Hashable | Iterable[Hashable],
+ *,
keep: Literal["first", "last", False] = "first",
- ) -> T_DataArray:
+ ) -> Self:
"""Returns a new DataArray with duplicate dimension values removed.
Parameters
@@ -6437,7 +6456,7 @@ def convert_calendar(
align_on: str | None = None,
missing: Any | None = None,
use_cftime: bool | None = None,
- ) -> DataArray:
+ ) -> Self:
"""Convert the DataArray to another calendar.
Only converts the individual timestamps, does not modify any data except
@@ -6557,7 +6576,7 @@ def interp_calendar(
self,
target: pd.DatetimeIndex | CFTimeIndex | DataArray,
dim: str = "time",
- ) -> DataArray:
+ ) -> Self:
"""Interpolates the DataArray to another calendar based on decimal year measure.
Each timestamp in `source` and `target` are first converted to their decimal
diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index 48e25f7e1c7..ebd6fb6f51f 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -93,7 +93,14 @@
is_duck_array,
is_duck_dask_array,
)
-from xarray.core.types import QuantileMethods, T_Dataset
+from xarray.core.types import (
+ QuantileMethods,
+ Self,
+ T_ChunkDim,
+ T_Chunks,
+ T_DataArrayOrSet,
+ T_Dataset,
+)
from xarray.core.utils import (
Default,
Frozen,
@@ -116,6 +123,7 @@
calculate_dimensions,
)
from xarray.plot.accessor import DatasetPlotAccessor
+from xarray.util.deprecation_helpers import _deprecate_positional_args
if TYPE_CHECKING:
from numpy.typing import ArrayLike
@@ -124,7 +132,7 @@
from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes
from xarray.core.dataarray import DataArray
from xarray.core.groupby import DatasetGroupBy
- from xarray.core.merge import CoercibleMapping, CoercibleValue
+ from xarray.core.merge import CoercibleMapping, CoercibleValue, _MergeResult
from xarray.core.parallelcompat import ChunkManagerEntrypoint
from xarray.core.resample import DatasetResample
from xarray.core.rolling import DatasetCoarsen, DatasetRolling
@@ -133,6 +141,7 @@
CoarsenBoundaryOptions,
CombineAttrsOptions,
CompatOptions,
+ DataVars,
DatetimeLike,
DatetimeUnitOptions,
Dims,
@@ -404,7 +413,7 @@ def _initialize_feasible(lb, ub):
return param_defaults, bounds_defaults
-def merge_data_and_coords(data_vars, coords):
+def merge_data_and_coords(data_vars: DataVars, coords) -> _MergeResult:
"""Used in Dataset.__init__."""
if isinstance(coords, Coordinates):
coords = coords.copy()
@@ -666,7 +675,7 @@ def __init__(
self,
# could make a VariableArgs to use more generally, and refine these
# categories
- data_vars: Mapping[Any, Any] | None = None,
+ data_vars: DataVars | None = None,
coords: Mapping[Any, Any] | None = None,
attrs: Mapping[Any, Any] | None = None,
) -> None:
@@ -698,11 +707,11 @@ def __init__(
# TODO: dirty workaround for mypy 1.5 error with inherited DatasetOpsMixin vs. Mapping
# related to https://github.com/python/mypy/issues/9319?
- def __eq__(self: T_Dataset, other: DsCompatible) -> T_Dataset: # type: ignore[override]
+ def __eq__(self, other: DsCompatible) -> Self: # type: ignore[override]
return super().__eq__(other)
@classmethod
- def load_store(cls: type[T_Dataset], store, decoder=None) -> T_Dataset:
+ def load_store(cls, store, decoder=None) -> Self:
"""Create a new dataset from the contents of a backends.*DataStore
object
"""
@@ -746,10 +755,16 @@ def encoding(self) -> dict[Any, Any]:
def encoding(self, value: Mapping[Any, Any]) -> None:
self._encoding = dict(value)
- def reset_encoding(self: T_Dataset) -> T_Dataset:
+ def reset_encoding(self) -> Self:
+ warnings.warn(
+ "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead"
+ )
+ return self.drop_encoding()
+
+ def drop_encoding(self) -> Self:
"""Return a new Dataset without encoding on the dataset or any of its
variables/coords."""
- variables = {k: v.reset_encoding() for k, v in self.variables.items()}
+ variables = {k: v.drop_encoding() for k, v in self.variables.items()}
return self._replace(variables=variables, encoding={})
@property
@@ -802,7 +817,7 @@ def dtypes(self) -> Frozen[Hashable, np.dtype]:
}
)
- def load(self: T_Dataset, **kwargs) -> T_Dataset:
+ def load(self, **kwargs) -> Self:
"""Manually trigger loading and/or computation of this dataset's data
from disk or a remote source into memory and return this dataset.
Unlike compute, the original dataset is modified and returned.
@@ -902,7 +917,7 @@ def __dask_postcompute__(self):
def __dask_postpersist__(self):
return self._dask_postpersist, ()
- def _dask_postcompute(self: T_Dataset, results: Iterable[Variable]) -> T_Dataset:
+ def _dask_postcompute(self, results: Iterable[Variable]) -> Self:
import dask
variables = {}
@@ -925,8 +940,8 @@ def _dask_postcompute(self: T_Dataset, results: Iterable[Variable]) -> T_Dataset
)
def _dask_postpersist(
- self: T_Dataset, dsk: Mapping, *, rename: Mapping[str, str] | None = None
- ) -> T_Dataset:
+ self, dsk: Mapping, *, rename: Mapping[str, str] | None = None
+ ) -> Self:
from dask import is_dask_collection
from dask.highlevelgraph import HighLevelGraph
from dask.optimization import cull
@@ -975,7 +990,7 @@ def _dask_postpersist(
self._close,
)
- def compute(self: T_Dataset, **kwargs) -> T_Dataset:
+ def compute(self, **kwargs) -> Self:
"""Manually trigger loading and/or computation of this dataset's data
from disk or a remote source into memory and return a new dataset.
Unlike load, the original dataset is left unaltered.
@@ -997,7 +1012,7 @@ def compute(self: T_Dataset, **kwargs) -> T_Dataset:
new = self.copy(deep=False)
return new.load(**kwargs)
- def _persist_inplace(self: T_Dataset, **kwargs) -> T_Dataset:
+ def _persist_inplace(self, **kwargs) -> Self:
"""Persist all Dask arrays in memory"""
# access .data to coerce everything to numpy or dask arrays
lazy_data = {
@@ -1014,7 +1029,7 @@ def _persist_inplace(self: T_Dataset, **kwargs) -> T_Dataset:
return self
- def persist(self: T_Dataset, **kwargs) -> T_Dataset:
+ def persist(self, **kwargs) -> Self:
"""Trigger computation, keeping data as dask arrays
This operation can be used to trigger computation on underlying dask
@@ -1037,7 +1052,7 @@ def persist(self: T_Dataset, **kwargs) -> T_Dataset:
@classmethod
def _construct_direct(
- cls: type[T_Dataset],
+ cls,
variables: dict[Any, Variable],
coord_names: set[Hashable],
dims: dict[Any, int] | None = None,
@@ -1045,7 +1060,7 @@ def _construct_direct(
indexes: dict[Any, Index] | None = None,
encoding: dict | None = None,
close: Callable[[], None] | None = None,
- ) -> T_Dataset:
+ ) -> Self:
"""Shortcut around __init__ for internal use when we want to skip
costly validation
"""
@@ -1064,7 +1079,7 @@ def _construct_direct(
return obj
def _replace(
- self: T_Dataset,
+ self,
variables: dict[Hashable, Variable] | None = None,
coord_names: set[Hashable] | None = None,
dims: dict[Any, int] | None = None,
@@ -1072,7 +1087,7 @@ def _replace(
indexes: dict[Hashable, Index] | None = None,
encoding: dict | None | Default = _default,
inplace: bool = False,
- ) -> T_Dataset:
+ ) -> Self:
"""Fastpath constructor for internal use.
Returns an object with optionally with replaced attributes.
@@ -1114,13 +1129,13 @@ def _replace(
return obj
def _replace_with_new_dims(
- self: T_Dataset,
+ self,
variables: dict[Hashable, Variable],
coord_names: set | None = None,
attrs: dict[Hashable, Any] | None | Default = _default,
indexes: dict[Hashable, Index] | None = None,
inplace: bool = False,
- ) -> T_Dataset:
+ ) -> Self:
"""Replace variables with recalculated dimensions."""
dims = calculate_dimensions(variables)
return self._replace(
@@ -1128,13 +1143,13 @@ def _replace_with_new_dims(
)
def _replace_vars_and_dims(
- self: T_Dataset,
+ self,
variables: dict[Hashable, Variable],
coord_names: set | None = None,
dims: dict[Hashable, int] | None = None,
attrs: dict[Hashable, Any] | None | Default = _default,
inplace: bool = False,
- ) -> T_Dataset:
+ ) -> Self:
"""Deprecated version of _replace_with_new_dims().
Unlike _replace_with_new_dims(), this method always recalculates
@@ -1147,13 +1162,13 @@ def _replace_vars_and_dims(
)
def _overwrite_indexes(
- self: T_Dataset,
+ self,
indexes: Mapping[Hashable, Index],
variables: Mapping[Hashable, Variable] | None = None,
drop_variables: list[Hashable] | None = None,
drop_indexes: list[Hashable] | None = None,
rename_dims: Mapping[Hashable, Hashable] | None = None,
- ) -> T_Dataset:
+ ) -> Self:
"""Maybe replace indexes.
This function may do a lot more depending on index query
@@ -1220,9 +1235,7 @@ def _overwrite_indexes(
else:
return replaced
- def copy(
- self: T_Dataset, deep: bool = False, data: Mapping[Any, ArrayLike] | None = None
- ) -> T_Dataset:
+ def copy(self, deep: bool = False, data: DataVars | None = None) -> Self:
"""Returns a copy of this dataset.
If `deep=True`, a deep copy is made of each of the component variables.
@@ -1322,11 +1335,11 @@ def copy(
return self._copy(deep=deep, data=data)
def _copy(
- self: T_Dataset,
+ self,
deep: bool = False,
- data: Mapping[Any, ArrayLike] | None = None,
+ data: DataVars | None = None,
memo: dict[int, Any] | None = None,
- ) -> T_Dataset:
+ ) -> Self:
if data is None:
data = {}
elif not utils.is_dict_like(data):
@@ -1364,13 +1377,13 @@ def _copy(
return self._replace(variables, indexes=indexes, attrs=attrs, encoding=encoding)
- def __copy__(self: T_Dataset) -> T_Dataset:
+ def __copy__(self) -> Self:
return self._copy(deep=False)
- def __deepcopy__(self: T_Dataset, memo: dict[int, Any] | None = None) -> T_Dataset:
+ def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
return self._copy(deep=True, memo=memo)
- def as_numpy(self: T_Dataset) -> T_Dataset:
+ def as_numpy(self) -> Self:
"""
Coerces wrapped data and coordinates into numpy arrays, returning a Dataset.
@@ -1382,7 +1395,7 @@ def as_numpy(self: T_Dataset) -> T_Dataset:
numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()}
return self._replace(variables=numpy_variables)
- def _copy_listed(self: T_Dataset, names: Iterable[Hashable]) -> T_Dataset:
+ def _copy_listed(self, names: Iterable[Hashable]) -> Self:
"""Create a new Dataset with the listed variables from this dataset and
the all relevant coordinates. Skips all validation.
"""
@@ -1476,13 +1489,20 @@ def __bool__(self) -> bool:
def __iter__(self) -> Iterator[Hashable]:
return iter(self.data_vars)
- def __array__(self, dtype=None):
- raise TypeError(
- "cannot directly convert an xarray.Dataset into a "
- "numpy array. Instead, create an xarray.DataArray "
- "first, either with indexing on the Dataset or by "
- "invoking the `to_array()` method."
- )
+ if TYPE_CHECKING:
+ # needed because __getattr__ is returning Any and otherwise
+ # this class counts as part of the SupportsArray Protocol
+ __array__ = None # type: ignore[var-annotated,unused-ignore]
+
+ else:
+
+ def __array__(self, dtype=None):
+ raise TypeError(
+ "cannot directly convert an xarray.Dataset into a "
+ "numpy array. Instead, create an xarray.DataArray "
+ "first, either with indexing on the Dataset or by "
+ "invoking the `to_array()` method."
+ )
@property
def nbytes(self) -> int:
@@ -1495,7 +1515,7 @@ def nbytes(self) -> int:
return sum(v.nbytes for v in self.variables.values())
@property
- def loc(self: T_Dataset) -> _LocIndexer[T_Dataset]:
+ def loc(self) -> _LocIndexer[Self]:
"""Attribute for location based indexing. Only supports __getitem__,
and only when the key is a dict of the form {dim: labels}.
"""
@@ -1507,12 +1527,12 @@ def __getitem__(self, key: Hashable) -> DataArray:
# Mapping is Iterable
@overload
- def __getitem__(self: T_Dataset, key: Iterable[Hashable]) -> T_Dataset:
+ def __getitem__(self, key: Iterable[Hashable]) -> Self:
...
def __getitem__(
- self: T_Dataset, key: Mapping[Any, Any] | Hashable | Iterable[Hashable]
- ) -> T_Dataset | DataArray:
+ self, key: Mapping[Any, Any] | Hashable | Iterable[Hashable]
+ ) -> Self | DataArray:
"""Access variables or coordinates of this dataset as a
:py:class:`~xarray.DataArray` or a subset of variables or a indexed dataset.
@@ -1677,7 +1697,7 @@ def __delitem__(self, key: Hashable) -> None:
# https://github.com/python/mypy/issues/4266
__hash__ = None # type: ignore[assignment]
- def _all_compat(self, other: Dataset, compat_str: str) -> bool:
+ def _all_compat(self, other: Self, compat_str: str) -> bool:
"""Helper function for equals and identical"""
# some stores (e.g., scipy) do not seem to preserve order, so don't
@@ -1689,7 +1709,7 @@ def compat(x: Variable, y: Variable) -> bool:
self._variables, other._variables, compat=compat
)
- def broadcast_equals(self, other: Dataset) -> bool:
+ def broadcast_equals(self, other: Self) -> bool:
"""Two Datasets are broadcast equal if they are equal after
broadcasting all variables against each other.
@@ -1756,7 +1776,7 @@ def broadcast_equals(self, other: Dataset) -> bool:
except (TypeError, AttributeError):
return False
- def equals(self, other: Dataset) -> bool:
+ def equals(self, other: Self) -> bool:
"""Two Datasets are equal if they have matching variables and
coordinates, all of which are equal.
@@ -1837,7 +1857,7 @@ def equals(self, other: Dataset) -> bool:
except (TypeError, AttributeError):
return False
- def identical(self, other: Dataset) -> bool:
+ def identical(self, other: Self) -> bool:
"""Like equals, but also checks all dataset attributes and the
attributes on all variables and coordinates.
@@ -1950,7 +1970,7 @@ def data_vars(self) -> DataVariables:
"""Dictionary of DataArray objects corresponding to data variables"""
return DataVariables(self)
- def set_coords(self: T_Dataset, names: Hashable | Iterable[Hashable]) -> T_Dataset:
+ def set_coords(self, names: Hashable | Iterable[Hashable]) -> Self:
"""Given names of one or more variables, set them as coordinates
Parameters
@@ -2008,10 +2028,10 @@ def set_coords(self: T_Dataset, names: Hashable | Iterable[Hashable]) -> T_Datas
return obj
def reset_coords(
- self: T_Dataset,
+ self,
names: Dims = None,
drop: bool = False,
- ) -> T_Dataset:
+ ) -> Self:
"""Given names of coordinates, reset them to become variables
Parameters
@@ -2280,6 +2300,7 @@ def to_zarr(
synchronizer=None,
group: str | None = None,
encoding: Mapping | None = None,
+ *,
compute: Literal[True] = True,
consolidated: bool | None = None,
append_dim: Hashable | None = None,
@@ -2323,6 +2344,7 @@ def to_zarr(
synchronizer=None,
group: str | None = None,
encoding: Mapping | None = None,
+ *,
compute: bool = True,
consolidated: bool | None = None,
append_dim: Hashable | None = None,
@@ -2562,18 +2584,16 @@ def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]:
return get_chunksizes(self.variables.values())
def chunk(
- self: T_Dataset,
- chunks: (
- int | Literal["auto"] | Mapping[Any, None | int | str | tuple[int, ...]]
- ) = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667)
+ self,
+ chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667)
name_prefix: str = "xarray-",
token: str | None = None,
lock: bool = False,
inline_array: bool = False,
chunked_array_type: str | ChunkManagerEntrypoint | None = None,
from_array_kwargs=None,
- **chunks_kwargs: None | int | str | tuple[int, ...],
- ) -> T_Dataset:
+ **chunks_kwargs: T_ChunkDim,
+ ) -> Self:
"""Coerce all arrays in this dataset into dask arrays with the given
chunks.
@@ -2623,20 +2643,20 @@ def chunk(
xarray.unify_chunks
dask.array.from_array
"""
- if chunks is None and chunks_kwargs is None:
+ if chunks is None and not chunks_kwargs:
warnings.warn(
"None value for 'chunks' is deprecated. "
"It will raise an error in the future. Use instead '{}'",
category=FutureWarning,
)
chunks = {}
-
- if isinstance(chunks, (Number, str, int)):
- chunks = dict.fromkeys(self.dims, chunks)
+ chunks_mapping: Mapping[Any, Any]
+ if not isinstance(chunks, Mapping) and chunks is not None:
+ chunks_mapping = dict.fromkeys(self.dims, chunks)
else:
- chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")
+ chunks_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")
- bad_dims = chunks.keys() - self.dims.keys()
+ bad_dims = chunks_mapping.keys() - self.dims.keys()
if bad_dims:
raise ValueError(
f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.dims)}"
@@ -2650,7 +2670,7 @@ def chunk(
k: _maybe_chunk(
k,
v,
- chunks,
+ chunks_mapping,
token,
lock,
name_prefix,
@@ -2767,12 +2787,12 @@ def _get_indexers_coords_and_indexes(self, indexers):
return attached_coords, attached_indexes
def isel(
- self: T_Dataset,
+ self,
indexers: Mapping[Any, Any] | None = None,
drop: bool = False,
missing_dims: ErrorOptionsWithWarn = "raise",
**indexers_kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""Returns a new dataset with each array indexed along the specified
dimension(s).
@@ -2915,12 +2935,12 @@ def isel(
)
def _isel_fancy(
- self: T_Dataset,
+ self,
indexers: Mapping[Any, Any],
*,
drop: bool,
missing_dims: ErrorOptionsWithWarn = "raise",
- ) -> T_Dataset:
+ ) -> Self:
valid_indexers = dict(self._validate_indexers(indexers, missing_dims))
variables: dict[Hashable, Variable] = {}
@@ -2956,13 +2976,13 @@ def _isel_fancy(
return self._replace_with_new_dims(variables, coord_names, indexes=indexes)
def sel(
- self: T_Dataset,
+ self,
indexers: Mapping[Any, Any] | None = None,
method: str | None = None,
tolerance: int | float | Iterable[int | float] | None = None,
drop: bool = False,
**indexers_kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""Returns a new dataset with each array indexed by tick labels
along the specified dimension(s).
@@ -3042,10 +3062,10 @@ def sel(
return result._overwrite_indexes(*query_results.as_tuple()[1:])
def head(
- self: T_Dataset,
+ self,
indexers: Mapping[Any, int] | int | None = None,
**indexers_kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""Returns a new dataset with the first `n` values of each array
for the specified dimension(s).
@@ -3132,10 +3152,10 @@ def head(
return self.isel(indexers_slices)
def tail(
- self: T_Dataset,
+ self,
indexers: Mapping[Any, int] | int | None = None,
**indexers_kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""Returns a new dataset with the last `n` values of each array
for the specified dimension(s).
@@ -3223,10 +3243,10 @@ def tail(
return self.isel(indexers_slices)
def thin(
- self: T_Dataset,
+ self,
indexers: Mapping[Any, int] | int | None = None,
**indexers_kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""Returns a new dataset with each array indexed along every `n`-th
value for the specified dimension(s)
@@ -3308,10 +3328,10 @@ def thin(
return self.isel(indexers_slices)
def broadcast_like(
- self: T_Dataset,
- other: Dataset | DataArray,
+ self,
+ other: T_DataArrayOrSet,
exclude: Iterable[Hashable] | None = None,
- ) -> T_Dataset:
+ ) -> Self:
"""Broadcast this DataArray against another Dataset or DataArray.
This is equivalent to xr.broadcast(other, self)[1]
@@ -3331,12 +3351,10 @@ def broadcast_like(
dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude)
- return _broadcast_helper(
- cast("T_Dataset", args[1]), exclude, dims_map, common_coords
- )
+ return _broadcast_helper(args[1], exclude, dims_map, common_coords)
def _reindex_callback(
- self: T_Dataset,
+ self,
aligner: alignment.Aligner,
dim_pos_indexers: dict[Hashable, Any],
variables: dict[Hashable, Variable],
@@ -3344,7 +3362,7 @@ def _reindex_callback(
fill_value: Any,
exclude_dims: frozenset[Hashable],
exclude_vars: frozenset[Hashable],
- ) -> T_Dataset:
+ ) -> Self:
"""Callback called from ``Aligner`` to create a new reindexed Dataset."""
new_variables = variables.copy()
@@ -3397,13 +3415,13 @@ def _reindex_callback(
return reindexed
def reindex_like(
- self: T_Dataset,
- other: Dataset | DataArray,
+ self,
+ other: T_Xarray,
method: ReindexMethodOptions = None,
tolerance: int | float | Iterable[int | float] | None = None,
copy: bool = True,
fill_value: Any = xrdtypes.NA,
- ) -> T_Dataset:
+ ) -> Self:
"""Conform this object onto the indexes of another object, filling in
missing values with ``fill_value``. The default fill value is NaN.
@@ -3463,14 +3481,14 @@ def reindex_like(
)
def reindex(
- self: T_Dataset,
+ self,
indexers: Mapping[Any, Any] | None = None,
method: ReindexMethodOptions = None,
tolerance: int | float | Iterable[int | float] | None = None,
copy: bool = True,
fill_value: Any = xrdtypes.NA,
**indexers_kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""Conform this object onto a new set of indexes, filling in
missing values with ``fill_value``. The default fill value is NaN.
@@ -3679,7 +3697,7 @@ def reindex(
)
def _reindex(
- self: T_Dataset,
+ self,
indexers: Mapping[Any, Any] | None = None,
method: str | None = None,
tolerance: int | float | Iterable[int | float] | None = None,
@@ -3687,7 +3705,7 @@ def _reindex(
fill_value: Any = xrdtypes.NA,
sparse: bool = False,
**indexers_kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""
Same as reindex but supports sparse option.
"""
@@ -3703,14 +3721,14 @@ def _reindex(
)
def interp(
- self: T_Dataset,
+ self,
coords: Mapping[Any, Any] | None = None,
method: InterpOptions = "linear",
assume_sorted: bool = False,
kwargs: Mapping[str, Any] | None = None,
method_non_numeric: str = "nearest",
**coords_kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""Interpolate a Dataset onto new coordinates
Performs univariate or multivariate interpolation of a Dataset onto
@@ -3983,12 +4001,12 @@ def _validate_interp_indexer(x, new_x):
def interp_like(
self,
- other: Dataset | DataArray,
+ other: T_Xarray,
method: InterpOptions = "linear",
assume_sorted: bool = False,
kwargs: Mapping[str, Any] | None = None,
method_non_numeric: str = "nearest",
- ) -> Dataset:
+ ) -> Self:
"""Interpolate this object onto the coordinates of another object,
filling the out of range values with NaN.
@@ -4138,10 +4156,10 @@ def _rename_all(
return variables, coord_names, dims, indexes
def _rename(
- self: T_Dataset,
+ self,
name_dict: Mapping[Any, Hashable] | None = None,
**names: Hashable,
- ) -> T_Dataset:
+ ) -> Self:
"""Also used internally by DataArray so that the warning (if any)
is raised at the right stack level.
"""
@@ -4156,6 +4174,9 @@ def _rename(
create_dim_coord = False
new_k = name_dict[k]
+ if k == new_k:
+ continue # Same name, nothing to do
+
if k in self.dims and new_k in self._coord_names:
coord_dims = self._variables[name_dict[k]].dims
if coord_dims == (k,):
@@ -4180,10 +4201,10 @@ def _rename(
return self._replace(variables, coord_names, dims=dims, indexes=indexes)
def rename(
- self: T_Dataset,
+ self,
name_dict: Mapping[Any, Hashable] | None = None,
**names: Hashable,
- ) -> T_Dataset:
+ ) -> Self:
"""Returns a new object with renamed variables, coordinates and dimensions.
Parameters
@@ -4210,10 +4231,10 @@ def rename(
return self._rename(name_dict=name_dict, **names)
def rename_dims(
- self: T_Dataset,
+ self,
dims_dict: Mapping[Any, Hashable] | None = None,
**dims: Hashable,
- ) -> T_Dataset:
+ ) -> Self:
"""Returns a new object with renamed dimensions only.
Parameters
@@ -4257,10 +4278,10 @@ def rename_dims(
return self._replace(variables, coord_names, dims=sizes, indexes=indexes)
def rename_vars(
- self: T_Dataset,
+ self,
name_dict: Mapping[Any, Hashable] | None = None,
**names: Hashable,
- ) -> T_Dataset:
+ ) -> Self:
"""Returns a new object with renamed variables including coordinates
Parameters
@@ -4297,8 +4318,8 @@ def rename_vars(
return self._replace(variables, coord_names, dims=dims, indexes=indexes)
def swap_dims(
- self: T_Dataset, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs
- ) -> T_Dataset:
+ self, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs
+ ) -> Self:
"""Returns a new object with swapped dimensions.
Parameters
@@ -4401,14 +4422,12 @@ def swap_dims(
return self._replace_with_new_dims(variables, coord_names, indexes=indexes)
- # change type of self and return to T_Dataset once
- # https://github.com/python/mypy/issues/12846 is resolved
def expand_dims(
self,
dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None,
axis: None | int | Sequence[int] = None,
**dim_kwargs: Any,
- ) -> Dataset:
+ ) -> Self:
"""Return a new object with an additional axis (or axes) inserted at
the corresponding position in the array shape. The new object is a
view into the underlying array, not a copy.
@@ -4598,14 +4617,12 @@ def expand_dims(
variables, coord_names=coord_names, indexes=indexes
)
- # change type of self and return to T_Dataset once
- # https://github.com/python/mypy/issues/12846 is resolved
def set_index(
self,
indexes: Mapping[Any, Hashable | Sequence[Hashable]] | None = None,
append: bool = False,
**indexes_kwargs: Hashable | Sequence[Hashable],
- ) -> Dataset:
+ ) -> Self:
"""Set Dataset (multi-)indexes using one or more existing coordinates
or variables.
@@ -4765,11 +4782,13 @@ def set_index(
variables, coord_names=coord_names, indexes=indexes_
)
+ @_deprecate_positional_args("v2023.10.0")
def reset_index(
- self: T_Dataset,
+ self,
dims_or_levels: Hashable | Sequence[Hashable],
+ *,
drop: bool = False,
- ) -> T_Dataset:
+ ) -> Self:
"""Reset the specified index(es) or multi-index level(s).
This legacy method is specific to pandas (multi-)indexes and
@@ -4877,11 +4896,11 @@ def drop_or_convert(var_names):
)
def set_xindex(
- self: T_Dataset,
+ self,
coord_names: str | Sequence[Hashable],
index_cls: type[Index] | None = None,
**options,
- ) -> T_Dataset:
+ ) -> Self:
"""Set a new, Xarray-compatible index from one or more existing
coordinate(s).
@@ -4989,10 +5008,10 @@ def set_xindex(
)
def reorder_levels(
- self: T_Dataset,
+ self,
dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None,
**dim_order_kwargs: Sequence[int | Hashable],
- ) -> T_Dataset:
+ ) -> Self:
"""Rearrange index levels using input order.
Parameters
@@ -5093,12 +5112,12 @@ def _get_stack_index(
return stack_index, stack_coords
def _stack_once(
- self: T_Dataset,
+ self,
dims: Sequence[Hashable | ellipsis],
new_dim: Hashable,
index_cls: type[Index],
create_index: bool | None = True,
- ) -> T_Dataset:
+ ) -> Self:
if dims == ...:
raise ValueError("Please use [...] for dims, rather than just ...")
if ... in dims:
@@ -5152,12 +5171,12 @@ def _stack_once(
)
def stack(
- self: T_Dataset,
+ self,
dimensions: Mapping[Any, Sequence[Hashable | ellipsis]] | None = None,
create_index: bool | None = True,
index_cls: type[Index] = PandasMultiIndex,
**dimensions_kwargs: Sequence[Hashable | ellipsis],
- ) -> T_Dataset:
+ ) -> Self:
"""
Stack any number of existing dimensions into a single new dimension.
@@ -5312,12 +5331,12 @@ def stack_dataarray(da):
return data_array
def _unstack_once(
- self: T_Dataset,
+ self,
dim: Hashable,
index_and_vars: tuple[Index, dict[Hashable, Variable]],
fill_value,
sparse: bool = False,
- ) -> T_Dataset:
+ ) -> Self:
index, index_vars = index_and_vars
variables: dict[Hashable, Variable] = {}
indexes = {k: v for k, v in self._indexes.items() if k != dim}
@@ -5352,12 +5371,12 @@ def _unstack_once(
)
def _unstack_full_reindex(
- self: T_Dataset,
+ self,
dim: Hashable,
index_and_vars: tuple[Index, dict[Hashable, Variable]],
fill_value,
sparse: bool,
- ) -> T_Dataset:
+ ) -> Self:
index, index_vars = index_and_vars
variables: dict[Hashable, Variable] = {}
indexes = {k: v for k, v in self._indexes.items() if k != dim}
@@ -5402,12 +5421,14 @@ def _unstack_full_reindex(
variables, coord_names=coord_names, indexes=indexes
)
+ @_deprecate_positional_args("v2023.10.0")
def unstack(
- self: T_Dataset,
+ self,
dim: Dims = None,
+ *,
fill_value: Any = xrdtypes.NA,
sparse: bool = False,
- ) -> T_Dataset:
+ ) -> Self:
"""
Unstack existing dimensions corresponding to MultiIndexes into
multiple new dimensions.
@@ -5504,7 +5525,7 @@ def unstack(
result = result._unstack_once(d, stacked_indexes[d], fill_value, sparse)
return result
- def update(self: T_Dataset, other: CoercibleMapping) -> T_Dataset:
+ def update(self, other: CoercibleMapping) -> Self:
"""Update this dataset's variables with those from another dataset.
Just like :py:meth:`dict.update` this is a in-place operation.
@@ -5544,14 +5565,14 @@ def update(self: T_Dataset, other: CoercibleMapping) -> T_Dataset:
return self._replace(inplace=True, **merge_result._asdict())
def merge(
- self: T_Dataset,
+ self,
other: CoercibleMapping | DataArray,
overwrite_vars: Hashable | Iterable[Hashable] = frozenset(),
compat: CompatOptions = "no_conflicts",
join: JoinOptions = "outer",
fill_value: Any = xrdtypes.NA,
combine_attrs: CombineAttrsOptions = "override",
- ) -> T_Dataset:
+ ) -> Self:
"""Merge the arrays of two datasets into a single dataset.
This method generally does not allow for overriding data, with the
@@ -5655,11 +5676,11 @@ def _assert_all_in_dataset(
)
def drop_vars(
- self: T_Dataset,
+ self,
names: Hashable | Iterable[Hashable],
*,
errors: ErrorOptions = "raise",
- ) -> T_Dataset:
+ ) -> Self:
"""Drop variables from this dataset.
Parameters
@@ -5801,11 +5822,11 @@ def drop_vars(
)
def drop_indexes(
- self: T_Dataset,
+ self,
coord_names: Hashable | Iterable[Hashable],
*,
errors: ErrorOptions = "raise",
- ) -> T_Dataset:
+ ) -> Self:
"""Drop the indexes assigned to the given coordinates.
Parameters
@@ -5857,13 +5878,13 @@ def drop_indexes(
return self._replace(variables=variables, indexes=indexes)
def drop(
- self: T_Dataset,
+ self,
labels=None,
dim=None,
*,
errors: ErrorOptions = "raise",
**labels_kwargs,
- ) -> T_Dataset:
+ ) -> Self:
"""Backward compatible method based on `drop_vars` and `drop_sel`
Using either `drop_vars` or `drop_sel` is encouraged
@@ -5913,8 +5934,8 @@ def drop(
return self.drop_sel(labels, errors=errors)
def drop_sel(
- self: T_Dataset, labels=None, *, errors: ErrorOptions = "raise", **labels_kwargs
- ) -> T_Dataset:
+ self, labels=None, *, errors: ErrorOptions = "raise", **labels_kwargs
+ ) -> Self:
"""Drop index labels from this dataset.
Parameters
@@ -5983,7 +6004,7 @@ def drop_sel(
ds = ds.loc[{dim: new_index}]
return ds
- def drop_isel(self: T_Dataset, indexers=None, **indexers_kwargs) -> T_Dataset:
+ def drop_isel(self, indexers=None, **indexers_kwargs) -> Self:
"""Drop index positions from this Dataset.
Parameters
@@ -6049,11 +6070,11 @@ def drop_isel(self: T_Dataset, indexers=None, **indexers_kwargs) -> T_Dataset:
return ds
def drop_dims(
- self: T_Dataset,
+ self,
drop_dims: str | Iterable[Hashable],
*,
errors: ErrorOptions = "raise",
- ) -> T_Dataset:
+ ) -> Self:
"""Drop dimensions and associated variables from this dataset.
Parameters
@@ -6090,10 +6111,10 @@ def drop_dims(
return self.drop_vars(drop_vars)
def transpose(
- self: T_Dataset,
+ self,
*dims: Hashable,
missing_dims: ErrorOptionsWithWarn = "raise",
- ) -> T_Dataset:
+ ) -> Self:
"""Return a new Dataset object with all array dimensions transposed.
Although the order of dimensions on each array will change, the dataset
@@ -6145,13 +6166,15 @@ def transpose(
ds._variables[name] = var.transpose(*var_dims)
return ds
+ @_deprecate_positional_args("v2023.10.0")
def dropna(
- self: T_Dataset,
+ self,
dim: Hashable,
+ *,
how: Literal["any", "all"] = "any",
thresh: int | None = None,
subset: Iterable[Hashable] | None = None,
- ) -> T_Dataset:
+ ) -> Self:
"""Returns a new dataset with dropped labels for missing values along
the provided dimension.
@@ -6273,7 +6296,7 @@ def dropna(
return self.isel({dim: mask})
- def fillna(self: T_Dataset, value: Any) -> T_Dataset:
+ def fillna(self, value: Any) -> Self:
"""Fill missing values in this object.
This operation follows the normal broadcasting and alignment rules that
@@ -6354,7 +6377,7 @@ def fillna(self: T_Dataset, value: Any) -> T_Dataset:
return out
def interpolate_na(
- self: T_Dataset,
+ self,
dim: Hashable | None = None,
method: InterpOptions = "linear",
limit: int | None = None,
@@ -6363,7 +6386,7 @@ def interpolate_na(
int | float | str | pd.Timedelta | np.timedelta64 | datetime.timedelta
) = None,
**kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""Fill in NaNs by interpolating according to different methods.
Parameters
@@ -6493,7 +6516,7 @@ def interpolate_na(
)
return new
- def ffill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset:
+ def ffill(self, dim: Hashable, limit: int | None = None) -> Self:
"""Fill NaN values by propagating values forward
*Requires bottleneck.*
@@ -6557,7 +6580,7 @@ def ffill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset
new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit)
return new
- def bfill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset:
+ def bfill(self, dim: Hashable, limit: int | None = None) -> Self:
"""Fill NaN values by propagating values backward
*Requires bottleneck.*
@@ -6622,7 +6645,7 @@ def bfill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset
new = _apply_over_vars_with_dim(bfill, self, dim=dim, limit=limit)
return new
- def combine_first(self: T_Dataset, other: T_Dataset) -> T_Dataset:
+ def combine_first(self, other: Self) -> Self:
"""Combine two Datasets, default to data_vars of self.
The new coordinates follow the normal broadcasting and alignment rules
@@ -6642,7 +6665,7 @@ def combine_first(self: T_Dataset, other: T_Dataset) -> T_Dataset:
return out
def reduce(
- self: T_Dataset,
+ self,
func: Callable,
dim: Dims = None,
*,
@@ -6650,7 +6673,7 @@ def reduce(
keepdims: bool = False,
numeric_only: bool = False,
**kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""Reduce this dataset by applying `func` along some dimension(s).
Parameters
@@ -6775,12 +6798,12 @@ def reduce(
)
def map(
- self: T_Dataset,
+ self,
func: Callable,
keep_attrs: bool | None = None,
args: Iterable[Any] = (),
**kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""Apply a function to each data variable in this dataset
Parameters
@@ -6835,12 +6858,12 @@ def map(
return type(self)(variables, attrs=attrs)
def apply(
- self: T_Dataset,
+ self,
func: Callable,
keep_attrs: bool | None = None,
args: Iterable[Any] = (),
**kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""
Backward compatible implementation of ``map``
@@ -6856,10 +6879,10 @@ def apply(
return self.map(func, keep_attrs, args, **kwargs)
def assign(
- self: T_Dataset,
+ self,
variables: Mapping[Any, Any] | None = None,
**variables_kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""Assign new data variables to a Dataset, returning a new object
with all the original variables in addition to the new ones.
@@ -7164,9 +7187,7 @@ def _set_numpy_data_from_dataframe(
self[name] = (dims, data)
@classmethod
- def from_dataframe(
- cls: type[T_Dataset], dataframe: pd.DataFrame, sparse: bool = False
- ) -> T_Dataset:
+ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
"""Convert a pandas.DataFrame into an xarray.Dataset
Each column will be converted into an independent variable in the
@@ -7380,7 +7401,7 @@ def to_dict(
return d
@classmethod
- def from_dict(cls: type[T_Dataset], d: Mapping[Any, Any]) -> T_Dataset:
+ def from_dict(cls, d: Mapping[Any, Any]) -> Self:
"""Convert a dictionary into an xarray.Dataset.
Parameters
@@ -7470,7 +7491,7 @@ def from_dict(cls: type[T_Dataset], d: Mapping[Any, Any]) -> T_Dataset:
return obj
- def _unary_op(self: T_Dataset, f, *args, **kwargs) -> T_Dataset:
+ def _unary_op(self, f, *args, **kwargs) -> Self:
variables = {}
keep_attrs = kwargs.pop("keep_attrs", None)
if keep_attrs is None:
@@ -7493,7 +7514,7 @@ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset:
return NotImplemented
align_type = OPTIONS["arithmetic_join"] if join is None else join
if isinstance(other, (DataArray, Dataset)):
- self, other = align(self, other, join=align_type, copy=False) # type: ignore[assignment]
+ self, other = align(self, other, join=align_type, copy=False)
g = f if not reflexive else lambda x, y: f(y, x)
ds = self._calculate_binary_op(g, other, join=align_type)
keep_attrs = _get_keep_attrs(default=False)
@@ -7501,7 +7522,7 @@ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset:
ds.attrs = self.attrs
return ds
- def _inplace_binary_op(self: T_Dataset, other, f) -> T_Dataset:
+ def _inplace_binary_op(self, other, f) -> Self:
from xarray.core.dataarray import DataArray
from xarray.core.groupby import GroupBy
@@ -7575,12 +7596,14 @@ def _copy_attrs_from(self, other):
if v in self.variables:
self.variables[v].attrs = other.variables[v].attrs
+ @_deprecate_positional_args("v2023.10.0")
def diff(
- self: T_Dataset,
+ self,
dim: Hashable,
n: int = 1,
+ *,
label: Literal["upper", "lower"] = "upper",
- ) -> T_Dataset:
+ ) -> Self:
"""Calculate the n-th order discrete difference along given axis.
Parameters
@@ -7663,11 +7686,11 @@ def diff(
return difference
def shift(
- self: T_Dataset,
+ self,
shifts: Mapping[Any, int] | None = None,
fill_value: Any = xrdtypes.NA,
**shifts_kwargs: int,
- ) -> T_Dataset:
+ ) -> Self:
"""Shift this dataset by an offset along one or more dimensions.
Only data variables are moved; coordinates stay in place. This is
@@ -7734,11 +7757,11 @@ def shift(
return self._replace(variables)
def roll(
- self: T_Dataset,
+ self,
shifts: Mapping[Any, int] | None = None,
roll_coords: bool = False,
**shifts_kwargs: int,
- ) -> T_Dataset:
+ ) -> Self:
"""Roll this dataset by an offset along one or more dimensions.
Unlike shift, roll treats the given dimensions as periodic, so will not
@@ -7820,10 +7843,13 @@ def roll(
return self._replace(variables, indexes=indexes)
def sortby(
- self: T_Dataset,
- variables: Hashable | DataArray | list[Hashable | DataArray],
+ self,
+ variables: Hashable
+ | DataArray
+ | Sequence[Hashable | DataArray]
+ | Callable[[Self], Hashable | DataArray | list[Hashable | DataArray]],
ascending: bool = True,
- ) -> T_Dataset:
+ ) -> Self:
"""
Sort object by labels or values (along an axis).
@@ -7843,9 +7869,10 @@ def sortby(
Parameters
----------
- variables : Hashable, DataArray, or list of hashable or DataArray
- 1D DataArray objects or name(s) of 1D variable(s) in
- coords/data_vars whose values are used to sort the dataset.
+ kariables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable
+ 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are
+ used to sort this array. If a callable, the callable is passed this object,
+ and the result is used as the value for cond.
ascending : bool, default: True
Whether to sort by ascending or descending order.
@@ -7871,8 +7898,7 @@ def sortby(
... },
... coords={"x": ["b", "a"], "y": [1, 0]},
... )
- >>> ds = ds.sortby("x")
- >>> ds
+ >>> ds.sortby("x")
Dimensions: (x: 2, y: 2)
Coordinates:
@@ -7881,17 +7907,28 @@ def sortby(
Data variables:
A (x, y) int64 3 4 1 2
B (x, y) int64 7 8 5 6
+ >>> ds.sortby(lambda x: -x["y"])
+
+ Dimensions: (x: 2, y: 2)
+ Coordinates:
+ * x (x) T_Dataset:
+ ) -> Self:
"""Compute the qth quantile of the data along the specified dimension.
Returns the qth quantiles(s) of the array elements for each variable
@@ -8083,12 +8122,14 @@ def quantile(
)
return new.assign_coords(quantile=q)
+ @_deprecate_positional_args("v2023.10.0")
def rank(
- self: T_Dataset,
+ self,
dim: Hashable,
+ *,
pct: bool = False,
keep_attrs: bool | None = None,
- ) -> T_Dataset:
+ ) -> Self:
"""Ranks the data.
Equal values are assigned a rank that is the average of the ranks that
@@ -8142,11 +8183,11 @@ def rank(
return self._replace(variables, coord_names, attrs=attrs)
def differentiate(
- self: T_Dataset,
+ self,
coord: Hashable,
edge_order: Literal[1, 2] = 1,
datetime_unit: DatetimeUnitOptions | None = None,
- ) -> T_Dataset:
+ ) -> Self:
""" Differentiate with the second order accurate central
differences.
@@ -8214,10 +8255,10 @@ def differentiate(
return self._replace(variables)
def integrate(
- self: T_Dataset,
+ self,
coord: Hashable | Sequence[Hashable],
datetime_unit: DatetimeUnitOptions = None,
- ) -> T_Dataset:
+ ) -> Self:
"""Integrate along the given coordinate using the trapezoidal rule.
.. note::
@@ -8333,10 +8374,10 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False):
)
def cumulative_integrate(
- self: T_Dataset,
+ self,
coord: Hashable | Sequence[Hashable],
datetime_unit: DatetimeUnitOptions = None,
- ) -> T_Dataset:
+ ) -> Self:
"""Integrate along the given coordinate using the trapezoidal rule.
.. note::
@@ -8408,7 +8449,7 @@ def cumulative_integrate(
return result
@property
- def real(self: T_Dataset) -> T_Dataset:
+ def real(self) -> Self:
"""
The real part of each data variable.
@@ -8419,7 +8460,7 @@ def real(self: T_Dataset) -> T_Dataset:
return self.map(lambda x: x.real, keep_attrs=True)
@property
- def imag(self: T_Dataset) -> T_Dataset:
+ def imag(self) -> Self:
"""
The imaginary part of each data variable.
@@ -8431,7 +8472,7 @@ def imag(self: T_Dataset) -> T_Dataset:
plot = utils.UncachedAccessor(DatasetPlotAccessor)
- def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset:
+ def filter_by_attrs(self, **kwargs) -> Self:
"""Returns a ``Dataset`` with variables that match specific conditions.
Can pass in ``key=value`` or ``key=callable``. A Dataset is returned
@@ -8526,7 +8567,7 @@ def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset:
selection.append(var_name)
return self[selection]
- def unify_chunks(self: T_Dataset) -> T_Dataset:
+ def unify_chunks(self) -> Self:
"""Unify chunk size along all chunked dimensions of this Dataset.
Returns
@@ -8648,7 +8689,7 @@ def map_blocks(
return map_blocks(func, self, args, kwargs, template)
def polyfit(
- self: T_Dataset,
+ self,
dim: Hashable,
deg: int,
skipna: bool | None = None,
@@ -8656,7 +8697,7 @@ def polyfit(
w: Hashable | Any = None,
full: bool = False,
cov: bool | Literal["unscaled"] = False,
- ) -> T_Dataset:
+ ) -> Self:
"""
Least squares polynomial fit.
@@ -8844,7 +8885,7 @@ def polyfit(
return type(self)(data_vars=variables, attrs=self.attrs.copy())
def pad(
- self: T_Dataset,
+ self,
pad_width: Mapping[Any, int | tuple[int, int]] | None = None,
mode: PadModeOptions = "constant",
stat_length: int
@@ -8858,7 +8899,7 @@ def pad(
reflect_type: PadReflectOptions = None,
keep_attrs: bool | None = None,
**pad_width_kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""Pad this dataset along one or more dimensions.
.. warning::
@@ -9029,13 +9070,15 @@ def pad(
attrs = self._attrs if keep_attrs else None
return self._replace_with_new_dims(variables, indexes=indexes, attrs=attrs)
+ @_deprecate_positional_args("v2023.10.0")
def idxmin(
- self: T_Dataset,
+ self,
dim: Hashable | None = None,
+ *,
skipna: bool | None = None,
fill_value: Any = xrdtypes.NA,
keep_attrs: bool | None = None,
- ) -> T_Dataset:
+ ) -> Self:
"""Return the coordinate label of the minimum value along a dimension.
Returns a new `Dataset` named after the dimension with the values of
@@ -9126,13 +9169,15 @@ def idxmin(
)
)
+ @_deprecate_positional_args("v2023.10.0")
def idxmax(
- self: T_Dataset,
+ self,
dim: Hashable | None = None,
+ *,
skipna: bool | None = None,
fill_value: Any = xrdtypes.NA,
keep_attrs: bool | None = None,
- ) -> T_Dataset:
+ ) -> Self:
"""Return the coordinate label of the maximum value along a dimension.
Returns a new `Dataset` named after the dimension with the values of
@@ -9223,7 +9268,7 @@ def idxmax(
)
)
- def argmin(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset:
+ def argmin(self, dim: Hashable | None = None, **kwargs) -> Self:
"""Indices of the minima of the member variables.
If there are multiple minima, the indices of the first one found will be
@@ -9326,7 +9371,7 @@ def argmin(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset:
"Dataset.argmin() with a sequence or ... for dim"
)
- def argmax(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset:
+ def argmax(self, dim: Hashable | None = None, **kwargs) -> Self:
"""Indices of the maxima of the member variables.
If there are multiple maxima, the indices of the first one found will be
@@ -9420,13 +9465,13 @@ def argmax(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset:
)
def query(
- self: T_Dataset,
+ self,
queries: Mapping[Any, Any] | None = None,
parser: QueryParserOptions = "pandas",
engine: QueryEngineOptions = None,
missing_dims: ErrorOptionsWithWarn = "raise",
**queries_kwargs: Any,
- ) -> T_Dataset:
+ ) -> Self:
"""Return a new dataset with each array indexed along the specified
dimension(s), where the indexers are given as strings containing
Python expressions to be evaluated against the data variables in the
@@ -9516,7 +9561,7 @@ def query(
return self.isel(indexers, missing_dims=missing_dims)
def curvefit(
- self: T_Dataset,
+ self,
coords: str | DataArray | Iterable[str | DataArray],
func: Callable[..., Any],
reduce_dims: Dims = None,
@@ -9526,7 +9571,7 @@ def curvefit(
param_names: Sequence[str] | None = None,
errors: ErrorOptions = "raise",
kwargs: dict[str, Any] | None = None,
- ) -> T_Dataset:
+ ) -> Self:
"""
Curve fitting optimization for arbitrary functions.
@@ -9749,11 +9794,13 @@ def _wrapper(Y, *args, **kwargs):
return result
+ @_deprecate_positional_args("v2023.10.0")
def drop_duplicates(
- self: T_Dataset,
+ self,
dim: Hashable | Iterable[Hashable],
+ *,
keep: Literal["first", "last", False] = "first",
- ) -> T_Dataset:
+ ) -> Self:
"""Returns a new Dataset with duplicate dimension values removed.
Parameters
@@ -9793,13 +9840,13 @@ def drop_duplicates(
return self.isel(indexes)
def convert_calendar(
- self: T_Dataset,
+ self,
calendar: CFCalendar,
dim: Hashable = "time",
align_on: Literal["date", "year", None] = None,
missing: Any | None = None,
use_cftime: bool | None = None,
- ) -> T_Dataset:
+ ) -> Self:
"""Convert the Dataset to another calendar.
Only converts the individual timestamps, does not modify any data except
@@ -9916,10 +9963,10 @@ def convert_calendar(
)
def interp_calendar(
- self: T_Dataset,
+ self,
target: pd.DatetimeIndex | CFTimeIndex | DataArray,
dim: Hashable = "time",
- ) -> T_Dataset:
+ ) -> Self:
"""Interpolates the Dataset to another calendar based on decimal year measure.
Each timestamp in `source` and `target` are first converted to their decimal
diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py
index 0762fa03112..ccf84146819 100644
--- a/xarray/core/dtypes.py
+++ b/xarray/core/dtypes.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import functools
+from typing import Any
import numpy as np
@@ -44,7 +45,7 @@ def __eq__(self, other):
)
-def maybe_promote(dtype):
+def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
"""Simpler equivalent of pandas.core.common._maybe_promote
Parameters
@@ -57,27 +58,33 @@ def maybe_promote(dtype):
fill_value : Valid missing value for the promoted dtype.
"""
# N.B. these casting rules should match pandas
+ dtype_: np.typing.DTypeLike
+ fill_value: Any
if np.issubdtype(dtype, np.floating):
+ dtype_ = dtype
fill_value = np.nan
elif np.issubdtype(dtype, np.timedelta64):
# See https://github.com/numpy/numpy/issues/10685
# np.timedelta64 is a subclass of np.integer
# Check np.timedelta64 before np.integer
fill_value = np.timedelta64("NaT")
+ dtype_ = dtype
elif np.issubdtype(dtype, np.integer):
- dtype = np.float32 if dtype.itemsize <= 2 else np.float64
+ dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64
fill_value = np.nan
elif np.issubdtype(dtype, np.complexfloating):
+ dtype_ = dtype
fill_value = np.nan + np.nan * 1j
elif np.issubdtype(dtype, np.datetime64):
+ dtype_ = dtype
fill_value = np.datetime64("NaT")
else:
- dtype = object
+ dtype_ = object
fill_value = np.nan
- dtype = np.dtype(dtype)
- fill_value = dtype.type(fill_value)
- return dtype, fill_value
+ dtype_out = np.dtype(dtype_)
+ fill_value = dtype_out.type(fill_value)
+ return dtype_out, fill_value
NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype}
diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py
index 4f245e59f73..078aab0ed63 100644
--- a/xarray/core/duck_array_ops.py
+++ b/xarray/core/duck_array_ops.py
@@ -337,6 +337,10 @@ def reshape(array, shape):
return xp.reshape(array, shape)
+def ravel(array):
+ return reshape(array, (-1,))
+
+
@contextlib.contextmanager
def _ignore_warnings_if(condition):
if condition:
@@ -363,7 +367,7 @@ def f(values, axis=None, skipna=None, **kwargs):
values = asarray(values)
if coerce_strings and values.dtype.kind in "SU":
- values = values.astype(object)
+ values = astype(values, object)
func = None
if skipna or (skipna is None and values.dtype.kind in "cfO"):
diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py
index 86e96ca2deb..09b7fafc7bf 100644
--- a/xarray/core/formatting.py
+++ b/xarray/core/formatting.py
@@ -178,6 +178,8 @@ def format_item(x, timedelta_format=None, quote_strings=True):
if isinstance(x, (np.timedelta64, timedelta)):
return format_timedelta(x, timedelta_format=timedelta_format)
elif isinstance(x, (str, bytes)):
+ if hasattr(x, "dtype"):
+ x = x.item()
return repr(x) if quote_strings else x
elif hasattr(x, "dtype") and np.issubdtype(x.dtype, np.floating):
return f"{x.item():.4}"
diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py
index 9894a4a4daf..8ed7148e2a1 100644
--- a/xarray/core/groupby.py
+++ b/xarray/core/groupby.py
@@ -43,6 +43,7 @@
peek_at,
)
from xarray.core.variable import IndexVariable, Variable
+from xarray.util.deprecation_helpers import _deprecate_positional_args
if TYPE_CHECKING:
from numpy.typing import ArrayLike
@@ -699,7 +700,7 @@ class GroupBy(Generic[T_Xarray]):
_groups: dict[GroupKey, GroupIndex] | None
_dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None
- _sizes: Frozen[Hashable, int] | None
+ _sizes: Mapping[Hashable, int] | None
def __init__(
self,
@@ -746,7 +747,7 @@ def __init__(
self._sizes = None
@property
- def sizes(self) -> Frozen[Hashable, int]:
+ def sizes(self) -> Mapping[Hashable, int]:
"""Ordered mapping from dimension names to lengths.
Immutable.
@@ -1092,10 +1093,12 @@ def fillna(self, value: Any) -> T_Xarray:
"""
return ops.fillna(self, value)
+ @_deprecate_positional_args("v2023.10.0")
def quantile(
self,
q: ArrayLike,
dim: Dims = None,
+ *,
method: QuantileMethods = "linear",
keep_attrs: bool | None = None,
skipna: bool | None = None,
diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py
index 9972896d6df..1697762f7ae 100644
--- a/xarray/core/indexes.py
+++ b/xarray/core/indexes.py
@@ -24,7 +24,7 @@
)
if TYPE_CHECKING:
- from xarray.core.types import ErrorOptions, JoinOptions, T_Index
+ from xarray.core.types import ErrorOptions, JoinOptions, Self
from xarray.core.variable import Variable
@@ -60,11 +60,11 @@ class Index:
@classmethod
def from_variables(
- cls: type[T_Index],
+ cls,
variables: Mapping[Any, Variable],
*,
options: Mapping[str, Any],
- ) -> T_Index:
+ ) -> Self:
"""Create a new index object from one or more coordinate variables.
This factory method must be implemented in all subclasses of Index.
@@ -88,11 +88,11 @@ def from_variables(
@classmethod
def concat(
- cls: type[T_Index],
- indexes: Sequence[T_Index],
+ cls,
+ indexes: Sequence[Self],
dim: Hashable,
positions: Iterable[Iterable[int]] | None = None,
- ) -> T_Index:
+ ) -> Self:
"""Create a new index by concatenating one or more indexes of the same
type.
@@ -120,9 +120,7 @@ def concat(
raise NotImplementedError()
@classmethod
- def stack(
- cls: type[T_Index], variables: Mapping[Any, Variable], dim: Hashable
- ) -> T_Index:
+ def stack(cls, variables: Mapping[Any, Variable], dim: Hashable) -> Self:
"""Create a new index by stacking coordinate variables into a single new
dimension.
@@ -208,8 +206,8 @@ def to_pandas_index(self) -> pd.Index:
raise TypeError(f"{self!r} cannot be cast to a pandas.Index object")
def isel(
- self: T_Index, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
- ) -> T_Index | None:
+ self, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
+ ) -> Self | None:
"""Maybe returns a new index from the current index itself indexed by
positional indexers.
@@ -264,7 +262,7 @@ def sel(self, labels: dict[Any, Any]) -> IndexSelResult:
"""
raise NotImplementedError(f"{self!r} doesn't support label-based selection")
- def join(self: T_Index, other: T_Index, how: JoinOptions = "inner") -> T_Index:
+ def join(self, other: Self, how: JoinOptions = "inner") -> Self:
"""Return a new index from the combination of this index with another
index of the same type.
@@ -286,7 +284,7 @@ def join(self: T_Index, other: T_Index, how: JoinOptions = "inner") -> T_Index:
f"{self!r} doesn't support alignment with inner/outer join method"
)
- def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]:
+ def reindex_like(self, other: Self) -> dict[Hashable, Any]:
"""Query the index with another index of the same type.
Implementation is optional but required in order to support alignment.
@@ -304,7 +302,7 @@ def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]:
"""
raise NotImplementedError(f"{self!r} doesn't support re-indexing labels")
- def equals(self: T_Index, other: T_Index) -> bool:
+ def equals(self, other: Self) -> bool:
"""Compare this index with another index of the same type.
Implementation is optional but required in order to support alignment.
@@ -321,7 +319,7 @@ def equals(self: T_Index, other: T_Index) -> bool:
"""
raise NotImplementedError()
- def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index | None:
+ def roll(self, shifts: Mapping[Any, int]) -> Self | None:
"""Roll this index by an offset along one or more dimensions.
This method can be re-implemented in subclasses of Index, e.g., when the
@@ -347,10 +345,10 @@ def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index | None:
return None
def rename(
- self: T_Index,
+ self,
name_dict: Mapping[Any, Hashable],
dims_dict: Mapping[Any, Hashable],
- ) -> T_Index:
+ ) -> Self:
"""Maybe update the index with new coordinate and dimension names.
This method should be re-implemented in subclasses of Index if it has
@@ -377,7 +375,7 @@ def rename(
"""
return self
- def copy(self: T_Index, deep: bool = True) -> T_Index:
+ def copy(self, deep: bool = True) -> Self:
"""Return a (deep) copy of this index.
Implementation in subclasses of Index is optional. The base class
@@ -396,15 +394,13 @@ def copy(self: T_Index, deep: bool = True) -> T_Index:
"""
return self._copy(deep=deep)
- def __copy__(self: T_Index) -> T_Index:
+ def __copy__(self) -> Self:
return self.copy(deep=False)
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Index:
return self._copy(deep=True, memo=memo)
- def _copy(
- self: T_Index, deep: bool = True, memo: dict[int, Any] | None = None
- ) -> T_Index:
+ def _copy(self, deep: bool = True, memo: dict[int, Any] | None = None) -> Self:
cls = self.__class__
copied = cls.__new__(cls)
if deep:
@@ -414,7 +410,7 @@ def _copy(
copied.__dict__.update(self.__dict__)
return copied
- def __getitem__(self: T_Index, indexer: Any) -> T_Index:
+ def __getitem__(self, indexer: Any) -> Self:
raise NotImplementedError()
def _repr_inline_(self, max_width):
@@ -674,10 +670,10 @@ def _concat_indexes(indexes, dim, positions=None) -> pd.Index:
@classmethod
def concat(
cls,
- indexes: Sequence[PandasIndex],
+ indexes: Sequence[Self],
dim: Hashable,
positions: Iterable[Iterable[int]] | None = None,
- ) -> PandasIndex:
+ ) -> Self:
new_pd_index = cls._concat_indexes(indexes, dim, positions)
if not indexes:
@@ -800,7 +796,11 @@ def equals(self, other: Index):
return False
return self.index.equals(other.index) and self.dim == other.dim
- def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasIndex:
+ def join(
+ self,
+ other: Self,
+ how: str = "inner",
+ ) -> Self:
if how == "outer":
index = self.index.union(other.index)
else:
@@ -811,7 +811,7 @@ def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasInd
return type(self)(index, self.dim, coord_dtype=coord_dtype)
def reindex_like(
- self, other: PandasIndex, method=None, tolerance=None
+ self, other: Self, method=None, tolerance=None
) -> dict[Hashable, Any]:
if not self.index.is_unique:
raise ValueError(
@@ -963,12 +963,12 @@ def from_variables(
return obj
@classmethod
- def concat( # type: ignore[override]
+ def concat(
cls,
- indexes: Sequence[PandasMultiIndex],
+ indexes: Sequence[Self],
dim: Hashable,
positions: Iterable[Iterable[int]] | None = None,
- ) -> PandasMultiIndex:
+ ) -> Self:
new_pd_index = cls._concat_indexes(indexes, dim, positions)
if not indexes:
@@ -1602,7 +1602,7 @@ def to_pandas_indexes(self) -> Indexes[pd.Index]:
return Indexes(indexes, self._variables, index_type=pd.Index)
def copy_indexes(
- self, deep: bool = True, memo: dict[int, Any] | None = None
+ self, deep: bool = True, memo: dict[int, T_PandasOrXarrayIndex] | None = None
) -> tuple[dict[Hashable, T_PandasOrXarrayIndex], dict[Hashable, Variable]]:
"""Return a new dictionary with copies of indexes, preserving
unique indexes.
@@ -1619,6 +1619,7 @@ def copy_indexes(
new_indexes = {}
new_index_vars = {}
+ idx: T_PandasOrXarrayIndex
for idx, coords in self.group_by_index():
if isinstance(idx, pd.Index):
convert_new_idx = True
diff --git a/xarray/core/merge.py b/xarray/core/merge.py
index 3475db4a010..a8e54ad1231 100644
--- a/xarray/core/merge.py
+++ b/xarray/core/merge.py
@@ -474,10 +474,11 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
- out = []
+ out: list[DatasetLike] = []
for obj in objects:
+ variables: DatasetLike
if isinstance(obj, (Dataset, Coordinates)):
- variables: DatasetLike = obj
+ variables = obj
else:
variables = {}
if isinstance(obj, PANDAS_TYPES):
@@ -491,7 +492,7 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik
def _get_priority_vars_and_indexes(
- objects: list[DatasetLike],
+ objects: Sequence[DatasetLike],
priority_arg: int | None,
compat: CompatOptions = "equals",
) -> dict[Hashable, MergeElement]:
@@ -503,7 +504,7 @@ def _get_priority_vars_and_indexes(
Parameters
----------
- objects : list of dict-like of Variable
+ objects : sequence of dict-like of Variable
Dictionaries in which to find the priority variables.
priority_arg : int or None
Integer object whose variable should take priority.
diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py
index 3b8ddfe032d..fc7240139aa 100644
--- a/xarray/core/nanops.py
+++ b/xarray/core/nanops.py
@@ -4,7 +4,7 @@
import numpy as np
-from xarray.core import dtypes, nputils, utils
+from xarray.core import dtypes, duck_array_ops, nputils, utils
from xarray.core.duck_array_ops import (
astype,
count,
@@ -21,12 +21,16 @@ def _maybe_null_out(result, axis, mask, min_count=1):
xarray version of pandas.core.nanops._maybe_null_out
"""
if axis is not None and getattr(result, "ndim", False):
- null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0
+ null_mask = (
+ np.take(mask.shape, axis).prod()
+ - duck_array_ops.sum(mask, axis)
+ - min_count
+ ) < 0
dtype, fill_value = dtypes.maybe_promote(result.dtype)
result = where(null_mask, fill_value, astype(result, dtype))
elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES:
- null_mask = mask.size - mask.sum()
+ null_mask = mask.size - duck_array_ops.sum(mask)
result = where(null_mask < min_count, np.nan, result)
return result
diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py
index 07c3c606bf2..949576b4ee8 100644
--- a/xarray/core/parallel.py
+++ b/xarray/core/parallel.py
@@ -443,7 +443,7 @@ def subset_dataset_to_block(
for dim in variable.dims:
chunk = chunk[chunk_index[dim]]
- chunk_variable_task = (f"{name}-{gname}-{chunk[0]}",) + chunk_tuple
+ chunk_variable_task = (f"{name}-{gname}-{chunk[0]!r}",) + chunk_tuple
graph[chunk_variable_task] = (
tuple,
[variable.dims, chunk, variable.attrs],
diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py
index d49cb6e13a4..b85092982e3 100644
--- a/xarray/core/rolling.py
+++ b/xarray/core/rolling.py
@@ -61,6 +61,11 @@ class Rolling(Generic[T_Xarray]):
__slots__ = ("obj", "window", "min_periods", "center", "dim")
_attributes = ("window", "min_periods", "center", "dim")
+ dim: list[Hashable]
+ window: list[int]
+ center: list[bool]
+ obj: T_Xarray
+ min_periods: int
def __init__(
self,
@@ -91,8 +96,8 @@ def __init__(
-------
rolling : type of input argument
"""
- self.dim: list[Hashable] = []
- self.window: list[int] = []
+ self.dim = []
+ self.window = []
for d, w in windows.items():
self.dim.append(d)
if w <= 0:
@@ -100,7 +105,7 @@ def __init__(
self.window.append(w)
self.center = self._mapping_to_list(center, default=False)
- self.obj: T_Xarray = obj
+ self.obj = obj
missing_dims = tuple(dim for dim in self.dim if dim not in self.obj.dims)
if missing_dims:
@@ -785,11 +790,14 @@ def construct(
if not keep_attrs:
dataset[key].attrs = {}
+ # Need to stride coords as well. TODO: is there a better way?
+ coords = self.obj.isel(
+ {d: slice(None, None, s) for d, s in zip(self.dim, strides)}
+ ).coords
+
attrs = self.obj.attrs if keep_attrs else {}
- return Dataset(dataset, coords=self.obj.coords, attrs=attrs).isel(
- {d: slice(None, None, s) for d, s in zip(self.dim, strides)}
- )
+ return Dataset(dataset, coords=coords, attrs=attrs)
class Coarsen(CoarsenArithmetic, Generic[T_Xarray]):
@@ -811,6 +819,10 @@ class Coarsen(CoarsenArithmetic, Generic[T_Xarray]):
)
_attributes = ("windows", "side", "trim_excess")
obj: T_Xarray
+ windows: Mapping[Hashable, int]
+ side: SideOptions | Mapping[Hashable, SideOptions]
+ boundary: CoarsenBoundaryOptions
+ coord_func: Mapping[Hashable, str | Callable]
def __init__(
self,
@@ -852,12 +864,15 @@ def __init__(
f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} "
f"dimensions {tuple(self.obj.dims)}"
)
- if not utils.is_dict_like(coord_func):
- coord_func = {d: coord_func for d in self.obj.dims} # type: ignore[misc]
+
+ if utils.is_dict_like(coord_func):
+ coord_func_map = coord_func
+ else:
+ coord_func_map = {d: coord_func for d in self.obj.dims}
for c in self.obj.coords:
- if c not in coord_func:
- coord_func[c] = duck_array_ops.mean # type: ignore[index]
- self.coord_func: Mapping[Hashable, str | Callable] = coord_func
+ if c not in coord_func_map:
+ coord_func_map[c] = duck_array_ops.mean # type: ignore[index]
+ self.coord_func = coord_func_map
def _get_keep_attrs(self, keep_attrs):
if keep_attrs is None:
diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py
index bd30c634aae..cb77358869c 100644
--- a/xarray/core/rolling_exp.py
+++ b/xarray/core/rolling_exp.py
@@ -9,10 +9,15 @@
from xarray.core.options import _get_keep_attrs
from xarray.core.pdcompat import count_not_none
from xarray.core.pycompat import is_duck_dask_array
-from xarray.core.types import T_DataWithCoords
+from xarray.core.types import T_DataWithCoords, T_DuckArray
-def _get_alpha(com=None, span=None, halflife=None, alpha=None):
+def _get_alpha(
+ com: float | None = None,
+ span: float | None = None,
+ halflife: float | None = None,
+ alpha: float | None = None,
+) -> float:
# pandas defines in terms of com (converting to alpha in the algo)
# so use its function to get a com and then convert to alpha
@@ -20,7 +25,7 @@ def _get_alpha(com=None, span=None, halflife=None, alpha=None):
return 1 / (1 + com)
-def move_exp_nanmean(array, *, axis, alpha):
+def move_exp_nanmean(array: T_DuckArray, *, axis: int, alpha: float) -> np.ndarray:
if is_duck_dask_array(array):
raise TypeError("rolling_exp is not currently support for dask-like arrays")
import numbagg
@@ -32,7 +37,7 @@ def move_exp_nanmean(array, *, axis, alpha):
return numbagg.move_exp_nanmean(array, axis=axis, alpha=alpha)
-def move_exp_nansum(array, *, axis, alpha):
+def move_exp_nansum(array: T_DuckArray, *, axis: int, alpha: float) -> np.ndarray:
if is_duck_dask_array(array):
raise TypeError("rolling_exp is not currently supported for dask-like arrays")
import numbagg
@@ -40,7 +45,12 @@ def move_exp_nansum(array, *, axis, alpha):
return numbagg.move_exp_nansum(array, axis=axis, alpha=alpha)
-def _get_center_of_mass(comass, span, halflife, alpha):
+def _get_center_of_mass(
+ comass: float | None,
+ span: float | None,
+ halflife: float | None,
+ alpha: float | None,
+) -> float:
"""
Vendored from pandas.core.window.common._get_center_of_mass
@@ -137,9 +147,9 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
input_core_dims=[[self.dim]],
kwargs=dict(alpha=self.alpha, axis=-1),
output_core_dims=[[self.dim]],
- exclude_dims={self.dim},
keep_attrs=keep_attrs,
on_missing_core_dim="copy",
+ dask="parallelized",
).transpose(*dim_order)
def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
@@ -173,7 +183,7 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
input_core_dims=[[self.dim]],
kwargs=dict(alpha=self.alpha, axis=-1),
output_core_dims=[[self.dim]],
- exclude_dims={self.dim},
keep_attrs=keep_attrs,
on_missing_core_dim="copy",
+ dask="parallelized",
).transpose(*dim_order)
diff --git a/xarray/core/types.py b/xarray/core/types.py
index f80c2c52cd7..2af9591d22a 100644
--- a/xarray/core/types.py
+++ b/xarray/core/types.py
@@ -19,9 +19,9 @@
try:
if sys.version_info >= (3, 11):
- from typing import Self
+ from typing import Self, TypeAlias
else:
- from typing_extensions import Self
+ from typing_extensions import Self, TypeAlias
except ImportError:
if TYPE_CHECKING:
raise
@@ -38,7 +38,6 @@
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
- from xarray.core.groupby import DataArrayGroupBy, GroupBy
from xarray.core.indexes import Index, Indexes
from xarray.core.utils import Frozen
from xarray.core.variable import Variable
@@ -107,7 +106,7 @@ def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]:
...
@property
- def sizes(self) -> Frozen[Hashable, int]:
+ def sizes(self) -> Mapping[Hashable, int]:
...
@property
@@ -146,6 +145,8 @@ def copy(
...
+T_Alignable = TypeVar("T_Alignable", bound="Alignable")
+
T_Backend = TypeVar("T_Backend", bound="BackendEntrypoint")
T_Dataset = TypeVar("T_Dataset", bound="Dataset")
T_DataArray = TypeVar("T_DataArray", bound="DataArray")
@@ -154,29 +155,46 @@ def copy(
T_Array = TypeVar("T_Array", bound="AbstractArray")
T_Index = TypeVar("T_Index", bound="Index")
+# `T_Xarray` is a type variable that can be either "DataArray" or "Dataset". When used
+# in a function definition, all inputs and outputs annotated with `T_Xarray` must be of
+# the same concrete type, either "DataArray" or "Dataset". This is generally preferred
+# over `T_DataArrayOrSet`, given the type system can determine the exact type.
+T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset")
+
+# `T_DataArrayOrSet` is a type variable that is bounded to either "DataArray" or
+# "Dataset". Use it for functions that might return either type, but where the exact
+# type cannot be determined statically using the type system.
T_DataArrayOrSet = TypeVar("T_DataArrayOrSet", bound=Union["Dataset", "DataArray"])
-# Maybe we rename this to T_Data or something less Fortran-y?
-T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset")
+# For working directly with `DataWithCoords`. It will only allow using methods defined
+# on `DataWithCoords`.
T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")
-T_Alignable = TypeVar("T_Alignable", bound="Alignable")
+
# Temporary placeholder for indicating an array api compliant type.
# hopefully in the future we can narrow this down more:
T_DuckArray = TypeVar("T_DuckArray", bound=Any)
ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"]
-DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"]
-DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"]
VarCompatible = Union["Variable", "ScalarOrArray"]
-GroupByIncompatible = Union["Variable", "GroupBy"]
+DaCompatible = Union["DataArray", "VarCompatible"]
+DsCompatible = Union["Dataset", "DaCompatible"]
+GroupByCompatible = Union["Dataset", "DataArray"]
Dims = Union[str, Iterable[Hashable], "ellipsis", None]
OrderedDims = Union[str, Sequence[Union[Hashable, "ellipsis"]], "ellipsis", None]
-T_Chunks = Union[int, dict[Any, Any], Literal["auto"], None]
+# FYI in some cases we don't allow `None`, which this doesn't take account of.
+T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]]
+# We allow the tuple form of this (though arguably we could transition to named dims only)
+T_Chunks: TypeAlias = Union[
+ T_ChunkDim, Mapping[Any, T_ChunkDim], tuple[T_ChunkDim, ...]
+]
T_NormalizedChunks = tuple[tuple[int, ...], ...]
+DataVars = Mapping[Any, Any]
+
+
ErrorOptions = Literal["raise", "ignore"]
ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"]
diff --git a/xarray/core/variable.py b/xarray/core/variable.py
index f459f044751..3f399dcac7c 100644
--- a/xarray/core/variable.py
+++ b/xarray/core/variable.py
@@ -26,10 +26,7 @@
as_indexable,
)
from xarray.core.options import OPTIONS, _get_keep_attrs
-from xarray.core.parallelcompat import (
- get_chunked_array_type,
- guess_chunkmanager,
-)
+from xarray.core.parallelcompat import get_chunked_array_type, guess_chunkmanager
from xarray.core.pycompat import (
integer_types,
is_0d_dask_array,
@@ -38,8 +35,6 @@
to_numpy,
)
from xarray.core.utils import (
- Frozen,
- NdimSizeLenMixin,
OrderedSet,
_default,
decode_numpy_dict_values,
@@ -50,6 +45,7 @@
is_duck_array,
maybe_coerce_to_str,
)
+from xarray.namedarray.core import NamedArray
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
indexing.ExplicitlyIndexed,
@@ -66,8 +62,8 @@
PadModeOptions,
PadReflectOptions,
QuantileMethods,
+ Self,
T_DuckArray,
- T_Variable,
)
NON_NANOSECOND_WARNING = (
@@ -268,7 +264,7 @@ def as_compatible_data(
data = np.timedelta64(getattr(data, "value", data), "ns")
# we don't want nested self-described arrays
- if isinstance(data, (pd.Series, pd.Index, pd.DataFrame)):
+ if isinstance(data, (pd.Series, pd.DataFrame)):
data = data.values
if isinstance(data, np.ma.MaskedArray):
@@ -315,7 +311,7 @@ def _as_array_or_item(data):
return data
-class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic):
+class Variable(NamedArray, AbstractArray, VariableArithmetic):
"""A netcdf-like variable consisting of dimensions, data and attributes
which describe a single Array. A single Variable object is not fully
described outside the context of its parent Dataset (if you want such a
@@ -365,51 +361,14 @@ def __init__(
Well-behaved code to serialize a Variable should ignore
unrecognized encoding items.
"""
- self._data: T_DuckArray = as_compatible_data(data, fastpath=fastpath)
- self._dims = self._parse_dimensions(dims)
- self._attrs: dict[Any, Any] | None = None
+ super().__init__(
+ dims=dims, data=as_compatible_data(data, fastpath=fastpath), attrs=attrs
+ )
+
self._encoding = None
- if attrs is not None:
- self.attrs = attrs
if encoding is not None:
self.encoding = encoding
- @property
- def dtype(self) -> np.dtype:
- """
- Data-type of the array’s elements.
-
- See Also
- --------
- ndarray.dtype
- numpy.dtype
- """
- return self._data.dtype
-
- @property
- def shape(self) -> tuple[int, ...]:
- """
- Tuple of array dimensions.
-
- See Also
- --------
- numpy.ndarray.shape
- """
- return self._data.shape
-
- @property
- def nbytes(self) -> int:
- """
- Total bytes consumed by the elements of the data array.
-
- If the underlying data array does not include ``nbytes``, estimates
- the bytes consumed based on the ``size`` and ``dtype``.
- """
- if hasattr(self._data, "nbytes"):
- return self._data.nbytes
- else:
- return self.size * self.dtype.itemsize
-
@property
def _in_memory(self):
return isinstance(
@@ -420,7 +379,7 @@ def _in_memory(self):
)
@property
- def data(self: T_Variable):
+ def data(self):
"""
The Variable's data as an array. The underlying array type
(e.g. dask, sparse, pint) is preserved.
@@ -439,17 +398,13 @@ def data(self: T_Variable):
return self.values
@data.setter
- def data(self: T_Variable, data: T_DuckArray | ArrayLike) -> None:
+ def data(self, data: T_DuckArray | ArrayLike) -> None:
data = as_compatible_data(data)
- if data.shape != self.shape: # type: ignore[attr-defined]
- raise ValueError(
- f"replacement data must match the Variable's shape. "
- f"replacement data has shape {data.shape}; Variable has shape {self.shape}" # type: ignore[attr-defined]
- )
+ self._check_shape(data)
self._data = data
def astype(
- self: T_Variable,
+ self,
dtype,
*,
order=None,
@@ -457,7 +412,7 @@ def astype(
subok=None,
copy=None,
keep_attrs=True,
- ) -> T_Variable:
+ ) -> Self:
"""
Copy of the Variable object, with data cast to a specified type.
@@ -571,41 +526,6 @@ def compute(self, **kwargs):
new = self.copy(deep=False)
return new.load(**kwargs)
- def __dask_tokenize__(self):
- # Use v.data, instead of v._data, in order to cope with the wrappers
- # around NetCDF and the like
- from dask.base import normalize_token
-
- return normalize_token((type(self), self._dims, self.data, self.attrs))
-
- def __dask_graph__(self):
- if is_duck_dask_array(self._data):
- return self._data.__dask_graph__()
- else:
- return None
-
- def __dask_keys__(self):
- return self._data.__dask_keys__()
-
- def __dask_layers__(self):
- return self._data.__dask_layers__()
-
- @property
- def __dask_optimize__(self):
- return self._data.__dask_optimize__
-
- @property
- def __dask_scheduler__(self):
- return self._data.__dask_scheduler__
-
- def __dask_postcompute__(self):
- array_func, array_args = self._data.__dask_postcompute__()
- return self._dask_finalize, (array_func,) + array_args
-
- def __dask_postpersist__(self):
- array_func, array_args = self._data.__dask_postpersist__()
- return self._dask_finalize, (array_func,) + array_args
-
def _dask_finalize(self, results, array_func, *args, **kwargs):
data = array_func(results, *args, **kwargs)
return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding)
@@ -667,27 +587,6 @@ def to_dict(
return item
- @property
- def dims(self) -> tuple[Hashable, ...]:
- """Tuple of dimension names with which this variable is associated."""
- return self._dims
-
- @dims.setter
- def dims(self, value: str | Iterable[Hashable]) -> None:
- self._dims = self._parse_dimensions(value)
-
- def _parse_dimensions(self, dims: str | Iterable[Hashable]) -> tuple[Hashable, ...]:
- if isinstance(dims, str):
- dims = (dims,)
- else:
- dims = tuple(dims)
- if len(dims) != self.ndim:
- raise ValueError(
- f"dimensions {dims} must have the same length as the "
- f"number of data dimensions, ndim={self.ndim}"
- )
- return dims
-
def _item_key_to_tuple(self, key):
if utils.is_dict_like(key):
return tuple(key.get(dim, slice(None)) for dim in self.dims)
@@ -820,13 +719,6 @@ def _broadcast_indexes_outer(self, key):
return dims, OuterIndexer(tuple(new_key)), None
- def _nonzero(self):
- """Equivalent numpy's nonzero but returns a tuple of Variables."""
- # TODO we should replace dask's native nonzero
- # after https://github.com/dask/dask/issues/1076 is implemented.
- nonzeros = np.nonzero(self.data)
- return tuple(Variable((dim), nz) for nz, dim in zip(nonzeros, self.dims))
-
def _broadcast_indexes_vectorized(self, key):
variables = []
out_dims_set = OrderedSet()
@@ -883,7 +775,7 @@ def _broadcast_indexes_vectorized(self, key):
return out_dims, VectorizedIndexer(tuple(out_key)), new_order
- def __getitem__(self: T_Variable, key) -> T_Variable:
+ def __getitem__(self, key) -> Self:
"""Return a new Variable object whose contents are consistent with
getting the provided key from the underlying data.
@@ -902,7 +794,7 @@ def __getitem__(self: T_Variable, key) -> T_Variable:
data = np.moveaxis(data, range(len(new_order)), new_order)
return self._finalize_indexing_result(dims, data)
- def _finalize_indexing_result(self: T_Variable, dims, data) -> T_Variable:
+ def _finalize_indexing_result(self, dims, data) -> Self:
"""Used by IndexVariable to return IndexVariable objects when possible."""
return self._replace(dims=dims, data=data)
@@ -976,17 +868,6 @@ def __setitem__(self, key, value):
indexable = as_indexable(self._data)
indexable[index_tuple] = value
- @property
- def attrs(self) -> dict[Any, Any]:
- """Dictionary of local attributes on this variable."""
- if self._attrs is None:
- self._attrs = {}
- return self._attrs
-
- @attrs.setter
- def attrs(self, value: Mapping[Any, Any]) -> None:
- self._attrs = dict(value)
-
@property
def encoding(self) -> dict[Any, Any]:
"""Dictionary of encodings on this variable."""
@@ -1001,76 +882,22 @@ def encoding(self, value):
except ValueError:
raise ValueError("encoding must be castable to a dictionary")
- def reset_encoding(self: T_Variable) -> T_Variable:
+ def reset_encoding(self) -> Self:
+ warnings.warn(
+ "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead"
+ )
+ return self.drop_encoding()
+
+ def drop_encoding(self) -> Self:
"""Return a new Variable without encoding."""
return self._replace(encoding={})
- def copy(
- self: T_Variable, deep: bool = True, data: T_DuckArray | ArrayLike | None = None
- ) -> T_Variable:
- """Returns a copy of this object.
-
- If `deep=True`, the data array is loaded into memory and copied onto
- the new object. Dimensions, attributes and encodings are always copied.
-
- Use `data` to create a new object with the same structure as
- original but entirely new data.
-
- Parameters
- ----------
- deep : bool, default: True
- Whether the data array is loaded into memory and copied onto
- the new object. Default is True.
- data : array_like, optional
- Data to use in the new object. Must have same shape as original.
- When `data` is used, `deep` is ignored.
-
- Returns
- -------
- object : Variable
- New object with dimensions, attributes, encodings, and optionally
- data copied from original.
-
- Examples
- --------
- Shallow copy versus deep copy
-
- >>> var = xr.Variable(data=[1, 2, 3], dims="x")
- >>> var.copy()
-
- array([1, 2, 3])
- >>> var_0 = var.copy(deep=False)
- >>> var_0[0] = 7
- >>> var_0
-
- array([7, 2, 3])
- >>> var
-
- array([7, 2, 3])
-
- Changing the data using the ``data`` argument maintains the
- structure of the original object, but with the new data. Original
- object is unaffected.
-
- >>> var.copy(data=[0.1, 0.2, 0.3])
-
- array([0.1, 0.2, 0.3])
- >>> var
-
- array([7, 2, 3])
-
- See Also
- --------
- pandas.DataFrame.copy
- """
- return self._copy(deep=deep, data=data)
-
def _copy(
- self: T_Variable,
+ self,
deep: bool = True,
data: T_DuckArray | ArrayLike | None = None,
memo: dict[int, Any] | None = None,
- ) -> T_Variable:
+ ) -> Self:
if data is None:
data_old = self._data
@@ -1099,71 +926,23 @@ def _copy(
return self._replace(data=ndata, attrs=attrs, encoding=encoding)
def _replace(
- self: T_Variable,
+ self,
dims=_default,
data=_default,
attrs=_default,
encoding=_default,
- ) -> T_Variable:
+ ) -> Self:
if dims is _default:
dims = copy.copy(self._dims)
if data is _default:
data = copy.copy(self.data)
if attrs is _default:
attrs = copy.copy(self._attrs)
+
if encoding is _default:
encoding = copy.copy(self._encoding)
return type(self)(dims, data, attrs, encoding, fastpath=True)
- def __copy__(self: T_Variable) -> T_Variable:
- return self._copy(deep=False)
-
- def __deepcopy__(
- self: T_Variable, memo: dict[int, Any] | None = None
- ) -> T_Variable:
- return self._copy(deep=True, memo=memo)
-
- # mutable objects should not be hashable
- # https://github.com/python/mypy/issues/4266
- __hash__ = None # type: ignore[assignment]
-
- @property
- def chunks(self) -> tuple[tuple[int, ...], ...] | None:
- """
- Tuple of block lengths for this dataarray's data, in order of dimensions, or None if
- the underlying data is not a dask array.
-
- See Also
- --------
- Variable.chunk
- Variable.chunksizes
- xarray.unify_chunks
- """
- return getattr(self._data, "chunks", None)
-
- @property
- def chunksizes(self) -> Mapping[Any, tuple[int, ...]]:
- """
- Mapping from dimension names to block lengths for this variable's data, or None if
- the underlying data is not a dask array.
- Cannot be modified directly, but can be modified by calling .chunk().
-
- Differs from variable.chunks because it returns a mapping of dimensions to chunk shapes
- instead of a tuple of chunk shapes.
-
- See Also
- --------
- Variable.chunk
- Variable.chunks
- xarray.unify_chunks
- """
- if hasattr(self._data, "chunks"):
- return Frozen({dim: c for dim, c in zip(self.dims, self.data.chunks)})
- else:
- return {}
-
- _array_counter = itertools.count()
-
def chunk(
self,
chunks: (
@@ -1179,7 +958,7 @@ def chunk(
chunked_array_type: str | ChunkManagerEntrypoint | None = None,
from_array_kwargs=None,
**chunks_kwargs: Any,
- ) -> Variable:
+ ) -> Self:
"""Coerce this array's data into a dask array with the given chunks.
If this variable is a non-dask array, it will be converted to dask
@@ -1262,7 +1041,7 @@ def chunk(
data_old = self._data
if chunkmanager.is_chunked_array(data_old):
- data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type]
+ data_chunked = chunkmanager.rechunk(data_old, chunks)
else:
if isinstance(data_old, indexing.ExplicitlyIndexed):
# Unambiguously handle array storage backends (like NetCDF4 and h5py)
@@ -1284,7 +1063,7 @@ def chunk(
data_chunked = chunkmanager.from_array(
ndata,
- chunks, # type: ignore[arg-type]
+ chunks,
**_from_array_kwargs,
)
@@ -1295,46 +1074,16 @@ def to_numpy(self) -> np.ndarray:
# TODO an entrypoint so array libraries can choose coercion method?
return to_numpy(self._data)
- def as_numpy(self: T_Variable) -> T_Variable:
+ def as_numpy(self) -> Self:
"""Coerces wrapped data into a numpy array, returning a Variable."""
return self._replace(data=self.to_numpy())
- def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA):
- """
- use sparse-array as backend.
- """
- import sparse
-
- # TODO: what to do if dask-backended?
- if fill_value is dtypes.NA:
- dtype, fill_value = dtypes.maybe_promote(self.dtype)
- else:
- dtype = dtypes.result_type(self.dtype, fill_value)
-
- if sparse_format is _default:
- sparse_format = "coo"
- try:
- as_sparse = getattr(sparse, f"as_{sparse_format.lower()}")
- except AttributeError:
- raise ValueError(f"{sparse_format} is not a valid sparse format")
-
- data = as_sparse(self.data.astype(dtype), fill_value=fill_value)
- return self._replace(data=data)
-
- def _to_dense(self):
- """
- Change backend from sparse to np.array
- """
- if hasattr(self._data, "todense"):
- return self._replace(data=self._data.todense())
- return self.copy(deep=False)
-
def isel(
- self: T_Variable,
+ self,
indexers: Mapping[Any, Any] | None = None,
missing_dims: ErrorOptionsWithWarn = "raise",
**indexers_kwargs: Any,
- ) -> T_Variable:
+ ) -> Self:
"""Return a new array indexed along the specified dimension(s).
Parameters
@@ -1621,7 +1370,7 @@ def transpose(
self,
*dims: Hashable | ellipsis,
missing_dims: ErrorOptionsWithWarn = "raise",
- ) -> Variable:
+ ) -> Self:
"""Return a new Variable object with transposed dimensions.
Parameters
@@ -1666,7 +1415,7 @@ def transpose(
return self._replace(dims=dims, data=data)
@property
- def T(self) -> Variable:
+ def T(self) -> Self:
return self.transpose()
def set_dims(self, dims, shape=None):
@@ -1774,9 +1523,7 @@ def stack(self, dimensions=None, **dimensions_kwargs):
result = result._stack_once(dims, new_dim)
return result
- def _unstack_once_full(
- self, dims: Mapping[Any, int], old_dim: Hashable
- ) -> Variable:
+ def _unstack_once_full(self, dims: Mapping[Any, int], old_dim: Hashable) -> Self:
"""
Unstacks the variable without needing an index.
@@ -1809,7 +1556,9 @@ def _unstack_once_full(
new_data = reordered.data.reshape(new_shape)
new_dims = reordered.dims[: len(other_dims)] + new_dim_names
- return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True)
+ return type(self)(
+ new_dims, new_data, self._attrs, self._encoding, fastpath=True
+ )
def _unstack_once(
self,
@@ -1817,7 +1566,7 @@ def _unstack_once(
dim: Hashable,
fill_value=dtypes.NA,
sparse: bool = False,
- ) -> Variable:
+ ) -> Self:
"""
Unstacks this variable given an index to unstack and the name of the
dimension to which the index refers.
@@ -2029,6 +1778,8 @@ def reduce(
keep_attrs = _get_keep_attrs(default=False)
attrs = self._attrs if keep_attrs else None
+ # We need to return `Variable` rather than the type of `self` at the moment, ref
+ # #8216
return Variable(dims, data, attrs=attrs)
@classmethod
@@ -2178,7 +1929,7 @@ def quantile(
keep_attrs: bool | None = None,
skipna: bool | None = None,
interpolation: QuantileMethods | None = None,
- ) -> Variable:
+ ) -> Self:
"""Compute the qth quantile of the data along the specified dimension.
Returns the qth quantiles(s) of the array elements.
@@ -2564,7 +2315,7 @@ def coarsen_reshape(self, windows, boundary, side):
else:
shape.append(variable.shape[i])
- return variable.data.reshape(shape), tuple(axes)
+ return duck_array_ops.reshape(variable.data, shape), tuple(axes)
def isnull(self, keep_attrs: bool | None = None):
"""Test each value in the array for whether it is a missing value.
@@ -2634,28 +2385,6 @@ def notnull(self, keep_attrs: bool | None = None):
keep_attrs=keep_attrs,
)
- @property
- def real(self):
- """
- The real part of the variable.
-
- See Also
- --------
- numpy.ndarray.real
- """
- return self._replace(data=self.data.real)
-
- @property
- def imag(self):
- """
- The imaginary part of the variable.
-
- See Also
- --------
- numpy.ndarray.imag
- """
- return self._replace(data=self.data.imag)
-
def __array_wrap__(self, obj, context=None):
return Variable(self.dims, obj)
diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py
index 82ffe684ec7..28740a99020 100644
--- a/xarray/core/weighted.py
+++ b/xarray/core/weighted.py
@@ -11,6 +11,7 @@
from xarray.core.computation import apply_ufunc, dot
from xarray.core.pycompat import is_duck_dask_array
from xarray.core.types import Dims, T_Xarray
+from xarray.util.deprecation_helpers import _deprecate_positional_args
# Weighted quantile methods are a subset of the numpy supported quantile methods.
QUANTILE_METHODS = Literal[
@@ -324,6 +325,7 @@ def _weighted_quantile(
def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray:
"""Return the interpolation parameter."""
# Note that options are not yet exposed in the public API.
+ h: np.ndarray
if method == "linear":
h = (n - 1) * q + 1
elif method == "interpolated_inverted_cdf":
@@ -449,18 +451,22 @@ def _weighted_quantile_1d(
def _implementation(self, func, dim, **kwargs):
raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`")
+ @_deprecate_positional_args("v2023.10.0")
def sum_of_weights(
self,
dim: Dims = None,
+ *,
keep_attrs: bool | None = None,
) -> T_Xarray:
return self._implementation(
self._sum_of_weights, dim=dim, keep_attrs=keep_attrs
)
+ @_deprecate_positional_args("v2023.10.0")
def sum_of_squares(
self,
dim: Dims = None,
+ *,
skipna: bool | None = None,
keep_attrs: bool | None = None,
) -> T_Xarray:
@@ -468,9 +474,11 @@ def sum_of_squares(
self._sum_of_squares, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)
+ @_deprecate_positional_args("v2023.10.0")
def sum(
self,
dim: Dims = None,
+ *,
skipna: bool | None = None,
keep_attrs: bool | None = None,
) -> T_Xarray:
@@ -478,9 +486,11 @@ def sum(
self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)
+ @_deprecate_positional_args("v2023.10.0")
def mean(
self,
dim: Dims = None,
+ *,
skipna: bool | None = None,
keep_attrs: bool | None = None,
) -> T_Xarray:
@@ -488,9 +498,11 @@ def mean(
self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)
+ @_deprecate_positional_args("v2023.10.0")
def var(
self,
dim: Dims = None,
+ *,
skipna: bool | None = None,
keep_attrs: bool | None = None,
) -> T_Xarray:
@@ -498,9 +510,11 @@ def var(
self._weighted_var, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)
+ @_deprecate_positional_args("v2023.10.0")
def std(
self,
dim: Dims = None,
+ *,
skipna: bool | None = None,
keep_attrs: bool | None = None,
) -> T_Xarray:
diff --git a/xarray/namedarray/__init__.py b/xarray/namedarray/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py
new file mode 100644
index 00000000000..ec3d8fa171b
--- /dev/null
+++ b/xarray/namedarray/core.py
@@ -0,0 +1,509 @@
+from __future__ import annotations
+
+import copy
+import math
+from collections.abc import Hashable, Iterable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast
+
+import numpy as np
+
+# TODO: get rid of this after migrating this class to array API
+from xarray.core import dtypes
+from xarray.core.indexing import ExplicitlyIndexed
+from xarray.namedarray.utils import (
+ Default,
+ T_DuckArray,
+ _default,
+ is_chunked_duck_array,
+ is_duck_array,
+ is_duck_dask_array,
+ to_0d_object_array,
+)
+
+if TYPE_CHECKING:
+ from xarray.namedarray.utils import Self # type: ignore[attr-defined]
+
+ try:
+ from dask.typing import (
+ Graph,
+ NestedKeys,
+ PostComputeCallable,
+ PostPersistCallable,
+ SchedulerGetCallable,
+ )
+ except ImportError:
+ Graph: Any # type: ignore[no-redef]
+ NestedKeys: Any # type: ignore[no-redef]
+ SchedulerGetCallable: Any # type: ignore[no-redef]
+ PostComputeCallable: Any # type: ignore[no-redef]
+ PostPersistCallable: Any # type: ignore[no-redef]
+
+ # T_NamedArray = TypeVar("T_NamedArray", bound="NamedArray[T_DuckArray]")
+ DimsInput = Union[str, Iterable[Hashable]]
+ Dims = tuple[Hashable, ...]
+ AttrsInput = Union[Mapping[Any, Any], None]
+
+
+# TODO: Add tests!
+def as_compatible_data(
+ data: T_DuckArray | np.typing.ArrayLike, fastpath: bool = False
+) -> T_DuckArray:
+ if fastpath and getattr(data, "ndim", 0) > 0:
+ # can't use fastpath (yet) for scalars
+ return cast(T_DuckArray, data)
+
+ if isinstance(data, np.ma.MaskedArray):
+ mask = np.ma.getmaskarray(data) # type: ignore[no-untyped-call]
+ if mask.any():
+ # TODO: requires refactoring/vendoring xarray.core.dtypes and xarray.core.duck_array_ops
+ raise NotImplementedError("MaskedArray is not supported yet")
+ else:
+ return cast(T_DuckArray, np.asarray(data))
+ if is_duck_array(data):
+ return data
+ if isinstance(data, NamedArray):
+ return cast(T_DuckArray, data.data)
+
+ if isinstance(data, ExplicitlyIndexed):
+ # TODO: better that is_duck_array(ExplicitlyIndexed) -> True
+ return cast(T_DuckArray, data)
+
+ if isinstance(data, tuple):
+ data = to_0d_object_array(data)
+
+ # validate whether the data is valid data types.
+ return cast(T_DuckArray, np.asarray(data))
+
+
+class NamedArray(Generic[T_DuckArray]):
+
+ """A lightweight wrapper around duck arrays with named dimensions and attributes which describe a single Array.
+ Numeric operations on this object implement array broadcasting and dimension alignment based on dimension names,
+ rather than axis order."""
+
+ __slots__ = ("_data", "_dims", "_attrs")
+
+ _data: T_DuckArray
+ _dims: Dims
+ _attrs: dict[Any, Any] | None
+
+ def __init__(
+ self,
+ dims: DimsInput,
+ data: T_DuckArray | np.typing.ArrayLike,
+ attrs: AttrsInput = None,
+ fastpath: bool = False,
+ ):
+ """
+ Parameters
+ ----------
+ dims : str or iterable of str
+ Name(s) of the dimension(s).
+ data : T_DuckArray or np.typing.ArrayLike
+ The actual data that populates the array. Should match the shape specified by `dims`.
+ attrs : dict, optional
+ A dictionary containing any additional information or attributes you want to store with the array.
+ Default is None, meaning no attributes will be stored.
+ fastpath : bool, optional
+ A flag to indicate if certain validations should be skipped for performance reasons.
+ Should only be True if you are certain about the integrity of the input data.
+ Default is False.
+
+ Raises
+ ------
+ ValueError
+ If the `dims` length does not match the number of data dimensions (ndim).
+
+
+ """
+ self._data = as_compatible_data(data, fastpath=fastpath)
+ self._dims = self._parse_dimensions(dims)
+ self._attrs = dict(attrs) if attrs else None
+
+ @property
+ def ndim(self) -> int:
+ """
+ Number of array dimensions.
+
+ See Also
+ --------
+ numpy.ndarray.ndim
+ """
+ return len(self.shape)
+
+ @property
+ def size(self) -> int:
+ """
+ Number of elements in the array.
+
+ Equal to ``np.prod(a.shape)``, i.e., the product of the array’s dimensions.
+
+ See Also
+ --------
+ numpy.ndarray.size
+ """
+ return math.prod(self.shape)
+
+ def __len__(self) -> int:
+ try:
+ return self.shape[0]
+ except Exception as exc:
+ raise TypeError("len() of unsized object") from exc
+
+ @property
+ def dtype(self) -> np.dtype[Any]:
+ """
+ Data-type of the array’s elements.
+
+ See Also
+ --------
+ ndarray.dtype
+ numpy.dtype
+ """
+ return self._data.dtype
+
+ @property
+ def shape(self) -> tuple[int, ...]:
+ """
+
+
+ Returns
+ -------
+ shape : tuple of ints
+ Tuple of array dimensions.
+
+
+
+ See Also
+ --------
+ numpy.ndarray.shape
+ """
+ return self._data.shape
+
+ @property
+ def nbytes(self) -> int:
+ """
+ Total bytes consumed by the elements of the data array.
+
+ If the underlying data array does not include ``nbytes``, estimates
+ the bytes consumed based on the ``size`` and ``dtype``.
+ """
+ if hasattr(self._data, "nbytes"):
+ return self._data.nbytes # type: ignore[no-any-return]
+ else:
+ return self.size * self.dtype.itemsize
+
+ @property
+ def dims(self) -> Dims:
+ """Tuple of dimension names with which this NamedArray is associated."""
+ return self._dims
+
+ @dims.setter
+ def dims(self, value: DimsInput) -> None:
+ self._dims = self._parse_dimensions(value)
+
+ def _parse_dimensions(self, dims: DimsInput) -> Dims:
+ dims = (dims,) if isinstance(dims, str) else tuple(dims)
+ if len(dims) != self.ndim:
+ raise ValueError(
+ f"dimensions {dims} must have the same length as the "
+ f"number of data dimensions, ndim={self.ndim}"
+ )
+ return dims
+
+ @property
+ def attrs(self) -> dict[Any, Any]:
+ """Dictionary of local attributes on this NamedArray."""
+ if self._attrs is None:
+ self._attrs = {}
+ return self._attrs
+
+ @attrs.setter
+ def attrs(self, value: Mapping[Any, Any]) -> None:
+ self._attrs = dict(value)
+
+ def _check_shape(self, new_data: T_DuckArray) -> None:
+ if new_data.shape != self.shape:
+ raise ValueError(
+ f"replacement data must match the {self.__class__.__name__}'s shape. "
+ f"replacement data has shape {new_data.shape}; {self.__class__.__name__} has shape {self.shape}"
+ )
+
+ @property
+ def data(self) -> T_DuckArray:
+ """
+ The NamedArray's data as an array. The underlying array type
+ (e.g. dask, sparse, pint) is preserved.
+
+ """
+
+ return self._data
+
+ @data.setter
+ def data(self, data: T_DuckArray | np.typing.ArrayLike) -> None:
+ data = as_compatible_data(data)
+ self._check_shape(data)
+ self._data = data
+
+ @property
+ def real(self) -> Self:
+ """
+ The real part of the NamedArray.
+
+ See Also
+ --------
+ numpy.ndarray.real
+ """
+ return self._replace(data=self.data.real)
+
+ @property
+ def imag(self) -> Self:
+ """
+ The imaginary part of the NamedArray.
+
+ See Also
+ --------
+ numpy.ndarray.imag
+ """
+ return self._replace(data=self.data.imag)
+
+ def __dask_tokenize__(self) -> Hashable:
+ # Use v.data, instead of v._data, in order to cope with the wrappers
+ # around NetCDF and the like
+ from dask.base import normalize_token
+
+ s, d, a, attrs = type(self), self._dims, self.data, self.attrs
+ return normalize_token((s, d, a, attrs)) # type: ignore[no-any-return]
+
+ def __dask_graph__(self) -> Graph | None:
+ if is_duck_dask_array(self._data):
+ return self._data.__dask_graph__()
+ else:
+ # TODO: Should this method just raise instead?
+ # raise NotImplementedError("Method requires self.data to be a dask array")
+ return None
+
+ def __dask_keys__(self) -> NestedKeys:
+ if is_duck_dask_array(self._data):
+ return self._data.__dask_keys__()
+ else:
+ raise AttributeError("Method requires self.data to be a dask array.")
+
+ def __dask_layers__(self) -> Sequence[str]:
+ if is_duck_dask_array(self._data):
+ return self._data.__dask_layers__()
+ else:
+ raise AttributeError("Method requires self.data to be a dask array.")
+
+ @property
+ def __dask_optimize__(
+ self,
+ ) -> Callable[..., dict[Any, Any]]:
+ if is_duck_dask_array(self._data):
+ return self._data.__dask_optimize__ # type: ignore[no-any-return]
+ else:
+ raise AttributeError("Method requires self.data to be a dask array.")
+
+ @property
+ def __dask_scheduler__(self) -> SchedulerGetCallable:
+ if is_duck_dask_array(self._data):
+ return self._data.__dask_scheduler__
+ else:
+ raise AttributeError("Method requires self.data to be a dask array.")
+
+ def __dask_postcompute__(
+ self,
+ ) -> tuple[PostComputeCallable, tuple[Any, ...]]:
+ if is_duck_dask_array(self._data):
+ array_func, array_args = self._data.__dask_postcompute__() # type: ignore[no-untyped-call]
+ return self._dask_finalize, (array_func,) + array_args
+ else:
+ raise AttributeError("Method requires self.data to be a dask array.")
+
+ def __dask_postpersist__(
+ self,
+ ) -> tuple[
+ Callable[
+ [Graph, PostPersistCallable[Any], Any, Any],
+ Self,
+ ],
+ tuple[Any, ...],
+ ]:
+ if is_duck_dask_array(self._data):
+ a: tuple[PostPersistCallable[Any], tuple[Any, ...]]
+ a = self._data.__dask_postpersist__() # type: ignore[no-untyped-call]
+ array_func, array_args = a
+
+ return self._dask_finalize, (array_func,) + array_args
+ else:
+ raise AttributeError("Method requires self.data to be a dask array.")
+
+ def _dask_finalize(
+ self,
+ results: Graph,
+ array_func: PostPersistCallable[Any],
+ *args: Any,
+ **kwargs: Any,
+ ) -> Self:
+ data = array_func(results, *args, **kwargs)
+ return type(self)(self._dims, data, attrs=self._attrs)
+
+ @property
+ def chunks(self) -> tuple[tuple[int, ...], ...] | None:
+ """
+ Tuple of block lengths for this NamedArray's data, in order of dimensions, or None if
+ the underlying data is not a dask array.
+
+ See Also
+ --------
+ NamedArray.chunk
+ NamedArray.chunksizes
+ xarray.unify_chunks
+ """
+ data = self._data
+ if is_chunked_duck_array(data):
+ return data.chunks
+ else:
+ return None
+
+ @property
+ def chunksizes(
+ self,
+ ) -> Mapping[Any, tuple[int, ...]]:
+ """
+ Mapping from dimension names to block lengths for this namedArray's data, or None if
+ the underlying data is not a dask array.
+ Cannot be modified directly, but can be modified by calling .chunk().
+
+ Differs from NamedArray.chunks because it returns a mapping of dimensions to chunk shapes
+ instead of a tuple of chunk shapes.
+
+ See Also
+ --------
+ NamedArray.chunk
+ NamedArray.chunks
+ xarray.unify_chunks
+ """
+ data = self._data
+ if is_chunked_duck_array(data):
+ return dict(zip(self.dims, data.chunks))
+ else:
+ return {}
+
+ @property
+ def sizes(self) -> dict[Hashable, int]:
+ """Ordered mapping from dimension names to lengths."""
+ return dict(zip(self.dims, self.shape))
+
+ def _replace(
+ self,
+ dims: DimsInput | Default = _default,
+ data: T_DuckArray | np.typing.ArrayLike | Default = _default,
+ attrs: AttrsInput | Default = _default,
+ ) -> Self:
+ if dims is _default:
+ dims = copy.copy(self._dims)
+ if data is _default:
+ data = copy.copy(self._data)
+ if attrs is _default:
+ attrs = copy.copy(self._attrs)
+ return type(self)(dims, data, attrs)
+
+ def _copy(
+ self,
+ deep: bool = True,
+ data: T_DuckArray | np.typing.ArrayLike | None = None,
+ memo: dict[int, Any] | None = None,
+ ) -> Self:
+ if data is None:
+ ndata = self._data
+ if deep:
+ ndata = copy.deepcopy(ndata, memo=memo)
+ else:
+ ndata = as_compatible_data(data)
+ self._check_shape(ndata)
+
+ attrs = (
+ copy.deepcopy(self._attrs, memo=memo) if deep else copy.copy(self._attrs)
+ )
+
+ return self._replace(data=ndata, attrs=attrs)
+
+ def __copy__(self) -> Self:
+ return self._copy(deep=False)
+
+ def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
+ return self._copy(deep=True, memo=memo)
+
+ def copy(
+ self,
+ deep: bool = True,
+ data: T_DuckArray | np.typing.ArrayLike | None = None,
+ ) -> Self:
+ """Returns a copy of this object.
+
+ If `deep=True`, the data array is loaded into memory and copied onto
+ the new object. Dimensions, attributes and encodings are always copied.
+
+ Use `data` to create a new object with the same structure as
+ original but entirely new data.
+
+ Parameters
+ ----------
+ deep : bool, default: True
+ Whether the data array is loaded into memory and copied onto
+ the new object. Default is True.
+ data : array_like, optional
+ Data to use in the new object. Must have same shape as original.
+ When `data` is used, `deep` is ignored.
+
+ Returns
+ -------
+ object : NamedArray
+ New object with dimensions, attributes, and optionally
+ data copied from original.
+
+
+ """
+ return self._copy(deep=deep, data=data)
+
+ def _nonzero(self) -> tuple[Self, ...]:
+ """Equivalent numpy's nonzero but returns a tuple of NamedArrays."""
+ # TODO we should replace dask's native nonzero
+ # after https://github.com/dask/dask/issues/1076 is implemented.
+ nonzeros = np.nonzero(self.data)
+ return tuple(type(self)((dim,), nz) for nz, dim in zip(nonzeros, self.dims))
+
+ def _as_sparse(
+ self,
+ sparse_format: str | Default = _default,
+ fill_value: np.typing.ArrayLike | Default = _default,
+ ) -> Self:
+ """
+ use sparse-array as backend.
+ """
+ import sparse
+
+ # TODO: what to do if dask-backended?
+ if fill_value is _default:
+ dtype, fill_value = dtypes.maybe_promote(self.dtype)
+ else:
+ dtype = dtypes.result_type(self.dtype, fill_value)
+
+ if sparse_format is _default:
+ sparse_format = "coo"
+ try:
+ as_sparse = getattr(sparse, f"as_{sparse_format.lower()}")
+ except AttributeError as exc:
+ raise ValueError(f"{sparse_format} is not a valid sparse format") from exc
+
+ data = as_sparse(self.data.astype(dtype), fill_value=fill_value)
+ return self._replace(data=data)
+
+ def _to_dense(self) -> Self:
+ """
+ Change backend from sparse to np.array
+ """
+ if hasattr(self._data, "todense"):
+ return self._replace(data=self._data.todense())
+ return self.copy(deep=False)
diff --git a/xarray/namedarray/dtypes.py b/xarray/namedarray/dtypes.py
new file mode 100644
index 00000000000..7a83bd17064
--- /dev/null
+++ b/xarray/namedarray/dtypes.py
@@ -0,0 +1,199 @@
+from __future__ import annotations
+
+import functools
+import sys
+from typing import Any, Literal
+
+if sys.version_info >= (3, 10):
+ from typing import TypeGuard
+else:
+ from typing_extensions import TypeGuard
+
+import numpy as np
+
+from xarray.namedarray import utils
+
+# Use as a sentinel value to indicate a dtype appropriate NA value.
+NA = utils.ReprObject("")
+
+
+@functools.total_ordering
+class AlwaysGreaterThan:
+ def __gt__(self, other: Any) -> Literal[True]:
+ return True
+
+ def __eq__(self, other: Any) -> bool:
+ return isinstance(other, type(self))
+
+
+@functools.total_ordering
+class AlwaysLessThan:
+ def __lt__(self, other: Any) -> Literal[True]:
+ return True
+
+ def __eq__(self, other: Any) -> bool:
+ return isinstance(other, type(self))
+
+
+# Equivalence to np.inf (-np.inf) for object-type
+INF = AlwaysGreaterThan()
+NINF = AlwaysLessThan()
+
+
+# Pairs of types that, if both found, should be promoted to object dtype
+# instead of following NumPy's own type-promotion rules. These type promotion
+# rules match pandas instead. For reference, see the NumPy type hierarchy:
+# https://numpy.org/doc/stable/reference/arrays.scalars.html
+PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = (
+ (np.number, np.character), # numpy promotes to character
+ (np.bool_, np.character), # numpy promotes to character
+ (np.bytes_, np.str_), # numpy promotes to unicode
+)
+
+
+def maybe_promote(dtype: np.dtype[np.generic]) -> tuple[np.dtype[np.generic], Any]:
+ """Simpler equivalent of pandas.core.common._maybe_promote
+
+ Parameters
+ ----------
+ dtype : np.dtype
+
+ Returns
+ -------
+ dtype : Promoted dtype that can hold missing values.
+ fill_value : Valid missing value for the promoted dtype.
+ """
+ # N.B. these casting rules should match pandas
+ dtype_: np.typing.DTypeLike
+ fill_value: Any
+ if np.issubdtype(dtype, np.floating):
+ dtype_ = dtype
+ fill_value = np.nan
+ elif np.issubdtype(dtype, np.timedelta64):
+ # See https://github.com/numpy/numpy/issues/10685
+ # np.timedelta64 is a subclass of np.integer
+ # Check np.timedelta64 before np.integer
+ fill_value = np.timedelta64("NaT")
+ dtype_ = dtype
+ elif np.issubdtype(dtype, np.integer):
+ dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64
+ fill_value = np.nan
+ elif np.issubdtype(dtype, np.complexfloating):
+ dtype_ = dtype
+ fill_value = np.nan + np.nan * 1j
+ elif np.issubdtype(dtype, np.datetime64):
+ dtype_ = dtype
+ fill_value = np.datetime64("NaT")
+ else:
+ dtype_ = object
+ fill_value = np.nan
+
+ dtype_out = np.dtype(dtype_)
+ fill_value = dtype_out.type(fill_value)
+ return dtype_out, fill_value
+
+
+NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype}
+
+
+def get_fill_value(dtype: np.dtype[np.generic]) -> Any:
+ """Return an appropriate fill value for this dtype.
+
+ Parameters
+ ----------
+ dtype : np.dtype
+
+ Returns
+ -------
+ fill_value : Missing value corresponding to this dtype.
+ """
+ _, fill_value = maybe_promote(dtype)
+ return fill_value
+
+
+def get_pos_infinity(
+ dtype: np.dtype[np.generic], max_for_int: bool = False
+) -> float | complex | AlwaysGreaterThan:
+ """Return an appropriate positive infinity for this dtype.
+
+ Parameters
+ ----------
+ dtype : np.dtype
+ max_for_int : bool
+ Return np.iinfo(dtype).max instead of np.inf
+
+ Returns
+ -------
+ fill_value : positive infinity value corresponding to this dtype.
+ """
+ if issubclass(dtype.type, np.floating):
+ return np.inf
+
+ if issubclass(dtype.type, np.integer):
+ return np.iinfo(dtype.type).max if max_for_int else np.inf
+ if issubclass(dtype.type, np.complexfloating):
+ return np.inf + 1j * np.inf
+
+ return INF
+
+
+def get_neg_infinity(
+ dtype: np.dtype[np.generic], min_for_int: bool = False
+) -> float | complex | AlwaysLessThan:
+ """Return an appropriate positive infinity for this dtype.
+
+ Parameters
+ ----------
+ dtype : np.dtype
+ min_for_int : bool
+ Return np.iinfo(dtype).min instead of -np.inf
+
+ Returns
+ -------
+ fill_value : positive infinity value corresponding to this dtype.
+ """
+ if issubclass(dtype.type, np.floating):
+ return -np.inf
+
+ if issubclass(dtype.type, np.integer):
+ return np.iinfo(dtype.type).min if min_for_int else -np.inf
+ if issubclass(dtype.type, np.complexfloating):
+ return -np.inf - 1j * np.inf
+
+ return NINF
+
+
+def is_datetime_like(
+ dtype: np.dtype[np.generic],
+) -> TypeGuard[np.datetime64 | np.timedelta64]:
+ """Check if a dtype is a subclass of the numpy datetime types"""
+ return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
+
+
+def result_type(
+ *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
+) -> np.dtype[np.generic]:
+ """Like np.result_type, but with type promotion rules matching pandas.
+
+ Examples of changed behavior:
+ number + string -> object (not string)
+ bytes + unicode -> object (not unicode)
+
+ Parameters
+ ----------
+ *arrays_and_dtypes : list of arrays and dtypes
+ The dtype is extracted from both numpy and dask arrays.
+
+ Returns
+ -------
+ numpy.dtype for the result.
+ """
+ types = {np.result_type(t).type for t in arrays_and_dtypes}
+
+ for left, right in PROMOTE_TO_OBJECT:
+ if any(issubclass(t, left) for t in types) and any(
+ issubclass(t, right) for t in types
+ ):
+ return np.dtype(object)
+
+ return np.result_type(*arrays_and_dtypes)
diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py
new file mode 100644
index 00000000000..6f7658ea00b
--- /dev/null
+++ b/xarray/namedarray/utils.py
@@ -0,0 +1,163 @@
+from __future__ import annotations
+
+import importlib
+import sys
+from collections.abc import Hashable
+from enum import Enum
+from typing import TYPE_CHECKING, Any, Final, Protocol, TypeVar
+
+import numpy as np
+
+if TYPE_CHECKING:
+ if sys.version_info >= (3, 10):
+ from typing import TypeGuard
+ else:
+ from typing_extensions import TypeGuard
+
+ if sys.version_info >= (3, 11):
+ from typing import Self
+ else:
+ from typing_extensions import Self
+
+ try:
+ from dask.array import Array as DaskArray
+ from dask.types import DaskCollection
+ except ImportError:
+ DaskArray = np.ndarray # type: ignore
+ DaskCollection: Any = np.ndarray # type: ignore
+
+
+# https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array
+T_DType_co = TypeVar("T_DType_co", bound=np.dtype[np.generic], covariant=True)
+# T_DType = TypeVar("T_DType", bound=np.dtype[np.generic])
+
+
+class _Array(Protocol[T_DType_co]):
+ @property
+ def dtype(self) -> T_DType_co:
+ ...
+
+ @property
+ def shape(self) -> tuple[int, ...]:
+ ...
+
+ @property
+ def real(self) -> Self:
+ ...
+
+ @property
+ def imag(self) -> Self:
+ ...
+
+ def astype(self, dtype: np.typing.DTypeLike) -> Self:
+ ...
+
+ # TODO: numpy doesn't use any inputs:
+ # https://github.com/numpy/numpy/blob/v1.24.3/numpy/_typing/_array_like.py#L38
+ def __array__(self) -> np.ndarray[Any, T_DType_co]:
+ ...
+
+
+class _ChunkedArray(_Array[T_DType_co], Protocol[T_DType_co]):
+ @property
+ def chunks(self) -> tuple[tuple[int, ...], ...]:
+ ...
+
+
+# temporary placeholder for indicating an array api compliant type.
+# hopefully in the future we can narrow this down more
+T_DuckArray = TypeVar("T_DuckArray", bound=_Array[np.dtype[np.generic]])
+T_ChunkedArray = TypeVar("T_ChunkedArray", bound=_ChunkedArray[np.dtype[np.generic]])
+
+
+# Singleton type, as per https://github.com/python/typing/pull/240
+class Default(Enum):
+ token: Final = 0
+
+
+_default = Default.token
+
+
+def module_available(module: str) -> bool:
+ """Checks whether a module is installed without importing it.
+
+ Use this for a lightweight check and lazy imports.
+
+ Parameters
+ ----------
+ module : str
+ Name of the module.
+
+ Returns
+ -------
+ available : bool
+ Whether the module is installed.
+ """
+ return importlib.util.find_spec(module) is not None
+
+
+def is_dask_collection(x: object) -> TypeGuard[DaskCollection]:
+ if module_available("dask"):
+ from dask.typing import DaskCollection
+
+ return isinstance(x, DaskCollection)
+ return False
+
+
+def is_duck_array(value: object) -> TypeGuard[T_DuckArray]:
+ if isinstance(value, np.ndarray):
+ return True
+ return (
+ hasattr(value, "ndim")
+ and hasattr(value, "shape")
+ and hasattr(value, "dtype")
+ and (
+ (hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__"))
+ or hasattr(value, "__array_namespace__")
+ )
+ )
+
+
+def is_duck_dask_array(x: T_DuckArray) -> TypeGuard[DaskArray]:
+ return is_dask_collection(x)
+
+
+def is_chunked_duck_array(
+ x: T_DuckArray,
+) -> TypeGuard[_ChunkedArray[np.dtype[np.generic]]]:
+ return hasattr(x, "chunks")
+
+
+def to_0d_object_array(
+ value: object,
+) -> np.ndarray[Any, np.dtype[np.object_]]:
+ """Given a value, wrap it in a 0-D numpy.ndarray with dtype=object."""
+ result = np.empty((), dtype=object)
+ result[()] = value
+ return result
+
+
+class ReprObject:
+ """Object that prints as the given value, for use with sentinel values."""
+
+ __slots__ = ("_value",)
+
+ _value: str
+
+ def __init__(self, value: str):
+ self._value = value
+
+ def __repr__(self) -> str:
+ return self._value
+
+ def __eq__(self, other: ReprObject | Any) -> bool:
+ # TODO: What type can other be? ArrayLike?
+ return self._value == other._value if isinstance(other, ReprObject) else False
+
+ def __hash__(self) -> int:
+ return hash((type(self), self._value))
+
+ def __dask_tokenize__(self) -> Hashable:
+ from dask.base import normalize_token
+
+ return normalize_token((type(self), self._value)) # type: ignore[no-any-return]
diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py
index 8e930d0731c..61f2014fbc3 100644
--- a/xarray/plot/dataarray_plot.py
+++ b/xarray/plot/dataarray_plot.py
@@ -1348,7 +1348,7 @@ def _plot2d(plotfunc):
`seaborn color palette `_.
Note: if ``cmap`` is a seaborn color palette and the plot type
is not ``'contour'`` or ``'contourf'``, ``levels`` must also be specified.
- center : float, optional
+ center : float or False, optional
The value at which to center the colormap. Passing this value implies
use of a diverging colormap. Setting it to ``False`` prevents use of a
diverging colormap.
@@ -1432,7 +1432,7 @@ def newplotfunc(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -1692,7 +1692,7 @@ def imshow( # type: ignore[misc,unused-ignore] # None is hashable :(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -1733,7 +1733,7 @@ def imshow(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -1774,7 +1774,7 @@ def imshow(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -1911,7 +1911,7 @@ def contour( # type: ignore[misc,unused-ignore] # None is hashable :(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -1952,7 +1952,7 @@ def contour(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -1993,7 +1993,7 @@ def contour(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -2047,7 +2047,7 @@ def contourf( # type: ignore[misc,unused-ignore] # None is hashable :(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -2088,7 +2088,7 @@ def contourf(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -2129,7 +2129,7 @@ def contourf(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -2183,7 +2183,7 @@ def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -2224,7 +2224,7 @@ def pcolormesh(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -2265,7 +2265,7 @@ def pcolormesh(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -2370,7 +2370,7 @@ def surface(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -2411,7 +2411,7 @@ def surface(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
@@ -2452,7 +2452,7 @@ def surface(
vmin: float | None = None,
vmax: float | None = None,
cmap: str | Colormap | None = None,
- center: float | None = None,
+ center: float | Literal[False] | None = None,
robust: bool = False,
extend: ExtendOptions = None,
levels: ArrayLike | None = None,
diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py
index 5bd517098f1..9ec67bf47dc 100644
--- a/xarray/tests/test_backends.py
+++ b/xarray/tests/test_backends.py
@@ -714,9 +714,6 @@ def multiple_indexing(indexers):
]
multiple_indexing(indexers5)
- @pytest.mark.xfail(
- reason="zarr without dask handles negative steps in slices incorrectly",
- )
def test_vectorized_indexing_negative_step(self) -> None:
# use dask explicitly when present
open_kwargs: dict[str, Any] | None
@@ -1842,8 +1839,8 @@ def test_unsorted_index_raises(self) -> None:
# dask first pulls items by block.
pass
+ @pytest.mark.skip(reason="caching behavior differs for dask")
def test_dataset_caching(self) -> None:
- # caching behavior differs for dask
pass
def test_write_inconsistent_chunks(self) -> None:
@@ -2261,9 +2258,6 @@ def test_encoding_kwarg_fixed_width_string(self) -> None:
# not relevant for zarr, since we don't use EncodedStringCoder
pass
- # TODO: someone who understand caching figure out whether caching
- # makes sense for Zarr backend
- @pytest.mark.xfail(reason="Zarr caching not implemented")
def test_dataset_caching(self) -> None:
super().test_dataset_caching()
@@ -2712,6 +2706,14 @@ def test_attributes(self, obj) -> None:
with pytest.raises(TypeError, match=r"Invalid attribute in Dataset.attrs."):
ds.to_zarr(store_target, **self.version_kwargs)
+ def test_vectorized_indexing_negative_step(self) -> None:
+ if not has_dask:
+ pytest.xfail(
+ reason="zarr without dask handles negative steps in slices incorrectly"
+ )
+
+ super().test_vectorized_indexing_negative_step()
+
@requires_zarr
class TestZarrDictStore(ZarrBase):
@@ -3378,8 +3380,8 @@ def roundtrip(
) as ds:
yield ds
+ @pytest.mark.skip(reason="caching behavior differs for dask")
def test_dataset_caching(self) -> None:
- # caching behavior differs for dask
pass
def test_write_inconsistent_chunks(self) -> None:
@@ -3457,6 +3459,7 @@ def skip_if_not_engine(engine):
@requires_dask
@pytest.mark.filterwarnings("ignore:use make_scale(name) instead")
+@pytest.mark.xfail(reason="Flaky test. Very open to contributions on fixing this")
def test_open_mfdataset_manyfiles(
readengine, nfiles, parallel, chunks, file_cache_maxsize
):
@@ -3982,7 +3985,6 @@ def test_open_mfdataset_raise_on_bad_combine_args(self) -> None:
with pytest.raises(ValueError, match="`concat_dim` has no effect"):
open_mfdataset([tmp1, tmp2], concat_dim="x")
- @pytest.mark.xfail(reason="mfdataset loses encoding currently.")
def test_encoding_mfdataset(self) -> None:
original = Dataset(
{
@@ -4195,7 +4197,6 @@ def test_dataarray_compute(self) -> None:
assert computed._in_memory
assert_allclose(actual, computed, decode_bytes=False)
- @pytest.mark.xfail
def test_save_mfdataset_compute_false_roundtrip(self) -> None:
from dask.delayed import Delayed
@@ -5125,15 +5126,17 @@ def test_open_fsspec() -> None:
ds2 = open_dataset(url, engine="zarr")
xr.testing.assert_equal(ds0, ds2)
- # multi dataset
- url = "memory://out*.zarr"
- ds2 = open_mfdataset(url, engine="zarr")
- xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2)
-
- # multi dataset with caching
- url = "simplecache::memory://out*.zarr"
- ds2 = open_mfdataset(url, engine="zarr")
- xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2)
+ # open_mfdataset requires dask
+ if has_dask:
+ # multi dataset
+ url = "memory://out*.zarr"
+ ds2 = open_mfdataset(url, engine="zarr")
+ xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2)
+
+ # multi dataset with caching
+ url = "simplecache::memory://out*.zarr"
+ ds2 = open_mfdataset(url, engine="zarr")
+ xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2)
@requires_h5netcdf
diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py
index f58a6490632..1a1df6b81fe 100644
--- a/xarray/tests/test_cftimeindex.py
+++ b/xarray/tests/test_cftimeindex.py
@@ -1135,7 +1135,6 @@ def test_to_datetimeindex_feb_29(calendar):
@requires_cftime
-@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/24263")
def test_multiindex():
index = xr.cftime_range("2001-01-01", periods=100, calendar="360_day")
mindex = pd.MultiIndex.from_arrays([index])
diff --git a/xarray/tests/test_coarsen.py b/xarray/tests/test_coarsen.py
index e345ae691ec..01d5393e289 100644
--- a/xarray/tests/test_coarsen.py
+++ b/xarray/tests/test_coarsen.py
@@ -6,6 +6,7 @@
import xarray as xr
from xarray import DataArray, Dataset, set_options
+from xarray.core import duck_array_ops
from xarray.tests import (
assert_allclose,
assert_equal,
@@ -272,21 +273,24 @@ def test_coarsen_construct(self, dask: bool) -> None:
expected = xr.Dataset(attrs={"foo": "bar"})
expected["vart"] = (
("year", "month"),
- ds.vart.data.reshape((-1, 12)),
+ duck_array_ops.reshape(ds.vart.data, (-1, 12)),
{"a": "b"},
)
expected["varx"] = (
("x", "x_reshaped"),
- ds.varx.data.reshape((-1, 5)),
+ duck_array_ops.reshape(ds.varx.data, (-1, 5)),
{"a": "b"},
)
expected["vartx"] = (
("x", "x_reshaped", "year", "month"),
- ds.vartx.data.reshape(2, 5, 4, 12),
+ duck_array_ops.reshape(ds.vartx.data, (2, 5, 4, 12)),
{"a": "b"},
)
expected["vary"] = ds.vary
- expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12)))
+ expected.coords["time"] = (
+ ("year", "month"),
+ duck_array_ops.reshape(ds.time.data, (-1, 12)),
+ )
with raise_if_dask_computes():
actual = ds.coarsen(time=12, x=5).construct(
diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py
index 079e432b565..423e48bd155 100644
--- a/xarray/tests/test_coding_times.py
+++ b/xarray/tests/test_coding_times.py
@@ -30,7 +30,7 @@
from xarray.coding.variables import SerializationWarning
from xarray.conventions import _update_bounds_attributes, cf_encoder
from xarray.core.common import contains_cftime_datetimes
-from xarray.testing import assert_allclose, assert_equal, assert_identical
+from xarray.testing import assert_equal, assert_identical
from xarray.tests import (
FirstElementAccessibleArray,
arm_xfail,
@@ -1036,7 +1036,7 @@ def test_encode_cf_datetime_defaults_to_correct_dtype(
pytest.skip("Nanosecond frequency is not valid for cftime dates.")
times = date_range("2000", periods=3, freq=freq)
units = f"{encoding_units} since 2000-01-01"
- encoded, _, _ = coding.times.encode_cf_datetime(times, units)
+ encoded, _units, _ = coding.times.encode_cf_datetime(times, units)
numpy_timeunit = coding.times._netcdf_to_numpy_timeunit(encoding_units)
encoding_units_as_timedelta = np.timedelta64(1, numpy_timeunit)
@@ -1212,6 +1212,7 @@ def test_contains_cftime_lazy() -> None:
("1677-09-21T00:12:43.145224193", "ns", np.int64, None, False),
("1677-09-21T00:12:43.145225", "us", np.int64, None, False),
("1970-01-01T00:00:01.000001", "us", np.int64, None, False),
+ ("1677-09-21T00:21:52.901038080", "ns", np.float32, 20.0, True),
],
)
def test_roundtrip_datetime64_nanosecond_precision(
@@ -1261,14 +1262,52 @@ def test_roundtrip_datetime64_nanosecond_precision_warning() -> None:
]
units = "days since 1970-01-10T01:01:00"
needed_units = "hours"
- encoding = dict(_FillValue=20, units=units)
+ new_units = f"{needed_units} since 1970-01-10T01:01:00"
+
+ encoding = dict(dtype=None, _FillValue=20, units=units)
var = Variable(["time"], times, encoding=encoding)
- wmsg = (
- f"Times can't be serialized faithfully with requested units {units!r}. "
- f"Resolution of {needed_units!r} needed. "
- )
- with pytest.warns(UserWarning, match=wmsg):
+ with pytest.warns(UserWarning, match=f"Resolution of {needed_units!r} needed."):
+ encoded_var = conventions.encode_cf_variable(var)
+ assert encoded_var.dtype == np.float64
+ assert encoded_var.attrs["units"] == units
+ assert encoded_var.attrs["_FillValue"] == 20.0
+
+ decoded_var = conventions.decode_cf_variable("foo", encoded_var)
+ assert_identical(var, decoded_var)
+
+ encoding = dict(dtype="int64", _FillValue=20, units=units)
+ var = Variable(["time"], times, encoding=encoding)
+ with pytest.warns(
+ UserWarning, match=f"Serializing with units {new_units!r} instead."
+ ):
encoded_var = conventions.encode_cf_variable(var)
+ assert encoded_var.dtype == np.int64
+ assert encoded_var.attrs["units"] == new_units
+ assert encoded_var.attrs["_FillValue"] == 20
+
+ decoded_var = conventions.decode_cf_variable("foo", encoded_var)
+ assert_identical(var, decoded_var)
+
+ encoding = dict(dtype="float64", _FillValue=20, units=units)
+ var = Variable(["time"], times, encoding=encoding)
+ with warnings.catch_warnings():
+ warnings.simplefilter("error")
+ encoded_var = conventions.encode_cf_variable(var)
+ assert encoded_var.dtype == np.float64
+ assert encoded_var.attrs["units"] == units
+ assert encoded_var.attrs["_FillValue"] == 20.0
+
+ decoded_var = conventions.decode_cf_variable("foo", encoded_var)
+ assert_identical(var, decoded_var)
+
+ encoding = dict(dtype="int64", _FillValue=20, units=new_units)
+ var = Variable(["time"], times, encoding=encoding)
+ with warnings.catch_warnings():
+ warnings.simplefilter("error")
+ encoded_var = conventions.encode_cf_variable(var)
+ assert encoded_var.dtype == np.int64
+ assert encoded_var.attrs["units"] == new_units
+ assert encoded_var.attrs["_FillValue"] == 20
decoded_var = conventions.decode_cf_variable("foo", encoded_var)
assert_identical(var, decoded_var)
@@ -1309,21 +1348,30 @@ def test_roundtrip_timedelta64_nanosecond_precision_warning() -> None:
needed_units = "hours"
wmsg = (
f"Timedeltas can't be serialized faithfully with requested units {units!r}. "
- f"Resolution of {needed_units!r} needed. "
+ f"Serializing with units {needed_units!r} instead."
)
- encoding = dict(_FillValue=20, units=units)
+ encoding = dict(dtype=np.int64, _FillValue=20, units=units)
var = Variable(["time"], timedelta_values, encoding=encoding)
with pytest.warns(UserWarning, match=wmsg):
encoded_var = conventions.encode_cf_variable(var)
+ assert encoded_var.dtype == np.int64
+ assert encoded_var.attrs["units"] == needed_units
+ assert encoded_var.attrs["_FillValue"] == 20
decoded_var = conventions.decode_cf_variable("foo", encoded_var)
- assert_allclose(var, decoded_var)
+ assert_identical(var, decoded_var)
+ assert decoded_var.encoding["dtype"] == np.int64
def test_roundtrip_float_times() -> None:
+ # Regression test for GitHub issue #8271
fill_value = 20.0
- times = [np.datetime64("2000-01-01 12:00:00", "ns"), np.datetime64("NaT", "ns")]
+ times = [
+ np.datetime64("1970-01-01 00:00:00", "ns"),
+ np.datetime64("1970-01-01 06:00:00", "ns"),
+ np.datetime64("NaT", "ns"),
+ ]
- units = "days since 2000-01-01"
+ units = "days since 1960-01-01"
var = Variable(
["time"],
times,
@@ -1331,7 +1379,7 @@ def test_roundtrip_float_times() -> None:
)
encoded_var = conventions.encode_cf_variable(var)
- np.testing.assert_array_equal(encoded_var, np.array([0.5, 20.0]))
+ np.testing.assert_array_equal(encoded_var, np.array([3653, 3653.25, 20.0]))
assert encoded_var.attrs["units"] == units
assert encoded_var.attrs["_FillValue"] == fill_value
diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py
index b75e80db2da..87f8328e441 100644
--- a/xarray/tests/test_computation.py
+++ b/xarray/tests/test_computation.py
@@ -1190,7 +1190,7 @@ def test_apply_dask() -> None:
# unknown setting for dask array handling
with pytest.raises(ValueError):
- apply_ufunc(identity, array, dask="unknown")
+ apply_ufunc(identity, array, dask="unknown") # type: ignore
def dask_safe_identity(x):
return apply_ufunc(identity, x, dask="allowed")
diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py
index 4dae7809be9..5157688b629 100644
--- a/xarray/tests/test_conventions.py
+++ b/xarray/tests/test_conventions.py
@@ -80,6 +80,28 @@ def test_decode_cf_with_conflicting_fill_missing_value() -> None:
assert_identical(actual, expected)
+def test_decode_cf_variable_with_mismatched_coordinates() -> None:
+ # tests for decoding mismatched coordinates attributes
+ # see GH #1809
+ zeros1 = np.zeros((1, 5, 3))
+ orig = Dataset(
+ {
+ "XLONG": (["x", "y"], zeros1.squeeze(0), {}),
+ "XLAT": (["x", "y"], zeros1.squeeze(0), {}),
+ "foo": (["time", "x", "y"], zeros1, {"coordinates": "XTIME XLONG XLAT"}),
+ "time": ("time", [0.0], {"units": "hours since 2017-01-01"}),
+ }
+ )
+ decoded = conventions.decode_cf(orig, decode_coords=True)
+ assert decoded["foo"].encoding["coordinates"] == "XTIME XLONG XLAT"
+ assert list(decoded.coords.keys()) == ["XLONG", "XLAT", "time"]
+
+ decoded = conventions.decode_cf(orig, decode_coords=False)
+ assert "coordinates" not in decoded["foo"].encoding
+ assert decoded["foo"].attrs.get("coordinates") == "XTIME XLONG XLAT"
+ assert list(decoded.coords.keys()) == ["time"]
+
+
@requires_cftime
class TestEncodeCFVariable:
def test_incompatible_attributes(self) -> None:
@@ -246,9 +268,12 @@ def test_dataset(self) -> None:
assert_identical(expected, actual)
def test_invalid_coordinates(self) -> None:
- # regression test for GH308
+ # regression test for GH308, GH1809
original = Dataset({"foo": ("t", [1, 2], {"coordinates": "invalid"})})
+ decoded = Dataset({"foo": ("t", [1, 2], {}, {"coordinates": "invalid"})})
actual = conventions.decode_cf(original)
+ assert_identical(decoded, actual)
+ actual = conventions.decode_cf(original, decode_coords=False)
assert_identical(original, actual)
def test_decode_coordinates(self) -> None:
diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py
index 66bc69966d2..5eb5394d58e 100644
--- a/xarray/tests/test_dataarray.py
+++ b/xarray/tests/test_dataarray.py
@@ -38,6 +38,7 @@
from xarray.core.indexes import Index, PandasIndex, filter_indexes_from_coords
from xarray.core.types import QueryEngineOptions, QueryParserOptions
from xarray.core.utils import is_scalar
+from xarray.testing import _assert_internal_invariants
from xarray.tests import (
InaccessibleArray,
ReturnItem,
@@ -286,7 +287,7 @@ def test_encoding(self) -> None:
self.dv.encoding = expected2
assert expected2 is not self.dv.encoding
- def test_reset_encoding(self) -> None:
+ def test_drop_encoding(self) -> None:
array = self.mda
encoding = {"scale_factor": 10}
array.encoding = encoding
@@ -295,7 +296,7 @@ def test_reset_encoding(self) -> None:
assert array.encoding == encoding
assert array["x"].encoding == encoding
- actual = array.reset_encoding()
+ actual = array.drop_encoding()
# did not modify in place
assert array.encoding == encoding
@@ -415,9 +416,6 @@ def test_constructor_invalid(self) -> None:
with pytest.raises(ValueError, match=r"conflicting MultiIndex"):
DataArray(np.random.rand(4, 4), [("x", self.mindex), ("level_1", range(4))])
- with pytest.raises(ValueError, match=r"matching the dimension size"):
- DataArray(data, coords={"x": 0}, dims=["x", "y"])
-
def test_constructor_from_self_described(self) -> None:
data = [[-0.1, 21], [0, 2]]
expected = DataArray(
@@ -1885,6 +1883,16 @@ def test_rename_dimension_coord_warnings(self) -> None:
):
da.rename(x="y")
+ # No operation should not raise a warning
+ da = xr.DataArray(
+ data=np.ones((2, 3)),
+ dims=["x", "y"],
+ coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])},
+ )
+ with warnings.catch_warnings():
+ warnings.simplefilter("error")
+ da.rename(x="x")
+
def test_init_value(self) -> None:
expected = DataArray(
np.full((3, 4), 3), dims=["x", "y"], coords=[range(3), range(4)]
@@ -2719,6 +2727,14 @@ def test_where_lambda(self) -> None:
actual = arr.where(lambda x: x.y < 2, drop=True)
assert_identical(actual, expected)
+ def test_where_other_lambda(self) -> None:
+ arr = DataArray(np.arange(4), dims="y")
+ expected = xr.concat(
+ [arr.sel(y=slice(2)), arr.sel(y=slice(2, None)) + 1], dim="y"
+ )
+ actual = arr.where(lambda x: x.y < 2, lambda x: x + 1)
+ assert_identical(actual, expected)
+
def test_where_string(self) -> None:
array = DataArray(["a", "b"])
expected = DataArray(np.array(["a", np.nan], dtype=object))
@@ -4011,7 +4027,7 @@ def test_dot(self) -> None:
assert_equal(expected5, actual5)
with pytest.raises(NotImplementedError):
- da.dot(dm3.to_dataset(name="dm")) # type: ignore
+ da.dot(dm3.to_dataset(name="dm"))
with pytest.raises(TypeError):
da.dot(dm3.values) # type: ignore
@@ -7112,3 +7128,35 @@ def test_error_on_ellipsis_without_list(self) -> None:
da = DataArray([[1, 2], [1, 2]], dims=("x", "y"))
with pytest.raises(ValueError):
da.stack(flat=...) # type: ignore
+
+
+def test_nD_coord_dataarray() -> None:
+ # should succeed
+ da = DataArray(
+ np.ones((2, 4)),
+ dims=("x", "y"),
+ coords={
+ "x": (("x", "y"), np.arange(8).reshape((2, 4))),
+ "y": ("y", np.arange(4)),
+ },
+ )
+ _assert_internal_invariants(da, check_default_indexes=True)
+
+ da2 = DataArray(np.ones(4), dims=("y"), coords={"y": ("y", np.arange(4))})
+ da3 = DataArray(np.ones(4), dims=("z"))
+
+ _, actual = xr.align(da, da2)
+ assert_identical(da2, actual)
+
+ expected = da.drop_vars("x")
+ _, actual = xr.broadcast(da, da2)
+ assert_identical(expected, actual)
+
+ actual, _ = xr.broadcast(da, da3)
+ expected = da.expand_dims(z=4, axis=-1)
+ assert_identical(actual, expected)
+
+ da4 = DataArray(np.ones((2, 4)), coords={"x": 0}, dims=["x", "y"])
+ _assert_internal_invariants(da4, check_default_indexes=True)
+ assert "x" not in da4.xindexes
+ assert "x" in da4.coords
diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py
index c832663ecff..687aae8f1dc 100644
--- a/xarray/tests/test_dataset.py
+++ b/xarray/tests/test_dataset.py
@@ -411,10 +411,14 @@ def test_repr_nep18(self) -> None:
class Array:
def __init__(self):
self.shape = (2,)
+ self.ndim = 1
self.dtype = np.dtype(np.float64)
def __array_function__(self, *args, **kwargs):
- pass
+ return NotImplemented
+
+ def __array_ufunc__(self, *args, **kwargs):
+ return NotImplemented
def __repr__(self):
return "Custom\nArray"
@@ -2328,9 +2332,9 @@ def test_align(self) -> None:
assert np.isnan(left2["var3"][-2:]).all()
with pytest.raises(ValueError, match=r"invalid value for join"):
- align(left, right, join="foobar") # type: ignore[arg-type]
+ align(left, right, join="foobar") # type: ignore[call-overload]
with pytest.raises(TypeError):
- align(left, right, foo="bar") # type: ignore[call-arg]
+ align(left, right, foo="bar") # type: ignore[call-overload]
def test_align_exact(self) -> None:
left = xr.Dataset(coords={"x": [0, 1]})
@@ -2955,7 +2959,7 @@ def test_copy_with_data_errors(self) -> None:
with pytest.raises(ValueError, match=r"contain all variables in original"):
orig.copy(data={"var1": new_var1})
- def test_reset_encoding(self) -> None:
+ def test_drop_encoding(self) -> None:
orig = create_test_data()
vencoding = {"scale_factor": 10}
orig.encoding = {"foo": "bar"}
@@ -2963,7 +2967,7 @@ def test_reset_encoding(self) -> None:
for k, v in orig.variables.items():
orig[k].encoding = vencoding
- actual = orig.reset_encoding()
+ actual = orig.drop_encoding()
assert actual.encoding == {}
for k, v in actual.variables.items():
assert v.encoding == {}
@@ -3028,8 +3032,7 @@ def test_rename_old_name(self) -> None:
def test_rename_same_name(self) -> None:
data = create_test_data()
newnames = {"var1": "var1", "dim2": "dim2"}
- with pytest.warns(UserWarning, match="does not create an index anymore"):
- renamed = data.rename(newnames)
+ renamed = data.rename(newnames)
assert_identical(renamed, data)
def test_rename_dims(self) -> None:
@@ -3099,6 +3102,15 @@ def test_rename_dimension_coord_warnings(self) -> None:
):
ds.rename(x="y")
+ # No operation should not raise a warning
+ ds = Dataset(
+ data_vars={"data": (("x", "y"), np.ones((2, 3)))},
+ coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])},
+ )
+ with warnings.catch_warnings():
+ warnings.simplefilter("error")
+ ds.rename(x="x")
+
def test_rename_multiindex(self) -> None:
midx = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"])
midx_coords = Coordinates.from_pandas_multiindex(midx, "x")
@@ -4116,7 +4128,8 @@ def test_setitem(self) -> None:
data4[{"dim2": [2, 3]}] = data3["var1"][{"dim2": [3, 4]}].values
data5 = data4.astype(str)
data5["var4"] = data4["var1"]
- err_msg = "could not convert string to float: 'a'"
+ # convert to `np.str_('a')` once `numpy<2.0` has been dropped
+ err_msg = "could not convert string to float: .*'a'.*"
with pytest.raises(ValueError, match=err_msg):
data5[{"dim2": 1}] = "a"
@@ -5046,9 +5059,9 @@ def test_dropna(self) -> None:
):
ds.dropna("foo")
with pytest.raises(ValueError, match=r"invalid how"):
- ds.dropna("a", how="somehow") # type: ignore
+ ds.dropna("a", how="somehow") # type: ignore[arg-type]
with pytest.raises(TypeError, match=r"must specify how or thresh"):
- ds.dropna("a", how=None) # type: ignore
+ ds.dropna("a", how=None) # type: ignore[arg-type]
def test_fillna(self) -> None:
ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]})
diff --git a/xarray/tests/test_deprecation_helpers.py b/xarray/tests/test_deprecation_helpers.py
index 35128829073..f21c8097060 100644
--- a/xarray/tests/test_deprecation_helpers.py
+++ b/xarray/tests/test_deprecation_helpers.py
@@ -15,15 +15,15 @@ def f1(a, b, *, c="c", d="d"):
assert result == (1, 2, 3, 4)
with pytest.warns(FutureWarning, match=r".*v0.1"):
- result = f1(1, 2, 3)
+ result = f1(1, 2, 3) # type: ignore[misc]
assert result == (1, 2, 3, "d")
with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"):
- result = f1(1, 2, 3)
+ result = f1(1, 2, 3) # type: ignore[misc]
assert result == (1, 2, 3, "d")
with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"):
- result = f1(1, 2, 3, 4)
+ result = f1(1, 2, 3, 4) # type: ignore[misc]
assert result == (1, 2, 3, 4)
@_deprecate_positional_args("v0.1")
@@ -31,7 +31,7 @@ def f2(a="a", *, b="b", c="c", d="d"):
return a, b, c, d
with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"):
- result = f2(1, 2)
+ result = f2(1, 2) # type: ignore[misc]
assert result == (1, 2, "c", "d")
@_deprecate_positional_args("v0.1")
@@ -39,11 +39,11 @@ def f3(a, *, b="b", **kwargs):
return a, b, kwargs
with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"):
- result = f3(1, 2)
+ result = f3(1, 2) # type: ignore[misc]
assert result == (1, 2, {})
with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"):
- result = f3(1, 2, f="f")
+ result = f3(1, 2, f="f") # type: ignore[misc]
assert result == (1, 2, {"f": "f"})
@_deprecate_positional_args("v0.1")
@@ -57,7 +57,7 @@ def f4(a, /, *, b="b", **kwargs):
assert result == (1, 2, {"f": "f"})
with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"):
- result = f4(1, 2, f="f")
+ result = f4(1, 2, f="f") # type: ignore[misc]
assert result == (1, 2, {"f": "f"})
with pytest.raises(TypeError, match=r"Keyword-only param without default"):
@@ -80,15 +80,15 @@ def method(self, a, b, *, c="c", d="d"):
assert result == (1, 2, 3, 4)
with pytest.warns(FutureWarning, match=r".*v0.1"):
- result = A1().method(1, 2, 3)
+ result = A1().method(1, 2, 3) # type: ignore[misc]
assert result == (1, 2, 3, "d")
with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"):
- result = A1().method(1, 2, 3)
+ result = A1().method(1, 2, 3) # type: ignore[misc]
assert result == (1, 2, 3, "d")
with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"):
- result = A1().method(1, 2, 3, 4)
+ result = A1().method(1, 2, 3, 4) # type: ignore[misc]
assert result == (1, 2, 3, 4)
class A2:
@@ -97,11 +97,11 @@ def method(self, a=1, b=1, *, c="c", d="d"):
return a, b, c, d
with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"):
- result = A2().method(1, 2, 3)
+ result = A2().method(1, 2, 3) # type: ignore[misc]
assert result == (1, 2, 3, "d")
with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"):
- result = A2().method(1, 2, 3, 4)
+ result = A2().method(1, 2, 3, 4) # type: ignore[misc]
assert result == (1, 2, 3, 4)
class A3:
@@ -110,11 +110,11 @@ def method(self, a, *, b="b", **kwargs):
return a, b, kwargs
with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"):
- result = A3().method(1, 2)
+ result = A3().method(1, 2) # type: ignore[misc]
assert result == (1, 2, {})
with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"):
- result = A3().method(1, 2, f="f")
+ result = A3().method(1, 2, f="f") # type: ignore[misc]
assert result == (1, 2, {"f": "f"})
class A4:
@@ -129,7 +129,7 @@ def method(self, a, /, *, b="b", **kwargs):
assert result == (1, 2, {"f": "f"})
with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"):
- result = A4().method(1, 2, f="f")
+ result = A4().method(1, 2, f="f") # type: ignore[misc]
assert result == (1, 2, {"f": "f"})
with pytest.raises(TypeError, match=r"Keyword-only param without default"):
diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py
index 6a8cd9c457b..bfc37121597 100644
--- a/xarray/tests/test_distributed.py
+++ b/xarray/tests/test_distributed.py
@@ -168,6 +168,10 @@ def test_open_mfdataset_multiple_files_parallel_distributed(parallel, tmp_path):
@requires_netCDF4
@pytest.mark.parametrize("parallel", (True, False))
def test_open_mfdataset_multiple_files_parallel(parallel, tmp_path):
+ if parallel:
+ pytest.skip(
+ "Flaky in CI. Would be a welcome contribution to make a similar test reliable."
+ )
lon = np.arange(100)
time = xr.cftime_range("20010101", periods=100, calendar="360_day")
data = np.random.random((time.size, lon.size))
diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py
index 7670b77322c..5ca134503e8 100644
--- a/xarray/tests/test_formatting.py
+++ b/xarray/tests/test_formatting.py
@@ -549,7 +549,7 @@ def _repr_inline_(self, width):
return formatted
- def __array_function__(self, *args, **kwargs):
+ def __array_namespace__(self, *args, **kwargs):
return NotImplemented
@property
diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py
index e143e2b8e03..320ba999318 100644
--- a/xarray/tests/test_groupby.py
+++ b/xarray/tests/test_groupby.py
@@ -810,9 +810,9 @@ def test_groupby_math_more() -> None:
with pytest.raises(TypeError, match=r"only support binary ops"):
grouped + 1 # type: ignore[operator]
with pytest.raises(TypeError, match=r"only support binary ops"):
- grouped + grouped
+ grouped + grouped # type: ignore[operator]
with pytest.raises(TypeError, match=r"in-place operations"):
- ds += grouped
+ ds += grouped # type: ignore[arg-type]
ds = Dataset(
{
diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py
index fe2cdc58807..c57d84c927d 100644
--- a/xarray/tests/test_missing.py
+++ b/xarray/tests/test_missing.py
@@ -92,26 +92,36 @@ def make_interpolate_example_data(shape, frac_nan, seed=12345, non_uniform=False
return da, df
+@pytest.mark.parametrize("fill_value", [None, np.nan, 47.11])
+@pytest.mark.parametrize(
+ "method", ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]
+)
@requires_scipy
-def test_interpolate_pd_compat():
+def test_interpolate_pd_compat(method, fill_value) -> None:
shapes = [(8, 8), (1, 20), (20, 1), (100, 100)]
frac_nans = [0, 0.5, 1]
- methods = ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]
- for shape, frac_nan, method in itertools.product(shapes, frac_nans, methods):
+ for shape, frac_nan in itertools.product(shapes, frac_nans):
da, df = make_interpolate_example_data(shape, frac_nan)
for dim in ["time", "x"]:
- actual = da.interpolate_na(method=method, dim=dim, fill_value=np.nan)
+ actual = da.interpolate_na(method=method, dim=dim, fill_value=fill_value)
+ # need limit_direction="both" here, to let pandas fill
+ # in both directions instead of default forward direction only
expected = df.interpolate(
method=method,
axis=da.get_axis_num(dim),
+ limit_direction="both",
+ fill_value=fill_value,
)
- # Note, Pandas does some odd things with the left/right fill_value
- # for the linear methods. This next line inforces the xarray
- # fill_value convention on the pandas output. Therefore, this test
- # only checks that interpolated values are the same (not nans)
- expected.values[pd.isnull(actual.values)] = np.nan
+
+ if method == "linear":
+ # Note, Pandas does not take left/right fill_value into account
+ # for the numpy linear methods.
+ # see https://github.com/pandas-dev/pandas/issues/55144
+ # This aligns the pandas output with the xarray output
+ expected.values[pd.isnull(actual.values)] = np.nan
+ expected.values[actual.values == fill_value] = fill_value
np.testing.assert_allclose(actual.values, expected.values)
diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py
new file mode 100644
index 00000000000..ea1588bf554
--- /dev/null
+++ b/xarray/tests/test_namedarray.py
@@ -0,0 +1,207 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+import numpy as np
+import pytest
+
+import xarray as xr
+from xarray.namedarray.core import NamedArray, as_compatible_data
+from xarray.namedarray.utils import T_DuckArray
+
+if TYPE_CHECKING:
+ from xarray.namedarray.utils import Self # type: ignore[attr-defined]
+
+
+@pytest.fixture
+def random_inputs() -> np.ndarray[Any, np.dtype[np.float32]]:
+ return np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
+
+
+@pytest.mark.parametrize(
+ "input_data, expected_output",
+ [
+ ([1, 2, 3], np.array([1, 2, 3])),
+ (np.array([4, 5, 6]), np.array([4, 5, 6])),
+ (NamedArray("time", np.array([1, 2, 3])), np.array([1, 2, 3])),
+ (2, np.array(2)),
+ ],
+)
+def test_as_compatible_data(
+ input_data: T_DuckArray, expected_output: T_DuckArray
+) -> None:
+ output: T_DuckArray = as_compatible_data(input_data)
+ assert np.array_equal(output, expected_output)
+
+
+def test_as_compatible_data_with_masked_array() -> None:
+ masked_array = np.ma.array([1, 2, 3], mask=[False, True, False]) # type: ignore[no-untyped-call]
+ with pytest.raises(NotImplementedError):
+ as_compatible_data(masked_array)
+
+
+def test_as_compatible_data_with_0d_object() -> None:
+ data = np.empty((), dtype=object)
+ data[()] = (10, 12, 12)
+ np.array_equal(as_compatible_data(data), data)
+
+
+def test_as_compatible_data_with_explicitly_indexed(
+ random_inputs: np.ndarray[Any, Any]
+) -> None:
+ # TODO: Make xr.core.indexing.ExplicitlyIndexed pass is_duck_array and remove this test.
+ class CustomArrayBase(xr.core.indexing.NDArrayMixin):
+ def __init__(self, array: T_DuckArray) -> None:
+ self.array = array
+
+ @property
+ def dtype(self) -> np.dtype[np.generic]:
+ return self.array.dtype
+
+ @property
+ def shape(self) -> tuple[int, ...]:
+ return self.array.shape
+
+ @property
+ def real(self) -> Self:
+ raise NotImplementedError
+
+ @property
+ def imag(self) -> Self:
+ raise NotImplementedError
+
+ def astype(self, dtype: np.typing.DTypeLike) -> Self:
+ raise NotImplementedError
+
+ class CustomArray(CustomArrayBase):
+ def __array__(self) -> np.ndarray[Any, np.dtype[np.generic]]:
+ return np.array(self.array)
+
+ class CustomArrayIndexable(CustomArrayBase, xr.core.indexing.ExplicitlyIndexed):
+ pass
+
+ array = CustomArray(random_inputs)
+ output: CustomArray = as_compatible_data(array)
+ assert isinstance(output, np.ndarray)
+
+ array2 = CustomArrayIndexable(random_inputs)
+ output2: CustomArrayIndexable = as_compatible_data(array2)
+ assert isinstance(output2, CustomArrayIndexable)
+
+
+def test_properties() -> None:
+ data = 0.5 * np.arange(10).reshape(2, 5)
+ named_array: NamedArray[np.ndarray[Any, Any]]
+ named_array = NamedArray(["x", "y"], data, {"key": "value"})
+ assert named_array.dims == ("x", "y")
+ assert np.array_equal(named_array.data, data)
+ assert named_array.attrs == {"key": "value"}
+ assert named_array.ndim == 2
+ assert named_array.sizes == {"x": 2, "y": 5}
+ assert named_array.size == 10
+ assert named_array.nbytes == 80
+ assert len(named_array) == 2
+
+
+def test_attrs() -> None:
+ named_array: NamedArray[np.ndarray[Any, Any]]
+ named_array = NamedArray(["x", "y"], np.arange(10).reshape(2, 5))
+ assert named_array.attrs == {}
+ named_array.attrs["key"] = "value"
+ assert named_array.attrs == {"key": "value"}
+ named_array.attrs = {"key": "value2"}
+ assert named_array.attrs == {"key": "value2"}
+
+
+def test_data(random_inputs: np.ndarray[Any, Any]) -> None:
+ named_array: NamedArray[np.ndarray[Any, Any]]
+ named_array = NamedArray(["x", "y", "z"], random_inputs)
+ assert np.array_equal(named_array.data, random_inputs)
+ with pytest.raises(ValueError):
+ named_array.data = np.random.random((3, 4)).astype(np.float64)
+
+
+# Additional tests as per your original class-based code
+@pytest.mark.parametrize(
+ "data, dtype",
+ [
+ ("foo", np.dtype("U3")),
+ (np.bytes_("foo"), np.dtype("S3")),
+ ],
+)
+def test_0d_string(data: Any, dtype: np.typing.DTypeLike) -> None:
+ named_array: NamedArray[np.ndarray[Any, Any]]
+ named_array = NamedArray([], data)
+ assert named_array.data == data
+ assert named_array.dims == ()
+ assert named_array.sizes == {}
+ assert named_array.attrs == {}
+ assert named_array.ndim == 0
+ assert named_array.size == 1
+ assert named_array.dtype == dtype
+
+
+def test_0d_object() -> None:
+ named_array: NamedArray[np.ndarray[Any, Any]]
+ named_array = NamedArray([], (10, 12, 12))
+ expected_data = np.empty((), dtype=object)
+ expected_data[()] = (10, 12, 12)
+ assert np.array_equal(named_array.data, expected_data)
+
+ assert named_array.dims == ()
+ assert named_array.sizes == {}
+ assert named_array.attrs == {}
+ assert named_array.ndim == 0
+ assert named_array.size == 1
+ assert named_array.dtype == np.dtype("O")
+
+
+def test_0d_datetime() -> None:
+ named_array: NamedArray[np.ndarray[Any, Any]]
+ named_array = NamedArray([], np.datetime64("2000-01-01"))
+ assert named_array.dtype == np.dtype("datetime64[D]")
+
+
+@pytest.mark.parametrize(
+ "timedelta, expected_dtype",
+ [
+ (np.timedelta64(1, "D"), np.dtype("timedelta64[D]")),
+ (np.timedelta64(1, "s"), np.dtype("timedelta64[s]")),
+ (np.timedelta64(1, "m"), np.dtype("timedelta64[m]")),
+ (np.timedelta64(1, "h"), np.dtype("timedelta64[h]")),
+ (np.timedelta64(1, "us"), np.dtype("timedelta64[us]")),
+ (np.timedelta64(1, "ns"), np.dtype("timedelta64[ns]")),
+ (np.timedelta64(1, "ps"), np.dtype("timedelta64[ps]")),
+ (np.timedelta64(1, "fs"), np.dtype("timedelta64[fs]")),
+ (np.timedelta64(1, "as"), np.dtype("timedelta64[as]")),
+ ],
+)
+def test_0d_timedelta(
+ timedelta: np.timedelta64, expected_dtype: np.dtype[np.timedelta64]
+) -> None:
+ named_array: NamedArray[np.ndarray[Any, np.dtype[np.timedelta64]]]
+ named_array = NamedArray([], timedelta)
+ assert named_array.dtype == expected_dtype
+ assert named_array.data == timedelta
+
+
+@pytest.mark.parametrize(
+ "dims, data_shape, new_dims, raises",
+ [
+ (["x", "y", "z"], (2, 3, 4), ["a", "b", "c"], False),
+ (["x", "y", "z"], (2, 3, 4), ["a", "b"], True),
+ (["x", "y", "z"], (2, 4, 5), ["a", "b", "c", "d"], True),
+ ([], [], (), False),
+ ([], [], ("x",), True),
+ ],
+)
+def test_dims_setter(dims: Any, data_shape: Any, new_dims: Any, raises: bool) -> None:
+ named_array: NamedArray[np.ndarray[Any, Any]]
+ named_array = NamedArray(dims, np.random.random(data_shape))
+ assert named_array.dims == tuple(dims)
+ if raises:
+ with pytest.raises(ValueError):
+ named_array.dims = new_dims
+ else:
+ named_array.dims = new_dims
+ assert named_array.dims == tuple(new_dims)
diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py
index 3cecf1b52ec..8ad1cbe11be 100644
--- a/xarray/tests/test_options.py
+++ b/xarray/tests/test_options.py
@@ -165,7 +165,6 @@ def test_concat_attr_retention(self) -> None:
result = concat([ds1, ds2], dim="dim1")
assert result.attrs == original_attrs
- @pytest.mark.xfail
def test_merge_attr_retention(self) -> None:
da1 = create_test_dataarray_attrs(var="var1")
da2 = create_test_dataarray_attrs(var="var2")
diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py
index 9a15696b004..da834b76124 100644
--- a/xarray/tests/test_rolling.py
+++ b/xarray/tests/test_rolling.py
@@ -175,7 +175,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None:
@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
- def test_rolling_construct(self, center, window) -> None:
+ def test_rolling_construct(self, center: bool, window: int) -> None:
s = pd.Series(np.arange(10))
da = DataArray.from_series(s)
@@ -610,7 +610,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None:
@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
- def test_rolling_construct(self, center, window) -> None:
+ def test_rolling_construct(self, center: bool, window: int) -> None:
df = pd.DataFrame(
{
"x": np.random.randn(20),
@@ -627,12 +627,6 @@ def test_rolling_construct(self, center, window) -> None:
np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values)
np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"])
- # with stride
- ds_rolling_mean = ds_rolling.construct("window", stride=2).mean("window")
- np.testing.assert_allclose(
- df_rolling["x"][::2].values, ds_rolling_mean["x"].values
- )
- np.testing.assert_allclose(df_rolling.index[::2], ds_rolling_mean["index"])
# with fill_value
ds_rolling_mean = ds_rolling.construct("window", stride=2, fill_value=0.0).mean(
"window"
@@ -640,6 +634,51 @@ def test_rolling_construct(self, center, window) -> None:
assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim="vars").all()
assert (ds_rolling_mean["x"] == 0.0).sum() >= 0
+ @pytest.mark.parametrize("center", (True, False))
+ @pytest.mark.parametrize("window", (1, 2, 3, 4))
+ def test_rolling_construct_stride(self, center: bool, window: int) -> None:
+ df = pd.DataFrame(
+ {
+ "x": np.random.randn(20),
+ "y": np.random.randn(20),
+ "time": np.linspace(0, 1, 20),
+ }
+ )
+ ds = Dataset.from_dataframe(df)
+ df_rolling_mean = df.rolling(window, center=center, min_periods=1).mean()
+
+ # With an index (dimension coordinate)
+ ds_rolling = ds.rolling(index=window, center=center)
+ ds_rolling_mean = ds_rolling.construct("w", stride=2).mean("w")
+ np.testing.assert_allclose(
+ df_rolling_mean["x"][::2].values, ds_rolling_mean["x"].values
+ )
+ np.testing.assert_allclose(df_rolling_mean.index[::2], ds_rolling_mean["index"])
+
+ # Without index (https://github.com/pydata/xarray/issues/7021)
+ ds2 = ds.drop_vars("index")
+ ds2_rolling = ds2.rolling(index=window, center=center)
+ ds2_rolling_mean = ds2_rolling.construct("w", stride=2).mean("w")
+ np.testing.assert_allclose(
+ df_rolling_mean["x"][::2].values, ds2_rolling_mean["x"].values
+ )
+
+ # Mixed coordinates, indexes and 2D coordinates
+ ds3 = xr.Dataset(
+ {"x": ("t", range(20)), "x2": ("y", range(5))},
+ {
+ "t": range(20),
+ "y": ("y", range(5)),
+ "t2": ("t", range(20)),
+ "y2": ("y", range(5)),
+ "yt": (["t", "y"], np.ones((20, 5))),
+ },
+ )
+ ds3_rolling = ds3.rolling(t=window, center=center)
+ ds3_rolling_mean = ds3_rolling.construct("w", stride=2).mean("w")
+ for coord in ds3.coords:
+ assert coord in ds3_rolling_mean.coords
+
@pytest.mark.slow
@pytest.mark.parametrize("ds", (1, 2), indirect=True)
@pytest.mark.parametrize("center", (True, False))
@@ -727,9 +766,7 @@ def test_ndrolling_construct(self, center, fill_value, dask) -> None:
)
assert_allclose(actual, expected)
- @pytest.mark.xfail(
- reason="See https://github.com/pydata/xarray/pull/4369 or docstring"
- )
+ @requires_dask
@pytest.mark.filterwarnings("error")
@pytest.mark.parametrize("ds", (2,), indirect=True)
@pytest.mark.parametrize("name", ("mean", "max"))
@@ -751,7 +788,9 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None:
@requires_numbagg
class TestDatasetRollingExp:
- @pytest.mark.parametrize("backend", ["numpy"], indirect=True)
+ @pytest.mark.parametrize(
+ "backend", ["numpy", pytest.param("dask", marks=requires_dask)], indirect=True
+ )
def test_rolling_exp(self, ds) -> None:
result = ds.rolling_exp(time=10, window_type="span").mean()
assert isinstance(result, Dataset)
diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py
index f64ce9338d7..489836b70fd 100644
--- a/xarray/tests/test_sparse.py
+++ b/xarray/tests/test_sparse.py
@@ -147,7 +147,6 @@ def test_variable_property(prop):
],
),
True,
- marks=xfail(reason="Coercion to dense"),
),
param(
do("conjugate"),
@@ -201,7 +200,6 @@ def test_variable_property(prop):
param(
do("reduce", func="sum", dim="x"),
True,
- marks=xfail(reason="Coercion to dense"),
),
param(
do("rolling_window", dim="x", window=2, window_dim="x_win"),
@@ -218,7 +216,7 @@ def test_variable_property(prop):
param(
do("var"), False, marks=xfail(reason="Missing implementation for np.nanvar")
),
- param(do("to_dict"), False, marks=xfail(reason="Coercion to dense")),
+ param(do("to_dict"), False),
(do("where", cond=make_xrvar({"x": 10, "y": 5}) > 0.5), True),
],
ids=repr,
@@ -237,7 +235,14 @@ def test_variable_method(func, sparse_output):
assert isinstance(ret_s.data, sparse.SparseArray)
assert np.allclose(ret_s.data.todense(), ret_d.data, equal_nan=True)
else:
- assert np.allclose(ret_s, ret_d, equal_nan=True)
+ if func.meth != "to_dict":
+ assert np.allclose(ret_s, ret_d)
+ else:
+ # pop the arrays from the dict
+ arr_s, arr_d = ret_s.pop("data"), ret_d.pop("data")
+
+ assert np.allclose(arr_s, arr_d)
+ assert ret_s == ret_d
@pytest.mark.parametrize(
diff --git a/xarray/tests/test_typed_ops.py b/xarray/tests/test_typed_ops.py
new file mode 100644
index 00000000000..1d4ef89ae29
--- /dev/null
+++ b/xarray/tests/test_typed_ops.py
@@ -0,0 +1,246 @@
+import numpy as np
+
+from xarray import DataArray, Dataset, Variable
+
+
+def test_variable_typed_ops() -> None:
+ """Tests for type checking of typed_ops on Variable"""
+
+ var = Variable(dims=["t"], data=[1, 2, 3])
+
+ def _test(var: Variable) -> None:
+ # mypy checks the input type
+ assert isinstance(var, Variable)
+
+ _int: int = 1
+ _list = [1, 2, 3]
+ _ndarray = np.array([1, 2, 3])
+
+ # __add__ as an example of binary ops
+ _test(var + _int)
+ _test(var + _list)
+ _test(var + _ndarray)
+ _test(var + var)
+
+ # __radd__ as an example of reflexive binary ops
+ _test(_int + var)
+ _test(_list + var)
+ _test(_ndarray + var) # type: ignore[arg-type] # numpy problem
+
+ # __eq__ as an example of cmp ops
+ _test(var == _int)
+ _test(var == _list)
+ _test(var == _ndarray)
+ _test(_int == var) # type: ignore[arg-type] # typeshed problem
+ _test(_list == var) # type: ignore[arg-type] # typeshed problem
+ _test(_ndarray == var)
+
+ # __lt__ as another example of cmp ops
+ _test(var < _int)
+ _test(var < _list)
+ _test(var < _ndarray)
+ _test(_int > var)
+ _test(_list > var)
+ _test(_ndarray > var) # type: ignore[arg-type] # numpy problem
+
+ # __iadd__ as an example of inplace binary ops
+ var += _int
+ var += _list
+ var += _ndarray
+
+ # __neg__ as an example of unary ops
+ _test(-var)
+
+
+def test_dataarray_typed_ops() -> None:
+ """Tests for type checking of typed_ops on DataArray"""
+
+ da = DataArray([1, 2, 3], dims=["t"])
+
+ def _test(da: DataArray) -> None:
+ # mypy checks the input type
+ assert isinstance(da, DataArray)
+
+ _int: int = 1
+ _list = [1, 2, 3]
+ _ndarray = np.array([1, 2, 3])
+ _var = Variable(dims=["t"], data=[1, 2, 3])
+
+ # __add__ as an example of binary ops
+ _test(da + _int)
+ _test(da + _list)
+ _test(da + _ndarray)
+ _test(da + _var)
+ _test(da + da)
+
+ # __radd__ as an example of reflexive binary ops
+ _test(_int + da)
+ _test(_list + da)
+ _test(_ndarray + da) # type: ignore[arg-type] # numpy problem
+ _test(_var + da)
+
+ # __eq__ as an example of cmp ops
+ _test(da == _int)
+ _test(da == _list)
+ _test(da == _ndarray)
+ _test(da == _var)
+ _test(_int == da) # type: ignore[arg-type] # typeshed problem
+ _test(_list == da) # type: ignore[arg-type] # typeshed problem
+ _test(_ndarray == da)
+ _test(_var == da)
+
+ # __lt__ as another example of cmp ops
+ _test(da < _int)
+ _test(da < _list)
+ _test(da < _ndarray)
+ _test(da < _var)
+ _test(_int > da)
+ _test(_list > da)
+ _test(_ndarray > da) # type: ignore[arg-type] # numpy problem
+ _test(_var > da)
+
+ # __iadd__ as an example of inplace binary ops
+ da += _int
+ da += _list
+ da += _ndarray
+ da += _var
+
+ # __neg__ as an example of unary ops
+ _test(-da)
+
+
+def test_dataset_typed_ops() -> None:
+ """Tests for type checking of typed_ops on Dataset"""
+
+ ds = Dataset({"a": ("t", [1, 2, 3])})
+
+ def _test(ds: Dataset) -> None:
+ # mypy checks the input type
+ assert isinstance(ds, Dataset)
+
+ _int: int = 1
+ _list = [1, 2, 3]
+ _ndarray = np.array([1, 2, 3])
+ _var = Variable(dims=["t"], data=[1, 2, 3])
+ _da = DataArray([1, 2, 3], dims=["t"])
+
+ # __add__ as an example of binary ops
+ _test(ds + _int)
+ _test(ds + _list)
+ _test(ds + _ndarray)
+ _test(ds + _var)
+ _test(ds + _da)
+ _test(ds + ds)
+
+ # __radd__ as an example of reflexive binary ops
+ _test(_int + ds)
+ _test(_list + ds)
+ _test(_ndarray + ds)
+ _test(_var + ds)
+ _test(_da + ds)
+
+ # __eq__ as an example of cmp ops
+ _test(ds == _int)
+ _test(ds == _list)
+ _test(ds == _ndarray)
+ _test(ds == _var)
+ _test(ds == _da)
+ _test(_int == ds) # type: ignore[arg-type] # typeshed problem
+ _test(_list == ds) # type: ignore[arg-type] # typeshed problem
+ _test(_ndarray == ds)
+ _test(_var == ds)
+ _test(_da == ds)
+
+ # __lt__ as another example of cmp ops
+ _test(ds < _int)
+ _test(ds < _list)
+ _test(ds < _ndarray)
+ _test(ds < _var)
+ _test(ds < _da)
+ _test(_int > ds)
+ _test(_list > ds)
+ _test(_ndarray > ds) # type: ignore[arg-type] # numpy problem
+ _test(_var > ds)
+ _test(_da > ds)
+
+ # __iadd__ as an example of inplace binary ops
+ ds += _int
+ ds += _list
+ ds += _ndarray
+ ds += _var
+ ds += _da
+
+ # __neg__ as an example of unary ops
+ _test(-ds)
+
+
+def test_dataarray_groupy_typed_ops() -> None:
+ """Tests for type checking of typed_ops on DataArrayGroupBy"""
+
+ da = DataArray([1, 2, 3], coords={"x": ("t", [1, 2, 2])}, dims=["t"])
+ grp = da.groupby("x")
+
+ def _testda(da: DataArray) -> None:
+ # mypy checks the input type
+ assert isinstance(da, DataArray)
+
+ def _testds(ds: Dataset) -> None:
+ # mypy checks the input type
+ assert isinstance(ds, Dataset)
+
+ _da = DataArray([5, 6], coords={"x": [1, 2]}, dims="x")
+ _ds = _da.to_dataset(name="a")
+
+ # __add__ as an example of binary ops
+ _testda(grp + _da)
+ _testds(grp + _ds)
+
+ # __radd__ as an example of reflexive binary ops
+ _testda(_da + grp)
+ _testds(_ds + grp)
+
+ # __eq__ as an example of cmp ops
+ _testda(grp == _da)
+ _testda(_da == grp)
+ _testds(grp == _ds)
+ _testds(_ds == grp)
+
+ # __lt__ as another example of cmp ops
+ _testda(grp < _da)
+ _testda(_da > grp)
+ _testds(grp < _ds)
+ _testds(_ds > grp)
+
+
+def test_dataset_groupy_typed_ops() -> None:
+ """Tests for type checking of typed_ops on DatasetGroupBy"""
+
+ ds = Dataset({"a": ("t", [1, 2, 3])}, coords={"x": ("t", [1, 2, 2])})
+ grp = ds.groupby("x")
+
+ def _test(ds: Dataset) -> None:
+ # mypy checks the input type
+ assert isinstance(ds, Dataset)
+
+ _da = DataArray([5, 6], coords={"x": [1, 2]}, dims="x")
+ _ds = _da.to_dataset(name="a")
+
+ # __add__ as an example of binary ops
+ _test(grp + _da)
+ _test(grp + _ds)
+
+ # __radd__ as an example of reflexive binary ops
+ _test(_da + grp)
+ _test(_ds + grp)
+
+ # __eq__ as an example of cmp ops
+ _test(grp == _da)
+ _test(_da == grp)
+ _test(grp == _ds)
+ _test(_ds == grp)
+
+ # __lt__ as another example of cmp ops
+ _test(grp < _da)
+ _test(_da > grp)
+ _test(grp < _ds)
+ _test(_ds > grp)
diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py
index addd7587544..7e1105e2e5d 100644
--- a/xarray/tests/test_units.py
+++ b/xarray/tests/test_units.py
@@ -18,6 +18,7 @@
assert_identical,
requires_dask,
requires_matplotlib,
+ requires_numbagg,
)
from xarray.tests.test_plot import PlotTestCase
from xarray.tests.test_variable import _PAD_XR_NP_ARGS
@@ -304,11 +305,13 @@ def __call__(self, obj, *args, **kwargs):
all_args = merge_args(self.args, args)
all_kwargs = {**self.kwargs, **kwargs}
+ from xarray.core.groupby import GroupBy
+
xarray_classes = (
xr.Variable,
xr.DataArray,
xr.Dataset,
- xr.core.groupby.GroupBy,
+ GroupBy,
)
if not isinstance(obj, xarray_classes):
@@ -2548,7 +2551,6 @@ def test_univariate_ufunc(self, units, error, dtype):
assert_units_equal(expected, actual)
assert_identical(expected, actual)
- @pytest.mark.xfail(reason="needs the type register system for __array_ufunc__")
@pytest.mark.parametrize(
"unit,error",
(
@@ -3849,23 +3851,21 @@ def test_computation(self, func, variant, dtype):
method("groupby", "x"),
method("groupby_bins", "y", bins=4),
method("coarsen", y=2),
- pytest.param(
- method("rolling", y=3),
- marks=pytest.mark.xfail(
- reason="numpy.lib.stride_tricks.as_strided converts to ndarray"
- ),
- ),
- pytest.param(
- method("rolling_exp", y=3),
- marks=pytest.mark.xfail(
- reason="numbagg functions are not supported by pint"
- ),
- ),
+ method("rolling", y=3),
+ pytest.param(method("rolling_exp", y=3), marks=requires_numbagg),
method("weighted", xr.DataArray(data=np.linspace(0, 1, 10), dims="y")),
),
ids=repr,
)
def test_computation_objects(self, func, variant, dtype):
+ if variant == "data":
+ if func.name == "rolling_exp":
+ pytest.xfail(reason="numbagg functions are not supported by pint")
+ elif func.name == "rolling":
+ pytest.xfail(
+ reason="numpy.lib.stride_tricks.as_strided converts to ndarray"
+ )
+
unit = unit_registry.m
variants = {
diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py
index 4fcd5f98d8f..73238b6ae3a 100644
--- a/xarray/tests/test_variable.py
+++ b/xarray/tests/test_variable.py
@@ -473,12 +473,12 @@ def test_encoding_preserved(self):
assert_identical(expected.to_base_variable(), actual.to_base_variable())
assert expected.encoding == actual.encoding
- def test_reset_encoding(self) -> None:
+ def test_drop_encoding(self) -> None:
encoding1 = {"scale_factor": 1}
# encoding set via cls constructor
v1 = self.cls(["a"], [0, 1, 2], encoding=encoding1)
assert v1.encoding == encoding1
- v2 = v1.reset_encoding()
+ v2 = v1.drop_encoding()
assert v1.encoding == encoding1
assert v2.encoding == {}
@@ -486,7 +486,7 @@ def test_reset_encoding(self) -> None:
encoding3 = {"scale_factor": 10}
v3 = self.cls(["a"], [0, 1, 2], encoding=encoding3)
assert v3.encoding == encoding3
- v4 = v3.reset_encoding()
+ v4 = v3.drop_encoding()
assert v3.encoding == encoding3
assert v4.encoding == {}
@@ -885,20 +885,10 @@ def test_getitem_error(self):
"mode",
[
"mean",
- pytest.param(
- "median",
- marks=pytest.mark.xfail(reason="median is not implemented by Dask"),
- ),
- pytest.param(
- "reflect", marks=pytest.mark.xfail(reason="dask.array.pad bug")
- ),
+ "median",
+ "reflect",
"edge",
- pytest.param(
- "linear_ramp",
- marks=pytest.mark.xfail(
- reason="pint bug: https://github.com/hgrecco/pint/issues/1026"
- ),
- ),
+ "linear_ramp",
"maximum",
"minimum",
"symmetric",
@@ -926,7 +916,7 @@ def test_pad_constant_values(self, xr_arg, np_arg):
actual = v.pad(**xr_arg)
expected = np.pad(
- np.array(v.data.astype(float)),
+ np.array(duck_array_ops.astype(v.data, float)),
np_arg,
mode="constant",
constant_values=np.nan,
@@ -2345,12 +2335,35 @@ def test_dask_rolling(self, dim, window, center):
assert actual.shape == expected.shape
assert_equal(actual, expected)
- @pytest.mark.xfail(
- reason="https://github.com/pydata/xarray/issues/6209#issuecomment-1025116203"
- )
def test_multiindex(self):
super().test_multiindex()
+ @pytest.mark.parametrize(
+ "mode",
+ [
+ "mean",
+ pytest.param(
+ "median",
+ marks=pytest.mark.xfail(reason="median is not implemented by Dask"),
+ ),
+ pytest.param(
+ "reflect", marks=pytest.mark.xfail(reason="dask.array.pad bug")
+ ),
+ "edge",
+ "linear_ramp",
+ "maximum",
+ "minimum",
+ "symmetric",
+ "wrap",
+ ],
+ )
+ @pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS)
+ @pytest.mark.filterwarnings(
+ r"ignore:dask.array.pad.+? converts integers to floats."
+ )
+ def test_pad(self, mode, xr_arg, np_arg):
+ super().test_pad(mode, xr_arg, np_arg)
+
@requires_sparse
class TestVariableWithSparse:
diff --git a/xarray/util/deprecation_helpers.py b/xarray/util/deprecation_helpers.py
index e9681bdf398..7b4cf901aa1 100644
--- a/xarray/util/deprecation_helpers.py
+++ b/xarray/util/deprecation_helpers.py
@@ -34,6 +34,9 @@
import inspect
import warnings
from functools import wraps
+from typing import Callable, TypeVar
+
+T = TypeVar("T", bound=Callable)
POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD
KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY
@@ -41,7 +44,7 @@
EMPTY = inspect.Parameter.empty
-def _deprecate_positional_args(version):
+def _deprecate_positional_args(version) -> Callable[[T], T]:
"""Decorator for methods that issues warnings for positional arguments
Using the keyword-only argument syntax in pep 3102, arguments after the
diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py
index cf0673e7cca..f339470884a 100644
--- a/xarray/util/generate_ops.py
+++ b/xarray/util/generate_ops.py
@@ -3,14 +3,16 @@
For internal xarray development use only.
Usage:
- python xarray/util/generate_ops.py --module > xarray/core/_typed_ops.py
- python xarray/util/generate_ops.py --stubs > xarray/core/_typed_ops.pyi
+ python xarray/util/generate_ops.py > xarray/core/_typed_ops.py
"""
# Note: the comments in https://github.com/pydata/xarray/pull/4904 provide some
# background to some of the design choices made here.
-import sys
+from __future__ import annotations
+
+from collections.abc import Iterator, Sequence
+from typing import Optional
BINOPS_EQNE = (("__eq__", "nputils.array_eq"), ("__ne__", "nputils.array_ne"))
BINOPS_CMP = (
@@ -74,155 +76,180 @@
("conjugate", "ops.conjugate"),
)
+
+required_method_binary = """
+ def _binary_op(
+ self, other: {other_type}, f: Callable, reflexive: bool = False
+ ) -> {return_type}:
+ raise NotImplementedError"""
template_binop = """
- def {method}(self, other):
+ def {method}(self, other: {other_type}) -> {return_type}:{type_ignore}
+ return self._binary_op(other, {func})"""
+template_binop_overload = """
+ @overload{overload_type_ignore}
+ def {method}(self, other: {overload_type}) -> {overload_type}:
+ ...
+
+ @overload
+ def {method}(self, other: {other_type}) -> {return_type}:
+ ...
+
+ def {method}(self, other: {other_type}) -> {return_type} | {overload_type}:{type_ignore}
return self._binary_op(other, {func})"""
template_reflexive = """
- def {method}(self, other):
+ def {method}(self, other: {other_type}) -> {return_type}:
return self._binary_op(other, {func}, reflexive=True)"""
+
+required_method_inplace = """
+ def _inplace_binary_op(self, other: {other_type}, f: Callable) -> Self:
+ raise NotImplementedError"""
template_inplace = """
- def {method}(self, other):
+ def {method}(self, other: {other_type}) -> Self:{type_ignore}
return self._inplace_binary_op(other, {func})"""
+
+required_method_unary = """
+ def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self:
+ raise NotImplementedError"""
template_unary = """
- def {method}(self):
+ def {method}(self) -> Self:
return self._unary_op({func})"""
template_other_unary = """
- def {method}(self, *args, **kwargs):
+ def {method}(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op({func}, *args, **kwargs)"""
-required_method_unary = """
- def _unary_op(self, f, *args, **kwargs):
- raise NotImplementedError"""
-required_method_binary = """
- def _binary_op(self, other, f, reflexive=False):
- raise NotImplementedError"""
-required_method_inplace = """
- def _inplace_binary_op(self, other, f):
- raise NotImplementedError"""
# For some methods we override return type `bool` defined by base class `object`.
-OVERRIDE_TYPESHED = {"override": " # type: ignore[override]"}
-NO_OVERRIDE = {"override": ""}
-
-# Note: in some of the overloads below the return value in reality is NotImplemented,
-# which cannot accurately be expressed with type hints,e.g. Literal[NotImplemented]
-# or type(NotImplemented) are not allowed and NoReturn has a different meaning.
-# In such cases we are lending the type checkers a hand by specifying the return type
-# of the corresponding reflexive method on `other` which will be called instead.
-stub_ds = """\
- def {method}(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...{override}"""
-stub_da = """\
- @overload{override}
- def {method}(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def {method}(self, other: "DatasetGroupBy") -> "Dataset": ...
- @overload
- def {method}(self: T_DataArray, other: DaCompatible) -> T_DataArray: ..."""
-stub_var = """\
- @overload{override}
- def {method}(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def {method}(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def {method}(self: T_Variable, other: VarCompatible) -> T_Variable: ..."""
-stub_dsgb = """\
- @overload{override}
- def {method}(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def {method}(self, other: "DataArray") -> "Dataset": ...
- @overload
- def {method}(self, other: GroupByIncompatible) -> NoReturn: ..."""
-stub_dagb = """\
- @overload{override}
- def {method}(self, other: T_Dataset) -> T_Dataset: ...
- @overload
- def {method}(self, other: T_DataArray) -> T_DataArray: ...
- @overload
- def {method}(self, other: GroupByIncompatible) -> NoReturn: ..."""
-stub_unary = """\
- def {method}(self: {self_type}) -> {self_type}: ..."""
-stub_other_unary = """\
- def {method}(self: {self_type}, *args, **kwargs) -> {self_type}: ..."""
-stub_required_unary = """\
- def _unary_op(self, f, *args, **kwargs): ..."""
-stub_required_binary = """\
- def _binary_op(self, other, f, reflexive=...): ..."""
-stub_required_inplace = """\
- def _inplace_binary_op(self, other, f): ..."""
-
-
-def unops(self_type):
- extra_context = {"self_type": self_type}
+# We need to add "# type: ignore[override]"
+# Keep an eye out for:
+# https://discuss.python.org/t/make-type-hints-for-eq-of-primitives-less-strict/34240
+# The type ignores might not be neccesary anymore at some point.
+#
+# We require a "hack" to tell type checkers that e.g. Variable + DataArray = DataArray
+# In reality this returns NotImplementes, but this is not a valid type in python 3.9.
+# Therefore, we return DataArray. In reality this would call DataArray.__add__(Variable)
+# TODO: change once python 3.10 is the minimum.
+#
+# Mypy seems to require that __iadd__ and __add__ have the same signature.
+# This requires some extra type: ignores[misc] in the inplace methods :/
+
+
+def _type_ignore(ignore: str) -> str:
+ return f" # type:ignore[{ignore}]" if ignore else ""
+
+
+FuncType = Sequence[tuple[Optional[str], Optional[str]]]
+OpsType = tuple[FuncType, str, dict[str, str]]
+
+
+def binops(
+ other_type: str, return_type: str = "Self", type_ignore_eq: str = "override"
+) -> list[OpsType]:
+ extras = {"other_type": other_type, "return_type": return_type}
+ return [
+ ([(None, None)], required_method_binary, extras),
+ (BINOPS_NUM + BINOPS_CMP, template_binop, extras | {"type_ignore": ""}),
+ (
+ BINOPS_EQNE,
+ template_binop,
+ extras | {"type_ignore": _type_ignore(type_ignore_eq)},
+ ),
+ (BINOPS_REFLEXIVE, template_reflexive, extras),
+ ]
+
+
+def binops_overload(
+ other_type: str,
+ overload_type: str,
+ return_type: str = "Self",
+ type_ignore_eq: str = "override",
+) -> list[OpsType]:
+ extras = {"other_type": other_type, "return_type": return_type}
return [
- ([(None, None)], required_method_unary, stub_required_unary, {}),
- (UNARY_OPS, template_unary, stub_unary, extra_context),
- (OTHER_UNARY_METHODS, template_other_unary, stub_other_unary, extra_context),
+ ([(None, None)], required_method_binary, extras),
+ (
+ BINOPS_NUM + BINOPS_CMP,
+ template_binop_overload,
+ extras
+ | {
+ "overload_type": overload_type,
+ "type_ignore": "",
+ "overload_type_ignore": "",
+ },
+ ),
+ (
+ BINOPS_EQNE,
+ template_binop_overload,
+ extras
+ | {
+ "overload_type": overload_type,
+ "type_ignore": "",
+ "overload_type_ignore": _type_ignore(type_ignore_eq),
+ },
+ ),
+ (BINOPS_REFLEXIVE, template_reflexive, extras),
]
-def binops(stub=""):
+def inplace(other_type: str, type_ignore: str = "") -> list[OpsType]:
+ extras = {"other_type": other_type}
return [
- ([(None, None)], required_method_binary, stub_required_binary, {}),
- (BINOPS_NUM + BINOPS_CMP, template_binop, stub, NO_OVERRIDE),
- (BINOPS_EQNE, template_binop, stub, OVERRIDE_TYPESHED),
- (BINOPS_REFLEXIVE, template_reflexive, stub, NO_OVERRIDE),
+ ([(None, None)], required_method_inplace, extras),
+ (
+ BINOPS_INPLACE,
+ template_inplace,
+ extras | {"type_ignore": _type_ignore(type_ignore)},
+ ),
]
-def inplace():
+def unops() -> list[OpsType]:
return [
- ([(None, None)], required_method_inplace, stub_required_inplace, {}),
- (BINOPS_INPLACE, template_inplace, "", {}),
+ ([(None, None)], required_method_unary, {}),
+ (UNARY_OPS, template_unary, {}),
+ (OTHER_UNARY_METHODS, template_other_unary, {}),
]
ops_info = {}
-ops_info["DatasetOpsMixin"] = binops(stub_ds) + inplace() + unops("T_Dataset")
-ops_info["DataArrayOpsMixin"] = binops(stub_da) + inplace() + unops("T_DataArray")
-ops_info["VariableOpsMixin"] = binops(stub_var) + inplace() + unops("T_Variable")
-ops_info["DatasetGroupByOpsMixin"] = binops(stub_dsgb)
-ops_info["DataArrayGroupByOpsMixin"] = binops(stub_dagb)
+ops_info["DatasetOpsMixin"] = (
+ binops(other_type="DsCompatible") + inplace(other_type="DsCompatible") + unops()
+)
+ops_info["DataArrayOpsMixin"] = (
+ binops(other_type="DaCompatible") + inplace(other_type="DaCompatible") + unops()
+)
+ops_info["VariableOpsMixin"] = (
+ binops_overload(other_type="VarCompatible", overload_type="T_DataArray")
+ + inplace(other_type="VarCompatible", type_ignore="misc")
+ + unops()
+)
+ops_info["DatasetGroupByOpsMixin"] = binops(
+ other_type="GroupByCompatible", return_type="Dataset"
+)
+ops_info["DataArrayGroupByOpsMixin"] = binops(
+ other_type="T_Xarray", return_type="T_Xarray"
+)
MODULE_PREAMBLE = '''\
"""Mixin classes with arithmetic operators."""
# This file was generated using xarray.util.generate_ops. Do not edit manually.
-import operator
-
-from . import nputils, ops'''
+from __future__ import annotations
-STUBFILE_PREAMBLE = '''\
-"""Stub file for mixin classes with arithmetic operators."""
-# This file was generated using xarray.util.generate_ops. Do not edit manually.
-
-from typing import NoReturn, TypeVar, overload
-
-import numpy as np
-from numpy.typing import ArrayLike
+import operator
+from typing import TYPE_CHECKING, Any, Callable, overload
-from .dataarray import DataArray
-from .dataset import Dataset
-from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy
-from .types import (
+from xarray.core import nputils, ops
+from xarray.core.types import (
DaCompatible,
DsCompatible,
- GroupByIncompatible,
- ScalarOrArray,
+ GroupByCompatible,
+ Self,
+ T_DataArray,
+ T_Xarray,
VarCompatible,
)
-from .variable import Variable
-try:
- from dask.array import Array as DaskArray
-except ImportError:
- DaskArray = np.ndarray # type: ignore
-
-# DatasetOpsMixin etc. are parent classes of Dataset etc.
-# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally
-# we use the ones in `types`. (We're open to refining this, and potentially integrating
-# the `py` & `pyi` files to simplify them.)
-T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin")
-T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin")
-T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin")'''
+if TYPE_CHECKING:
+ from xarray.core.dataset import Dataset'''
CLASS_PREAMBLE = """{newline}
@@ -233,35 +260,28 @@ class {cls_name}:
{method}.__doc__ = {func}.__doc__"""
-def render(ops_info, is_module):
+def render(ops_info: dict[str, list[OpsType]]) -> Iterator[str]:
"""Render the module or stub file."""
- yield MODULE_PREAMBLE if is_module else STUBFILE_PREAMBLE
+ yield MODULE_PREAMBLE
for cls_name, method_blocks in ops_info.items():
- yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n" * is_module)
- yield from _render_classbody(method_blocks, is_module)
+ yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n")
+ yield from _render_classbody(method_blocks)
-def _render_classbody(method_blocks, is_module):
- for method_func_pairs, method_template, stub_template, extra in method_blocks:
- template = method_template if is_module else stub_template
+def _render_classbody(method_blocks: list[OpsType]) -> Iterator[str]:
+ for method_func_pairs, template, extra in method_blocks:
if template:
for method, func in method_func_pairs:
yield template.format(method=method, func=func, **extra)
- if is_module:
- yield ""
- for method_func_pairs, *_ in method_blocks:
- for method, func in method_func_pairs:
- if method and func:
- yield COPY_DOCSTRING.format(method=method, func=func)
+ yield ""
+ for method_func_pairs, *_ in method_blocks:
+ for method, func in method_func_pairs:
+ if method and func:
+ yield COPY_DOCSTRING.format(method=method, func=func)
if __name__ == "__main__":
- option = sys.argv[1].lower() if len(sys.argv) == 2 else None
- if option not in {"--module", "--stubs"}:
- raise SystemExit(f"Usage: {sys.argv[0]} --module | --stubs")
- is_module = option == "--module"
-
- for line in render(ops_info, is_module):
+ for line in render(ops_info):
print(line)