diff --git a/.github/ISSUE_TEMPLATE/bugreport.yml b/.github/ISSUE_TEMPLATE/bugreport.yml index 59e5889f5ec..cc1a2e12be3 100644 --- a/.github/ISSUE_TEMPLATE/bugreport.yml +++ b/.github/ISSUE_TEMPLATE/bugreport.yml @@ -44,6 +44,7 @@ body: - label: Complete example — the example is self-contained, including all data and the text of any traceback. - label: Verifiable example — the example copy & pastes into an IPython prompt or [Binder notebook](https://mybinder.org/v2/gh/pydata/xarray/main?urlpath=lab/tree/doc/examples/blank_template.ipynb), returning the result. - label: New issue — a search of GitHub Issues suggests this is not a duplicate. + - label: Recent environment — the issue occurs with the latest version of xarray and its dependencies. - type: textarea id: log-output diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index ec1c192fd35..dc9cc2cd2fe 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -82,8 +82,6 @@ jobs: name: Mypy runs-on: "ubuntu-latest" needs: detect-ci-trigger - # temporarily skipping due to https://github.com/pydata/xarray/issues/6551 - if: needs.detect-ci-trigger.outputs.triggered == 'false' defaults: run: shell: bash -l {0} @@ -190,6 +188,126 @@ jobs: + pyright: + name: Pyright + runs-on: "ubuntu-latest" + needs: detect-ci-trigger + if: | + always() + && ( + contains( github.event.pull_request.labels.*.name, 'run-pyright') + ) + defaults: + run: + shell: bash -l {0} + env: + CONDA_ENV_FILE: ci/requirements/environment.yml + PYTHON_VERSION: "3.10" + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. + + - name: set environment variables + run: | + echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + - name: Setup micromamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: ${{env.CONDA_ENV_FILE}} + environment-name: xarray-tests + create-args: >- + python=${{env.PYTHON_VERSION}} + conda + cache-environment: true + cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" + - name: Install xarray + run: | + python -m pip install --no-deps -e . + - name: Version info + run: | + conda info -a + conda list + python xarray/util/print_versions.py + - name: Install pyright + run: | + python -m pip install pyright --force-reinstall + + - name: Run pyright + run: | + python -m pyright xarray/ + + - name: Upload pyright coverage to Codecov + uses: codecov/codecov-action@v3.1.4 + with: + file: pyright_report/cobertura.xml + flags: pyright + env_vars: PYTHON_VERSION + name: codecov-umbrella + fail_ci_if_error: false + + pyright39: + name: Pyright 3.9 + runs-on: "ubuntu-latest" + needs: detect-ci-trigger + if: | + always() + && ( + contains( github.event.pull_request.labels.*.name, 'run-pyright') + ) + defaults: + run: + shell: bash -l {0} + env: + CONDA_ENV_FILE: ci/requirements/environment.yml + PYTHON_VERSION: "3.9" + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. + + - name: set environment variables + run: | + echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + - name: Setup micromamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: ${{env.CONDA_ENV_FILE}} + environment-name: xarray-tests + create-args: >- + python=${{env.PYTHON_VERSION}} + conda + cache-environment: true + cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" + - name: Install xarray + run: | + python -m pip install --no-deps -e . + - name: Version info + run: | + conda info -a + conda list + python xarray/util/print_versions.py + - name: Install pyright + run: | + python -m pip install pyright --force-reinstall + + - name: Run pyright + run: | + python -m pyright xarray/ + + - name: Upload pyright coverage to Codecov + uses: codecov/codecov-action@v3.1.4 + with: + file: pyright_report/cobertura.xml + flags: pyright39 + env_vars: PYTHON_VERSION + name: codecov-umbrella + fail_ci_if_error: false + + + min-version-policy: name: Minimum Version Policy runs-on: "ubuntu-latest" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c2586a12aa2..5626f450ec0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,13 +18,13 @@ repos: files: ^xarray/ - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.0.287' + rev: 'v0.0.292' hooks: - id: ruff args: ["--fix"] # https://github.com/python/black#version-control-integration - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.9.1 hooks: - id: black-jupyter - repo: https://github.com/keewis/blackdoc @@ -32,7 +32,7 @@ repos: hooks: - id: blackdoc exclude: "generate_aggregations.py" - additional_dependencies: ["black==23.7.0"] + additional_dependencies: ["black==23.9.1"] - id: blackdoc-autoupdate-black - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.5.1 diff --git a/asv_bench/benchmarks/rolling.py b/asv_bench/benchmarks/rolling.py index 1d3713f19bf..579f4f00fbc 100644 --- a/asv_bench/benchmarks/rolling.py +++ b/asv_bench/benchmarks/rolling.py @@ -5,10 +5,10 @@ from . import parameterized, randn, requires_dask -nx = 300 +nx = 3000 long_nx = 30000 ny = 200 -nt = 100 +nt = 1000 window = 20 randn_xy = randn((nx, ny), frac_nan=0.1) @@ -115,6 +115,11 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck): roll = self.ds.var3.rolling(t=100) getattr(roll, func)() + @parameterized(["stride"], ([None, 5, 50])) + def peakmem_1drolling_construct(self, stride): + self.ds.var2.rolling(t=100).construct("w", stride=stride) + self.ds.var3.rolling(t=100).construct("w", stride=stride) + class DatasetRollingMemory(RollingMemory): @parameterized(["func", "use_bottleneck"], (["sum", "max", "mean"], [True, False])) @@ -128,3 +133,7 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck): with xr.set_options(use_bottleneck=use_bottleneck): roll = self.ds.rolling(t=100) getattr(roll, func)() + + @parameterized(["stride"], ([None, 5, 50])) + def peakmem_1drolling_construct(self, stride): + self.ds.rolling(t=100).construct("w", stride=stride) diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index d97c4010528..552d11a06dc 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -265,7 +265,7 @@ Variable.dims Variable.dtype Variable.encoding - Variable.reset_encoding + Variable.drop_encoding Variable.imag Variable.nbytes Variable.ndim diff --git a/doc/api.rst b/doc/api.rst index 0cf07f91df8..96b4864804f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -110,9 +110,9 @@ Dataset contents Dataset.drop_indexes Dataset.drop_duplicates Dataset.drop_dims + Dataset.drop_encoding Dataset.set_coords Dataset.reset_coords - Dataset.reset_encoding Dataset.convert_calendar Dataset.interp_calendar Dataset.get_index @@ -303,8 +303,8 @@ DataArray contents DataArray.drop_vars DataArray.drop_indexes DataArray.drop_duplicates + DataArray.drop_encoding DataArray.reset_coords - DataArray.reset_encoding DataArray.copy DataArray.convert_calendar DataArray.interp_calendar diff --git a/doc/ecosystem.rst b/doc/ecosystem.rst index e6e970c6239..fc5ae963a1d 100644 --- a/doc/ecosystem.rst +++ b/doc/ecosystem.rst @@ -41,6 +41,7 @@ Geosciences harmonic wind analysis in Python. - `wradlib `_: An Open Source Library for Weather Radar Data Processing. - `wrf-python `_: A collection of diagnostic and interpolation routines for use with output of the Weather Research and Forecasting (WRF-ARW) Model. +- `xarray-regrid `_: xarray extension for regridding rectilinear data. - `xarray-simlab `_: xarray extension for computer model simulations. - `xarray-spatial `_: Numba-accelerated raster-based spatial processing tools (NDVI, curvature, zonal-statistics, proximity, hillshading, viewshed, etc.) - `xarray-topo `_: xarray extension for topographic analysis and modelling. diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst index 2ffc25b2009..ffded682035 100644 --- a/doc/user-guide/io.rst +++ b/doc/user-guide/io.rst @@ -260,12 +260,12 @@ Note that all operations that manipulate variables other than indexing will remove encoding information. In some cases it is useful to intentionally reset a dataset's original encoding values. -This can be done with either the :py:meth:`Dataset.reset_encoding` or -:py:meth:`DataArray.reset_encoding` methods. +This can be done with either the :py:meth:`Dataset.drop_encoding` or +:py:meth:`DataArray.drop_encoding` methods. .. ipython:: python - ds_no_encoding = ds_disk.reset_encoding() + ds_no_encoding = ds_disk.drop_encoding() ds_no_encoding.encoding .. _combining multiple files: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ee74411a004..40c50e158ad 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,25 +14,99 @@ What's New np.random.seed(123456) -.. _whats-new.2023.08.1: +.. _whats-new.2023.09.1: -v2023.08.1 (unreleased) +v2023.09.1 (unreleased) ----------------------- New Features ~~~~~~~~~~~~ +- :py:meth:`DataArray.where` & :py:meth:`Dataset.where` accept a callable for + the ``other`` parameter, passing the object as the only argument. Previously, + this was only valid for the ``cond`` parameter. (:issue:`8255`) + By `Maximilian Roos `_. +- :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for + the ``variables`` parameter, passing the object as the only argument. + By `Maximilian Roos `_. +- ``.rolling_exp`` functions can now operate on dask-backed arrays, assuming the + core dim has exactly one chunk. (:pull:`8284`). + By `Maximilian Roos `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Made more arguments keyword-only (e.g. ``keep_attrs``, ``skipna``) for many :py:class:`xarray.DataArray` and + :py:class:`xarray.Dataset` methods (:pull:`6403`). By `Mathias Hauser `_. +- :py:meth:`Dataset.to_zarr` & :py:meth:`DataArray.to_zarr` require keyword + arguments after the initial 7 positional arguments. + By `Maximilian Roos `_. + + +Deprecations +~~~~~~~~~~~~ +- Rename :py:meth:`Dataset.reset_encoding` & :py:meth:`DataArray.reset_encoding` + to :py:meth:`Dataset.drop_encoding` & :py:meth:`DataArray.drop_encoding` for + consistency with other ``drop`` & ``reset`` methods — ``drop`` generally + removes something, while ``reset`` generally resets to some default or + standard value. (:pull:`8287`, :issue:`8259`) + By `Maximilian Roos `_. + +Bug fixes +~~~~~~~~~ +- :py:meth:`DataArray.rename` & :py:meth:`Dataset.rename` would emit a warning + when the operation was a no-op. (:issue:`8266`) + By `Simon Hansen `_. + +- Fix datetime encoding precision loss regression introduced in the previous + release for datetimes encoded with units requiring floating point values, and + a reference date not equal to the first value of the datetime array + (:issue:`8271`, :pull:`8272`). By `Spencer Clark + `_. + + +Documentation +~~~~~~~~~~~~~ + +- Added xarray-regrid to the list of xarray related projects (:pull:`8272`). + By `Bart Schilperoort `_. + + +Internal Changes +~~~~~~~~~~~~~~~~ + +- More improvements to support the Python `array API standard `_ + by using duck array ops in more places in the codebase. (:pull:`8267`) + By `Tom White `_. + + +.. _whats-new.2023.09.0: + +v2023.09.0 (Sep 26, 2023) +------------------------- + +This release continues work on the new :py:class:`xarray.Coordinates` object, allows to provide `preferred_chunks` when +reading from netcdf files, enables :py:func:`xarray.apply_ufunc` to handle missing core dimensions and fixes several bugs. + +Thanks to the 24 contributors to this release: Alexander Fischer, Amrest Chinkamol, Benoit Bovy, Darsh Ranjan, Deepak Cherian, +Gianfranco Costamagna, Gregorio L. Trevisan, Illviljan, Joe Hamman, JR, Justus Magin, Kai Mühlbauer, Kian-Meng Ang, Kyle Sunden, +Martin Raspaud, Mathias Hauser, Mattia Almansi, Maximilian Roos, András Gunyhó, Michael Niklas, Richard Kleijn, Riulinchen, +Tom Nicholas and Wiktor Kraśnicki. + +We welcome the following new contributors to Xarray!: Alexander Fischer, Amrest Chinkamol, Darsh Ranjan, Gianfranco Costamagna, Gregorio L. Trevisan, +Kian-Meng Ang, Riulinchen and Wiktor Kraśnicki. + +New Features +~~~~~~~~~~~~ + - Added the :py:meth:`Coordinates.assign` method that can be used to combine different collections of coordinates prior to assign them to a Dataset or DataArray (:pull:`8102`) at once. By `Benoît Bovy `_. - Provide `preferred_chunks` for data read from netcdf files (:issue:`1440`, :pull:`7948`). By `Martin Raspaud `_. -- Improved static typing of reduction methods (:pull:`6746`). - By `Richard Kleijn `_. - Added `on_missing_core_dims` to :py:meth:`apply_ufunc` to allow for copying or - dropping a :py:class:`Dataset`'s variables with missing core dimensions. - (:pull:`8138`) + dropping a :py:class:`Dataset`'s variables with missing core dimensions (:pull:`8138`). By `Maximilian Roos `_. Breaking changes @@ -61,6 +135,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Improved static typing of reduction methods (:pull:`6746`). + By `Richard Kleijn `_. - Fix bug where empty attrs would generate inconsistent tokens (:issue:`6970`, :pull:`8101`). By `Mattia Almansi `_. - Improved handling of multi-coordinate indexes when updating coordinates, including bug fixes @@ -71,26 +147,39 @@ Bug fixes :pull:`8104`). By `Benoît Bovy `_. - Fix bug where :py:class:`DataArray` instances on the right-hand side - of :py:meth:`DataArray.__setitem__` lose dimension names. - (:issue:`7030`, :pull:`8067`) By `Darsh Ranjan `_. + of :py:meth:`DataArray.__setitem__` lose dimension names (:issue:`7030`, :pull:`8067`). + By `Darsh Ranjan `_. - Return ``float64`` in presence of ``NaT`` in :py:class:`~core.accessor_dt.DatetimeAccessor` and - special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar()` + special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar` (:issue:`7928`, :pull:`8084`). By `Kai Mühlbauer `_. +- Fix :py:meth:`~core.rolling.DatasetRolling.construct` with stride on Datasets without indexes. + (:issue:`7021`, :pull:`7578`). + By `Amrest Chinkamol `_ and `Michael Niklas `_. - Calling plot with kwargs ``col``, ``row`` or ``hue`` no longer squeezes dimensions passed via these arguments (:issue:`7552`, :pull:`8174`). By `Wiktor Kraśnicki `_. -- Fixed a bug where casting from ``float`` to ``int64`` (undefined for ``NaN``) led to varying - issues (:issue:`7817`, :issue:`7942`, :issue:`7790`, :issue:`6191`, :issue:`7096`, +- Fixed a bug where casting from ``float`` to ``int64`` (undefined for ``NaN``) led to varying issues (:issue:`7817`, :issue:`7942`, :issue:`7790`, :issue:`6191`, :issue:`7096`, :issue:`1064`, :pull:`7827`). By `Kai Mühlbauer `_. +- Fixed a bug where inaccurate ``coordinates`` silently failed to decode variable (:issue:`1809`, :pull:`8195`). + By `Kai Mühlbauer `_ - ``.rolling_exp`` functions no longer mistakenly lose non-dimensioned coords - (:issue:`6528`, :pull:`8114`) + (:issue:`6528`, :pull:`8114`). By `Maximilian Roos `_. +- In the event that user-provided datetime64/timedelta64 units and integer dtype encoding parameters conflict with each other, override the units to preserve an integer dtype for most faithful serialization to disk (:issue:`1064`, :pull:`8201`). + By `Kai Mühlbauer `_. +- Static typing of dunder ops methods (like :py:meth:`DataArray.__eq__`) has been fixed. + Remaining issues are upstream problems (:issue:`7780`, :pull:`8204`). + By `Michael Niklas `_. +- Fix type annotation for ``center`` argument of plotting methods (like :py:meth:`xarray.plot.dataarray_plot.pcolormesh`) (:pull:`8261`). + By `Pieter Eendebak `_. Documentation ~~~~~~~~~~~~~ +- Make documentation of :py:meth:`DataArray.where` clearer (:issue:`7767`, :pull:`7955`). + By `Riulinchen `_. Internal Changes ~~~~~~~~~~~~~~~~ @@ -103,6 +192,8 @@ Internal Changes than `.reduce`, as the start of a broader effort to move non-reducing functions away from ```.reduce``, (:pull:`8114`). By `Maximilian Roos `_. +- Test range of fill_value's in test_interpolate_pd_compat (:issue:`8146`, :pull:`8189`). + By `Kai Mühlbauer `_. .. _whats-new.2023.08.0: diff --git a/pyproject.toml b/pyproject.toml index dd380937bd2..bdae33e4d0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,14 +72,16 @@ source = ["xarray"] exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"] [tool.mypy] +enable_error_code = "redundant-self" exclude = 'xarray/util/generate_.*\.py' files = "xarray" show_error_codes = true show_error_context = true warn_redundant_casts = true +warn_unused_configs = true warn_unused_ignores = true -# Most of the numerical computing stack doesn't have type annotations yet. +# Much of the numerical computing stack doesn't have type annotations yet. [[tool.mypy.overrides]] ignore_missing_imports = true module = [ @@ -92,10 +94,10 @@ module = [ "cftime.*", "cubed.*", "cupy.*", + "dask.types.*", "fsspec.*", "h5netcdf.*", "h5py.*", - "importlib_metadata.*", "iris.*", "matplotlib.*", "mpl_toolkits.*", @@ -118,6 +120,108 @@ module = [ "numpy.exceptions.*", # remove once support for `numpy<2.0` has been dropped ] +# Gradually we want to add more modules to this list, ratcheting up our total +# coverage. Once a module is here, functions are checked by mypy regardless of +# whether they have type annotations. It would be especially useful to have test +# files listed here, because without them being checked, we don't have a great +# way of testing our annotations. +[[tool.mypy.overrides]] +check_untyped_defs = true +module = [ + "xarray.core.accessor_dt", + "xarray.core.accessor_str", + "xarray.core.alignment", + "xarray.core.computation", + "xarray.core.rolling_exp", + "xarray.indexes.*", + "xarray.tests.*", +] +# This then excludes some modules from the above list. (So ideally we remove +# from here in time...) +[[tool.mypy.overrides]] +check_untyped_defs = false +module = [ + "xarray.tests.test_coarsen", + "xarray.tests.test_coding_times", + "xarray.tests.test_combine", + "xarray.tests.test_computation", + "xarray.tests.test_concat", + "xarray.tests.test_coordinates", + "xarray.tests.test_dask", + "xarray.tests.test_dataarray", + "xarray.tests.test_duck_array_ops", + "xarray.tests.test_groupby", + "xarray.tests.test_indexing", + "xarray.tests.test_merge", + "xarray.tests.test_missing", + "xarray.tests.test_parallelcompat", + "xarray.tests.test_plot", + "xarray.tests.test_sparse", + "xarray.tests.test_ufuncs", + "xarray.tests.test_units", + "xarray.tests.test_utils", + "xarray.tests.test_variable", + "xarray.tests.test_weighted", +] + +# Use strict = true whenever namedarray has become standalone. In the meantime +# don't forget to add all new files related to namedarray here: +# ref: https://mypy.readthedocs.io/en/stable/existing_code.html#introduce-stricter-options +[[tool.mypy.overrides]] +# Start off with these +warn_unused_ignores = true + +# Getting these passing should be easy +strict_concatenate = true +strict_equality = true + +# Strongly recommend enabling this one as soon as you can +check_untyped_defs = true + +# These shouldn't be too much additional work, but may be tricky to +# get passing if you use a lot of untyped libraries +disallow_any_generics = true +disallow_subclassing_any = true +disallow_untyped_decorators = true + +# These next few are various gradations of forcing use of type annotations +disallow_incomplete_defs = true +disallow_untyped_calls = true +disallow_untyped_defs = true + +# This one isn't too hard to get passing, but return on investment is lower +no_implicit_reexport = true + +# This one can be tricky to get passing if you use a lot of untyped libraries +warn_return_any = true + +module = ["xarray.namedarray.*", "xarray.tests.test_namedarray"] + +[tool.pyright] +# include = ["src"] +# exclude = ["**/node_modules", +# "**/__pycache__", +# "src/experimental", +# "src/typestubs" +# ] +# ignore = ["src/oldstuff"] +defineConstant = {DEBUG = true} +# stubPath = "src/stubs" +# venv = "env367" + +reportMissingImports = true +reportMissingTypeStubs = false + +# pythonVersion = "3.6" +# pythonPlatform = "Linux" + +# executionEnvironments = [ +# { root = "src/web", pythonVersion = "3.5", pythonPlatform = "Windows", extraPaths = [ "src/service_libs" ] }, +# { root = "src/sdk", pythonVersion = "3.0", extraPaths = [ "src/backend" ] }, +# { root = "src/tests", extraPaths = ["src/tests/e2e", "src/sdk" ]}, +# { root = "src" } +# ] + [tool.ruff] builtins = ["ellipsis"] exclude = [ @@ -146,15 +250,17 @@ select = [ known-first-party = ["xarray"] [tool.pytest.ini_options] -addopts = '--strict-markers' +addopts = ["--strict-config", "--strict-markers"] filterwarnings = [ "ignore:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning", ] +log_cli_level = "INFO" markers = [ "flaky: flaky tests", "network: tests requiring a network connection", "slow: slow tests", ] +minversion = "7" python_files = "test_*.py" testpaths = ["xarray/tests", "properties"] diff --git a/xarray/__init__.py b/xarray/__init__.py index b63b0d81470..1fd3b0c4336 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -1,3 +1,5 @@ +from importlib.metadata import version as _version + from xarray import testing, tutorial from xarray.backends.api import ( load_dataarray, @@ -41,12 +43,6 @@ from xarray.core.variable import IndexVariable, Variable, as_variable from xarray.util.print_versions import show_versions -try: - from importlib.metadata import version as _version -except ImportError: - # if the fallback library is missing, we are doomed. - from importlib_metadata import version as _version - try: __version__ = _version("xarray") except Exception: diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 58a05aeddce..27e155872de 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -488,6 +488,9 @@ def open_dataset( as coordinate variables. - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and other attributes as coordinate variables. + + Only existing variables can be set as coordinates. Missing variables + will be silently ignored. drop_variables: str or iterable of str, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or @@ -691,6 +694,9 @@ def open_dataarray( as coordinate variables. - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and other attributes as coordinate variables. + + Only existing variables can be set as coordinates. Missing variables + will be silently ignored. drop_variables: str or iterable of str, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or @@ -1522,6 +1528,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: Literal[True] = True, consolidated: bool | None = None, append_dim: Hashable | None = None, @@ -1567,6 +1574,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: bool = True, consolidated: bool | None = None, append_dim: Hashable | None = None, diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 79efbecfb7c..039fe371100 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -656,8 +656,22 @@ def cast_to_int_if_safe(num) -> np.ndarray: return num +def _division(deltas, delta, floor): + if floor: + # calculate int64 floor division + # to preserve integer dtype if possible (GH 4045, GH7817). + num = deltas // delta.astype(np.int64) + num = num.astype(np.int64, copy=False) + else: + num = deltas / delta + return num + + def encode_cf_datetime( - dates, units: str | None = None, calendar: str | None = None + dates, + units: str | None = None, + calendar: str | None = None, + dtype: np.dtype | None = None, ) -> tuple[np.ndarray, str, str]: """Given an array of datetime objects, returns the tuple `(num, units, calendar)` suitable for a CF compliant time variable. @@ -689,34 +703,47 @@ def encode_cf_datetime( time_units, ref_date = _unpack_time_units_and_ref_date(units) time_delta = _time_units_to_timedelta64(time_units) + # Wrap the dates in a DatetimeIndex to do the subtraction to ensure + # an OverflowError is raised if the ref_date is too far away from + # dates to be encoded (GH 2272). + dates_as_index = pd.DatetimeIndex(dates.ravel()) + time_deltas = dates_as_index - ref_date + # retrieve needed units to faithfully encode to int64 needed_units, data_ref_date = _unpack_time_units_and_ref_date(data_units) if data_units != units: # this accounts for differences in the reference times ref_delta = abs(data_ref_date - ref_date).to_timedelta64() - if ref_delta > np.timedelta64(0, "ns"): + data_delta = _time_units_to_timedelta64(needed_units) + if (ref_delta % data_delta) > np.timedelta64(0, "ns"): needed_units = _infer_time_units_from_diff(ref_delta) - # Wrap the dates in a DatetimeIndex to do the subtraction to ensure - # an OverflowError is raised if the ref_date is too far away from - # dates to be encoded (GH 2272). - dates_as_index = pd.DatetimeIndex(dates.ravel()) - time_deltas = dates_as_index - ref_date - # needed time delta to encode faithfully to int64 needed_time_delta = _time_units_to_timedelta64(needed_units) - if time_delta <= needed_time_delta: - # calculate int64 floor division - # to preserve integer dtype if possible (GH 4045, GH7817). - num = time_deltas // time_delta.astype(np.int64) - num = num.astype(np.int64, copy=False) - else: - emit_user_level_warning( - f"Times can't be serialized faithfully with requested units {units!r}. " - f"Resolution of {needed_units!r} needed. " - f"Serializing timeseries to floating point." - ) - num = time_deltas / time_delta + + floor_division = True + if time_delta > needed_time_delta: + floor_division = False + if dtype is None: + emit_user_level_warning( + f"Times can't be serialized faithfully to int64 with requested units {units!r}. " + f"Resolution of {needed_units!r} needed. Serializing times to floating point instead. " + f"Set encoding['dtype'] to integer dtype to serialize to int64. " + f"Set encoding['dtype'] to floating point dtype to silence this warning." + ) + elif np.issubdtype(dtype, np.integer): + new_units = f"{needed_units} since {format_timestamp(ref_date)}" + emit_user_level_warning( + f"Times can't be serialized faithfully to int64 with requested units {units!r}. " + f"Serializing with units {new_units!r} instead. " + f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. " + f"Set encoding['units'] to {new_units!r} to silence this warning ." + ) + units = new_units + time_delta = needed_time_delta + floor_division = True + + num = _division(time_deltas, time_delta, floor_division) num = num.values.reshape(dates.shape) except (OutOfBoundsDatetime, OverflowError, ValueError): @@ -728,7 +755,9 @@ def encode_cf_datetime( return (num, units, calendar) -def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarray, str]: +def encode_cf_timedelta( + timedeltas, units: str | None = None, dtype: np.dtype | None = None +) -> tuple[np.ndarray, str]: data_units = infer_timedelta_units(timedeltas) if units is None: @@ -744,18 +773,29 @@ def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarra # needed time delta to encode faithfully to int64 needed_time_delta = _time_units_to_timedelta64(needed_units) - if time_delta <= needed_time_delta: - # calculate int64 floor division - # to preserve integer dtype if possible - num = time_deltas // time_delta.astype(np.int64) - num = num.astype(np.int64, copy=False) - else: - emit_user_level_warning( - f"Timedeltas can't be serialized faithfully with requested units {units!r}. " - f"Resolution of {needed_units!r} needed. " - f"Serializing timedeltas to floating point." - ) - num = time_deltas / time_delta + + floor_division = True + if time_delta > needed_time_delta: + floor_division = False + if dtype is None: + emit_user_level_warning( + f"Timedeltas can't be serialized faithfully to int64 with requested units {units!r}. " + f"Resolution of {needed_units!r} needed. Serializing timeseries to floating point instead. " + f"Set encoding['dtype'] to integer dtype to serialize to int64. " + f"Set encoding['dtype'] to floating point dtype to silence this warning." + ) + elif np.issubdtype(dtype, np.integer): + emit_user_level_warning( + f"Timedeltas can't be serialized faithfully with requested units {units!r}. " + f"Serializing with units {needed_units!r} instead. " + f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. " + f"Set encoding['units'] to {needed_units!r} to silence this warning ." + ) + units = needed_units + time_delta = needed_time_delta + floor_division = True + + num = _division(time_deltas, time_delta, floor_division) num = num.values.reshape(timedeltas.shape) return (num, units) @@ -772,7 +812,8 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: units = encoding.pop("units", None) calendar = encoding.pop("calendar", None) - (data, units, calendar) = encode_cf_datetime(data, units, calendar) + dtype = encoding.get("dtype", None) + (data, units, calendar) = encode_cf_datetime(data, units, calendar, dtype) safe_setitem(attrs, "units", units, name=name) safe_setitem(attrs, "calendar", calendar, name=name) @@ -807,7 +848,9 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: if np.issubdtype(variable.data.dtype, np.timedelta64): dims, data, attrs, encoding = unpack_for_encoding(variable) - data, units = encode_cf_timedelta(data, encoding.pop("units", None)) + data, units = encode_cf_timedelta( + data, encoding.pop("units", None), encoding.get("dtype", None) + ) safe_setitem(attrs, "units", units, name=name) return Variable(dims, data, attrs, encoding, fastpath=True) diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index d694c531b15..c583afc93c2 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -304,6 +304,8 @@ def decode(self, variable: Variable, name: T_Name = None): ) # special case DateTime to properly handle NaT + dtype: np.typing.DTypeLike + decoded_fill_value: Any if _is_time_like(attrs.get("units")) and data.dtype.kind in "iu": dtype, decoded_fill_value = np.int64, np.iinfo(np.int64).min else: diff --git a/xarray/conventions.py b/xarray/conventions.py index 596831e270a..cf207f0c37a 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -1,9 +1,8 @@ from __future__ import annotations -import warnings from collections import defaultdict from collections.abc import Hashable, Iterable, Mapping, MutableMapping -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Literal, Union import numpy as np import pandas as pd @@ -16,6 +15,7 @@ contains_cftime_datetimes, ) from xarray.core.pycompat import is_duck_dask_array +from xarray.core.utils import emit_user_level_warning from xarray.core.variable import IndexVariable, Variable CF_RELATED_DATA = ( @@ -111,13 +111,13 @@ def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable: return var if is_duck_dask_array(data): - warnings.warn( + emit_user_level_warning( f"variable {name} has data in the form of a dask array with " "dtype=object, which means it is being loaded into memory " "to determine a data type that can be safely stored on disk. " "To avoid this, coerce this variable to a fixed-size dtype " "with astype() before saving it.", - SerializationWarning, + category=SerializationWarning, ) data = data.compute() @@ -354,15 +354,14 @@ def _update_bounds_encoding(variables: T_Variables) -> None: and "bounds" in attrs and attrs["bounds"] in variables ): - warnings.warn( - "Variable '{0}' has datetime type and a " - "bounds variable but {0}.encoding does not have " - "units specified. The units encodings for '{0}' " - "and '{1}' will be determined independently " + emit_user_level_warning( + f"Variable {name:s} has datetime type and a " + f"bounds variable but {name:s}.encoding does not have " + f"units specified. The units encodings for {name:s} " + f"and {attrs['bounds']} will be determined independently " "and may not be equal, counter to CF-conventions. " "If this is a concern, specify a units encoding for " - "'{0}' before writing to a file.".format(name, attrs["bounds"]), - UserWarning, + f"{name:s} before writing to a file.", ) if has_date_units and "bounds" in attrs: @@ -379,7 +378,7 @@ def decode_cf_variables( concat_characters: bool = True, mask_and_scale: bool = True, decode_times: bool = True, - decode_coords: bool = True, + decode_coords: bool | Literal["coordinates", "all"] = True, drop_variables: T_DropVariables = None, use_cftime: bool | None = None, decode_timedelta: bool | None = None, @@ -441,11 +440,14 @@ def stackable(dim: Hashable) -> bool: if decode_coords in [True, "coordinates", "all"]: var_attrs = new_vars[k].attrs if "coordinates" in var_attrs: - coord_str = var_attrs["coordinates"] - var_coord_names = coord_str.split() - if all(k in variables for k in var_coord_names): - new_vars[k].encoding["coordinates"] = coord_str - del var_attrs["coordinates"] + var_coord_names = [ + c for c in var_attrs["coordinates"].split() if c in variables + ] + # propagate as is + new_vars[k].encoding["coordinates"] = var_attrs["coordinates"] + del var_attrs["coordinates"] + # but only use as coordinate if existing + if var_coord_names: coord_names.update(var_coord_names) if decode_coords == "all": @@ -461,8 +463,8 @@ def stackable(dim: Hashable) -> bool: for role_or_name in part.split() ] if len(roles_and_names) % 2 == 1: - warnings.warn( - f"Attribute {attr_name:s} malformed", stacklevel=5 + emit_user_level_warning( + f"Attribute {attr_name:s} malformed" ) var_names = roles_and_names[1::2] if all(var_name in variables for var_name in var_names): @@ -474,9 +476,8 @@ def stackable(dim: Hashable) -> bool: for proj_name in var_names if proj_name not in variables ] - warnings.warn( + emit_user_level_warning( f"Variable(s) referenced in {attr_name:s} not in variables: {referenced_vars_not_in_variables!s}", - stacklevel=5, ) del var_attrs[attr_name] @@ -493,7 +494,7 @@ def decode_cf( concat_characters: bool = True, mask_and_scale: bool = True, decode_times: bool = True, - decode_coords: bool = True, + decode_coords: bool | Literal["coordinates", "all"] = True, drop_variables: T_DropVariables = None, use_cftime: bool | None = None, decode_timedelta: bool | None = None, @@ -632,12 +633,11 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): for name in list(non_dim_coord_names): if isinstance(name, str) and " " in name: - warnings.warn( + emit_user_level_warning( f"coordinate {name!r} has a space in its name, which means it " "cannot be marked as a coordinate on disk and will be " "saved as a data variable instead", - SerializationWarning, - stacklevel=6, + category=SerializationWarning, ) non_dim_coord_names.discard(name) @@ -710,11 +710,11 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): if global_coordinates: attributes = dict(attributes) if "coordinates" in attributes: - warnings.warn( + emit_user_level_warning( f"cannot serialize global coordinates {global_coordinates!r} because the global " f"attribute 'coordinates' already exists. This may prevent faithful roundtripping" f"of xarray datasets", - SerializationWarning, + category=SerializationWarning, ) else: attributes["coordinates"] = " ".join(sorted(map(str, global_coordinates))) diff --git a/xarray/core/_typed_ops.py b/xarray/core/_typed_ops.py index d3a783be45d..9b79ed46a9c 100644 --- a/xarray/core/_typed_ops.py +++ b/xarray/core/_typed_ops.py @@ -1,165 +1,182 @@ """Mixin classes with arithmetic operators.""" # This file was generated using xarray.util.generate_ops. Do not edit manually. +from __future__ import annotations + import operator +from typing import TYPE_CHECKING, Any, Callable, overload from xarray.core import nputils, ops +from xarray.core.types import ( + DaCompatible, + DsCompatible, + GroupByCompatible, + Self, + T_DataArray, + T_Xarray, + VarCompatible, +) + +if TYPE_CHECKING: + from xarray.core.dataset import Dataset class DatasetOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: DsCompatible, f: Callable, reflexive: bool = False + ) -> Self: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: DsCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: DsCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: DsCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ior) - def __ilshift__(self, other): + def __ilshift__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ilshift) - def __irshift__(self, other): + def __irshift__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.irshift) - def _unary_op(self, f, *args, **kwargs): + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -215,157 +232,159 @@ def conjugate(self, *args, **kwargs): class DataArrayOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: DaCompatible, f: Callable, reflexive: bool = False + ) -> Self: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: DaCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: DaCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ior) - def __ilshift__(self, other): + def __ilshift__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ilshift) - def __irshift__(self, other): + def __irshift__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.irshift) - def _unary_op(self, f, *args, **kwargs): + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -421,157 +440,303 @@ def conjugate(self, *args, **kwargs): class VariableOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: VarCompatible, f: Callable, reflexive: bool = False + ) -> Self: raise NotImplementedError - def __add__(self, other): + @overload + def __add__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __add__(self, other: VarCompatible) -> Self: + ... + + def __add__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.add) - def __sub__(self, other): + @overload + def __sub__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __sub__(self, other: VarCompatible) -> Self: + ... + + def __sub__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.sub) - def __mul__(self, other): + @overload + def __mul__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __mul__(self, other: VarCompatible) -> Self: + ... + + def __mul__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.mul) - def __pow__(self, other): + @overload + def __pow__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __pow__(self, other: VarCompatible) -> Self: + ... + + def __pow__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + @overload + def __truediv__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __truediv__(self, other: VarCompatible) -> Self: + ... + + def __truediv__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + @overload + def __floordiv__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __floordiv__(self, other: VarCompatible) -> Self: + ... + + def __floordiv__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + @overload + def __mod__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __mod__(self, other: VarCompatible) -> Self: + ... + + def __mod__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.mod) - def __and__(self, other): + @overload + def __and__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __and__(self, other: VarCompatible) -> Self: + ... + + def __and__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.and_) - def __xor__(self, other): + @overload + def __xor__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __xor__(self, other: VarCompatible) -> Self: + ... + + def __xor__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.xor) - def __or__(self, other): + @overload + def __or__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __or__(self, other: VarCompatible) -> Self: + ... + + def __or__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + @overload + def __lshift__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __lshift__(self, other: VarCompatible) -> Self: + ... + + def __lshift__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + @overload + def __rshift__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __rshift__(self, other: VarCompatible) -> Self: + ... + + def __rshift__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + @overload + def __lt__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __lt__(self, other: VarCompatible) -> Self: + ... + + def __lt__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.lt) - def __le__(self, other): + @overload + def __le__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __le__(self, other: VarCompatible) -> Self: + ... + + def __le__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.le) - def __gt__(self, other): + @overload + def __gt__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __gt__(self, other: VarCompatible) -> Self: + ... + + def __gt__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.gt) - def __ge__(self, other): + @overload + def __ge__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __ge__(self, other: VarCompatible) -> Self: + ... + + def __ge__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.ge) - def __eq__(self, other): + @overload # type:ignore[override] + def __eq__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __eq__(self, other: VarCompatible) -> Self: + ... + + def __eq__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + @overload # type:ignore[override] + def __ne__(self, other: T_DataArray) -> T_DataArray: + ... + + @overload + def __ne__(self, other: VarCompatible) -> Self: + ... + + def __ne__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: VarCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ior) - def __ilshift__(self, other): + def __ilshift__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ilshift) - def __irshift__(self, other): + def __irshift__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.irshift) - def _unary_op(self, f, *args, **kwargs): + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -627,91 +792,93 @@ def conjugate(self, *args, **kwargs): class DatasetGroupByOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: GroupByCompatible, f: Callable, reflexive: bool = False + ) -> Dataset: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.or_, reflexive=True) __add__.__doc__ = operator.add.__doc__ @@ -747,91 +914,93 @@ def __ror__(self, other): class DataArrayGroupByOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: T_Xarray, f: Callable, reflexive: bool = False + ) -> T_Xarray: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.or_, reflexive=True) __add__.__doc__ = operator.add.__doc__ diff --git a/xarray/core/_typed_ops.pyi b/xarray/core/_typed_ops.pyi deleted file mode 100644 index 9e2ba2d3a06..00000000000 --- a/xarray/core/_typed_ops.pyi +++ /dev/null @@ -1,782 +0,0 @@ -"""Stub file for mixin classes with arithmetic operators.""" -# This file was generated using xarray.util.generate_ops. Do not edit manually. - -from typing import NoReturn, TypeVar, overload - -import numpy as np -from numpy.typing import ArrayLike - -from .dataarray import DataArray -from .dataset import Dataset -from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy -from .types import ( - DaCompatible, - DsCompatible, - GroupByIncompatible, - ScalarOrArray, - VarCompatible, -) -from .variable import Variable - -try: - from dask.array import Array as DaskArray -except ImportError: - DaskArray = np.ndarray # type: ignore - -# DatasetOpsMixin etc. are parent classes of Dataset etc. -# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally -# we use the ones in `types`. (We're open to refining this, and potentially integrating -# the `py` & `pyi` files to simplify them.) -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin") - -class DatasetOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - def __add__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __sub__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __mul__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __pow__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __truediv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __floordiv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __mod__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __and__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __xor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __or__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __lshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __lt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __le__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __gt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __ge__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __eq__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... # type: ignore[override] - def __ne__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... # type: ignore[override] - def __radd__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rsub__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rmul__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rpow__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rtruediv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rfloordiv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rmod__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rand__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rxor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __ror__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_Dataset) -> T_Dataset: ... - def __pos__(self: T_Dataset) -> T_Dataset: ... - def __abs__(self: T_Dataset) -> T_Dataset: ... - def __invert__(self: T_Dataset) -> T_Dataset: ... - def round(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def argsort(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def conj(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def conjugate(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - -class DataArrayOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __add__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __sub__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __mul__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __pow__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __truediv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __floordiv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __mod__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __and__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __xor__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __or__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __lshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __lt__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __le__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __gt__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ge__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __eq__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ne__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __radd__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rsub__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rmul__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rpow__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rtruediv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rfloordiv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rmod__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rand__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rxor__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ror__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_DataArray) -> T_DataArray: ... - def __pos__(self: T_DataArray) -> T_DataArray: ... - def __abs__(self: T_DataArray) -> T_DataArray: ... - def __invert__(self: T_DataArray) -> T_DataArray: ... - def round(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def argsort(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def conj(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def conjugate(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - -class VariableOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __add__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __sub__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mul__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __pow__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __truediv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __floordiv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mod__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __and__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __xor__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __or__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lt__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __le__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __gt__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ge__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __eq__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ne__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __radd__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rsub__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmul__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rpow__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rtruediv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rfloordiv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmod__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rand__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rxor__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ror__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_Variable) -> T_Variable: ... - def __pos__(self: T_Variable) -> T_Variable: ... - def __abs__(self: T_Variable) -> T_Variable: ... - def __invert__(self: T_Variable) -> T_Variable: ... - def round(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def argsort(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def conj(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def conjugate(self: T_Variable, *args, **kwargs) -> T_Variable: ... - -class DatasetGroupByOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: "DataArray") -> "Dataset": ... - @overload - def __add__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: "DataArray") -> "Dataset": ... - @overload - def __sub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: "DataArray") -> "Dataset": ... - @overload - def __mul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: "DataArray") -> "Dataset": ... - @overload - def __pow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __truediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __floordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: "DataArray") -> "Dataset": ... - @overload - def __mod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: "DataArray") -> "Dataset": ... - @overload - def __and__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: "DataArray") -> "Dataset": ... - @overload - def __xor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: "DataArray") -> "Dataset": ... - @overload - def __or__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: "DataArray") -> "Dataset": ... - @overload - def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: "DataArray") -> "Dataset": ... - @overload - def __lt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: "DataArray") -> "Dataset": ... - @overload - def __le__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: "DataArray") -> "Dataset": ... - @overload - def __gt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ge__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: "DataArray") -> "Dataset": ... - @overload - def __eq__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ne__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: "DataArray") -> "Dataset": ... - @overload - def __radd__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rsub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rmul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rpow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rtruediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rfloordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rmod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rand__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rxor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ror__(self, other: GroupByIncompatible) -> NoReturn: ... - -class DataArrayGroupByOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __add__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __sub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __pow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __truediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __floordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __and__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __xor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __or__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __le__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __gt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ge__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __eq__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ne__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __radd__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rsub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rpow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rtruediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rfloordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rand__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rxor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ror__(self, other: GroupByIncompatible) -> NoReturn: ... diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py index 4c1ce4b5c48..8255e2a5232 100644 --- a/xarray/core/accessor_dt.py +++ b/xarray/core/accessor_dt.py @@ -7,6 +7,7 @@ import pandas as pd from xarray.coding.times import infer_calendar_name +from xarray.core import duck_array_ops from xarray.core.common import ( _contains_datetime_like_objects, is_np_datetime_like, @@ -50,7 +51,7 @@ def _access_through_cftimeindex(values, name): from xarray.coding.cftimeindex import CFTimeIndex if not isinstance(values, CFTimeIndex): - values_as_cftimeindex = CFTimeIndex(values.ravel()) + values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values)) else: values_as_cftimeindex = values if name == "season": @@ -69,7 +70,7 @@ def _access_through_series(values, name): """Coerce an array of datetime-like values to a pandas Series and access requested datetime component """ - values_as_series = pd.Series(values.ravel(), copy=False) + values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False) if name == "season": months = values_as_series.dt.month.values field_values = _season_from_months(months) @@ -148,10 +149,10 @@ def _round_through_series_or_index(values, name, freq): from xarray.coding.cftimeindex import CFTimeIndex if is_np_datetime_like(values.dtype): - values_as_series = pd.Series(values.ravel(), copy=False) + values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False) method = getattr(values_as_series.dt, name) else: - values_as_cftimeindex = CFTimeIndex(values.ravel()) + values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values)) method = getattr(values_as_cftimeindex, name) field_values = method(freq=freq).values @@ -195,7 +196,7 @@ def _strftime_through_cftimeindex(values, date_format: str): """ from xarray.coding.cftimeindex import CFTimeIndex - values_as_cftimeindex = CFTimeIndex(values.ravel()) + values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values)) field_values = values_as_cftimeindex.strftime(date_format) return field_values.values.reshape(values.shape) @@ -205,7 +206,7 @@ def _strftime_through_series(values, date_format: str): """Coerce an array of datetime-like values to a pandas Series and apply string formatting """ - values_as_series = pd.Series(values.ravel(), copy=False) + values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False) strs = values_as_series.dt.strftime(date_format) return strs.values.reshape(values.shape) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index aa6dc2c7114..573200b5c88 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -2386,7 +2386,7 @@ def _partitioner( # _apply breaks on an empty array in this case if not self._obj.size: - return self._obj.copy().expand_dims({dim: 0}, axis=-1) # type: ignore[return-value] + return self._obj.copy().expand_dims({dim: 0}, axis=-1) arrfunc = lambda x, isep: np.array(func(x, isep), dtype=self._obj.dtype) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index ff2ecbc74a1..732ec5d3ea6 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -5,7 +5,7 @@ from collections import defaultdict from collections.abc import Hashable, Iterable, Mapping from contextlib import suppress -from typing import TYPE_CHECKING, Any, Callable, Generic, cast +from typing import TYPE_CHECKING, Any, Callable, Final, Generic, TypeVar, cast, overload import numpy as np import pandas as pd @@ -26,7 +26,13 @@ if TYPE_CHECKING: from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import JoinOptions, T_DataArray, T_Dataset, T_DuckArray + from xarray.core.types import ( + Alignable, + JoinOptions, + T_DataArray, + T_Dataset, + T_DuckArray, + ) def reindex_variables( @@ -128,7 +134,7 @@ def __init__( objects: Iterable[T_Alignable], join: str = "inner", indexes: Mapping[Any, Any] | None = None, - exclude_dims: Iterable = frozenset(), + exclude_dims: str | Iterable[Hashable] = frozenset(), exclude_vars: Iterable[Hashable] = frozenset(), method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, @@ -576,12 +582,111 @@ def align(self) -> None: self.reindex_all() +T_Obj1 = TypeVar("T_Obj1", bound="Alignable") +T_Obj2 = TypeVar("T_Obj2", bound="Alignable") +T_Obj3 = TypeVar("T_Obj3", bound="Alignable") +T_Obj4 = TypeVar("T_Obj4", bound="Alignable") +T_Obj5 = TypeVar("T_Obj5", bound="Alignable") + + +@overload +def align( + obj1: T_Obj1, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1]: + ... + + +@overload +def align( + obj1: T_Obj1, + obj2: T_Obj2, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2]: + ... + + +@overload def align( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2, T_Obj3]: + ... + + +@overload +def align( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4]: + ... + + +@overload +def align( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + obj5: T_Obj5, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4, T_Obj5]: + ... + + +@overload +def align( + *objects: T_Alignable, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Alignable, ...]: + ... + + +def align( # type: ignore[misc] *objects: T_Alignable, join: JoinOptions = "inner", copy: bool = True, indexes=None, - exclude=frozenset(), + exclude: str | Iterable[Hashable] = frozenset(), fill_value=dtypes.NA, ) -> tuple[T_Alignable, ...]: """ @@ -620,7 +725,7 @@ def align( indexes : dict-like, optional Any indexes explicitly provided with the `indexes` argument should be used in preference to the aligned indexes. - exclude : sequence of str, optional + exclude : str, iterable of hashable or None, optional Dimensions that must be excluded from alignment fill_value : scalar or dict-like, optional Value to use for newly missing values. If a dict-like, maps @@ -787,12 +892,12 @@ def align( def deep_align( objects: Iterable[Any], join: JoinOptions = "inner", - copy=True, + copy: bool = True, indexes=None, - exclude=frozenset(), - raise_on_invalid=True, + exclude: str | Iterable[Hashable] = frozenset(), + raise_on_invalid: bool = True, fill_value=dtypes.NA, -): +) -> list[Any]: """Align objects for merging, recursing into dictionary values. This function is not public API. @@ -807,12 +912,12 @@ def deep_align( def is_alignable(obj): return isinstance(obj, (Coordinates, DataArray, Dataset)) - positions = [] - keys = [] - out = [] - targets = [] - no_key = object() - not_replaced = object() + positions: list[int] = [] + keys: list[type[object] | Hashable] = [] + out: list[Any] = [] + targets: list[Alignable] = [] + no_key: Final = object() + not_replaced: Final = object() for position, variables in enumerate(objects): if is_alignable(variables): positions.append(position) @@ -857,7 +962,7 @@ def is_alignable(obj): if key is no_key: out[position] = aligned_obj else: - out[position][key] = aligned_obj # type: ignore[index] # maybe someone can fix this? + out[position][key] = aligned_obj return out @@ -988,9 +1093,69 @@ def _broadcast_dataset(ds: T_Dataset) -> T_Dataset: raise ValueError("all input must be Dataset or DataArray objects") -# TODO: this typing is too restrictive since it cannot deal with mixed -# DataArray and Dataset types...? Is this a problem? -def broadcast(*args: T_Alignable, exclude=None) -> tuple[T_Alignable, ...]: +@overload +def broadcast( + obj1: T_Obj1, /, *, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Obj1]: + ... + + +@overload +def broadcast( + obj1: T_Obj1, obj2: T_Obj2, /, *, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Obj1, T_Obj2]: + ... + + +@overload +def broadcast( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + /, + *, + exclude: str | Iterable[Hashable] | None = None, +) -> tuple[T_Obj1, T_Obj2, T_Obj3]: + ... + + +@overload +def broadcast( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + /, + *, + exclude: str | Iterable[Hashable] | None = None, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4]: + ... + + +@overload +def broadcast( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + obj5: T_Obj5, + /, + *, + exclude: str | Iterable[Hashable] | None = None, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4, T_Obj5]: + ... + + +@overload +def broadcast( + *args: T_Alignable, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Alignable, ...]: + ... + + +def broadcast( # type: ignore[misc] + *args: T_Alignable, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Alignable, ...]: """Explicitly broadcast any number of DataArray or Dataset objects against one another. @@ -1004,7 +1169,7 @@ def broadcast(*args: T_Alignable, exclude=None) -> tuple[T_Alignable, ...]: ---------- *args : DataArray or Dataset Arrays to broadcast against each other. - exclude : sequence of str, optional + exclude : str, iterable of hashable or None, optional Dimensions that must not be broadcasted Returns diff --git a/xarray/core/common.py b/xarray/core/common.py index ade701457c6..ab8a4d84261 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -45,6 +45,7 @@ DatetimeLike, DTypeLikeSave, ScalarOrArray, + Self, SideOptions, T_Chunks, T_DataWithCoords, @@ -222,7 +223,7 @@ def _get_axis_num(self: Any, dim: Hashable) -> int: raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") @property - def sizes(self: Any) -> Frozen[Hashable, int]: + def sizes(self: Any) -> Mapping[Hashable, int]: """Ordered mapping from dimension names to lengths. Immutable. @@ -381,11 +382,11 @@ class DataWithCoords(AttrAccessMixin): __slots__ = ("_close",) def squeeze( - self: T_DataWithCoords, + self, dim: Hashable | Iterable[Hashable] | None = None, drop: bool = False, axis: int | Iterable[int] | None = None, - ) -> T_DataWithCoords: + ) -> Self: """Return a new object with squeezed data. Parameters @@ -414,12 +415,12 @@ def squeeze( return self.isel(drop=drop, **{d: 0 for d in dims}) def clip( - self: T_DataWithCoords, + self, min: ScalarOrArray | None = None, max: ScalarOrArray | None = None, *, keep_attrs: bool | None = None, - ) -> T_DataWithCoords: + ) -> Self: """ Return an array whose values are limited to ``[min, max]``. At least one of max or min must be given. @@ -472,10 +473,10 @@ def _calc_assign_results( return {k: v(self) if callable(v) else v for k, v in kwargs.items()} def assign_coords( - self: T_DataWithCoords, + self, coords: Mapping[Any, Any] | None = None, **coords_kwargs: Any, - ) -> T_DataWithCoords: + ) -> Self: """Assign new coordinates to this object. Returns a new object with all the original data in addition to the new @@ -620,9 +621,7 @@ def assign_coords( data.coords.update(results) return data - def assign_attrs( - self: T_DataWithCoords, *args: Any, **kwargs: Any - ) -> T_DataWithCoords: + def assign_attrs(self, *args: Any, **kwargs: Any) -> Self: """Assign new attrs to this object. Returns a new object equivalent to ``self.attrs.update(*args, **kwargs)``. @@ -1061,11 +1060,12 @@ def _resample( restore_coord_dims=restore_coord_dims, ) - def where( - self: T_DataWithCoords, cond: Any, other: Any = dtypes.NA, drop: bool = False - ) -> T_DataWithCoords: + def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: """Filter elements from this object according to a condition. + Returns elements from 'DataArray', where 'cond' is True, + otherwise fill in 'other'. + This operation follows the normal broadcasting and alignment rules that xarray uses for binary arithmetic. @@ -1073,10 +1073,12 @@ def where( ---------- cond : DataArray, Dataset, or callable Locations at which to preserve this object's values. dtype must be `bool`. - If a callable, it must expect this object as its only parameter. - other : scalar, DataArray or Dataset, optional + If a callable, the callable is passed this object, and the result is used as + the value for cond. + other : scalar, DataArray, Dataset, or callable, optional Value to use for locations in this object where ``cond`` is False. - By default, these locations filled with NA. + By default, these locations are filled with NA. If a callable, it must + expect this object as its only parameter. drop : bool, default: False If True, coordinate labels that only correspond to False values of the condition are dropped from the result. @@ -1124,7 +1126,16 @@ def where( [15., nan, nan, nan]]) Dimensions without coordinates: x, y - >>> a.where(lambda x: x.x + x.y < 4, drop=True) + >>> a.where(lambda x: x.x + x.y < 4, lambda x: -x) + + array([[ 0, 1, 2, 3, -4], + [ 5, 6, 7, -8, -9], + [ 10, 11, -12, -13, -14], + [ 15, -16, -17, -18, -19], + [-20, -21, -22, -23, -24]]) + Dimensions without coordinates: x, y + + >>> a.where(a.x + a.y < 4, drop=True) array([[ 0., 1., 2., 3.], [ 5., 6., 7., nan], @@ -1132,14 +1143,6 @@ def where( [15., nan, nan, nan]]) Dimensions without coordinates: x, y - >>> a.where(a.x + a.y < 4, -1, drop=True) - - array([[ 0, 1, 2, 3], - [ 5, 6, 7, -1], - [10, 11, -1, -1], - [15, -1, -1, -1]]) - Dimensions without coordinates: x, y - See Also -------- numpy.where : corresponding numpy function @@ -1151,14 +1154,16 @@ def where( if callable(cond): cond = cond(self) + if callable(other): + other = other(self) if drop: if not isinstance(cond, (Dataset, DataArray)): raise TypeError( - f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r}" + f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r} (or a callable than returns one)." ) - self, cond = align(self, cond) # type: ignore[assignment] + self, cond = align(self, cond) def _dataarray_indexer(dim: Hashable) -> DataArray: return cond.any(dim=(d for d in cond.dims if d != dim)) @@ -1205,9 +1210,7 @@ def close(self) -> None: self._close() self._close = None - def isnull( - self: T_DataWithCoords, keep_attrs: bool | None = None - ) -> T_DataWithCoords: + def isnull(self, keep_attrs: bool | None = None) -> Self: """Test each value in the array for whether it is a missing value. Parameters @@ -1250,9 +1253,7 @@ def isnull( keep_attrs=keep_attrs, ) - def notnull( - self: T_DataWithCoords, keep_attrs: bool | None = None - ) -> T_DataWithCoords: + def notnull(self, keep_attrs: bool | None = None) -> Self: """Test each value in the array for whether it is not a missing value. Parameters @@ -1295,7 +1296,7 @@ def notnull( keep_attrs=keep_attrs, ) - def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords: + def isin(self, test_elements: Any) -> Self: """Tests each value in the array for whether it is in test elements. Parameters @@ -1344,7 +1345,7 @@ def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords: ) def astype( - self: T_DataWithCoords, + self, dtype, *, order=None, @@ -1352,7 +1353,7 @@ def astype( subok=None, copy=None, keep_attrs=True, - ) -> T_DataWithCoords: + ) -> Self: """ Copy of the xarray object, with data cast to a specified type. Leaves coordinate dtype unchanged. @@ -1419,7 +1420,7 @@ def astype( dask="allowed", ) - def __enter__(self: T_DataWithCoords) -> T_DataWithCoords: + def __enter__(self) -> Self: return self def __exit__(self, exc_type, exc_value, traceback) -> None: diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 971f036b394..9cb60e0c424 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -8,7 +8,7 @@ import operator import warnings from collections import Counter -from collections.abc import Hashable, Iterable, Mapping, Sequence, Set +from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence, Set from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, overload import numpy as np @@ -163,7 +163,7 @@ def to_gufunc_string(self, exclude_dims=frozenset()): if exclude_dims: exclude_dims = [self.dims_map[dim] for dim in exclude_dims] - counter = Counter() + counter: Counter = Counter() def _enumerate(dim): if dim in exclude_dims: @@ -289,8 +289,14 @@ def apply_dataarray_vfunc( from xarray.core.dataarray import DataArray if len(args) > 1: - args = deep_align( - args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False + args = tuple( + deep_align( + args, + join=join, + copy=False, + exclude=exclude_dims, + raise_on_invalid=False, + ) ) objs = _all_of_type(args, DataArray) @@ -506,8 +512,14 @@ def apply_dataset_vfunc( objs = _all_of_type(args, Dataset) if len(args) > 1: - args = deep_align( - args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False + args = tuple( + deep_align( + args, + join=join, + copy=False, + exclude=exclude_dims, + raise_on_invalid=False, + ) ) list_of_coords, list_of_indexes = build_output_coords_and_indexes( @@ -571,7 +583,7 @@ def apply_groupby_func(func, *args): assert groupbys, "must have at least one groupby to iterate over" first_groupby = groupbys[0] (grouper,) = first_groupby.groupers - if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]): + if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]): # type: ignore[union-attr] raise ValueError( "apply_ufunc can only perform operations over " "multiple GroupBy objects at once if they are all " @@ -583,6 +595,7 @@ def apply_groupby_func(func, *args): iterators = [] for arg in args: + iterator: Iterator[Any] if isinstance(arg, GroupBy): iterator = (value for _, value in arg) elif hasattr(arg, "dims") and grouped_dim in arg.dims: @@ -597,9 +610,9 @@ def apply_groupby_func(func, *args): iterator = itertools.repeat(arg) iterators.append(iterator) - applied = (func(*zipped_args) for zipped_args in zip(*iterators)) + applied: Iterator = (func(*zipped_args) for zipped_args in zip(*iterators)) applied_example, applied = peek_at(applied) - combine = first_groupby._combine + combine = first_groupby._combine # type: ignore[attr-defined] if isinstance(applied_example, tuple): combined = tuple(combine(output) for output in zip(*applied)) else: @@ -893,7 +906,7 @@ def apply_ufunc( dataset_fill_value: object = _NO_FILL_VALUE, keep_attrs: bool | str | None = None, kwargs: Mapping | None = None, - dask: str = "forbidden", + dask: Literal["forbidden", "allowed", "parallelized"] = "forbidden", output_dtypes: Sequence | None = None, output_sizes: Mapping[Any, int] | None = None, meta: Any = None, @@ -2122,7 +2135,8 @@ def _calc_idxminmax( chunkmanager = get_chunked_array_type(array.data) chunks = dict(zip(array.dims, array.chunks)) dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim]) - res = indx.copy(data=dask_coord[indx.data.ravel()].reshape(indx.shape)) + data = dask_coord[duck_array_ops.ravel(indx.data)] + res = indx.copy(data=duck_array_ops.reshape(data, indx.shape)) # we need to attach back the dim name res.name = dim else: diff --git a/xarray/core/concat.py b/xarray/core/concat.py index a76bb6b0033..a136480b2fb 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Hashable, Iterable -from typing import TYPE_CHECKING, Any, Union, cast, overload +from typing import TYPE_CHECKING, Any, Union, overload import numpy as np import pandas as pd @@ -504,8 +504,7 @@ def _dataset_concat( # case where concat dimension is a coordinate or data_var but not a dimension if (dim in coord_names or dim in data_names) and dim not in dim_names: - # TODO: Overriding type because .expand_dims has incorrect typing: - datasets = [cast(T_Dataset, ds.expand_dims(dim)) for ds in datasets] + datasets = [ds.expand_dims(dim) for ds in datasets] # determine which variables to concatenate concat_over, equals, concat_dim_lengths = _calc_concat_over( @@ -708,8 +707,7 @@ def _dataarray_concat( if compat == "identical": raise ValueError("array names not identical") else: - # TODO: Overriding type because .rename has incorrect typing: - arr = cast(T_DataArray, arr.rename(name)) + arr = arr.rename(name) datasets.append(arr._to_temp_dataset()) ds = _dataset_concat( diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index e20c022e637..0c85b2a2d69 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -23,7 +23,7 @@ create_default_index_implicit, ) from xarray.core.merge import merge_coordinates_without_align, merge_coords -from xarray.core.types import Self, T_DataArray +from xarray.core.types import DataVars, Self, T_DataArray, T_Xarray from xarray.core.utils import ( Frozen, ReprObject, @@ -425,7 +425,7 @@ def __delitem__(self, key: Hashable) -> None: # redirect to DatasetCoordinates.__delitem__ del self._data.coords[key] - def equals(self, other: Coordinates) -> bool: + def equals(self, other: Self) -> bool: """Two Coordinates objects are equal if they have matching variables, all of which are equal. @@ -437,7 +437,7 @@ def equals(self, other: Coordinates) -> bool: return False return self.to_dataset().equals(other.to_dataset()) - def identical(self, other: Coordinates) -> bool: + def identical(self, other: Self) -> bool: """Like equals, but also checks all variable attributes. See Also @@ -565,9 +565,7 @@ def update(self, other: Mapping[Any, Any]) -> None: self._update_coords(coords, indexes) - def assign( - self, coords: Mapping | None = None, **coords_kwargs: Any - ) -> Coordinates: + def assign(self, coords: Mapping | None = None, **coords_kwargs: Any) -> Self: """Assign new coordinates (and indexes) to a Coordinates object, returning a new object with all the original coordinates in addition to the new ones. @@ -656,7 +654,7 @@ def copy( self, deep: bool = False, memo: dict[int, Any] | None = None, - ) -> Coordinates: + ) -> Self: """Return a copy of this Coordinates object.""" # do not copy indexes (may corrupt multi-coordinate indexes) # TODO: disable variables deepcopy? it may also be problematic when they @@ -664,8 +662,16 @@ def copy( variables = { k: v._copy(deep=deep, memo=memo) for k, v in self.variables.items() } - return Coordinates._construct_direct( - coords=variables, indexes=dict(self.xindexes), dims=dict(self.sizes) + + # TODO: getting an error with `self._construct_direct`, possibly because of how + # a subclass implements `_construct_direct`. (This was originally the same + # runtime code, but we switched the type definitions in #8216, which + # necessitates the cast.) + return cast( + Self, + Coordinates._construct_direct( + coords=variables, indexes=dict(self.xindexes), dims=dict(self.sizes) + ), ) @@ -915,9 +921,7 @@ def drop_indexed_coords( return Coordinates._construct_direct(coords=new_variables, indexes=new_indexes) -def assert_coordinate_consistent( - obj: T_DataArray | Dataset, coords: Mapping[Any, Variable] -) -> None: +def assert_coordinate_consistent(obj: T_Xarray, coords: Mapping[Any, Variable]) -> None: """Make sure the dimension coordinate of obj is consistent with coords. obj: DataArray or Dataset @@ -933,7 +937,7 @@ def assert_coordinate_consistent( def create_coords_with_default_indexes( - coords: Mapping[Any, Any], data_vars: Mapping[Any, Any] | None = None + coords: Mapping[Any, Any], data_vars: DataVars | None = None ) -> Coordinates: """Returns a Coordinates object from a mapping of coordinates (arbitrary objects). diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 791aad5cd17..391b4ed9412 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4,7 +4,15 @@ import warnings from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence from os import PathLike -from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + NoReturn, + overload, +) import numpy as np import pandas as pd @@ -41,6 +49,7 @@ from xarray.core.indexing import is_fancy_indexer, map_index_queries from xarray.core.merge import PANDAS_TYPES, MergeError from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.types import DaCompatible, T_DataArray, T_DataArrayOrSet from xarray.core.utils import ( Default, HybridMappingProxy, @@ -57,6 +66,7 @@ ) from xarray.plot.accessor import DataArrayPlotAccessor from xarray.plot.utils import _get_units_from_attrs +from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: from typing import TypeVar, Union @@ -100,8 +110,9 @@ QueryEngineOptions, QueryParserOptions, ReindexMethodOptions, + Self, SideOptions, - T_DataArray, + T_Chunks, T_Xarray, ) from xarray.core.weighted import DataArrayWeighted @@ -119,7 +130,7 @@ def _check_coords_dims(shape, coords, dims): f"dimensions {dims}" ) - for d, s in zip(v.dims, v.shape): + for d, s in v.sizes.items(): if s != sizes[d]: raise ValueError( f"conflicting sizes for dimension {d!r}: " @@ -127,13 +138,6 @@ def _check_coords_dims(shape, coords, dims): f"coordinate {k!r}" ) - if k in sizes and v.shape != (sizes[k],): - raise ValueError( - f"coordinate {k!r} is a DataArray dimension, but " - f"it has shape {v.shape!r} rather than expected shape {sizes[k]!r} " - "matching the dimension size" - ) - def _infer_coords_and_dims( shape, coords, dims @@ -213,13 +217,13 @@ def _check_data_shape(data, coords, dims): return data -class _LocIndexer: +class _LocIndexer(Generic[T_DataArray]): __slots__ = ("data_array",) - def __init__(self, data_array: DataArray): + def __init__(self, data_array: T_DataArray): self.data_array = data_array - def __getitem__(self, key) -> DataArray: + def __getitem__(self, key) -> T_DataArray: if not utils.is_dict_like(key): # expand the indexer so we can handle Ellipsis labels = indexing.expanded_indexer(key, self.data_array.ndim) @@ -462,12 +466,12 @@ def __init__( @classmethod def _construct_direct( - cls: type[T_DataArray], + cls, variable: Variable, coords: dict[Any, Variable], name: Hashable, indexes: dict[Hashable, Index], - ) -> T_DataArray: + ) -> Self: """Shortcut around __init__ for internal use when we want to skip costly validation """ @@ -480,12 +484,12 @@ def _construct_direct( return obj def _replace( - self: T_DataArray, + self, variable: Variable | None = None, coords=None, name: Hashable | None | Default = _default, indexes=None, - ) -> T_DataArray: + ) -> Self: if variable is None: variable = self.variable if coords is None: @@ -497,10 +501,10 @@ def _replace( return type(self)(variable, coords, name=name, indexes=indexes, fastpath=True) def _replace_maybe_drop_dims( - self: T_DataArray, + self, variable: Variable, name: Hashable | None | Default = _default, - ) -> T_DataArray: + ) -> Self: if variable.dims == self.dims and variable.shape == self.shape: coords = self._coords.copy() indexes = self._indexes @@ -522,12 +526,12 @@ def _replace_maybe_drop_dims( return self._replace(variable, coords, name, indexes=indexes) def _overwrite_indexes( - self: T_DataArray, + self, indexes: Mapping[Any, Index], variables: Mapping[Any, Variable] | None = None, drop_coords: list[Hashable] | None = None, rename_dims: Mapping[Any, Any] | None = None, - ) -> T_DataArray: + ) -> Self: """Maybe replace indexes and their corresponding coordinates.""" if not indexes: return self @@ -560,8 +564,8 @@ def _to_temp_dataset(self) -> Dataset: return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False) def _from_temp_dataset( - self: T_DataArray, dataset: Dataset, name: Hashable | None | Default = _default - ) -> T_DataArray: + self, dataset: Dataset, name: Hashable | None | Default = _default + ) -> Self: variable = dataset._variables.pop(_THIS_ARRAY) coords = dataset._variables indexes = dataset._indexes @@ -773,7 +777,7 @@ def to_numpy(self) -> np.ndarray: """ return self.variable.to_numpy() - def as_numpy(self: T_DataArray) -> T_DataArray: + def as_numpy(self) -> Self: """ Coerces wrapped data and coordinates into numpy arrays, returning a DataArray. @@ -828,7 +832,7 @@ def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]: key = indexing.expanded_indexer(key, self.ndim) return dict(zip(self.dims, key)) - def _getitem_coord(self: T_DataArray, key: Any) -> T_DataArray: + def _getitem_coord(self, key: Any) -> Self: from xarray.core.dataset import _get_virtual_variable try: @@ -839,7 +843,7 @@ def _getitem_coord(self: T_DataArray, key: Any) -> T_DataArray: return self._replace_maybe_drop_dims(var, name=key) - def __getitem__(self: T_DataArray, key: Any) -> T_DataArray: + def __getitem__(self, key: Any) -> Self: if isinstance(key, str): return self._getitem_coord(key) else: @@ -909,10 +913,16 @@ def encoding(self) -> dict[Any, Any]: def encoding(self, value: Mapping[Any, Any]) -> None: self.variable.encoding = dict(value) - def reset_encoding(self: T_DataArray) -> T_DataArray: + def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() + + def drop_encoding(self) -> Self: """Return a new DataArray without encoding on the array or any attached coords.""" - ds = self._to_temp_dataset().reset_encoding() + ds = self._to_temp_dataset().drop_encoding() return self._from_temp_dataset(ds) @property @@ -949,26 +959,29 @@ def coords(self) -> DataArrayCoordinates: @overload def reset_coords( - self: T_DataArray, + self, names: Dims = None, + *, drop: Literal[False] = False, ) -> Dataset: ... @overload def reset_coords( - self: T_DataArray, + self, names: Dims = None, *, drop: Literal[True], - ) -> T_DataArray: + ) -> Self: ... + @_deprecate_positional_args("v2023.10.0") def reset_coords( - self: T_DataArray, + self, names: Dims = None, + *, drop: bool = False, - ) -> T_DataArray | Dataset: + ) -> Self | Dataset: """Given names of coordinates, reset them to become variables. Parameters @@ -1080,15 +1093,15 @@ def __dask_postpersist__(self): func, args = self._to_temp_dataset().__dask_postpersist__() return self._dask_finalize, (self.name, func) + args - @staticmethod - def _dask_finalize(results, name, func, *args, **kwargs) -> DataArray: + @classmethod + def _dask_finalize(cls, results, name, func, *args, **kwargs) -> Self: ds = func(results, *args, **kwargs) variable = ds._variables.pop(_THIS_ARRAY) coords = ds._variables indexes = ds._indexes - return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) + return cls(variable, coords, name=name, indexes=indexes, fastpath=True) - def load(self: T_DataArray, **kwargs) -> T_DataArray: + def load(self, **kwargs) -> Self: """Manually trigger loading of this array's data from disk or a remote source into memory and return this array. @@ -1112,7 +1125,7 @@ def load(self: T_DataArray, **kwargs) -> T_DataArray: self._coords = new._coords return self - def compute(self: T_DataArray, **kwargs) -> T_DataArray: + def compute(self, **kwargs) -> Self: """Manually trigger loading of this array's data from disk or a remote source into memory and return a new array. The original is left unaltered. @@ -1134,7 +1147,7 @@ def compute(self: T_DataArray, **kwargs) -> T_DataArray: new = self.copy(deep=False) return new.load(**kwargs) - def persist(self: T_DataArray, **kwargs) -> T_DataArray: + def persist(self, **kwargs) -> Self: """Trigger computation in constituent dask arrays This keeps them as dask arrays but encourages them to keep data in @@ -1153,7 +1166,7 @@ def persist(self: T_DataArray, **kwargs) -> T_DataArray: ds = self._to_temp_dataset().persist(**kwargs) return self._from_temp_dataset(ds) - def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: + def copy(self, deep: bool = True, data: Any = None) -> Self: """Returns a copy of this array. If `deep=True`, a deep copy is made of the data array. @@ -1224,11 +1237,11 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: return self._copy(deep=deep, data=data) def _copy( - self: T_DataArray, + self, deep: bool = True, data: Any = None, memo: dict[int, Any] | None = None, - ) -> T_DataArray: + ) -> Self: variable = self.variable._copy(deep=deep, data=data, memo=memo) indexes, index_vars = self.xindexes.copy_indexes(deep=deep) @@ -1241,12 +1254,10 @@ def _copy( return self._replace(variable, coords, indexes=indexes) - def __copy__(self: T_DataArray) -> T_DataArray: + def __copy__(self) -> Self: return self._copy(deep=False) - def __deepcopy__( - self: T_DataArray, memo: dict[int, Any] | None = None - ) -> T_DataArray: + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: return self._copy(deep=True, memo=memo) # mutable objects should not be Hashable @@ -1286,15 +1297,11 @@ def chunksizes(self) -> Mapping[Any, tuple[int, ...]]: all_variables = [self.variable] + [c.variable for c in self.coords.values()] return get_chunksizes(all_variables) + @_deprecate_positional_args("v2023.10.0") def chunk( - self: T_DataArray, - chunks: ( - int - | Literal["auto"] - | tuple[int, ...] - | tuple[tuple[int, ...], ...] - | Mapping[Any, None | int | tuple[int, ...]] - ) = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + self, + chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + *, name_prefix: str = "xarray-", token: str | None = None, lock: bool = False, @@ -1302,7 +1309,7 @@ def chunk( chunked_array_type: str | ChunkManagerEntrypoint | None = None, from_array_kwargs=None, **chunks_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Coerce this array's data into a dask arrays with the given chunks. If this variable is a non-dask array, it will be converted to dask @@ -1362,7 +1369,7 @@ def chunk( if isinstance(chunks, (float, str, int)): # ignoring type; unclear why it won't accept a Literal into the value. - chunks = dict.fromkeys(self.dims, chunks) # type: ignore + chunks = dict.fromkeys(self.dims, chunks) elif isinstance(chunks, (tuple, list)): chunks = dict(zip(self.dims, chunks)) else: @@ -1380,12 +1387,12 @@ def chunk( return self._from_temp_dataset(ds) def isel( - self: T_DataArray, + self, indexers: Mapping[Any, Any] | None = None, drop: bool = False, missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by selecting indexes along the specified dimension(s). @@ -1471,13 +1478,13 @@ def isel( return self._replace(variable=variable, coords=coords, indexes=indexes) def sel( - self: T_DataArray, + self, indexers: Mapping[Any, Any] | None = None, method: str | None = None, tolerance=None, drop: bool = False, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by selecting index labels along the specified dimension(s). @@ -1590,10 +1597,10 @@ def sel( return self._from_temp_dataset(ds) def head( - self: T_DataArray, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by the the first `n` values along the specified dimension(s). Default `n` = 5 @@ -1633,10 +1640,10 @@ def head( return self._from_temp_dataset(ds) def tail( - self: T_DataArray, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by the the last `n` values along the specified dimension(s). Default `n` = 5 @@ -1680,10 +1687,10 @@ def tail( return self._from_temp_dataset(ds) def thin( - self: T_DataArray, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by each `n` value along the specified dimension(s). @@ -1729,11 +1736,13 @@ def thin( ds = self._to_temp_dataset().thin(indexers, **indexers_kwargs) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def broadcast_like( - self: T_DataArray, - other: DataArray | Dataset, + self, + other: T_DataArrayOrSet, + *, exclude: Iterable[Hashable] | None = None, - ) -> T_DataArray: + ) -> Self: """Broadcast this DataArray against another Dataset or DataArray. This is equivalent to xr.broadcast(other, self)[1] @@ -1803,12 +1812,10 @@ def broadcast_like( dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude) - return _broadcast_helper( - cast("T_DataArray", args[1]), exclude, dims_map, common_coords - ) + return _broadcast_helper(args[1], exclude, dims_map, common_coords) def _reindex_callback( - self: T_DataArray, + self, aligner: alignment.Aligner, dim_pos_indexers: dict[Hashable, Any], variables: dict[Hashable, Variable], @@ -1816,7 +1823,7 @@ def _reindex_callback( fill_value: Any, exclude_dims: frozenset[Hashable], exclude_vars: frozenset[Hashable], - ) -> T_DataArray: + ) -> Self: """Callback called from ``Aligner`` to create a new reindexed DataArray.""" if isinstance(fill_value, dict): @@ -1842,14 +1849,16 @@ def _reindex_callback( return da + @_deprecate_positional_args("v2023.10.0") def reindex_like( - self: T_DataArray, - other: DataArray | Dataset, + self, + other: T_DataArrayOrSet, + *, method: ReindexMethodOptions = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value=dtypes.NA, - ) -> T_DataArray: + ) -> Self: """Conform this object onto the indexes of another object, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -2012,15 +2021,17 @@ def reindex_like( fill_value=fill_value, ) + @_deprecate_positional_args("v2023.10.0") def reindex( - self: T_DataArray, + self, indexers: Mapping[Any, Any] | None = None, + *, method: ReindexMethodOptions = None, tolerance: float | Iterable[float] | None = None, copy: bool = True, fill_value=dtypes.NA, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Conform this object onto the indexes of another object, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -2104,13 +2115,13 @@ def reindex( ) def interp( - self: T_DataArray, + self, coords: Mapping[Any, Any] | None = None, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, **coords_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Interpolate a DataArray onto new coordinates Performs univariate or multivariate interpolation of a DataArray onto @@ -2247,12 +2258,12 @@ def interp( return self._from_temp_dataset(ds) def interp_like( - self: T_DataArray, - other: DataArray | Dataset, + self, + other: T_Xarray, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, - ) -> T_DataArray: + ) -> Self: """Interpolate this object onto the coordinates of another object, filling out of range values with NaN. @@ -2369,13 +2380,11 @@ def interp_like( ) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def rename( self, new_name_or_name_dict: Hashable | Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> DataArray: + ) -> Self: """Returns a new DataArray with renamed coordinates, dimensions or a new name. Parameters @@ -2416,10 +2425,10 @@ def rename( return self._replace(name=new_name_or_name_dict) def swap_dims( - self: T_DataArray, + self, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs, - ) -> T_DataArray: + ) -> Self: """Returns a new DataArray with swapped dimensions. Parameters @@ -2474,14 +2483,12 @@ def swap_dims( ds = self._to_temp_dataset().swap_dims(dims_dict) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def expand_dims( self, dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None, axis: None | int | Sequence[int] = None, **dim_kwargs: Any, - ) -> DataArray: + ) -> Self: """Return a new object with an additional axis (or axes) inserted at the corresponding position in the array shape. The new object is a view into the underlying array, not a copy. @@ -2570,14 +2577,12 @@ def expand_dims( ds = self._to_temp_dataset().expand_dims(dim, axis) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def set_index( self, indexes: Mapping[Any, Hashable | Sequence[Hashable]] | None = None, append: bool = False, **indexes_kwargs: Hashable | Sequence[Hashable], - ) -> DataArray: + ) -> Self: """Set DataArray (multi-)indexes using one or more existing coordinates. @@ -2635,13 +2640,11 @@ def set_index( ds = self._to_temp_dataset().set_index(indexes, append=append, **indexes_kwargs) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def reset_index( self, dims_or_levels: Hashable | Sequence[Hashable], drop: bool = False, - ) -> DataArray: + ) -> Self: """Reset the specified index(es) or multi-index level(s). This legacy method is specific to pandas (multi-)indexes and @@ -2675,11 +2678,11 @@ def reset_index( return self._from_temp_dataset(ds) def set_xindex( - self: T_DataArray, + self, coord_names: str | Sequence[Hashable], index_cls: type[Index] | None = None, **options, - ) -> T_DataArray: + ) -> Self: """Set a new, Xarray-compatible index from one or more existing coordinate(s). @@ -2704,10 +2707,10 @@ def set_xindex( return self._from_temp_dataset(ds) def reorder_levels( - self: T_DataArray, + self, dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None, **dim_order_kwargs: Sequence[int | Hashable], - ) -> T_DataArray: + ) -> Self: """Rearrange index levels using input order. Parameters @@ -2730,12 +2733,12 @@ def reorder_levels( return self._from_temp_dataset(ds) def stack( - self: T_DataArray, + self, dimensions: Mapping[Any, Sequence[Hashable]] | None = None, create_index: bool | None = True, index_cls: type[Index] = PandasMultiIndex, **dimensions_kwargs: Sequence[Hashable], - ) -> T_DataArray: + ) -> Self: """ Stack any number of existing dimensions into a single new dimension. @@ -2802,14 +2805,14 @@ def stack( ) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved + @_deprecate_positional_args("v2023.10.0") def unstack( self, dim: Dims = None, + *, fill_value: Any = dtypes.NA, sparse: bool = False, - ) -> DataArray: + ) -> Self: """ Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions. @@ -2864,7 +2867,7 @@ def unstack( -------- DataArray.stack """ - ds = self._to_temp_dataset().unstack(dim, fill_value, sparse) + ds = self._to_temp_dataset().unstack(dim, fill_value=fill_value, sparse=sparse) return self._from_temp_dataset(ds) def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Dataset: @@ -2933,11 +2936,11 @@ def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Data return Dataset(data_dict) def transpose( - self: T_DataArray, + self, *dims: Hashable, transpose_coords: bool = True, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> T_DataArray: + ) -> Self: """Return a new DataArray object with transposed dimensions. Parameters @@ -2983,17 +2986,15 @@ def transpose( return self._replace(variable) @property - def T(self: T_DataArray) -> T_DataArray: + def T(self) -> Self: return self.transpose() - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def drop_vars( self, names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> DataArray: + ) -> Self: """Returns an array with dropped variables. Parameters @@ -3054,11 +3055,11 @@ def drop_vars( return self._from_temp_dataset(ds) def drop_indexes( - self: T_DataArray, + self, coord_names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_DataArray: + ) -> Self: """Drop the indexes assigned to the given coordinates. Parameters @@ -3079,13 +3080,13 @@ def drop_indexes( return self._from_temp_dataset(ds) def drop( - self: T_DataArray, + self, labels: Mapping[Any, Any] | None = None, dim: Hashable | None = None, *, errors: ErrorOptions = "raise", **labels_kwargs, - ) -> T_DataArray: + ) -> Self: """Backward compatible method based on `drop_vars` and `drop_sel` Using either `drop_vars` or `drop_sel` is encouraged @@ -3099,12 +3100,12 @@ def drop( return self._from_temp_dataset(ds) def drop_sel( - self: T_DataArray, + self, labels: Mapping[Any, Any] | None = None, *, errors: ErrorOptions = "raise", **labels_kwargs, - ) -> T_DataArray: + ) -> Self: """Drop index labels from this DataArray. Parameters @@ -3167,8 +3168,8 @@ def drop_sel( return self._from_temp_dataset(ds) def drop_isel( - self: T_DataArray, indexers: Mapping[Any, Any] | None = None, **indexers_kwargs - ) -> T_DataArray: + self, indexers: Mapping[Any, Any] | None = None, **indexers_kwargs + ) -> Self: """Drop index positions from this DataArray. Parameters @@ -3217,12 +3218,14 @@ def drop_isel( dataset = dataset.drop_isel(indexers=indexers, **indexers_kwargs) return self._from_temp_dataset(dataset) + @_deprecate_positional_args("v2023.10.0") def dropna( - self: T_DataArray, + self, dim: Hashable, + *, how: Literal["any", "all"] = "any", thresh: int | None = None, - ) -> T_DataArray: + ) -> Self: """Returns a new array with dropped labels for missing values along the provided dimension. @@ -3293,7 +3296,7 @@ def dropna( ds = self._to_temp_dataset().dropna(dim, how=how, thresh=thresh) return self._from_temp_dataset(ds) - def fillna(self: T_DataArray, value: Any) -> T_DataArray: + def fillna(self, value: Any) -> Self: """Fill missing values in this object. This operation follows the normal broadcasting and alignment rules that @@ -3356,7 +3359,7 @@ def fillna(self: T_DataArray, value: Any) -> T_DataArray: return out def interpolate_na( - self: T_DataArray, + self, dim: Hashable | None = None, method: InterpOptions = "linear", limit: int | None = None, @@ -3372,7 +3375,7 @@ def interpolate_na( ) = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Fill in NaNs by interpolating according to different methods. Parameters @@ -3479,9 +3482,7 @@ def interpolate_na( **kwargs, ) - def ffill( - self: T_DataArray, dim: Hashable, limit: int | None = None - ) -> T_DataArray: + def ffill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values forward *Requires bottleneck.* @@ -3565,9 +3566,7 @@ def ffill( return ffill(self, dim, limit=limit) - def bfill( - self: T_DataArray, dim: Hashable, limit: int | None = None - ) -> T_DataArray: + def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values backward *Requires bottleneck.* @@ -3651,7 +3650,7 @@ def bfill( return bfill(self, dim, limit=limit) - def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray: + def combine_first(self, other: Self) -> Self: """Combine two DataArray objects, with union of coordinates. This operation follows the normal broadcasting and alignment rules of @@ -3670,7 +3669,7 @@ def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray: return ops.fillna(self, other, join="outer") def reduce( - self: T_DataArray, + self, func: Callable[..., Any], dim: Dims = None, *, @@ -3678,7 +3677,7 @@ def reduce( keep_attrs: bool | None = None, keepdims: bool = False, **kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Reduce this array by applying `func` along some dimension(s). Parameters @@ -3716,7 +3715,7 @@ def reduce( var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs) return self._replace_maybe_drop_dims(var) - def to_pandas(self) -> DataArray | pd.Series | pd.DataFrame: + def to_pandas(self) -> Self | pd.Series | pd.DataFrame: """Convert this array into a pandas object with the same shape. The type of the returned object depends on the number of DataArray @@ -4033,6 +4032,7 @@ def to_zarr( mode: Literal["w", "w-", "a", "r+", None] = None, synchronizer=None, group: str | None = None, + *, encoding: Mapping | None = None, compute: Literal[True] = True, consolidated: bool | None = None, @@ -4073,6 +4073,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: bool = True, consolidated: bool | None = None, append_dim: Hashable | None = None, @@ -4270,7 +4271,7 @@ def to_dict( return d @classmethod - def from_dict(cls: type[T_DataArray], d: Mapping[str, Any]) -> T_DataArray: + def from_dict(cls, d: Mapping[str, Any]) -> Self: """Convert a dictionary into an xarray.DataArray Parameters @@ -4387,7 +4388,7 @@ def to_cdms2(self) -> cdms2_Variable: return to_cdms2(self) @classmethod - def from_cdms2(cls, variable: cdms2_Variable) -> DataArray: + def from_cdms2(cls, variable: cdms2_Variable) -> Self: """Convert a cdms2.Variable into an xarray.DataArray .. deprecated:: 2023.06.0 @@ -4414,13 +4415,13 @@ def to_iris(self) -> iris_Cube: return to_iris(self) @classmethod - def from_iris(cls, cube: iris_Cube) -> DataArray: + def from_iris(cls, cube: iris_Cube) -> Self: """Convert a iris.cube.Cube into an xarray.DataArray""" from xarray.convert import from_iris return from_iris(cube) - def _all_compat(self: T_DataArray, other: T_DataArray, compat_str: str) -> bool: + def _all_compat(self, other: Self, compat_str: str) -> bool: """Helper function for equals, broadcast_equals, and identical""" def compat(x, y): @@ -4430,7 +4431,7 @@ def compat(x, y): self, other ) - def broadcast_equals(self: T_DataArray, other: T_DataArray) -> bool: + def broadcast_equals(self, other: Self) -> bool: """Two DataArrays are broadcast equal if they are equal after broadcasting them against each other such that they have the same dimensions. @@ -4479,7 +4480,7 @@ def broadcast_equals(self: T_DataArray, other: T_DataArray) -> bool: except (TypeError, AttributeError): return False - def equals(self: T_DataArray, other: T_DataArray) -> bool: + def equals(self, other: Self) -> bool: """True if two DataArrays have the same dimensions, coordinates and values; otherwise False. @@ -4541,7 +4542,7 @@ def equals(self: T_DataArray, other: T_DataArray) -> bool: except (TypeError, AttributeError): return False - def identical(self: T_DataArray, other: T_DataArray) -> bool: + def identical(self, other: Self) -> bool: """Like equals, but also checks the array name and attributes, and attributes on all coordinates. @@ -4608,19 +4609,19 @@ def _result_name(self, other: Any = None) -> Hashable | None: else: return None - def __array_wrap__(self: T_DataArray, obj, context=None) -> T_DataArray: + def __array_wrap__(self, obj, context=None) -> Self: new_var = self.variable.__array_wrap__(obj, context) return self._replace(new_var) - def __matmul__(self: T_DataArray, obj: T_DataArray) -> T_DataArray: + def __matmul__(self, obj: T_Xarray) -> T_Xarray: return self.dot(obj) - def __rmatmul__(self: T_DataArray, other: T_DataArray) -> T_DataArray: + def __rmatmul__(self, other: T_Xarray) -> T_Xarray: # currently somewhat duplicative, as only other DataArrays are # compatible with matmul return computation.dot(other, self) - def _unary_op(self: T_DataArray, f: Callable, *args, **kwargs) -> T_DataArray: + def _unary_op(self, f: Callable, *args, **kwargs) -> Self: keep_attrs = kwargs.pop("keep_attrs", None) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) @@ -4636,32 +4637,29 @@ def _unary_op(self: T_DataArray, f: Callable, *args, **kwargs) -> T_DataArray: return da def _binary_op( - self: T_DataArray, - other: Any, - f: Callable, - reflexive: bool = False, - ) -> T_DataArray: + self, other: DaCompatible, f: Callable, reflexive: bool = False + ) -> Self: from xarray.core.groupby import GroupBy if isinstance(other, (Dataset, GroupBy)): return NotImplemented if isinstance(other, DataArray): align_type = OPTIONS["arithmetic_join"] - self, other = align(self, other, join=align_type, copy=False) # type: ignore - other_variable = getattr(other, "variable", other) + self, other = align(self, other, join=align_type, copy=False) + other_variable_or_arraylike: DaCompatible = getattr(other, "variable", other) other_coords = getattr(other, "coords", None) variable = ( - f(self.variable, other_variable) + f(self.variable, other_variable_or_arraylike) if not reflexive - else f(other_variable, self.variable) + else f(other_variable_or_arraylike, self.variable) ) coords, indexes = self.coords._merge_raw(other_coords, reflexive) name = self._result_name(other) return self._replace(variable, coords, name, indexes=indexes) - def _inplace_binary_op(self: T_DataArray, other: Any, f: Callable) -> T_DataArray: + def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self: from xarray.core.groupby import GroupBy if isinstance(other, GroupBy): @@ -4720,12 +4718,14 @@ def _title_for_slice(self, truncate: int = 50) -> str: return title + @_deprecate_positional_args("v2023.10.0") def diff( - self: T_DataArray, + self, dim: Hashable, n: int = 1, + *, label: Literal["upper", "lower"] = "upper", - ) -> T_DataArray: + ) -> Self: """Calculate the n-th order discrete difference along given axis. Parameters @@ -4771,11 +4771,11 @@ def diff( return self._from_temp_dataset(ds) def shift( - self: T_DataArray, + self, shifts: Mapping[Any, int] | None = None, fill_value: Any = dtypes.NA, **shifts_kwargs: int, - ) -> T_DataArray: + ) -> Self: """Shift this DataArray by an offset along one or more dimensions. Only the data is moved; coordinates stay in place. This is consistent @@ -4821,11 +4821,11 @@ def shift( return self._replace(variable=variable) def roll( - self: T_DataArray, + self, shifts: Mapping[Hashable, int] | None = None, roll_coords: bool = False, **shifts_kwargs: int, - ) -> T_DataArray: + ) -> Self: """Roll this array by an offset along one or more dimensions. Unlike shift, roll treats the given dimensions as periodic, so will not @@ -4870,7 +4870,7 @@ def roll( return self._from_temp_dataset(ds) @property - def real(self: T_DataArray) -> T_DataArray: + def real(self) -> Self: """ The real part of the array. @@ -4881,7 +4881,7 @@ def real(self: T_DataArray) -> T_DataArray: return self._replace(self.variable.real) @property - def imag(self: T_DataArray) -> T_DataArray: + def imag(self) -> Self: """ The imaginary part of the array. @@ -4892,10 +4892,10 @@ def imag(self: T_DataArray) -> T_DataArray: return self._replace(self.variable.imag) def dot( - self: T_DataArray, - other: T_DataArray, + self, + other: T_Xarray, dims: Dims = None, - ) -> T_DataArray: + ) -> T_Xarray: """Perform dot product of two DataArrays along their shared dims. Equivalent to taking taking tensordot over all shared dims. @@ -4945,13 +4945,14 @@ def dot( return computation.dot(self, other, dims=dims) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def sortby( self, - variables: Hashable | DataArray | Sequence[Hashable | DataArray], + variables: Hashable + | DataArray + | Sequence[Hashable | DataArray] + | Callable[[Self], Hashable | DataArray | Sequence[Hashable | DataArray]], ascending: bool = True, - ) -> DataArray: + ) -> Self: """Sort object by labels or values (along an axis). Sorts the dataarray, either along specified dimensions, @@ -4970,9 +4971,10 @@ def sortby( Parameters ---------- - variables : Hashable, DataArray, or sequence of Hashable or DataArray - 1D DataArray objects or name(s) of 1D variable(s) in - coords whose values are used to sort this array. + variables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable + 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are + used to sort this array. If a callable, the callable is passed this object, + and the result is used as the value for cond. ascending : bool, default: True Whether to sort by ascending or descending order. @@ -4992,34 +4994,47 @@ def sortby( Examples -------- >>> da = xr.DataArray( - ... np.random.rand(5), + ... np.arange(5, 0, -1), ... coords=[pd.date_range("1/1/2000", periods=5)], ... dims="time", ... ) >>> da - array([0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.4236548 ]) + array([5, 4, 3, 2, 1]) Coordinates: * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05 >>> da.sortby(da) - array([0.4236548 , 0.54488318, 0.5488135 , 0.60276338, 0.71518937]) + array([1, 2, 3, 4, 5]) Coordinates: - * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-02 + * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-01 + + >>> da.sortby(lambda x: x) + + array([1, 2, 3, 4, 5]) + Coordinates: + * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-01 """ + # We need to convert the callable here rather than pass it through to the + # dataset method, since otherwise the dataset method would try to call the + # callable with the dataset as the object + if callable(variables): + variables = variables(self) ds = self._to_temp_dataset().sortby(variables, ascending=ascending) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def quantile( - self: T_DataArray, + self, q: ArrayLike, dim: Dims = None, + *, method: QuantileMethods = "linear", keep_attrs: bool | None = None, skipna: bool | None = None, interpolation: QuantileMethods | None = None, - ) -> T_DataArray: + ) -> Self: """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements. @@ -5129,12 +5144,14 @@ def quantile( ) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def rank( - self: T_DataArray, + self, dim: Hashable, + *, pct: bool = False, keep_attrs: bool | None = None, - ) -> T_DataArray: + ) -> Self: """Ranks the data. Equal values are assigned a rank that is the average of the ranks that @@ -5174,11 +5191,11 @@ def rank( return self._from_temp_dataset(ds) def differentiate( - self: T_DataArray, + self, coord: Hashable, edge_order: Literal[1, 2] = 1, datetime_unit: DatetimeUnitOptions = None, - ) -> T_DataArray: + ) -> Self: """ Differentiate the array with the second order accurate central differences. @@ -5236,13 +5253,11 @@ def differentiate( ds = self._to_temp_dataset().differentiate(coord, edge_order, datetime_unit) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def integrate( self, coord: Hashable | Sequence[Hashable] = None, datetime_unit: DatetimeUnitOptions = None, - ) -> DataArray: + ) -> Self: """Integrate along the given coordinate using the trapezoidal rule. .. note:: @@ -5292,13 +5307,11 @@ def integrate( ds = self._to_temp_dataset().integrate(coord, datetime_unit) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def cumulative_integrate( self, coord: Hashable | Sequence[Hashable] = None, datetime_unit: DatetimeUnitOptions = None, - ) -> DataArray: + ) -> Self: """Integrate cumulatively along the given coordinate using the trapezoidal rule. .. note:: @@ -5356,7 +5369,7 @@ def cumulative_integrate( ds = self._to_temp_dataset().cumulative_integrate(coord, datetime_unit) return self._from_temp_dataset(ds) - def unify_chunks(self) -> DataArray: + def unify_chunks(self) -> Self: """Unify chunk size along all chunked dimensions of this DataArray. Returns @@ -5541,7 +5554,7 @@ def polyfit( ) def pad( - self: T_DataArray, + self, pad_width: Mapping[Any, int | tuple[int, int]] | None = None, mode: PadModeOptions = "constant", stat_length: int @@ -5556,7 +5569,7 @@ def pad( reflect_type: PadReflectOptions = None, keep_attrs: bool | None = None, **pad_width_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Pad this array along one or more dimensions. .. warning:: @@ -5708,13 +5721,15 @@ def pad( ) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def idxmin( self, dim: Hashable | None = None, + *, skipna: bool | None = None, fill_value: Any = dtypes.NA, keep_attrs: bool | None = None, - ) -> DataArray: + ) -> Self: """Return the coordinate label of the minimum value along a dimension. Returns a new `DataArray` named after the dimension with the values of @@ -5804,13 +5819,15 @@ def idxmin( keep_attrs=keep_attrs, ) + @_deprecate_positional_args("v2023.10.0") def idxmax( self, dim: Hashable = None, + *, skipna: bool | None = None, fill_value: Any = dtypes.NA, keep_attrs: bool | None = None, - ) -> DataArray: + ) -> Self: """Return the coordinate label of the maximum value along a dimension. Returns a new `DataArray` named after the dimension with the values of @@ -5900,15 +5917,15 @@ def idxmax( keep_attrs=keep_attrs, ) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved + @_deprecate_positional_args("v2023.10.0") def argmin( self, dim: Dims = None, + *, axis: int | None = None, keep_attrs: bool | None = None, skipna: bool | None = None, - ) -> DataArray | dict[Hashable, DataArray]: + ) -> Self | dict[Hashable, Self]: """Index or indices of the minimum of the DataArray over one or more dimensions. If a sequence is passed to 'dim', then result returned as dict of DataArrays, @@ -6002,15 +6019,15 @@ def argmin( else: return self._replace_maybe_drop_dims(result) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved + @_deprecate_positional_args("v2023.10.0") def argmax( self, dim: Dims = None, + *, axis: int | None = None, keep_attrs: bool | None = None, skipna: bool | None = None, - ) -> DataArray | dict[Hashable, DataArray]: + ) -> Self | dict[Hashable, Self]: """Index or indices of the maximum of the DataArray over one or more dimensions. If a sequence is passed to 'dim', then result returned as dict of DataArrays, @@ -6351,11 +6368,13 @@ def curvefit( kwargs=kwargs, ) + @_deprecate_positional_args("v2023.10.0") def drop_duplicates( - self: T_DataArray, + self, dim: Hashable | Iterable[Hashable], + *, keep: Literal["first", "last", False] = "first", - ) -> T_DataArray: + ) -> Self: """Returns a new DataArray with duplicate dimension values removed. Parameters @@ -6437,7 +6456,7 @@ def convert_calendar( align_on: str | None = None, missing: Any | None = None, use_cftime: bool | None = None, - ) -> DataArray: + ) -> Self: """Convert the DataArray to another calendar. Only converts the individual timestamps, does not modify any data except @@ -6557,7 +6576,7 @@ def interp_calendar( self, target: pd.DatetimeIndex | CFTimeIndex | DataArray, dim: str = "time", - ) -> DataArray: + ) -> Self: """Interpolates the DataArray to another calendar based on decimal year measure. Each timestamp in `source` and `target` are first converted to their decimal diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 48e25f7e1c7..ebd6fb6f51f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -93,7 +93,14 @@ is_duck_array, is_duck_dask_array, ) -from xarray.core.types import QuantileMethods, T_Dataset +from xarray.core.types import ( + QuantileMethods, + Self, + T_ChunkDim, + T_Chunks, + T_DataArrayOrSet, + T_Dataset, +) from xarray.core.utils import ( Default, Frozen, @@ -116,6 +123,7 @@ calculate_dimensions, ) from xarray.plot.accessor import DatasetPlotAccessor +from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: from numpy.typing import ArrayLike @@ -124,7 +132,7 @@ from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes from xarray.core.dataarray import DataArray from xarray.core.groupby import DatasetGroupBy - from xarray.core.merge import CoercibleMapping, CoercibleValue + from xarray.core.merge import CoercibleMapping, CoercibleValue, _MergeResult from xarray.core.parallelcompat import ChunkManagerEntrypoint from xarray.core.resample import DatasetResample from xarray.core.rolling import DatasetCoarsen, DatasetRolling @@ -133,6 +141,7 @@ CoarsenBoundaryOptions, CombineAttrsOptions, CompatOptions, + DataVars, DatetimeLike, DatetimeUnitOptions, Dims, @@ -404,7 +413,7 @@ def _initialize_feasible(lb, ub): return param_defaults, bounds_defaults -def merge_data_and_coords(data_vars, coords): +def merge_data_and_coords(data_vars: DataVars, coords) -> _MergeResult: """Used in Dataset.__init__.""" if isinstance(coords, Coordinates): coords = coords.copy() @@ -666,7 +675,7 @@ def __init__( self, # could make a VariableArgs to use more generally, and refine these # categories - data_vars: Mapping[Any, Any] | None = None, + data_vars: DataVars | None = None, coords: Mapping[Any, Any] | None = None, attrs: Mapping[Any, Any] | None = None, ) -> None: @@ -698,11 +707,11 @@ def __init__( # TODO: dirty workaround for mypy 1.5 error with inherited DatasetOpsMixin vs. Mapping # related to https://github.com/python/mypy/issues/9319? - def __eq__(self: T_Dataset, other: DsCompatible) -> T_Dataset: # type: ignore[override] + def __eq__(self, other: DsCompatible) -> Self: # type: ignore[override] return super().__eq__(other) @classmethod - def load_store(cls: type[T_Dataset], store, decoder=None) -> T_Dataset: + def load_store(cls, store, decoder=None) -> Self: """Create a new dataset from the contents of a backends.*DataStore object """ @@ -746,10 +755,16 @@ def encoding(self) -> dict[Any, Any]: def encoding(self, value: Mapping[Any, Any]) -> None: self._encoding = dict(value) - def reset_encoding(self: T_Dataset) -> T_Dataset: + def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() + + def drop_encoding(self) -> Self: """Return a new Dataset without encoding on the dataset or any of its variables/coords.""" - variables = {k: v.reset_encoding() for k, v in self.variables.items()} + variables = {k: v.drop_encoding() for k, v in self.variables.items()} return self._replace(variables=variables, encoding={}) @property @@ -802,7 +817,7 @@ def dtypes(self) -> Frozen[Hashable, np.dtype]: } ) - def load(self: T_Dataset, **kwargs) -> T_Dataset: + def load(self, **kwargs) -> Self: """Manually trigger loading and/or computation of this dataset's data from disk or a remote source into memory and return this dataset. Unlike compute, the original dataset is modified and returned. @@ -902,7 +917,7 @@ def __dask_postcompute__(self): def __dask_postpersist__(self): return self._dask_postpersist, () - def _dask_postcompute(self: T_Dataset, results: Iterable[Variable]) -> T_Dataset: + def _dask_postcompute(self, results: Iterable[Variable]) -> Self: import dask variables = {} @@ -925,8 +940,8 @@ def _dask_postcompute(self: T_Dataset, results: Iterable[Variable]) -> T_Dataset ) def _dask_postpersist( - self: T_Dataset, dsk: Mapping, *, rename: Mapping[str, str] | None = None - ) -> T_Dataset: + self, dsk: Mapping, *, rename: Mapping[str, str] | None = None + ) -> Self: from dask import is_dask_collection from dask.highlevelgraph import HighLevelGraph from dask.optimization import cull @@ -975,7 +990,7 @@ def _dask_postpersist( self._close, ) - def compute(self: T_Dataset, **kwargs) -> T_Dataset: + def compute(self, **kwargs) -> Self: """Manually trigger loading and/or computation of this dataset's data from disk or a remote source into memory and return a new dataset. Unlike load, the original dataset is left unaltered. @@ -997,7 +1012,7 @@ def compute(self: T_Dataset, **kwargs) -> T_Dataset: new = self.copy(deep=False) return new.load(**kwargs) - def _persist_inplace(self: T_Dataset, **kwargs) -> T_Dataset: + def _persist_inplace(self, **kwargs) -> Self: """Persist all Dask arrays in memory""" # access .data to coerce everything to numpy or dask arrays lazy_data = { @@ -1014,7 +1029,7 @@ def _persist_inplace(self: T_Dataset, **kwargs) -> T_Dataset: return self - def persist(self: T_Dataset, **kwargs) -> T_Dataset: + def persist(self, **kwargs) -> Self: """Trigger computation, keeping data as dask arrays This operation can be used to trigger computation on underlying dask @@ -1037,7 +1052,7 @@ def persist(self: T_Dataset, **kwargs) -> T_Dataset: @classmethod def _construct_direct( - cls: type[T_Dataset], + cls, variables: dict[Any, Variable], coord_names: set[Hashable], dims: dict[Any, int] | None = None, @@ -1045,7 +1060,7 @@ def _construct_direct( indexes: dict[Any, Index] | None = None, encoding: dict | None = None, close: Callable[[], None] | None = None, - ) -> T_Dataset: + ) -> Self: """Shortcut around __init__ for internal use when we want to skip costly validation """ @@ -1064,7 +1079,7 @@ def _construct_direct( return obj def _replace( - self: T_Dataset, + self, variables: dict[Hashable, Variable] | None = None, coord_names: set[Hashable] | None = None, dims: dict[Any, int] | None = None, @@ -1072,7 +1087,7 @@ def _replace( indexes: dict[Hashable, Index] | None = None, encoding: dict | None | Default = _default, inplace: bool = False, - ) -> T_Dataset: + ) -> Self: """Fastpath constructor for internal use. Returns an object with optionally with replaced attributes. @@ -1114,13 +1129,13 @@ def _replace( return obj def _replace_with_new_dims( - self: T_Dataset, + self, variables: dict[Hashable, Variable], coord_names: set | None = None, attrs: dict[Hashable, Any] | None | Default = _default, indexes: dict[Hashable, Index] | None = None, inplace: bool = False, - ) -> T_Dataset: + ) -> Self: """Replace variables with recalculated dimensions.""" dims = calculate_dimensions(variables) return self._replace( @@ -1128,13 +1143,13 @@ def _replace_with_new_dims( ) def _replace_vars_and_dims( - self: T_Dataset, + self, variables: dict[Hashable, Variable], coord_names: set | None = None, dims: dict[Hashable, int] | None = None, attrs: dict[Hashable, Any] | None | Default = _default, inplace: bool = False, - ) -> T_Dataset: + ) -> Self: """Deprecated version of _replace_with_new_dims(). Unlike _replace_with_new_dims(), this method always recalculates @@ -1147,13 +1162,13 @@ def _replace_vars_and_dims( ) def _overwrite_indexes( - self: T_Dataset, + self, indexes: Mapping[Hashable, Index], variables: Mapping[Hashable, Variable] | None = None, drop_variables: list[Hashable] | None = None, drop_indexes: list[Hashable] | None = None, rename_dims: Mapping[Hashable, Hashable] | None = None, - ) -> T_Dataset: + ) -> Self: """Maybe replace indexes. This function may do a lot more depending on index query @@ -1220,9 +1235,7 @@ def _overwrite_indexes( else: return replaced - def copy( - self: T_Dataset, deep: bool = False, data: Mapping[Any, ArrayLike] | None = None - ) -> T_Dataset: + def copy(self, deep: bool = False, data: DataVars | None = None) -> Self: """Returns a copy of this dataset. If `deep=True`, a deep copy is made of each of the component variables. @@ -1322,11 +1335,11 @@ def copy( return self._copy(deep=deep, data=data) def _copy( - self: T_Dataset, + self, deep: bool = False, - data: Mapping[Any, ArrayLike] | None = None, + data: DataVars | None = None, memo: dict[int, Any] | None = None, - ) -> T_Dataset: + ) -> Self: if data is None: data = {} elif not utils.is_dict_like(data): @@ -1364,13 +1377,13 @@ def _copy( return self._replace(variables, indexes=indexes, attrs=attrs, encoding=encoding) - def __copy__(self: T_Dataset) -> T_Dataset: + def __copy__(self) -> Self: return self._copy(deep=False) - def __deepcopy__(self: T_Dataset, memo: dict[int, Any] | None = None) -> T_Dataset: + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: return self._copy(deep=True, memo=memo) - def as_numpy(self: T_Dataset) -> T_Dataset: + def as_numpy(self) -> Self: """ Coerces wrapped data and coordinates into numpy arrays, returning a Dataset. @@ -1382,7 +1395,7 @@ def as_numpy(self: T_Dataset) -> T_Dataset: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) - def _copy_listed(self: T_Dataset, names: Iterable[Hashable]) -> T_Dataset: + def _copy_listed(self, names: Iterable[Hashable]) -> Self: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. """ @@ -1476,13 +1489,20 @@ def __bool__(self) -> bool: def __iter__(self) -> Iterator[Hashable]: return iter(self.data_vars) - def __array__(self, dtype=None): - raise TypeError( - "cannot directly convert an xarray.Dataset into a " - "numpy array. Instead, create an xarray.DataArray " - "first, either with indexing on the Dataset or by " - "invoking the `to_array()` method." - ) + if TYPE_CHECKING: + # needed because __getattr__ is returning Any and otherwise + # this class counts as part of the SupportsArray Protocol + __array__ = None # type: ignore[var-annotated,unused-ignore] + + else: + + def __array__(self, dtype=None): + raise TypeError( + "cannot directly convert an xarray.Dataset into a " + "numpy array. Instead, create an xarray.DataArray " + "first, either with indexing on the Dataset or by " + "invoking the `to_array()` method." + ) @property def nbytes(self) -> int: @@ -1495,7 +1515,7 @@ def nbytes(self) -> int: return sum(v.nbytes for v in self.variables.values()) @property - def loc(self: T_Dataset) -> _LocIndexer[T_Dataset]: + def loc(self) -> _LocIndexer[Self]: """Attribute for location based indexing. Only supports __getitem__, and only when the key is a dict of the form {dim: labels}. """ @@ -1507,12 +1527,12 @@ def __getitem__(self, key: Hashable) -> DataArray: # Mapping is Iterable @overload - def __getitem__(self: T_Dataset, key: Iterable[Hashable]) -> T_Dataset: + def __getitem__(self, key: Iterable[Hashable]) -> Self: ... def __getitem__( - self: T_Dataset, key: Mapping[Any, Any] | Hashable | Iterable[Hashable] - ) -> T_Dataset | DataArray: + self, key: Mapping[Any, Any] | Hashable | Iterable[Hashable] + ) -> Self | DataArray: """Access variables or coordinates of this dataset as a :py:class:`~xarray.DataArray` or a subset of variables or a indexed dataset. @@ -1677,7 +1697,7 @@ def __delitem__(self, key: Hashable) -> None: # https://github.com/python/mypy/issues/4266 __hash__ = None # type: ignore[assignment] - def _all_compat(self, other: Dataset, compat_str: str) -> bool: + def _all_compat(self, other: Self, compat_str: str) -> bool: """Helper function for equals and identical""" # some stores (e.g., scipy) do not seem to preserve order, so don't @@ -1689,7 +1709,7 @@ def compat(x: Variable, y: Variable) -> bool: self._variables, other._variables, compat=compat ) - def broadcast_equals(self, other: Dataset) -> bool: + def broadcast_equals(self, other: Self) -> bool: """Two Datasets are broadcast equal if they are equal after broadcasting all variables against each other. @@ -1756,7 +1776,7 @@ def broadcast_equals(self, other: Dataset) -> bool: except (TypeError, AttributeError): return False - def equals(self, other: Dataset) -> bool: + def equals(self, other: Self) -> bool: """Two Datasets are equal if they have matching variables and coordinates, all of which are equal. @@ -1837,7 +1857,7 @@ def equals(self, other: Dataset) -> bool: except (TypeError, AttributeError): return False - def identical(self, other: Dataset) -> bool: + def identical(self, other: Self) -> bool: """Like equals, but also checks all dataset attributes and the attributes on all variables and coordinates. @@ -1950,7 +1970,7 @@ def data_vars(self) -> DataVariables: """Dictionary of DataArray objects corresponding to data variables""" return DataVariables(self) - def set_coords(self: T_Dataset, names: Hashable | Iterable[Hashable]) -> T_Dataset: + def set_coords(self, names: Hashable | Iterable[Hashable]) -> Self: """Given names of one or more variables, set them as coordinates Parameters @@ -2008,10 +2028,10 @@ def set_coords(self: T_Dataset, names: Hashable | Iterable[Hashable]) -> T_Datas return obj def reset_coords( - self: T_Dataset, + self, names: Dims = None, drop: bool = False, - ) -> T_Dataset: + ) -> Self: """Given names of coordinates, reset them to become variables Parameters @@ -2280,6 +2300,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: Literal[True] = True, consolidated: bool | None = None, append_dim: Hashable | None = None, @@ -2323,6 +2344,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: bool = True, consolidated: bool | None = None, append_dim: Hashable | None = None, @@ -2562,18 +2584,16 @@ def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]: return get_chunksizes(self.variables.values()) def chunk( - self: T_Dataset, - chunks: ( - int | Literal["auto"] | Mapping[Any, None | int | str | tuple[int, ...]] - ) = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + self, + chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) name_prefix: str = "xarray-", token: str | None = None, lock: bool = False, inline_array: bool = False, chunked_array_type: str | ChunkManagerEntrypoint | None = None, from_array_kwargs=None, - **chunks_kwargs: None | int | str | tuple[int, ...], - ) -> T_Dataset: + **chunks_kwargs: T_ChunkDim, + ) -> Self: """Coerce all arrays in this dataset into dask arrays with the given chunks. @@ -2623,20 +2643,20 @@ def chunk( xarray.unify_chunks dask.array.from_array """ - if chunks is None and chunks_kwargs is None: + if chunks is None and not chunks_kwargs: warnings.warn( "None value for 'chunks' is deprecated. " "It will raise an error in the future. Use instead '{}'", category=FutureWarning, ) chunks = {} - - if isinstance(chunks, (Number, str, int)): - chunks = dict.fromkeys(self.dims, chunks) + chunks_mapping: Mapping[Any, Any] + if not isinstance(chunks, Mapping) and chunks is not None: + chunks_mapping = dict.fromkeys(self.dims, chunks) else: - chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") + chunks_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") - bad_dims = chunks.keys() - self.dims.keys() + bad_dims = chunks_mapping.keys() - self.dims.keys() if bad_dims: raise ValueError( f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.dims)}" @@ -2650,7 +2670,7 @@ def chunk( k: _maybe_chunk( k, v, - chunks, + chunks_mapping, token, lock, name_prefix, @@ -2767,12 +2787,12 @@ def _get_indexers_coords_and_indexes(self, indexers): return attached_coords, attached_indexes def isel( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, drop: bool = False, missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with each array indexed along the specified dimension(s). @@ -2915,12 +2935,12 @@ def isel( ) def _isel_fancy( - self: T_Dataset, + self, indexers: Mapping[Any, Any], *, drop: bool, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> T_Dataset: + ) -> Self: valid_indexers = dict(self._validate_indexers(indexers, missing_dims)) variables: dict[Hashable, Variable] = {} @@ -2956,13 +2976,13 @@ def _isel_fancy( return self._replace_with_new_dims(variables, coord_names, indexes=indexes) def sel( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, drop: bool = False, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with each array indexed by tick labels along the specified dimension(s). @@ -3042,10 +3062,10 @@ def sel( return result._overwrite_indexes(*query_results.as_tuple()[1:]) def head( - self: T_Dataset, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with the first `n` values of each array for the specified dimension(s). @@ -3132,10 +3152,10 @@ def head( return self.isel(indexers_slices) def tail( - self: T_Dataset, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with the last `n` values of each array for the specified dimension(s). @@ -3223,10 +3243,10 @@ def tail( return self.isel(indexers_slices) def thin( - self: T_Dataset, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with each array indexed along every `n`-th value for the specified dimension(s) @@ -3308,10 +3328,10 @@ def thin( return self.isel(indexers_slices) def broadcast_like( - self: T_Dataset, - other: Dataset | DataArray, + self, + other: T_DataArrayOrSet, exclude: Iterable[Hashable] | None = None, - ) -> T_Dataset: + ) -> Self: """Broadcast this DataArray against another Dataset or DataArray. This is equivalent to xr.broadcast(other, self)[1] @@ -3331,12 +3351,10 @@ def broadcast_like( dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude) - return _broadcast_helper( - cast("T_Dataset", args[1]), exclude, dims_map, common_coords - ) + return _broadcast_helper(args[1], exclude, dims_map, common_coords) def _reindex_callback( - self: T_Dataset, + self, aligner: alignment.Aligner, dim_pos_indexers: dict[Hashable, Any], variables: dict[Hashable, Variable], @@ -3344,7 +3362,7 @@ def _reindex_callback( fill_value: Any, exclude_dims: frozenset[Hashable], exclude_vars: frozenset[Hashable], - ) -> T_Dataset: + ) -> Self: """Callback called from ``Aligner`` to create a new reindexed Dataset.""" new_variables = variables.copy() @@ -3397,13 +3415,13 @@ def _reindex_callback( return reindexed def reindex_like( - self: T_Dataset, - other: Dataset | DataArray, + self, + other: T_Xarray, method: ReindexMethodOptions = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value: Any = xrdtypes.NA, - ) -> T_Dataset: + ) -> Self: """Conform this object onto the indexes of another object, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -3463,14 +3481,14 @@ def reindex_like( ) def reindex( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, method: ReindexMethodOptions = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value: Any = xrdtypes.NA, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Conform this object onto a new set of indexes, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -3679,7 +3697,7 @@ def reindex( ) def _reindex( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, @@ -3687,7 +3705,7 @@ def _reindex( fill_value: Any = xrdtypes.NA, sparse: bool = False, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """ Same as reindex but supports sparse option. """ @@ -3703,14 +3721,14 @@ def _reindex( ) def interp( - self: T_Dataset, + self, coords: Mapping[Any, Any] | None = None, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, method_non_numeric: str = "nearest", **coords_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Interpolate a Dataset onto new coordinates Performs univariate or multivariate interpolation of a Dataset onto @@ -3983,12 +4001,12 @@ def _validate_interp_indexer(x, new_x): def interp_like( self, - other: Dataset | DataArray, + other: T_Xarray, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, method_non_numeric: str = "nearest", - ) -> Dataset: + ) -> Self: """Interpolate this object onto the coordinates of another object, filling the out of range values with NaN. @@ -4138,10 +4156,10 @@ def _rename_all( return variables, coord_names, dims, indexes def _rename( - self: T_Dataset, + self, name_dict: Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> T_Dataset: + ) -> Self: """Also used internally by DataArray so that the warning (if any) is raised at the right stack level. """ @@ -4156,6 +4174,9 @@ def _rename( create_dim_coord = False new_k = name_dict[k] + if k == new_k: + continue # Same name, nothing to do + if k in self.dims and new_k in self._coord_names: coord_dims = self._variables[name_dict[k]].dims if coord_dims == (k,): @@ -4180,10 +4201,10 @@ def _rename( return self._replace(variables, coord_names, dims=dims, indexes=indexes) def rename( - self: T_Dataset, + self, name_dict: Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> T_Dataset: + ) -> Self: """Returns a new object with renamed variables, coordinates and dimensions. Parameters @@ -4210,10 +4231,10 @@ def rename( return self._rename(name_dict=name_dict, **names) def rename_dims( - self: T_Dataset, + self, dims_dict: Mapping[Any, Hashable] | None = None, **dims: Hashable, - ) -> T_Dataset: + ) -> Self: """Returns a new object with renamed dimensions only. Parameters @@ -4257,10 +4278,10 @@ def rename_dims( return self._replace(variables, coord_names, dims=sizes, indexes=indexes) def rename_vars( - self: T_Dataset, + self, name_dict: Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> T_Dataset: + ) -> Self: """Returns a new object with renamed variables including coordinates Parameters @@ -4297,8 +4318,8 @@ def rename_vars( return self._replace(variables, coord_names, dims=dims, indexes=indexes) def swap_dims( - self: T_Dataset, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs - ) -> T_Dataset: + self, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs + ) -> Self: """Returns a new object with swapped dimensions. Parameters @@ -4401,14 +4422,12 @@ def swap_dims( return self._replace_with_new_dims(variables, coord_names, indexes=indexes) - # change type of self and return to T_Dataset once - # https://github.com/python/mypy/issues/12846 is resolved def expand_dims( self, dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None, axis: None | int | Sequence[int] = None, **dim_kwargs: Any, - ) -> Dataset: + ) -> Self: """Return a new object with an additional axis (or axes) inserted at the corresponding position in the array shape. The new object is a view into the underlying array, not a copy. @@ -4598,14 +4617,12 @@ def expand_dims( variables, coord_names=coord_names, indexes=indexes ) - # change type of self and return to T_Dataset once - # https://github.com/python/mypy/issues/12846 is resolved def set_index( self, indexes: Mapping[Any, Hashable | Sequence[Hashable]] | None = None, append: bool = False, **indexes_kwargs: Hashable | Sequence[Hashable], - ) -> Dataset: + ) -> Self: """Set Dataset (multi-)indexes using one or more existing coordinates or variables. @@ -4765,11 +4782,13 @@ def set_index( variables, coord_names=coord_names, indexes=indexes_ ) + @_deprecate_positional_args("v2023.10.0") def reset_index( - self: T_Dataset, + self, dims_or_levels: Hashable | Sequence[Hashable], + *, drop: bool = False, - ) -> T_Dataset: + ) -> Self: """Reset the specified index(es) or multi-index level(s). This legacy method is specific to pandas (multi-)indexes and @@ -4877,11 +4896,11 @@ def drop_or_convert(var_names): ) def set_xindex( - self: T_Dataset, + self, coord_names: str | Sequence[Hashable], index_cls: type[Index] | None = None, **options, - ) -> T_Dataset: + ) -> Self: """Set a new, Xarray-compatible index from one or more existing coordinate(s). @@ -4989,10 +5008,10 @@ def set_xindex( ) def reorder_levels( - self: T_Dataset, + self, dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None, **dim_order_kwargs: Sequence[int | Hashable], - ) -> T_Dataset: + ) -> Self: """Rearrange index levels using input order. Parameters @@ -5093,12 +5112,12 @@ def _get_stack_index( return stack_index, stack_coords def _stack_once( - self: T_Dataset, + self, dims: Sequence[Hashable | ellipsis], new_dim: Hashable, index_cls: type[Index], create_index: bool | None = True, - ) -> T_Dataset: + ) -> Self: if dims == ...: raise ValueError("Please use [...] for dims, rather than just ...") if ... in dims: @@ -5152,12 +5171,12 @@ def _stack_once( ) def stack( - self: T_Dataset, + self, dimensions: Mapping[Any, Sequence[Hashable | ellipsis]] | None = None, create_index: bool | None = True, index_cls: type[Index] = PandasMultiIndex, **dimensions_kwargs: Sequence[Hashable | ellipsis], - ) -> T_Dataset: + ) -> Self: """ Stack any number of existing dimensions into a single new dimension. @@ -5312,12 +5331,12 @@ def stack_dataarray(da): return data_array def _unstack_once( - self: T_Dataset, + self, dim: Hashable, index_and_vars: tuple[Index, dict[Hashable, Variable]], fill_value, sparse: bool = False, - ) -> T_Dataset: + ) -> Self: index, index_vars = index_and_vars variables: dict[Hashable, Variable] = {} indexes = {k: v for k, v in self._indexes.items() if k != dim} @@ -5352,12 +5371,12 @@ def _unstack_once( ) def _unstack_full_reindex( - self: T_Dataset, + self, dim: Hashable, index_and_vars: tuple[Index, dict[Hashable, Variable]], fill_value, sparse: bool, - ) -> T_Dataset: + ) -> Self: index, index_vars = index_and_vars variables: dict[Hashable, Variable] = {} indexes = {k: v for k, v in self._indexes.items() if k != dim} @@ -5402,12 +5421,14 @@ def _unstack_full_reindex( variables, coord_names=coord_names, indexes=indexes ) + @_deprecate_positional_args("v2023.10.0") def unstack( - self: T_Dataset, + self, dim: Dims = None, + *, fill_value: Any = xrdtypes.NA, sparse: bool = False, - ) -> T_Dataset: + ) -> Self: """ Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions. @@ -5504,7 +5525,7 @@ def unstack( result = result._unstack_once(d, stacked_indexes[d], fill_value, sparse) return result - def update(self: T_Dataset, other: CoercibleMapping) -> T_Dataset: + def update(self, other: CoercibleMapping) -> Self: """Update this dataset's variables with those from another dataset. Just like :py:meth:`dict.update` this is a in-place operation. @@ -5544,14 +5565,14 @@ def update(self: T_Dataset, other: CoercibleMapping) -> T_Dataset: return self._replace(inplace=True, **merge_result._asdict()) def merge( - self: T_Dataset, + self, other: CoercibleMapping | DataArray, overwrite_vars: Hashable | Iterable[Hashable] = frozenset(), compat: CompatOptions = "no_conflicts", join: JoinOptions = "outer", fill_value: Any = xrdtypes.NA, combine_attrs: CombineAttrsOptions = "override", - ) -> T_Dataset: + ) -> Self: """Merge the arrays of two datasets into a single dataset. This method generally does not allow for overriding data, with the @@ -5655,11 +5676,11 @@ def _assert_all_in_dataset( ) def drop_vars( - self: T_Dataset, + self, names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_Dataset: + ) -> Self: """Drop variables from this dataset. Parameters @@ -5801,11 +5822,11 @@ def drop_vars( ) def drop_indexes( - self: T_Dataset, + self, coord_names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_Dataset: + ) -> Self: """Drop the indexes assigned to the given coordinates. Parameters @@ -5857,13 +5878,13 @@ def drop_indexes( return self._replace(variables=variables, indexes=indexes) def drop( - self: T_Dataset, + self, labels=None, dim=None, *, errors: ErrorOptions = "raise", **labels_kwargs, - ) -> T_Dataset: + ) -> Self: """Backward compatible method based on `drop_vars` and `drop_sel` Using either `drop_vars` or `drop_sel` is encouraged @@ -5913,8 +5934,8 @@ def drop( return self.drop_sel(labels, errors=errors) def drop_sel( - self: T_Dataset, labels=None, *, errors: ErrorOptions = "raise", **labels_kwargs - ) -> T_Dataset: + self, labels=None, *, errors: ErrorOptions = "raise", **labels_kwargs + ) -> Self: """Drop index labels from this dataset. Parameters @@ -5983,7 +6004,7 @@ def drop_sel( ds = ds.loc[{dim: new_index}] return ds - def drop_isel(self: T_Dataset, indexers=None, **indexers_kwargs) -> T_Dataset: + def drop_isel(self, indexers=None, **indexers_kwargs) -> Self: """Drop index positions from this Dataset. Parameters @@ -6049,11 +6070,11 @@ def drop_isel(self: T_Dataset, indexers=None, **indexers_kwargs) -> T_Dataset: return ds def drop_dims( - self: T_Dataset, + self, drop_dims: str | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_Dataset: + ) -> Self: """Drop dimensions and associated variables from this dataset. Parameters @@ -6090,10 +6111,10 @@ def drop_dims( return self.drop_vars(drop_vars) def transpose( - self: T_Dataset, + self, *dims: Hashable, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> T_Dataset: + ) -> Self: """Return a new Dataset object with all array dimensions transposed. Although the order of dimensions on each array will change, the dataset @@ -6145,13 +6166,15 @@ def transpose( ds._variables[name] = var.transpose(*var_dims) return ds + @_deprecate_positional_args("v2023.10.0") def dropna( - self: T_Dataset, + self, dim: Hashable, + *, how: Literal["any", "all"] = "any", thresh: int | None = None, subset: Iterable[Hashable] | None = None, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with dropped labels for missing values along the provided dimension. @@ -6273,7 +6296,7 @@ def dropna( return self.isel({dim: mask}) - def fillna(self: T_Dataset, value: Any) -> T_Dataset: + def fillna(self, value: Any) -> Self: """Fill missing values in this object. This operation follows the normal broadcasting and alignment rules that @@ -6354,7 +6377,7 @@ def fillna(self: T_Dataset, value: Any) -> T_Dataset: return out def interpolate_na( - self: T_Dataset, + self, dim: Hashable | None = None, method: InterpOptions = "linear", limit: int | None = None, @@ -6363,7 +6386,7 @@ def interpolate_na( int | float | str | pd.Timedelta | np.timedelta64 | datetime.timedelta ) = None, **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Fill in NaNs by interpolating according to different methods. Parameters @@ -6493,7 +6516,7 @@ def interpolate_na( ) return new - def ffill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset: + def ffill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values forward *Requires bottleneck.* @@ -6557,7 +6580,7 @@ def ffill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit) return new - def bfill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset: + def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values backward *Requires bottleneck.* @@ -6622,7 +6645,7 @@ def bfill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset new = _apply_over_vars_with_dim(bfill, self, dim=dim, limit=limit) return new - def combine_first(self: T_Dataset, other: T_Dataset) -> T_Dataset: + def combine_first(self, other: Self) -> Self: """Combine two Datasets, default to data_vars of self. The new coordinates follow the normal broadcasting and alignment rules @@ -6642,7 +6665,7 @@ def combine_first(self: T_Dataset, other: T_Dataset) -> T_Dataset: return out def reduce( - self: T_Dataset, + self, func: Callable, dim: Dims = None, *, @@ -6650,7 +6673,7 @@ def reduce( keepdims: bool = False, numeric_only: bool = False, **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Reduce this dataset by applying `func` along some dimension(s). Parameters @@ -6775,12 +6798,12 @@ def reduce( ) def map( - self: T_Dataset, + self, func: Callable, keep_attrs: bool | None = None, args: Iterable[Any] = (), **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Apply a function to each data variable in this dataset Parameters @@ -6835,12 +6858,12 @@ def map( return type(self)(variables, attrs=attrs) def apply( - self: T_Dataset, + self, func: Callable, keep_attrs: bool | None = None, args: Iterable[Any] = (), **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """ Backward compatible implementation of ``map`` @@ -6856,10 +6879,10 @@ def apply( return self.map(func, keep_attrs, args, **kwargs) def assign( - self: T_Dataset, + self, variables: Mapping[Any, Any] | None = None, **variables_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Assign new data variables to a Dataset, returning a new object with all the original variables in addition to the new ones. @@ -7164,9 +7187,7 @@ def _set_numpy_data_from_dataframe( self[name] = (dims, data) @classmethod - def from_dataframe( - cls: type[T_Dataset], dataframe: pd.DataFrame, sparse: bool = False - ) -> T_Dataset: + def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: """Convert a pandas.DataFrame into an xarray.Dataset Each column will be converted into an independent variable in the @@ -7380,7 +7401,7 @@ def to_dict( return d @classmethod - def from_dict(cls: type[T_Dataset], d: Mapping[Any, Any]) -> T_Dataset: + def from_dict(cls, d: Mapping[Any, Any]) -> Self: """Convert a dictionary into an xarray.Dataset. Parameters @@ -7470,7 +7491,7 @@ def from_dict(cls: type[T_Dataset], d: Mapping[Any, Any]) -> T_Dataset: return obj - def _unary_op(self: T_Dataset, f, *args, **kwargs) -> T_Dataset: + def _unary_op(self, f, *args, **kwargs) -> Self: variables = {} keep_attrs = kwargs.pop("keep_attrs", None) if keep_attrs is None: @@ -7493,7 +7514,7 @@ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset: return NotImplemented align_type = OPTIONS["arithmetic_join"] if join is None else join if isinstance(other, (DataArray, Dataset)): - self, other = align(self, other, join=align_type, copy=False) # type: ignore[assignment] + self, other = align(self, other, join=align_type, copy=False) g = f if not reflexive else lambda x, y: f(y, x) ds = self._calculate_binary_op(g, other, join=align_type) keep_attrs = _get_keep_attrs(default=False) @@ -7501,7 +7522,7 @@ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset: ds.attrs = self.attrs return ds - def _inplace_binary_op(self: T_Dataset, other, f) -> T_Dataset: + def _inplace_binary_op(self, other, f) -> Self: from xarray.core.dataarray import DataArray from xarray.core.groupby import GroupBy @@ -7575,12 +7596,14 @@ def _copy_attrs_from(self, other): if v in self.variables: self.variables[v].attrs = other.variables[v].attrs + @_deprecate_positional_args("v2023.10.0") def diff( - self: T_Dataset, + self, dim: Hashable, n: int = 1, + *, label: Literal["upper", "lower"] = "upper", - ) -> T_Dataset: + ) -> Self: """Calculate the n-th order discrete difference along given axis. Parameters @@ -7663,11 +7686,11 @@ def diff( return difference def shift( - self: T_Dataset, + self, shifts: Mapping[Any, int] | None = None, fill_value: Any = xrdtypes.NA, **shifts_kwargs: int, - ) -> T_Dataset: + ) -> Self: """Shift this dataset by an offset along one or more dimensions. Only data variables are moved; coordinates stay in place. This is @@ -7734,11 +7757,11 @@ def shift( return self._replace(variables) def roll( - self: T_Dataset, + self, shifts: Mapping[Any, int] | None = None, roll_coords: bool = False, **shifts_kwargs: int, - ) -> T_Dataset: + ) -> Self: """Roll this dataset by an offset along one or more dimensions. Unlike shift, roll treats the given dimensions as periodic, so will not @@ -7820,10 +7843,13 @@ def roll( return self._replace(variables, indexes=indexes) def sortby( - self: T_Dataset, - variables: Hashable | DataArray | list[Hashable | DataArray], + self, + variables: Hashable + | DataArray + | Sequence[Hashable | DataArray] + | Callable[[Self], Hashable | DataArray | list[Hashable | DataArray]], ascending: bool = True, - ) -> T_Dataset: + ) -> Self: """ Sort object by labels or values (along an axis). @@ -7843,9 +7869,10 @@ def sortby( Parameters ---------- - variables : Hashable, DataArray, or list of hashable or DataArray - 1D DataArray objects or name(s) of 1D variable(s) in - coords/data_vars whose values are used to sort the dataset. + kariables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable + 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are + used to sort this array. If a callable, the callable is passed this object, + and the result is used as the value for cond. ascending : bool, default: True Whether to sort by ascending or descending order. @@ -7871,8 +7898,7 @@ def sortby( ... }, ... coords={"x": ["b", "a"], "y": [1, 0]}, ... ) - >>> ds = ds.sortby("x") - >>> ds + >>> ds.sortby("x") Dimensions: (x: 2, y: 2) Coordinates: @@ -7881,17 +7907,28 @@ def sortby( Data variables: A (x, y) int64 3 4 1 2 B (x, y) int64 7 8 5 6 + >>> ds.sortby(lambda x: -x["y"]) + + Dimensions: (x: 2, y: 2) + Coordinates: + * x (x) T_Dataset: + ) -> Self: """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements for each variable @@ -8083,12 +8122,14 @@ def quantile( ) return new.assign_coords(quantile=q) + @_deprecate_positional_args("v2023.10.0") def rank( - self: T_Dataset, + self, dim: Hashable, + *, pct: bool = False, keep_attrs: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Ranks the data. Equal values are assigned a rank that is the average of the ranks that @@ -8142,11 +8183,11 @@ def rank( return self._replace(variables, coord_names, attrs=attrs) def differentiate( - self: T_Dataset, + self, coord: Hashable, edge_order: Literal[1, 2] = 1, datetime_unit: DatetimeUnitOptions | None = None, - ) -> T_Dataset: + ) -> Self: """ Differentiate with the second order accurate central differences. @@ -8214,10 +8255,10 @@ def differentiate( return self._replace(variables) def integrate( - self: T_Dataset, + self, coord: Hashable | Sequence[Hashable], datetime_unit: DatetimeUnitOptions = None, - ) -> T_Dataset: + ) -> Self: """Integrate along the given coordinate using the trapezoidal rule. .. note:: @@ -8333,10 +8374,10 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): ) def cumulative_integrate( - self: T_Dataset, + self, coord: Hashable | Sequence[Hashable], datetime_unit: DatetimeUnitOptions = None, - ) -> T_Dataset: + ) -> Self: """Integrate along the given coordinate using the trapezoidal rule. .. note:: @@ -8408,7 +8449,7 @@ def cumulative_integrate( return result @property - def real(self: T_Dataset) -> T_Dataset: + def real(self) -> Self: """ The real part of each data variable. @@ -8419,7 +8460,7 @@ def real(self: T_Dataset) -> T_Dataset: return self.map(lambda x: x.real, keep_attrs=True) @property - def imag(self: T_Dataset) -> T_Dataset: + def imag(self) -> Self: """ The imaginary part of each data variable. @@ -8431,7 +8472,7 @@ def imag(self: T_Dataset) -> T_Dataset: plot = utils.UncachedAccessor(DatasetPlotAccessor) - def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset: + def filter_by_attrs(self, **kwargs) -> Self: """Returns a ``Dataset`` with variables that match specific conditions. Can pass in ``key=value`` or ``key=callable``. A Dataset is returned @@ -8526,7 +8567,7 @@ def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset: selection.append(var_name) return self[selection] - def unify_chunks(self: T_Dataset) -> T_Dataset: + def unify_chunks(self) -> Self: """Unify chunk size along all chunked dimensions of this Dataset. Returns @@ -8648,7 +8689,7 @@ def map_blocks( return map_blocks(func, self, args, kwargs, template) def polyfit( - self: T_Dataset, + self, dim: Hashable, deg: int, skipna: bool | None = None, @@ -8656,7 +8697,7 @@ def polyfit( w: Hashable | Any = None, full: bool = False, cov: bool | Literal["unscaled"] = False, - ) -> T_Dataset: + ) -> Self: """ Least squares polynomial fit. @@ -8844,7 +8885,7 @@ def polyfit( return type(self)(data_vars=variables, attrs=self.attrs.copy()) def pad( - self: T_Dataset, + self, pad_width: Mapping[Any, int | tuple[int, int]] | None = None, mode: PadModeOptions = "constant", stat_length: int @@ -8858,7 +8899,7 @@ def pad( reflect_type: PadReflectOptions = None, keep_attrs: bool | None = None, **pad_width_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Pad this dataset along one or more dimensions. .. warning:: @@ -9029,13 +9070,15 @@ def pad( attrs = self._attrs if keep_attrs else None return self._replace_with_new_dims(variables, indexes=indexes, attrs=attrs) + @_deprecate_positional_args("v2023.10.0") def idxmin( - self: T_Dataset, + self, dim: Hashable | None = None, + *, skipna: bool | None = None, fill_value: Any = xrdtypes.NA, keep_attrs: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Return the coordinate label of the minimum value along a dimension. Returns a new `Dataset` named after the dimension with the values of @@ -9126,13 +9169,15 @@ def idxmin( ) ) + @_deprecate_positional_args("v2023.10.0") def idxmax( - self: T_Dataset, + self, dim: Hashable | None = None, + *, skipna: bool | None = None, fill_value: Any = xrdtypes.NA, keep_attrs: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Return the coordinate label of the maximum value along a dimension. Returns a new `Dataset` named after the dimension with the values of @@ -9223,7 +9268,7 @@ def idxmax( ) ) - def argmin(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: + def argmin(self, dim: Hashable | None = None, **kwargs) -> Self: """Indices of the minima of the member variables. If there are multiple minima, the indices of the first one found will be @@ -9326,7 +9371,7 @@ def argmin(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: "Dataset.argmin() with a sequence or ... for dim" ) - def argmax(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: + def argmax(self, dim: Hashable | None = None, **kwargs) -> Self: """Indices of the maxima of the member variables. If there are multiple maxima, the indices of the first one found will be @@ -9420,13 +9465,13 @@ def argmax(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: ) def query( - self: T_Dataset, + self, queries: Mapping[Any, Any] | None = None, parser: QueryParserOptions = "pandas", engine: QueryEngineOptions = None, missing_dims: ErrorOptionsWithWarn = "raise", **queries_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Return a new dataset with each array indexed along the specified dimension(s), where the indexers are given as strings containing Python expressions to be evaluated against the data variables in the @@ -9516,7 +9561,7 @@ def query( return self.isel(indexers, missing_dims=missing_dims) def curvefit( - self: T_Dataset, + self, coords: str | DataArray | Iterable[str | DataArray], func: Callable[..., Any], reduce_dims: Dims = None, @@ -9526,7 +9571,7 @@ def curvefit( param_names: Sequence[str] | None = None, errors: ErrorOptions = "raise", kwargs: dict[str, Any] | None = None, - ) -> T_Dataset: + ) -> Self: """ Curve fitting optimization for arbitrary functions. @@ -9749,11 +9794,13 @@ def _wrapper(Y, *args, **kwargs): return result + @_deprecate_positional_args("v2023.10.0") def drop_duplicates( - self: T_Dataset, + self, dim: Hashable | Iterable[Hashable], + *, keep: Literal["first", "last", False] = "first", - ) -> T_Dataset: + ) -> Self: """Returns a new Dataset with duplicate dimension values removed. Parameters @@ -9793,13 +9840,13 @@ def drop_duplicates( return self.isel(indexes) def convert_calendar( - self: T_Dataset, + self, calendar: CFCalendar, dim: Hashable = "time", align_on: Literal["date", "year", None] = None, missing: Any | None = None, use_cftime: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Convert the Dataset to another calendar. Only converts the individual timestamps, does not modify any data except @@ -9916,10 +9963,10 @@ def convert_calendar( ) def interp_calendar( - self: T_Dataset, + self, target: pd.DatetimeIndex | CFTimeIndex | DataArray, dim: Hashable = "time", - ) -> T_Dataset: + ) -> Self: """Interpolates the Dataset to another calendar based on decimal year measure. Each timestamp in `source` and `target` are first converted to their decimal diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 0762fa03112..ccf84146819 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +from typing import Any import numpy as np @@ -44,7 +45,7 @@ def __eq__(self, other): ) -def maybe_promote(dtype): +def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]: """Simpler equivalent of pandas.core.common._maybe_promote Parameters @@ -57,27 +58,33 @@ def maybe_promote(dtype): fill_value : Valid missing value for the promoted dtype. """ # N.B. these casting rules should match pandas + dtype_: np.typing.DTypeLike + fill_value: Any if np.issubdtype(dtype, np.floating): + dtype_ = dtype fill_value = np.nan elif np.issubdtype(dtype, np.timedelta64): # See https://github.com/numpy/numpy/issues/10685 # np.timedelta64 is a subclass of np.integer # Check np.timedelta64 before np.integer fill_value = np.timedelta64("NaT") + dtype_ = dtype elif np.issubdtype(dtype, np.integer): - dtype = np.float32 if dtype.itemsize <= 2 else np.float64 + dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64 fill_value = np.nan elif np.issubdtype(dtype, np.complexfloating): + dtype_ = dtype fill_value = np.nan + np.nan * 1j elif np.issubdtype(dtype, np.datetime64): + dtype_ = dtype fill_value = np.datetime64("NaT") else: - dtype = object + dtype_ = object fill_value = np.nan - dtype = np.dtype(dtype) - fill_value = dtype.type(fill_value) - return dtype, fill_value + dtype_out = np.dtype(dtype_) + fill_value = dtype_out.type(fill_value) + return dtype_out, fill_value NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 4f245e59f73..078aab0ed63 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -337,6 +337,10 @@ def reshape(array, shape): return xp.reshape(array, shape) +def ravel(array): + return reshape(array, (-1,)) + + @contextlib.contextmanager def _ignore_warnings_if(condition): if condition: @@ -363,7 +367,7 @@ def f(values, axis=None, skipna=None, **kwargs): values = asarray(values) if coerce_strings and values.dtype.kind in "SU": - values = values.astype(object) + values = astype(values, object) func = None if skipna or (skipna is None and values.dtype.kind in "cfO"): diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 86e96ca2deb..09b7fafc7bf 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -178,6 +178,8 @@ def format_item(x, timedelta_format=None, quote_strings=True): if isinstance(x, (np.timedelta64, timedelta)): return format_timedelta(x, timedelta_format=timedelta_format) elif isinstance(x, (str, bytes)): + if hasattr(x, "dtype"): + x = x.item() return repr(x) if quote_strings else x elif hasattr(x, "dtype") and np.issubdtype(x.dtype, np.floating): return f"{x.item():.4}" diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 9894a4a4daf..8ed7148e2a1 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -43,6 +43,7 @@ peek_at, ) from xarray.core.variable import IndexVariable, Variable +from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: from numpy.typing import ArrayLike @@ -699,7 +700,7 @@ class GroupBy(Generic[T_Xarray]): _groups: dict[GroupKey, GroupIndex] | None _dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None - _sizes: Frozen[Hashable, int] | None + _sizes: Mapping[Hashable, int] | None def __init__( self, @@ -746,7 +747,7 @@ def __init__( self._sizes = None @property - def sizes(self) -> Frozen[Hashable, int]: + def sizes(self) -> Mapping[Hashable, int]: """Ordered mapping from dimension names to lengths. Immutable. @@ -1092,10 +1093,12 @@ def fillna(self, value: Any) -> T_Xarray: """ return ops.fillna(self, value) + @_deprecate_positional_args("v2023.10.0") def quantile( self, q: ArrayLike, dim: Dims = None, + *, method: QuantileMethods = "linear", keep_attrs: bool | None = None, skipna: bool | None = None, diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 9972896d6df..1697762f7ae 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -24,7 +24,7 @@ ) if TYPE_CHECKING: - from xarray.core.types import ErrorOptions, JoinOptions, T_Index + from xarray.core.types import ErrorOptions, JoinOptions, Self from xarray.core.variable import Variable @@ -60,11 +60,11 @@ class Index: @classmethod def from_variables( - cls: type[T_Index], + cls, variables: Mapping[Any, Variable], *, options: Mapping[str, Any], - ) -> T_Index: + ) -> Self: """Create a new index object from one or more coordinate variables. This factory method must be implemented in all subclasses of Index. @@ -88,11 +88,11 @@ def from_variables( @classmethod def concat( - cls: type[T_Index], - indexes: Sequence[T_Index], + cls, + indexes: Sequence[Self], dim: Hashable, positions: Iterable[Iterable[int]] | None = None, - ) -> T_Index: + ) -> Self: """Create a new index by concatenating one or more indexes of the same type. @@ -120,9 +120,7 @@ def concat( raise NotImplementedError() @classmethod - def stack( - cls: type[T_Index], variables: Mapping[Any, Variable], dim: Hashable - ) -> T_Index: + def stack(cls, variables: Mapping[Any, Variable], dim: Hashable) -> Self: """Create a new index by stacking coordinate variables into a single new dimension. @@ -208,8 +206,8 @@ def to_pandas_index(self) -> pd.Index: raise TypeError(f"{self!r} cannot be cast to a pandas.Index object") def isel( - self: T_Index, indexers: Mapping[Any, int | slice | np.ndarray | Variable] - ) -> T_Index | None: + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> Self | None: """Maybe returns a new index from the current index itself indexed by positional indexers. @@ -264,7 +262,7 @@ def sel(self, labels: dict[Any, Any]) -> IndexSelResult: """ raise NotImplementedError(f"{self!r} doesn't support label-based selection") - def join(self: T_Index, other: T_Index, how: JoinOptions = "inner") -> T_Index: + def join(self, other: Self, how: JoinOptions = "inner") -> Self: """Return a new index from the combination of this index with another index of the same type. @@ -286,7 +284,7 @@ def join(self: T_Index, other: T_Index, how: JoinOptions = "inner") -> T_Index: f"{self!r} doesn't support alignment with inner/outer join method" ) - def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]: + def reindex_like(self, other: Self) -> dict[Hashable, Any]: """Query the index with another index of the same type. Implementation is optional but required in order to support alignment. @@ -304,7 +302,7 @@ def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]: """ raise NotImplementedError(f"{self!r} doesn't support re-indexing labels") - def equals(self: T_Index, other: T_Index) -> bool: + def equals(self, other: Self) -> bool: """Compare this index with another index of the same type. Implementation is optional but required in order to support alignment. @@ -321,7 +319,7 @@ def equals(self: T_Index, other: T_Index) -> bool: """ raise NotImplementedError() - def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index | None: + def roll(self, shifts: Mapping[Any, int]) -> Self | None: """Roll this index by an offset along one or more dimensions. This method can be re-implemented in subclasses of Index, e.g., when the @@ -347,10 +345,10 @@ def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index | None: return None def rename( - self: T_Index, + self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable], - ) -> T_Index: + ) -> Self: """Maybe update the index with new coordinate and dimension names. This method should be re-implemented in subclasses of Index if it has @@ -377,7 +375,7 @@ def rename( """ return self - def copy(self: T_Index, deep: bool = True) -> T_Index: + def copy(self, deep: bool = True) -> Self: """Return a (deep) copy of this index. Implementation in subclasses of Index is optional. The base class @@ -396,15 +394,13 @@ def copy(self: T_Index, deep: bool = True) -> T_Index: """ return self._copy(deep=deep) - def __copy__(self: T_Index) -> T_Index: + def __copy__(self) -> Self: return self.copy(deep=False) def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Index: return self._copy(deep=True, memo=memo) - def _copy( - self: T_Index, deep: bool = True, memo: dict[int, Any] | None = None - ) -> T_Index: + def _copy(self, deep: bool = True, memo: dict[int, Any] | None = None) -> Self: cls = self.__class__ copied = cls.__new__(cls) if deep: @@ -414,7 +410,7 @@ def _copy( copied.__dict__.update(self.__dict__) return copied - def __getitem__(self: T_Index, indexer: Any) -> T_Index: + def __getitem__(self, indexer: Any) -> Self: raise NotImplementedError() def _repr_inline_(self, max_width): @@ -674,10 +670,10 @@ def _concat_indexes(indexes, dim, positions=None) -> pd.Index: @classmethod def concat( cls, - indexes: Sequence[PandasIndex], + indexes: Sequence[Self], dim: Hashable, positions: Iterable[Iterable[int]] | None = None, - ) -> PandasIndex: + ) -> Self: new_pd_index = cls._concat_indexes(indexes, dim, positions) if not indexes: @@ -800,7 +796,11 @@ def equals(self, other: Index): return False return self.index.equals(other.index) and self.dim == other.dim - def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasIndex: + def join( + self, + other: Self, + how: str = "inner", + ) -> Self: if how == "outer": index = self.index.union(other.index) else: @@ -811,7 +811,7 @@ def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasInd return type(self)(index, self.dim, coord_dtype=coord_dtype) def reindex_like( - self, other: PandasIndex, method=None, tolerance=None + self, other: Self, method=None, tolerance=None ) -> dict[Hashable, Any]: if not self.index.is_unique: raise ValueError( @@ -963,12 +963,12 @@ def from_variables( return obj @classmethod - def concat( # type: ignore[override] + def concat( cls, - indexes: Sequence[PandasMultiIndex], + indexes: Sequence[Self], dim: Hashable, positions: Iterable[Iterable[int]] | None = None, - ) -> PandasMultiIndex: + ) -> Self: new_pd_index = cls._concat_indexes(indexes, dim, positions) if not indexes: @@ -1602,7 +1602,7 @@ def to_pandas_indexes(self) -> Indexes[pd.Index]: return Indexes(indexes, self._variables, index_type=pd.Index) def copy_indexes( - self, deep: bool = True, memo: dict[int, Any] | None = None + self, deep: bool = True, memo: dict[int, T_PandasOrXarrayIndex] | None = None ) -> tuple[dict[Hashable, T_PandasOrXarrayIndex], dict[Hashable, Variable]]: """Return a new dictionary with copies of indexes, preserving unique indexes. @@ -1619,6 +1619,7 @@ def copy_indexes( new_indexes = {} new_index_vars = {} + idx: T_PandasOrXarrayIndex for idx, coords in self.group_by_index(): if isinstance(idx, pd.Index): convert_new_idx = True diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 3475db4a010..a8e54ad1231 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -474,10 +474,11 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - out = [] + out: list[DatasetLike] = [] for obj in objects: + variables: DatasetLike if isinstance(obj, (Dataset, Coordinates)): - variables: DatasetLike = obj + variables = obj else: variables = {} if isinstance(obj, PANDAS_TYPES): @@ -491,7 +492,7 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik def _get_priority_vars_and_indexes( - objects: list[DatasetLike], + objects: Sequence[DatasetLike], priority_arg: int | None, compat: CompatOptions = "equals", ) -> dict[Hashable, MergeElement]: @@ -503,7 +504,7 @@ def _get_priority_vars_and_indexes( Parameters ---------- - objects : list of dict-like of Variable + objects : sequence of dict-like of Variable Dictionaries in which to find the priority variables. priority_arg : int or None Integer object whose variable should take priority. diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 3b8ddfe032d..fc7240139aa 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -4,7 +4,7 @@ import numpy as np -from xarray.core import dtypes, nputils, utils +from xarray.core import dtypes, duck_array_ops, nputils, utils from xarray.core.duck_array_ops import ( astype, count, @@ -21,12 +21,16 @@ def _maybe_null_out(result, axis, mask, min_count=1): xarray version of pandas.core.nanops._maybe_null_out """ if axis is not None and getattr(result, "ndim", False): - null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0 + null_mask = ( + np.take(mask.shape, axis).prod() + - duck_array_ops.sum(mask, axis) + - min_count + ) < 0 dtype, fill_value = dtypes.maybe_promote(result.dtype) result = where(null_mask, fill_value, astype(result, dtype)) elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES: - null_mask = mask.size - mask.sum() + null_mask = mask.size - duck_array_ops.sum(mask) result = where(null_mask < min_count, np.nan, result) return result diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 07c3c606bf2..949576b4ee8 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -443,7 +443,7 @@ def subset_dataset_to_block( for dim in variable.dims: chunk = chunk[chunk_index[dim]] - chunk_variable_task = (f"{name}-{gname}-{chunk[0]}",) + chunk_tuple + chunk_variable_task = (f"{name}-{gname}-{chunk[0]!r}",) + chunk_tuple graph[chunk_variable_task] = ( tuple, [variable.dims, chunk, variable.attrs], diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index d49cb6e13a4..b85092982e3 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -61,6 +61,11 @@ class Rolling(Generic[T_Xarray]): __slots__ = ("obj", "window", "min_periods", "center", "dim") _attributes = ("window", "min_periods", "center", "dim") + dim: list[Hashable] + window: list[int] + center: list[bool] + obj: T_Xarray + min_periods: int def __init__( self, @@ -91,8 +96,8 @@ def __init__( ------- rolling : type of input argument """ - self.dim: list[Hashable] = [] - self.window: list[int] = [] + self.dim = [] + self.window = [] for d, w in windows.items(): self.dim.append(d) if w <= 0: @@ -100,7 +105,7 @@ def __init__( self.window.append(w) self.center = self._mapping_to_list(center, default=False) - self.obj: T_Xarray = obj + self.obj = obj missing_dims = tuple(dim for dim in self.dim if dim not in self.obj.dims) if missing_dims: @@ -785,11 +790,14 @@ def construct( if not keep_attrs: dataset[key].attrs = {} + # Need to stride coords as well. TODO: is there a better way? + coords = self.obj.isel( + {d: slice(None, None, s) for d, s in zip(self.dim, strides)} + ).coords + attrs = self.obj.attrs if keep_attrs else {} - return Dataset(dataset, coords=self.obj.coords, attrs=attrs).isel( - {d: slice(None, None, s) for d, s in zip(self.dim, strides)} - ) + return Dataset(dataset, coords=coords, attrs=attrs) class Coarsen(CoarsenArithmetic, Generic[T_Xarray]): @@ -811,6 +819,10 @@ class Coarsen(CoarsenArithmetic, Generic[T_Xarray]): ) _attributes = ("windows", "side", "trim_excess") obj: T_Xarray + windows: Mapping[Hashable, int] + side: SideOptions | Mapping[Hashable, SideOptions] + boundary: CoarsenBoundaryOptions + coord_func: Mapping[Hashable, str | Callable] def __init__( self, @@ -852,12 +864,15 @@ def __init__( f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} " f"dimensions {tuple(self.obj.dims)}" ) - if not utils.is_dict_like(coord_func): - coord_func = {d: coord_func for d in self.obj.dims} # type: ignore[misc] + + if utils.is_dict_like(coord_func): + coord_func_map = coord_func + else: + coord_func_map = {d: coord_func for d in self.obj.dims} for c in self.obj.coords: - if c not in coord_func: - coord_func[c] = duck_array_ops.mean # type: ignore[index] - self.coord_func: Mapping[Hashable, str | Callable] = coord_func + if c not in coord_func_map: + coord_func_map[c] = duck_array_ops.mean # type: ignore[index] + self.coord_func = coord_func_map def _get_keep_attrs(self, keep_attrs): if keep_attrs is None: diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index bd30c634aae..cb77358869c 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -9,10 +9,15 @@ from xarray.core.options import _get_keep_attrs from xarray.core.pdcompat import count_not_none from xarray.core.pycompat import is_duck_dask_array -from xarray.core.types import T_DataWithCoords +from xarray.core.types import T_DataWithCoords, T_DuckArray -def _get_alpha(com=None, span=None, halflife=None, alpha=None): +def _get_alpha( + com: float | None = None, + span: float | None = None, + halflife: float | None = None, + alpha: float | None = None, +) -> float: # pandas defines in terms of com (converting to alpha in the algo) # so use its function to get a com and then convert to alpha @@ -20,7 +25,7 @@ def _get_alpha(com=None, span=None, halflife=None, alpha=None): return 1 / (1 + com) -def move_exp_nanmean(array, *, axis, alpha): +def move_exp_nanmean(array: T_DuckArray, *, axis: int, alpha: float) -> np.ndarray: if is_duck_dask_array(array): raise TypeError("rolling_exp is not currently support for dask-like arrays") import numbagg @@ -32,7 +37,7 @@ def move_exp_nanmean(array, *, axis, alpha): return numbagg.move_exp_nanmean(array, axis=axis, alpha=alpha) -def move_exp_nansum(array, *, axis, alpha): +def move_exp_nansum(array: T_DuckArray, *, axis: int, alpha: float) -> np.ndarray: if is_duck_dask_array(array): raise TypeError("rolling_exp is not currently supported for dask-like arrays") import numbagg @@ -40,7 +45,12 @@ def move_exp_nansum(array, *, axis, alpha): return numbagg.move_exp_nansum(array, axis=axis, alpha=alpha) -def _get_center_of_mass(comass, span, halflife, alpha): +def _get_center_of_mass( + comass: float | None, + span: float | None, + halflife: float | None, + alpha: float | None, +) -> float: """ Vendored from pandas.core.window.common._get_center_of_mass @@ -137,9 +147,9 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords: input_core_dims=[[self.dim]], kwargs=dict(alpha=self.alpha, axis=-1), output_core_dims=[[self.dim]], - exclude_dims={self.dim}, keep_attrs=keep_attrs, on_missing_core_dim="copy", + dask="parallelized", ).transpose(*dim_order) def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords: @@ -173,7 +183,7 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords: input_core_dims=[[self.dim]], kwargs=dict(alpha=self.alpha, axis=-1), output_core_dims=[[self.dim]], - exclude_dims={self.dim}, keep_attrs=keep_attrs, on_missing_core_dim="copy", + dask="parallelized", ).transpose(*dim_order) diff --git a/xarray/core/types.py b/xarray/core/types.py index f80c2c52cd7..2af9591d22a 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -19,9 +19,9 @@ try: if sys.version_info >= (3, 11): - from typing import Self + from typing import Self, TypeAlias else: - from typing_extensions import Self + from typing_extensions import Self, TypeAlias except ImportError: if TYPE_CHECKING: raise @@ -38,7 +38,6 @@ from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.groupby import DataArrayGroupBy, GroupBy from xarray.core.indexes import Index, Indexes from xarray.core.utils import Frozen from xarray.core.variable import Variable @@ -107,7 +106,7 @@ def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]: ... @property - def sizes(self) -> Frozen[Hashable, int]: + def sizes(self) -> Mapping[Hashable, int]: ... @property @@ -146,6 +145,8 @@ def copy( ... +T_Alignable = TypeVar("T_Alignable", bound="Alignable") + T_Backend = TypeVar("T_Backend", bound="BackendEntrypoint") T_Dataset = TypeVar("T_Dataset", bound="Dataset") T_DataArray = TypeVar("T_DataArray", bound="DataArray") @@ -154,29 +155,46 @@ def copy( T_Array = TypeVar("T_Array", bound="AbstractArray") T_Index = TypeVar("T_Index", bound="Index") +# `T_Xarray` is a type variable that can be either "DataArray" or "Dataset". When used +# in a function definition, all inputs and outputs annotated with `T_Xarray` must be of +# the same concrete type, either "DataArray" or "Dataset". This is generally preferred +# over `T_DataArrayOrSet`, given the type system can determine the exact type. +T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") + +# `T_DataArrayOrSet` is a type variable that is bounded to either "DataArray" or +# "Dataset". Use it for functions that might return either type, but where the exact +# type cannot be determined statically using the type system. T_DataArrayOrSet = TypeVar("T_DataArrayOrSet", bound=Union["Dataset", "DataArray"]) -# Maybe we rename this to T_Data or something less Fortran-y? -T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") +# For working directly with `DataWithCoords`. It will only allow using methods defined +# on `DataWithCoords`. T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") -T_Alignable = TypeVar("T_Alignable", bound="Alignable") + # Temporary placeholder for indicating an array api compliant type. # hopefully in the future we can narrow this down more: T_DuckArray = TypeVar("T_DuckArray", bound=Any) ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] -DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"] -DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"] VarCompatible = Union["Variable", "ScalarOrArray"] -GroupByIncompatible = Union["Variable", "GroupBy"] +DaCompatible = Union["DataArray", "VarCompatible"] +DsCompatible = Union["Dataset", "DaCompatible"] +GroupByCompatible = Union["Dataset", "DataArray"] Dims = Union[str, Iterable[Hashable], "ellipsis", None] OrderedDims = Union[str, Sequence[Union[Hashable, "ellipsis"]], "ellipsis", None] -T_Chunks = Union[int, dict[Any, Any], Literal["auto"], None] +# FYI in some cases we don't allow `None`, which this doesn't take account of. +T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]] +# We allow the tuple form of this (though arguably we could transition to named dims only) +T_Chunks: TypeAlias = Union[ + T_ChunkDim, Mapping[Any, T_ChunkDim], tuple[T_ChunkDim, ...] +] T_NormalizedChunks = tuple[tuple[int, ...], ...] +DataVars = Mapping[Any, Any] + + ErrorOptions = Literal["raise", "ignore"] ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"] diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f459f044751..3f399dcac7c 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -26,10 +26,7 @@ as_indexable, ) from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.parallelcompat import ( - get_chunked_array_type, - guess_chunkmanager, -) +from xarray.core.parallelcompat import get_chunked_array_type, guess_chunkmanager from xarray.core.pycompat import ( integer_types, is_0d_dask_array, @@ -38,8 +35,6 @@ to_numpy, ) from xarray.core.utils import ( - Frozen, - NdimSizeLenMixin, OrderedSet, _default, decode_numpy_dict_values, @@ -50,6 +45,7 @@ is_duck_array, maybe_coerce_to_str, ) +from xarray.namedarray.core import NamedArray NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( indexing.ExplicitlyIndexed, @@ -66,8 +62,8 @@ PadModeOptions, PadReflectOptions, QuantileMethods, + Self, T_DuckArray, - T_Variable, ) NON_NANOSECOND_WARNING = ( @@ -268,7 +264,7 @@ def as_compatible_data( data = np.timedelta64(getattr(data, "value", data), "ns") # we don't want nested self-described arrays - if isinstance(data, (pd.Series, pd.Index, pd.DataFrame)): + if isinstance(data, (pd.Series, pd.DataFrame)): data = data.values if isinstance(data, np.ma.MaskedArray): @@ -315,7 +311,7 @@ def _as_array_or_item(data): return data -class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic): +class Variable(NamedArray, AbstractArray, VariableArithmetic): """A netcdf-like variable consisting of dimensions, data and attributes which describe a single Array. A single Variable object is not fully described outside the context of its parent Dataset (if you want such a @@ -365,51 +361,14 @@ def __init__( Well-behaved code to serialize a Variable should ignore unrecognized encoding items. """ - self._data: T_DuckArray = as_compatible_data(data, fastpath=fastpath) - self._dims = self._parse_dimensions(dims) - self._attrs: dict[Any, Any] | None = None + super().__init__( + dims=dims, data=as_compatible_data(data, fastpath=fastpath), attrs=attrs + ) + self._encoding = None - if attrs is not None: - self.attrs = attrs if encoding is not None: self.encoding = encoding - @property - def dtype(self) -> np.dtype: - """ - Data-type of the array’s elements. - - See Also - -------- - ndarray.dtype - numpy.dtype - """ - return self._data.dtype - - @property - def shape(self) -> tuple[int, ...]: - """ - Tuple of array dimensions. - - See Also - -------- - numpy.ndarray.shape - """ - return self._data.shape - - @property - def nbytes(self) -> int: - """ - Total bytes consumed by the elements of the data array. - - If the underlying data array does not include ``nbytes``, estimates - the bytes consumed based on the ``size`` and ``dtype``. - """ - if hasattr(self._data, "nbytes"): - return self._data.nbytes - else: - return self.size * self.dtype.itemsize - @property def _in_memory(self): return isinstance( @@ -420,7 +379,7 @@ def _in_memory(self): ) @property - def data(self: T_Variable): + def data(self): """ The Variable's data as an array. The underlying array type (e.g. dask, sparse, pint) is preserved. @@ -439,17 +398,13 @@ def data(self: T_Variable): return self.values @data.setter - def data(self: T_Variable, data: T_DuckArray | ArrayLike) -> None: + def data(self, data: T_DuckArray | ArrayLike) -> None: data = as_compatible_data(data) - if data.shape != self.shape: # type: ignore[attr-defined] - raise ValueError( - f"replacement data must match the Variable's shape. " - f"replacement data has shape {data.shape}; Variable has shape {self.shape}" # type: ignore[attr-defined] - ) + self._check_shape(data) self._data = data def astype( - self: T_Variable, + self, dtype, *, order=None, @@ -457,7 +412,7 @@ def astype( subok=None, copy=None, keep_attrs=True, - ) -> T_Variable: + ) -> Self: """ Copy of the Variable object, with data cast to a specified type. @@ -571,41 +526,6 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) - def __dask_tokenize__(self): - # Use v.data, instead of v._data, in order to cope with the wrappers - # around NetCDF and the like - from dask.base import normalize_token - - return normalize_token((type(self), self._dims, self.data, self.attrs)) - - def __dask_graph__(self): - if is_duck_dask_array(self._data): - return self._data.__dask_graph__() - else: - return None - - def __dask_keys__(self): - return self._data.__dask_keys__() - - def __dask_layers__(self): - return self._data.__dask_layers__() - - @property - def __dask_optimize__(self): - return self._data.__dask_optimize__ - - @property - def __dask_scheduler__(self): - return self._data.__dask_scheduler__ - - def __dask_postcompute__(self): - array_func, array_args = self._data.__dask_postcompute__() - return self._dask_finalize, (array_func,) + array_args - - def __dask_postpersist__(self): - array_func, array_args = self._data.__dask_postpersist__() - return self._dask_finalize, (array_func,) + array_args - def _dask_finalize(self, results, array_func, *args, **kwargs): data = array_func(results, *args, **kwargs) return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding) @@ -667,27 +587,6 @@ def to_dict( return item - @property - def dims(self) -> tuple[Hashable, ...]: - """Tuple of dimension names with which this variable is associated.""" - return self._dims - - @dims.setter - def dims(self, value: str | Iterable[Hashable]) -> None: - self._dims = self._parse_dimensions(value) - - def _parse_dimensions(self, dims: str | Iterable[Hashable]) -> tuple[Hashable, ...]: - if isinstance(dims, str): - dims = (dims,) - else: - dims = tuple(dims) - if len(dims) != self.ndim: - raise ValueError( - f"dimensions {dims} must have the same length as the " - f"number of data dimensions, ndim={self.ndim}" - ) - return dims - def _item_key_to_tuple(self, key): if utils.is_dict_like(key): return tuple(key.get(dim, slice(None)) for dim in self.dims) @@ -820,13 +719,6 @@ def _broadcast_indexes_outer(self, key): return dims, OuterIndexer(tuple(new_key)), None - def _nonzero(self): - """Equivalent numpy's nonzero but returns a tuple of Variables.""" - # TODO we should replace dask's native nonzero - # after https://github.com/dask/dask/issues/1076 is implemented. - nonzeros = np.nonzero(self.data) - return tuple(Variable((dim), nz) for nz, dim in zip(nonzeros, self.dims)) - def _broadcast_indexes_vectorized(self, key): variables = [] out_dims_set = OrderedSet() @@ -883,7 +775,7 @@ def _broadcast_indexes_vectorized(self, key): return out_dims, VectorizedIndexer(tuple(out_key)), new_order - def __getitem__(self: T_Variable, key) -> T_Variable: + def __getitem__(self, key) -> Self: """Return a new Variable object whose contents are consistent with getting the provided key from the underlying data. @@ -902,7 +794,7 @@ def __getitem__(self: T_Variable, key) -> T_Variable: data = np.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) - def _finalize_indexing_result(self: T_Variable, dims, data) -> T_Variable: + def _finalize_indexing_result(self, dims, data) -> Self: """Used by IndexVariable to return IndexVariable objects when possible.""" return self._replace(dims=dims, data=data) @@ -976,17 +868,6 @@ def __setitem__(self, key, value): indexable = as_indexable(self._data) indexable[index_tuple] = value - @property - def attrs(self) -> dict[Any, Any]: - """Dictionary of local attributes on this variable.""" - if self._attrs is None: - self._attrs = {} - return self._attrs - - @attrs.setter - def attrs(self, value: Mapping[Any, Any]) -> None: - self._attrs = dict(value) - @property def encoding(self) -> dict[Any, Any]: """Dictionary of encodings on this variable.""" @@ -1001,76 +882,22 @@ def encoding(self, value): except ValueError: raise ValueError("encoding must be castable to a dictionary") - def reset_encoding(self: T_Variable) -> T_Variable: + def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() + + def drop_encoding(self) -> Self: """Return a new Variable without encoding.""" return self._replace(encoding={}) - def copy( - self: T_Variable, deep: bool = True, data: T_DuckArray | ArrayLike | None = None - ) -> T_Variable: - """Returns a copy of this object. - - If `deep=True`, the data array is loaded into memory and copied onto - the new object. Dimensions, attributes and encodings are always copied. - - Use `data` to create a new object with the same structure as - original but entirely new data. - - Parameters - ---------- - deep : bool, default: True - Whether the data array is loaded into memory and copied onto - the new object. Default is True. - data : array_like, optional - Data to use in the new object. Must have same shape as original. - When `data` is used, `deep` is ignored. - - Returns - ------- - object : Variable - New object with dimensions, attributes, encodings, and optionally - data copied from original. - - Examples - -------- - Shallow copy versus deep copy - - >>> var = xr.Variable(data=[1, 2, 3], dims="x") - >>> var.copy() - - array([1, 2, 3]) - >>> var_0 = var.copy(deep=False) - >>> var_0[0] = 7 - >>> var_0 - - array([7, 2, 3]) - >>> var - - array([7, 2, 3]) - - Changing the data using the ``data`` argument maintains the - structure of the original object, but with the new data. Original - object is unaffected. - - >>> var.copy(data=[0.1, 0.2, 0.3]) - - array([0.1, 0.2, 0.3]) - >>> var - - array([7, 2, 3]) - - See Also - -------- - pandas.DataFrame.copy - """ - return self._copy(deep=deep, data=data) - def _copy( - self: T_Variable, + self, deep: bool = True, data: T_DuckArray | ArrayLike | None = None, memo: dict[int, Any] | None = None, - ) -> T_Variable: + ) -> Self: if data is None: data_old = self._data @@ -1099,71 +926,23 @@ def _copy( return self._replace(data=ndata, attrs=attrs, encoding=encoding) def _replace( - self: T_Variable, + self, dims=_default, data=_default, attrs=_default, encoding=_default, - ) -> T_Variable: + ) -> Self: if dims is _default: dims = copy.copy(self._dims) if data is _default: data = copy.copy(self.data) if attrs is _default: attrs = copy.copy(self._attrs) + if encoding is _default: encoding = copy.copy(self._encoding) return type(self)(dims, data, attrs, encoding, fastpath=True) - def __copy__(self: T_Variable) -> T_Variable: - return self._copy(deep=False) - - def __deepcopy__( - self: T_Variable, memo: dict[int, Any] | None = None - ) -> T_Variable: - return self._copy(deep=True, memo=memo) - - # mutable objects should not be hashable - # https://github.com/python/mypy/issues/4266 - __hash__ = None # type: ignore[assignment] - - @property - def chunks(self) -> tuple[tuple[int, ...], ...] | None: - """ - Tuple of block lengths for this dataarray's data, in order of dimensions, or None if - the underlying data is not a dask array. - - See Also - -------- - Variable.chunk - Variable.chunksizes - xarray.unify_chunks - """ - return getattr(self._data, "chunks", None) - - @property - def chunksizes(self) -> Mapping[Any, tuple[int, ...]]: - """ - Mapping from dimension names to block lengths for this variable's data, or None if - the underlying data is not a dask array. - Cannot be modified directly, but can be modified by calling .chunk(). - - Differs from variable.chunks because it returns a mapping of dimensions to chunk shapes - instead of a tuple of chunk shapes. - - See Also - -------- - Variable.chunk - Variable.chunks - xarray.unify_chunks - """ - if hasattr(self._data, "chunks"): - return Frozen({dim: c for dim, c in zip(self.dims, self.data.chunks)}) - else: - return {} - - _array_counter = itertools.count() - def chunk( self, chunks: ( @@ -1179,7 +958,7 @@ def chunk( chunked_array_type: str | ChunkManagerEntrypoint | None = None, from_array_kwargs=None, **chunks_kwargs: Any, - ) -> Variable: + ) -> Self: """Coerce this array's data into a dask array with the given chunks. If this variable is a non-dask array, it will be converted to dask @@ -1262,7 +1041,7 @@ def chunk( data_old = self._data if chunkmanager.is_chunked_array(data_old): - data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type] + data_chunked = chunkmanager.rechunk(data_old, chunks) else: if isinstance(data_old, indexing.ExplicitlyIndexed): # Unambiguously handle array storage backends (like NetCDF4 and h5py) @@ -1284,7 +1063,7 @@ def chunk( data_chunked = chunkmanager.from_array( ndata, - chunks, # type: ignore[arg-type] + chunks, **_from_array_kwargs, ) @@ -1295,46 +1074,16 @@ def to_numpy(self) -> np.ndarray: # TODO an entrypoint so array libraries can choose coercion method? return to_numpy(self._data) - def as_numpy(self: T_Variable) -> T_Variable: + def as_numpy(self) -> Self: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) - def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA): - """ - use sparse-array as backend. - """ - import sparse - - # TODO: what to do if dask-backended? - if fill_value is dtypes.NA: - dtype, fill_value = dtypes.maybe_promote(self.dtype) - else: - dtype = dtypes.result_type(self.dtype, fill_value) - - if sparse_format is _default: - sparse_format = "coo" - try: - as_sparse = getattr(sparse, f"as_{sparse_format.lower()}") - except AttributeError: - raise ValueError(f"{sparse_format} is not a valid sparse format") - - data = as_sparse(self.data.astype(dtype), fill_value=fill_value) - return self._replace(data=data) - - def _to_dense(self): - """ - Change backend from sparse to np.array - """ - if hasattr(self._data, "todense"): - return self._replace(data=self._data.todense()) - return self.copy(deep=False) - def isel( - self: T_Variable, + self, indexers: Mapping[Any, Any] | None = None, missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, - ) -> T_Variable: + ) -> Self: """Return a new array indexed along the specified dimension(s). Parameters @@ -1621,7 +1370,7 @@ def transpose( self, *dims: Hashable | ellipsis, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> Variable: + ) -> Self: """Return a new Variable object with transposed dimensions. Parameters @@ -1666,7 +1415,7 @@ def transpose( return self._replace(dims=dims, data=data) @property - def T(self) -> Variable: + def T(self) -> Self: return self.transpose() def set_dims(self, dims, shape=None): @@ -1774,9 +1523,7 @@ def stack(self, dimensions=None, **dimensions_kwargs): result = result._stack_once(dims, new_dim) return result - def _unstack_once_full( - self, dims: Mapping[Any, int], old_dim: Hashable - ) -> Variable: + def _unstack_once_full(self, dims: Mapping[Any, int], old_dim: Hashable) -> Self: """ Unstacks the variable without needing an index. @@ -1809,7 +1556,9 @@ def _unstack_once_full( new_data = reordered.data.reshape(new_shape) new_dims = reordered.dims[: len(other_dims)] + new_dim_names - return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) + return type(self)( + new_dims, new_data, self._attrs, self._encoding, fastpath=True + ) def _unstack_once( self, @@ -1817,7 +1566,7 @@ def _unstack_once( dim: Hashable, fill_value=dtypes.NA, sparse: bool = False, - ) -> Variable: + ) -> Self: """ Unstacks this variable given an index to unstack and the name of the dimension to which the index refers. @@ -2029,6 +1778,8 @@ def reduce( keep_attrs = _get_keep_attrs(default=False) attrs = self._attrs if keep_attrs else None + # We need to return `Variable` rather than the type of `self` at the moment, ref + # #8216 return Variable(dims, data, attrs=attrs) @classmethod @@ -2178,7 +1929,7 @@ def quantile( keep_attrs: bool | None = None, skipna: bool | None = None, interpolation: QuantileMethods | None = None, - ) -> Variable: + ) -> Self: """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements. @@ -2564,7 +2315,7 @@ def coarsen_reshape(self, windows, boundary, side): else: shape.append(variable.shape[i]) - return variable.data.reshape(shape), tuple(axes) + return duck_array_ops.reshape(variable.data, shape), tuple(axes) def isnull(self, keep_attrs: bool | None = None): """Test each value in the array for whether it is a missing value. @@ -2634,28 +2385,6 @@ def notnull(self, keep_attrs: bool | None = None): keep_attrs=keep_attrs, ) - @property - def real(self): - """ - The real part of the variable. - - See Also - -------- - numpy.ndarray.real - """ - return self._replace(data=self.data.real) - - @property - def imag(self): - """ - The imaginary part of the variable. - - See Also - -------- - numpy.ndarray.imag - """ - return self._replace(data=self.data.imag) - def __array_wrap__(self, obj, context=None): return Variable(self.dims, obj) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 82ffe684ec7..28740a99020 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -11,6 +11,7 @@ from xarray.core.computation import apply_ufunc, dot from xarray.core.pycompat import is_duck_dask_array from xarray.core.types import Dims, T_Xarray +from xarray.util.deprecation_helpers import _deprecate_positional_args # Weighted quantile methods are a subset of the numpy supported quantile methods. QUANTILE_METHODS = Literal[ @@ -324,6 +325,7 @@ def _weighted_quantile( def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray: """Return the interpolation parameter.""" # Note that options are not yet exposed in the public API. + h: np.ndarray if method == "linear": h = (n - 1) * q + 1 elif method == "interpolated_inverted_cdf": @@ -449,18 +451,22 @@ def _weighted_quantile_1d( def _implementation(self, func, dim, **kwargs): raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`") + @_deprecate_positional_args("v2023.10.0") def sum_of_weights( self, dim: Dims = None, + *, keep_attrs: bool | None = None, ) -> T_Xarray: return self._implementation( self._sum_of_weights, dim=dim, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def sum_of_squares( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -468,9 +474,11 @@ def sum_of_squares( self._sum_of_squares, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def sum( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -478,9 +486,11 @@ def sum( self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def mean( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -488,9 +498,11 @@ def mean( self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def var( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -498,9 +510,11 @@ def var( self._weighted_var, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def std( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: diff --git a/xarray/namedarray/__init__.py b/xarray/namedarray/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py new file mode 100644 index 00000000000..ec3d8fa171b --- /dev/null +++ b/xarray/namedarray/core.py @@ -0,0 +1,509 @@ +from __future__ import annotations + +import copy +import math +from collections.abc import Hashable, Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast + +import numpy as np + +# TODO: get rid of this after migrating this class to array API +from xarray.core import dtypes +from xarray.core.indexing import ExplicitlyIndexed +from xarray.namedarray.utils import ( + Default, + T_DuckArray, + _default, + is_chunked_duck_array, + is_duck_array, + is_duck_dask_array, + to_0d_object_array, +) + +if TYPE_CHECKING: + from xarray.namedarray.utils import Self # type: ignore[attr-defined] + + try: + from dask.typing import ( + Graph, + NestedKeys, + PostComputeCallable, + PostPersistCallable, + SchedulerGetCallable, + ) + except ImportError: + Graph: Any # type: ignore[no-redef] + NestedKeys: Any # type: ignore[no-redef] + SchedulerGetCallable: Any # type: ignore[no-redef] + PostComputeCallable: Any # type: ignore[no-redef] + PostPersistCallable: Any # type: ignore[no-redef] + + # T_NamedArray = TypeVar("T_NamedArray", bound="NamedArray[T_DuckArray]") + DimsInput = Union[str, Iterable[Hashable]] + Dims = tuple[Hashable, ...] + AttrsInput = Union[Mapping[Any, Any], None] + + +# TODO: Add tests! +def as_compatible_data( + data: T_DuckArray | np.typing.ArrayLike, fastpath: bool = False +) -> T_DuckArray: + if fastpath and getattr(data, "ndim", 0) > 0: + # can't use fastpath (yet) for scalars + return cast(T_DuckArray, data) + + if isinstance(data, np.ma.MaskedArray): + mask = np.ma.getmaskarray(data) # type: ignore[no-untyped-call] + if mask.any(): + # TODO: requires refactoring/vendoring xarray.core.dtypes and xarray.core.duck_array_ops + raise NotImplementedError("MaskedArray is not supported yet") + else: + return cast(T_DuckArray, np.asarray(data)) + if is_duck_array(data): + return data + if isinstance(data, NamedArray): + return cast(T_DuckArray, data.data) + + if isinstance(data, ExplicitlyIndexed): + # TODO: better that is_duck_array(ExplicitlyIndexed) -> True + return cast(T_DuckArray, data) + + if isinstance(data, tuple): + data = to_0d_object_array(data) + + # validate whether the data is valid data types. + return cast(T_DuckArray, np.asarray(data)) + + +class NamedArray(Generic[T_DuckArray]): + + """A lightweight wrapper around duck arrays with named dimensions and attributes which describe a single Array. + Numeric operations on this object implement array broadcasting and dimension alignment based on dimension names, + rather than axis order.""" + + __slots__ = ("_data", "_dims", "_attrs") + + _data: T_DuckArray + _dims: Dims + _attrs: dict[Any, Any] | None + + def __init__( + self, + dims: DimsInput, + data: T_DuckArray | np.typing.ArrayLike, + attrs: AttrsInput = None, + fastpath: bool = False, + ): + """ + Parameters + ---------- + dims : str or iterable of str + Name(s) of the dimension(s). + data : T_DuckArray or np.typing.ArrayLike + The actual data that populates the array. Should match the shape specified by `dims`. + attrs : dict, optional + A dictionary containing any additional information or attributes you want to store with the array. + Default is None, meaning no attributes will be stored. + fastpath : bool, optional + A flag to indicate if certain validations should be skipped for performance reasons. + Should only be True if you are certain about the integrity of the input data. + Default is False. + + Raises + ------ + ValueError + If the `dims` length does not match the number of data dimensions (ndim). + + + """ + self._data = as_compatible_data(data, fastpath=fastpath) + self._dims = self._parse_dimensions(dims) + self._attrs = dict(attrs) if attrs else None + + @property + def ndim(self) -> int: + """ + Number of array dimensions. + + See Also + -------- + numpy.ndarray.ndim + """ + return len(self.shape) + + @property + def size(self) -> int: + """ + Number of elements in the array. + + Equal to ``np.prod(a.shape)``, i.e., the product of the array’s dimensions. + + See Also + -------- + numpy.ndarray.size + """ + return math.prod(self.shape) + + def __len__(self) -> int: + try: + return self.shape[0] + except Exception as exc: + raise TypeError("len() of unsized object") from exc + + @property + def dtype(self) -> np.dtype[Any]: + """ + Data-type of the array’s elements. + + See Also + -------- + ndarray.dtype + numpy.dtype + """ + return self._data.dtype + + @property + def shape(self) -> tuple[int, ...]: + """ + + + Returns + ------- + shape : tuple of ints + Tuple of array dimensions. + + + + See Also + -------- + numpy.ndarray.shape + """ + return self._data.shape + + @property + def nbytes(self) -> int: + """ + Total bytes consumed by the elements of the data array. + + If the underlying data array does not include ``nbytes``, estimates + the bytes consumed based on the ``size`` and ``dtype``. + """ + if hasattr(self._data, "nbytes"): + return self._data.nbytes # type: ignore[no-any-return] + else: + return self.size * self.dtype.itemsize + + @property + def dims(self) -> Dims: + """Tuple of dimension names with which this NamedArray is associated.""" + return self._dims + + @dims.setter + def dims(self, value: DimsInput) -> None: + self._dims = self._parse_dimensions(value) + + def _parse_dimensions(self, dims: DimsInput) -> Dims: + dims = (dims,) if isinstance(dims, str) else tuple(dims) + if len(dims) != self.ndim: + raise ValueError( + f"dimensions {dims} must have the same length as the " + f"number of data dimensions, ndim={self.ndim}" + ) + return dims + + @property + def attrs(self) -> dict[Any, Any]: + """Dictionary of local attributes on this NamedArray.""" + if self._attrs is None: + self._attrs = {} + return self._attrs + + @attrs.setter + def attrs(self, value: Mapping[Any, Any]) -> None: + self._attrs = dict(value) + + def _check_shape(self, new_data: T_DuckArray) -> None: + if new_data.shape != self.shape: + raise ValueError( + f"replacement data must match the {self.__class__.__name__}'s shape. " + f"replacement data has shape {new_data.shape}; {self.__class__.__name__} has shape {self.shape}" + ) + + @property + def data(self) -> T_DuckArray: + """ + The NamedArray's data as an array. The underlying array type + (e.g. dask, sparse, pint) is preserved. + + """ + + return self._data + + @data.setter + def data(self, data: T_DuckArray | np.typing.ArrayLike) -> None: + data = as_compatible_data(data) + self._check_shape(data) + self._data = data + + @property + def real(self) -> Self: + """ + The real part of the NamedArray. + + See Also + -------- + numpy.ndarray.real + """ + return self._replace(data=self.data.real) + + @property + def imag(self) -> Self: + """ + The imaginary part of the NamedArray. + + See Also + -------- + numpy.ndarray.imag + """ + return self._replace(data=self.data.imag) + + def __dask_tokenize__(self) -> Hashable: + # Use v.data, instead of v._data, in order to cope with the wrappers + # around NetCDF and the like + from dask.base import normalize_token + + s, d, a, attrs = type(self), self._dims, self.data, self.attrs + return normalize_token((s, d, a, attrs)) # type: ignore[no-any-return] + + def __dask_graph__(self) -> Graph | None: + if is_duck_dask_array(self._data): + return self._data.__dask_graph__() + else: + # TODO: Should this method just raise instead? + # raise NotImplementedError("Method requires self.data to be a dask array") + return None + + def __dask_keys__(self) -> NestedKeys: + if is_duck_dask_array(self._data): + return self._data.__dask_keys__() + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def __dask_layers__(self) -> Sequence[str]: + if is_duck_dask_array(self._data): + return self._data.__dask_layers__() + else: + raise AttributeError("Method requires self.data to be a dask array.") + + @property + def __dask_optimize__( + self, + ) -> Callable[..., dict[Any, Any]]: + if is_duck_dask_array(self._data): + return self._data.__dask_optimize__ # type: ignore[no-any-return] + else: + raise AttributeError("Method requires self.data to be a dask array.") + + @property + def __dask_scheduler__(self) -> SchedulerGetCallable: + if is_duck_dask_array(self._data): + return self._data.__dask_scheduler__ + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def __dask_postcompute__( + self, + ) -> tuple[PostComputeCallable, tuple[Any, ...]]: + if is_duck_dask_array(self._data): + array_func, array_args = self._data.__dask_postcompute__() # type: ignore[no-untyped-call] + return self._dask_finalize, (array_func,) + array_args + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def __dask_postpersist__( + self, + ) -> tuple[ + Callable[ + [Graph, PostPersistCallable[Any], Any, Any], + Self, + ], + tuple[Any, ...], + ]: + if is_duck_dask_array(self._data): + a: tuple[PostPersistCallable[Any], tuple[Any, ...]] + a = self._data.__dask_postpersist__() # type: ignore[no-untyped-call] + array_func, array_args = a + + return self._dask_finalize, (array_func,) + array_args + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def _dask_finalize( + self, + results: Graph, + array_func: PostPersistCallable[Any], + *args: Any, + **kwargs: Any, + ) -> Self: + data = array_func(results, *args, **kwargs) + return type(self)(self._dims, data, attrs=self._attrs) + + @property + def chunks(self) -> tuple[tuple[int, ...], ...] | None: + """ + Tuple of block lengths for this NamedArray's data, in order of dimensions, or None if + the underlying data is not a dask array. + + See Also + -------- + NamedArray.chunk + NamedArray.chunksizes + xarray.unify_chunks + """ + data = self._data + if is_chunked_duck_array(data): + return data.chunks + else: + return None + + @property + def chunksizes( + self, + ) -> Mapping[Any, tuple[int, ...]]: + """ + Mapping from dimension names to block lengths for this namedArray's data, or None if + the underlying data is not a dask array. + Cannot be modified directly, but can be modified by calling .chunk(). + + Differs from NamedArray.chunks because it returns a mapping of dimensions to chunk shapes + instead of a tuple of chunk shapes. + + See Also + -------- + NamedArray.chunk + NamedArray.chunks + xarray.unify_chunks + """ + data = self._data + if is_chunked_duck_array(data): + return dict(zip(self.dims, data.chunks)) + else: + return {} + + @property + def sizes(self) -> dict[Hashable, int]: + """Ordered mapping from dimension names to lengths.""" + return dict(zip(self.dims, self.shape)) + + def _replace( + self, + dims: DimsInput | Default = _default, + data: T_DuckArray | np.typing.ArrayLike | Default = _default, + attrs: AttrsInput | Default = _default, + ) -> Self: + if dims is _default: + dims = copy.copy(self._dims) + if data is _default: + data = copy.copy(self._data) + if attrs is _default: + attrs = copy.copy(self._attrs) + return type(self)(dims, data, attrs) + + def _copy( + self, + deep: bool = True, + data: T_DuckArray | np.typing.ArrayLike | None = None, + memo: dict[int, Any] | None = None, + ) -> Self: + if data is None: + ndata = self._data + if deep: + ndata = copy.deepcopy(ndata, memo=memo) + else: + ndata = as_compatible_data(data) + self._check_shape(ndata) + + attrs = ( + copy.deepcopy(self._attrs, memo=memo) if deep else copy.copy(self._attrs) + ) + + return self._replace(data=ndata, attrs=attrs) + + def __copy__(self) -> Self: + return self._copy(deep=False) + + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: + return self._copy(deep=True, memo=memo) + + def copy( + self, + deep: bool = True, + data: T_DuckArray | np.typing.ArrayLike | None = None, + ) -> Self: + """Returns a copy of this object. + + If `deep=True`, the data array is loaded into memory and copied onto + the new object. Dimensions, attributes and encodings are always copied. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, default: True + Whether the data array is loaded into memory and copied onto + the new object. Default is True. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + When `data` is used, `deep` is ignored. + + Returns + ------- + object : NamedArray + New object with dimensions, attributes, and optionally + data copied from original. + + + """ + return self._copy(deep=deep, data=data) + + def _nonzero(self) -> tuple[Self, ...]: + """Equivalent numpy's nonzero but returns a tuple of NamedArrays.""" + # TODO we should replace dask's native nonzero + # after https://github.com/dask/dask/issues/1076 is implemented. + nonzeros = np.nonzero(self.data) + return tuple(type(self)((dim,), nz) for nz, dim in zip(nonzeros, self.dims)) + + def _as_sparse( + self, + sparse_format: str | Default = _default, + fill_value: np.typing.ArrayLike | Default = _default, + ) -> Self: + """ + use sparse-array as backend. + """ + import sparse + + # TODO: what to do if dask-backended? + if fill_value is _default: + dtype, fill_value = dtypes.maybe_promote(self.dtype) + else: + dtype = dtypes.result_type(self.dtype, fill_value) + + if sparse_format is _default: + sparse_format = "coo" + try: + as_sparse = getattr(sparse, f"as_{sparse_format.lower()}") + except AttributeError as exc: + raise ValueError(f"{sparse_format} is not a valid sparse format") from exc + + data = as_sparse(self.data.astype(dtype), fill_value=fill_value) + return self._replace(data=data) + + def _to_dense(self) -> Self: + """ + Change backend from sparse to np.array + """ + if hasattr(self._data, "todense"): + return self._replace(data=self._data.todense()) + return self.copy(deep=False) diff --git a/xarray/namedarray/dtypes.py b/xarray/namedarray/dtypes.py new file mode 100644 index 00000000000..7a83bd17064 --- /dev/null +++ b/xarray/namedarray/dtypes.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import functools +import sys +from typing import Any, Literal + +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + +import numpy as np + +from xarray.namedarray import utils + +# Use as a sentinel value to indicate a dtype appropriate NA value. +NA = utils.ReprObject("") + + +@functools.total_ordering +class AlwaysGreaterThan: + def __gt__(self, other: Any) -> Literal[True]: + return True + + def __eq__(self, other: Any) -> bool: + return isinstance(other, type(self)) + + +@functools.total_ordering +class AlwaysLessThan: + def __lt__(self, other: Any) -> Literal[True]: + return True + + def __eq__(self, other: Any) -> bool: + return isinstance(other, type(self)) + + +# Equivalence to np.inf (-np.inf) for object-type +INF = AlwaysGreaterThan() +NINF = AlwaysLessThan() + + +# Pairs of types that, if both found, should be promoted to object dtype +# instead of following NumPy's own type-promotion rules. These type promotion +# rules match pandas instead. For reference, see the NumPy type hierarchy: +# https://numpy.org/doc/stable/reference/arrays.scalars.html +PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ( + (np.number, np.character), # numpy promotes to character + (np.bool_, np.character), # numpy promotes to character + (np.bytes_, np.str_), # numpy promotes to unicode +) + + +def maybe_promote(dtype: np.dtype[np.generic]) -> tuple[np.dtype[np.generic], Any]: + """Simpler equivalent of pandas.core.common._maybe_promote + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + dtype : Promoted dtype that can hold missing values. + fill_value : Valid missing value for the promoted dtype. + """ + # N.B. these casting rules should match pandas + dtype_: np.typing.DTypeLike + fill_value: Any + if np.issubdtype(dtype, np.floating): + dtype_ = dtype + fill_value = np.nan + elif np.issubdtype(dtype, np.timedelta64): + # See https://github.com/numpy/numpy/issues/10685 + # np.timedelta64 is a subclass of np.integer + # Check np.timedelta64 before np.integer + fill_value = np.timedelta64("NaT") + dtype_ = dtype + elif np.issubdtype(dtype, np.integer): + dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64 + fill_value = np.nan + elif np.issubdtype(dtype, np.complexfloating): + dtype_ = dtype + fill_value = np.nan + np.nan * 1j + elif np.issubdtype(dtype, np.datetime64): + dtype_ = dtype + fill_value = np.datetime64("NaT") + else: + dtype_ = object + fill_value = np.nan + + dtype_out = np.dtype(dtype_) + fill_value = dtype_out.type(fill_value) + return dtype_out, fill_value + + +NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} + + +def get_fill_value(dtype: np.dtype[np.generic]) -> Any: + """Return an appropriate fill value for this dtype. + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + fill_value : Missing value corresponding to this dtype. + """ + _, fill_value = maybe_promote(dtype) + return fill_value + + +def get_pos_infinity( + dtype: np.dtype[np.generic], max_for_int: bool = False +) -> float | complex | AlwaysGreaterThan: + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + max_for_int : bool + Return np.iinfo(dtype).max instead of np.inf + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if issubclass(dtype.type, np.floating): + return np.inf + + if issubclass(dtype.type, np.integer): + return np.iinfo(dtype.type).max if max_for_int else np.inf + if issubclass(dtype.type, np.complexfloating): + return np.inf + 1j * np.inf + + return INF + + +def get_neg_infinity( + dtype: np.dtype[np.generic], min_for_int: bool = False +) -> float | complex | AlwaysLessThan: + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + min_for_int : bool + Return np.iinfo(dtype).min instead of -np.inf + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if issubclass(dtype.type, np.floating): + return -np.inf + + if issubclass(dtype.type, np.integer): + return np.iinfo(dtype.type).min if min_for_int else -np.inf + if issubclass(dtype.type, np.complexfloating): + return -np.inf - 1j * np.inf + + return NINF + + +def is_datetime_like( + dtype: np.dtype[np.generic], +) -> TypeGuard[np.datetime64 | np.timedelta64]: + """Check if a dtype is a subclass of the numpy datetime types""" + return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) + + +def result_type( + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, +) -> np.dtype[np.generic]: + """Like np.result_type, but with type promotion rules matching pandas. + + Examples of changed behavior: + number + string -> object (not string) + bytes + unicode -> object (not unicode) + + Parameters + ---------- + *arrays_and_dtypes : list of arrays and dtypes + The dtype is extracted from both numpy and dask arrays. + + Returns + ------- + numpy.dtype for the result. + """ + types = {np.result_type(t).type for t in arrays_and_dtypes} + + for left, right in PROMOTE_TO_OBJECT: + if any(issubclass(t, left) for t in types) and any( + issubclass(t, right) for t in types + ): + return np.dtype(object) + + return np.result_type(*arrays_and_dtypes) diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py new file mode 100644 index 00000000000..6f7658ea00b --- /dev/null +++ b/xarray/namedarray/utils.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import importlib +import sys +from collections.abc import Hashable +from enum import Enum +from typing import TYPE_CHECKING, Any, Final, Protocol, TypeVar + +import numpy as np + +if TYPE_CHECKING: + if sys.version_info >= (3, 10): + from typing import TypeGuard + else: + from typing_extensions import TypeGuard + + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + + try: + from dask.array import Array as DaskArray + from dask.types import DaskCollection + except ImportError: + DaskArray = np.ndarray # type: ignore + DaskCollection: Any = np.ndarray # type: ignore + + +# https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array +T_DType_co = TypeVar("T_DType_co", bound=np.dtype[np.generic], covariant=True) +# T_DType = TypeVar("T_DType", bound=np.dtype[np.generic]) + + +class _Array(Protocol[T_DType_co]): + @property + def dtype(self) -> T_DType_co: + ... + + @property + def shape(self) -> tuple[int, ...]: + ... + + @property + def real(self) -> Self: + ... + + @property + def imag(self) -> Self: + ... + + def astype(self, dtype: np.typing.DTypeLike) -> Self: + ... + + # TODO: numpy doesn't use any inputs: + # https://github.com/numpy/numpy/blob/v1.24.3/numpy/_typing/_array_like.py#L38 + def __array__(self) -> np.ndarray[Any, T_DType_co]: + ... + + +class _ChunkedArray(_Array[T_DType_co], Protocol[T_DType_co]): + @property + def chunks(self) -> tuple[tuple[int, ...], ...]: + ... + + +# temporary placeholder for indicating an array api compliant type. +# hopefully in the future we can narrow this down more +T_DuckArray = TypeVar("T_DuckArray", bound=_Array[np.dtype[np.generic]]) +T_ChunkedArray = TypeVar("T_ChunkedArray", bound=_ChunkedArray[np.dtype[np.generic]]) + + +# Singleton type, as per https://github.com/python/typing/pull/240 +class Default(Enum): + token: Final = 0 + + +_default = Default.token + + +def module_available(module: str) -> bool: + """Checks whether a module is installed without importing it. + + Use this for a lightweight check and lazy imports. + + Parameters + ---------- + module : str + Name of the module. + + Returns + ------- + available : bool + Whether the module is installed. + """ + return importlib.util.find_spec(module) is not None + + +def is_dask_collection(x: object) -> TypeGuard[DaskCollection]: + if module_available("dask"): + from dask.typing import DaskCollection + + return isinstance(x, DaskCollection) + return False + + +def is_duck_array(value: object) -> TypeGuard[T_DuckArray]: + if isinstance(value, np.ndarray): + return True + return ( + hasattr(value, "ndim") + and hasattr(value, "shape") + and hasattr(value, "dtype") + and ( + (hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__")) + or hasattr(value, "__array_namespace__") + ) + ) + + +def is_duck_dask_array(x: T_DuckArray) -> TypeGuard[DaskArray]: + return is_dask_collection(x) + + +def is_chunked_duck_array( + x: T_DuckArray, +) -> TypeGuard[_ChunkedArray[np.dtype[np.generic]]]: + return hasattr(x, "chunks") + + +def to_0d_object_array( + value: object, +) -> np.ndarray[Any, np.dtype[np.object_]]: + """Given a value, wrap it in a 0-D numpy.ndarray with dtype=object.""" + result = np.empty((), dtype=object) + result[()] = value + return result + + +class ReprObject: + """Object that prints as the given value, for use with sentinel values.""" + + __slots__ = ("_value",) + + _value: str + + def __init__(self, value: str): + self._value = value + + def __repr__(self) -> str: + return self._value + + def __eq__(self, other: ReprObject | Any) -> bool: + # TODO: What type can other be? ArrayLike? + return self._value == other._value if isinstance(other, ReprObject) else False + + def __hash__(self) -> int: + return hash((type(self), self._value)) + + def __dask_tokenize__(self) -> Hashable: + from dask.base import normalize_token + + return normalize_token((type(self), self._value)) # type: ignore[no-any-return] diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 8e930d0731c..61f2014fbc3 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -1348,7 +1348,7 @@ def _plot2d(plotfunc): `seaborn color palette `_. Note: if ``cmap`` is a seaborn color palette and the plot type is not ``'contour'`` or ``'contourf'``, ``levels`` must also be specified. - center : float, optional + center : float or False, optional The value at which to center the colormap. Passing this value implies use of a diverging colormap. Setting it to ``False`` prevents use of a diverging colormap. @@ -1432,7 +1432,7 @@ def newplotfunc( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1692,7 +1692,7 @@ def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1733,7 +1733,7 @@ def imshow( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1774,7 +1774,7 @@ def imshow( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1911,7 +1911,7 @@ def contour( # type: ignore[misc,unused-ignore] # None is hashable :( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1952,7 +1952,7 @@ def contour( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1993,7 +1993,7 @@ def contour( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2047,7 +2047,7 @@ def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2088,7 +2088,7 @@ def contourf( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2129,7 +2129,7 @@ def contourf( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2183,7 +2183,7 @@ def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2224,7 +2224,7 @@ def pcolormesh( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2265,7 +2265,7 @@ def pcolormesh( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2370,7 +2370,7 @@ def surface( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2411,7 +2411,7 @@ def surface( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2452,7 +2452,7 @@ def surface( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 5bd517098f1..9ec67bf47dc 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -714,9 +714,6 @@ def multiple_indexing(indexers): ] multiple_indexing(indexers5) - @pytest.mark.xfail( - reason="zarr without dask handles negative steps in slices incorrectly", - ) def test_vectorized_indexing_negative_step(self) -> None: # use dask explicitly when present open_kwargs: dict[str, Any] | None @@ -1842,8 +1839,8 @@ def test_unsorted_index_raises(self) -> None: # dask first pulls items by block. pass + @pytest.mark.skip(reason="caching behavior differs for dask") def test_dataset_caching(self) -> None: - # caching behavior differs for dask pass def test_write_inconsistent_chunks(self) -> None: @@ -2261,9 +2258,6 @@ def test_encoding_kwarg_fixed_width_string(self) -> None: # not relevant for zarr, since we don't use EncodedStringCoder pass - # TODO: someone who understand caching figure out whether caching - # makes sense for Zarr backend - @pytest.mark.xfail(reason="Zarr caching not implemented") def test_dataset_caching(self) -> None: super().test_dataset_caching() @@ -2712,6 +2706,14 @@ def test_attributes(self, obj) -> None: with pytest.raises(TypeError, match=r"Invalid attribute in Dataset.attrs."): ds.to_zarr(store_target, **self.version_kwargs) + def test_vectorized_indexing_negative_step(self) -> None: + if not has_dask: + pytest.xfail( + reason="zarr without dask handles negative steps in slices incorrectly" + ) + + super().test_vectorized_indexing_negative_step() + @requires_zarr class TestZarrDictStore(ZarrBase): @@ -3378,8 +3380,8 @@ def roundtrip( ) as ds: yield ds + @pytest.mark.skip(reason="caching behavior differs for dask") def test_dataset_caching(self) -> None: - # caching behavior differs for dask pass def test_write_inconsistent_chunks(self) -> None: @@ -3457,6 +3459,7 @@ def skip_if_not_engine(engine): @requires_dask @pytest.mark.filterwarnings("ignore:use make_scale(name) instead") +@pytest.mark.xfail(reason="Flaky test. Very open to contributions on fixing this") def test_open_mfdataset_manyfiles( readengine, nfiles, parallel, chunks, file_cache_maxsize ): @@ -3982,7 +3985,6 @@ def test_open_mfdataset_raise_on_bad_combine_args(self) -> None: with pytest.raises(ValueError, match="`concat_dim` has no effect"): open_mfdataset([tmp1, tmp2], concat_dim="x") - @pytest.mark.xfail(reason="mfdataset loses encoding currently.") def test_encoding_mfdataset(self) -> None: original = Dataset( { @@ -4195,7 +4197,6 @@ def test_dataarray_compute(self) -> None: assert computed._in_memory assert_allclose(actual, computed, decode_bytes=False) - @pytest.mark.xfail def test_save_mfdataset_compute_false_roundtrip(self) -> None: from dask.delayed import Delayed @@ -5125,15 +5126,17 @@ def test_open_fsspec() -> None: ds2 = open_dataset(url, engine="zarr") xr.testing.assert_equal(ds0, ds2) - # multi dataset - url = "memory://out*.zarr" - ds2 = open_mfdataset(url, engine="zarr") - xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) - - # multi dataset with caching - url = "simplecache::memory://out*.zarr" - ds2 = open_mfdataset(url, engine="zarr") - xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) + # open_mfdataset requires dask + if has_dask: + # multi dataset + url = "memory://out*.zarr" + ds2 = open_mfdataset(url, engine="zarr") + xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) + + # multi dataset with caching + url = "simplecache::memory://out*.zarr" + ds2 = open_mfdataset(url, engine="zarr") + xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) @requires_h5netcdf diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index f58a6490632..1a1df6b81fe 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -1135,7 +1135,6 @@ def test_to_datetimeindex_feb_29(calendar): @requires_cftime -@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/24263") def test_multiindex(): index = xr.cftime_range("2001-01-01", periods=100, calendar="360_day") mindex = pd.MultiIndex.from_arrays([index]) diff --git a/xarray/tests/test_coarsen.py b/xarray/tests/test_coarsen.py index e345ae691ec..01d5393e289 100644 --- a/xarray/tests/test_coarsen.py +++ b/xarray/tests/test_coarsen.py @@ -6,6 +6,7 @@ import xarray as xr from xarray import DataArray, Dataset, set_options +from xarray.core import duck_array_ops from xarray.tests import ( assert_allclose, assert_equal, @@ -272,21 +273,24 @@ def test_coarsen_construct(self, dask: bool) -> None: expected = xr.Dataset(attrs={"foo": "bar"}) expected["vart"] = ( ("year", "month"), - ds.vart.data.reshape((-1, 12)), + duck_array_ops.reshape(ds.vart.data, (-1, 12)), {"a": "b"}, ) expected["varx"] = ( ("x", "x_reshaped"), - ds.varx.data.reshape((-1, 5)), + duck_array_ops.reshape(ds.varx.data, (-1, 5)), {"a": "b"}, ) expected["vartx"] = ( ("x", "x_reshaped", "year", "month"), - ds.vartx.data.reshape(2, 5, 4, 12), + duck_array_ops.reshape(ds.vartx.data, (2, 5, 4, 12)), {"a": "b"}, ) expected["vary"] = ds.vary - expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12))) + expected.coords["time"] = ( + ("year", "month"), + duck_array_ops.reshape(ds.time.data, (-1, 12)), + ) with raise_if_dask_computes(): actual = ds.coarsen(time=12, x=5).construct( diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 079e432b565..423e48bd155 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -30,7 +30,7 @@ from xarray.coding.variables import SerializationWarning from xarray.conventions import _update_bounds_attributes, cf_encoder from xarray.core.common import contains_cftime_datetimes -from xarray.testing import assert_allclose, assert_equal, assert_identical +from xarray.testing import assert_equal, assert_identical from xarray.tests import ( FirstElementAccessibleArray, arm_xfail, @@ -1036,7 +1036,7 @@ def test_encode_cf_datetime_defaults_to_correct_dtype( pytest.skip("Nanosecond frequency is not valid for cftime dates.") times = date_range("2000", periods=3, freq=freq) units = f"{encoding_units} since 2000-01-01" - encoded, _, _ = coding.times.encode_cf_datetime(times, units) + encoded, _units, _ = coding.times.encode_cf_datetime(times, units) numpy_timeunit = coding.times._netcdf_to_numpy_timeunit(encoding_units) encoding_units_as_timedelta = np.timedelta64(1, numpy_timeunit) @@ -1212,6 +1212,7 @@ def test_contains_cftime_lazy() -> None: ("1677-09-21T00:12:43.145224193", "ns", np.int64, None, False), ("1677-09-21T00:12:43.145225", "us", np.int64, None, False), ("1970-01-01T00:00:01.000001", "us", np.int64, None, False), + ("1677-09-21T00:21:52.901038080", "ns", np.float32, 20.0, True), ], ) def test_roundtrip_datetime64_nanosecond_precision( @@ -1261,14 +1262,52 @@ def test_roundtrip_datetime64_nanosecond_precision_warning() -> None: ] units = "days since 1970-01-10T01:01:00" needed_units = "hours" - encoding = dict(_FillValue=20, units=units) + new_units = f"{needed_units} since 1970-01-10T01:01:00" + + encoding = dict(dtype=None, _FillValue=20, units=units) var = Variable(["time"], times, encoding=encoding) - wmsg = ( - f"Times can't be serialized faithfully with requested units {units!r}. " - f"Resolution of {needed_units!r} needed. " - ) - with pytest.warns(UserWarning, match=wmsg): + with pytest.warns(UserWarning, match=f"Resolution of {needed_units!r} needed."): + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.float64 + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == 20.0 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="int64", _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with pytest.warns( + UserWarning, match=f"Serializing with units {new_units!r} instead." + ): encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == new_units + assert encoded_var.attrs["_FillValue"] == 20 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="float64", _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with warnings.catch_warnings(): + warnings.simplefilter("error") + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.float64 + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == 20.0 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="int64", _FillValue=20, units=new_units) + var = Variable(["time"], times, encoding=encoding) + with warnings.catch_warnings(): + warnings.simplefilter("error") + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == new_units + assert encoded_var.attrs["_FillValue"] == 20 decoded_var = conventions.decode_cf_variable("foo", encoded_var) assert_identical(var, decoded_var) @@ -1309,21 +1348,30 @@ def test_roundtrip_timedelta64_nanosecond_precision_warning() -> None: needed_units = "hours" wmsg = ( f"Timedeltas can't be serialized faithfully with requested units {units!r}. " - f"Resolution of {needed_units!r} needed. " + f"Serializing with units {needed_units!r} instead." ) - encoding = dict(_FillValue=20, units=units) + encoding = dict(dtype=np.int64, _FillValue=20, units=units) var = Variable(["time"], timedelta_values, encoding=encoding) with pytest.warns(UserWarning, match=wmsg): encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == needed_units + assert encoded_var.attrs["_FillValue"] == 20 decoded_var = conventions.decode_cf_variable("foo", encoded_var) - assert_allclose(var, decoded_var) + assert_identical(var, decoded_var) + assert decoded_var.encoding["dtype"] == np.int64 def test_roundtrip_float_times() -> None: + # Regression test for GitHub issue #8271 fill_value = 20.0 - times = [np.datetime64("2000-01-01 12:00:00", "ns"), np.datetime64("NaT", "ns")] + times = [ + np.datetime64("1970-01-01 00:00:00", "ns"), + np.datetime64("1970-01-01 06:00:00", "ns"), + np.datetime64("NaT", "ns"), + ] - units = "days since 2000-01-01" + units = "days since 1960-01-01" var = Variable( ["time"], times, @@ -1331,7 +1379,7 @@ def test_roundtrip_float_times() -> None: ) encoded_var = conventions.encode_cf_variable(var) - np.testing.assert_array_equal(encoded_var, np.array([0.5, 20.0])) + np.testing.assert_array_equal(encoded_var, np.array([3653, 3653.25, 20.0])) assert encoded_var.attrs["units"] == units assert encoded_var.attrs["_FillValue"] == fill_value diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index b75e80db2da..87f8328e441 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1190,7 +1190,7 @@ def test_apply_dask() -> None: # unknown setting for dask array handling with pytest.raises(ValueError): - apply_ufunc(identity, array, dask="unknown") + apply_ufunc(identity, array, dask="unknown") # type: ignore def dask_safe_identity(x): return apply_ufunc(identity, x, dask="allowed") diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 4dae7809be9..5157688b629 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -80,6 +80,28 @@ def test_decode_cf_with_conflicting_fill_missing_value() -> None: assert_identical(actual, expected) +def test_decode_cf_variable_with_mismatched_coordinates() -> None: + # tests for decoding mismatched coordinates attributes + # see GH #1809 + zeros1 = np.zeros((1, 5, 3)) + orig = Dataset( + { + "XLONG": (["x", "y"], zeros1.squeeze(0), {}), + "XLAT": (["x", "y"], zeros1.squeeze(0), {}), + "foo": (["time", "x", "y"], zeros1, {"coordinates": "XTIME XLONG XLAT"}), + "time": ("time", [0.0], {"units": "hours since 2017-01-01"}), + } + ) + decoded = conventions.decode_cf(orig, decode_coords=True) + assert decoded["foo"].encoding["coordinates"] == "XTIME XLONG XLAT" + assert list(decoded.coords.keys()) == ["XLONG", "XLAT", "time"] + + decoded = conventions.decode_cf(orig, decode_coords=False) + assert "coordinates" not in decoded["foo"].encoding + assert decoded["foo"].attrs.get("coordinates") == "XTIME XLONG XLAT" + assert list(decoded.coords.keys()) == ["time"] + + @requires_cftime class TestEncodeCFVariable: def test_incompatible_attributes(self) -> None: @@ -246,9 +268,12 @@ def test_dataset(self) -> None: assert_identical(expected, actual) def test_invalid_coordinates(self) -> None: - # regression test for GH308 + # regression test for GH308, GH1809 original = Dataset({"foo": ("t", [1, 2], {"coordinates": "invalid"})}) + decoded = Dataset({"foo": ("t", [1, 2], {}, {"coordinates": "invalid"})}) actual = conventions.decode_cf(original) + assert_identical(decoded, actual) + actual = conventions.decode_cf(original, decode_coords=False) assert_identical(original, actual) def test_decode_coordinates(self) -> None: diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 66bc69966d2..5eb5394d58e 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -38,6 +38,7 @@ from xarray.core.indexes import Index, PandasIndex, filter_indexes_from_coords from xarray.core.types import QueryEngineOptions, QueryParserOptions from xarray.core.utils import is_scalar +from xarray.testing import _assert_internal_invariants from xarray.tests import ( InaccessibleArray, ReturnItem, @@ -286,7 +287,7 @@ def test_encoding(self) -> None: self.dv.encoding = expected2 assert expected2 is not self.dv.encoding - def test_reset_encoding(self) -> None: + def test_drop_encoding(self) -> None: array = self.mda encoding = {"scale_factor": 10} array.encoding = encoding @@ -295,7 +296,7 @@ def test_reset_encoding(self) -> None: assert array.encoding == encoding assert array["x"].encoding == encoding - actual = array.reset_encoding() + actual = array.drop_encoding() # did not modify in place assert array.encoding == encoding @@ -415,9 +416,6 @@ def test_constructor_invalid(self) -> None: with pytest.raises(ValueError, match=r"conflicting MultiIndex"): DataArray(np.random.rand(4, 4), [("x", self.mindex), ("level_1", range(4))]) - with pytest.raises(ValueError, match=r"matching the dimension size"): - DataArray(data, coords={"x": 0}, dims=["x", "y"]) - def test_constructor_from_self_described(self) -> None: data = [[-0.1, 21], [0, 2]] expected = DataArray( @@ -1885,6 +1883,16 @@ def test_rename_dimension_coord_warnings(self) -> None: ): da.rename(x="y") + # No operation should not raise a warning + da = xr.DataArray( + data=np.ones((2, 3)), + dims=["x", "y"], + coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])}, + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + da.rename(x="x") + def test_init_value(self) -> None: expected = DataArray( np.full((3, 4), 3), dims=["x", "y"], coords=[range(3), range(4)] @@ -2719,6 +2727,14 @@ def test_where_lambda(self) -> None: actual = arr.where(lambda x: x.y < 2, drop=True) assert_identical(actual, expected) + def test_where_other_lambda(self) -> None: + arr = DataArray(np.arange(4), dims="y") + expected = xr.concat( + [arr.sel(y=slice(2)), arr.sel(y=slice(2, None)) + 1], dim="y" + ) + actual = arr.where(lambda x: x.y < 2, lambda x: x + 1) + assert_identical(actual, expected) + def test_where_string(self) -> None: array = DataArray(["a", "b"]) expected = DataArray(np.array(["a", np.nan], dtype=object)) @@ -4011,7 +4027,7 @@ def test_dot(self) -> None: assert_equal(expected5, actual5) with pytest.raises(NotImplementedError): - da.dot(dm3.to_dataset(name="dm")) # type: ignore + da.dot(dm3.to_dataset(name="dm")) with pytest.raises(TypeError): da.dot(dm3.values) # type: ignore @@ -7112,3 +7128,35 @@ def test_error_on_ellipsis_without_list(self) -> None: da = DataArray([[1, 2], [1, 2]], dims=("x", "y")) with pytest.raises(ValueError): da.stack(flat=...) # type: ignore + + +def test_nD_coord_dataarray() -> None: + # should succeed + da = DataArray( + np.ones((2, 4)), + dims=("x", "y"), + coords={ + "x": (("x", "y"), np.arange(8).reshape((2, 4))), + "y": ("y", np.arange(4)), + }, + ) + _assert_internal_invariants(da, check_default_indexes=True) + + da2 = DataArray(np.ones(4), dims=("y"), coords={"y": ("y", np.arange(4))}) + da3 = DataArray(np.ones(4), dims=("z")) + + _, actual = xr.align(da, da2) + assert_identical(da2, actual) + + expected = da.drop_vars("x") + _, actual = xr.broadcast(da, da2) + assert_identical(expected, actual) + + actual, _ = xr.broadcast(da, da3) + expected = da.expand_dims(z=4, axis=-1) + assert_identical(actual, expected) + + da4 = DataArray(np.ones((2, 4)), coords={"x": 0}, dims=["x", "y"]) + _assert_internal_invariants(da4, check_default_indexes=True) + assert "x" not in da4.xindexes + assert "x" in da4.coords diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c832663ecff..687aae8f1dc 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -411,10 +411,14 @@ def test_repr_nep18(self) -> None: class Array: def __init__(self): self.shape = (2,) + self.ndim = 1 self.dtype = np.dtype(np.float64) def __array_function__(self, *args, **kwargs): - pass + return NotImplemented + + def __array_ufunc__(self, *args, **kwargs): + return NotImplemented def __repr__(self): return "Custom\nArray" @@ -2328,9 +2332,9 @@ def test_align(self) -> None: assert np.isnan(left2["var3"][-2:]).all() with pytest.raises(ValueError, match=r"invalid value for join"): - align(left, right, join="foobar") # type: ignore[arg-type] + align(left, right, join="foobar") # type: ignore[call-overload] with pytest.raises(TypeError): - align(left, right, foo="bar") # type: ignore[call-arg] + align(left, right, foo="bar") # type: ignore[call-overload] def test_align_exact(self) -> None: left = xr.Dataset(coords={"x": [0, 1]}) @@ -2955,7 +2959,7 @@ def test_copy_with_data_errors(self) -> None: with pytest.raises(ValueError, match=r"contain all variables in original"): orig.copy(data={"var1": new_var1}) - def test_reset_encoding(self) -> None: + def test_drop_encoding(self) -> None: orig = create_test_data() vencoding = {"scale_factor": 10} orig.encoding = {"foo": "bar"} @@ -2963,7 +2967,7 @@ def test_reset_encoding(self) -> None: for k, v in orig.variables.items(): orig[k].encoding = vencoding - actual = orig.reset_encoding() + actual = orig.drop_encoding() assert actual.encoding == {} for k, v in actual.variables.items(): assert v.encoding == {} @@ -3028,8 +3032,7 @@ def test_rename_old_name(self) -> None: def test_rename_same_name(self) -> None: data = create_test_data() newnames = {"var1": "var1", "dim2": "dim2"} - with pytest.warns(UserWarning, match="does not create an index anymore"): - renamed = data.rename(newnames) + renamed = data.rename(newnames) assert_identical(renamed, data) def test_rename_dims(self) -> None: @@ -3099,6 +3102,15 @@ def test_rename_dimension_coord_warnings(self) -> None: ): ds.rename(x="y") + # No operation should not raise a warning + ds = Dataset( + data_vars={"data": (("x", "y"), np.ones((2, 3)))}, + coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])}, + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + ds.rename(x="x") + def test_rename_multiindex(self) -> None: midx = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) midx_coords = Coordinates.from_pandas_multiindex(midx, "x") @@ -4116,7 +4128,8 @@ def test_setitem(self) -> None: data4[{"dim2": [2, 3]}] = data3["var1"][{"dim2": [3, 4]}].values data5 = data4.astype(str) data5["var4"] = data4["var1"] - err_msg = "could not convert string to float: 'a'" + # convert to `np.str_('a')` once `numpy<2.0` has been dropped + err_msg = "could not convert string to float: .*'a'.*" with pytest.raises(ValueError, match=err_msg): data5[{"dim2": 1}] = "a" @@ -5046,9 +5059,9 @@ def test_dropna(self) -> None: ): ds.dropna("foo") with pytest.raises(ValueError, match=r"invalid how"): - ds.dropna("a", how="somehow") # type: ignore + ds.dropna("a", how="somehow") # type: ignore[arg-type] with pytest.raises(TypeError, match=r"must specify how or thresh"): - ds.dropna("a", how=None) # type: ignore + ds.dropna("a", how=None) # type: ignore[arg-type] def test_fillna(self) -> None: ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]}) diff --git a/xarray/tests/test_deprecation_helpers.py b/xarray/tests/test_deprecation_helpers.py index 35128829073..f21c8097060 100644 --- a/xarray/tests/test_deprecation_helpers.py +++ b/xarray/tests/test_deprecation_helpers.py @@ -15,15 +15,15 @@ def f1(a, b, *, c="c", d="d"): assert result == (1, 2, 3, 4) with pytest.warns(FutureWarning, match=r".*v0.1"): - result = f1(1, 2, 3) + result = f1(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): - result = f1(1, 2, 3) + result = f1(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): - result = f1(1, 2, 3, 4) + result = f1(1, 2, 3, 4) # type: ignore[misc] assert result == (1, 2, 3, 4) @_deprecate_positional_args("v0.1") @@ -31,7 +31,7 @@ def f2(a="a", *, b="b", c="c", d="d"): return a, b, c, d with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f2(1, 2) + result = f2(1, 2) # type: ignore[misc] assert result == (1, 2, "c", "d") @_deprecate_positional_args("v0.1") @@ -39,11 +39,11 @@ def f3(a, *, b="b", **kwargs): return a, b, kwargs with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f3(1, 2) + result = f3(1, 2) # type: ignore[misc] assert result == (1, 2, {}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f3(1, 2, f="f") + result = f3(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) @_deprecate_positional_args("v0.1") @@ -57,7 +57,7 @@ def f4(a, /, *, b="b", **kwargs): assert result == (1, 2, {"f": "f"}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f4(1, 2, f="f") + result = f4(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) with pytest.raises(TypeError, match=r"Keyword-only param without default"): @@ -80,15 +80,15 @@ def method(self, a, b, *, c="c", d="d"): assert result == (1, 2, 3, 4) with pytest.warns(FutureWarning, match=r".*v0.1"): - result = A1().method(1, 2, 3) + result = A1().method(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): - result = A1().method(1, 2, 3) + result = A1().method(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): - result = A1().method(1, 2, 3, 4) + result = A1().method(1, 2, 3, 4) # type: ignore[misc] assert result == (1, 2, 3, 4) class A2: @@ -97,11 +97,11 @@ def method(self, a=1, b=1, *, c="c", d="d"): return a, b, c, d with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): - result = A2().method(1, 2, 3) + result = A2().method(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): - result = A2().method(1, 2, 3, 4) + result = A2().method(1, 2, 3, 4) # type: ignore[misc] assert result == (1, 2, 3, 4) class A3: @@ -110,11 +110,11 @@ def method(self, a, *, b="b", **kwargs): return a, b, kwargs with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = A3().method(1, 2) + result = A3().method(1, 2) # type: ignore[misc] assert result == (1, 2, {}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = A3().method(1, 2, f="f") + result = A3().method(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) class A4: @@ -129,7 +129,7 @@ def method(self, a, /, *, b="b", **kwargs): assert result == (1, 2, {"f": "f"}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = A4().method(1, 2, f="f") + result = A4().method(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) with pytest.raises(TypeError, match=r"Keyword-only param without default"): diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 6a8cd9c457b..bfc37121597 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -168,6 +168,10 @@ def test_open_mfdataset_multiple_files_parallel_distributed(parallel, tmp_path): @requires_netCDF4 @pytest.mark.parametrize("parallel", (True, False)) def test_open_mfdataset_multiple_files_parallel(parallel, tmp_path): + if parallel: + pytest.skip( + "Flaky in CI. Would be a welcome contribution to make a similar test reliable." + ) lon = np.arange(100) time = xr.cftime_range("20010101", periods=100, calendar="360_day") data = np.random.random((time.size, lon.size)) diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 7670b77322c..5ca134503e8 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -549,7 +549,7 @@ def _repr_inline_(self, width): return formatted - def __array_function__(self, *args, **kwargs): + def __array_namespace__(self, *args, **kwargs): return NotImplemented @property diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index e143e2b8e03..320ba999318 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -810,9 +810,9 @@ def test_groupby_math_more() -> None: with pytest.raises(TypeError, match=r"only support binary ops"): grouped + 1 # type: ignore[operator] with pytest.raises(TypeError, match=r"only support binary ops"): - grouped + grouped + grouped + grouped # type: ignore[operator] with pytest.raises(TypeError, match=r"in-place operations"): - ds += grouped + ds += grouped # type: ignore[arg-type] ds = Dataset( { diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index fe2cdc58807..c57d84c927d 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -92,26 +92,36 @@ def make_interpolate_example_data(shape, frac_nan, seed=12345, non_uniform=False return da, df +@pytest.mark.parametrize("fill_value", [None, np.nan, 47.11]) +@pytest.mark.parametrize( + "method", ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"] +) @requires_scipy -def test_interpolate_pd_compat(): +def test_interpolate_pd_compat(method, fill_value) -> None: shapes = [(8, 8), (1, 20), (20, 1), (100, 100)] frac_nans = [0, 0.5, 1] - methods = ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"] - for shape, frac_nan, method in itertools.product(shapes, frac_nans, methods): + for shape, frac_nan in itertools.product(shapes, frac_nans): da, df = make_interpolate_example_data(shape, frac_nan) for dim in ["time", "x"]: - actual = da.interpolate_na(method=method, dim=dim, fill_value=np.nan) + actual = da.interpolate_na(method=method, dim=dim, fill_value=fill_value) + # need limit_direction="both" here, to let pandas fill + # in both directions instead of default forward direction only expected = df.interpolate( method=method, axis=da.get_axis_num(dim), + limit_direction="both", + fill_value=fill_value, ) - # Note, Pandas does some odd things with the left/right fill_value - # for the linear methods. This next line inforces the xarray - # fill_value convention on the pandas output. Therefore, this test - # only checks that interpolated values are the same (not nans) - expected.values[pd.isnull(actual.values)] = np.nan + + if method == "linear": + # Note, Pandas does not take left/right fill_value into account + # for the numpy linear methods. + # see https://github.com/pandas-dev/pandas/issues/55144 + # This aligns the pandas output with the xarray output + expected.values[pd.isnull(actual.values)] = np.nan + expected.values[actual.values == fill_value] = fill_value np.testing.assert_allclose(actual.values, expected.values) diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py new file mode 100644 index 00000000000..ea1588bf554 --- /dev/null +++ b/xarray/tests/test_namedarray.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest + +import xarray as xr +from xarray.namedarray.core import NamedArray, as_compatible_data +from xarray.namedarray.utils import T_DuckArray + +if TYPE_CHECKING: + from xarray.namedarray.utils import Self # type: ignore[attr-defined] + + +@pytest.fixture +def random_inputs() -> np.ndarray[Any, np.dtype[np.float32]]: + return np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + ([1, 2, 3], np.array([1, 2, 3])), + (np.array([4, 5, 6]), np.array([4, 5, 6])), + (NamedArray("time", np.array([1, 2, 3])), np.array([1, 2, 3])), + (2, np.array(2)), + ], +) +def test_as_compatible_data( + input_data: T_DuckArray, expected_output: T_DuckArray +) -> None: + output: T_DuckArray = as_compatible_data(input_data) + assert np.array_equal(output, expected_output) + + +def test_as_compatible_data_with_masked_array() -> None: + masked_array = np.ma.array([1, 2, 3], mask=[False, True, False]) # type: ignore[no-untyped-call] + with pytest.raises(NotImplementedError): + as_compatible_data(masked_array) + + +def test_as_compatible_data_with_0d_object() -> None: + data = np.empty((), dtype=object) + data[()] = (10, 12, 12) + np.array_equal(as_compatible_data(data), data) + + +def test_as_compatible_data_with_explicitly_indexed( + random_inputs: np.ndarray[Any, Any] +) -> None: + # TODO: Make xr.core.indexing.ExplicitlyIndexed pass is_duck_array and remove this test. + class CustomArrayBase(xr.core.indexing.NDArrayMixin): + def __init__(self, array: T_DuckArray) -> None: + self.array = array + + @property + def dtype(self) -> np.dtype[np.generic]: + return self.array.dtype + + @property + def shape(self) -> tuple[int, ...]: + return self.array.shape + + @property + def real(self) -> Self: + raise NotImplementedError + + @property + def imag(self) -> Self: + raise NotImplementedError + + def astype(self, dtype: np.typing.DTypeLike) -> Self: + raise NotImplementedError + + class CustomArray(CustomArrayBase): + def __array__(self) -> np.ndarray[Any, np.dtype[np.generic]]: + return np.array(self.array) + + class CustomArrayIndexable(CustomArrayBase, xr.core.indexing.ExplicitlyIndexed): + pass + + array = CustomArray(random_inputs) + output: CustomArray = as_compatible_data(array) + assert isinstance(output, np.ndarray) + + array2 = CustomArrayIndexable(random_inputs) + output2: CustomArrayIndexable = as_compatible_data(array2) + assert isinstance(output2, CustomArrayIndexable) + + +def test_properties() -> None: + data = 0.5 * np.arange(10).reshape(2, 5) + named_array: NamedArray[np.ndarray[Any, Any]] + named_array = NamedArray(["x", "y"], data, {"key": "value"}) + assert named_array.dims == ("x", "y") + assert np.array_equal(named_array.data, data) + assert named_array.attrs == {"key": "value"} + assert named_array.ndim == 2 + assert named_array.sizes == {"x": 2, "y": 5} + assert named_array.size == 10 + assert named_array.nbytes == 80 + assert len(named_array) == 2 + + +def test_attrs() -> None: + named_array: NamedArray[np.ndarray[Any, Any]] + named_array = NamedArray(["x", "y"], np.arange(10).reshape(2, 5)) + assert named_array.attrs == {} + named_array.attrs["key"] = "value" + assert named_array.attrs == {"key": "value"} + named_array.attrs = {"key": "value2"} + assert named_array.attrs == {"key": "value2"} + + +def test_data(random_inputs: np.ndarray[Any, Any]) -> None: + named_array: NamedArray[np.ndarray[Any, Any]] + named_array = NamedArray(["x", "y", "z"], random_inputs) + assert np.array_equal(named_array.data, random_inputs) + with pytest.raises(ValueError): + named_array.data = np.random.random((3, 4)).astype(np.float64) + + +# Additional tests as per your original class-based code +@pytest.mark.parametrize( + "data, dtype", + [ + ("foo", np.dtype("U3")), + (np.bytes_("foo"), np.dtype("S3")), + ], +) +def test_0d_string(data: Any, dtype: np.typing.DTypeLike) -> None: + named_array: NamedArray[np.ndarray[Any, Any]] + named_array = NamedArray([], data) + assert named_array.data == data + assert named_array.dims == () + assert named_array.sizes == {} + assert named_array.attrs == {} + assert named_array.ndim == 0 + assert named_array.size == 1 + assert named_array.dtype == dtype + + +def test_0d_object() -> None: + named_array: NamedArray[np.ndarray[Any, Any]] + named_array = NamedArray([], (10, 12, 12)) + expected_data = np.empty((), dtype=object) + expected_data[()] = (10, 12, 12) + assert np.array_equal(named_array.data, expected_data) + + assert named_array.dims == () + assert named_array.sizes == {} + assert named_array.attrs == {} + assert named_array.ndim == 0 + assert named_array.size == 1 + assert named_array.dtype == np.dtype("O") + + +def test_0d_datetime() -> None: + named_array: NamedArray[np.ndarray[Any, Any]] + named_array = NamedArray([], np.datetime64("2000-01-01")) + assert named_array.dtype == np.dtype("datetime64[D]") + + +@pytest.mark.parametrize( + "timedelta, expected_dtype", + [ + (np.timedelta64(1, "D"), np.dtype("timedelta64[D]")), + (np.timedelta64(1, "s"), np.dtype("timedelta64[s]")), + (np.timedelta64(1, "m"), np.dtype("timedelta64[m]")), + (np.timedelta64(1, "h"), np.dtype("timedelta64[h]")), + (np.timedelta64(1, "us"), np.dtype("timedelta64[us]")), + (np.timedelta64(1, "ns"), np.dtype("timedelta64[ns]")), + (np.timedelta64(1, "ps"), np.dtype("timedelta64[ps]")), + (np.timedelta64(1, "fs"), np.dtype("timedelta64[fs]")), + (np.timedelta64(1, "as"), np.dtype("timedelta64[as]")), + ], +) +def test_0d_timedelta( + timedelta: np.timedelta64, expected_dtype: np.dtype[np.timedelta64] +) -> None: + named_array: NamedArray[np.ndarray[Any, np.dtype[np.timedelta64]]] + named_array = NamedArray([], timedelta) + assert named_array.dtype == expected_dtype + assert named_array.data == timedelta + + +@pytest.mark.parametrize( + "dims, data_shape, new_dims, raises", + [ + (["x", "y", "z"], (2, 3, 4), ["a", "b", "c"], False), + (["x", "y", "z"], (2, 3, 4), ["a", "b"], True), + (["x", "y", "z"], (2, 4, 5), ["a", "b", "c", "d"], True), + ([], [], (), False), + ([], [], ("x",), True), + ], +) +def test_dims_setter(dims: Any, data_shape: Any, new_dims: Any, raises: bool) -> None: + named_array: NamedArray[np.ndarray[Any, Any]] + named_array = NamedArray(dims, np.random.random(data_shape)) + assert named_array.dims == tuple(dims) + if raises: + with pytest.raises(ValueError): + named_array.dims = new_dims + else: + named_array.dims = new_dims + assert named_array.dims == tuple(new_dims) diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py index 3cecf1b52ec..8ad1cbe11be 100644 --- a/xarray/tests/test_options.py +++ b/xarray/tests/test_options.py @@ -165,7 +165,6 @@ def test_concat_attr_retention(self) -> None: result = concat([ds1, ds2], dim="dim1") assert result.attrs == original_attrs - @pytest.mark.xfail def test_merge_attr_retention(self) -> None: da1 = create_test_dataarray_attrs(var="var1") da2 = create_test_dataarray_attrs(var="var2") diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 9a15696b004..da834b76124 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -175,7 +175,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None: @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("window", (1, 2, 3, 4)) - def test_rolling_construct(self, center, window) -> None: + def test_rolling_construct(self, center: bool, window: int) -> None: s = pd.Series(np.arange(10)) da = DataArray.from_series(s) @@ -610,7 +610,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None: @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("window", (1, 2, 3, 4)) - def test_rolling_construct(self, center, window) -> None: + def test_rolling_construct(self, center: bool, window: int) -> None: df = pd.DataFrame( { "x": np.random.randn(20), @@ -627,12 +627,6 @@ def test_rolling_construct(self, center, window) -> None: np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values) np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"]) - # with stride - ds_rolling_mean = ds_rolling.construct("window", stride=2).mean("window") - np.testing.assert_allclose( - df_rolling["x"][::2].values, ds_rolling_mean["x"].values - ) - np.testing.assert_allclose(df_rolling.index[::2], ds_rolling_mean["index"]) # with fill_value ds_rolling_mean = ds_rolling.construct("window", stride=2, fill_value=0.0).mean( "window" @@ -640,6 +634,51 @@ def test_rolling_construct(self, center, window) -> None: assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim="vars").all() assert (ds_rolling_mean["x"] == 0.0).sum() >= 0 + @pytest.mark.parametrize("center", (True, False)) + @pytest.mark.parametrize("window", (1, 2, 3, 4)) + def test_rolling_construct_stride(self, center: bool, window: int) -> None: + df = pd.DataFrame( + { + "x": np.random.randn(20), + "y": np.random.randn(20), + "time": np.linspace(0, 1, 20), + } + ) + ds = Dataset.from_dataframe(df) + df_rolling_mean = df.rolling(window, center=center, min_periods=1).mean() + + # With an index (dimension coordinate) + ds_rolling = ds.rolling(index=window, center=center) + ds_rolling_mean = ds_rolling.construct("w", stride=2).mean("w") + np.testing.assert_allclose( + df_rolling_mean["x"][::2].values, ds_rolling_mean["x"].values + ) + np.testing.assert_allclose(df_rolling_mean.index[::2], ds_rolling_mean["index"]) + + # Without index (https://github.com/pydata/xarray/issues/7021) + ds2 = ds.drop_vars("index") + ds2_rolling = ds2.rolling(index=window, center=center) + ds2_rolling_mean = ds2_rolling.construct("w", stride=2).mean("w") + np.testing.assert_allclose( + df_rolling_mean["x"][::2].values, ds2_rolling_mean["x"].values + ) + + # Mixed coordinates, indexes and 2D coordinates + ds3 = xr.Dataset( + {"x": ("t", range(20)), "x2": ("y", range(5))}, + { + "t": range(20), + "y": ("y", range(5)), + "t2": ("t", range(20)), + "y2": ("y", range(5)), + "yt": (["t", "y"], np.ones((20, 5))), + }, + ) + ds3_rolling = ds3.rolling(t=window, center=center) + ds3_rolling_mean = ds3_rolling.construct("w", stride=2).mean("w") + for coord in ds3.coords: + assert coord in ds3_rolling_mean.coords + @pytest.mark.slow @pytest.mark.parametrize("ds", (1, 2), indirect=True) @pytest.mark.parametrize("center", (True, False)) @@ -727,9 +766,7 @@ def test_ndrolling_construct(self, center, fill_value, dask) -> None: ) assert_allclose(actual, expected) - @pytest.mark.xfail( - reason="See https://github.com/pydata/xarray/pull/4369 or docstring" - ) + @requires_dask @pytest.mark.filterwarnings("error") @pytest.mark.parametrize("ds", (2,), indirect=True) @pytest.mark.parametrize("name", ("mean", "max")) @@ -751,7 +788,9 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None: @requires_numbagg class TestDatasetRollingExp: - @pytest.mark.parametrize("backend", ["numpy"], indirect=True) + @pytest.mark.parametrize( + "backend", ["numpy", pytest.param("dask", marks=requires_dask)], indirect=True + ) def test_rolling_exp(self, ds) -> None: result = ds.rolling_exp(time=10, window_type="span").mean() assert isinstance(result, Dataset) diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index f64ce9338d7..489836b70fd 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -147,7 +147,6 @@ def test_variable_property(prop): ], ), True, - marks=xfail(reason="Coercion to dense"), ), param( do("conjugate"), @@ -201,7 +200,6 @@ def test_variable_property(prop): param( do("reduce", func="sum", dim="x"), True, - marks=xfail(reason="Coercion to dense"), ), param( do("rolling_window", dim="x", window=2, window_dim="x_win"), @@ -218,7 +216,7 @@ def test_variable_property(prop): param( do("var"), False, marks=xfail(reason="Missing implementation for np.nanvar") ), - param(do("to_dict"), False, marks=xfail(reason="Coercion to dense")), + param(do("to_dict"), False), (do("where", cond=make_xrvar({"x": 10, "y": 5}) > 0.5), True), ], ids=repr, @@ -237,7 +235,14 @@ def test_variable_method(func, sparse_output): assert isinstance(ret_s.data, sparse.SparseArray) assert np.allclose(ret_s.data.todense(), ret_d.data, equal_nan=True) else: - assert np.allclose(ret_s, ret_d, equal_nan=True) + if func.meth != "to_dict": + assert np.allclose(ret_s, ret_d) + else: + # pop the arrays from the dict + arr_s, arr_d = ret_s.pop("data"), ret_d.pop("data") + + assert np.allclose(arr_s, arr_d) + assert ret_s == ret_d @pytest.mark.parametrize( diff --git a/xarray/tests/test_typed_ops.py b/xarray/tests/test_typed_ops.py new file mode 100644 index 00000000000..1d4ef89ae29 --- /dev/null +++ b/xarray/tests/test_typed_ops.py @@ -0,0 +1,246 @@ +import numpy as np + +from xarray import DataArray, Dataset, Variable + + +def test_variable_typed_ops() -> None: + """Tests for type checking of typed_ops on Variable""" + + var = Variable(dims=["t"], data=[1, 2, 3]) + + def _test(var: Variable) -> None: + # mypy checks the input type + assert isinstance(var, Variable) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + + # __add__ as an example of binary ops + _test(var + _int) + _test(var + _list) + _test(var + _ndarray) + _test(var + var) + + # __radd__ as an example of reflexive binary ops + _test(_int + var) + _test(_list + var) + _test(_ndarray + var) # type: ignore[arg-type] # numpy problem + + # __eq__ as an example of cmp ops + _test(var == _int) + _test(var == _list) + _test(var == _ndarray) + _test(_int == var) # type: ignore[arg-type] # typeshed problem + _test(_list == var) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == var) + + # __lt__ as another example of cmp ops + _test(var < _int) + _test(var < _list) + _test(var < _ndarray) + _test(_int > var) + _test(_list > var) + _test(_ndarray > var) # type: ignore[arg-type] # numpy problem + + # __iadd__ as an example of inplace binary ops + var += _int + var += _list + var += _ndarray + + # __neg__ as an example of unary ops + _test(-var) + + +def test_dataarray_typed_ops() -> None: + """Tests for type checking of typed_ops on DataArray""" + + da = DataArray([1, 2, 3], dims=["t"]) + + def _test(da: DataArray) -> None: + # mypy checks the input type + assert isinstance(da, DataArray) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + _var = Variable(dims=["t"], data=[1, 2, 3]) + + # __add__ as an example of binary ops + _test(da + _int) + _test(da + _list) + _test(da + _ndarray) + _test(da + _var) + _test(da + da) + + # __radd__ as an example of reflexive binary ops + _test(_int + da) + _test(_list + da) + _test(_ndarray + da) # type: ignore[arg-type] # numpy problem + _test(_var + da) + + # __eq__ as an example of cmp ops + _test(da == _int) + _test(da == _list) + _test(da == _ndarray) + _test(da == _var) + _test(_int == da) # type: ignore[arg-type] # typeshed problem + _test(_list == da) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == da) + _test(_var == da) + + # __lt__ as another example of cmp ops + _test(da < _int) + _test(da < _list) + _test(da < _ndarray) + _test(da < _var) + _test(_int > da) + _test(_list > da) + _test(_ndarray > da) # type: ignore[arg-type] # numpy problem + _test(_var > da) + + # __iadd__ as an example of inplace binary ops + da += _int + da += _list + da += _ndarray + da += _var + + # __neg__ as an example of unary ops + _test(-da) + + +def test_dataset_typed_ops() -> None: + """Tests for type checking of typed_ops on Dataset""" + + ds = Dataset({"a": ("t", [1, 2, 3])}) + + def _test(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + _var = Variable(dims=["t"], data=[1, 2, 3]) + _da = DataArray([1, 2, 3], dims=["t"]) + + # __add__ as an example of binary ops + _test(ds + _int) + _test(ds + _list) + _test(ds + _ndarray) + _test(ds + _var) + _test(ds + _da) + _test(ds + ds) + + # __radd__ as an example of reflexive binary ops + _test(_int + ds) + _test(_list + ds) + _test(_ndarray + ds) + _test(_var + ds) + _test(_da + ds) + + # __eq__ as an example of cmp ops + _test(ds == _int) + _test(ds == _list) + _test(ds == _ndarray) + _test(ds == _var) + _test(ds == _da) + _test(_int == ds) # type: ignore[arg-type] # typeshed problem + _test(_list == ds) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == ds) + _test(_var == ds) + _test(_da == ds) + + # __lt__ as another example of cmp ops + _test(ds < _int) + _test(ds < _list) + _test(ds < _ndarray) + _test(ds < _var) + _test(ds < _da) + _test(_int > ds) + _test(_list > ds) + _test(_ndarray > ds) # type: ignore[arg-type] # numpy problem + _test(_var > ds) + _test(_da > ds) + + # __iadd__ as an example of inplace binary ops + ds += _int + ds += _list + ds += _ndarray + ds += _var + ds += _da + + # __neg__ as an example of unary ops + _test(-ds) + + +def test_dataarray_groupy_typed_ops() -> None: + """Tests for type checking of typed_ops on DataArrayGroupBy""" + + da = DataArray([1, 2, 3], coords={"x": ("t", [1, 2, 2])}, dims=["t"]) + grp = da.groupby("x") + + def _testda(da: DataArray) -> None: + # mypy checks the input type + assert isinstance(da, DataArray) + + def _testds(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _da = DataArray([5, 6], coords={"x": [1, 2]}, dims="x") + _ds = _da.to_dataset(name="a") + + # __add__ as an example of binary ops + _testda(grp + _da) + _testds(grp + _ds) + + # __radd__ as an example of reflexive binary ops + _testda(_da + grp) + _testds(_ds + grp) + + # __eq__ as an example of cmp ops + _testda(grp == _da) + _testda(_da == grp) + _testds(grp == _ds) + _testds(_ds == grp) + + # __lt__ as another example of cmp ops + _testda(grp < _da) + _testda(_da > grp) + _testds(grp < _ds) + _testds(_ds > grp) + + +def test_dataset_groupy_typed_ops() -> None: + """Tests for type checking of typed_ops on DatasetGroupBy""" + + ds = Dataset({"a": ("t", [1, 2, 3])}, coords={"x": ("t", [1, 2, 2])}) + grp = ds.groupby("x") + + def _test(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _da = DataArray([5, 6], coords={"x": [1, 2]}, dims="x") + _ds = _da.to_dataset(name="a") + + # __add__ as an example of binary ops + _test(grp + _da) + _test(grp + _ds) + + # __radd__ as an example of reflexive binary ops + _test(_da + grp) + _test(_ds + grp) + + # __eq__ as an example of cmp ops + _test(grp == _da) + _test(_da == grp) + _test(grp == _ds) + _test(_ds == grp) + + # __lt__ as another example of cmp ops + _test(grp < _da) + _test(_da > grp) + _test(grp < _ds) + _test(_ds > grp) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index addd7587544..7e1105e2e5d 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -18,6 +18,7 @@ assert_identical, requires_dask, requires_matplotlib, + requires_numbagg, ) from xarray.tests.test_plot import PlotTestCase from xarray.tests.test_variable import _PAD_XR_NP_ARGS @@ -304,11 +305,13 @@ def __call__(self, obj, *args, **kwargs): all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} + from xarray.core.groupby import GroupBy + xarray_classes = ( xr.Variable, xr.DataArray, xr.Dataset, - xr.core.groupby.GroupBy, + GroupBy, ) if not isinstance(obj, xarray_classes): @@ -2548,7 +2551,6 @@ def test_univariate_ufunc(self, units, error, dtype): assert_units_equal(expected, actual) assert_identical(expected, actual) - @pytest.mark.xfail(reason="needs the type register system for __array_ufunc__") @pytest.mark.parametrize( "unit,error", ( @@ -3849,23 +3851,21 @@ def test_computation(self, func, variant, dtype): method("groupby", "x"), method("groupby_bins", "y", bins=4), method("coarsen", y=2), - pytest.param( - method("rolling", y=3), - marks=pytest.mark.xfail( - reason="numpy.lib.stride_tricks.as_strided converts to ndarray" - ), - ), - pytest.param( - method("rolling_exp", y=3), - marks=pytest.mark.xfail( - reason="numbagg functions are not supported by pint" - ), - ), + method("rolling", y=3), + pytest.param(method("rolling_exp", y=3), marks=requires_numbagg), method("weighted", xr.DataArray(data=np.linspace(0, 1, 10), dims="y")), ), ids=repr, ) def test_computation_objects(self, func, variant, dtype): + if variant == "data": + if func.name == "rolling_exp": + pytest.xfail(reason="numbagg functions are not supported by pint") + elif func.name == "rolling": + pytest.xfail( + reason="numpy.lib.stride_tricks.as_strided converts to ndarray" + ) + unit = unit_registry.m variants = { diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 4fcd5f98d8f..73238b6ae3a 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -473,12 +473,12 @@ def test_encoding_preserved(self): assert_identical(expected.to_base_variable(), actual.to_base_variable()) assert expected.encoding == actual.encoding - def test_reset_encoding(self) -> None: + def test_drop_encoding(self) -> None: encoding1 = {"scale_factor": 1} # encoding set via cls constructor v1 = self.cls(["a"], [0, 1, 2], encoding=encoding1) assert v1.encoding == encoding1 - v2 = v1.reset_encoding() + v2 = v1.drop_encoding() assert v1.encoding == encoding1 assert v2.encoding == {} @@ -486,7 +486,7 @@ def test_reset_encoding(self) -> None: encoding3 = {"scale_factor": 10} v3 = self.cls(["a"], [0, 1, 2], encoding=encoding3) assert v3.encoding == encoding3 - v4 = v3.reset_encoding() + v4 = v3.drop_encoding() assert v3.encoding == encoding3 assert v4.encoding == {} @@ -885,20 +885,10 @@ def test_getitem_error(self): "mode", [ "mean", - pytest.param( - "median", - marks=pytest.mark.xfail(reason="median is not implemented by Dask"), - ), - pytest.param( - "reflect", marks=pytest.mark.xfail(reason="dask.array.pad bug") - ), + "median", + "reflect", "edge", - pytest.param( - "linear_ramp", - marks=pytest.mark.xfail( - reason="pint bug: https://github.com/hgrecco/pint/issues/1026" - ), - ), + "linear_ramp", "maximum", "minimum", "symmetric", @@ -926,7 +916,7 @@ def test_pad_constant_values(self, xr_arg, np_arg): actual = v.pad(**xr_arg) expected = np.pad( - np.array(v.data.astype(float)), + np.array(duck_array_ops.astype(v.data, float)), np_arg, mode="constant", constant_values=np.nan, @@ -2345,12 +2335,35 @@ def test_dask_rolling(self, dim, window, center): assert actual.shape == expected.shape assert_equal(actual, expected) - @pytest.mark.xfail( - reason="https://github.com/pydata/xarray/issues/6209#issuecomment-1025116203" - ) def test_multiindex(self): super().test_multiindex() + @pytest.mark.parametrize( + "mode", + [ + "mean", + pytest.param( + "median", + marks=pytest.mark.xfail(reason="median is not implemented by Dask"), + ), + pytest.param( + "reflect", marks=pytest.mark.xfail(reason="dask.array.pad bug") + ), + "edge", + "linear_ramp", + "maximum", + "minimum", + "symmetric", + "wrap", + ], + ) + @pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS) + @pytest.mark.filterwarnings( + r"ignore:dask.array.pad.+? converts integers to floats." + ) + def test_pad(self, mode, xr_arg, np_arg): + super().test_pad(mode, xr_arg, np_arg) + @requires_sparse class TestVariableWithSparse: diff --git a/xarray/util/deprecation_helpers.py b/xarray/util/deprecation_helpers.py index e9681bdf398..7b4cf901aa1 100644 --- a/xarray/util/deprecation_helpers.py +++ b/xarray/util/deprecation_helpers.py @@ -34,6 +34,9 @@ import inspect import warnings from functools import wraps +from typing import Callable, TypeVar + +T = TypeVar("T", bound=Callable) POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY @@ -41,7 +44,7 @@ EMPTY = inspect.Parameter.empty -def _deprecate_positional_args(version): +def _deprecate_positional_args(version) -> Callable[[T], T]: """Decorator for methods that issues warnings for positional arguments Using the keyword-only argument syntax in pep 3102, arguments after the diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index cf0673e7cca..f339470884a 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -3,14 +3,16 @@ For internal xarray development use only. Usage: - python xarray/util/generate_ops.py --module > xarray/core/_typed_ops.py - python xarray/util/generate_ops.py --stubs > xarray/core/_typed_ops.pyi + python xarray/util/generate_ops.py > xarray/core/_typed_ops.py """ # Note: the comments in https://github.com/pydata/xarray/pull/4904 provide some # background to some of the design choices made here. -import sys +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from typing import Optional BINOPS_EQNE = (("__eq__", "nputils.array_eq"), ("__ne__", "nputils.array_ne")) BINOPS_CMP = ( @@ -74,155 +76,180 @@ ("conjugate", "ops.conjugate"), ) + +required_method_binary = """ + def _binary_op( + self, other: {other_type}, f: Callable, reflexive: bool = False + ) -> {return_type}: + raise NotImplementedError""" template_binop = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> {return_type}:{type_ignore} + return self._binary_op(other, {func})""" +template_binop_overload = """ + @overload{overload_type_ignore} + def {method}(self, other: {overload_type}) -> {overload_type}: + ... + + @overload + def {method}(self, other: {other_type}) -> {return_type}: + ... + + def {method}(self, other: {other_type}) -> {return_type} | {overload_type}:{type_ignore} return self._binary_op(other, {func})""" template_reflexive = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> {return_type}: return self._binary_op(other, {func}, reflexive=True)""" + +required_method_inplace = """ + def _inplace_binary_op(self, other: {other_type}, f: Callable) -> Self: + raise NotImplementedError""" template_inplace = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> Self:{type_ignore} return self._inplace_binary_op(other, {func})""" + +required_method_unary = """ + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: + raise NotImplementedError""" template_unary = """ - def {method}(self): + def {method}(self) -> Self: return self._unary_op({func})""" template_other_unary = """ - def {method}(self, *args, **kwargs): + def {method}(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op({func}, *args, **kwargs)""" -required_method_unary = """ - def _unary_op(self, f, *args, **kwargs): - raise NotImplementedError""" -required_method_binary = """ - def _binary_op(self, other, f, reflexive=False): - raise NotImplementedError""" -required_method_inplace = """ - def _inplace_binary_op(self, other, f): - raise NotImplementedError""" # For some methods we override return type `bool` defined by base class `object`. -OVERRIDE_TYPESHED = {"override": " # type: ignore[override]"} -NO_OVERRIDE = {"override": ""} - -# Note: in some of the overloads below the return value in reality is NotImplemented, -# which cannot accurately be expressed with type hints,e.g. Literal[NotImplemented] -# or type(NotImplemented) are not allowed and NoReturn has a different meaning. -# In such cases we are lending the type checkers a hand by specifying the return type -# of the corresponding reflexive method on `other` which will be called instead. -stub_ds = """\ - def {method}(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...{override}""" -stub_da = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def {method}(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...""" -stub_var = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: T_DataArray) -> T_DataArray: ... - @overload - def {method}(self: T_Variable, other: VarCompatible) -> T_Variable: ...""" -stub_dsgb = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: "DataArray") -> "Dataset": ... - @overload - def {method}(self, other: GroupByIncompatible) -> NoReturn: ...""" -stub_dagb = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: T_DataArray) -> T_DataArray: ... - @overload - def {method}(self, other: GroupByIncompatible) -> NoReturn: ...""" -stub_unary = """\ - def {method}(self: {self_type}) -> {self_type}: ...""" -stub_other_unary = """\ - def {method}(self: {self_type}, *args, **kwargs) -> {self_type}: ...""" -stub_required_unary = """\ - def _unary_op(self, f, *args, **kwargs): ...""" -stub_required_binary = """\ - def _binary_op(self, other, f, reflexive=...): ...""" -stub_required_inplace = """\ - def _inplace_binary_op(self, other, f): ...""" - - -def unops(self_type): - extra_context = {"self_type": self_type} +# We need to add "# type: ignore[override]" +# Keep an eye out for: +# https://discuss.python.org/t/make-type-hints-for-eq-of-primitives-less-strict/34240 +# The type ignores might not be neccesary anymore at some point. +# +# We require a "hack" to tell type checkers that e.g. Variable + DataArray = DataArray +# In reality this returns NotImplementes, but this is not a valid type in python 3.9. +# Therefore, we return DataArray. In reality this would call DataArray.__add__(Variable) +# TODO: change once python 3.10 is the minimum. +# +# Mypy seems to require that __iadd__ and __add__ have the same signature. +# This requires some extra type: ignores[misc] in the inplace methods :/ + + +def _type_ignore(ignore: str) -> str: + return f" # type:ignore[{ignore}]" if ignore else "" + + +FuncType = Sequence[tuple[Optional[str], Optional[str]]] +OpsType = tuple[FuncType, str, dict[str, str]] + + +def binops( + other_type: str, return_type: str = "Self", type_ignore_eq: str = "override" +) -> list[OpsType]: + extras = {"other_type": other_type, "return_type": return_type} + return [ + ([(None, None)], required_method_binary, extras), + (BINOPS_NUM + BINOPS_CMP, template_binop, extras | {"type_ignore": ""}), + ( + BINOPS_EQNE, + template_binop, + extras | {"type_ignore": _type_ignore(type_ignore_eq)}, + ), + (BINOPS_REFLEXIVE, template_reflexive, extras), + ] + + +def binops_overload( + other_type: str, + overload_type: str, + return_type: str = "Self", + type_ignore_eq: str = "override", +) -> list[OpsType]: + extras = {"other_type": other_type, "return_type": return_type} return [ - ([(None, None)], required_method_unary, stub_required_unary, {}), - (UNARY_OPS, template_unary, stub_unary, extra_context), - (OTHER_UNARY_METHODS, template_other_unary, stub_other_unary, extra_context), + ([(None, None)], required_method_binary, extras), + ( + BINOPS_NUM + BINOPS_CMP, + template_binop_overload, + extras + | { + "overload_type": overload_type, + "type_ignore": "", + "overload_type_ignore": "", + }, + ), + ( + BINOPS_EQNE, + template_binop_overload, + extras + | { + "overload_type": overload_type, + "type_ignore": "", + "overload_type_ignore": _type_ignore(type_ignore_eq), + }, + ), + (BINOPS_REFLEXIVE, template_reflexive, extras), ] -def binops(stub=""): +def inplace(other_type: str, type_ignore: str = "") -> list[OpsType]: + extras = {"other_type": other_type} return [ - ([(None, None)], required_method_binary, stub_required_binary, {}), - (BINOPS_NUM + BINOPS_CMP, template_binop, stub, NO_OVERRIDE), - (BINOPS_EQNE, template_binop, stub, OVERRIDE_TYPESHED), - (BINOPS_REFLEXIVE, template_reflexive, stub, NO_OVERRIDE), + ([(None, None)], required_method_inplace, extras), + ( + BINOPS_INPLACE, + template_inplace, + extras | {"type_ignore": _type_ignore(type_ignore)}, + ), ] -def inplace(): +def unops() -> list[OpsType]: return [ - ([(None, None)], required_method_inplace, stub_required_inplace, {}), - (BINOPS_INPLACE, template_inplace, "", {}), + ([(None, None)], required_method_unary, {}), + (UNARY_OPS, template_unary, {}), + (OTHER_UNARY_METHODS, template_other_unary, {}), ] ops_info = {} -ops_info["DatasetOpsMixin"] = binops(stub_ds) + inplace() + unops("T_Dataset") -ops_info["DataArrayOpsMixin"] = binops(stub_da) + inplace() + unops("T_DataArray") -ops_info["VariableOpsMixin"] = binops(stub_var) + inplace() + unops("T_Variable") -ops_info["DatasetGroupByOpsMixin"] = binops(stub_dsgb) -ops_info["DataArrayGroupByOpsMixin"] = binops(stub_dagb) +ops_info["DatasetOpsMixin"] = ( + binops(other_type="DsCompatible") + inplace(other_type="DsCompatible") + unops() +) +ops_info["DataArrayOpsMixin"] = ( + binops(other_type="DaCompatible") + inplace(other_type="DaCompatible") + unops() +) +ops_info["VariableOpsMixin"] = ( + binops_overload(other_type="VarCompatible", overload_type="T_DataArray") + + inplace(other_type="VarCompatible", type_ignore="misc") + + unops() +) +ops_info["DatasetGroupByOpsMixin"] = binops( + other_type="GroupByCompatible", return_type="Dataset" +) +ops_info["DataArrayGroupByOpsMixin"] = binops( + other_type="T_Xarray", return_type="T_Xarray" +) MODULE_PREAMBLE = '''\ """Mixin classes with arithmetic operators.""" # This file was generated using xarray.util.generate_ops. Do not edit manually. -import operator - -from . import nputils, ops''' +from __future__ import annotations -STUBFILE_PREAMBLE = '''\ -"""Stub file for mixin classes with arithmetic operators.""" -# This file was generated using xarray.util.generate_ops. Do not edit manually. - -from typing import NoReturn, TypeVar, overload - -import numpy as np -from numpy.typing import ArrayLike +import operator +from typing import TYPE_CHECKING, Any, Callable, overload -from .dataarray import DataArray -from .dataset import Dataset -from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy -from .types import ( +from xarray.core import nputils, ops +from xarray.core.types import ( DaCompatible, DsCompatible, - GroupByIncompatible, - ScalarOrArray, + GroupByCompatible, + Self, + T_DataArray, + T_Xarray, VarCompatible, ) -from .variable import Variable -try: - from dask.array import Array as DaskArray -except ImportError: - DaskArray = np.ndarray # type: ignore - -# DatasetOpsMixin etc. are parent classes of Dataset etc. -# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally -# we use the ones in `types`. (We're open to refining this, and potentially integrating -# the `py` & `pyi` files to simplify them.) -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin")''' +if TYPE_CHECKING: + from xarray.core.dataset import Dataset''' CLASS_PREAMBLE = """{newline} @@ -233,35 +260,28 @@ class {cls_name}: {method}.__doc__ = {func}.__doc__""" -def render(ops_info, is_module): +def render(ops_info: dict[str, list[OpsType]]) -> Iterator[str]: """Render the module or stub file.""" - yield MODULE_PREAMBLE if is_module else STUBFILE_PREAMBLE + yield MODULE_PREAMBLE for cls_name, method_blocks in ops_info.items(): - yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n" * is_module) - yield from _render_classbody(method_blocks, is_module) + yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n") + yield from _render_classbody(method_blocks) -def _render_classbody(method_blocks, is_module): - for method_func_pairs, method_template, stub_template, extra in method_blocks: - template = method_template if is_module else stub_template +def _render_classbody(method_blocks: list[OpsType]) -> Iterator[str]: + for method_func_pairs, template, extra in method_blocks: if template: for method, func in method_func_pairs: yield template.format(method=method, func=func, **extra) - if is_module: - yield "" - for method_func_pairs, *_ in method_blocks: - for method, func in method_func_pairs: - if method and func: - yield COPY_DOCSTRING.format(method=method, func=func) + yield "" + for method_func_pairs, *_ in method_blocks: + for method, func in method_func_pairs: + if method and func: + yield COPY_DOCSTRING.format(method=method, func=func) if __name__ == "__main__": - option = sys.argv[1].lower() if len(sys.argv) == 2 else None - if option not in {"--module", "--stubs"}: - raise SystemExit(f"Usage: {sys.argv[0]} --module | --stubs") - is_module = option == "--module" - - for line in render(ops_info, is_module): + for line in render(ops_info): print(line)