diff --git a/.binder/environment.yml b/.binder/environment.yml index 99a7d9f2494..053b12dfc86 100644 --- a/.binder/environment.yml +++ b/.binder/environment.yml @@ -2,11 +2,10 @@ name: xarray-examples channels: - conda-forge dependencies: - - python=3.9 + - python=3.10 - boto3 - bottleneck - cartopy - - cdms2 - cfgrib - cftime - coveralls @@ -25,7 +24,7 @@ dependencies: - numpy - packaging - pandas - - pint + - pint>=0.22 - pip - pooch - pydap @@ -38,5 +37,4 @@ dependencies: - toolz - xarray - zarr - - pip: - - numbagg + - numbagg diff --git a/.codecov.yml b/.codecov.yml index f3a055c09d4..d0bec9539f8 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,16 +1,38 @@ codecov: - ci: - # by default, codecov doesn't recognize azure as a CI provider - - dev.azure.com - require_ci_to_pass: yes + require_ci_to_pass: true coverage: status: project: default: # Require 1% coverage, i.e., always succeed - target: 1 + target: 1% + flags: + - unittests + paths: + - "!xarray/tests/" + unittests: + target: 90% + flags: + - unittests + paths: + - "!xarray/tests/" + mypy: + target: 20% + flags: + - mypy patch: false changes: false -comment: off +comment: false + +flags: + unittests: + paths: + - "xarray" + - "!xarray/tests" + carryforward: false + mypy: + paths: + - "xarray" + carryforward: false diff --git a/.github/ISSUE_TEMPLATE/bugreport.yml b/.github/ISSUE_TEMPLATE/bugreport.yml index 59e5889f5ec..cc1a2e12be3 100644 --- a/.github/ISSUE_TEMPLATE/bugreport.yml +++ b/.github/ISSUE_TEMPLATE/bugreport.yml @@ -44,6 +44,7 @@ body: - label: Complete example — the example is self-contained, including all data and the text of any traceback. - label: Verifiable example — the example copy & pastes into an IPython prompt or [Binder notebook](https://mybinder.org/v2/gh/pydata/xarray/main?urlpath=lab/tree/doc/examples/blank_template.ipynb), returning the result. - label: New issue — a search of GitHub Issues suggests this is not a duplicate. + - label: Recent environment — the issue occurs with the latest version of xarray and its dependencies. - type: textarea id: log-output diff --git a/.github/config.yml b/.github/config.yml new file mode 100644 index 00000000000..c64c2e28e59 --- /dev/null +++ b/.github/config.yml @@ -0,0 +1,23 @@ +# Comment to be posted to on first time issues +newIssueWelcomeComment: > + Thanks for opening your first issue here at xarray! Be sure to follow the issue template! + + If you have an idea for a solution, we would really welcome a Pull Request with proposed changes. + + See the [Contributing Guide](https://docs.xarray.dev/en/latest/contributing.html) for more. + + It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better. + + Thank you! + +# Comment to be posted to on PRs from first time contributors in your repository +newPRWelcomeComment: > + Thank you for opening this pull request! It may take us a few days to respond here, so thank you for being patient. + + If you have questions, some answers may be found in our [contributing guidelines](http://docs.xarray.dev/en/stable/contributing.html). + +# Comment to be posted to on pull requests merged by a first time user +firstPRMergeComment: > + Congratulations on completing your first pull request! Welcome to Xarray! + We are proud of you, and hope to see you again! + ![celebration gif](https://media.giphy.com/media/umYMU8G2ixG5mJBDo5/giphy.gif) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index bad6ba3f62a..bd72c5b9396 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -5,3 +5,7 @@ updates: schedule: # Check for updates once a week interval: 'weekly' + groups: + actions: + patterns: + - "*" diff --git a/.github/labeler.yml b/.github/labeler.yml deleted file mode 100644 index d866a7342fe..00000000000 --- a/.github/labeler.yml +++ /dev/null @@ -1,82 +0,0 @@ -Automation: - - .github/* - - .github/**/* - -CI: - - ci/* - - ci/**/* - -dependencies: - - requirements.txt - - ci/requirements/* - -topic-arrays: - - xarray/core/duck_array_ops.py - -topic-backends: - - xarray/backends/* - - xarray/backends/**/* - -topic-cftime: - - xarray/coding/*time* - -topic-CF conventions: - - xarray/conventions.py - -topic-combine: - - xarray/core/combine.py - -topic-dask: - - xarray/core/dask* - - xarray/core/parallel.py - -topic-DataTree: - - xarray/core/datatree* - -# topic-documentation: -# - ['doc/*', '!doc/whats-new.rst'] -# - doc/**/* - -topic-faq: - - doc/howdoi.rst - -topic-groupby: - - xarray/core/groupby.py - -topic-html-repr: - - xarray/core/formatting_html.py - -topic-hypothesis: - - xarray/properties/* - - xarray/testing/strategies/* - -topic-indexing: - - xarray/core/indexes.py - - xarray/core/indexing.py - -topic-performance: - - asv_bench/benchmarks/* - - asv_bench/benchmarks/**/* - -topic-plotting: - - xarray/plot/* - - xarray/plot/**/* - -topic-rolling: - - xarray/core/rolling.py - - xarray/core/rolling_exp.py - -topic-testing: - - conftest.py - - xarray/testing.py - - xarray/testing/* - -topic-typing: - - xarray/core/types.py - -topic-zarr: - - xarray/backends/zarr.py - -io: - - xarray/backends/* - - xarray/backends/**/* diff --git a/.github/workflows/benchmarks-last-release.yml b/.github/workflows/benchmarks-last-release.yml new file mode 100644 index 00000000000..794f35300ba --- /dev/null +++ b/.github/workflows/benchmarks-last-release.yml @@ -0,0 +1,79 @@ +name: Benchmark compare last release + +on: + push: + branches: + - main + workflow_dispatch: + +jobs: + benchmark: + name: Linux + runs-on: ubuntu-20.04 + env: + ASV_DIR: "./asv_bench" + CONDA_ENV_FILE: ci/requirements/environment.yml + + steps: + # We need the full repo to avoid this issue + # https://github.com/actions/checkout/issues/23 + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up conda environment + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: ${{env.CONDA_ENV_FILE}} + environment-name: xarray-tests + cache-environment: true + cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}-benchmark" + create-args: >- + asv + + - name: 'Get Previous tag' + id: previoustag + uses: "WyriHaximus/github-action-get-previous-tag@v1" + # with: + # fallback: 1.0.0 # Optional fallback tag to use when no tag can be found + + - name: Run benchmarks + shell: bash -l {0} + id: benchmark + env: + OPENBLAS_NUM_THREADS: 1 + MKL_NUM_THREADS: 1 + OMP_NUM_THREADS: 1 + ASV_FACTOR: 1.5 + ASV_SKIP_SLOW: 1 + run: | + set -x + # ID this runner + asv machine --yes + echo "Baseline: ${{ steps.previoustag.outputs.tag }} " + echo "Contender: ${{ github.sha }}" + # Use mamba for env creation + # export CONDA_EXE=$(which mamba) + export CONDA_EXE=$(which conda) + # Run benchmarks for current commit against base + ASV_OPTIONS="--split --show-stderr --factor $ASV_FACTOR" + asv continuous $ASV_OPTIONS ${{ steps.previoustag.outputs.tag }} ${{ github.sha }} \ + | sed "/Traceback \|failed$\|PERFORMANCE DECREASED/ s/^/::error::/" \ + | tee benchmarks.log + # Report and export results for subsequent steps + if grep "Traceback \|failed\|PERFORMANCE DECREASED" benchmarks.log > /dev/null ; then + exit 1 + fi + working-directory: ${{ env.ASV_DIR }} + + - name: Add instructions to artifact + if: always() + run: | + cp benchmarks/README_CI.md benchmarks.log .asv/results/ + working-directory: ${{ env.ASV_DIR }} + + - uses: actions/upload-artifact@v4 + if: always() + with: + name: asv-benchmark-results-${{ runner.os }} + path: ${{ env.ASV_DIR }}/.asv/results diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index b9a8d773c5a..7969847c61f 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -7,7 +7,7 @@ on: jobs: benchmark: - if: ${{ contains( github.event.pull_request.labels.*.name, 'run-benchmark') && github.event_name == 'pull_request' || contains( github.event.pull_request.labels.*.name, 'topic-performance') && github.event_name == 'pull_request' || github.event_name == 'workflow_dispatch' }} + if: ${{ contains( github.event.pull_request.labels.*.name, 'run-benchmark') && github.event_name == 'pull_request' || github.event_name == 'workflow_dispatch' }} name: Linux runs-on: ubuntu-20.04 env: @@ -17,18 +17,18 @@ jobs: steps: # We need the full repo to avoid this issue # https://github.com/actions/checkout/issues/23 - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Set up conda environment - uses: mamba-org/provision-with-micromamba@v15 + uses: mamba-org/setup-micromamba@v1 with: environment-file: ${{env.CONDA_ENV_FILE}} environment-name: xarray-tests - cache-env: true - cache-env-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}-benchmark" - extra-specs: | + cache-environment: true + cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}-benchmark" + create-args: >- asv @@ -67,7 +67,7 @@ jobs: cp benchmarks/README_CI.md benchmarks.log .asv/results/ working-directory: ${{ env.ASV_DIR }} - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 if: always() with: name: asv-benchmark-results-${{ runner.os }} diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index 6f069af5da6..9aa3b17746f 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -2,10 +2,10 @@ name: CI Additional on: push: branches: - - "*" + - "main" pull_request: branches: - - "*" + - "main" workflow_dispatch: # allows you to trigger manually concurrency: @@ -22,7 +22,7 @@ jobs: outputs: triggered: ${{ steps.detect-trigger.outputs.trigger-found }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 2 - uses: xarray-contrib/ci-trigger@v1 @@ -41,10 +41,10 @@ jobs: env: CONDA_ENV_FILE: ci/requirements/environment.yml - PYTHON_VERSION: "3.10" + PYTHON_VERSION: "3.11" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history for all branches and tags. @@ -53,15 +53,15 @@ jobs: echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV - name: Setup micromamba - uses: mamba-org/provision-with-micromamba@v15 + uses: mamba-org/setup-micromamba@v1 with: environment-file: ${{env.CONDA_ENV_FILE}} environment-name: xarray-tests - extra-specs: | + create-args: >- python=${{env.PYTHON_VERSION}} conda - cache-env: true - cache-env-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" + cache-environment: true + cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" - name: Install xarray run: | @@ -76,23 +76,25 @@ jobs: # Raise an error if there are warnings in the doctests, with `-Werror`. # This is a trial; if it presents an problem, feel free to remove. # See https://github.com/pydata/xarray/issues/7164 for more info. - python -m pytest --doctest-modules xarray --ignore xarray/tests -Werror + # + # If dependencies emit warnings we can't do anything about, add ignores to + # `xarray/tests/__init__.py`. + # [MHS, 01/25/2024] Skip datatree_ documentation remove after #8572 + python -m pytest --doctest-modules xarray --ignore xarray/tests --ignore xarray/datatree_ -Werror mypy: name: Mypy runs-on: "ubuntu-latest" needs: detect-ci-trigger - # temporarily skipping due to https://github.com/pydata/xarray/issues/6551 - if: needs.detect-ci-trigger.outputs.triggered == 'false' defaults: run: shell: bash -l {0} env: CONDA_ENV_FILE: ci/requirements/environment.yml - PYTHON_VERSION: "3.10" + PYTHON_VERSION: "3.11" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history for all branches and tags. @@ -100,15 +102,15 @@ jobs: run: | echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV - name: Setup micromamba - uses: mamba-org/provision-with-micromamba@v15 + uses: mamba-org/setup-micromamba@v1 with: environment-file: ${{env.CONDA_ENV_FILE}} environment-name: xarray-tests - extra-specs: | + create-args: >- python=${{env.PYTHON_VERSION}} conda - cache-env: true - cache-env-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" + cache-environment: true + cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" - name: Install xarray run: | python -m pip install --no-deps -e . @@ -119,14 +121,14 @@ jobs: python xarray/util/print_versions.py - name: Install mypy run: | - python -m pip install 'mypy<0.990' + python -m pip install "mypy<1.9" --force-reinstall - name: Run mypy run: | - python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report + python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report xarray/ - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v3.1.1 + uses: codecov/codecov-action@v4.1.0 with: file: mypy_report/cobertura.xml flags: mypy @@ -146,7 +148,7 @@ jobs: PYTHON_VERSION: "3.9" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history for all branches and tags. @@ -154,15 +156,15 @@ jobs: run: | echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV - name: Setup micromamba - uses: mamba-org/provision-with-micromamba@v15 + uses: mamba-org/setup-micromamba@v1 with: environment-file: ${{env.CONDA_ENV_FILE}} environment-name: xarray-tests - extra-specs: | + create-args: >- python=${{env.PYTHON_VERSION}} conda - cache-env: true - cache-env-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" + cache-environment: true + cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" - name: Install xarray run: | python -m pip install --no-deps -e . @@ -173,14 +175,14 @@ jobs: python xarray/util/print_versions.py - name: Install mypy run: | - python -m pip install 'mypy<0.990' + python -m pip install "mypy<1.9" --force-reinstall - name: Run mypy run: | - python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report + python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report xarray/ - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v3.1.1 + uses: codecov/codecov-action@v4.1.0 with: file: mypy_report/cobertura.xml flags: mypy39 @@ -190,6 +192,126 @@ jobs: + pyright: + name: Pyright + runs-on: "ubuntu-latest" + needs: detect-ci-trigger + if: | + always() + && ( + contains( github.event.pull_request.labels.*.name, 'run-pyright') + ) + defaults: + run: + shell: bash -l {0} + env: + CONDA_ENV_FILE: ci/requirements/environment.yml + PYTHON_VERSION: "3.10" + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. + + - name: set environment variables + run: | + echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + - name: Setup micromamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: ${{env.CONDA_ENV_FILE}} + environment-name: xarray-tests + create-args: >- + python=${{env.PYTHON_VERSION}} + conda + cache-environment: true + cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" + - name: Install xarray + run: | + python -m pip install --no-deps -e . + - name: Version info + run: | + conda info -a + conda list + python xarray/util/print_versions.py + - name: Install pyright + run: | + python -m pip install pyright --force-reinstall + + - name: Run pyright + run: | + python -m pyright xarray/ + + - name: Upload pyright coverage to Codecov + uses: codecov/codecov-action@v4.1.0 + with: + file: pyright_report/cobertura.xml + flags: pyright + env_vars: PYTHON_VERSION + name: codecov-umbrella + fail_ci_if_error: false + + pyright39: + name: Pyright 3.9 + runs-on: "ubuntu-latest" + needs: detect-ci-trigger + if: | + always() + && ( + contains( github.event.pull_request.labels.*.name, 'run-pyright') + ) + defaults: + run: + shell: bash -l {0} + env: + CONDA_ENV_FILE: ci/requirements/environment.yml + PYTHON_VERSION: "3.9" + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. + + - name: set environment variables + run: | + echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + - name: Setup micromamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: ${{env.CONDA_ENV_FILE}} + environment-name: xarray-tests + create-args: >- + python=${{env.PYTHON_VERSION}} + conda + cache-environment: true + cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" + - name: Install xarray + run: | + python -m pip install --no-deps -e . + - name: Version info + run: | + conda info -a + conda list + python xarray/util/print_versions.py + - name: Install pyright + run: | + python -m pip install pyright --force-reinstall + + - name: Run pyright + run: | + python -m pyright xarray/ + + - name: Upload pyright coverage to Codecov + uses: codecov/codecov-action@v4.1.0 + with: + file: pyright_report/cobertura.xml + flags: pyright39 + env_vars: PYTHON_VERSION + name: codecov-umbrella + fail_ci_if_error: false + + + min-version-policy: name: Minimum Version Policy runs-on: "ubuntu-latest" @@ -199,28 +321,25 @@ jobs: run: shell: bash -l {0} - strategy: - matrix: - environment-file: ["bare-minimum", "min-all-deps"] - fail-fast: false - steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history for all branches and tags. - name: Setup micromamba - uses: mamba-org/provision-with-micromamba@v15 + uses: mamba-org/setup-micromamba@v1 with: environment-name: xarray-tests - environment-file: false - extra-specs: | - python=3.10 + create-args: >- + python=3.11 pyyaml conda python-dateutil - channels: conda-forge - - name: minimum versions policy + - name: All-deps minimum versions policy + run: | + python ci/min_deps_check.py ci/requirements/min-all-deps.yml + + - name: Bare minimum versions policy run: | - python ci/min_deps_check.py ci/requirements/${{ matrix.environment-file }}.yml + python ci/min_deps_check.py ci/requirements/bare-minimum.yml diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index acace7aab95..a37ff876e20 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -2,10 +2,10 @@ name: CI on: push: branches: - - "*" + - "main" pull_request: branches: - - "*" + - "main" workflow_dispatch: # allows you to trigger manually concurrency: @@ -22,7 +22,7 @@ jobs: outputs: triggered: ${{ steps.detect-trigger.outputs.trigger-found }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 2 - uses: xarray-contrib/ci-trigger@v1 @@ -34,6 +34,8 @@ jobs: runs-on: ${{ matrix.os }} needs: detect-ci-trigger if: needs.detect-ci-trigger.outputs.triggered == 'false' + env: + ZARR_V3_EXPERIMENTAL_API: 1 defaults: run: shell: bash -l {0} @@ -42,7 +44,7 @@ jobs: matrix: os: ["ubuntu-latest", "macos-latest", "windows-latest"] # Bookend python versions - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.9", "3.11", "3.12"] env: [""] include: # Minimum python version: @@ -60,22 +62,20 @@ jobs: python-version: "3.10" os: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history for all branches and tags. - name: Set environment variables run: | echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV - if [[ "${{matrix.python-version}}" == "3.11" ]]; then - if [[ ${{matrix.os}} == windows* ]]; then - echo "CONDA_ENV_FILE=ci/requirements/environment-windows-py311.yml" >> $GITHUB_ENV + if [[ ${{ matrix.os }} == windows* ]] ; + then + if [[ ${{ matrix.python-version }} != "3.12" ]]; then + echo "CONDA_ENV_FILE=ci/requirements/environment-windows.yml" >> $GITHUB_ENV else - echo "CONDA_ENV_FILE=ci/requirements/environment-py311.yml" >> $GITHUB_ENV + echo "CONDA_ENV_FILE=ci/requirements/environment-windows-3.12.yml" >> $GITHUB_ENV fi - elif [[ ${{ matrix.os }} == windows* ]] ; - then - echo "CONDA_ENV_FILE=ci/requirements/environment-windows.yml" >> $GITHUB_ENV elif [[ "${{ matrix.env }}" != "" ]] ; then if [[ "${{ matrix.env }}" == "flaky" ]] ; @@ -86,19 +86,23 @@ jobs: echo "CONDA_ENV_FILE=ci/requirements/${{ matrix.env }}.yml" >> $GITHUB_ENV fi else - echo "CONDA_ENV_FILE=ci/requirements/environment.yml" >> $GITHUB_ENV + if [[ ${{ matrix.python-version }} != "3.12" ]]; then + echo "CONDA_ENV_FILE=ci/requirements/environment.yml" >> $GITHUB_ENV + else + echo "CONDA_ENV_FILE=ci/requirements/environment-3.12.yml" >> $GITHUB_ENV + fi fi echo "PYTHON_VERSION=${{ matrix.python-version }}" >> $GITHUB_ENV - name: Setup micromamba - uses: mamba-org/provision-with-micromamba@v15 + uses: mamba-org/setup-micromamba@v1 with: environment-file: ${{ env.CONDA_ENV_FILE }} environment-name: xarray-tests - cache-env: true - cache-env-key: "${{runner.os}}-${{runner.arch}}-py${{matrix.python-version}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" - extra-specs: | + cache-environment: true + cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{matrix.python-version}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" + create-args: >- python=${{matrix.python-version}} conda @@ -133,13 +137,13 @@ jobs: - name: Upload test results if: always() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: - name: Test results for ${{ runner.os }}-${{ matrix.python-version }} + name: Test results for ${{ runner.os }}-${{ matrix.python-version }} ${{ matrix.env }} path: pytest.xml - name: Upload code coverage to Codecov - uses: codecov/codecov-action@v3.1.1 + uses: codecov/codecov-action@v4.1.0 with: file: ./coverage.xml flags: unittests @@ -153,7 +157,7 @@ jobs: if: github.repository == 'pydata/xarray' steps: - name: Upload - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: Event File path: ${{ github.event_path }} diff --git a/.github/workflows/label-all.yml b/.github/workflows/label-all.yml deleted file mode 100644 index 9d09c42e734..00000000000 --- a/.github/workflows/label-all.yml +++ /dev/null @@ -1,14 +0,0 @@ -name: "Issue and PR Labeler" -on: - pull_request: - types: [opened] - issues: - types: [opened, reopened] -jobs: - label-all-on-open: - runs-on: ubuntu-latest - steps: - - uses: andymckay/labeler@1.0.4 - with: - add-labels: "needs triage" - ignore-if-labeled: false diff --git a/.github/workflows/label-prs.yml b/.github/workflows/label-prs.yml deleted file mode 100644 index ec39e68a3ff..00000000000 --- a/.github/workflows/label-prs.yml +++ /dev/null @@ -1,12 +0,0 @@ -name: "PR Labeler" -on: -- pull_request_target - -jobs: - label: - runs-on: ubuntu-latest - steps: - - uses: actions/labeler@main - with: - repo-token: "${{ secrets.GITHUB_TOKEN }}" - sync-labels: false diff --git a/.github/workflows/nightly-wheels.yml b/.github/workflows/nightly-wheels.yml new file mode 100644 index 00000000000..9aac7ab8775 --- /dev/null +++ b/.github/workflows/nightly-wheels.yml @@ -0,0 +1,44 @@ +name: Upload nightly wheels +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" +jobs: + cron: + runs-on: ubuntu-latest + if: github.repository == 'pydata/xarray' + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install build twine + + - name: Build tarball and wheels + run: | + git clean -xdf + git restore -SW . + python -m build + + - name: Check built artifacts + run: | + python -m twine check --strict dist/* + pwd + if [ -f dist/xarray-0.0.0.tar.gz ]; then + echo "❌ INVALID VERSION NUMBER" + exit 1 + else + echo "✅ Looks good" + fi + + - name: Upload wheel + uses: scientific-python/upload-nightly-action@b67d7fcc0396e1128a474d1ab2b48aa94680f9fc # 0.5.0 + with: + anaconda_nightly_upload_token: ${{ secrets.ANACONDA_NIGHTLY }} + artifacts_path: dist diff --git a/.github/workflows/parse_logs.py b/.github/workflows/parse_logs.py deleted file mode 100644 index c0674aeac0b..00000000000 --- a/.github/workflows/parse_logs.py +++ /dev/null @@ -1,102 +0,0 @@ -# type: ignore -import argparse -import functools -import json -import pathlib -import textwrap -from dataclasses import dataclass - -from pytest import CollectReport, TestReport - - -@dataclass -class SessionStart: - pytest_version: str - outcome: str = "status" - - @classmethod - def _from_json(cls, json): - json_ = json.copy() - json_.pop("$report_type") - return cls(**json_) - - -@dataclass -class SessionFinish: - exitstatus: str - outcome: str = "status" - - @classmethod - def _from_json(cls, json): - json_ = json.copy() - json_.pop("$report_type") - return cls(**json_) - - -def parse_record(record): - report_types = { - "TestReport": TestReport, - "CollectReport": CollectReport, - "SessionStart": SessionStart, - "SessionFinish": SessionFinish, - } - cls = report_types.get(record["$report_type"]) - if cls is None: - raise ValueError(f"unknown report type: {record['$report_type']}") - - return cls._from_json(record) - - -@functools.singledispatch -def format_summary(report): - return f"{report.nodeid}: {report}" - - -@format_summary.register -def _(report: TestReport): - message = report.longrepr.chain[0][1].message - return f"{report.nodeid}: {message}" - - -@format_summary.register -def _(report: CollectReport): - message = report.longrepr.split("\n")[-1].removeprefix("E").lstrip() - return f"{report.nodeid}: {message}" - - -def format_report(reports, py_version): - newline = "\n" - summaries = newline.join(format_summary(r) for r in reports) - message = textwrap.dedent( - """\ -
Python {py_version} Test Summary - - ``` - {summaries} - ``` - -
- """ - ).format(summaries=summaries, py_version=py_version) - return message - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("filepath", type=pathlib.Path) - args = parser.parse_args() - - py_version = args.filepath.stem.split("-")[1] - - print("Parsing logs ...") - - lines = args.filepath.read_text().splitlines() - reports = [parse_record(json.loads(line)) for line in lines] - - failed = [report for report in reports if report.outcome == "failed"] - - message = format_report(failed, py_version=py_version) - - output_file = pathlib.Path("pytest-logs.txt") - print(f"Writing output file to: {output_file.absolute()}") - output_file.write_text(message) diff --git a/.github/workflows/pypi-release.yaml b/.github/workflows/pypi-release.yaml index a1e38644045..354a8b59d4e 100644 --- a/.github/workflows/pypi-release.yaml +++ b/.github/workflows/pypi-release.yaml @@ -12,10 +12,10 @@ jobs: runs-on: ubuntu-latest if: github.repository == 'pydata/xarray' steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 name: Install Python with: python-version: "3.11" @@ -41,7 +41,7 @@ jobs: else echo "✅ Looks good" fi - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: releases path: dist @@ -50,11 +50,11 @@ jobs: needs: build-artifacts runs-on: ubuntu-latest steps: - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 name: Install Python with: python-version: "3.11" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: releases path: dist @@ -70,12 +70,26 @@ jobs: python -m pip install dist/xarray*.whl python -m xarray.util.print_versions + upload-to-test-pypi: + needs: test-built-dist + if: github.event_name == 'push' + runs-on: ubuntu-latest + + environment: + name: pypi + url: https://test.pypi.org/p/xarray + permissions: + id-token: write + + steps: + - uses: actions/download-artifact@v4 + with: + name: releases + path: dist - name: Publish package to TestPyPI if: github.event_name == 'push' - uses: pypa/gh-action-pypi-publish@v1.8.3 + uses: pypa/gh-action-pypi-publish@v1.8.14 with: - user: __token__ - password: ${{ secrets.TESTPYPI_TOKEN }} repository_url: https://test.pypi.org/legacy/ verbose: true @@ -84,14 +98,19 @@ jobs: needs: test-built-dist if: github.event_name == 'release' runs-on: ubuntu-latest + + environment: + name: pypi + url: https://pypi.org/p/xarray + permissions: + id-token: write + steps: - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: releases path: dist - name: Publish package to PyPI - uses: pypa/gh-action-pypi-publish@v1.8.3 + uses: pypa/gh-action-pypi-publish@v1.8.14 with: - user: __token__ - password: ${{ secrets.PYPI_TOKEN }} verbose: true diff --git a/.github/workflows/testpypi-release.yaml b/.github/workflows/testpypi-release.yaml deleted file mode 100644 index b892e97268f..00000000000 --- a/.github/workflows/testpypi-release.yaml +++ /dev/null @@ -1,86 +0,0 @@ -name: Build and Upload xarray to PyPI -on: - push: - branches: - - 'main' - -# no need for concurrency limits - -jobs: - build-artifacts: - runs-on: ubuntu-latest - if: github.repository == 'pydata/xarray' - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - - uses: actions/setup-python@v4 - name: Install Python - with: - python-version: "3.10" - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install build twine - python -m pip install tomli tomli_w - - - name: Disable local versions - run: | - python .github/workflows/configure-testpypi-version.py pyproject.toml - git update-index --assume-unchanged pyproject.toml - cat pyproject.toml - - - name: Build tarball and wheels - run: | - git clean -xdf - python -m build - - - name: Check built artifacts - run: | - python -m twine check --strict dist/* - if [ -f dist/xarray-0.0.0.tar.gz ]; then - echo "❌ INVALID VERSION NUMBER" - exit 1 - else - echo "✅ Looks good" - fi - - - uses: actions/upload-artifact@v3 - with: - name: releases - path: dist - - test-built-dist: - needs: build-artifacts - runs-on: ubuntu-latest - steps: - - uses: actions/setup-python@v4 - name: Install Python - with: - python-version: "3.10" - - uses: actions/download-artifact@v3 - with: - name: releases - path: dist - - name: List contents of built dist - run: | - ls -ltrh - ls -ltrh dist - - - name: Verify the built dist/wheel is valid - if: github.event_name == 'push' - run: | - python -m pip install --upgrade pip - python -m pip install dist/xarray*.whl - python -m xarray.util.print_versions - - - name: Publish package to TestPyPI - if: github.event_name == 'push' - uses: pypa/gh-action-pypi-publish@v1.8.3 - with: - user: __token__ - password: ${{ secrets.TESTPYPI_TOKEN }} - repository_url: https://test.pypi.org/legacy/ - verbose: true diff --git a/.github/workflows/upstream-dev-ci.yaml b/.github/workflows/upstream-dev-ci.yaml index c7ccee73414..872b2d865fb 100644 --- a/.github/workflows/upstream-dev-ci.yaml +++ b/.github/workflows/upstream-dev-ci.yaml @@ -6,6 +6,7 @@ on: pull_request: branches: - main + types: [opened, reopened, synchronize, labeled] schedule: - cron: "0 0 * * *" # Daily “At 00:00” UTC workflow_dispatch: # allows you to trigger the workflow run manually @@ -24,7 +25,7 @@ jobs: outputs: triggered: ${{ steps.detect-trigger.outputs.trigger-found }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 2 - uses: xarray-contrib/ci-trigger@v1 @@ -36,11 +37,14 @@ jobs: name: upstream-dev runs-on: ubuntu-latest needs: detect-ci-trigger + env: + ZARR_V3_EXPERIMENTAL_API: 1 if: | always() && ( (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') || needs.detect-ci-trigger.outputs.triggered == 'true' + || contains( github.event.pull_request.labels.*.name, 'run-upstream') ) defaults: run: @@ -48,17 +52,17 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10"] + python-version: ["3.12"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 # Fetch all history for all branches and tags. - name: Set up conda environment - uses: mamba-org/provision-with-micromamba@v15 + uses: mamba-org/setup-micromamba@v1 with: environment-file: ci/requirements/environment.yml environment-name: xarray-tests - extra-specs: | + create-args: >- python=${{ matrix.python-version }} pytest-reportlog conda @@ -80,7 +84,6 @@ jobs: if: success() id: status run: | - export ZARR_V3_EXPERIMENTAL_API=1 python -m pytest --timeout=60 -rf \ --report-log output-${{ matrix.python-version }}-log.jsonl - name: Generate and publish the report @@ -92,3 +95,58 @@ jobs: uses: xarray-contrib/issue-from-pytest-log@v1 with: log-path: output-${{ matrix.python-version }}-log.jsonl + + mypy-upstream-dev: + name: mypy-upstream-dev + runs-on: ubuntu-latest + needs: detect-ci-trigger + if: | + always() + && ( + contains( github.event.pull_request.labels.*.name, 'run-upstream') + ) + defaults: + run: + shell: bash -l {0} + strategy: + fail-fast: false + matrix: + python-version: ["3.11"] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. + - name: Set up conda environment + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: ci/requirements/environment.yml + environment-name: xarray-tests + create-args: >- + python=${{ matrix.python-version }} + pytest-reportlog + conda + - name: Install upstream versions + run: | + bash ci/install-upstream-wheels.sh + - name: Install xarray + run: | + python -m pip install --no-deps -e . + - name: Version info + run: | + conda info -a + conda list + python xarray/util/print_versions.py + - name: Install mypy + run: | + python -m pip install mypy --force-reinstall + - name: Run mypy + run: | + python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report + - name: Upload mypy coverage to Codecov + uses: codecov/codecov-action@v4.1.0 + with: + file: mypy_report/cobertura.xml + flags: mypy + env_vars: PYTHON_VERSION + name: codecov-umbrella + fail_ci_if_error: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 45c15da8236..74d77e2f2ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,40 +1,36 @@ # https://pre-commit.com/ +ci: + autoupdate_schedule: monthly +exclude: 'xarray/datatree_.*' repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml - id: debug-statements - id: mixed-line-ending - - repo: https://github.com/MarcoGorelli/absolufy-imports - rev: v0.3.1 - hooks: - - id: absolufy-imports - name: absolufy-imports - files: ^xarray/ - - repo: https://github.com/charliermarsh/ruff-pre-commit + - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.0.259' + rev: 'v0.2.0' hooks: - id: ruff - args: ["--fix"] + args: ["--fix", "--show-fixes"] # https://github.com/python/black#version-control-integration - - repo: https://github.com/psf/black - rev: 23.1.0 + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 24.1.1 hooks: - - id: black - id: black-jupyter - repo: https://github.com/keewis/blackdoc - rev: v0.3.8 + rev: v0.3.9 hooks: - id: blackdoc exclude: "generate_aggregations.py" - additional_dependencies: ["black==23.1.0"] + additional_dependencies: ["black==24.1.1"] - id: blackdoc-autoupdate-black - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.1.1 + rev: v1.8.0 hooks: - id: mypy # Copied from setup.cfg @@ -48,7 +44,7 @@ repos: types-pkg_resources, types-PyYAML, types-pytz, - typing-extensions==3.10.0.0, + typing-extensions>=4.1.0, numpy, ] - repo: https://github.com/citation-file-format/cff-converter-python diff --git a/.readthedocs.yaml b/.readthedocs.yaml index db2e1cd0b9a..55fea717f71 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -7,6 +7,7 @@ build: jobs: post_checkout: - (git --no-pager log --pretty="tformat:%s" -1 | grep -vqF "[skip-rtd]") || exit 183 + - git fetch --unshallow || true pre_install: - git update-index --assume-unchanged doc/conf.py ci/requirements/doc.yml diff --git a/CORE_TEAM_GUIDE.md b/CORE_TEAM_GUIDE.md new file mode 100644 index 00000000000..9eb91f4e586 --- /dev/null +++ b/CORE_TEAM_GUIDE.md @@ -0,0 +1,322 @@ +> **_Note:_** This Core Team Member Guide was adapted from the [napari project's Core Developer Guide](https://napari.org/stable/developers/core_dev_guide.html) and the [Pandas maintainers guide](https://pandas.pydata.org/docs/development/maintaining.html). + +# Core Team Member Guide + +Welcome, new core team member! We appreciate the quality of your work, and enjoy working with you! +Thank you for your numerous contributions to the project so far. + +By accepting the invitation to become a core team member you are **not required to commit to doing any more work** - +xarray is a volunteer project, and we value the contributions you have made already. + +You can see a list of all the current core team members on our +[@pydata/xarray](https://github.com/orgs/pydata/teams/xarray) +GitHub team. Once accepted, you should now be on that list too. +This document offers guidelines for your new role. + +## Tasks + +Xarray values a wide range of contributions, only some of which involve writing code. +As such, we do not currently make a distinction between a "core team member", "core developer", "maintainer", +or "triage team member" as some projects do (e.g. [pandas](https://pandas.pydata.org/docs/development/maintaining.html)). +That said, if you prefer to refer to your role as one of the other titles above then that is fine by us! + +Xarray is mostly a volunteer project, so these tasks shouldn’t be read as “expectations”. +**There are no strict expectations**, other than to adhere to our [Code of Conduct](https://github.com/pydata/xarray/tree/main/CODE_OF_CONDUCT.md). +Rather, the tasks that follow are general descriptions of what it might mean to be a core team member: + +- Facilitate a welcoming environment for those who file issues, make pull requests, and open discussion topics, +- Triage newly filed issues, +- Review newly opened pull requests, +- Respond to updates on existing issues and pull requests, +- Drive discussion and decisions on stalled issues and pull requests, +- Provide experience / wisdom on API design questions to ensure consistency and maintainability, +- Project organization (run developer meetings, coordinate with sponsors), +- Project evangelism (advertise xarray to new users), +- Community contact (represent xarray in user communities such as [Pangeo](https://pangeo.io/)), +- Key project contact (represent xarray's perspective within key related projects like NumPy, Zarr or Dask), +- Project fundraising (help write and administrate grants that will support xarray), +- Improve documentation or tutorials (especially on [`tutorial.xarray.dev`](https://tutorial.xarray.dev/)), +- Presenting or running tutorials (such as those we have given at the SciPy conference), +- Help maintain the [`xarray.dev`](https://xarray.dev/) landing page and website, the [code for which is here](https://github.com/xarray-contrib/xarray.dev), +- Write blog posts on the [xarray blog](https://xarray.dev/blog), +- Help maintain xarray's various Continuous Integration Workflows, +- Help maintain a regular release schedule (we aim for one or more releases per month), +- Attend the bi-weekly community meeting ([issue](https://github.com/pydata/xarray/issues/4001)), +- Contribute to the xarray codebase. + +(Matt Rocklin's post on [the role of a maintainer](https://matthewrocklin.com/blog/2019/05/18/maintainer) may be +interesting background reading, but should not be taken to strictly apply to the Xarray project.) + +Obviously you are not expected to contribute in all (or even more than one) of these ways! +They are listed so as to indicate the many types of work that go into maintaining xarray. + +It is natural that your available time and enthusiasm for the project will wax and wane - this is fine and expected! +It is also common for core team members to have a "niche" - a particular part of the codebase they have specific expertise +with, or certain types of task above which they primarily perform. + +If however you feel that is unlikely you will be able to be actively contribute in the foreseeable future +(or especially if you won't be available to answer questions about pieces of code that you wrote previously) +then you may want to consider letting us know you would rather be listed as an "Emeritus Core Team Member", +as this would help us in evaluating the overall health of the project. + +## Issue triage + +One of the main ways you might spend your contribution time is by responding to or triaging new issues. +Here’s a typical workflow for triaging a newly opened issue or discussion: + +1. **Thank the reporter for opening an issue.** + + The issue tracker is many people’s first interaction with the xarray project itself, beyond just using the library. + It may also be their first open-source contribution of any kind. As such, we want it to be a welcoming, pleasant experience. + +2. **Is the necessary information provided?** + + Ideally reporters would fill out the issue template, but many don’t. If crucial information (like the version of xarray they used), + is missing feel free to ask for that and label the issue with “needs info”. + The report should follow the [guidelines for xarray discussions](https://github.com/pydata/xarray/discussions/5404). + You may want to link to that if they didn’t follow the template. + + Make sure that the title accurately reflects the issue. Edit it yourself if it’s not clear. + Remember also that issues can be converted to discussions and vice versa if appropriate. + +3. **Is this a duplicate issue?** + + We have many open issues. If a new issue is clearly a duplicate, label the new issue as “duplicate”, and close the issue with a link to the original issue. + Make sure to still thank the reporter, and encourage them to chime in on the original issue, and perhaps try to fix it. + + If the new issue provides relevant information, such as a better or slightly different example, add it to the original issue as a comment or an edit to the original post. + +4. **Is the issue minimal and reproducible?** + + For bug reports, we ask that the reporter provide a minimal reproducible example. + See [minimal-bug-reports](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) for a good explanation. + If the example is not reproducible, or if it’s clearly not minimal, feel free to ask the reporter if they can provide and example or simplify the provided one. + Do acknowledge that writing minimal reproducible examples is hard work. If the reporter is struggling, you can try to write one yourself and we’ll edit the original post to include it. + + If a nice reproducible example has been provided, thank the reporter for that. + If a reproducible example can’t be provided, add the “needs mcve” label. + + If a reproducible example is provided, but you see a simplification, edit the original post with your simpler reproducible example. + +5. **Is this a clearly defined feature request?** + + Generally, xarray prefers to discuss and design new features in issues, before a pull request is made. + Encourage the submitter to include a proposed API for the new feature. Having them write a full docstring is a good way to pin down specifics. + + We may need a discussion from several xarray maintainers before deciding whether the proposal is in scope for xarray. + +6. **Is this a usage question?** + + We prefer that usage questions are asked on StackOverflow with the [`python-xarray` tag](https://stackoverflow.com/questions/tagged/python-xarray +) or as a [GitHub discussion topic](https://github.com/pydata/xarray/discussions). + + If it’s easy to answer, feel free to link to the relevant documentation section, let them know that in the future this kind of question should be on StackOverflow, and close the issue. + +7. **What labels and milestones should I add?** + + Apply the relevant labels. This is a bit of an art, and comes with experience. Look at similar issues to get a feel for how things are labeled. + Labels used for labelling issues that relate to particular features or parts of the codebase normally have the form `topic-`. + + If the issue is clearly defined and the fix seems relatively straightforward, label the issue as `contrib-good-first-issue`. + You can also remove the `needs triage` label that is automatically applied to all newly-opened issues. + +8. **Where should the poster look to fix the issue?** + + If you can, it is very helpful to point to the approximate location in the codebase where a contributor might begin to fix the issue. + This helps ease the way in for new contributors to the repository. + +## Code review and contributions + +As a core team member, you are a representative of the project, +and trusted to make decisions that will serve the long term interests +of all users. You also gain the responsibility of shepherding +other contributors through the review process; here are some +guidelines for how to do that. + +### All contributors are treated the same + +You should now have gained the ability to merge or approve +other contributors' pull requests. Merging contributions is a shared power: +only merge contributions you yourself have carefully reviewed, and that are +clear improvements for the project. When in doubt, and especially for more +complex changes, wait until at least one other core team member has approved. +(See [Reviewing](#reviewing) and especially +[Merge Only Changes You Understand](#merge-only-changes-you-understand) below.) + +It should also be considered best practice to leave a reasonable (24hr) time window +after approval before merge to ensure that other core team members have a reasonable +chance to weigh in. +Adding the `plan-to-merge` label notifies developers of the imminent merge. + +We are also an international community, with contributors from many different time zones, +some of whom will only contribute during their working hours, others who might only be able +to contribute during nights and weekends. It is important to be respectful of other peoples +schedules and working habits, even if it slows the project down slightly - we are in this +for the long run. In the same vein you also shouldn't feel pressured to be constantly +available or online, and users or contributors who are overly demanding and unreasonable +to the point of harassment will be directed to our [Code of Conduct](https://github.com/pydata/xarray/tree/main/CODE_OF_CONDUCT.md). +We value sustainable development practices over mad rushes. + +When merging, we automatically use GitHub's +[Squash and Merge](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/merging-a-pull-request#merging-a-pull-request) +to ensure a clean git history. + +You should also continue to make your own pull requests as before and in accordance +with the [general contributing guide](https://docs.xarray.dev/en/stable/contributing.html). These pull requests still +require the approval of another core team member before they can be merged. + +### How to conduct a good review + +*Always* be kind to contributors. Contributors are often doing +volunteer work, for which we are tremendously grateful. Provide +constructive criticism on ideas and implementations, and remind +yourself of how it felt when your own work was being evaluated as a +novice. + +``xarray`` strongly values mentorship in code review. New users +often need more handholding, having little to no git +experience. Repeat yourself liberally, and, if you don’t recognize a +contributor, point them to our development guide, or other GitHub +workflow tutorials around the web. Do not assume that they know how +GitHub works (many don't realize that adding a commit +automatically updates a pull request, for example). Gentle, polite, kind +encouragement can make the difference between a new core team member and +an abandoned pull request. + +When reviewing, focus on the following: + +1. **Usability and generality:** `xarray` is a user-facing package that strives to be accessible +to both novice and advanced users, and new features should ultimately be +accessible to everyone using the package. `xarray` targets the scientific user +community broadly, and core features should be domain-agnostic and general purpose. +Custom functionality is meant to be provided through our various types of interoperability. + +2. **Performance and benchmarks:** As `xarray` targets scientific applications that often involve +large multidimensional datasets, high performance is a key value of `xarray`. While +every new feature won't scale equally to all sizes of data, keeping in mind performance +and our [benchmarks](https://github.com/pydata/xarray/tree/main/asv_bench) during a review may be important, and you may +need to ask for benchmarks to be run and reported or new benchmarks to be added. +You can run the CI benchmarking suite on any PR by tagging it with the ``run-benchmark`` label. + +3. **APIs and stability:** Coding users and developers will make +extensive use of our APIs. The foundation of a healthy ecosystem will be +a fully capable and stable set of APIs, so as `xarray` matures it will +very important to ensure our APIs are stable. Spending the extra time to consider names of public facing +variables and methods, alongside function signatures, could save us considerable +trouble in the future. We do our best to provide [deprecation cycles](https://docs.xarray.dev/en/stable/contributing.html#backwards-compatibility) +when making backwards-incompatible changes. + +4. **Documentation and tutorials:** All new methods should have appropriate doc +strings following [PEP257](https://peps.python.org/pep-0257/) and the +[NumPy documentation guide](https://numpy.org/devdocs/dev/howto-docs.html#documentation-style). +For any major new features, accompanying changes should be made to our +[tutorials](https://tutorial.xarray.dev). These should not only +illustrates the new feature, but explains it. + +5. **Implementations and algorithms:** You should understand the code being modified +or added before approving it. (See [Merge Only Changes You Understand](#merge-only-changes-you-understand) +below.) Implementations should do what they claim and be simple, readable, and efficient +in that order. + +6. **Tests:** All contributions *must* be tested, and each added line of code +should be covered by at least one test. Good tests not only execute the code, +but explore corner cases. It can be tempting not to review tests, but please +do so. + +Other changes may be *nitpicky*: spelling mistakes, formatting, +etc. Do not insist contributors make these changes, but instead you should offer +to make these changes by [pushing to their branch](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/committing-changes-to-a-pull-request-branch-created-from-a-fork), +or using GitHub’s [suggestion](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/reviewing-changes-in-pull-requests/commenting-on-a-pull-request) +[feature](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/reviewing-changes-in-pull-requests/incorporating-feedback-in-your-pull-request), and +be prepared to make them yourself if needed. Using the suggestion feature is preferred because +it gives the contributor a choice in whether to accept the changes. + +Unless you know that a contributor is experienced with git, don’t +ask for a rebase when merge conflicts arise. Instead, rebase the +branch yourself, force-push to their branch, and advise the contributor to force-pull. If the contributor is +no longer active, you may take over their branch by submitting a new pull +request and closing the original, including a reference to the original pull +request. In doing so, ensure you communicate that you are not throwing the +contributor's work away! If appropriate it is a good idea to acknowledge other contributions +to the pull request using the `Co-authored-by` +[syntax](https://docs.github.com/en/pull-requests/committing-changes-to-your-project/creating-and-editing-commits/creating-a-commit-with-multiple-authors) in the commit message. + +### Merge only changes you understand + +*Long-term maintainability* is an important concern. Code doesn't +merely have to *work*, but should be *understood* by multiple core +developers. Changes will have to be made in the future, and the +original contributor may have moved on. + +Therefore, *do not merge a code change unless you understand it*. Ask +for help freely: we can consult community members, or even external developers, +for added insight where needed, and see this as a great learning opportunity. + +While we collectively "own" any patches (and bugs!) that become part +of the code base, you are vouching for changes you merge. Please take +that responsibility seriously. + +Feel free to ping other active maintainers with any questions you may have. + +## Further resources + +As a core member, you should be familiar with community and developer +resources such as: + +- Our [contributor guide](https://docs.xarray.dev/en/stable/contributing.html). +- Our [code of conduct](https://github.com/pydata/xarray/tree/main/CODE_OF_CONDUCT.md). +- Our [philosophy and development roadmap](https://docs.xarray.dev/en/stable/roadmap.html). +- [PEP8](https://peps.python.org/pep-0008/) for Python style. +- [PEP257](https://peps.python.org/pep-0257/) and the + [NumPy documentation guide](https://numpy.org/devdocs/dev/howto-docs.html#documentation-style) + for docstring conventions. +- [`pre-commit`](https://pre-commit.com) hooks for autoformatting. +- [`black`](https://github.com/psf/black) autoformatting. +- [`flake8`](https://github.com/PyCQA/flake8) linting. +- [python-xarray](https://stackoverflow.com/questions/tagged/python-xarray) on Stack Overflow. +- [@xarray_dev](https://twitter.com/xarray_dev) on Twitter. +- [xarray-dev](https://discord.gg/bsSGdwBn) discord community (normally only used for remote synchronous chat during sprints). + +You are not required to monitor any of the social resources. + +Where possible we prefer to point people towards asynchronous forms of communication +like github issues instead of realtime chat options as they are far easier +for a global community to consume and refer back to. + +We hold a [bi-weekly developers meeting](https://docs.xarray.dev/en/stable/developers-meeting.html) via video call. +This is a great place to bring up any questions you have, raise visibility of an issue and/or gather more perspectives. +Attendance is absolutely optional, and we keep the meeting to 30 minutes in respect of your valuable time. +This meeting is public, so we occasionally have non-core team members join us. + +We also have a private mailing list for core team members +`xarray-core-team@googlegroups.com` which is sparingly used for discussions +that are required to be private, such as nominating new core members and discussing financial issues. + +## Inviting new core members + +Any core member may nominate other contributors to join the core team. +While there is no hard-and-fast rule about who can be nominated, ideally, +they should have: been part of the project for at least two months, contributed +significant changes of their own, contributed to the discussion and +review of others' work, and collaborated in a way befitting our +community values. **We strongly encourage nominating anyone who has made significant non-code contributions +to the Xarray community in any way**. After nomination voting will happen on a private mailing list. +While it is expected that most votes will be unanimous, a two-thirds majority of +the cast votes is enough. + +Core team members can choose to become emeritus core team members and suspend +their approval and voting rights until they become active again. + +## Contribute to this guide (!) + +This guide reflects the experience of the current core team members. We +may well have missed things that, by now, have become second +nature—things that you, as a new team member, will spot more easily. +Please ask the other core team members if you have any questions, and +submit a pull request with insights gained. + +## Conclusion + +We are excited to have you on board! We look forward to your +contributions to the code base and the community. Thank you in +advance! diff --git a/HOW_TO_RELEASE.md b/HOW_TO_RELEASE.md index 3bbd551415b..9d1164547b9 100644 --- a/HOW_TO_RELEASE.md +++ b/HOW_TO_RELEASE.md @@ -18,11 +18,16 @@ upstream https://github.com/pydata/xarray (push) git switch main git pull upstream main ``` - 2. Add a list of contributors with: + 2. Add a list of contributors. + First fetch all previous release tags so we can see the version number of the last release was: + ```sh + git fetch upstream --tags + ``` + This will return a list of all the contributors since the last release: ```sh git log "$(git tag --sort=v:refname | tail -1).." --format=%aN | sort -u | perl -pe 's/\n/$1, /' ``` - This will return the number of contributors: + This will return the total number of contributors: ```sh git log "$(git tag --sort=v:refname | tail -1).." --format=%aN | sort -u | wc -l ``` @@ -47,14 +52,14 @@ upstream https://github.com/pydata/xarray (push) ```sh pytest ``` - 8. Check that the ReadTheDocs build is passing on the `main` branch. + 8. Check that the [ReadTheDocs build](https://readthedocs.org/projects/xray/) is passing on the `latest` build version (which is built from the `main` branch). 9. Issue the release on GitHub. Click on "Draft a new release" at . Type in the version number (with a "v") and paste the release summary in the notes. 10. This should automatically trigger an upload of the new build to PyPI via GitHub Actions. Check this has run [here](https://github.com/pydata/xarray/actions/workflows/pypi-release.yaml), and that the version number you expect is displayed [on PyPI](https://pypi.org/project/xarray/) -11. Add a section for the next release {YYYY.MM.X+1} to doc/whats-new.rst: +11. Add a section for the next release {YYYY.MM.X+1} to doc/whats-new.rst (we avoid doing this earlier so that it doesn't show up in the RTD build): ```rst .. _whats-new.YYYY.MM.X+1: diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000000..032b620f433 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +prune xarray/datatree_* diff --git a/README.md b/README.md index 41db66fd395..432d535d1b1 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # xarray: N-D labeled arrays and datasets [![CI](https://github.com/pydata/xarray/workflows/CI/badge.svg?branch=main)](https://github.com/pydata/xarray/actions?query=workflow%3ACI) -[![Code coverage](https://codecov.io/gh/pydata/xarray/branch/main/graph/badge.svg)](https://codecov.io/gh/pydata/xarray) +[![Code coverage](https://codecov.io/gh/pydata/xarray/branch/main/graph/badge.svg?flag=unittests)](https://codecov.io/gh/pydata/xarray) [![Docs](https://readthedocs.org/projects/xray/badge/?version=latest)](https://docs.xarray.dev/) [![Benchmarked with asv](https://img.shields.io/badge/benchmarked%20by-asv-green.svg?style=flat)](https://pandas.pydata.org/speed/xarray/) [![Available on pypi](https://img.shields.io/pypi/v/xarray.svg)](https://pypi.python.org/pypi/xarray/) @@ -108,7 +108,7 @@ Thanks to our many contributors! ## License -Copyright 2014-2019, xarray Developers +Copyright 2014-2023, xarray Developers Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may @@ -125,12 +125,12 @@ limitations under the License. Xarray bundles portions of pandas, NumPy and Seaborn, all of which are available under a "3-clause BSD" license: -- pandas: setup.py, xarray/util/print_versions.py -- NumPy: xarray/core/npcompat.py -- Seaborn: _determine_cmap_params in xarray/core/plot/utils.py +- pandas: `setup.py`, `xarray/util/print_versions.py` +- NumPy: `xarray/core/npcompat.py` +- Seaborn: `_determine_cmap_params` in `xarray/core/plot/utils.py` Xarray also bundles portions of CPython, which is available under the -"Python Software Foundation License" in xarray/core/pycompat.py. +"Python Software Foundation License" in `xarray/core/pycompat.py`. Xarray uses icons from the icomoon package (free version), which is available under the "CC BY 4.0" license. diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index 6f8a306fc43..a709d0a51a7 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -30,6 +30,7 @@ // determined by looking for tools on the PATH environment // variable. "environment_type": "conda", + "conda_channels": ["conda-forge"], // timeout in seconds for installing any dependencies in environment // defaults to 10 min @@ -68,7 +69,8 @@ "distributed": [""], "flox": [""], "numpy_groupies": [""], - "sparse": [""] + "sparse": [""], + "cftime": [""] }, diff --git a/asv_bench/benchmarks/accessors.py b/asv_bench/benchmarks/accessors.py new file mode 100644 index 00000000000..f9eb95851cc --- /dev/null +++ b/asv_bench/benchmarks/accessors.py @@ -0,0 +1,25 @@ +import numpy as np + +import xarray as xr + +from . import parameterized + +NTIME = 365 * 30 + + +@parameterized(["calendar"], [("standard", "noleap")]) +class DateTimeAccessor: + def setup(self, calendar): + np.random.randn(NTIME) + time = xr.date_range("2000", periods=30 * 365, calendar=calendar) + data = np.ones((NTIME,)) + self.da = xr.DataArray(data, dims="time", coords={"time": time}) + + def time_dayofyear(self, calendar): + self.da.time.dt.dayofyear + + def time_year(self, calendar): + self.da.time.dt.year + + def time_floor(self, calendar): + self.da.time.dt.floor("D") diff --git a/asv_bench/benchmarks/alignment.py b/asv_bench/benchmarks/alignment.py new file mode 100644 index 00000000000..5a6ee3fa0a6 --- /dev/null +++ b/asv_bench/benchmarks/alignment.py @@ -0,0 +1,54 @@ +import numpy as np + +import xarray as xr + +from . import parameterized, requires_dask + +ntime = 365 * 30 +nx = 50 +ny = 50 + +rng = np.random.default_rng(0) + + +class Align: + def setup(self, *args, **kwargs): + data = rng.standard_normal((ntime, nx, ny)) + self.ds = xr.Dataset( + {"temperature": (("time", "x", "y"), data)}, + coords={ + "time": xr.date_range("2000", periods=ntime), + "x": np.arange(nx), + "y": np.arange(ny), + }, + ) + self.year = self.ds.time.dt.year + self.idx = np.unique(rng.integers(low=0, high=ntime, size=ntime // 2)) + self.year_subset = self.year.isel(time=self.idx) + + @parameterized(["join"], [("outer", "inner", "left", "right", "exact", "override")]) + def time_already_aligned(self, join): + xr.align(self.ds, self.year, join=join) + + @parameterized(["join"], [("outer", "inner", "left", "right")]) + def time_not_aligned(self, join): + xr.align(self.ds, self.year[-100:], join=join) + + @parameterized(["join"], [("outer", "inner", "left", "right")]) + def time_not_aligned_random_integers(self, join): + xr.align(self.ds, self.year_subset, join=join) + + +class AlignCFTime(Align): + def setup(self, *args, **kwargs): + super().setup() + self.ds["time"] = xr.date_range("2000", periods=ntime, calendar="noleap") + self.year = self.ds.time.dt.year + self.year_subset = self.year.isel(time=self.idx) + + +class AlignDask(Align): + def setup(self, *args, **kwargs): + requires_dask() + super().setup() + self.ds = self.ds.chunk({"time": 100}) diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index a4f8db2786b..772d888306c 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -2,8 +2,49 @@ import xarray as xr +from . import requires_dask -class Combine: + +class Combine1d: + """Benchmark concatenating and merging large datasets""" + + def setup(self) -> None: + """Create 2 datasets with two different variables""" + + t_size = 8000 + t = np.arange(t_size) + data = np.random.randn(t_size) + + self.dsA0 = xr.Dataset({"A": xr.DataArray(data, coords={"T": t}, dims=("T"))}) + self.dsA1 = xr.Dataset( + {"A": xr.DataArray(data, coords={"T": t + t_size}, dims=("T"))} + ) + + def time_combine_by_coords(self) -> None: + """Also has to load and arrange t coordinate""" + datasets = [self.dsA0, self.dsA1] + + xr.combine_by_coords(datasets) + + +class Combine1dDask(Combine1d): + """Benchmark concatenating and merging large datasets""" + + def setup(self) -> None: + """Create 2 datasets with two different variables""" + requires_dask() + + t_size = 8000 + t = np.arange(t_size) + var = xr.Variable(dims=("T",), data=np.random.randn(t_size)).chunk() + + data_vars = {f"long_name_{v}": ("T", var) for v in range(500)} + + self.dsA0 = xr.Dataset(data_vars, coords={"T": t}) + self.dsA1 = xr.Dataset(data_vars, coords={"T": t + t_size}) + + +class Combine3d: """Benchmark concatenating and merging large datasets""" def setup(self): diff --git a/asv_bench/benchmarks/dataset.py b/asv_bench/benchmarks/dataset.py new file mode 100644 index 00000000000..d8a6d6df9d8 --- /dev/null +++ b/asv_bench/benchmarks/dataset.py @@ -0,0 +1,32 @@ +import numpy as np + +from xarray import Dataset + +from . import requires_dask + + +class DatasetBinaryOp: + def setup(self): + self.ds = Dataset( + { + "a": (("x", "y"), np.ones((300, 400))), + "b": (("x", "y"), np.ones((300, 400))), + } + ) + self.mean = self.ds.mean() + self.std = self.ds.std() + + def time_normalize(self): + (self.ds - self.mean) / self.std + + +class DatasetChunk: + def setup(self): + requires_dask() + self.ds = Dataset() + array = np.ones(1000) + for i in range(250): + self.ds[f"var{i}"] = ("x", array) + + def time_chunk(self): + self.ds.chunk(x=(1,) * 1000) diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index 0af8084dd21..dcc2de0473b 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -527,7 +527,7 @@ def time_read_dataset(self, engine, chunks): class IOReadCustomEngine: def setup(self, *args, **kwargs): """ - The custom backend does the bare mininum to be considered a lazy backend. But + The custom backend does the bare minimum to be considered a lazy backend. But the data in it is still in memory so slow file reading shouldn't affect the results. """ @@ -593,7 +593,7 @@ def load(self) -> tuple: n_variables = 2000 # Important to have a shape and dtype for lazy loading. - shape = (1,) + shape = (1000,) dtype = np.dtype(int) variables = { f"long_variable_name_{v}": xr.Variable( @@ -643,7 +643,7 @@ def open_dataset( self.engine = PerformanceBackend - @parameterized(["chunks"], ([None, {}])) + @parameterized(["chunks"], ([None, {}, {"time": 10}])) def time_open_dataset(self, chunks): """ Time how fast xr.open_dataset is without the slow data reading part. diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 8cd23f3947c..1b3e55fa659 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -18,7 +18,7 @@ def setup(self, *args, **kwargs): "c": xr.DataArray(np.arange(2 * self.n)), } ) - self.ds2d = self.ds1d.expand_dims(z=10) + self.ds2d = self.ds1d.expand_dims(z=10).copy() self.ds1d_mean = self.ds1d.groupby("b").mean() self.ds2d_mean = self.ds2d.groupby("b").mean() @@ -26,15 +26,21 @@ def setup(self, *args, **kwargs): def time_init(self, ndim): getattr(self, f"ds{ndim}d").groupby("b") - @parameterized(["method", "ndim"], [("sum", "mean"), (1, 2)]) - def time_agg_small_num_groups(self, method, ndim): + @parameterized( + ["method", "ndim", "use_flox"], [("sum", "mean"), (1, 2), (True, False)] + ) + def time_agg_small_num_groups(self, method, ndim, use_flox): ds = getattr(self, f"ds{ndim}d") - getattr(ds.groupby("a"), method)().compute() + with xr.set_options(use_flox=use_flox): + getattr(ds.groupby("a"), method)().compute() - @parameterized(["method", "ndim"], [("sum", "mean"), (1, 2)]) - def time_agg_large_num_groups(self, method, ndim): + @parameterized( + ["method", "ndim", "use_flox"], [("sum", "mean"), (1, 2), (True, False)] + ) + def time_agg_large_num_groups(self, method, ndim, use_flox): ds = getattr(self, f"ds{ndim}d") - getattr(ds.groupby("b"), method)().compute() + with xr.set_options(use_flox=use_flox): + getattr(ds.groupby("b"), method)().compute() def time_binary_op_1d(self): (self.ds1d.groupby("b") - self.ds1d_mean).compute() @@ -115,15 +121,21 @@ def setup(self, *args, **kwargs): def time_init(self, ndim): getattr(self, f"ds{ndim}d").resample(time="D") - @parameterized(["method", "ndim"], [("sum", "mean"), (1, 2)]) - def time_agg_small_num_groups(self, method, ndim): + @parameterized( + ["method", "ndim", "use_flox"], [("sum", "mean"), (1, 2), (True, False)] + ) + def time_agg_small_num_groups(self, method, ndim, use_flox): ds = getattr(self, f"ds{ndim}d") - getattr(ds.resample(time="3M"), method)().compute() + with xr.set_options(use_flox=use_flox): + getattr(ds.resample(time="3M"), method)().compute() - @parameterized(["method", "ndim"], [("sum", "mean"), (1, 2)]) - def time_agg_large_num_groups(self, method, ndim): + @parameterized( + ["method", "ndim", "use_flox"], [("sum", "mean"), (1, 2), (True, False)] + ) + def time_agg_large_num_groups(self, method, ndim, use_flox): ds = getattr(self, f"ds{ndim}d") - getattr(ds.resample(time="48H"), method)().compute() + with xr.set_options(use_flox=use_flox): + getattr(ds.resample(time="48H"), method)().compute() class ResampleDask(Resample): @@ -132,3 +144,32 @@ def setup(self, *args, **kwargs): super().setup(**kwargs) self.ds1d = self.ds1d.chunk({"time": 50}) self.ds2d = self.ds2d.chunk({"time": 50, "z": 4}) + + +class ResampleCFTime(Resample): + def setup(self, *args, **kwargs): + self.ds1d = xr.Dataset( + { + "b": ("time", np.arange(365.0 * 24)), + }, + coords={ + "time": xr.date_range( + "2001-01-01", freq="H", periods=365 * 24, calendar="noleap" + ) + }, + ) + self.ds2d = self.ds1d.expand_dims(z=10) + self.ds1d_mean = self.ds1d.resample(time="48H").mean() + self.ds2d_mean = self.ds2d.resample(time="48H").mean() + + +@parameterized(["use_cftime", "use_flox"], [[True, False], [True, False]]) +class GroupByLongTime: + def setup(self, use_cftime, use_flox): + arr = np.random.randn(10, 10, 365 * 30) + time = xr.date_range("2000", periods=30 * 365, use_cftime=use_cftime) + self.da = xr.DataArray(arr, dims=("y", "x", "time"), coords={"time": time}) + + def time_mean(self, use_cftime, use_flox): + with xr.set_options(use_flox=use_flox): + self.da.groupby("time.year").mean() diff --git a/asv_bench/benchmarks/merge.py b/asv_bench/benchmarks/merge.py index 043de35bdf7..6c8c1e9da90 100644 --- a/asv_bench/benchmarks/merge.py +++ b/asv_bench/benchmarks/merge.py @@ -41,7 +41,7 @@ def setup(self, strategy, count): data = np.array(["0", "b"], dtype=str) self.dataset_coords = dict(time=np.array([0, 1])) self.dataset_attrs = dict(description="Test data") - attrs = dict(units="Celcius") + attrs = dict(units="Celsius") if strategy == "dict_of_DataArrays": def create_data_vars(): diff --git a/asv_bench/benchmarks/pandas.py b/asv_bench/benchmarks/pandas.py index 2a296ecc4d0..ebe61081916 100644 --- a/asv_bench/benchmarks/pandas.py +++ b/asv_bench/benchmarks/pandas.py @@ -13,7 +13,7 @@ def setup(self, dtype, subset): [ list("abcdefhijk"), list("abcdefhijk"), - pd.date_range(start="2000-01-01", periods=1000, freq="B"), + pd.date_range(start="2000-01-01", periods=1000, freq="D"), ] ) series = pd.Series(data, index) @@ -29,19 +29,20 @@ def time_from_series(self, dtype, subset): class ToDataFrame: def setup(self, *args, **kwargs): xp = kwargs.get("xp", np) + nvars = kwargs.get("nvars", 1) random_kws = kwargs.get("random_kws", {}) method = kwargs.get("method", "to_dataframe") dim1 = 10_000 dim2 = 10_000 + + var = xr.Variable( + dims=("dim1", "dim2"), data=xp.random.random((dim1, dim2), **random_kws) + ) + data_vars = {f"long_name_{v}": (("dim1", "dim2"), var) for v in range(nvars)} + ds = xr.Dataset( - { - "x": xr.DataArray( - data=xp.random.random((dim1, dim2), **random_kws), - dims=["dim1", "dim2"], - coords={"dim1": np.arange(0, dim1), "dim2": np.arange(0, dim2)}, - ) - } + data_vars, coords={"dim1": np.arange(0, dim1), "dim2": np.arange(0, dim2)} ) self.to_frame = getattr(ds, method) @@ -58,4 +59,6 @@ def setup(self, *args, **kwargs): import dask.array as da - super().setup(xp=da, random_kws=dict(chunks=5000), method="to_dask_dataframe") + super().setup( + xp=da, random_kws=dict(chunks=5000), method="to_dask_dataframe", nvars=500 + ) diff --git a/asv_bench/benchmarks/rolling.py b/asv_bench/benchmarks/rolling.py index 1d3713f19bf..579f4f00fbc 100644 --- a/asv_bench/benchmarks/rolling.py +++ b/asv_bench/benchmarks/rolling.py @@ -5,10 +5,10 @@ from . import parameterized, randn, requires_dask -nx = 300 +nx = 3000 long_nx = 30000 ny = 200 -nt = 100 +nt = 1000 window = 20 randn_xy = randn((nx, ny), frac_nan=0.1) @@ -115,6 +115,11 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck): roll = self.ds.var3.rolling(t=100) getattr(roll, func)() + @parameterized(["stride"], ([None, 5, 50])) + def peakmem_1drolling_construct(self, stride): + self.ds.var2.rolling(t=100).construct("w", stride=stride) + self.ds.var3.rolling(t=100).construct("w", stride=stride) + class DatasetRollingMemory(RollingMemory): @parameterized(["func", "use_bottleneck"], (["sum", "max", "mean"], [True, False])) @@ -128,3 +133,7 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck): with xr.set_options(use_bottleneck=use_bottleneck): roll = self.ds.rolling(t=100) getattr(roll, func)() + + @parameterized(["stride"], ([None, 5, 50])) + def peakmem_1drolling_construct(self, stride): + self.ds.rolling(t=100).construct("w", stride=stride) diff --git a/ci/install-upstream-wheels.sh b/ci/install-upstream-wheels.sh index 171ba3bf55f..d9c797e27cd 100755 --- a/ci/install-upstream-wheels.sh +++ b/ci/install-upstream-wheels.sh @@ -1,30 +1,30 @@ #!/usr/bin/env bash +# install cython for building cftime without build isolation +micromamba install "cython>=0.29.20" py-cpuinfo # temporarily (?) remove numbagg and numba -pip uninstall -y numbagg -conda uninstall -y numba +micromamba remove -y numba numbagg sparse +# temporarily remove numexpr +micromamba remove -y numexpr +# temporarily remove backends +micromamba remove -y cf_units hdf5 h5py netcdf4 # forcibly remove packages to avoid artifacts -conda uninstall -y --force \ +micromamba remove -y --force \ numpy \ scipy \ pandas \ - matplotlib \ - dask \ distributed \ fsspec \ zarr \ cftime \ - rasterio \ packaging \ pint \ bottleneck \ - sparse \ flox \ - h5netcdf \ - xarray + numcodecs # to limit the runtime of Upstream CI python -m pip install \ - -i https://pypi.anaconda.org/scipy-wheels-nightly/simple \ + -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple \ --no-deps \ --pre \ --upgrade \ @@ -32,19 +32,44 @@ python -m pip install \ scipy \ matplotlib \ pandas +# for some reason pandas depends on pyarrow already. +# Remove once a `pyarrow` version compiled with `numpy>=2.0` is on `conda-forge` +python -m pip install \ + -i https://pypi.fury.io/arrow-nightlies/ \ + --prefer-binary \ + --no-deps \ + --pre \ + --upgrade \ + pyarrow +# without build isolation for packages compiling against numpy +# TODO: remove once there are `numpy>=2.0` builds for these +python -m pip install \ + --no-deps \ + --upgrade \ + --no-build-isolation \ + git+https://github.com/Unidata/cftime +python -m pip install \ + --no-deps \ + --upgrade \ + --no-build-isolation \ + git+https://github.com/zarr-developers/numcodecs +python -m pip install \ + --no-deps \ + --upgrade \ + --no-build-isolation \ + git+https://github.com/pydata/bottleneck python -m pip install \ --no-deps \ --upgrade \ git+https://github.com/dask/dask \ + git+https://github.com/dask/dask-expr \ git+https://github.com/dask/distributed \ git+https://github.com/zarr-developers/zarr \ - git+https://github.com/Unidata/cftime \ - git+https://github.com/rasterio/rasterio \ git+https://github.com/pypa/packaging \ git+https://github.com/hgrecco/pint \ - git+https://github.com/pydata/bottleneck \ - git+https://github.com/pydata/sparse \ git+https://github.com/intake/filesystem_spec \ git+https://github.com/SciTools/nc-time-axis \ git+https://github.com/xarray-contrib/flox \ - git+https://github.com/h5netcdf/h5netcdf + git+https://github.com/dgasmith/opt_einsum + # git+https://github.com/pydata/sparse + # git+https://github.com/h5netcdf/h5netcdf diff --git a/ci/min_deps_check.py b/ci/min_deps_check.py index 9631cb03162..48ea323ed81 100755 --- a/ci/min_deps_check.py +++ b/ci/min_deps_check.py @@ -1,8 +1,10 @@ #!/usr/bin/env python """Fetch from conda database all available versions of the xarray dependencies and their -publication date. Compare it against requirements/py37-min-all-deps.yml to verify the +publication date. Compare it against requirements/min-all-deps.yml to verify the policy on obsolete dependencies is being followed. Print a pretty report :) """ +from __future__ import annotations + import itertools import sys from collections.abc import Iterator @@ -29,7 +31,7 @@ "pytest-timeout", } -POLICY_MONTHS = {"python": 24, "numpy": 18} +POLICY_MONTHS = {"python": 30, "numpy": 18} POLICY_MONTHS_DEFAULT = 12 POLICY_OVERRIDE: dict[str, tuple[int, int]] = {} errors = [] @@ -46,7 +48,7 @@ def warning(msg: str) -> None: def parse_requirements(fname) -> Iterator[tuple[str, int, int, int | None]]: - """Load requirements/py37-min-all-deps.yml + """Load requirements/min-all-deps.yml Yield (package name, major version, minor version, [patch version]) """ @@ -109,6 +111,9 @@ def metadata(entry): (3, 6): datetime(2016, 12, 23), (3, 7): datetime(2018, 6, 27), (3, 8): datetime(2019, 10, 14), + (3, 9): datetime(2020, 10, 5), + (3, 10): datetime(2021, 10, 4), + (3, 11): datetime(2022, 10, 24), } ) diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index 1387466b702..2f47643cc87 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -3,13 +3,12 @@ channels: - conda-forge - nodefaults dependencies: - - python=3.10 - black - aiobotocore + - array-api-strict - boto3 - bottleneck - cartopy - - cdms2 - cftime - coveralls - flox @@ -23,12 +22,11 @@ dependencies: - netcdf4 - numba - numbagg - - numpy<1.24 + - numpy - packaging - - pandas<2 - - pint + - pandas + - pint>=0.22 - pip - - pseudonetcdf - pydap - pytest - pytest-cov diff --git a/ci/requirements/bare-minimum.yml b/ci/requirements/bare-minimum.yml index 0a36493fa07..56af319f0bb 100644 --- a/ci/requirements/bare-minimum.yml +++ b/ci/requirements/bare-minimum.yml @@ -11,6 +11,6 @@ dependencies: - pytest-env - pytest-xdist - pytest-timeout - - numpy=1.21 - - packaging=21.3 - - pandas=1.4 + - numpy=1.23 + - packaging=22.0 + - pandas=1.5 diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index 2d35ab8724b..2669224748e 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -7,9 +7,13 @@ dependencies: - python=3.10 - bottleneck - cartopy + - cfgrib - dask-core>=2022.1 + - dask-expr + - hypothesis>=6.75.8 - h5netcdf>=0.13 - ipykernel + - ipywidgets # silence nbsphinx warning - ipython - iris>=2.3 - jupyter_client @@ -17,22 +21,22 @@ dependencies: - nbsphinx - netcdf4>=1.5 - numba - - numpy>=1.21,<1.24 + - numpy>=1.21 - packaging>=21.3 - - pandas>=1.4,<2 + - pandas>=1.4,!=2.1.0 - pooch - pip - pre-commit - pyproj - - rasterio>=1.1 - scipy!=1.10.0 - seaborn - setuptools - sparse - sphinx-autosummary-accessors - - sphinx-book-theme >= 0.3.0 + - sphinx-book-theme<=1.0.1 - sphinx-copybutton - sphinx-design + - sphinx-inline-tabs - sphinx>=5.0 - zarr>=2.10 - pip: diff --git a/ci/requirements/environment-py311.yml b/ci/requirements/environment-3.12.yml similarity index 81% rename from ci/requirements/environment-py311.yml rename to ci/requirements/environment-3.12.yml index cd9edbb5052..dbb446f4454 100644 --- a/ci/requirements/environment-py311.yml +++ b/ci/requirements/environment-3.12.yml @@ -4,21 +4,22 @@ channels: - nodefaults dependencies: - aiobotocore + - array-api-strict - boto3 - bottleneck - cartopy - # - cdms2 - cftime - dask-core + - dask-expr - distributed - flox - - fsspec!=2021.7.0 + - fsspec - h5netcdf - h5py - hdf5 - hypothesis - iris - - lxml # Optional dep of pydap + - lxml # Optional dep of pydap - matplotlib-base - nc-time-axis - netcdf4 @@ -26,13 +27,13 @@ dependencies: # - numbagg - numexpr - numpy + - opt_einsum - packaging - - pandas<2 - - pint + - pandas + # - pint>=0.22 - pip - pooch - pre-commit - - pseudonetcdf - pydap - pytest - pytest-cov diff --git a/ci/requirements/environment-windows-py311.yml b/ci/requirements/environment-windows-3.12.yml similarity index 78% rename from ci/requirements/environment-windows-py311.yml rename to ci/requirements/environment-windows-3.12.yml index effef0d7961..448e3f70c0c 100644 --- a/ci/requirements/environment-windows-py311.yml +++ b/ci/requirements/environment-windows-3.12.yml @@ -2,21 +2,22 @@ name: xarray-tests channels: - conda-forge dependencies: + - array-api-strict - boto3 - bottleneck - cartopy - # - cdms2 # Not available on Windows - cftime - dask-core + - dask-expr - distributed - flox - - fsspec!=2021.7.0 + - fsspec - h5netcdf - h5py - hdf5 - hypothesis - iris - - lxml # Optional dep of pydap + - lxml # Optional dep of pydap - matplotlib-base - nc-time-axis - netcdf4 @@ -24,11 +25,10 @@ dependencies: # - numbagg - numpy - packaging - - pandas<2 - - pint + - pandas + # - pint>=0.22 - pip - pre-commit - - pseudonetcdf - pydap - pytest - pytest-cov diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index c02907b24ac..c1027b525d0 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -2,33 +2,33 @@ name: xarray-tests channels: - conda-forge dependencies: + - array-api-strict - boto3 - bottleneck - cartopy - # - cdms2 # Not available on Windows - cftime - dask-core + - dask-expr - distributed - flox - - fsspec!=2021.7.0 + - fsspec - h5netcdf - h5py - hdf5 - hypothesis - iris - - lxml # Optional dep of pydap + - lxml # Optional dep of pydap - matplotlib-base - nc-time-axis - netcdf4 - numba - numbagg - - numpy<1.24 + - numpy - packaging - - pandas<2 - - pint + - pandas + - pint>=0.22 - pip - pre-commit - - pseudonetcdf - pydap - pytest - pytest-cov diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 9abe1b295a2..d3dbc088867 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -4,35 +4,37 @@ channels: - nodefaults dependencies: - aiobotocore + - array-api-strict - boto3 - bottleneck - cartopy - - cdms2 - cftime - dask-core + - dask-expr # dask raises a deprecation warning without this, breaking doctests - distributed - flox - - fsspec!=2021.7.0 + - fsspec - h5netcdf - h5py - hdf5 - hypothesis - iris - - lxml # Optional dep of pydap + - lxml # Optional dep of pydap - matplotlib-base - nc-time-axis - netcdf4 - numba - numbagg - numexpr - - numpy<1.24 + - numpy + - opt_einsum - packaging - - pandas<2 - - pint + - pandas + - pint>=0.22 - pip - pooch - pre-commit - - pseudonetcdf + - pyarrow # pandas raises a deprecation warning without this, breaking doctests - pydap - pytest - pytest-cov diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index e50d08264b8..d2965fb3fc5 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -8,48 +8,50 @@ dependencies: # When upgrading python, numpy, or pandas, must also change # doc/user-guide/installing.rst, doc/user-guide/plotting.rst and setup.py. - python=3.9 - - boto3=1.20 + - array-api-strict=1.0 # dependency for testing the array api compat + - boto3=1.24 - bottleneck=1.3 - - cartopy=0.20 - - cdms2=3.1 - - cftime=1.5 + - cartopy=0.21 + - cftime=1.6 - coveralls - - dask-core=2022.1 - - distributed=2022.1 - - flox=0.5 - - h5netcdf=0.13 + - dask-core=2022.12 + - distributed=2022.12 + # Flox > 0.8 has a bug with numbagg versions + # It will require numbagg > 0.6 + # so we should just skip that series eventually + # or keep flox pinned for longer than necessary + - flox=0.7 + - h5netcdf=1.1 # h5py and hdf5 tend to cause conflicts # for e.g. hdf5 1.12 conflicts with h5py=3.1 # prioritize bumping other packages instead - - h5py=3.6 + - h5py=3.7 - hdf5=1.12 - hypothesis - - iris=3.1 - - lxml=4.7 # Optional dep of pydap - - matplotlib-base=3.5 + - iris=3.4 + - lxml=4.9 # Optional dep of pydap + - matplotlib-base=3.6 - nc-time-axis=1.4 # netcdf follows a 1.major.minor[.patch] convention # (see https://github.com/Unidata/netcdf4-python/issues/1090) - - netcdf4=1.5.7 - - numba=0.55 - - numpy=1.21 - - packaging=21.3 - - pandas=1.4 - - pint=0.18 + - netcdf4=1.6.0 + - numba=0.56 + - numbagg=0.2.1 + - numpy=1.23 + - packaging=22.0 + - pandas=1.5 + - pint=0.22 - pip - - pseudonetcdf=3.2 - - pydap=3.2 + - pydap=3.3 - pytest - pytest-cov - pytest-env - pytest-xdist - pytest-timeout - - rasterio=1.2 - - scipy=1.7 - - seaborn=0.11 + - rasterio=1.3 + - scipy=1.10 + - seaborn=0.12 - sparse=0.13 - - toolz=0.11 - - typing_extensions=4.0 - - zarr=2.10 - - pip: - - numbagg==0.1 + - toolz=0.12 + - typing_extensions=4.4 + - zarr=2.13 diff --git a/conftest.py b/conftest.py index 862a1a1d0bc..24b7530b220 100644 --- a/conftest.py +++ b/conftest.py @@ -39,3 +39,11 @@ def add_standard_imports(doctest_namespace, tmpdir): # always switch to the temporary directory, so files get written there tmpdir.chdir() + + # Avoid the dask deprecation warning, can remove if CI passes without this. + try: + import dask + except ImportError: + pass + else: + dask.config.set({"dataframe.query-planning": True}) diff --git a/design_notes/grouper_objects.md b/design_notes/grouper_objects.md new file mode 100644 index 00000000000..af42ef2f493 --- /dev/null +++ b/design_notes/grouper_objects.md @@ -0,0 +1,240 @@ +# Grouper Objects +**Author**: Deepak Cherian +**Created**: Nov 21, 2023 + +## Abstract + +I propose the addition of Grouper objects to Xarray's public API so that +```python +Dataset.groupby(x=BinGrouper(bins=np.arange(10, 2)))) +``` +is identical to today's syntax: +```python +Dataset.groupby_bins("x", bins=np.arange(10, 2)) +``` + +## Motivation and scope + +Xarray's GroupBy API implements the split-apply-combine pattern (Wickham, 2011)[^1], which applies to a very large number of problems: histogramming, compositing, climatological averaging, resampling to a different time frequency, etc. +The pattern abstracts the following pseudocode: +```python +results = [] +for element in unique_labels: + subset = ds.sel(x=(ds.x == element)) # split + # subset = ds.where(ds.x == element, drop=True) # alternative + result = subset.mean() # apply + results.append(result) + +xr.concat(results) # combine +``` + +to +```python +ds.groupby('x').mean() # splits, applies, and combines +``` + +Efficient vectorized implementations of this pattern are implemented in numpy's [`ufunc.at`](https://numpy.org/doc/stable/reference/generated/numpy.ufunc.at.html), [`ufunc.reduceat`](https://numpy.org/doc/stable/reference/generated/numpy.ufunc.reduceat.html), [`numbagg.grouped`](https://github.com/numbagg/numbagg/blob/main/numbagg/grouped.py), [`numpy_groupies`](https://github.com/ml31415/numpy-groupies), and probably more. +These vectorized implementations *all* require, as input, an array of integer codes or labels that identify unique elements in the array being grouped over (`'x'` in the example above). +```python +import numpy as np + +# array to reduce +a = np.array([1, 1, 1, 1, 2]) + +# initial value for result +out = np.zeros((3,), dtype=int) + +# integer codes +labels = np.array([0, 0, 1, 2, 1]) + +# groupby-reduction +np.add.at(out, labels, a) +out # array([2, 3, 1]) +``` + +One can 'factorize' or construct such an array of integer codes using `pandas.factorize` or `numpy.unique(..., return_inverse=True)` for categorical arrays; `pandas.cut`, `pandas.qcut`, or `np.digitize` for discretizing continuous variables. +In practice, since `GroupBy` objects exist, much of complexity in applying the groupby paradigm stems from appropriately factorizing or generating labels for the operation. +Consider these two examples: +1. [Bins that vary in a dimension](https://flox.readthedocs.io/en/latest/user-stories/nD-bins.html) +2. [Overlapping groups](https://flox.readthedocs.io/en/latest/user-stories/overlaps.html) +3. [Rolling resampling](https://github.com/pydata/xarray/discussions/8361) + +Anecdotally, less experienced users commonly resort to the for-loopy implementation illustrated by the pseudocode above when the analysis at hand is not easily expressed using the API presented by Xarray's GroupBy object. +Xarray's GroupBy API today abstracts away the split, apply, and combine stages but not the "factorize" stage. +Grouper objects will close the gap. + +## Usage and impact + + +Grouper objects +1. Will abstract useful factorization algorithms, and +2. Present a natural way to extend GroupBy to grouping by multiple variables: `ds.groupby(x=BinGrouper(...), t=Resampler(freq="M", ...)).mean()`. + +In addition, Grouper objects provide a nice interface to add often-requested grouping functionality +1. A new `SpaceResampler` would allow specifying resampling spatial dimensions. ([issue](https://github.com/pydata/xarray/issues/4008)) +2. `RollingTimeResampler` would allow rolling-like functionality that understands timestamps ([issue](https://github.com/pydata/xarray/issues/3216)) +3. A `QuantileBinGrouper` to abstract away `pd.cut` ([issue](https://github.com/pydata/xarray/discussions/7110)) +4. A `SeasonGrouper` and `SeasonResampler` would abstract away common annoyances with such calculations today + 1. Support seasons that span a year-end. + 2. Only include seasons with complete data coverage. + 3. Allow grouping over seasons of unequal length + 4. See [this xcdat discussion](https://github.com/xCDAT/xcdat/issues/416) for a `SeasonGrouper` like functionality: + 5. Return results with seasons in a sensible order +5. Weighted grouping ([issue](https://github.com/pydata/xarray/issues/3937)) + 1. Once `IntervalIndex` like objects are supported, `Resampler` groupers can account for interval lengths when resampling. + +## Backward Compatibility + +Xarray's existing grouping functionality will be exposed using two new Groupers: +1. `UniqueGrouper` which uses `pandas.factorize` +2. `BinGrouper` which uses `pandas.cut` +3. `TimeResampler` which mimics pandas' `.resample` + +Grouping by single variables will be unaffected so that `ds.groupby('x')` will be identical to `ds.groupby(x=UniqueGrouper())`. +Similarly, `ds.groupby_bins('x', bins=np.arange(10, 2))` will be unchanged and identical to `ds.groupby(x=BinGrouper(bins=np.arange(10, 2)))`. + +## Detailed description + +All Grouper objects will subclass from a Grouper object +```python +import abc + +class Grouper(abc.ABC): + @abc.abstractmethod + def factorize(self, by: DataArray): + raise NotImplementedError + +class CustomGrouper(Grouper): + def factorize(self, by: DataArray): + ... + return codes, group_indices, unique_coord, full_index + + def weights(self, by: DataArray) -> DataArray: + ... + return weights +``` + +### The `factorize` method +Today, the `factorize` method takes as input the group variable and returns 4 variables (I propose to clean this up below): +1. `codes`: An array of same shape as the `group` with int dtype. NaNs in `group` are coded by `-1` and ignored later. +2. `group_indices` is a list of index location of `group` elements that belong to a single group. +3. `unique_coord` is (usually) a `pandas.Index` object of all unique `group` members present in `group`. +4. `full_index` is a `pandas.Index` of all `group` members. This is different from `unique_coord` for binning and resampling, where not all groups in the output may be represented in the input `group`. For grouping by a categorical variable e.g. `['a', 'b', 'a', 'c']`, `full_index` and `unique_coord` are identical. +There is some redundancy here since `unique_coord` is always equal to or a subset of `full_index`. +We can clean this up (see Implementation below). + +### The `weights` method (?) + +The proposed `weights` method is optional and unimplemented today. +Groupers with `weights` will allow composing `weighted` and `groupby` ([issue](https://github.com/pydata/xarray/issues/3937)). +The `weights` method should return an appropriate array of weights such that the following property is satisfied +```python +gb_sum = ds.groupby(by).sum() + +weights = CustomGrouper.weights(by) +weighted_sum = xr.dot(ds, weights) + +assert_identical(gb_sum, weighted_sum) +``` +For example, the boolean weights for `group=np.array(['a', 'b', 'c', 'a', 'a'])` should be +``` +[[1, 0, 0, 1, 1], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0]] +``` +This is the boolean "summarization matrix" referred to in the classic Iverson (1980, Section 4.3)[^2] and "nub sieve" in [various APLs](https://aplwiki.com/wiki/Nub_Sieve). + +> [!NOTE] +> We can always construct `weights` automatically using `group_indices` from `factorize`, so this is not a required method. + +For a rolling resampling, windowed weights are possible +``` +[[0.5, 1, 0.5, 0, 0], + [0, 0.25, 1, 1, 0], + [0, 0, 0, 1, 1]] +``` + +### The `preferred_chunks` method (?) + +Rechunking support is another optional extension point. +In `flox` I experimented some with automatically rechunking to make a groupby more parallel-friendly ([example 1](https://flox.readthedocs.io/en/latest/generated/flox.rechunk_for_blockwise.html), [example 2](https://flox.readthedocs.io/en/latest/generated/flox.rechunk_for_cohorts.html)). +A great example is for resampling-style groupby reductions, for which `codes` might look like +``` +0001|11122|3333 +``` +where `|` represents chunk boundaries. A simple rechunking to +``` +000|111122|3333 +``` +would make this resampling reduction an embarassingly parallel blockwise problem. + +Similarly consider monthly-mean climatologies for which the month numbers might be +``` +1 2 3 4 5 | 6 7 8 9 10 | 11 12 1 2 3 | 4 5 6 7 8 | 9 10 11 12 | +``` +A slight rechunking to +``` +1 2 3 4 | 5 6 7 8 | 9 10 11 12 | 1 2 3 4 | 5 6 7 8 | 9 10 11 12 | +``` +allows us to reduce `1, 2, 3, 4` separately from `5,6,7,8` and `9, 10, 11, 12` while still being parallel friendly (see the [flox documentation](https://flox.readthedocs.io/en/latest/implementation.html#method-cohorts) for more). + +We could attempt to detect these patterns, or we could just have the Grouper take as input `chunks` and return a tuple of "nice" chunk sizes to rechunk to. +```python +def preferred_chunks(self, chunks: ChunksTuple) -> ChunksTuple: + pass +``` +For monthly means, since the period of repetition of labels is 12, the Grouper might choose possible chunk sizes of `((2,),(3,),(4,),(6,))`. +For resampling, the Grouper could choose to resample to a multiple or an even fraction of the resampling frequency. + +## Related work + +Pandas has [Grouper objects](https://pandas.pydata.org/docs/reference/api/pandas.Grouper.html#pandas-grouper) that represent the GroupBy instruction. +However, these objects do not appear to be extension points, unlike the Grouper objects proposed here. +Instead, Pandas' `ExtensionArray` has a [`factorize`](https://pandas.pydata.org/docs/reference/api/pandas.api.extensions.ExtensionArray.factorize.html) method. + +Composing rolling with time resampling is a common workload: +1. Polars has [`group_by_dynamic`](https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.group_by_dynamic.html) which appears to be like the proposed `RollingResampler`. +2. scikit-downscale provides [`PaddedDOYGrouper`]( +https://github.com/pangeo-data/scikit-downscale/blob/e16944a32b44f774980fa953ea18e29a628c71b8/skdownscale/pointwise_models/groupers.py#L19) + +## Implementation Proposal + +1. Get rid of `squeeze` [issue](https://github.com/pydata/xarray/issues/2157): [PR](https://github.com/pydata/xarray/pull/8506) +2. Merge existing two class implementation to a single Grouper class + 1. This design was implemented in [this PR](https://github.com/pydata/xarray/pull/7206) to account for some annoying data dependencies. + 2. See [PR](https://github.com/pydata/xarray/pull/8509) +3. Clean up what's returned by `factorize` methods. + 1. A solution here might be to have `group_indices: Mapping[int, Sequence[int]]` be a mapping from group index in `full_index` to a sequence of integers. + 2. Return a `namedtuple` or `dataclass` from existing Grouper factorize methods to facilitate API changes in the future. +4. Figure out what to pass to `factorize` + 1. Xarray eagerly reshapes nD variables to 1D. This is an implementation detail we need not expose. + 2. When grouping by an unindexed variable Xarray passes a `_DummyGroup` object. This seems like something we don't want in the public interface. We could special case "internal" Groupers to preserve the optimizations in `UniqueGrouper`. +5. Grouper objects will exposed under the `xr.groupers` Namespace. At first these will include `UniqueGrouper`, `BinGrouper`, and `TimeResampler`. + +## Alternatives + +One major design choice made here was to adopt the syntax `ds.groupby(x=BinGrouper(...))` instead of `ds.groupby(BinGrouper('x', ...))`. +This allows reuse of Grouper objects, example +```python +grouper = BinGrouper(...) +ds.groupby(x=grouper, y=grouper) +``` +but requires that all variables being grouped by (`x` and `y` above) are present in Dataset `ds`. This does not seem like a bad requirement. +Importantly `Grouper` instances will be copied internally so that they can safely cache state that might be shared between `factorize` and `weights`. + +Today, it is possible to `ds.groupby(DataArray, ...)`. This syntax will still be supported. + +## Discussion + +This proposal builds on these discussions: +1. https://github.com/xarray-contrib/flox/issues/191#issuecomment-1328898836 +2. https://github.com/pydata/xarray/issues/6610 + +## Copyright + +This document has been placed in the public domain. + +## References and footnotes + +[^1]: Wickham, H. (2011). The split-apply-combine strategy for data analysis. https://vita.had.co.nz/papers/plyr.html +[^2]: Iverson, K.E. (1980). Notation as a tool of thought. Commun. ACM 23, 8 (Aug. 1980), 444–465. https://doi.org/10.1145/358896.358899 diff --git a/design_notes/named_array_design_doc.md b/design_notes/named_array_design_doc.md new file mode 100644 index 00000000000..074f8cf17e7 --- /dev/null +++ b/design_notes/named_array_design_doc.md @@ -0,0 +1,371 @@ +# named-array Design Document + +## Abstract + +Despite the wealth of scientific libraries in the Python ecosystem, there is a gap for a lightweight, efficient array structure with named dimensions that can provide convenient broadcasting and indexing. + +Existing solutions like Xarray's Variable, [Pytorch Named Tensor](https://github.com/pytorch/pytorch/issues/60832), [Levanter](https://crfm.stanford.edu/2023/06/16/levanter-1_0-release.html), and [Larray](https://larray.readthedocs.io/en/stable/tutorial/getting_started.html) have their own strengths and weaknesses. Xarray's Variable is an efficient data structure, but it depends on the relatively heavy-weight library Pandas, which limits its use in other projects. Pytorch Named Tensor offers named dimensions, but it lacks support for many operations, making it less user-friendly. Levanter is a powerful tool with a named tensor module (Haliax) that makes deep learning code easier to read, understand, and write, but it is not as lightweight or generic as desired. Larry offers labeled N-dimensional arrays, but it may not provide the level of seamless interoperability with other scientific Python libraries that some users need. + +named-array aims to solve these issues by exposing the core functionality of Xarray's Variable class as a standalone package. + +## Motivation and Scope + +The Python ecosystem boasts a wealth of scientific libraries that enable efficient computations on large, multi-dimensional arrays. Libraries like PyTorch, Xarray, and NumPy have revolutionized scientific computing by offering robust data structures for array manipulations. Despite this wealth of tools, a gap exists in the Python landscape for a lightweight, efficient array structure with named dimensions that can provide convenient broadcasting and indexing. + +Xarray internally maintains a data structure that meets this need, referred to as [`xarray.Variable`](https://docs.xarray.dev/en/latest/generated/xarray.Variable.html) . However, Xarray's dependency on Pandas, a relatively heavy-weight library, restricts other projects from leveraging this efficient data structure (, , ). + +We propose the creation of a standalone Python package, "named-array". This package is envisioned to be a version of the `xarray.Variable` data structure, cleanly separated from the heavier dependencies of Xarray. named-array will provide a lightweight, user-friendly array-like data structure with named dimensions, facilitating convenient indexing and broadcasting. The package will use existing scientific Python community standards such as established array protocols and the new [Python array API standard](https://data-apis.org/array-api/latest), allowing users to wrap multiple duck-array objects, including, but not limited to, NumPy, Dask, Sparse, Pint, CuPy, and Pytorch. + +The development of named-array is projected to meet a key community need and expected to broaden Xarray's user base. By making the core `xarray.Variable` more accessible, we anticipate an increase in contributors and a reduction in the developer burden on current Xarray maintainers. + +### Goals + +1. **Simple and minimal**: named-array will expose Xarray's [Variable class](https://docs.xarray.dev/en/stable/internals/variable-objects.html) as a standalone object (`NamedArray`) with named axes (dimensions) and arbitrary metadata (attributes) but without coordinate labels. This will make it a lightweight, efficient array data structure that allows convenient broadcasting and indexing. + +2. **Interoperability**: named-array will follow established scientific Python community standards and in doing so, will allow it to wrap multiple duck-array objects, including but not limited to, NumPy, Dask, Sparse, Pint, CuPy, and Pytorch. + +3. **Community Engagement**: By making the core `xarray.Variable` more accessible, we open the door to increased adoption of this fundamental data structure. As such, we hope to see an increase in contributors and reduction in the developer burden on current Xarray maintainers. + +### Non-Goals + +1. **Extensive Data Analysis**: named-array will not provide extensive data analysis features like statistical functions, data cleaning, or visualization. Its primary focus is on providing a data structure that allows users to use dimension names for descriptive array manipulations. + +2. **Support for I/O**: named-array will not bundle file reading functions. Instead users will be expected to handle I/O and then wrap those arrays with the new named-array data structure. + +## Backward Compatibility + +The creation of named-array is intended to separate the `xarray.Variable` from Xarray into a standalone package. This allows it to be used independently, without the need for Xarray's dependencies, like Pandas. This separation has implications for backward compatibility. + +Since the new named-array is envisioned to contain the core features of Xarray's variable, existing code using Variable from Xarray should be able to switch to named-array with minimal changes. However, there are several potential issues related to backward compatibility: + +* **API Changes**: as the Variable is decoupled from Xarray and moved into named-array, some changes to the API may be necessary. These changes might include differences in function signature, etc. These changes could break existing code that relies on the current API and associated utility functions (e.g. `as_variable()`). The `xarray.Variable` object will subclass `NamedArray`, and provide the existing interface for compatibility. + +## Detailed Description + +named-array aims to provide a lightweight, efficient array structure with named dimensions, or axes, that enables convenient broadcasting and indexing. The primary component of named-array is a standalone version of the xarray.Variable data structure, which was previously a part of the Xarray library. +The xarray.Variable data structure in named-array will maintain the core features of its counterpart in Xarray, including: + +* **Named Axes (Dimensions)**: Each axis of the array can be given a name, providing a descriptive and intuitive way to reference the dimensions of the array. + +* **Arbitrary Metadata (Attributes)**: named-array will support the attachment of arbitrary metadata to arrays as a dict, providing a mechanism to store additional information about the data that the array represents. + +* **Convenient Broadcasting and Indexing**: With named dimensions, broadcasting and indexing operations become more intuitive and less error-prone. + +The named-array package is designed to be interoperable with other scientific Python libraries. It will follow established scientific Python community standards and use standard array protocols, as well as the new data-apis standard. This allows named-array to wrap multiple duck-array objects, including, but not limited to, NumPy, Dask, Sparse, Pint, CuPy, and Pytorch. + +## Implementation + +* **Decoupling**: making `variable.py` agnostic to Xarray internals by decoupling it from the rest of the library. This will make the code more modular and easier to maintain. However, this will also make the code more complex, as we will need to define a clear interface for how the functionality in `variable.py` interacts with the rest of the library, particularly the ExplicitlyIndexed subclasses used to enable lazy indexing of data on disk. +* **Move Xarray's internal lazy indexing classes to follow standard Array Protocols**: moving the lazy indexing classes like `ExplicitlyIndexed` to use standard array protocols will be a key step in decoupling. It will also potentially improve interoperability with other libraries that use these protocols, and prepare these classes [for eventual movement out](https://github.com/pydata/xarray/issues/5081) of the Xarray code base. However, this will also require significant changes to the code, and we will need to ensure that all existing functionality is preserved. + * Use [https://data-apis.org/array-api-compat/](https://data-apis.org/array-api-compat/) to handle compatibility issues? +* **Leave lazy indexing classes in Xarray for now** +* **Preserve support for Dask collection protocols**: named-array will preserve existing support for the dask collections protocol namely the __dask_***__ methods +* **Preserve support for ChunkManagerEntrypoint?** Opening variables backed by dask vs cubed arrays currently is [handled within Variable.chunk](https://github.com/pydata/xarray/blob/92c8b33eb464b09d6f8277265b16cae039ab57ee/xarray/core/variable.py#L1272C15-L1272C15). If we are preserving dask support it would be nice to preserve general chunked array type support, but this currently requires an entrypoint. + +### Plan + +1. Create a new baseclass for `xarray.Variable` to its own module e.g. `xarray.core.base_variable` +2. Remove all imports of internal Xarray classes and utils from `base_variable.py`. `base_variable.Variable` should not depend on anything in xarray.core + * Will require moving the lazy indexing classes (subclasses of ExplicitlyIndexed) to be standards compliant containers.` + * an array-api compliant container that provides **array_namespace**` + * Support `.oindex` and `.vindex` for explicit indexing + * Potentially implement this by introducing a new compliant wrapper object? + * Delete the `NON_NUMPY_SUPPORTED_ARRAY_TYPES` variable which special-cases ExplicitlyIndexed and `pd.Index.` + * `ExplicitlyIndexed` class and subclasses should provide `.oindex` and `.vindex` for indexing by `Variable.__getitem__.`: `oindex` and `vindex` were proposed in [NEP21](https://numpy.org/neps/nep-0021-advanced-indexing.html), but have not been implemented yet + * Delete the ExplicitIndexer objects (`BasicIndexer`, `VectorizedIndexer`, `OuterIndexer`) + * Remove explicit support for `pd.Index`. When provided with a `pd.Index` object, Variable will coerce to an array using `np.array(pd.Index)`. For Xarray's purposes, Xarray can use `as_variable` to explicitly wrap these in PandasIndexingAdapter and pass them to `Variable.__init__`. +3. Define a minimal variable interface that the rest of Xarray can use: + 1. `dims`: tuple of dimension names + 2. `data`: numpy/dask/duck arrays` + 3. `attrs``: dictionary of attributes + +4. Implement basic functions & methods for manipulating these objects. These methods will be a cleaned-up subset (for now) of functionality on xarray.Variable, with adaptations inspired by the [Python array API](https://data-apis.org/array-api/2022.12/API_specification/index.html). +5. Existing Variable structures + 1. Keep Variable object which subclasses the new structure that adds the `.encoding` attribute and potentially other methods needed for easy refactoring. + 2. IndexVariable will remain in xarray.core.variable and subclass the new named-array data structure pending future deletion. +6. Docstrings and user-facing APIs will need to be updated to reflect the changed methods on Variable objects. + +Further implementation details are in Appendix: [Implementation Details](#appendix-implementation-details). + +## Plan for decoupling lazy indexing functionality from NamedArray + +Today's implementation Xarray's lazy indexing functionality uses three private objects: `*Indexer`, `*IndexingAdapter`, `*Array`. +These objects are needed for two reason: +1. We need to translate from Xarray (NamedArray) indexing rules to bare array indexing rules. + - `*Indexer` objects track the type of indexing - basic, orthogonal, vectorized +2. Not all arrays support the same indexing rules, so we need `*Indexing` adapters + 1. Indexing Adapters today implement `__getitem__` and use type of `*Indexer` object to do appropriate conversions. +3. We also want to support lazy indexing of on-disk arrays. + 1. These again support different types of indexing, so we have `explicit_indexing_adapter` that understands `*Indexer` objects. + +### Goals +1. We would like to keep the lazy indexing array objects, and backend array objects within Xarray. Thus NamedArray cannot treat these objects specially. +2. A key source of confusion (and coupling) is that both lazy indexing arrays and indexing adapters, both handle Indexer objects, and both subclass `ExplicitlyIndexedNDArrayMixin`. These are however conceptually different. + +### Proposal + +1. The `NumpyIndexingAdapter`, `DaskIndexingAdapter`, and `ArrayApiIndexingAdapter` classes will need to migrate to Named Array project since we will want to support indexing of numpy, dask, and array-API arrays appropriately. +2. The `as_indexable` function which wraps an array with the appropriate adapter will also migrate over to named array. +3. Lazy indexing arrays will implement `__getitem__` for basic indexing, `.oindex` for orthogonal indexing, and `.vindex` for vectorized indexing. +4. IndexingAdapter classes will similarly implement `__getitem__`, `oindex`, and `vindex`. +5. `NamedArray.__getitem__` (and `__setitem__`) will still use `*Indexer` objects internally (for e.g. in `NamedArray._broadcast_indexes`), but use `.oindex`, `.vindex` on the underlying indexing adapters. +6. We will move the `*Indexer` and `*IndexingAdapter` classes to Named Array. These will be considered private in the long-term. +7. `as_indexable` will no longer special case `ExplicitlyIndexed` objects (we can special case a new `IndexingAdapter` mixin class that will be private to NamedArray). To handle Xarray's lazy indexing arrays, we will introduce a new `ExplicitIndexingAdapter` which will wrap any array with any of `.oindex` of `.vindex` implemented. + 1. This will be the last case in the if-chain that is, we will try to wrap with all other `IndexingAdapter` objects before using `ExplicitIndexingAdapter` as a fallback. This Adapter will be used for the lazy indexing arrays, and backend arrays. + 2. As with other indexing adapters (point 4 above), this `ExplicitIndexingAdapter` will only implement `__getitem__` and will understand `*Indexer` objects. +8. For backwards compatibility with external backends, we will have to gracefully deprecate `indexing.explicit_indexing_adapter` which translates from Xarray's indexing rules to the indexing supported by the backend. + 1. We could split `explicit_indexing_adapter` in to 3: + - `basic_indexing_adapter`, `outer_indexing_adapter` and `vectorized_indexing_adapter` for public use. + 2. Implement fall back `.oindex`, `.vindex` properties on `BackendArray` base class. These will simply rewrap the `key` tuple with the appropriate `*Indexer` object, and pass it on to `__getitem__` or `__setitem__`. These methods will also raise DeprecationWarning so that external backends will know to migrate to `.oindex`, and `.vindex` over the next year. + +THe most uncertain piece here is maintaining backward compatibility with external backends. We should first migrate a single internal backend, and test out the proposed approach. + +## Project Timeline and Milestones + +We have identified the following milestones for the completion of this project: + +1. **Write and publish a design document**: this document will explain the purpose of named-array, the intended audience, and the features it will provide. It will also describe the architecture of named-array and how it will be implemented. This will ensure early community awareness and engagement in the project to promote subsequent uptake. +2. **Refactor `variable.py` to `base_variable.py`** and remove internal Xarray imports. +3. **Break out the package and create continuous integration infrastructure**: this will entail breaking out the named-array project into a Python package and creating a continuous integration (CI) system. This will help to modularize the code and make it easier to manage. Building a CI system will help ensure that codebase changes do not break existing functionality. +4. Incrementally add new functions & methods to the new package, ported from xarray. This will start to make named-array useful on its own. +5. Refactor the existing Xarray codebase to rely on the newly created package (named-array): This will help to demonstrate the usefulness of the new package, and also provide an example for others who may want to use it. +6. Expand tests, add documentation, and write a blog post: expanding the test suite will help to ensure that the code is reliable and that changes do not introduce bugs. Adding documentation will make it easier for others to understand and use the project. +7. Finally, we will write a series of blog posts on [xarray.dev](https://xarray.dev/) to promote the project and attract more contributors. + * Toward the end of the process, write a few blog posts that demonstrate the use of the newly available data structure + * pick the same example applications used by other implementations/applications (e.g. Pytorch, sklearn, and Levanter) to show how it can work. + +## Related Work + +1. [GitHub - deepmind/graphcast](https://github.com/deepmind/graphcast) +2. [Getting Started — LArray 0.34 documentation](https://larray.readthedocs.io/en/stable/tutorial/getting_started.html) +3. [Levanter — Legible, Scalable, Reproducible Foundation Models with JAX](https://crfm.stanford.edu/2023/06/16/levanter-1_0-release.html) +4. [google/xarray-tensorstore](https://github.com/google/xarray-tensorstore) +5. [State of Torch Named Tensors · Issue #60832 · pytorch/pytorch · GitHub](https://github.com/pytorch/pytorch/issues/60832) + * Incomplete support: Many primitive operations result in errors, making it difficult to use NamedTensors in Practice. Users often have to resort to removing the names from tensors to avoid these errors. + * Lack of active development: the development of the NamedTensor feature in PyTorch is not currently active due a lack of bandwidth for resolving ambiguities in the design. + * Usability issues: the current form of NamedTensor is not user-friendly and sometimes raises errors, making it difficult for users to incorporate NamedTensors into their workflows. +6. [Scikit-learn Enhancement Proposals (SLEPs) 8, 12, 14](https://github.com/scikit-learn/enhancement_proposals/pull/18) + * Some of the key points and limitations discussed in these proposals are: + * Inconsistency in feature name handling: Scikit-learn currently lacks a consistent and comprehensive way to handle and propagate feature names through its pipelines and estimators ([SLEP 8](https://github.com/scikit-learn/enhancement_proposals/pull/18),[SLEP 12](https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep012/proposal.html)). + * Memory intensive for large feature sets: storing and propagating feature names can be memory intensive, particularly in cases where the entire "dictionary" becomes the features, such as in NLP use cases ([SLEP 8](https://github.com/scikit-learn/enhancement_proposals/pull/18),[GitHub issue #35](https://github.com/scikit-learn/enhancement_proposals/issues/35)) + * Sparse matrices: sparse data structures present a challenge for feature name propagation. For instance, the sparse data structure functionality in Pandas 1.0 only supports converting directly to the coordinate format (COO), which can be an issue with transformers such as the OneHotEncoder.transform that has been optimized to construct a CSR matrix ([SLEP 14](https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep014/proposal.html)) + * New Data structures: the introduction of new data structures, such as "InputArray" or "DataArray" could lead to more burden for third-party estimator maintainers and increase the learning curve for users. Xarray's "DataArray" is mentioned as a potential alternative, but the proposal mentions that the conversion from a Pandas dataframe to a Dataset is not lossless ([SLEP 12](https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep012/proposal.html),[SLEP 14](https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep014/proposal.html),[GitHub issue #35](https://github.com/scikit-learn/enhancement_proposals/issues/35)). + * Dependency on other libraries: solutions that involve using Xarray and/or Pandas to handle feature names come with the challenge of managing dependencies. While a soft dependency approach is suggested, this means users would be able to have/enable the feature only if they have the dependency installed. Xarra-lite's integration with other scientific Python libraries could potentially help with this issue ([GitHub issue #35](https://github.com/scikit-learn/enhancement_proposals/issues/35)). + +## References and Previous Discussion + +* [[Proposal] Expose Variable without Pandas dependency · Issue #3981 · pydata/xarray · GitHub](https://github.com/pydata/xarray/issues/3981) +* [https://github.com/pydata/xarray/issues/3981#issuecomment-985051449](https://github.com/pydata/xarray/issues/3981#issuecomment-985051449) +* [Lazy indexing arrays as a stand-alone package · Issue #5081 · pydata/xarray · GitHub](https://github.com/pydata/xarray/issues/5081) + +### Appendix: Engagement with the Community + +We plan to publicize this document on : + +* [x] `Xarray dev call` +* [ ] `Scientific Python discourse` +* [ ] `Xarray Github` +* [ ] `Twitter` +* [ ] `Respond to NamedTensor and Scikit-Learn issues?` +* [ ] `Pangeo Discourse` +* [ ] `Numpy, SciPy email lists?` +* [ ] `Xarray blog` + +Additionally, We plan on writing a series of blog posts to effectively showcase the implementation and potential of the newly available functionality. To illustrate this, we will use the same example applications as other established libraries (such as Pytorch, sklearn), providing practical demonstrations of how these new data structures can be leveraged. + +### Appendix: API Surface + +Questions: + +1. Document Xarray indexing rules +2. Document use of .oindex and .vindex protocols +3. Do we use `.mean` and `.nanmean` or `.mean(skipna=...)`? + * Default behavior in named-array should mirror NumPy / the array API standard, not pandas. + * nanmean is not (yet) in the [array API](https://github.com/pydata/xarray/pull/7424#issuecomment-1373979208). There are a handful of other key functions (e.g., median) that are are also missing. I think that should be OK, as long as what we support is a strict superset of the array API. +4. What methods need to be exposed on Variable? + * `Variable.concat` classmethod: create two functions, one as the equivalent of `np.stack` and other for `np.concat` + * `.rolling_window` and `.coarsen_reshape` ? + * `named-array.apply_ufunc`: used in astype, clip, quantile, isnull, notnull` + +#### methods to be preserved from xarray.Variable + +```python +# Sorting + Variable.argsort + Variable.searchsorted + +# NaN handling + Variable.fillna + Variable.isnull + Variable.notnull + +# Lazy data handling + Variable.chunk # Could instead have accessor interface and recommend users use `Variable.dask.chunk` and `Variable.cubed.chunk`? + Variable.to_numpy() + Variable.as_numpy() + +# Xarray-specific + Variable.get_axis_num + Variable.isel + Variable.to_dict + +# Reductions + Variable.reduce + Variable.all + Variable.any + Variable.argmax + Variable.argmin + Variable.count + Variable.max + Variable.mean + Variable.median + Variable.min + Variable.prod + Variable.quantile + Variable.std + Variable.sum + Variable.var + +# Accumulate + Variable.cumprod + Variable.cumsum + +# numpy-like Methods + Variable.astype + Variable.copy + Variable.clip + Variable.round + Variable.item + Variable.where + +# Reordering/Reshaping + Variable.squeeze + Variable.pad + Variable.roll + Variable.shift + +``` + +#### methods to be renamed from xarray.Variable + +```python +# Xarray-specific + Variable.concat # create two functions, one as the equivalent of `np.stack` and other for `np.concat` + + # Given how niche these are, these would be better as functions than methods. + # We could also keep these in Xarray, at least for now. If we don't think people will use functionality outside of Xarray it probably is not worth the trouble of porting it (including documentation, etc). + Variable.coarsen # This should probably be called something like coarsen_reduce. + Variable.coarsen_reshape + Variable.rolling_window + + Variable.set_dims # split this into broadcas_to and expand_dims + + +# Reordering/Reshaping + Variable.stack # To avoid confusion with np.stack, let's call this stack_dims. + Variable.transpose # Could consider calling this permute_dims, like the [array API standard](https://data-apis.org/array-api/2022.12/API_specification/manipulation_functions.html#objects-in-api) + Variable.unstack # Likewise, maybe call this unstack_dims? +``` + +#### methods to be removed from xarray.Variable + +```python +# Testing + Variable.broadcast_equals + Variable.equals + Variable.identical + Variable.no_conflicts + +# Lazy data handling + Variable.compute # We can probably omit this method for now, too, given that dask.compute() uses a protocol. The other concern is that different array libraries have different notions of "compute" and this one is rather Dask specific, including conversion from Dask to NumPy arrays. For example, in JAX every operation executes eagerly, but in a non-blocking fashion, and you need to call jax.block_until_ready() to ensure computation is finished. + Variable.load # Could remove? compute vs load is a common source of confusion. + +# Xarray-specific + Variable.to_index + Variable.to_index_variable + Variable.to_variable + Variable.to_base_variable + Variable.to_coord + + Variable.rank # Uses bottleneck. Delete? Could use https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.rankdata.html instead + + +# numpy-like Methods + Variable.conjugate # .conj is enough + Variable.__array_wrap__ # This is a very old NumPy protocol for duck arrays. We don't need it now that we have `__array_ufunc__` and `__array_function__` + +# Encoding + Variable.reset_encoding + +``` + +#### Attributes to be preserved from xarray.Variable + +```python +# Properties + Variable.attrs + Variable.chunks + Variable.data + Variable.dims + Variable.dtype + + Variable.nbytes + Variable.ndim + Variable.shape + Variable.size + Variable.sizes + + Variable.T + Variable.real + Variable.imag + Variable.conj +``` + +#### Attributes to be renamed from xarray.Variable + +```python +``` + +#### Attributes to be removed from xarray.Variable + +```python + + Variable.values # Probably also remove -- this is a legacy from before Xarray supported dask arrays. ".data" is enough. + +# Encoding + Variable.encoding + +``` + +### Appendix: Implementation Details + +* Merge in VariableArithmetic's parent classes: AbstractArray, NdimSizeLenMixin with the new data structure.. + +```python +class VariableArithmetic( + ImplementsArrayReduce, + IncludeReduceMethods, + IncludeCumMethods, + IncludeNumpySameMethods, + SupportsArithmetic, + VariableOpsMixin, +): + __slots__ = () + # prioritize our operations over those of numpy.ndarray (priority=0) + __array_priority__ = 50 + +``` + +* Move over `_typed_ops.VariableOpsMixin` +* Build a list of utility functions used elsewhere : Which of these should become public API? + * `broadcast_variables`: `dataset.py`, `dataarray.py`,`missing.py` + * This could be just called "broadcast" in named-array. + * `Variable._getitem_with_mask` : `alignment.py` + * keep this method/function as private and inside Xarray. +* The Variable constructor will need to be rewritten to no longer accept tuples, encodings, etc. These details should be handled at the Xarray data structure level. +* What happens to `duck_array_ops?` +* What about Variable.chunk and "chunk managers"? + * Could this functionality be left in Xarray proper for now? Alternative array types like JAX also have some notion of "chunks" for parallel arrays, but the details differ in a number of ways from the Dask/Cubed. + * Perhaps variable.chunk/load methods should become functions defined in xarray that convert Variable objects. This is easy so long as xarray can reach in and replace .data + +* Utility functions like `as_variable` should be moved out of `base_variable.py` so they can convert BaseVariable objects to/from DataArray or Dataset containing explicitly indexed arrays. diff --git a/doc/_static/dataset-diagram-build.sh b/doc/_static/dataset-diagram-build.sh deleted file mode 100755 index 1e69d454ff6..00000000000 --- a/doc/_static/dataset-diagram-build.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/usr/bin/env bash -pdflatex -shell-escape dataset-diagram.tex diff --git a/doc/_static/dataset-diagram-logo.pdf b/doc/_static/dataset-diagram-logo.pdf deleted file mode 100644 index 0ef2b1247eb..00000000000 Binary files a/doc/_static/dataset-diagram-logo.pdf and /dev/null differ diff --git a/doc/_static/dataset-diagram-logo.png b/doc/_static/dataset-diagram-logo.png deleted file mode 100644 index 23c413d3414..00000000000 Binary files a/doc/_static/dataset-diagram-logo.png and /dev/null differ diff --git a/doc/_static/dataset-diagram-logo.svg b/doc/_static/dataset-diagram-logo.svg deleted file mode 100644 index 2809bf2f5a1..00000000000 --- a/doc/_static/dataset-diagram-logo.svg +++ /dev/null @@ -1,484 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/doc/_static/dataset-diagram-logo.tex b/doc/_static/dataset-diagram-logo.tex deleted file mode 100644 index 7c96c47dfc4..00000000000 --- a/doc/_static/dataset-diagram-logo.tex +++ /dev/null @@ -1,283 +0,0 @@ -% For svg output: -% 1. Use second \documentclass line. -% 2. Install potrace : mamba install -c bioconda potrace -% As of Feb 14, 2023: svg does not work. PDF does work -% \documentclass[class=minimal,border=0pt,convert={size=600,outext=.png}]{standalone} -\documentclass[class=minimal,border=0pt,convert={size=600,outext=.svg}]{standalone} -% \documentclass[class=minimal,border=0pt]{standalone} -\usepackage[scaled]{helvet} -\usepackage[T1]{fontenc} -\renewcommand*\familydefault{\sfdefault} - -% =========================================================================== -% The code below (used to define the \tikzcuboid command) is copied, -% unmodified, from a tex.stackexchange.com answer by the user "Tom Bombadil": -% http://tex.stackexchange.com/a/29882/8335 -% -% It is licensed under the Creative Commons Attribution-ShareAlike 3.0 -% Unported license: http://creativecommons.org/licenses/by-sa/3.0/ -% =========================================================================== - -\usepackage[usenames,dvipsnames]{color} -\usepackage{tikz} -\usepackage{keyval} -\usepackage{ifthen} - -%==================================== -%emphasize vertices --> switch and emph style (e.g. thick,black) -%==================================== -\makeatletter -% Standard Values for Parameters -\newcommand{\tikzcuboid@shiftx}{0} -\newcommand{\tikzcuboid@shifty}{0} -\newcommand{\tikzcuboid@dimx}{3} -\newcommand{\tikzcuboid@dimy}{3} -\newcommand{\tikzcuboid@dimz}{3} -\newcommand{\tikzcuboid@scale}{1} -\newcommand{\tikzcuboid@densityx}{1} -\newcommand{\tikzcuboid@densityy}{1} -\newcommand{\tikzcuboid@densityz}{1} -\newcommand{\tikzcuboid@rotation}{0} -\newcommand{\tikzcuboid@anglex}{0} -\newcommand{\tikzcuboid@angley}{90} -\newcommand{\tikzcuboid@anglez}{225} -\newcommand{\tikzcuboid@scalex}{1} -\newcommand{\tikzcuboid@scaley}{1} -\newcommand{\tikzcuboid@scalez}{sqrt(0.5)} -\newcommand{\tikzcuboid@linefront}{black} -\newcommand{\tikzcuboid@linetop}{black} -\newcommand{\tikzcuboid@lineright}{black} -\newcommand{\tikzcuboid@fillfront}{white} -\newcommand{\tikzcuboid@filltop}{white} -\newcommand{\tikzcuboid@fillright}{white} -\newcommand{\tikzcuboid@shaded}{N} -\newcommand{\tikzcuboid@shadecolor}{black} -\newcommand{\tikzcuboid@shadeperc}{25} -\newcommand{\tikzcuboid@emphedge}{N} -\newcommand{\tikzcuboid@emphstyle}{thick} - -% Definition of Keys -\define@key{tikzcuboid}{shiftx}[\tikzcuboid@shiftx]{\renewcommand{\tikzcuboid@shiftx}{#1}} -\define@key{tikzcuboid}{shifty}[\tikzcuboid@shifty]{\renewcommand{\tikzcuboid@shifty}{#1}} -\define@key{tikzcuboid}{dimx}[\tikzcuboid@dimx]{\renewcommand{\tikzcuboid@dimx}{#1}} -\define@key{tikzcuboid}{dimy}[\tikzcuboid@dimy]{\renewcommand{\tikzcuboid@dimy}{#1}} -\define@key{tikzcuboid}{dimz}[\tikzcuboid@dimz]{\renewcommand{\tikzcuboid@dimz}{#1}} -\define@key{tikzcuboid}{scale}[\tikzcuboid@scale]{\renewcommand{\tikzcuboid@scale}{#1}} -\define@key{tikzcuboid}{densityx}[\tikzcuboid@densityx]{\renewcommand{\tikzcuboid@densityx}{#1}} -\define@key{tikzcuboid}{densityy}[\tikzcuboid@densityy]{\renewcommand{\tikzcuboid@densityy}{#1}} -\define@key{tikzcuboid}{densityz}[\tikzcuboid@densityz]{\renewcommand{\tikzcuboid@densityz}{#1}} -\define@key{tikzcuboid}{rotation}[\tikzcuboid@rotation]{\renewcommand{\tikzcuboid@rotation}{#1}} -\define@key{tikzcuboid}{anglex}[\tikzcuboid@anglex]{\renewcommand{\tikzcuboid@anglex}{#1}} -\define@key{tikzcuboid}{angley}[\tikzcuboid@angley]{\renewcommand{\tikzcuboid@angley}{#1}} -\define@key{tikzcuboid}{anglez}[\tikzcuboid@anglez]{\renewcommand{\tikzcuboid@anglez}{#1}} -\define@key{tikzcuboid}{scalex}[\tikzcuboid@scalex]{\renewcommand{\tikzcuboid@scalex}{#1}} -\define@key{tikzcuboid}{scaley}[\tikzcuboid@scaley]{\renewcommand{\tikzcuboid@scaley}{#1}} -\define@key{tikzcuboid}{scalez}[\tikzcuboid@scalez]{\renewcommand{\tikzcuboid@scalez}{#1}} -\define@key{tikzcuboid}{linefront}[\tikzcuboid@linefront]{\renewcommand{\tikzcuboid@linefront}{#1}} -\define@key{tikzcuboid}{linetop}[\tikzcuboid@linetop]{\renewcommand{\tikzcuboid@linetop}{#1}} -\define@key{tikzcuboid}{lineright}[\tikzcuboid@lineright]{\renewcommand{\tikzcuboid@lineright}{#1}} -\define@key{tikzcuboid}{fillfront}[\tikzcuboid@fillfront]{\renewcommand{\tikzcuboid@fillfront}{#1}} -\define@key{tikzcuboid}{filltop}[\tikzcuboid@filltop]{\renewcommand{\tikzcuboid@filltop}{#1}} -\define@key{tikzcuboid}{fillright}[\tikzcuboid@fillright]{\renewcommand{\tikzcuboid@fillright}{#1}} -\define@key{tikzcuboid}{shaded}[\tikzcuboid@shaded]{\renewcommand{\tikzcuboid@shaded}{#1}} -\define@key{tikzcuboid}{shadecolor}[\tikzcuboid@shadecolor]{\renewcommand{\tikzcuboid@shadecolor}{#1}} -\define@key{tikzcuboid}{shadeperc}[\tikzcuboid@shadeperc]{\renewcommand{\tikzcuboid@shadeperc}{#1}} -\define@key{tikzcuboid}{emphedge}[\tikzcuboid@emphedge]{\renewcommand{\tikzcuboid@emphedge}{#1}} -\define@key{tikzcuboid}{emphstyle}[\tikzcuboid@emphstyle]{\renewcommand{\tikzcuboid@emphstyle}{#1}} -% Commands -\newcommand{\tikzcuboid}[1]{ - \setkeys{tikzcuboid}{#1} % Process Keys passed to command - \pgfmathsetmacro{\vectorxx}{\tikzcuboid@scalex*cos(\tikzcuboid@anglex)} - \pgfmathsetmacro{\vectorxy}{\tikzcuboid@scalex*sin(\tikzcuboid@anglex)} - \pgfmathsetmacro{\vectoryx}{\tikzcuboid@scaley*cos(\tikzcuboid@angley)} - \pgfmathsetmacro{\vectoryy}{\tikzcuboid@scaley*sin(\tikzcuboid@angley)} - \pgfmathsetmacro{\vectorzx}{\tikzcuboid@scalez*cos(\tikzcuboid@anglez)} - \pgfmathsetmacro{\vectorzy}{\tikzcuboid@scalez*sin(\tikzcuboid@anglez)} - \begin{scope}[xshift=\tikzcuboid@shiftx, yshift=\tikzcuboid@shifty, scale=\tikzcuboid@scale, rotate=\tikzcuboid@rotation, x={(\vectorxx,\vectorxy)}, y={(\vectoryx,\vectoryy)}, z={(\vectorzx,\vectorzy)}] - \pgfmathsetmacro{\steppingx}{1/\tikzcuboid@densityx} - \pgfmathsetmacro{\steppingy}{1/\tikzcuboid@densityy} - \pgfmathsetmacro{\steppingz}{1/\tikzcuboid@densityz} - \newcommand{\dimx}{\tikzcuboid@dimx} - \newcommand{\dimy}{\tikzcuboid@dimy} - \newcommand{\dimz}{\tikzcuboid@dimz} - \pgfmathsetmacro{\secondx}{2*\steppingx} - \pgfmathsetmacro{\secondy}{2*\steppingy} - \pgfmathsetmacro{\secondz}{2*\steppingz} - \foreach \x in {\steppingx,\secondx,...,\dimx} - { \foreach \y in {\steppingy,\secondy,...,\dimy} - { \pgfmathsetmacro{\lowx}{(\x-\steppingx)} - \pgfmathsetmacro{\lowy}{(\y-\steppingy)} - \filldraw[fill=\tikzcuboid@fillfront,draw=\tikzcuboid@linefront] (\lowx,\lowy,\dimz) -- (\lowx,\y,\dimz) -- (\x,\y,\dimz) -- (\x,\lowy,\dimz) -- cycle; - - } - } - \foreach \x in {\steppingx,\secondx,...,\dimx} - { \foreach \z in {\steppingz,\secondz,...,\dimz} - { \pgfmathsetmacro{\lowx}{(\x-\steppingx)} - \pgfmathsetmacro{\lowz}{(\z-\steppingz)} - \filldraw[fill=\tikzcuboid@filltop,draw=\tikzcuboid@linetop] (\lowx,\dimy,\lowz) -- (\lowx,\dimy,\z) -- (\x,\dimy,\z) -- (\x,\dimy,\lowz) -- cycle; - } - } - \foreach \y in {\steppingy,\secondy,...,\dimy} - { \foreach \z in {\steppingz,\secondz,...,\dimz} - { \pgfmathsetmacro{\lowy}{(\y-\steppingy)} - \pgfmathsetmacro{\lowz}{(\z-\steppingz)} - \filldraw[fill=\tikzcuboid@fillright,draw=\tikzcuboid@lineright] (\dimx,\lowy,\lowz) -- (\dimx,\lowy,\z) -- (\dimx,\y,\z) -- (\dimx,\y,\lowz) -- cycle; - } - } - \ifthenelse{\equal{\tikzcuboid@emphedge}{Y}}% - {\draw[\tikzcuboid@emphstyle](0,\dimy,0) -- (\dimx,\dimy,0) -- (\dimx,\dimy,\dimz) -- (0,\dimy,\dimz) -- cycle;% - \draw[\tikzcuboid@emphstyle] (0,0,\dimz) -- (0,\dimy,\dimz) -- (\dimx,\dimy,\dimz) -- (\dimx,0,\dimz) -- cycle;% - \draw[\tikzcuboid@emphstyle](\dimx,0,0) -- (\dimx,\dimy,0) -- (\dimx,\dimy,\dimz) -- (\dimx,0,\dimz) -- cycle;% - }% - {} - \end{scope} -} - -\makeatother - -\begin{document} - -\begin{tikzpicture} - \tikzcuboid{% - shiftx=21cm,% - shifty=8cm,% - scale=1.00,% - rotation=0,% - densityx=2,% - densityy=2,% - densityz=2,% - dimx=4,% - dimy=3,% - dimz=3,% - linefront=purple!75!black,% - linetop=purple!50!black,% - lineright=purple!25!black,% - fillfront=purple!25!white,% - filltop=purple!50!white,% - fillright=purple!75!white,% - emphedge=Y,% - emphstyle=ultra thick, - } - \tikzcuboid{% - shiftx=21cm,% - shifty=11.6cm,% - scale=1.00,% - rotation=0,% - densityx=2,% - densityy=2,% - densityz=2,% - dimx=4,% - dimy=3,% - dimz=3,% - linefront=teal!75!black,% - linetop=teal!50!black,% - lineright=teal!25!black,% - fillfront=teal!25!white,% - filltop=teal!50!white,% - fillright=teal!75!white,% - emphedge=Y,% - emphstyle=ultra thick, - } - \tikzcuboid{% - shiftx=26.8cm,% - shifty=8cm,% - scale=1.00,% - rotation=0,% - densityx=10000,% - densityy=2,% - densityz=2,% - dimx=0,% - dimy=3,% - dimz=3,% - linefront=orange!75!black,% - linetop=orange!50!black,% - lineright=orange!25!black,% - fillfront=orange!25!white,% - filltop=orange!50!white,% - fillright=orange!100!white,% - emphedge=Y,% - emphstyle=ultra thick, - } - \tikzcuboid{% - shiftx=28.6cm,% - shifty=8cm,% - scale=1.00,% - rotation=0,% - densityx=10000,% - densityy=2,% - densityz=2,% - dimx=0,% - dimy=3,% - dimz=3,% - linefront=purple!75!black,% - linetop=purple!50!black,% - lineright=purple!25!black,% - fillfront=purple!25!white,% - filltop=purple!50!white,% - fillright=red!75!white,% - emphedge=Y,% - emphstyle=ultra thick, - } - % \tikzcuboid{% - % shiftx=27.1cm,% - % shifty=10.1cm,% - % scale=1.00,% - % rotation=0,% - % densityx=100,% - % densityy=2,% - % densityz=100,% - % dimx=0,% - % dimy=3,% - % dimz=0,% - % emphedge=Y,% - % emphstyle=ultra thick, - % } - % \tikzcuboid{% - % shiftx=27.1cm,% - % shifty=10.1cm,% - % scale=1.00,% - % rotation=180,% - % densityx=100,% - % densityy=100,% - % densityz=2,% - % dimx=0,% - % dimy=0,% - % dimz=3,% - % emphedge=Y,% - % emphstyle=ultra thick, - % } - \tikzcuboid{% - shiftx=26.8cm,% - shifty=11.4cm,% - scale=1.00,% - rotation=0,% - densityx=100,% - densityy=2,% - densityz=100,% - dimx=0,% - dimy=3,% - dimz=0,% - emphedge=Y,% - emphstyle=ultra thick, - } - \tikzcuboid{% - shiftx=25.3cm,% - shifty=12.9cm,% - scale=1.00,% - rotation=180,% - densityx=100,% - densityy=100,% - densityz=2,% - dimx=0,% - dimy=0,% - dimz=3,% - emphedge=Y,% - emphstyle=ultra thick, - } - % \fill (27.1,10.1) circle[radius=2pt]; - \node [font=\fontsize{100}{100}\fontfamily{phv}\selectfont, anchor=west, text width=9cm, color=white!50!black] at (30,10.6) {\textbf{\emph{x}}}; - \node [font=\fontsize{100}{100}\fontfamily{phv}\selectfont, anchor=west, text width=9cm] at (32,10.25) {{array}}; -\end{tikzpicture} - -\end{document} diff --git a/doc/_static/dataset-diagram-square-logo.png b/doc/_static/dataset-diagram-square-logo.png deleted file mode 100644 index d1eeda092c4..00000000000 Binary files a/doc/_static/dataset-diagram-square-logo.png and /dev/null differ diff --git a/doc/_static/dataset-diagram-square-logo.tex b/doc/_static/dataset-diagram-square-logo.tex deleted file mode 100644 index 0a784770b50..00000000000 --- a/doc/_static/dataset-diagram-square-logo.tex +++ /dev/null @@ -1,277 +0,0 @@ -\documentclass[class=minimal,border=0pt,convert={size=600,outext=.png}]{standalone} -% \documentclass[class=minimal,border=0pt]{standalone} -\usepackage[scaled]{helvet} -\renewcommand*\familydefault{\sfdefault} - -% =========================================================================== -% The code below (used to define the \tikzcuboid command) is copied, -% unmodified, from a tex.stackexchange.com answer by the user "Tom Bombadil": -% http://tex.stackexchange.com/a/29882/8335 -% -% It is licensed under the Creative Commons Attribution-ShareAlike 3.0 -% Unported license: http://creativecommons.org/licenses/by-sa/3.0/ -% =========================================================================== - -\usepackage[usenames,dvipsnames]{color} -\usepackage{tikz} -\usepackage{keyval} -\usepackage{ifthen} - -%==================================== -%emphasize vertices --> switch and emph style (e.g. thick,black) -%==================================== -\makeatletter -% Standard Values for Parameters -\newcommand{\tikzcuboid@shiftx}{0} -\newcommand{\tikzcuboid@shifty}{0} -\newcommand{\tikzcuboid@dimx}{3} -\newcommand{\tikzcuboid@dimy}{3} -\newcommand{\tikzcuboid@dimz}{3} -\newcommand{\tikzcuboid@scale}{1} -\newcommand{\tikzcuboid@densityx}{1} -\newcommand{\tikzcuboid@densityy}{1} -\newcommand{\tikzcuboid@densityz}{1} -\newcommand{\tikzcuboid@rotation}{0} -\newcommand{\tikzcuboid@anglex}{0} -\newcommand{\tikzcuboid@angley}{90} -\newcommand{\tikzcuboid@anglez}{225} -\newcommand{\tikzcuboid@scalex}{1} -\newcommand{\tikzcuboid@scaley}{1} -\newcommand{\tikzcuboid@scalez}{sqrt(0.5)} -\newcommand{\tikzcuboid@linefront}{black} -\newcommand{\tikzcuboid@linetop}{black} -\newcommand{\tikzcuboid@lineright}{black} -\newcommand{\tikzcuboid@fillfront}{white} -\newcommand{\tikzcuboid@filltop}{white} -\newcommand{\tikzcuboid@fillright}{white} -\newcommand{\tikzcuboid@shaded}{N} -\newcommand{\tikzcuboid@shadecolor}{black} -\newcommand{\tikzcuboid@shadeperc}{25} -\newcommand{\tikzcuboid@emphedge}{N} -\newcommand{\tikzcuboid@emphstyle}{thick} - -% Definition of Keys -\define@key{tikzcuboid}{shiftx}[\tikzcuboid@shiftx]{\renewcommand{\tikzcuboid@shiftx}{#1}} -\define@key{tikzcuboid}{shifty}[\tikzcuboid@shifty]{\renewcommand{\tikzcuboid@shifty}{#1}} -\define@key{tikzcuboid}{dimx}[\tikzcuboid@dimx]{\renewcommand{\tikzcuboid@dimx}{#1}} -\define@key{tikzcuboid}{dimy}[\tikzcuboid@dimy]{\renewcommand{\tikzcuboid@dimy}{#1}} -\define@key{tikzcuboid}{dimz}[\tikzcuboid@dimz]{\renewcommand{\tikzcuboid@dimz}{#1}} -\define@key{tikzcuboid}{scale}[\tikzcuboid@scale]{\renewcommand{\tikzcuboid@scale}{#1}} -\define@key{tikzcuboid}{densityx}[\tikzcuboid@densityx]{\renewcommand{\tikzcuboid@densityx}{#1}} -\define@key{tikzcuboid}{densityy}[\tikzcuboid@densityy]{\renewcommand{\tikzcuboid@densityy}{#1}} -\define@key{tikzcuboid}{densityz}[\tikzcuboid@densityz]{\renewcommand{\tikzcuboid@densityz}{#1}} -\define@key{tikzcuboid}{rotation}[\tikzcuboid@rotation]{\renewcommand{\tikzcuboid@rotation}{#1}} -\define@key{tikzcuboid}{anglex}[\tikzcuboid@anglex]{\renewcommand{\tikzcuboid@anglex}{#1}} -\define@key{tikzcuboid}{angley}[\tikzcuboid@angley]{\renewcommand{\tikzcuboid@angley}{#1}} -\define@key{tikzcuboid}{anglez}[\tikzcuboid@anglez]{\renewcommand{\tikzcuboid@anglez}{#1}} -\define@key{tikzcuboid}{scalex}[\tikzcuboid@scalex]{\renewcommand{\tikzcuboid@scalex}{#1}} -\define@key{tikzcuboid}{scaley}[\tikzcuboid@scaley]{\renewcommand{\tikzcuboid@scaley}{#1}} -\define@key{tikzcuboid}{scalez}[\tikzcuboid@scalez]{\renewcommand{\tikzcuboid@scalez}{#1}} -\define@key{tikzcuboid}{linefront}[\tikzcuboid@linefront]{\renewcommand{\tikzcuboid@linefront}{#1}} -\define@key{tikzcuboid}{linetop}[\tikzcuboid@linetop]{\renewcommand{\tikzcuboid@linetop}{#1}} -\define@key{tikzcuboid}{lineright}[\tikzcuboid@lineright]{\renewcommand{\tikzcuboid@lineright}{#1}} -\define@key{tikzcuboid}{fillfront}[\tikzcuboid@fillfront]{\renewcommand{\tikzcuboid@fillfront}{#1}} -\define@key{tikzcuboid}{filltop}[\tikzcuboid@filltop]{\renewcommand{\tikzcuboid@filltop}{#1}} -\define@key{tikzcuboid}{fillright}[\tikzcuboid@fillright]{\renewcommand{\tikzcuboid@fillright}{#1}} -\define@key{tikzcuboid}{shaded}[\tikzcuboid@shaded]{\renewcommand{\tikzcuboid@shaded}{#1}} -\define@key{tikzcuboid}{shadecolor}[\tikzcuboid@shadecolor]{\renewcommand{\tikzcuboid@shadecolor}{#1}} -\define@key{tikzcuboid}{shadeperc}[\tikzcuboid@shadeperc]{\renewcommand{\tikzcuboid@shadeperc}{#1}} -\define@key{tikzcuboid}{emphedge}[\tikzcuboid@emphedge]{\renewcommand{\tikzcuboid@emphedge}{#1}} -\define@key{tikzcuboid}{emphstyle}[\tikzcuboid@emphstyle]{\renewcommand{\tikzcuboid@emphstyle}{#1}} -% Commands -\newcommand{\tikzcuboid}[1]{ - \setkeys{tikzcuboid}{#1} % Process Keys passed to command - \pgfmathsetmacro{\vectorxx}{\tikzcuboid@scalex*cos(\tikzcuboid@anglex)} - \pgfmathsetmacro{\vectorxy}{\tikzcuboid@scalex*sin(\tikzcuboid@anglex)} - \pgfmathsetmacro{\vectoryx}{\tikzcuboid@scaley*cos(\tikzcuboid@angley)} - \pgfmathsetmacro{\vectoryy}{\tikzcuboid@scaley*sin(\tikzcuboid@angley)} - \pgfmathsetmacro{\vectorzx}{\tikzcuboid@scalez*cos(\tikzcuboid@anglez)} - \pgfmathsetmacro{\vectorzy}{\tikzcuboid@scalez*sin(\tikzcuboid@anglez)} - \begin{scope}[xshift=\tikzcuboid@shiftx, yshift=\tikzcuboid@shifty, scale=\tikzcuboid@scale, rotate=\tikzcuboid@rotation, x={(\vectorxx,\vectorxy)}, y={(\vectoryx,\vectoryy)}, z={(\vectorzx,\vectorzy)}] - \pgfmathsetmacro{\steppingx}{1/\tikzcuboid@densityx} - \pgfmathsetmacro{\steppingy}{1/\tikzcuboid@densityy} - \pgfmathsetmacro{\steppingz}{1/\tikzcuboid@densityz} - \newcommand{\dimx}{\tikzcuboid@dimx} - \newcommand{\dimy}{\tikzcuboid@dimy} - \newcommand{\dimz}{\tikzcuboid@dimz} - \pgfmathsetmacro{\secondx}{2*\steppingx} - \pgfmathsetmacro{\secondy}{2*\steppingy} - \pgfmathsetmacro{\secondz}{2*\steppingz} - \foreach \x in {\steppingx,\secondx,...,\dimx} - { \foreach \y in {\steppingy,\secondy,...,\dimy} - { \pgfmathsetmacro{\lowx}{(\x-\steppingx)} - \pgfmathsetmacro{\lowy}{(\y-\steppingy)} - \filldraw[fill=\tikzcuboid@fillfront,draw=\tikzcuboid@linefront] (\lowx,\lowy,\dimz) -- (\lowx,\y,\dimz) -- (\x,\y,\dimz) -- (\x,\lowy,\dimz) -- cycle; - - } - } - \foreach \x in {\steppingx,\secondx,...,\dimx} - { \foreach \z in {\steppingz,\secondz,...,\dimz} - { \pgfmathsetmacro{\lowx}{(\x-\steppingx)} - \pgfmathsetmacro{\lowz}{(\z-\steppingz)} - \filldraw[fill=\tikzcuboid@filltop,draw=\tikzcuboid@linetop] (\lowx,\dimy,\lowz) -- (\lowx,\dimy,\z) -- (\x,\dimy,\z) -- (\x,\dimy,\lowz) -- cycle; - } - } - \foreach \y in {\steppingy,\secondy,...,\dimy} - { \foreach \z in {\steppingz,\secondz,...,\dimz} - { \pgfmathsetmacro{\lowy}{(\y-\steppingy)} - \pgfmathsetmacro{\lowz}{(\z-\steppingz)} - \filldraw[fill=\tikzcuboid@fillright,draw=\tikzcuboid@lineright] (\dimx,\lowy,\lowz) -- (\dimx,\lowy,\z) -- (\dimx,\y,\z) -- (\dimx,\y,\lowz) -- cycle; - } - } - \ifthenelse{\equal{\tikzcuboid@emphedge}{Y}}% - {\draw[\tikzcuboid@emphstyle](0,\dimy,0) -- (\dimx,\dimy,0) -- (\dimx,\dimy,\dimz) -- (0,\dimy,\dimz) -- cycle;% - \draw[\tikzcuboid@emphstyle] (0,0,\dimz) -- (0,\dimy,\dimz) -- (\dimx,\dimy,\dimz) -- (\dimx,0,\dimz) -- cycle;% - \draw[\tikzcuboid@emphstyle](\dimx,0,0) -- (\dimx,\dimy,0) -- (\dimx,\dimy,\dimz) -- (\dimx,0,\dimz) -- cycle;% - }% - {} - \end{scope} -} - -\makeatother - -\begin{document} - -\begin{tikzpicture} - \tikzcuboid{% - shiftx=21cm,% - shifty=8cm,% - scale=1.00,% - rotation=0,% - densityx=2,% - densityy=2,% - densityz=2,% - dimx=4,% - dimy=3,% - dimz=3,% - linefront=purple!75!black,% - linetop=purple!50!black,% - lineright=purple!25!black,% - fillfront=purple!25!white,% - filltop=purple!50!white,% - fillright=purple!75!white,% - emphedge=Y,% - emphstyle=ultra thick, - } - \tikzcuboid{% - shiftx=21cm,% - shifty=11.6cm,% - scale=1.00,% - rotation=0,% - densityx=2,% - densityy=2,% - densityz=2,% - dimx=4,% - dimy=3,% - dimz=3,% - linefront=teal!75!black,% - linetop=teal!50!black,% - lineright=teal!25!black,% - fillfront=teal!25!white,% - filltop=teal!50!white,% - fillright=teal!75!white,% - emphedge=Y,% - emphstyle=ultra thick, - } - \tikzcuboid{% - shiftx=26.8cm,% - shifty=8cm,% - scale=1.00,% - rotation=0,% - densityx=10000,% - densityy=2,% - densityz=2,% - dimx=0,% - dimy=3,% - dimz=3,% - linefront=orange!75!black,% - linetop=orange!50!black,% - lineright=orange!25!black,% - fillfront=orange!25!white,% - filltop=orange!50!white,% - fillright=orange!100!white,% - emphedge=Y,% - emphstyle=ultra thick, - } - \tikzcuboid{% - shiftx=28.6cm,% - shifty=8cm,% - scale=1.00,% - rotation=0,% - densityx=10000,% - densityy=2,% - densityz=2,% - dimx=0,% - dimy=3,% - dimz=3,% - linefront=purple!75!black,% - linetop=purple!50!black,% - lineright=purple!25!black,% - fillfront=purple!25!white,% - filltop=purple!50!white,% - fillright=red!75!white,% - emphedge=Y,% - emphstyle=ultra thick, - } - % \tikzcuboid{% - % shiftx=27.1cm,% - % shifty=10.1cm,% - % scale=1.00,% - % rotation=0,% - % densityx=100,% - % densityy=2,% - % densityz=100,% - % dimx=0,% - % dimy=3,% - % dimz=0,% - % emphedge=Y,% - % emphstyle=ultra thick, - % } - % \tikzcuboid{% - % shiftx=27.1cm,% - % shifty=10.1cm,% - % scale=1.00,% - % rotation=180,% - % densityx=100,% - % densityy=100,% - % densityz=2,% - % dimx=0,% - % dimy=0,% - % dimz=3,% - % emphedge=Y,% - % emphstyle=ultra thick, - % } - \tikzcuboid{% - shiftx=26.8cm,% - shifty=11.4cm,% - scale=1.00,% - rotation=0,% - densityx=100,% - densityy=2,% - densityz=100,% - dimx=0,% - dimy=3,% - dimz=0,% - emphedge=Y,% - emphstyle=ultra thick, - } - \tikzcuboid{% - shiftx=25.3cm,% - shifty=12.9cm,% - scale=1.00,% - rotation=180,% - densityx=100,% - densityy=100,% - densityz=2,% - dimx=0,% - dimy=0,% - dimz=3,% - emphedge=Y,% - emphstyle=ultra thick, - } - % \fill (27.1,10.1) circle[radius=2pt]; - \node [font=\fontsize{130}{100}\fontfamily{phv}\selectfont, anchor=east, text width=2cm, align=right, color=white!50!black] at (19.8,4.4) {\textbf{\emph{x}}}; - \node [font=\fontsize{130}{100}\fontfamily{phv}\selectfont, anchor=west, text width=10cm, align=left] at (20.3,4) {{array}}; -\end{tikzpicture} - -\end{document} diff --git a/doc/_static/dataset-diagram.tex b/doc/_static/dataset-diagram.tex deleted file mode 100644 index fbc063b2dad..00000000000 --- a/doc/_static/dataset-diagram.tex +++ /dev/null @@ -1,270 +0,0 @@ -\documentclass[class=minimal,border=0pt,convert={density=300,outext=.png}]{standalone} -% \documentclass[class=minimal,border=0pt]{standalone} -\usepackage[scaled]{helvet} -\renewcommand*\familydefault{\sfdefault} - -% =========================================================================== -% The code below (used to define the \tikzcuboid command) is copied, -% unmodified, from a tex.stackexchange.com answer by the user "Tom Bombadil": -% http://tex.stackexchange.com/a/29882/8335 -% -% It is licensed under the Creative Commons Attribution-ShareAlike 3.0 -% Unported license: http://creativecommons.org/licenses/by-sa/3.0/ -% =========================================================================== - -\usepackage[usenames,dvipsnames]{color} -\usepackage{tikz} -\usepackage{keyval} -\usepackage{ifthen} - -%==================================== -%emphasize vertices --> switch and emph style (e.g. thick,black) -%==================================== -\makeatletter -% Standard Values for Parameters -\newcommand{\tikzcuboid@shiftx}{0} -\newcommand{\tikzcuboid@shifty}{0} -\newcommand{\tikzcuboid@dimx}{3} -\newcommand{\tikzcuboid@dimy}{3} -\newcommand{\tikzcuboid@dimz}{3} -\newcommand{\tikzcuboid@scale}{1} -\newcommand{\tikzcuboid@densityx}{1} -\newcommand{\tikzcuboid@densityy}{1} -\newcommand{\tikzcuboid@densityz}{1} -\newcommand{\tikzcuboid@rotation}{0} -\newcommand{\tikzcuboid@anglex}{0} -\newcommand{\tikzcuboid@angley}{90} -\newcommand{\tikzcuboid@anglez}{225} -\newcommand{\tikzcuboid@scalex}{1} -\newcommand{\tikzcuboid@scaley}{1} -\newcommand{\tikzcuboid@scalez}{sqrt(0.5)} -\newcommand{\tikzcuboid@linefront}{black} -\newcommand{\tikzcuboid@linetop}{black} -\newcommand{\tikzcuboid@lineright}{black} -\newcommand{\tikzcuboid@fillfront}{white} -\newcommand{\tikzcuboid@filltop}{white} -\newcommand{\tikzcuboid@fillright}{white} -\newcommand{\tikzcuboid@shaded}{N} -\newcommand{\tikzcuboid@shadecolor}{black} -\newcommand{\tikzcuboid@shadeperc}{25} -\newcommand{\tikzcuboid@emphedge}{N} -\newcommand{\tikzcuboid@emphstyle}{thick} - -% Definition of Keys -\define@key{tikzcuboid}{shiftx}[\tikzcuboid@shiftx]{\renewcommand{\tikzcuboid@shiftx}{#1}} -\define@key{tikzcuboid}{shifty}[\tikzcuboid@shifty]{\renewcommand{\tikzcuboid@shifty}{#1}} -\define@key{tikzcuboid}{dimx}[\tikzcuboid@dimx]{\renewcommand{\tikzcuboid@dimx}{#1}} -\define@key{tikzcuboid}{dimy}[\tikzcuboid@dimy]{\renewcommand{\tikzcuboid@dimy}{#1}} -\define@key{tikzcuboid}{dimz}[\tikzcuboid@dimz]{\renewcommand{\tikzcuboid@dimz}{#1}} -\define@key{tikzcuboid}{scale}[\tikzcuboid@scale]{\renewcommand{\tikzcuboid@scale}{#1}} -\define@key{tikzcuboid}{densityx}[\tikzcuboid@densityx]{\renewcommand{\tikzcuboid@densityx}{#1}} -\define@key{tikzcuboid}{densityy}[\tikzcuboid@densityy]{\renewcommand{\tikzcuboid@densityy}{#1}} -\define@key{tikzcuboid}{densityz}[\tikzcuboid@densityz]{\renewcommand{\tikzcuboid@densityz}{#1}} -\define@key{tikzcuboid}{rotation}[\tikzcuboid@rotation]{\renewcommand{\tikzcuboid@rotation}{#1}} -\define@key{tikzcuboid}{anglex}[\tikzcuboid@anglex]{\renewcommand{\tikzcuboid@anglex}{#1}} -\define@key{tikzcuboid}{angley}[\tikzcuboid@angley]{\renewcommand{\tikzcuboid@angley}{#1}} -\define@key{tikzcuboid}{anglez}[\tikzcuboid@anglez]{\renewcommand{\tikzcuboid@anglez}{#1}} -\define@key{tikzcuboid}{scalex}[\tikzcuboid@scalex]{\renewcommand{\tikzcuboid@scalex}{#1}} -\define@key{tikzcuboid}{scaley}[\tikzcuboid@scaley]{\renewcommand{\tikzcuboid@scaley}{#1}} -\define@key{tikzcuboid}{scalez}[\tikzcuboid@scalez]{\renewcommand{\tikzcuboid@scalez}{#1}} -\define@key{tikzcuboid}{linefront}[\tikzcuboid@linefront]{\renewcommand{\tikzcuboid@linefront}{#1}} -\define@key{tikzcuboid}{linetop}[\tikzcuboid@linetop]{\renewcommand{\tikzcuboid@linetop}{#1}} -\define@key{tikzcuboid}{lineright}[\tikzcuboid@lineright]{\renewcommand{\tikzcuboid@lineright}{#1}} -\define@key{tikzcuboid}{fillfront}[\tikzcuboid@fillfront]{\renewcommand{\tikzcuboid@fillfront}{#1}} -\define@key{tikzcuboid}{filltop}[\tikzcuboid@filltop]{\renewcommand{\tikzcuboid@filltop}{#1}} -\define@key{tikzcuboid}{fillright}[\tikzcuboid@fillright]{\renewcommand{\tikzcuboid@fillright}{#1}} -\define@key{tikzcuboid}{shaded}[\tikzcuboid@shaded]{\renewcommand{\tikzcuboid@shaded}{#1}} -\define@key{tikzcuboid}{shadecolor}[\tikzcuboid@shadecolor]{\renewcommand{\tikzcuboid@shadecolor}{#1}} -\define@key{tikzcuboid}{shadeperc}[\tikzcuboid@shadeperc]{\renewcommand{\tikzcuboid@shadeperc}{#1}} -\define@key{tikzcuboid}{emphedge}[\tikzcuboid@emphedge]{\renewcommand{\tikzcuboid@emphedge}{#1}} -\define@key{tikzcuboid}{emphstyle}[\tikzcuboid@emphstyle]{\renewcommand{\tikzcuboid@emphstyle}{#1}} -% Commands -\newcommand{\tikzcuboid}[1]{ - \setkeys{tikzcuboid}{#1} % Process Keys passed to command - \pgfmathsetmacro{\vectorxx}{\tikzcuboid@scalex*cos(\tikzcuboid@anglex)} - \pgfmathsetmacro{\vectorxy}{\tikzcuboid@scalex*sin(\tikzcuboid@anglex)} - \pgfmathsetmacro{\vectoryx}{\tikzcuboid@scaley*cos(\tikzcuboid@angley)} - \pgfmathsetmacro{\vectoryy}{\tikzcuboid@scaley*sin(\tikzcuboid@angley)} - \pgfmathsetmacro{\vectorzx}{\tikzcuboid@scalez*cos(\tikzcuboid@anglez)} - \pgfmathsetmacro{\vectorzy}{\tikzcuboid@scalez*sin(\tikzcuboid@anglez)} - \begin{scope}[xshift=\tikzcuboid@shiftx, yshift=\tikzcuboid@shifty, scale=\tikzcuboid@scale, rotate=\tikzcuboid@rotation, x={(\vectorxx,\vectorxy)}, y={(\vectoryx,\vectoryy)}, z={(\vectorzx,\vectorzy)}] - \pgfmathsetmacro{\steppingx}{1/\tikzcuboid@densityx} - \pgfmathsetmacro{\steppingy}{1/\tikzcuboid@densityy} - \pgfmathsetmacro{\steppingz}{1/\tikzcuboid@densityz} - \newcommand{\dimx}{\tikzcuboid@dimx} - \newcommand{\dimy}{\tikzcuboid@dimy} - \newcommand{\dimz}{\tikzcuboid@dimz} - \pgfmathsetmacro{\secondx}{2*\steppingx} - \pgfmathsetmacro{\secondy}{2*\steppingy} - \pgfmathsetmacro{\secondz}{2*\steppingz} - \foreach \x in {\steppingx,\secondx,...,\dimx} - { \foreach \y in {\steppingy,\secondy,...,\dimy} - { \pgfmathsetmacro{\lowx}{(\x-\steppingx)} - \pgfmathsetmacro{\lowy}{(\y-\steppingy)} - \filldraw[fill=\tikzcuboid@fillfront,draw=\tikzcuboid@linefront] (\lowx,\lowy,\dimz) -- (\lowx,\y,\dimz) -- (\x,\y,\dimz) -- (\x,\lowy,\dimz) -- cycle; - - } - } - \foreach \x in {\steppingx,\secondx,...,\dimx} - { \foreach \z in {\steppingz,\secondz,...,\dimz} - { \pgfmathsetmacro{\lowx}{(\x-\steppingx)} - \pgfmathsetmacro{\lowz}{(\z-\steppingz)} - \filldraw[fill=\tikzcuboid@filltop,draw=\tikzcuboid@linetop] (\lowx,\dimy,\lowz) -- (\lowx,\dimy,\z) -- (\x,\dimy,\z) -- (\x,\dimy,\lowz) -- cycle; - } - } - \foreach \y in {\steppingy,\secondy,...,\dimy} - { \foreach \z in {\steppingz,\secondz,...,\dimz} - { \pgfmathsetmacro{\lowy}{(\y-\steppingy)} - \pgfmathsetmacro{\lowz}{(\z-\steppingz)} - \filldraw[fill=\tikzcuboid@fillright,draw=\tikzcuboid@lineright] (\dimx,\lowy,\lowz) -- (\dimx,\lowy,\z) -- (\dimx,\y,\z) -- (\dimx,\y,\lowz) -- cycle; - } - } - \ifthenelse{\equal{\tikzcuboid@emphedge}{Y}}% - {\draw[\tikzcuboid@emphstyle](0,\dimy,0) -- (\dimx,\dimy,0) -- (\dimx,\dimy,\dimz) -- (0,\dimy,\dimz) -- cycle;% - \draw[\tikzcuboid@emphstyle] (0,0,\dimz) -- (0,\dimy,\dimz) -- (\dimx,\dimy,\dimz) -- (\dimx,0,\dimz) -- cycle;% - \draw[\tikzcuboid@emphstyle](\dimx,0,0) -- (\dimx,\dimy,0) -- (\dimx,\dimy,\dimz) -- (\dimx,0,\dimz) -- cycle;% - }% - {} - \end{scope} -} - -\makeatother - -\begin{document} - -\begin{tikzpicture} - \tikzcuboid{% - shiftx=16cm,% - shifty=8cm,% - scale=1.00,% - rotation=0,% - densityx=2,% - densityy=2,% - densityz=2,% - dimx=4,% - dimy=3,% - dimz=3,% - linefront=teal!75!black,% - linetop=teal!50!black,% - lineright=teal!25!black,% - fillfront=teal!25!white,% - filltop=teal!50!white,% - fillright=teal!75!white,% - emphedge=Y,% - emphstyle=very thick, - } - \tikzcuboid{% - shiftx=21cm,% - shifty=8cm,% - scale=1.00,% - rotation=0,% - densityx=2,% - densityy=2,% - densityz=2,% - dimx=4,% - dimy=3,% - dimz=3,% - linefront=purple!75!black,% - linetop=purple!50!black,% - lineright=purple!25!black,% - fillfront=purple!25!white,% - filltop=purple!50!white,% - fillright=purple!75!white,% - emphedge=Y,% - emphstyle=very thick, - } - \tikzcuboid{% - shiftx=26.2cm,% - shifty=8cm,% - scale=1.00,% - rotation=0,% - densityx=10000,% - densityy=2,% - densityz=2,% - dimx=0,% - dimy=3,% - dimz=3,% - linefront=orange!75!black,% - linetop=orange!50!black,% - lineright=orange!25!black,% - fillfront=orange!25!white,% - filltop=orange!50!white,% - fillright=orange!100!white,% - emphedge=Y,% - emphstyle=very thick, - } - \tikzcuboid{% - shiftx=27.6cm,% - shifty=8cm,% - scale=1.00,% - rotation=0,% - densityx=10000,% - densityy=2,% - densityz=2,% - dimx=0,% - dimy=3,% - dimz=3,% - linefront=purple!75!black,% - linetop=purple!50!black,% - lineright=purple!25!black,% - fillfront=purple!25!white,% - filltop=purple!50!white,% - fillright=red!75!white,% - emphedge=Y,% - emphstyle=very thick, - } - \tikzcuboid{% - shiftx=28cm,% - shifty=6.5cm,% - scale=1.00,% - rotation=0,% - densityx=2,% - densityx=2,% - densityy=100,% - densityz=100,% - dimx=4,% - dimy=0,% - dimz=0,% - emphedge=Y,% - emphstyle=very thick, - } - \tikzcuboid{% - shiftx=28cm,% - shifty=6.5cm,% - scale=1.00,% - rotation=0,% - densityx=100,% - densityy=2,% - densityz=100,% - dimx=0,% - dimy=3,% - dimz=0,% - emphedge=Y,% - emphstyle=very thick, - } - \tikzcuboid{% - shiftx=28cm,% - shifty=6.5cm,% - scale=1.00,% - rotation=180,% - densityx=100,% - densityy=100,% - densityz=2,% - dimx=0,% - dimy=0,% - dimz=3,% - emphedge=Y,% - emphstyle=very thick, - } - \node [font=\fontsize{11}{11}\selectfont] at (18,11.5) {temperature}; - \node [font=\fontsize{11}{11}\selectfont] at (23,11.5) {precipitation}; - \node [font=\fontsize{11}{11}\selectfont] at (25.8,11.5) {latitude}; - \node [font=\fontsize{11}{11}\selectfont] at (27.5,11.47) {longitude}; - \node [font=\fontsize{11}{11}\selectfont] at (28,10) {x}; - \node [font=\fontsize{11}{11}\selectfont] at (29.5,8.5) {y}; - \node [font=\fontsize{11}{11}\selectfont] at (32,7) {t}; - \node [font=\fontsize{11}{11}\selectfont] at (31,10) {reference\_time}; - \fill (31,9.5) circle[radius=2pt]; -\end{tikzpicture} - -\end{document} diff --git a/doc/_static/favicon.ico b/doc/_static/favicon.ico deleted file mode 100644 index a1536e3ef76..00000000000 Binary files a/doc/_static/favicon.ico and /dev/null differ diff --git a/doc/_static/logos/Xarray_Icon_Final.png b/doc/_static/logos/Xarray_Icon_Final.png new file mode 100644 index 00000000000..6c0bae41829 Binary files /dev/null and b/doc/_static/logos/Xarray_Icon_Final.png differ diff --git a/doc/_static/logos/Xarray_Icon_Final.svg b/doc/_static/logos/Xarray_Icon_Final.svg new file mode 100644 index 00000000000..689b2079834 --- /dev/null +++ b/doc/_static/logos/Xarray_Icon_Final.svg @@ -0,0 +1,29 @@ + + + + + + + + + + + + + + + + + + + diff --git a/doc/_static/logos/Xarray_Logo_FullColor_InverseRGB_Final.png b/doc/_static/logos/Xarray_Logo_FullColor_InverseRGB_Final.png new file mode 100644 index 00000000000..68701eea116 Binary files /dev/null and b/doc/_static/logos/Xarray_Logo_FullColor_InverseRGB_Final.png differ diff --git a/doc/_static/logos/Xarray_Logo_FullColor_InverseRGB_Final.svg b/doc/_static/logos/Xarray_Logo_FullColor_InverseRGB_Final.svg new file mode 100644 index 00000000000..a803e93ea1a --- /dev/null +++ b/doc/_static/logos/Xarray_Logo_FullColor_InverseRGB_Final.svg @@ -0,0 +1,54 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/doc/_static/logos/Xarray_Logo_RGB_Final.png b/doc/_static/logos/Xarray_Logo_RGB_Final.png new file mode 100644 index 00000000000..823ff8db961 Binary files /dev/null and b/doc/_static/logos/Xarray_Logo_RGB_Final.png differ diff --git a/doc/_static/logos/Xarray_Logo_RGB_Final.svg b/doc/_static/logos/Xarray_Logo_RGB_Final.svg new file mode 100644 index 00000000000..86e1b4841ef --- /dev/null +++ b/doc/_static/logos/Xarray_Logo_RGB_Final.svg @@ -0,0 +1,54 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/doc/_static/style.css b/doc/_static/style.css index a097398d1e9..ea42511c51e 100644 --- a/doc/_static/style.css +++ b/doc/_static/style.css @@ -7,6 +7,11 @@ table.docutils td { word-wrap: break-word; } +div.bd-header-announcement { + background-color: unset; + color: #000; +} + /* Reduce left and right margins */ .container, .container-lg, .container-md, .container-sm, .container-xl { diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 04013d545c3..d9c89649358 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -9,17 +9,42 @@ .. autosummary:: :toctree: generated/ + Coordinates.from_pandas_multiindex + Coordinates.get + Coordinates.items + Coordinates.keys + Coordinates.values + Coordinates.dims + Coordinates.dtypes + Coordinates.variables + Coordinates.xindexes + Coordinates.indexes + Coordinates.to_dataset + Coordinates.to_index + Coordinates.update + Coordinates.assign + Coordinates.merge + Coordinates.copy + Coordinates.equals + Coordinates.identical + core.coordinates.DatasetCoordinates.get core.coordinates.DatasetCoordinates.items core.coordinates.DatasetCoordinates.keys - core.coordinates.DatasetCoordinates.merge - core.coordinates.DatasetCoordinates.to_dataset - core.coordinates.DatasetCoordinates.to_index - core.coordinates.DatasetCoordinates.update core.coordinates.DatasetCoordinates.values core.coordinates.DatasetCoordinates.dims - core.coordinates.DatasetCoordinates.indexes + core.coordinates.DatasetCoordinates.dtypes core.coordinates.DatasetCoordinates.variables + core.coordinates.DatasetCoordinates.xindexes + core.coordinates.DatasetCoordinates.indexes + core.coordinates.DatasetCoordinates.to_dataset + core.coordinates.DatasetCoordinates.to_index + core.coordinates.DatasetCoordinates.update + core.coordinates.DatasetCoordinates.assign + core.coordinates.DatasetCoordinates.merge + core.coordinates.DatasetCoordinates.copy + core.coordinates.DatasetCoordinates.equals + core.coordinates.DatasetCoordinates.identical core.rolling.DatasetCoarsen.boundary core.rolling.DatasetCoarsen.coord_func @@ -47,14 +72,20 @@ core.coordinates.DataArrayCoordinates.get core.coordinates.DataArrayCoordinates.items core.coordinates.DataArrayCoordinates.keys - core.coordinates.DataArrayCoordinates.merge - core.coordinates.DataArrayCoordinates.to_dataset - core.coordinates.DataArrayCoordinates.to_index - core.coordinates.DataArrayCoordinates.update core.coordinates.DataArrayCoordinates.values core.coordinates.DataArrayCoordinates.dims - core.coordinates.DataArrayCoordinates.indexes + core.coordinates.DataArrayCoordinates.dtypes core.coordinates.DataArrayCoordinates.variables + core.coordinates.DataArrayCoordinates.xindexes + core.coordinates.DataArrayCoordinates.indexes + core.coordinates.DataArrayCoordinates.to_dataset + core.coordinates.DataArrayCoordinates.to_index + core.coordinates.DataArrayCoordinates.update + core.coordinates.DataArrayCoordinates.assign + core.coordinates.DataArrayCoordinates.merge + core.coordinates.DataArrayCoordinates.copy + core.coordinates.DataArrayCoordinates.equals + core.coordinates.DataArrayCoordinates.identical core.rolling.DataArrayCoarsen.boundary core.rolling.DataArrayCoarsen.coord_func @@ -103,7 +134,6 @@ core.accessor_dt.DatetimeAccessor.time core.accessor_dt.DatetimeAccessor.week core.accessor_dt.DatetimeAccessor.weekday - core.accessor_dt.DatetimeAccessor.weekday_name core.accessor_dt.DatetimeAccessor.weekofyear core.accessor_dt.DatetimeAccessor.year @@ -234,6 +264,7 @@ Variable.dims Variable.dtype Variable.encoding + Variable.drop_encoding Variable.imag Variable.nbytes Variable.ndim @@ -319,6 +350,38 @@ IndexVariable.sizes IndexVariable.values + + NamedArray.all + NamedArray.any + NamedArray.attrs + NamedArray.broadcast_to + NamedArray.chunks + NamedArray.chunksizes + NamedArray.copy + NamedArray.count + NamedArray.cumprod + NamedArray.cumsum + NamedArray.data + NamedArray.dims + NamedArray.dtype + NamedArray.expand_dims + NamedArray.get_axis_num + NamedArray.max + NamedArray.mean + NamedArray.median + NamedArray.min + NamedArray.nbytes + NamedArray.ndim + NamedArray.prod + NamedArray.reduce + NamedArray.shape + NamedArray.size + NamedArray.sizes + NamedArray.std + NamedArray.sum + NamedArray.var + + plot.plot plot.line plot.step @@ -374,10 +437,8 @@ CFTimeIndex.is_floating CFTimeIndex.is_integer CFTimeIndex.is_interval - CFTimeIndex.is_mixed CFTimeIndex.is_numeric CFTimeIndex.is_object - CFTimeIndex.is_type_compatible CFTimeIndex.isin CFTimeIndex.isna CFTimeIndex.isnull @@ -398,7 +459,6 @@ CFTimeIndex.round CFTimeIndex.searchsorted CFTimeIndex.set_names - CFTimeIndex.set_value CFTimeIndex.shift CFTimeIndex.slice_indexer CFTimeIndex.slice_locs @@ -412,7 +472,6 @@ CFTimeIndex.to_flat_index CFTimeIndex.to_frame CFTimeIndex.to_list - CFTimeIndex.to_native_types CFTimeIndex.to_numpy CFTimeIndex.to_series CFTimeIndex.tolist @@ -437,8 +496,6 @@ CFTimeIndex.hasnans CFTimeIndex.hour CFTimeIndex.inferred_type - CFTimeIndex.is_all_dates - CFTimeIndex.is_monotonic CFTimeIndex.is_monotonic_increasing CFTimeIndex.is_monotonic_decreasing CFTimeIndex.is_unique @@ -456,6 +513,21 @@ CFTimeIndex.values CFTimeIndex.year + Index.from_variables + Index.concat + Index.stack + Index.unstack + Index.create_variables + Index.to_pandas_index + Index.isel + Index.sel + Index.join + Index.reindex_like + Index.equals + Index.roll + Index.rename + Index.copy + backends.NetCDF4DataStore.close backends.NetCDF4DataStore.encode backends.NetCDF4DataStore.encode_attribute @@ -483,7 +555,6 @@ backends.NetCDF4DataStore.is_remote backends.NetCDF4DataStore.lock - backends.NetCDF4BackendEntrypoint.available backends.NetCDF4BackendEntrypoint.description backends.NetCDF4BackendEntrypoint.url backends.NetCDF4BackendEntrypoint.guess_can_open @@ -516,27 +587,11 @@ backends.H5NetCDFStore.sync backends.H5NetCDFStore.ds - backends.H5netcdfBackendEntrypoint.available backends.H5netcdfBackendEntrypoint.description backends.H5netcdfBackendEntrypoint.url backends.H5netcdfBackendEntrypoint.guess_can_open backends.H5netcdfBackendEntrypoint.open_dataset - backends.PseudoNetCDFDataStore.close - backends.PseudoNetCDFDataStore.get_attrs - backends.PseudoNetCDFDataStore.get_dimensions - backends.PseudoNetCDFDataStore.get_encoding - backends.PseudoNetCDFDataStore.get_variables - backends.PseudoNetCDFDataStore.open - backends.PseudoNetCDFDataStore.open_store_variable - backends.PseudoNetCDFDataStore.ds - - backends.PseudoNetCDFBackendEntrypoint.available - backends.PseudoNetCDFBackendEntrypoint.description - backends.PseudoNetCDFBackendEntrypoint.url - backends.PseudoNetCDFBackendEntrypoint.guess_can_open - backends.PseudoNetCDFBackendEntrypoint.open_dataset - backends.PydapDataStore.close backends.PydapDataStore.get_attrs backends.PydapDataStore.get_dimensions @@ -546,7 +601,6 @@ backends.PydapDataStore.open backends.PydapDataStore.open_store_variable - backends.PydapBackendEntrypoint.available backends.PydapBackendEntrypoint.description backends.PydapBackendEntrypoint.url backends.PydapBackendEntrypoint.guess_can_open @@ -574,7 +628,6 @@ backends.ScipyDataStore.sync backends.ScipyDataStore.ds - backends.ScipyBackendEntrypoint.available backends.ScipyBackendEntrypoint.description backends.ScipyBackendEntrypoint.url backends.ScipyBackendEntrypoint.guess_can_open @@ -595,13 +648,11 @@ backends.ZarrStore.sync backends.ZarrStore.ds - backends.ZarrBackendEntrypoint.available backends.ZarrBackendEntrypoint.description backends.ZarrBackendEntrypoint.url backends.ZarrBackendEntrypoint.guess_can_open backends.ZarrBackendEntrypoint.open_dataset - backends.StoreBackendEntrypoint.available backends.StoreBackendEntrypoint.description backends.StoreBackendEntrypoint.url backends.StoreBackendEntrypoint.guess_can_open diff --git a/doc/api.rst b/doc/api.rst index 0d56fc73997..a8f8ea7dd1c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -110,6 +110,7 @@ Dataset contents Dataset.drop_indexes Dataset.drop_duplicates Dataset.drop_dims + Dataset.drop_encoding Dataset.set_coords Dataset.reset_coords Dataset.convert_calendar @@ -181,6 +182,7 @@ Computation Dataset.groupby_bins Dataset.rolling Dataset.rolling_exp + Dataset.cumulative Dataset.weighted Dataset.coarsen Dataset.resample @@ -191,6 +193,7 @@ Computation Dataset.map_blocks Dataset.polyfit Dataset.curvefit + Dataset.eval Aggregation ----------- @@ -302,6 +305,7 @@ DataArray contents DataArray.drop_vars DataArray.drop_indexes DataArray.drop_duplicates + DataArray.drop_encoding DataArray.reset_coords DataArray.copy DataArray.convert_calendar @@ -376,6 +380,7 @@ Computation DataArray.groupby_bins DataArray.rolling DataArray.rolling_exp + DataArray.cumulative DataArray.weighted DataArray.coarsen DataArray.resample @@ -518,7 +523,6 @@ Datetimelike properties DataArray.dt.nanosecond DataArray.dt.dayofweek DataArray.dt.weekday - DataArray.dt.weekday_name DataArray.dt.dayofyear DataArray.dt.quarter DataArray.dt.days_in_month @@ -555,6 +559,7 @@ Datetimelike properties DataArray.dt.seconds DataArray.dt.microseconds DataArray.dt.nanoseconds + DataArray.dt.total_seconds **Timedelta methods**: @@ -595,13 +600,12 @@ Dataset methods load_dataset open_dataset open_mfdataset - open_rasterio open_zarr save_mfdataset Dataset.as_numpy Dataset.from_dataframe Dataset.from_dict - Dataset.to_array + Dataset.to_dataarray Dataset.to_dataframe Dataset.to_dask_dataframe Dataset.to_dict @@ -626,11 +630,10 @@ DataArray methods load_dataarray open_dataarray DataArray.as_numpy - DataArray.from_cdms2 DataArray.from_dict DataArray.from_iris DataArray.from_series - DataArray.to_cdms2 + DataArray.to_dask_dataframe DataArray.to_dataframe DataArray.to_dataset DataArray.to_dict @@ -641,6 +644,7 @@ DataArray methods DataArray.to_numpy DataArray.to_pandas DataArray.to_series + DataArray.to_zarr DataArray.chunk DataArray.close DataArray.compute @@ -1054,7 +1058,6 @@ Tutorial :toctree: generated/ tutorial.open_dataset - tutorial.open_rasterio tutorial.load_dataset Testing @@ -1068,6 +1071,27 @@ Testing testing.assert_allclose testing.assert_chunks_equal +Hypothesis Testing Strategies +============================= + +.. currentmodule:: xarray + +See the :ref:`documentation page on testing ` for a guide on how to use these strategies. + +.. warning:: + These strategies should be considered highly experimental, and liable to change at any time. + +.. autosummary:: + :toctree: generated/ + + testing.strategies.supported_dtypes + testing.strategies.names + testing.strategies.dimension_names + testing.strategies.dimension_sizes + testing.strategies.attrs + testing.strategies.variables + testing.strategies.unique_subset_of + Exceptions ========== @@ -1083,12 +1107,14 @@ Advanced API .. autosummary:: :toctree: generated/ + Coordinates Dataset.variables DataArray.variable Variable IndexVariable as_variable - indexes.Index + Index + IndexSelResult Context register_dataset_accessor register_dataarray_accessor @@ -1096,6 +1122,7 @@ Advanced API backends.BackendArray backends.BackendEntrypoint backends.list_engines + backends.refresh_engines Default, pandas-backed indexes built-in Xarray: @@ -1111,7 +1138,6 @@ arguments for the ``load_store`` and ``dump_to_store`` Dataset methods: backends.NetCDF4DataStore backends.H5NetCDFStore - backends.PseudoNetCDFDataStore backends.PydapDataStore backends.ScipyDataStore backends.ZarrStore @@ -1127,7 +1153,6 @@ used filetypes in the xarray universe. backends.NetCDF4BackendEntrypoint backends.H5netcdfBackendEntrypoint - backends.PseudoNetCDFBackendEntrypoint backends.PydapBackendEntrypoint backends.ScipyBackendEntrypoint backends.StoreBackendEntrypoint diff --git a/doc/conf.py b/doc/conf.py index 0b6c6766c3b..152eb6794b4 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -49,25 +49,16 @@ matplotlib.use("Agg") -try: - import rasterio # noqa: F401 -except ImportError: - allowed_failures.update( - ["gallery/plot_rasterio_rgb.py", "gallery/plot_rasterio.py"] - ) - try: import cartopy # noqa: F401 except ImportError: allowed_failures.update( [ "gallery/plot_cartopy_facetgrid.py", - "gallery/plot_rasterio_rgb.py", - "gallery/plot_rasterio.py", ] ) -nbsphinx_allow_errors = True +nbsphinx_allow_errors = False # -- General configuration ------------------------------------------------ @@ -93,6 +84,7 @@ "sphinx_copybutton", "sphinxext.rediraffe", "sphinx_design", + "sphinx_inline_tabs", ] @@ -239,6 +231,7 @@ # canonical_url="", repository_url="https://github.com/pydata/xarray", repository_branch="main", + navigation_with_keys=False, # pydata/pydata-sphinx-theme#1492 path_to_docs="doc", use_edit_page_button=True, use_repository_button=True, @@ -247,19 +240,20 @@ extra_footer="""

Xarray is a fiscally sponsored project of NumFOCUS, a nonprofit dedicated to supporting the open-source scientific computing community.
Theme by the Executable Book Project

""", - twitter_url="https://twitter.com/xarray_devs", + twitter_url="https://twitter.com/xarray_dev", icon_links=[], # workaround for pydata/pydata-sphinx-theme#1220 + announcement="🍾 Xarray is now 10 years old! 🎉", ) # The name of an image file (relative to this directory) to place at the top # of the sidebar. -html_logo = "_static/dataset-diagram-logo.png" +html_logo = "_static/logos/Xarray_Logo_RGB_Final.svg" # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -html_favicon = "_static/favicon.ico" +html_favicon = "_static/logos/Xarray_Icon_Final.svg" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, @@ -270,11 +264,11 @@ # configuration for sphinxext.opengraph ogp_site_url = "https://docs.xarray.dev/en/latest/" -ogp_image = "https://docs.xarray.dev/en/stable/_static/dataset-diagram-logo.png" +ogp_image = "https://docs.xarray.dev/en/stable/_static/logos/Xarray_Logo_RGB_Final.png" ogp_custom_meta_tags = [ '', '', - '', + '', ] # Redirects for pages that were moved to new locations @@ -322,17 +316,22 @@ # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - "python": ("https://docs.python.org/3/", None), - "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), + "cftime": ("https://unidata.github.io/cftime", None), + "cubed": ("https://cubed-dev.github.io/cubed/", None), + "dask": ("https://docs.dask.org/en/latest", None), + "datatree": ("https://xarray-datatree.readthedocs.io/en/latest/", None), + "flox": ("https://flox.readthedocs.io/en/latest/", None), + "hypothesis": ("https://hypothesis.readthedocs.io/en/latest/", None), "iris": ("https://scitools-iris.readthedocs.io/en/latest", None), + "matplotlib": ("https://matplotlib.org/stable/", None), + "numba": ("https://numba.readthedocs.io/en/stable/", None), "numpy": ("https://numpy.org/doc/stable", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), + "python": ("https://docs.python.org/3/", None), "scipy": ("https://docs.scipy.org/doc/scipy", None), - "numba": ("https://numba.readthedocs.io/en/stable/", None), - "matplotlib": ("https://matplotlib.org/stable/", None), - "dask": ("https://docs.dask.org/en/latest", None), - "cftime": ("https://unidata.github.io/cftime", None), - "rasterio": ("https://rasterio.readthedocs.io/en/latest", None), "sparse": ("https://sparse.pydata.org/en/latest/", None), + "xarray-tutorial": ("https://tutorial.xarray.dev/", None), + "zarr": ("https://zarr.readthedocs.io/en/latest/", None), } diff --git a/doc/contributing.rst b/doc/contributing.rst index 07938f23c9f..c3dc484f4c1 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -4,29 +4,46 @@ Contributing to xarray ********************** - .. note:: Large parts of this document came from the `Pandas Contributing Guide `_. +Overview +======== + +We welcome your skills and enthusiasm at the xarray project!. There are numerous opportunities to +contribute beyond just writing code. +All contributions, including bug reports, bug fixes, documentation improvements, enhancement suggestions, +and other ideas are welcome. + +If you have any questions on the process or how to fix something feel free to ask us! +The recommended place to ask a question is on `GitHub Discussions `_ +, but we also have a `Discord `_ and a +`mailing list `_. There is also a +`"python-xarray" tag on Stack Overflow `_ which we monitor for questions. + +We also have a biweekly community call, details of which are announced on the +`Developers meeting `_. +You are very welcome to join! Though we would love to hear from you, there is no expectation to +contribute during the meeting either - you are always welcome to just sit in and listen. + +This project is a community effort, and everyone is welcome to contribute. Everyone within the community +is expected to abide by our `code of conduct `_. + Where to start? =============== -All contributions, bug reports, bug fixes, documentation improvements, -enhancements, and ideas are welcome. - If you are brand new to *xarray* or open-source development, we recommend going through the `GitHub "issues" tab `_ -to find issues that interest you. There are a number of issues listed under -`Documentation `_ +to find issues that interest you. +Some issues are particularly suited for new contributors by the label `Documentation `_ and `good first issue -`_ -where you could start out. Once you've found an interesting issue, you can -return here to get your development environment setup. +`_ where you could start out. +These are well documented issues, that do not require a deep understanding of the internals of xarray. -Feel free to ask questions on the `mailing list -`_. +Once you've found an interesting issue, you can return here to get your development environment setup. +The xarray project does not assign issues. Issues are "assigned" by opening a Pull Request(PR). .. _contributing.bug_reports: @@ -34,15 +51,20 @@ Bug reports and enhancement requests ==================================== Bug reports are an important part of making *xarray* more stable. Having a complete bug -report will allow others to reproduce the bug and provide insight into fixing. See -this `stackoverflow article for tips on -writing a good bug report `_ . +report will allow others to reproduce the bug and provide insight into fixing. Trying out the bug-producing code on the *main* branch is often a worthwhile exercise to confirm that the bug still exists. It is also worth searching existing bug reports and pull requests to see if the issue has already been reported and/or fixed. -Bug reports must: +Submitting a bug report +----------------------- + +If you find a bug in the code or documentation, do not hesitate to submit a ticket to the +`Issue Tracker `_. +You are also welcome to post feature requests or pull requests. + +If you are reporting a bug, please use the provided template which includes the following: #. Include a short, self-contained Python snippet reproducing the problem. You can format the code nicely by using `GitHub Flavored Markdown @@ -67,13 +89,12 @@ Bug reports must: #. Explain why the current behavior is wrong/not desired and what you expect instead. -The issue will then show up to the *xarray* community and be open to comments/ideas -from others. +The issue will then show up to the *xarray* community and be open to comments/ideas from others. -.. _contributing.github: +See this `stackoverflow article for tips on writing a good bug report `_ . -Working with the code -===================== + +.. _contributing.github: Now that you have an issue you want to fix, enhancement to add, or documentation to improve, you need to learn how to work with GitHub and the *xarray* code base. @@ -81,12 +102,7 @@ to improve, you need to learn how to work with GitHub and the *xarray* code base .. _contributing.version_control: Version control, Git, and GitHub --------------------------------- - -To the new user, working with Git is one of the more daunting aspects of contributing -to *xarray*. It can very quickly become overwhelming, but sticking to the guidelines -below will help keep the process straightforward and mostly trouble free. As always, -if you are having difficulties please feel free to ask for help. +================================ The code is hosted on `GitHub `_. To contribute you will need to sign up for a `free GitHub account @@ -112,41 +128,41 @@ you can work seamlessly between your local repository and GitHub. but contributors who are new to git may find it easier to use other tools instead such as `Github Desktop `_. -.. _contributing.forking: +Development workflow +==================== + +To keep your work well organized, with readable history, and in turn make it easier for project +maintainers to see what you've done, and why you did it, we recommend you to follow workflow: -Forking -------- +1. `Create an account `_ on GitHub if you do not already have one. -You will need your own fork to work on the code. Go to the `xarray project -page `_ and hit the ``Fork`` button. You will -want to clone your fork to your machine:: +2. You will need your own fork to work on the code. Go to the `xarray project + page `_ and hit the ``Fork`` button near the top of the page. + This creates a copy of the code under your account on the GitHub server. + +3. Clone your fork to your machine:: git clone https://github.com/your-user-name/xarray.git cd xarray git remote add upstream https://github.com/pydata/xarray.git -This creates the directory `xarray` and connects your repository to -the upstream (main project) *xarray* repository. - -Creating a branch ------------------ - -You want your ``main`` branch to reflect only production-ready code, so create a -feature branch before making your changes. For example:: + This creates the directory `xarray` and connects your repository to + the upstream (main project) *xarray* repository. - git branch shiny-new-feature - git checkout shiny-new-feature +Creating a development environment +---------------------------------- -The above can be simplified to:: +To test out code changes locally, you'll need to build *xarray* from source, which requires you to +`create a local development environment `_. - git checkout -b shiny-new-feature +Update the ``main`` branch +-------------------------- -This changes your working directory to the shiny-new-feature branch. Keep any -changes in this branch specific to one bug or feature so it is clear -what the branch brings to *xarray*. You can have many "shiny-new-features" -and switch in between them using the ``git checkout`` command. +First make sure you have followed `Setting up xarray for development +`_ -To update this branch, you need to retrieve the changes from the ``main`` branch:: +Before starting a new set of changes, fetch all changes from ``upstream/main``, and start a new +feature branch from that. From time to time you should fetch the upstream changes from GitHub: :: git fetch upstream git merge upstream/main @@ -157,10 +173,83 @@ request. If you have uncommitted changes, you will need to ``git stash`` them prior to updating. This will effectively store your changes, which can be reapplied after updating. +Create a new feature branch +--------------------------- + +Create a branch to save your changes, even before you start making changes. You want your +``main branch`` to contain only production-ready code:: + + git checkout -b shiny-new-feature + +This changes your working directory to the ``shiny-new-feature`` branch. Keep any changes in this +branch specific to one bug or feature so it is clear what the branch brings to *xarray*. You can have +many "shiny-new-features" and switch in between them using the ``git checkout`` command. + +Generally, you will want to keep your feature branches on your public GitHub fork of xarray. To do this, +you ``git push`` this new branch up to your GitHub repo. Generally (if you followed the instructions in +these pages, and by default), git will have a link to your fork of the GitHub repo, called ``origin``. +You push up to your own fork with: :: + + git push origin shiny-new-feature + +In git >= 1.7 you can ensure that the link is correctly set by using the ``--set-upstream`` option: :: + + git push --set-upstream origin shiny-new-feature + +From now on git will know that ``shiny-new-feature`` is related to the ``shiny-new-feature branch`` in the GitHub repo. + +The editing workflow +-------------------- + +1. Make some changes + +2. See which files have changed with ``git status``. You'll see a listing like this one: :: + + # On branch shiny-new-feature + # Changed but not updated: + # (use "git add ..." to update what will be committed) + # (use "git checkout -- ..." to discard changes in working directory) + # + # modified: README + +3. Check what the actual changes are with ``git diff``. + +4. Build the `documentation run `_ +for the documentation changes. + +`Run the test suite `_ for code changes. + +Commit and push your changes +---------------------------- + +1. To commit all modified files into the local copy of your repo, do ``git commit -am 'A commit message'``. + +2. To push the changes up to your forked repo on GitHub, do a ``git push``. + +Open a pull request +------------------- + +When you're ready or need feedback on your code, open a Pull Request (PR) so that the xarray developers can +give feedback and eventually include your suggested code into the ``main`` branch. +`Pull requests (PRs) on GitHub `_ +are the mechanism for contributing to xarray's code and documentation. + +Enter a title for the set of changes with some explanation of what you've done. +Follow the PR template, which looks like this. :: + + [ ]Closes #xxxx + [ ]Tests added + [ ]User visible changes (including notable bug fixes) are documented in whats-new.rst + [ ]New functions/methods are listed in api.rst + +Mention anything you'd like particular attention for - such as a complicated change or some code you are not happy with. +If you don't think your request is ready to be merged, just say so in your pull request message and use +the "Draft PR" feature of GitHub. This is a good way of getting some preliminary code review. + .. _contributing.dev_env: Creating a development environment ----------------------------------- +================================== To test out code changes locally, you'll need to build *xarray* from source, which requires a Python environment. If you're making documentation changes, you can @@ -182,7 +271,7 @@ documentation locally before pushing your changes. .. _contributing.dev_python: Creating a Python Environment -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +----------------------------- Before starting any development, you'll need to create an isolated xarray development environment: @@ -240,6 +329,22 @@ To return to your root environment:: See the full `conda docs here `__. +Install pre-commit hooks +------------------------ + +We highly recommend that you setup `pre-commit `_ hooks to automatically +run all the above tools every time you make a git commit. To install the hooks:: + + python -m pip install pre-commit + pre-commit install + +This can be done by running: :: + + pre-commit run + +from the root of the xarray repository. You can skip the pre-commit checks with +``git commit --no-verify``. + .. _contributing.documentation: Contributing to the documentation @@ -363,6 +468,60 @@ If you want to do a full clean build, do:: make clean make html +Writing ReST pages +------------------ + +Most documentation is either in the docstrings of individual classes and methods, in explicit +``.rst`` files, or in examples and tutorials. All of these use the +`ReST `_ syntax and are processed by +`Sphinx `_. + +This section contains additional information and conventions how ReST is used in the +xarray documentation. + +Section formatting +~~~~~~~~~~~~~~~~~~ + +We aim to follow the recommendations from the +`Python documentation `_ +and the `Sphinx reStructuredText documentation `_ +for section markup characters, + +- ``*`` with overline, for chapters + +- ``=``, for heading + +- ``-``, for sections + +- ``~``, for subsections + +- ``**`` text ``**``, for **bold** text + +Referring to other documents and sections +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +`Sphinx `_ allows internal +`references `_ between documents. + +Documents can be linked with the ``:doc:`` directive: + +:: + + See the :doc:`/getting-started-guide/installing` + + See the :doc:`/getting-started-guide/quick-overview` + +will render as: + +See the `Installation `_ + +See the `Quick Overview `_ + +Including figures and files +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Image files can be directly included in pages with the ``image::`` directive. + .. _contributing.code: Contributing to the code base @@ -490,9 +649,7 @@ Writing tests All tests should go into the ``tests`` subdirectory of the specific package. This folder contains many current examples of tests, and we suggest looking to these for -inspiration. If your test requires working with files or -network connectivity, there is more information on the `testing page -`_ of the wiki. +inspiration. The ``xarray.testing`` module has many special ``assert`` functions that make it easier to make statements about whether DataArray or Dataset objects are @@ -513,8 +670,7 @@ typically find tests wrapped in a class. .. code-block:: python - class TestReallyCoolFeature: - ... + class TestReallyCoolFeature: ... Going forward, we are moving to a more *functional* style using the `pytest `__ framework, which offers a richer @@ -523,8 +679,7 @@ writing test classes, we will write test functions like this: .. code-block:: python - def test_really_cool_feature(): - ... + def test_really_cool_feature(): ... Using ``pytest`` ~~~~~~~~~~~~~~~~ @@ -672,17 +827,17 @@ Running the performance test suite Performance matters and it is worth considering whether your code has introduced performance regressions. *xarray* is starting to write a suite of benchmarking tests -using `asv `__ +using `asv `__ to enable easy monitoring of the performance of critical *xarray* operations. These benchmarks are all found in the ``xarray/asv_bench`` directory. To use all features of asv, you will need either ``conda`` or ``virtualenv``. For more details please check the `asv installation -webpage `_. +webpage `_. To install asv:: - pip install git+https://github.com/spacetelescope/asv + python -m pip install asv If you need to run a benchmark, change your directory to ``asv_bench/`` and run:: @@ -912,7 +1067,7 @@ PR checklist - Write new tests if needed. See `"Test-driven development/code writing" `_. - Test the code using `Pytest `_. Running all tests (type ``pytest`` in the root directory) takes a while, so feel free to only run the tests you think are needed based on your PR (example: ``pytest xarray/tests/test_dataarray.py``). CI will catch any failing tests. - - By default, the upstream dev CI is disabled on pull request and push events. You can override this behavior per commit by adding a [test-upstream] tag to the first line of the commit message. For documentation-only commits, you can skip the CI per commit by adding a "[skip-ci]" tag to the first line of the commit message. + - By default, the upstream dev CI is disabled on pull request and push events. You can override this behavior per commit by adding a ``[test-upstream]`` tag to the first line of the commit message. For documentation-only commits, you can skip the CI per commit by adding a ``[skip-ci]`` tag to the first line of the commit message. - **Properly format your code** and verify that it passes the formatting guidelines set by `Black `_ and `Flake8 `_. See `"Code formatting" `_. You can use `pre-commit `_ to run these automatically on each commit. diff --git a/doc/developers-meeting.rst b/doc/developers-meeting.rst index 1c49a900f66..153f3520f26 100644 --- a/doc/developers-meeting.rst +++ b/doc/developers-meeting.rst @@ -3,18 +3,18 @@ Developers meeting Xarray developers meet bi-weekly every other Wednesday. -The meeting occurs on `Zoom `__. +The meeting occurs on `Zoom `__. -Find the `notes for the meeting here `__. +Find the `notes for the meeting here `__. There is a :issue:`GitHub issue for changes to the meeting<4001>`. You can subscribe to this calendar to be notified of changes: -* `Google Calendar `__ -* `iCal `__ +* `Google Calendar `__ +* `iCal `__ .. raw:: html - + diff --git a/doc/ecosystem.rst b/doc/ecosystem.rst index e6e970c6239..076874d82f3 100644 --- a/doc/ecosystem.rst +++ b/doc/ecosystem.rst @@ -36,11 +36,13 @@ Geosciences - `rioxarray `_: geospatial xarray extension powered by rasterio - `salem `_: Adds geolocalised subsetting, masking, and plotting operations to xarray's data structures via accessors. - `SatPy `_ : Library for reading and manipulating meteorological remote sensing data and writing it to various image and data file formats. +- `SARXarray `_: xarray extension for reading and processing large Synthetic Aperture Radar (SAR) data stacks. - `Spyfit `_: FTIR spectroscopy of the atmosphere - `windspharm `_: Spherical harmonic wind analysis in Python. - `wradlib `_: An Open Source Library for Weather Radar Data Processing. - `wrf-python `_: A collection of diagnostic and interpolation routines for use with output of the Weather Research and Forecasting (WRF-ARW) Model. +- `xarray-regrid `_: xarray extension for regridding rectilinear data. - `xarray-simlab `_: xarray extension for computer model simulations. - `xarray-spatial `_: Numba-accelerated raster-based spatial processing tools (NDVI, curvature, zonal-statistics, proximity, hillshading, viewshed, etc.) - `xarray-topo `_: xarray extension for topographic analysis and modelling. @@ -77,6 +79,7 @@ Extend xarray capabilities - `xarray-dataclasses `_: xarray extension for typed DataArray and Dataset creation. - `xarray_einstats `_: Statistics, linear algebra and einops for xarray - `xarray_extras `_: Advanced algorithms for xarray objects (e.g. integrations/interpolations). +- `xeofs `_: PCA/EOF analysis and related techniques, integrated with xarray and Dask for efficient handling of large-scale data. - `xpublish `_: Publish Xarray Datasets via a Zarr compatible REST API. - `xrft `_: Fourier transforms for xarray data. - `xr-scipy `_: A lightweight scipy wrapper for xarray. @@ -96,7 +99,6 @@ Visualization Non-Python projects ~~~~~~~~~~~~~~~~~~~ - `xframe `_: C++ data structures inspired by xarray. -- `AxisArrays `_ and - `NamedArrays `_: similar data structures for Julia. +- `AxisArrays `_, `NamedArrays `_ and `YAXArrays.jl `_: similar data structures for Julia. More projects can be found at the `"xarray" Github topic `_. diff --git a/doc/examples/apply_ufunc_vectorize_1d.ipynb b/doc/examples/apply_ufunc_vectorize_1d.ipynb index 68d011d0725..c2ab7271873 100644 --- a/doc/examples/apply_ufunc_vectorize_1d.ipynb +++ b/doc/examples/apply_ufunc_vectorize_1d.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This example will illustrate how to conveniently apply an unvectorized function `func` to xarray objects using `apply_ufunc`. `func` expects 1D numpy arrays and returns a 1D numpy array. Our goal is to coveniently apply this function along a dimension of xarray objects that may or may not wrap dask arrays with a signature.\n", + "This example will illustrate how to conveniently apply an unvectorized function `func` to xarray objects using `apply_ufunc`. `func` expects 1D numpy arrays and returns a 1D numpy array. Our goal is to conveniently apply this function along a dimension of xarray objects that may or may not wrap dask arrays with a signature.\n", "\n", "We will illustrate this using `np.interp`: \n", "\n", diff --git a/doc/examples/multidimensional-coords.ipynb b/doc/examples/multidimensional-coords.ipynb index f7471f05e5d..ce8a091a5da 100644 --- a/doc/examples/multidimensional-coords.ipynb +++ b/doc/examples/multidimensional-coords.ipynb @@ -56,7 +56,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In this example, the _logical coordinates_ are `x` and `y`, while the _physical coordinates_ are `xc` and `yc`, which represent the latitudes and longitude of the data." + "In this example, the _logical coordinates_ are `x` and `y`, while the _physical coordinates_ are `xc` and `yc`, which represent the longitudes and latitudes of the data." ] }, { diff --git a/doc/examples/visualization_gallery.ipynb b/doc/examples/visualization_gallery.ipynb index e6fa564db0d..e7e9196a6f6 100644 --- a/doc/examples/visualization_gallery.ipynb +++ b/doc/examples/visualization_gallery.ipynb @@ -193,90 +193,6 @@ "# Show\n", "plt.tight_layout()" ] - }, - { - "cell_type": "markdown", - "metadata": { - "jp-MarkdownHeadingCollapsed": true, - "tags": [] - }, - "source": [ - "## `imshow()` and rasterio map projections\n", - "\n", - "\n", - "Using rasterio's projection information for more accurate plots.\n", - "\n", - "This example extends `recipes.rasterio` and plots the image in the\n", - "original map projection instead of relying on pcolormesh and a map\n", - "transformation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "da = xr.tutorial.open_rasterio(\"RGB.byte\")\n", - "\n", - "# The data is in UTM projection. We have to set it manually until\n", - "# https://github.com/SciTools/cartopy/issues/813 is implemented\n", - "crs = ccrs.UTM(\"18\")\n", - "\n", - "# Plot on a map\n", - "ax = plt.subplot(projection=crs)\n", - "da.plot.imshow(ax=ax, rgb=\"band\", transform=crs)\n", - "ax.coastlines(\"10m\", color=\"r\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Parsing rasterio geocoordinates\n", - "\n", - "Converting a projection's cartesian coordinates into 2D longitudes and\n", - "latitudes.\n", - "\n", - "These new coordinates might be handy for plotting and indexing, but it should\n", - "be kept in mind that a grid which is regular in projection coordinates will\n", - "likely be irregular in lon/lat. It is often recommended to work in the data's\n", - "original map projection (see `recipes.rasterio_rgb`)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from pyproj import Transformer\n", - "import numpy as np\n", - "\n", - "da = xr.tutorial.open_rasterio(\"RGB.byte\")\n", - "\n", - "x, y = np.meshgrid(da[\"x\"], da[\"y\"])\n", - "transformer = Transformer.from_crs(da.crs, \"EPSG:4326\", always_xy=True)\n", - "lon, lat = transformer.transform(x, y)\n", - "da.coords[\"lon\"] = ((\"y\", \"x\"), lon)\n", - "da.coords[\"lat\"] = ((\"y\", \"x\"), lat)\n", - "\n", - "# Compute a greyscale out of the rgb image\n", - "greyscale = da.mean(dim=\"band\")\n", - "\n", - "# Plot on a map\n", - "ax = plt.subplot(projection=ccrs.PlateCarree())\n", - "greyscale.plot(\n", - " ax=ax,\n", - " x=\"lon\",\n", - " y=\"lat\",\n", - " transform=ccrs.PlateCarree(),\n", - " cmap=\"Greys_r\",\n", - " shading=\"auto\",\n", - " add_colorbar=False,\n", - ")\n", - "ax.coastlines(\"10m\", color=\"r\")" - ] } ], "metadata": { @@ -296,6 +212,13 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } } }, "nbformat": 4, diff --git a/doc/gallery.yml b/doc/gallery.yml index f1a147dae87..f8316017d8c 100644 --- a/doc/gallery.yml +++ b/doc/gallery.yml @@ -25,12 +25,12 @@ notebooks-examples: - title: Applying unvectorized functions with apply_ufunc path: examples/apply_ufunc_vectorize_1d.html - thumbnail: _static/dataset-diagram-square-logo.png + thumbnail: _static/logos/Xarray_Logo_RGB_Final.svg external-examples: - title: Managing raster data with rioxarray path: https://corteva.github.io/rioxarray/stable/examples/examples.html - thumbnail: _static/dataset-diagram-square-logo.png + thumbnail: _static/logos/Xarray_Logo_RGB_Final.svg - title: Xarray and dask on the cloud with Pangeo path: https://gallery.pangeo.io/ @@ -38,7 +38,7 @@ external-examples: - title: Xarray with Dask Arrays path: https://examples.dask.org/xarray.html_ - thumbnail: _static/dataset-diagram-square-logo.png + thumbnail: _static/logos/Xarray_Logo_RGB_Final.svg - title: Project Pythia Foundations Book path: https://foundations.projectpythia.org/core/xarray.html diff --git a/doc/gallery/plot_cartopy_facetgrid.py b/doc/gallery/plot_cartopy_facetgrid.py index d8f5e73ee56..faa148938d6 100644 --- a/doc/gallery/plot_cartopy_facetgrid.py +++ b/doc/gallery/plot_cartopy_facetgrid.py @@ -13,7 +13,6 @@ .. _this discussion: https://github.com/pydata/xarray/issues/1397#issuecomment-299190567 """ - import cartopy.crs as ccrs import matplotlib.pyplot as plt @@ -30,7 +29,7 @@ transform=ccrs.PlateCarree(), # the data's projection col="time", col_wrap=1, # multiplot settings - aspect=ds.dims["lon"] / ds.dims["lat"], # for a sensible figsize + aspect=ds.sizes["lon"] / ds.sizes["lat"], # for a sensible figsize subplot_kws={"projection": map_proj}, # the plot's projection ) diff --git a/doc/gallery/plot_control_colorbar.py b/doc/gallery/plot_control_colorbar.py index 8fb8d7f8be6..280e753db9a 100644 --- a/doc/gallery/plot_control_colorbar.py +++ b/doc/gallery/plot_control_colorbar.py @@ -6,6 +6,7 @@ Use ``cbar_kwargs`` keyword to specify the number of ticks. The ``spacing`` kwarg can be used to draw proportional ticks. """ + import matplotlib.pyplot as plt import xarray as xr diff --git a/doc/gallery/plot_rasterio.py b/doc/gallery/plot_rasterio.py deleted file mode 100644 index 853923a38bd..00000000000 --- a/doc/gallery/plot_rasterio.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -.. _recipes.rasterio: - -================================= -Parsing rasterio's geocoordinates -================================= - - -Converting a projection's cartesian coordinates into 2D longitudes and -latitudes. - -These new coordinates might be handy for plotting and indexing, but it should -be kept in mind that a grid which is regular in projection coordinates will -likely be irregular in lon/lat. It is often recommended to work in the data's -original map projection (see :ref:`recipes.rasterio_rgb`). -""" - -import cartopy.crs as ccrs -import matplotlib.pyplot as plt -import numpy as np -from pyproj import Transformer - -import xarray as xr - -# Read the data -url = "https://github.com/rasterio/rasterio/raw/master/tests/data/RGB.byte.tif" -da = xr.open_rasterio(url) - -# Compute the lon/lat coordinates with pyproj -transformer = Transformer.from_crs(da.crs, "EPSG:4326", always_xy=True) -lon, lat = transformer.transform(*np.meshgrid(da["x"], da["y"])) -da.coords["lon"] = (("y", "x"), lon) -da.coords["lat"] = (("y", "x"), lat) - -# Compute a greyscale out of the rgb image -greyscale = da.mean(dim="band") - -# Plot on a map -ax = plt.subplot(projection=ccrs.PlateCarree()) -greyscale.plot( - ax=ax, - x="lon", - y="lat", - transform=ccrs.PlateCarree(), - cmap="Greys_r", - add_colorbar=False, -) -ax.coastlines("10m", color="r") -plt.show() diff --git a/doc/gallery/plot_rasterio_rgb.py b/doc/gallery/plot_rasterio_rgb.py deleted file mode 100644 index 912224ac132..00000000000 --- a/doc/gallery/plot_rasterio_rgb.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -.. _recipes.rasterio_rgb: - -============================ -imshow() and map projections -============================ - -Using rasterio's projection information for more accurate plots. - -This example extends :ref:`recipes.rasterio` and plots the image in the -original map projection instead of relying on pcolormesh and a map -transformation. -""" - -import cartopy.crs as ccrs -import matplotlib.pyplot as plt - -import xarray as xr - -# Read the data -url = "https://github.com/rasterio/rasterio/raw/master/tests/data/RGB.byte.tif" -da = xr.open_rasterio(url) - -# The data is in UTM projection. We have to set it manually until -# https://github.com/SciTools/cartopy/issues/813 is implemented -crs = ccrs.UTM("18N") - -# Plot on a map -ax = plt.subplot(projection=crs) -da.plot.imshow(ax=ax, rgb="band", transform=crs) -ax.coastlines("10m", color="r") -plt.show() diff --git a/doc/getting-started-guide/faq.rst b/doc/getting-started-guide/faq.rst index 08cb9646f94..7f99fa77e3a 100644 --- a/doc/getting-started-guide/faq.rst +++ b/doc/getting-started-guide/faq.rst @@ -168,18 +168,11 @@ integration with Cartopy_. .. _Iris: https://scitools-iris.readthedocs.io/en/stable/ .. _Cartopy: https://scitools.org.uk/cartopy/docs/latest/ -`UV-CDAT`__ is another Python library that implements in-memory netCDF-like -variables and `tools for working with climate data`__. - -__ https://uvcdat.llnl.gov/ -__ https://drclimate.wordpress.com/2014/01/02/a-beginners-guide-to-scripting-with-uv-cdat/ - We think the design decisions we have made for xarray (namely, basing it on pandas) make it a faster and more flexible data analysis tool. That said, Iris -and CDAT have some great domain specific functionality, and xarray includes -methods for converting back and forth between xarray and these libraries. See -:py:meth:`~xarray.DataArray.to_iris` and :py:meth:`~xarray.DataArray.to_cdms2` -for more details. +has some great domain specific functionality, and xarray includes +methods for converting back and forth between xarray and Iris. See +:py:meth:`~xarray.DataArray.to_iris` for more details. What other projects leverage xarray? ------------------------------------ @@ -356,6 +349,25 @@ There may be situations where you need to specify the engine manually using the Some packages may have additional functionality beyond what is shown here. You can refer to the documentation for each package for more information. +How does xarray handle missing values? +-------------------------------------- + +**xarray can handle missing values using ``np.NaN``** + +- ``np.NaN`` is used to represent missing values in labeled arrays and datasets. It is a commonly used standard for representing missing or undefined numerical data in scientific computing. ``np.NaN`` is a constant value in NumPy that represents "Not a Number" or missing values. + +- Most of xarray's computation methods are designed to automatically handle missing values appropriately. + + For example, when performing operations like addition or multiplication on arrays that contain missing values, xarray will automatically ignore the missing values and only perform the operation on the valid data. This makes it easy to work with data that may contain missing or undefined values without having to worry about handling them explicitly. + +- Many of xarray's `aggregation methods `_, such as ``sum()``, ``mean()``, ``min()``, ``max()``, and others, have a skipna argument that controls whether missing values (represented by NaN) should be skipped (True) or treated as NaN (False) when performing the calculation. + + By default, ``skipna`` is set to `True`, so missing values are ignored when computing the result. However, you can set ``skipna`` to `False` if you want missing values to be treated as NaN and included in the calculation. + +- On `plotting `_ an xarray dataset or array that contains missing values, xarray will simply leave the missing values as blank spaces in the plot. + +- We have a set of `methods `_ for manipulating missing and filling values. + How should I cite xarray? ------------------------- diff --git a/doc/getting-started-guide/installing.rst b/doc/getting-started-guide/installing.rst index 6b3283adcbd..f7eaf92f9cf 100644 --- a/doc/getting-started-guide/installing.rst +++ b/doc/getting-started-guide/installing.rst @@ -7,9 +7,9 @@ Required dependencies --------------------- - Python (3.9 or later) -- `numpy `__ (1.21 or later) -- `packaging `__ (21.3 or later) -- `pandas `__ (1.4 or later) +- `numpy `__ (1.23 or later) +- `packaging `__ (22 or later) +- `pandas `__ (1.5 or later) .. _optional-dependencies: @@ -38,11 +38,6 @@ For netCDF and IO - `cftime `__: recommended if you want to encode/decode datetimes for non-standard calendars or dates before year 1678 or after year 2262. -- `PseudoNetCDF `__: recommended - for accessing CAMx, GEOS-Chem (bpch), NOAA ARL files, ICARTT files - (ffi1001) and many other. -- `rasterio `__: for reading GeoTiffs and - other gridded raster datasets. - `iris `__: for conversion to and from iris' Cube objects @@ -88,7 +83,7 @@ Minimum dependency versions Xarray adopts a rolling policy regarding the minimum supported version of its dependencies: -- **Python:** 24 months +- **Python:** 30 months (`NEP-29 `_) - **numpy:** 18 months (`NEP-29 `_) @@ -137,13 +132,13 @@ We also maintain other dependency sets for different subsets of functionality:: The above commands should install most of the `optional dependencies`_. However, some packages which are either not listed on PyPI or require extra installation steps are excluded. To know which dependencies would be -installed, take a look at the ``[options.extras_require]`` section in -``setup.cfg``: +installed, take a look at the ``[project.optional-dependencies]`` section in +``pyproject.toml``: -.. literalinclude:: ../../setup.cfg - :language: ini - :start-at: [options.extras_require] - :end-before: [options.package_data] +.. literalinclude:: ../../pyproject.toml + :language: toml + :start-at: [project.optional-dependencies] + :end-before: [build-system] Development versions -------------------- diff --git a/doc/howdoi.rst b/doc/howdoi.rst index b6374cc5100..97b0872fdc4 100644 --- a/doc/howdoi.rst +++ b/doc/howdoi.rst @@ -36,13 +36,13 @@ How do I ... * - rename a variable, dimension or coordinate - :py:meth:`Dataset.rename`, :py:meth:`DataArray.rename`, :py:meth:`Dataset.rename_vars`, :py:meth:`Dataset.rename_dims`, * - convert a DataArray to Dataset or vice versa - - :py:meth:`DataArray.to_dataset`, :py:meth:`Dataset.to_array`, :py:meth:`Dataset.to_stacked_array`, :py:meth:`DataArray.to_unstacked_dataset` + - :py:meth:`DataArray.to_dataset`, :py:meth:`Dataset.to_dataarray`, :py:meth:`Dataset.to_stacked_array`, :py:meth:`DataArray.to_unstacked_dataset` * - extract variables that have certain attributes - :py:meth:`Dataset.filter_by_attrs` * - extract the underlying array (e.g. NumPy or Dask arrays) - :py:attr:`DataArray.data` * - convert to and extract the underlying NumPy array - - :py:attr:`DataArray.values` + - :py:attr:`DataArray.to_numpy` * - convert to a pandas DataFrame - :py:attr:`Dataset.to_dataframe` * - sort values diff --git a/doc/internals/chunked-arrays.rst b/doc/internals/chunked-arrays.rst new file mode 100644 index 00000000000..ba7ce72c834 --- /dev/null +++ b/doc/internals/chunked-arrays.rst @@ -0,0 +1,102 @@ +.. currentmodule:: xarray + +.. _internals.chunkedarrays: + +Alternative chunked array types +=============================== + +.. warning:: + + This is a *highly* experimental feature. Please report any bugs or other difficulties on `xarray's issue tracker `_. + In particular see discussion on `xarray issue #6807 `_ + +Xarray can wrap chunked dask arrays (see :ref:`dask`), but can also wrap any other chunked array type that exposes the correct interface. +This allows us to support using other frameworks for distributed and out-of-core processing, with user code still written as xarray commands. +In particular xarray also supports wrapping :py:class:`cubed.Array` objects +(see `Cubed's documentation `_ and the `cubed-xarray package `_). + +The basic idea is that by wrapping an array that has an explicit notion of ``.chunks``, xarray can expose control over +the choice of chunking scheme to users via methods like :py:meth:`DataArray.chunk` whilst the wrapped array actually +implements the handling of processing all of the chunks. + +Chunked array methods and "core operations" +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A chunked array needs to meet all the :ref:`requirements for normal duck arrays `, but must also +implement additional features. + +Chunked arrays have additional attributes and methods, such as ``.chunks`` and ``.rechunk``. +Furthermore, Xarray dispatches chunk-aware computations across one or more chunked arrays using special functions known +as "core operations". Examples include ``map_blocks``, ``blockwise``, and ``apply_gufunc``. + +The core operations are generalizations of functions first implemented in :py:mod:`dask.array`. +The implementation of these functions is specific to the type of arrays passed to them. For example, when applying the +``map_blocks`` core operation, :py:class:`dask.array.Array` objects must be processed by :py:func:`dask.array.map_blocks`, +whereas :py:class:`cubed.Array` objects must be processed by :py:func:`cubed.map_blocks`. + +In order to use the correct implementation of a core operation for the array type encountered, xarray dispatches to the +corresponding subclass of :py:class:`~xarray.namedarray.parallelcompat.ChunkManagerEntrypoint`, +also known as a "Chunk Manager". Therefore **a full list of the operations that need to be defined is set by the +API of the** :py:class:`~xarray.namedarray.parallelcompat.ChunkManagerEntrypoint` **abstract base class**. Note that chunked array +methods are also currently dispatched using this class. + +Chunked array creation is also handled by this class. As chunked array objects have a one-to-one correspondence with +in-memory numpy arrays, it should be possible to create a chunked array from a numpy array by passing the desired +chunking pattern to an implementation of :py:class:`~xarray.namedarray.parallelcompat.ChunkManagerEntrypoint.from_array``. + +.. note:: + + The :py:class:`~xarray.namedarray.parallelcompat.ChunkManagerEntrypoint` abstract base class is mostly just acting as a + namespace for containing the chunked-aware function primitives. Ideally in the future we would have an API standard + for chunked array types which codified this structure, making the entrypoint system unnecessary. + +.. currentmodule:: xarray.namedarray.parallelcompat + +.. autoclass:: xarray.namedarray.parallelcompat.ChunkManagerEntrypoint + :members: + +Registering a new ChunkManagerEntrypoint subclass +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Rather than hard-coding various chunk managers to deal with specific chunked array implementations, xarray uses an +entrypoint system to allow developers of new chunked array implementations to register their corresponding subclass of +:py:class:`~xarray.namedarray.parallelcompat.ChunkManagerEntrypoint`. + + +To register a new entrypoint you need to add an entry to the ``setup.cfg`` like this:: + + [options.entry_points] + xarray.chunkmanagers = + dask = xarray.namedarray.daskmanager:DaskManager + +See also `cubed-xarray `_ for another example. + +To check that the entrypoint has worked correctly, you may find it useful to display the available chunkmanagers using +the internal function :py:func:`~xarray.namedarray.parallelcompat.list_chunkmanagers`. + +.. autofunction:: list_chunkmanagers + + +User interface +~~~~~~~~~~~~~~ + +Once the chunkmanager subclass has been registered, xarray objects wrapping the desired array type can be created in 3 ways: + +#. By manually passing the array type to the :py:class:`~xarray.DataArray` constructor, see the examples for :ref:`numpy-like arrays `, + +#. Calling :py:meth:`~xarray.DataArray.chunk`, passing the keyword arguments ``chunked_array_type`` and ``from_array_kwargs``, + +#. Calling :py:func:`~xarray.open_dataset`, passing the keyword arguments ``chunked_array_type`` and ``from_array_kwargs``. + +The latter two methods ultimately call the chunkmanager's implementation of ``.from_array``, to which they pass the ``from_array_kwargs`` dict. +The ``chunked_array_type`` kwarg selects which registered chunkmanager subclass to dispatch to. It defaults to ``'dask'`` +if Dask is installed, otherwise it defaults to whichever chunkmanager is registered if only one is registered. +If multiple chunkmanagers are registered it will raise an error by default. + +Parallel processing without chunks +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To use a parallel array type that does not expose a concept of chunks explicitly, none of the information on this page +is theoretically required. Such an array type (e.g. `Ramba `_ or +`Arkouda `_) could be wrapped using xarray's existing support for +:ref:`numpy-like "duck" arrays `. diff --git a/doc/internals/duck-arrays-integration.rst b/doc/internals/duck-arrays-integration.rst index d403328aa2f..43b17be8bb8 100644 --- a/doc/internals/duck-arrays-integration.rst +++ b/doc/internals/duck-arrays-integration.rst @@ -1,23 +1,59 @@ -.. _internals.duck_arrays: +.. _internals.duckarrays: Integrating with duck arrays ============================= .. warning:: - This is a experimental feature. + This is an experimental feature. Please report any bugs or other difficulties on `xarray's issue tracker `_. -Xarray can wrap custom :term:`duck array` objects as long as they define numpy's -``shape``, ``dtype`` and ``ndim`` properties and the ``__array__``, -``__array_ufunc__`` and ``__array_function__`` methods. +Xarray can wrap custom numpy-like arrays (":term:`duck array`\s") - see the :ref:`user guide documentation `. +This page is intended for developers who are interested in wrapping a new custom array type with xarray. + +.. _internals.duckarrays.requirements: + +Duck array requirements +~~~~~~~~~~~~~~~~~~~~~~~ + +Xarray does not explicitly check that required methods are defined by the underlying duck array object before +attempting to wrap the given array. However, a wrapped array type should at a minimum define these attributes: + +* ``shape`` property, +* ``dtype`` property, +* ``ndim`` property, +* ``__array__`` method, +* ``__array_ufunc__`` method, +* ``__array_function__`` method. + +These need to be defined consistently with :py:class:`numpy.ndarray`, for example the array ``shape`` +property needs to obey `numpy's broadcasting rules `_ +(see also the `Python Array API standard's explanation `_ +of these same rules). + +.. _internals.duckarrays.array_api_standard: + +Python Array API standard support +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +As an integration library xarray benefits greatly from the standardization of duck-array libraries' APIs, and so is a +big supporter of the `Python Array API Standard `_. + +We aim to support any array libraries that follow the Array API standard out-of-the-box. However, xarray does occasionally +call some numpy functions which are not (yet) part of the standard (e.g. :py:meth:`xarray.DataArray.pad` calls :py:func:`numpy.pad`). +See `xarray issue #7848 `_ for a list of such functions. We can still support dispatching on these functions through +the array protocols above, it just means that if you exclusively implement the methods in the Python Array API standard +then some features in xarray will not work. + +Custom inline reprs +~~~~~~~~~~~~~~~~~~~ In certain situations (e.g. when printing the collapsed preview of variables of a ``Dataset``), xarray will display the repr of a :term:`duck array` in a single line, truncating it to a certain number of characters. If that would drop too much information, the :term:`duck array` may define a ``_repr_inline_`` method that takes ``max_width`` (number of characters) as an -argument: +argument .. code:: python diff --git a/doc/internals/extending-xarray.rst b/doc/internals/extending-xarray.rst index f8b61d12a2f..0537ae85389 100644 --- a/doc/internals/extending-xarray.rst +++ b/doc/internals/extending-xarray.rst @@ -1,6 +1,8 @@ -Extending xarray -================ +.. _internals.accessors: + +Extending xarray using accessors +================================ .. ipython:: python :suppress: @@ -8,11 +10,16 @@ Extending xarray import xarray as xr -Xarray is designed as a general purpose library, and hence tries to avoid +Xarray is designed as a general purpose library and hence tries to avoid including overly domain specific functionality. But inevitably, the need for more domain specific logic arises. -One standard solution to this problem is to subclass Dataset and/or DataArray to +.. _internals.accessors.composition: + +Composition over Inheritance +---------------------------- + +One potential solution to this problem is to subclass Dataset and/or DataArray to add domain specific functionality. However, inheritance is not very robust. It's easy to inadvertently use internal APIs when subclassing, which means that your code may break when xarray upgrades. Furthermore, many builtin methods will @@ -21,15 +28,23 @@ only return native xarray objects. The standard advice is to use :issue:`composition over inheritance <706>`, but reimplementing an API as large as xarray's on your own objects can be an onerous task, even if most methods are only forwarding to xarray implementations. +(For an example of a project which took this approach of subclassing see `UXarray `_). If you simply want the ability to call a function with the syntax of a method call, then the builtin :py:meth:`~xarray.DataArray.pipe` method (copied from pandas) may suffice. +.. _internals.accessors.writing accessors: + +Writing Custom Accessors +------------------------ + To resolve this issue for more complex cases, xarray has the :py:func:`~xarray.register_dataset_accessor` and :py:func:`~xarray.register_dataarray_accessor` decorators for adding custom -"accessors" on xarray objects. Here's how you might use these decorators to +"accessors" on xarray objects, thereby "extending" the functionality of your xarray object. + +Here's how you might use these decorators to write a custom "geo" accessor implementing a geography specific extension to xarray: @@ -88,7 +103,7 @@ The intent here is that libraries that extend xarray could add such an accessor to implement subclass specific functionality rather than using actual subclasses or patching in a large number of domain specific methods. For further reading on ways to write new accessors and the philosophy behind the approach, see -:issue:`1080`. +https://github.com/pydata/xarray/issues/1080. To help users keep things straight, please `let us know `_ if you plan to write a new accessor diff --git a/doc/internals/how-to-add-new-backend.rst b/doc/internals/how-to-add-new-backend.rst index a106232958e..4352dd3df5b 100644 --- a/doc/internals/how-to-add-new-backend.rst +++ b/doc/internals/how-to-add-new-backend.rst @@ -9,7 +9,8 @@ to integrate any code in Xarray; all you need to do is: - Create a class that inherits from Xarray :py:class:`~xarray.backends.BackendEntrypoint` and implements the method ``open_dataset`` see :ref:`RST backend_entrypoint` -- Declare this class as an external plugin in your ``setup.py``, see :ref:`RST backend_registration` +- Declare this class as an external plugin in your project configuration, see :ref:`RST + backend_registration` If you also want to support lazy loading and dask see :ref:`RST lazy_loading`. @@ -267,42 +268,57 @@ interface only the boolean keywords related to the supported decoders. How to register a backend +++++++++++++++++++++++++ -Define a new entrypoint in your ``setup.py`` (or ``setup.cfg``) with: +Define a new entrypoint in your ``pyproject.toml`` (or ``setup.cfg/setup.py`` for older +configurations), with: - group: ``xarray.backends`` - name: the name to be passed to :py:meth:`~xarray.open_dataset` as ``engine`` - object reference: the reference of the class that you have implemented. -You can declare the entrypoint in ``setup.py`` using the following syntax: +You can declare the entrypoint in your project configuration like so: -.. code-block:: +.. tab:: pyproject.toml - setuptools.setup( - entry_points={ - "xarray.backends": ["my_engine=my_package.my_module:MyBackendEntryClass"], - }, - ) + .. code:: toml + + [project.entry-points."xarray.backends"] + my_engine = "my_package.my_module:MyBackendEntrypoint" + +.. tab:: pyproject.toml [Poetry] + + .. code-block:: toml + + [tool.poetry.plugins."xarray.backends"] + my_engine = "my_package.my_module:MyBackendEntrypoint" -in ``setup.cfg``: +.. tab:: setup.cfg -.. code-block:: cfg + .. code-block:: cfg - [options.entry_points] - xarray.backends = - my_engine = my_package.my_module:MyBackendEntryClass + [options.entry_points] + xarray.backends = + my_engine = my_package.my_module:MyBackendEntrypoint +.. tab:: setup.py -See https://packaging.python.org/specifications/entry-points/#data-model -for more information + .. code-block:: -If you are using `Poetry `_ for your build system, you can accomplish the same thing using "plugins". In this case you would need to add the following to your ``pyproject.toml`` file: + setuptools.setup( + entry_points={ + "xarray.backends": [ + "my_engine=my_package.my_module:MyBackendEntrypoint" + ], + }, + ) -.. code-block:: toml - [tool.poetry.plugins."xarray.backends"] - "my_engine" = "my_package.my_module:MyBackendEntryClass" +See the `Python Packaging User Guide +`_ for more +information on entrypoints and details of the syntax. -See https://python-poetry.org/docs/pyproject/#plugins for more information on Poetry plugins. +If you're using Poetry, note that table name in ``pyproject.toml`` is slightly different. +See `the Poetry docs `_ for more +information on plugins. .. _RST lazy_loading: diff --git a/doc/internals/how-to-create-custom-index.rst b/doc/internals/how-to-create-custom-index.rst new file mode 100644 index 00000000000..90b3412c2cb --- /dev/null +++ b/doc/internals/how-to-create-custom-index.rst @@ -0,0 +1,235 @@ +.. currentmodule:: xarray + +.. _internals.custom indexes: + +How to create a custom index +============================ + +.. warning:: + + This feature is highly experimental. Support for custom indexes has been + introduced in v2022.06.0 and is still incomplete. API is subject to change + without deprecation notice. However we encourage you to experiment and report issues that arise. + +Xarray's built-in support for label-based indexing (e.g. `ds.sel(latitude=40, method="nearest")`) and alignment operations +relies on :py:class:`pandas.Index` objects. Pandas Indexes are powerful and suitable for many +applications but also have some limitations: + +- it only works with 1-dimensional coordinates where explicit labels + are fully loaded in memory +- it is hard to reuse it with irregular data for which there exist more + efficient, tree-based structures to perform data selection +- it doesn't support extra metadata that may be required for indexing and + alignment (e.g., a coordinate reference system) + +Fortunately, Xarray now allows extending this functionality with custom indexes, +which can be implemented in 3rd-party libraries. + +The Index base class +-------------------- + +Every Xarray index must inherit from the :py:class:`Index` base class. It is for +example the case of Xarray built-in ``PandasIndex`` and ``PandasMultiIndex`` +subclasses, which wrap :py:class:`pandas.Index` and +:py:class:`pandas.MultiIndex` respectively. + +The ``Index`` API closely follows the :py:class:`Dataset` and +:py:class:`DataArray` API, e.g., for an index to support :py:meth:`DataArray.sel` it needs to +implement :py:meth:`Index.sel`, to support :py:meth:`DataArray.stack` and :py:meth:`DataArray.unstack` it +needs to implement :py:meth:`Index.stack` and :py:meth:`Index.unstack`, etc. + +Some guidelines and examples are given below. More details can be found in the +documented :py:class:`Index` API. + +Minimal requirements +-------------------- + +Every index must at least implement the :py:meth:`Index.from_variables` class +method, which is used by Xarray to build a new index instance from one or more +existing coordinates in a Dataset or DataArray. + +Since any collection of coordinates can be passed to that method (i.e., the +number, order and dimensions of the coordinates are all arbitrary), it is the +responsibility of the index to check the consistency and validity of those input +coordinates. + +For example, :py:class:`~xarray.core.indexes.PandasIndex` accepts only one coordinate and +:py:class:`~xarray.core.indexes.PandasMultiIndex` accepts one or more 1-dimensional coordinates that must all +share the same dimension. Other, custom indexes need not have the same +constraints, e.g., + +- a georeferenced raster index which only accepts two 1-d coordinates with + distinct dimensions +- a staggered grid index which takes coordinates with different dimension name + suffixes (e.g., "_c" and "_l" for center and left) + +Optional requirements +--------------------- + +Pretty much everything else is optional. Depending on the method, in the absence +of a (re)implementation, an index will either raise a `NotImplementedError` +or won't do anything specific (just drop, pass or copy itself +from/to the resulting Dataset or DataArray). + +For example, you can just skip re-implementing :py:meth:`Index.rename` if there +is no internal attribute or object to rename according to the new desired +coordinate or dimension names. In the case of ``PandasIndex``, we rename the +underlying ``pandas.Index`` object and/or update the ``PandasIndex.dim`` +attribute since the associated dimension name has been changed. + +Wrap index data as coordinate data +---------------------------------- + +In some cases it is possible to reuse the index's underlying object or structure +as coordinate data and hence avoid data duplication. + +For ``PandasIndex`` and ``PandasMultiIndex``, we +leverage the fact that ``pandas.Index`` objects expose some array-like API. In +Xarray we use some wrappers around those underlying objects as a thin +compatibility layer to preserve dtypes, handle explicit and n-dimensional +indexing, etc. + +Other structures like tree-based indexes (e.g., kd-tree) may differ too much +from arrays to reuse it as coordinate data. + +If the index data can be reused as coordinate data, the ``Index`` subclass +should implement :py:meth:`Index.create_variables`. This method accepts a +dictionary of variable names as keys and :py:class:`Variable` objects as values (used for propagating +variable metadata) and should return a dictionary of new :py:class:`Variable` or +:py:class:`IndexVariable` objects. + +Data selection +-------------- + +For an index to support label-based selection, it needs to at least implement +:py:meth:`Index.sel`. This method accepts a dictionary of labels where the keys +are coordinate names (already filtered for the current index) and the values can +be pretty much anything (e.g., a slice, a tuple, a list, a numpy array, a +:py:class:`Variable` or a :py:class:`DataArray`). It is the responsibility of +the index to properly handle those input labels. + +:py:meth:`Index.sel` must return an instance of :py:class:`IndexSelResult`. The +latter is a small data class that holds positional indexers (indices) and that +may also hold new variables, new indexes, names of variables or indexes to drop, +names of dimensions to rename, etc. For example, this is useful in the case of +``PandasMultiIndex`` as it allows Xarray to convert it into a single ``PandasIndex`` +when only one level remains after the selection. + +The :py:class:`IndexSelResult` class is also used to merge results from label-based +selection performed by different indexes. Note that it is now possible to have +two distinct indexes for two 1-d coordinates sharing the same dimension, but it +is not currently possible to use those two indexes in the same call to +:py:meth:`Dataset.sel`. + +Optionally, the index may also implement :py:meth:`Index.isel`. In the case of +``PandasIndex`` we use it to create a new index object by just indexing the +underlying ``pandas.Index`` object. In other cases this may not be possible, +e.g., a kd-tree object may not be easily indexed. If ``Index.isel()`` is not +implemented, the index in just dropped in the DataArray or Dataset resulting +from the selection. + +Alignment +--------- + +For an index to support alignment, it needs to implement: + +- :py:meth:`Index.equals`, which compares the index with another index and + returns either ``True`` or ``False`` +- :py:meth:`Index.join`, which combines the index with another index and returns + a new Index object +- :py:meth:`Index.reindex_like`, which queries the index with another index and + returns positional indexers that are used to re-index Dataset or DataArray + variables along one or more dimensions + +Xarray ensures that those three methods are called with an index of the same +type as argument. + +Meta-indexes +------------ + +Nothing prevents writing a custom Xarray index that itself encapsulates other +Xarray index(es). We call such index a "meta-index". + +Here is a small example of a meta-index for geospatial, raster datasets (i.e., +regularly spaced 2-dimensional data) that internally relies on two +``PandasIndex`` instances for the x and y dimensions respectively: + +.. code-block:: python + + from xarray import Index + from xarray.core.indexes import PandasIndex + from xarray.core.indexing import merge_sel_results + + + class RasterIndex(Index): + def __init__(self, xy_indexes): + assert len(xy_indexes) == 2 + + # must have two distinct dimensions + dim = [idx.dim for idx in xy_indexes.values()] + assert dim[0] != dim[1] + + self._xy_indexes = xy_indexes + + @classmethod + def from_variables(cls, variables): + assert len(variables) == 2 + + xy_indexes = { + k: PandasIndex.from_variables({k: v}) for k, v in variables.items() + } + + return cls(xy_indexes) + + def create_variables(self, variables): + idx_variables = {} + + for index in self._xy_indexes.values(): + idx_variables.update(index.create_variables(variables)) + + return idx_variables + + def sel(self, labels): + results = [] + + for k, index in self._xy_indexes.items(): + if k in labels: + results.append(index.sel({k: labels[k]})) + + return merge_sel_results(results) + + +This basic index only supports label-based selection. Providing a full-featured +index by implementing the other ``Index`` methods should be pretty +straightforward for this example, though. + +This example is also not very useful unless we add some extra functionality on +top of the two encapsulated ``PandasIndex`` objects, such as a coordinate +reference system. + +How to use a custom index +------------------------- + +You can use :py:meth:`Dataset.set_xindex` or :py:meth:`DataArray.set_xindex` to assign a +custom index to a Dataset or DataArray, e.g., using the ``RasterIndex`` above: + +.. code-block:: python + + import numpy as np + import xarray as xr + + da = xr.DataArray( + np.random.uniform(size=(100, 50)), + coords={"x": ("x", np.arange(50)), "y": ("y", np.arange(100))}, + dims=("y", "x"), + ) + + # Xarray create default indexes for the 'x' and 'y' coordinates + # we first need to explicitly drop it + da = da.drop_indexes(["x", "y"]) + + # Build a RasterIndex from the 'x' and 'y' coordinates + da_raster = da.set_xindex(["x", "y"], RasterIndex) + + # RasterIndex now takes care of label-based selection + selected = da_raster.sel(x=10, y=slice(20, 50)) diff --git a/doc/internals/index.rst b/doc/internals/index.rst index e4ca9779dd7..b2a37900338 100644 --- a/doc/internals/index.rst +++ b/doc/internals/index.rst @@ -1,6 +1,6 @@ .. _internals: -xarray Internals +Xarray Internals ================ Xarray builds upon two of the foundational libraries of the scientific Python @@ -8,13 +8,21 @@ stack, NumPy and pandas. It is written in pure Python (no C or Cython extensions), which makes it easy to develop and extend. Instead, we push compiled code to :ref:`optional dependencies`. +The pages in this section are intended for: + +* Contributors to xarray who wish to better understand some of the internals, +* Developers from other fields who wish to extend xarray with domain-specific logic, perhaps to support a new scientific community of users, +* Developers of other packages who wish to interface xarray with their existing tools, e.g. by creating a backend for reading a new file format, or wrapping a custom array type. .. toctree:: :maxdepth: 2 :hidden: - variable-objects + internal-design + interoperability duck-arrays-integration + chunked-arrays extending-xarray - zarr-encoding-spec how-to-add-new-backend + how-to-create-custom-index + zarr-encoding-spec diff --git a/doc/internals/internal-design.rst b/doc/internals/internal-design.rst new file mode 100644 index 00000000000..55ab2d79dbe --- /dev/null +++ b/doc/internals/internal-design.rst @@ -0,0 +1,224 @@ +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + np.set_printoptions(threshold=20) + +.. _internal design: + +Internal Design +=============== + +This page gives an overview of the internal design of xarray. + +In totality, the Xarray project defines 4 key data structures. +In order of increasing complexity, they are: + +- :py:class:`xarray.Variable`, +- :py:class:`xarray.DataArray`, +- :py:class:`xarray.Dataset`, +- :py:class:`datatree.DataTree`. + +The user guide lists only :py:class:`xarray.DataArray` and :py:class:`xarray.Dataset`, +but :py:class:`~xarray.Variable` is the fundamental object internally, +and :py:class:`~datatree.DataTree` is a natural generalisation of :py:class:`xarray.Dataset`. + +.. note:: + + Our :ref:`roadmap` includes plans both to document :py:class:`~xarray.Variable` as fully public API, + and to merge the `xarray-datatree `_ package into xarray's main repository. + +Internally private :ref:`lazy indexing classes ` are used to avoid loading more data than necessary, +and flexible indexes classes (derived from :py:class:`~xarray.indexes.Index`) provide performant label-based lookups. + + +.. _internal design.data structures: + +Data Structures +--------------- + +The :ref:`data structures` page in the user guide explains the basics and concentrates on user-facing behavior, +whereas this section explains how xarray's data structure classes actually work internally. + + +.. _internal design.data structures.variable: + +Variable Objects +~~~~~~~~~~~~~~~~ + +The core internal data structure in xarray is the :py:class:`~xarray.Variable`, +which is used as the basic building block behind xarray's +:py:class:`~xarray.Dataset`, :py:class:`~xarray.DataArray` types. A +:py:class:`~xarray.Variable` consists of: + +- ``dims``: A tuple of dimension names. +- ``data``: The N-dimensional array (typically a NumPy or Dask array) storing + the Variable's data. It must have the same number of dimensions as the length + of ``dims``. +- ``attrs``: A dictionary of metadata associated with this array. By + convention, xarray's built-in operations never use this metadata. +- ``encoding``: Another dictionary used to store information about how + these variable's data is represented on disk. See :ref:`io.encoding` for more + details. + +:py:class:`~xarray.Variable` has an interface similar to NumPy arrays, but extended to make use +of named dimensions. For example, it uses ``dim`` in preference to an ``axis`` +argument for methods like ``mean``, and supports :ref:`compute.broadcasting`. + +However, unlike ``Dataset`` and ``DataArray``, the basic ``Variable`` does not +include coordinate labels along each axis. + +:py:class:`~xarray.Variable` is public API, but because of its incomplete support for labeled +data, it is mostly intended for advanced uses, such as in xarray itself, for +writing new backends, or when creating custom indexes. +You can access the variable objects that correspond to xarray objects via the (readonly) +:py:attr:`Dataset.variables ` and +:py:attr:`DataArray.variable ` attributes. + + +.. _internal design.dataarray: + +DataArray Objects +~~~~~~~~~~~~~~~~~ + +The simplest data structure used by most users is :py:class:`~xarray.DataArray`. +A :py:class:`~xarray.DataArray` is a composite object consisting of multiple +:py:class:`~xarray.core.variable.Variable` objects which store related data. + +A single :py:class:`~xarray.core.Variable` is referred to as the "data variable", and stored under the :py:attr:`~xarray.DataArray.variable`` attribute. +A :py:class:`~xarray.DataArray` inherits all of the properties of this data variable, i.e. ``dims``, ``data``, ``attrs`` and ``encoding``, +all of which are implemented by forwarding on to the underlying ``Variable`` object. + +In addition, a :py:class:`~xarray.DataArray` stores additional ``Variable`` objects stored in a dict under the private ``_coords`` attribute, +each of which is referred to as a "Coordinate Variable". These coordinate variable objects are only allowed to have ``dims`` that are a subset of the data variable's ``dims``, +and each dim has a specific length. This means that the full :py:attr:`~xarray.DataArray.size` of the dataarray can be represented by a dictionary mapping dimension names to integer sizes. +The underlying data variable has this exact same size, and the attached coordinate variables have sizes which are some subset of the size of the data variable. +Another way of saying this is that all coordinate variables must be "alignable" with the data variable. + +When a coordinate is accessed by the user (e.g. via the dict-like :py:class:`~xarray.DataArray.__getitem__` syntax), +then a new ``DataArray`` is constructed by finding all coordinate variables that have compatible dimensions and re-attaching them before the result is returned. +This is why most users never see the ``Variable`` class underlying each coordinate variable - it is always promoted to a ``DataArray`` before returning. + +Lookups are performed by special :py:class:`~xarray.indexes.Index` objects, which are stored in a dict under the private ``_indexes`` attribute. +Indexes must be associated with one or more coordinates, and essentially act by translating a query given in physical coordinate space +(typically via the :py:meth:`~xarray.DataArray.sel` method) into a set of integer indices in array index space that can be used to index the underlying n-dimensional array-like ``data``. +Indexing in array index space (typically performed via the :py:meth:`~xarray.DataArray.isel` method) does not require consulting an ``Index`` object. + +Finally a :py:class:`~xarray.DataArray` defines a :py:attr:`~xarray.DataArray.name` attribute, which refers to its data +variable but is stored on the wrapping ``DataArray`` class. +The ``name`` attribute is primarily used when one or more :py:class:`~xarray.DataArray` objects are promoted into a :py:class:`~xarray.Dataset` +(e.g. via :py:meth:`~xarray.DataArray.to_dataset`). +Note that the underlying :py:class:`~xarray.core.Variable` objects are all unnamed, so they can always be referred to uniquely via a +dict-like mapping. + +.. _internal design.dataset: + +Dataset Objects +~~~~~~~~~~~~~~~ + +The :py:class:`~xarray.Dataset` class is a generalization of the :py:class:`~xarray.DataArray` class that can hold multiple data variables. +Internally all data variables and coordinate variables are stored under a single ``variables`` dict, and coordinates are +specified by storing their names in a private ``_coord_names`` dict. + +The dataset's ``dims`` are the set of all dims present across any variable, but (similar to in dataarrays) coordinate +variables cannot have a dimension that is not present on any data variable. + +When a data variable or coordinate variable is accessed, a new ``DataArray`` is again constructed from all compatible +coordinates before returning. + +.. _internal design.subclassing: + +.. note:: + + The way that selecting a variable from a ``DataArray`` or ``Dataset`` actually involves internally wrapping the + ``Variable`` object back up into a ``DataArray``/``Dataset`` is the primary reason :ref:`we recommend against subclassing ` + Xarray objects. The main problem it creates is that we currently cannot easily guarantee that for example selecting + a coordinate variable from your ``SubclassedDataArray`` would return an instance of ``SubclassedDataArray`` instead + of just an :py:class:`xarray.DataArray`. See `GH issue `_ for more details. + +.. _internal design.lazy indexing: + +Lazy Indexing Classes +--------------------- + +Lazy Loading +~~~~~~~~~~~~ + +If we open a ``Variable`` object from disk using :py:func:`~xarray.open_dataset` we can see that the actual values of +the array wrapped by the data variable are not displayed. + +.. ipython:: python + + da = xr.tutorial.open_dataset("air_temperature")["air"] + var = da.variable + var + +We can see the size, and the dtype of the underlying array, but not the actual values. +This is because the values have not yet been loaded. + +If we look at the private attribute :py:meth:`~xarray.Variable._data` containing the underlying array object, we see +something interesting: + +.. ipython:: python + + var._data + +You're looking at one of xarray's internal `Lazy Indexing Classes`. These powerful classes are hidden from the user, +but provide important functionality. + +Calling the public :py:attr:`~xarray.Variable.data` property loads the underlying array into memory. + +.. ipython:: python + + var.data + +This array is now cached, which we can see by accessing the private attribute again: + +.. ipython:: python + + var._data + +Lazy Indexing +~~~~~~~~~~~~~ + +The purpose of these lazy indexing classes is to prevent more data being loaded into memory than is necessary for the +subsequent analysis, by deferring loading data until after indexing is performed. + +Let's open the data from disk again. + +.. ipython:: python + + da = xr.tutorial.open_dataset("air_temperature")["air"] + var = da.variable + +Now, notice how even after subsetting the data has does not get loaded: + +.. ipython:: python + + var.isel(time=0) + +The shape has changed, but the values are still not shown. + +Looking at the private attribute again shows how this indexing information was propagated via the hidden lazy indexing classes: + +.. ipython:: python + + var.isel(time=0)._data + +.. note:: + + Currently only certain indexing operations are lazy, not all array operations. For discussion of making all array + operations lazy see `GH issue #5081 `_. + + +Lazy Dask Arrays +~~~~~~~~~~~~~~~~ + +Note that xarray's implementation of Lazy Indexing classes is completely separate from how :py:class:`dask.array.Array` +objects evaluate lazily. Dask-backed xarray objects delay almost all operations until :py:meth:`~xarray.DataArray.compute` +is called (either explicitly or implicitly via :py:meth:`~xarray.DataArray.plot` for example). The exceptions to this +laziness are operations whose output shape is data-dependent, such as when calling :py:meth:`~xarray.DataArray.where`. diff --git a/doc/internals/interoperability.rst b/doc/internals/interoperability.rst new file mode 100644 index 00000000000..a45363bcab7 --- /dev/null +++ b/doc/internals/interoperability.rst @@ -0,0 +1,45 @@ +.. _interoperability: + +Interoperability of Xarray +========================== + +Xarray is designed to be extremely interoperable, in many orthogonal ways. +Making xarray as flexible as possible is the common theme of most of the goals on our :ref:`roadmap`. + +This interoperability comes via a set of flexible abstractions into which the user can plug in. The current full list is: + +- :ref:`Custom file backends ` via the :py:class:`~xarray.backends.BackendEntrypoint` system, +- Numpy-like :ref:`"duck" array wrapping `, which supports the `Python Array API Standard `_, +- :ref:`Chunked distributed array computation ` via the :py:class:`~xarray.core.parallelcompat.ChunkManagerEntrypoint` system, +- Custom :py:class:`~xarray.Index` objects for :ref:`flexible label-based lookups `, +- Extending xarray objects with domain-specific methods via :ref:`custom accessors `. + +.. warning:: + + One obvious way in which xarray could be more flexible is that whilst subclassing xarray objects is possible, we + currently don't support it in most transformations, instead recommending composition over inheritance. See the + :ref:`internal design page ` for the rationale and look at the corresponding `GH issue `_ + if you're interested in improving support for subclassing! + +.. note:: + + If you think there is another way in which xarray could become more generically flexible then please + tell us your ideas by `raising an issue to request the feature `_! + + +Whilst xarray was originally designed specifically to open ``netCDF4`` files as :py:class:`numpy.ndarray` objects labelled by :py:class:`pandas.Index` objects, +it is entirely possible today to: + +- lazily open an xarray object directly from a custom binary file format (e.g. using ``xarray.open_dataset(path, engine='my_custom_format')``, +- handle the data as any API-compliant numpy-like array type (e.g. sparse or GPU-backed), +- distribute out-of-core computation across that array type in parallel (e.g. via :ref:`dask`), +- track the physical units of the data through computations (e.g via `pint-xarray `_), +- query the data via custom index logic optimized for specific applications (e.g. an :py:class:`~xarray.Index` object backed by a KDTree structure), +- attach domain-specific logic via accessor methods (e.g. to understand geographic Coordinate Reference System metadata), +- organize hierarchical groups of xarray data in a :py:class:`~datatree.DataTree` (e.g. to treat heterogeneous simulation and observational data together during analysis). + +All of these features can be provided simultaneously, using libraries compatible with the rest of the scientific python ecosystem. +In this situation xarray would be essentially a thin wrapper acting as pure-python framework, providing a common interface and +separation of concerns via various domain-agnostic abstractions. + +Most of the remaining pages in the documentation of xarray's internals describe these various types of interoperability in more detail. diff --git a/doc/internals/variable-objects.rst b/doc/internals/variable-objects.rst deleted file mode 100644 index 6ae3c2f7e6d..00000000000 --- a/doc/internals/variable-objects.rst +++ /dev/null @@ -1,31 +0,0 @@ -Variable objects -================ - -The core internal data structure in xarray is the :py:class:`~xarray.Variable`, -which is used as the basic building block behind xarray's -:py:class:`~xarray.Dataset` and :py:class:`~xarray.DataArray` types. A -``Variable`` consists of: - -- ``dims``: A tuple of dimension names. -- ``data``: The N-dimensional array (typically, a NumPy or Dask array) storing - the Variable's data. It must have the same number of dimensions as the length - of ``dims``. -- ``attrs``: An ordered dictionary of metadata associated with this array. By - convention, xarray's built-in operations never use this metadata. -- ``encoding``: Another ordered dictionary used to store information about how - these variable's data is represented on disk. See :ref:`io.encoding` for more - details. - -``Variable`` has an interface similar to NumPy arrays, but extended to make use -of named dimensions. For example, it uses ``dim`` in preference to an ``axis`` -argument for methods like ``mean``, and supports :ref:`compute.broadcasting`. - -However, unlike ``Dataset`` and ``DataArray``, the basic ``Variable`` does not -include coordinate labels along each axis. - -``Variable`` is public API, but because of its incomplete support for labeled -data, it is mostly intended for advanced uses, such as in xarray itself or for -writing new backends. You can access the variable objects that correspond to -xarray objects via the (readonly) :py:attr:`Dataset.variables -` and -:py:attr:`DataArray.variable ` attributes. diff --git a/doc/roadmap.rst b/doc/roadmap.rst index eeaaf10813b..820ff82151c 100644 --- a/doc/roadmap.rst +++ b/doc/roadmap.rst @@ -156,7 +156,7 @@ types would also be highly useful for xarray users. By pursuing these improvements in NumPy we hope to extend the benefits to the full scientific Python community, and avoid tight coupling between xarray and specific third-party libraries (e.g., for -implementing untis). This will allow xarray to maintain its domain +implementing units). This will allow xarray to maintain its domain agnostic strengths. We expect that we may eventually add some minimal interfaces in xarray diff --git a/doc/user-guide/computation.rst b/doc/user-guide/computation.rst index f913ea41a91..f8141f40321 100644 --- a/doc/user-guide/computation.rst +++ b/doc/user-guide/computation.rst @@ -63,33 +63,121 @@ Data arrays also implement many :py:class:`numpy.ndarray` methods: arr.round(2) arr.T + intarr = xr.DataArray([0, 1, 2, 3, 4, 5]) + intarr << 2 # only supported for int types + intarr >> 1 + .. _missing_values: Missing values ============== +Xarray represents missing values using the "NaN" (Not a Number) value from NumPy, which is a +special floating-point value that indicates a value that is undefined or unrepresentable. +There are several methods for handling missing values in xarray: + Xarray objects borrow the :py:meth:`~xarray.DataArray.isnull`, :py:meth:`~xarray.DataArray.notnull`, :py:meth:`~xarray.DataArray.count`, :py:meth:`~xarray.DataArray.dropna`, :py:meth:`~xarray.DataArray.fillna`, :py:meth:`~xarray.DataArray.ffill`, and :py:meth:`~xarray.DataArray.bfill` methods for working with missing data from pandas: +:py:meth:`~xarray.DataArray.isnull` is a method in xarray that can be used to check for missing or null values in an xarray object. +It returns a new xarray object with the same dimensions as the original object, but with boolean values +indicating where **missing values** are present. + .. ipython:: python x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) x.isnull() + +In this example, the third and fourth elements of 'x' are NaN, so the resulting :py:class:`~xarray.DataArray` +object has 'True' values in the third and fourth positions and 'False' values in the other positions. + +:py:meth:`~xarray.DataArray.notnull` is a method in xarray that can be used to check for non-missing or non-null values in an xarray +object. It returns a new xarray object with the same dimensions as the original object, but with boolean +values indicating where **non-missing values** are present. + +.. ipython:: python + + x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) x.notnull() + +In this example, the first two and the last elements of x are not NaN, so the resulting +:py:class:`~xarray.DataArray` object has 'True' values in these positions, and 'False' values in the +third and fourth positions where NaN is located. + +:py:meth:`~xarray.DataArray.count` is a method in xarray that can be used to count the number of +non-missing values along one or more dimensions of an xarray object. It returns a new xarray object with +the same dimensions as the original object, but with each element replaced by the count of non-missing +values along the specified dimensions. + +.. ipython:: python + + x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) x.count() + +In this example, 'x' has five elements, but two of them are NaN, so the resulting +:py:class:`~xarray.DataArray` object having a single element containing the value '3', which represents +the number of non-null elements in x. + +:py:meth:`~xarray.DataArray.dropna` is a method in xarray that can be used to remove missing or null values from an xarray object. +It returns a new xarray object with the same dimensions as the original object, but with missing values +removed. + +.. ipython:: python + + x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) x.dropna(dim="x") + +In this example, on calling x.dropna(dim="x") removes any missing values and returns a new +:py:class:`~xarray.DataArray` object with only the non-null elements [0, 1, 2] of 'x', in the +original order. + +:py:meth:`~xarray.DataArray.fillna` is a method in xarray that can be used to fill missing or null values in an xarray object with a +specified value or method. It returns a new xarray object with the same dimensions as the original object, but with missing values filled. + +.. ipython:: python + + x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) x.fillna(-1) + +In this example, there are two NaN values in 'x', so calling x.fillna(-1) replaces these values with -1 and +returns a new :py:class:`~xarray.DataArray` object with five elements, containing the values +[0, 1, -1, -1, 2] in the original order. + +:py:meth:`~xarray.DataArray.ffill` is a method in xarray that can be used to forward fill (or fill forward) missing values in an +xarray object along one or more dimensions. It returns a new xarray object with the same dimensions as the +original object, but with missing values replaced by the last non-missing value along the specified dimensions. + +.. ipython:: python + + x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) x.ffill("x") + +In this example, there are two NaN values in 'x', so calling x.ffill("x") fills these values with the last +non-null value in the same dimension, which are 0 and 1, respectively. The resulting :py:class:`~xarray.DataArray` object has +five elements, containing the values [0, 1, 1, 1, 2] in the original order. + +:py:meth:`~xarray.DataArray.bfill` is a method in xarray that can be used to backward fill (or fill backward) missing values in an +xarray object along one or more dimensions. It returns a new xarray object with the same dimensions as the original object, but +with missing values replaced by the next non-missing value along the specified dimensions. + +.. ipython:: python + + x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) x.bfill("x") +In this example, there are two NaN values in 'x', so calling x.bfill("x") fills these values with the next +non-null value in the same dimension, which are 2 and 2, respectively. The resulting :py:class:`~xarray.DataArray` object has +five elements, containing the values [0, 1, 2, 2, 2] in the original order. + Like pandas, xarray uses the float value ``np.nan`` (not-a-number) to represent missing values. Xarray objects also have an :py:meth:`~xarray.DataArray.interpolate_na` method -for filling missing values via 1D interpolation. +for filling missing values via 1D interpolation. It returns a new xarray object with the same dimensions +as the original object, but with missing values interpolated. .. ipython:: python @@ -100,6 +188,13 @@ for filling missing values via 1D interpolation. ) x.interpolate_na(dim="x", method="linear", use_coordinate="xx") +In this example, there are two NaN values in 'x', so calling x.interpolate_na(dim="x", method="linear", +use_coordinate="xx") fills these values with interpolated values along the "x" dimension using linear +interpolation based on the values of the xx coordinate. The resulting :py:class:`~xarray.DataArray` object has five elements, +containing the values [0., 1., 1.05, 1.45, 2.] in the original order. Note that the interpolated values +are calculated based on the values of the 'xx' coordinate, which has non-integer values, resulting in +non-integer interpolated values. + Note that xarray slightly diverges from the pandas ``interpolate`` syntax by providing the ``use_coordinate`` keyword which facilitates a clear specification of which values to use as the index in the interpolation. diff --git a/doc/user-guide/data-structures.rst b/doc/user-guide/data-structures.rst index e0fd4bd0d25..64e7b3625ac 100644 --- a/doc/user-guide/data-structures.rst +++ b/doc/user-guide/data-structures.rst @@ -19,7 +19,8 @@ DataArray :py:class:`xarray.DataArray` is xarray's implementation of a labeled, multi-dimensional array. It has several key properties: -- ``values``: a :py:class:`numpy.ndarray` holding the array's values +- ``values``: a :py:class:`numpy.ndarray` or + :ref:`numpy-like array ` holding the array's values - ``dims``: dimension names for each axis (e.g., ``('x', 'y', 'z')``) - ``coords``: a dict-like container of arrays (*coordinates*) that label each point (e.g., 1-dimensional arrays of numbers, datetime objects or @@ -46,7 +47,8 @@ Creating a DataArray The :py:class:`~xarray.DataArray` constructor takes: - ``data``: a multi-dimensional array of values (e.g., a numpy ndarray, - :py:class:`~pandas.Series`, :py:class:`~pandas.DataFrame` or ``pandas.Panel``) + a :ref:`numpy-like array `, :py:class:`~pandas.Series`, + :py:class:`~pandas.DataFrame` or ``pandas.Panel``) - ``coords``: a list or dictionary of coordinates. If a list, it should be a list of tuples where the first element is the dimension name and the second element is the corresponding coordinate array_like object. diff --git a/doc/user-guide/duckarrays.rst b/doc/user-guide/duckarrays.rst index 78c7d1e572a..f0650ac61b5 100644 --- a/doc/user-guide/duckarrays.rst +++ b/doc/user-guide/duckarrays.rst @@ -1,30 +1,183 @@ .. currentmodule:: xarray +.. _userguide.duckarrays: + Working with numpy-like arrays ============================== +NumPy-like arrays (often known as :term:`duck array`\s) are drop-in replacements for the :py:class:`numpy.ndarray` +class but with different features, such as propagating physical units or a different layout in memory. +Xarray can often wrap these array types, allowing you to use labelled dimensions and indexes whilst benefiting from the +additional features of these array libraries. + +Some numpy-like array types that xarray already has some support for: + +* `Cupy `_ - GPU support (see `cupy-xarray `_), +* `Sparse `_ - for performant arrays with many zero elements, +* `Pint `_ - for tracking the physical units of your data (see `pint-xarray `_), +* `Dask `_ - parallel computing on larger-than-memory arrays (see :ref:`using dask with xarray `), +* `Cubed `_ - another parallel computing framework that emphasises reliability (see `cubed-xarray `_). + .. warning:: - This feature should be considered experimental. Please report any bug you may find on - xarray’s github repository. + This feature should be considered somewhat experimental. Please report any bugs you find on + `xarray’s issue tracker `_. + +.. note:: + + For information on wrapping dask arrays see :ref:`dask`. Whilst xarray wraps dask arrays in a similar way to that + described on this page, chunked array types like :py:class:`dask.array.Array` implement additional methods that require + slightly different user code (e.g. calling ``.chunk`` or ``.compute``). See the docs on :ref:`wrapping chunked arrays `. + +Why "duck"? +----------- + +Why is it also called a "duck" array? This comes from a common statement of object-oriented programming - +"If it walks like a duck, and quacks like a duck, treat it like a duck". In other words, a library like xarray that +is capable of using multiple different types of arrays does not have to explicitly check that each one it encounters is +permitted (e.g. ``if dask``, ``if numpy``, ``if sparse`` etc.). Instead xarray can take the more permissive approach of simply +treating the wrapped array as valid, attempting to call the relevant methods (e.g. ``.mean()``) and only raising an +error if a problem occurs (e.g. the method is not found on the wrapped class). This is much more flexible, and allows +objects and classes from different libraries to work together more easily. + +What is a numpy-like array? +--------------------------- + +A "numpy-like array" (also known as a "duck array") is a class that contains array-like data, and implements key +numpy-like functionality such as indexing, broadcasting, and computation methods. + +For example, the `sparse `_ library provides a sparse array type which is useful for representing nD array objects like sparse matrices +in a memory-efficient manner. We can create a sparse array object (of the :py:class:`sparse.COO` type) from a numpy array like this: + +.. ipython:: python + + from sparse import COO + + x = np.eye(4, dtype=np.uint8) # create diagonal identity matrix + s = COO.from_numpy(x) + s -NumPy-like arrays (:term:`duck array`) extend the :py:class:`numpy.ndarray` with -additional features, like propagating physical units or a different layout in memory. +This sparse object does not attempt to explicitly store every element in the array, only the non-zero elements. +This approach is much more efficient for large arrays with only a few non-zero elements (such as tri-diagonal matrices). +Sparse array objects can be converted back to a "dense" numpy array by calling :py:meth:`sparse.COO.todense`. -:py:class:`DataArray` and :py:class:`Dataset` objects can wrap these duck arrays, as -long as they satisfy certain conditions (see :ref:`internals.duck_arrays`). +Just like :py:class:`numpy.ndarray` objects, :py:class:`sparse.COO` arrays support indexing + +.. ipython:: python + + s[1, 1] # diagonal elements should be ones + s[2, 3] # off-diagonal elements should be zero + +broadcasting, + +.. ipython:: python + + x2 = np.zeros( + (4, 1), dtype=np.uint8 + ) # create second sparse array of different shape + s2 = COO.from_numpy(x2) + (s * s2) # multiplication requires broadcasting + +and various computation methods + +.. ipython:: python + + s.sum(axis=1) + +This numpy-like array also supports calling so-called `numpy ufuncs `_ +("universal functions") on it directly: + +.. ipython:: python + + np.sum(s, axis=1) + + +Notice that in each case the API for calling the operation on the sparse array is identical to that of calling it on the +equivalent numpy array - this is the sense in which the sparse array is "numpy-like". .. note:: - For ``dask`` support see :ref:`dask`. + For discussion on exactly which methods a class needs to implement to be considered "numpy-like", see :ref:`internals.duckarrays`. + +Wrapping numpy-like arrays in xarray +------------------------------------ + +:py:class:`DataArray`, :py:class:`Dataset`, and :py:class:`Variable` objects can wrap these numpy-like arrays. +Constructing xarray objects which wrap numpy-like arrays +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Missing features ----------------- -Most of the API does support :term:`duck array` objects, but there are a few areas where -the code will still cast to ``numpy`` arrays: +The primary way to create an xarray object which wraps a numpy-like array is to pass that numpy-like array instance directly +to the constructor of the xarray class. The :ref:`page on xarray data structures ` shows how :py:class:`DataArray` and :py:class:`Dataset` +both accept data in various forms through their ``data`` argument, but in fact this data can also be any wrappable numpy-like array. -- dimension coordinates, and thus all indexing operations: +For example, we can wrap the sparse array we created earlier inside a new DataArray object: + +.. ipython:: python + + s_da = xr.DataArray(s, dims=["i", "j"]) + s_da + +We can see what's inside - the printable representation of our xarray object (the repr) automatically uses the printable +representation of the underlying wrapped array. + +Of course our sparse array object is still there underneath - it's stored under the ``.data`` attribute of the dataarray: + +.. ipython:: python + + s_da.data + +Array methods +~~~~~~~~~~~~~ + +We saw above that numpy-like arrays provide numpy methods. Xarray automatically uses these when you call the corresponding xarray method: + +.. ipython:: python + + s_da.sum(dim="j") + +Converting wrapped types +~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to change the type inside your xarray object you can use :py:meth:`DataArray.as_numpy`: + +.. ipython:: python + + s_da.as_numpy() + +This returns a new :py:class:`DataArray` object, but now wrapping a normal numpy array. + +If instead you want to convert to numpy and return that numpy array you can use either :py:meth:`DataArray.to_numpy` or +:py:meth:`DataArray.values`, where the former is strongly preferred. The difference is in the way they coerce to numpy - :py:meth:`~DataArray.values` +always uses :py:func:`numpy.asarray` which will fail for some array types (e.g. ``cupy``), whereas :py:meth:`~DataArray.to_numpy` +uses the correct method depending on the array type. + +.. ipython:: python + + s_da.to_numpy() + +.. ipython:: python + :okexcept: + + s_da.values + +This illustrates the difference between :py:meth:`~DataArray.data` and :py:meth:`~DataArray.values`, +which is sometimes a point of confusion for new xarray users. +Explicitly: :py:meth:`DataArray.data` returns the underlying numpy-like array, regardless of type, whereas +:py:meth:`DataArray.values` converts the underlying array to a numpy array before returning it. +(This is another reason to use :py:meth:`~DataArray.to_numpy` over :py:meth:`~DataArray.values` - the intention is clearer.) + +Conversion to numpy as a fallback +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If a wrapped array does not implement the corresponding array method then xarray will often attempt to convert the +underlying array to a numpy array so that the operation can be performed. You may want to watch out for this behavior, +and report any instances in which it causes problems. + +Most of xarray's API does support using :term:`duck array` objects, but there are a few areas where +the code will still convert to ``numpy`` arrays: + +- Dimension coordinates, and thus all indexing operations: * :py:meth:`Dataset.sel` and :py:meth:`DataArray.sel` * :py:meth:`Dataset.loc` and :py:meth:`DataArray.loc` @@ -33,7 +186,7 @@ the code will still cast to ``numpy`` arrays: :py:meth:`DataArray.reindex` and :py:meth:`DataArray.reindex_like`: duck arrays in data variables and non-dimension coordinates won't be casted -- functions and methods that depend on external libraries or features of ``numpy`` not +- Functions and methods that depend on external libraries or features of ``numpy`` not covered by ``__array_function__`` / ``__array_ufunc__``: * :py:meth:`Dataset.ffill` and :py:meth:`DataArray.ffill` (uses ``bottleneck``) @@ -49,17 +202,25 @@ the code will still cast to ``numpy`` arrays: :py:class:`numpy.vectorize`) * :py:func:`apply_ufunc` with ``vectorize=True`` (uses :py:class:`numpy.vectorize`) -- incompatibilities between different :term:`duck array` libraries: +- Incompatibilities between different :term:`duck array` libraries: * :py:meth:`Dataset.chunk` and :py:meth:`DataArray.chunk`: this fails if the data was not already chunked and the :term:`duck array` (e.g. a ``pint`` quantity) should - wrap the new ``dask`` array; changing the chunk sizes works. - + wrap the new ``dask`` array; changing the chunk sizes works however. Extensions using duck arrays ---------------------------- -Here's a list of libraries extending ``xarray`` to make working with wrapped duck arrays -easier: + +Whilst the features above allow many numpy-like array libraries to be used pretty seamlessly with xarray, it often also +makes sense to use an interfacing package to make certain tasks easier. + +For example the `pint-xarray package `_ offers a custom ``.pint`` accessor (see :ref:`internals.accessors`) which provides +convenient access to information stored within the wrapped array (e.g. ``.units`` and ``.magnitude``), and makes makes +creating wrapped pint arrays (and especially xarray-wrapping-pint-wrapping-dask arrays) simpler for the user. + +We maintain a list of libraries extending ``xarray`` to make working with particular wrapped duck arrays +easier. If you know of more that aren't on this list please raise an issue to add them! - `pint-xarray `_ - `cupy-xarray `_ +- `cubed-xarray `_ diff --git a/doc/user-guide/groupby.rst b/doc/user-guide/groupby.rst index dce20dce228..1ad2d52fc00 100644 --- a/doc/user-guide/groupby.rst +++ b/doc/user-guide/groupby.rst @@ -177,28 +177,18 @@ This last line is roughly equivalent to the following:: results.append(group - alt.sel(letters=label)) xr.concat(results, dim='x') -Squeezing -~~~~~~~~~ +Iterating and Squeezing +~~~~~~~~~~~~~~~~~~~~~~~ -When grouping over a dimension, you can control whether the dimension is -squeezed out or if it should remain with length one on each group by using -the ``squeeze`` parameter: - -.. ipython:: python - - next(iter(arr.groupby("x"))) +Previously, Xarray defaulted to squeezing out dimensions of size one when iterating over +a GroupBy object. This behaviour is being removed. +You can always squeeze explicitly later with the Dataset or DataArray +:py:meth:`~xarray.DataArray.squeeze` methods. .. ipython:: python next(iter(arr.groupby("x", squeeze=False))) -Although xarray will attempt to automatically -:py:attr:`~xarray.DataArray.transpose` dimensions back into their original order -when you use apply, it is sometimes useful to set ``squeeze=False`` to -guarantee that all original dimensions remain unchanged. - -You can always squeeze explicitly later with the Dataset or DataArray -:py:meth:`~xarray.DataArray.squeeze` methods. .. _groupby.multidim: diff --git a/doc/user-guide/index.rst b/doc/user-guide/index.rst index 0ac25d68930..45f0ce352de 100644 --- a/doc/user-guide/index.rst +++ b/doc/user-guide/index.rst @@ -25,4 +25,5 @@ examples that describe many common tasks that you can accomplish with xarray. dask plotting options + testing duckarrays diff --git a/doc/user-guide/indexing.rst b/doc/user-guide/indexing.rst index 492316f898f..fba9dd585ab 100644 --- a/doc/user-guide/indexing.rst +++ b/doc/user-guide/indexing.rst @@ -352,7 +352,6 @@ dimensions: ind_x = xr.DataArray([0, 1], dims=["x"]) ind_y = xr.DataArray([0, 1], dims=["y"]) da[ind_x, ind_y] # orthogonal indexing - da[ind_x, ind_x] # vectorized indexing Slices or sequences/arrays without named-dimensions are treated as if they have the same dimension which is indexed along: @@ -399,6 +398,12 @@ These methods may also be applied to ``Dataset`` objects Vectorized indexing may be used to extract information from the nearest grid cells of interest, for example, the nearest climate model grid cells to a collection specified weather station latitudes and longitudes. +To trigger vectorized indexing behavior +you will need to provide the selection dimensions with a new +shared output dimension name. In the example below, the selections +of the closest latitude and longitude are renamed to an output +dimension named "points": + .. ipython:: python @@ -544,6 +549,7 @@ __ https://numpy.org/doc/stable/user/basics.indexing.html#assigning-values-to-in You can also assign values to all variables of a :py:class:`Dataset` at once: .. ipython:: python + :okwarning: ds_org = xr.tutorial.open_dataset("eraint_uvz").isel( latitude=slice(56, 59), longitude=slice(255, 258), level=0 diff --git a/doc/user-guide/interpolation.rst b/doc/user-guide/interpolation.rst index 7b40962e826..311e1bf0129 100644 --- a/doc/user-guide/interpolation.rst +++ b/doc/user-guide/interpolation.rst @@ -292,8 +292,8 @@ Let's see how :py:meth:`~xarray.DataArray.interp` works on real data. axes[0].set_title("Raw data") # Interpolated data - new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.dims["lon"] * 4) - new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.dims["lat"] * 4) + new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.sizes["lon"] * 4) + new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.sizes["lat"] * 4) dsi = ds.interp(lat=new_lat, lon=new_lon) dsi.air.plot(ax=axes[1]) @savefig interpolation_sample3.png width=8in diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst index d5de181f562..48751c5f299 100644 --- a/doc/user-guide/io.rst +++ b/doc/user-guide/io.rst @@ -44,9 +44,9 @@ __ https://www.unidata.ucar.edu/software/netcdf/ .. _netCDF FAQ: https://www.unidata.ucar.edu/software/netcdf/docs/faq.html#What-Is-netCDF -Reading and writing netCDF files with xarray requires scipy or the -`netCDF4-Python`__ library to be installed (the latter is required to -read/write netCDF V4 files and use the compression options described below). +Reading and writing netCDF files with xarray requires scipy, h5netcdf, or the +`netCDF4-Python`__ library to be installed. SciPy only supports reading and writing +of netCDF V3 files. __ https://github.com/Unidata/netcdf4-python @@ -115,10 +115,7 @@ you try to perform some sort of actual computation. For an example of how these lazy arrays work, see the OPeNDAP section below. There may be minor differences in the :py:class:`Dataset` object returned -when reading a NetCDF file with different engines. For example, -single-valued attributes are returned as scalars by the default -``engine=netcdf4``, but as arrays of size ``(1,)`` when reading with -``engine=h5netcdf``. +when reading a NetCDF file with different engines. It is important to note that when you modify values of a Dataset, even one linked to files on disk, only the in-memory copy you are manipulating in xarray @@ -254,31 +251,22 @@ You can view this encoding information (among others) in the :py:attr:`DataArray.encoding` and :py:attr:`DataArray.encoding` attributes: -.. ipython:: - :verbatim: +.. ipython:: python - In [1]: ds_disk["y"].encoding - Out[1]: - {'zlib': False, - 'shuffle': False, - 'complevel': 0, - 'fletcher32': False, - 'contiguous': True, - 'chunksizes': None, - 'source': 'saved_on_disk.nc', - 'original_shape': (5,), - 'dtype': dtype('int64'), - 'units': 'days since 2000-01-01 00:00:00', - 'calendar': 'proleptic_gregorian'} - - In [9]: ds_disk.encoding - Out[9]: - {'unlimited_dims': set(), - 'source': 'saved_on_disk.nc'} + ds_disk["y"].encoding + ds_disk.encoding Note that all operations that manipulate variables other than indexing will remove encoding information. +In some cases it is useful to intentionally reset a dataset's original encoding values. +This can be done with either the :py:meth:`Dataset.drop_encoding` or +:py:meth:`DataArray.drop_encoding` methods. + +.. ipython:: python + + ds_no_encoding = ds_disk.drop_encoding() + ds_no_encoding.encoding .. _combining multiple files: @@ -568,6 +556,67 @@ and currently raises a warning unless ``invalid_netcdf=True`` is set: Note that this produces a file that is likely to be not readable by other netCDF libraries! +.. _io.hdf5: + +HDF5 +---- +`HDF5`_ is both a file format and a data model for storing information. HDF5 stores +data hierarchically, using groups to create a nested structure. HDF5 is a more +general version of the netCDF4 data model, so the nested structure is one of many +similarities between the two data formats. + +Reading HDF5 files in xarray requires the ``h5netcdf`` engine, which can be installed +with ``conda install h5netcdf``. Once installed we can use xarray to open HDF5 files: + +.. code:: python + + xr.open_dataset("/path/to/my/file.h5") + +The similarities between HDF5 and netCDF4 mean that HDF5 data can be written with the +same :py:meth:`Dataset.to_netcdf` method as used for netCDF4 data: + +.. ipython:: python + + ds = xr.Dataset( + {"foo": (("x", "y"), np.random.rand(4, 5))}, + coords={ + "x": [10, 20, 30, 40], + "y": pd.date_range("2000-01-01", periods=5), + "z": ("x", list("abcd")), + }, + ) + + ds.to_netcdf("saved_on_disk.h5") + +Groups +~~~~~~ + +If you have multiple or highly nested groups, xarray by default may not read the group +that you want. A particular group of an HDF5 file can be specified using the ``group`` +argument: + +.. code:: python + + xr.open_dataset("/path/to/my/file.h5", group="/my/group") + +While xarray cannot interrogate an HDF5 file to determine which groups are available, +the HDF5 Python reader `h5py`_ can be used instead. + +Natively the xarray data structures can only handle one level of nesting, organized as +DataArrays inside of Datasets. If your HDF5 file has additional levels of hierarchy you +can only access one group and a time and will need to specify group names. + +.. note:: + + For native handling of multiple HDF5 groups with xarray, including I/O, you might be + interested in the experimental + `xarray-datatree `_ package. + + +.. _HDF5: https://hdfgroup.github.io/hdf5/index.html +.. _h5py: https://www.h5py.org/ + + .. _io.zarr: Zarr @@ -617,10 +666,17 @@ store is already present at that path, an error will be raised, preventing it from being overwritten. To override this behavior and overwrite an existing store, add ``mode='w'`` when invoking :py:meth:`~Dataset.to_zarr`. +DataArrays can also be saved to disk using the :py:meth:`DataArray.to_zarr` method, +and loaded from disk using the :py:func:`open_dataarray` function with `engine='zarr'`. +Similar to :py:meth:`DataArray.to_netcdf`, :py:meth:`DataArray.to_zarr` will +convert the ``DataArray`` to a ``Dataset`` before saving, and then convert back +when loading, ensuring that the ``DataArray`` that is loaded is always exactly +the same as the one that was saved. + .. note:: - xarray does not write NCZarr attributes. Therefore, NCZarr data must be - opened in read-only mode. + xarray does not write `NCZarr `_ attributes. + Therefore, NCZarr data must be opened in read-only mode. To store variable length strings, convert them to object arrays first with ``dtype=object``. @@ -640,10 +696,10 @@ It is possible to read and write xarray datasets directly from / to cloud storage buckets using zarr. This example uses the `gcsfs`_ package to provide an interface to `Google Cloud Storage`_. -From v0.16.2: general `fsspec`_ URLs are parsed and the store set up for you -automatically when reading, such that you can open a dataset in a single -call. You should include any arguments to the storage backend as the -key ``storage_options``, part of ``backend_kwargs``. +General `fsspec`_ URLs, those that begin with ``s3://`` or ``gcs://`` for example, +are parsed and the store set up for you automatically when reading. +You should include any arguments to the storage backend as the +key ```storage_options``, part of ``backend_kwargs``. .. code:: python @@ -659,7 +715,7 @@ key ``storage_options``, part of ``backend_kwargs``. This also works with ``open_mfdataset``, allowing you to pass a list of paths or a URL to be interpreted as a glob string. -For older versions, and for writing, you must explicitly set up a ``MutableMapping`` +For writing, you must explicitly set up a ``MutableMapping`` instance and pass this, as follows: .. code:: python @@ -713,10 +769,10 @@ Consolidated Metadata ~~~~~~~~~~~~~~~~~~~~~ Xarray needs to read all of the zarr metadata when it opens a dataset. -In some storage mediums, such as with cloud object storage (e.g. amazon S3), +In some storage mediums, such as with cloud object storage (e.g. `Amazon S3`_), this can introduce significant overhead, because two separate HTTP calls to the object store must be made for each variable in the dataset. -As of xarray version 0.18, xarray by default uses a feature called +By default Xarray uses a feature called *consolidated metadata*, storing all metadata for the entire dataset with a single key (by default called ``.zmetadata``). This typically drastically speeds up opening the store. (For more information on this feature, consult the @@ -740,16 +796,20 @@ reads. Because this fall-back option is so much slower, xarray issues a .. _io.zarr.appending: -Appending to existing Zarr stores -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Modifying existing Zarr stores +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Xarray supports several ways of incrementally writing variables to a Zarr store. These options are useful for scenarios when it is infeasible or undesirable to write your entire dataset at once. +1. Use ``mode='a'`` to add or overwrite entire variables, +2. Use ``append_dim`` to resize and append to existing variables, and +3. Use ``region`` to write to limited regions of existing arrays. + .. tip:: - If you can load all of your data into a single ``Dataset`` using dask, a + For ``Dataset`` objects containing dask arrays, a single call to ``to_zarr()`` will write all of your data in parallel. .. warning:: @@ -763,7 +823,7 @@ with ``mode='a'`` on a Dataset containing the new variables, passing in an existing Zarr store or path to a Zarr store. To resize and then append values along an existing dimension in a store, set -``append_dim``. This is a good option if data always arives in a particular +``append_dim``. This is a good option if data always arrives in a particular order, e.g., for time-stepping a simulation: .. ipython:: python @@ -820,17 +880,20 @@ and then calling ``to_zarr`` with ``compute=False`` to write only metadata ds.to_zarr(path, compute=False) Now, a Zarr store with the correct variable shapes and attributes exists that -can be filled out by subsequent calls to ``to_zarr``. The ``region`` provides a -mapping from dimension names to Python ``slice`` objects indicating where the -data should be written (in index space, not coordinate space), e.g., +can be filled out by subsequent calls to ``to_zarr``. +Setting ``region="auto"`` will open the existing store and determine the +correct alignment of the new data with the existing coordinates, or as an +explicit mapping from dimension names to Python ``slice`` objects indicating +where the data should be written (in index space, not label space), e.g., .. ipython:: python # For convenience, we'll slice a single dataset, but in the real use-case # we would create them separately possibly even from separate processes. ds = xr.Dataset({"foo": ("x", np.arange(30))}) - ds.isel(x=slice(0, 10)).to_zarr(path, region={"x": slice(0, 10)}) - ds.isel(x=slice(10, 20)).to_zarr(path, region={"x": slice(10, 20)}) + # Any of the following region specifications are valid + ds.isel(x=slice(0, 10)).to_zarr(path, region="auto") + ds.isel(x=slice(10, 20)).to_zarr(path, region={"x": "auto"}) ds.isel(x=slice(20, 30)).to_zarr(path, region={"x": slice(20, 30)}) Concurrent writes with ``region`` are safe as long as they modify distinct @@ -1157,46 +1220,7 @@ search indices or other automated data discovery tools. Rasterio -------- -GeoTIFFs and other gridded raster datasets can be opened using `rasterio`_, if -rasterio is installed. Here is an example of how to use -:py:func:`open_rasterio` to read one of rasterio's `test files`_: - -.. deprecated:: 0.20.0 - - Deprecated in favor of rioxarray. - For information about transitioning, see: - `rioxarray getting started docs`` - -.. ipython:: - :verbatim: - - In [7]: rio = xr.open_rasterio("RGB.byte.tif") - - In [8]: rio - Out[8]: - - [1703814 values with dtype=uint8] - Coordinates: - * band (band) int64 1 2 3 - * y (y) float64 2.827e+06 2.826e+06 2.826e+06 2.826e+06 2.826e+06 ... - * x (x) float64 1.021e+05 1.024e+05 1.027e+05 1.03e+05 1.033e+05 ... - Attributes: - res: (300.0379266750948, 300.041782729805) - transform: (300.0379266750948, 0.0, 101985.0, 0.0, -300.041782729805, 28... - is_tiled: 0 - crs: +init=epsg:32618 - - -The ``x`` and ``y`` coordinates are generated out of the file's metadata -(``bounds``, ``width``, ``height``), and they can be understood as cartesian -coordinates defined in the file's projection provided by the ``crs`` attribute. -``crs`` is a PROJ4 string which can be parsed by e.g. `pyproj`_ or rasterio. -See :ref:`/examples/visualization_gallery.ipynb#Parsing-rasterio-geocoordinates` -for an example of how to convert these to longitudes and latitudes. - - -Additionally, you can use `rioxarray`_ for reading in GeoTiff, netCDF or other -GDAL readable raster data using `rasterio`_ as well as for exporting to a geoTIFF. +GDAL readable raster data using `rasterio`_ such as GeoTIFFs can be opened using the `rioxarray`_ extension. `rioxarray`_ can also handle geospatial related tasks such as re-projecting and clipping. .. ipython:: @@ -1291,27 +1315,6 @@ We recommend installing PyNIO via conda:: .. _PyNIO backend is deprecated: https://github.com/pydata/xarray/issues/4491 .. _PyNIO is no longer maintained: https://github.com/NCAR/pynio/issues/53 -.. _io.PseudoNetCDF: - -Formats supported by PseudoNetCDF ---------------------------------- - -Xarray can also read CAMx, BPCH, ARL PACKED BIT, and many other file -formats supported by PseudoNetCDF_, if PseudoNetCDF is installed. -PseudoNetCDF can also provide Climate Forecasting Conventions to -CMAQ files. In addition, PseudoNetCDF can automatically register custom -readers that subclass PseudoNetCDF.PseudoNetCDFFile. PseudoNetCDF can -identify readers either heuristically, or by a format specified via a key in -`backend_kwargs`. - -To use PseudoNetCDF to read such files, supply -``engine='pseudonetcdf'`` to :py:func:`open_dataset`. - -Add ``backend_kwargs={'format': ''}`` where `` -options are listed on the PseudoNetCDF page. - -.. _PseudoNetCDF: https://github.com/barronh/PseudoNetCDF - CSV and other formats supported by pandas ----------------------------------------- diff --git a/doc/user-guide/reshaping.rst b/doc/user-guide/reshaping.rst index 95bf21a71b0..14b343549e2 100644 --- a/doc/user-guide/reshaping.rst +++ b/doc/user-guide/reshaping.rst @@ -4,7 +4,12 @@ Reshaping and reorganizing data ############################### -These methods allow you to reorganize your data by changing dimensions, array shape, order of values, or indexes. +Reshaping and reorganizing data refers to the process of changing the structure or organization of data by modifying dimensions, array shapes, order of values, or indexes. Xarray provides several methods to accomplish these tasks. + +These methods are particularly useful for reshaping xarray objects for use in machine learning packages, such as scikit-learn, that usually require two-dimensional numpy arrays as inputs. Reshaping can also be required before passing data to external visualization tools, for example geospatial data might expect input organized into a particular format corresponding to stacks of satellite images. + +Importing the library +--------------------- .. ipython:: python :suppress: @@ -54,11 +59,11 @@ use :py:meth:`~xarray.DataArray.squeeze` Converting between datasets and arrays -------------------------------------- -To convert from a Dataset to a DataArray, use :py:meth:`~xarray.Dataset.to_array`: +To convert from a Dataset to a DataArray, use :py:meth:`~xarray.Dataset.to_dataarray`: .. ipython:: python - arr = ds.to_array() + arr = ds.to_dataarray() arr This method broadcasts all data variables in the dataset against each other, @@ -72,7 +77,7 @@ To convert back from a DataArray to a Dataset, use arr.to_dataset(dim="variable") -The broadcasting behavior of ``to_array`` means that the resulting array +The broadcasting behavior of ``to_dataarray`` means that the resulting array includes the union of data variable dimensions: .. ipython:: python @@ -83,7 +88,7 @@ includes the union of data variable dimensions: ds2 # the resulting array has 6 elements - ds2.to_array() + ds2.to_dataarray() Otherwise, the result could not be represented as an orthogonal array. @@ -156,8 +161,8 @@ arrays as inputs. For datasets with only one variable, we only need ``stack`` and ``unstack``, but combining multiple variables in a :py:class:`xarray.Dataset` is more complicated. If the variables in the dataset have matching numbers of dimensions, we can call -:py:meth:`~xarray.Dataset.to_array` and then stack along the the new coordinate. -But :py:meth:`~xarray.Dataset.to_array` will broadcast the dataarrays together, +:py:meth:`~xarray.Dataset.to_dataarray` and then stack along the the new coordinate. +But :py:meth:`~xarray.Dataset.to_dataarray` will broadcast the dataarrays together, which will effectively tile the lower dimensional variable along the missing dimensions. The method :py:meth:`xarray.Dataset.to_stacked_array` allows combining variables of differing dimensions without this wasteful copying while @@ -269,7 +274,7 @@ Sort ---- One may sort a DataArray/Dataset via :py:meth:`~xarray.DataArray.sortby` and -:py:meth:`~xarray.DataArray.sortby`. The input can be an individual or list of +:py:meth:`~xarray.Dataset.sortby`. The input can be an individual or list of 1D ``DataArray`` objects: .. ipython:: python diff --git a/doc/user-guide/terminology.rst b/doc/user-guide/terminology.rst index 24e6ab69927..55937310827 100644 --- a/doc/user-guide/terminology.rst +++ b/doc/user-guide/terminology.rst @@ -47,30 +47,28 @@ complete examples, please consult the relevant documentation.* all but one of these degrees of freedom is fixed. We can think of each dimension axis as having a name, for example the "x dimension". In xarray, a ``DataArray`` object's *dimensions* are its named dimension - axes, and the name of the ``i``-th dimension is ``arr.dims[i]``. If an - array is created without dimension names, the default dimension names are - ``dim_0``, ``dim_1``, and so forth. + axes ``da.dims``, and the name of the ``i``-th dimension is ``da.dims[i]``. + If an array is created without specifying dimension names, the default dimension + names will be ``dim_0``, ``dim_1``, and so forth. Coordinate An array that labels a dimension or set of dimensions of another ``DataArray``. In the usual one-dimensional case, the coordinate array's - values can loosely be thought of as tick labels along a dimension. There - are two types of coordinate arrays: *dimension coordinates* and - *non-dimension coordinates* (see below). A coordinate named ``x`` can be - retrieved from ``arr.coords[x]``. A ``DataArray`` can have more - coordinates than dimensions because a single dimension can be labeled by - multiple coordinate arrays. However, only one coordinate array can be a - assigned as a particular dimension's dimension coordinate array. As a - consequence, ``len(arr.dims) <= len(arr.coords)`` in general. + values can loosely be thought of as tick labels along a dimension. We + distinguish :term:`Dimension coordinate` vs. :term:`Non-dimension + coordinate` and :term:`Indexed coordinate` vs. :term:`Non-indexed + coordinate`. A coordinate named ``x`` can be retrieved from + ``arr.coords[x]``. A ``DataArray`` can have more coordinates than + dimensions because a single dimension can be labeled by multiple + coordinate arrays. However, only one coordinate array can be a assigned + as a particular dimension's dimension coordinate array. Dimension coordinate A one-dimensional coordinate array assigned to ``arr`` with both a name - and dimension name in ``arr.dims``. Dimension coordinates are used for - label-based indexing and alignment, like the index found on a - :py:class:`pandas.DataFrame` or :py:class:`pandas.Series`. In fact, - dimension coordinates use :py:class:`pandas.Index` objects under the - hood for efficient computation. Dimension coordinates are marked by - ``*`` when printing a ``DataArray`` or ``Dataset``. + and dimension name in ``arr.dims``. Usually (but not always), a + dimension coordinate is also an :term:`Indexed coordinate` so that it can + be used for label-based indexing and alignment, like the index found on + a :py:class:`pandas.DataFrame` or :py:class:`pandas.Series`. Non-dimension coordinate A coordinate array assigned to ``arr`` with a name in ``arr.coords`` but @@ -79,20 +77,40 @@ complete examples, please consult the relevant documentation.* example, multidimensional coordinates are often used in geoscience datasets when :doc:`the data's physical coordinates (such as latitude and longitude) differ from their logical coordinates - <../examples/multidimensional-coords>`. However, non-dimension coordinates - are not indexed, and any operation on non-dimension coordinates that - leverages indexing will fail. Printing ``arr.coords`` will print all of - ``arr``'s coordinate names, with the corresponding dimension(s) in - parentheses. For example, ``coord_name (dim_name) 1 2 3 ...``. + <../examples/multidimensional-coords>`. Printing ``arr.coords`` will + print all of ``arr``'s coordinate names, with the corresponding + dimension(s) in parentheses. For example, ``coord_name (dim_name) 1 2 3 + ...``. + + Indexed coordinate + A coordinate which has an associated :term:`Index`. Generally this means + that the coordinate labels can be used for indexing (selection) and/or + alignment. An indexed coordinate may have one or more arbitrary + dimensions although in most cases it is also a :term:`Dimension + coordinate`. It may or may not be grouped with other indexed coordinates + depending on whether they share the same index. Indexed coordinates are + marked by an asterisk ``*`` when printing a ``DataArray`` or ``Dataset``. + + Non-indexed coordinate + A coordinate which has no associated :term:`Index`. It may still + represent fixed labels along one or more dimensions but it cannot be + used for label-based indexing and alignment. Index - An *index* is a data structure optimized for efficient selecting and - slicing of an associated array. Xarray creates indexes for dimension - coordinates so that operations along dimensions are fast, while - non-dimension coordinates are not indexed. Under the hood, indexes are - implemented as :py:class:`pandas.Index` objects. The index associated - with dimension name ``x`` can be retrieved by ``arr.indexes[x]``. By - construction, ``len(arr.dims) == len(arr.indexes)`` + An *index* is a data structure optimized for efficient data selection + and alignment within a discrete or continuous space that is defined by + coordinate labels (unless it is a functional index). By default, Xarray + creates a :py:class:`~xarray.indexes.PandasIndex` object (i.e., a + :py:class:`pandas.Index` wrapper) for each :term:`Dimension coordinate`. + For more advanced use cases (e.g., staggered or irregular grids, + geospatial indexes), Xarray also accepts any instance of a specialized + :py:class:`~xarray.indexes.Index` subclass that is associated to one or + more arbitrary coordinates. The index associated with the coordinate + ``x`` can be retrieved by ``arr.xindexes[x]`` (or ``arr.indexes["x"]`` + if the index is convertible to a :py:class:`pandas.Index` object). If + two coordinates ``x`` and ``y`` share the same index, + ``arr.xindexes[x]`` and ``arr.xindexes[y]`` both return the same + :py:class:`~xarray.indexes.Index` object. name The names of dimensions, coordinates, DataArray objects and data @@ -112,3 +130,128 @@ complete examples, please consult the relevant documentation.* ``__array_ufunc__`` and ``__array_function__`` protocols are also required. __ https://numpy.org/neps/nep-0022-ndarray-duck-typing-overview.html + + .. ipython:: python + :suppress: + + import numpy as np + import xarray as xr + + Aligning + Aligning refers to the process of ensuring that two or more DataArrays or Datasets + have the same dimensions and coordinates, so that they can be combined or compared properly. + + .. ipython:: python + + x = xr.DataArray( + [[25, 35], [10, 24]], + dims=("lat", "lon"), + coords={"lat": [35.0, 40.0], "lon": [100.0, 120.0]}, + ) + y = xr.DataArray( + [[20, 5], [7, 13]], + dims=("lat", "lon"), + coords={"lat": [35.0, 42.0], "lon": [100.0, 120.0]}, + ) + x + y + + Broadcasting + A technique that allows operations to be performed on arrays with different shapes and dimensions. + When performing operations on arrays with different shapes and dimensions, xarray will automatically attempt to broadcast the + arrays to a common shape before the operation is applied. + + .. ipython:: python + + # 'a' has shape (3,) and 'b' has shape (4,) + a = xr.DataArray(np.array([1, 2, 3]), dims=["x"]) + b = xr.DataArray(np.array([4, 5, 6, 7]), dims=["y"]) + + # 2D array with shape (3, 4) + a + b + + Merging + Merging is used to combine two or more Datasets or DataArrays that have different variables or coordinates along + the same dimensions. When merging, xarray aligns the variables and coordinates of the different datasets along + the specified dimensions and creates a new ``Dataset`` containing all the variables and coordinates. + + .. ipython:: python + + # create two 1D arrays with names + arr1 = xr.DataArray( + [1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]}, name="arr1" + ) + arr2 = xr.DataArray( + [4, 5, 6], dims=["x"], coords={"x": [20, 30, 40]}, name="arr2" + ) + + # merge the two arrays into a new dataset + merged_ds = xr.Dataset({"arr1": arr1, "arr2": arr2}) + merged_ds + + Concatenating + Concatenating is used to combine two or more Datasets or DataArrays along a dimension. When concatenating, + xarray arranges the datasets or dataarrays along a new dimension, and the resulting ``Dataset`` or ``Dataarray`` + will have the same variables and coordinates along the other dimensions. + + .. ipython:: python + + a = xr.DataArray([[1, 2], [3, 4]], dims=("x", "y")) + b = xr.DataArray([[5, 6], [7, 8]], dims=("x", "y")) + c = xr.concat([a, b], dim="c") + c + + Combining + Combining is the process of arranging two or more DataArrays or Datasets into a single ``DataArray`` or + ``Dataset`` using some combination of merging and concatenation operations. + + .. ipython:: python + + ds1 = xr.Dataset( + {"data": xr.DataArray([[1, 2], [3, 4]], dims=("x", "y"))}, + coords={"x": [1, 2], "y": [3, 4]}, + ) + ds2 = xr.Dataset( + {"data": xr.DataArray([[5, 6], [7, 8]], dims=("x", "y"))}, + coords={"x": [2, 3], "y": [4, 5]}, + ) + + # combine the datasets + combined_ds = xr.combine_by_coords([ds1, ds2]) + combined_ds + + lazy + Lazily-evaluated operations do not load data into memory until necessary.Instead of doing calculations + right away, xarray lets you plan what calculations you want to do, like finding the + average temperature in a dataset.This planning is called "lazy evaluation." Later, when + you're ready to see the final result, you tell xarray, "Okay, go ahead and do those calculations now!" + That's when xarray starts working through the steps you planned and gives you the answer you wanted.This + lazy approach helps save time and memory because xarray only does the work when you actually need the + results. + + labeled + Labeled data has metadata describing the context of the data, not just the raw data values. + This contextual information can be labels for array axes (i.e. dimension names) tick labels along axes (stored as Coordinate variables) or unique names for each array. These labels + provide context and meaning to the data, making it easier to understand and work with. If you have + temperature data for different cities over time. Using xarray, you can label the dimensions: one for + cities and another for time. + + serialization + Serialization is the process of converting your data into a format that makes it easy to save and share. + When you serialize data in xarray, you're taking all those temperature measurements, along with their + labels and other information, and turning them into a format that can be stored in a file or sent over + the internet. xarray objects can be serialized into formats which store the labels alongside the data. + Some supported serialization formats are files that can then be stored or transferred (e.g. netCDF), + whilst others are protocols that allow for data access over a network (e.g. Zarr). + + indexing + :ref:`Indexing` is how you select subsets of your data which you are interested in. + + - Label-based Indexing: Selecting data by passing a specific label and comparing it to the labels + stored in the associated coordinates. You can use labels to specify what you want like "Give me the + temperature for New York on July 15th." + + - Positional Indexing: You can use numbers to refer to positions in the data like "Give me the third temperature value" This is useful when you know the order of your data but don't need to remember the exact labels. + + - Slicing: You can take a "slice" of your data, like you might want all temperatures from July 1st + to July 10th. xarray supports slicing for both positional and label-based indexing. diff --git a/doc/user-guide/testing.rst b/doc/user-guide/testing.rst new file mode 100644 index 00000000000..13279eccb0b --- /dev/null +++ b/doc/user-guide/testing.rst @@ -0,0 +1,303 @@ +.. _testing: + +Testing your code +================= + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + + np.random.seed(123456) + +.. _testing.hypothesis: + +Hypothesis testing +------------------ + +.. note:: + + Testing with hypothesis is a fairly advanced topic. Before reading this section it is recommended that you take a look + at our guide to xarray's :ref:`data structures`, are familiar with conventional unit testing in + `pytest `_, and have seen the + `hypothesis library documentation `_. + +`The hypothesis library `_ is a powerful tool for property-based testing. +Instead of writing tests for one example at a time, it allows you to write tests parameterized by a source of many +dynamically generated examples. For example you might have written a test which you wish to be parameterized by the set +of all possible integers via :py:func:`hypothesis.strategies.integers()`. + +Property-based testing is extremely powerful, because (unlike more conventional example-based testing) it can find bugs +that you did not even think to look for! + +Strategies +~~~~~~~~~~ + +Each source of examples is called a "strategy", and xarray provides a range of custom strategies which produce xarray +data structures containing arbitrary data. You can use these to efficiently test downstream code, +quickly ensuring that your code can handle xarray objects of all possible structures and contents. + +These strategies are accessible in the :py:mod:`xarray.testing.strategies` module, which provides + +.. currentmodule:: xarray + +.. autosummary:: + + testing.strategies.supported_dtypes + testing.strategies.names + testing.strategies.dimension_names + testing.strategies.dimension_sizes + testing.strategies.attrs + testing.strategies.variables + testing.strategies.unique_subset_of + +These build upon the numpy and array API strategies offered in :py:mod:`hypothesis.extra.numpy` and :py:mod:`hypothesis.extra.array_api`: + +.. ipython:: python + + import hypothesis.extra.numpy as npst + +Generating Examples +~~~~~~~~~~~~~~~~~~~ + +To see an example of what each of these strategies might produce, you can call one followed by the ``.example()`` method, +which is a general hypothesis method valid for all strategies. + +.. ipython:: python + + import xarray.testing.strategies as xrst + + xrst.variables().example() + xrst.variables().example() + xrst.variables().example() + +You can see that calling ``.example()`` multiple times will generate different examples, giving you an idea of the wide +range of data that the xarray strategies can generate. + +In your tests however you should not use ``.example()`` - instead you should parameterize your tests with the +:py:func:`hypothesis.given` decorator: + +.. ipython:: python + + from hypothesis import given + +.. ipython:: python + + @given(xrst.variables()) + def test_function_that_acts_on_variables(var): + assert func(var) == ... + + +Chaining Strategies +~~~~~~~~~~~~~~~~~~~ + +Xarray's strategies can accept other strategies as arguments, allowing you to customise the contents of the generated +examples. + +.. ipython:: python + + # generate a Variable containing an array with a complex number dtype, but all other details still arbitrary + from hypothesis.extra.numpy import complex_number_dtypes + + xrst.variables(dtype=complex_number_dtypes()).example() + +This also works with custom strategies, or strategies defined in other packages. +For example you could imagine creating a ``chunks`` strategy to specify particular chunking patterns for a dask-backed array. + +Fixing Arguments +~~~~~~~~~~~~~~~~ + +If you want to fix one aspect of the data structure, whilst allowing variation in the generated examples +over all other aspects, then use :py:func:`hypothesis.strategies.just()`. + +.. ipython:: python + + import hypothesis.strategies as st + + # Generates only variable objects with dimensions ["x", "y"] + xrst.variables(dims=st.just(["x", "y"])).example() + +(This is technically another example of chaining strategies - :py:func:`hypothesis.strategies.just()` is simply a +special strategy that just contains a single example.) + +To fix the length of dimensions you can instead pass ``dims`` as a mapping of dimension names to lengths +(i.e. following xarray objects' ``.sizes()`` property), e.g. + +.. ipython:: python + + # Generates only variables with dimensions ["x", "y"], of lengths 2 & 3 respectively + xrst.variables(dims=st.just({"x": 2, "y": 3})).example() + +You can also use this to specify that you want examples which are missing some part of the data structure, for instance + +.. ipython:: python + + # Generates a Variable with no attributes + xrst.variables(attrs=st.just({})).example() + +Through a combination of chaining strategies and fixing arguments, you can specify quite complicated requirements on the +objects your chained strategy will generate. + +.. ipython:: python + + fixed_x_variable_y_maybe_z = st.fixed_dictionaries( + {"x": st.just(2), "y": st.integers(3, 4)}, optional={"z": st.just(2)} + ) + fixed_x_variable_y_maybe_z.example() + + special_variables = xrst.variables(dims=fixed_x_variable_y_maybe_z) + + special_variables.example() + special_variables.example() + +Here we have used one of hypothesis' built-in strategies :py:func:`hypothesis.strategies.fixed_dictionaries` to create a +strategy which generates mappings of dimension names to lengths (i.e. the ``size`` of the xarray object we want). +This particular strategy will always generate an ``x`` dimension of length 2, and a ``y`` dimension of +length either 3 or 4, and will sometimes also generate a ``z`` dimension of length 2. +By feeding this strategy for dictionaries into the ``dims`` argument of xarray's :py:func:`~st.variables` strategy, +we can generate arbitrary :py:class:`~xarray.Variable` objects whose dimensions will always match these specifications. + +Generating Duck-type Arrays +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Xarray objects don't have to wrap numpy arrays, in fact they can wrap any array type which presents the same API as a +numpy array (so-called "duck array wrapping", see :ref:`wrapping numpy-like arrays `). + +Imagine we want to write a strategy which generates arbitrary ``Variable`` objects, each of which wraps a +:py:class:`sparse.COO` array instead of a ``numpy.ndarray``. How could we do that? There are two ways: + +1. Create a xarray object with numpy data and use the hypothesis' ``.map()`` method to convert the underlying array to a +different type: + +.. ipython:: python + + import sparse + +.. ipython:: python + + def convert_to_sparse(var): + return var.copy(data=sparse.COO.from_numpy(var.to_numpy())) + +.. ipython:: python + + sparse_variables = xrst.variables(dims=xrst.dimension_names(min_dims=1)).map( + convert_to_sparse + ) + + sparse_variables.example() + sparse_variables.example() + +2. Pass a function which returns a strategy which generates the duck-typed arrays directly to the ``array_strategy_fn`` argument of the xarray strategies: + +.. ipython:: python + + def sparse_random_arrays(shape: tuple[int]) -> sparse._coo.core.COO: + """Strategy which generates random sparse.COO arrays""" + if shape is None: + shape = npst.array_shapes() + else: + shape = st.just(shape) + density = st.integers(min_value=0, max_value=1) + # note sparse.random does not accept a dtype kwarg + return st.builds(sparse.random, shape=shape, density=density) + + + def sparse_random_arrays_fn( + *, shape: tuple[int, ...], dtype: np.dtype + ) -> st.SearchStrategy[sparse._coo.core.COO]: + return sparse_random_arrays(shape=shape) + + +.. ipython:: python + + sparse_random_variables = xrst.variables( + array_strategy_fn=sparse_random_arrays_fn, dtype=st.just(np.dtype("float64")) + ) + sparse_random_variables.example() + +Either approach is fine, but one may be more convenient than the other depending on the type of the duck array which you +want to wrap. + +Compatibility with the Python Array API Standard +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Xarray aims to be compatible with any duck-array type that conforms to the `Python Array API Standard `_ +(see our :ref:`docs on Array API Standard support `). + +.. warning:: + + The strategies defined in :py:mod:`testing.strategies` are **not** guaranteed to use array API standard-compliant + dtypes by default. + For example arrays with the dtype ``np.dtype('float16')`` may be generated by :py:func:`testing.strategies.variables` + (assuming the ``dtype`` kwarg was not explicitly passed), despite ``np.dtype('float16')`` not being in the + array API standard. + +If the array type you want to generate has an array API-compliant top-level namespace +(e.g. that which is conventionally imported as ``xp`` or similar), +you can use this neat trick: + +.. ipython:: python + :okwarning: + + from numpy import array_api as xp # available in numpy 1.26.0 + + from hypothesis.extra.array_api import make_strategies_namespace + + xps = make_strategies_namespace(xp) + + xp_variables = xrst.variables( + array_strategy_fn=xps.arrays, + dtype=xps.scalar_dtypes(), + ) + xp_variables.example() + +Another array API-compliant duck array library would replace the import, e.g. ``import cupy as cp`` instead. + +Testing over Subsets of Dimensions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A common task when testing xarray user code is checking that your function works for all valid input dimensions. +We can chain strategies to achieve this, for which the helper strategy :py:func:`~testing.strategies.unique_subset_of` +is useful. + +It works for lists of dimension names + +.. ipython:: python + + dims = ["x", "y", "z"] + xrst.unique_subset_of(dims).example() + xrst.unique_subset_of(dims).example() + +as well as for mappings of dimension names to sizes + +.. ipython:: python + + dim_sizes = {"x": 2, "y": 3, "z": 4} + xrst.unique_subset_of(dim_sizes).example() + xrst.unique_subset_of(dim_sizes).example() + +This is useful because operations like reductions can be performed over any subset of the xarray object's dimensions. +For example we can write a pytest test that tests that a reduction gives the expected result when applying that reduction +along any possible valid subset of the Variable's dimensions. + +.. code-block:: python + + import numpy.testing as npt + + + @given(st.data(), xrst.variables(dims=xrst.dimension_names(min_dims=1))) + def test_mean(data, var): + """Test that the mean of an xarray Variable is always equal to the mean of the underlying array.""" + + # specify arbitrary reduction along at least one dimension + reduction_dims = data.draw(xrst.unique_subset_of(var.dims, min_size=1)) + + # create expected result (using nanmean because arrays with Nans will be generated) + reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims) + expected = np.nanmean(var.data, axis=reduction_axes) + + # assert property is always satisfied + result = var.mean(dim=reduction_dims).data + npt.assert_equal(expected, result) diff --git a/doc/user-guide/time-series.rst b/doc/user-guide/time-series.rst index d2e15adeba7..82172aa8998 100644 --- a/doc/user-guide/time-series.rst +++ b/doc/user-guide/time-series.rst @@ -89,7 +89,7 @@ items and with the `slice` object: .. ipython:: python - time = pd.date_range("2000-01-01", freq="H", periods=365 * 24) + time = pd.date_range("2000-01-01", freq="h", periods=365 * 24) ds = xr.Dataset({"foo": ("time", np.arange(365 * 24)), "time": time}) ds.sel(time="2000-01") ds.sel(time=slice("2000-06-01", "2000-06-10")) @@ -115,7 +115,7 @@ given ``DataArray`` can be quickly computed using a special ``.dt`` accessor. .. ipython:: python - time = pd.date_range("2000-01-01", freq="6H", periods=365 * 4) + time = pd.date_range("2000-01-01", freq="6h", periods=365 * 4) ds = xr.Dataset({"foo": ("time", np.arange(365 * 4)), "time": time}) ds.time.dt.hour ds.time.dt.dayofweek @@ -207,7 +207,7 @@ For example, we can downsample our dataset from hourly to 6-hourly: .. ipython:: python :okwarning: - ds.resample(time="6H") + ds.resample(time="6h") This will create a specialized ``Resample`` object which saves information necessary for resampling. All of the reduction methods which work with @@ -216,14 +216,21 @@ necessary for resampling. All of the reduction methods which work with .. ipython:: python :okwarning: - ds.resample(time="6H").mean() + ds.resample(time="6h").mean() You can also supply an arbitrary reduction function to aggregate over each resampling group: .. ipython:: python - ds.resample(time="6H").reduce(np.mean) + ds.resample(time="6h").reduce(np.mean) + +You can also resample on the time dimension while applying reducing along other dimensions at the same time +by specifying the `dim` keyword argument + +.. code-block:: python + + ds.resample(time="6h").mean(dim=["time", "latitude", "longitude"]) For upsampling, xarray provides six methods: ``asfreq``, ``ffill``, ``bfill``, ``pad``, ``nearest`` and ``interpolate``. ``interpolate`` extends ``scipy.interpolate.interp1d`` @@ -236,8 +243,20 @@ Data that has indices outside of the given ``tolerance`` are set to ``NaN``. .. ipython:: python - ds.resample(time="1H").nearest(tolerance="1H") + ds.resample(time="1h").nearest(tolerance="1h") + +It is often desirable to center the time values after a resampling operation. +That can be accomplished by updating the resampled dataset time coordinate values +using time offset arithmetic via the `pandas.tseries.frequencies.to_offset`_ function. + +.. _pandas.tseries.frequencies.to_offset: https://pandas.pydata.org/docs/reference/api/pandas.tseries.frequencies.to_offset.html + +.. ipython:: python + resampled_ds = ds.resample(time="6h").mean() + offset = pd.tseries.frequencies.to_offset("6h") / 2 + resampled_ds["time"] = resampled_ds.get_index("time") + offset + resampled_ds For more examples of using grouped operations on a time dimension, see :doc:`../examples/weather-data`. diff --git a/doc/user-guide/weather-climate.rst b/doc/user-guide/weather-climate.rst index 30876eb36bc..5014f5a8641 100644 --- a/doc/user-guide/weather-climate.rst +++ b/doc/user-guide/weather-climate.rst @@ -57,14 +57,14 @@ CF-compliant coordinate variables .. _CFTimeIndex: -Non-standard calendars and dates outside the Timestamp-valid range ------------------------------------------------------------------- +Non-standard calendars and dates outside the nanosecond-precision range +----------------------------------------------------------------------- Through the standalone ``cftime`` library and a custom subclass of :py:class:`pandas.Index`, xarray supports a subset of the indexing functionality enabled through the standard :py:class:`pandas.DatetimeIndex` for dates from non-standard calendars commonly used in climate science or dates -using a standard calendar, but outside the `Timestamp-valid range`_ +using a standard calendar, but outside the `nanosecond-precision range`_ (approximately between years 1678 and 2262). .. note:: @@ -75,13 +75,19 @@ using a standard calendar, but outside the `Timestamp-valid range`_ any of the following are true: - The dates are from a non-standard calendar - - Any dates are outside the Timestamp-valid range. + - Any dates are outside the nanosecond-precision range. Otherwise pandas-compatible dates from a standard calendar will be represented with the ``np.datetime64[ns]`` data type, enabling the use of a :py:class:`pandas.DatetimeIndex` or arrays with dtype ``np.datetime64[ns]`` and their full set of associated features. + As of pandas version 2.0.0, pandas supports non-nanosecond precision datetime + values. For the time being, xarray still automatically casts datetime values + to nanosecond-precision for backwards compatibility with older pandas + versions; however, this is something we would like to relax going forward. + See :issue:`7493` for more discussion. + For example, you can create a DataArray indexed by a time coordinate with dates from a no-leap calendar and a :py:class:`~xarray.CFTimeIndex` will automatically be used: @@ -233,8 +239,8 @@ For data indexed by a :py:class:`~xarray.CFTimeIndex` xarray currently supports: .. ipython:: python - da.resample(time="81T", closed="right", label="right", offset="3T").mean() + da.resample(time="81min", closed="right", label="right", offset="3min").mean() -.. _Timestamp-valid range: https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timestamp-limitations +.. _nanosecond-precision range: https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timestamp-limitations .. _ISO 8601 standard: https://en.wikipedia.org/wiki/ISO_8601 .. _partial datetime string indexing: https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#partial-string-indexing diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a7218ba11da..1ef6f86f20a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -15,9 +15,9 @@ What's New np.random.seed(123456) -.. _whats-new.2023.04.0: +.. _whats-new.2024.03.0: -v2023.04.0 (unreleased) +v2024.03.0 (unreleased) ----------------------- New Features @@ -25,14 +25,1048 @@ New Features - Allow control over padding in rolling. (:issue:`2007`, :pr:`5603`). By `Kevin Squire `_. +- Do not broadcast in arithmetic operations when global option ``arithmetic_broadcast=False`` + (:issue:`6806`, :pull:`8784`). + By `Etienne Schalk `_ and `Deepak Cherian `_. +- Add the ``.oindex`` property to Explicitly Indexed Arrays for orthogonal indexing functionality. (:issue:`8238`, :pull:`8750`) + By `Anderson Banihirwe `_. + +- Add the ``.vindex`` property to Explicitly Indexed Arrays for vectorized indexing functionality. (:issue:`8238`, :pull:`8780`) + By `Anderson Banihirwe `_. + +- Expand use of ``.oindex`` and ``.vindex`` properties. (:pull: `8790`) + By `Anderson Banihirwe `_ and `Deepak Cherian `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + + +Deprecations +~~~~~~~~~~~~ + + +Bug fixes +~~~~~~~~~ +- The default ``freq`` parameter in :py:meth:`xr.date_range` and :py:meth:`xr.cftime_range` is + set to ``'D'`` only if ``periods``, ``start``, or ``end`` are ``None`` (:issue:`8770`, :pull:`8774`). + By `Roberto Chang `_. +- Ensure that non-nanosecond precision :py:class:`numpy.datetime64` and + :py:class:`numpy.timedelta64` values are cast to nanosecond precision values + when used in :py:meth:`DataArray.expand_dims` and + ::py:meth:`Dataset.expand_dims` (:pull:`8781`). By `Spencer + Clark `_. +- CF conform handling of `_FillValue`/`missing_value` and `dtype` in + `CFMaskCoder`/`CFScaleOffsetCoder` (:issue:`2304`, :issue:`5597`, + :issue:`7691`, :pull:`8713`, see also discussion in :pull:`7654`). + By `Kai Mühlbauer `_. +- do not cast `_FillValue`/`missing_value` in `CFMaskCoder` if `_Unsigned` is provided + (:issue:`8844`, :pull:`8852`). +- Adapt handling of copy keyword argument in scipy backend for numpy >= 2.0dev + (:issue:`8844`, :pull:`8851`). + By `Kai Mühlbauer `_. + +Documentation +~~~~~~~~~~~~~ + + +Internal Changes +~~~~~~~~~~~~~~~~ +- Migrates ``treenode`` functionality into ``xarray/core`` (:pull:`8757`) + By `Matt Savoie `_ and `Tom Nicholas + `_. + + +.. _whats-new.2024.02.0: + +v2024.02.0 (Feb 19, 2024) +------------------------- + +This release brings size information to the text ``repr``, changes to the accepted frequency +strings, and various bug fixes. + +Thanks to our 12 contributors: + +Anderson Banihirwe, Deepak Cherian, Eivind Jahren, Etienne Schalk, Justus Magin, Marco Wolsza, +Mathias Hauser, Matt Savoie, Maximilian Roos, Rambaud Pierrick, Tom Nicholas + +New Features +~~~~~~~~~~~~ + +- Added a simple ``nbytes`` representation in DataArrays and Dataset ``repr``. + (:issue:`8690`, :pull:`8702`). + By `Etienne Schalk `_. +- Allow negative frequency strings (e.g. ``"-1YE"``). These strings are for example used in + :py:func:`date_range`, and :py:func:`cftime_range` (:pull:`8651`). + By `Mathias Hauser `_. +- Add :py:meth:`NamedArray.expand_dims`, :py:meth:`NamedArray.permute_dims` and + :py:meth:`NamedArray.broadcast_to` (:pull:`8380`) + By `Anderson Banihirwe `_. +- Xarray now defers to `flox's heuristics `_ + to set the default `method` for groupby problems. This only applies to ``flox>=0.9``. + By `Deepak Cherian `_. +- All `quantile` methods (e.g. :py:meth:`DataArray.quantile`) now use `numbagg` + for the calculation of nanquantiles (i.e., `skipna=True`) if it is installed. + This is currently limited to the linear interpolation method (`method='linear'`). + (:issue:`7377`, :pull:`8684`) + By `Marco Wolsza `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- :py:func:`infer_freq` always returns the frequency strings as defined in pandas 2.2 + (:issue:`8612`, :pull:`8627`). + By `Mathias Hauser `_. + +Deprecations +~~~~~~~~~~~~ +- The `dt.weekday_name` parameter wasn't functional on modern pandas versions and has been + removed. (:issue:`8610`, :pull:`8664`) + By `Sam Coleman `_. + + +Bug fixes +~~~~~~~~~ + +- Fixed a regression that prevented multi-index level coordinates being serialized after resetting + or dropping the multi-index (:issue:`8628`, :pull:`8672`). + By `Benoit Bovy `_. +- Fix bug with broadcasting when wrapping array API-compliant classes. (:issue:`8665`, :pull:`8669`) + By `Tom Nicholas `_. +- Ensure :py:meth:`DataArray.unstack` works when wrapping array API-compliant + classes. (:issue:`8666`, :pull:`8668`) + By `Tom Nicholas `_. +- Fix negative slicing of Zarr arrays without dask installed. (:issue:`8252`) + By `Deepak Cherian `_. +- Preserve chunks when writing time-like variables to zarr by enabling lazy CF encoding of time-like + variables (:issue:`7132`, :issue:`8230`, :issue:`8432`, :pull:`8575`). + By `Spencer Clark `_ and `Mattia Almansi `_. +- Preserve chunks when writing time-like variables to zarr by enabling their lazy encoding + (:issue:`7132`, :issue:`8230`, :issue:`8432`, :pull:`8253`, :pull:`8575`; see also discussion in + :pull:`8253`). + By `Spencer Clark `_ and `Mattia Almansi `_. +- Raise an informative error if dtype encoding of time-like variables would lead to integer overflow + or unsafe conversion from floating point to integer values (:issue:`8542`, :pull:`8575`). + By `Spencer Clark `_. +- Raise an error when unstacking a MultiIndex that has duplicates as this would lead to silent data + loss (:issue:`7104`, :pull:`8737`). + By `Mathias Hauser `_. + +Documentation +~~~~~~~~~~~~~ +- Fix `variables` arg typo in `Dataset.sortby()` docstring (:issue:`8663`, :pull:`8670`) + By `Tom Vo `_. +- Fixed documentation where the use of the depreciated pandas frequency string prevented the + documentation from being built. (:pull:`8638`) + By `Sam Coleman `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- ``DataArray.dt`` now raises an ``AttributeError`` rather than a ``TypeError`` when the data isn't + datetime-like. (:issue:`8718`, :pull:`8724`) + By `Maximilian Roos `_. +- Move ``parallelcompat`` and ``chunk managers`` modules from ``xarray/core`` to + ``xarray/namedarray``. (:pull:`8319`) + By `Tom Nicholas `_ and `Anderson Banihirwe `_. +- Imports ``datatree`` repository and history into internal location. (:pull:`8688`) + By `Matt Savoie `_, `Justus Magin `_ + and `Tom Nicholas `_. +- Adds :py:func:`open_datatree` into ``xarray/backends`` (:pull:`8697`) + By `Matt Savoie `_ and `Tom Nicholas + `_. +- Refactor :py:meth:`xarray.core.indexing.DaskIndexingAdapter.__getitem__` to remove an unnecessary + rewrite of the indexer key (:issue: `8377`, :pull:`8758`) + By `Anderson Banihirwe `_. + +.. _whats-new.2024.01.1: + +v2024.01.1 (23 Jan, 2024) +------------------------- + +This release is to fix a bug with the rendering of the documentation, but it also includes changes to the handling of pandas frequency strings. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Following pandas, :py:meth:`infer_freq` will return ``"YE"``, instead of ``"Y"`` (formerly ``"A"``). + This is to be consistent with the deprecation of the latter frequency string in pandas 2.2. + This is a follow up to :pull:`8415` (:issue:`8612`, :pull:`8642`). + By `Mathias Hauser `_. + +Deprecations +~~~~~~~~~~~~ + +- Following pandas, the frequency string ``"Y"`` (formerly ``"A"``) is deprecated in + favor of ``"YE"``. These strings are used, for example, in :py:func:`date_range`, + :py:func:`cftime_range`, :py:meth:`DataArray.resample`, and :py:meth:`Dataset.resample` + among others (:issue:`8612`, :pull:`8629`). + By `Mathias Hauser `_. + +Documentation +~~~~~~~~~~~~~ + +- Pin ``sphinx-book-theme`` to ``1.0.1`` to fix a rendering issue with the sidebar in the docs. (:issue:`8619`, :pull:`8632`) + By `Tom Nicholas `_. + +.. _whats-new.2024.01.0: + +v2024.01.0 (17 Jan, 2024) +------------------------- + +This release brings support for weights in correlation and covariance functions, +a new `DataArray.cumulative` aggregation, improvements to `xr.map_blocks`, +an update to our minimum dependencies, and various bugfixes. + +Thanks to our 17 contributors to this release: + +Abel Aoun, Deepak Cherian, Illviljan, Johan Mathe, Justus Magin, Kai Mühlbauer, +Llorenç Lledó, Mark Harfouche, Markel, Mathias Hauser, Maximilian Roos, Michael Niklas, +Niclas Rieger, Sébastien Celles, Tom Nicholas, Trinh Quoc Anh, and crusaderky. + +New Features +~~~~~~~~~~~~ + +- :py:meth:`xr.cov` and :py:meth:`xr.corr` now support using weights (:issue:`8527`, :pull:`7392`). + By `Llorenç Lledó `_. +- Accept the compression arguments new in netCDF 1.6.0 in the netCDF4 backend. + See `netCDF4 documentation `_ for details. + Note that some new compression filters needs plugins to be installed which may not be available in all netCDF distributions. + By `Markel García-Díez `_. (:issue:`6929`, :pull:`7551`) +- Add :py:meth:`DataArray.cumulative` & :py:meth:`Dataset.cumulative` to compute + cumulative aggregations, such as ``sum``, along a dimension — for example + ``da.cumulative('time').sum()``. This is similar to pandas' ``.expanding``, + and mostly equivalent to ``.cumsum`` methods, or to + :py:meth:`DataArray.rolling` with a window length equal to the dimension size. + By `Maximilian Roos `_. (:pull:`8512`) +- Decode/Encode netCDF4 enums and store the enum definition in dataarrays' dtype metadata. + If multiple variables share the same enum in netCDF4, each dataarray will have its own + enum definition in their respective dtype metadata. + By `Abel Aoun `_. (:issue:`8144`, :pull:`8147`) + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The minimum versions of some dependencies were changed (:pull:`8586`): + + ===================== ========= ======== + Package Old New + ===================== ========= ======== + cartopy 0.20 0.21 + dask-core 2022.7 2022.12 + distributed 2022.7 2022.12 + flox 0.5 0.7 + iris 3.2 3.4 + matplotlib-base 3.5 3.6 + numpy 1.22 1.23 + numba 0.55 0.56 + packaging 21.3 22.0 + seaborn 0.11 0.12 + scipy 1.8 1.10 + typing_extensions 4.3 4.4 + zarr 2.12 2.13 + ===================== ========= ======== + +Deprecations +~~~~~~~~~~~~ + +- The `squeeze` kwarg to GroupBy is now deprecated. (:issue:`2157`, :pull:`8507`) + By `Deepak Cherian `_. + +Bug fixes +~~~~~~~~~ + +- Support non-string hashable dimensions in :py:class:`xarray.DataArray` (:issue:`8546`, :pull:`8559`). + By `Michael Niklas `_. +- Reverse index output of bottleneck's rolling move_argmax/move_argmin functions (:issue:`8541`, :pull:`8552`). + By `Kai Mühlbauer `_. +- Vendor `SerializableLock` from dask and use as default lock for netcdf4 backends (:issue:`8442`, :pull:`8571`). + By `Kai Mühlbauer `_. +- Add tests and fixes for empty :py:class:`CFTimeIndex`, including broken html repr (:issue:`7298`, :pull:`8600`). + By `Mathias Hauser `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- The implementation of :py:func:`map_blocks` has changed to minimize graph size and duplication of data. + This should be a strict improvement even though the graphs are not always embarassingly parallel any more. + Please open an issue if you spot a regression. (:pull:`8412`, :issue:`8409`). + By `Deepak Cherian `_. +- Remove null values before plotting. (:pull:`8535`). + By `Jimmy Westling `_. +- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`, + potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.bfill` to + use non-dask chunked array types. + (:pull:`8019`) By `Tom Nicholas `_. + +.. _whats-new.2023.12.0: + +v2023.12.0 (2023 Dec 08) +------------------------ + +This release brings new `hypothesis `_ strategies for testing, significantly faster rolling aggregations as well as +``ffill`` and ``bfill`` with ``numbagg``, a new :py:meth:`Dataset.eval` method, and improvements to +reading and writing Zarr arrays (including a new ``"a-"`` mode). + +Thanks to our 16 contributors: + +Anderson Banihirwe, Ben Mares, Carl Andersson, Deepak Cherian, Doug Latornell, Gregorio L. Trevisan, Illviljan, Jens Hedegaard Nielsen, Justus Magin, Mathias Hauser, Max Jones, Maximilian Roos, Michael Niklas, Patrick Hoefler, Ryan Abernathey, Tom Nicholas + +New Features +~~~~~~~~~~~~ + +- Added hypothesis strategies for generating :py:class:`xarray.Variable` objects containing arbitrary data, useful for parametrizing downstream tests. + Accessible under :py:mod:`testing.strategies`, and documented in a new page on testing in the User Guide. + (:issue:`6911`, :pull:`8404`) + By `Tom Nicholas `_. +- :py:meth:`rolling` uses `numbagg `_ for + most of its computations by default. Numbagg is up to 5x faster than bottleneck + where parallelization is possible. Where parallelization isn't possible — for + example a 1D array — it's about the same speed as bottleneck, and 2-5x faster + than pandas' default functions. (:pull:`8493`). numbagg is an optional + dependency, so requires installing separately. +- Use a concise format when plotting datetime arrays. (:pull:`8449`). + By `Jimmy Westling `_. +- Avoid overwriting unchanged existing coordinate variables when appending with :py:meth:`Dataset.to_zarr` by setting ``mode='a-'``. + By `Ryan Abernathey `_ and `Deepak Cherian `_. +- :py:meth:`~xarray.DataArray.rank` now operates on dask-backed arrays, assuming + the core dim has exactly one chunk. (:pull:`8475`). + By `Maximilian Roos `_. +- Add a :py:meth:`Dataset.eval` method, similar to the pandas' method of the + same name. (:pull:`7163`). This is currently marked as experimental and + doesn't yet support the ``numexpr`` engine. +- :py:meth:`Dataset.drop_vars` & :py:meth:`DataArray.drop_vars` allow passing a + callable, similar to :py:meth:`Dataset.where` & :py:meth:`Dataset.sortby` & others. + (:pull:`8511`). + By `Maximilian Roos `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Explicitly warn when creating xarray objects with repeated dimension names. + Such objects will also now raise when :py:meth:`DataArray.get_axis_num` is called, + which means many functions will raise. + This latter change is technically a breaking change, but whilst allowed, + this behaviour was never actually supported! (:issue:`3731`, :pull:`8491`) + By `Tom Nicholas `_. + +Deprecations +~~~~~~~~~~~~ +- As part of an effort to standardize the API, we're renaming the ``dims`` + keyword arg to ``dim`` for the minority of functions which current use + ``dims``. This started with :py:func:`xarray.dot` & :py:meth:`DataArray.dot` + and we'll gradually roll this out across all functions. The warnings are + currently ``PendingDeprecationWarning``, which are silenced by default. We'll + convert these to ``DeprecationWarning`` in a future release. + By `Maximilian Roos `_. +- Raise a ``FutureWarning`` warning that the type of :py:meth:`Dataset.dims` will be changed + from a mapping of dimension names to lengths to a set of dimension names. + This is to increase consistency with :py:meth:`DataArray.dims`. + To access a mapping of dimension names to lengths please use :py:meth:`Dataset.sizes`. + The same change also applies to `DatasetGroupBy.dims`. + (:issue:`8496`, :pull:`8500`) + By `Tom Nicholas `_. +- :py:meth:`Dataset.drop` & :py:meth:`DataArray.drop` are now deprecated, since pending deprecation for + several years. :py:meth:`DataArray.drop_sel` & :py:meth:`DataArray.drop_var` + replace them for labels & variables respectively. (:pull:`8497`) + By `Maximilian Roos `_. + +Bug fixes +~~~~~~~~~ + +- Fix dtype inference for ``pd.CategoricalIndex`` when categories are backed by a ``pd.ExtensionDtype`` (:pull:`8481`) +- Fix writing a variable that requires transposing when not writing to a region (:pull:`8484`) + By `Maximilian Roos `_. +- Static typing of ``p0`` and ``bounds`` arguments of :py:func:`xarray.DataArray.curvefit` and :py:func:`xarray.Dataset.curvefit` + was changed to ``Mapping`` (:pull:`8502`). + By `Michael Niklas `_. +- Fix typing of :py:func:`xarray.DataArray.to_netcdf` and :py:func:`xarray.Dataset.to_netcdf` + when ``compute`` is evaluated to bool instead of a Literal (:pull:`8268`). + By `Jens Hedegaard Nielsen `_. + +Documentation +~~~~~~~~~~~~~ + +- Added illustration of updating the time coordinate values of a resampled dataset using + time offset arithmetic. + This is the recommended technique to replace the use of the deprecated ``loffset`` parameter + in ``resample`` (:pull:`8479`). + By `Doug Latornell `_. +- Improved error message when attempting to get a variable which doesn't exist from a Dataset. + (:pull:`8474`) + By `Maximilian Roos `_. +- Fix default value of ``combine_attrs`` in :py:func:`xarray.combine_by_coords` (:pull:`8471`) + By `Gregorio L. Trevisan `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- :py:meth:`DataArray.bfill` & :py:meth:`DataArray.ffill` now use numbagg `_ by + default, which is up to 5x faster where parallelization is possible. (:pull:`8339`) + By `Maximilian Roos `_. +- Update mypy version to 1.7 (:issue:`8448`, :pull:`8501`). + By `Michael Niklas `_. + +.. _whats-new.2023.11.0: + +v2023.11.0 (Nov 16, 2023) +------------------------- + + +.. tip:: + + `This is our 10th year anniversary release! `_ Thank you for your love and support. + + +This release brings the ability to use ``opt_einsum`` for :py:func:`xarray.dot` by default, +support for auto-detecting ``region`` when writing partial datasets to Zarr, and the use of h5py +drivers with ``h5netcdf``. + +Thanks to the 19 contributors to this release: +Aman Bagrecha, Anderson Banihirwe, Ben Mares, Deepak Cherian, Dimitri Papadopoulos Orfanos, Ezequiel Cimadevilla Alvarez, +Illviljan, Justus Magin, Katelyn FitzGerald, Kai Muehlbauer, Martin Durant, Maximilian Roos, Metamess, Sam Levang, Spencer Clark, Tom Nicholas, mgunyho, templiert + +New Features +~~~~~~~~~~~~ + +- Use `opt_einsum `_ for :py:func:`xarray.dot` by default if installed. + By `Deepak Cherian `_. (:issue:`7764`, :pull:`8373`). +- Add ``DataArray.dt.total_seconds()`` method to match the Pandas API. (:pull:`8435`). + By `Ben Mares `_. +- Allow passing ``region="auto"`` in :py:meth:`Dataset.to_zarr` to automatically infer the + region to write in the original store. Also implement automatic transpose when dimension + order does not match the original store. (:issue:`7702`, :issue:`8421`, :pull:`8434`). + By `Sam Levang `_. +- Allow the usage of h5py drivers (eg: ros3) via h5netcdf (:pull:`8360`). + By `Ezequiel Cimadevilla `_. +- Enable VLEN string fill_values, preserve VLEN string dtypes (:issue:`1647`, :issue:`7652`, :issue:`7868`, :pull:`7869`). + By `Kai Mühlbauer `_. + +Breaking changes +~~~~~~~~~~~~~~~~ +- drop support for `cdms2 `_. Please use + `xcdat `_ instead (:pull:`8441`). + By `Justus Magin `_. +- Following pandas, :py:meth:`infer_freq` will return ``"Y"``, ``"YS"``, + ``"QE"``, ``"ME"``, ``"h"``, ``"min"``, ``"s"``, ``"ms"``, ``"us"``, or + ``"ns"`` instead of ``"A"``, ``"AS"``, ``"Q"``, ``"M"``, ``"H"``, ``"T"``, + ``"S"``, ``"L"``, ``"U"``, or ``"N"``. This is to be consistent with the + deprecation of the latter frequency strings (:issue:`8394`, :pull:`8415`). By + `Spencer Clark `_. +- Bump minimum tested pint version to ``>=0.22``. By `Deepak Cherian `_. +- Minimum supported versions for the following packages have changed: ``h5py >=3.7``, ``h5netcdf>=1.1``. + By `Kai Mühlbauer `_. + +Deprecations +~~~~~~~~~~~~ +- The PseudoNetCDF backend has been removed. By `Deepak Cherian `_. +- Supplying dimension-ordered sequences to :py:meth:`DataArray.chunk` & + :py:meth:`Dataset.chunk` is deprecated in favor of supplying a dictionary of + dimensions, or a single ``int`` or ``"auto"`` argument covering all + dimensions. Xarray favors using dimensions names rather than positions, and + this was one place in the API where dimension positions were used. + (:pull:`8341`) + By `Maximilian Roos `_. +- Following pandas, the frequency strings ``"A"``, ``"AS"``, ``"Q"``, ``"M"``, + ``"H"``, ``"T"``, ``"S"``, ``"L"``, ``"U"``, and ``"N"`` are deprecated in + favor of ``"Y"``, ``"YS"``, ``"QE"``, ``"ME"``, ``"h"``, ``"min"``, ``"s"``, + ``"ms"``, ``"us"``, and ``"ns"``, respectively. These strings are used, for + example, in :py:func:`date_range`, :py:func:`cftime_range`, + :py:meth:`DataArray.resample`, and :py:meth:`Dataset.resample` among others + (:issue:`8394`, :pull:`8415`). By `Spencer Clark + `_. +- Rename :py:meth:`Dataset.to_array` to :py:meth:`Dataset.to_dataarray` for + consistency with :py:meth:`DataArray.to_dataset` & + :py:func:`open_dataarray` functions. This is a "soft" deprecation — the + existing methods work and don't raise any warnings, given the relatively small + benefits of the change. + By `Maximilian Roos `_. +- Finally remove ``keep_attrs`` kwarg from :py:meth:`DataArray.resample` and + :py:meth:`Dataset.resample`. These were deprecated a long time ago. + By `Deepak Cherian `_. + +Bug fixes +~~~~~~~~~ + +- Port `bug fix from pandas `_ + to eliminate the adjustment of resample bin edges in the case that the + resampling frequency has units of days and is greater than one day + (e.g. ``"2D"``, ``"3D"`` etc.) and the ``closed`` argument is set to + ``"right"`` to xarray's implementation of resample for data indexed by a + :py:class:`CFTimeIndex` (:pull:`8393`). + By `Spencer Clark `_. +- Fix to once again support date offset strings as input to the loffset + parameter of resample and test this functionality (:pull:`8422`, :issue:`8399`). + By `Katelyn FitzGerald `_. +- Fix a bug where :py:meth:`DataArray.to_dataset` silently drops a variable + if a coordinate with the same name already exists (:pull:`8433`, :issue:`7823`). + By `András Gunyhó `_. +- Fix for :py:meth:`DataArray.to_zarr` & :py:meth:`Dataset.to_zarr` to close + the created zarr store when passing a path with `.zip` extension (:pull:`8425`). + By `Carl Andersson _`. + +Documentation +~~~~~~~~~~~~~ +- Small updates to documentation on distributed writes: See :ref:`io.zarr.appending` to Zarr. + By `Deepak Cherian `_. + +.. _whats-new.2023.10.1: + +v2023.10.1 (19 Oct, 2023) +------------------------- + +This release updates our minimum numpy version in ``pyproject.toml`` to 1.22, +consistent with our documentation below. + +.. _whats-new.2023.10.0: + +v2023.10.0 (19 Oct, 2023) +------------------------- + +This release brings performance enhancements to reading Zarr datasets, the ability to use `numbagg `_ for reductions, +an expansion in API for ``rolling_exp``, fixes two regressions with datetime decoding, +and many other bugfixes and improvements. Groupby reductions will also use ``numbagg`` if ``flox>=0.8.1`` and ``numbagg`` are both installed. + +Thanks to our 13 contributors: +Anderson Banihirwe, Bart Schilperoort, Deepak Cherian, Illviljan, Kai Mühlbauer, Mathias Hauser, Maximilian Roos, Michael Niklas, Pieter Eendebak, Simon Høxbro Hansen, Spencer Clark, Tom White, olimcc + +New Features +~~~~~~~~~~~~ +- Support high-performance reductions with `numbagg `_. + This is enabled by default if ``numbagg`` is installed. + By `Deepak Cherian `_. (:pull:`8316`) +- Add ``corr``, ``cov``, ``std`` & ``var`` to ``.rolling_exp``. + By `Maximilian Roos `_. (:pull:`8307`) +- :py:meth:`DataArray.where` & :py:meth:`Dataset.where` accept a callable for + the ``other`` parameter, passing the object as the only argument. Previously, + this was only valid for the ``cond`` parameter. (:issue:`8255`) + By `Maximilian Roos `_. +- ``.rolling_exp`` functions can now take a ``min_weight`` parameter, to only + output values when there are sufficient recent non-nan values. + ``numbagg>=0.3.1`` is required. (:pull:`8285`) + By `Maximilian Roos `_. +- :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for + the ``variables`` parameter, passing the object as the only argument. + By `Maximilian Roos `_. +- ``.rolling_exp`` functions can now operate on dask-backed arrays, assuming the + core dim has exactly one chunk. (:pull:`8284`). + By `Maximilian Roos `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Made more arguments keyword-only (e.g. ``keep_attrs``, ``skipna``) for many :py:class:`xarray.DataArray` and + :py:class:`xarray.Dataset` methods (:pull:`6403`). By `Mathias Hauser `_. +- :py:meth:`Dataset.to_zarr` & :py:meth:`DataArray.to_zarr` require keyword + arguments after the initial 7 positional arguments. + By `Maximilian Roos `_. + + +Deprecations +~~~~~~~~~~~~ +- Rename :py:meth:`Dataset.reset_encoding` & :py:meth:`DataArray.reset_encoding` + to :py:meth:`Dataset.drop_encoding` & :py:meth:`DataArray.drop_encoding` for + consistency with other ``drop`` & ``reset`` methods — ``drop`` generally + removes something, while ``reset`` generally resets to some default or + standard value. (:pull:`8287`, :issue:`8259`) + By `Maximilian Roos `_. + +Bug fixes +~~~~~~~~~ + +- :py:meth:`DataArray.rename` & :py:meth:`Dataset.rename` would emit a warning + when the operation was a no-op. (:issue:`8266`) + By `Simon Hansen `_. +- Fixed a regression introduced in the previous release checking time-like units + when encoding/decoding masked data (:issue:`8269`, :pull:`8277`). + By `Kai Mühlbauer `_. + +- Fix datetime encoding precision loss regression introduced in the previous + release for datetimes encoded with units requiring floating point values, and + a reference date not equal to the first value of the datetime array + (:issue:`8271`, :pull:`8272`). By `Spencer Clark + `_. + +- Fix excess metadata requests when using a Zarr store. Prior to this, metadata + was re-read every time data was retrieved from the array, now metadata is retrieved only once + when they array is initialized. + (:issue:`8290`, :pull:`8297`). + By `Oliver McCormack `_. + +- Fix to_zarr ending in a ReadOnlyError when consolidated metadata was used and the + write_empty_chunks was provided. + (:issue:`8323`, :pull:`8326`) + By `Matthijs Amesz `_. + + +Documentation +~~~~~~~~~~~~~ + +- Added page on the interoperability of xarray objects. + (:pull:`7992`) By `Tom Nicholas `_. +- Added xarray-regrid to the list of xarray related projects (:pull:`8272`). + By `Bart Schilperoort `_. + + +Internal Changes +~~~~~~~~~~~~~~~~ + +- More improvements to support the Python `array API standard `_ + by using duck array ops in more places in the codebase. (:pull:`8267`) + By `Tom White `_. + + +.. _whats-new.2023.09.0: + +v2023.09.0 (Sep 26, 2023) +------------------------- + +This release continues work on the new :py:class:`xarray.Coordinates` object, allows to provide `preferred_chunks` when +reading from netcdf files, enables :py:func:`xarray.apply_ufunc` to handle missing core dimensions and fixes several bugs. + +Thanks to the 24 contributors to this release: Alexander Fischer, Amrest Chinkamol, Benoit Bovy, Darsh Ranjan, Deepak Cherian, +Gianfranco Costamagna, Gregorio L. Trevisan, Illviljan, Joe Hamman, JR, Justus Magin, Kai Mühlbauer, Kian-Meng Ang, Kyle Sunden, +Martin Raspaud, Mathias Hauser, Mattia Almansi, Maximilian Roos, András Gunyhó, Michael Niklas, Richard Kleijn, Riulinchen, +Tom Nicholas and Wiktor Kraśnicki. + +We welcome the following new contributors to Xarray!: Alexander Fischer, Amrest Chinkamol, Darsh Ranjan, Gianfranco Costamagna, Gregorio L. Trevisan, +Kian-Meng Ang, Riulinchen and Wiktor Kraśnicki. + +New Features +~~~~~~~~~~~~ + +- Added the :py:meth:`Coordinates.assign` method that can be used to combine + different collections of coordinates prior to assign them to a Dataset or + DataArray (:pull:`8102`) at once. + By `Benoît Bovy `_. +- Provide `preferred_chunks` for data read from netcdf files (:issue:`1440`, :pull:`7948`). + By `Martin Raspaud `_. +- Added `on_missing_core_dims` to :py:meth:`apply_ufunc` to allow for copying or + dropping a :py:class:`Dataset`'s variables with missing core dimensions (:pull:`8138`). + By `Maximilian Roos `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The :py:class:`Coordinates` constructor now creates a (pandas) index by + default for each dimension coordinate. To keep the previous behavior (no index + created), pass an empty dictionary to ``indexes``. The constructor now also + extracts and add the indexes from another :py:class:`Coordinates` object + passed via ``coords`` (:pull:`8107`). + By `Benoît Bovy `_. +- Static typing of ``xlim`` and ``ylim`` arguments in plotting functions now must + be ``tuple[float, float]`` to align with matplotlib requirements. (:issue:`7802`, :pull:`8030`). + By `Michael Niklas `_. + +Deprecations +~~~~~~~~~~~~ + +- Deprecate passing a :py:class:`pandas.MultiIndex` object directly to the + :py:class:`Dataset` and :py:class:`DataArray` constructors as well as to + :py:meth:`Dataset.assign` and :py:meth:`Dataset.assign_coords`. + A new Xarray :py:class:`Coordinates` object has to be created first using + :py:meth:`Coordinates.from_pandas_multiindex` (:pull:`8094`). + By `Benoît Bovy `_. + +Bug fixes +~~~~~~~~~ + +- Improved static typing of reduction methods (:pull:`6746`). + By `Richard Kleijn `_. +- Fix bug where empty attrs would generate inconsistent tokens (:issue:`6970`, :pull:`8101`). + By `Mattia Almansi `_. +- Improved handling of multi-coordinate indexes when updating coordinates, including bug fixes + (and improved warnings for deprecated features) for pandas multi-indexes (:pull:`8094`). + By `Benoît Bovy `_. +- Fixed a bug in :py:func:`merge` with ``compat='minimal'`` where the coordinate + names were not updated properly internally (:issue:`7405`, :issue:`7588`, + :pull:`8104`). + By `Benoît Bovy `_. +- Fix bug where :py:class:`DataArray` instances on the right-hand side + of :py:meth:`DataArray.__setitem__` lose dimension names (:issue:`7030`, :pull:`8067`). + By `Darsh Ranjan `_. +- Return ``float64`` in presence of ``NaT`` in :py:class:`~core.accessor_dt.DatetimeAccessor` and + special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar` + (:issue:`7928`, :pull:`8084`). + By `Kai Mühlbauer `_. +- Fix :py:meth:`~core.rolling.DatasetRolling.construct` with stride on Datasets without indexes. + (:issue:`7021`, :pull:`7578`). + By `Amrest Chinkamol `_ and `Michael Niklas `_. +- Calling plot with kwargs ``col``, ``row`` or ``hue`` no longer squeezes dimensions passed via these arguments + (:issue:`7552`, :pull:`8174`). + By `Wiktor Kraśnicki `_. +- Fixed a bug where casting from ``float`` to ``int64`` (undefined for ``NaN``) led to varying issues (:issue:`7817`, :issue:`7942`, :issue:`7790`, :issue:`6191`, :issue:`7096`, + :issue:`1064`, :pull:`7827`). + By `Kai Mühlbauer `_. +- Fixed a bug where inaccurate ``coordinates`` silently failed to decode variable (:issue:`1809`, :pull:`8195`). + By `Kai Mühlbauer `_ +- ``.rolling_exp`` functions no longer mistakenly lose non-dimensioned coords + (:issue:`6528`, :pull:`8114`). + By `Maximilian Roos `_. +- In the event that user-provided datetime64/timedelta64 units and integer dtype encoding parameters conflict with each other, override the units to preserve an integer dtype for most faithful serialization to disk (:issue:`1064`, :pull:`8201`). + By `Kai Mühlbauer `_. +- Static typing of dunder ops methods (like :py:meth:`DataArray.__eq__`) has been fixed. + Remaining issues are upstream problems (:issue:`7780`, :pull:`8204`). + By `Michael Niklas `_. +- Fix type annotation for ``center`` argument of plotting methods (like :py:meth:`xarray.plot.dataarray_plot.pcolormesh`) (:pull:`8261`). + By `Pieter Eendebak `_. + +Documentation +~~~~~~~~~~~~~ + +- Make documentation of :py:meth:`DataArray.where` clearer (:issue:`7767`, :pull:`7955`). + By `Riulinchen `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Many error messages related to invalid dimensions or coordinates now always show the list of valid dims/coords (:pull:`8079`). + By `András Gunyhó `_. +- Refactor of encoding and decoding times/timedeltas to preserve nanosecond resolution in arrays that contain missing values (:pull:`7827`). + By `Kai Mühlbauer `_. +- Transition ``.rolling_exp`` functions to use `.apply_ufunc` internally rather + than `.reduce`, as the start of a broader effort to move non-reducing + functions away from ```.reduce``, (:pull:`8114`). + By `Maximilian Roos `_. +- Test range of fill_value's in test_interpolate_pd_compat (:issue:`8146`, :pull:`8189`). + By `Kai Mühlbauer `_. + +.. _whats-new.2023.08.0: + +v2023.08.0 (Aug 18, 2023) +------------------------- + +This release brings changes to minimum dependencies, allows reading of datasets where a dimension name is +associated with a multidimensional variable (e.g. finite volume ocean model output), and introduces +a new :py:class:`xarray.Coordinates` object. + +Thanks to the 16 contributors to this release: Anderson Banihirwe, Articoking, Benoit Bovy, Deepak Cherian, Harshitha, Ian Carroll, +Joe Hamman, Justus Magin, Peter Hill, Rachel Wegener, Riley Kuttruff, Thomas Nicholas, Tom Nicholas, ilgast, quantsnus, vallirep + +Announcements +~~~~~~~~~~~~~ + +The :py:class:`xarray.Variable` class is being refactored out to a new project title 'namedarray'. +See the `design doc `_ for more +details. Reach out to us on this [discussion topic](https://github.com/pydata/xarray/discussions/8080) if you have any thoughts. + +New Features +~~~~~~~~~~~~ + +- :py:class:`Coordinates` can now be constructed independently of any Dataset or + DataArray (it is also returned by the :py:attr:`Dataset.coords` and + :py:attr:`DataArray.coords` properties). ``Coordinates`` objects are useful for + passing both coordinate variables and indexes to new Dataset / DataArray objects, + e.g., via their constructor or via :py:meth:`Dataset.assign_coords`. We may also + wrap coordinate variables in a ``Coordinates`` object in order to skip + the automatic creation of (pandas) indexes for dimension coordinates. + The :py:class:`Coordinates.from_pandas_multiindex` constructor may be used to + create coordinates directly from a :py:class:`pandas.MultiIndex` object (it is + preferred over passing it directly as coordinate data, which may be deprecated soon). + Like Dataset and DataArray objects, ``Coordinates`` objects may now be used in + :py:func:`align` and :py:func:`merge`. + (:issue:`6392`, :pull:`7368`). + By `Benoît Bovy `_. +- Visually group together coordinates with the same indexes in the index section of the text repr (:pull:`7225`). + By `Justus Magin `_. +- Allow creating Xarray objects where a multidimensional variable shares its name + with a dimension. Examples include output from finite volume models like FVCOM. + (:issue:`2233`, :pull:`7989`) + By `Deepak Cherian `_ and `Benoit Bovy `_. +- When outputting :py:class:`Dataset` objects as Zarr via :py:meth:`Dataset.to_zarr`, + user can now specify that chunks that will contain no valid data will not be written. + Originally, this could be done by specifying ``"write_empty_chunks": True`` in the + ``encoding`` parameter; however, this setting would not carry over when appending new + data to an existing dataset. (:issue:`8009`) Requires ``zarr>=2.11``. + + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The minimum versions of some dependencies were changed (:pull:`8022`): + + ===================== ========= ======== + Package Old New + ===================== ========= ======== + boto3 1.20 1.24 + cftime 1.5 1.6 + dask-core 2022.1 2022.7 + distributed 2022.1 2022.7 + hfnetcdf 0.13 1.0 + iris 3.1 3.2 + lxml 4.7 4.9 + netcdf4 1.5.7 1.6.0 + numpy 1.21 1.22 + pint 0.18 0.19 + pydap 3.2 3.3 + rasterio 1.2 1.3 + scipy 1.7 1.8 + toolz 0.11 0.12 + typing_extensions 4.0 4.3 + zarr 2.10 2.12 + numbagg 0.1 0.2.1 + ===================== ========= ======== + +Documentation +~~~~~~~~~~~~~ + +- Added page on the internal design of xarray objects. + (:pull:`7991`) By `Tom Nicholas `_. +- Added examples to docstrings of :py:meth:`Dataset.assign_attrs`, :py:meth:`Dataset.broadcast_equals`, + :py:meth:`Dataset.equals`, :py:meth:`Dataset.identical`, :py:meth:`Dataset.expand_dims`,:py:meth:`Dataset.drop_vars` + (:issue:`6793`, :pull:`7937`) By `Harshitha `_. +- Add docstrings for the :py:class:`Index` base class and add some documentation on how to + create custom, Xarray-compatible indexes (:pull:`6975`) + By `Benoît Bovy `_. +- Added a page clarifying the role of Xarray core team members. + (:pull:`7999`) By `Tom Nicholas `_. +- Fixed broken links in "See also" section of :py:meth:`Dataset.count` (:issue:`8055`, :pull:`8057`) + By `Articoking `_. +- Extended the glossary by adding terms Aligning, Broadcasting, Merging, Concatenating, Combining, lazy, + labeled, serialization, indexing (:issue:`3355`, :pull:`7732`) + By `Harshitha `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- :py:func:`as_variable` now consistently includes the variable name in any exceptions + raised. (:pull:`7995`). By `Peter Hill `_ +- :py:func:`encode_dataset_coordinates` now sorts coordinates automatically assigned to + `coordinates` attributes during serialization (:issue:`8026`, :pull:`8034`). + `By Ian Carroll `_. + +.. _whats-new.2023.07.0: + +v2023.07.0 (July 17, 2023) +-------------------------- + +This release brings improvements to the documentation on wrapping numpy-like arrays, improved docstrings, and bug fixes. + +Deprecations +~~~~~~~~~~~~ + +- `hue_style` is being deprecated for scatter plots. (:issue:`7907`, :pull:`7925`). + By `Jimmy Westling `_. + +Bug fixes +~~~~~~~~~ + +- Ensure no forward slashes in variable and dimension names for HDF5-based engines. + (:issue:`7943`, :pull:`7953`) By `Kai Mühlbauer `_. + +Documentation +~~~~~~~~~~~~~ + +- Added examples to docstrings of :py:meth:`Dataset.assign_attrs`, :py:meth:`Dataset.broadcast_equals`, + :py:meth:`Dataset.equals`, :py:meth:`Dataset.identical`, :py:meth:`Dataset.expand_dims`,:py:meth:`Dataset.drop_vars` + (:issue:`6793`, :pull:`7937`) By `Harshitha `_. +- Added page on wrapping chunked numpy-like arrays as alternatives to dask arrays. + (:pull:`7951`) By `Tom Nicholas `_. +- Expanded the page on wrapping numpy-like "duck" arrays. + (:pull:`7911`) By `Tom Nicholas `_. +- Added examples to docstrings of :py:meth:`Dataset.isel`, :py:meth:`Dataset.reduce`, :py:meth:`Dataset.argmin`, + :py:meth:`Dataset.argmax` (:issue:`6793`, :pull:`7881`) + By `Harshitha `_ . + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Allow chunked non-dask arrays (i.e. Cubed arrays) in groupby operations. (:pull:`7941`) + By `Tom Nicholas `_. + + +.. _whats-new.2023.06.0: + +v2023.06.0 (June 21, 2023) +-------------------------- + +This release adds features to ``curvefit``, improves the performance of concatenation, and fixes various bugs. + +Thank to our 13 contributors to this release: +Anderson Banihirwe, Deepak Cherian, dependabot[bot], Illviljan, Juniper Tyree, Justus Magin, Martin Fleischmann, +Mattia Almansi, mgunyho, Rutger van Haasteren, Thomas Nicholas, Tom Nicholas, Tom White. + + +New Features +~~~~~~~~~~~~ + +- Added support for multidimensional initial guess and bounds in :py:meth:`DataArray.curvefit` (:issue:`7768`, :pull:`7821`). + By `András Gunyhó `_. +- Add an ``errors`` option to :py:meth:`Dataset.curve_fit` that allows + returning NaN for the parameters and covariances of failed fits, rather than + failing the whole series of fits (:issue:`6317`, :pull:`7891`). + By `Dominik Stańczak `_ and `András Gunyhó `_. Breaking changes ~~~~~~~~~~~~~~~~ +Deprecations +~~~~~~~~~~~~ +- Deprecate the `cdms2 `_ conversion methods (:pull:`7876`) + By `Justus Magin `_. + +Performance +~~~~~~~~~~~ +- Improve concatenation performance (:issue:`7833`, :pull:`7824`). + By `Jimmy Westling `_. + +Bug fixes +~~~~~~~~~ +- Fix bug where weighted ``polyfit`` were changing the original object (:issue:`5644`, :pull:`7900`). + By `Mattia Almansi `_. +- Don't call ``CachingFileManager.__del__`` on interpreter shutdown (:issue:`7814`, :pull:`7880`). + By `Justus Magin `_. +- Preserve vlen dtype for empty string arrays (:issue:`7328`, :pull:`7862`). + By `Tom White `_ and `Kai Mühlbauer `_. +- Ensure dtype of reindex result matches dtype of the original DataArray (:issue:`7299`, :pull:`7917`) + By `Anderson Banihirwe `_. +- Fix bug where a zero-length zarr ``chunk_store`` was ignored as if it was ``None`` (:pull:`7923`) + By `Juniper Tyree `_. + +Documentation +~~~~~~~~~~~~~ + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Minor improvements to support of the python `array api standard `_, + internally using the function ``xp.astype()`` instead of the method ``arr.astype()``, as the latter is not in the standard. + (:pull:`7847`) By `Tom Nicholas `_. +- Xarray now uploads nightly wheels to https://pypi.anaconda.org/scientific-python-nightly-wheels/simple/ (:issue:`7863`, :pull:`7865`). + By `Martin Fleischmann `_. +- Stop uploading development wheels to TestPyPI (:pull:`7889`) + By `Justus Magin `_. +- Added an exception catch for ``AttributeError`` along with ``ImportError`` when duck typing the dynamic imports in pycompat.py. This catches some name collisions between packages. (:issue:`7870`, :pull:`7874`) + +.. _whats-new.2023.05.0: + +v2023.05.0 (May 18, 2023) +------------------------- + +This release adds some new methods and operators, updates our deprecation policy for python versions, fixes some bugs with groupby, +and introduces experimental support for alternative chunked parallel array computation backends via a new plugin system! + +**Note:** If you are using a locally-installed development version of xarray then pulling the changes from this release may require you to re-install. +This avoids an error where xarray cannot detect dask via the new entrypoints system introduced in :pull:`7019`. See :issue:`7856` for details. + +Thanks to our 14 contributors: +Alan Brammer, crusaderky, David Stansby, dcherian, Deeksha, Deepak Cherian, Illviljan, James McCreight, +Joe Hamman, Justus Magin, Kyle Sunden, Max Hollmann, mgunyho, and Tom Nicholas + + +New Features +~~~~~~~~~~~~ +- Added new method :py:meth:`DataArray.to_dask_dataframe`, convert a dataarray into a dask dataframe (:issue:`7409`). + By `Deeksha `_. +- Add support for lshift and rshift binary operators (``<<``, ``>>``) on + :py:class:`xr.DataArray` of type :py:class:`int` (:issue:`7727` , :pull:`7741`). + By `Alan Brammer `_. +- Keyword argument `data='array'` to both :py:meth:`xarray.Dataset.to_dict` and + :py:meth:`xarray.DataArray.to_dict` will now return data as the underlying array type. + Python lists are returned for `data='list'` or `data=True`. Supplying `data=False` only returns the schema without data. + ``encoding=True`` returns the encoding dictionary for the underlying variable also. (:issue:`1599`, :pull:`7739`) . + By `James McCreight `_. + +Breaking changes +~~~~~~~~~~~~~~~~ +- adjust the deprecation policy for python to once again align with NEP-29 (:issue:`7765`, :pull:`7793`) + By `Justus Magin `_. + +Performance +~~~~~~~~~~~ +- Optimize ``.dt `` accessor performance with ``CFTimeIndex``. (:pull:`7796`) + By `Deepak Cherian `_. + +Bug fixes +~~~~~~~~~ +- Fix `as_compatible_data` for masked float arrays, now always creates a copy when mask is present (:issue:`2377`, :pull:`7788`). + By `Max Hollmann `_. +- Fix groupby binary ops when grouped array is subset relative to other. (:issue:`7797`). + By `Deepak Cherian `_. +- Fix groupby sum, prod for all-NaN groups with ``flox``. (:issue:`7808`). + By `Deepak Cherian `_. + +Internal Changes +~~~~~~~~~~~~~~~~ +- Experimental support for wrapping chunked array libraries other than dask. + A new ABC is defined - :py:class:`xr.core.parallelcompat.ChunkManagerEntrypoint` - which can be subclassed and then + registered by alternative chunked array implementations. (:issue:`6807`, :pull:`7019`) + By `Tom Nicholas `_. + + +.. _whats-new.2023.04.2: + +v2023.04.2 (April 20, 2023) +--------------------------- + +This is a patch release to fix a bug with binning (:issue:`7766`) + +Bug fixes +~~~~~~~~~ + +- Fix binning when ``labels`` is specified. (:issue:`7766`). + By `Deepak Cherian `_. + + +Documentation +~~~~~~~~~~~~~ +- Added examples to docstrings for :py:meth:`xarray.core.accessor_str.StringAccessor` methods. + (:pull:`7669`) . + By `Mary Gathoni `_. + + +.. _whats-new.2023.04.1: + +v2023.04.1 (April 18, 2023) +--------------------------- + +This is a patch release to fix a bug with binning (:issue:`7759`) + +Bug fixes +~~~~~~~~~ + +- Fix binning by unsorted arrays. (:issue:`7759`) + + +.. _whats-new.2023.04.0: + +v2023.04.0 (April 14, 2023) +--------------------------- + +This release includes support for pandas v2, allows refreshing of backend engines in a session, and removes deprecated backends +for ``rasterio`` and ``cfgrib``. + +Thanks to our 19 contributors: +Chinemere, Tom Coleman, Deepak Cherian, Harshitha, Illviljan, Jessica Scheick, Joe Hamman, Justus Magin, Kai Mühlbauer, Kwonil-Kim, Mary Gathoni, Michael Niklas, Pierre, Scott Henderson, Shreyal Gupta, Spencer Clark, mccloskey, nishtha981, veenstrajelmer + +We welcome the following new contributors to Xarray!: +Mary Gathoni, Harshitha, veenstrajelmer, Chinemere, nishtha981, Shreyal Gupta, Kwonil-Kim, mccloskey. + +New Features +~~~~~~~~~~~~ +- New methods to reset an objects encoding (:py:meth:`Dataset.reset_encoding`, :py:meth:`DataArray.reset_encoding`). + (:issue:`7686`, :pull:`7689`). + By `Joe Hamman `_. +- Allow refreshing backend engines with :py:meth:`xarray.backends.refresh_engines` (:issue:`7478`, :pull:`7523`). + By `Michael Niklas `_. +- Added ability to save ``DataArray`` objects directly to Zarr using :py:meth:`~xarray.DataArray.to_zarr`. + (:issue:`7692`, :pull:`7693`) . + By `Joe Hamman `_. + +Breaking changes +~~~~~~~~~~~~~~~~ +- Remove deprecated rasterio backend in favor of rioxarray (:pull:`7392`). + By `Scott Henderson `_. + Deprecations ~~~~~~~~~~~~ +Performance +~~~~~~~~~~~ +- Optimize alignment with ``join="exact", copy=False`` by avoiding copies. (:pull:`7736`) + By `Deepak Cherian `_. +- Avoid unnecessary copies of ``CFTimeIndex``. (:pull:`7735`) + By `Deepak Cherian `_. Bug fixes ~~~~~~~~~ @@ -43,6 +1077,15 @@ Bug fixes By `Thomas Coleman `_. - Proper plotting when passing :py:class:`~matplotlib.colors.BoundaryNorm` type argument in :py:meth:`DataArray.plot`. (:issue:`4061`, :issue:`7014`,:pull:`7553`) By `Jelmer Veenstra `_. +- Ensure the formatting of time encoding reference dates outside the range of + nanosecond-precision datetimes remains the same under pandas version 2.0.0 + (:issue:`7420`, :pull:`7441`). + By `Justus Magin `_ and + `Spencer Clark `_. +- Various `dtype` related fixes needed to support `pandas>=2.0` (:pull:`7724`) + By `Justus Magin `_. +- Preserve boolean dtype within encoding (:issue:`7652`, :pull:`7720`). + By `Kai Mühlbauer `_ Documentation ~~~~~~~~~~~~~ @@ -53,9 +1096,25 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Don't assume that arrays read from disk will be Numpy arrays. This is a step toward + enabling reads from a Zarr store using the `Kvikio `_ + or `TensorStore `_ libraries. + (:pull:`6874`). By `Deepak Cherian `_. + - Remove internal support for reading GRIB files through the ``cfgrib`` backend. ``cfgrib`` now uses the external backend interface, so no existing code should break. By `Deepak Cherian `_. +- Implement CF coding functions in ``VariableCoders`` (:pull:`7719`). + By `Kai Mühlbauer `_ + +- Added a config.yml file with messages for the welcome bot when a Github user creates their first ever issue or pull request or has their first PR merged. (:issue:`7685`, :pull:`7685`) + By `Nishtha P `_. + +- Ensure that only nanosecond-precision :py:class:`pd.Timestamp` objects + continue to be used internally under pandas version 2.0.0. This is mainly to + ease the transition to this latest version of pandas. It should be relaxed + when addressing :issue:`7493`. By `Spencer Clark + `_ (:issue:`7707`, :pull:`7731`). .. _whats-new.2023.03.0: @@ -489,7 +1548,7 @@ Bug fixes By `Michael Niklas `_. - Fix side effects on index coordinate metadata after aligning objects. (:issue:`6852`, :pull:`6857`) By `Benoît Bovy `_. -- Make FacetGrid.set_titles send kwargs correctly using `handle.udpate(kwargs)`. (:issue:`6839`, :pull:`6843`) +- Make FacetGrid.set_titles send kwargs correctly using `handle.update(kwargs)`. (:issue:`6839`, :pull:`6843`) By `Oliver Lopez `_. - Fix bug where index variables would be changed inplace. (:issue:`6931`, :pull:`6938`) By `Michael Niklas `_. @@ -525,6 +1584,7 @@ Bug fixes Documentation ~~~~~~~~~~~~~ + - Update merge docstrings. (:issue:`6935`, :pull:`7033`) By `Zach Moon `_. - Raise a more informative error when trying to open a non-existent zarr store. (:issue:`6484`, :pull:`7060`) @@ -3905,7 +4965,7 @@ Enhancements - New PseudoNetCDF backend for many Atmospheric data formats including GEOS-Chem, CAMx, NOAA arlpacked bit and many others. See - :ref:`io.PseudoNetCDF` for more details. + ``io.PseudoNetCDF`` for more details. By `Barron Henderson `_. - The :py:class:`Dataset` constructor now aligns :py:class:`DataArray` @@ -4351,7 +5411,7 @@ Bug fixes - Corrected a bug with incorrect coordinates for non-georeferenced geotiff files (:issue:`1686`). Internally, we now use the rasterio coordinate transform tool instead of doing the computations ourselves. A - ``parse_coordinates`` kwarg has beed added to :py:func:`~open_rasterio` + ``parse_coordinates`` kwarg has been added to :py:func:`~open_rasterio` (set to ``True`` per default). By `Fabien Maussion `_. - The colors of discrete colormaps are now the same regardless if `seaborn` @@ -6094,7 +7154,7 @@ Backwards incompatible changes Enhancements ~~~~~~~~~~~~ -- New ``xray.Dataset.to_array`` and enhanced +- New ``xray.Dataset.to_dataarray`` and enhanced ``xray.DataArray.to_dataset`` methods make it easy to switch back and forth between arrays and datasets: @@ -6105,8 +7165,8 @@ Enhancements coords={"c": 42}, attrs={"Conventions": "None"}, ) - ds.to_array() - ds.to_array().to_dataset(dim="variable") + ds.to_dataarray() + ds.to_dataarray().to_dataset(dim="variable") - New ``xray.Dataset.fillna`` method to fill missing values, modeled off the pandas method of the same name: @@ -6367,10 +7427,16 @@ Breaking changes - The ``season`` datetime shortcut now returns an array of string labels such `'DJF'`: - .. ipython:: python + .. code-block:: ipython + + In[92]: ds = xray.Dataset({"t": pd.date_range("2000-01-01", periods=12, freq="M")}) - ds = xray.Dataset({"t": pd.date_range("2000-01-01", periods=12, freq="M")}) - ds["t.season"] + In[93]: ds["t.season"] + Out[93]: + + array(['DJF', 'DJF', 'MAM', ..., 'SON', 'SON', 'DJF'], dtype='=1.23", + "packaging>=22", + "pandas>=1.5", +] + +[project.optional-dependencies] +accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"] +complete = ["xarray[accel,io,parallel,viz,dev]"] +dev = [ + "hypothesis", + "pre-commit", + "pytest", + "pytest-cov", + "pytest-env", + "pytest-xdist", + "pytest-timeout", + "ruff", + "xarray[complete]", +] +io = ["netCDF4", "h5netcdf", "scipy", 'pydap; python_version<"3.10"', "zarr", "fsspec", "cftime", "pooch"] +parallel = ["dask[complete]"] +viz = ["matplotlib", "seaborn", "nc-time-axis"] + +[project.urls] +Documentation = "https://docs.xarray.dev" +SciPy2015-talk = "https://www.youtube.com/watch?v=X0pAhJgySxk" +homepage = "https://xarray.dev/" +issue-tracker = "https://github.com/pydata/xarray/issues" +source-code = "https://github.com/pydata/xarray" + +[project.entry-points."xarray.chunkmanagers"] +dask = "xarray.namedarray.daskmanager:DaskManager" + [build-system] build-backend = "setuptools.build_meta" requires = [ @@ -5,8 +63,11 @@ requires = [ "setuptools-scm>=7", ] +[tool.setuptools] +packages = ["xarray"] + [tool.setuptools_scm] -fallback_version = "999" +fallback_version = "9999" [tool.coverage.run] omit = [ @@ -23,27 +84,40 @@ source = ["xarray"] exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"] [tool.mypy] -exclude = 'xarray/util/generate_.*\.py' +enable_error_code = "redundant-self" +exclude = [ + 'xarray/util/generate_.*\.py', + 'xarray/datatree_/.*\.py', +] files = "xarray" show_error_codes = true +show_error_context = true +warn_redundant_casts = true +warn_unused_configs = true warn_unused_ignores = true -# Most of the numerical computing stack doesn't have type annotations yet. +# Ignore mypy errors for modules imported from datatree_. +[[tool.mypy.overrides]] +module = "xarray.datatree_.*" +ignore_errors = true + +# Much of the numerical computing stack doesn't have type annotations yet. [[tool.mypy.overrides]] ignore_missing_imports = true module = [ "affine.*", "bottleneck.*", "cartopy.*", - "cdms2.*", "cf_units.*", "cfgrib.*", "cftime.*", + "cloudpickle.*", + "cubed.*", "cupy.*", + "dask.types.*", "fsspec.*", "h5netcdf.*", "h5py.*", - "importlib_metadata.*", "iris.*", "matplotlib.*", "mpl_toolkits.*", @@ -52,51 +126,184 @@ module = [ "numbagg.*", "netCDF4.*", "netcdftime.*", + "opt_einsum.*", "pandas.*", "pooch.*", - "PseudoNetCDF.*", "pydap.*", "pytest.*", - "rasterio.*", "scipy.*", "seaborn.*", "setuptools", "sparse.*", "toolz.*", "zarr.*", + "numpy.exceptions.*", # remove once support for `numpy<2.0` has been dropped + "array_api_strict.*", ] +# Gradually we want to add more modules to this list, ratcheting up our total +# coverage. Once a module is here, functions are checked by mypy regardless of +# whether they have type annotations. It would be especially useful to have test +# files listed here, because without them being checked, we don't have a great +# way of testing our annotations. [[tool.mypy.overrides]] -ignore_errors = true -module = [] +check_untyped_defs = true +module = [ + "xarray.core.accessor_dt", + "xarray.core.accessor_str", + "xarray.core.alignment", + "xarray.core.computation", + "xarray.core.rolling_exp", + "xarray.indexes.*", + "xarray.tests.*", +] +# This then excludes some modules from the above list. (So ideally we remove +# from here in time...) +[[tool.mypy.overrides]] +check_untyped_defs = false +module = [ + "xarray.tests.test_coarsen", + "xarray.tests.test_coding_times", + "xarray.tests.test_combine", + "xarray.tests.test_computation", + "xarray.tests.test_concat", + "xarray.tests.test_coordinates", + "xarray.tests.test_dask", + "xarray.tests.test_dataarray", + "xarray.tests.test_duck_array_ops", + "xarray.tests.test_groupby", + "xarray.tests.test_indexing", + "xarray.tests.test_merge", + "xarray.tests.test_missing", + "xarray.tests.test_parallelcompat", + "xarray.tests.test_plot", + "xarray.tests.test_sparse", + "xarray.tests.test_ufuncs", + "xarray.tests.test_units", + "xarray.tests.test_utils", + "xarray.tests.test_variable", + "xarray.tests.test_weighted", +] + +# Use strict = true whenever namedarray has become standalone. In the meantime +# don't forget to add all new files related to namedarray here: +# ref: https://mypy.readthedocs.io/en/stable/existing_code.html#introduce-stricter-options +[[tool.mypy.overrides]] +# Start off with these +warn_unused_ignores = true + +# Getting these passing should be easy +strict_concatenate = true +strict_equality = true + +# Strongly recommend enabling this one as soon as you can +check_untyped_defs = true + +# These shouldn't be too much additional work, but may be tricky to +# get passing if you use a lot of untyped libraries +disallow_any_generics = true +disallow_subclassing_any = true +disallow_untyped_decorators = true + +# These next few are various gradations of forcing use of type annotations +disallow_incomplete_defs = true +disallow_untyped_calls = true +disallow_untyped_defs = true + +# This one isn't too hard to get passing, but return on investment is lower +no_implicit_reexport = true + +# This one can be tricky to get passing if you use a lot of untyped libraries +warn_return_any = true + +module = ["xarray.namedarray.*", "xarray.tests.test_namedarray"] + +[tool.pyright] +# include = ["src"] +# exclude = ["**/node_modules", +# "**/__pycache__", +# "src/experimental", +# "src/typestubs" +# ] +# ignore = ["src/oldstuff"] +defineConstant = {DEBUG = true} +# stubPath = "src/stubs" +# venv = "env367" + +# Enabling this means that developers who have disabled the warning locally — +# because not all dependencies are installable — are overridden +# reportMissingImports = true +reportMissingTypeStubs = false + +# pythonVersion = "3.6" +# pythonPlatform = "Linux" + +# executionEnvironments = [ +# { root = "src/web", pythonVersion = "3.5", pythonPlatform = "Windows", extraPaths = [ "src/service_libs" ] }, +# { root = "src/sdk", pythonVersion = "3.0", extraPaths = [ "src/backend" ] }, +# { root = "src/tests", extraPaths = ["src/tests/e2e", "src/sdk" ]}, +# { root = "src" } +# ] [tool.ruff] -target-version = "py39" builtins = ["ellipsis"] -exclude = [ - ".eggs", - "doc", - "_typed_ops.pyi", +extend-exclude = [ + "doc", + "_typed_ops.pyi", ] +target-version = "py39" + +[tool.ruff.lint] # E402: module level import not at top of file # E501: line too long - let black worry about that # E731: do not assign a lambda expression, use a def ignore = [ - "E402", - "E501", - "E731", + "E402", + "E501", + "E731", ] select = [ - # Pyflakes - "F", - # Pycodestyle - "E", - "W", - # isort - "I", - # Pyupgrade - "UP", -] - -[tool.ruff.isort] + "F", # Pyflakes + "E", # Pycodestyle + "W", + "TID", # flake8-tidy-imports (absolute imports) + "I", # isort + "UP", # Pyupgrade +] +extend-safe-fixes = [ + "TID252", # absolute imports +] + +[tool.ruff.lint.per-file-ignores] +# don't enforce absolute imports +"asv_bench/**" = ["TID252"] + +[tool.ruff.lint.isort] known-first-party = ["xarray"] + +[tool.ruff.lint.flake8-tidy-imports] +# Disallow all relative imports. +ban-relative-imports = "all" + +[tool.pytest.ini_options] +addopts = ["--strict-config", "--strict-markers"] +filterwarnings = [ + "ignore:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning", +] +log_cli_level = "INFO" +markers = [ + "flaky: flaky tests", + "network: tests requiring a network connection", + "slow: slow tests", +] +minversion = "7" +python_files = "test_*.py" +testpaths = ["xarray/tests", "properties"] + +[tool.aliases] +test = "pytest" + +[tool.repo-review] +ignore = [ + "PP308", # This option creates a large amount of log lines. +] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 20638d267a7..00000000000 --- a/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -# This file is redundant with setup.cfg; -# it exists to let GitHub build the repository dependency graph -# https://help.github.com/en/github/visualizing-repository-data-with-graphs/listing-the-packages-that-a-repository-depends-on - -numpy >= 1.21 -packaging >= 21.3 -pandas >= 1.4, <2 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 5d5cf161195..00000000000 --- a/setup.cfg +++ /dev/null @@ -1,151 +0,0 @@ -[metadata] -name = xarray -author = xarray Developers -author_email = xarray@googlegroups.com -license = Apache-2.0 -description = N-D labeled arrays and datasets in Python -long_description_content_type=text/x-rst -long_description = - **xarray** (formerly **xray**) is an open source project and Python package - that makes working with labelled multi-dimensional arrays simple, - efficient, and fun! - - xarray introduces labels in the form of dimensions, coordinates and - attributes on top of raw NumPy_-like arrays, which allows for a more - intuitive, more concise, and less error-prone developer experience. - The package includes a large and growing library of domain-agnostic functions - for advanced analytics and visualization with these data structures. - - xarray was inspired by and borrows heavily from pandas_, the popular data - analysis package focused on labelled tabular data. - It is particularly tailored to working with netCDF_ files, which were the - source of xarray's data model, and integrates tightly with dask_ for parallel - computing. - - .. _NumPy: https://www.numpy.org - .. _pandas: https://pandas.pydata.org - .. _dask: https://dask.org - .. _netCDF: https://www.unidata.ucar.edu/software/netcdf - - Why xarray? - ----------- - Multi-dimensional (a.k.a. N-dimensional, ND) arrays (sometimes called - "tensors") are an essential part of computational science. - They are encountered in a wide range of fields, including physics, astronomy, - geoscience, bioinformatics, engineering, finance, and deep learning. - In Python, NumPy_ provides the fundamental data structure and API for - working with raw ND arrays. - However, real-world datasets are usually more than just raw numbers; - they have labels which encode information about how the array values map - to locations in space, time, etc. - - xarray doesn't just keep track of labels on arrays -- it uses them to provide a - powerful and concise interface. For example: - - - Apply operations over dimensions by name: ``x.sum('time')``. - - Select values by label instead of integer location: ``x.loc['2014-01-01']`` or ``x.sel(time='2014-01-01')``. - - Mathematical operations (e.g., ``x - y``) vectorize across multiple dimensions (array broadcasting) based on dimension names, not shape. - - Flexible split-apply-combine operations with groupby: ``x.groupby('time.dayofyear').mean()``. - - Database like alignment based on coordinate labels that smoothly handles missing values: ``x, y = xr.align(x, y, join='outer')``. - - Keep track of arbitrary metadata in the form of a Python dictionary: ``x.attrs``. - - Learn more - ---------- - - Documentation: ``_ - - Issue tracker: ``_ - - Source code: ``_ - - SciPy2015 talk: ``_ - -url = https://github.com/pydata/xarray -classifiers = - Development Status :: 5 - Production/Stable - License :: OSI Approved :: Apache Software License - Operating System :: OS Independent - Intended Audience :: Science/Research - Programming Language :: Python - Programming Language :: Python :: 3 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Programming Language :: Python :: 3.11 - Topic :: Scientific/Engineering - -[options] -packages = find: -zip_safe = False # https://mypy.readthedocs.io/en/latest/installed_packages.html -include_package_data = True -python_requires = >=3.9 -install_requires = - numpy >= 1.21 # recommended to use >= 1.22 for full quantile method support - pandas >= 1.4, <2 - packaging >= 21.3 - -[options.extras_require] -io = - netCDF4 - h5netcdf - scipy - pydap; python_version<"3.10" # see https://github.com/pydap/pydap/issues/268 - zarr - fsspec - cftime - rasterio - pooch - ## Scitools packages & dependencies (e.g: cartopy, cf-units) can be hard to install - # scitools-iris - -accel = - scipy - bottleneck - numbagg - flox - -parallel = - dask[complete] - -viz = - matplotlib - seaborn - nc-time-axis - ## Cartopy requires 3rd party libraries and only provides source distributions - ## See: https://github.com/SciTools/cartopy/issues/805 - # cartopy - -complete = - %(io)s - %(accel)s - %(parallel)s - %(viz)s - -docs = - %(complete)s - sphinx-autosummary-accessors - sphinx_rtd_theme - ipython - ipykernel - jupyter-client - nbsphinx - scanpydoc - -[options.package_data] -xarray = - py.typed - tests/data/* - static/css/* - static/html/* - -[tool:pytest] -python_files = test_*.py -testpaths = xarray/tests properties -# Fixed upstream in https://github.com/pydata/bottleneck/pull/199 -filterwarnings = - ignore:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning -markers = - flaky: flaky tests - network: tests requiring a network connection - slow: slow tests - -[aliases] -test = pytest - -[pytest-watch] -nobeep = True diff --git a/setup.py b/setup.py index 088d7e4eac6..69343515fd5 100755 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ #!/usr/bin/env python from setuptools import setup -setup(use_scm_version={"fallback_version": "999"}) +setup(use_scm_version={"fallback_version": "9999"}) diff --git a/xarray/__init__.py b/xarray/__init__.py index d064502c20b..0c0d5995f72 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -1,3 +1,5 @@ +from importlib.metadata import version as _version + from xarray import testing, tutorial from xarray.backends.api import ( load_dataarray, @@ -7,7 +9,6 @@ open_mfdataset, save_mfdataset, ) -from xarray.backends.rasterio_ import open_rasterio from xarray.backends.zarr import open_zarr from xarray.coding.cftime_offsets import cftime_range, date_range, date_range_like from xarray.coding.cftimeindex import CFTimeIndex @@ -27,30 +28,28 @@ where, ) from xarray.core.concat import concat +from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.extensions import ( register_dataarray_accessor, register_dataset_accessor, ) +from xarray.core.indexes import Index +from xarray.core.indexing import IndexSelResult from xarray.core.merge import Context, MergeError, merge from xarray.core.options import get_options, set_options from xarray.core.parallel import map_blocks -from xarray.core.variable import Coordinate, IndexVariable, Variable, as_variable +from xarray.core.variable import IndexVariable, Variable, as_variable +from xarray.namedarray.core import NamedArray from xarray.util.print_versions import show_versions -try: - from importlib.metadata import version as _version -except ImportError: - # if the fallback library is missing, we are doomed. - from importlib_metadata import version as _version - try: __version__ = _version("xarray") except Exception: # Local copy or not installed with setuptools. # Disable minimum version checks on downstream libraries. - __version__ = "999" + __version__ = "9999" # A hardcoded __all__ variable is necessary to appease # `mypy --strict` running in projects that import xarray. @@ -85,7 +84,6 @@ "open_dataarray", "open_dataset", "open_mfdataset", - "open_rasterio", "open_zarr", "polyval", "register_dataarray_accessor", @@ -99,11 +97,14 @@ # Classes "CFTimeIndex", "Context", - "Coordinate", + "Coordinates", "DataArray", "Dataset", + "Index", + "IndexSelResult", "IndexVariable", "Variable", + "NamedArray", # Exceptions "MergeError", "SerializationWarning", diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index ca0b8fe4e6b..1c8d2d3a659 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -3,6 +3,7 @@ DataStores provide a uniform interface for saving and loading data in different formats. They should not be used directly, but rather through Dataset objects. """ + from xarray.backends.common import AbstractDataStore, BackendArray, BackendEntrypoint from xarray.backends.file_manager import ( CachingFileManager, @@ -12,11 +13,7 @@ from xarray.backends.h5netcdf_ import H5netcdfBackendEntrypoint, H5NetCDFStore from xarray.backends.memory import InMemoryDataStore from xarray.backends.netCDF4_ import NetCDF4BackendEntrypoint, NetCDF4DataStore -from xarray.backends.plugins import list_engines -from xarray.backends.pseudonetcdf_ import ( - PseudoNetCDFBackendEntrypoint, - PseudoNetCDFDataStore, -) +from xarray.backends.plugins import list_engines, refresh_engines from xarray.backends.pydap_ import PydapBackendEntrypoint, PydapDataStore from xarray.backends.pynio_ import NioDataStore from xarray.backends.scipy_ import ScipyBackendEntrypoint, ScipyDataStore @@ -37,13 +34,12 @@ "ScipyDataStore", "H5NetCDFStore", "ZarrStore", - "PseudoNetCDFDataStore", "H5netcdfBackendEntrypoint", "NetCDF4BackendEntrypoint", - "PseudoNetCDFBackendEntrypoint", "PydapBackendEntrypoint", "ScipyBackendEntrypoint", "StoreBackendEntrypoint", "ZarrBackendEntrypoint", "list_engines", + "refresh_engines", ] diff --git a/xarray/backends/api.py b/xarray/backends/api.py index e5adedbb576..d3026a535e2 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -3,17 +3,31 @@ import os from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence from functools import partial -from glob import glob from io import BytesIO from numbers import Number -from typing import TYPE_CHECKING, Any, Callable, Final, Literal, Union, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Final, + Literal, + Union, + cast, + overload, +) import numpy as np from xarray import backends, conventions from xarray.backends import plugins -from xarray.backends.common import AbstractDataStore, ArrayWriter, _normalize_path +from xarray.backends.common import ( + AbstractDataStore, + ArrayWriter, + _find_absolute_paths, + _normalize_path, +) from xarray.backends.locks import _get_scheduler +from xarray.backends.zarr import open_zarr from xarray.core import indexing from xarray.core.combine import ( _infer_concat_order_from_positions, @@ -23,7 +37,10 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk from xarray.core.indexes import Index +from xarray.core.types import ZarrWriteModes from xarray.core.utils import is_remote_uri +from xarray.namedarray.daskmanager import DaskManager +from xarray.namedarray.parallelcompat import guess_chunkmanager if TYPE_CHECKING: try: @@ -38,21 +55,21 @@ CompatOptions, JoinOptions, NestedSequence, + T_Chunks, ) T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] T_Engine = Union[ T_NetcdfEngine, - Literal["pydap", "pynio", "pseudonetcdf", "zarr"], + Literal["pydap", "pynio", "zarr"], type[BackendEntrypoint], str, # no nice typing support for custom backends None, ] - T_Chunks = Union[int, dict[Any, Any], Literal["auto"], None] T_NetcdfTypes = Literal[ "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC" ] - + from xarray.datatree_.datatree import DataTree DATAARRAY_NAME = "__xarray_dataarray_name__" DATAARRAY_VARIABLE = "__xarray_dataarray_variable__" @@ -63,7 +80,6 @@ "pydap": backends.PydapDataStore.open, "h5netcdf": backends.H5NetCDFStore.open, "pynio": backends.NioDataStore, - "pseudonetcdf": backends.PseudoNetCDFDataStore.open, "zarr": backends.ZarrStore.open_group, } @@ -297,17 +313,27 @@ def _chunk_ds( chunks, overwrite_encoded_chunks, inline_array, + chunked_array_type, + from_array_kwargs, **extra_tokens, ): - from dask.base import tokenize + chunkmanager = guess_chunkmanager(chunked_array_type) - mtime = _get_mtime(filename_or_obj) - token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens) - name_prefix = f"open_dataset-{token}" + # TODO refactor to move this dask-specific logic inside the DaskManager class + if isinstance(chunkmanager, DaskManager): + from dask.base import tokenize + + mtime = _get_mtime(filename_or_obj) + token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens) + name_prefix = "open_dataset-" + else: + # not used + token = (None,) + name_prefix = None variables = {} for name, var in backend_ds.variables.items(): - var_chunks = _get_chunk(var, chunks) + var_chunks = _get_chunk(var, chunks, chunkmanager) variables[name] = _maybe_chunk( name, var, @@ -316,6 +342,8 @@ def _chunk_ds( name_prefix=name_prefix, token=token, inline_array=inline_array, + chunked_array_type=chunkmanager, + from_array_kwargs=from_array_kwargs.copy(), ) return backend_ds._replace(variables) @@ -328,6 +356,8 @@ def _dataset_from_backend_dataset( cache, overwrite_encoded_chunks, inline_array, + chunked_array_type, + from_array_kwargs, **extra_tokens, ): if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}: @@ -346,6 +376,8 @@ def _dataset_from_backend_dataset( chunks, overwrite_encoded_chunks, inline_array, + chunked_array_type, + from_array_kwargs, **extra_tokens, ) @@ -373,6 +405,8 @@ def open_dataset( decode_coords: Literal["coordinates", "all"] | bool | None = None, drop_variables: str | Iterable[str] | None = None, inline_array: bool = False, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, backend_kwargs: dict[str, Any] | None = None, **kwargs, ) -> Dataset: @@ -387,7 +421,7 @@ def open_dataset( scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", \ - "pseudonetcdf", "zarr", None}, installed backend \ + "zarr", None}, installed backend \ or subclass of xarray.backends.BackendEntrypoint, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for @@ -419,8 +453,7 @@ def open_dataset( taken from variable attributes (if they exist). If the `_FillValue` or `missing_value` attribute contains multiple values a warning will be issued and all array values matching one of the multiple values will - be replaced by NA. mask_and_scale defaults to True except for the - pseudonetcdf backend. This keyword may not be supported by all the backends. + be replaced by NA. This keyword may not be supported by all the backends. decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. @@ -455,6 +488,9 @@ def open_dataset( as coordinate variables. - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and other attributes as coordinate variables. + + Only existing variables can be set as coordinates. Missing variables + will be silently ignored. drop_variables: str or iterable of str, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or @@ -465,6 +501,15 @@ def open_dataset( itself, and each chunk refers to that task by its key. With ``inline_array=True``, Dask will instead inline the array directly in the values of the task graph. See :py:func:`dask.array.from_array`. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example if :py:func:`dask.array.Array` objects are used for chunking, additional kwargs will be passed + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. backend_kwargs: dict Additional keyword arguments passed on to the engine open function, equivalent to `**kwargs`. @@ -478,7 +523,7 @@ def open_dataset( relevant when using dask or another form of parallelism. By default, appropriate locks are chosen to safely read and write files with the currently active dask scheduler. Supported by "netcdf4", "h5netcdf", - "scipy", "pynio", "pseudonetcdf". + "scipy", "pynio". See engine open function for kwargs accepted by each specific engine. @@ -508,6 +553,9 @@ def open_dataset( if engine is None: engine = plugins.guess_engine(filename_or_obj) + if from_array_kwargs is None: + from_array_kwargs = {} + backend = plugins.get_backend(engine) decoders = _resolve_decoders_kwargs( @@ -536,6 +584,8 @@ def open_dataset( cache, overwrite_encoded_chunks, inline_array, + chunked_array_type, + from_array_kwargs, drop_variables=drop_variables, **decoders, **kwargs, @@ -546,8 +596,8 @@ def open_dataset( def open_dataarray( filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, *, - engine: T_Engine = None, - chunks: T_Chunks = None, + engine: T_Engine | None = None, + chunks: T_Chunks | None = None, cache: bool | None = None, decode_cf: bool | None = None, mask_and_scale: bool | None = None, @@ -558,6 +608,8 @@ def open_dataarray( decode_coords: Literal["coordinates", "all"] | bool | None = None, drop_variables: str | Iterable[str] | None = None, inline_array: bool = False, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, backend_kwargs: dict[str, Any] | None = None, **kwargs, ) -> DataArray: @@ -576,7 +628,7 @@ def open_dataarray( scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", \ - "pseudonetcdf", "zarr", None}, installed backend \ + "zarr", None}, installed backend \ or subclass of xarray.backends.BackendEntrypoint, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for @@ -606,8 +658,7 @@ def open_dataarray( taken from variable attributes (if they exist). If the `_FillValue` or `missing_value` attribute contains multiple values a warning will be issued and all array values matching one of the multiple values will - be replaced by NA. mask_and_scale defaults to True except for the - pseudonetcdf backend. This keyword may not be supported by all the backends. + be replaced by NA. This keyword may not be supported by all the backends. decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. @@ -642,6 +693,9 @@ def open_dataarray( as coordinate variables. - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and other attributes as coordinate variables. + + Only existing variables can be set as coordinates. Missing variables + will be silently ignored. drop_variables: str or iterable of str, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or @@ -652,6 +706,15 @@ def open_dataarray( itself, and each chunk refers to that task by its key. With ``inline_array=True``, Dask will instead inline the array directly in the values of the task graph. See :py:func:`dask.array.from_array`. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example if :py:func:`dask.array.Array` objects are used for chunking, additional kwargs will be passed + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. backend_kwargs: dict Additional keyword arguments passed on to the engine open function, equivalent to `**kwargs`. @@ -665,7 +728,7 @@ def open_dataarray( relevant when using dask or another form of parallelism. By default, appropriate locks are chosen to safely read and write files with the currently active dask scheduler. Supported by "netcdf4", "h5netcdf", - "scipy", "pynio", "pseudonetcdf". + "scipy", "pynio". See engine open function for kwargs accepted by each specific engine. @@ -695,6 +758,8 @@ def open_dataarray( cache=cache, drop_variables=drop_variables, inline_array=inline_array, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, backend_kwargs=backend_kwargs, use_cftime=use_cftime, decode_timedelta=decode_timedelta, @@ -724,19 +789,49 @@ def open_dataarray( return data_array +def open_datatree( + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + engine: T_Engine = None, + **kwargs, +) -> DataTree: + """ + Open and decode a DataTree from a file or file-like object, creating one tree node for each group in the file. + + Parameters + ---------- + filename_or_obj : str, Path, file-like, or DataStore + Strings and Path objects are interpreted as a path to a netCDF file or Zarr store. + engine : str, optional + Xarray backend engine to use. Valid options include `{"netcdf4", "h5netcdf", "zarr"}`. + **kwargs : dict + Additional keyword arguments passed to :py:func:`~xarray.open_dataset` for each group. + Returns + ------- + xarray.DataTree + """ + if engine is None: + engine = plugins.guess_engine(filename_or_obj) + + backend = plugins.get_backend(engine) + + return backend.open_datatree(filename_or_obj, **kwargs) + + def open_mfdataset( paths: str | NestedSequence[str | os.PathLike], - chunks: T_Chunks = None, - concat_dim: str - | DataArray - | Index - | Sequence[str] - | Sequence[DataArray] - | Sequence[Index] - | None = None, + chunks: T_Chunks | None = None, + concat_dim: ( + str + | DataArray + | Index + | Sequence[str] + | Sequence[DataArray] + | Sequence[Index] + | None + ) = None, compat: CompatOptions = "no_conflicts", preprocess: Callable[[Dataset], Dataset] | None = None, - engine: T_Engine = None, + engine: T_Engine | None = None, data_vars: Literal["all", "minimal", "different"] | list[str] = "all", coords="different", combine: Literal["by_coords", "nested"] = "by_coords", @@ -803,7 +898,7 @@ def open_mfdataset( You can find the file-name from which each dataset was loaded in ``ds.encoding["source"]``. engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", \ - "pseudonetcdf", "zarr", None}, installed backend \ + "zarr", None}, installed backend \ or subclass of xarray.backends.BackendEntrypoint, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for @@ -870,7 +965,9 @@ def open_mfdataset( If a callable, it must expect a sequence of ``attrs`` dicts and a context object as its only parameters. **kwargs : optional - Additional arguments passed on to :py:func:`xarray.open_dataset`. + Additional arguments passed on to :py:func:`xarray.open_dataset`. For an + overview of some of the possible options, see the documentation of + :py:func:`xarray.open_dataset` Returns ------- @@ -905,43 +1002,20 @@ def open_mfdataset( ... "file_*.nc", concat_dim="time", preprocess=partial_func ... ) # doctest: +SKIP + It is also possible to use any argument to ``open_dataset`` together + with ``open_mfdataset``, such as for example ``drop_variables``: + + >>> ds = xr.open_mfdataset( + ... "file.nc", drop_variables=["varname_1", "varname_2"] # any list of vars + ... ) # doctest: +SKIP + References ---------- .. [1] https://docs.xarray.dev/en/stable/dask.html .. [2] https://docs.xarray.dev/en/stable/dask.html#chunking-and-performance """ - if isinstance(paths, str): - if is_remote_uri(paths) and engine == "zarr": - try: - from fsspec.core import get_fs_token_paths - except ImportError as e: - raise ImportError( - "The use of remote URLs for opening zarr requires the package fsspec" - ) from e - - fs, _, _ = get_fs_token_paths( - paths, - mode="rb", - storage_options=kwargs.get("backend_kwargs", {}).get( - "storage_options", {} - ), - expand=False, - ) - tmp_paths = fs.glob(fs._strip_protocol(paths)) # finds directories - paths = [fs.get_mapper(path) for path in tmp_paths] - elif is_remote_uri(paths): - raise ValueError( - "cannot do wild-card matching for paths that are remote URLs " - f"unless engine='zarr' is specified. Got paths: {paths}. " - "Instead, supply paths as an explicit list of strings." - ) - else: - paths = sorted(glob(_normalize_path(paths))) - elif isinstance(paths, os.PathLike): - paths = [os.fspath(paths)] - else: - paths = [os.fspath(p) if isinstance(p, os.PathLike) else p for p in paths] + paths = _find_absolute_paths(paths, engine=engine, **kwargs) if not paths: raise OSError("no files to open") @@ -1017,8 +1091,8 @@ def open_mfdataset( ) else: raise ValueError( - "{} is an invalid option for the keyword argument" - " ``combine``".format(combine) + f"{combine} is an invalid option for the keyword argument" + " ``combine``" ) except ValueError: for ds in datasets: @@ -1058,8 +1132,7 @@ def to_netcdf( *, multifile: Literal[True], invalid_netcdf: bool = False, -) -> tuple[ArrayWriter, AbstractDataStore]: - ... +) -> tuple[ArrayWriter, AbstractDataStore]: ... # path=None writes to bytes @@ -1076,8 +1149,7 @@ def to_netcdf( compute: bool = True, multifile: Literal[False] = False, invalid_netcdf: bool = False, -) -> bytes: - ... +) -> bytes: ... # compute=False returns dask.Delayed @@ -1095,8 +1167,7 @@ def to_netcdf( compute: Literal[False], multifile: Literal[False] = False, invalid_netcdf: bool = False, -) -> Delayed: - ... +) -> Delayed: ... # default return None @@ -1113,8 +1184,60 @@ def to_netcdf( compute: Literal[True] = True, multifile: Literal[False] = False, invalid_netcdf: bool = False, -) -> None: - ... +) -> None: ... + + +# if compute cannot be evaluated at type check time +# we may get back either Delayed or None +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = False, + multifile: Literal[False] = False, + invalid_netcdf: bool = False, +) -> Delayed | None: ... + + +# if multifile cannot be evaluated at type check time +# we may get back either writer and datastore or Delayed or None +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = False, + multifile: bool = False, + invalid_netcdf: bool = False, +) -> tuple[ArrayWriter, AbstractDataStore] | Delayed | None: ... + + +# Any +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike | None, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = False, + multifile: bool = False, + invalid_netcdf: bool = False, +) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: ... def to_netcdf( @@ -1334,15 +1457,15 @@ def save_mfdataset( >>> ds = xr.Dataset( ... {"a": ("time", np.linspace(0, 1, 48))}, - ... coords={"time": pd.date_range("2010-01-01", freq="M", periods=48)}, + ... coords={"time": pd.date_range("2010-01-01", freq="ME", periods=48)}, ... ) >>> ds - + Size: 768B Dimensions: (time: 48) Coordinates: - * time (time) datetime64[ns] 2010-01-31 2010-02-28 ... 2013-12-31 + * time (time) datetime64[ns] 384B 2010-01-31 2010-02-28 ... 2013-12-31 Data variables: - a (time) float64 0.0 0.02128 0.04255 0.06383 ... 0.9574 0.9787 1.0 + a (time) float64 384B 0.0 0.02128 0.04255 ... 0.9574 0.9787 1.0 >>> years, datasets = zip(*ds.groupby("time.year")) >>> paths = [f"{y}.nc" for y in years] >>> xr.save_mfdataset(datasets, paths) @@ -1401,10 +1524,63 @@ def save_mfdataset( ) -def _validate_region(ds, region): +def _auto_detect_region(ds_new, ds_orig, dim): + # Create a mapping array of coordinates to indices on the original array + coord = ds_orig[dim] + da_map = DataArray(np.arange(coord.size), coords={dim: coord}) + + try: + da_idxs = da_map.sel({dim: ds_new[dim]}) + except KeyError as e: + if "not all values found" in str(e): + raise KeyError( + f"Not all values of coordinate '{dim}' in the new array were" + " found in the original store. Writing to a zarr region slice" + " requires that no dimensions or metadata are changed by the write." + ) + else: + raise e + + if (da_idxs.diff(dim) != 1).any(): + raise ValueError( + f"The auto-detected region of coordinate '{dim}' for writing new data" + " to the original store had non-contiguous indices. Writing to a zarr" + " region slice requires that the new data constitute a contiguous subset" + " of the original store." + ) + + dim_slice = slice(da_idxs.values[0], da_idxs.values[-1] + 1) + + return dim_slice + + +def _auto_detect_regions(ds, region, open_kwargs): + ds_original = open_zarr(**open_kwargs) + for key, val in region.items(): + if val == "auto": + region[key] = _auto_detect_region(ds, ds_original, key) + return region + + +def _validate_and_autodetect_region( + ds, region, mode, open_kwargs +) -> tuple[dict[str, slice], bool]: + if region == "auto": + region = {dim: "auto" for dim in ds.dims} + if not isinstance(region, dict): raise TypeError(f"``region`` must be a dict, got {type(region)}") + if any(v == "auto" for v in region.values()): + region_was_autodetected = True + if mode != "r+": + raise ValueError( + f"``mode`` must be 'r+' when using ``region='auto'``, got {mode}" + ) + region = _auto_detect_regions(ds, region, open_kwargs) + else: + region_was_autodetected = False + for k, v in region.items(): if k not in ds.dims: raise ValueError( @@ -1436,6 +1612,8 @@ def _validate_region(ds, region): f".drop_vars({non_matching_vars!r})" ) + return region, region_was_autodetected + def _validate_datatypes_for_zarr_append(zstore, dataset): """If variable exists in the store, confirm dtype of the data to append is compatible with @@ -1479,19 +1657,21 @@ def to_zarr( dataset: Dataset, store: MutableMapping | str | os.PathLike[str] | None = None, chunk_store: MutableMapping | str | os.PathLike | None = None, - mode: Literal["w", "w-", "a", "r+", None] = None, + mode: ZarrWriteModes | None = None, synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: Literal[True] = True, consolidated: bool | None = None, append_dim: Hashable | None = None, - region: Mapping[str, slice] | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, -) -> backends.ZarrStore: - ... + write_empty_chunks: bool | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, +) -> backends.ZarrStore: ... # compute=False returns dask.Delayed @@ -1500,7 +1680,7 @@ def to_zarr( dataset: Dataset, store: MutableMapping | str | os.PathLike[str] | None = None, chunk_store: MutableMapping | str | os.PathLike | None = None, - mode: Literal["w", "w-", "a", "r+", None] = None, + mode: ZarrWriteModes | None = None, synchronizer=None, group: str | None = None, encoding: Mapping | None = None, @@ -1508,29 +1688,33 @@ def to_zarr( compute: Literal[False], consolidated: bool | None = None, append_dim: Hashable | None = None, - region: Mapping[str, slice] | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, -) -> Delayed: - ... + write_empty_chunks: bool | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, +) -> Delayed: ... def to_zarr( dataset: Dataset, store: MutableMapping | str | os.PathLike[str] | None = None, chunk_store: MutableMapping | str | os.PathLike | None = None, - mode: Literal["w", "w-", "a", "r+", None] = None, + mode: ZarrWriteModes | None = None, synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: bool = True, consolidated: bool | None = None, append_dim: Hashable | None = None, - region: Mapping[str, slice] | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, + write_empty_chunks: bool | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, ) -> backends.ZarrStore | Delayed: """This function creates an appropriate datastore for writing a dataset to a zarr ztore @@ -1574,23 +1758,38 @@ def to_zarr( else: mode = "w-" - if mode != "a" and append_dim is not None: + if mode not in ["a", "a-"] and append_dim is not None: raise ValueError("cannot set append_dim unless mode='a' or mode=None") - if mode not in ["a", "r+"] and region is not None: - raise ValueError("cannot set region unless mode='a', mode='r+' or mode=None") + if mode not in ["a", "a-", "r+"] and region is not None: + raise ValueError( + "cannot set region unless mode='a', mode='a-', mode='r+' or mode=None" + ) - if mode not in ["w", "w-", "a", "r+"]: + if mode not in ["w", "w-", "a", "a-", "r+"]: raise ValueError( "The only supported options for mode are 'w', " - f"'w-', 'a' and 'r+', but mode={mode!r}" + f"'w-', 'a', 'a-', and 'r+', but mode={mode!r}" ) # validate Dataset keys, DataArray names _validate_dataset_names(dataset) if region is not None: - _validate_region(dataset, region) + open_kwargs = dict( + store=store, + synchronizer=synchronizer, + group=group, + consolidated=consolidated, + storage_options=storage_options, + zarr_version=zarr_version, + ) + region, region_was_autodetected = _validate_and_autodetect_region( + dataset, region, mode, open_kwargs + ) + # drop indices to avoid potential race condition with auto region + if region_was_autodetected: + dataset = dataset.drop_vars(dataset.indexes) if append_dim is not None and append_dim in region: raise ValueError( f"cannot list the same dimension in both ``append_dim`` and " @@ -1623,9 +1822,10 @@ def to_zarr( safe_chunks=safe_chunks, stacklevel=4, # for Dataset.to_zarr() zarr_version=zarr_version, + write_empty=write_empty_chunks, ) - if mode in ["a", "r+"]: + if mode in ["a", "a-", "r+"]: _validate_datatypes_for_zarr_append(zstore, dataset) if append_dim is not None: existing_dims = zstore.get_dimensions() @@ -1652,7 +1852,9 @@ def to_zarr( writer = ArrayWriter() # TODO: figure out how to properly handle unlimited_dims dump_to_store(dataset, zstore, writer, encoding=encoding) - writes = writer.sync(compute=compute) + writes = writer.sync( + compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs + ) if compute: _finalize_store(writes, zstore) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 050493e3034..7d3cc00a52d 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -5,18 +5,27 @@ import time import traceback from collections.abc import Iterable +from glob import glob from typing import TYPE_CHECKING, Any, ClassVar import numpy as np from xarray.conventions import cf_encoder from xarray.core import indexing -from xarray.core.pycompat import is_duck_dask_array from xarray.core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import is_chunked_array if TYPE_CHECKING: from io import BufferedIOBase + from h5netcdf.legacyapi import Dataset as ncDatasetLegacyH5 + from netCDF4 import Dataset as ncDataset + + from xarray.core.dataset import Dataset + from xarray.core.types import NestedSequence + from xarray.datatree_.datatree import DataTree + # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -25,6 +34,24 @@ def _normalize_path(path): + """ + Normalize pathlikes to string. + + Parameters + ---------- + path : + Path to file. + + Examples + -------- + >>> from pathlib import Path + + >>> directory = Path(xr.backends.common.__file__).parent + >>> paths_path = Path(directory).joinpath("comm*n.py") + >>> paths_str = xr.backends.common._normalize_path(paths_path) + >>> print([type(p) for p in (paths_str,)]) + [] + """ if isinstance(path, os.PathLike): path = os.fspath(path) @@ -34,6 +61,64 @@ def _normalize_path(path): return path +def _find_absolute_paths( + paths: str | os.PathLike | NestedSequence[str | os.PathLike], **kwargs +) -> list[str]: + """ + Find absolute paths from the pattern. + + Parameters + ---------- + paths : + Path(s) to file(s). Can include wildcards like * . + **kwargs : + Extra kwargs. Mainly for fsspec. + + Examples + -------- + >>> from pathlib import Path + + >>> directory = Path(xr.backends.common.__file__).parent + >>> paths = str(Path(directory).joinpath("comm*n.py")) # Find common with wildcard + >>> paths = xr.backends.common._find_absolute_paths(paths) + >>> [Path(p).name for p in paths] + ['common.py'] + """ + if isinstance(paths, str): + if is_remote_uri(paths) and kwargs.get("engine", None) == "zarr": + try: + from fsspec.core import get_fs_token_paths + except ImportError as e: + raise ImportError( + "The use of remote URLs for opening zarr requires the package fsspec" + ) from e + + fs, _, _ = get_fs_token_paths( + paths, + mode="rb", + storage_options=kwargs.get("backend_kwargs", {}).get( + "storage_options", {} + ), + expand=False, + ) + tmp_paths = fs.glob(fs._strip_protocol(paths)) # finds directories + paths = [fs.get_mapper(path) for path in tmp_paths] + elif is_remote_uri(paths): + raise ValueError( + "cannot do wild-card matching for paths that are remote URLs " + f"unless engine='zarr' is specified. Got paths: {paths}. " + "Instead, supply paths as an explicit list of strings." + ) + else: + paths = sorted(glob(_normalize_path(paths))) + elif isinstance(paths, os.PathLike): + paths = [os.fspath(paths)] + else: + paths = [os.fspath(p) if isinstance(p, os.PathLike) else p for p in paths] + + return paths + + def _encode_variable_name(name): if name is None: name = NONE_VAR_NAME @@ -46,6 +131,43 @@ def _decode_variable_name(name): return name +def _open_datatree_netcdf( + ncDataset: ncDataset | ncDatasetLegacyH5, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + **kwargs, +) -> DataTree: + from xarray.backends.api import open_dataset + from xarray.core.treenode import NodePath + from xarray.datatree_.datatree import DataTree + + ds = open_dataset(filename_or_obj, **kwargs) + tree_root = DataTree.from_dict({"/": ds}) + with ncDataset(filename_or_obj, mode="r") as ncds: + for path in _iter_nc_groups(ncds): + subgroup_ds = open_dataset(filename_or_obj, group=path, **kwargs) + + # TODO refactor to use __setitem__ once creation of new nodes by assigning Dataset works again + node_name = NodePath(path).name + new_node: DataTree = DataTree(name=node_name, data=subgroup_ds) + tree_root._set_item( + path, + new_node, + allow_overwrite=False, + new_nodes_along_path=True, + ) + return tree_root + + +def _iter_nc_groups(root, parent="/"): + from xarray.core.treenode import NodePath + + parent = NodePath(parent) + for path, group in root.groups.items(): + gpath = parent / path + yield str(gpath) + yield from _iter_nc_groups(group, parent=gpath) + + def find_root_and_group(ds): """Find the root and group name of a netCDF4/h5netcdf dataset.""" hierarchy = () @@ -84,9 +206,9 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500 class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): __slots__ = () - def __array__(self, dtype=None): + def get_duck_array(self, dtype: np.typing.DTypeLike = None): key = indexing.BasicIndexer((slice(None),) * self.ndim) - return np.asarray(self[key], dtype=dtype) + return self[key] # type: ignore [index] class AbstractDataStore: @@ -151,7 +273,7 @@ def __init__(self, lock=None): self.lock = lock def add(self, source, target, region=None): - if is_duck_dask_array(source): + if is_chunked_array(source): self.sources.append(source) self.targets.append(target) self.regions.append(region) @@ -161,21 +283,25 @@ def add(self, source, target, region=None): else: target[...] = source - def sync(self, compute=True): + def sync(self, compute=True, chunkmanager_store_kwargs=None): if self.sources: - import dask.array as da + chunkmanager = get_chunked_array_type(*self.sources) # TODO: consider wrapping targets with dask.delayed, if this makes - # for any discernible difference in perforance, e.g., + # for any discernible difference in performance, e.g., # targets = [dask.delayed(t) for t in self.targets] - delayed_store = da.store( + if chunkmanager_store_kwargs is None: + chunkmanager_store_kwargs = {} + + delayed_store = chunkmanager.store( self.sources, self.targets, lock=self.lock, compute=compute, flush=True, regions=self.regions, + **chunkmanager_store_kwargs, ) self.sources = [] self.targets = [] @@ -373,13 +499,15 @@ class BackendEntrypoint: - ``guess_can_open`` method: it shall return ``True`` if the backend is able to open ``filename_or_obj``, ``False`` otherwise. The implementation of this method is not mandatory. + - ``open_datatree`` method: it shall implement reading from file, variables + decoding and it returns an instance of :py:class:`~datatree.DataTree`. + It shall take in input at least ``filename_or_obj`` argument. The + implementation of this method is not mandatory. For more details see + . Attributes ---------- - available : bool, default: True - Indicate wether this backend is available given the installed packages. - The setting of this attribute is not mandatory. open_dataset_parameters : tuple, default: None A list of ``open_dataset`` method parameters. The setting of this attribute is not mandatory. @@ -391,8 +519,6 @@ class BackendEntrypoint: The setting of this attribute is not mandatory. """ - available: ClassVar[bool] = True - open_dataset_parameters: ClassVar[tuple | None] = None description: ClassVar[str] = "" url: ClassVar[str] = "" @@ -408,24 +534,37 @@ def __repr__(self) -> str: def open_dataset( self, filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, drop_variables: str | Iterable[str] | None = None, **kwargs: Any, - ): + ) -> Dataset: """ Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`. """ - raise NotImplementedError + raise NotImplementedError() def guess_can_open( self, filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, - ): + ) -> bool: """ Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`. """ return False + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + **kwargs: Any, + ) -> DataTree: + """ + Backend open_datatree method used by Xarray in :py:func:`~xarray.open_datatree`. + """ + + raise NotImplementedError() + -BACKEND_ENTRYPOINTS: dict[str, type[BackendEntrypoint]] = {} +# mapping of engine name to (module name, BackendEntrypoint Class) +BACKEND_ENTRYPOINTS: dict[str, tuple[str | None, type[BackendEntrypoint]]] = {} diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 91fd15fcaa4..df901f9a1d9 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -1,5 +1,6 @@ from __future__ import annotations +import atexit import contextlib import io import threading @@ -289,6 +290,13 @@ def __repr__(self) -> str: ) +@atexit.register +def _remove_del_method(): + # We don't need to close unclosed files at program exit, and may not be able + # to, because Python is cleaning up imports / globals. + del CachingFileManager.__del__ + + class _RefCounter: """Class for keeping track of reference counts.""" diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index c4f75672173..b7c1b2a5f03 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -3,14 +3,15 @@ import functools import io import os - -from packaging.version import Version +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any from xarray.backends.common import ( BACKEND_ENTRYPOINTS, BackendEntrypoint, WritableCFDataStore, _normalize_path, + _open_datatree_netcdf, find_root_and_group, ) from xarray.backends.file_manager import CachingFileManager, DummyFileManager @@ -18,6 +19,7 @@ from xarray.backends.netCDF4_ import ( BaseNetCDF4Array, _encode_nc4_variable, + _ensure_no_forward_slash_in_name, _extract_nc4_variable_encoding, _get_datatype, _nc4_require_group, @@ -27,12 +29,18 @@ from xarray.core.utils import ( FrozenDict, is_remote_uri, - module_available, read_magic_number_from_file, try_read_magic_number_from_file_or_path, ) from xarray.core.variable import Variable +if TYPE_CHECKING: + from io import BufferedIOBase + + from xarray.backends.common import AbstractDataStore + from xarray.core.dataset import Dataset + from xarray.datatree_.datatree import DataTree + class H5NetCDFArrayWrapper(BaseNetCDF4Array): def get_array(self, needs_lock=True): @@ -134,6 +142,8 @@ def open( invalid_netcdf=None, phony_dims=None, decode_vlen_strings=True, + driver=None, + driver_kwds=None, ): import h5netcdf @@ -155,7 +165,10 @@ def open( kwargs = { "invalid_netcdf": invalid_netcdf, "decode_vlen_strings": decode_vlen_strings, + "driver": driver, } + if driver_kwds is not None: + kwargs.update(driver_kwds) if phony_dims is not None: kwargs["phony_dims"] = phony_dims @@ -192,6 +205,8 @@ def open_store_variable(self, name, var): "fletcher32": var.fletcher32, "shuffle": var.shuffle, } + if var.chunks: + encoding["preferred_chunks"] = dict(zip(var.dimensions, var.chunks)) # Convert h5py-style compression options to NetCDF4-Python # style, if possible if var.compression == "gzip": @@ -225,30 +240,17 @@ def get_attrs(self): return FrozenDict(_read_attributes(self.ds)) def get_dimensions(self): - import h5netcdf - - if Version(h5netcdf.__version__) >= Version("0.14.0.dev0"): - return FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items()) - else: - return self.ds.dimensions + return FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items()) def get_encoding(self): - import h5netcdf - - if Version(h5netcdf.__version__) >= Version("0.14.0.dev0"): - return { - "unlimited_dims": { - k for k, v in self.ds.dimensions.items() if v.isunlimited() - } - } - else: - return { - "unlimited_dims": { - k for k, v in self.ds.dimensions.items() if v is None - } + return { + "unlimited_dims": { + k for k, v in self.ds.dimensions.items() if v.isunlimited() } + } def set_dimension(self, name, length, is_unlimited=False): + _ensure_no_forward_slash_in_name(name) if is_unlimited: self.ds.dimensions[name] = None self.ds.resize_dimension(name, length) @@ -266,19 +268,11 @@ def prepare_variable( ): import h5py + _ensure_no_forward_slash_in_name(name) attrs = variable.attrs.copy() dtype = _get_datatype(variable, raise_on_invalid_encoding=check_encoding) fillvalue = attrs.pop("_FillValue", None) - if dtype is str and fillvalue is not None: - raise NotImplementedError( - "h5netcdf does not yet support setting a fill value for " - "variable-length strings " - "(https://github.com/h5netcdf/h5netcdf/issues/37). " - f"Either remove '_FillValue' from encoding on variable {name!r} " - "or set {'dtype': 'S1'} in encoding to use the fixed width " - "NC_CHAR type." - ) if dtype is str: dtype = h5py.special_dtype(vlen=str) @@ -365,33 +359,34 @@ class H5netcdfBackendEntrypoint(BackendEntrypoint): backends.ScipyBackendEntrypoint """ - available = module_available("h5netcdf") description = ( "Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using h5netcdf in Xarray" ) url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.H5netcdfBackendEntrypoint.html" - def guess_can_open(self, filename_or_obj): + def guess_can_open( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + ) -> bool: magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) if magic_number is not None: return magic_number.startswith(b"\211HDF\r\n\032\n") - try: + if isinstance(filename_or_obj, (str, os.PathLike)): _, ext = os.path.splitext(filename_or_obj) - except TypeError: - return False + return ext in {".nc", ".nc4", ".cdf"} - return ext in {".nc", ".nc4", ".cdf"} + return False - def open_dataset( + def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs self, - filename_or_obj, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, *, mask_and_scale=True, decode_times=True, concat_characters=True, decode_coords=True, - drop_variables=None, + drop_variables: str | Iterable[str] | None = None, use_cftime=None, decode_timedelta=None, format=None, @@ -400,7 +395,9 @@ def open_dataset( invalid_netcdf=None, phony_dims=None, decode_vlen_strings=True, - ): + driver=None, + driver_kwds=None, + ) -> Dataset: filename_or_obj = _normalize_path(filename_or_obj) store = H5NetCDFStore.open( filename_or_obj, @@ -410,6 +407,8 @@ def open_dataset( invalid_netcdf=invalid_netcdf, phony_dims=phony_dims, decode_vlen_strings=decode_vlen_strings, + driver=driver, + driver_kwds=driver_kwds, ) store_entrypoint = StoreBackendEntrypoint() @@ -426,5 +425,14 @@ def open_dataset( ) return ds + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + **kwargs, + ) -> DataTree: + from h5netcdf.legacyapi import Dataset as ncDataset + + return _open_datatree_netcdf(ncDataset, filename_or_obj, **kwargs) + -BACKEND_ENTRYPOINTS["h5netcdf"] = H5netcdfBackendEntrypoint +BACKEND_ENTRYPOINTS["h5netcdf"] = ("h5netcdf", H5netcdfBackendEntrypoint) diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index bba12a29609..69cef309b45 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -2,15 +2,83 @@ import multiprocessing import threading +import uuid import weakref -from collections.abc import MutableMapping -from typing import Any - -try: - from dask.utils import SerializableLock -except ImportError: - # no need to worry about serializing the lock - SerializableLock = threading.Lock # type: ignore +from collections.abc import Hashable, MutableMapping +from typing import Any, ClassVar +from weakref import WeakValueDictionary + + +# SerializableLock is adapted from Dask: +# https://github.com/dask/dask/blob/74e898f0ec712e8317ba86cc3b9d18b6b9922be0/dask/utils.py#L1160-L1224 +# Used under the terms of Dask's license, see licenses/DASK_LICENSE. +class SerializableLock: + """A Serializable per-process Lock + + This wraps a normal ``threading.Lock`` object and satisfies the same + interface. However, this lock can also be serialized and sent to different + processes. It will not block concurrent operations between processes (for + this you should look at ``dask.multiprocessing.Lock`` or ``locket.lock_file`` + but will consistently deserialize into the same lock. + + So if we make a lock in one process:: + + lock = SerializableLock() + + And then send it over to another process multiple times:: + + bytes = pickle.dumps(lock) + a = pickle.loads(bytes) + b = pickle.loads(bytes) + + Then the deserialized objects will operate as though they were the same + lock, and collide as appropriate. + + This is useful for consistently protecting resources on a per-process + level. + + The creation of locks is itself not threadsafe. + """ + + _locks: ClassVar[WeakValueDictionary[Hashable, threading.Lock]] = ( + WeakValueDictionary() + ) + token: Hashable + lock: threading.Lock + + def __init__(self, token: Hashable | None = None): + self.token = token or str(uuid.uuid4()) + if self.token in SerializableLock._locks: + self.lock = SerializableLock._locks[self.token] + else: + self.lock = threading.Lock() + SerializableLock._locks[self.token] = self.lock + + def acquire(self, *args, **kwargs): + return self.lock.acquire(*args, **kwargs) + + def release(self, *args, **kwargs): + return self.lock.release(*args, **kwargs) + + def __enter__(self): + self.lock.__enter__() + + def __exit__(self, *args): + self.lock.__exit__(*args) + + def locked(self): + return self.lock.locked() + + def __getstate__(self): + return self.token + + def __setstate__(self, token): + self.__init__(token) + + def __str__(self): + return f"<{self.__class__.__name__}: {self.token}>" + + __repr__ = __str__ # Locks used by multiple backends. diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 0c6e083158d..6720a67ae2f 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -3,7 +3,9 @@ import functools import operator import os +from collections.abc import Iterable from contextlib import suppress +from typing import TYPE_CHECKING, Any import numpy as np @@ -14,6 +16,7 @@ BackendEntrypoint, WritableCFDataStore, _normalize_path, + _open_datatree_netcdf, find_root_and_group, robust_getitem, ) @@ -33,16 +36,21 @@ FrozenDict, close_on_error, is_remote_uri, - module_available, try_read_magic_number_from_path, ) from xarray.core.variable import Variable +if TYPE_CHECKING: + from io import BufferedIOBase + + from xarray.backends.common import AbstractDataStore + from xarray.core.dataset import Dataset + from xarray.datatree_.datatree import DataTree + # This lookup table maps from dtype.byteorder to a readable endian # string used by netCDF4. _endian_lookup = {"=": "native", ">": "big", "<": "little", "|": "native"} - NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK]) @@ -58,10 +66,12 @@ def __init__(self, variable_name, datastore): dtype = array.dtype if dtype is str: - # use object dtype because that's the only way in numpy to - # represent variable length strings; it also prevents automatic - # string concatenation via conventions.decode_cf_variable - dtype = np.dtype("O") + # use object dtype (with additional vlen string metadata) because that's + # the only way in numpy to represent variable length strings and to + # check vlen string dtype in further steps + # it also prevents automatic string concatenation via + # conventions.decode_cf_variable + dtype = coding.strings.create_vlen_dtype(str) self.dtype = dtype def __setitem__(self, key, value): @@ -132,7 +142,9 @@ def _check_encoding_dtype_is_vlen_string(dtype): ) -def _get_datatype(var, nc_format="NETCDF4", raise_on_invalid_encoding=False): +def _get_datatype( + var, nc_format="NETCDF4", raise_on_invalid_encoding=False +) -> np.dtype: if nc_format == "NETCDF4": return _nc4_dtype(var) if "dtype" in var.encoding: @@ -185,11 +197,20 @@ def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group): return ds +def _ensure_no_forward_slash_in_name(name): + if "/" in name: + raise ValueError( + f"Forward slashes '/' are not allowed in variable and dimension names (got {name!r}). " + "Forward slashes are used as hierarchy-separators for " + "HDF5-based files ('netcdf4'/'h5netcdf')." + ) + + def _ensure_fill_value_valid(data, attributes): # work around for netCDF4/scipy issue where _FillValue has the wrong type: # https://github.com/Unidata/netcdf4-python/issues/271 if data.dtype.kind == "S" and "_FillValue" in attributes: - attributes["_FillValue"] = np.string_(attributes["_FillValue"]) + attributes["_FillValue"] = np.bytes_(attributes["_FillValue"]) def _force_native_endianness(var): @@ -216,13 +237,13 @@ def _force_native_endianness(var): def _extract_nc4_variable_encoding( - variable, + variable: Variable, raise_on_invalid=False, lsd_okay=True, h5py_okay=False, backend="netCDF4", unlimited_dims=None, -): +) -> dict[str, Any]: if unlimited_dims is None: unlimited_dims = () @@ -239,6 +260,12 @@ def _extract_nc4_variable_encoding( "_FillValue", "dtype", "compression", + "significant_digits", + "quantize_mode", + "blosc_shuffle", + "szip_coding", + "szip_pixels_per_block", + "endian", } if lsd_okay: valid_encodings.add("least_significant_digit") @@ -284,7 +311,7 @@ def _extract_nc4_variable_encoding( return encoding -def _is_list_of_strings(value): +def _is_list_of_strings(value) -> bool: arr = np.asarray(value) return arr.dtype.kind in ["U", "S"] and arr.size > 1 @@ -390,13 +417,25 @@ def _acquire(self, needs_lock=True): def ds(self): return self._acquire() - def open_store_variable(self, name, var): + def open_store_variable(self, name: str, var): + import netCDF4 + dimensions = var.dimensions - data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) attributes = {k: var.getncattr(k) for k in var.ncattrs()} + data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) + encoding: dict[str, Any] = {} + if isinstance(var.datatype, netCDF4.EnumType): + encoding["dtype"] = np.dtype( + data.dtype, + metadata={ + "enum": var.datatype.enum_dict, + "enum_name": var.datatype.name, + }, + ) + else: + encoding["dtype"] = var.dtype _ensure_fill_value_valid(data, attributes) # netCDF4 specific encoding; save _FillValue for later - encoding = {} filters = var.filters() if filters is not None: encoding.update(filters) @@ -408,6 +447,7 @@ def open_store_variable(self, name, var): else: encoding["contiguous"] = False encoding["chunksizes"] = tuple(chunking) + encoding["preferred_chunks"] = dict(zip(var.dimensions, chunking)) # TODO: figure out how to round-trip "endian-ness" without raising # warnings from netCDF4 # encoding['endian'] = var.endian() @@ -415,7 +455,6 @@ def open_store_variable(self, name, var): # save source so __repr__ can detect if it's local or not encoding["source"] = self._filename encoding["original_shape"] = var.shape - encoding["dtype"] = var.dtype return Variable(dimensions, data, attributes, encoding) @@ -438,6 +477,7 @@ def get_encoding(self): } def set_dimension(self, name, length, is_unlimited=False): + _ensure_no_forward_slash_in_name(name) dim_length = length if not is_unlimited else None self.ds.createDimension(name, size=dim_length) @@ -459,46 +499,44 @@ def encode_variable(self, variable): return variable def prepare_variable( - self, name, variable, check_encoding=False, unlimited_dims=None + self, name, variable: Variable, check_encoding=False, unlimited_dims=None ): + _ensure_no_forward_slash_in_name(name) + attrs = variable.attrs.copy() + fill_value = attrs.pop("_FillValue", None) datatype = _get_datatype( variable, self.format, raise_on_invalid_encoding=check_encoding ) - attrs = variable.attrs.copy() - - fill_value = attrs.pop("_FillValue", None) - - if datatype is str and fill_value is not None: - raise NotImplementedError( - "netCDF4 does not yet support setting a fill value for " - "variable-length strings " - "(https://github.com/Unidata/netcdf4-python/issues/730). " - f"Either remove '_FillValue' from encoding on variable {name!r} " - "or set {'dtype': 'S1'} in encoding to use the fixed width " - "NC_CHAR type." - ) - + # check enum metadata and use netCDF4.EnumType + if ( + (meta := np.dtype(datatype).metadata) + and (e_name := meta.get("enum_name")) + and (e_dict := meta.get("enum")) + ): + datatype = self._build_and_get_enum(name, datatype, e_name, e_dict) encoding = _extract_nc4_variable_encoding( variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims ) - if name in self.ds.variables: nc4_var = self.ds.variables[name] else: - nc4_var = self.ds.createVariable( + default_args = dict( varname=name, datatype=datatype, dimensions=variable.dims, - zlib=encoding.get("zlib", False), - complevel=encoding.get("complevel", 4), - shuffle=encoding.get("shuffle", True), - fletcher32=encoding.get("fletcher32", False), - contiguous=encoding.get("contiguous", False), - chunksizes=encoding.get("chunksizes"), + zlib=False, + complevel=4, + shuffle=True, + fletcher32=False, + contiguous=False, + chunksizes=None, endian="native", - least_significant_digit=encoding.get("least_significant_digit"), + least_significant_digit=None, fill_value=fill_value, ) + default_args.update(encoding) + default_args.pop("_FillValue", None) + nc4_var = self.ds.createVariable(**default_args) nc4_var.setncatts(attrs) @@ -506,6 +544,33 @@ def prepare_variable( return target, variable.data + def _build_and_get_enum( + self, var_name: str, dtype: np.dtype, enum_name: str, enum_dict: dict[str, int] + ) -> Any: + """ + Add or get the netCDF4 Enum based on the dtype in encoding. + The return type should be ``netCDF4.EnumType``, + but we avoid importing netCDF4 globally for performances. + """ + if enum_name not in self.ds.enumtypes: + return self.ds.createEnumType( + dtype, + enum_name, + enum_dict, + ) + datatype = self.ds.enumtypes[enum_name] + if datatype.enum_dict != enum_dict: + error_msg = ( + f"Cannot save variable `{var_name}` because an enum" + f" `{enum_name}` already exists in the Dataset but have" + " a different definition. To fix this error, make sure" + " each variable have a uniquely named enum in their" + " `encoding['dtype'].metadata` or, if they should share" + " the same enum type, make sure the enums are identical." + ) + raise ValueError(error_msg) + return datatype + def sync(self): self.ds.sync() @@ -517,7 +582,7 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint): """ Backend for netCDF files based on the netCDF4 package. - It can open ".nc", ".nc4", ".cdf" files and will be choosen + It can open ".nc", ".nc4", ".cdf" files and will be chosen as default for these files. Additionally it can open valid HDF5 files, see @@ -535,33 +600,37 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint): backends.ScipyBackendEntrypoint """ - available = module_available("netCDF4") description = ( "Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using netCDF4 in Xarray" ) url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.NetCDF4BackendEntrypoint.html" - def guess_can_open(self, filename_or_obj): + def guess_can_open( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + ) -> bool: if isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj): return True magic_number = try_read_magic_number_from_path(filename_or_obj) if magic_number is not None: # netcdf 3 or HDF5 return magic_number.startswith((b"CDF", b"\211HDF\r\n\032\n")) - try: + + if isinstance(filename_or_obj, (str, os.PathLike)): _, ext = os.path.splitext(filename_or_obj) - except TypeError: - return False - return ext in {".nc", ".nc4", ".cdf"} + return ext in {".nc", ".nc4", ".cdf"} + + return False - def open_dataset( + def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs self, - filename_or_obj, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, mask_and_scale=True, decode_times=True, concat_characters=True, decode_coords=True, - drop_variables=None, + drop_variables: str | Iterable[str] | None = None, use_cftime=None, decode_timedelta=None, group=None, @@ -572,7 +641,7 @@ def open_dataset( persist=False, lock=None, autoclose=False, - ): + ) -> Dataset: filename_or_obj = _normalize_path(filename_or_obj) store = NetCDF4DataStore.open( filename_or_obj, @@ -600,5 +669,14 @@ def open_dataset( ) return ds + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + **kwargs, + ) -> DataTree: + from netCDF4 import Dataset as ncDataset + + return _open_datatree_netcdf(ncDataset, filename_or_obj, **kwargs) + -BACKEND_ENTRYPOINTS["netcdf4"] = NetCDF4BackendEntrypoint +BACKEND_ENTRYPOINTS["netcdf4"] = ("netCDF4", NetCDF4BackendEntrypoint) diff --git a/xarray/backends/netcdf3.py b/xarray/backends/netcdf3.py index ef389eefc90..70ddbdd1e01 100644 --- a/xarray/backends/netcdf3.py +++ b/xarray/backends/netcdf3.py @@ -42,6 +42,21 @@ # encode all strings as UTF-8 STRING_ENCODING = "utf-8" +COERCION_VALUE_ERROR = ( + "could not safely cast array from {dtype} to {new_dtype}. While it is not " + "always the case, a common reason for this is that xarray has deemed it " + "safest to encode np.datetime64[ns] or np.timedelta64[ns] values with " + "int64 values representing units of 'nanoseconds'. This is either due to " + "the fact that the times are known to require nanosecond precision for an " + "accurate round trip, or that the times are unknown prior to writing due " + "to being contained in a chunked array. Ways to work around this are " + "either to use a backend that supports writing int64 values, or to " + "manually specify the encoding['units'] and encoding['dtype'] (e.g. " + "'seconds since 1970-01-01' and np.dtype('int32')) on the time " + "variable(s) such that the times can be serialized in a netCDF3 file " + "(note that depending on the situation, however, this latter option may " + "result in an inaccurate round trip)." +) def coerce_nc3_dtype(arr): @@ -66,7 +81,7 @@ def coerce_nc3_dtype(arr): cast_arr = arr.astype(new_dtype) if not (cast_arr == arr).all(): raise ValueError( - f"could not safely cast array from dtype {dtype} to {new_dtype}" + COERCION_VALUE_ERROR.format(dtype=dtype, new_dtype=new_dtype) ) arr = cast_arr return arr @@ -88,13 +103,30 @@ def encode_nc3_attrs(attrs): return {k: encode_nc3_attr_value(v) for k, v in attrs.items()} +def _maybe_prepare_times(var): + # checks for integer-based time-like and + # replaces np.iinfo(np.int64).min with _FillValue or np.nan + # this keeps backwards compatibility + + data = var.data + if data.dtype.kind in "iu": + units = var.attrs.get("units", None) + if units is not None: + if coding.variables._is_time_like(units): + mask = data == np.iinfo(np.int64).min + if mask.any(): + data = np.where(mask, var.attrs.get("_FillValue", np.nan), data) + return data + + def encode_nc3_variable(var): for coder in [ coding.strings.EncodedStringCoder(allows_unicode=False), coding.strings.CharacterArrayCoder(), ]: var = coder.encode(var) - data = coerce_nc3_dtype(var.data) + data = _maybe_prepare_times(var) + data = coerce_nc3_dtype(data) attrs = encode_nc3_attrs(var.attrs) return Variable(var.dims, data, attrs, var.encoding) diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index d6ad6dfbe18..a62ca6c9862 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -6,12 +6,19 @@ import sys import warnings from importlib.metadata import entry_points -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable from xarray.backends.common import BACKEND_ENTRYPOINTS, BackendEntrypoint +from xarray.core.utils import module_available if TYPE_CHECKING: import os + from importlib.metadata import EntryPoint + + if sys.version_info >= (3, 10): + from importlib.metadata import EntryPoints + else: + EntryPoints = list[EntryPoint] from io import BufferedIOBase from xarray.backends.common import AbstractDataStore @@ -19,15 +26,15 @@ STANDARD_BACKENDS_ORDER = ["netcdf4", "h5netcdf", "scipy"] -def remove_duplicates(entrypoints): +def remove_duplicates(entrypoints: EntryPoints) -> list[EntryPoint]: # sort and group entrypoints by name - entrypoints = sorted(entrypoints, key=lambda ep: ep.name) - entrypoints_grouped = itertools.groupby(entrypoints, key=lambda ep: ep.name) + entrypoints_sorted = sorted(entrypoints, key=lambda ep: ep.name) + entrypoints_grouped = itertools.groupby(entrypoints_sorted, key=lambda ep: ep.name) # check if there are multiple entrypoints for the same name unique_entrypoints = [] - for name, matches in entrypoints_grouped: + for name, _matches in entrypoints_grouped: # remove equal entrypoints - matches = list(set(matches)) + matches = list(set(_matches)) unique_entrypoints.append(matches[0]) matches_len = len(matches) if matches_len > 1: @@ -42,7 +49,7 @@ def remove_duplicates(entrypoints): return unique_entrypoints -def detect_parameters(open_dataset): +def detect_parameters(open_dataset: Callable) -> tuple[str, ...]: signature = inspect.signature(open_dataset) parameters = signature.parameters parameters_list = [] @@ -60,7 +67,9 @@ def detect_parameters(open_dataset): return tuple(parameters_list) -def backends_dict_from_pkg(entrypoints): +def backends_dict_from_pkg( + entrypoints: list[EntryPoint], +) -> dict[str, type[BackendEntrypoint]]: backend_entrypoints = {} for entrypoint in entrypoints: name = entrypoint.name @@ -72,14 +81,18 @@ def backends_dict_from_pkg(entrypoints): return backend_entrypoints -def set_missing_parameters(backend_entrypoints): - for name, backend in backend_entrypoints.items(): +def set_missing_parameters( + backend_entrypoints: dict[str, type[BackendEntrypoint]] +) -> None: + for _, backend in backend_entrypoints.items(): if backend.open_dataset_parameters is None: open_dataset = backend.open_dataset backend.open_dataset_parameters = detect_parameters(open_dataset) -def sort_backends(backend_entrypoints): +def sort_backends( + backend_entrypoints: dict[str, type[BackendEntrypoint]] +) -> dict[str, type[BackendEntrypoint]]: ordered_backends_entrypoints = {} for be_name in STANDARD_BACKENDS_ORDER: if backend_entrypoints.get(be_name, None) is not None: @@ -90,13 +103,13 @@ def sort_backends(backend_entrypoints): return ordered_backends_entrypoints -def build_engines(entrypoints) -> dict[str, BackendEntrypoint]: - backend_entrypoints = {} - for backend_name, backend in BACKEND_ENTRYPOINTS.items(): - if backend.available: +def build_engines(entrypoints: EntryPoints) -> dict[str, BackendEntrypoint]: + backend_entrypoints: dict[str, type[BackendEntrypoint]] = {} + for backend_name, (module_name, backend) in BACKEND_ENTRYPOINTS.items(): + if module_name is None or module_available(module_name): backend_entrypoints[backend_name] = backend - entrypoints = remove_duplicates(entrypoints) - external_backend_entrypoints = backends_dict_from_pkg(entrypoints) + entrypoints_unique = remove_duplicates(entrypoints) + external_backend_entrypoints = backends_dict_from_pkg(entrypoints_unique) backend_entrypoints.update(external_backend_entrypoints) backend_entrypoints = sort_backends(backend_entrypoints) set_missing_parameters(backend_entrypoints) @@ -122,13 +135,18 @@ def list_engines() -> dict[str, BackendEntrypoint]: if sys.version_info >= (3, 10): entrypoints = entry_points(group="xarray.backends") else: - entrypoints = entry_points().get("xarray.backends", ()) + entrypoints = entry_points().get("xarray.backends", []) return build_engines(entrypoints) +def refresh_engines() -> None: + """Refreshes the backend engines based on installed packages.""" + list_engines.cache_clear() + + def guess_engine( store_spec: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, -): +) -> str | type[BackendEntrypoint]: engines = list_engines() for engine, backend in engines.items(): @@ -141,7 +159,7 @@ def guess_engine( warnings.warn(f"{engine!r} fails while guessing", RuntimeWarning) compatible_engines = [] - for engine, backend_cls in BACKEND_ENTRYPOINTS.items(): + for engine, (_, backend_cls) in BACKEND_ENTRYPOINTS.items(): try: backend = backend_cls() if backend.guess_can_open(store_spec): diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py deleted file mode 100644 index ae8f90e3a44..00000000000 --- a/xarray/backends/pseudonetcdf_.py +++ /dev/null @@ -1,179 +0,0 @@ -from __future__ import annotations - -import numpy as np - -from xarray.backends.common import ( - BACKEND_ENTRYPOINTS, - AbstractDataStore, - BackendArray, - BackendEntrypoint, - _normalize_path, -) -from xarray.backends.file_manager import CachingFileManager -from xarray.backends.locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock -from xarray.backends.store import StoreBackendEntrypoint -from xarray.core import indexing -from xarray.core.utils import Frozen, FrozenDict, close_on_error, module_available -from xarray.core.variable import Variable - -# psuedonetcdf can invoke netCDF libraries internally -PNETCDF_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK]) - - -class PncArrayWrapper(BackendArray): - def __init__(self, variable_name, datastore): - self.datastore = datastore - self.variable_name = variable_name - array = self.get_array() - self.shape = array.shape - self.dtype = np.dtype(array.dtype) - - def get_array(self, needs_lock=True): - ds = self.datastore._manager.acquire(needs_lock) - return ds.variables[self.variable_name] - - def __getitem__(self, key): - return indexing.explicit_indexing_adapter( - key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem - ) - - def _getitem(self, key): - with self.datastore.lock: - array = self.get_array(needs_lock=False) - return array[key] - - -class PseudoNetCDFDataStore(AbstractDataStore): - """Store for accessing datasets via PseudoNetCDF""" - - @classmethod - def open(cls, filename, lock=None, mode=None, **format_kwargs): - from PseudoNetCDF import pncopen - - keywords = {"kwargs": format_kwargs} - # only include mode if explicitly passed - if mode is not None: - keywords["mode"] = mode - - if lock is None: - lock = PNETCDF_LOCK - - manager = CachingFileManager(pncopen, filename, lock=lock, **keywords) - return cls(manager, lock) - - def __init__(self, manager, lock=None): - self._manager = manager - self.lock = ensure_lock(lock) - - @property - def ds(self): - return self._manager.acquire() - - def open_store_variable(self, name, var): - data = indexing.LazilyIndexedArray(PncArrayWrapper(name, self)) - attrs = {k: getattr(var, k) for k in var.ncattrs()} - return Variable(var.dimensions, data, attrs) - - def get_variables(self): - return FrozenDict( - (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items() - ) - - def get_attrs(self): - return Frozen({k: getattr(self.ds, k) for k in self.ds.ncattrs()}) - - def get_dimensions(self): - return Frozen(self.ds.dimensions) - - def get_encoding(self): - return { - "unlimited_dims": { - k for k in self.ds.dimensions if self.ds.dimensions[k].isunlimited() - } - } - - def close(self): - self._manager.close() - - -class PseudoNetCDFBackendEntrypoint(BackendEntrypoint): - """ - Backend for netCDF-like data formats in the air quality field - based on the PseudoNetCDF package. - - It can open: - - CAMx - - RACM2 box-model outputs - - Kinetic Pre-Processor outputs - - ICARTT Data files (ffi1001) - - CMAQ Files - - GEOS-Chem Binary Punch/NetCDF files - - and many more - - This backend is not selected by default for any files, so make - sure to specify ``engine="pseudonetcdf"`` in ``open_dataset``. - - For more information about the underlying library, visit: - https://pseudonetcdf.readthedocs.io - - See Also - -------- - backends.PseudoNetCDFDataStore - """ - - available = module_available("PseudoNetCDF") - description = ( - "Open many atmospheric science data formats using PseudoNetCDF in Xarray" - ) - url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.PseudoNetCDFBackendEntrypoint.html" - - # *args and **kwargs are not allowed in open_backend_dataset_ kwargs, - # unless the open_dataset_parameters are explicitly defined like this: - open_dataset_parameters = ( - "filename_or_obj", - "mask_and_scale", - "decode_times", - "concat_characters", - "decode_coords", - "drop_variables", - "use_cftime", - "decode_timedelta", - "mode", - "lock", - ) - - def open_dataset( - self, - filename_or_obj, - mask_and_scale=False, - decode_times=True, - concat_characters=True, - decode_coords=True, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - mode=None, - lock=None, - **format_kwargs, - ): - filename_or_obj = _normalize_path(filename_or_obj) - store = PseudoNetCDFDataStore.open( - filename_or_obj, lock=lock, mode=mode, **format_kwargs - ) - - store_entrypoint = StoreBackendEntrypoint() - with close_on_error(store): - ds = store_entrypoint.open_dataset( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, - ) - return ds - - -BACKEND_ENTRYPOINTS["pseudonetcdf"] = PseudoNetCDFBackendEntrypoint diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index df26a03d790..5a475a7c3be 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -1,7 +1,9 @@ from __future__ import annotations +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + import numpy as np -from packaging.version import Version from xarray.backends.common import ( BACKEND_ENTRYPOINTS, @@ -12,16 +14,21 @@ ) from xarray.backends.store import StoreBackendEntrypoint from xarray.core import indexing -from xarray.core.pycompat import integer_types from xarray.core.utils import ( Frozen, FrozenDict, close_on_error, is_dict_like, is_remote_uri, - module_available, ) from xarray.core.variable import Variable +from xarray.namedarray.pycompat import integer_types + +if TYPE_CHECKING: + import os + from io import BufferedIOBase + + from xarray.core.dataset import Dataset class PydapArrayWrapper(BackendArray): @@ -46,6 +53,7 @@ def _getitem(self, key): # downloading coordinate data twice array = getattr(self.array, "array", self.array) result = robust_getitem(array, key, catch=ValueError) + result = np.asarray(result) # in some cases, pydap doesn't squeeze axes automatically like numpy axis = tuple(n for n, k in enumerate(key) if isinstance(k, integer_types)) if result.ndim + len(axis) != array.ndim and axis: @@ -114,11 +122,10 @@ def open( "output_grid": output_grid or True, "timeout": timeout, } - if Version(pydap.lib.__version__) >= Version("3.3.0"): - if verify is not None: - kwargs.update({"verify": verify}) - if user_charset is not None: - kwargs.update({"user_charset": user_charset}) + if verify is not None: + kwargs.update({"verify": verify}) + if user_charset is not None: + kwargs.update({"user_charset": user_charset}) ds = pydap.client.open_url(**kwargs) return cls(ds) @@ -154,21 +161,24 @@ class PydapBackendEntrypoint(BackendEntrypoint): backends.PydapDataStore """ - available = module_available("pydap") description = "Open remote datasets via OPeNDAP using pydap in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.PydapBackendEntrypoint.html" - def guess_can_open(self, filename_or_obj): + def guess_can_open( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + ) -> bool: return isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj) - def open_dataset( + def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs self, - filename_or_obj, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, mask_and_scale=True, decode_times=True, concat_characters=True, decode_coords=True, - drop_variables=None, + drop_variables: str | Iterable[str] | None = None, use_cftime=None, decode_timedelta=None, application=None, @@ -177,7 +187,7 @@ def open_dataset( timeout=None, verify=None, user_charset=None, - ): + ) -> Dataset: store = PydapDataStore.open( url=filename_or_obj, application=application, @@ -203,4 +213,4 @@ def open_dataset( return ds -BACKEND_ENTRYPOINTS["pydap"] = PydapBackendEntrypoint +BACKEND_ENTRYPOINTS["pydap"] = ("pydap", PydapBackendEntrypoint) diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 611ea978990..75e96ffdc0a 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -1,6 +1,8 @@ from __future__ import annotations import warnings +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any import numpy as np @@ -21,9 +23,15 @@ ) from xarray.backends.store import StoreBackendEntrypoint from xarray.core import indexing -from xarray.core.utils import Frozen, FrozenDict, close_on_error, module_available +from xarray.core.utils import Frozen, FrozenDict, close_on_error from xarray.core.variable import Variable +if TYPE_CHECKING: + import os + from io import BufferedIOBase + + from xarray.core.dataset import Dataset + # PyNIO can invoke netCDF libraries internally # Add a dedicated lock just in case NCL as well isn't thread-safe. NCL_LOCK = SerializableLock() @@ -117,21 +125,20 @@ class PynioBackendEntrypoint(BackendEntrypoint): https://github.com/pydata/xarray/issues/4491 for more information """ - available = module_available("Nio") - - def open_dataset( + def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs self, - filename_or_obj, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, mask_and_scale=True, decode_times=True, concat_characters=True, decode_coords=True, - drop_variables=None, + drop_variables: str | Iterable[str] | None = None, use_cftime=None, decode_timedelta=None, mode="r", lock=None, - ): + ) -> Dataset: filename_or_obj = _normalize_path(filename_or_obj) store = NioDataStore( filename_or_obj, @@ -154,4 +161,4 @@ def open_dataset( return ds -BACKEND_ENTRYPOINTS["pynio"] = PynioBackendEntrypoint +BACKEND_ENTRYPOINTS["pynio"] = ("Nio", PynioBackendEntrypoint) diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py deleted file mode 100644 index 15006dee5f1..00000000000 --- a/xarray/backends/rasterio_.py +++ /dev/null @@ -1,383 +0,0 @@ -from __future__ import annotations - -import os -import warnings - -import numpy as np - -from xarray.backends.common import BackendArray -from xarray.backends.file_manager import CachingFileManager -from xarray.backends.locks import SerializableLock -from xarray.core import indexing -from xarray.core.dataarray import DataArray -from xarray.core.utils import is_scalar - -# TODO: should this be GDAL_LOCK instead? -RASTERIO_LOCK = SerializableLock() - -_ERROR_MSG = ( - "The kind of indexing operation you are trying to do is not " - "valid on rasterio files. Try to load your data with ds.load()" - "first." -) - - -class RasterioArrayWrapper(BackendArray): - """A wrapper around rasterio dataset objects""" - - def __init__(self, manager, lock, vrt_params=None): - from rasterio.vrt import WarpedVRT - - self.manager = manager - self.lock = lock - - # cannot save riods as an attribute: this would break pickleability - riods = manager.acquire() - if vrt_params is not None: - riods = WarpedVRT(riods, **vrt_params) - self.vrt_params = vrt_params - self._shape = (riods.count, riods.height, riods.width) - - dtypes = riods.dtypes - if not np.all(np.asarray(dtypes) == dtypes[0]): - raise ValueError("All bands should have the same dtype") - self._dtype = np.dtype(dtypes[0]) - - @property - def dtype(self): - return self._dtype - - @property - def shape(self) -> tuple[int, ...]: - return self._shape - - def _get_indexer(self, key): - """Get indexer for rasterio array. - - Parameters - ---------- - key : tuple of int - - Returns - ------- - band_key: an indexer for the 1st dimension - window: two tuples. Each consists of (start, stop). - squeeze_axis: axes to be squeezed - np_ind: indexer for loaded numpy array - - See Also - -------- - indexing.decompose_indexer - """ - assert len(key) == 3, "rasterio datasets should always be 3D" - - # bands cannot be windowed but they can be listed - band_key = key[0] - np_inds = [] - # bands (axis=0) cannot be windowed but they can be listed - if isinstance(band_key, slice): - start, stop, step = band_key.indices(self.shape[0]) - band_key = np.arange(start, stop, step) - # be sure we give out a list - band_key = (np.asarray(band_key) + 1).tolist() - if isinstance(band_key, list): # if band_key is not a scalar - np_inds.append(slice(None)) - - # but other dims can only be windowed - window = [] - squeeze_axis = [] - for i, (k, n) in enumerate(zip(key[1:], self.shape[1:])): - if isinstance(k, slice): - # step is always positive. see indexing.decompose_indexer - start, stop, step = k.indices(n) - np_inds.append(slice(None, None, step)) - elif is_scalar(k): - # windowed operations will always return an array - # we will have to squeeze it later - squeeze_axis.append(-(2 - i)) - start = k - stop = k + 1 - else: - start, stop = np.min(k), np.max(k) + 1 - np_inds.append(k - start) - window.append((start, stop)) - - if isinstance(key[1], np.ndarray) and isinstance(key[2], np.ndarray): - # do outer-style indexing - np_inds[-2:] = np.ix_(*np_inds[-2:]) - - return band_key, tuple(window), tuple(squeeze_axis), tuple(np_inds) - - def _getitem(self, key): - from rasterio.vrt import WarpedVRT - - band_key, window, squeeze_axis, np_inds = self._get_indexer(key) - - if not band_key or any(start == stop for (start, stop) in window): - # no need to do IO - shape = (len(band_key),) + tuple(stop - start for (start, stop) in window) - out = np.zeros(shape, dtype=self.dtype) - else: - with self.lock: - riods = self.manager.acquire(needs_lock=False) - if self.vrt_params is not None: - riods = WarpedVRT(riods, **self.vrt_params) - out = riods.read(band_key, window=window) - - if squeeze_axis: - out = np.squeeze(out, axis=squeeze_axis) - return out[np_inds] - - def __getitem__(self, key): - return indexing.explicit_indexing_adapter( - key, self.shape, indexing.IndexingSupport.OUTER, self._getitem - ) - - -def _parse_envi(meta): - """Parse ENVI metadata into Python data structures. - - See the link for information on the ENVI header file format: - http://www.harrisgeospatial.com/docs/enviheaderfiles.html - - Parameters - ---------- - meta : dict - Dictionary of keys and str values to parse, as returned by the rasterio - tags(ns='ENVI') call. - - Returns - ------- - parsed_meta : dict - Dictionary containing the original keys and the parsed values - - """ - - def parsevec(s): - return np.fromstring(s.strip("{}"), dtype="float", sep=",") - - def default(s): - return s.strip("{}") - - parse = {"wavelength": parsevec, "fwhm": parsevec} - parsed_meta = {k: parse.get(k, default)(v) for k, v in meta.items()} - return parsed_meta - - -def open_rasterio( - filename, - parse_coordinates=None, - chunks=None, - cache=None, - lock=None, - **kwargs, -): - """Open a file with rasterio. - - .. deprecated:: 0.20.0 - - Deprecated in favor of rioxarray. - For information about transitioning, see: - https://corteva.github.io/rioxarray/stable/getting_started/getting_started.html - - This should work with any file that rasterio can open (most often: - geoTIFF). The x and y coordinates are generated automatically from the - file's geoinformation, shifted to the center of each pixel (see - `"PixelIsArea" Raster Space - `_ - for more information). - - Parameters - ---------- - filename : str, rasterio.DatasetReader, or rasterio.WarpedVRT - Path to the file to open. Or already open rasterio dataset. - parse_coordinates : bool, optional - Whether to parse the x and y coordinates out of the file's - ``transform`` attribute or not. The default is to automatically - parse the coordinates only if they are rectilinear (1D). - It can be useful to set ``parse_coordinates=False`` - if your files are very large or if you don't need the coordinates. - chunks : int, tuple or dict, optional - Chunk sizes along each dimension, e.g., ``5``, ``(5, 5)`` or - ``{'x': 5, 'y': 5}``. If chunks is provided, it used to load the new - DataArray into a dask array. - cache : bool, optional - If True, cache data loaded from the underlying datastore in memory as - NumPy arrays when accessed to avoid reading from the underlying data- - store multiple times. Defaults to True unless you specify the `chunks` - argument to use dask, in which case it defaults to False. - lock : False, True or threading.Lock, optional - If chunks is provided, this argument is passed on to - :py:func:`dask.array.from_array`. By default, a global lock is - used to avoid issues with concurrent access to the same file when using - dask's multithreaded backend. - - Returns - ------- - data : DataArray - The newly created DataArray. - """ - warnings.warn( - "open_rasterio is Deprecated in favor of rioxarray. " - "For information about transitioning, see: " - "https://corteva.github.io/rioxarray/stable/getting_started/getting_started.html", - DeprecationWarning, - stacklevel=2, - ) - import rasterio - from rasterio.vrt import WarpedVRT - - vrt_params = None - if isinstance(filename, rasterio.io.DatasetReader): - filename = filename.name - elif isinstance(filename, rasterio.vrt.WarpedVRT): - vrt = filename - filename = vrt.src_dataset.name - vrt_params = dict( - src_crs=vrt.src_crs.to_string(), - crs=vrt.crs.to_string(), - resampling=vrt.resampling, - tolerance=vrt.tolerance, - src_nodata=vrt.src_nodata, - nodata=vrt.nodata, - width=vrt.width, - height=vrt.height, - src_transform=vrt.src_transform, - transform=vrt.transform, - dtype=vrt.working_dtype, - warp_extras=vrt.warp_extras, - ) - - if lock is None: - lock = RASTERIO_LOCK - - manager = CachingFileManager( - rasterio.open, - filename, - lock=lock, - mode="r", - kwargs=kwargs, - ) - riods = manager.acquire() - if vrt_params is not None: - riods = WarpedVRT(riods, **vrt_params) - - if cache is None: - cache = chunks is None - - coords = {} - - # Get bands - if riods.count < 1: - raise ValueError("Unknown dims") - coords["band"] = np.asarray(riods.indexes) - - # Get coordinates - if riods.transform.is_rectilinear: - # 1d coordinates - parse = True if parse_coordinates is None else parse_coordinates - if parse: - nx, ny = riods.width, riods.height - # xarray coordinates are pixel centered - x, _ = riods.transform * (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) - _, y = riods.transform * (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) - coords["y"] = y - coords["x"] = x - else: - # 2d coordinates - parse = False if (parse_coordinates is None) else parse_coordinates - if parse: - warnings.warn( - "The file coordinates' transformation isn't " - "rectilinear: xarray won't parse the coordinates " - "in this case. Set `parse_coordinates=False` to " - "suppress this warning.", - RuntimeWarning, - stacklevel=3, - ) - - # Attributes - attrs = {} - # Affine transformation matrix (always available) - # This describes coefficients mapping pixel coordinates to CRS - # For serialization store as tuple of 6 floats, the last row being - # always (0, 0, 1) per definition (see - # https://github.com/sgillies/affine) - attrs["transform"] = tuple(riods.transform)[:6] - if hasattr(riods, "crs") and riods.crs: - # CRS is a dict-like object specific to rasterio - # If CRS is not None, we convert it back to a PROJ4 string using - # rasterio itself - try: - attrs["crs"] = riods.crs.to_proj4() - except AttributeError: - attrs["crs"] = riods.crs.to_string() - if hasattr(riods, "res"): - # (width, height) tuple of pixels in units of CRS - attrs["res"] = riods.res - if hasattr(riods, "is_tiled"): - # Is the TIF tiled? (bool) - # We cast it to an int for netCDF compatibility - attrs["is_tiled"] = np.uint8(riods.is_tiled) - if hasattr(riods, "nodatavals"): - # The nodata values for the raster bands - attrs["nodatavals"] = tuple( - np.nan if nodataval is None else nodataval for nodataval in riods.nodatavals - ) - if hasattr(riods, "scales"): - # The scale values for the raster bands - attrs["scales"] = riods.scales - if hasattr(riods, "offsets"): - # The offset values for the raster bands - attrs["offsets"] = riods.offsets - if hasattr(riods, "descriptions") and any(riods.descriptions): - # Descriptions for each dataset band - attrs["descriptions"] = riods.descriptions - if hasattr(riods, "units") and any(riods.units): - # A list of units string for each dataset band - attrs["units"] = riods.units - - # Parse extra metadata from tags, if supported - parsers = {"ENVI": _parse_envi, "GTiff": lambda m: m} - - driver = riods.driver - if driver in parsers: - if driver == "GTiff": - meta = parsers[driver](riods.tags()) - else: - meta = parsers[driver](riods.tags(ns=driver)) - - for k, v in meta.items(): - # Add values as coordinates if they match the band count, - # as attributes otherwise - if isinstance(v, (list, np.ndarray)) and len(v) == riods.count: - coords[k] = ("band", np.asarray(v)) - else: - attrs[k] = v - - data = indexing.LazilyIndexedArray(RasterioArrayWrapper(manager, lock, vrt_params)) - - # this lets you write arrays loaded with rasterio - data = indexing.CopyOnWriteArray(data) - if cache and chunks is None: - data = indexing.MemoryCachedArray(data) - - result = DataArray(data=data, dims=("band", "y", "x"), coords=coords, attrs=attrs) - - if chunks is not None: - from dask.base import tokenize - - # augment the token with the file modification time - try: - mtime = os.path.getmtime(os.path.expanduser(filename)) - except OSError: - # the filename is probably an s3 bucket rather than a regular file - mtime = None - token = tokenize(filename, mtime, chunks) - name_prefix = f"open_rasterio-{token}" - result = result.chunk(chunks, name_prefix=name_prefix, token=token) - - # Make the file closeable - result.set_close(manager.close) - - return result diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 651aebce2ce..f8c486e512c 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -3,6 +3,8 @@ import gzip import io import os +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any import numpy as np @@ -21,7 +23,7 @@ is_valid_nc3_name, ) from xarray.backends.store import StoreBackendEntrypoint -from xarray.core.indexing import NumpyIndexingAdapter +from xarray.core import indexing from xarray.core.utils import ( Frozen, FrozenDict, @@ -31,6 +33,15 @@ ) from xarray.core.variable import Variable +if TYPE_CHECKING: + from io import BufferedIOBase + + from xarray.backends.common import AbstractDataStore + from xarray.core.dataset import Dataset + + +HAS_NUMPY_2_0 = module_available("numpy", minversion="2.0.0.dev0") + def _decode_string(s): if isinstance(s, bytes): @@ -56,12 +67,25 @@ def get_variable(self, needs_lock=True): ds = self.datastore._manager.acquire(needs_lock) return ds.variables[self.variable_name] + def _getitem(self, key): + with self.datastore.lock: + data = self.get_variable(needs_lock=False).data + return data[key] + def __getitem__(self, key): - data = NumpyIndexingAdapter(self.get_variable().data)[key] + data = indexing.explicit_indexing_adapter( + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem + ) # Copy data if the source file is mmapped. This makes things consistent # with the netCDF4 library by ensuring we can safely read arrays even # after closing associated files. copy = self.datastore.ds.use_mmap + + # adapt handling of copy-kwarg to numpy 2.0 + # see https://github.com/numpy/numpy/issues/25916 + # and https://github.com/numpy/numpy/pull/25922 + copy = None if HAS_NUMPY_2_0 and copy is False else copy + return np.array(data, dtype=self.dtype, copy=copy) def __setitem__(self, key, value): @@ -261,32 +285,35 @@ class ScipyBackendEntrypoint(BackendEntrypoint): backends.H5netcdfBackendEntrypoint """ - available = module_available("scipy") description = "Open netCDF files (.nc, .nc4, .cdf and .gz) using scipy in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.ScipyBackendEntrypoint.html" - def guess_can_open(self, filename_or_obj): + def guess_can_open( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + ) -> bool: magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) if magic_number is not None and magic_number.startswith(b"\x1f\x8b"): - with gzip.open(filename_or_obj) as f: + with gzip.open(filename_or_obj) as f: # type: ignore[arg-type] magic_number = try_read_magic_number_from_file_or_path(f) if magic_number is not None: return magic_number.startswith(b"CDF") - try: + if isinstance(filename_or_obj, (str, os.PathLike)): _, ext = os.path.splitext(filename_or_obj) - except TypeError: - return False - return ext in {".nc", ".nc4", ".cdf", ".gz"} + return ext in {".nc", ".nc4", ".cdf", ".gz"} + + return False - def open_dataset( + def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs self, - filename_or_obj, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, mask_and_scale=True, decode_times=True, concat_characters=True, decode_coords=True, - drop_variables=None, + drop_variables: str | Iterable[str] | None = None, use_cftime=None, decode_timedelta=None, mode="r", @@ -294,7 +321,7 @@ def open_dataset( group=None, mmap=None, lock=None, - ): + ) -> Dataset: filename_or_obj = _normalize_path(filename_or_obj) store = ScipyDataStore( filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock @@ -315,4 +342,4 @@ def open_dataset( return ds -BACKEND_ENTRYPOINTS["scipy"] = ScipyBackendEntrypoint +BACKEND_ENTRYPOINTS["scipy"] = ("scipy", ScipyBackendEntrypoint) diff --git a/xarray/backends/store.py b/xarray/backends/store.py index 1f7a44bf4dc..a507ee37470 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -1,5 +1,8 @@ from __future__ import annotations +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + from xarray import conventions from xarray.backends.common import ( BACKEND_ENTRYPOINTS, @@ -8,29 +11,37 @@ ) from xarray.core.dataset import Dataset +if TYPE_CHECKING: + import os + from io import BufferedIOBase + class StoreBackendEntrypoint(BackendEntrypoint): - available = True description = "Open AbstractDataStore instances in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.StoreBackendEntrypoint.html" - def guess_can_open(self, filename_or_obj): + def guess_can_open( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + ) -> bool: return isinstance(filename_or_obj, AbstractDataStore) - def open_dataset( + def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs self, - store, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, *, mask_and_scale=True, decode_times=True, concat_characters=True, decode_coords=True, - drop_variables=None, + drop_variables: str | Iterable[str] | None = None, use_cftime=None, decode_timedelta=None, - ): - vars, attrs = store.load() - encoding = store.get_encoding() + ) -> Dataset: + assert isinstance(filename_or_obj, AbstractDataStore) + + vars, attrs = filename_or_obj.load() + encoding = filename_or_obj.get_encoding() vars, attrs, coord_names = conventions.decode_cf_variables( vars, @@ -46,10 +57,10 @@ def open_dataset( ds = Dataset(vars, attrs=attrs) ds = ds.set_coords(coord_names.intersection(vars)) - ds.set_close(store.close) + ds.set_close(filename_or_obj.close) ds.encoding = encoding return ds -BACKEND_ENTRYPOINTS["store"] = StoreBackendEntrypoint +BACKEND_ENTRYPOINTS["store"] = (None, StoreBackendEntrypoint) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index bc251d05631..e9465dc0ba0 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -3,6 +3,8 @@ import json import os import warnings +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any import numpy as np @@ -17,14 +19,23 @@ ) from xarray.backends.store import StoreBackendEntrypoint from xarray.core import indexing -from xarray.core.pycompat import integer_types +from xarray.core.types import ZarrWriteModes from xarray.core.utils import ( FrozenDict, HiddenKeyDict, close_on_error, - module_available, ) from xarray.core.variable import Variable +from xarray.namedarray.parallelcompat import guess_chunkmanager +from xarray.namedarray.pycompat import integer_types + +if TYPE_CHECKING: + from io import BufferedIOBase + + from xarray.backends.common import AbstractDataStore + from xarray.core.dataset import Dataset + from xarray.datatree_.datatree import DataTree + # need some special secret attributes to tell us the dimensions DIMENSION_KEY = "_ARRAY_DIMENSIONS" @@ -52,37 +63,47 @@ def encode_zarr_attr_value(value): class ZarrArrayWrapper(BackendArray): - __slots__ = ("datastore", "dtype", "shape", "variable_name") - - def __init__(self, variable_name, datastore): - self.datastore = datastore - self.variable_name = variable_name - - array = self.get_array() - self.shape = array.shape + __slots__ = ("dtype", "shape", "_array") + + def __init__(self, zarr_array): + # some callers attempt to evaluate an array if an `array` property exists on the object. + # we prefix with _ to avoid this inference. + self._array = zarr_array + self.shape = self._array.shape + + # preserve vlen string object dtype (GH 7328) + if self._array.filters is not None and any( + [filt.codec_id == "vlen-utf8" for filt in self._array.filters] + ): + dtype = coding.strings.create_vlen_dtype(str) + else: + dtype = self._array.dtype - dtype = array.dtype self.dtype = dtype def get_array(self): - return self.datastore.zarr_group[self.variable_name] + return self._array def _oindex(self, key): - return self.get_array().oindex[key] + return self._array.oindex[key] + + def _vindex(self, key): + return self._array.vindex[key] + + def _getitem(self, key): + return self._array[key] def __getitem__(self, key): - array = self.get_array() + array = self._array if isinstance(key, indexing.BasicIndexer): - return array[key.tuple] + method = self._getitem elif isinstance(key, indexing.VectorizedIndexer): - return array.vindex[ - indexing._arrayize_vectorized_indexer(key, self.shape).tuple - ] - else: - assert isinstance(key, indexing.OuterIndexer) - return indexing.explicit_indexing_adapter( - key, array.shape, indexing.IndexingSupport.VECTORIZED, self._oindex - ) + method = self._vindex + elif isinstance(key, indexing.OuterIndexer): + method = self._oindex + return indexing.explicit_indexing_adapter( + key, array.shape, indexing.IndexingSupport.VECTORIZED, method + ) # if self.ndim == 0: # could possibly have a work-around for 0d data here @@ -159,8 +180,8 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks): # DESIGN CHOICE: do not allow multiple dask chunks on a single zarr chunk # this avoids the need to get involved in zarr synchronization / locking # From zarr docs: - # "If each worker in a parallel computation is writing to a separate - # region of the array, and if region boundaries are perfectly aligned + # "If each worker in a parallel computation is writing to a + # separate region of the array, and if region boundaries are perfectly aligned # with chunk boundaries, then no synchronization is required." # TODO: incorporate synchronizer to allow writes from multiple dask # threads @@ -302,14 +323,19 @@ def encode_zarr_variable(var, needs_copy=True, name=None): return var -def _validate_existing_dims(var_name, new_var, existing_var, region, append_dim): +def _validate_and_transpose_existing_dims( + var_name, new_var, existing_var, region, append_dim +): if new_var.dims != existing_var.dims: - raise ValueError( - f"variable {var_name!r} already exists with different " - f"dimension names {existing_var.dims} != " - f"{new_var.dims}, but changing variable " - f"dimensions is not supported by to_zarr()." - ) + if set(existing_var.dims) == set(new_var.dims): + new_var = new_var.transpose(*existing_var.dims) + else: + raise ValueError( + f"variable {var_name!r} already exists with different " + f"dimension names {existing_var.dims} != " + f"{new_var.dims}, but changing variable " + f"dimensions is not supported by to_zarr()." + ) existing_sizes = {} for dim, size in existing_var.sizes.items(): @@ -326,9 +352,14 @@ def _validate_existing_dims(var_name, new_var, existing_var, region, append_dim) f"variable {var_name!r} already exists with different " f"dimension sizes: {existing_sizes} != {new_sizes}. " f"to_zarr() only supports changing dimension sizes when " - f"explicitly appending, but append_dim={append_dim!r}." + f"explicitly appending, but append_dim={append_dim!r}. " + f"If you are attempting to write to a subset of the " + f"existing store without changing dimension sizes, " + f"consider using the region argument in to_zarr()." ) + return new_var + def _put_attrs(zarr_obj, attrs): """Raise a more informative error message for invalid attrs.""" @@ -352,13 +383,15 @@ class ZarrStore(AbstractWritableDataStore): "_synchronizer", "_write_region", "_safe_chunks", + "_write_empty", + "_close_store_on_close", ) @classmethod def open_group( cls, store, - mode="r", + mode: ZarrWriteModes = "r", synchronizer=None, group=None, consolidated=False, @@ -370,6 +403,7 @@ def open_group( safe_chunks=True, stacklevel=2, zarr_version=None, + write_empty: bool | None = None, ): import zarr @@ -382,7 +416,8 @@ def open_group( zarr_version = getattr(store, "_store_version", 2) open_kwargs = dict( - mode=mode, + # mode='a-' is a handcrafted xarray specialty + mode="a" if mode == "a-" else mode, synchronizer=synchronizer, path=group, ) @@ -401,7 +436,7 @@ def open_group( if consolidated is None: consolidated = False - if chunk_store: + if chunk_store is not None: open_kwargs["chunk_store"] = chunk_store if consolidated is None: consolidated = False @@ -434,6 +469,7 @@ def open_group( zarr_group = zarr.open_consolidated(store, **open_kwargs) else: zarr_group = zarr.open_group(store, **open_kwargs) + close_store_on_close = zarr_group.store is not store return cls( zarr_group, mode, @@ -441,6 +477,8 @@ def open_group( append_dim, write_region, safe_chunks, + write_empty, + close_store_on_close, ) def __init__( @@ -451,6 +489,8 @@ def __init__( append_dim=None, write_region=None, safe_chunks=True, + write_empty: bool | None = None, + close_store_on_close: bool = False, ): self.zarr_group = zarr_group self._read_only = self.zarr_group.read_only @@ -461,6 +501,8 @@ def __init__( self._append_dim = append_dim self._write_region = write_region self._safe_chunks = safe_chunks + self._write_empty = write_empty + self._close_store_on_close = close_store_on_close @property def ds(self): @@ -468,7 +510,7 @@ def ds(self): return self.zarr_group def open_store_variable(self, name, zarr_array): - data = indexing.LazilyIndexedArray(ZarrArrayWrapper(name, self)) + data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array)) try_nczarr = self._mode == "r" dimensions, attributes = _get_zarr_dims_and_attrs( zarr_array, DIMENSION_KEY, try_nczarr @@ -566,8 +608,9 @@ def store( """ import zarr + existing_keys = tuple(self.zarr_group.array_keys()) existing_variable_names = { - vn for vn in variables if _encode_variable_name(vn) in self.zarr_group + vn for vn in variables if _encode_variable_name(vn) in existing_keys } new_variables = set(variables) - existing_variable_names variables_without_encoding = {vn: variables[vn] for vn in new_variables} @@ -591,12 +634,10 @@ def store( variables_encoded.update(vars_with_encoding) for var_name in existing_variable_names: - new_var = variables_encoded[var_name] - existing_var = existing_vars[var_name] - _validate_existing_dims( + variables_encoded[var_name] = _validate_and_transpose_existing_dims( var_name, - new_var, - existing_var, + variables_encoded[var_name], + existing_vars[var_name], self._write_region, self._append_dim, ) @@ -605,8 +646,21 @@ def store( self.set_attributes(attributes) self.set_dimensions(variables_encoded, unlimited_dims=unlimited_dims) + # if we are appending to an append_dim, only write either + # - new variables not already present, OR + # - variables with the append_dim in their dimensions + # We do NOT overwrite other variables. + if self._mode == "a-" and self._append_dim is not None: + variables_to_set = { + k: v + for k, v in variables_encoded.items() + if (k not in existing_variable_names) or (self._append_dim in v.dims) + } + else: + variables_to_set = variables_encoded + self.set_variables( - variables_encoded, check_encoding_set, writer, unlimited_dims=unlimited_dims + variables_to_set, check_encoding_set, writer, unlimited_dims=unlimited_dims ) if self._consolidate_on_close: zarr.consolidate_metadata(self.zarr_group.store) @@ -632,6 +686,10 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No dimensions. """ + import zarr + + existing_keys = tuple(self.zarr_group.array_keys()) + for vn, v in variables.items(): name = _encode_variable_name(vn) check = vn in check_encoding_set @@ -644,12 +702,34 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No if v.encoding == {"_FillValue": None} and fill_value is None: v.encoding = {} - if name in self.zarr_group: + if name in existing_keys: # existing variable # TODO: if mode="a", consider overriding the existing variable # metadata. This would need some case work properly with region # and append_dim. - zarr_array = self.zarr_group[name] + if self._write_empty is not None: + # Write to zarr_group.chunk_store instead of zarr_group.store + # See https://github.com/pydata/xarray/pull/8326#discussion_r1365311316 for a longer explanation + # The open_consolidated() enforces a mode of r or r+ + # (and to_zarr with region provided enforces a read mode of r+), + # and this function makes sure the resulting Group has a store of type ConsolidatedMetadataStore + # and a 'normal Store subtype for chunk_store. + # The exact type depends on if a local path was used, or a URL of some sort, + # but the point is that it's not a read-only ConsolidatedMetadataStore. + # It is safe to write chunk data to the chunk_store because no metadata would be changed by + # to_zarr with the region parameter: + # - Because the write mode is enforced to be r+, no new variables can be added to the store + # (this is also checked and enforced in xarray.backends.api.py::to_zarr()). + # - Existing variables already have their attrs included in the consolidated metadata file. + # - The size of dimensions can not be expanded, that would require a call using `append_dim` + # which is mutually exclusive with `region` + zarr_array = zarr.open( + store=self.zarr_group.chunk_store, + path=f"{self.zarr_group.name}/{name}", + write_empty_chunks=self._write_empty, + ) + else: + zarr_array = self.zarr_group[name] else: # new variable encoding = extract_zarr_variable_encoding( @@ -663,8 +743,25 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No if coding.strings.check_vlen_dtype(dtype) == str: dtype = str + + if self._write_empty is not None: + if ( + "write_empty_chunks" in encoding + and encoding["write_empty_chunks"] != self._write_empty + ): + raise ValueError( + 'Differing "write_empty_chunks" values in encoding and parameters' + f'Got {encoding["write_empty_chunks"] = } and {self._write_empty = }' + ) + else: + encoding["write_empty_chunks"] = self._write_empty + zarr_array = self.zarr_group.create( - name, shape=shape, dtype=dtype, fill_value=fill_value, **encoding + name, + shape=shape, + dtype=dtype, + fill_value=fill_value, + **encoding, ) zarr_array = _put_attrs(zarr_array, encoded_attrs) @@ -687,7 +784,8 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No writer.add(v.data, zarr_array, region) def close(self): - pass + if self._close_store_on_close: + self.zarr_group.store.close() def open_zarr( @@ -708,6 +806,8 @@ def open_zarr( decode_timedelta=None, use_cftime=None, zarr_version=None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, **kwargs, ): """Load and decode a dataset from a Zarr store. @@ -792,6 +892,15 @@ def open_zarr( The desired zarr spec version to target (currently 2 or 3). The default of None will attempt to determine the zarr version from ``store`` when possible, otherwise defaulting to 2. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + Defaults to {'manager': 'dask'}, meaning additional kwargs will be passed eventually to + :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. Returns ------- @@ -809,12 +918,17 @@ def open_zarr( """ from xarray.backends.api import open_dataset + if from_array_kwargs is None: + from_array_kwargs = {} + if chunks == "auto": try: - import dask.array # noqa + guess_chunkmanager( + chunked_array_type + ) # attempt to import that parallel backend chunks = {} - except ImportError: + except ValueError: chunks = None if kwargs: @@ -843,6 +957,8 @@ def open_zarr( engine="zarr", chunks=chunks, drop_variables=drop_variables, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, backend_kwargs=backend_kwargs, decode_timedelta=decode_timedelta, use_cftime=use_cftime, @@ -863,25 +979,28 @@ class ZarrBackendEntrypoint(BackendEntrypoint): backends.ZarrStore """ - available = module_available("zarr") description = "Open zarr files (.zarr) using zarr in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.ZarrBackendEntrypoint.html" - def guess_can_open(self, filename_or_obj): - try: + def guess_can_open( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + ) -> bool: + if isinstance(filename_or_obj, (str, os.PathLike)): _, ext = os.path.splitext(filename_or_obj) - except TypeError: - return False - return ext in {".zarr"} + return ext in {".zarr"} + + return False - def open_dataset( + def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs self, - filename_or_obj, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, mask_and_scale=True, decode_times=True, concat_characters=True, decode_coords=True, - drop_variables=None, + drop_variables: str | Iterable[str] | None = None, use_cftime=None, decode_timedelta=None, group=None, @@ -892,7 +1011,7 @@ def open_dataset( storage_options=None, stacklevel=3, zarr_version=None, - ): + ) -> Dataset: filename_or_obj = _normalize_path(filename_or_obj) store = ZarrStore.open_group( filename_or_obj, @@ -921,5 +1040,48 @@ def open_dataset( ) return ds + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + **kwargs, + ) -> DataTree: + import zarr + + from xarray.backends.api import open_dataset + from xarray.core.treenode import NodePath + from xarray.datatree_.datatree import DataTree + + zds = zarr.open_group(filename_or_obj, mode="r") + ds = open_dataset(filename_or_obj, engine="zarr", **kwargs) + tree_root = DataTree.from_dict({"/": ds}) + for path in _iter_zarr_groups(zds): + try: + subgroup_ds = open_dataset( + filename_or_obj, engine="zarr", group=path, **kwargs + ) + except zarr.errors.PathNotFoundError: + subgroup_ds = Dataset() + + # TODO refactor to use __setitem__ once creation of new nodes by assigning Dataset works again + node_name = NodePath(path).name + new_node: DataTree = DataTree(name=node_name, data=subgroup_ds) + tree_root._set_item( + path, + new_node, + allow_overwrite=False, + new_nodes_along_path=True, + ) + return tree_root + + +def _iter_zarr_groups(root, parent="/"): + from xarray.core.treenode import NodePath + + parent = NodePath(parent) + for path, group in root.groups(): + gpath = parent / path + yield str(gpath) + yield from _iter_zarr_groups(group, parent=gpath) + -BACKEND_ENTRYPOINTS["zarr"] = ZarrBackendEntrypoint +BACKEND_ENTRYPOINTS["zarr"] = ("zarr", ZarrBackendEntrypoint) diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index 792724ecc79..2e594455874 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -1,4 +1,5 @@ """Time offset classes for use with cftime.datetime objects""" + # The offset classes and mechanisms for generating time ranges defined in # this module were copied/adapted from those defined in pandas. See in # particular the objects and methods defined in pandas.tseries.offsets @@ -48,6 +49,7 @@ import numpy as np import pandas as pd +from packaging.version import Version from xarray.coding.cftimeindex import CFTimeIndex, _parse_iso8601_with_reso from xarray.coding.times import ( @@ -57,7 +59,12 @@ format_cftime_datetime, ) from xarray.core.common import _contains_datetime_like_objects, is_np_datetime_like -from xarray.core.pdcompat import NoDefault, count_not_none, no_default +from xarray.core.pdcompat import ( + NoDefault, + count_not_none, + nanosecond_precision_timestamp, + no_default, +) from xarray.core.utils import emit_user_level_warning try: @@ -100,7 +107,7 @@ def __init__(self, n: int = 1): if not isinstance(n, int): raise TypeError( "The provided multiple 'n' must be an integer. " - "Instead a value of type {!r} was provided.".format(type(n)) + f"Instead a value of type {type(n)!r} was provided." ) self.n = n @@ -348,13 +355,13 @@ def _validate_month(month, default_month): raise TypeError( "'self.month' must be an integer value between 1 " "and 12. Instead, it was set to a value of " - "{!r}".format(result_month) + f"{result_month!r}" ) elif not (1 <= result_month <= 12): raise ValueError( "'self.month' must be an integer value between 1 " "and 12. Instead, it was set to a value of " - "{!r}".format(result_month) + f"{result_month!r}" ) return result_month @@ -373,7 +380,7 @@ def onOffset(self, date): class MonthEnd(BaseCFTimeOffset): - _freq = "M" + _freq = "ME" def __apply__(self, other): n = _adjust_n_months(other.day, self.n, _days_in_month(other)) @@ -485,7 +492,7 @@ class QuarterEnd(QuarterOffset): # from the constructor, however, the default month is March. # We follow that behavior here. _default_month = 3 - _freq = "Q" + _freq = "QE" _day_option = "end" def rollforward(self, date): @@ -542,7 +549,7 @@ def __str__(self): class YearBegin(YearOffset): - _freq = "AS" + _freq = "YS" _day_option = "start" _default_month = 1 @@ -567,7 +574,7 @@ def rollback(self, date): class YearEnd(YearOffset): - _freq = "A" + _freq = "YE" _day_option = "end" _default_month = 12 @@ -602,7 +609,7 @@ def __apply__(self, other): class Hour(Tick): - _freq = "H" + _freq = "h" def as_timedelta(self): return timedelta(hours=self.n) @@ -612,7 +619,7 @@ def __apply__(self, other): class Minute(Tick): - _freq = "T" + _freq = "min" def as_timedelta(self): return timedelta(minutes=self.n) @@ -622,7 +629,7 @@ def __apply__(self, other): class Second(Tick): - _freq = "S" + _freq = "s" def as_timedelta(self): return timedelta(seconds=self.n) @@ -632,7 +639,7 @@ def __apply__(self, other): class Millisecond(Tick): - _freq = "L" + _freq = "ms" def as_timedelta(self): return timedelta(milliseconds=self.n) @@ -642,7 +649,7 @@ def __apply__(self, other): class Microsecond(Tick): - _freq = "U" + _freq = "us" def as_timedelta(self): return timedelta(microseconds=self.n) @@ -651,77 +658,50 @@ def __apply__(self, other): return other + self.as_timedelta() +def _generate_anchored_offsets(base_freq, offset): + offsets = {} + for month, abbreviation in _MONTH_ABBREVIATIONS.items(): + anchored_freq = f"{base_freq}-{abbreviation}" + offsets[anchored_freq] = partial(offset, month=month) + return offsets + + _FREQUENCIES = { "A": YearEnd, "AS": YearBegin, "Y": YearEnd, + "YE": YearEnd, "YS": YearBegin, "Q": partial(QuarterEnd, month=12), + "QE": partial(QuarterEnd, month=12), "QS": partial(QuarterBegin, month=1), "M": MonthEnd, + "ME": MonthEnd, "MS": MonthBegin, "D": Day, "H": Hour, + "h": Hour, "T": Minute, "min": Minute, "S": Second, + "s": Second, "L": Millisecond, "ms": Millisecond, "U": Microsecond, "us": Microsecond, - "AS-JAN": partial(YearBegin, month=1), - "AS-FEB": partial(YearBegin, month=2), - "AS-MAR": partial(YearBegin, month=3), - "AS-APR": partial(YearBegin, month=4), - "AS-MAY": partial(YearBegin, month=5), - "AS-JUN": partial(YearBegin, month=6), - "AS-JUL": partial(YearBegin, month=7), - "AS-AUG": partial(YearBegin, month=8), - "AS-SEP": partial(YearBegin, month=9), - "AS-OCT": partial(YearBegin, month=10), - "AS-NOV": partial(YearBegin, month=11), - "AS-DEC": partial(YearBegin, month=12), - "A-JAN": partial(YearEnd, month=1), - "A-FEB": partial(YearEnd, month=2), - "A-MAR": partial(YearEnd, month=3), - "A-APR": partial(YearEnd, month=4), - "A-MAY": partial(YearEnd, month=5), - "A-JUN": partial(YearEnd, month=6), - "A-JUL": partial(YearEnd, month=7), - "A-AUG": partial(YearEnd, month=8), - "A-SEP": partial(YearEnd, month=9), - "A-OCT": partial(YearEnd, month=10), - "A-NOV": partial(YearEnd, month=11), - "A-DEC": partial(YearEnd, month=12), - "QS-JAN": partial(QuarterBegin, month=1), - "QS-FEB": partial(QuarterBegin, month=2), - "QS-MAR": partial(QuarterBegin, month=3), - "QS-APR": partial(QuarterBegin, month=4), - "QS-MAY": partial(QuarterBegin, month=5), - "QS-JUN": partial(QuarterBegin, month=6), - "QS-JUL": partial(QuarterBegin, month=7), - "QS-AUG": partial(QuarterBegin, month=8), - "QS-SEP": partial(QuarterBegin, month=9), - "QS-OCT": partial(QuarterBegin, month=10), - "QS-NOV": partial(QuarterBegin, month=11), - "QS-DEC": partial(QuarterBegin, month=12), - "Q-JAN": partial(QuarterEnd, month=1), - "Q-FEB": partial(QuarterEnd, month=2), - "Q-MAR": partial(QuarterEnd, month=3), - "Q-APR": partial(QuarterEnd, month=4), - "Q-MAY": partial(QuarterEnd, month=5), - "Q-JUN": partial(QuarterEnd, month=6), - "Q-JUL": partial(QuarterEnd, month=7), - "Q-AUG": partial(QuarterEnd, month=8), - "Q-SEP": partial(QuarterEnd, month=9), - "Q-OCT": partial(QuarterEnd, month=10), - "Q-NOV": partial(QuarterEnd, month=11), - "Q-DEC": partial(QuarterEnd, month=12), + **_generate_anchored_offsets("AS", YearBegin), + **_generate_anchored_offsets("A", YearEnd), + **_generate_anchored_offsets("YS", YearBegin), + **_generate_anchored_offsets("Y", YearEnd), + **_generate_anchored_offsets("YE", YearEnd), + **_generate_anchored_offsets("QS", QuarterBegin), + **_generate_anchored_offsets("Q", QuarterEnd), + **_generate_anchored_offsets("QE", QuarterEnd), } _FREQUENCY_CONDITION = "|".join(_FREQUENCIES.keys()) -_PATTERN = rf"^((?P\d+)|())(?P({_FREQUENCY_CONDITION}))$" +_PATTERN = rf"^((?P[+-]?\d+)|())(?P({_FREQUENCY_CONDITION}))$" # pandas defines these offsets as "Tick" objects, which for instance have @@ -729,7 +709,49 @@ def __apply__(self, other): CFTIME_TICKS = (Day, Hour, Minute, Second) -def to_offset(freq): +def _generate_anchored_deprecated_frequencies(deprecated, recommended): + pairs = {} + for abbreviation in _MONTH_ABBREVIATIONS.values(): + anchored_deprecated = f"{deprecated}-{abbreviation}" + anchored_recommended = f"{recommended}-{abbreviation}" + pairs[anchored_deprecated] = anchored_recommended + return pairs + + +_DEPRECATED_FREQUENICES = { + "A": "YE", + "Y": "YE", + "AS": "YS", + "Q": "QE", + "M": "ME", + "H": "h", + "T": "min", + "S": "s", + "L": "ms", + "U": "us", + **_generate_anchored_deprecated_frequencies("A", "YE"), + **_generate_anchored_deprecated_frequencies("Y", "YE"), + **_generate_anchored_deprecated_frequencies("AS", "YS"), + **_generate_anchored_deprecated_frequencies("Q", "QE"), +} + + +_DEPRECATION_MESSAGE = ( + "{deprecated_freq!r} is deprecated and will be removed in a future " + "version. Please use {recommended_freq!r} instead of " + "{deprecated_freq!r}." +) + + +def _emit_freq_deprecation_warning(deprecated_freq): + recommended_freq = _DEPRECATED_FREQUENICES[deprecated_freq] + message = _DEPRECATION_MESSAGE.format( + deprecated_freq=deprecated_freq, recommended_freq=recommended_freq + ) + emit_user_level_warning(message, FutureWarning) + + +def to_offset(freq, warn=True): """Convert a frequency string to the appropriate subclass of BaseCFTimeOffset.""" if isinstance(freq, BaseCFTimeOffset): @@ -741,6 +763,8 @@ def to_offset(freq): raise ValueError("Invalid frequency string provided") freq = freq_data["freq"] + if warn and freq in _DEPRECATED_FREQUENICES: + _emit_freq_deprecation_warning(freq) multiples = freq_data["multiple"] multiples = 1 if multiples is None else int(multiples) return _FREQUENCIES[freq](n=multiples) @@ -766,7 +790,7 @@ def to_cftime_datetime(date_str_or_date, calendar=None): raise TypeError( "date_str_or_date must be a string or a " "subclass of cftime.datetime. Instead got " - "{!r}.".format(date_str_or_date) + f"{date_str_or_date!r}." ) @@ -802,7 +826,8 @@ def _generate_range(start, end, periods, offset): """Generate a regular range of cftime.datetime objects with a given time offset. - Adapted from pandas.tseries.offsets.generate_range. + Adapted from pandas.tseries.offsets.generate_range (now at + pandas.core.arrays.datetimes._generate_range). Parameters ---------- @@ -822,10 +847,7 @@ def _generate_range(start, end, periods, offset): if start: start = offset.rollforward(start) - if end: - end = offset.rollback(end) - - if periods is None and end < start: + if periods is None and end < start and offset.n >= 0: end = None periods = 0 @@ -892,7 +914,7 @@ def cftime_range( start=None, end=None, periods=None, - freq="D", + freq=None, normalize=False, name=None, closed: NoDefault | SideOptions = no_default, @@ -910,7 +932,7 @@ def cftime_range( periods : int, optional Number of periods to generate. freq : str or None, default: "D" - Frequency strings can have multiples, e.g. "5H". + Frequency strings can have multiples, e.g. "5h" and negative values, e.g. "-1D". normalize : bool, default: False Normalize start/end dates to midnight before generating date range. name : str, default: None @@ -960,84 +982,84 @@ def cftime_range( +--------+--------------------------+ | Alias | Description | +========+==========================+ - | A, Y | Year-end frequency | + | YE | Year-end frequency | +--------+--------------------------+ - | AS, YS | Year-start frequency | + | YS | Year-start frequency | +--------+--------------------------+ - | Q | Quarter-end frequency | + | QE | Quarter-end frequency | +--------+--------------------------+ | QS | Quarter-start frequency | +--------+--------------------------+ - | M | Month-end frequency | + | ME | Month-end frequency | +--------+--------------------------+ | MS | Month-start frequency | +--------+--------------------------+ | D | Day frequency | +--------+--------------------------+ - | H | Hour frequency | + | h | Hour frequency | +--------+--------------------------+ - | T, min | Minute frequency | + | min | Minute frequency | +--------+--------------------------+ - | S | Second frequency | + | s | Second frequency | +--------+--------------------------+ - | L, ms | Millisecond frequency | + | ms | Millisecond frequency | +--------+--------------------------+ - | U, us | Microsecond frequency | + | us | Microsecond frequency | +--------+--------------------------+ Any multiples of the following anchored offsets are also supported. - +----------+--------------------------------------------------------------------+ - | Alias | Description | - +==========+====================================================================+ - | A(S)-JAN | Annual frequency, anchored at the end (or beginning) of January | - +----------+--------------------------------------------------------------------+ - | A(S)-FEB | Annual frequency, anchored at the end (or beginning) of February | - +----------+--------------------------------------------------------------------+ - | A(S)-MAR | Annual frequency, anchored at the end (or beginning) of March | - +----------+--------------------------------------------------------------------+ - | A(S)-APR | Annual frequency, anchored at the end (or beginning) of April | - +----------+--------------------------------------------------------------------+ - | A(S)-MAY | Annual frequency, anchored at the end (or beginning) of May | - +----------+--------------------------------------------------------------------+ - | A(S)-JUN | Annual frequency, anchored at the end (or beginning) of June | - +----------+--------------------------------------------------------------------+ - | A(S)-JUL | Annual frequency, anchored at the end (or beginning) of July | - +----------+--------------------------------------------------------------------+ - | A(S)-AUG | Annual frequency, anchored at the end (or beginning) of August | - +----------+--------------------------------------------------------------------+ - | A(S)-SEP | Annual frequency, anchored at the end (or beginning) of September | - +----------+--------------------------------------------------------------------+ - | A(S)-OCT | Annual frequency, anchored at the end (or beginning) of October | - +----------+--------------------------------------------------------------------+ - | A(S)-NOV | Annual frequency, anchored at the end (or beginning) of November | - +----------+--------------------------------------------------------------------+ - | A(S)-DEC | Annual frequency, anchored at the end (or beginning) of December | - +----------+--------------------------------------------------------------------+ - | Q(S)-JAN | Quarter frequency, anchored at the end (or beginning) of January | - +----------+--------------------------------------------------------------------+ - | Q(S)-FEB | Quarter frequency, anchored at the end (or beginning) of February | - +----------+--------------------------------------------------------------------+ - | Q(S)-MAR | Quarter frequency, anchored at the end (or beginning) of March | - +----------+--------------------------------------------------------------------+ - | Q(S)-APR | Quarter frequency, anchored at the end (or beginning) of April | - +----------+--------------------------------------------------------------------+ - | Q(S)-MAY | Quarter frequency, anchored at the end (or beginning) of May | - +----------+--------------------------------------------------------------------+ - | Q(S)-JUN | Quarter frequency, anchored at the end (or beginning) of June | - +----------+--------------------------------------------------------------------+ - | Q(S)-JUL | Quarter frequency, anchored at the end (or beginning) of July | - +----------+--------------------------------------------------------------------+ - | Q(S)-AUG | Quarter frequency, anchored at the end (or beginning) of August | - +----------+--------------------------------------------------------------------+ - | Q(S)-SEP | Quarter frequency, anchored at the end (or beginning) of September | - +----------+--------------------------------------------------------------------+ - | Q(S)-OCT | Quarter frequency, anchored at the end (or beginning) of October | - +----------+--------------------------------------------------------------------+ - | Q(S)-NOV | Quarter frequency, anchored at the end (or beginning) of November | - +----------+--------------------------------------------------------------------+ - | Q(S)-DEC | Quarter frequency, anchored at the end (or beginning) of December | - +----------+--------------------------------------------------------------------+ + +------------+--------------------------------------------------------------------+ + | Alias | Description | + +============+====================================================================+ + | Y(E,S)-JAN | Annual frequency, anchored at the (end, beginning) of January | + +------------+--------------------------------------------------------------------+ + | Y(E,S)-FEB | Annual frequency, anchored at the (end, beginning) of February | + +------------+--------------------------------------------------------------------+ + | Y(E,S)-MAR | Annual frequency, anchored at the (end, beginning) of March | + +------------+--------------------------------------------------------------------+ + | Y(E,S)-APR | Annual frequency, anchored at the (end, beginning) of April | + +------------+--------------------------------------------------------------------+ + | Y(E,S)-MAY | Annual frequency, anchored at the (end, beginning) of May | + +------------+--------------------------------------------------------------------+ + | Y(E,S)-JUN | Annual frequency, anchored at the (end, beginning) of June | + +------------+--------------------------------------------------------------------+ + | Y(E,S)-JUL | Annual frequency, anchored at the (end, beginning) of July | + +------------+--------------------------------------------------------------------+ + | Y(E,S)-AUG | Annual frequency, anchored at the (end, beginning) of August | + +------------+--------------------------------------------------------------------+ + | Y(E,S)-SEP | Annual frequency, anchored at the (end, beginning) of September | + +------------+--------------------------------------------------------------------+ + | Y(E,S)-OCT | Annual frequency, anchored at the (end, beginning) of October | + +------------+--------------------------------------------------------------------+ + | Y(E,S)-NOV | Annual frequency, anchored at the (end, beginning) of November | + +------------+--------------------------------------------------------------------+ + | Y(E,S)-DEC | Annual frequency, anchored at the (end, beginning) of December | + +------------+--------------------------------------------------------------------+ + | Q(E,S)-JAN | Quarter frequency, anchored at the (end, beginning) of January | + +------------+--------------------------------------------------------------------+ + | Q(E,S)-FEB | Quarter frequency, anchored at the (end, beginning) of February | + +------------+--------------------------------------------------------------------+ + | Q(E,S)-MAR | Quarter frequency, anchored at the (end, beginning) of March | + +------------+--------------------------------------------------------------------+ + | Q(E,S)-APR | Quarter frequency, anchored at the (end, beginning) of April | + +------------+--------------------------------------------------------------------+ + | Q(E,S)-MAY | Quarter frequency, anchored at the (end, beginning) of May | + +------------+--------------------------------------------------------------------+ + | Q(E,S)-JUN | Quarter frequency, anchored at the (end, beginning) of June | + +------------+--------------------------------------------------------------------+ + | Q(E,S)-JUL | Quarter frequency, anchored at the (end, beginning) of July | + +------------+--------------------------------------------------------------------+ + | Q(E,S)-AUG | Quarter frequency, anchored at the (end, beginning) of August | + +------------+--------------------------------------------------------------------+ + | Q(E,S)-SEP | Quarter frequency, anchored at the (end, beginning) of September | + +------------+--------------------------------------------------------------------+ + | Q(E,S)-OCT | Quarter frequency, anchored at the (end, beginning) of October | + +------------+--------------------------------------------------------------------+ + | Q(E,S)-NOV | Quarter frequency, anchored at the (end, beginning) of November | + +------------+--------------------------------------------------------------------+ + | Q(E,S)-DEC | Quarter frequency, anchored at the (end, beginning) of December | + +------------+--------------------------------------------------------------------+ Finally, the following calendar aliases are supported. @@ -1078,6 +1100,10 @@ def cftime_range( -------- pandas.date_range """ + + if freq is None and any(arg is None for arg in [periods, start, end]): + freq = "D" + # Adapted from pandas.core.indexes.datetimes._generate_range. if count_not_none(start, end, periods, freq) != 3: raise ValueError( @@ -1130,7 +1156,7 @@ def date_range( start=None, end=None, periods=None, - freq="D", + freq=None, tz=None, normalize=False, name=None, @@ -1153,7 +1179,7 @@ def date_range( periods : int, optional Number of periods to generate. freq : str or None, default: "D" - Frequency strings can have multiples, e.g. "5H". + Frequency strings can have multiples, e.g. "5h" and negative values, e.g. "-1D". tz : str or tzinfo, optional Time zone name for returning localized DatetimeIndex, for example 'Asia/Hong_Kong'. By default, the resulting DatetimeIndex is @@ -1207,7 +1233,8 @@ def date_range( start=start, end=end, periods=periods, - freq=freq, + # TODO remove translation once requiring pandas >= 2.2 + freq=_new_to_legacy_freq(freq), tz=tz, normalize=normalize, name=name, @@ -1235,6 +1262,96 @@ def date_range( ) +def _new_to_legacy_freq(freq): + # xarray will now always return "ME" and "QE" for MonthEnd and QuarterEnd + # frequencies, but older versions of pandas do not support these as + # frequency strings. Until xarray's minimum pandas version is 2.2 or above, + # we add logic to continue using the deprecated "M" and "Q" frequency + # strings in these circumstances. + + # NOTE: other conversions ("h" -> "H", ..., "ns" -> "N") not required + + # TODO: remove once requiring pandas >= 2.2 + if not freq or Version(pd.__version__) >= Version("2.2"): + return freq + + try: + freq_as_offset = to_offset(freq) + except ValueError: + # freq may be valid in pandas but not in xarray + return freq + + if isinstance(freq_as_offset, MonthEnd) and "ME" in freq: + freq = freq.replace("ME", "M") + elif isinstance(freq_as_offset, QuarterEnd) and "QE" in freq: + freq = freq.replace("QE", "Q") + elif isinstance(freq_as_offset, YearBegin) and "YS" in freq: + freq = freq.replace("YS", "AS") + elif isinstance(freq_as_offset, YearEnd): + # testing for "Y" is required as this was valid in xarray 2023.11 - 2024.01 + if "Y-" in freq: + # Check for and replace "Y-" instead of just "Y" to prevent + # corrupting anchored offsets that contain "Y" in the month + # abbreviation, e.g. "Y-MAY" -> "A-MAY". + freq = freq.replace("Y-", "A-") + elif "YE-" in freq: + freq = freq.replace("YE-", "A-") + elif "A-" not in freq and freq.endswith("Y"): + freq = freq.replace("Y", "A") + elif freq.endswith("YE"): + freq = freq.replace("YE", "A") + + return freq + + +def _legacy_to_new_freq(freq): + # to avoid internal deprecation warnings when freq is determined using pandas < 2.2 + + # TODO: remove once requiring pandas >= 2.2 + + if not freq or Version(pd.__version__) >= Version("2.2"): + return freq + + try: + freq_as_offset = to_offset(freq, warn=False) + except ValueError: + # freq may be valid in pandas but not in xarray + return freq + + if isinstance(freq_as_offset, MonthEnd) and "ME" not in freq: + freq = freq.replace("M", "ME") + elif isinstance(freq_as_offset, QuarterEnd) and "QE" not in freq: + freq = freq.replace("Q", "QE") + elif isinstance(freq_as_offset, YearBegin) and "YS" not in freq: + freq = freq.replace("AS", "YS") + elif isinstance(freq_as_offset, YearEnd): + if "A-" in freq: + # Check for and replace "A-" instead of just "A" to prevent + # corrupting anchored offsets that contain "Y" in the month + # abbreviation, e.g. "A-MAY" -> "YE-MAY". + freq = freq.replace("A-", "YE-") + elif "Y-" in freq: + freq = freq.replace("Y-", "YE-") + elif freq.endswith("A"): + # the "A-MAY" case is already handled above + freq = freq.replace("A", "YE") + elif "YE" not in freq and freq.endswith("Y"): + # the "Y-MAY" case is already handled above + freq = freq.replace("Y", "YE") + elif isinstance(freq_as_offset, Hour): + freq = freq.replace("H", "h") + elif isinstance(freq_as_offset, Minute): + freq = freq.replace("T", "min") + elif isinstance(freq_as_offset, Second): + freq = freq.replace("S", "s") + elif isinstance(freq_as_offset, Millisecond): + freq = freq.replace("L", "ms") + elif isinstance(freq_as_offset, Microsecond): + freq = freq.replace("U", "us") + + return freq + + def date_range_like(source, calendar, use_cftime=None): """Generate a datetime array with the same frequency, start and end as another one, but in a different calendar. @@ -1279,15 +1396,25 @@ def date_range_like(source, calendar, use_cftime=None): "`date_range_like` was unable to generate a range as the source frequency was not inferable." ) + # TODO remove once requiring pandas >= 2.2 + freq = _legacy_to_new_freq(freq) + use_cftime = _should_cftime_be_used(source, calendar, use_cftime) source_start = source.values.min() source_end = source.values.max() + + freq_as_offset = to_offset(freq) + if freq_as_offset.n < 0: + source_start, source_end = source_end, source_start + if is_np_datetime_like(source.dtype): # We want to use datetime fields (datetime64 object don't have them) source_calendar = "standard" - source_start = pd.Timestamp(source_start) - source_end = pd.Timestamp(source_end) + # TODO: the strict enforcement of nanosecond precision Timestamps can be + # relaxed when addressing GitHub issue #7493. + source_start = nanosecond_precision_timestamp(source_start) + source_end = nanosecond_precision_timestamp(source_end) else: if isinstance(source, CFTimeIndex): source_calendar = source.calendar @@ -1303,7 +1430,7 @@ def date_range_like(source, calendar, use_cftime=None): # For the cases where the source ends on the end of the month, we expect the same in the new calendar. if source_end.day == source_end.daysinmonth and isinstance( - to_offset(freq), (YearEnd, QuarterEnd, MonthEnd, Day) + freq_as_offset, (YearEnd, QuarterEnd, MonthEnd, Day) ): end = end.replace(day=end.daysinmonth) diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 7227ba9edb6..6898809e3b0 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -1,4 +1,5 @@ """DatetimeIndex analog for cftime.datetime objects""" + # The pandas.Index subclass defined here was copied and adapted for # use with cftime.datetime objects based on the source code defining # pandas.DatetimeIndex. @@ -187,7 +188,7 @@ def _parsed_string_to_bounds(date_type, resolution, parsed): def get_date_field(datetimes, field): """Adapted from pandas.tslib.get_date_field""" - return np.array([getattr(date, field) for date in datetimes]) + return np.array([getattr(date, field) for date in datetimes], dtype=np.int64) def _field_accessor(name, docstring=None, min_cftime_version="0.0"): @@ -228,12 +229,12 @@ def assert_all_valid_date_type(data): if not isinstance(sample, cftime.datetime): raise TypeError( "CFTimeIndex requires cftime.datetime " - "objects. Got object of {}.".format(date_type) + f"objects. Got object of {date_type}." ) if not all(isinstance(value, date_type) for value in data): raise TypeError( "CFTimeIndex requires using datetime " - "objects of all the same type. Got\n{}.".format(data) + f"objects of all the same type. Got\n{data}." ) @@ -272,8 +273,8 @@ def format_attrs(index, separator=", "): attrs = { "dtype": f"'{index.dtype}'", "length": f"{len(index)}", - "calendar": f"'{index.calendar}'", - "freq": f"'{index.freq}'" if len(index) >= 3 else None, + "calendar": f"{index.calendar!r}", + "freq": f"{index.freq!r}", } attrs_str = [f"{k}={v}" for k, v in attrs.items()] @@ -382,30 +383,30 @@ def _partial_date_slice(self, resolution, parsed): ... dims=["time"], ... ) >>> da.sel(time="2001-01-01") - + Size: 8B array([1]) Coordinates: - * time (time) object 2001-01-01 00:00:00 + * time (time) object 8B 2001-01-01 00:00:00 >>> da = xr.DataArray( ... [1, 2], ... coords=[[pd.Timestamp(2001, 1, 1), pd.Timestamp(2001, 2, 1)]], ... dims=["time"], ... ) >>> da.sel(time="2001-01-01") - + Size: 8B array(1) Coordinates: - time datetime64[ns] 2001-01-01 + time datetime64[ns] 8B 2001-01-01 >>> da = xr.DataArray( ... [1, 2], ... coords=[[pd.Timestamp(2001, 1, 1, 1), pd.Timestamp(2001, 2, 1)]], ... dims=["time"], ... ) >>> da.sel(time="2001-01-01") - + Size: 8B array([1]) Coordinates: - * time (time) datetime64[ns] 2001-01-01T01:00:00 + * time (time) datetime64[ns] 8B 2001-01-01T01:00:00 """ start, end = _parsed_string_to_bounds(self.date_type, resolution, parsed) @@ -470,13 +471,9 @@ def get_loc(self, key): else: return super().get_loc(key) - def _maybe_cast_slice_bound(self, label, side, kind=None): + def _maybe_cast_slice_bound(self, label, side): """Adapted from pandas.tseries.index.DatetimeIndex._maybe_cast_slice_bound - - Note that we have never used the kind argument in CFTimeIndex and it is - deprecated as of pandas version 1.3.0. It exists only for compatibility - reasons. We can remove it when our minimum version of pandas is 1.3.0. """ if not isinstance(label, str): return label @@ -538,11 +535,11 @@ def shift(self, n: int | float, freq: str | timedelta): Examples -------- - >>> index = xr.cftime_range("2000", periods=1, freq="M") + >>> index = xr.cftime_range("2000", periods=1, freq="ME") >>> index CFTimeIndex([2000-01-31 00:00:00], dtype='object', length=1, calendar='standard', freq=None) - >>> index.shift(1, "M") + >>> index.shift(1, "ME") CFTimeIndex([2000-02-29 00:00:00], dtype='object', length=1, calendar='standard', freq=None) >>> index.shift(1.5, "D") @@ -557,8 +554,7 @@ def shift(self, n: int | float, freq: str | timedelta): return self + n * to_offset(freq) else: raise TypeError( - "'freq' must be of type " - "str or datetime.timedelta, got {}.".format(freq) + f"'freq' must be of type str or datetime.timedelta, got {freq}." ) def __add__(self, other): @@ -613,7 +609,7 @@ def to_datetimeindex(self, unsafe=False): ------ ValueError If the CFTimeIndex contains dates that are not possible in the - standard calendar or outside the pandas.Timestamp-valid range. + standard calendar or outside the nanosecond-precision range. Warns ----- @@ -635,15 +631,19 @@ def to_datetimeindex(self, unsafe=False): >>> times.to_datetimeindex() DatetimeIndex(['2000-01-01', '2000-01-02'], dtype='datetime64[ns]', freq=None) """ + + if not self._data.size: + return pd.DatetimeIndex([]) + nptimes = cftime_to_nptime(self) calendar = infer_calendar_name(self) if calendar not in _STANDARD_CALENDARS and not unsafe: warnings.warn( "Converting a CFTimeIndex with dates from a non-standard " - "calendar, {!r}, to a pandas.DatetimeIndex, which uses dates " + f"calendar, {calendar!r}, to a pandas.DatetimeIndex, which uses dates " "from the standard calendar. This may lead to subtle errors " "in operations that depend on the length of time between " - "dates.".format(calendar), + "dates.", RuntimeWarning, stacklevel=2, ) @@ -684,6 +684,9 @@ def asi8(self): """Convert to integers with units of microseconds since 1970-01-01.""" from xarray.core.resample_cftime import exact_cftime_datetime_difference + if not self._data.size: + return np.array([], dtype=np.int64) + epoch = self.date_type(1970, 1, 1) return np.array( [ @@ -698,6 +701,9 @@ def calendar(self): """The calendar used by the datetimes in the index.""" from xarray.coding.times import infer_calendar_name + if not self._data.size: + return None + return infer_calendar_name(self) @property @@ -705,12 +711,19 @@ def freq(self): """The frequency used by the dates in the index.""" from xarray.coding.frequencies import infer_freq + # min 3 elemtents required to determine freq + if self._data.size < 3: + return None + return infer_freq(self) def _round_via_method(self, freq, method): """Round dates using a specified method.""" from xarray.coding.cftime_offsets import CFTIME_TICKS, to_offset + if not self._data.size: + return CFTimeIndex(np.array(self)) + offset = to_offset(freq) if not isinstance(offset, CFTIME_TICKS): raise ValueError(f"{offset} is a non-fixed frequency") diff --git a/xarray/coding/frequencies.py b/xarray/coding/frequencies.py index 4d24327aa2f..b912b9a1fca 100644 --- a/xarray/coding/frequencies.py +++ b/xarray/coding/frequencies.py @@ -1,4 +1,5 @@ """FrequencyInferer analog for cftime.datetime objects""" + # The infer_freq method and the _CFTimeFrequencyInferer # subclass defined here were copied and adapted for # use with cftime.datetime objects based on the source code in @@ -44,7 +45,7 @@ import numpy as np import pandas as pd -from xarray.coding.cftime_offsets import _MONTH_ABBREVIATIONS +from xarray.coding.cftime_offsets import _MONTH_ABBREVIATIONS, _legacy_to_new_freq from xarray.coding.cftimeindex import CFTimeIndex from xarray.core.common import _contains_datetime_like_objects @@ -98,7 +99,7 @@ def infer_freq(index): inferer = _CFTimeFrequencyInferer(index) return inferer.get_freq() - return pd.infer_freq(index) + return _legacy_to_new_freq(pd.infer_freq(index)) class _CFTimeFrequencyInferer: # (pd.tseries.frequencies._FrequencyInferer): @@ -138,15 +139,15 @@ def get_freq(self): return None if _is_multiple(delta, _ONE_HOUR): - return _maybe_add_count("H", delta / _ONE_HOUR) + return _maybe_add_count("h", delta / _ONE_HOUR) elif _is_multiple(delta, _ONE_MINUTE): - return _maybe_add_count("T", delta / _ONE_MINUTE) + return _maybe_add_count("min", delta / _ONE_MINUTE) elif _is_multiple(delta, _ONE_SECOND): - return _maybe_add_count("S", delta / _ONE_SECOND) + return _maybe_add_count("s", delta / _ONE_SECOND) elif _is_multiple(delta, _ONE_MILLI): - return _maybe_add_count("L", delta / _ONE_MILLI) + return _maybe_add_count("ms", delta / _ONE_MILLI) else: - return _maybe_add_count("U", delta / _ONE_MICRO) + return _maybe_add_count("us", delta / _ONE_MICRO) def _infer_daily_rule(self): annual_rule = self._get_annual_rule() @@ -183,7 +184,7 @@ def _get_annual_rule(self): if len(np.unique(self.index.month)) > 1: return None - return {"cs": "AS", "ce": "A"}.get(month_anchor_check(self.index)) + return {"cs": "YS", "ce": "YE"}.get(month_anchor_check(self.index)) def _get_quartely_rule(self): if len(self.month_deltas) > 1: @@ -192,13 +193,13 @@ def _get_quartely_rule(self): if self.month_deltas[0] % 3 != 0: return None - return {"cs": "QS", "ce": "Q"}.get(month_anchor_check(self.index)) + return {"cs": "QS", "ce": "QE"}.get(month_anchor_check(self.index)) def _get_monthly_rule(self): if len(self.month_deltas) > 1: return None - return {"cs": "MS", "ce": "M"}.get(month_anchor_check(self.index)) + return {"cs": "MS", "ce": "ME"}.get(month_anchor_check(self.index)) @property def deltas(self): diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index 61b3ab7c46c..b3b9d8d1041 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -1,4 +1,5 @@ """Coders for strings.""" + from __future__ import annotations from functools import partial @@ -14,8 +15,9 @@ unpack_for_encoding, ) from xarray.core import indexing -from xarray.core.pycompat import is_duck_dask_array from xarray.core.variable import Variable +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import is_chunked_array def create_vlen_dtype(element_type): @@ -29,7 +31,8 @@ def check_vlen_dtype(dtype): if dtype.kind != "O" or dtype.metadata is None: return None else: - return dtype.metadata.get("element_type") + # check xarray (element_type) as well as h5py (vlen) + return dtype.metadata.get("element_type", dtype.metadata.get("vlen")) def is_unicode_dtype(dtype): @@ -46,21 +49,20 @@ class EncodedStringCoder(VariableCoder): def __init__(self, allows_unicode=True): self.allows_unicode = allows_unicode - def encode(self, variable, name=None): + def encode(self, variable: Variable, name=None) -> Variable: dims, data, attrs, encoding = unpack_for_encoding(variable) contains_unicode = is_unicode_dtype(data.dtype) encode_as_char = encoding.get("dtype") == "S1" - if encode_as_char: del encoding["dtype"] # no longer relevant if contains_unicode and (encode_as_char or not self.allows_unicode): if "_FillValue" in attrs: raise NotImplementedError( - "variable {!r} has a _FillValue specified, but " + f"variable {name!r} has a _FillValue specified, but " "_FillValue is not yet supported on unicode strings: " - "https://github.com/pydata/xarray/issues/1647".format(name) + "https://github.com/pydata/xarray/issues/1647" ) string_encoding = encoding.pop("_Encoding", "utf-8") @@ -68,9 +70,12 @@ def encode(self, variable, name=None): # TODO: figure out how to handle this in a lazy way with dask data = encode_string_array(data, string_encoding) - return Variable(dims, data, attrs, encoding) + return Variable(dims, data, attrs, encoding) + else: + variable.encoding = encoding + return variable - def decode(self, variable, name=None): + def decode(self, variable: Variable, name=None) -> Variable: dims, data, attrs, encoding = unpack_for_decoding(variable) if "_Encoding" in attrs: @@ -94,13 +99,15 @@ def encode_string_array(string_array, encoding="utf-8"): return np.array(encoded, dtype=bytes).reshape(string_array.shape) -def ensure_fixed_length_bytes(var): +def ensure_fixed_length_bytes(var: Variable) -> Variable: """Ensure that a variable with vlen bytes is converted to fixed width.""" - dims, data, attrs, encoding = unpack_for_encoding(var) - if check_vlen_dtype(data.dtype) == bytes: + if check_vlen_dtype(var.dtype) == bytes: + dims, data, attrs, encoding = unpack_for_encoding(var) # TODO: figure out how to handle this with dask - data = np.asarray(data, dtype=np.string_) - return Variable(dims, data, attrs, encoding) + data = np.asarray(data, dtype=np.bytes_) + return Variable(dims, data, attrs, encoding) + else: + return var class CharacterArrayCoder(VariableCoder): @@ -134,10 +141,10 @@ def bytes_to_char(arr): if arr.dtype.kind != "S": raise ValueError("argument must have a fixed-width bytes dtype") - if is_duck_dask_array(arr): - import dask.array as da + if is_chunked_array(arr): + chunkmanager = get_chunked_array_type(arr) - return da.map_blocks( + return chunkmanager.map_blocks( _numpy_bytes_to_char, arr, dtype="S1", @@ -150,7 +157,7 @@ def bytes_to_char(arr): def _numpy_bytes_to_char(arr): """Like netCDF4.stringtochar, but faster and more flexible.""" # ensure the array is contiguous - arr = np.array(arr, copy=False, order="C", dtype=np.string_) + arr = np.array(arr, copy=False, order="C", dtype=np.bytes_) return arr.reshape(arr.shape + (1,)).view("S1") @@ -167,19 +174,19 @@ def char_to_bytes(arr): if not size: # can't make an S0 dtype - return np.zeros(arr.shape[:-1], dtype=np.string_) + return np.zeros(arr.shape[:-1], dtype=np.bytes_) - if is_duck_dask_array(arr): - import dask.array as da + if is_chunked_array(arr): + chunkmanager = get_chunked_array_type(arr) if len(arr.chunks[-1]) > 1: raise ValueError( "cannot stacked dask character array with " - "multiple chunks in the last dimension: {}".format(arr) + f"multiple chunks in the last dimension: {arr}" ) dtype = np.dtype("S" + str(arr.shape[-1])) - return da.map_blocks( + return chunkmanager.map_blocks( _numpy_char_to_bytes, arr, dtype=dtype, @@ -231,6 +238,12 @@ def shape(self) -> tuple[int, ...]: def __repr__(self): return f"{type(self).__name__}({self.array!r})" + def _vindex_get(self, key): + return _numpy_char_to_bytes(self.array.vindex[key]) + + def _oindex_get(self, key): + return _numpy_char_to_bytes(self.array.oindex[key]) + def __getitem__(self, key): # require slicing the last dimension completely key = type(key)(indexing.expanded_indexer(key.tuple, self.array.ndim)) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index f9e79863d46..92bce0abeaa 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -22,9 +22,14 @@ ) from xarray.core import indexing from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like +from xarray.core.duck_array_ops import asarray from xarray.core.formatting import first_n_items, format_timestamp, last_item -from xarray.core.pycompat import is_duck_dask_array +from xarray.core.pdcompat import nanosecond_precision_timestamp +from xarray.core.utils import emit_user_level_warning from xarray.core.variable import Variable +from xarray.namedarray.parallelcompat import T_ChunkedArray, get_chunked_array_type +from xarray.namedarray.pycompat import is_chunked_array +from xarray.namedarray.utils import is_duck_dask_array try: import cftime @@ -32,7 +37,7 @@ cftime = None if TYPE_CHECKING: - from xarray.core.types import CFCalendar + from xarray.core.types import CFCalendar, T_DuckArray T_Name = Union[Hashable, None] @@ -121,6 +126,18 @@ def _netcdf_to_numpy_timeunit(units: str) -> str: }[units] +def _numpy_to_netcdf_timeunit(units: str) -> str: + return { + "ns": "nanoseconds", + "us": "microseconds", + "ms": "milliseconds", + "s": "seconds", + "m": "minutes", + "h": "hours", + "D": "days", + }[units] + + def _ensure_padded_year(ref_date: str) -> str: # Reference dates without a padded year (e.g. since 1-1-1 or since 2-3-4) # are ambiguous (is it YMD or DMY?). This can lead to some very odd @@ -170,6 +187,20 @@ def _unpack_netcdf_time_units(units: str) -> tuple[str, str]: return delta_units, ref_date +def _unpack_time_units_and_ref_date(units: str) -> tuple[str, pd.Timestamp]: + # same us _unpack_netcdf_time_units but finalizes ref_date for + # processing in encode_cf_datetime + time_units, _ref_date = _unpack_netcdf_time_units(units) + # TODO: the strict enforcement of nanosecond precision Timestamps can be + # relaxed when addressing GitHub issue #7493. + ref_date = nanosecond_precision_timestamp(_ref_date) + # If the ref_date Timestamp is timezone-aware, convert to UTC and + # make it timezone-naive (GH 2649). + if ref_date.tz is not None: + ref_date = ref_date.tz_convert(None) + return time_units, ref_date + + def _decode_cf_datetime_dtype( data, units: str, calendar: str, use_cftime: bool | None ) -> np.dtype: @@ -217,14 +248,16 @@ def _decode_datetime_with_pandas( ) -> np.ndarray: if not _is_standard_calendar(calendar): raise OutOfBoundsDatetime( - "Cannot decode times from a non-standard calendar, {!r}, using " - "pandas.".format(calendar) + f"Cannot decode times from a non-standard calendar, {calendar!r}, using " + "pandas." ) - delta, ref_date = _unpack_netcdf_time_units(units) - delta = _netcdf_to_numpy_timeunit(delta) + time_units, ref_date = _unpack_netcdf_time_units(units) + time_units = _netcdf_to_numpy_timeunit(time_units) try: - ref_date = pd.Timestamp(ref_date) + # TODO: the strict enforcement of nanosecond precision Timestamps can be + # relaxed when addressing GitHub issue #7493. + ref_date = nanosecond_precision_timestamp(ref_date) except ValueError: # ValueError is raised by pd.Timestamp for non-ISO timestamp # strings, in which case we fall back to using cftime @@ -234,8 +267,8 @@ def _decode_datetime_with_pandas( warnings.filterwarnings("ignore", "invalid value encountered", RuntimeWarning) if flat_num_dates.size > 0: # avoid size 0 datetimes GH1329 - pd.to_timedelta(flat_num_dates.min(), delta) + ref_date - pd.to_timedelta(flat_num_dates.max(), delta) + ref_date + pd.to_timedelta(flat_num_dates.min(), time_units) + ref_date + pd.to_timedelta(flat_num_dates.max(), time_units) + ref_date # To avoid integer overflow when converting to nanosecond units for integer # dtypes smaller than np.int64 cast all integer and unsigned integer dtype @@ -248,9 +281,12 @@ def _decode_datetime_with_pandas( # Cast input ordinals to integers of nanoseconds because pd.to_timedelta # works much faster when dealing with integers (GH 1399). - flat_num_dates_ns_int = (flat_num_dates * _NS_PER_TIME_DELTA[delta]).astype( - np.int64 - ) + # properly handle NaN/NaT to prevent casting NaN to int + nan = np.isnan(flat_num_dates) | (flat_num_dates == np.iinfo(np.int64).min) + flat_num_dates = flat_num_dates * _NS_PER_TIME_DELTA[time_units] + flat_num_dates_ns_int = np.zeros_like(flat_num_dates, dtype=np.int64) + flat_num_dates_ns_int[nan] = np.iinfo(np.int64).min + flat_num_dates_ns_int[~nan] = flat_num_dates[~nan].astype(np.int64) # Use pd.to_timedelta to safely cast integer values to timedeltas, # and add those to a Timestamp to safely produce a DatetimeIndex. This @@ -361,6 +397,10 @@ def _infer_time_units_from_diff(unique_timedeltas) -> str: return "seconds" +def _time_units_to_timedelta64(units: str) -> np.timedelta64: + return np.timedelta64(1, _netcdf_to_numpy_timeunit(units)).astype("timedelta64[ns]") + + def infer_calendar_name(dates) -> CFCalendar: """Given an array of datetimes, infer the CF calendar name""" if is_np_datetime_like(dates.dtype): @@ -391,7 +431,9 @@ def infer_datetime_units(dates) -> str: dates = to_datetime_unboxed(dates) dates = dates[pd.notnull(dates)] reference_date = dates[0] if len(dates) > 0 else "1970-01-01" - reference_date = pd.Timestamp(reference_date) + # TODO: the strict enforcement of nanosecond precision Timestamps can be + # relaxed when addressing GitHub issue #7493. + reference_date = nanosecond_precision_timestamp(reference_date) else: reference_date = dates[0] if len(dates) > 0 else "1970-01-01" reference_date = format_cftime_datetime(reference_date) @@ -432,6 +474,8 @@ def cftime_to_nptime(times, raise_on_invalid: bool = True) -> np.ndarray: If raise_on_invalid is True (default), invalid dates trigger a ValueError. Otherwise, the invalid element is replaced by np.NaT.""" times = np.asarray(times) + # TODO: the strict enforcement of nanosecond precision datetime values can + # be relaxed when addressing GitHub issue #7493. new = np.empty(times.shape, dtype="M8[ns]") for i, t in np.ndenumerate(times): try: @@ -439,14 +483,14 @@ def cftime_to_nptime(times, raise_on_invalid: bool = True) -> np.ndarray: # NumPy casts it safely it np.datetime64[ns] for dates outside # 1678 to 2262 (this is not currently the case for # datetime.datetime). - dt = pd.Timestamp( + dt = nanosecond_precision_timestamp( t.year, t.month, t.day, t.hour, t.minute, t.second, t.microsecond ) except ValueError as e: if raise_on_invalid: raise ValueError( - "Cannot convert date {} to a date in the " - "standard calendar. Reason: {}.".format(t, e) + f"Cannot convert date {t} to a date in the " + f"standard calendar. Reason: {e}." ) else: dt = "NaT" @@ -460,7 +504,7 @@ def convert_times(times, date_type, raise_on_invalid: bool = True) -> np.ndarray Useful to convert between calendars in numpy and cftime or between cftime calendars. If raise_on_valid is True (default), invalid dates trigger a ValueError. - Otherwise, the invalid element is replaced by np.NaN for cftime types and np.NaT for np.datetime64. + Otherwise, the invalid element is replaced by np.nan for cftime types and np.NaT for np.datetime64. """ if date_type in (pd.Timestamp, np.datetime64) and not is_np_datetime_like( times.dtype @@ -478,13 +522,11 @@ def convert_times(times, date_type, raise_on_invalid: bool = True) -> np.ndarray except ValueError as e: if raise_on_invalid: raise ValueError( - "Cannot convert date {} to a date in the " - "{} calendar. Reason: {}.".format( - t, date_type(2000, 1, 1).calendar, e - ) + f"Cannot convert date {t} to a date in the " + f"{date_type(2000, 1, 1).calendar} calendar. Reason: {e}." ) else: - dt = np.NaN + dt = np.nan new[i] = dt return new @@ -498,6 +540,10 @@ def convert_time_or_go_back(date, date_type): This is meant to convert end-of-month dates into a new calendar. """ + # TODO: the strict enforcement of nanosecond precision Timestamps can be + # relaxed when addressing GitHub issue #7493. + if date_type == pd.Timestamp: + date_type = nanosecond_precision_timestamp try: return date_type( date.year, @@ -563,9 +609,12 @@ def _should_cftime_be_used( def _cleanup_netcdf_time_units(units: str) -> str: - delta, ref_date = _unpack_netcdf_time_units(units) + time_units, ref_date = _unpack_netcdf_time_units(units) + time_units = time_units.lower() + if not time_units.endswith("s"): + time_units = f"{time_units}s" try: - units = f"{delta} since {format_timestamp(ref_date)}" + units = f"{time_units} since {format_timestamp(ref_date)}" except (OutOfBoundsDatetime, ValueError): # don't worry about reifying the units if they're out of bounds or # formatted badly @@ -610,9 +659,59 @@ def cast_to_int_if_safe(num) -> np.ndarray: return num +def _division(deltas, delta, floor): + if floor: + # calculate int64 floor division + # to preserve integer dtype if possible (GH 4045, GH7817). + num = deltas // delta.astype(np.int64) + num = num.astype(np.int64, copy=False) + else: + num = deltas / delta + return num + + +def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="overflow") + cast_num = np.asarray(num, dtype=dtype) + + if np.issubdtype(dtype, np.integer): + if not (num == cast_num).all(): + if np.issubdtype(num.dtype, np.floating): + raise ValueError( + f"Not possible to cast all encoded times from " + f"{num.dtype!r} to {dtype!r} without losing precision. " + f"Consider modifying the units such that integer values " + f"can be used, or removing the units and dtype encoding, " + f"at which point xarray will make an appropriate choice." + ) + else: + raise OverflowError( + f"Not possible to cast encoded times from " + f"{num.dtype!r} to {dtype!r} without overflow. Consider " + f"removing the dtype encoding, at which point xarray will " + f"make an appropriate choice, or explicitly switching to " + "a larger integer dtype." + ) + else: + if np.isinf(cast_num).any(): + raise OverflowError( + f"Not possible to cast encoded times from {num.dtype!r} to " + f"{dtype!r} without overflow. Consider removing the dtype " + f"encoding, at which point xarray will make an appropriate " + f"choice, or explicitly switching to a larger floating point " + f"dtype." + ) + + return cast_num + + def encode_cf_datetime( - dates, units: str | None = None, calendar: str | None = None -) -> tuple[np.ndarray, str, str]: + dates: T_DuckArray, # type: ignore + units: str | None = None, + calendar: str | None = None, + dtype: np.dtype | None = None, +) -> tuple[T_DuckArray, str, str]: """Given an array of datetime objects, returns the tuple `(num, units, calendar)` suitable for a CF compliant time variable. @@ -622,31 +721,40 @@ def encode_cf_datetime( -------- cftime.date2num """ - dates = np.asarray(dates) + dates = asarray(dates) + if is_chunked_array(dates): + return _lazily_encode_cf_datetime(dates, units, calendar, dtype) + else: + return _eagerly_encode_cf_datetime(dates, units, calendar, dtype) + + +def _eagerly_encode_cf_datetime( + dates: T_DuckArray, # type: ignore + units: str | None = None, + calendar: str | None = None, + dtype: np.dtype | None = None, + allow_units_modification: bool = True, +) -> tuple[T_DuckArray, str, str]: + dates = asarray(dates) + + data_units = infer_datetime_units(dates) if units is None: - units = infer_datetime_units(dates) + units = data_units else: units = _cleanup_netcdf_time_units(units) if calendar is None: calendar = infer_calendar_name(dates) - delta, _ref_date = _unpack_netcdf_time_units(units) try: if not _is_standard_calendar(calendar) or dates.dtype.kind == "O": # parse with cftime instead raise OutOfBoundsDatetime assert dates.dtype == "datetime64[ns]" - delta_units = _netcdf_to_numpy_timeunit(delta) - time_delta = np.timedelta64(1, delta_units).astype("timedelta64[ns]") - ref_date = pd.Timestamp(_ref_date) - - # If the ref_date Timestamp is timezone-aware, convert to UTC and - # make it timezone-naive (GH 2649). - if ref_date.tz is not None: - ref_date = ref_date.tz_convert(None) + time_units, ref_date = _unpack_time_units_and_ref_date(units) + time_delta = _time_units_to_timedelta64(time_units) # Wrap the dates in a DatetimeIndex to do the subtraction to ensure # an OverflowError is raised if the ref_date is too far away from @@ -654,30 +762,205 @@ def encode_cf_datetime( dates_as_index = pd.DatetimeIndex(dates.ravel()) time_deltas = dates_as_index - ref_date - # Use floor division if time_delta evenly divides all differences - # to preserve integer dtype if possible (GH 4045). - if np.all(time_deltas % time_delta == np.timedelta64(0, "ns")): - num = time_deltas // time_delta - else: - num = time_deltas / time_delta + # retrieve needed units to faithfully encode to int64 + needed_units, data_ref_date = _unpack_time_units_and_ref_date(data_units) + if data_units != units: + # this accounts for differences in the reference times + ref_delta = abs(data_ref_date - ref_date).to_timedelta64() + data_delta = _time_units_to_timedelta64(needed_units) + if (ref_delta % data_delta) > np.timedelta64(0, "ns"): + needed_units = _infer_time_units_from_diff(ref_delta) + + # needed time delta to encode faithfully to int64 + needed_time_delta = _time_units_to_timedelta64(needed_units) + + floor_division = True + if time_delta > needed_time_delta: + floor_division = False + if dtype is None: + emit_user_level_warning( + f"Times can't be serialized faithfully to int64 with requested units {units!r}. " + f"Resolution of {needed_units!r} needed. Serializing times to floating point instead. " + f"Set encoding['dtype'] to integer dtype to serialize to int64. " + f"Set encoding['dtype'] to floating point dtype to silence this warning." + ) + elif np.issubdtype(dtype, np.integer) and allow_units_modification: + new_units = f"{needed_units} since {format_timestamp(ref_date)}" + emit_user_level_warning( + f"Times can't be serialized faithfully to int64 with requested units {units!r}. " + f"Serializing with units {new_units!r} instead. " + f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. " + f"Set encoding['units'] to {new_units!r} to silence this warning ." + ) + units = new_units + time_delta = needed_time_delta + floor_division = True + + num = _division(time_deltas, time_delta, floor_division) num = num.values.reshape(dates.shape) except (OutOfBoundsDatetime, OverflowError, ValueError): num = _encode_datetime_with_cftime(dates, units, calendar) + # do it now only for cftime-based flow + # we already covered for this in pandas-based flow + num = cast_to_int_if_safe(num) - num = cast_to_int_if_safe(num) - return (num, units, calendar) + if dtype is not None: + num = _cast_to_dtype_if_safe(num, dtype) + return num, units, calendar + + +def _encode_cf_datetime_within_map_blocks( + dates: T_DuckArray, # type: ignore + units: str, + calendar: str, + dtype: np.dtype, +) -> T_DuckArray: + num, *_ = _eagerly_encode_cf_datetime( + dates, units, calendar, dtype, allow_units_modification=False + ) + return num + + +def _lazily_encode_cf_datetime( + dates: T_ChunkedArray, + units: str | None = None, + calendar: str | None = None, + dtype: np.dtype | None = None, +) -> tuple[T_ChunkedArray, str, str]: + if calendar is None: + # This will only trigger minor compute if dates is an object dtype array. + calendar = infer_calendar_name(dates) + + if units is None and dtype is None: + if dates.dtype == "O": + units = "microseconds since 1970-01-01" + dtype = np.dtype("int64") + else: + units = "nanoseconds since 1970-01-01" + dtype = np.dtype("int64") + + if units is None or dtype is None: + raise ValueError( + f"When encoding chunked arrays of datetime values, both the units " + f"and dtype must be prescribed or both must be unprescribed. " + f"Prescribing only one or the other is not currently supported. " + f"Got a units encoding of {units} and a dtype encoding of {dtype}." + ) + + chunkmanager = get_chunked_array_type(dates) + num = chunkmanager.map_blocks( + _encode_cf_datetime_within_map_blocks, + dates, + units, + calendar, + dtype, + dtype=dtype, + ) + return num, units, calendar + + +def encode_cf_timedelta( + timedeltas: T_DuckArray, # type: ignore + units: str | None = None, + dtype: np.dtype | None = None, +) -> tuple[T_DuckArray, str]: + timedeltas = asarray(timedeltas) + if is_chunked_array(timedeltas): + return _lazily_encode_cf_timedelta(timedeltas, units, dtype) + else: + return _eagerly_encode_cf_timedelta(timedeltas, units, dtype) + + +def _eagerly_encode_cf_timedelta( + timedeltas: T_DuckArray, # type: ignore + units: str | None = None, + dtype: np.dtype | None = None, + allow_units_modification: bool = True, +) -> tuple[T_DuckArray, str]: + data_units = infer_timedelta_units(timedeltas) -def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarray, str]: if units is None: - units = infer_timedelta_units(timedeltas) + units = data_units + + time_delta = _time_units_to_timedelta64(units) + time_deltas = pd.TimedeltaIndex(timedeltas.ravel()) + + # retrieve needed units to faithfully encode to int64 + needed_units = data_units + if data_units != units: + needed_units = _infer_time_units_from_diff(np.unique(time_deltas.dropna())) + + # needed time delta to encode faithfully to int64 + needed_time_delta = _time_units_to_timedelta64(needed_units) + + floor_division = True + if time_delta > needed_time_delta: + floor_division = False + if dtype is None: + emit_user_level_warning( + f"Timedeltas can't be serialized faithfully to int64 with requested units {units!r}. " + f"Resolution of {needed_units!r} needed. Serializing timeseries to floating point instead. " + f"Set encoding['dtype'] to integer dtype to serialize to int64. " + f"Set encoding['dtype'] to floating point dtype to silence this warning." + ) + elif np.issubdtype(dtype, np.integer) and allow_units_modification: + emit_user_level_warning( + f"Timedeltas can't be serialized faithfully with requested units {units!r}. " + f"Serializing with units {needed_units!r} instead. " + f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. " + f"Set encoding['units'] to {needed_units!r} to silence this warning ." + ) + units = needed_units + time_delta = needed_time_delta + floor_division = True + + num = _division(time_deltas, time_delta, floor_division) + num = num.values.reshape(timedeltas.shape) + + if dtype is not None: + num = _cast_to_dtype_if_safe(num, dtype) - np_unit = _netcdf_to_numpy_timeunit(units) - num = 1.0 * timedeltas / np.timedelta64(1, np_unit) - num = np.where(pd.isnull(timedeltas), np.nan, num) - num = cast_to_int_if_safe(num) - return (num, units) + return num, units + + +def _encode_cf_timedelta_within_map_blocks( + timedeltas: T_DuckArray, # type:ignore + units: str, + dtype: np.dtype, +) -> T_DuckArray: + num, _ = _eagerly_encode_cf_timedelta( + timedeltas, units, dtype, allow_units_modification=False + ) + return num + + +def _lazily_encode_cf_timedelta( + timedeltas: T_ChunkedArray, units: str | None = None, dtype: np.dtype | None = None +) -> tuple[T_ChunkedArray, str]: + if units is None and dtype is None: + units = "nanoseconds" + dtype = np.dtype("int64") + + if units is None or dtype is None: + raise ValueError( + f"When encoding chunked arrays of timedelta values, both the " + f"units and dtype must be prescribed or both must be " + f"unprescribed. Prescribing only one or the other is not " + f"currently supported. Got a units encoding of {units} and a " + f"dtype encoding of {dtype}." + ) + + chunkmanager = get_chunked_array_type(timedeltas) + num = chunkmanager.map_blocks( + _encode_cf_timedelta_within_map_blocks, + timedeltas, + units, + dtype, + dtype=dtype, + ) + return num, units class CFDatetimeCoder(VariableCoder): @@ -690,9 +973,11 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: ) or contains_cftime_datetimes(variable): dims, data, attrs, encoding = unpack_for_encoding(variable) - (data, units, calendar) = encode_cf_datetime( - data, encoding.pop("units", None), encoding.pop("calendar", None) - ) + units = encoding.pop("units", None) + calendar = encoding.pop("calendar", None) + dtype = encoding.get("dtype", None) + (data, units, calendar) = encode_cf_datetime(data, units, calendar, dtype) + safe_setitem(attrs, "units", units, name=name) safe_setitem(attrs, "calendar", calendar, name=name) @@ -726,7 +1011,9 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: if np.issubdtype(variable.data.dtype, np.timedelta64): dims, data, attrs, encoding = unpack_for_encoding(variable) - data, units = encode_cf_timedelta(data, encoding.pop("units", None)) + data, units = encode_cf_timedelta( + data, encoding.pop("units", None), encoding.get("dtype", None) + ) safe_setitem(attrs, "units", units, name=name) return Variable(dims, data, attrs, encoding, fastpath=True) diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index c290307b4b6..52cf0fc3656 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -1,4 +1,5 @@ """Coders for individual Variable objects.""" + from __future__ import annotations import warnings @@ -10,8 +11,9 @@ import pandas as pd from xarray.core import dtypes, duck_array_ops, indexing -from xarray.core.pycompat import is_duck_dask_array from xarray.core.variable import Variable +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import is_chunked_array if TYPE_CHECKING: T_VarTuple = tuple[tuple[Hashable, ...], Any, dict, dict] @@ -57,7 +59,7 @@ class _ElementwiseFunctionArray(indexing.ExplicitlyIndexedNDArrayMixin): """ def __init__(self, array, func: Callable, dtype: np.typing.DTypeLike): - assert not is_duck_dask_array(array) + assert not is_chunked_array(array) self.array = indexing.as_indexable(array) self.func = func self._dtype = dtype @@ -66,11 +68,17 @@ def __init__(self, array, func: Callable, dtype: np.typing.DTypeLike): def dtype(self) -> np.dtype: return np.dtype(self._dtype) + def _oindex_get(self, key): + return type(self)(self.array.oindex[key], self.func, self.dtype) + + def _vindex_get(self, key): + return type(self)(self.array.vindex[key], self.func, self.dtype) + def __getitem__(self, key): return type(self)(self.array[key], self.func, self.dtype) - def __array__(self, dtype=None): - return self.func(self.array) + def get_duck_array(self): + return self.func(self.array.get_duck_array()) def __repr__(self) -> str: return "{}({!r}, func={!r}, dtype={!r})".format( @@ -78,6 +86,83 @@ def __repr__(self) -> str: ) +class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin): + """Decode arrays on the fly from non-native to native endianness + + This is useful for decoding arrays from netCDF3 files (which are all + big endian) into native endianness, so they can be used with Cython + functions, such as those found in bottleneck and pandas. + + >>> x = np.arange(5, dtype=">i2") + + >>> x.dtype + dtype('>i2') + + >>> NativeEndiannessArray(x).dtype + dtype('int16') + + >>> indexer = indexing.BasicIndexer((slice(None),)) + >>> NativeEndiannessArray(x)[indexer].dtype + dtype('int16') + """ + + __slots__ = ("array",) + + def __init__(self, array) -> None: + self.array = indexing.as_indexable(array) + + @property + def dtype(self) -> np.dtype: + return np.dtype(self.array.dtype.kind + str(self.array.dtype.itemsize)) + + def _oindex_get(self, key): + return np.asarray(self.array.oindex[key], dtype=self.dtype) + + def _vindex_get(self, key): + return np.asarray(self.array.vindex[key], dtype=self.dtype) + + def __getitem__(self, key) -> np.ndarray: + return np.asarray(self.array[key], dtype=self.dtype) + + +class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin): + """Decode arrays on the fly from integer to boolean datatype + + This is useful for decoding boolean arrays from integer typed netCDF + variables. + + >>> x = np.array([1, 0, 1, 1, 0], dtype="i1") + + >>> x.dtype + dtype('int8') + + >>> BoolTypeArray(x).dtype + dtype('bool') + + >>> indexer = indexing.BasicIndexer((slice(None),)) + >>> BoolTypeArray(x)[indexer].dtype + dtype('bool') + """ + + __slots__ = ("array",) + + def __init__(self, array) -> None: + self.array = indexing.as_indexable(array) + + @property + def dtype(self) -> np.dtype: + return np.dtype("bool") + + def _oindex_get(self, key): + return np.asarray(self.array.oindex[key], dtype=self.dtype) + + def _vindex_get(self, key): + return np.asarray(self.array.vindex[key], dtype=self.dtype) + + def __getitem__(self, key) -> np.ndarray: + return np.asarray(self.array[key], dtype=self.dtype) + + def lazy_elemwise_func(array, func: Callable, dtype: np.typing.DTypeLike): """Lazily apply an element-wise function to an array. Parameters @@ -93,10 +178,10 @@ def lazy_elemwise_func(array, func: Callable, dtype: np.typing.DTypeLike): ------- Either a dask.array.Array or _ElementwiseFunctionArray. """ - if is_duck_dask_array(array): - import dask.array as da + if is_chunked_array(array): + chunkmanager = get_chunked_array_type(array) - return da.map_blocks(func, array, dtype=dtype) + return chunkmanager.map_blocks(func, array, dtype=dtype) # type: ignore[arg-type] else: return _ElementwiseFunctionArray(array, func, dtype) @@ -113,10 +198,10 @@ def safe_setitem(dest, key: Hashable, value, name: T_Name = None): if key in dest: var_str = f" on variable {name!r}" if name else "" raise ValueError( - "failed to prevent overwriting existing key {} in attrs{}. " + f"failed to prevent overwriting existing key {key} in attrs{var_str}. " "This is probably an encoding field used by xarray to describe " "how a variable is serialized. To proceed, remove this key from " - "the variable's attributes manually.".format(key, var_str) + "the variable's attributes manually." ) dest[key] = value @@ -149,6 +234,72 @@ def _apply_mask( return np.where(condition, decoded_fill_value, data) +def _is_time_like(units): + # test for time-like + if units is None: + return False + time_strings = [ + "days", + "hours", + "minutes", + "seconds", + "milliseconds", + "microseconds", + "nanoseconds", + ] + units = str(units) + # to prevent detecting units like `days accumulated` as time-like + # special casing for datetime-units and timedelta-units (GH-8269) + if "since" in units: + from xarray.coding.times import _unpack_netcdf_time_units + + try: + _unpack_netcdf_time_units(units) + except ValueError: + return False + return True + else: + return any(tstr == units for tstr in time_strings) + + +def _check_fill_values(attrs, name, dtype): + """ "Check _FillValue and missing_value if available. + + Return dictionary with raw fill values and set with encoded fill values. + + Issue SerializationWarning if appropriate. + """ + raw_fill_dict = {} + [ + pop_to(attrs, raw_fill_dict, attr, name=name) + for attr in ("missing_value", "_FillValue") + ] + encoded_fill_values = set() + for k in list(raw_fill_dict): + v = raw_fill_dict[k] + kfill = {fv for fv in np.ravel(v) if not pd.isnull(fv)} + if not kfill and np.issubdtype(dtype, np.integer): + warnings.warn( + f"variable {name!r} has non-conforming {k!r} " + f"{v!r} defined, dropping {k!r} entirely.", + SerializationWarning, + stacklevel=3, + ) + del raw_fill_dict[k] + else: + encoded_fill_values |= kfill + + if len(encoded_fill_values) > 1: + warnings.warn( + f"variable {name!r} has multiple fill values " + f"{encoded_fill_values} defined, decoding all values to NaN.", + SerializationWarning, + stacklevel=3, + ) + + return raw_fill_dict, encoded_fill_values + + class CFMaskCoder(VariableCoder): """Mask or unmask fill values according to CF conventions.""" @@ -158,58 +309,77 @@ def encode(self, variable: Variable, name: T_Name = None): dtype = np.dtype(encoding.get("dtype", data.dtype)) fv = encoding.get("_FillValue") mv = encoding.get("missing_value") + # to properly handle _FillValue/missing_value below [a], [b] + # we need to check if unsigned data is written as signed data + unsigned = encoding.get("_Unsigned") is not None - if ( - fv is not None - and mv is not None - and not duck_array_ops.allclose_or_equiv(fv, mv) - ): + fv_exists = fv is not None + mv_exists = mv is not None + + if not fv_exists and not mv_exists: + return variable + + if fv_exists and mv_exists and not duck_array_ops.allclose_or_equiv(fv, mv): raise ValueError( f"Variable {name!r} has conflicting _FillValue ({fv}) and missing_value ({mv}). Cannot encode data." ) - if fv is not None: + if fv_exists: # Ensure _FillValue is cast to same dtype as data's - encoding["_FillValue"] = dtype.type(fv) + # [a] need to skip this if _Unsigned is available + if not unsigned: + encoding["_FillValue"] = dtype.type(fv) fill_value = pop_to(encoding, attrs, "_FillValue", name=name) - if not pd.isnull(fill_value): - data = duck_array_ops.fillna(data, fill_value) - if mv is not None: - # Ensure missing_value is cast to same dtype as data's - encoding["missing_value"] = dtype.type(mv) + if mv_exists: + # try to use _FillValue, if it exists to align both values + # or use missing_value and ensure it's cast to same dtype as data's + # [b] need to provide mv verbatim if _Unsigned is available + encoding["missing_value"] = attrs.get( + "_FillValue", + (dtype.type(mv) if not unsigned else mv), + ) fill_value = pop_to(encoding, attrs, "missing_value", name=name) - if not pd.isnull(fill_value) and fv is None: + + # apply fillna + if not pd.isnull(fill_value): + # special case DateTime to properly handle NaT + if _is_time_like(attrs.get("units")) and data.dtype.kind in "iu": + data = duck_array_ops.where( + data != np.iinfo(np.int64).min, data, fill_value + ) + else: data = duck_array_ops.fillna(data, fill_value) return Variable(dims, data, attrs, encoding, fastpath=True) def decode(self, variable: Variable, name: T_Name = None): - dims, data, attrs, encoding = unpack_for_decoding(variable) - - raw_fill_values = [ - pop_to(attrs, encoding, attr, name=name) - for attr in ("missing_value", "_FillValue") - ] - if raw_fill_values: - encoded_fill_values = { - fv - for option in raw_fill_values - for fv in np.ravel(option) - if not pd.isnull(fv) - } - - if len(encoded_fill_values) > 1: - warnings.warn( - "variable {!r} has multiple fill values {}, " - "decoding all values to NaN.".format(name, encoded_fill_values), - SerializationWarning, - stacklevel=3, - ) + raw_fill_dict, encoded_fill_values = _check_fill_values( + variable.attrs, name, variable.dtype + ) - dtype, decoded_fill_value = dtypes.maybe_promote(data.dtype) + if raw_fill_dict: + dims, data, attrs, encoding = unpack_for_decoding(variable) + [ + safe_setitem(encoding, attr, value, name=name) + for attr, value in raw_fill_dict.items() + ] if encoded_fill_values: + # special case DateTime to properly handle NaT + dtype: np.typing.DTypeLike + decoded_fill_value: Any + if _is_time_like(attrs.get("units")) and data.dtype.kind in "iu": + dtype, decoded_fill_value = np.int64, np.iinfo(np.int64).min + else: + if "scale_factor" not in attrs and "add_offset" not in attrs: + dtype, decoded_fill_value = dtypes.maybe_promote(data.dtype) + else: + dtype, decoded_fill_value = ( + _choose_float_dtype(data.dtype, attrs), + np.nan, + ) + transform = partial( _apply_mask, encoded_fill_values=encoded_fill_values, @@ -224,7 +394,7 @@ def decode(self, variable: Variable, name: T_Name = None): def _scale_offset_decoding(data, scale_factor, add_offset, dtype: np.typing.DTypeLike): - data = np.array(data, dtype=dtype, copy=True) + data = data.astype(dtype=dtype, copy=True) if scale_factor is not None: data *= scale_factor if add_offset is not None: @@ -232,20 +402,51 @@ def _scale_offset_decoding(data, scale_factor, add_offset, dtype: np.typing.DTyp return data -def _choose_float_dtype(dtype: np.dtype, has_offset: bool) -> type[np.floating[Any]]: +def _choose_float_dtype( + dtype: np.dtype, mapping: MutableMapping +) -> type[np.floating[Any]]: """Return a float dtype that can losslessly represent `dtype` values.""" - # Keep float32 as-is. Upcast half-precision to single-precision, + # check scale/offset first to derive wanted float dtype + # see https://github.com/pydata/xarray/issues/5597#issuecomment-879561954 + scale_factor = mapping.get("scale_factor") + add_offset = mapping.get("add_offset") + if scale_factor is not None or add_offset is not None: + # get the type from scale_factor/add_offset to determine + # the needed floating point type + if scale_factor is not None: + scale_type = np.dtype(type(scale_factor)) + if add_offset is not None: + offset_type = np.dtype(type(add_offset)) + # CF conforming, both scale_factor and add-offset are given and + # of same floating point type (float32/64) + if ( + add_offset is not None + and scale_factor is not None + and offset_type == scale_type + and scale_type in [np.float32, np.float64] + ): + # in case of int32 -> we need upcast to float64 + # due to precision issues + if dtype.itemsize == 4 and np.issubdtype(dtype, np.integer): + return np.float64 + return scale_type.type + # Not CF conforming and add_offset given: + # A scale factor is entirely safe (vanishing into the mantissa), + # but a large integer offset could lead to loss of precision. + # Sensitivity analysis can be tricky, so we just use a float64 + # if there's any offset at all - better unoptimised than wrong! + if add_offset is not None: + return np.float64 + # return dtype depending on given scale_factor + return scale_type.type + # If no scale_factor or add_offset is given, use some general rules. + # Keep float32 as-is. Upcast half-precision to single-precision, # because float16 is "intended for storage but not computation" if dtype.itemsize <= 4 and np.issubdtype(dtype, np.floating): return np.float32 # float32 can exactly represent all integers up to 24 bits if dtype.itemsize <= 2 and np.issubdtype(dtype, np.integer): - # A scale factor is entirely safe (vanishing into the mantissa), - # but a large integer offset could lead to loss of precision. - # Sensitivity analysis can be tricky, so we just use a float64 - # if there's any offset at all - better unoptimised than wrong! - if not has_offset: - return np.float32 + return np.float32 # For all other types and circumstances, we just use float64. # (safe because eg. complex numbers are not supported in NetCDF) return np.float64 @@ -262,8 +463,13 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: dims, data, attrs, encoding = unpack_for_encoding(variable) if "scale_factor" in encoding or "add_offset" in encoding: - dtype = _choose_float_dtype(data.dtype, "add_offset" in encoding) - data = data.astype(dtype=dtype, copy=True) + # if we have a _FillValue/masked_value we do not want to cast now + # but leave that to CFMaskCoder + dtype = data.dtype + if "_FillValue" not in encoding and "missing_value" not in encoding: + dtype = _choose_float_dtype(data.dtype, encoding) + # but still we need a copy prevent changing original data + data = duck_array_ops.astype(data, dtype=dtype, copy=True) if "add_offset" in encoding: data -= pop_to(encoding, attrs, "add_offset", name=name) if "scale_factor" in encoding: @@ -278,11 +484,17 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: scale_factor = pop_to(attrs, encoding, "scale_factor", name=name) add_offset = pop_to(attrs, encoding, "add_offset", name=name) - dtype = _choose_float_dtype(data.dtype, "add_offset" in encoding) if np.ndim(scale_factor) > 0: scale_factor = np.asarray(scale_factor).item() if np.ndim(add_offset) > 0: add_offset = np.asarray(add_offset).item() + # if we have a _FillValue/masked_value we already have the wanted + # floating point dtype here (via CFMaskCoder), so no check is necessary + # only check in other cases + dtype = data.dtype + if "_FillValue" not in encoding and "missing_value" not in encoding: + dtype = _choose_float_dtype(dtype, encoding) + transform = partial( _scale_offset_decoding, scale_factor=scale_factor, @@ -299,7 +511,7 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: class UnsignedIntegerCoder(VariableCoder): def encode(self, variable: Variable, name: T_Name = None) -> Variable: # from netCDF best practices - # https://www.unidata.ucar.edu/software/netcdf/docs/BestPractices.html + # https://docs.unidata.ucar.edu/nug/current/best_practices.html#bp_Unsigned-Data # "_Unsigned = "true" to indicate that # integer data should be treated as unsigned" if variable.encoding.get("_Unsigned", "false") == "true": @@ -310,7 +522,7 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: if "_FillValue" in attrs: new_fill = signed_dtype.type(attrs["_FillValue"]) attrs["_FillValue"] = new_fill - data = duck_array_ops.around(data).astype(signed_dtype) + data = duck_array_ops.astype(duck_array_ops.around(data), signed_dtype) return Variable(dims, data, attrs, encoding, fastpath=True) else: @@ -319,7 +531,6 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: def decode(self, variable: Variable, name: T_Name = None) -> Variable: if "_Unsigned" in variable.attrs: dims, data, attrs, encoding = unpack_for_decoding(variable) - unsigned = pop_to(attrs, encoding, "_Unsigned") if data.dtype.kind == "i": @@ -349,3 +560,132 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: return Variable(dims, data, attrs, encoding, fastpath=True) else: return variable + + +class DefaultFillvalueCoder(VariableCoder): + """Encode default _FillValue if needed.""" + + def encode(self, variable: Variable, name: T_Name = None) -> Variable: + dims, data, attrs, encoding = unpack_for_encoding(variable) + # make NaN the fill value for float types + if ( + "_FillValue" not in attrs + and "_FillValue" not in encoding + and np.issubdtype(variable.dtype, np.floating) + ): + attrs["_FillValue"] = variable.dtype.type(np.nan) + return Variable(dims, data, attrs, encoding, fastpath=True) + else: + return variable + + def decode(self, variable: Variable, name: T_Name = None) -> Variable: + raise NotImplementedError() + + +class BooleanCoder(VariableCoder): + """Code boolean values.""" + + def encode(self, variable: Variable, name: T_Name = None) -> Variable: + if ( + (variable.dtype == bool) + and ("dtype" not in variable.encoding) + and ("dtype" not in variable.attrs) + ): + dims, data, attrs, encoding = unpack_for_encoding(variable) + attrs["dtype"] = "bool" + data = duck_array_ops.astype(data, dtype="i1", copy=True) + + return Variable(dims, data, attrs, encoding, fastpath=True) + else: + return variable + + def decode(self, variable: Variable, name: T_Name = None) -> Variable: + if variable.attrs.get("dtype", False) == "bool": + dims, data, attrs, encoding = unpack_for_decoding(variable) + # overwrite (!) dtype in encoding, and remove from attrs + # needed for correct subsequent encoding + encoding["dtype"] = attrs.pop("dtype") + data = BoolTypeArray(data) + return Variable(dims, data, attrs, encoding, fastpath=True) + else: + return variable + + +class EndianCoder(VariableCoder): + """Decode Endianness to native.""" + + def encode(self): + raise NotImplementedError() + + def decode(self, variable: Variable, name: T_Name = None) -> Variable: + dims, data, attrs, encoding = unpack_for_decoding(variable) + if not data.dtype.isnative: + data = NativeEndiannessArray(data) + return Variable(dims, data, attrs, encoding, fastpath=True) + else: + return variable + + +class NonStringCoder(VariableCoder): + """Encode NonString variables if dtypes differ.""" + + def encode(self, variable: Variable, name: T_Name = None) -> Variable: + if "dtype" in variable.encoding and variable.encoding["dtype"] not in ( + "S1", + str, + ): + dims, data, attrs, encoding = unpack_for_encoding(variable) + dtype = np.dtype(encoding.pop("dtype")) + if dtype != variable.dtype: + if np.issubdtype(dtype, np.integer): + if ( + np.issubdtype(variable.dtype, np.floating) + and "_FillValue" not in variable.attrs + and "missing_value" not in variable.attrs + ): + warnings.warn( + f"saving variable {name} with floating " + "point data as an integer dtype without " + "any _FillValue to use for NaNs", + SerializationWarning, + stacklevel=10, + ) + data = np.around(data) + data = data.astype(dtype=dtype) + return Variable(dims, data, attrs, encoding, fastpath=True) + else: + return variable + + def decode(self): + raise NotImplementedError() + + +class ObjectVLenStringCoder(VariableCoder): + def encode(self): + raise NotImplementedError + + def decode(self, variable: Variable, name: T_Name = None) -> Variable: + if variable.dtype == object and variable.encoding.get("dtype", False) == str: + variable = variable.astype(variable.encoding["dtype"]) + return variable + else: + return variable + + +class NativeEnumCoder(VariableCoder): + """Encode Enum into variable dtype metadata.""" + + def encode(self, variable: Variable, name: T_Name = None) -> Variable: + if ( + "dtype" in variable.encoding + and np.dtype(variable.encoding["dtype"]).metadata + and "enum" in variable.encoding["dtype"].metadata + ): + dims, data, attrs, encoding = unpack_for_encoding(variable) + data = data.astype(dtype=variable.encoding.pop("dtype")) + return Variable(dims, data, attrs, encoding, fastpath=True) + else: + return variable + + def decode(self, variable: Variable, name: T_Name = None) -> Variable: + raise NotImplementedError() diff --git a/xarray/conventions.py b/xarray/conventions.py index 780172879c6..6eff45c5b2d 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -1,22 +1,22 @@ from __future__ import annotations -import warnings from collections import defaultdict from collections.abc import Hashable, Iterable, Mapping, MutableMapping -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Literal, Union import numpy as np import pandas as pd from xarray.coding import strings, times, variables from xarray.coding.variables import SerializationWarning, pop_to -from xarray.core import duck_array_ops, indexing +from xarray.core import indexing from xarray.core.common import ( _contains_datetime_like_objects, contains_cftime_datetimes, ) -from xarray.core.pycompat import is_duck_dask_array +from xarray.core.utils import emit_user_level_warning from xarray.core.variable import IndexVariable, Variable +from xarray.namedarray.utils import is_duck_dask_array CF_RELATED_DATA = ( "bounds", @@ -48,133 +48,23 @@ T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore] -class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin): - """Decode arrays on the fly from non-native to native endianness - - This is useful for decoding arrays from netCDF3 files (which are all - big endian) into native endianness, so they can be used with Cython - functions, such as those found in bottleneck and pandas. - - >>> x = np.arange(5, dtype=">i2") - - >>> x.dtype - dtype('>i2') - - >>> NativeEndiannessArray(x).dtype - dtype('int16') - - >>> indexer = indexing.BasicIndexer((slice(None),)) - >>> NativeEndiannessArray(x)[indexer].dtype - dtype('int16') - """ - - __slots__ = ("array",) - - def __init__(self, array): - self.array = indexing.as_indexable(array) - - @property - def dtype(self): - return np.dtype(self.array.dtype.kind + str(self.array.dtype.itemsize)) - - def __getitem__(self, key): - return np.asarray(self.array[key], dtype=self.dtype) - - -class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin): - """Decode arrays on the fly from integer to boolean datatype - - This is useful for decoding boolean arrays from integer typed netCDF - variables. - - >>> x = np.array([1, 0, 1, 1, 0], dtype="i1") - - >>> x.dtype - dtype('int8') - - >>> BoolTypeArray(x).dtype - dtype('bool') - - >>> indexer = indexing.BasicIndexer((slice(None),)) - >>> BoolTypeArray(x)[indexer].dtype - dtype('bool') - """ - - __slots__ = ("array",) - - def __init__(self, array): - self.array = indexing.as_indexable(array) - - @property - def dtype(self): - return np.dtype("bool") - - def __getitem__(self, key): - return np.asarray(self.array[key], dtype=self.dtype) - - -def _var_as_tuple(var: Variable) -> T_VarTuple: - return var.dims, var.data, var.attrs.copy(), var.encoding.copy() - - -def maybe_encode_nonstring_dtype(var: Variable, name: T_Name = None) -> Variable: - if "dtype" in var.encoding and var.encoding["dtype"] not in ("S1", str): - dims, data, attrs, encoding = _var_as_tuple(var) - dtype = np.dtype(encoding.pop("dtype")) - if dtype != var.dtype: - if np.issubdtype(dtype, np.integer): - if ( - np.issubdtype(var.dtype, np.floating) - and "_FillValue" not in var.attrs - and "missing_value" not in var.attrs - ): - warnings.warn( - f"saving variable {name} with floating " - "point data as an integer dtype without " - "any _FillValue to use for NaNs", - SerializationWarning, - stacklevel=10, - ) - data = np.around(data) - data = data.astype(dtype=dtype) - var = Variable(dims, data, attrs, encoding, fastpath=True) - return var - - -def maybe_default_fill_value(var: Variable) -> Variable: - # make NaN the fill value for float types: - if ( - "_FillValue" not in var.attrs - and "_FillValue" not in var.encoding - and np.issubdtype(var.dtype, np.floating) - ): - var.attrs["_FillValue"] = var.dtype.type(np.nan) - return var - - -def maybe_encode_bools(var: Variable) -> Variable: - if ( - (var.dtype == bool) - and ("dtype" not in var.encoding) - and ("dtype" not in var.attrs) - ): - dims, data, attrs, encoding = _var_as_tuple(var) - attrs["dtype"] = "bool" - data = duck_array_ops.astype(data, dtype="i1", copy=True) - var = Variable(dims, data, attrs, encoding, fastpath=True) - return var - - -def _infer_dtype(array, name: T_Name = None) -> np.dtype: - """Given an object array with no missing values, infer its dtype from its - first element - """ +def _infer_dtype(array, name=None): + """Given an object array with no missing values, infer its dtype from all elements.""" if array.dtype.kind != "O": raise TypeError("infer_type must be called on a dtype=object array") if array.size == 0: return np.dtype(float) + native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel())) + if len(native_dtypes) > 1 and native_dtypes != {bytes, str}: + raise ValueError( + "unable to infer dtype on variable {!r}; object array " + "contains mixed native types: {}".format( + name, ", ".join(x.__name__ for x in native_dtypes) + ) + ) + element = array[(0,) * array.ndim] # We use the base types to avoid subclasses of bytes and str (which might # not play nice with e.g. hdf5 datatypes), such as those from numpy @@ -188,21 +78,23 @@ def _infer_dtype(array, name: T_Name = None) -> np.dtype: return dtype raise ValueError( - "unable to infer dtype on variable {!r}; xarray " - "cannot serialize arbitrary Python objects".format(name) + f"unable to infer dtype on variable {name!r}; xarray " + "cannot serialize arbitrary Python objects" ) def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None: - if isinstance(var, IndexVariable) and isinstance(var.to_index(), pd.MultiIndex): - raise NotImplementedError( - "variable {!r} is a MultiIndex, which cannot yet be " - "serialized to netCDF files. Instead, either use reset_index() " - "to convert MultiIndex levels into coordinate variables instead " - "or use https://cf-xarray.readthedocs.io/en/latest/coding.html.".format( - name + # only the pandas multi-index dimension coordinate cannot be serialized (tuple values) + if isinstance(var._data, indexing.PandasMultiIndexingAdapter): + if name is None and isinstance(var, IndexVariable): + name = var.name + if var.dims == (name,): + raise NotImplementedError( + f"variable {name!r} is a MultiIndex, which cannot yet be " + "serialized. Instead, either use reset_index() " + "to convert MultiIndex levels into coordinate variables instead " + "or use https://cf-xarray.readthedocs.io/en/latest/coding.html." ) - ) def _copy_with_dtype(data, dtype: np.typing.DTypeLike): @@ -219,16 +111,20 @@ def _copy_with_dtype(data, dtype: np.typing.DTypeLike): def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable: # TODO: move this from conventions to backends? (it's not CF related) if var.dtype.kind == "O": - dims, data, attrs, encoding = _var_as_tuple(var) + dims, data, attrs, encoding = variables.unpack_for_encoding(var) + + # leave vlen dtypes unchanged + if strings.check_vlen_dtype(data.dtype) is not None: + return var if is_duck_dask_array(data): - warnings.warn( - "variable {} has data in the form of a dask array with " + emit_user_level_warning( + f"variable {name} has data in the form of a dask array with " "dtype=object, which means it is being loaded into memory " "to determine a data type that can be safely stored on disk. " "To avoid this, coerce this variable to a fixed-size dtype " - "with astype() before saving it.".format(name), - SerializationWarning, + "with astype() before saving it.", + category=SerializationWarning, ) data = data.compute() @@ -266,7 +162,7 @@ def encode_cf_variable( var: Variable, needs_copy: bool = True, name: T_Name = None ) -> Variable: """ - Converts an Variable into an Variable which follows some + Converts a Variable into a Variable which follows some of the CF conventions: - Nans are masked using _FillValue (or the deprecated missing_value) @@ -292,13 +188,14 @@ def encode_cf_variable( variables.CFScaleOffsetCoder(), variables.CFMaskCoder(), variables.UnsignedIntegerCoder(), + variables.NativeEnumCoder(), + variables.NonStringCoder(), + variables.DefaultFillvalueCoder(), + variables.BooleanCoder(), ]: var = coder.encode(var, name=name) - # TODO(shoyer): convert all of these to use coders, too: - var = maybe_encode_nonstring_dtype(var, name=name) - var = maybe_default_fill_value(var) - var = maybe_encode_bools(var) + # TODO(kmuehlbauer): check if ensure_dtype_not_object can be moved to backends: var = ensure_dtype_not_object(var, name=name) for attr_name in CF_RELATED_DATA: @@ -376,6 +273,10 @@ def decode_cf_variable( var = strings.CharacterArrayCoder().decode(var, name=name) var = strings.EncodedStringCoder().decode(var) + if original_dtype == object: + var = variables.ObjectVLenStringCoder().decode(var) + original_dtype = var.dtype + if mask_and_scale: for coder in [ variables.UnsignedIntegerCoder(), @@ -389,19 +290,15 @@ def decode_cf_variable( if decode_times: var = times.CFDatetimeCoder(use_cftime=use_cftime).decode(var, name=name) - dimensions, data, attributes, encoding = variables.unpack_for_decoding(var) - # TODO(shoyer): convert everything below to use coders + if decode_endianness and not var.dtype.isnative: + var = variables.EndianCoder().decode(var) + original_dtype = var.dtype - if decode_endianness and not data.dtype.isnative: - # do this last, so it's only done if we didn't already unmask/scale - data = NativeEndiannessArray(data) - original_dtype = data.dtype + var = variables.BooleanCoder().decode(var) - encoding.setdefault("dtype", original_dtype) + dimensions, data, attributes, encoding = variables.unpack_for_decoding(var) - if "dtype" in attributes and attributes["dtype"] == "bool": - del attributes["dtype"] - data = BoolTypeArray(data) + encoding.setdefault("dtype", original_dtype) if not is_duck_dask_array(data): data = indexing.LazilyIndexedArray(data) @@ -469,15 +366,14 @@ def _update_bounds_encoding(variables: T_Variables) -> None: and "bounds" in attrs and attrs["bounds"] in variables ): - warnings.warn( - "Variable '{0}' has datetime type and a " - "bounds variable but {0}.encoding does not have " - "units specified. The units encodings for '{0}' " - "and '{1}' will be determined independently " + emit_user_level_warning( + f"Variable {name:s} has datetime type and a " + f"bounds variable but {name:s}.encoding does not have " + f"units specified. The units encodings for {name:s} " + f"and {attrs['bounds']} will be determined independently " "and may not be equal, counter to CF-conventions. " "If this is a concern, specify a units encoding for " - "'{0}' before writing to a file.".format(name, attrs["bounds"]), - UserWarning, + f"{name:s} before writing to a file.", ) if has_date_units and "bounds" in attrs: @@ -494,7 +390,7 @@ def decode_cf_variables( concat_characters: bool = True, mask_and_scale: bool = True, decode_times: bool = True, - decode_coords: bool = True, + decode_coords: bool | Literal["coordinates", "all"] = True, drop_variables: T_DropVariables = None, use_cftime: bool | None = None, decode_timedelta: bool | None = None, @@ -552,15 +448,18 @@ def stackable(dim: Hashable) -> bool: decode_timedelta=decode_timedelta, ) except Exception as e: - raise type(e)(f"Failed to decode variable {k!r}: {e}") + raise type(e)(f"Failed to decode variable {k!r}: {e}") from e if decode_coords in [True, "coordinates", "all"]: var_attrs = new_vars[k].attrs if "coordinates" in var_attrs: - coord_str = var_attrs["coordinates"] - var_coord_names = coord_str.split() - if all(k in variables for k in var_coord_names): - new_vars[k].encoding["coordinates"] = coord_str - del var_attrs["coordinates"] + var_coord_names = [ + c for c in var_attrs["coordinates"].split() if c in variables + ] + # propagate as is + new_vars[k].encoding["coordinates"] = var_attrs["coordinates"] + del var_attrs["coordinates"] + # but only use as coordinate if existing + if var_coord_names: coord_names.update(var_coord_names) if decode_coords == "all": @@ -576,8 +475,8 @@ def stackable(dim: Hashable) -> bool: for role_or_name in part.split() ] if len(roles_and_names) % 2 == 1: - warnings.warn( - f"Attribute {attr_name:s} malformed", stacklevel=5 + emit_user_level_warning( + f"Attribute {attr_name:s} malformed" ) var_names = roles_and_names[1::2] if all(var_name in variables for var_name in var_names): @@ -589,9 +488,8 @@ def stackable(dim: Hashable) -> bool: for proj_name in var_names if proj_name not in variables ] - warnings.warn( + emit_user_level_warning( f"Variable(s) referenced in {attr_name:s} not in variables: {referenced_vars_not_in_variables!s}", - stacklevel=5, ) del var_attrs[attr_name] @@ -608,7 +506,7 @@ def decode_cf( concat_characters: bool = True, mask_and_scale: bool = True, decode_times: bool = True, - decode_coords: bool = True, + decode_coords: bool | Literal["coordinates", "all"] = True, drop_variables: T_DropVariables = None, use_cftime: bool | None = None, decode_timedelta: bool | None = None, @@ -736,23 +634,28 @@ def cf_decoder( decode_cf_variable """ variables, attributes, _ = decode_cf_variables( - variables, attributes, concat_characters, mask_and_scale, decode_times + variables, + attributes, + concat_characters, + mask_and_scale, + decode_times, ) return variables, attributes -def _encode_coordinates(variables, attributes, non_dim_coord_names): +def _encode_coordinates( + variables: T_Variables, attributes: T_Attrs, non_dim_coord_names +): # calculate global and variable specific coordinates non_dim_coord_names = set(non_dim_coord_names) for name in list(non_dim_coord_names): if isinstance(name, str) and " " in name: - warnings.warn( - "coordinate {!r} has a space in its name, which means it " + emit_user_level_warning( + f"coordinate {name!r} has a space in its name, which means it " "cannot be marked as a coordinate on disk and will be " - "saved as a data variable instead".format(name), - SerializationWarning, - stacklevel=6, + "saved as a data variable instead", + category=SerializationWarning, ) non_dim_coord_names.discard(name) @@ -770,7 +673,7 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): variable_coordinates[k].add(coord_name) if any( - attr_name in v.encoding and coord_name in v.encoding.get(attr_name) + coord_name in v.encoding.get(attr_name, tuple()) for attr_name in CF_RELATED_DATA ): not_technically_coordinates.add(coord_name) @@ -807,7 +710,7 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): if not coords_str and variable_coordinates[name]: coordinates_text = " ".join( str(coord_name) - for coord_name in variable_coordinates[name] + for coord_name in sorted(variable_coordinates[name]) if coord_name not in not_technically_coordinates ) if coordinates_text: @@ -825,19 +728,19 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): if global_coordinates: attributes = dict(attributes) if "coordinates" in attributes: - warnings.warn( + emit_user_level_warning( f"cannot serialize global coordinates {global_coordinates!r} because the global " f"attribute 'coordinates' already exists. This may prevent faithful roundtripping" f"of xarray datasets", - SerializationWarning, + category=SerializationWarning, ) else: - attributes["coordinates"] = " ".join(map(str, global_coordinates)) + attributes["coordinates"] = " ".join(sorted(map(str, global_coordinates))) return variables, attributes -def encode_dataset_coordinates(dataset): +def encode_dataset_coordinates(dataset: Dataset): """Encode coordinates on the given dataset object into variable specific and global attributes. @@ -859,7 +762,7 @@ def encode_dataset_coordinates(dataset): ) -def cf_encoder(variables, attributes): +def cf_encoder(variables: T_Variables, attributes: T_Attrs): """ Encode a set of CF encoded variables and attributes. Takes a dicts of variables and attributes and encodes them diff --git a/xarray/convert.py b/xarray/convert.py index 5863352ae41..b8d81ccf9f0 100644 --- a/xarray/convert.py +++ b/xarray/convert.py @@ -1,18 +1,17 @@ """Functions for converting to and from xarray objects """ + from collections import Counter import numpy as np -import pandas as pd from xarray.coding.times import CFDatetimeCoder, CFTimedeltaCoder from xarray.conventions import decode_cf from xarray.core import duck_array_ops from xarray.core.dataarray import DataArray from xarray.core.dtypes import get_fill_value -from xarray.core.pycompat import array_type +from xarray.namedarray.pycompat import array_type -cdms2_ignored_attrs = {"name", "tileIndex"} iris_forbidden_keys = { "standard_name", "long_name", @@ -60,92 +59,6 @@ def _filter_attrs(attrs, ignored_attrs): return {k: v for k, v in attrs.items() if k not in ignored_attrs} -def from_cdms2(variable): - """Convert a cdms2 variable into an DataArray""" - values = np.asarray(variable) - name = variable.id - dims = variable.getAxisIds() - coords = {} - for axis in variable.getAxisList(): - coords[axis.id] = DataArray( - np.asarray(axis), - dims=[axis.id], - attrs=_filter_attrs(axis.attributes, cdms2_ignored_attrs), - ) - grid = variable.getGrid() - if grid is not None: - ids = [a.id for a in grid.getAxisList()] - for axis in grid.getLongitude(), grid.getLatitude(): - if axis.id not in variable.getAxisIds(): - coords[axis.id] = DataArray( - np.asarray(axis[:]), - dims=ids, - attrs=_filter_attrs(axis.attributes, cdms2_ignored_attrs), - ) - attrs = _filter_attrs(variable.attributes, cdms2_ignored_attrs) - dataarray = DataArray(values, dims=dims, coords=coords, name=name, attrs=attrs) - return decode_cf(dataarray.to_dataset())[dataarray.name] - - -def to_cdms2(dataarray, copy=True): - """Convert a DataArray into a cdms2 variable""" - # we don't want cdms2 to be a hard dependency - import cdms2 - - def set_cdms2_attrs(var, attrs): - for k, v in attrs.items(): - setattr(var, k, v) - - # 1D axes - axes = [] - for dim in dataarray.dims: - coord = encode(dataarray.coords[dim]) - axis = cdms2.createAxis(coord.values, id=dim) - set_cdms2_attrs(axis, coord.attrs) - axes.append(axis) - - # Data - var = encode(dataarray) - cdms2_var = cdms2.createVariable( - var.values, axes=axes, id=dataarray.name, mask=pd.isnull(var.values), copy=copy - ) - - # Attributes - set_cdms2_attrs(cdms2_var, var.attrs) - - # Curvilinear and unstructured grids - if dataarray.name not in dataarray.coords: - cdms2_axes = {} - for coord_name in set(dataarray.coords.keys()) - set(dataarray.dims): - coord_array = dataarray.coords[coord_name].to_cdms2() - - cdms2_axis_cls = ( - cdms2.coord.TransientAxis2D - if coord_array.ndim - else cdms2.auxcoord.TransientAuxAxis1D - ) - cdms2_axis = cdms2_axis_cls(coord_array) - if cdms2_axis.isLongitude(): - cdms2_axes["lon"] = cdms2_axis - elif cdms2_axis.isLatitude(): - cdms2_axes["lat"] = cdms2_axis - - if "lon" in cdms2_axes and "lat" in cdms2_axes: - if len(cdms2_axes["lon"].shape) == 2: - cdms2_grid = cdms2.hgrid.TransientCurveGrid( - cdms2_axes["lat"], cdms2_axes["lon"] - ) - else: - cdms2_grid = cdms2.gengrid.AbstractGenericGrid( - cdms2_axes["lat"], cdms2_axes["lon"] - ) - for axis in cdms2_grid.getAxisList(): - cdms2_var.setAxis(cdms2_var.getAxisIds().index(axis.id), axis) - cdms2_var.setGrid(cdms2_grid) - - return cdms2_var - - def _pick_attrs(attrs, keys): """Return attrs with keys in keys list""" return {k: v for k, v in attrs.items() if k in keys} diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index 3051502beba..bee6afd5a19 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -1,4 +1,5 @@ """Mixin classes with reduction operations.""" + # This file was generated using xarray.util.generate_aggregations. Do not edit manually. from __future__ import annotations @@ -8,8 +9,8 @@ from xarray.core import duck_array_ops from xarray.core.options import OPTIONS -from xarray.core.types import Dims -from xarray.core.utils import contains_only_dask_or_numpy, module_available +from xarray.core.types import Dims, Self +from xarray.core.utils import contains_only_chunked_or_numpy, module_available if TYPE_CHECKING: from xarray.core.dataarray import DataArray @@ -30,7 +31,7 @@ def reduce( keep_attrs: bool | None = None, keepdims: bool = False, **kwargs: Any, - ) -> Dataset: + ) -> Self: raise NotImplementedError() def count( @@ -39,7 +40,7 @@ def count( *, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``count`` along some dimension(s). @@ -65,8 +66,8 @@ def count( See Also -------- - numpy.count - dask.array.count + pandas.DataFrame.count + dask.dataframe.DataFrame.count DataArray.count :ref:`agg` User guide on reduction or aggregation operations. @@ -74,28 +75,28 @@ def count( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.count() - + Size: 8B Dimensions: () Data variables: - da int64 5 + da int64 8B 5 """ return self.reduce( duck_array_ops.count, @@ -111,7 +112,7 @@ def all( *, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``all`` along some dimension(s). @@ -149,25 +150,25 @@ def all( ... np.array([True, True, True, True, True, False], dtype=bool), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 78B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.all() - + Size: 1B Dimensions: () Data variables: - da bool False + da bool 1B False """ return self.reduce( duck_array_ops.array_all, @@ -183,7 +184,7 @@ def any( *, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``any`` along some dimension(s). @@ -221,25 +222,25 @@ def any( ... np.array([True, True, True, True, True, False], dtype=bool), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 78B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.any() - + Size: 1B Dimensions: () Data variables: - da bool True + da bool 1B True """ return self.reduce( duck_array_ops.array_any, @@ -256,7 +257,7 @@ def max( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``max`` along some dimension(s). @@ -296,36 +297,36 @@ def max( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.max() - + Size: 8B Dimensions: () Data variables: - da float64 3.0 + da float64 8B 3.0 Use ``skipna`` to control whether NaNs are ignored. >>> ds.max(skipna=False) - + Size: 8B Dimensions: () Data variables: - da float64 nan + da float64 8B nan """ return self.reduce( duck_array_ops.max, @@ -343,7 +344,7 @@ def min( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``min`` along some dimension(s). @@ -383,36 +384,36 @@ def min( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.min() - + Size: 8B Dimensions: () Data variables: - da float64 1.0 + da float64 8B 0.0 Use ``skipna`` to control whether NaNs are ignored. >>> ds.min(skipna=False) - + Size: 8B Dimensions: () Data variables: - da float64 nan + da float64 8B nan """ return self.reduce( duck_array_ops.min, @@ -430,7 +431,7 @@ def mean( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``mean`` along some dimension(s). @@ -474,36 +475,36 @@ def mean( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.mean() - + Size: 8B Dimensions: () Data variables: - da float64 1.8 + da float64 8B 1.6 Use ``skipna`` to control whether NaNs are ignored. >>> ds.mean(skipna=False) - + Size: 8B Dimensions: () Data variables: - da float64 nan + da float64 8B nan """ return self.reduce( duck_array_ops.mean, @@ -522,7 +523,7 @@ def prod( min_count: int | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``prod`` along some dimension(s). @@ -572,44 +573,44 @@ def prod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.prod() - + Size: 8B Dimensions: () Data variables: - da float64 12.0 + da float64 8B 0.0 Use ``skipna`` to control whether NaNs are ignored. >>> ds.prod(skipna=False) - + Size: 8B Dimensions: () Data variables: - da float64 nan + da float64 8B nan Specify ``min_count`` for finer control over when NaNs are ignored. >>> ds.prod(skipna=True, min_count=2) - + Size: 8B Dimensions: () Data variables: - da float64 12.0 + da float64 8B 0.0 """ return self.reduce( duck_array_ops.prod, @@ -629,7 +630,7 @@ def sum( min_count: int | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``sum`` along some dimension(s). @@ -679,44 +680,44 @@ def sum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.sum() - + Size: 8B Dimensions: () Data variables: - da float64 9.0 + da float64 8B 8.0 Use ``skipna`` to control whether NaNs are ignored. >>> ds.sum(skipna=False) - + Size: 8B Dimensions: () Data variables: - da float64 nan + da float64 8B nan Specify ``min_count`` for finer control over when NaNs are ignored. >>> ds.sum(skipna=True, min_count=2) - + Size: 8B Dimensions: () Data variables: - da float64 9.0 + da float64 8B 8.0 """ return self.reduce( duck_array_ops.sum, @@ -736,7 +737,7 @@ def std( ddof: int = 0, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``std`` along some dimension(s). @@ -783,44 +784,44 @@ def std( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.std() - + Size: 8B Dimensions: () Data variables: - da float64 0.7483 + da float64 8B 1.02 Use ``skipna`` to control whether NaNs are ignored. >>> ds.std(skipna=False) - + Size: 8B Dimensions: () Data variables: - da float64 nan + da float64 8B nan Specify ``ddof=1`` for an unbiased estimate. >>> ds.std(skipna=True, ddof=1) - + Size: 8B Dimensions: () Data variables: - da float64 0.8367 + da float64 8B 1.14 """ return self.reduce( duck_array_ops.std, @@ -840,7 +841,7 @@ def var( ddof: int = 0, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``var`` along some dimension(s). @@ -887,44 +888,44 @@ def var( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.var() - + Size: 8B Dimensions: () Data variables: - da float64 0.56 + da float64 8B 1.04 Use ``skipna`` to control whether NaNs are ignored. >>> ds.var(skipna=False) - + Size: 8B Dimensions: () Data variables: - da float64 nan + da float64 8B nan Specify ``ddof=1`` for an unbiased estimate. >>> ds.var(skipna=True, ddof=1) - + Size: 8B Dimensions: () Data variables: - da float64 0.7 + da float64 8B 1.3 """ return self.reduce( duck_array_ops.var, @@ -943,7 +944,7 @@ def median( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``median`` along some dimension(s). @@ -987,36 +988,36 @@ def median( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.median() - + Size: 8B Dimensions: () Data variables: - da float64 2.0 + da float64 8B 2.0 Use ``skipna`` to control whether NaNs are ignored. >>> ds.median(skipna=False) - + Size: 8B Dimensions: () Data variables: - da float64 nan + da float64 8B nan """ return self.reduce( duck_array_ops.median, @@ -1034,7 +1035,7 @@ def cumsum( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``cumsum`` along some dimension(s). @@ -1078,38 +1079,38 @@ def cumsum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.cumsum() - + Size: 48B Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 3.0 6.0 7.0 9.0 9.0 + da (time) float64 48B 1.0 3.0 6.0 6.0 8.0 8.0 Use ``skipna`` to control whether NaNs are ignored. >>> ds.cumsum(skipna=False) - + Size: 48B Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 3.0 6.0 7.0 9.0 nan + da (time) float64 48B 1.0 3.0 6.0 6.0 8.0 nan """ return self.reduce( duck_array_ops.cumsum, @@ -1127,7 +1128,7 @@ def cumprod( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> Dataset: + ) -> Self: """ Reduce this Dataset's data by applying ``cumprod`` along some dimension(s). @@ -1171,38 +1172,38 @@ def cumprod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.cumprod() - + Size: 48B Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 6.0 6.0 12.0 12.0 + da (time) float64 48B 1.0 2.0 6.0 0.0 0.0 0.0 Use ``skipna`` to control whether NaNs are ignored. >>> ds.cumprod(skipna=False) - + Size: 48B Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 6.0 6.0 12.0 nan + da (time) float64 48B 1.0 2.0 6.0 0.0 0.0 nan """ return self.reduce( duck_array_ops.cumprod, @@ -1226,7 +1227,7 @@ def reduce( keep_attrs: bool | None = None, keepdims: bool = False, **kwargs: Any, - ) -> DataArray: + ) -> Self: raise NotImplementedError() def count( @@ -1235,7 +1236,7 @@ def count( *, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``count`` along some dimension(s). @@ -1261,8 +1262,8 @@ def count( See Also -------- - numpy.count - dask.array.count + pandas.DataFrame.count + dask.dataframe.DataFrame.count Dataset.count :ref:`agg` User guide on reduction or aggregation operations. @@ -1270,22 +1271,22 @@ def count( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.count() - + Size: 8B array(5) """ return self.reduce( @@ -1301,7 +1302,7 @@ def all( *, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``all`` along some dimension(s). @@ -1339,19 +1340,19 @@ def all( ... np.array([True, True, True, True, True, False], dtype=bool), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - + Size: 6B array([ True, True, True, True, True, False]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.all() - + Size: 1B array(False) """ return self.reduce( @@ -1367,7 +1368,7 @@ def any( *, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``any`` along some dimension(s). @@ -1405,19 +1406,19 @@ def any( ... np.array([True, True, True, True, True, False], dtype=bool), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - + Size: 6B array([ True, True, True, True, True, False]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.any() - + Size: 1B array(True) """ return self.reduce( @@ -1434,7 +1435,7 @@ def max( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``max`` along some dimension(s). @@ -1474,28 +1475,28 @@ def max( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.max() - + Size: 8B array(3.) Use ``skipna`` to control whether NaNs are ignored. >>> da.max(skipna=False) - + Size: 8B array(nan) """ return self.reduce( @@ -1513,7 +1514,7 @@ def min( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``min`` along some dimension(s). @@ -1553,28 +1554,28 @@ def min( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.min() - - array(1.) + Size: 8B + array(0.) Use ``skipna`` to control whether NaNs are ignored. >>> da.min(skipna=False) - + Size: 8B array(nan) """ return self.reduce( @@ -1592,7 +1593,7 @@ def mean( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``mean`` along some dimension(s). @@ -1636,28 +1637,28 @@ def mean( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.mean() - - array(1.8) + Size: 8B + array(1.6) Use ``skipna`` to control whether NaNs are ignored. >>> da.mean(skipna=False) - + Size: 8B array(nan) """ return self.reduce( @@ -1676,7 +1677,7 @@ def prod( min_count: int | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``prod`` along some dimension(s). @@ -1726,35 +1727,35 @@ def prod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.prod() - - array(12.) + Size: 8B + array(0.) Use ``skipna`` to control whether NaNs are ignored. >>> da.prod(skipna=False) - + Size: 8B array(nan) Specify ``min_count`` for finer control over when NaNs are ignored. >>> da.prod(skipna=True, min_count=2) - - array(12.) + Size: 8B + array(0.) """ return self.reduce( duck_array_ops.prod, @@ -1773,7 +1774,7 @@ def sum( min_count: int | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``sum`` along some dimension(s). @@ -1823,35 +1824,35 @@ def sum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.sum() - - array(9.) + Size: 8B + array(8.) Use ``skipna`` to control whether NaNs are ignored. >>> da.sum(skipna=False) - + Size: 8B array(nan) Specify ``min_count`` for finer control over when NaNs are ignored. >>> da.sum(skipna=True, min_count=2) - - array(9.) + Size: 8B + array(8.) """ return self.reduce( duck_array_ops.sum, @@ -1870,7 +1871,7 @@ def std( ddof: int = 0, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``std`` along some dimension(s). @@ -1917,35 +1918,35 @@ def std( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.std() - - array(0.74833148) + Size: 8B + array(1.0198039) Use ``skipna`` to control whether NaNs are ignored. >>> da.std(skipna=False) - + Size: 8B array(nan) Specify ``ddof=1`` for an unbiased estimate. >>> da.std(skipna=True, ddof=1) - - array(0.83666003) + Size: 8B + array(1.14017543) """ return self.reduce( duck_array_ops.std, @@ -1964,7 +1965,7 @@ def var( ddof: int = 0, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``var`` along some dimension(s). @@ -2011,35 +2012,35 @@ def var( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.var() - - array(0.56) + Size: 8B + array(1.04) Use ``skipna`` to control whether NaNs are ignored. >>> da.var(skipna=False) - + Size: 8B array(nan) Specify ``ddof=1`` for an unbiased estimate. >>> da.var(skipna=True, ddof=1) - - array(0.7) + Size: 8B + array(1.3) """ return self.reduce( duck_array_ops.var, @@ -2057,7 +2058,7 @@ def median( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``median`` along some dimension(s). @@ -2101,28 +2102,28 @@ def median( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.median() - + Size: 8B array(2.) Use ``skipna`` to control whether NaNs are ignored. >>> da.median(skipna=False) - + Size: 8B array(nan) """ return self.reduce( @@ -2140,7 +2141,7 @@ def cumsum( skipna: bool | None = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``cumsum`` along some dimension(s). @@ -2184,35 +2185,35 @@ def cumsum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.cumsum() - - array([1., 3., 6., 7., 9., 9.]) + Size: 48B + array([1., 3., 6., 6., 8., 8.]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.cumsum(skipna=False) - - array([ 1., 3., 6., 7., 9., nan]) + Size: 48B + array([ 1., 3., 6., 6., 8., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) DataArray: + ) -> Self: """ Reduce this DataArray's data by applying ``cumprod`` along some dimension(s). @@ -2273,35 +2274,35 @@ def cumprod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.cumprod() - - array([ 1., 2., 6., 6., 12., 12.]) + Size: 48B + array([1., 2., 6., 0., 0., 0.]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.cumprod(skipna=False) - - array([ 1., 2., 6., 6., 12., nan]) + Size: 48B + array([ 1., 2., 6., 0., 0., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) Dataset: + raise NotImplementedError() + def reduce( self, func: Callable[..., Any], @@ -2367,8 +2381,8 @@ def count( See Also -------- - numpy.count - dask.array.count + pandas.DataFrame.count + dask.dataframe.DataFrame.count Dataset.count :ref:`groupby` User guide on groupby operations. @@ -2385,35 +2399,35 @@ def count( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.groupby("labels").count() - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) int64 1 2 2 + da (labels) int64 24B 1 2 2 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="count", @@ -2424,7 +2438,7 @@ def count( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.count, dim=dim, numeric_only=False, @@ -2486,32 +2500,32 @@ def all( ... np.array([True, True, True, True, True, False], dtype=bool), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 78B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.groupby("labels").all() - + Size: 27B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) bool False True True + da (labels) bool 3B False True True """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="all", @@ -2522,7 +2536,7 @@ def all( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.array_all, dim=dim, numeric_only=False, @@ -2584,32 +2598,32 @@ def any( ... np.array([True, True, True, True, True, False], dtype=bool), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 78B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.groupby("labels").any() - + Size: 27B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) bool True True True + da (labels) bool 3B True True True """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="any", @@ -2620,7 +2634,7 @@ def any( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.array_any, dim=dim, numeric_only=False, @@ -2685,45 +2699,45 @@ def max( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.groupby("labels").max() - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 1.0 2.0 3.0 + da (labels) float64 24B 1.0 2.0 3.0 Use ``skipna`` to control whether NaNs are ignored. >>> ds.groupby("labels").max(skipna=False) - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 nan 2.0 3.0 + da (labels) float64 24B nan 2.0 3.0 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="max", @@ -2735,7 +2749,7 @@ def max( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.max, dim=dim, skipna=skipna, @@ -2801,45 +2815,45 @@ def min( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.groupby("labels").min() - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 1.0 2.0 1.0 + da (labels) float64 24B 1.0 2.0 0.0 Use ``skipna`` to control whether NaNs are ignored. >>> ds.groupby("labels").min(skipna=False) - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 nan 2.0 1.0 + da (labels) float64 24B nan 2.0 0.0 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="min", @@ -2851,7 +2865,7 @@ def min( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.min, dim=dim, skipna=skipna, @@ -2919,45 +2933,45 @@ def mean( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.groupby("labels").mean() - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 1.0 2.0 2.0 + da (labels) float64 24B 1.0 2.0 1.5 Use ``skipna`` to control whether NaNs are ignored. >>> ds.groupby("labels").mean(skipna=False) - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 nan 2.0 2.0 + da (labels) float64 24B nan 2.0 1.5 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="mean", @@ -2969,7 +2983,7 @@ def mean( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.mean, dim=dim, skipna=skipna, @@ -3044,55 +3058,55 @@ def prod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.groupby("labels").prod() - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 1.0 4.0 3.0 + da (labels) float64 24B 1.0 4.0 0.0 Use ``skipna`` to control whether NaNs are ignored. >>> ds.groupby("labels").prod(skipna=False) - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 nan 4.0 3.0 + da (labels) float64 24B nan 4.0 0.0 Specify ``min_count`` for finer control over when NaNs are ignored. >>> ds.groupby("labels").prod(skipna=True, min_count=2) - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 nan 4.0 3.0 + da (labels) float64 24B nan 4.0 0.0 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="prod", @@ -3105,7 +3119,7 @@ def prod( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.prod, dim=dim, skipna=skipna, @@ -3181,55 +3195,55 @@ def sum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.groupby("labels").sum() - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 1.0 4.0 4.0 + da (labels) float64 24B 1.0 4.0 3.0 Use ``skipna`` to control whether NaNs are ignored. >>> ds.groupby("labels").sum(skipna=False) - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 nan 4.0 4.0 + da (labels) float64 24B nan 4.0 3.0 Specify ``min_count`` for finer control over when NaNs are ignored. >>> ds.groupby("labels").sum(skipna=True, min_count=2) - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 nan 4.0 4.0 + da (labels) float64 24B nan 4.0 3.0 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="sum", @@ -3242,7 +3256,7 @@ def sum( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.sum, dim=dim, skipna=skipna, @@ -3315,55 +3329,55 @@ def std( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.groupby("labels").std() - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 0.0 0.0 1.0 + da (labels) float64 24B 0.0 0.0 1.5 Use ``skipna`` to control whether NaNs are ignored. >>> ds.groupby("labels").std(skipna=False) - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 nan 0.0 1.0 + da (labels) float64 24B nan 0.0 1.5 Specify ``ddof=1`` for an unbiased estimate. >>> ds.groupby("labels").std(skipna=True, ddof=1) - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 nan 0.0 1.414 + da (labels) float64 24B nan 0.0 2.121 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="std", @@ -3376,7 +3390,7 @@ def std( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.std, dim=dim, skipna=skipna, @@ -3449,55 +3463,55 @@ def var( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.groupby("labels").var() - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 0.0 0.0 1.0 + da (labels) float64 24B 0.0 0.0 2.25 Use ``skipna`` to control whether NaNs are ignored. >>> ds.groupby("labels").var(skipna=False) - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 nan 0.0 1.0 + da (labels) float64 24B nan 0.0 2.25 Specify ``ddof=1`` for an unbiased estimate. >>> ds.groupby("labels").var(skipna=True, ddof=1) - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 nan 0.0 2.0 + da (labels) float64 24B nan 0.0 4.5 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="var", @@ -3510,7 +3524,7 @@ def var( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.var, dim=dim, skipna=skipna, @@ -3579,42 +3593,42 @@ def median( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.groupby("labels").median() - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 1.0 2.0 2.0 + da (labels) float64 24B 1.0 2.0 1.5 Use ``skipna`` to control whether NaNs are ignored. >>> ds.groupby("labels").median(skipna=False) - + Size: 48B Dimensions: (labels: 3) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Data variables: - da (labels) float64 nan 2.0 2.0 + da (labels) float64 24B nan 2.0 1.5 """ - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.median, dim=dim, skipna=skipna, @@ -3682,40 +3696,40 @@ def cumsum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.groupby("labels").cumsum() - + Size: 48B Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 3.0 4.0 4.0 1.0 + da (time) float64 48B 1.0 2.0 3.0 3.0 4.0 1.0 Use ``skipna`` to control whether NaNs are ignored. >>> ds.groupby("labels").cumsum(skipna=False) - + Size: 48B Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 3.0 4.0 4.0 nan + da (time) float64 48B 1.0 2.0 3.0 3.0 4.0 nan """ - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.cumsum, dim=dim, skipna=skipna, @@ -3783,40 +3797,40 @@ def cumprod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.groupby("labels").cumprod() - + Size: 48B Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 3.0 3.0 4.0 1.0 + da (time) float64 48B 1.0 2.0 3.0 0.0 4.0 1.0 Use ``skipna`` to control whether NaNs are ignored. >>> ds.groupby("labels").cumprod(skipna=False) - + Size: 48B Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 3.0 3.0 4.0 nan + da (time) float64 48B 1.0 2.0 3.0 0.0 4.0 nan """ - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.cumprod, dim=dim, skipna=skipna, @@ -3829,6 +3843,19 @@ def cumprod( class DatasetResampleAggregations: _obj: Dataset + def _reduce_without_squeeze_warn( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> Dataset: + raise NotImplementedError() + def reduce( self, func: Callable[..., Any], @@ -3881,8 +3908,8 @@ def count( See Also -------- - numpy.count - dask.array.count + pandas.DataFrame.count + dask.dataframe.DataFrame.count Dataset.count :ref:`resampling` User guide on resampling operations. @@ -3899,35 +3926,35 @@ def count( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.resample(time="3M").count() - + >>> ds.resample(time="3ME").count() + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) int64 1 3 1 + da (time) int64 24B 1 3 1 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="count", @@ -3938,7 +3965,7 @@ def count( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.count, dim=dim, numeric_only=False, @@ -4000,32 +4027,32 @@ def all( ... np.array([True, True, True, True, True, False], dtype=bool), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 78B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.resample(time="3M").all() - + >>> ds.resample(time="3ME").all() + Size: 27B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) bool True True False + da (time) bool 3B True True False """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="all", @@ -4036,7 +4063,7 @@ def all( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.array_all, dim=dim, numeric_only=False, @@ -4098,32 +4125,32 @@ def any( ... np.array([True, True, True, True, True, False], dtype=bool), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 78B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.resample(time="3M").any() - + >>> ds.resample(time="3ME").any() + Size: 27B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) bool True True True + da (time) bool 3B True True True """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="any", @@ -4134,7 +4161,7 @@ def any( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.array_any, dim=dim, numeric_only=False, @@ -4199,45 +4226,45 @@ def max( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.resample(time="3M").max() - + >>> ds.resample(time="3ME").max() + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 3.0 2.0 + da (time) float64 24B 1.0 3.0 2.0 Use ``skipna`` to control whether NaNs are ignored. - >>> ds.resample(time="3M").max(skipna=False) - + >>> ds.resample(time="3ME").max(skipna=False) + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 3.0 nan + da (time) float64 24B 1.0 3.0 nan """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="max", @@ -4249,7 +4276,7 @@ def max( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.max, dim=dim, skipna=skipna, @@ -4315,45 +4342,45 @@ def min( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.resample(time="3M").min() - + >>> ds.resample(time="3ME").min() + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 1.0 2.0 + da (time) float64 24B 1.0 0.0 2.0 Use ``skipna`` to control whether NaNs are ignored. - >>> ds.resample(time="3M").min(skipna=False) - + >>> ds.resample(time="3ME").min(skipna=False) + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 1.0 nan + da (time) float64 24B 1.0 0.0 nan """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="min", @@ -4365,7 +4392,7 @@ def min( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.min, dim=dim, skipna=skipna, @@ -4433,45 +4460,45 @@ def mean( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.resample(time="3M").mean() - + >>> ds.resample(time="3ME").mean() + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 2.0 2.0 + da (time) float64 24B 1.0 1.667 2.0 Use ``skipna`` to control whether NaNs are ignored. - >>> ds.resample(time="3M").mean(skipna=False) - + >>> ds.resample(time="3ME").mean(skipna=False) + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 2.0 nan + da (time) float64 24B 1.0 1.667 nan """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="mean", @@ -4483,7 +4510,7 @@ def mean( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.mean, dim=dim, skipna=skipna, @@ -4558,55 +4585,55 @@ def prod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.resample(time="3M").prod() - + >>> ds.resample(time="3ME").prod() + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 6.0 2.0 + da (time) float64 24B 1.0 0.0 2.0 Use ``skipna`` to control whether NaNs are ignored. - >>> ds.resample(time="3M").prod(skipna=False) - + >>> ds.resample(time="3ME").prod(skipna=False) + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 6.0 nan + da (time) float64 24B 1.0 0.0 nan Specify ``min_count`` for finer control over when NaNs are ignored. - >>> ds.resample(time="3M").prod(skipna=True, min_count=2) - + >>> ds.resample(time="3ME").prod(skipna=True, min_count=2) + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 nan 6.0 nan + da (time) float64 24B nan 0.0 nan """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="prod", @@ -4619,7 +4646,7 @@ def prod( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.prod, dim=dim, skipna=skipna, @@ -4695,55 +4722,55 @@ def sum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.resample(time="3M").sum() - + >>> ds.resample(time="3ME").sum() + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 6.0 2.0 + da (time) float64 24B 1.0 5.0 2.0 Use ``skipna`` to control whether NaNs are ignored. - >>> ds.resample(time="3M").sum(skipna=False) - + >>> ds.resample(time="3ME").sum(skipna=False) + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 6.0 nan + da (time) float64 24B 1.0 5.0 nan Specify ``min_count`` for finer control over when NaNs are ignored. - >>> ds.resample(time="3M").sum(skipna=True, min_count=2) - + >>> ds.resample(time="3ME").sum(skipna=True, min_count=2) + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 nan 6.0 nan + da (time) float64 24B nan 5.0 nan """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="sum", @@ -4756,7 +4783,7 @@ def sum( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.sum, dim=dim, skipna=skipna, @@ -4829,55 +4856,55 @@ def std( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.resample(time="3M").std() - + >>> ds.resample(time="3ME").std() + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 0.0 0.8165 0.0 + da (time) float64 24B 0.0 1.247 0.0 Use ``skipna`` to control whether NaNs are ignored. - >>> ds.resample(time="3M").std(skipna=False) - + >>> ds.resample(time="3ME").std(skipna=False) + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 0.0 0.8165 nan + da (time) float64 24B 0.0 1.247 nan Specify ``ddof=1`` for an unbiased estimate. - >>> ds.resample(time="3M").std(skipna=True, ddof=1) - + >>> ds.resample(time="3ME").std(skipna=True, ddof=1) + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 nan 1.0 nan + da (time) float64 24B nan 1.528 nan """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="std", @@ -4890,7 +4917,7 @@ def std( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.std, dim=dim, skipna=skipna, @@ -4963,55 +4990,55 @@ def var( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.resample(time="3M").var() - + >>> ds.resample(time="3ME").var() + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 0.0 0.6667 0.0 + da (time) float64 24B 0.0 1.556 0.0 Use ``skipna`` to control whether NaNs are ignored. - >>> ds.resample(time="3M").var(skipna=False) - + >>> ds.resample(time="3ME").var(skipna=False) + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 0.0 0.6667 nan + da (time) float64 24B 0.0 1.556 nan Specify ``ddof=1`` for an unbiased estimate. - >>> ds.resample(time="3M").var(skipna=True, ddof=1) - + >>> ds.resample(time="3ME").var(skipna=True, ddof=1) + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 nan 1.0 nan + da (time) float64 24B nan 2.333 nan """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="var", @@ -5024,7 +5051,7 @@ def var( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.var, dim=dim, skipna=skipna, @@ -5093,42 +5120,42 @@ def median( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.resample(time="3M").median() - + >>> ds.resample(time="3ME").median() + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 2.0 2.0 + da (time) float64 24B 1.0 2.0 2.0 Use ``skipna`` to control whether NaNs are ignored. - >>> ds.resample(time="3M").median(skipna=False) - + >>> ds.resample(time="3ME").median(skipna=False) + Size: 48B Dimensions: (time: 3) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Data variables: - da (time) float64 1.0 2.0 nan + da (time) float64 24B 1.0 2.0 nan """ - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.median, dim=dim, skipna=skipna, @@ -5196,40 +5223,40 @@ def cumsum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.resample(time="3M").cumsum() - + >>> ds.resample(time="3ME").cumsum() + Size: 48B Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 5.0 6.0 2.0 2.0 + da (time) float64 48B 1.0 2.0 5.0 5.0 2.0 2.0 Use ``skipna`` to control whether NaNs are ignored. - >>> ds.resample(time="3M").cumsum(skipna=False) - + >>> ds.resample(time="3ME").cumsum(skipna=False) + Size: 48B Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 5.0 6.0 2.0 nan + da (time) float64 48B 1.0 2.0 5.0 5.0 2.0 nan """ - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.cumsum, dim=dim, skipna=skipna, @@ -5297,40 +5324,40 @@ def cumprod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> ds = xr.Dataset(dict(da=da)) >>> ds - + Size: 120B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> ds.resample(time="3M").cumprod() - + >>> ds.resample(time="3ME").cumprod() + Size: 48B Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 6.0 6.0 2.0 2.0 + da (time) float64 48B 1.0 2.0 6.0 0.0 2.0 2.0 Use ``skipna`` to control whether NaNs are ignored. - >>> ds.resample(time="3M").cumprod(skipna=False) - + >>> ds.resample(time="3ME").cumprod(skipna=False) + Size: 48B Dimensions: (time: 6) Dimensions without coordinates: time Data variables: - da (time) float64 1.0 2.0 6.0 6.0 2.0 nan + da (time) float64 48B 1.0 2.0 6.0 0.0 2.0 nan """ - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.cumprod, dim=dim, skipna=skipna, @@ -5343,6 +5370,19 @@ def cumprod( class DataArrayGroupByAggregations: _obj: DataArray + def _reduce_without_squeeze_warn( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> DataArray: + raise NotImplementedError() + def reduce( self, func: Callable[..., Any], @@ -5395,8 +5435,8 @@ def count( See Also -------- - numpy.count - dask.array.count + pandas.DataFrame.count + dask.dataframe.DataFrame.count DataArray.count :ref:`groupby` User guide on groupby operations. @@ -5413,30 +5453,30 @@ def count( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").count() - + Size: 24B array([1, 2, 2]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="count", @@ -5446,7 +5486,7 @@ def count( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.count, dim=dim, keep_attrs=keep_attrs, @@ -5507,27 +5547,27 @@ def all( ... np.array([True, True, True, True, True, False], dtype=bool), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - + Size: 6B array([ True, True, True, True, True, False]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").all() - + Size: 3B array([False, True, True]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="all", @@ -5537,7 +5577,7 @@ def all( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.array_all, dim=dim, keep_attrs=keep_attrs, @@ -5598,27 +5638,27 @@ def any( ... np.array([True, True, True, True, True, False], dtype=bool), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - + Size: 6B array([ True, True, True, True, True, False]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").any() - + Size: 3B array([ True, True, True]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="any", @@ -5628,7 +5668,7 @@ def any( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.array_any, dim=dim, keep_attrs=keep_attrs, @@ -5692,38 +5732,38 @@ def max( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").max() - + Size: 24B array([1., 2., 3.]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Use ``skipna`` to control whether NaNs are ignored. >>> da.groupby("labels").max(skipna=False) - + Size: 24B array([nan, 2., 3.]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="max", @@ -5734,7 +5774,7 @@ def max( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.max, dim=dim, skipna=skipna, @@ -5799,38 +5839,38 @@ def min( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").min() - - array([1., 2., 1.]) + Size: 24B + array([1., 2., 0.]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Use ``skipna`` to control whether NaNs are ignored. >>> da.groupby("labels").min(skipna=False) - - array([nan, 2., 1.]) + Size: 24B + array([nan, 2., 0.]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="min", @@ -5841,7 +5881,7 @@ def min( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.min, dim=dim, skipna=skipna, @@ -5908,38 +5948,38 @@ def mean( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").mean() - - array([1., 2., 2.]) + Size: 24B + array([1. , 2. , 1.5]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Use ``skipna`` to control whether NaNs are ignored. >>> da.groupby("labels").mean(skipna=False) - - array([nan, 2., 2.]) + Size: 24B + array([nan, 2. , 1.5]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="mean", @@ -5950,7 +5990,7 @@ def mean( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.mean, dim=dim, skipna=skipna, @@ -6024,46 +6064,46 @@ def prod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").prod() - - array([1., 4., 3.]) + Size: 24B + array([1., 4., 0.]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Use ``skipna`` to control whether NaNs are ignored. >>> da.groupby("labels").prod(skipna=False) - - array([nan, 4., 3.]) + Size: 24B + array([nan, 4., 0.]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Specify ``min_count`` for finer control over when NaNs are ignored. >>> da.groupby("labels").prod(skipna=True, min_count=2) - - array([nan, 4., 3.]) + Size: 24B + array([nan, 4., 0.]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="prod", @@ -6075,7 +6115,7 @@ def prod( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.prod, dim=dim, skipna=skipna, @@ -6150,46 +6190,46 @@ def sum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").sum() - - array([1., 4., 4.]) + Size: 24B + array([1., 4., 3.]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Use ``skipna`` to control whether NaNs are ignored. >>> da.groupby("labels").sum(skipna=False) - - array([nan, 4., 4.]) + Size: 24B + array([nan, 4., 3.]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Specify ``min_count`` for finer control over when NaNs are ignored. >>> da.groupby("labels").sum(skipna=True, min_count=2) - - array([nan, 4., 4.]) + Size: 24B + array([nan, 4., 3.]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="sum", @@ -6201,7 +6241,7 @@ def sum( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.sum, dim=dim, skipna=skipna, @@ -6273,46 +6313,46 @@ def std( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").std() - - array([0., 0., 1.]) + Size: 24B + array([0. , 0. , 1.5]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Use ``skipna`` to control whether NaNs are ignored. >>> da.groupby("labels").std(skipna=False) - - array([nan, 0., 1.]) + Size: 24B + array([nan, 0. , 1.5]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Specify ``ddof=1`` for an unbiased estimate. >>> da.groupby("labels").std(skipna=True, ddof=1) - - array([ nan, 0. , 1.41421356]) + Size: 24B + array([ nan, 0. , 2.12132034]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="std", @@ -6324,7 +6364,7 @@ def std( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.std, dim=dim, skipna=skipna, @@ -6396,46 +6436,46 @@ def var( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").var() - - array([0., 0., 1.]) + Size: 24B + array([0. , 0. , 2.25]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Use ``skipna`` to control whether NaNs are ignored. >>> da.groupby("labels").var(skipna=False) - - array([nan, 0., 1.]) + Size: 24B + array([ nan, 0. , 2.25]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Specify ``ddof=1`` for an unbiased estimate. >>> da.groupby("labels").var(skipna=True, ddof=1) - - array([nan, 0., 2.]) + Size: 24B + array([nan, 0. , 4.5]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="var", @@ -6447,7 +6487,7 @@ def var( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.var, dim=dim, skipna=skipna, @@ -6515,35 +6555,35 @@ def median( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").median() - - array([1., 2., 2.]) + Size: 24B + array([1. , 2. , 1.5]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' Use ``skipna`` to control whether NaNs are ignored. >>> da.groupby("labels").median(skipna=False) - - array([nan, 2., 2.]) + Size: 24B + array([nan, 2. , 1.5]) Coordinates: - * labels (labels) object 'a' 'b' 'c' + * labels (labels) object 24B 'a' 'b' 'c' """ - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.median, dim=dim, skipna=skipna, @@ -6610,37 +6650,37 @@ def cumsum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").cumsum() - - array([1., 2., 3., 4., 4., 1.]) + Size: 48B + array([1., 2., 3., 3., 4., 1.]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").cumsum(skipna=False) - - array([ 1., 2., 3., 4., 4., nan]) + Size: 48B + array([ 1., 2., 3., 3., 4., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").cumprod() - - array([1., 2., 3., 3., 4., 1.]) + Size: 48B + array([1., 2., 3., 0., 4., 1.]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.groupby("labels").cumprod(skipna=False) - - array([ 1., 2., 3., 3., 4., nan]) + Size: 48B + array([ 1., 2., 3., 0., 4., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) DataArray: + raise NotImplementedError() + def reduce( self, func: Callable[..., Any], @@ -6801,8 +6854,8 @@ def count( See Also -------- - numpy.count - dask.array.count + pandas.DataFrame.count + dask.dataframe.DataFrame.count DataArray.count :ref:`resampling` User guide on resampling operations. @@ -6819,30 +6872,30 @@ def count( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.resample(time="3M").count() - + >>> da.resample(time="3ME").count() + Size: 24B array([1, 3, 1]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="count", @@ -6852,7 +6905,7 @@ def count( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.count, dim=dim, keep_attrs=keep_attrs, @@ -6913,27 +6966,27 @@ def all( ... np.array([True, True, True, True, True, False], dtype=bool), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - + Size: 6B array([ True, True, True, True, True, False]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.resample(time="3M").all() - + >>> da.resample(time="3ME").all() + Size: 3B array([ True, True, False]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="all", @@ -6943,7 +6996,7 @@ def all( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.array_all, dim=dim, keep_attrs=keep_attrs, @@ -7004,27 +7057,27 @@ def any( ... np.array([True, True, True, True, True, False], dtype=bool), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - + Size: 6B array([ True, True, True, True, True, False]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.resample(time="3M").any() - + >>> da.resample(time="3ME").any() + Size: 3B array([ True, True, True]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="any", @@ -7034,7 +7087,7 @@ def any( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.array_any, dim=dim, keep_attrs=keep_attrs, @@ -7098,38 +7151,38 @@ def max( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.resample(time="3M").max() - + >>> da.resample(time="3ME").max() + Size: 24B array([1., 3., 2.]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Use ``skipna`` to control whether NaNs are ignored. - >>> da.resample(time="3M").max(skipna=False) - + >>> da.resample(time="3ME").max(skipna=False) + Size: 24B array([ 1., 3., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="max", @@ -7140,7 +7193,7 @@ def max( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.max, dim=dim, skipna=skipna, @@ -7205,38 +7258,38 @@ def min( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.resample(time="3M").min() - - array([1., 1., 2.]) + >>> da.resample(time="3ME").min() + Size: 24B + array([1., 0., 2.]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Use ``skipna`` to control whether NaNs are ignored. - >>> da.resample(time="3M").min(skipna=False) - - array([ 1., 1., nan]) + >>> da.resample(time="3ME").min(skipna=False) + Size: 24B + array([ 1., 0., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="min", @@ -7247,7 +7300,7 @@ def min( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.min, dim=dim, skipna=skipna, @@ -7314,38 +7367,38 @@ def mean( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.resample(time="3M").mean() - - array([1., 2., 2.]) + >>> da.resample(time="3ME").mean() + Size: 24B + array([1. , 1.66666667, 2. ]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Use ``skipna`` to control whether NaNs are ignored. - >>> da.resample(time="3M").mean(skipna=False) - - array([ 1., 2., nan]) + >>> da.resample(time="3ME").mean(skipna=False) + Size: 24B + array([1. , 1.66666667, nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="mean", @@ -7356,7 +7409,7 @@ def mean( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.mean, dim=dim, skipna=skipna, @@ -7430,46 +7483,46 @@ def prod( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.resample(time="3M").prod() - - array([1., 6., 2.]) + >>> da.resample(time="3ME").prod() + Size: 24B + array([1., 0., 2.]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Use ``skipna`` to control whether NaNs are ignored. - >>> da.resample(time="3M").prod(skipna=False) - - array([ 1., 6., nan]) + >>> da.resample(time="3ME").prod(skipna=False) + Size: 24B + array([ 1., 0., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Specify ``min_count`` for finer control over when NaNs are ignored. - >>> da.resample(time="3M").prod(skipna=True, min_count=2) - - array([nan, 6., nan]) + >>> da.resample(time="3ME").prod(skipna=True, min_count=2) + Size: 24B + array([nan, 0., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="prod", @@ -7481,7 +7534,7 @@ def prod( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.prod, dim=dim, skipna=skipna, @@ -7556,46 +7609,46 @@ def sum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.resample(time="3M").sum() - - array([1., 6., 2.]) + >>> da.resample(time="3ME").sum() + Size: 24B + array([1., 5., 2.]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Use ``skipna`` to control whether NaNs are ignored. - >>> da.resample(time="3M").sum(skipna=False) - - array([ 1., 6., nan]) + >>> da.resample(time="3ME").sum(skipna=False) + Size: 24B + array([ 1., 5., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Specify ``min_count`` for finer control over when NaNs are ignored. - >>> da.resample(time="3M").sum(skipna=True, min_count=2) - - array([nan, 6., nan]) + >>> da.resample(time="3ME").sum(skipna=True, min_count=2) + Size: 24B + array([nan, 5., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="sum", @@ -7607,7 +7660,7 @@ def sum( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.sum, dim=dim, skipna=skipna, @@ -7679,46 +7732,46 @@ def std( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.resample(time="3M").std() - - array([0. , 0.81649658, 0. ]) + >>> da.resample(time="3ME").std() + Size: 24B + array([0. , 1.24721913, 0. ]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Use ``skipna`` to control whether NaNs are ignored. - >>> da.resample(time="3M").std(skipna=False) - - array([0. , 0.81649658, nan]) + >>> da.resample(time="3ME").std(skipna=False) + Size: 24B + array([0. , 1.24721913, nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Specify ``ddof=1`` for an unbiased estimate. - >>> da.resample(time="3M").std(skipna=True, ddof=1) - - array([nan, 1., nan]) + >>> da.resample(time="3ME").std(skipna=True, ddof=1) + Size: 24B + array([ nan, 1.52752523, nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="std", @@ -7730,7 +7783,7 @@ def std( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.std, dim=dim, skipna=skipna, @@ -7802,46 +7855,46 @@ def var( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.resample(time="3M").var() - - array([0. , 0.66666667, 0. ]) + >>> da.resample(time="3ME").var() + Size: 24B + array([0. , 1.55555556, 0. ]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Use ``skipna`` to control whether NaNs are ignored. - >>> da.resample(time="3M").var(skipna=False) - - array([0. , 0.66666667, nan]) + >>> da.resample(time="3ME").var(skipna=False) + Size: 24B + array([0. , 1.55555556, nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Specify ``ddof=1`` for an unbiased estimate. - >>> da.resample(time="3M").var(skipna=True, ddof=1) - - array([nan, 1., nan]) + >>> da.resample(time="3ME").var(skipna=True, ddof=1) + Size: 24B + array([ nan, 2.33333333, nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 """ if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="var", @@ -7853,7 +7906,7 @@ def var( **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.var, dim=dim, skipna=skipna, @@ -7921,35 +7974,35 @@ def median( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.resample(time="3M").median() - + >>> da.resample(time="3ME").median() + Size: 24B array([1., 2., 2.]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 Use ``skipna`` to control whether NaNs are ignored. - >>> da.resample(time="3M").median(skipna=False) - + >>> da.resample(time="3ME").median(skipna=False) + Size: 24B array([ 1., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 + * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.median, dim=dim, skipna=skipna, @@ -8016,37 +8069,37 @@ def cumsum( Examples -------- >>> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.resample(time="3M").cumsum() - - array([1., 2., 5., 6., 2., 2.]) + >>> da.resample(time="3ME").cumsum() + Size: 48B + array([1., 2., 5., 5., 2., 2.]) Coordinates: - labels (time) >> da.resample(time="3M").cumsum(skipna=False) - - array([ 1., 2., 5., 6., 2., nan]) + >>> da.resample(time="3ME").cumsum(skipna=False) + Size: 48B + array([ 1., 2., 5., 5., 2., nan]) Coordinates: - labels (time) >> da = xr.DataArray( - ... np.array([1, 2, 3, 1, 2, np.nan]), + ... np.array([1, 2, 3, 0, 2, np.nan]), ... dims="time", ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ... ), ... ) >>> da - - array([ 1., 2., 3., 1., 2., nan]) + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) Coordinates: - * time (time) datetime64[ns] 2001-01-31 2001-02-28 ... 2001-06-30 - labels (time) >> da.resample(time="3M").cumprod() - - array([1., 2., 6., 6., 2., 2.]) + >>> da.resample(time="3ME").cumprod() + Size: 48B + array([1., 2., 6., 0., 2., 2.]) Coordinates: - labels (time) >> da.resample(time="3M").cumprod(skipna=False) - - array([ 1., 2., 6., 6., 2., nan]) + >>> da.resample(time="3ME").cumprod(skipna=False) + Size: 48B + array([ 1., 2., 6., 0., 2., nan]) Coordinates: - labels (time) Self: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.or_) - def __lt__(self, other): + def __lshift__(self, other: DsCompatible) -> Self: + return self._binary_op(other, operator.lshift) + + def __rshift__(self, other: DsCompatible) -> Self: + return self._binary_op(other, operator.rshift) + + def __lt__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: DsCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: DsCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + # When __eq__ is defined but __hash__ is not, then an object is unhashable, + # and it should be declared as follows: + __hash__: None # type:ignore[assignment] + + def __radd__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: DsCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ior) - def _unary_op(self, f, *args, **kwargs): + def __ilshift__(self, other: DsCompatible) -> Self: + return self._inplace_binary_op(other, operator.ilshift) + + def __irshift__(self, other: DsCompatible) -> Self: + return self._inplace_binary_op(other, operator.irshift) + + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -160,6 +194,8 @@ def conjugate(self, *args, **kwargs): __and__.__doc__ = operator.and_.__doc__ __xor__.__doc__ = operator.xor.__doc__ __or__.__doc__ = operator.or_.__doc__ + __lshift__.__doc__ = operator.lshift.__doc__ + __rshift__.__doc__ = operator.rshift.__doc__ __lt__.__doc__ = operator.lt.__doc__ __le__.__doc__ = operator.le.__doc__ __gt__.__doc__ = operator.gt.__doc__ @@ -186,6 +222,8 @@ def conjugate(self, *args, **kwargs): __iand__.__doc__ = operator.iand.__doc__ __ixor__.__doc__ = operator.ixor.__doc__ __ior__.__doc__ = operator.ior.__doc__ + __ilshift__.__doc__ = operator.ilshift.__doc__ + __irshift__.__doc__ = operator.irshift.__doc__ __neg__.__doc__ = operator.neg.__doc__ __pos__.__doc__ = operator.pos.__doc__ __abs__.__doc__ = operator.abs.__doc__ @@ -199,145 +237,163 @@ def conjugate(self, *args, **kwargs): class DataArrayOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: DaCompatible, f: Callable, reflexive: bool = False + ) -> Self: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.or_) - def __lt__(self, other): + def __lshift__(self, other: DaCompatible) -> Self: + return self._binary_op(other, operator.lshift) + + def __rshift__(self, other: DaCompatible) -> Self: + return self._binary_op(other, operator.rshift) + + def __lt__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: DaCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: DaCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + # When __eq__ is defined but __hash__ is not, then an object is unhashable, + # and it should be declared as follows: + __hash__: None # type:ignore[assignment] + + def __radd__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ior) - def _unary_op(self, f, *args, **kwargs): + def __ilshift__(self, other: DaCompatible) -> Self: + return self._inplace_binary_op(other, operator.ilshift) + + def __irshift__(self, other: DaCompatible) -> Self: + return self._inplace_binary_op(other, operator.irshift) + + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -350,6 +406,8 @@ def conjugate(self, *args, **kwargs): __and__.__doc__ = operator.and_.__doc__ __xor__.__doc__ = operator.xor.__doc__ __or__.__doc__ = operator.or_.__doc__ + __lshift__.__doc__ = operator.lshift.__doc__ + __rshift__.__doc__ = operator.rshift.__doc__ __lt__.__doc__ = operator.lt.__doc__ __le__.__doc__ = operator.le.__doc__ __gt__.__doc__ = operator.gt.__doc__ @@ -376,6 +434,8 @@ def conjugate(self, *args, **kwargs): __iand__.__doc__ = operator.iand.__doc__ __ixor__.__doc__ = operator.ixor.__doc__ __ior__.__doc__ = operator.ior.__doc__ + __ilshift__.__doc__ = operator.ilshift.__doc__ + __irshift__.__doc__ = operator.irshift.__doc__ __neg__.__doc__ = operator.neg.__doc__ __pos__.__doc__ = operator.pos.__doc__ __abs__.__doc__ = operator.abs.__doc__ @@ -389,145 +449,271 @@ def conjugate(self, *args, **kwargs): class VariableOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: VarCompatible, f: Callable, reflexive: bool = False + ) -> Self: raise NotImplementedError - def __add__(self, other): + @overload + def __add__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __add__(self, other: VarCompatible) -> Self: ... + + def __add__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.add) - def __sub__(self, other): + @overload + def __sub__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __sub__(self, other: VarCompatible) -> Self: ... + + def __sub__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.sub) - def __mul__(self, other): + @overload + def __mul__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __mul__(self, other: VarCompatible) -> Self: ... + + def __mul__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.mul) - def __pow__(self, other): + @overload + def __pow__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __pow__(self, other: VarCompatible) -> Self: ... + + def __pow__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + @overload + def __truediv__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __truediv__(self, other: VarCompatible) -> Self: ... + + def __truediv__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + @overload + def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __floordiv__(self, other: VarCompatible) -> Self: ... + + def __floordiv__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + @overload + def __mod__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __mod__(self, other: VarCompatible) -> Self: ... + + def __mod__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.mod) - def __and__(self, other): + @overload + def __and__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __and__(self, other: VarCompatible) -> Self: ... + + def __and__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.and_) - def __xor__(self, other): + @overload + def __xor__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __xor__(self, other: VarCompatible) -> Self: ... + + def __xor__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.xor) - def __or__(self, other): + @overload + def __or__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __or__(self, other: VarCompatible) -> Self: ... + + def __or__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.or_) - def __lt__(self, other): + @overload + def __lshift__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __lshift__(self, other: VarCompatible) -> Self: ... + + def __lshift__(self, other: VarCompatible) -> Self | T_DataArray: + return self._binary_op(other, operator.lshift) + + @overload + def __rshift__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __rshift__(self, other: VarCompatible) -> Self: ... + + def __rshift__(self, other: VarCompatible) -> Self | T_DataArray: + return self._binary_op(other, operator.rshift) + + @overload + def __lt__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __lt__(self, other: VarCompatible) -> Self: ... + + def __lt__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.lt) - def __le__(self, other): + @overload + def __le__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __le__(self, other: VarCompatible) -> Self: ... + + def __le__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.le) - def __gt__(self, other): + @overload + def __gt__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __gt__(self, other: VarCompatible) -> Self: ... + + def __gt__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.gt) - def __ge__(self, other): + @overload + def __ge__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __ge__(self, other: VarCompatible) -> Self: ... + + def __ge__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.ge) - def __eq__(self, other): + @overload # type:ignore[override] + def __eq__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __eq__(self, other: VarCompatible) -> Self: ... + + def __eq__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + @overload # type:ignore[override] + def __ne__(self, other: T_DataArray) -> T_DataArray: ... + + @overload + def __ne__(self, other: VarCompatible) -> Self: ... + + def __ne__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + # When __eq__ is defined but __hash__ is not, then an object is unhashable, + # and it should be declared as follows: + __hash__: None # type:ignore[assignment] + + def __radd__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: VarCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ior) - def _unary_op(self, f, *args, **kwargs): + def __ilshift__(self, other: VarCompatible) -> Self: # type:ignore[misc] + return self._inplace_binary_op(other, operator.ilshift) + + def __irshift__(self, other: VarCompatible) -> Self: # type:ignore[misc] + return self._inplace_binary_op(other, operator.irshift) + + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -540,6 +726,8 @@ def conjugate(self, *args, **kwargs): __and__.__doc__ = operator.and_.__doc__ __xor__.__doc__ = operator.xor.__doc__ __or__.__doc__ = operator.or_.__doc__ + __lshift__.__doc__ = operator.lshift.__doc__ + __rshift__.__doc__ = operator.rshift.__doc__ __lt__.__doc__ = operator.lt.__doc__ __le__.__doc__ = operator.le.__doc__ __gt__.__doc__ = operator.gt.__doc__ @@ -566,6 +754,8 @@ def conjugate(self, *args, **kwargs): __iand__.__doc__ = operator.iand.__doc__ __ixor__.__doc__ = operator.ixor.__doc__ __ior__.__doc__ = operator.ior.__doc__ + __ilshift__.__doc__ = operator.ilshift.__doc__ + __irshift__.__doc__ = operator.irshift.__doc__ __neg__.__doc__ = operator.neg.__doc__ __pos__.__doc__ = operator.pos.__doc__ __abs__.__doc__ = operator.abs.__doc__ @@ -579,85 +769,97 @@ def conjugate(self, *args, **kwargs): class DatasetGroupByOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: GroupByCompatible, f: Callable, reflexive: bool = False + ) -> Dataset: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.or_) - def __lt__(self, other): + def __lshift__(self, other: GroupByCompatible) -> Dataset: + return self._binary_op(other, operator.lshift) + + def __rshift__(self, other: GroupByCompatible) -> Dataset: + return self._binary_op(other, operator.rshift) + + def __lt__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + # When __eq__ is defined but __hash__ is not, then an object is unhashable, + # and it should be declared as follows: + __hash__: None # type:ignore[assignment] + + def __radd__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.or_, reflexive=True) __add__.__doc__ = operator.add.__doc__ @@ -670,6 +872,8 @@ def __ror__(self, other): __and__.__doc__ = operator.and_.__doc__ __xor__.__doc__ = operator.xor.__doc__ __or__.__doc__ = operator.or_.__doc__ + __lshift__.__doc__ = operator.lshift.__doc__ + __rshift__.__doc__ = operator.rshift.__doc__ __lt__.__doc__ = operator.lt.__doc__ __le__.__doc__ = operator.le.__doc__ __gt__.__doc__ = operator.gt.__doc__ @@ -691,85 +895,97 @@ def __ror__(self, other): class DataArrayGroupByOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: T_Xarray, f: Callable, reflexive: bool = False + ) -> T_Xarray: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.or_) - def __lt__(self, other): + def __lshift__(self, other: T_Xarray) -> T_Xarray: + return self._binary_op(other, operator.lshift) + + def __rshift__(self, other: T_Xarray) -> T_Xarray: + return self._binary_op(other, operator.rshift) + + def __lt__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + # When __eq__ is defined but __hash__ is not, then an object is unhashable, + # and it should be declared as follows: + __hash__: None # type:ignore[assignment] + + def __radd__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.or_, reflexive=True) __add__.__doc__ = operator.add.__doc__ @@ -782,6 +998,8 @@ def __ror__(self, other): __and__.__doc__ = operator.and_.__doc__ __xor__.__doc__ = operator.xor.__doc__ __or__.__doc__ = operator.or_.__doc__ + __lshift__.__doc__ = operator.lshift.__doc__ + __rshift__.__doc__ = operator.rshift.__doc__ __lt__.__doc__ = operator.lt.__doc__ __le__.__doc__ = operator.le.__doc__ __gt__.__doc__ = operator.gt.__doc__ diff --git a/xarray/core/_typed_ops.pyi b/xarray/core/_typed_ops.pyi deleted file mode 100644 index 98a17a47cd5..00000000000 --- a/xarray/core/_typed_ops.pyi +++ /dev/null @@ -1,732 +0,0 @@ -"""Stub file for mixin classes with arithmetic operators.""" -# This file was generated using xarray.util.generate_ops. Do not edit manually. - -from typing import NoReturn, TypeVar, overload - -import numpy as np -from numpy.typing import ArrayLike - -from .dataarray import DataArray -from .dataset import Dataset -from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy -from .types import ( - DaCompatible, - DsCompatible, - GroupByIncompatible, - ScalarOrArray, - VarCompatible, -) -from .variable import Variable - -try: - from dask.array import Array as DaskArray -except ImportError: - DaskArray = np.ndarray # type: ignore - -# DatasetOpsMixin etc. are parent classes of Dataset etc. -# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally -# we use the ones in `types`. (We're open to refining this, and potentially integrating -# the `py` & `pyi` files to simplify them.) -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin") - -class DatasetOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - def __add__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __sub__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __mul__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __pow__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __truediv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __floordiv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __mod__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __and__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __xor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __or__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __lt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __le__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __gt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __ge__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __eq__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... # type: ignore[override] - def __ne__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... # type: ignore[override] - def __radd__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rsub__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rmul__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rpow__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rtruediv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rfloordiv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rmod__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rand__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rxor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __ror__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_Dataset) -> T_Dataset: ... - def __pos__(self: T_Dataset) -> T_Dataset: ... - def __abs__(self: T_Dataset) -> T_Dataset: ... - def __invert__(self: T_Dataset) -> T_Dataset: ... - def round(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def argsort(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def conj(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def conjugate(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - -class DataArrayOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __add__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __sub__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __mul__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __pow__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __truediv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __floordiv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __mod__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __and__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __xor__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __or__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __lt__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __le__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __gt__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ge__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __eq__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ne__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __radd__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rsub__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rmul__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rpow__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rtruediv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rfloordiv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rmod__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rand__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rxor__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ror__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_DataArray) -> T_DataArray: ... - def __pos__(self: T_DataArray) -> T_DataArray: ... - def __abs__(self: T_DataArray) -> T_DataArray: ... - def __invert__(self: T_DataArray) -> T_DataArray: ... - def round(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def argsort(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def conj(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def conjugate(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - -class VariableOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __add__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __sub__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mul__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __pow__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __truediv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __floordiv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mod__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __and__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __xor__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __or__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lt__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __le__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __gt__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ge__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __eq__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ne__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __radd__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rsub__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmul__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rpow__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rtruediv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rfloordiv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmod__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rand__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rxor__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ror__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_Variable) -> T_Variable: ... - def __pos__(self: T_Variable) -> T_Variable: ... - def __abs__(self: T_Variable) -> T_Variable: ... - def __invert__(self: T_Variable) -> T_Variable: ... - def round(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def argsort(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def conj(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def conjugate(self: T_Variable, *args, **kwargs) -> T_Variable: ... - -class DatasetGroupByOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: "DataArray") -> "Dataset": ... - @overload - def __add__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: "DataArray") -> "Dataset": ... - @overload - def __sub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: "DataArray") -> "Dataset": ... - @overload - def __mul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: "DataArray") -> "Dataset": ... - @overload - def __pow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __truediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __floordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: "DataArray") -> "Dataset": ... - @overload - def __mod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: "DataArray") -> "Dataset": ... - @overload - def __and__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: "DataArray") -> "Dataset": ... - @overload - def __xor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: "DataArray") -> "Dataset": ... - @overload - def __or__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: "DataArray") -> "Dataset": ... - @overload - def __lt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: "DataArray") -> "Dataset": ... - @overload - def __le__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: "DataArray") -> "Dataset": ... - @overload - def __gt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ge__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: "DataArray") -> "Dataset": ... - @overload - def __eq__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ne__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: "DataArray") -> "Dataset": ... - @overload - def __radd__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rsub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rmul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rpow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rtruediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rfloordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rmod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rand__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rxor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ror__(self, other: GroupByIncompatible) -> NoReturn: ... - -class DataArrayGroupByOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __add__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __sub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __pow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __truediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __floordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __and__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __xor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __or__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __le__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __gt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ge__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __eq__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ne__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __radd__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rsub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rpow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rtruediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rfloordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rand__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rxor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ror__(self, other: GroupByIncompatible) -> NoReturn: ... diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py index 7e6d4ab82d7..41b982d268b 100644 --- a/xarray/core/accessor_dt.py +++ b/xarray/core/accessor_dt.py @@ -7,13 +7,15 @@ import pandas as pd from xarray.coding.times import infer_calendar_name +from xarray.core import duck_array_ops from xarray.core.common import ( _contains_datetime_like_objects, is_np_datetime_like, is_np_timedelta_like, ) -from xarray.core.pycompat import is_duck_dask_array from xarray.core.types import T_DataArray +from xarray.core.variable import IndexVariable +from xarray.namedarray.utils import is_duck_dask_array if TYPE_CHECKING: from numpy.typing import DTypeLike @@ -48,13 +50,17 @@ def _access_through_cftimeindex(values, name): """ from xarray.coding.cftimeindex import CFTimeIndex - values_as_cftimeindex = CFTimeIndex(values.ravel()) + if not isinstance(values, CFTimeIndex): + values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values)) + else: + values_as_cftimeindex = values if name == "season": months = values_as_cftimeindex.month field_values = _season_from_months(months) elif name == "date": raise AttributeError( - "'CFTimeIndex' object has no attribute `date`. Consider using the floor method instead, for instance: `.time.dt.floor('D')`." + "'CFTimeIndex' object has no attribute `date`. Consider using the floor method " + "instead, for instance: `.time.dt.floor('D')`." ) else: field_values = getattr(values_as_cftimeindex, name) @@ -65,16 +71,32 @@ def _access_through_series(values, name): """Coerce an array of datetime-like values to a pandas Series and access requested datetime component """ - values_as_series = pd.Series(values.ravel(), copy=False) + values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False) if name == "season": months = values_as_series.dt.month.values field_values = _season_from_months(months) + elif name == "total_seconds": + field_values = values_as_series.dt.total_seconds().values elif name == "isocalendar": + # special NaT-handling can be removed when + # https://github.com/pandas-dev/pandas/issues/54657 is resolved + field_values = values_as_series.dt.isocalendar() + # test for and apply needed dtype + hasna = any(field_values.year.isnull()) + if hasna: + field_values = np.dstack( + [ + getattr(field_values, name).astype(np.float64, copy=False).values + for name in ["year", "week", "day"] + ] + ) + else: + field_values = np.array(field_values, dtype=np.int64) # isocalendar returns iso- year, week, and weekday -> reshape - field_values = np.array(values_as_series.dt.isocalendar(), dtype=np.int64) return field_values.T.reshape(3, *values.shape) else: field_values = getattr(values_as_series.dt, name).values + return field_values.reshape(values.shape) @@ -106,7 +128,7 @@ def _get_date_field(values, name, dtype): from dask.array import map_blocks new_axis = chunks = None - # isocalendar adds adds an axis + # isocalendar adds an axis if name == "isocalendar": chunks = (3,) + values.chunksize new_axis = 0 @@ -115,7 +137,12 @@ def _get_date_field(values, name, dtype): access_method, values, name, dtype=dtype, new_axis=new_axis, chunks=chunks ) else: - return access_method(values, name) + out = access_method(values, name) + # cast only for integer types to keep float64 in presence of NaT + # see https://github.com/pydata/xarray/issues/7928 + if np.issubdtype(out.dtype, np.integer): + out = out.astype(dtype, copy=False) + return out def _round_through_series_or_index(values, name, freq): @@ -125,10 +152,10 @@ def _round_through_series_or_index(values, name, freq): from xarray.coding.cftimeindex import CFTimeIndex if is_np_datetime_like(values.dtype): - values_as_series = pd.Series(values.ravel(), copy=False) + values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False) method = getattr(values_as_series.dt, name) else: - values_as_cftimeindex = CFTimeIndex(values.ravel()) + values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values)) method = getattr(values_as_cftimeindex, name) field_values = method(freq=freq).values @@ -172,7 +199,7 @@ def _strftime_through_cftimeindex(values, date_format: str): """ from xarray.coding.cftimeindex import CFTimeIndex - values_as_cftimeindex = CFTimeIndex(values.ravel()) + values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values)) field_values = values_as_cftimeindex.strftime(date_format) return field_values.values.reshape(values.shape) @@ -182,7 +209,7 @@ def _strftime_through_series(values, date_format: str): """Coerce an array of datetime-like values to a pandas Series and apply string formatting """ - values_as_series = pd.Series(values.ravel(), copy=False) + values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False) strs = values_as_series.dt.strftime(date_format) return strs.values.reshape(values.shape) @@ -200,6 +227,13 @@ def _strftime(values, date_format): return access_method(values, date_format) +def _index_or_data(obj): + if isinstance(obj.variable, IndexVariable): + return obj.to_index() + else: + return obj.data + + class TimeAccessor(Generic[T_DataArray]): __slots__ = ("_obj",) @@ -209,14 +243,14 @@ def __init__(self, obj: T_DataArray) -> None: def _date_field(self, name: str, dtype: DTypeLike) -> T_DataArray: if dtype is None: dtype = self._obj.dtype - obj_type = type(self._obj) - result = _get_date_field(self._obj.data, name, dtype) - return obj_type(result, name=name, coords=self._obj.coords, dims=self._obj.dims) + result = _get_date_field(_index_or_data(self._obj), name, dtype) + newvar = self._obj.variable.copy(data=result, deep=False) + return self._obj._replace(newvar, name=name) def _tslib_round_accessor(self, name: str, freq: str) -> T_DataArray: - obj_type = type(self._obj) - result = _round_field(self._obj.data, name, freq) - return obj_type(result, name=name, coords=self._obj.coords, dims=self._obj.dims) + result = _round_field(_index_or_data(self._obj), name, freq) + newvar = self._obj.variable.copy(data=result, deep=False) + return self._obj._replace(newvar, name=name) def floor(self, freq: str) -> T_DataArray: """ @@ -279,7 +313,7 @@ class DatetimeAccessor(TimeAccessor[T_DataArray]): >>> dates = pd.date_range(start="2000/01/01", freq="D", periods=10) >>> ts = xr.DataArray(dates, dims=("time")) >>> ts - + Size: 80B array(['2000-01-01T00:00:00.000000000', '2000-01-02T00:00:00.000000000', '2000-01-03T00:00:00.000000000', '2000-01-04T00:00:00.000000000', '2000-01-05T00:00:00.000000000', '2000-01-06T00:00:00.000000000', @@ -287,19 +321,19 @@ class DatetimeAccessor(TimeAccessor[T_DataArray]): '2000-01-09T00:00:00.000000000', '2000-01-10T00:00:00.000000000'], dtype='datetime64[ns]') Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-10 + * time (time) datetime64[ns] 80B 2000-01-01 2000-01-02 ... 2000-01-10 >>> ts.dt # doctest: +ELLIPSIS >>> ts.dt.dayofyear - + Size: 80B array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-10 + * time (time) datetime64[ns] 80B 2000-01-01 2000-01-02 ... 2000-01-10 >>> ts.dt.quarter - + Size: 80B array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-10 + * time (time) datetime64[ns] 80B 2000-01-01 2000-01-02 ... 2000-01-10 """ @@ -325,7 +359,7 @@ def strftime(self, date_format: str) -> T_DataArray: >>> import datetime >>> rng = xr.Dataset({"time": datetime.datetime(2000, 1, 1)}) >>> rng["time"].dt.strftime("%B %d, %Y, %r") - + Size: 8B array('January 01, 2000, 12:00:00 AM', dtype=object) """ obj_type = type(self._obj) @@ -423,11 +457,6 @@ def dayofweek(self) -> T_DataArray: weekday = dayofweek - @property - def weekday_name(self) -> T_DataArray: - """The name of day in a week""" - return self._date_field("weekday_name", object) - @property def dayofyear(self) -> T_DataArray: """The ordinal day of the year""" @@ -512,10 +541,10 @@ class TimedeltaAccessor(TimeAccessor[T_DataArray]): Examples -------- - >>> dates = pd.timedelta_range(start="1 day", freq="6H", periods=20) + >>> dates = pd.timedelta_range(start="1 day", freq="6h", periods=20) >>> ts = xr.DataArray(dates, dims=("time")) >>> ts - + Size: 160B array([ 86400000000000, 108000000000000, 129600000000000, 151200000000000, 172800000000000, 194400000000000, 216000000000000, 237600000000000, 259200000000000, 280800000000000, 302400000000000, 324000000000000, @@ -523,26 +552,33 @@ class TimedeltaAccessor(TimeAccessor[T_DataArray]): 432000000000000, 453600000000000, 475200000000000, 496800000000000], dtype='timedelta64[ns]') Coordinates: - * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 + * time (time) timedelta64[ns] 160B 1 days 00:00:00 ... 5 days 18:00:00 >>> ts.dt # doctest: +ELLIPSIS >>> ts.dt.days - + Size: 160B array([1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5]) Coordinates: - * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 + * time (time) timedelta64[ns] 160B 1 days 00:00:00 ... 5 days 18:00:00 >>> ts.dt.microseconds - + Size: 160B array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) Coordinates: - * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 + * time (time) timedelta64[ns] 160B 1 days 00:00:00 ... 5 days 18:00:00 >>> ts.dt.seconds - + Size: 160B array([ 0, 21600, 43200, 64800, 0, 21600, 43200, 64800, 0, 21600, 43200, 64800, 0, 21600, 43200, 64800, 0, 21600, 43200, 64800]) Coordinates: - * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 + * time (time) timedelta64[ns] 160B 1 days 00:00:00 ... 5 days 18:00:00 + >>> ts.dt.total_seconds() + Size: 160B + array([ 86400., 108000., 129600., 151200., 172800., 194400., 216000., + 237600., 259200., 280800., 302400., 324000., 345600., 367200., + 388800., 410400., 432000., 453600., 475200., 496800.]) + Coordinates: + * time (time) timedelta64[ns] 160B 1 days 00:00:00 ... 5 days 18:00:00 """ @property @@ -565,17 +601,26 @@ def nanoseconds(self) -> T_DataArray: """Number of nanoseconds (>= 0 and less than 1 microsecond) for each element""" return self._date_field("nanoseconds", np.int64) + # Not defined as a property in order to match the Pandas API + def total_seconds(self) -> T_DataArray: + """Total duration of each element expressed in seconds.""" + return self._date_field("total_seconds", np.float64) + class CombinedDatetimelikeAccessor( DatetimeAccessor[T_DataArray], TimedeltaAccessor[T_DataArray] ): def __new__(cls, obj: T_DataArray) -> CombinedDatetimelikeAccessor: - # CombinedDatetimelikeAccessor isn't really instatiated. Instead + # CombinedDatetimelikeAccessor isn't really instantiated. Instead # we need to choose which parent (datetime or timedelta) is # appropriate. Since we're checking the dtypes anyway, we'll just # do all the validation here. if not _contains_datetime_like_objects(obj.variable): - raise TypeError( + # We use an AttributeError here so that `obj.dt` raises an error that + # `getattr` expects; https://github.com/pydata/xarray/issues/8718. It's a + # bit unusual in a `__new__`, but that's the only case where we use this + # class. + raise AttributeError( "'.dt' accessor only available for " "DataArray with datetime64 timedelta64 dtype or " "for arrays containing cftime datetime " diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 16e22ec1c66..a48fbc91faf 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -51,6 +51,7 @@ import numpy as np +from xarray.core import duck_array_ops from xarray.core.computation import apply_ufunc from xarray.core.types import T_DataArray @@ -147,7 +148,7 @@ class StringAccessor(Generic[T_DataArray]): >>> da = xr.DataArray(["some", "text", "in", "an", "array"]) >>> da.str.len() - + Size: 40B array([4, 4, 2, 2, 5]) Dimensions without coordinates: dim_0 @@ -158,7 +159,7 @@ class StringAccessor(Generic[T_DataArray]): >>> da1 = xr.DataArray(["first", "second", "third"], dims=["X"]) >>> da2 = xr.DataArray([1, 2, 3], dims=["Y"]) >>> da1.str + da2 - + Size: 252B array([['first1', 'first2', 'first3'], ['second1', 'second2', 'second3'], ['third1', 'third2', 'third3']], dtype='>> da1 = xr.DataArray(["a", "b", "c", "d"], dims=["X"]) >>> reps = xr.DataArray([3, 4], dims=["Y"]) >>> da1.str * reps - + Size: 128B array([['aaa', 'aaaa'], ['bbb', 'bbbb'], ['ccc', 'cccc'], @@ -178,7 +179,7 @@ class StringAccessor(Generic[T_DataArray]): >>> da2 = xr.DataArray([1, 2], dims=["Y"]) >>> da3 = xr.DataArray([0.1, 0.2], dims=["Z"]) >>> da1.str % (da2, da3) - + Size: 240B array([[['1_0.1', '1_0.2'], ['2_0.1', '2_0.2']], @@ -196,8 +197,8 @@ class StringAccessor(Generic[T_DataArray]): >>> da1 = xr.DataArray(["%(a)s"], dims=["X"]) >>> da2 = xr.DataArray([1, 2, 3], dims=["Y"]) >>> da1 % {"a": da2} - - array(['\narray([1, 2, 3])\nDimensions without coordinates: Y'], + Size: 8B + array([' Size: 24B\narray([1, 2, 3])\nDimensions without coordinates: Y'], dtype=object) Dimensions without coordinates: X """ @@ -470,7 +471,7 @@ def cat(self, *others, sep: str | bytes | Any = "") -> T_DataArray: ... ) >>> values_2 = np.array(3.4) >>> values_3 = "" - >>> values_4 = np.array("test", dtype=np.unicode_) + >>> values_4 = np.array("test", dtype=np.str_) Determine the separator to use @@ -482,7 +483,7 @@ def cat(self, *others, sep: str | bytes | Any = "") -> T_DataArray: Concatenate the arrays using the separator >>> myarray.str.cat(values_1, values_2, values_3, values_4, sep=seps) - + Size: 1kB array([[['11111 a 3.4 test', '11111, a, 3.4, , test'], ['11111 bb 3.4 test', '11111, bb, 3.4, , test'], ['11111 cccc 3.4 test', '11111, cccc, 3.4, , test']], @@ -555,7 +556,7 @@ def join( Join the strings along a given dimension >>> values.str.join(dim="Y", sep=seps) - + Size: 192B array([['a-bab-abc', 'a_bab_abc'], ['abcd--abcdef', 'abcd__abcdef']], dtype='>> values.str.format(noun0, noun1, adj0=adj0, adj1=adj1) - + Size: 1kB array([[['spam is unexpected', 'spam is unexpected'], ['egg is unexpected', 'egg is unexpected']], @@ -672,6 +673,23 @@ def capitalize(self) -> T_DataArray: Returns ------- capitalized : same type as values + + Examples + -------- + >>> da = xr.DataArray( + ... ["temperature", "PRESSURE", "PreCipiTation", "daily rainfall"], dims="x" + ... ) + >>> da + Size: 224B + array(['temperature', 'PRESSURE', 'PreCipiTation', 'daily rainfall'], + dtype='>> capitalized = da.str.capitalize() + >>> capitalized + Size: 224B + array(['Temperature', 'Pressure', 'Precipitation', 'Daily rainfall'], + dtype=' T_DataArray: Returns ------- - lowerd : same type as values + lowered : same type as values + + Examples + -------- + >>> da = xr.DataArray(["Temperature", "PRESSURE"], dims="x") + >>> da + Size: 88B + array(['Temperature', 'PRESSURE'], dtype='>> lowered = da.str.lower() + >>> lowered + Size: 88B + array(['temperature', 'pressure'], dtype=' T_DataArray: Returns ------- swapcased : same type as values + + Examples + -------- + >>> import xarray as xr + >>> da = xr.DataArray(["temperature", "PRESSURE", "HuMiDiTy"], dims="x") + >>> da + Size: 132B + array(['temperature', 'PRESSURE', 'HuMiDiTy'], dtype='>> swapcased = da.str.swapcase() + >>> swapcased + Size: 132B + array(['TEMPERATURE', 'pressure', 'hUmIdItY'], dtype=' T_DataArray: Returns ------- titled : same type as values + + Examples + -------- + >>> da = xr.DataArray(["temperature", "PRESSURE", "HuMiDiTy"], dims="x") + >>> da + Size: 132B + array(['temperature', 'PRESSURE', 'HuMiDiTy'], dtype='>> titled = da.str.title() + >>> titled + Size: 132B + array(['Temperature', 'Pressure', 'Humidity'], dtype=' T_DataArray: Returns ------- uppered : same type as values + + Examples + -------- + >>> da = xr.DataArray(["temperature", "HuMiDiTy"], dims="x") + >>> da + Size: 88B + array(['temperature', 'HuMiDiTy'], dtype='>> uppered = da.str.upper() + >>> uppered + Size: 88B + array(['TEMPERATURE', 'HUMIDITY'], dtype=' T_DataArray: Casefolding is similar to converting to lowercase, but removes all case distinctions. This is important in some languages that have more complicated - cases and case conversions. + cases and case conversions. For example, + the 'ß' character in German is case-folded to 'ss', whereas it is lowercased + to 'ß'. Returns ------- casefolded : same type as values + + Examples + -------- + >>> da = xr.DataArray(["TEMPERATURE", "HuMiDiTy"], dims="x") + >>> da + Size: 88B + array(['TEMPERATURE', 'HuMiDiTy'], dtype='>> casefolded = da.str.casefold() + >>> casefolded + Size: 88B + array(['temperature', 'humidity'], dtype='>> da = xr.DataArray(["ß", "İ"], dims="x") + >>> da + Size: 8B + array(['ß', 'İ'], dtype='>> casefolded = da.str.casefold() + >>> casefolded + Size: 16B + array(['ss', 'i̇'], dtype=' T_DataArray: ------- isalnum : array of bool Array of boolean values with the same shape as the original array. + + Examples + -------- + >>> da = xr.DataArray(["H2O", "NaCl-"], dims="x") + >>> da + Size: 40B + array(['H2O', 'NaCl-'], dtype='>> isalnum = da.str.isalnum() + >>> isalnum + Size: 2B + array([ True, False]) + Dimensions without coordinates: x """ return self._apply(func=lambda x: x.isalnum(), dtype=bool) @@ -771,6 +881,19 @@ def isalpha(self) -> T_DataArray: ------- isalpha : array of bool Array of boolean values with the same shape as the original array. + + Examples + -------- + >>> da = xr.DataArray(["Mn", "H2O", "NaCl-"], dims="x") + >>> da + Size: 60B + array(['Mn', 'H2O', 'NaCl-'], dtype='>> isalpha = da.str.isalpha() + >>> isalpha + Size: 3B + array([ True, False, False]) + Dimensions without coordinates: x """ return self._apply(func=lambda x: x.isalpha(), dtype=bool) @@ -782,6 +905,19 @@ def isdecimal(self) -> T_DataArray: ------- isdecimal : array of bool Array of boolean values with the same shape as the original array. + + Examples + -------- + >>> da = xr.DataArray(["2.3", "123", "0"], dims="x") + >>> da + Size: 36B + array(['2.3', '123', '0'], dtype='>> isdecimal = da.str.isdecimal() + >>> isdecimal + Size: 3B + array([False, True, True]) + Dimensions without coordinates: x """ return self._apply(func=lambda x: x.isdecimal(), dtype=bool) @@ -793,6 +929,19 @@ def isdigit(self) -> T_DataArray: ------- isdigit : array of bool Array of boolean values with the same shape as the original array. + + Examples + -------- + >>> da = xr.DataArray(["123", "1.2", "0", "CO2", "NaCl"], dims="x") + >>> da + Size: 80B + array(['123', '1.2', '0', 'CO2', 'NaCl'], dtype='>> isdigit = da.str.isdigit() + >>> isdigit + Size: 5B + array([ True, False, True, False, False]) + Dimensions without coordinates: x """ return self._apply(func=lambda x: x.isdigit(), dtype=bool) @@ -803,7 +952,21 @@ def islower(self) -> T_DataArray: Returns ------- islower : array of bool - Array of boolean values with the same shape as the original array. + Array of boolean values with the same shape as the original array indicating whether all characters of each + element of the string array are lowercase (True) or not (False). + + Examples + -------- + >>> da = xr.DataArray(["temperature", "HUMIDITY", "pREciPiTaTioN"], dims="x") + >>> da + Size: 156B + array(['temperature', 'HUMIDITY', 'pREciPiTaTioN'], dtype='>> islower = da.str.islower() + >>> islower + Size: 3B + array([ True, False, False]) + Dimensions without coordinates: x """ return self._apply(func=lambda x: x.islower(), dtype=bool) @@ -815,6 +978,19 @@ def isnumeric(self) -> T_DataArray: ------- isnumeric : array of bool Array of boolean values with the same shape as the original array. + + Examples + -------- + >>> da = xr.DataArray(["123", "2.3", "H2O", "NaCl-", "Mn"], dims="x") + >>> da + Size: 100B + array(['123', '2.3', 'H2O', 'NaCl-', 'Mn'], dtype='>> isnumeric = da.str.isnumeric() + >>> isnumeric + Size: 5B + array([ True, False, False, False, False]) + Dimensions without coordinates: x """ return self._apply(func=lambda x: x.isnumeric(), dtype=bool) @@ -826,6 +1002,19 @@ def isspace(self) -> T_DataArray: ------- isspace : array of bool Array of boolean values with the same shape as the original array. + + Examples + -------- + >>> da = xr.DataArray(["", " ", "\\t", "\\n"], dims="x") + >>> da + Size: 16B + array(['', ' ', '\\t', '\\n'], dtype='>> isspace = da.str.isspace() + >>> isspace + Size: 4B + array([False, True, True, True]) + Dimensions without coordinates: x """ return self._apply(func=lambda x: x.isspace(), dtype=bool) @@ -837,6 +1026,27 @@ def istitle(self) -> T_DataArray: ------- istitle : array of bool Array of boolean values with the same shape as the original array. + + Examples + -------- + >>> da = xr.DataArray( + ... [ + ... "The Evolution Of Species", + ... "The Theory of relativity", + ... "the quantum mechanics of atoms", + ... ], + ... dims="title", + ... ) + >>> da + Size: 360B + array(['The Evolution Of Species', 'The Theory of relativity', + 'the quantum mechanics of atoms'], dtype='>> istitle = da.str.istitle() + >>> istitle + Size: 3B + array([ True, False, False]) + Dimensions without coordinates: title """ return self._apply(func=lambda x: x.istitle(), dtype=bool) @@ -848,6 +1058,19 @@ def isupper(self) -> T_DataArray: ------- isupper : array of bool Array of boolean values with the same shape as the original array. + + Examples + -------- + >>> da = xr.DataArray(["TEMPERATURE", "humidity", "PreCIpiTAtioN"], dims="x") + >>> da + Size: 156B + array(['TEMPERATURE', 'humidity', 'PreCIpiTAtioN'], dtype='>> isupper = da.str.isupper() + >>> isupper + Size: 3B + array([ True, False, False]) + Dimensions without coordinates: x """ return self._apply(func=lambda x: x.isupper(), dtype=bool) @@ -883,6 +1106,46 @@ def count( Returns ------- counts : array of int + + Examples + -------- + >>> da = xr.DataArray(["jjklmn", "opjjqrs", "t-JJ99vwx"], dims="x") + >>> da + Size: 108B + array(['jjklmn', 'opjjqrs', 't-JJ99vwx'], dtype='>> da.str.count("jj") + Size: 24B + array([1, 1, 0]) + Dimensions without coordinates: x + + Enable case-insensitive matching by setting case to false: + >>> counts = da.str.count("jj", case=False) + >>> counts + Size: 24B + array([1, 1, 1]) + Dimensions without coordinates: x + + Using regex: + >>> pat = "JJ[0-9]{2}[a-z]{3}" + >>> counts = da.str.count(pat) + >>> counts + Size: 24B + array([0, 0, 1]) + Dimensions without coordinates: x + + Using an array of strings (the pattern will be broadcast against the array): + + >>> pat = xr.DataArray(["jj", "JJ"], dims="y") + >>> counts = da.str.count(pat) + >>> counts + Size: 48B + array([[1, 0], + [1, 0], + [0, 1]]) + Dimensions without coordinates: x, y """ pat = self._re_compile(pat=pat, flags=flags, case=case) @@ -907,6 +1170,19 @@ def startswith(self, pat: str | bytes | Any) -> T_DataArray: startswith : array of bool An array of booleans indicating whether the given pattern matches the start of each string element. + + Examples + -------- + >>> da = xr.DataArray(["$100", "£23", "100"], dims="x") + >>> da + Size: 48B + array(['$100', '£23', '100'], dtype='>> startswith = da.str.startswith("$") + >>> startswith + Size: 3B + array([ True, False, False]) + Dimensions without coordinates: x """ pat = self._stringify(pat) func = lambda x, y: x.startswith(y) @@ -930,6 +1206,19 @@ def endswith(self, pat: str | bytes | Any) -> T_DataArray: endswith : array of bool A Series of booleans indicating whether the given pattern matches the end of each string element. + + Examples + -------- + >>> da = xr.DataArray(["10C", "10c", "100F"], dims="x") + >>> da + Size: 48B + array(['10C', '10c', '100F'], dtype='>> endswith = da.str.endswith("C") + >>> endswith + Size: 3B + array([ True, False, False]) + Dimensions without coordinates: x """ pat = self._stringify(pat) func = lambda x, y: x.endswith(y) @@ -963,6 +1252,66 @@ def pad( ------- filled : same type as values Array with a minimum number of char in each element. + + Examples + -------- + Pad strings in the array with a single string on the left side. + + Define the string in the array. + + >>> da = xr.DataArray(["PAR184", "TKO65", "NBO9139", "NZ39"], dims="x") + >>> da + Size: 112B + array(['PAR184', 'TKO65', 'NBO9139', 'NZ39'], dtype='>> filled = da.str.pad(8, side="left", fillchar="0") + >>> filled + Size: 128B + array(['00PAR184', '000TKO65', '0NBO9139', '0000NZ39'], dtype='>> filled = da.str.pad(8, side="right", fillchar="0") + >>> filled + Size: 128B + array(['PAR18400', 'TKO65000', 'NBO91390', 'NZ390000'], dtype='>> filled = da.str.pad(8, side="both", fillchar="0") + >>> filled + Size: 128B + array(['0PAR1840', '0TKO6500', 'NBO91390', '00NZ3900'], dtype='>> width = xr.DataArray([8, 10], dims="y") + >>> filled = da.str.pad(width, side="left", fillchar="0") + >>> filled + Size: 320B + array([['00PAR184', '0000PAR184'], + ['000TKO65', '00000TKO65'], + ['0NBO9139', '000NBO9139'], + ['0000NZ39', '000000NZ39']], dtype='>> fillchar = xr.DataArray(["0", "-"], dims="y") + >>> filled = da.str.pad(8, side="left", fillchar=fillchar) + >>> filled + Size: 256B + array([['00PAR184', '--PAR184'], + ['000TKO65', '---TKO65'], + ['0NBO9139', '-NBO9139'], + ['0000NZ39', '----NZ39']], dtype='>> value.str.extract(r"(\w+)_Xy_(\d*)", dim="match") - + Size: 288B array([[['a', '0'], ['bab', '110'], ['abc', '01']], @@ -1737,13 +2086,16 @@ def _get_res_multi(val, pat): else: # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 - return self._apply( - func=_get_res_multi, - func_args=(pat,), - dtype=np.object_, - output_core_dims=[[dim]], - output_sizes={dim: maxgroups}, - ).astype(self._obj.dtype.kind) + return duck_array_ops.astype( + self._apply( + func=_get_res_multi, + func_args=(pat,), + dtype=np.object_, + output_core_dims=[[dim]], + output_sizes={dim: maxgroups}, + ), + self._obj.dtype.kind, + ) def extractall( self, @@ -1826,7 +2178,7 @@ def extractall( >>> value.str.extractall( ... r"(\w+)_Xy_(\d*)", group_dim="group", match_dim="match" ... ) - + Size: 1kB array([[[['a', '0'], ['', ''], ['', '']], @@ -1910,15 +2262,18 @@ def _get_res(val, ipat, imaxcount=maxcount, dtype=self._obj.dtype): return res - return self._apply( - # dtype MUST be object or strings can be truncated - # See: https://github.com/numpy/numpy/issues/8352 - func=_get_res, - func_args=(pat,), - dtype=np.object_, - output_core_dims=[[group_dim, match_dim]], - output_sizes={group_dim: maxgroups, match_dim: maxcount}, - ).astype(self._obj.dtype.kind) + return duck_array_ops.astype( + self._apply( + # dtype MUST be object or strings can be truncated + # See: https://github.com/numpy/numpy/issues/8352 + func=_get_res, + func_args=(pat,), + dtype=np.object_, + output_core_dims=[[group_dim, match_dim]], + output_sizes={group_dim: maxgroups, match_dim: maxcount}, + ), + self._obj.dtype.kind, + ) def findall( self, @@ -1987,7 +2342,7 @@ def findall( Extract matches >>> value.str.findall(r"(\w+)_Xy_(\d*)") - + Size: 48B array([[list([('a', '0')]), list([('bab', '110'), ('baab', '1100')]), list([('abc', '01'), ('cbc', '2210')])], [list([('abcd', ''), ('dcd', '33210'), ('dccd', '332210')]), @@ -2031,19 +2386,22 @@ def _partitioner( # _apply breaks on an empty array in this case if not self._obj.size: - return self._obj.copy().expand_dims({dim: 0}, axis=-1) # type: ignore[return-value] + return self._obj.copy().expand_dims({dim: 0}, axis=-1) arrfunc = lambda x, isep: np.array(func(x, isep), dtype=self._obj.dtype) # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 - return self._apply( - func=arrfunc, - func_args=(sep,), - dtype=np.object_, - output_core_dims=[[dim]], - output_sizes={dim: 3}, - ).astype(self._obj.dtype.kind) + return duck_array_ops.astype( + self._apply( + func=arrfunc, + func_args=(sep,), + dtype=np.object_, + output_core_dims=[[dim]], + output_sizes={dim: 3}, + ), + self._obj.dtype.kind, + ) def partition( self, @@ -2162,13 +2520,16 @@ def _dosplit(mystr, sep, maxsplit=maxsplit, dtype=self._obj.dtype): # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 - return self._apply( - func=_dosplit, - func_args=(sep,), - dtype=np.object_, - output_core_dims=[[dim]], - output_sizes={dim: maxsplit}, - ).astype(self._obj.dtype.kind) + return duck_array_ops.astype( + self._apply( + func=_dosplit, + func_args=(sep,), + dtype=np.object_, + output_core_dims=[[dim]], + output_sizes={dim: maxsplit}, + ), + self._obj.dtype.kind, + ) def split( self, @@ -2216,7 +2577,7 @@ def split( Split once and put the results in a new dimension >>> values.str.split(dim="splitted", maxsplit=1) - + Size: 864B array([[['abc', 'def'], ['spam', 'eggs\tswallow'], ['red_blue', '']], @@ -2229,7 +2590,7 @@ def split( Split as many times as needed and put the results in a new dimension >>> values.str.split(dim="splitted") - + Size: 768B array([[['abc', 'def', '', ''], ['spam', 'eggs', 'swallow', ''], ['red_blue', '', '', '']], @@ -2242,7 +2603,7 @@ def split( Split once and put the results in lists >>> values.str.split(dim=None, maxsplit=1) - + Size: 48B array([[list(['abc', 'def']), list(['spam', 'eggs\tswallow']), list(['red_blue'])], [list(['test0', 'test1\ntest2\n\ntest3']), list([]), @@ -2252,7 +2613,7 @@ def split( Split as many times as needed and put the results in a list >>> values.str.split(dim=None) - + Size: 48B array([[list(['abc', 'def']), list(['spam', 'eggs', 'swallow']), list(['red_blue'])], [list(['test0', 'test1', 'test2', 'test3']), list([]), @@ -2262,7 +2623,7 @@ def split( Split only on spaces >>> values.str.split(dim="splitted", sep=" ") - + Size: 2kB array([[['abc', 'def', ''], ['spam\t\teggs\tswallow', '', ''], ['red_blue', '', '']], @@ -2334,7 +2695,7 @@ def rsplit( Split once and put the results in a new dimension >>> values.str.rsplit(dim="splitted", maxsplit=1) - + Size: 816B array([[['abc', 'def'], ['spam\t\teggs', 'swallow'], ['', 'red_blue']], @@ -2347,7 +2708,7 @@ def rsplit( Split as many times as needed and put the results in a new dimension >>> values.str.rsplit(dim="splitted") - + Size: 768B array([[['', '', 'abc', 'def'], ['', 'spam', 'eggs', 'swallow'], ['', '', '', 'red_blue']], @@ -2360,7 +2721,7 @@ def rsplit( Split once and put the results in lists >>> values.str.rsplit(dim=None, maxsplit=1) - + Size: 48B array([[list(['abc', 'def']), list(['spam\t\teggs', 'swallow']), list(['red_blue'])], [list(['test0\ntest1\ntest2', 'test3']), list([]), @@ -2370,7 +2731,7 @@ def rsplit( Split as many times as needed and put the results in a list >>> values.str.rsplit(dim=None) - + Size: 48B array([[list(['abc', 'def']), list(['spam', 'eggs', 'swallow']), list(['red_blue'])], [list(['test0', 'test1', 'test2', 'test3']), list([]), @@ -2380,7 +2741,7 @@ def rsplit( Split only on spaces >>> values.str.rsplit(dim="splitted", sep=" ") - + Size: 2kB array([[['', 'abc', 'def'], ['', '', 'spam\t\teggs\tswallow'], ['', '', 'red_blue']], @@ -2447,7 +2808,7 @@ def get_dummies( Extract dummy values >>> values.str.get_dummies(dim="dummies") - + Size: 30B array([[[ True, False, True, False, True], [False, True, False, False, False], [ True, False, True, True, False]], @@ -2456,7 +2817,7 @@ def get_dummies( [False, False, True, False, True], [ True, False, False, False, False]]]) Coordinates: - * dummies (dummies) tuple[NormalizedIndexes, NormalizedIndexVars]: """Normalize the indexes/indexers used for re-indexing or alignment. @@ -196,7 +200,7 @@ def _normalize_indexes( f"Indexer has dimensions {idx.dims} that are different " f"from that to be indexed along '{k}'" ) - data = as_compatible_data(idx) + data: T_DuckArray = as_compatible_data(idx) pd_idx = safe_cast_to_index(data) pd_idx.name = k if isinstance(pd_idx, pd.MultiIndex): @@ -320,7 +324,7 @@ def assert_no_index_conflict(self) -> None: "- they may be used to reindex data along common dimensions" ) - def _need_reindex(self, dims, cmp_indexes) -> bool: + def _need_reindex(self, dim, cmp_indexes) -> bool: """Whether or not we need to reindex variables for a set of matching indexes. @@ -336,14 +340,14 @@ def _need_reindex(self, dims, cmp_indexes) -> bool: return True unindexed_dims_sizes = {} - for dim in dims: - if dim in self.unindexed_dim_sizes: - sizes = self.unindexed_dim_sizes[dim] + for d in dim: + if d in self.unindexed_dim_sizes: + sizes = self.unindexed_dim_sizes[d] if len(sizes) > 1: # reindex if different sizes are found for unindexed dims return True else: - unindexed_dims_sizes[dim] = next(iter(sizes)) + unindexed_dims_sizes[d] = next(iter(sizes)) if unindexed_dims_sizes: indexed_dims_sizes = {} @@ -352,8 +356,8 @@ def _need_reindex(self, dims, cmp_indexes) -> bool: for var in index_vars.values(): indexed_dims_sizes.update(var.sizes) - for dim, size in unindexed_dims_sizes.items(): - if indexed_dims_sizes.get(dim, -1) != size: + for d, size in unindexed_dims_sizes.items(): + if indexed_dims_sizes.get(d, -1) != size: # reindex if unindexed dimension size doesn't match return True @@ -510,7 +514,7 @@ def _get_dim_pos_indexers( def _get_indexes_and_vars( self, - obj: DataAlignable, + obj: T_Alignable, matching_indexes: dict[MatchingIndexKey, Index], ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]: new_indexes = {} @@ -533,13 +537,13 @@ def _get_indexes_and_vars( def _reindex_one( self, - obj: DataAlignable, + obj: T_Alignable, matching_indexes: dict[MatchingIndexKey, Index], - ) -> DataAlignable: + ) -> T_Alignable: new_indexes, new_variables = self._get_indexes_and_vars(obj, matching_indexes) dim_pos_indexers = self._get_dim_pos_indexers(matching_indexes) - new_obj = obj._reindex_callback( + return obj._reindex_callback( self, dim_pos_indexers, new_variables, @@ -548,8 +552,6 @@ def _reindex_one( self.exclude_dims, self.exclude_vars, ) - new_obj.encoding = obj.encoding - return new_obj def reindex_all(self) -> None: self.results = tuple( @@ -574,18 +576,113 @@ def align(self) -> None: if self.join == "override": self.override_indexes() + elif self.join == "exact" and not self.copy: + self.results = self.objects else: self.reindex_all() +T_Obj1 = TypeVar("T_Obj1", bound="Alignable") +T_Obj2 = TypeVar("T_Obj2", bound="Alignable") +T_Obj3 = TypeVar("T_Obj3", bound="Alignable") +T_Obj4 = TypeVar("T_Obj4", bound="Alignable") +T_Obj5 = TypeVar("T_Obj5", bound="Alignable") + + +@overload +def align( + obj1: T_Obj1, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1]: ... + + +@overload +def align( + obj1: T_Obj1, + obj2: T_Obj2, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2]: ... + + +@overload def align( - *objects: DataAlignable, + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + /, + *, join: JoinOptions = "inner", copy: bool = True, indexes=None, - exclude=frozenset(), + exclude: str | Iterable[Hashable] = frozenset(), fill_value=dtypes.NA, -) -> tuple[DataAlignable, ...]: +) -> tuple[T_Obj1, T_Obj2, T_Obj3]: ... + + +@overload +def align( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4]: ... + + +@overload +def align( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + obj5: T_Obj5, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4, T_Obj5]: ... + + +@overload +def align( + *objects: T_Alignable, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Alignable, ...]: ... + + +def align( + *objects: T_Alignable, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Alignable, ...]: """ Given any number of Dataset and/or DataArray objects, returns new objects with aligned indexes and dimension sizes. @@ -622,7 +719,7 @@ def align( indexes : dict-like, optional Any indexes explicitly provided with the `indexes` argument should be used in preference to the aligned indexes. - exclude : sequence of str, optional + exclude : str, iterable of hashable or None, optional Dimensions that must be excluded from alignment fill_value : scalar or dict-like, optional Value to use for newly missing values. If a dict-like, maps @@ -655,102 +752,102 @@ def align( ... ) >>> x - + Size: 32B array([[25, 35], [10, 24]]) Coordinates: - * lat (lat) float64 35.0 40.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 16B 35.0 40.0 + * lon (lon) float64 16B 100.0 120.0 >>> y - + Size: 32B array([[20, 5], [ 7, 13]]) Coordinates: - * lat (lat) float64 35.0 42.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 16B 35.0 42.0 + * lon (lon) float64 16B 100.0 120.0 >>> a, b = xr.align(x, y) >>> a - + Size: 16B array([[25, 35]]) Coordinates: - * lat (lat) float64 35.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 8B 35.0 + * lon (lon) float64 16B 100.0 120.0 >>> b - + Size: 16B array([[20, 5]]) Coordinates: - * lat (lat) float64 35.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 8B 35.0 + * lon (lon) float64 16B 100.0 120.0 >>> a, b = xr.align(x, y, join="outer") >>> a - + Size: 48B array([[25., 35.], [10., 24.], [nan, nan]]) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 24B 35.0 40.0 42.0 + * lon (lon) float64 16B 100.0 120.0 >>> b - + Size: 48B array([[20., 5.], [nan, nan], [ 7., 13.]]) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 24B 35.0 40.0 42.0 + * lon (lon) float64 16B 100.0 120.0 >>> a, b = xr.align(x, y, join="outer", fill_value=-999) >>> a - + Size: 48B array([[ 25, 35], [ 10, 24], [-999, -999]]) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 24B 35.0 40.0 42.0 + * lon (lon) float64 16B 100.0 120.0 >>> b - + Size: 48B array([[ 20, 5], [-999, -999], [ 7, 13]]) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 24B 35.0 40.0 42.0 + * lon (lon) float64 16B 100.0 120.0 >>> a, b = xr.align(x, y, join="left") >>> a - + Size: 32B array([[25, 35], [10, 24]]) Coordinates: - * lat (lat) float64 35.0 40.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 16B 35.0 40.0 + * lon (lon) float64 16B 100.0 120.0 >>> b - + Size: 32B array([[20., 5.], [nan, nan]]) Coordinates: - * lat (lat) float64 35.0 40.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 16B 35.0 40.0 + * lon (lon) float64 16B 100.0 120.0 >>> a, b = xr.align(x, y, join="right") >>> a - + Size: 32B array([[25., 35.], [nan, nan]]) Coordinates: - * lat (lat) float64 35.0 42.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 16B 35.0 42.0 + * lon (lon) float64 16B 100.0 120.0 >>> b - + Size: 32B array([[20, 5], [ 7, 13]]) Coordinates: - * lat (lat) float64 35.0 42.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 16B 35.0 42.0 + * lon (lon) float64 16B 100.0 120.0 >>> a, b = xr.align(x, y, join="exact") Traceback (most recent call last): @@ -759,19 +856,19 @@ def align( >>> a, b = xr.align(x, y, join="override") >>> a - + Size: 32B array([[25, 35], [10, 24]]) Coordinates: - * lat (lat) float64 35.0 40.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 16B 35.0 40.0 + * lon (lon) float64 16B 100.0 120.0 >>> b - + Size: 32B array([[20, 5], [ 7, 13]]) Coordinates: - * lat (lat) float64 35.0 40.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 16B 35.0 40.0 + * lon (lon) float64 16B 100.0 120.0 """ aligner = Aligner( @@ -789,16 +886,17 @@ def align( def deep_align( objects: Iterable[Any], join: JoinOptions = "inner", - copy=True, + copy: bool = True, indexes=None, - exclude=frozenset(), - raise_on_invalid=True, + exclude: str | Iterable[Hashable] = frozenset(), + raise_on_invalid: bool = True, fill_value=dtypes.NA, -): +) -> list[Any]: """Align objects for merging, recursing into dictionary values. This function is not public API. """ + from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -806,14 +904,14 @@ def deep_align( indexes = {} def is_alignable(obj): - return isinstance(obj, (DataArray, Dataset)) - - positions = [] - keys = [] - out = [] - targets = [] - no_key = object() - not_replaced = object() + return isinstance(obj, (Coordinates, DataArray, Dataset)) + + positions: list[int] = [] + keys: list[type[object] | Hashable] = [] + out: list[Any] = [] + targets: list[Alignable] = [] + no_key: Final = object() + not_replaced: Final = object() for position, variables in enumerate(objects): if is_alignable(variables): positions.append(position) @@ -840,7 +938,7 @@ def is_alignable(obj): elif raise_on_invalid: raise ValueError( "object to align is neither an xarray.Dataset, " - "an xarray.DataArray nor a dictionary: {!r}".format(variables) + f"an xarray.DataArray nor a dictionary: {variables!r}" ) else: out.append(variables) @@ -858,13 +956,13 @@ def is_alignable(obj): if key is no_key: out[position] = aligned_obj else: - out[position][key] = aligned_obj # type: ignore[index] # maybe someone can fix this? + out[position][key] = aligned_obj return out def reindex( - obj: DataAlignable, + obj: T_Alignable, indexers: Mapping[Any, Any], method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, @@ -872,7 +970,7 @@ def reindex( fill_value: Any = dtypes.NA, sparse: bool = False, exclude_vars: Iterable[Hashable] = frozenset(), -) -> DataAlignable: +) -> T_Alignable: """Re-index either a Dataset or a DataArray. Not public API. @@ -903,13 +1001,13 @@ def reindex( def reindex_like( - obj: DataAlignable, + obj: T_Alignable, other: Dataset | DataArray, method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value: Any = dtypes.NA, -) -> DataAlignable: +) -> T_Alignable: """Re-index either a Dataset or a DataArray like another Dataset/DataArray. Not public API. @@ -951,8 +1049,8 @@ def _get_broadcast_dims_map_common_coords(args, exclude): def _broadcast_helper( - arg: T_DataWithCoords, exclude, dims_map, common_coords -) -> T_DataWithCoords: + arg: T_Alignable, exclude, dims_map, common_coords +) -> T_Alignable: from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -982,16 +1080,70 @@ def _broadcast_dataset(ds: T_Dataset) -> T_Dataset: # remove casts once https://github.com/python/mypy/issues/12800 is resolved if isinstance(arg, DataArray): - return cast("T_DataWithCoords", _broadcast_array(arg)) + return cast(T_Alignable, _broadcast_array(arg)) elif isinstance(arg, Dataset): - return cast("T_DataWithCoords", _broadcast_dataset(arg)) + return cast(T_Alignable, _broadcast_dataset(arg)) else: raise ValueError("all input must be Dataset or DataArray objects") -# TODO: this typing is too restrictive since it cannot deal with mixed -# DataArray and Dataset types...? Is this a problem? -def broadcast(*args: T_DataWithCoords, exclude=None) -> tuple[T_DataWithCoords, ...]: +@overload +def broadcast( + obj1: T_Obj1, /, *, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Obj1]: ... + + +@overload +def broadcast( + obj1: T_Obj1, obj2: T_Obj2, /, *, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Obj1, T_Obj2]: ... + + +@overload +def broadcast( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + /, + *, + exclude: str | Iterable[Hashable] | None = None, +) -> tuple[T_Obj1, T_Obj2, T_Obj3]: ... + + +@overload +def broadcast( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + /, + *, + exclude: str | Iterable[Hashable] | None = None, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4]: ... + + +@overload +def broadcast( + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + obj5: T_Obj5, + /, + *, + exclude: str | Iterable[Hashable] | None = None, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4, T_Obj5]: ... + + +@overload +def broadcast( + *args: T_Alignable, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Alignable, ...]: ... + + +def broadcast( + *args: T_Alignable, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Alignable, ...]: """Explicitly broadcast any number of DataArray or Dataset objects against one another. @@ -1005,7 +1157,7 @@ def broadcast(*args: T_DataWithCoords, exclude=None) -> tuple[T_DataWithCoords, ---------- *args : DataArray or Dataset Arrays to broadcast against each other. - exclude : sequence of str, optional + exclude : str, iterable of hashable or None, optional Dimensions that must not be broadcasted Returns @@ -1021,22 +1173,22 @@ def broadcast(*args: T_DataWithCoords, exclude=None) -> tuple[T_DataWithCoords, >>> a = xr.DataArray([1, 2, 3], dims="x") >>> b = xr.DataArray([5, 6], dims="y") >>> a - + Size: 24B array([1, 2, 3]) Dimensions without coordinates: x >>> b - + Size: 16B array([5, 6]) Dimensions without coordinates: y >>> a2, b2 = xr.broadcast(a, b) >>> a2 - + Size: 48B array([[1, 1], [2, 2], [3, 3]]) Dimensions without coordinates: x, y >>> b2 - + Size: 48B array([[5, 6], [5, 6], [5, 6]]) @@ -1047,12 +1199,12 @@ def broadcast(*args: T_DataWithCoords, exclude=None) -> tuple[T_DataWithCoords, >>> ds = xr.Dataset({"a": a, "b": b}) >>> (ds2,) = xr.broadcast(ds) # use tuple unpacking to extract one dataset >>> ds2 - + Size: 96B Dimensions: (x: 3, y: 2) Dimensions without coordinates: x, y Data variables: - a (x, y) int64 1 1 2 2 3 3 - b (x, y) int64 5 6 5 6 5 6 + a (x, y) int64 48B 1 1 2 2 3 3 + b (x, y) int64 48B 5 6 5 6 5 6 """ if exclude is None: diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py index 5b2cf38ee2e..452c7115b75 100644 --- a/xarray/core/arithmetic.py +++ b/xarray/core/arithmetic.py @@ -1,4 +1,5 @@ """Base classes implementing arithmetic for xarray objects.""" + from __future__ import annotations import numbers @@ -14,13 +15,9 @@ VariableOpsMixin, ) from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce -from xarray.core.ops import ( - IncludeCumMethods, - IncludeNumpySameMethods, - IncludeReduceMethods, -) +from xarray.core.ops import IncludeNumpySameMethods, IncludeReduceMethods from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.pycompat import is_duck_array +from xarray.namedarray.utils import is_duck_array class SupportsArithmetic: @@ -56,10 +53,10 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): if ufunc.signature is not None: raise NotImplementedError( - "{} not supported: xarray objects do not directly implement " + f"{ufunc} not supported: xarray objects do not directly implement " "generalized ufuncs. Instead, use xarray.apply_ufunc or " "explicitly convert to xarray objects to NumPy arrays " - "(e.g., with `.values`).".format(ufunc) + "(e.g., with `.values`)." ) if method != "__call__": @@ -99,8 +96,6 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): class VariableArithmetic( ImplementsArrayReduce, - IncludeReduceMethods, - IncludeCumMethods, IncludeNumpySameMethods, SupportsArithmetic, VariableOpsMixin, diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 946f71e5d28..5cb0a3417fa 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -1,7 +1,6 @@ from __future__ import annotations import itertools -import warnings from collections import Counter from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Literal, Union @@ -110,9 +109,9 @@ def _infer_concat_order_from_coords(datasets): ascending = False else: raise ValueError( - "Coordinate variable {} is neither " + f"Coordinate variable {dim} is neither " "monotonically increasing nor " - "monotonically decreasing on all datasets".format(dim) + "monotonically decreasing on all datasets" ) # Assume that any two datasets whose coord along dim starts @@ -180,7 +179,7 @@ def _check_shape_tile_ids(combined_tile_ids): raise ValueError( "The supplied objects do not form a hypercube " "because sub-lists do not have consistent " - "lengths along dimension" + str(dim) + f"lengths along dimension {dim}" ) @@ -222,10 +221,8 @@ def _combine_nd( n_dims = len(example_tile_id) if len(concat_dims) != n_dims: raise ValueError( - "concat_dims has length {} but the datasets " - "passed are nested in a {}-dimensional structure".format( - len(concat_dims), n_dims - ) + f"concat_dims has length {len(concat_dims)} but the datasets " + f"passed are nested in a {n_dims}-dimensional structure" ) # Each iteration of this loop reduces the length of the tile_ids tuples @@ -369,14 +366,13 @@ def _nested_combine( return combined -# Define type for arbitrarily-nested list of lists recursively -# Currently mypy cannot handle this but other linters can (https://stackoverflow.com/a/53845083/3154101) -DATASET_HYPERCUBE = Union[Dataset, Iterable["DATASET_HYPERCUBE"]] # type: ignore[misc] +# Define type for arbitrarily-nested list of lists recursively: +DATASET_HYPERCUBE = Union[Dataset, Iterable["DATASET_HYPERCUBE"]] def combine_nested( datasets: DATASET_HYPERCUBE, - concat_dim: (str | DataArray | None | Sequence[str | DataArray | pd.Index | None]), + concat_dim: str | DataArray | None | Sequence[str | DataArray | pd.Index | None], compat: str = "no_conflicts", data_vars: str = "all", coords: str = "different", @@ -488,12 +484,12 @@ def combine_nested( ... } ... ) >>> x1y1 - + Size: 64B Dimensions: (x: 2, y: 2) Dimensions without coordinates: x, y Data variables: - temperature (x, y) float64 1.764 0.4002 0.9787 2.241 - precipitation (x, y) float64 1.868 -0.9773 0.9501 -0.1514 + temperature (x, y) float64 32B 1.764 0.4002 0.9787 2.241 + precipitation (x, y) float64 32B 1.868 -0.9773 0.9501 -0.1514 >>> x1y2 = xr.Dataset( ... { ... "temperature": (("x", "y"), np.random.randn(2, 2)), @@ -517,12 +513,12 @@ def combine_nested( >>> ds_grid = [[x1y1, x1y2], [x2y1, x2y2]] >>> combined = xr.combine_nested(ds_grid, concat_dim=["x", "y"]) >>> combined - + Size: 256B Dimensions: (x: 4, y: 4) Dimensions without coordinates: x, y Data variables: - temperature (x, y) float64 1.764 0.4002 -0.1032 ... 0.04576 -0.1872 - precipitation (x, y) float64 1.868 -0.9773 0.761 ... -0.7422 0.1549 0.3782 + temperature (x, y) float64 128B 1.764 0.4002 -0.1032 ... 0.04576 -0.1872 + precipitation (x, y) float64 128B 1.868 -0.9773 0.761 ... 0.1549 0.3782 ``combine_nested`` can also be used to explicitly merge datasets with different variables. For example if we have 4 datasets, which are divided @@ -532,19 +528,19 @@ def combine_nested( >>> t1temp = xr.Dataset({"temperature": ("t", np.random.randn(5))}) >>> t1temp - + Size: 40B Dimensions: (t: 5) Dimensions without coordinates: t Data variables: - temperature (t) float64 -0.8878 -1.981 -0.3479 0.1563 1.23 + temperature (t) float64 40B -0.8878 -1.981 -0.3479 0.1563 1.23 >>> t1precip = xr.Dataset({"precipitation": ("t", np.random.randn(5))}) >>> t1precip - + Size: 40B Dimensions: (t: 5) Dimensions without coordinates: t Data variables: - precipitation (t) float64 1.202 -0.3873 -0.3023 -1.049 -1.42 + precipitation (t) float64 40B 1.202 -0.3873 -0.3023 -1.049 -1.42 >>> t2temp = xr.Dataset({"temperature": ("t", np.random.randn(5))}) >>> t2precip = xr.Dataset({"precipitation": ("t", np.random.randn(5))}) @@ -553,12 +549,12 @@ def combine_nested( >>> ds_grid = [[t1temp, t1precip], [t2temp, t2precip]] >>> combined = xr.combine_nested(ds_grid, concat_dim=["t", None]) >>> combined - + Size: 160B Dimensions: (t: 10) Dimensions without coordinates: t Data variables: - temperature (t) float64 -0.8878 -1.981 -0.3479 ... -0.5097 -0.4381 -1.253 - precipitation (t) float64 1.202 -0.3873 -0.3023 ... -0.2127 -0.8955 0.3869 + temperature (t) float64 80B -0.8878 -1.981 -0.3479 ... -0.4381 -1.253 + precipitation (t) float64 80B 1.202 -0.3873 -0.3023 ... -0.8955 0.3869 See also -------- @@ -648,13 +644,12 @@ def _combine_single_variable_hypercube( if not (indexes.is_monotonic_increasing or indexes.is_monotonic_decreasing): raise ValueError( "Resulting object does not have monotonic" - " global indexes along dimension {}".format(dim) + f" global indexes along dimension {dim}" ) return concatenated -# TODO remove empty list default param after version 0.21, see PR4696 def combine_by_coords( data_objects: Iterable[Dataset | DataArray] = [], compat: CompatOptions = "no_conflicts", @@ -663,7 +658,6 @@ def combine_by_coords( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "no_conflicts", - datasets: Iterable[Dataset] | None = None, ) -> Dataset | DataArray: """ @@ -745,7 +739,7 @@ def combine_by_coords( dimension must have the same size in all objects. combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ - "override"} or callable, default: "drop" + "override"} or callable, default: "no_conflicts" A callable or a string indicating how to combine attrs of the objects being merged: @@ -761,8 +755,6 @@ def combine_by_coords( If a callable, it must expect a sequence of ``attrs`` dicts and a context object as its only parameters. - datasets : Iterable of Datasets - Returns ------- combined : xarray.Dataset or xarray.DataArray @@ -805,74 +797,74 @@ def combine_by_coords( ... ) >>> x1 - + Size: 136B Dimensions: (y: 2, x: 3) Coordinates: - * y (y) int64 0 1 - * x (x) int64 10 20 30 + * y (y) int64 16B 0 1 + * x (x) int64 24B 10 20 30 Data variables: - temperature (y, x) float64 10.98 14.3 12.06 10.9 8.473 12.92 - precipitation (y, x) float64 0.4376 0.8918 0.9637 0.3834 0.7917 0.5289 + temperature (y, x) float64 48B 10.98 14.3 12.06 10.9 8.473 12.92 + precipitation (y, x) float64 48B 0.4376 0.8918 0.9637 0.3834 0.7917 0.5289 >>> x2 - + Size: 136B Dimensions: (y: 2, x: 3) Coordinates: - * y (y) int64 2 3 - * x (x) int64 10 20 30 + * y (y) int64 16B 2 3 + * x (x) int64 24B 10 20 30 Data variables: - temperature (y, x) float64 11.36 18.51 1.421 1.743 0.4044 16.65 - precipitation (y, x) float64 0.7782 0.87 0.9786 0.7992 0.4615 0.7805 + temperature (y, x) float64 48B 11.36 18.51 1.421 1.743 0.4044 16.65 + precipitation (y, x) float64 48B 0.7782 0.87 0.9786 0.7992 0.4615 0.7805 >>> x3 - + Size: 136B Dimensions: (y: 2, x: 3) Coordinates: - * y (y) int64 2 3 - * x (x) int64 40 50 60 + * y (y) int64 16B 2 3 + * x (x) int64 24B 40 50 60 Data variables: - temperature (y, x) float64 2.365 12.8 2.867 18.89 10.44 8.293 - precipitation (y, x) float64 0.2646 0.7742 0.4562 0.5684 0.01879 0.6176 + temperature (y, x) float64 48B 2.365 12.8 2.867 18.89 10.44 8.293 + precipitation (y, x) float64 48B 0.2646 0.7742 0.4562 0.5684 0.01879 0.6176 >>> xr.combine_by_coords([x2, x1]) - + Size: 248B Dimensions: (y: 4, x: 3) Coordinates: - * y (y) int64 0 1 2 3 - * x (x) int64 10 20 30 + * y (y) int64 32B 0 1 2 3 + * x (x) int64 24B 10 20 30 Data variables: - temperature (y, x) float64 10.98 14.3 12.06 10.9 ... 1.743 0.4044 16.65 - precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.7992 0.4615 0.7805 + temperature (y, x) float64 96B 10.98 14.3 12.06 ... 1.743 0.4044 16.65 + precipitation (y, x) float64 96B 0.4376 0.8918 0.9637 ... 0.4615 0.7805 >>> xr.combine_by_coords([x3, x1]) - + Size: 464B Dimensions: (y: 4, x: 6) Coordinates: - * y (y) int64 0 1 2 3 - * x (x) int64 10 20 30 40 50 60 + * y (y) int64 32B 0 1 2 3 + * x (x) int64 48B 10 20 30 40 50 60 Data variables: - temperature (y, x) float64 10.98 14.3 12.06 nan ... nan 18.89 10.44 8.293 - precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.5684 0.01879 0.6176 + temperature (y, x) float64 192B 10.98 14.3 12.06 ... 18.89 10.44 8.293 + precipitation (y, x) float64 192B 0.4376 0.8918 0.9637 ... 0.01879 0.6176 >>> xr.combine_by_coords([x3, x1], join="override") - + Size: 256B Dimensions: (y: 2, x: 6) Coordinates: - * y (y) int64 0 1 - * x (x) int64 10 20 30 40 50 60 + * y (y) int64 16B 0 1 + * x (x) int64 48B 10 20 30 40 50 60 Data variables: - temperature (y, x) float64 10.98 14.3 12.06 2.365 ... 18.89 10.44 8.293 - precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.5684 0.01879 0.6176 + temperature (y, x) float64 96B 10.98 14.3 12.06 ... 18.89 10.44 8.293 + precipitation (y, x) float64 96B 0.4376 0.8918 0.9637 ... 0.01879 0.6176 >>> xr.combine_by_coords([x1, x2, x3]) - + Size: 464B Dimensions: (y: 4, x: 6) Coordinates: - * y (y) int64 0 1 2 3 - * x (x) int64 10 20 30 40 50 60 + * y (y) int64 32B 0 1 2 3 + * x (x) int64 48B 10 20 30 40 50 60 Data variables: - temperature (y, x) float64 10.98 14.3 12.06 nan ... 18.89 10.44 8.293 - precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.5684 0.01879 0.6176 + temperature (y, x) float64 192B 10.98 14.3 12.06 ... 18.89 10.44 8.293 + precipitation (y, x) float64 192B 0.4376 0.8918 0.9637 ... 0.01879 0.6176 You can also combine DataArray objects, but the behaviour will differ depending on whether or not the DataArrays are named. If all DataArrays are named then they will @@ -883,50 +875,42 @@ def combine_by_coords( ... name="a", data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x" ... ) >>> named_da1 - + Size: 16B array([1., 2.]) Coordinates: - * x (x) int64 0 1 + * x (x) int64 16B 0 1 >>> named_da2 = xr.DataArray( ... name="a", data=[3.0, 4.0], coords={"x": [2, 3]}, dims="x" ... ) >>> named_da2 - + Size: 16B array([3., 4.]) Coordinates: - * x (x) int64 2 3 + * x (x) int64 16B 2 3 >>> xr.combine_by_coords([named_da1, named_da2]) - + Size: 64B Dimensions: (x: 4) Coordinates: - * x (x) int64 0 1 2 3 + * x (x) int64 32B 0 1 2 3 Data variables: - a (x) float64 1.0 2.0 3.0 4.0 + a (x) float64 32B 1.0 2.0 3.0 4.0 If all the DataArrays are unnamed, a single DataArray will be returned, e.g. >>> unnamed_da1 = xr.DataArray(data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x") >>> unnamed_da2 = xr.DataArray(data=[3.0, 4.0], coords={"x": [2, 3]}, dims="x") >>> xr.combine_by_coords([unnamed_da1, unnamed_da2]) - + Size: 32B array([1., 2., 3., 4.]) Coordinates: - * x (x) int64 0 1 2 3 + * x (x) int64 32B 0 1 2 3 Finally, if you attempt to combine a mix of unnamed DataArrays with either named DataArrays or Datasets, a ValueError will be raised (as this is an ambiguous operation). """ - # TODO remove after version 0.21, see PR4696 - if datasets is not None: - warnings.warn( - "The datasets argument has been renamed to `data_objects`." - " From 0.21 on passing a value for datasets will raise an error." - ) - data_objects = datasets - if not data_objects: return Dataset() @@ -971,10 +955,9 @@ def combine_by_coords( # Perform the multidimensional combine on each group of data variables # before merging back together - concatenated_grouped_by_data_vars = [] - for vars, datasets_with_same_vars in grouped_by_vars: - concatenated = _combine_single_variable_hypercube( - list(datasets_with_same_vars), + concatenated_grouped_by_data_vars = tuple( + _combine_single_variable_hypercube( + tuple(datasets_with_same_vars), fill_value=fill_value, data_vars=data_vars, coords=coords, @@ -982,7 +965,8 @@ def combine_by_coords( join=join, combine_attrs=combine_attrs, ) - concatenated_grouped_by_data_vars.append(concatenated) + for vars, datasets_with_same_vars in grouped_by_vars + ) return merge( concatenated_grouped_by_data_vars, diff --git a/xarray/core/common.py b/xarray/core/common.py index af935ae15d2..7b9a049c662 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -13,14 +13,14 @@ from xarray.core import dtypes, duck_array_ops, formatting, formatting_html, ops from xarray.core.indexing import BasicIndexer, ExplicitlyIndexed from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.pdcompat import _convert_base_to_offset -from xarray.core.pycompat import is_duck_dask_array from xarray.core.utils import ( Frozen, either_dict_or_kwargs, - emit_user_level_warning, is_scalar, ) +from xarray.namedarray.core import _raise_if_any_duplicate_dimensions +from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager +from xarray.namedarray.pycompat import is_chunked_array try: import cftime @@ -45,7 +45,9 @@ DatetimeLike, DTypeLikeSave, ScalarOrArray, + Self, SideOptions, + T_Chunks, T_DataWithCoords, T_Variable, ) @@ -159,7 +161,7 @@ def __int__(self: Any) -> int: def __complex__(self: Any) -> complex: return complex(self.values) - def __array__(self: Any, dtype: DTypeLike = None) -> np.ndarray: + def __array__(self: Any, dtype: DTypeLike | None = None) -> np.ndarray: return np.asarray(self.values, dtype=dtype) def __repr__(self) -> str: @@ -196,6 +198,12 @@ def __iter__(self: Any) -> Iterator[Any]: raise TypeError("iteration over a 0-d array") return self._iter() + @overload + def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]: ... + + @overload + def get_axis_num(self, dim: Hashable) -> int: ... + def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]: """Return axis number(s) corresponding to dimension(s) in this array. @@ -209,19 +217,20 @@ def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, . int or tuple of int Axis number or numbers corresponding to the given dimensions. """ - if isinstance(dim, Iterable) and not isinstance(dim, str): + if not isinstance(dim, str) and isinstance(dim, Iterable): return tuple(self._get_axis_num(d) for d in dim) else: return self._get_axis_num(dim) def _get_axis_num(self: Any, dim: Hashable) -> int: + _raise_if_any_duplicate_dimensions(self.dims) try: return self.dims.index(dim) except ValueError: raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") @property - def sizes(self: Any) -> Frozen[Hashable, int]: + def sizes(self: Any) -> Mapping[Hashable, int]: """Ordered mapping from dimension names to lengths. Immutable. @@ -305,9 +314,7 @@ def __setattr__(self, name: str, value: Any) -> None: except AttributeError as e: # Don't accidentally shadow custom AttributeErrors, e.g. # DataArray.dims.setter - if str(e) != "{!r} object has no attribute {!r}".format( - type(self).__name__, name - ): + if str(e) != f"{type(self).__name__!r} object has no attribute {name!r}": raise raise AttributeError( f"cannot set attribute {name!r} on a {type(self).__name__!r} object. Use __setitem__ style" @@ -382,11 +389,11 @@ class DataWithCoords(AttrAccessMixin): __slots__ = ("_close",) def squeeze( - self: T_DataWithCoords, + self, dim: Hashable | Iterable[Hashable] | None = None, drop: bool = False, axis: int | Iterable[int] | None = None, - ) -> T_DataWithCoords: + ) -> Self: """Return a new object with squeezed data. Parameters @@ -415,12 +422,12 @@ def squeeze( return self.isel(drop=drop, **{d: 0 for d in dims}) def clip( - self: T_DataWithCoords, + self, min: ScalarOrArray | None = None, max: ScalarOrArray | None = None, *, keep_attrs: bool | None = None, - ) -> T_DataWithCoords: + ) -> Self: """ Return an array whose values are limited to ``[min, max]``. At least one of max or min must be given. @@ -473,10 +480,10 @@ def _calc_assign_results( return {k: v(self) if callable(v) else v for k, v in kwargs.items()} def assign_coords( - self: T_DataWithCoords, - coords: Mapping[Any, Any] | None = None, + self, + coords: Mapping | None = None, **coords_kwargs: Any, - ) -> T_DataWithCoords: + ) -> Self: """Assign new coordinates to this object. Returns a new object with all the original data in addition to the new @@ -484,15 +491,21 @@ def assign_coords( Parameters ---------- - coords : dict-like or None, optional - A dict where the keys are the names of the coordinates - with the new values to assign. If the values are callable, they are - computed on this object and assigned to new coordinate variables. - If the values are not callable, (e.g. a ``DataArray``, scalar, or - array), they are simply assigned. A new coordinate can also be - defined and attached to an existing dimension using a tuple with - the first element the dimension name and the second element the - values for this new coordinate. + coords : mapping of dim to coord, optional + A mapping whose keys are the names of the coordinates and values are the + coordinates to assign. The mapping will generally be a dict or + :class:`Coordinates`. + + * If a value is a standard data value — for example, a ``DataArray``, + scalar, or array — the data is simply assigned as a coordinate. + + * If a value is callable, it is called with this object as the only + parameter, and the return value is used as new coordinate variables. + + * A coordinate can also be defined and attached to an existing dimension + using a tuple with the first element the dimension name and the second + element the values for this new coordinate. + **coords_kwargs : optional The keyword arguments form of ``coords``. One of ``coords`` or ``coords_kwargs`` must be provided. @@ -513,33 +526,33 @@ def assign_coords( ... dims="lon", ... ) >>> da - + Size: 32B array([0.5488135 , 0.71518937, 0.60276338, 0.54488318]) Coordinates: - * lon (lon) int64 358 359 0 1 + * lon (lon) int64 32B 358 359 0 1 >>> da.assign_coords(lon=(((da.lon + 180) % 360) - 180)) - + Size: 32B array([0.5488135 , 0.71518937, 0.60276338, 0.54488318]) Coordinates: - * lon (lon) int64 -2 -1 0 1 + * lon (lon) int64 32B -2 -1 0 1 The function also accepts dictionary arguments: >>> da.assign_coords({"lon": (((da.lon + 180) % 360) - 180)}) - + Size: 32B array([0.5488135 , 0.71518937, 0.60276338, 0.54488318]) Coordinates: - * lon (lon) int64 -2 -1 0 1 + * lon (lon) int64 32B -2 -1 0 1 New coordinate can also be attached to an existing dimension: >>> lon_2 = np.array([300, 289, 0, 1]) >>> da.assign_coords(lon_2=("lon", lon_2)) - + Size: 32B array([0.5488135 , 0.71518937, 0.60276338, 0.54488318]) Coordinates: - * lon (lon) int64 358 359 0 1 - lon_2 (lon) int64 300 289 0 1 + * lon (lon) int64 32B 358 359 0 1 + lon_2 (lon) int64 32B 300 289 0 1 Note that the same result can also be obtained with a dict e.g. @@ -565,57 +578,55 @@ def assign_coords( ... attrs=dict(description="Weather-related data"), ... ) >>> ds - + Size: 360B Dimensions: (x: 2, y: 2, time: 4) Coordinates: - lon (x, y) float64 260.2 260.7 260.2 260.8 - lat (x, y) float64 42.25 42.21 42.63 42.59 - * time (time) datetime64[ns] 2014-09-06 2014-09-07 ... 2014-09-09 - reference_time datetime64[ns] 2014-09-05 + lon (x, y) float64 32B 260.2 260.7 260.2 260.8 + lat (x, y) float64 32B 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 32B 2014-09-06 ... 2014-09-09 + reference_time datetime64[ns] 8B 2014-09-05 Dimensions without coordinates: x, y Data variables: - temperature (x, y, time) float64 20.0 20.8 21.6 22.4 ... 30.4 31.2 32.0 - precipitation (x, y, time) float64 2.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 2.0 + temperature (x, y, time) float64 128B 20.0 20.8 21.6 ... 30.4 31.2 32.0 + precipitation (x, y, time) float64 128B 2.0 0.0 0.0 0.0 ... 0.0 0.0 2.0 Attributes: description: Weather-related data >>> ds.assign_coords(lon=(((ds.lon + 180) % 360) - 180)) - + Size: 360B Dimensions: (x: 2, y: 2, time: 4) Coordinates: - lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 - lat (x, y) float64 42.25 42.21 42.63 42.59 - * time (time) datetime64[ns] 2014-09-06 2014-09-07 ... 2014-09-09 - reference_time datetime64[ns] 2014-09-05 + lon (x, y) float64 32B -99.83 -99.32 -99.79 -99.23 + lat (x, y) float64 32B 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 32B 2014-09-06 ... 2014-09-09 + reference_time datetime64[ns] 8B 2014-09-05 Dimensions without coordinates: x, y Data variables: - temperature (x, y, time) float64 20.0 20.8 21.6 22.4 ... 30.4 31.2 32.0 - precipitation (x, y, time) float64 2.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 2.0 + temperature (x, y, time) float64 128B 20.0 20.8 21.6 ... 30.4 31.2 32.0 + precipitation (x, y, time) float64 128B 2.0 0.0 0.0 0.0 ... 0.0 0.0 2.0 Attributes: description: Weather-related data - Notes - ----- - Since ``coords_kwargs`` is a dictionary, the order of your arguments - may not be preserved, and so the order of the new variables is not well - defined. Assigning multiple variables within the same ``assign_coords`` - is possible, but you cannot reference other variables created within - the same ``assign_coords`` call. - See Also -------- Dataset.assign Dataset.swap_dims Dataset.set_coords """ + from xarray.core.coordinates import Coordinates + coords_combined = either_dict_or_kwargs(coords, coords_kwargs, "assign_coords") data = self.copy(deep=False) - results: dict[Hashable, Any] = self._calc_assign_results(coords_combined) + + results: Coordinates | dict[Hashable, Any] + if isinstance(coords, Coordinates): + results = coords + else: + results = self._calc_assign_results(coords_combined) + data.coords.update(results) return data - def assign_attrs( - self: T_DataWithCoords, *args: Any, **kwargs: Any - ) -> T_DataWithCoords: + def assign_attrs(self, *args: Any, **kwargs: Any) -> Self: """Assign new attrs to this object. Returns a new object equivalent to ``self.attrs.update(*args, **kwargs)``. @@ -627,6 +638,36 @@ def assign_attrs( **kwargs keyword arguments passed into ``attrs.update``. + Examples + -------- + >>> dataset = xr.Dataset({"temperature": [25, 30, 27]}) + >>> dataset + Size: 24B + Dimensions: (temperature: 3) + Coordinates: + * temperature (temperature) int64 24B 25 30 27 + Data variables: + *empty* + + >>> new_dataset = dataset.assign_attrs( + ... units="Celsius", description="Temperature data" + ... ) + >>> new_dataset + Size: 24B + Dimensions: (temperature: 3) + Coordinates: + * temperature (temperature) int64 24B 25 30 27 + Data variables: + *empty* + Attributes: + units: Celsius + description: Temperature data + + # Attributes of the new dataset + + >>> new_dataset.attrs + {'units': 'Celsius', 'description': 'Temperature data'} + Returns ------- assigned : same type as caller @@ -705,14 +746,14 @@ def pipe( ... coords={"lat": [10, 20], "lon": [150, 160]}, ... ) >>> x - + Size: 96B Dimensions: (lat: 2, lon: 2) Coordinates: - * lat (lat) int64 10 20 - * lon (lon) int64 150 160 + * lat (lat) int64 16B 10 20 + * lon (lon) int64 16B 150 160 Data variables: - temperature_c (lat, lon) float64 10.98 14.3 12.06 10.9 - precipitation (lat, lon) float64 0.4237 0.6459 0.4376 0.8918 + temperature_c (lat, lon) float64 32B 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 32B 0.4237 0.6459 0.4376 0.8918 >>> def adder(data, arg): ... return data + arg @@ -724,38 +765,38 @@ def pipe( ... return (data * mult_arg) - sub_arg ... >>> x.pipe(adder, 2) - + Size: 96B Dimensions: (lat: 2, lon: 2) Coordinates: - * lat (lat) int64 10 20 - * lon (lon) int64 150 160 + * lat (lat) int64 16B 10 20 + * lon (lon) int64 16B 150 160 Data variables: - temperature_c (lat, lon) float64 12.98 16.3 14.06 12.9 - precipitation (lat, lon) float64 2.424 2.646 2.438 2.892 + temperature_c (lat, lon) float64 32B 12.98 16.3 14.06 12.9 + precipitation (lat, lon) float64 32B 2.424 2.646 2.438 2.892 >>> x.pipe(adder, arg=2) - + Size: 96B Dimensions: (lat: 2, lon: 2) Coordinates: - * lat (lat) int64 10 20 - * lon (lon) int64 150 160 + * lat (lat) int64 16B 10 20 + * lon (lon) int64 16B 150 160 Data variables: - temperature_c (lat, lon) float64 12.98 16.3 14.06 12.9 - precipitation (lat, lon) float64 2.424 2.646 2.438 2.892 + temperature_c (lat, lon) float64 32B 12.98 16.3 14.06 12.9 + precipitation (lat, lon) float64 32B 2.424 2.646 2.438 2.892 >>> ( ... x.pipe(adder, arg=2) ... .pipe(div, arg=2) ... .pipe(sub_mult, sub_arg=2, mult_arg=2) ... ) - + Size: 96B Dimensions: (lat: 2, lon: 2) Coordinates: - * lat (lat) int64 10 20 - * lon (lon) int64 150 160 + * lat (lat) int64 16B 10 20 + * lon (lon) int64 16B 150 160 Data variables: - temperature_c (lat, lon) float64 10.98 14.3 12.06 10.9 - precipitation (lat, lon) float64 0.4237 0.6459 0.4376 0.8918 + temperature_c (lat, lon) float64 32B 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 32B 0.4237 0.6459 0.4376 0.8918 See Also -------- @@ -824,7 +865,6 @@ def _resample( base: int | None, offset: pd.Timedelta | datetime.timedelta | str | None, origin: str | DatetimeLike, - keep_attrs: bool | None, loffset: datetime.timedelta | str | None, restore_coord_dims: bool | None, **indexer_kwargs: str, @@ -906,36 +946,96 @@ def _resample( ... dims="time", ... ) >>> da - + Size: 96B array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) Coordinates: - * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15 + * time (time) datetime64[ns] 96B 1999-12-15 2000-01-15 ... 2000-11-15 >>> da.resample(time="QS-DEC").mean() - + Size: 32B array([ 1., 4., 7., 10.]) Coordinates: - * time (time) datetime64[ns] 1999-12-01 2000-03-01 2000-06-01 2000-09-01 + * time (time) datetime64[ns] 32B 1999-12-01 2000-03-01 ... 2000-09-01 Upsample monthly time-series data to daily data: >>> da.resample(time="1D").interpolate("linear") # +doctest: ELLIPSIS - + Size: 3kB array([ 0. , 0.03225806, 0.06451613, 0.09677419, 0.12903226, 0.16129032, 0.19354839, 0.22580645, 0.25806452, 0.29032258, 0.32258065, 0.35483871, 0.38709677, 0.41935484, 0.4516129 , + 0.48387097, 0.51612903, 0.5483871 , 0.58064516, 0.61290323, + 0.64516129, 0.67741935, 0.70967742, 0.74193548, 0.77419355, + 0.80645161, 0.83870968, 0.87096774, 0.90322581, 0.93548387, + 0.96774194, 1. , 1.03225806, 1.06451613, 1.09677419, + 1.12903226, 1.16129032, 1.19354839, 1.22580645, 1.25806452, + 1.29032258, 1.32258065, 1.35483871, 1.38709677, 1.41935484, + 1.4516129 , 1.48387097, 1.51612903, 1.5483871 , 1.58064516, + 1.61290323, 1.64516129, 1.67741935, 1.70967742, 1.74193548, + 1.77419355, 1.80645161, 1.83870968, 1.87096774, 1.90322581, + 1.93548387, 1.96774194, 2. , 2.03448276, 2.06896552, + 2.10344828, 2.13793103, 2.17241379, 2.20689655, 2.24137931, + 2.27586207, 2.31034483, 2.34482759, 2.37931034, 2.4137931 , + 2.44827586, 2.48275862, 2.51724138, 2.55172414, 2.5862069 , + 2.62068966, 2.65517241, 2.68965517, 2.72413793, 2.75862069, + 2.79310345, 2.82758621, 2.86206897, 2.89655172, 2.93103448, + 2.96551724, 3. , 3.03225806, 3.06451613, 3.09677419, + 3.12903226, 3.16129032, 3.19354839, 3.22580645, 3.25806452, ... + 7.87096774, 7.90322581, 7.93548387, 7.96774194, 8. , + 8.03225806, 8.06451613, 8.09677419, 8.12903226, 8.16129032, + 8.19354839, 8.22580645, 8.25806452, 8.29032258, 8.32258065, + 8.35483871, 8.38709677, 8.41935484, 8.4516129 , 8.48387097, + 8.51612903, 8.5483871 , 8.58064516, 8.61290323, 8.64516129, + 8.67741935, 8.70967742, 8.74193548, 8.77419355, 8.80645161, + 8.83870968, 8.87096774, 8.90322581, 8.93548387, 8.96774194, + 9. , 9.03333333, 9.06666667, 9.1 , 9.13333333, + 9.16666667, 9.2 , 9.23333333, 9.26666667, 9.3 , + 9.33333333, 9.36666667, 9.4 , 9.43333333, 9.46666667, + 9.5 , 9.53333333, 9.56666667, 9.6 , 9.63333333, + 9.66666667, 9.7 , 9.73333333, 9.76666667, 9.8 , + 9.83333333, 9.86666667, 9.9 , 9.93333333, 9.96666667, + 10. , 10.03225806, 10.06451613, 10.09677419, 10.12903226, + 10.16129032, 10.19354839, 10.22580645, 10.25806452, 10.29032258, + 10.32258065, 10.35483871, 10.38709677, 10.41935484, 10.4516129 , + 10.48387097, 10.51612903, 10.5483871 , 10.58064516, 10.61290323, + 10.64516129, 10.67741935, 10.70967742, 10.74193548, 10.77419355, 10.80645161, 10.83870968, 10.87096774, 10.90322581, 10.93548387, 10.96774194, 11. ]) Coordinates: - * time (time) datetime64[ns] 1999-12-15 1999-12-16 ... 2000-11-15 + * time (time) datetime64[ns] 3kB 1999-12-15 1999-12-16 ... 2000-11-15 Limit scope of upsampling method >>> da.resample(time="1D").nearest(tolerance="1D") - - array([ 0., 0., nan, ..., nan, 11., 11.]) + Size: 3kB + array([ 0., 0., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, 1., 1., 1., nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, 2., 2., 2., nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 3., + 3., 3., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, 4., 4., 4., nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, 5., 5., 5., nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + 6., 6., 6., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, 7., 7., 7., nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, 8., 8., 8., nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, 9., 9., 9., nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, 10., 10., 10., nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 11., 11.]) Coordinates: - * time (time) datetime64[ns] 1999-12-15 1999-12-16 ... 2000-11-15 + * time (time) datetime64[ns] 3kB 1999-12-15 1999-12-16 ... 2000-11-15 See Also -------- @@ -949,16 +1049,9 @@ def _resample( # TODO support non-string indexer after removing the old API. from xarray.core.dataarray import DataArray - from xarray.core.groupby import TimeResampleGrouper + from xarray.core.groupby import ResolvedGrouper, TimeResampler from xarray.core.resample import RESAMPLE_DIM - if keep_attrs is not None: - warnings.warn( - "Passing ``keep_attrs`` to ``resample`` has no effect and will raise an" - " error in xarray 0.20. Pass ``keep_attrs`` directly to the applied" - " function, e.g. ``resample(...).mean(keep_attrs=True)``." - ) - # note: the second argument (now 'skipna') use to be 'dim' if ( (skipna is not None and not isinstance(skipna, bool)) @@ -979,53 +1072,39 @@ def _resample( dim_name: Hashable = dim dim_coord = self[dim] - if loffset is not None: - emit_user_level_warning( - "Following pandas, the `loffset` parameter to resample will be deprecated " - "in a future version of xarray. Switch to using time offset arithmetic.", - FutureWarning, - ) - - if base is not None: - emit_user_level_warning( - "Following pandas, the `base` parameter to resample will be deprecated in " - "a future version of xarray. Switch to using `origin` or `offset` instead.", - FutureWarning, - ) - - if base is not None and offset is not None: - raise ValueError("base and offset cannot be present at the same time") - - if base is not None: - index = self._indexes[dim_name].to_pandas_index() - offset = _convert_base_to_offset(base, freq, index) + group = DataArray( + dim_coord, + coords=dim_coord.coords, + dims=dim_coord.dims, + name=RESAMPLE_DIM, + ) - grouper = TimeResampleGrouper( + grouper = TimeResampler( freq=freq, closed=closed, label=label, origin=origin, offset=offset, loffset=loffset, + base=base, ) - group = DataArray( - dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM - ) + rgrouper = ResolvedGrouper(grouper, group, self) + return resample_cls( self, - group=group, + (rgrouper,), dim=dim_name, - grouper=grouper, resample_dim=RESAMPLE_DIM, restore_coord_dims=restore_coord_dims, ) - def where( - self: T_DataWithCoords, cond: Any, other: Any = dtypes.NA, drop: bool = False - ) -> T_DataWithCoords: + def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: """Filter elements from this object according to a condition. + Returns elements from 'DataArray', where 'cond' is True, + otherwise fill in 'other'. + This operation follows the normal broadcasting and alignment rules that xarray uses for binary arithmetic. @@ -1033,10 +1112,12 @@ def where( ---------- cond : DataArray, Dataset, or callable Locations at which to preserve this object's values. dtype must be `bool`. - If a callable, it must expect this object as its only parameter. - other : scalar, DataArray or Dataset, optional + If a callable, the callable is passed this object, and the result is used as + the value for cond. + other : scalar, DataArray, Dataset, or callable, optional Value to use for locations in this object where ``cond`` is False. - By default, these locations filled with NA. + By default, these locations are filled with NA. If a callable, it must + expect this object as its only parameter. drop : bool, default: False If True, coordinate labels that only correspond to False values of the condition are dropped from the result. @@ -1050,7 +1131,7 @@ def where( -------- >>> a = xr.DataArray(np.arange(25).reshape(5, 5), dims=("x", "y")) >>> a - + Size: 200B array([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], @@ -1059,7 +1140,7 @@ def where( Dimensions without coordinates: x, y >>> a.where(a.x + a.y < 4) - + Size: 200B array([[ 0., 1., 2., 3., nan], [ 5., 6., 7., nan, nan], [10., 11., nan, nan, nan], @@ -1068,7 +1149,7 @@ def where( Dimensions without coordinates: x, y >>> a.where(a.x + a.y < 5, -1) - + Size: 200B array([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, -1], [10, 11, 12, -1, -1], @@ -1077,29 +1158,30 @@ def where( Dimensions without coordinates: x, y >>> a.where(a.x + a.y < 4, drop=True) - + Size: 128B array([[ 0., 1., 2., 3.], [ 5., 6., 7., nan], [10., 11., nan, nan], [15., nan, nan, nan]]) Dimensions without coordinates: x, y - >>> a.where(lambda x: x.x + x.y < 4, drop=True) - + >>> a.where(lambda x: x.x + x.y < 4, lambda x: -x) + Size: 200B + array([[ 0, 1, 2, 3, -4], + [ 5, 6, 7, -8, -9], + [ 10, 11, -12, -13, -14], + [ 15, -16, -17, -18, -19], + [-20, -21, -22, -23, -24]]) + Dimensions without coordinates: x, y + + >>> a.where(a.x + a.y < 4, drop=True) + Size: 128B array([[ 0., 1., 2., 3.], [ 5., 6., 7., nan], [10., 11., nan, nan], [15., nan, nan, nan]]) Dimensions without coordinates: x, y - >>> a.where(a.x + a.y < 4, -1, drop=True) - - array([[ 0, 1, 2, 3], - [ 5, 6, 7, -1], - [10, 11, -1, -1], - [15, -1, -1, -1]]) - Dimensions without coordinates: x, y - See Also -------- numpy.where : corresponding numpy function @@ -1111,14 +1193,16 @@ def where( if callable(cond): cond = cond(self) + if callable(other): + other = other(self) if drop: if not isinstance(cond, (Dataset, DataArray)): raise TypeError( - f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r}" + f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r} (or a callable than returns one)." ) - self, cond = align(self, cond) # type: ignore[assignment] + self, cond = align(self, cond) def _dataarray_indexer(dim: Hashable) -> DataArray: return cond.any(dim=(d for d in cond.dims if d != dim)) @@ -1127,8 +1211,8 @@ def _dataset_indexer(dim: Hashable) -> DataArray: cond_wdim = cond.drop_vars( var for var in cond if dim not in cond[var].dims ) - keepany = cond_wdim.any(dim=(d for d in cond.dims.keys() if d != dim)) - return keepany.to_array().any("variable") + keepany = cond_wdim.any(dim=(d for d in cond.dims if d != dim)) + return keepany.to_dataarray().any("variable") _get_indexer = ( _dataarray_indexer if isinstance(cond, DataArray) else _dataset_indexer @@ -1165,9 +1249,7 @@ def close(self) -> None: self._close() self._close = None - def isnull( - self: T_DataWithCoords, keep_attrs: bool | None = None - ) -> T_DataWithCoords: + def isnull(self, keep_attrs: bool | None = None) -> Self: """Test each value in the array for whether it is a missing value. Parameters @@ -1190,11 +1272,11 @@ def isnull( -------- >>> array = xr.DataArray([1, np.nan, 3], dims="x") >>> array - + Size: 24B array([ 1., nan, 3.]) Dimensions without coordinates: x >>> array.isnull() - + Size: 3B array([False, True, False]) Dimensions without coordinates: x """ @@ -1210,9 +1292,7 @@ def isnull( keep_attrs=keep_attrs, ) - def notnull( - self: T_DataWithCoords, keep_attrs: bool | None = None - ) -> T_DataWithCoords: + def notnull(self, keep_attrs: bool | None = None) -> Self: """Test each value in the array for whether it is not a missing value. Parameters @@ -1235,11 +1315,11 @@ def notnull( -------- >>> array = xr.DataArray([1, np.nan, 3], dims="x") >>> array - + Size: 24B array([ 1., nan, 3.]) Dimensions without coordinates: x >>> array.notnull() - + Size: 3B array([ True, False, True]) Dimensions without coordinates: x """ @@ -1255,7 +1335,7 @@ def notnull( keep_attrs=keep_attrs, ) - def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords: + def isin(self, test_elements: Any) -> Self: """Tests each value in the array for whether it is in test elements. Parameters @@ -1274,7 +1354,7 @@ def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords: -------- >>> array = xr.DataArray([1, 2, 3], dims="x") >>> array.isin([1, 3]) - + Size: 3B array([ True, False, True]) Dimensions without coordinates: x @@ -1289,9 +1369,7 @@ def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords: if isinstance(test_elements, Dataset): raise TypeError( - "isin() argument must be convertible to an array: {}".format( - test_elements - ) + f"isin() argument must be convertible to an array: {test_elements}" ) elif isinstance(test_elements, (Variable, DataArray)): # need to explicitly pull out data to support dask arrays as the @@ -1306,7 +1384,7 @@ def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords: ) def astype( - self: T_DataWithCoords, + self, dtype, *, order=None, @@ -1314,7 +1392,7 @@ def astype( subok=None, copy=None, keep_attrs=True, - ) -> T_DataWithCoords: + ) -> Self: """ Copy of the xarray object, with data cast to a specified type. Leaves coordinate dtype unchanged. @@ -1381,7 +1459,7 @@ def astype( dask="allowed", ) - def __enter__(self: T_DataWithCoords) -> T_DataWithCoords: + def __enter__(self) -> Self: return self def __exit__(self, exc_type, exc_value, traceback) -> None: @@ -1394,47 +1472,77 @@ def __getitem__(self, value): @overload def full_like( - other: DataArray, fill_value: Any, dtype: DTypeLikeSave = None -) -> DataArray: - ... + other: DataArray, + fill_value: Any, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> DataArray: ... @overload def full_like( - other: Dataset, fill_value: Any, dtype: DTypeMaybeMapping = None -) -> Dataset: - ... + other: Dataset, + fill_value: Any, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset: ... @overload def full_like( - other: Variable, fill_value: Any, dtype: DTypeLikeSave = None -) -> Variable: - ... + other: Variable, + fill_value: Any, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Variable: ... @overload def full_like( - other: Dataset | DataArray, fill_value: Any, dtype: DTypeMaybeMapping = None -) -> Dataset | DataArray: - ... + other: Dataset | DataArray, + fill_value: Any, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = {}, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray: ... @overload def full_like( other: Dataset | DataArray | Variable, fill_value: Any, - dtype: DTypeMaybeMapping = None, -) -> Dataset | DataArray | Variable: - ... + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray | Variable: ... def full_like( other: Dataset | DataArray | Variable, fill_value: Any, - dtype: DTypeMaybeMapping = None, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Dataset | DataArray | Variable: - """Return a new object with the same shape and type as a given object. + """ + Return a new object with the same shape and type as a given object. + + Returned object will be chunked if if the given object is chunked, or if chunks or chunked_array_type are specified. Parameters ---------- @@ -1447,6 +1555,18 @@ def full_like( dtype : dtype or dict-like of dtype, optional dtype of the new array. If a dict-like, maps dtypes to variables. If omitted, it defaults to other.dtype. + chunks : int, "auto", tuple of int or mapping of Hashable to int, optional + Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, ``(5, 5)`` or + ``{"x": 5, "y": 5}``. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. Returns ------- @@ -1464,72 +1584,72 @@ def full_like( ... coords={"lat": [1, 2], "lon": [0, 1, 2]}, ... ) >>> x - + Size: 48B array([[0, 1, 2], [3, 4, 5]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 >>> xr.full_like(x, 1) - + Size: 48B array([[1, 1, 1], [1, 1, 1]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 >>> xr.full_like(x, 0.5) - + Size: 48B array([[0, 0, 0], [0, 0, 0]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 >>> xr.full_like(x, 0.5, dtype=np.double) - + Size: 48B array([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 >>> xr.full_like(x, np.nan, dtype=np.double) - + Size: 48B array([[nan, nan, nan], [nan, nan, nan]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 >>> ds = xr.Dataset( ... {"a": ("x", [3, 5, 2]), "b": ("x", [9, 1, 0])}, coords={"x": [2, 4, 6]} ... ) >>> ds - + Size: 72B Dimensions: (x: 3) Coordinates: - * x (x) int64 2 4 6 + * x (x) int64 24B 2 4 6 Data variables: - a (x) int64 3 5 2 - b (x) int64 9 1 0 + a (x) int64 24B 3 5 2 + b (x) int64 24B 9 1 0 >>> xr.full_like(ds, fill_value={"a": 1, "b": 2}) - + Size: 72B Dimensions: (x: 3) Coordinates: - * x (x) int64 2 4 6 + * x (x) int64 24B 2 4 6 Data variables: - a (x) int64 1 1 1 - b (x) int64 2 2 2 + a (x) int64 24B 1 1 1 + b (x) int64 24B 2 2 2 >>> xr.full_like(ds, fill_value={"a": 1, "b": 2}, dtype={"a": bool, "b": float}) - + Size: 51B Dimensions: (x: 3) Coordinates: - * x (x) int64 2 4 6 + * x (x) int64 24B 2 4 6 Data variables: - a (x) bool True True True - b (x) float64 2.0 2.0 2.0 + a (x) bool 3B True True True + b (x) float64 24B 2.0 2.0 2.0 See Also -------- @@ -1560,7 +1680,12 @@ def full_like( data_vars = { k: _full_like_variable( - v.variable, fill_value.get(k, dtypes.NA), dtype_.get(k, None) + v.variable, + fill_value.get(k, dtypes.NA), + dtype_.get(k, None), + chunks, + chunked_array_type, + from_array_kwargs, ) for k, v in other.data_vars.items() } @@ -1569,7 +1694,14 @@ def full_like( if isinstance(dtype, Mapping): raise ValueError("'dtype' cannot be dict-like when passing a DataArray") return DataArray( - _full_like_variable(other.variable, fill_value, dtype), + _full_like_variable( + other.variable, + fill_value, + dtype, + chunks, + chunked_array_type, + from_array_kwargs, + ), dims=other.dims, coords=other.coords, attrs=other.attrs, @@ -1578,13 +1710,20 @@ def full_like( elif isinstance(other, Variable): if isinstance(dtype, Mapping): raise ValueError("'dtype' cannot be dict-like when passing a Variable") - return _full_like_variable(other, fill_value, dtype) + return _full_like_variable( + other, fill_value, dtype, chunks, chunked_array_type, from_array_kwargs + ) else: raise TypeError("Expected DataArray, Dataset, or Variable") def _full_like_variable( - other: Variable, fill_value: Any, dtype: DTypeLike = None + other: Variable, + fill_value: Any, + dtype: DTypeLike | None = None, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Variable: """Inner function of full_like, where other must be a variable""" from xarray.core.variable import Variable @@ -1592,13 +1731,28 @@ def _full_like_variable( if fill_value is dtypes.NA: fill_value = dtypes.get_fill_value(dtype if dtype is not None else other.dtype) - if is_duck_dask_array(other.data): - import dask.array + if ( + is_chunked_array(other.data) + or chunked_array_type is not None + or chunks is not None + ): + if chunked_array_type is None: + chunkmanager = get_chunked_array_type(other.data) + else: + chunkmanager = guess_chunkmanager(chunked_array_type) if dtype is None: dtype = other.dtype - data = dask.array.full( - other.shape, fill_value, dtype=dtype, chunks=other.data.chunks + + if from_array_kwargs is None: + from_array_kwargs = {} + + data = chunkmanager.array_api.full( + other.shape, + fill_value, + dtype=dtype, + chunks=chunks if chunks else other.data.chunks, + **from_array_kwargs, ) else: data = np.full_like(other.data, fill_value, dtype=dtype) @@ -1607,36 +1761,67 @@ def _full_like_variable( @overload -def zeros_like(other: DataArray, dtype: DTypeLikeSave = None) -> DataArray: - ... +def zeros_like( + other: DataArray, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> DataArray: ... @overload -def zeros_like(other: Dataset, dtype: DTypeMaybeMapping = None) -> Dataset: - ... +def zeros_like( + other: Dataset, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset: ... @overload -def zeros_like(other: Variable, dtype: DTypeLikeSave = None) -> Variable: - ... +def zeros_like( + other: Variable, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Variable: ... @overload def zeros_like( - other: Dataset | DataArray, dtype: DTypeMaybeMapping = None -) -> Dataset | DataArray: - ... + other: Dataset | DataArray, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray: ... @overload def zeros_like( - other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None -) -> Dataset | DataArray | Variable: - ... + other: Dataset | DataArray | Variable, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray | Variable: ... def zeros_like( - other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None + other: Dataset | DataArray | Variable, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Dataset | DataArray | Variable: """Return a new object of zeros with the same shape and type as a given dataarray or dataset. @@ -1647,6 +1832,18 @@ def zeros_like( The reference object. The output will have the same dimensions and coordinates as this object. dtype : dtype, optional dtype of the new array. If omitted, it defaults to other.dtype. + chunks : int, "auto", tuple of int or mapping of Hashable to int, optional + Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, ``(5, 5)`` or + ``{"x": 5, "y": 5}``. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. Returns ------- @@ -1661,28 +1858,28 @@ def zeros_like( ... coords={"lat": [1, 2], "lon": [0, 1, 2]}, ... ) >>> x - + Size: 48B array([[0, 1, 2], [3, 4, 5]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 >>> xr.zeros_like(x) - + Size: 48B array([[0, 0, 0], [0, 0, 0]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 >>> xr.zeros_like(x, dtype=float) - + Size: 48B array([[0., 0., 0.], [0., 0., 0.]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 See Also -------- @@ -1690,40 +1887,78 @@ def zeros_like( full_like """ - return full_like(other, 0, dtype) + return full_like( + other, + 0, + dtype, + chunks=chunks, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, + ) @overload -def ones_like(other: DataArray, dtype: DTypeLikeSave = None) -> DataArray: - ... +def ones_like( + other: DataArray, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> DataArray: ... @overload -def ones_like(other: Dataset, dtype: DTypeMaybeMapping = None) -> Dataset: - ... +def ones_like( + other: Dataset, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset: ... @overload -def ones_like(other: Variable, dtype: DTypeLikeSave = None) -> Variable: - ... +def ones_like( + other: Variable, + dtype: DTypeLikeSave | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Variable: ... @overload def ones_like( - other: Dataset | DataArray, dtype: DTypeMaybeMapping = None -) -> Dataset | DataArray: - ... + other: Dataset | DataArray, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray: ... @overload def ones_like( - other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None -) -> Dataset | DataArray | Variable: - ... + other: Dataset | DataArray | Variable, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, +) -> Dataset | DataArray | Variable: ... def ones_like( - other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None + other: Dataset | DataArray | Variable, + dtype: DTypeMaybeMapping | None = None, + *, + chunks: T_Chunks = None, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, ) -> Dataset | DataArray | Variable: """Return a new object of ones with the same shape and type as a given dataarray or dataset. @@ -1734,6 +1969,18 @@ def ones_like( The reference object. The output will have the same dimensions and coordinates as this object. dtype : dtype, optional dtype of the new array. If omitted, it defaults to other.dtype. + chunks : int, "auto", tuple of int or mapping of Hashable to int, optional + Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, ``(5, 5)`` or + ``{"x": 5, "y": 5}``. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. Returns ------- @@ -1748,20 +1995,20 @@ def ones_like( ... coords={"lat": [1, 2], "lon": [0, 1, 2]}, ... ) >>> x - + Size: 48B array([[0, 1, 2], [3, 4, 5]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 >>> xr.ones_like(x) - + Size: 48B array([[1, 1, 1], [1, 1, 1]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 16B 1 2 + * lon (lon) int64 24B 0 1 2 See Also -------- @@ -1769,7 +2016,14 @@ def ones_like( full_like """ - return full_like(other, 1, dtype) + return full_like( + other, + 1, + dtype, + chunks=chunks, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, + ) def get_chunksizes( diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 9af7fcd89a4..f29f6c4dd35 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1,6 +1,7 @@ """ Functions for applying functions that act on arrays to xarray's labeled data. """ + from __future__ import annotations import functools @@ -8,8 +9,8 @@ import operator import warnings from collections import Counter -from collections.abc import Hashable, Iterable, Mapping, Sequence, Set -from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, overload +from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence, Set +from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, cast, overload import numpy as np @@ -17,13 +18,16 @@ from xarray.core.alignment import align, deep_align from xarray.core.common import zeros_like from xarray.core.duck_array_ops import datetime_to_numeric +from xarray.core.formatting import limit_lines from xarray.core.indexes import Index, filter_indexes_from_coords from xarray.core.merge import merge_attrs, merge_coordinates_without_align from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.pycompat import is_duck_dask_array from xarray.core.types import Dims, T_DataArray -from xarray.core.utils import is_dict_like, is_scalar +from xarray.core.utils import is_dict_like, is_duck_dask_array, is_scalar, parse_dims from xarray.core.variable import Variable +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import is_chunked_array +from xarray.util.deprecation_helpers import deprecate_dims if TYPE_CHECKING: from xarray.core.coordinates import Coordinates @@ -31,6 +35,8 @@ from xarray.core.dataset import Dataset from xarray.core.types import CombineAttrsOptions, JoinOptions + MissingCoreDimOptions = Literal["raise", "copy", "drop"] + _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") _JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"}) @@ -41,6 +47,7 @@ def _first_of_type(args, kind): for arg in args: if isinstance(arg, kind): return arg + raise ValueError("This should be unreachable.") @@ -158,7 +165,7 @@ def to_gufunc_string(self, exclude_dims=frozenset()): if exclude_dims: exclude_dims = [self.dims_map[dim] for dim in exclude_dims] - counter = Counter() + counter: Counter = Counter() def _enumerate(dim): if dim in exclude_dims: @@ -284,8 +291,14 @@ def apply_dataarray_vfunc( from xarray.core.dataarray import DataArray if len(args) > 1: - args = deep_align( - args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False + args = tuple( + deep_align( + args, + join=join, + copy=False, + exclude=exclude_dims, + raise_on_invalid=False, + ) ) objs = _all_of_type(args, DataArray) @@ -346,7 +359,7 @@ def assert_and_return_exact_match(all_keys): if keys != first_keys: raise ValueError( "exact match required for all data variable names, " - f"but {keys!r} != {first_keys!r}" + f"but {list(keys)} != {list(first_keys)}: {set(keys) ^ set(first_keys)} are not in both." ) return first_keys @@ -375,7 +388,7 @@ def collect_dict_values( ] -def _as_variables_or_variable(arg): +def _as_variables_or_variable(arg) -> Variable | tuple[Variable]: try: return arg.variables except AttributeError: @@ -395,8 +408,39 @@ def _unpack_dict_tuples( return out +def _check_core_dims(signature, variable_args, name): + """ + Check if an arg has all the core dims required by the signature. + + Slightly awkward design, of returning the error message. But we want to + give a detailed error message, which requires inspecting the variable in + the inner loop. + """ + missing = [] + for i, (core_dims, variable_arg) in enumerate( + zip(signature.input_core_dims, variable_args) + ): + # Check whether all the dims are on the variable. Note that we need the + # `hasattr` to check for a dims property, to protect against the case where + # a numpy array is passed in. + if hasattr(variable_arg, "dims") and set(core_dims) - set(variable_arg.dims): + missing += [[i, variable_arg, core_dims]] + if missing: + message = "" + for i, variable_arg, core_dims in missing: + message += f"Missing core dims {set(core_dims) - set(variable_arg.dims)} from arg number {i + 1} on a variable named `{name}`:\n{variable_arg}\n\n" + message += "Either add the core dimension, or if passing a dataset alternatively pass `on_missing_core_dim` as `copy` or `drop`. " + return message + return True + + def apply_dict_of_variables_vfunc( - func, *args, signature: _UFuncSignature, join="inner", fill_value=None + func, + *args, + signature: _UFuncSignature, + join="inner", + fill_value=None, + on_missing_core_dim: MissingCoreDimOptions = "raise", ): """Apply a variable level function over dicts of DataArray, DataArray, Variable and ndarray objects. @@ -407,7 +451,20 @@ def apply_dict_of_variables_vfunc( result_vars = {} for name, variable_args in zip(names, grouped_by_name): - result_vars[name] = func(*variable_args) + core_dim_present = _check_core_dims(signature, variable_args, name) + if core_dim_present is True: + result_vars[name] = func(*variable_args) + else: + if on_missing_core_dim == "raise": + raise ValueError(core_dim_present) + elif on_missing_core_dim == "copy": + result_vars[name] = variable_args[0] + elif on_missing_core_dim == "drop": + pass + else: + raise ValueError( + f"Invalid value for `on_missing_core_dim`: {on_missing_core_dim!r}" + ) if signature.num_outputs > 1: return _unpack_dict_tuples(result_vars, signature.num_outputs) @@ -440,6 +497,7 @@ def apply_dataset_vfunc( fill_value=_NO_FILL_VALUE, exclude_dims=frozenset(), keep_attrs="override", + on_missing_core_dim: MissingCoreDimOptions = "raise", ) -> Dataset | tuple[Dataset, ...]: """Apply a variable level function over Dataset, dict of DataArray, DataArray, Variable and/or ndarray objects. @@ -456,8 +514,14 @@ def apply_dataset_vfunc( objs = _all_of_type(args, Dataset) if len(args) > 1: - args = deep_align( - args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False + args = tuple( + deep_align( + args, + join=join, + copy=False, + exclude=exclude_dims, + raise_on_invalid=False, + ) ) list_of_coords, list_of_indexes = build_output_coords_and_indexes( @@ -466,7 +530,12 @@ def apply_dataset_vfunc( args = tuple(getattr(arg, "data_vars", arg) for arg in args) result_vars = apply_dict_of_variables_vfunc( - func, *args, signature=signature, join=dataset_join, fill_value=fill_value + func, + *args, + signature=signature, + join=dataset_join, + fill_value=fill_value, + on_missing_core_dim=on_missing_core_dim, ) out: Dataset | tuple[Dataset, ...] @@ -515,18 +584,20 @@ def apply_groupby_func(func, *args): groupbys = [arg for arg in args if isinstance(arg, GroupBy)] assert groupbys, "must have at least one groupby to iterate over" first_groupby = groupbys[0] - if any(not first_groupby._group.equals(gb._group) for gb in groupbys[1:]): + (grouper,) = first_groupby.groupers + if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]): # type: ignore[union-attr] raise ValueError( "apply_ufunc can only perform operations over " "multiple GroupBy objects at once if they are all " "grouped the same way" ) - grouped_dim = first_groupby._group.name - unique_values = first_groupby._unique_coord.values + grouped_dim = grouper.name + unique_values = grouper.unique_coord.values iterators = [] for arg in args: + iterator: Iterator[Any] if isinstance(arg, GroupBy): iterator = (value for _, value in arg) elif hasattr(arg, "dims") and grouped_dim in arg.dims: @@ -541,9 +612,9 @@ def apply_groupby_func(func, *args): iterator = itertools.repeat(arg) iterators.append(iterator) - applied = (func(*zipped_args) for zipped_args in zip(*iterators)) + applied: Iterator = (func(*zipped_args) for zipped_args in zip(*iterators)) applied_example, applied = peek_at(applied) - combine = first_groupby._combine + combine = first_groupby._combine # type: ignore[attr-defined] if isinstance(applied_example, tuple): combined = tuple(combine(output) for output in zip(*applied)) else: @@ -593,17 +664,9 @@ def broadcast_compat_data( return data set_old_dims = set(old_dims) - missing_core_dims = [d for d in core_dims if d not in set_old_dims] - if missing_core_dims: - raise ValueError( - "operand to apply_ufunc has required core dimensions {}, but " - "some of these dimensions are absent on an input variable: {}".format( - list(core_dims), missing_core_dims - ) - ) - set_new_dims = set(new_dims) unexpected_dims = [d for d in old_dims if d not in set_new_dims] + if unexpected_dims: raise ValueError( "operand to apply_ufunc encountered unexpected " @@ -657,6 +720,7 @@ def apply_variable_ufunc( dask_gufunc_kwargs=None, ) -> Variable | tuple[Variable, ...]: """Apply a ndarray level function over Variable and/or ndarray objects.""" + from xarray.core.formatting import short_array_repr from xarray.core.variable import Variable, as_compatible_data dim_sizes = unified_dim_sizes( @@ -668,22 +732,26 @@ def apply_variable_ufunc( output_dims = [broadcast_dims + out for out in signature.output_core_dims] input_data = [ - broadcast_compat_data(arg, broadcast_dims, core_dims) - if isinstance(arg, Variable) - else arg + ( + broadcast_compat_data(arg, broadcast_dims, core_dims) + if isinstance(arg, Variable) + else arg + ) for arg, core_dims in zip(args, signature.input_core_dims) ] - if any(is_duck_dask_array(array) for array in input_data): + if any(is_chunked_array(array) for array in input_data): if dask == "forbidden": raise ValueError( - "apply_ufunc encountered a dask array on an " - "argument, but handling for dask arrays has not " + "apply_ufunc encountered a chunked array on an " + "argument, but handling for chunked arrays has not " "been enabled. Either set the ``dask`` argument " "or load your data into memory first with " "``.load()`` or ``.compute()``" ) elif dask == "parallelized": + chunkmanager = get_chunked_array_type(*input_data) + numpy_func = func if dask_gufunc_kwargs is None: @@ -696,7 +764,7 @@ def apply_variable_ufunc( for n, (data, core_dims) in enumerate( zip(input_data, signature.input_core_dims) ): - if is_duck_dask_array(data): + if is_chunked_array(data): # core dimensions cannot span multiple chunks for axis, dim in enumerate(core_dims, start=-len(core_dims)): if len(data.chunks[axis]) != 1: @@ -704,7 +772,7 @@ def apply_variable_ufunc( f"dimension {dim} on {n}th function argument to " "apply_ufunc with dask='parallelized' consists of " "multiple chunks, but is also a core dimension. To " - "fix, either rechunk into a single dask array chunk along " + "fix, either rechunk into a single array chunk along " f"this dimension, i.e., ``.chunk(dict({dim}=-1))``, or " "pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` " "but beware that this may significantly increase memory usage." @@ -731,9 +799,7 @@ def apply_variable_ufunc( ) def func(*arrays): - import dask.array as da - - res = da.apply_gufunc( + res = chunkmanager.apply_gufunc( numpy_func, signature.to_gufunc_string(exclude_dims), *arrays, @@ -748,8 +814,7 @@ def func(*arrays): pass else: raise ValueError( - "unknown setting for dask array handling in " - "apply_ufunc: {}".format(dask) + f"unknown setting for chunked array handling in apply_ufunc: {dask}" ) else: if vectorize: @@ -765,11 +830,11 @@ def func(*arrays): not isinstance(result_data, tuple) or len(result_data) != signature.num_outputs ): raise ValueError( - "applied function does not have the number of " - "outputs specified in the ufunc signature. " - "Result is not a tuple of {} elements: {!r}".format( - signature.num_outputs, result_data - ) + f"applied function does not have the number of " + f"outputs specified in the ufunc signature. " + f"Received a {type(result_data)} with {len(result_data)} elements. " + f"Expected a tuple of {signature.num_outputs} elements:\n\n" + f"{limit_lines(repr(result_data), limit=10)}" ) objs = _all_of_type(args, Variable) @@ -783,21 +848,22 @@ def func(*arrays): data = as_compatible_data(data) if data.ndim != len(dims): raise ValueError( - "applied function returned data with unexpected " + "applied function returned data with an unexpected " f"number of dimensions. Received {data.ndim} dimension(s) but " - f"expected {len(dims)} dimensions with names: {dims!r}" + f"expected {len(dims)} dimensions with names {dims!r}, from:\n\n" + f"{short_array_repr(data)}" ) var = Variable(dims, data, fastpath=True) for dim, new_size in var.sizes.items(): if dim in dim_sizes and new_size != dim_sizes[dim]: raise ValueError( - "size of dimension {!r} on inputs was unexpectedly " - "changed by applied function from {} to {}. Only " + f"size of dimension '{dim}' on inputs was unexpectedly " + f"changed by applied function from {dim_sizes[dim]} to {new_size}. Only " "dimensions specified in ``exclude_dims`` with " - "xarray.apply_ufunc are allowed to change size.".format( - dim, dim_sizes[dim], new_size - ) + "xarray.apply_ufunc are allowed to change size. " + "The data returned was:\n\n" + f"{short_array_repr(data)}" ) var.attrs = attrs @@ -811,7 +877,7 @@ def func(*arrays): def apply_array_ufunc(func, *args, dask="forbidden"): """Apply a ndarray level function over ndarray objects.""" - if any(is_duck_dask_array(arg) for arg in args): + if any(is_chunked_array(arg) for arg in args): if dask == "forbidden": raise ValueError( "apply_ufunc encountered a dask array on an " @@ -844,11 +910,12 @@ def apply_ufunc( dataset_fill_value: object = _NO_FILL_VALUE, keep_attrs: bool | str | None = None, kwargs: Mapping | None = None, - dask: str = "forbidden", + dask: Literal["forbidden", "allowed", "parallelized"] = "forbidden", output_dtypes: Sequence | None = None, output_sizes: Mapping[Any, int] | None = None, meta: Any = None, dask_gufunc_kwargs: dict[str, Any] | None = None, + on_missing_core_dim: MissingCoreDimOptions = "raise", ) -> Any: """Apply a vectorized function for unlabeled arrays on xarray objects. @@ -962,6 +1029,8 @@ def apply_ufunc( :py:func:`dask.array.apply_gufunc`. ``meta`` should be given in the ``dask_gufunc_kwargs`` parameter . It will be removed as direct parameter a future version. + on_missing_core_dim : {"raise", "copy", "drop"}, default: "raise" + How to handle missing core dimensions on input variables. Returns ------- @@ -990,10 +1059,10 @@ def apply_ufunc( >>> array = xr.DataArray([1, 2, 3], coords=[("x", [0.1, 0.2, 0.3])]) >>> magnitude(array, -array) - + Size: 24B array([1.41421356, 2.82842712, 4.24264069]) Coordinates: - * x (x) float64 0.1 0.2 0.3 + * x (x) float64 24B 0.1 0.2 0.3 Plain scalars, numpy arrays and a mix of these with xarray objects is also supported: @@ -1003,10 +1072,10 @@ def apply_ufunc( >>> magnitude(3, np.array([0, 4])) array([3., 5.]) >>> magnitude(array, 0) - + Size: 24B array([1., 2., 3.]) Coordinates: - * x (x) float64 0.1 0.2 0.3 + * x (x) float64 24B 0.1 0.2 0.3 Other examples of how you could use ``apply_ufunc`` to write functions to (very nearly) replicate existing xarray functionality: @@ -1076,9 +1145,13 @@ def apply_ufunc( numba.guvectorize dask.array.apply_gufunc xarray.map_blocks + :ref:`dask.automatic-parallelization` User guide describing :py:func:`apply_ufunc` and :py:func:`map_blocks`. + :doc:`xarray-tutorial:advanced/apply_ufunc/apply_ufunc` + Advanced Tutorial on applying numpy function using :py:func:`apply_ufunc` + References ---------- .. [1] https://numpy.org/doc/stable/reference/ufuncs.html @@ -1190,6 +1263,7 @@ def apply_ufunc( dataset_join=dataset_join, fill_value=dataset_fill_value, keep_attrs=keep_attrs, + on_missing_core_dim=on_missing_core_dim, ) # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc elif any(isinstance(a, DataArray) for a in args): @@ -1210,7 +1284,11 @@ def apply_ufunc( def cov( - da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None, ddof: int = 1 + da_a: T_DataArray, + da_b: T_DataArray, + dim: Dims = None, + ddof: int = 1, + weights: T_DataArray | None = None, ) -> T_DataArray: """ Compute covariance between two DataArray objects along a shared dimension. @@ -1226,6 +1304,8 @@ def cov( ddof : int, default: 1 If ddof=1, covariance is normalized by N-1, giving an unbiased estimate, else normalization is by N. + weights : DataArray, optional + Array of weights. Returns ------- @@ -1248,13 +1328,13 @@ def cov( ... ], ... ) >>> da_a - + Size: 72B array([[1. , 2. , 3. ], [0.1, 0.2, 0.3], [3.2, 0.6, 1.8]]) Coordinates: - * space (space) >> da_b = DataArray( ... np.array([[0.2, 0.4, 0.6], [15, 10, 5], [3.2, 0.6, 1.8]]), ... dims=("space", "time"), @@ -1264,34 +1344,58 @@ def cov( ... ], ... ) >>> da_b - + Size: 72B array([[ 0.2, 0.4, 0.6], [15. , 10. , 5. ], [ 3.2, 0.6, 1.8]]) Coordinates: - * space (space) >> xr.cov(da_a, da_b) - + Size: 8B array(-3.53055556) >>> xr.cov(da_a, da_b, dim="time") - + Size: 24B array([ 0.2 , -0.5 , 1.69333333]) Coordinates: - * space (space) >> weights = DataArray( + ... [4, 2, 1], + ... dims=("space"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ], + ... ) + >>> weights + Size: 24B + array([4, 2, 1]) + Coordinates: + * space (space) >> xr.cov(da_a, da_b, dim="space", weights=weights) + Size: 24B + array([-4.69346939, -4.49632653, -3.37959184]) + Coordinates: + * time (time) datetime64[ns] 24B 2000-01-01 2000-01-02 2000-01-03 """ from xarray.core.dataarray import DataArray if any(not isinstance(arr, DataArray) for arr in [da_a, da_b]): raise TypeError( "Only xr.DataArray is supported." - "Given {}.".format([type(arr) for arr in [da_a, da_b]]) + f"Given {[type(arr) for arr in [da_a, da_b]]}." ) + if weights is not None: + if not isinstance(weights, DataArray): + raise TypeError(f"Only xr.DataArray is supported. Given {type(weights)}.") + return _cov_corr(da_a, da_b, weights=weights, dim=dim, ddof=ddof, method="cov") - return _cov_corr(da_a, da_b, dim=dim, ddof=ddof, method="cov") - -def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray: +def corr( + da_a: T_DataArray, + da_b: T_DataArray, + dim: Dims = None, + weights: T_DataArray | None = None, +) -> T_DataArray: """ Compute the Pearson correlation coefficient between two DataArray objects along a shared dimension. @@ -1304,6 +1408,8 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray: Array to compute. dim : str, iterable of hashable, "..." or None, optional The dimension along which the correlation will be computed + weights : DataArray, optional + Array of weights. Returns ------- @@ -1326,13 +1432,13 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray: ... ], ... ) >>> da_a - + Size: 72B array([[1. , 2. , 3. ], [0.1, 0.2, 0.3], [3.2, 0.6, 1.8]]) Coordinates: - * space (space) >> da_b = DataArray( ... np.array([[0.2, 0.4, 0.6], [15, 10, 5], [3.2, 0.6, 1.8]]), ... dims=("space", "time"), @@ -1342,36 +1448,56 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray: ... ], ... ) >>> da_b - + Size: 72B array([[ 0.2, 0.4, 0.6], [15. , 10. , 5. ], [ 3.2, 0.6, 1.8]]) Coordinates: - * space (space) >> xr.corr(da_a, da_b) - + Size: 8B array(-0.57087777) >>> xr.corr(da_a, da_b, dim="time") - + Size: 24B array([ 1., -1., 1.]) Coordinates: - * space (space) >> weights = DataArray( + ... [4, 2, 1], + ... dims=("space"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ], + ... ) + >>> weights + Size: 24B + array([4, 2, 1]) + Coordinates: + * space (space) >> xr.corr(da_a, da_b, dim="space", weights=weights) + Size: 24B + array([-0.50240504, -0.83215028, -0.99057446]) + Coordinates: + * time (time) datetime64[ns] 24B 2000-01-01 2000-01-02 2000-01-03 """ from xarray.core.dataarray import DataArray if any(not isinstance(arr, DataArray) for arr in [da_a, da_b]): raise TypeError( "Only xr.DataArray is supported." - "Given {}.".format([type(arr) for arr in [da_a, da_b]]) + f"Given {[type(arr) for arr in [da_a, da_b]]}." ) - - return _cov_corr(da_a, da_b, dim=dim, method="corr") + if weights is not None: + if not isinstance(weights, DataArray): + raise TypeError(f"Only xr.DataArray is supported. Given {type(weights)}.") + return _cov_corr(da_a, da_b, weights=weights, dim=dim, method="corr") def _cov_corr( da_a: T_DataArray, da_b: T_DataArray, + weights: T_DataArray | None = None, dim: Dims = None, ddof: int = 0, method: Literal["cov", "corr", None] = None, @@ -1387,28 +1513,46 @@ def _cov_corr( valid_values = da_a.notnull() & da_b.notnull() da_a = da_a.where(valid_values) da_b = da_b.where(valid_values) - valid_count = valid_values.sum(dim) - ddof # 3. Detrend along the given dim - demeaned_da_a = da_a - da_a.mean(dim=dim) - demeaned_da_b = da_b - da_b.mean(dim=dim) + if weights is not None: + demeaned_da_a = da_a - da_a.weighted(weights).mean(dim=dim) + demeaned_da_b = da_b - da_b.weighted(weights).mean(dim=dim) + else: + demeaned_da_a = da_a - da_a.mean(dim=dim) + demeaned_da_b = da_b - da_b.mean(dim=dim) # 4. Compute covariance along the given dim # N.B. `skipna=True` is required or auto-covariance is computed incorrectly. E.g. # Try xr.cov(da,da) for da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]) - cov = (demeaned_da_a.conj() * demeaned_da_b).sum( - dim=dim, skipna=True, min_count=1 - ) / (valid_count) + if weights is not None: + cov = ( + (demeaned_da_a.conj() * demeaned_da_b) + .weighted(weights) + .mean(dim=dim, skipna=True) + ) + else: + cov = (demeaned_da_a.conj() * demeaned_da_b).mean(dim=dim, skipna=True) if method == "cov": - return cov # type: ignore[return-value] + # Adjust covariance for degrees of freedom + valid_count = valid_values.sum(dim) + adjust = valid_count / (valid_count - ddof) + # I think the cast is required because of `T_DataArray` + `T_Xarray` (would be + # the same with `T_DatasetOrArray`) + # https://github.com/pydata/xarray/pull/8384#issuecomment-1784228026 + return cast(T_DataArray, cov * adjust) else: - # compute std + corr - da_a_std = da_a.std(dim=dim) - da_b_std = da_b.std(dim=dim) + # Compute std and corr + if weights is not None: + da_a_std = da_a.weighted(weights).std(dim=dim) + da_b_std = da_b.weighted(weights).std(dim=dim) + else: + da_a_std = da_a.std(dim=dim) + da_b_std = da_b.std(dim=dim) corr = cov / (da_a_std * da_b_std) - return corr # type: ignore[return-value] + return cast(T_DataArray, corr) def cross( @@ -1441,7 +1585,7 @@ def cross( >>> a = xr.DataArray([1, 2, 3]) >>> b = xr.DataArray([4, 5, 6]) >>> xr.cross(a, b, dim="dim_0") - + Size: 24B array([-3, 6, -3]) Dimensions without coordinates: dim_0 @@ -1451,7 +1595,7 @@ def cross( >>> a = xr.DataArray([1, 2]) >>> b = xr.DataArray([4, 5]) >>> xr.cross(a, b, dim="dim_0") - + Size: 8B array(-3) Vector cross-product with 3 dimensions but zeros at the last axis @@ -1460,7 +1604,7 @@ def cross( >>> a = xr.DataArray([1, 2, 0]) >>> b = xr.DataArray([4, 5, 0]) >>> xr.cross(a, b, dim="dim_0") - + Size: 24B array([ 0, 0, -3]) Dimensions without coordinates: dim_0 @@ -1477,10 +1621,10 @@ def cross( ... coords=dict(cartesian=(["cartesian"], ["x", "y", "z"])), ... ) >>> xr.cross(a, b, dim="cartesian") - + Size: 24B array([12, -6, -3]) Coordinates: - * cartesian (cartesian) >> xr.cross(a, b, dim="cartesian") - + Size: 24B array([-10, 2, 5]) Coordinates: - * cartesian (cartesian) >> xr.cross(a, b, dim="cartesian") - + Size: 48B array([[-3, 6, -3], [ 3, -6, 3]]) Coordinates: - * time (time) int64 0 1 - * cartesian (cartesian) >> ds_a = xr.Dataset(dict(x=("dim_0", [1]), y=("dim_0", [2]), z=("dim_0", [3]))) >>> ds_b = xr.Dataset(dict(x=("dim_0", [4]), y=("dim_0", [5]), z=("dim_0", [6]))) >>> c = xr.cross( - ... ds_a.to_array("cartesian"), ds_b.to_array("cartesian"), dim="cartesian" + ... ds_a.to_dataarray("cartesian"), + ... ds_b.to_dataarray("cartesian"), + ... dim="cartesian", ... ) >>> c.to_dataset(dim="cartesian") - + Size: 24B Dimensions: (dim_0: 1) Dimensions without coordinates: dim_0 Data variables: - x (dim_0) int64 -3 - y (dim_0) int64 6 - z (dim_0) int64 -3 + x (dim_0) int64 8B -3 + y (dim_0) int64 8B 6 + z (dim_0) int64 8B -3 See Also -------- @@ -1619,29 +1765,41 @@ def cross( return c +@deprecate_dims def dot( *arrays, - dims: Dims = None, + dim: Dims = None, **kwargs: Any, ): - """Generalized dot product for xarray objects. Like np.einsum, but - provides a simpler interface based on array dimensions. + """Generalized dot product for xarray objects. Like ``np.einsum``, but + provides a simpler interface based on array dimension names. Parameters ---------- *arrays : DataArray or Variable Arrays to compute. - dims : str, iterable of hashable, "..." or None, optional + dim : str, iterable of hashable, "..." or None, optional Which dimensions to sum over. Ellipsis ('...') sums over all dimensions. If not specified, then all the common dimensions are summed over. **kwargs : dict - Additional keyword arguments passed to numpy.einsum or - dask.array.einsum + Additional keyword arguments passed to ``numpy.einsum`` or + ``dask.array.einsum`` Returns ------- DataArray + See Also + -------- + numpy.einsum + dask.array.einsum + opt_einsum.contract + + Notes + ----- + We recommend installing the optional ``opt_einsum`` package, or alternatively passing ``optimize=True``, + which is passed through to ``np.einsum``, and works for most array backends. + Examples -------- >>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=["a", "b"]) @@ -1649,14 +1807,14 @@ def dot( >>> da_c = xr.DataArray(np.arange(2 * 3).reshape(2, 3), dims=["c", "d"]) >>> da_a - + Size: 48B array([[0, 1], [2, 3], [4, 5]]) Dimensions without coordinates: a, b >>> da_b - + Size: 96B array([[[ 0, 1], [ 2, 3]], @@ -1668,36 +1826,36 @@ def dot( Dimensions without coordinates: a, b, c >>> da_c - + Size: 48B array([[0, 1, 2], [3, 4, 5]]) Dimensions without coordinates: c, d - >>> xr.dot(da_a, da_b, dims=["a", "b"]) - + >>> xr.dot(da_a, da_b, dim=["a", "b"]) + Size: 16B array([110, 125]) Dimensions without coordinates: c - >>> xr.dot(da_a, da_b, dims=["a"]) - + >>> xr.dot(da_a, da_b, dim=["a"]) + Size: 32B array([[40, 46], [70, 79]]) Dimensions without coordinates: b, c - >>> xr.dot(da_a, da_b, da_c, dims=["b", "c"]) - + >>> xr.dot(da_a, da_b, da_c, dim=["b", "c"]) + Size: 72B array([[ 9, 14, 19], [ 93, 150, 207], [273, 446, 619]]) Dimensions without coordinates: a, d >>> xr.dot(da_a, da_b) - + Size: 16B array([110, 125]) Dimensions without coordinates: c - >>> xr.dot(da_a, da_b, dims=...) - + >>> xr.dot(da_a, da_b, dim=...) + Size: 8B array(235) """ from xarray.core.dataarray import DataArray @@ -1706,7 +1864,7 @@ def dot( if any(not isinstance(arr, (Variable, DataArray)) for arr in arrays): raise TypeError( "Only xr.DataArray and xr.Variable are supported." - "Given {}.".format([type(arr) for arr in arrays]) + f"Given {[type(arr) for arr in arrays]}." ) if len(arrays) == 0: @@ -1720,18 +1878,16 @@ def dot( einsum_axes = "abcdefghijklmnopqrstuvwxyz" dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} - if dims is ...: - dims = all_dims - elif isinstance(dims, str): - dims = (dims,) - elif dims is None: - # find dimensions that occur more than one times + if dim is None: + # find dimensions that occur more than once dim_counts: Counter = Counter() for arr in arrays: dim_counts.update(arr.dims) - dims = tuple(d for d, c in dim_counts.items() if c > 1) + dim = tuple(d for d, c in dim_counts.items() if c > 1) + else: + dim = parse_dims(dim, all_dims=tuple(all_dims)) - dot_dims: set[Hashable] = set(dims) + dot_dims: set[Hashable] = set(dim) # dimensions to be parallelized broadcast_dims = common_dims - dot_dims @@ -1803,16 +1959,16 @@ def where(cond, x, y, keep_attrs=None): ... name="sst", ... ) >>> x - + Size: 80B array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) Coordinates: - * lat (lat) int64 0 1 2 3 4 5 6 7 8 9 + * lat (lat) int64 80B 0 1 2 3 4 5 6 7 8 9 >>> xr.where(x < 0.5, x, x * 100) - + Size: 80B array([ 0. , 0.1, 0.2, 0.3, 0.4, 50. , 60. , 70. , 80. , 90. ]) Coordinates: - * lat (lat) int64 0 1 2 3 4 5 6 7 8 9 + * lat (lat) int64 80B 0 1 2 3 4 5 6 7 8 9 >>> y = xr.DataArray( ... 0.1 * np.arange(9).reshape(3, 3), @@ -1821,27 +1977,27 @@ def where(cond, x, y, keep_attrs=None): ... name="sst", ... ) >>> y - + Size: 72B array([[0. , 0.1, 0.2], [0.3, 0.4, 0.5], [0.6, 0.7, 0.8]]) Coordinates: - * lat (lat) int64 0 1 2 - * lon (lon) int64 10 11 12 + * lat (lat) int64 24B 0 1 2 + * lon (lon) int64 24B 10 11 12 >>> xr.where(y.lat < 1, y, -1) - + Size: 72B array([[ 0. , 0.1, 0.2], [-1. , -1. , -1. ], [-1. , -1. , -1. ]]) Coordinates: - * lat (lat) int64 0 1 2 - * lon (lon) int64 10 11 12 + * lat (lat) int64 24B 0 1 2 + * lon (lon) int64 24B 10 11 12 >>> cond = xr.DataArray([True, False], dims=["x"]) >>> x = xr.DataArray([1, 2], dims=["y"]) >>> xr.where(cond, x, 0) - + Size: 32B array([[1, 2], [0, 0]]) Dimensions without coordinates: x, y @@ -1894,29 +2050,25 @@ def where(cond, x, y, keep_attrs=None): @overload def polyval( coord: DataArray, coeffs: DataArray, degree_dim: Hashable = "degree" -) -> DataArray: - ... +) -> DataArray: ... @overload def polyval( coord: DataArray, coeffs: Dataset, degree_dim: Hashable = "degree" -) -> Dataset: - ... +) -> Dataset: ... @overload def polyval( coord: Dataset, coeffs: DataArray, degree_dim: Hashable = "degree" -) -> Dataset: - ... +) -> Dataset: ... @overload def polyval( coord: Dataset, coeffs: Dataset, degree_dim: Hashable = "degree" -) -> Dataset: - ... +) -> Dataset: ... @overload @@ -1924,8 +2076,7 @@ def polyval( coord: Dataset | DataArray, coeffs: Dataset | DataArray, degree_dim: Hashable = "degree", -) -> Dataset | DataArray: - ... +) -> Dataset | DataArray: ... def polyval( @@ -2012,7 +2163,7 @@ def to_floatable(x: DataArray) -> DataArray: ) elif x.dtype.kind == "m": # timedeltas - return x.astype(float) + return duck_array_ops.astype(x, dtype=float) return x if isinstance(data, Dataset): @@ -2045,9 +2196,13 @@ def _calc_idxminmax( raise ValueError("Must supply 'dim' argument for multidimensional arrays") if dim not in array.dims: - raise KeyError(f'Dimension "{dim}" not in dimension') + raise KeyError( + f"Dimension {dim!r} not found in array dimensions {array.dims!r}" + ) if dim not in array.coords: - raise KeyError(f'Dimension "{dim}" does not have coordinates') + raise KeyError( + f"Dimension {dim!r} is not one of the coordinates {tuple(array.coords.keys())}" + ) # These are dtypes with NaN values argmin and argmax can handle na_dtypes = "cfO" @@ -2060,13 +2215,13 @@ def _calc_idxminmax( # This will run argmin or argmax. indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna) - # Handle dask arrays. - if is_duck_dask_array(array.data): - import dask.array - + # Handle chunked arrays (e.g. dask). + if is_chunked_array(array.data): + chunkmanager = get_chunked_array_type(array.data) chunks = dict(zip(array.dims, array.chunks)) - dask_coord = dask.array.from_array(array[dim].data, chunks=chunks[dim]) - res = indx.copy(data=dask_coord[indx.data.ravel()].reshape(indx.shape)) + dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim]) + data = dask_coord[duck_array_ops.ravel(indx.data)] + res = indx.copy(data=duck_array_ops.reshape(data, indx.shape)) # we need to attach back the dim name res.name = dim else: @@ -2090,23 +2245,19 @@ def _calc_idxminmax( @overload -def unify_chunks(__obj: _T) -> tuple[_T]: - ... +def unify_chunks(__obj: _T) -> tuple[_T]: ... @overload -def unify_chunks(__obj1: _T, __obj2: _U) -> tuple[_T, _U]: - ... +def unify_chunks(__obj1: _T, __obj2: _U) -> tuple[_T, _U]: ... @overload -def unify_chunks(__obj1: _T, __obj2: _U, __obj3: _V) -> tuple[_T, _U, _V]: - ... +def unify_chunks(__obj1: _T, __obj2: _U, __obj3: _V) -> tuple[_T, _U, _V]: ... @overload -def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, ...]: - ... +def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, ...]: ... def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, ...]: @@ -2152,16 +2303,14 @@ def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, .. if not unify_chunks_args: return objects - # Run dask.array.core.unify_chunks - from dask.array.core import unify_chunks - - _, dask_data = unify_chunks(*unify_chunks_args) - dask_data_iter = iter(dask_data) + chunkmanager = get_chunked_array_type(*[arg for arg in unify_chunks_args]) + _, chunked_data = chunkmanager.unify_chunks(*unify_chunks_args) + chunked_data_iter = iter(chunked_data) out: list[Dataset | DataArray] = [] for obj, ds in zip(objects, datasets): for k, v in ds._variables.items(): if v.chunks is not None: - ds._variables[k] = v.copy(data=next(dask_data_iter)) + ds._variables[k] = v.copy(data=next(chunked_data_iter)) out.append(obj._from_temp_dataset(ds) if isinstance(obj, DataArray) else ds) return tuple(out) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index f092911948f..d95cbccd36a 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,8 +1,9 @@ from __future__ import annotations from collections.abc import Hashable, Iterable -from typing import TYPE_CHECKING, Any, cast, overload +from typing import TYPE_CHECKING, Any, Union, overload +import numpy as np import pandas as pd from xarray.core import dtypes, utils @@ -15,7 +16,7 @@ merge_attrs, merge_collected, ) -from xarray.core.types import T_DataArray, T_Dataset +from xarray.core.types import T_DataArray, T_Dataset, T_Variable from xarray.core.variable import Variable from xarray.core.variable import concat as concat_vars @@ -27,41 +28,41 @@ JoinOptions, ) + T_DataVars = Union[ConcatOptions, Iterable[Hashable]] + @overload def concat( objs: Iterable[T_Dataset], - dim: Hashable | T_DataArray | pd.Index, - data_vars: ConcatOptions | list[Hashable] = "all", + dim: Hashable | T_Variable | T_DataArray | pd.Index, + data_vars: T_DataVars = "all", coords: ConcatOptions | list[Hashable] = "different", compat: CompatOptions = "equals", positions: Iterable[Iterable[int]] | None = None, fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", -) -> T_Dataset: - ... +) -> T_Dataset: ... @overload def concat( objs: Iterable[T_DataArray], - dim: Hashable | T_DataArray | pd.Index, - data_vars: ConcatOptions | list[Hashable] = "all", + dim: Hashable | T_Variable | T_DataArray | pd.Index, + data_vars: T_DataVars = "all", coords: ConcatOptions | list[Hashable] = "different", compat: CompatOptions = "equals", positions: Iterable[Iterable[int]] | None = None, fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", -) -> T_DataArray: - ... +) -> T_DataArray: ... def concat( objs, dim, - data_vars="all", + data_vars: T_DataVars = "all", coords="different", compat: CompatOptions = "equals", positions=None, @@ -77,11 +78,11 @@ def concat( xarray objects to concatenate together. Each object is expected to consist of variables and coordinates with matching shapes except for along the concatenated dimension. - dim : Hashable or DataArray or pandas.Index + dim : Hashable or Variable or DataArray or pandas.Index Name of the dimension to concatenate along. This can either be a new dimension name, in which case it is added along axis=0, or an existing dimension name, in which case the location of the dimension is - unchanged. If dimension is provided as a DataArray or Index, its name + unchanged. If dimension is provided as a Variable, DataArray or Index, its name is used as the dimension to concatenate along and the values are added as a coordinate. data_vars : {"minimal", "different", "all"} or list of Hashable, optional @@ -176,46 +177,46 @@ def concat( ... np.arange(6).reshape(2, 3), [("x", ["a", "b"]), ("y", [10, 20, 30])] ... ) >>> da - + Size: 48B array([[0, 1, 2], [3, 4, 5]]) Coordinates: - * x (x) >> xr.concat([da.isel(y=slice(0, 1)), da.isel(y=slice(1, None))], dim="y") - + Size: 48B array([[0, 1, 2], [3, 4, 5]]) Coordinates: - * x (x) >> xr.concat([da.isel(x=0), da.isel(x=1)], "x") - + Size: 48B array([[0, 1, 2], [3, 4, 5]]) Coordinates: - * x (x) >> xr.concat([da.isel(x=0), da.isel(x=1)], "new_dim") - + Size: 48B array([[0, 1, 2], [3, 4, 5]]) Coordinates: - x (new_dim) >> xr.concat([da.isel(x=0), da.isel(x=1)], pd.Index([-90, -100], name="new_dim")) - + Size: 48B array([[0, 1, 2], [3, 4, 5]]) Coordinates: - x (new_dim) ") -class Coordinates(Mapping[Hashable, "T_DataArray"]): +class AbstractCoordinates(Mapping[Hashable, "T_DataArray"]): _data: DataWithCoords __slots__ = ("_data",) def __getitem__(self, key: Hashable) -> T_DataArray: raise NotImplementedError() - def __setitem__(self, key: Hashable, value: Any) -> None: - self.update({key: value}) - @property def _names(self) -> set[Hashable]: raise NotImplementedError() @property - def dims(self) -> Mapping[Hashable, int] | tuple[Hashable, ...]: + def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]: raise NotImplementedError() @property @@ -54,10 +63,22 @@ def dtypes(self) -> Frozen[Hashable, np.dtype]: @property def indexes(self) -> Indexes[pd.Index]: + """Mapping of pandas.Index objects used for label based indexing. + + Raises an error if this Coordinates object has indexes that cannot + be coerced to pandas.Index objects. + + See Also + -------- + Coordinates.xindexes + """ return self._data.indexes @property def xindexes(self) -> Indexes[Index]: + """Mapping of :py:class:`~xarray.indexes.Index` objects + used for label based indexing. + """ return self._data.xindexes @property @@ -67,7 +88,7 @@ def variables(self): def _update_coords(self, coords, indexes): raise NotImplementedError() - def _maybe_drop_multiindex_coords(self, coords): + def _drop_coords(self, coord_names): raise NotImplementedError() def __iter__(self) -> Iterator[Hashable]: @@ -109,7 +130,7 @@ def to_index(self, ordered_dims: Sequence[Hashable] | None = None) -> pd.Index: elif set(ordered_dims) != set(self.dims): raise ValueError( "ordered_dims must match dims, but does not: " - "{} vs {}".format(ordered_dims, self.dims) + f"{ordered_dims} vs {self.dims}" ) if len(ordered_dims) == 0: @@ -125,7 +146,7 @@ def to_index(self, ordered_dims: Sequence[Hashable] | None = None) -> pd.Index: index_lengths = np.fromiter( (len(index) for index in indexes), dtype=np.intp ) - cumprod_lengths = np.cumproduct(index_lengths) + cumprod_lengths = np.cumprod(index_lengths) if cumprod_lengths[-1] == 0: # if any factor is empty, the cartesian product is empty @@ -163,13 +184,279 @@ def to_index(self, ordered_dims: Sequence[Hashable] | None = None) -> pd.Index: return pd.MultiIndex(level_list, code_list, names=names) - def update(self, other: Mapping[Any, Any]) -> None: - other_vars = getattr(other, "variables", other) - self._maybe_drop_multiindex_coords(set(other_vars)) - coords, indexes = merge_coords( - [self.variables, other_vars], priority_arg=1, indexes=self.xindexes + +class Coordinates(AbstractCoordinates): + """Dictionary like container for Xarray coordinates (variables + indexes). + + This collection is a mapping of coordinate names to + :py:class:`~xarray.DataArray` objects. + + It can be passed directly to the :py:class:`~xarray.Dataset` and + :py:class:`~xarray.DataArray` constructors via their `coords` argument. This + will add both the coordinates variables and their index. + + Coordinates are either: + + - returned via the :py:attr:`Dataset.coords` and :py:attr:`DataArray.coords` + properties + - built from Pandas or other index objects + (e.g., :py:meth:`Coordinates.from_pandas_multiindex`) + - built directly from coordinate data and Xarray ``Index`` objects (beware that + no consistency check is done on those inputs) + + Parameters + ---------- + coords: dict-like, optional + Mapping where keys are coordinate names and values are objects that + can be converted into a :py:class:`~xarray.Variable` object + (see :py:func:`~xarray.as_variable`). If another + :py:class:`~xarray.Coordinates` object is passed, its indexes + will be added to the new created object. + indexes: dict-like, optional + Mapping where keys are coordinate names and values are + :py:class:`~xarray.indexes.Index` objects. If None (default), + pandas indexes will be created for each dimension coordinate. + Passing an empty dictionary will skip this default behavior. + + Examples + -------- + Create a dimension coordinate with a default (pandas) index: + + >>> xr.Coordinates({"x": [1, 2]}) + Coordinates: + * x (x) int64 16B 1 2 + + Create a dimension coordinate with no index: + + >>> xr.Coordinates(coords={"x": [1, 2]}, indexes={}) + Coordinates: + x (x) int64 16B 1 2 + + Create a new Coordinates object from existing dataset coordinates + (indexes are passed): + + >>> ds = xr.Dataset(coords={"x": [1, 2]}) + >>> xr.Coordinates(ds.coords) + Coordinates: + * x (x) int64 16B 1 2 + + Create indexed coordinates from a ``pandas.MultiIndex`` object: + + >>> midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]]) + >>> xr.Coordinates.from_pandas_multiindex(midx, "x") + Coordinates: + * x (x) object 32B MultiIndex + * x_level_0 (x) object 32B 'a' 'a' 'b' 'b' + * x_level_1 (x) int64 32B 0 1 0 1 + + Create a new Dataset object by passing a Coordinates object: + + >>> midx_coords = xr.Coordinates.from_pandas_multiindex(midx, "x") + >>> xr.Dataset(coords=midx_coords) + Size: 96B + Dimensions: (x: 4) + Coordinates: + * x (x) object 32B MultiIndex + * x_level_0 (x) object 32B 'a' 'a' 'b' 'b' + * x_level_1 (x) int64 32B 0 1 0 1 + Data variables: + *empty* + + """ + + _data: DataWithCoords + + __slots__ = ("_data",) + + def __init__( + self, + coords: Mapping[Any, Any] | None = None, + indexes: Mapping[Any, Index] | None = None, + ) -> None: + # When coordinates are constructed directly, an internal Dataset is + # created so that it is compatible with the DatasetCoordinates and + # DataArrayCoordinates classes serving as a proxy for the data. + # TODO: refactor DataArray / Dataset so that Coordinates store the data. + from xarray.core.dataset import Dataset + + if coords is None: + coords = {} + + variables: dict[Hashable, Variable] + default_indexes: dict[Hashable, PandasIndex] = {} + coords_obj_indexes: dict[Hashable, Index] = {} + + if isinstance(coords, Coordinates): + if indexes is not None: + raise ValueError( + "passing both a ``Coordinates`` object and a mapping of indexes " + "to ``Coordinates.__init__`` is not allowed " + "(this constructor does not support merging them)" + ) + variables = {k: v.copy() for k, v in coords.variables.items()} + coords_obj_indexes = dict(coords.xindexes) + else: + variables = {} + for name, data in coords.items(): + var = as_variable(data, name=name) + if var.dims == (name,) and indexes is None: + index, index_vars = create_default_index_implicit(var, list(coords)) + default_indexes.update({k: index for k in index_vars}) + variables.update(index_vars) + else: + variables[name] = var + + if indexes is None: + indexes = {} + else: + indexes = dict(indexes) + + indexes.update(default_indexes) + indexes.update(coords_obj_indexes) + + no_coord_index = set(indexes) - set(variables) + if no_coord_index: + raise ValueError( + f"no coordinate variables found for these indexes: {no_coord_index}" + ) + + for k, idx in indexes.items(): + if not isinstance(idx, Index): + raise TypeError(f"'{k}' is not an `xarray.indexes.Index` object") + + # maybe convert to base variable + for k, v in variables.items(): + if k not in indexes: + variables[k] = v.to_base_variable() + + self._data = Dataset._construct_direct( + coord_names=set(variables), variables=variables, indexes=indexes ) - self._update_coords(coords, indexes) + + @classmethod + def _construct_direct( + cls, + coords: dict[Any, Variable], + indexes: dict[Any, Index], + dims: dict[Any, int] | None = None, + ) -> Self: + from xarray.core.dataset import Dataset + + obj = object.__new__(cls) + obj._data = Dataset._construct_direct( + coord_names=set(coords), + variables=coords, + indexes=indexes, + dims=dims, + ) + return obj + + @classmethod + def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: str) -> Self: + """Wrap a pandas multi-index as Xarray coordinates (dimension + levels). + + The returned coordinates can be directly assigned to a + :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the + ``coords`` argument of their constructor. + + Parameters + ---------- + midx : :py:class:`pandas.MultiIndex` + Pandas multi-index object. + dim : str + Dimension name. + + Returns + ------- + coords : Coordinates + A collection of Xarray indexed coordinates created from the multi-index. + + """ + xr_idx = PandasMultiIndex(midx, dim) + + variables = xr_idx.create_variables() + indexes = {k: xr_idx for k in variables} + + return cls(coords=variables, indexes=indexes) + + @property + def _names(self) -> set[Hashable]: + return self._data._coord_names + + @property + def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]: + """Mapping from dimension names to lengths or tuple of dimension names.""" + return self._data.dims + + @property + def sizes(self) -> Frozen[Hashable, int]: + """Mapping from dimension names to lengths.""" + return self._data.sizes + + @property + def dtypes(self) -> Frozen[Hashable, np.dtype]: + """Mapping from coordinate names to dtypes. + + Cannot be modified directly. + + See Also + -------- + Dataset.dtypes + """ + return Frozen({n: v.dtype for n, v in self._data.variables.items()}) + + @property + def variables(self) -> Mapping[Hashable, Variable]: + """Low level interface to Coordinates contents as dict of Variable objects. + + This dictionary is frozen to prevent mutation. + """ + return self._data.variables + + def to_dataset(self) -> Dataset: + """Convert these coordinates into a new Dataset.""" + names = [name for name in self._data._variables if name in self._names] + return self._data._copy_listed(names) + + def __getitem__(self, key: Hashable) -> DataArray: + return self._data[key] + + def __delitem__(self, key: Hashable) -> None: + # redirect to DatasetCoordinates.__delitem__ + del self._data.coords[key] + + def equals(self, other: Self) -> bool: + """Two Coordinates objects are equal if they have matching variables, + all of which are equal. + + See Also + -------- + Coordinates.identical + """ + if not isinstance(other, Coordinates): + return False + return self.to_dataset().equals(other.to_dataset()) + + def identical(self, other: Self) -> bool: + """Like equals, but also checks all variable attributes. + + See Also + -------- + Coordinates.equals + """ + if not isinstance(other, Coordinates): + return False + return self.to_dataset().identical(other.to_dataset()) + + def _update_coords( + self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index] + ) -> None: + # redirect to DatasetCoordinates._update_coords + self._data.coords._update_coords(coords, indexes) + + def _drop_coords(self, coord_names): + # redirect to DatasetCoordinates._drop_coords + self._data.coords._drop_coords(coord_names) def _merge_raw(self, other, reflexive): """For use with binary arithmetic.""" @@ -200,7 +487,7 @@ def _merge_inplace(self, other): yield self._update_coords(variables, indexes) - def merge(self, other: Coordinates | None) -> Dataset: + def merge(self, other: Mapping[Any, Any] | None) -> Dataset: """Merge two sets of coordinates to create a new Dataset The method implements the logic used for joining coordinates in the @@ -214,8 +501,9 @@ def merge(self, other: Coordinates | None) -> Dataset: Parameters ---------- - other : DatasetCoordinates or DataArrayCoordinates - The coordinates from another dataset or data array. + other : dict-like, optional + A :py:class:`Coordinates` object or any mapping that can be turned + into coordinates. Returns ------- @@ -236,13 +524,171 @@ def merge(self, other: Coordinates | None) -> Dataset: variables=coords, coord_names=coord_names, indexes=indexes ) + def __setitem__(self, key: Hashable, value: Any) -> None: + self.update({key: value}) + + def update(self, other: Mapping[Any, Any]) -> None: + """Update this Coordinates variables with other coordinate variables.""" + + if not len(other): + return + + other_coords: Coordinates + + if isinstance(other, Coordinates): + # Coordinates object: just pass it (default indexes won't be created) + other_coords = other + else: + other_coords = create_coords_with_default_indexes( + getattr(other, "variables", other) + ) + + # Discard original indexed coordinates prior to merge allows to: + # - fail early if the new coordinates don't preserve the integrity of existing + # multi-coordinate indexes + # - drop & replace coordinates without alignment (note: we must keep indexed + # coordinates extracted from the DataArray objects passed as values to + # `other` - if any - as those are still used for aligning the old/new coordinates) + coords_to_align = drop_indexed_coords(set(other_coords) & set(other), self) + + coords, indexes = merge_coords( + [coords_to_align, other_coords], + priority_arg=1, + indexes=coords_to_align.xindexes, + ) + + # special case for PandasMultiIndex: updating only its dimension coordinate + # is still allowed but depreciated. + # It is the only case where we need to actually drop coordinates here (multi-index levels) + # TODO: remove when removing PandasMultiIndex's dimension coordinate. + self._drop_coords(self._names - coords_to_align._names) + + self._update_coords(coords, indexes) + + def assign(self, coords: Mapping | None = None, **coords_kwargs: Any) -> Self: + """Assign new coordinates (and indexes) to a Coordinates object, returning + a new object with all the original coordinates in addition to the new ones. + + Parameters + ---------- + coords : mapping of dim to coord, optional + A mapping whose keys are the names of the coordinates and values are the + coordinates to assign. The mapping will generally be a dict or + :class:`Coordinates`. + + * If a value is a standard data value — for example, a ``DataArray``, + scalar, or array — the data is simply assigned as a coordinate. + + * A coordinate can also be defined and attached to an existing dimension + using a tuple with the first element the dimension name and the second + element the values for this new coordinate. + + **coords_kwargs + The keyword arguments form of ``coords``. + One of ``coords`` or ``coords_kwargs`` must be provided. + + Returns + ------- + new_coords : Coordinates + A new Coordinates object with the new coordinates (and indexes) + in addition to all the existing coordinates. + + Examples + -------- + >>> coords = xr.Coordinates() + >>> coords + Coordinates: + *empty* + + >>> coords.assign(x=[1, 2]) + Coordinates: + * x (x) int64 16B 1 2 + + >>> midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]]) + >>> coords.assign(xr.Coordinates.from_pandas_multiindex(midx, "y")) + Coordinates: + * y (y) object 32B MultiIndex + * y_level_0 (y) object 32B 'a' 'a' 'b' 'b' + * y_level_1 (y) int64 32B 0 1 0 1 + + """ + # TODO: this doesn't support a callable, which is inconsistent with `DataArray.assign_coords` + coords = either_dict_or_kwargs(coords, coords_kwargs, "assign") + new_coords = self.copy() + new_coords.update(coords) + return new_coords + + def _overwrite_indexes( + self, + indexes: Mapping[Any, Index], + variables: Mapping[Any, Variable] | None = None, + ) -> Self: + results = self.to_dataset()._overwrite_indexes(indexes, variables) + + # TODO: remove cast once we get rid of DatasetCoordinates + # and DataArrayCoordinates (i.e., Dataset and DataArray encapsulate Coordinates) + return cast(Self, results.coords) + + def _reindex_callback( + self, + aligner: Aligner, + dim_pos_indexers: dict[Hashable, Any], + variables: dict[Hashable, Variable], + indexes: dict[Hashable, Index], + fill_value: Any, + exclude_dims: frozenset[Hashable], + exclude_vars: frozenset[Hashable], + ) -> Self: + """Callback called from ``Aligner`` to create a new reindexed Coordinate.""" + aligned = self.to_dataset()._reindex_callback( + aligner, + dim_pos_indexers, + variables, + indexes, + fill_value, + exclude_dims, + exclude_vars, + ) + + # TODO: remove cast once we get rid of DatasetCoordinates + # and DataArrayCoordinates (i.e., Dataset and DataArray encapsulate Coordinates) + return cast(Self, aligned.coords) + + def _ipython_key_completions_(self): + """Provide method for the key-autocompletions in IPython.""" + return self._data._ipython_key_completions_() + + def copy( + self, + deep: bool = False, + memo: dict[int, Any] | None = None, + ) -> Self: + """Return a copy of this Coordinates object.""" + # do not copy indexes (may corrupt multi-coordinate indexes) + # TODO: disable variables deepcopy? it may also be problematic when they + # encapsulate index objects like pd.Index + variables = { + k: v._copy(deep=deep, memo=memo) for k, v in self.variables.items() + } + + # TODO: getting an error with `self._construct_direct`, possibly because of how + # a subclass implements `_construct_direct`. (This was originally the same + # runtime code, but we switched the type definitions in #8216, which + # necessitates the cast.) + return cast( + Self, + Coordinates._construct_direct( + coords=variables, indexes=dict(self.xindexes), dims=dict(self.sizes) + ), + ) + class DatasetCoordinates(Coordinates): - """Dictionary like container for Dataset coordinates. + """Dictionary like container for Dataset coordinates (variables + indexes). - Essentially an immutable dictionary with keys given by the array's - dimensions and the values given by the corresponding xarray.Coordinate - objects. + This collection can be passed directly to the :py:class:`~xarray.Dataset` + and :py:class:`~xarray.DataArray` constructors via their `coords` argument. + This will add both the coordinates variables and their index. """ _data: Dataset @@ -257,7 +703,7 @@ def _names(self) -> set[Hashable]: return self._data._coord_names @property - def dims(self) -> Mapping[Hashable, int]: + def dims(self) -> Frozen[Hashable, int]: return self._data.dims @property @@ -318,21 +764,28 @@ def _update_coords( original_indexes.update(indexes) self._data._indexes = original_indexes - def _maybe_drop_multiindex_coords(self, coords: set[Hashable]) -> None: - """Drops variables in coords, and any associated variables as well.""" + def _drop_coords(self, coord_names): + # should drop indexed coordinates only + for name in coord_names: + del self._data._variables[name] + del self._data._indexes[name] + self._data._coord_names.difference_update(coord_names) + + def _drop_indexed_coords(self, coords_to_drop: set[Hashable]) -> None: assert self._data.xindexes is not None - variables, indexes = drop_coords( - coords, self._data._variables, self._data.xindexes - ) - self._data._coord_names.intersection_update(variables) - self._data._variables = variables - self._data._indexes = indexes + new_coords = drop_indexed_coords(coords_to_drop, self) + for name in self._data._coord_names - new_coords._names: + del self._data._variables[name] + self._data._indexes = dict(new_coords.xindexes) + self._data._coord_names.intersection_update(new_coords._names) def __delitem__(self, key: Hashable) -> None: if key in self: del self._data[key] else: - raise KeyError(f"{key!r} is not a coordinate variable.") + raise KeyError( + f"{key!r} is not in coordinate variables {tuple(self.keys())}" + ) def _ipython_key_completions_(self): """Provide method for the key-autocompletions in IPython.""" @@ -343,11 +796,12 @@ def _ipython_key_completions_(self): ] -class DataArrayCoordinates(Coordinates["T_DataArray"]): - """Dictionary like container for DataArray coordinates. +class DataArrayCoordinates(Coordinates, Generic[T_DataArray]): + """Dictionary like container for DataArray coordinates (variables + indexes). - Essentially a dict with keys given by the array's - dimensions and the values given by corresponding DataArray objects. + This collection can be passed directly to the :py:class:`~xarray.Dataset` + and :py:class:`~xarray.DataArray` constructors via their `coords` argument. + This will add both the coordinates variables and their index. """ _data: T_DataArray @@ -398,13 +852,11 @@ def _update_coords( original_indexes.update(indexes) self._data._indexes = original_indexes - def _maybe_drop_multiindex_coords(self, coords: set[Hashable]) -> None: - """Drops variables in coords, and any associated variables as well.""" - variables, indexes = drop_coords( - coords, self._data._coords, self._data.xindexes - ) - self._data._coords = variables - self._data._indexes = indexes + def _drop_coords(self, coord_names): + # should drop indexed coordinates only + for name in coord_names: + del self._data._coords[name] + del self._data._indexes[name] @property def variables(self): @@ -419,7 +871,9 @@ def to_dataset(self) -> Dataset: def __delitem__(self, key: Hashable) -> None: if key not in self: - raise KeyError(f"{key!r} is not a coordinate variable.") + raise KeyError( + f"{key!r} is not in coordinate variables {tuple(self.keys())}" + ) assert_no_index_corrupted(self._data.xindexes, {key}) del self._data._coords[key] @@ -431,40 +885,51 @@ def _ipython_key_completions_(self): return self._data._ipython_key_completions_() -def drop_coords( - coords_to_drop: set[Hashable], variables, indexes: Indexes -) -> tuple[dict, dict]: - """Drop index variables associated with variables in coords_to_drop.""" - # Only warn when we're dropping the dimension with the multi-indexed coordinate - # If asked to drop a subset of the levels in a multi-index, we raise an error - # later but skip the warning here. - new_variables = dict(variables.copy()) - new_indexes = dict(indexes.copy()) - for key in coords_to_drop & set(indexes): - maybe_midx = indexes[key] - idx_coord_names = set(indexes.get_all_coords(key)) - if ( - isinstance(maybe_midx, PandasMultiIndex) - and key == maybe_midx.dim - and (idx_coord_names - coords_to_drop) - ): - warnings.warn( - f"Updating MultiIndexed coordinate {key!r} would corrupt indices for " - f"other variables: {list(maybe_midx.index.names)!r}. " - f"This will raise an error in the future. Use `.drop_vars({idx_coord_names!r})` before " +def drop_indexed_coords( + coords_to_drop: set[Hashable], coords: Coordinates +) -> Coordinates: + """Drop indexed coordinates associated with coordinates in coords_to_drop. + + This will raise an error in case it corrupts any passed index and its + coordinate variables. + + """ + new_variables = dict(coords.variables) + new_indexes = dict(coords.xindexes) + + for idx, idx_coords in coords.xindexes.group_by_index(): + idx_drop_coords = set(idx_coords) & coords_to_drop + + # special case for pandas multi-index: still allow but deprecate + # dropping only its dimension coordinate. + # TODO: remove when removing PandasMultiIndex's dimension coordinate. + if isinstance(idx, PandasMultiIndex) and idx_drop_coords == {idx.dim}: + idx_drop_coords.update(idx.index.names) + emit_user_level_warning( + f"updating coordinate {idx.dim!r} with a PandasMultiIndex would leave " + f"the multi-index level coordinates {list(idx.index.names)!r} in an inconsistent state. " + f"This will raise an error in the future. Use `.drop_vars({list(idx_coords)!r})` before " "assigning new coordinate values.", FutureWarning, - stacklevel=4, ) - for k in idx_coord_names: - del new_variables[k] - del new_indexes[k] - return new_variables, new_indexes + elif idx_drop_coords and len(idx_drop_coords) != len(idx_coords): + idx_drop_coords_str = ", ".join(f"{k!r}" for k in idx_drop_coords) + idx_coords_str = ", ".join(f"{k!r}" for k in idx_coords) + raise ValueError( + f"cannot drop or update coordinate(s) {idx_drop_coords_str}, which would corrupt " + f"the following index built from coordinates {idx_coords_str}:\n" + f"{idx}" + ) -def assert_coordinate_consistent( - obj: T_DataArray | Dataset, coords: Mapping[Any, Variable] -) -> None: + for k in idx_drop_coords: + del new_variables[k] + del new_indexes[k] + + return Coordinates._construct_direct(coords=new_variables, indexes=new_indexes) + + +def assert_coordinate_consistent(obj: T_Xarray, coords: Mapping[Any, Variable]) -> None: """Make sure the dimension coordinate of obj is consistent with coords. obj: DataArray or Dataset @@ -477,3 +942,81 @@ def assert_coordinate_consistent( f"dimension coordinate {k!r} conflicts between " f"indexed and indexing objects:\n{obj[k]}\nvs.\n{coords[k]}" ) + + +def create_coords_with_default_indexes( + coords: Mapping[Any, Any], data_vars: DataVars | None = None +) -> Coordinates: + """Returns a Coordinates object from a mapping of coordinates (arbitrary objects). + + Create default (pandas) indexes for each of the input dimension coordinates. + Extract coordinates from each input DataArray. + + """ + # Note: data_vars is needed here only because a pd.MultiIndex object + # can be promoted as coordinates. + # TODO: It won't be relevant anymore when this behavior will be dropped + # in favor of the more explicit ``Coordinates.from_pandas_multiindex()``. + + from xarray.core.dataarray import DataArray + + all_variables = dict(coords) + if data_vars is not None: + all_variables.update(data_vars) + + indexes: dict[Hashable, Index] = {} + variables: dict[Hashable, Variable] = {} + + # promote any pandas multi-index in data_vars as coordinates + coords_promoted: dict[Hashable, Any] = {} + pd_mindex_keys: list[Hashable] = [] + + for k, v in all_variables.items(): + if isinstance(v, pd.MultiIndex): + coords_promoted[k] = v + pd_mindex_keys.append(k) + elif k in coords: + coords_promoted[k] = v + + if pd_mindex_keys: + pd_mindex_keys_fmt = ",".join([f"'{k}'" for k in pd_mindex_keys]) + emit_user_level_warning( + f"the `pandas.MultiIndex` object(s) passed as {pd_mindex_keys_fmt} coordinate(s) or " + "data variable(s) will no longer be implicitly promoted and wrapped into " + "multiple indexed coordinates in the future " + "(i.e., one coordinate for each multi-index level + one dimension coordinate). " + "If you want to keep this behavior, you need to first wrap it explicitly using " + "`mindex_coords = xarray.Coordinates.from_pandas_multiindex(mindex_obj, 'dim')` " + "and pass it as coordinates, e.g., `xarray.Dataset(coords=mindex_coords)`, " + "`dataset.assign_coords(mindex_coords)` or `dataarray.assign_coords(mindex_coords)`.", + FutureWarning, + ) + + dataarray_coords: list[DataArrayCoordinates] = [] + + for name, obj in coords_promoted.items(): + if isinstance(obj, DataArray): + dataarray_coords.append(obj.coords) + + variable = as_variable(obj, name=name) + + if variable.dims == (name,): + idx, idx_vars = create_default_index_implicit(variable, all_variables) + indexes.update({k: idx for k in idx_vars}) + variables.update(idx_vars) + all_variables.update(idx_vars) + else: + variables[name] = variable + + new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes) + + # extract and merge coordinates and indexes from input DataArrays + if dataarray_coords: + prioritized = {k: (v, indexes.get(k, None)) for k, v in variables.items()} + variables, indexes = merge_coordinates_without_align( + dataarray_coords + [new_coords], + prioritized=prioritized, + ) + new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes) + + return new_coords diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 24c5f698a27..98ff9002856 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -1,9 +1,5 @@ from __future__ import annotations -from functools import partial - -from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined] - from xarray.core import dtypes, nputils @@ -63,10 +59,11 @@ def push(array, n, axis): """ Dask-aware bottleneck.push """ - import bottleneck import dask.array as da import numpy as np + from xarray.core.duck_array_ops import _push + def _fill_with_last_one(a, b): # cumreduction apply the push func over all the blocks first so, the only missing part is filling # the missing values using the last data of the previous chunk @@ -89,43 +86,10 @@ def _fill_with_last_one(a, b): # The method parameter makes that the tests for python 3.7 fails. return da.reductions.cumreduction( - func=bottleneck.push, + func=_push, binop=_fill_with_last_one, ident=np.nan, x=array, axis=axis, dtype=array.dtype, ) - - -def _first_last_wrapper(array, *, axis, op, keepdims): - return op(array, axis, keepdims=keepdims) - - -def _first_or_last(darray, axis, op): - import dask.array - - # This will raise the same error message seen for numpy - axis = normalize_axis_index(axis, darray.ndim) - - wrapped_op = partial(_first_last_wrapper, op=op) - return dask.array.reduction( - darray, - chunk=wrapped_op, - aggregate=wrapped_op, - axis=axis, - dtype=darray.dtype, - keepdims=False, # match numpy version - ) - - -def nanfirst(darray, axis): - from xarray.core.duck_array_ops import nanfirst - - return _first_or_last(darray, axis, op=nanfirst) - - -def nanlast(darray, axis): - from xarray.core.duck_array_ops import nanlast - - return _first_or_last(darray, axis, op=nanlast) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 4161941a190..bc07de8c908 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2,9 +2,19 @@ import datetime import warnings -from collections.abc import Hashable, Iterable, Mapping, Sequence +from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence from os import PathLike -from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + NoReturn, + TypeVar, + Union, + overload, +) import numpy as np import pandas as pd @@ -23,7 +33,12 @@ from xarray.core.arithmetic import DataArrayArithmetic from xarray.core.common import AbstractArray, DataWithCoords, get_chunksizes from xarray.core.computation import unify_chunks -from xarray.core.coordinates import DataArrayCoordinates, assert_coordinate_consistent +from xarray.core.coordinates import ( + Coordinates, + DataArrayCoordinates, + assert_coordinate_consistent, + create_coords_with_default_indexes, +) from xarray.core.dataset import Dataset from xarray.core.formatting import format_item from xarray.core.indexes import ( @@ -34,14 +49,22 @@ isel_indexes, ) from xarray.core.indexing import is_fancy_indexer, map_index_queries -from xarray.core.merge import PANDAS_TYPES, MergeError, _create_indexes_from_coords +from xarray.core.merge import PANDAS_TYPES, MergeError from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.types import ( + DaCompatible, + T_DataArray, + T_DataArrayOrSet, + ZarrWriteModes, +) from xarray.core.utils import ( Default, HybridMappingProxy, ReprObject, _default, either_dict_or_kwargs, + hashable, + infix_dims, ) from xarray.core.variable import ( IndexVariable, @@ -51,25 +74,15 @@ ) from xarray.plot.accessor import DataArrayPlotAccessor from xarray.plot.utils import _get_units_from_attrs +from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims if TYPE_CHECKING: - from typing import TypeVar, Union - + from dask.dataframe import DataFrame as DaskDataFrame + from dask.delayed import Delayed + from iris.cube import Cube as iris_Cube from numpy.typing import ArrayLike - try: - from dask.delayed import Delayed - except ImportError: - Delayed = None # type: ignore - try: - from cdms2 import Variable as cdms2_Variable - except ImportError: - cdms2_Variable = None - try: - from iris.cube import Cube as iris_Cube - except ImportError: - iris_Cube = None - + from xarray.backends import ZarrStore from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes from xarray.core.groupby import DataArrayGroupBy from xarray.core.resample import DataArrayResample @@ -88,18 +101,41 @@ QueryEngineOptions, QueryParserOptions, ReindexMethodOptions, + Self, SideOptions, - T_DataArray, + T_Chunks, T_Xarray, ) from xarray.core.weighted import DataArrayWeighted + from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint T_XarrayOther = TypeVar("T_XarrayOther", bound=Union["DataArray", Dataset]) +def _check_coords_dims(shape, coords, dim): + sizes = dict(zip(dim, shape)) + for k, v in coords.items(): + if any(d not in dim for d in v.dims): + raise ValueError( + f"coordinate {k} has dimensions {v.dims}, but these " + "are not a subset of the DataArray " + f"dimensions {dim}" + ) + + for d, s in v.sizes.items(): + if s != sizes[d]: + raise ValueError( + f"conflicting sizes for dimension {d!r}: " + f"length {sizes[d]} on the data but length {s} on " + f"coordinate {k!r}" + ) + + def _infer_coords_and_dims( - shape, coords, dims -) -> tuple[dict[Hashable, Variable], tuple[Hashable, ...]]: + shape: tuple[int, ...], + coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None, + dims: str | Iterable[Hashable] | None, +) -> tuple[Mapping[Hashable, Any], tuple[Hashable, ...]]: """All the logic for creating a new DataArray""" if ( @@ -115,8 +151,7 @@ def _infer_coords_and_dims( if isinstance(dims, str): dims = (dims,) - - if dims is None: + elif dims is None: dims = [f"dim_{n}" for n in range(len(shape))] if coords is not None and len(coords) == len(shape): # try to infer dimensions from coords @@ -126,56 +161,41 @@ def _infer_coords_and_dims( for n, (dim, coord) in enumerate(zip(dims, coords)): coord = as_variable(coord, name=dims[n]).to_index_variable() dims[n] = coord.name - dims = tuple(dims) - elif len(dims) != len(shape): + dims_tuple = tuple(dims) + if len(dims_tuple) != len(shape): raise ValueError( "different number of dimensions on data " - f"and dims: {len(shape)} vs {len(dims)}" + f"and dims: {len(shape)} vs {len(dims_tuple)}" ) - else: - for d in dims: - if not isinstance(d, str): - raise TypeError(f"dimension {d} is not a string") - - new_coords: dict[Hashable, Variable] = {} - - if utils.is_dict_like(coords): - for k, v in coords.items(): - new_coords[k] = as_variable(v, name=k) - elif coords is not None: - for dim, coord in zip(dims, coords): - var = as_variable(coord, name=dim) - var.dims = (dim,) - new_coords[dim] = var.to_index_variable() - - sizes = dict(zip(dims, shape)) - for k, v in new_coords.items(): - if any(d not in dims for d in v.dims): - raise ValueError( - f"coordinate {k} has dimensions {v.dims}, but these " - "are not a subset of the DataArray " - f"dimensions {dims}" - ) + for d in dims_tuple: + if not hashable(d): + raise TypeError(f"Dimension {d} is not hashable") - for d, s in zip(v.dims, v.shape): - if s != sizes[d]: - raise ValueError( - f"conflicting sizes for dimension {d!r}: " - f"length {sizes[d]} on the data but length {s} on " - f"coordinate {k!r}" - ) + new_coords: Mapping[Hashable, Any] - if k in sizes and v.shape != (sizes[k],): - raise ValueError( - f"coordinate {k!r} is a DataArray dimension, but " - f"it has shape {v.shape!r} rather than expected shape {sizes[k]!r} " - "matching the dimension size" - ) + if isinstance(coords, Coordinates): + new_coords = coords + else: + new_coords = {} + if utils.is_dict_like(coords): + for k, v in coords.items(): + new_coords[k] = as_variable(v, name=k) + elif coords is not None: + for dim, coord in zip(dims_tuple, coords): + var = as_variable(coord, name=dim) + var.dims = (dim,) + new_coords[dim] = var.to_index_variable() + + _check_coords_dims(shape, new_coords, dims_tuple) - return new_coords, dims + return new_coords, dims_tuple -def _check_data_shape(data, coords, dims): +def _check_data_shape( + data: Any, + coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None, + dims: str | Iterable[Hashable] | None, +) -> Any: if data is dtypes.NA: data = np.nan if coords is not None and utils.is_scalar(data, include_0d=False): @@ -193,13 +213,13 @@ def _check_data_shape(data, coords, dims): return data -class _LocIndexer: +class _LocIndexer(Generic[T_DataArray]): __slots__ = ("data_array",) - def __init__(self, data_array: DataArray): + def __init__(self, data_array: T_DataArray): self.data_array = data_array - def __getitem__(self, key) -> DataArray: + def __getitem__(self, key) -> T_DataArray: if not utils.is_dict_like(key): # expand the indexer so we can handle Ellipsis labels = indexing.expanded_indexer(key, self.data_array.ndim) @@ -259,7 +279,7 @@ class DataArray( or pandas object, attempts are made to use this array's metadata to fill in other unspecified arguments. A view of the array's data is used instead of a copy if possible. - coords : sequence or dict of array_like, optional + coords : sequence or dict of array_like or :py:class:`~xarray.Coordinates`, optional Coordinates (tick labels) to use for indexing along each dimension. The following notations are accepted: @@ -279,6 +299,10 @@ class DataArray( - mapping {coord name: (dimension name, array-like)} - mapping {coord name: (tuple of dimension names, array-like)} + Alternatively, a :py:class:`~xarray.Coordinates` object may be used in + order to explicitly pass indexes (e.g., a multi-index or any custom + Xarray index) or to bypass the creation of a default index for any + :term:`Dimension coordinate` included in that object. dims : Hashable or sequence of Hashable, optional Name(s) of the data dimension(s). Must be either a Hashable (only for 1D data) or a sequence of Hashables with length equal @@ -290,6 +314,11 @@ class DataArray( attrs : dict_like or None, optional Attributes to assign to the new instance. By default, an empty attribute dictionary is initialized. + indexes : py:class:`~xarray.Indexes` or dict-like, optional + For internal use only. For passing indexes objects to the + new DataArray, use the ``coords`` argument instead with a + :py:class:`~xarray.Coordinate` object (both coordinate variables + and indexes will be extracted from the latter). Examples -------- @@ -319,17 +348,17 @@ class DataArray( ... ), ... ) >>> da - + Size: 96B array([[[29.11241877, 18.20125767, 22.82990387], [32.92714559, 29.94046392, 7.18177696]], [[22.60070734, 13.78914233, 14.17424919], [18.28478802, 16.15234857, 26.63418806]]]) Coordinates: - lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 - lat (x, y) float64 42.25 42.21 42.63 42.59 - * time (time) datetime64[ns] 2014-09-06 2014-09-07 2014-09-08 - reference_time datetime64[ns] 2014-09-05 + lon (x, y) float64 32B -99.83 -99.32 -99.79 -99.23 + lat (x, y) float64 32B 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 24B 2014-09-06 2014-09-07 2014-09-08 + reference_time datetime64[ns] 8B 2014-09-05 Dimensions without coordinates: x, y Attributes: description: Ambient temperature. @@ -338,13 +367,13 @@ class DataArray( Find out where the coldest temperature was: >>> da.isel(da.argmin(...)) - + Size: 8B array(7.18177696) Coordinates: - lon float64 -99.32 - lat float64 42.21 - time datetime64[ns] 2014-09-08 - reference_time datetime64[ns] 2014-09-05 + lon float64 8B -99.32 + lat float64 8B 42.21 + time datetime64[ns] 8B 2014-09-08 + reference_time datetime64[ns] 8B 2014-09-05 Attributes: description: Ambient temperature. units: degC @@ -372,14 +401,12 @@ class DataArray( def __init__( self, data: Any = dtypes.NA, - coords: Sequence[Sequence[Any] | pd.Index | DataArray] - | Mapping[Any, Any] - | None = None, - dims: Hashable | Sequence[Hashable] | None = None, + coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None, + dims: str | Iterable[Hashable] | None = None, name: Hashable | None = None, attrs: Mapping | None = None, # internal parameters - indexes: dict[Hashable, Index] | None = None, + indexes: Mapping[Any, Index] | None = None, fastpath: bool = False, ) -> None: if fastpath: @@ -388,10 +415,11 @@ def __init__( assert attrs is None assert indexes is not None else: - # TODO: (benbovy - explicit indexes) remove - # once it becomes part of the public interface if indexes is not None: - raise ValueError("Providing explicit indexes is not supported yet") + raise ValueError( + "Explicitly passing indexes via the `indexes` argument is not supported " + "when `fastpath=False`. Use the `coords` argument instead." + ) # try to fill in arguments from data if they weren't supplied if coords is None: @@ -415,28 +443,29 @@ def __init__( data = as_compatible_data(data) coords, dims = _infer_coords_and_dims(data.shape, coords, dims) variable = Variable(dims, data, attrs, fastpath=True) - indexes, coords = _create_indexes_from_coords(coords) + + if not isinstance(coords, Coordinates): + coords = create_coords_with_default_indexes(coords) + indexes = dict(coords.xindexes) + coords = {k: v.copy() for k, v in coords.variables.items()} # These fully describe a DataArray self._variable = variable assert isinstance(coords, dict) self._coords = coords self._name = name - - # TODO(shoyer): document this argument, once it becomes part of the - # public interface. - self._indexes = indexes + self._indexes = indexes # type: ignore[assignment] self._close = None @classmethod def _construct_direct( - cls: type[T_DataArray], + cls, variable: Variable, coords: dict[Any, Variable], name: Hashable, indexes: dict[Hashable, Index], - ) -> T_DataArray: + ) -> Self: """Shortcut around __init__ for internal use when we want to skip costly validation """ @@ -449,12 +478,12 @@ def _construct_direct( return obj def _replace( - self: T_DataArray, + self, variable: Variable | None = None, coords=None, name: Hashable | None | Default = _default, indexes=None, - ) -> T_DataArray: + ) -> Self: if variable is None: variable = self.variable if coords is None: @@ -466,10 +495,10 @@ def _replace( return type(self)(variable, coords, name=name, indexes=indexes, fastpath=True) def _replace_maybe_drop_dims( - self: T_DataArray, + self, variable: Variable, name: Hashable | None | Default = _default, - ) -> T_DataArray: + ) -> Self: if variable.dims == self.dims and variable.shape == self.shape: coords = self._coords.copy() indexes = self._indexes @@ -491,18 +520,18 @@ def _replace_maybe_drop_dims( return self._replace(variable, coords, name, indexes=indexes) def _overwrite_indexes( - self: T_DataArray, + self, indexes: Mapping[Any, Index], - coords: Mapping[Any, Variable] | None = None, + variables: Mapping[Any, Variable] | None = None, drop_coords: list[Hashable] | None = None, rename_dims: Mapping[Any, Any] | None = None, - ) -> T_DataArray: + ) -> Self: """Maybe replace indexes and their corresponding coordinates.""" if not indexes: return self - if coords is None: - coords = {} + if variables is None: + variables = {} if drop_coords is None: drop_coords = [] @@ -511,7 +540,7 @@ def _overwrite_indexes( new_indexes = dict(self._indexes) for name in indexes: - new_coords[name] = coords[name] + new_coords[name] = variables[name] new_indexes[name] = indexes[name] for name in drop_coords: @@ -529,8 +558,8 @@ def _to_temp_dataset(self) -> Dataset: return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False) def _from_temp_dataset( - self: T_DataArray, dataset: Dataset, name: Hashable | None | Default = _default - ) -> T_DataArray: + self, dataset: Dataset, name: Hashable | None | Default = _default + ) -> Self: variable = dataset._variables.pop(_THIS_ARRAY) coords = dataset._variables indexes = dataset._indexes @@ -544,9 +573,24 @@ def subset(dim, label): array.attrs = {} return as_variable(array) - variables = {label: subset(dim, label) for label in self.get_index(dim)} - variables.update({k: v for k, v in self._coords.items() if k != dim}) + variables_from_split = { + label: subset(dim, label) for label in self.get_index(dim) + } coord_names = set(self._coords) - {dim} + + ambiguous_vars = set(variables_from_split) & coord_names + if ambiguous_vars: + rename_msg_fmt = ", ".join([f"{v}=..." for v in sorted(ambiguous_vars)]) + raise ValueError( + f"Splitting along the dimension {dim!r} would produce the variables " + f"{tuple(sorted(ambiguous_vars))} which are also existing coordinate " + f"variables. Use DataArray.rename({rename_msg_fmt}) or " + f"DataArray.assign_coords({dim}=...) to resolve this ambiguity." + ) + + variables = variables_from_split | { + k: v for k, v in self._coords.items() if k != dim + } indexes = filter_indexes_from_coords(self._indexes, coord_names) dataset = Dataset._construct_direct( variables, coord_names, indexes=indexes, attrs=self.attrs @@ -742,7 +786,7 @@ def to_numpy(self) -> np.ndarray: """ return self.variable.to_numpy() - def as_numpy(self: T_DataArray) -> T_DataArray: + def as_numpy(self) -> Self: """ Coerces wrapped data and coordinates into numpy arrays, returning a DataArray. @@ -797,7 +841,7 @@ def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]: key = indexing.expanded_indexer(key, self.ndim) return dict(zip(self.dims, key)) - def _getitem_coord(self: T_DataArray, key: Any) -> T_DataArray: + def _getitem_coord(self, key: Any) -> Self: from xarray.core.dataset import _get_virtual_variable try: @@ -808,7 +852,7 @@ def _getitem_coord(self: T_DataArray, key: Any) -> T_DataArray: return self._replace_maybe_drop_dims(var, name=key) - def __getitem__(self: T_DataArray, key: Any) -> T_DataArray: + def __getitem__(self, key: Any) -> Self: if isinstance(key, str): return self._getitem_coord(key) else: @@ -825,6 +869,7 @@ def __setitem__(self, key: Any, value: Any) -> None: obj = self[key] if isinstance(value, DataArray): assert_coordinate_consistent(value, obj.coords.variables) + value = value.variable # DataArray key -> Variable key key = { k: v.variable if isinstance(v, DataArray) else v @@ -877,6 +922,18 @@ def encoding(self) -> dict[Any, Any]: def encoding(self, value: Mapping[Any, Any]) -> None: self.variable.encoding = dict(value) + def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() + + def drop_encoding(self) -> Self: + """Return a new DataArray without encoding on the array or any attached + coords.""" + ds = self._to_temp_dataset().drop_encoding() + return self._from_temp_dataset(ds) + @property def indexes(self) -> Indexes: """Mapping of pandas.Index objects used for label based indexing. @@ -893,36 +950,45 @@ def indexes(self) -> Indexes: @property def xindexes(self) -> Indexes: - """Mapping of xarray Index objects used for label based indexing.""" + """Mapping of :py:class:`~xarray.indexes.Index` objects + used for label based indexing. + """ return Indexes(self._indexes, {k: self._coords[k] for k in self._indexes}) @property def coords(self) -> DataArrayCoordinates: - """Dictionary-like container of coordinate arrays.""" + """Mapping of :py:class:`~xarray.DataArray` objects corresponding to + coordinate variables. + + See Also + -------- + Coordinates + """ return DataArrayCoordinates(self) @overload def reset_coords( - self: T_DataArray, + self, names: Dims = None, + *, drop: Literal[False] = False, - ) -> Dataset: - ... + ) -> Dataset: ... @overload def reset_coords( - self: T_DataArray, + self, names: Dims = None, *, drop: Literal[True], - ) -> T_DataArray: - ... + ) -> Self: ... + @_deprecate_positional_args("v2023.10.0") def reset_coords( - self: T_DataArray, + self, names: Dims = None, + *, drop: bool = False, - ) -> T_DataArray | Dataset: + ) -> Self | Dataset: """Given names of coordinates, reset them to become variables. Parameters @@ -953,43 +1019,43 @@ def reset_coords( ... name="Temperature", ... ) >>> da - + Size: 200B array([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]) Coordinates: - lon (x) int64 10 11 12 13 14 - lat (y) int64 20 21 22 23 24 - Pressure (x, y) int64 50 51 52 53 54 55 56 57 ... 67 68 69 70 71 72 73 74 + lon (x) int64 40B 10 11 12 13 14 + lat (y) int64 40B 20 21 22 23 24 + Pressure (x, y) int64 200B 50 51 52 53 54 55 56 57 ... 68 69 70 71 72 73 74 Dimensions without coordinates: x, y Return Dataset with target coordinate as a data variable rather than a coordinate variable: >>> da.reset_coords(names="Pressure") - + Size: 480B Dimensions: (x: 5, y: 5) Coordinates: - lon (x) int64 10 11 12 13 14 - lat (y) int64 20 21 22 23 24 + lon (x) int64 40B 10 11 12 13 14 + lat (y) int64 40B 20 21 22 23 24 Dimensions without coordinates: x, y Data variables: - Pressure (x, y) int64 50 51 52 53 54 55 56 57 ... 68 69 70 71 72 73 74 - Temperature (x, y) int64 0 1 2 3 4 5 6 7 8 9 ... 16 17 18 19 20 21 22 23 24 + Pressure (x, y) int64 200B 50 51 52 53 54 55 56 ... 68 69 70 71 72 73 74 + Temperature (x, y) int64 200B 0 1 2 3 4 5 6 7 8 ... 17 18 19 20 21 22 23 24 Return DataArray without targeted coordinate: >>> da.reset_coords(names="Pressure", drop=True) - + Size: 200B array([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]) Coordinates: - lon (x) int64 10 11 12 13 14 - lat (y) int64 20 21 22 23 24 + lon (x) int64 40B 10 11 12 13 14 + lat (y) int64 40B 20 21 22 23 24 Dimensions without coordinates: x, y """ if names is None: @@ -1004,7 +1070,7 @@ def reset_coords( dataset[self.name] = self.variable return dataset - def __dask_tokenize__(self): + def __dask_tokenize__(self) -> object: from dask.base import normalize_token return normalize_token((type(self), self._variable, self._coords, self._name)) @@ -1034,15 +1100,15 @@ def __dask_postpersist__(self): func, args = self._to_temp_dataset().__dask_postpersist__() return self._dask_finalize, (self.name, func) + args - @staticmethod - def _dask_finalize(results, name, func, *args, **kwargs) -> DataArray: + @classmethod + def _dask_finalize(cls, results, name, func, *args, **kwargs) -> Self: ds = func(results, *args, **kwargs) variable = ds._variables.pop(_THIS_ARRAY) coords = ds._variables indexes = ds._indexes - return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) + return cls(variable, coords, name=name, indexes=indexes, fastpath=True) - def load(self: T_DataArray, **kwargs) -> T_DataArray: + def load(self, **kwargs) -> Self: """Manually trigger loading of this array's data from disk or a remote source into memory and return this array. @@ -1066,7 +1132,7 @@ def load(self: T_DataArray, **kwargs) -> T_DataArray: self._coords = new._coords return self - def compute(self: T_DataArray, **kwargs) -> T_DataArray: + def compute(self, **kwargs) -> Self: """Manually trigger loading of this array's data from disk or a remote source into memory and return a new array. The original is left unaltered. @@ -1088,7 +1154,7 @@ def compute(self: T_DataArray, **kwargs) -> T_DataArray: new = self.copy(deep=False) return new.load(**kwargs) - def persist(self: T_DataArray, **kwargs) -> T_DataArray: + def persist(self, **kwargs) -> Self: """Trigger computation in constituent dask arrays This keeps them as dask arrays but encourages them to keep data in @@ -1107,7 +1173,7 @@ def persist(self: T_DataArray, **kwargs) -> T_DataArray: ds = self._to_temp_dataset().persist(**kwargs) return self._from_temp_dataset(ds) - def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: + def copy(self, deep: bool = True, data: Any = None) -> Self: """Returns a copy of this array. If `deep=True`, a deep copy is made of the data array. @@ -1139,37 +1205,37 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: >>> array = xr.DataArray([1, 2, 3], dims="x", coords={"x": ["a", "b", "c"]}) >>> array.copy() - + Size: 24B array([1, 2, 3]) Coordinates: - * x (x) >> array_0 = array.copy(deep=False) >>> array_0[0] = 7 >>> array_0 - + Size: 24B array([7, 2, 3]) Coordinates: - * x (x) >> array - + Size: 24B array([7, 2, 3]) Coordinates: - * x (x) >> array.copy(data=[0.1, 0.2, 0.3]) - + Size: 24B array([0.1, 0.2, 0.3]) Coordinates: - * x (x) >> array - + Size: 24B array([7, 2, 3]) Coordinates: - * x (x) T_DataArray: return self._copy(deep=deep, data=data) def _copy( - self: T_DataArray, + self, deep: bool = True, data: Any = None, memo: dict[int, Any] | None = None, - ) -> T_DataArray: + ) -> Self: variable = self.variable._copy(deep=deep, data=data, memo=memo) indexes, index_vars = self.xindexes.copy_indexes(deep=deep) @@ -1195,12 +1261,10 @@ def _copy( return self._replace(variable, coords, indexes=indexes) - def __copy__(self: T_DataArray) -> T_DataArray: + def __copy__(self) -> Self: return self._copy(deep=False) - def __deepcopy__( - self: T_DataArray, memo: dict[int, Any] | None = None - ) -> T_DataArray: + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: return self._copy(deep=True, memo=memo) # mutable objects should not be Hashable @@ -1240,21 +1304,19 @@ def chunksizes(self) -> Mapping[Any, tuple[int, ...]]: all_variables = [self.variable] + [c.variable for c in self.coords.values()] return get_chunksizes(all_variables) + @_deprecate_positional_args("v2023.10.0") def chunk( - self: T_DataArray, - chunks: ( - int - | Literal["auto"] - | tuple[int, ...] - | tuple[tuple[int, ...], ...] - | Mapping[Any, None | int | tuple[int, ...]] - ) = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + self, + chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + *, name_prefix: str = "xarray-", token: str | None = None, lock: bool = False, inline_array: bool = False, + chunked_array_type: str | ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, **chunks_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Coerce this array's data into a dask arrays with the given chunks. If this variable is a non-dask array, it will be converted to dask @@ -1274,12 +1336,21 @@ def chunk( Prefix for the name of the new dask array. token : str, optional Token uniquely identifying this array. - lock : optional + lock : bool, default: False Passed on to :py:func:`dask.array.from_array`, if the array is not already as dask array. - inline_array: optional + inline_array: bool, default: False Passed on to :py:func:`dask.array.from_array`, if the array is not already as dask array. + chunked_array_type: str, optional + Which chunked array type to coerce the underlying data array to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. **chunks_kwargs : {dim: chunks, ...}, optional The keyword arguments form of ``chunks``. One of chunks or chunks_kwargs must be provided. @@ -1305,8 +1376,13 @@ def chunk( if isinstance(chunks, (float, str, int)): # ignoring type; unclear why it won't accept a Literal into the value. - chunks = dict.fromkeys(self.dims, chunks) # type: ignore + chunks = dict.fromkeys(self.dims, chunks) elif isinstance(chunks, (tuple, list)): + utils.emit_user_level_warning( + "Supplying chunks as dimension-order tuples is deprecated. " + "It will raise an error in the future. Instead use a dict with dimension names as keys.", + category=DeprecationWarning, + ) chunks = dict(zip(self.dims, chunks)) else: chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") @@ -1317,16 +1393,18 @@ def chunk( token=token, lock=lock, inline_array=inline_array, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, ) return self._from_temp_dataset(ds) def isel( - self: T_DataArray, + self, indexers: Mapping[Any, Any] | None = None, drop: bool = False, missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by selecting indexes along the specified dimension(s). @@ -1360,11 +1438,17 @@ def isel( Dataset.isel DataArray.sel + :doc:`xarray-tutorial:intermediate/indexing/indexing` + Tutorial material on indexing with Xarray objects + + :doc:`xarray-tutorial:fundamentals/02.1_indexing_Basic` + Tutorial material on basics of indexing + Examples -------- >>> da = xr.DataArray(np.arange(25).reshape(5, 5), dims=("x", "y")) >>> da - + Size: 200B array([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], @@ -1376,7 +1460,7 @@ def isel( >>> tgt_y = xr.DataArray(np.arange(0, 5), dims="points") >>> da = da.isel(x=tgt_x, y=tgt_y) >>> da - + Size: 40B array([ 0, 6, 12, 18, 24]) Dimensions without coordinates: points """ @@ -1412,13 +1496,13 @@ def isel( return self._replace(variable=variable, coords=coords, indexes=indexes) def sel( - self: T_DataArray, + self, indexers: Mapping[Any, Any] | None = None, method: str | None = None, tolerance=None, drop: bool = False, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by selecting index labels along the specified dimension(s). @@ -1492,6 +1576,12 @@ def sel( Dataset.sel DataArray.isel + :doc:`xarray-tutorial:intermediate/indexing/indexing` + Tutorial material on indexing with Xarray objects + + :doc:`xarray-tutorial:fundamentals/02.1_indexing_Basic` + Tutorial material on basics of indexing + Examples -------- >>> da = xr.DataArray( @@ -1500,25 +1590,25 @@ def sel( ... dims=("x", "y"), ... ) >>> da - + Size: 200B array([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]) Coordinates: - * x (x) int64 0 1 2 3 4 - * y (y) int64 0 1 2 3 4 + * x (x) int64 40B 0 1 2 3 4 + * y (y) int64 40B 0 1 2 3 4 >>> tgt_x = xr.DataArray(np.linspace(0, 4, num=5), dims="points") >>> tgt_y = xr.DataArray(np.linspace(0, 4, num=5), dims="points") >>> da = da.sel(x=tgt_x, y=tgt_y, method="nearest") >>> da - + Size: 40B array([ 0, 6, 12, 18, 24]) Coordinates: - x (points) int64 0 1 2 3 4 - y (points) int64 0 1 2 3 4 + x (points) int64 40B 0 1 2 3 4 + y (points) int64 40B 0 1 2 3 4 Dimensions without coordinates: points """ ds = self._to_temp_dataset().sel( @@ -1531,10 +1621,10 @@ def sel( return self._from_temp_dataset(ds) def head( - self: T_DataArray, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by the the first `n` values along the specified dimension(s). Default `n` = 5 @@ -1551,7 +1641,7 @@ def head( ... dims=("x", "y"), ... ) >>> da - + Size: 200B array([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], @@ -1560,12 +1650,12 @@ def head( Dimensions without coordinates: x, y >>> da.head(x=1) - + Size: 40B array([[0, 1, 2, 3, 4]]) Dimensions without coordinates: x, y >>> da.head({"x": 2, "y": 2}) - + Size: 32B array([[0, 1], [5, 6]]) Dimensions without coordinates: x, y @@ -1574,10 +1664,10 @@ def head( return self._from_temp_dataset(ds) def tail( - self: T_DataArray, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by the the last `n` values along the specified dimension(s). Default `n` = 5 @@ -1594,7 +1684,7 @@ def tail( ... dims=("x", "y"), ... ) >>> da - + Size: 200B array([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], @@ -1603,7 +1693,7 @@ def tail( Dimensions without coordinates: x, y >>> da.tail(y=1) - + Size: 40B array([[ 4], [ 9], [14], @@ -1612,7 +1702,7 @@ def tail( Dimensions without coordinates: x, y >>> da.tail({"x": 2, "y": 2}) - + Size: 32B array([[18, 19], [23, 24]]) Dimensions without coordinates: x, y @@ -1621,10 +1711,10 @@ def tail( return self._from_temp_dataset(ds) def thin( - self: T_DataArray, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by each `n` value along the specified dimension(s). @@ -1640,26 +1730,26 @@ def thin( ... coords={"x": [0, 1], "y": np.arange(0, 13)}, ... ) >>> x - + Size: 208B array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]]) Coordinates: - * x (x) int64 0 1 - * y (y) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 + * x (x) int64 16B 0 1 + * y (y) int64 104B 0 1 2 3 4 5 6 7 8 9 10 11 12 >>> >>> x.thin(3) - + Size: 40B array([[ 0, 3, 6, 9, 12]]) Coordinates: - * x (x) int64 0 - * y (y) int64 0 3 6 9 12 + * x (x) int64 8B 0 + * y (y) int64 40B 0 3 6 9 12 >>> x.thin({"x": 2, "y": 5}) - + Size: 24B array([[ 0, 5, 10]]) Coordinates: - * x (x) int64 0 - * y (y) int64 0 5 10 + * x (x) int64 8B 0 + * y (y) int64 24B 0 5 10 See Also -------- @@ -1670,11 +1760,13 @@ def thin( ds = self._to_temp_dataset().thin(indexers, **indexers_kwargs) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def broadcast_like( - self: T_DataArray, - other: DataArray | Dataset, + self, + other: T_DataArrayOrSet, + *, exclude: Iterable[Hashable] | None = None, - ) -> T_DataArray: + ) -> Self: """Broadcast this DataArray against another Dataset or DataArray. This is equivalent to xr.broadcast(other, self)[1] @@ -1713,28 +1805,28 @@ def broadcast_like( ... coords={"x": ["a", "b", "c"], "y": ["a", "b"]}, ... ) >>> arr1 - + Size: 48B array([[ 1.76405235, 0.40015721, 0.97873798], [ 2.2408932 , 1.86755799, -0.97727788]]) Coordinates: - * x (x) >> arr2 - + Size: 48B array([[ 0.95008842, -0.15135721], [-0.10321885, 0.4105985 ], [ 0.14404357, 1.45427351]]) Coordinates: - * x (x) >> arr1.broadcast_like(arr2) - + Size: 72B array([[ 1.76405235, 0.40015721, 0.97873798], [ 2.2408932 , 1.86755799, -0.97727788], [ nan, nan, nan]]) Coordinates: - * x (x) T_DataArray: + ) -> Self: """Callback called from ``Aligner`` to create a new reindexed DataArray.""" if isinstance(fill_value, dict): @@ -1777,18 +1867,26 @@ def _reindex_callback( exclude_dims, exclude_vars, ) - return self._from_temp_dataset(reindexed) + da = self._from_temp_dataset(reindexed) + da.encoding = self.encoding + + return da + + @_deprecate_positional_args("v2023.10.0") def reindex_like( - self: T_DataArray, - other: DataArray | Dataset, + self, + other: T_DataArrayOrSet, + *, method: ReindexMethodOptions = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value=dtypes.NA, - ) -> T_DataArray: - """Conform this object onto the indexes of another object, filling in - missing values with ``fill_value``. The default fill value is NaN. + ) -> Self: + """ + Conform this object onto the indexes of another object, for indexes which the + objects share. Missing values are filled with ``fill_value``. The default fill + value is NaN. Parameters ---------- @@ -1841,103 +1939,116 @@ def reindex_like( ... coords={"x": [10, 20, 30, 40], "y": [70, 80, 90]}, ... ) >>> da1 - + Size: 96B array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) Coordinates: - * x (x) int64 10 20 30 40 - * y (y) int64 70 80 90 + * x (x) int64 32B 10 20 30 40 + * y (y) int64 24B 70 80 90 >>> da2 = xr.DataArray( ... data=data, ... dims=["x", "y"], ... coords={"x": [40, 30, 20, 10], "y": [90, 80, 70]}, ... ) >>> da2 - + Size: 96B array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) Coordinates: - * x (x) int64 40 30 20 10 - * y (y) int64 90 80 70 + * x (x) int64 32B 40 30 20 10 + * y (y) int64 24B 90 80 70 Reindexing with both DataArrays having the same coordinates set, but in different order: >>> da1.reindex_like(da2) - + Size: 96B array([[11, 10, 9], [ 8, 7, 6], [ 5, 4, 3], [ 2, 1, 0]]) Coordinates: - * x (x) int64 40 30 20 10 - * y (y) int64 90 80 70 + * x (x) int64 32B 40 30 20 10 + * y (y) int64 24B 90 80 70 - Reindexing with the other array having coordinates which the source array doesn't have: + Reindexing with the other array having additional coordinates: - >>> data = np.arange(12).reshape(4, 3) - >>> da1 = xr.DataArray( - ... data=data, - ... dims=["x", "y"], - ... coords={"x": [10, 20, 30, 40], "y": [70, 80, 90]}, - ... ) - >>> da2 = xr.DataArray( + >>> da3 = xr.DataArray( ... data=data, ... dims=["x", "y"], ... coords={"x": [20, 10, 29, 39], "y": [70, 80, 90]}, ... ) - >>> da1.reindex_like(da2) - + >>> da1.reindex_like(da3) + Size: 96B array([[ 3., 4., 5.], [ 0., 1., 2.], [nan, nan, nan], [nan, nan, nan]]) Coordinates: - * x (x) int64 20 10 29 39 - * y (y) int64 70 80 90 + * x (x) int64 32B 20 10 29 39 + * y (y) int64 24B 70 80 90 Filling missing values with the previous valid index with respect to the coordinates' value: - >>> da1.reindex_like(da2, method="ffill") - + >>> da1.reindex_like(da3, method="ffill") + Size: 96B array([[3, 4, 5], [0, 1, 2], [3, 4, 5], [6, 7, 8]]) Coordinates: - * x (x) int64 20 10 29 39 - * y (y) int64 70 80 90 + * x (x) int64 32B 20 10 29 39 + * y (y) int64 24B 70 80 90 Filling missing values while tolerating specified error for inexact matches: - >>> da1.reindex_like(da2, method="ffill", tolerance=5) - + >>> da1.reindex_like(da3, method="ffill", tolerance=5) + Size: 96B array([[ 3., 4., 5.], [ 0., 1., 2.], [nan, nan, nan], [nan, nan, nan]]) Coordinates: - * x (x) int64 20 10 29 39 - * y (y) int64 70 80 90 + * x (x) int64 32B 20 10 29 39 + * y (y) int64 24B 70 80 90 Filling missing values with manually specified values: - >>> da1.reindex_like(da2, fill_value=19) - + >>> da1.reindex_like(da3, fill_value=19) + Size: 96B array([[ 3, 4, 5], [ 0, 1, 2], [19, 19, 19], [19, 19, 19]]) Coordinates: - * x (x) int64 20 10 29 39 - * y (y) int64 70 80 90 + * x (x) int64 32B 20 10 29 39 + * y (y) int64 24B 70 80 90 + + Note that unlike ``broadcast_like``, ``reindex_like`` doesn't create new dimensions: + + >>> da1.sel(x=20) + Size: 24B + array([3, 4, 5]) + Coordinates: + x int64 8B 20 + * y (y) int64 24B 70 80 90 + + ...so ``b`` in not added here: + + >>> da1.sel(x=20).reindex_like(da1) + Size: 24B + array([3, 4, 5]) + Coordinates: + x int64 8B 20 + * y (y) int64 24B 70 80 90 See Also -------- DataArray.reindex + DataArray.broadcast_like align """ return alignment.reindex_like( @@ -1949,15 +2060,17 @@ def reindex_like( fill_value=fill_value, ) + @_deprecate_positional_args("v2023.10.0") def reindex( - self: T_DataArray, + self, indexers: Mapping[Any, Any] | None = None, + *, method: ReindexMethodOptions = None, tolerance: float | Iterable[float] | None = None, copy: bool = True, fill_value=dtypes.NA, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Conform this object onto the indexes of another object, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -2015,15 +2128,15 @@ def reindex( ... dims="lat", ... ) >>> da - + Size: 32B array([0, 1, 2, 3]) Coordinates: - * lat (lat) int64 90 89 88 87 + * lat (lat) int64 32B 90 89 88 87 >>> da.reindex(lat=da.lat[::-1]) - + Size: 32B array([3, 2, 1, 0]) Coordinates: - * lat (lat) int64 87 88 89 90 + * lat (lat) int64 32B 87 88 89 90 See Also -------- @@ -2041,13 +2154,13 @@ def reindex( ) def interp( - self: T_DataArray, + self, coords: Mapping[Any, Any] | None = None, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, **coords_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Interpolate a DataArray onto new coordinates Performs univariate or multivariate interpolation of a DataArray onto @@ -2102,6 +2215,9 @@ def interp( scipy.interpolate.interp1d scipy.interpolate.interpn + :doc:`xarray-tutorial:fundamentals/02.2_manipulating_dimensions` + Tutorial material on manipulating data resolution using :py:func:`~xarray.DataArray.interp` + Examples -------- >>> da = xr.DataArray( @@ -2110,37 +2226,37 @@ def interp( ... coords={"x": [0, 1, 2], "y": [10, 12, 14, 16]}, ... ) >>> da - + Size: 96B array([[ 1., 4., 2., 9.], [ 2., 7., 6., nan], [ 6., nan, 5., 8.]]) Coordinates: - * x (x) int64 0 1 2 - * y (y) int64 10 12 14 16 + * x (x) int64 24B 0 1 2 + * y (y) int64 32B 10 12 14 16 1D linear interpolation (the default): >>> da.interp(x=[0, 0.75, 1.25, 1.75]) - + Size: 128B array([[1. , 4. , 2. , nan], [1.75, 6.25, 5. , nan], [3. , nan, 5.75, nan], [5. , nan, 5.25, nan]]) Coordinates: - * y (y) int64 10 12 14 16 - * x (x) float64 0.0 0.75 1.25 1.75 + * y (y) int64 32B 10 12 14 16 + * x (x) float64 32B 0.0 0.75 1.25 1.75 1D nearest interpolation: >>> da.interp(x=[0, 0.75, 1.25, 1.75], method="nearest") - + Size: 128B array([[ 1., 4., 2., 9.], [ 2., 7., 6., nan], [ 2., 7., 6., nan], [ 6., nan, 5., 8.]]) Coordinates: - * y (y) int64 10 12 14 16 - * x (x) float64 0.0 0.75 1.25 1.75 + * y (y) int64 32B 10 12 14 16 + * x (x) float64 32B 0.0 0.75 1.25 1.75 1D linear extrapolation: @@ -2149,31 +2265,30 @@ def interp( ... method="linear", ... kwargs={"fill_value": "extrapolate"}, ... ) - + Size: 128B array([[ 2. , 7. , 6. , nan], [ 4. , nan, 5.5, nan], [ 8. , nan, 4.5, nan], [12. , nan, 3.5, nan]]) Coordinates: - * y (y) int64 10 12 14 16 - * x (x) float64 1.0 1.5 2.5 3.5 + * y (y) int64 32B 10 12 14 16 + * x (x) float64 32B 1.0 1.5 2.5 3.5 2D linear interpolation: >>> da.interp(x=[0, 0.75, 1.25, 1.75], y=[11, 13, 15], method="linear") - + Size: 96B array([[2.5 , 3. , nan], [4. , 5.625, nan], [ nan, nan, nan], [ nan, nan, nan]]) Coordinates: - * x (x) float64 0.0 0.75 1.25 1.75 - * y (y) int64 11 13 15 + * x (x) float64 32B 0.0 0.75 1.25 1.75 + * y (y) int64 24B 11 13 15 """ if self.dtype.kind not in "uifc": raise TypeError( - "interp only works for a numeric type array. " - "Given {}.".format(self.dtype) + f"interp only works for a numeric type array. Given {self.dtype}." ) ds = self._to_temp_dataset().interp( coords, @@ -2185,12 +2300,12 @@ def interp( return self._from_temp_dataset(ds) def interp_like( - self: T_DataArray, - other: DataArray | Dataset, + self, + other: T_Xarray, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, - ) -> T_DataArray: + ) -> Self: """Interpolate this object onto the coordinates of another object, filling out of range values with NaN. @@ -2240,52 +2355,52 @@ def interp_like( ... coords={"x": [10, 20, 30, 40], "y": [70, 80, 90]}, ... ) >>> da1 - + Size: 96B array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) Coordinates: - * x (x) int64 10 20 30 40 - * y (y) int64 70 80 90 + * x (x) int64 32B 10 20 30 40 + * y (y) int64 24B 70 80 90 >>> da2 = xr.DataArray( ... data=data, ... dims=["x", "y"], ... coords={"x": [10, 20, 29, 39], "y": [70, 80, 90]}, ... ) >>> da2 - + Size: 96B array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) Coordinates: - * x (x) int64 10 20 29 39 - * y (y) int64 70 80 90 + * x (x) int64 32B 10 20 29 39 + * y (y) int64 24B 70 80 90 Interpolate the values in the coordinates of the other DataArray with respect to the source's values: >>> da2.interp_like(da1) - + Size: 96B array([[0. , 1. , 2. ], [3. , 4. , 5. ], [6.3, 7.3, 8.3], [nan, nan, nan]]) Coordinates: - * x (x) int64 10 20 30 40 - * y (y) int64 70 80 90 + * x (x) int64 32B 10 20 30 40 + * y (y) int64 24B 70 80 90 Could also extrapolate missing values: >>> da2.interp_like(da1, kwargs={"fill_value": "extrapolate"}) - + Size: 96B array([[ 0. , 1. , 2. ], [ 3. , 4. , 5. ], [ 6.3, 7.3, 8.3], [ 9.3, 10.3, 11.3]]) Coordinates: - * x (x) int64 10 20 30 40 - * y (y) int64 70 80 90 + * x (x) int64 32B 10 20 30 40 + * y (y) int64 24B 70 80 90 Notes ----- @@ -2300,21 +2415,18 @@ def interp_like( """ if self.dtype.kind not in "uifc": raise TypeError( - "interp only works for a numeric type array. " - "Given {}.".format(self.dtype) + f"interp only works for a numeric type array. Given {self.dtype}." ) ds = self._to_temp_dataset().interp_like( other, method=method, kwargs=kwargs, assume_sorted=assume_sorted ) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def rename( self, new_name_or_name_dict: Hashable | Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> DataArray: + ) -> Self: """Returns a new DataArray with renamed coordinates, dimensions or a new name. Parameters @@ -2355,10 +2467,10 @@ def rename( return self._replace(name=new_name_or_name_dict) def swap_dims( - self: T_DataArray, + self, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs, - ) -> T_DataArray: + ) -> Self: """Returns a new DataArray with swapped dimensions. Parameters @@ -2383,25 +2495,25 @@ def swap_dims( ... coords={"x": ["a", "b"], "y": ("x", [0, 1])}, ... ) >>> arr - + Size: 16B array([0, 1]) Coordinates: - * x (x) >> arr.swap_dims({"x": "y"}) - + Size: 16B array([0, 1]) Coordinates: - x (y) >> arr.swap_dims({"x": "z"}) - + Size: 16B array([0, 1]) Coordinates: - x (z) DataArray: + ) -> Self: """Return a new object with an additional axis (or axes) inserted at the corresponding position in the array shape. The new object is a view into the underlying array, not a copy. @@ -2462,20 +2572,20 @@ def expand_dims( -------- >>> da = xr.DataArray(np.arange(5), dims=("x")) >>> da - + Size: 40B array([0, 1, 2, 3, 4]) Dimensions without coordinates: x Add new dimension of length 2: >>> da.expand_dims(dim={"y": 2}) - + Size: 80B array([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]) Dimensions without coordinates: y, x >>> da.expand_dims(dim={"y": 2}, axis=1) - + Size: 80B array([[0, 0], [1, 1], [2, 2], @@ -2486,14 +2596,14 @@ def expand_dims( Add a new dimension with coordinates from array: >>> da.expand_dims(dim={"y": np.arange(5)}, axis=0) - + Size: 200B array([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]) Coordinates: - * y (y) int64 0 1 2 3 4 + * y (y) int64 40B 0 1 2 3 4 Dimensions without coordinates: x """ if isinstance(dim, int): @@ -2503,20 +2613,18 @@ def expand_dims( raise ValueError("dims should not contain duplicate values.") dim = dict.fromkeys(dim, 1) elif dim is not None and not isinstance(dim, Mapping): - dim = {cast(Hashable, dim): 1} + dim = {dim: 1} dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims") ds = self._to_temp_dataset().expand_dims(dim, axis) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def set_index( self, indexes: Mapping[Any, Hashable | Sequence[Hashable]] | None = None, append: bool = False, **indexes_kwargs: Hashable | Sequence[Hashable], - ) -> DataArray: + ) -> Self: """Set DataArray (multi-)indexes using one or more existing coordinates. @@ -2551,20 +2659,20 @@ def set_index( ... coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])}, ... ) >>> arr - + Size: 48B array([[1., 1., 1.], [1., 1., 1.]]) Coordinates: - * x (x) int64 0 1 - * y (y) int64 0 1 2 - a (x) int64 3 4 + * x (x) int64 16B 0 1 + * y (y) int64 24B 0 1 2 + a (x) int64 16B 3 4 >>> arr.set_index(x="a") - + Size: 48B array([[1., 1., 1.], [1., 1., 1.]]) Coordinates: - * x (x) int64 3 4 - * y (y) int64 0 1 2 + * x (x) int64 16B 3 4 + * y (y) int64 24B 0 1 2 See Also -------- @@ -2574,13 +2682,11 @@ def set_index( ds = self._to_temp_dataset().set_index(indexes, append=append, **indexes_kwargs) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def reset_index( self, dims_or_levels: Hashable | Sequence[Hashable], drop: bool = False, - ) -> DataArray: + ) -> Self: """Reset the specified index(es) or multi-index level(s). This legacy method is specific to pandas (multi-)indexes and @@ -2614,11 +2720,11 @@ def reset_index( return self._from_temp_dataset(ds) def set_xindex( - self: T_DataArray, + self, coord_names: str | Sequence[Hashable], index_cls: type[Index] | None = None, **options, - ) -> T_DataArray: + ) -> Self: """Set a new, Xarray-compatible index from one or more existing coordinate(s). @@ -2643,10 +2749,10 @@ def set_xindex( return self._from_temp_dataset(ds) def reorder_levels( - self: T_DataArray, + self, dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None, **dim_order_kwargs: Sequence[int | Hashable], - ) -> T_DataArray: + ) -> Self: """Rearrange index levels using input order. Parameters @@ -2669,12 +2775,12 @@ def reorder_levels( return self._from_temp_dataset(ds) def stack( - self: T_DataArray, + self, dimensions: Mapping[Any, Sequence[Hashable]] | None = None, create_index: bool | None = True, index_cls: type[Index] = PandasMultiIndex, **dimensions_kwargs: Sequence[Hashable], - ) -> T_DataArray: + ) -> Self: """ Stack any number of existing dimensions into a single new dimension. @@ -2713,12 +2819,12 @@ def stack( ... coords=[("x", ["a", "b"]), ("y", [0, 1, 2])], ... ) >>> arr - + Size: 48B array([[0, 1, 2], [3, 4, 5]]) Coordinates: - * x (x) >> stacked = arr.stack(z=("x", "y")) >>> stacked.indexes["z"] MultiIndex([('a', 0), @@ -2741,14 +2847,14 @@ def stack( ) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved + @_deprecate_positional_args("v2023.10.0") def unstack( self, dim: Dims = None, + *, fill_value: Any = dtypes.NA, sparse: bool = False, - ) -> DataArray: + ) -> Self: """ Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions. @@ -2780,12 +2886,12 @@ def unstack( ... coords=[("x", ["a", "b"]), ("y", [0, 1, 2])], ... ) >>> arr - + Size: 48B array([[0, 1, 2], [3, 4, 5]]) Coordinates: - * x (x) >> stacked = arr.stack(z=("x", "y")) >>> stacked.indexes["z"] MultiIndex([('a', 0), @@ -2803,7 +2909,7 @@ def unstack( -------- DataArray.stack """ - ds = self._to_temp_dataset().unstack(dim, fill_value, sparse) + ds = self._to_temp_dataset().unstack(dim, fill_value=fill_value, sparse=sparse) return self._from_temp_dataset(ds) def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Dataset: @@ -2832,19 +2938,19 @@ def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Data ... ) >>> data = xr.Dataset({"a": arr, "b": arr.isel(y=0)}) >>> data - + Size: 96B Dimensions: (x: 2, y: 3) Coordinates: - * x (x) >> stacked = data.to_stacked_array("z", ["x"]) >>> stacked.indexes["z"] - MultiIndex([('a', 0.0), - ('a', 1.0), - ('a', 2.0), + MultiIndex([('a', 0), + ('a', 1), + ('a', 2), ('b', nan)], name='z') >>> roundtripped = stacked.to_unstacked_dataset(dim="z") @@ -2872,11 +2978,11 @@ def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Data return Dataset(data_dict) def transpose( - self: T_DataArray, + self, *dims: Hashable, transpose_coords: bool = True, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> T_DataArray: + ) -> Self: """Return a new DataArray object with transposed dimensions. Parameters @@ -2910,7 +3016,7 @@ def transpose( Dataset.transpose """ if dims: - dims = tuple(utils.infix_dims(dims, self.dims, missing_dims)) + dims = tuple(infix_dims(dims, self.dims, missing_dims)) variable = self.variable.transpose(*dims) if transpose_coords: coords: dict[Hashable, Variable] = {} @@ -2922,23 +3028,22 @@ def transpose( return self._replace(variable) @property - def T(self: T_DataArray) -> T_DataArray: + def T(self) -> Self: return self.transpose() - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def drop_vars( self, - names: Hashable | Iterable[Hashable], + names: str | Iterable[Hashable] | Callable[[Self], str | Iterable[Hashable]], *, errors: ErrorOptions = "raise", - ) -> DataArray: + ) -> Self: """Returns an array with dropped variables. Parameters ---------- - names : Hashable or iterable of Hashable - Name(s) of variables to drop. + names : Hashable or iterable of Hashable or Callable + Name(s) of variables to drop. If a Callable, this object is passed as its + only argument and its result is used. errors : {"raise", "ignore"}, default: "raise" If 'raise', raises a ValueError error if any of the variable passed are not in the dataset. If 'ignore', any given names that are in the @@ -2958,46 +3063,56 @@ def drop_vars( ... coords={"x": [10, 20, 30, 40], "y": [70, 80, 90]}, ... ) >>> da - + Size: 96B array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) Coordinates: - * x (x) int64 10 20 30 40 - * y (y) int64 70 80 90 + * x (x) int64 32B 10 20 30 40 + * y (y) int64 24B 70 80 90 Removing a single variable: >>> da.drop_vars("x") - + Size: 96B array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) Coordinates: - * y (y) int64 70 80 90 + * y (y) int64 24B 70 80 90 Dimensions without coordinates: x Removing a list of variables: >>> da.drop_vars(["x", "y"]) - + Size: 96B + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Dimensions without coordinates: x, y + + >>> da.drop_vars(lambda x: x.coords) + Size: 96B array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) Dimensions without coordinates: x, y """ + if callable(names): + names = names(self) ds = self._to_temp_dataset().drop_vars(names, errors=errors) return self._from_temp_dataset(ds) def drop_indexes( - self: T_DataArray, + self, coord_names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_DataArray: + ) -> Self: """Drop the indexes assigned to the given coordinates. Parameters @@ -3018,13 +3133,13 @@ def drop_indexes( return self._from_temp_dataset(ds) def drop( - self: T_DataArray, + self, labels: Mapping[Any, Any] | None = None, dim: Hashable | None = None, *, errors: ErrorOptions = "raise", **labels_kwargs, - ) -> T_DataArray: + ) -> Self: """Backward compatible method based on `drop_vars` and `drop_sel` Using either `drop_vars` or `drop_sel` is encouraged @@ -3038,12 +3153,12 @@ def drop( return self._from_temp_dataset(ds) def drop_sel( - self: T_DataArray, + self, labels: Mapping[Any, Any] | None = None, *, errors: ErrorOptions = "raise", **labels_kwargs, - ) -> T_DataArray: + ) -> Self: """Drop index labels from this DataArray. Parameters @@ -3070,34 +3185,34 @@ def drop_sel( ... dims=("x", "y"), ... ) >>> da - + Size: 200B array([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]) Coordinates: - * x (x) int64 0 2 4 6 8 - * y (y) int64 0 3 6 9 12 + * x (x) int64 40B 0 2 4 6 8 + * y (y) int64 40B 0 3 6 9 12 >>> da.drop_sel(x=[0, 2], y=9) - + Size: 96B array([[10, 11, 12, 14], [15, 16, 17, 19], [20, 21, 22, 24]]) Coordinates: - * x (x) int64 4 6 8 - * y (y) int64 0 3 6 12 + * x (x) int64 24B 4 6 8 + * y (y) int64 32B 0 3 6 12 >>> da.drop_sel({"x": 6, "y": [0, 3]}) - + Size: 96B array([[ 2, 3, 4], [ 7, 8, 9], [12, 13, 14], [22, 23, 24]]) Coordinates: - * x (x) int64 0 2 4 8 - * y (y) int64 6 9 12 + * x (x) int64 32B 0 2 4 8 + * y (y) int64 24B 6 9 12 """ if labels_kwargs or isinstance(labels, dict): labels = either_dict_or_kwargs(labels, labels_kwargs, "drop") @@ -3106,8 +3221,8 @@ def drop_sel( return self._from_temp_dataset(ds) def drop_isel( - self: T_DataArray, indexers: Mapping[Any, Any] | None = None, **indexers_kwargs - ) -> T_DataArray: + self, indexers: Mapping[Any, Any] | None = None, **indexers_kwargs + ) -> Self: """Drop index positions from this DataArray. Parameters @@ -3129,7 +3244,7 @@ def drop_isel( -------- >>> da = xr.DataArray(np.arange(25).reshape(5, 5), dims=("X", "Y")) >>> da - + Size: 200B array([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], @@ -3138,14 +3253,14 @@ def drop_isel( Dimensions without coordinates: X, Y >>> da.drop_isel(X=[0, 4], Y=2) - + Size: 96B array([[ 5, 6, 8, 9], [10, 11, 13, 14], [15, 16, 18, 19]]) Dimensions without coordinates: X, Y >>> da.drop_isel({"X": 3, "Y": 3}) - + Size: 128B array([[ 0, 1, 2, 4], [ 5, 6, 7, 9], [10, 11, 12, 14], @@ -3156,12 +3271,14 @@ def drop_isel( dataset = dataset.drop_isel(indexers=indexers, **indexers_kwargs) return self._from_temp_dataset(dataset) + @_deprecate_positional_args("v2023.10.0") def dropna( - self: T_DataArray, + self, dim: Hashable, + *, how: Literal["any", "all"] = "any", thresh: int | None = None, - ) -> T_DataArray: + ) -> Self: """Returns a new array with dropped labels for missing values along the provided dimension. @@ -3198,41 +3315,41 @@ def dropna( ... ), ... ) >>> da - + Size: 128B array([[ 0., 4., 2., 9.], [nan, nan, nan, nan], [nan, 4., 2., 0.], [ 3., 1., 0., 0.]]) Coordinates: - lat (Y) float64 -20.0 -20.25 -20.5 -20.75 - lon (X) float64 10.0 10.25 10.5 10.75 + lat (Y) float64 32B -20.0 -20.25 -20.5 -20.75 + lon (X) float64 32B 10.0 10.25 10.5 10.75 Dimensions without coordinates: Y, X >>> da.dropna(dim="Y", how="any") - + Size: 64B array([[0., 4., 2., 9.], [3., 1., 0., 0.]]) Coordinates: - lat (Y) float64 -20.0 -20.75 - lon (X) float64 10.0 10.25 10.5 10.75 + lat (Y) float64 16B -20.0 -20.75 + lon (X) float64 32B 10.0 10.25 10.5 10.75 Dimensions without coordinates: Y, X Drop values only if all values along the dimension are NaN: >>> da.dropna(dim="Y", how="all") - + Size: 96B array([[ 0., 4., 2., 9.], [nan, 4., 2., 0.], [ 3., 1., 0., 0.]]) Coordinates: - lat (Y) float64 -20.0 -20.5 -20.75 - lon (X) float64 10.0 10.25 10.5 10.75 + lat (Y) float64 24B -20.0 -20.5 -20.75 + lon (X) float64 32B 10.0 10.25 10.5 10.75 Dimensions without coordinates: Y, X """ ds = self._to_temp_dataset().dropna(dim, how=how, thresh=thresh) return self._from_temp_dataset(ds) - def fillna(self: T_DataArray, value: Any) -> T_DataArray: + def fillna(self, value: Any) -> Self: """Fill missing values in this object. This operation follows the normal broadcasting and alignment rules that @@ -3262,29 +3379,29 @@ def fillna(self: T_DataArray, value: Any) -> T_DataArray: ... ), ... ) >>> da - + Size: 48B array([ 1., 4., nan, 0., 3., nan]) Coordinates: - * Z (Z) int64 0 1 2 3 4 5 - height (Z) int64 0 10 20 30 40 50 + * Z (Z) int64 48B 0 1 2 3 4 5 + height (Z) int64 48B 0 10 20 30 40 50 Fill all NaN values with 0: >>> da.fillna(0) - + Size: 48B array([1., 4., 0., 0., 3., 0.]) Coordinates: - * Z (Z) int64 0 1 2 3 4 5 - height (Z) int64 0 10 20 30 40 50 + * Z (Z) int64 48B 0 1 2 3 4 5 + height (Z) int64 48B 0 10 20 30 40 50 Fill NaN values with corresponding values in array: >>> da.fillna(np.array([2, 9, 4, 2, 8, 9])) - + Size: 48B array([1., 4., 4., 0., 3., 9.]) Coordinates: - * Z (Z) int64 0 1 2 3 4 5 - height (Z) int64 0 10 20 30 40 50 + * Z (Z) int64 48B 0 1 2 3 4 5 + height (Z) int64 48B 0 10 20 30 40 50 """ if utils.is_dict_like(value): raise TypeError( @@ -3295,7 +3412,7 @@ def fillna(self: T_DataArray, value: Any) -> T_DataArray: return out def interpolate_na( - self: T_DataArray, + self, dim: Hashable | None = None, method: InterpOptions = "linear", limit: int | None = None, @@ -3311,7 +3428,7 @@ def interpolate_na( ) = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Fill in NaNs by interpolating according to different methods. Parameters @@ -3319,7 +3436,7 @@ def interpolate_na( dim : Hashable or None, optional Specifies the dimension along which to interpolate. method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \ - "barycentric", "krog", "pchip", "spline", "akima"}, default: "linear" + "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear" String indicating which method to use for interpolation: - 'linear': linear interpolation. Additional keyword @@ -3328,15 +3445,15 @@ def interpolate_na( are passed to :py:func:`scipy.interpolate.interp1d`. If ``method='polynomial'``, the ``order`` keyword argument must also be provided. - - 'barycentric', 'krog', 'pchip', 'spline', 'akima': use their + - 'barycentric', 'krogh', 'pchip', 'spline', 'akima': use their respective :py:class:`scipy.interpolate` classes. use_coordinate : bool or str, default: True Specifies which index to use as the x values in the interpolation formulated as `y = f(x)`. If False, values are treated as if - eqaully-spaced along ``dim``. If True, the IndexVariable `dim` is + equally-spaced along ``dim``. If True, the IndexVariable `dim` is used. If ``use_coordinate`` is a string, it specifies the name of a - coordinate variariable to use as the index. + coordinate variable to use as the index. limit : int or None, default: None Maximum number of consecutive NaNs to fill. Must be greater than 0 or None for no limit. This filling is done regardless of the size of @@ -3388,22 +3505,22 @@ def interpolate_na( ... [np.nan, 2, 3, np.nan, 0], dims="x", coords={"x": [0, 1, 2, 3, 4]} ... ) >>> da - + Size: 40B array([nan, 2., 3., nan, 0.]) Coordinates: - * x (x) int64 0 1 2 3 4 + * x (x) int64 40B 0 1 2 3 4 >>> da.interpolate_na(dim="x", method="linear") - + Size: 40B array([nan, 2. , 3. , 1.5, 0. ]) Coordinates: - * x (x) int64 0 1 2 3 4 + * x (x) int64 40B 0 1 2 3 4 >>> da.interpolate_na(dim="x", method="linear", fill_value="extrapolate") - + Size: 40B array([1. , 2. , 3. , 1.5, 0. ]) Coordinates: - * x (x) int64 0 1 2 3 4 + * x (x) int64 40B 0 1 2 3 4 """ from xarray.core.missing import interp_na @@ -3418,9 +3535,7 @@ def interpolate_na( **kwargs, ) - def ffill( - self: T_DataArray, dim: Hashable, limit: int | None = None - ) -> T_DataArray: + def ffill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values forward *Requires bottleneck.* @@ -3461,52 +3576,50 @@ def ffill( ... ), ... ) >>> da - + Size: 120B array([[nan, 1., 3.], [ 0., nan, 5.], [ 5., nan, nan], [ 3., nan, nan], [ 0., 2., 0.]]) Coordinates: - lat (Y) float64 -20.0 -20.25 -20.5 -20.75 -21.0 - lon (X) float64 10.0 10.25 10.5 + lat (Y) float64 40B -20.0 -20.25 -20.5 -20.75 -21.0 + lon (X) float64 24B 10.0 10.25 10.5 Dimensions without coordinates: Y, X Fill all NaN values: >>> da.ffill(dim="Y", limit=None) - + Size: 120B array([[nan, 1., 3.], [ 0., 1., 5.], [ 5., 1., 5.], [ 3., 1., 5.], [ 0., 2., 0.]]) Coordinates: - lat (Y) float64 -20.0 -20.25 -20.5 -20.75 -21.0 - lon (X) float64 10.0 10.25 10.5 + lat (Y) float64 40B -20.0 -20.25 -20.5 -20.75 -21.0 + lon (X) float64 24B 10.0 10.25 10.5 Dimensions without coordinates: Y, X Fill only the first of consecutive NaN values: >>> da.ffill(dim="Y", limit=1) - + Size: 120B array([[nan, 1., 3.], [ 0., 1., 5.], [ 5., nan, 5.], [ 3., nan, nan], [ 0., 2., 0.]]) Coordinates: - lat (Y) float64 -20.0 -20.25 -20.5 -20.75 -21.0 - lon (X) float64 10.0 10.25 10.5 + lat (Y) float64 40B -20.0 -20.25 -20.5 -20.75 -21.0 + lon (X) float64 24B 10.0 10.25 10.5 Dimensions without coordinates: Y, X """ from xarray.core.missing import ffill return ffill(self, dim, limit=limit) - def bfill( - self: T_DataArray, dim: Hashable, limit: int | None = None - ) -> T_DataArray: + def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values backward *Requires bottleneck.* @@ -3547,50 +3660,50 @@ def bfill( ... ), ... ) >>> da - + Size: 120B array([[ 0., 1., 3.], [ 0., nan, 5.], [ 5., nan, nan], [ 3., nan, nan], [nan, 2., 0.]]) Coordinates: - lat (Y) float64 -20.0 -20.25 -20.5 -20.75 -21.0 - lon (X) float64 10.0 10.25 10.5 + lat (Y) float64 40B -20.0 -20.25 -20.5 -20.75 -21.0 + lon (X) float64 24B 10.0 10.25 10.5 Dimensions without coordinates: Y, X Fill all NaN values: >>> da.bfill(dim="Y", limit=None) - + Size: 120B array([[ 0., 1., 3.], [ 0., 2., 5.], [ 5., 2., 0.], [ 3., 2., 0.], [nan, 2., 0.]]) Coordinates: - lat (Y) float64 -20.0 -20.25 -20.5 -20.75 -21.0 - lon (X) float64 10.0 10.25 10.5 + lat (Y) float64 40B -20.0 -20.25 -20.5 -20.75 -21.0 + lon (X) float64 24B 10.0 10.25 10.5 Dimensions without coordinates: Y, X Fill only the first of consecutive NaN values: >>> da.bfill(dim="Y", limit=1) - + Size: 120B array([[ 0., 1., 3.], [ 0., nan, 5.], [ 5., nan, nan], [ 3., 2., 0.], [nan, 2., 0.]]) Coordinates: - lat (Y) float64 -20.0 -20.25 -20.5 -20.75 -21.0 - lon (X) float64 10.0 10.25 10.5 + lat (Y) float64 40B -20.0 -20.25 -20.5 -20.75 -21.0 + lon (X) float64 24B 10.0 10.25 10.5 Dimensions without coordinates: Y, X """ from xarray.core.missing import bfill return bfill(self, dim, limit=limit) - def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray: + def combine_first(self, other: Self) -> Self: """Combine two DataArray objects, with union of coordinates. This operation follows the normal broadcasting and alignment rules of @@ -3609,7 +3722,7 @@ def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray: return ops.fillna(self, other, join="outer") def reduce( - self: T_DataArray, + self, func: Callable[..., Any], dim: Dims = None, *, @@ -3617,7 +3730,7 @@ def reduce( keep_attrs: bool | None = None, keepdims: bool = False, **kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Reduce this array by applying `func` along some dimension(s). Parameters @@ -3655,7 +3768,7 @@ def reduce( var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs) return self._replace_maybe_drop_dims(var) - def to_pandas(self) -> DataArray | pd.Series | pd.DataFrame: + def to_pandas(self) -> Self | pd.Series | pd.DataFrame: """Convert this array into a pandas object with the same shape. The type of the returned object depends on the number of DataArray @@ -3801,8 +3914,23 @@ def to_netcdf( unlimited_dims: Iterable[Hashable] | None = None, compute: bool = True, invalid_netcdf: bool = False, - ) -> bytes: - ... + ) -> bytes: ... + + # compute=False returns dask.Delayed + @overload + def to_netcdf( + self, + path: str | PathLike, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + *, + compute: Literal[False], + invalid_netcdf: bool = False, + ) -> Delayed: ... # default return None @overload @@ -3817,10 +3945,10 @@ def to_netcdf( unlimited_dims: Iterable[Hashable] | None = None, compute: Literal[True] = True, invalid_netcdf: bool = False, - ) -> None: - ... + ) -> None: ... - # compute=False returns dask.Delayed + # if compute cannot be evaluated at type check time + # we may get back either Delayed or None @overload def to_netcdf( self, @@ -3831,11 +3959,9 @@ def to_netcdf( engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, unlimited_dims: Iterable[Hashable] | None = None, - *, - compute: Literal[False], + compute: bool = True, invalid_netcdf: bool = False, - ) -> Delayed: - ... + ) -> Delayed | None: ... def to_netcdf( self, @@ -3849,7 +3975,7 @@ def to_netcdf( compute: bool = True, invalid_netcdf: bool = False, ) -> bytes | Delayed | None: - """Write dataset contents to a netCDF file. + """Write DataArray contents to a netCDF file. Parameters ---------- @@ -3931,6 +4057,9 @@ def to_netcdf( name is the same as a coordinate name, then it is given the name ``"__xarray_dataarray_variable__"``. + [netCDF4 backend only] netCDF4 enums are decoded into the + dataarray dtype metadata. + See Also -------- Dataset.to_netcdf @@ -3963,7 +4092,214 @@ def to_netcdf( invalid_netcdf=invalid_netcdf, ) - def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]: + # compute=True (default) returns ZarrStore + @overload + def to_zarr( + self, + store: MutableMapping | str | PathLike[str] | None = None, + chunk_store: MutableMapping | str | PathLike | None = None, + mode: ZarrWriteModes | None = None, + synchronizer=None, + group: str | None = None, + *, + encoding: Mapping | None = None, + compute: Literal[True] = True, + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + ) -> ZarrStore: ... + + # compute=False returns dask.Delayed + @overload + def to_zarr( + self, + store: MutableMapping | str | PathLike[str] | None = None, + chunk_store: MutableMapping | str | PathLike | None = None, + mode: ZarrWriteModes | None = None, + synchronizer=None, + group: str | None = None, + encoding: Mapping | None = None, + *, + compute: Literal[False], + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + ) -> Delayed: ... + + def to_zarr( + self, + store: MutableMapping | str | PathLike[str] | None = None, + chunk_store: MutableMapping | str | PathLike | None = None, + mode: ZarrWriteModes | None = None, + synchronizer=None, + group: str | None = None, + encoding: Mapping | None = None, + *, + compute: bool = True, + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None, + zarr_version: int | None = None, + ) -> ZarrStore | Delayed: + """Write DataArray contents to a Zarr store + + Zarr chunks are determined in the following way: + + - From the ``chunks`` attribute in each variable's ``encoding`` + (can be set via `DataArray.chunk`). + - If the variable is a Dask array, from the dask chunks + - If neither Dask chunks nor encoding chunks are present, chunks will + be determined automatically by Zarr + - If both Dask chunks and encoding chunks are present, encoding chunks + will be used, provided that there is a many-to-one relationship between + encoding chunks and dask chunks (i.e. Dask chunks are bigger than and + evenly divide encoding chunks); otherwise raise a ``ValueError``. + This restriction ensures that no synchronization / locks are required + when writing. To disable this restriction, use ``safe_chunks=False``. + + Parameters + ---------- + store : MutableMapping, str or path-like, optional + Store or path to directory in local or remote file system. + chunk_store : MutableMapping, str or path-like, optional + Store or path to directory in local or remote file system only for Zarr + array chunks. Requires zarr-python v2.4.0 or later. + mode : {"w", "w-", "a", "a-", r+", None}, optional + Persistence mode: "w" means create (overwrite if exists); + "w-" means create (fail if exists); + "a" means override all existing variables including dimension coordinates (create if does not exist); + "a-" means only append those variables that have ``append_dim``. + "r+" means modify existing array *values* only (raise an error if + any metadata or shapes would change). + The default mode is "a" if ``append_dim`` is set. Otherwise, it is + "r+" if ``region`` is set and ``w-`` otherwise. + synchronizer : object, optional + Zarr array synchronizer. + group : str, optional + Group path. (a.k.a. `path` in zarr terminology.) + encoding : dict, optional + Nested dictionary with variable names as keys and dictionaries of + variable specific encodings as values, e.g., + ``{"my_variable": {"dtype": "int16", "scale_factor": 0.1,}, ...}`` + compute : bool, default: True + If True write array data immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed to write + array data later. Metadata is always updated eagerly. + consolidated : bool, optional + If True, apply zarr's `consolidate_metadata` function to the store + after writing metadata and read existing stores with consolidated + metadata; if False, do not. The default (`consolidated=None`) means + write consolidated metadata and attempt to read consolidated + metadata for existing stores (falling back to non-consolidated). + + When the experimental ``zarr_version=3``, ``consolidated`` must be + either be ``None`` or ``False``. + append_dim : hashable, optional + If set, the dimension along which the data will be appended. All + other dimensions on overridden variables must remain the same size. + region : dict, optional + Optional mapping from dimension names to integer slices along + dataarray dimensions to indicate the region of existing zarr array(s) + in which to write this datarray's data. For example, + ``{'x': slice(0, 1000), 'y': slice(10000, 11000)}`` would indicate + that values should be written to the region ``0:1000`` along ``x`` + and ``10000:11000`` along ``y``. + + Two restrictions apply to the use of ``region``: + + - If ``region`` is set, _all_ variables in a dataarray must have at + least one dimension in common with the region. Other variables + should be written in a separate call to ``to_zarr()``. + - Dimensions cannot be included in both ``region`` and + ``append_dim`` at the same time. To create empty arrays to fill + in with ``region``, use a separate call to ``to_zarr()`` with + ``compute=False``. See "Appending to existing Zarr stores" in + the reference documentation for full details. + safe_chunks : bool, default: True + If True, only allow writes to when there is a many-to-one relationship + between Zarr chunks (specified in encoding) and Dask chunks. + Set False to override this restriction; however, data may become corrupted + if Zarr arrays are written in parallel. This option may be useful in combination + with ``compute=False`` to initialize a Zarr store from an existing + DataArray with arbitrary chunk structure. + storage_options : dict, optional + Any additional parameters for the storage backend (ignored for local + paths). + zarr_version : int or None, optional + The desired zarr spec version to target (currently 2 or 3). The + default of None will attempt to determine the zarr version from + ``store`` when possible, otherwise defaulting to 2. + + Returns + ------- + * ``dask.delayed.Delayed`` if compute is False + * ZarrStore otherwise + + References + ---------- + https://zarr.readthedocs.io/ + + Notes + ----- + Zarr chunking behavior: + If chunks are found in the encoding argument or attribute + corresponding to any DataArray, those chunks are used. + If a DataArray is a dask array, it is written with those chunks. + If not other chunks are found, Zarr uses its own heuristics to + choose automatic chunk sizes. + + encoding: + The encoding attribute (if exists) of the DataArray(s) will be + used. Override any existing encodings by providing the ``encoding`` kwarg. + + See Also + -------- + Dataset.to_zarr + :ref:`io.zarr` + The I/O user guide, with more details and examples. + """ + from xarray.backends.api import DATAARRAY_NAME, DATAARRAY_VARIABLE, to_zarr + + if self.name is None: + # If no name is set then use a generic xarray name + dataset = self.to_dataset(name=DATAARRAY_VARIABLE) + elif self.name in self.coords or self.name in self.dims: + # The name is the same as one of the coords names, which the netCDF data model + # does not support, so rename it but keep track of the old name + dataset = self.to_dataset(name=DATAARRAY_VARIABLE) + dataset.attrs[DATAARRAY_NAME] = self.name + else: + # No problems with the name - so we're fine! + dataset = self.to_dataset() + + return to_zarr( # type: ignore[call-overload,misc] + dataset, + store=store, + chunk_store=chunk_store, + mode=mode, + synchronizer=synchronizer, + group=group, + encoding=encoding, + compute=compute, + consolidated=consolidated, + append_dim=append_dim, + region=region, + safe_chunks=safe_chunks, + storage_options=storage_options, + zarr_version=zarr_version, + ) + + def to_dict( + self, data: bool | Literal["list", "array"] = "list", encoding: bool = False + ) -> dict[str, Any]: """ Convert this xarray.DataArray into a dictionary following xarray naming conventions. @@ -3974,9 +4310,14 @@ def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]: Parameters ---------- - data : bool, default: True + data : bool or {"list", "array"}, default: "list" Whether to include the actual data in the dictionary. When set to - False, returns just the schema. + False, returns just the schema. If set to "array", returns data as + underlying array type. If set to "list" (or True for backwards + compatibility), returns data in lists of Python data types. Note + that for obtaining the "list" output efficiently, use + `da.compute().to_dict(data="list")`. + encoding : bool, default: False Whether to include the Dataset's encoding in the dictionary. @@ -3998,7 +4339,7 @@ def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]: return d @classmethod - def from_dict(cls: type[T_DataArray], d: Mapping[str, Any]) -> T_DataArray: + def from_dict(cls, d: Mapping[str, Any]) -> Self: """Convert a dictionary into an xarray.DataArray Parameters @@ -4020,7 +4361,7 @@ def from_dict(cls: type[T_DataArray], d: Mapping[str, Any]) -> T_DataArray: >>> d = {"dims": "t", "data": [1, 2, 3]} >>> da = xr.DataArray.from_dict(d) >>> da - + Size: 24B array([1, 2, 3]) Dimensions without coordinates: t @@ -4035,10 +4376,10 @@ def from_dict(cls: type[T_DataArray], d: Mapping[str, Any]) -> T_DataArray: ... } >>> da = xr.DataArray.from_dict(d) >>> da - + Size: 24B array([10, 20, 30]) Coordinates: - * t (t) int64 0 1 2 + * t (t) int64 24B 0 1 2 Attributes: title: air temperature """ @@ -4052,7 +4393,7 @@ def from_dict(cls: type[T_DataArray], d: Mapping[str, Any]) -> T_DataArray: except KeyError as e: raise ValueError( "cannot convert dict when coords are missing the key " - "'{dims_data}'".format(dims_data=str(e.args[0])) + f"'{str(e.args[0])}'" ) try: data = d["data"] @@ -4090,23 +4431,10 @@ def from_series(cls, series: pd.Series, sparse: bool = False) -> DataArray: temp_name = "__temporary_name" df = pd.DataFrame({temp_name: series}) ds = Dataset.from_dataframe(df, sparse=sparse) - result = cast(DataArray, ds[temp_name]) + result = ds[temp_name] result.name = series.name return result - def to_cdms2(self) -> cdms2_Variable: - """Convert this array into a cdms2.Variable""" - from xarray.convert import to_cdms2 - - return to_cdms2(self) - - @classmethod - def from_cdms2(cls, variable: cdms2_Variable) -> DataArray: - """Convert a cdms2.Variable into an xarray.DataArray""" - from xarray.convert import from_cdms2 - - return from_cdms2(variable) - def to_iris(self) -> iris_Cube: """Convert this array into a iris.cube.Cube""" from xarray.convert import to_iris @@ -4114,13 +4442,13 @@ def to_iris(self) -> iris_Cube: return to_iris(self) @classmethod - def from_iris(cls, cube: iris_Cube) -> DataArray: + def from_iris(cls, cube: iris_Cube) -> Self: """Convert a iris.cube.Cube into an xarray.DataArray""" from xarray.convert import from_iris return from_iris(cube) - def _all_compat(self: T_DataArray, other: T_DataArray, compat_str: str) -> bool: + def _all_compat(self, other: Self, compat_str: str) -> bool: """Helper function for equals, broadcast_equals, and identical""" def compat(x, y): @@ -4130,7 +4458,7 @@ def compat(x, y): self, other ) - def broadcast_equals(self: T_DataArray, other: T_DataArray) -> bool: + def broadcast_equals(self, other: Self) -> bool: """Two DataArrays are broadcast equal if they are equal after broadcasting them against each other such that they have the same dimensions. @@ -4155,16 +4483,16 @@ def broadcast_equals(self: T_DataArray, other: T_DataArray) -> bool: >>> a = xr.DataArray([1, 2], dims="X") >>> b = xr.DataArray([[1, 1], [2, 2]], dims=["X", "Y"]) >>> a - + Size: 16B array([1, 2]) Dimensions without coordinates: X >>> b - + Size: 32B array([[1, 1], [2, 2]]) Dimensions without coordinates: X, Y - .equals returns True if two DataArrays have the same values, dimensions, and coordinates. .broadcast_equals returns True if the results of broadcasting two DataArrays against eachother have the same values, dimensions, and coordinates. + .equals returns True if two DataArrays have the same values, dimensions, and coordinates. .broadcast_equals returns True if the results of broadcasting two DataArrays against each other have the same values, dimensions, and coordinates. >>> a.equals(b) False @@ -4179,7 +4507,7 @@ def broadcast_equals(self: T_DataArray, other: T_DataArray) -> bool: except (TypeError, AttributeError): return False - def equals(self: T_DataArray, other: T_DataArray) -> bool: + def equals(self, other: Self) -> bool: """True if two DataArrays have the same dimensions, coordinates and values; otherwise False. @@ -4211,21 +4539,21 @@ def equals(self: T_DataArray, other: T_DataArray) -> bool: >>> c = xr.DataArray([1, 2, 3], dims="Y") >>> d = xr.DataArray([3, 2, 1], dims="X") >>> a - + Size: 24B array([1, 2, 3]) Dimensions without coordinates: X >>> b - + Size: 24B array([1, 2, 3]) Dimensions without coordinates: X Attributes: units: m >>> c - + Size: 24B array([1, 2, 3]) Dimensions without coordinates: Y >>> d - + Size: 24B array([3, 2, 1]) Dimensions without coordinates: X @@ -4241,7 +4569,7 @@ def equals(self: T_DataArray, other: T_DataArray) -> bool: except (TypeError, AttributeError): return False - def identical(self: T_DataArray, other: T_DataArray) -> bool: + def identical(self, other: Self) -> bool: """Like equals, but also checks the array name and attributes, and attributes on all coordinates. @@ -4266,19 +4594,19 @@ def identical(self: T_DataArray, other: T_DataArray) -> bool: >>> b = xr.DataArray([1, 2, 3], dims="X", attrs=dict(units="m"), name="Width") >>> c = xr.DataArray([1, 2, 3], dims="X", attrs=dict(units="ft"), name="Width") >>> a - + Size: 24B array([1, 2, 3]) Dimensions without coordinates: X Attributes: units: m >>> b - + Size: 24B array([1, 2, 3]) Dimensions without coordinates: X Attributes: units: m >>> c - + Size: 24B array([1, 2, 3]) Dimensions without coordinates: X Attributes: @@ -4308,19 +4636,19 @@ def _result_name(self, other: Any = None) -> Hashable | None: else: return None - def __array_wrap__(self: T_DataArray, obj, context=None) -> T_DataArray: + def __array_wrap__(self, obj, context=None) -> Self: new_var = self.variable.__array_wrap__(obj, context) return self._replace(new_var) - def __matmul__(self: T_DataArray, obj: T_DataArray) -> T_DataArray: + def __matmul__(self, obj: T_Xarray) -> T_Xarray: return self.dot(obj) - def __rmatmul__(self: T_DataArray, other: T_DataArray) -> T_DataArray: + def __rmatmul__(self, other: T_Xarray) -> T_Xarray: # currently somewhat duplicative, as only other DataArrays are # compatible with matmul return computation.dot(other, self) - def _unary_op(self: T_DataArray, f: Callable, *args, **kwargs) -> T_DataArray: + def _unary_op(self, f: Callable, *args, **kwargs) -> Self: keep_attrs = kwargs.pop("keep_attrs", None) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) @@ -4336,32 +4664,29 @@ def _unary_op(self: T_DataArray, f: Callable, *args, **kwargs) -> T_DataArray: return da def _binary_op( - self: T_DataArray, - other: Any, - f: Callable, - reflexive: bool = False, - ) -> T_DataArray: + self, other: DaCompatible, f: Callable, reflexive: bool = False + ) -> Self: from xarray.core.groupby import GroupBy if isinstance(other, (Dataset, GroupBy)): return NotImplemented if isinstance(other, DataArray): align_type = OPTIONS["arithmetic_join"] - self, other = align(self, other, join=align_type, copy=False) # type: ignore - other_variable = getattr(other, "variable", other) + self, other = align(self, other, join=align_type, copy=False) + other_variable_or_arraylike: DaCompatible = getattr(other, "variable", other) other_coords = getattr(other, "coords", None) variable = ( - f(self.variable, other_variable) + f(self.variable, other_variable_or_arraylike) if not reflexive - else f(other_variable, self.variable) + else f(other_variable_or_arraylike, self.variable) ) coords, indexes = self.coords._merge_raw(other_coords, reflexive) name = self._result_name(other) return self._replace(variable, coords, name, indexes=indexes) - def _inplace_binary_op(self: T_DataArray, other: Any, f: Callable) -> T_DataArray: + def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self: from xarray.core.groupby import GroupBy if isinstance(other, GroupBy): @@ -4411,11 +4736,7 @@ def _title_for_slice(self, truncate: int = 50) -> str: for dim, coord in self.coords.items(): if coord.size == 1: one_dims.append( - "{dim} = {v}{unit}".format( - dim=dim, - v=format_item(coord.values), - unit=_get_units_from_attrs(coord), - ) + f"{dim} = {format_item(coord.values)}{_get_units_from_attrs(coord)}" ) title = ", ".join(one_dims) @@ -4424,12 +4745,14 @@ def _title_for_slice(self, truncate: int = 50) -> str: return title + @_deprecate_positional_args("v2023.10.0") def diff( - self: T_DataArray, + self, dim: Hashable, n: int = 1, + *, label: Literal["upper", "lower"] = "upper", - ) -> T_DataArray: + ) -> Self: """Calculate the n-th order discrete difference along given axis. Parameters @@ -4457,15 +4780,15 @@ def diff( -------- >>> arr = xr.DataArray([5, 5, 6, 6], [[1, 2, 3, 4]], ["x"]) >>> arr.diff("x") - + Size: 24B array([0, 1, 0]) Coordinates: - * x (x) int64 2 3 4 + * x (x) int64 24B 2 3 4 >>> arr.diff("x", 2) - + Size: 16B array([ 1, -1]) Coordinates: - * x (x) int64 3 4 + * x (x) int64 16B 3 4 See Also -------- @@ -4475,11 +4798,11 @@ def diff( return self._from_temp_dataset(ds) def shift( - self: T_DataArray, + self, shifts: Mapping[Any, int] | None = None, fill_value: Any = dtypes.NA, **shifts_kwargs: int, - ) -> T_DataArray: + ) -> Self: """Shift this DataArray by an offset along one or more dimensions. Only the data is moved; coordinates stay in place. This is consistent @@ -4515,7 +4838,7 @@ def shift( -------- >>> arr = xr.DataArray([5, 6, 7], dims="x") >>> arr.shift(x=1) - + Size: 24B array([nan, 5., 6.]) Dimensions without coordinates: x """ @@ -4525,11 +4848,11 @@ def shift( return self._replace(variable=variable) def roll( - self: T_DataArray, + self, shifts: Mapping[Hashable, int] | None = None, roll_coords: bool = False, **shifts_kwargs: int, - ) -> T_DataArray: + ) -> Self: """Roll this array by an offset along one or more dimensions. Unlike shift, roll treats the given dimensions as periodic, so will not @@ -4564,7 +4887,7 @@ def roll( -------- >>> arr = xr.DataArray([5, 6, 7], dims="x") >>> arr.roll(x=1) - + Size: 24B array([7, 5, 6]) Dimensions without coordinates: x """ @@ -4574,7 +4897,7 @@ def roll( return self._from_temp_dataset(ds) @property - def real(self: T_DataArray) -> T_DataArray: + def real(self) -> Self: """ The real part of the array. @@ -4585,7 +4908,7 @@ def real(self: T_DataArray) -> T_DataArray: return self._replace(self.variable.real) @property - def imag(self: T_DataArray) -> T_DataArray: + def imag(self) -> Self: """ The imaginary part of the array. @@ -4595,11 +4918,12 @@ def imag(self: T_DataArray) -> T_DataArray: """ return self._replace(self.variable.imag) + @deprecate_dims def dot( - self: T_DataArray, - other: T_DataArray, - dims: Dims = None, - ) -> T_DataArray: + self, + other: T_Xarray, + dim: Dims = None, + ) -> T_Xarray: """Perform dot product of two DataArrays along their shared dims. Equivalent to taking taking tensordot over all shared dims. @@ -4608,7 +4932,7 @@ def dot( ---------- other : DataArray The other array with which the dot product is performed. - dims : ..., str, Iterable of Hashable or None, optional + dim : ..., str, Iterable of Hashable or None, optional Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions. If not specified, then all the common dimensions are summed over. @@ -4647,15 +4971,18 @@ def dot( if not isinstance(other, DataArray): raise TypeError("dot only operates on DataArrays.") - return computation.dot(self, other, dims=dims) + return computation.dot(self, other, dim=dim) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def sortby( self, - variables: Hashable | DataArray | Sequence[Hashable | DataArray], + variables: ( + Hashable + | DataArray + | Sequence[Hashable | DataArray] + | Callable[[Self], Hashable | DataArray | Sequence[Hashable | DataArray]] + ), ascending: bool = True, - ) -> DataArray: + ) -> Self: """Sort object by labels or values (along an axis). Sorts the dataarray, either along specified dimensions, @@ -4674,9 +5001,10 @@ def sortby( Parameters ---------- - variables : Hashable, DataArray, or sequence of Hashable or DataArray - 1D DataArray objects or name(s) of 1D variable(s) in - coords whose values are used to sort this array. + variables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable + 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are + used to sort this array. If a callable, the callable is passed this object, + and the result is used as the value for cond. ascending : bool, default: True Whether to sort by ascending or descending order. @@ -4696,34 +5024,47 @@ def sortby( Examples -------- >>> da = xr.DataArray( - ... np.random.rand(5), + ... np.arange(5, 0, -1), ... coords=[pd.date_range("1/1/2000", periods=5)], ... dims="time", ... ) >>> da - - array([0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.4236548 ]) + Size: 40B + array([5, 4, 3, 2, 1]) Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05 + * time (time) datetime64[ns] 40B 2000-01-01 2000-01-02 ... 2000-01-05 >>> da.sortby(da) - - array([0.4236548 , 0.54488318, 0.5488135 , 0.60276338, 0.71518937]) + Size: 40B + array([1, 2, 3, 4, 5]) + Coordinates: + * time (time) datetime64[ns] 40B 2000-01-05 2000-01-04 ... 2000-01-01 + + >>> da.sortby(lambda x: x) + Size: 40B + array([1, 2, 3, 4, 5]) Coordinates: - * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-02 + * time (time) datetime64[ns] 40B 2000-01-05 2000-01-04 ... 2000-01-01 """ + # We need to convert the callable here rather than pass it through to the + # dataset method, since otherwise the dataset method would try to call the + # callable with the dataset as the object + if callable(variables): + variables = variables(self) ds = self._to_temp_dataset().sortby(variables, ascending=ascending) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def quantile( - self: T_DataArray, + self, q: ArrayLike, dim: Dims = None, + *, method: QuantileMethods = "linear", keep_attrs: bool | None = None, skipna: bool | None = None, interpolation: QuantileMethods | None = None, - ) -> T_DataArray: + ) -> Self: """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements. @@ -4739,15 +5080,15 @@ def quantile( desired quantile lies between two data points. The options sorted by their R type as summarized in the H&F paper [1]_ are: - 1. "inverted_cdf" (*) - 2. "averaged_inverted_cdf" (*) - 3. "closest_observation" (*) - 4. "interpolated_inverted_cdf" (*) - 5. "hazen" (*) - 6. "weibull" (*) + 1. "inverted_cdf" + 2. "averaged_inverted_cdf" + 3. "closest_observation" + 4. "interpolated_inverted_cdf" + 5. "hazen" + 6. "weibull" 7. "linear" (default) - 8. "median_unbiased" (*) - 9. "normal_unbiased" (*) + 8. "median_unbiased" + 9. "normal_unbiased" The first three methods are discontiuous. The following discontinuous variations of the default "linear" (7.) option are also available: @@ -4761,8 +5102,6 @@ def quantile( was previously called "interpolation", renamed in accordance with numpy version 1.22.0. - (*) These methods require numpy version 1.22 or newer. - keep_attrs : bool or None, optional If True, the dataset's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new @@ -4794,29 +5133,29 @@ def quantile( ... dims=("x", "y"), ... ) >>> da.quantile(0) # or da.quantile(0, dim=...) - + Size: 8B array(0.7) Coordinates: - quantile float64 0.0 + quantile float64 8B 0.0 >>> da.quantile(0, dim="x") - + Size: 32B array([0.7, 4.2, 2.6, 1.5]) Coordinates: - * y (y) float64 1.0 1.5 2.0 2.5 - quantile float64 0.0 + * y (y) float64 32B 1.0 1.5 2.0 2.5 + quantile float64 8B 0.0 >>> da.quantile([0, 0.5, 1]) - + Size: 24B array([0.7, 3.4, 9.4]) Coordinates: - * quantile (quantile) float64 0.0 0.5 1.0 + * quantile (quantile) float64 24B 0.0 0.5 1.0 >>> da.quantile([0, 0.5, 1], dim="x") - + Size: 96B array([[0.7 , 4.2 , 2.6 , 1.5 ], [3.6 , 5.75, 6. , 1.7 ], [6.5 , 7.3 , 9.4 , 1.9 ]]) Coordinates: - * y (y) float64 1.0 1.5 2.0 2.5 - * quantile (quantile) float64 0.0 0.5 1.0 + * y (y) float64 32B 1.0 1.5 2.0 2.5 + * quantile (quantile) float64 24B 0.0 0.5 1.0 References ---------- @@ -4835,12 +5174,14 @@ def quantile( ) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def rank( - self: T_DataArray, + self, dim: Hashable, + *, pct: bool = False, keep_attrs: bool | None = None, - ) -> T_DataArray: + ) -> Self: """Ranks the data. Equal values are assigned a rank that is the average of the ranks that @@ -4871,7 +5212,7 @@ def rank( -------- >>> arr = xr.DataArray([5, 6, 7], dims="x") >>> arr.rank("x") - + Size: 24B array([1., 2., 3.]) Dimensions without coordinates: x """ @@ -4880,11 +5221,11 @@ def rank( return self._from_temp_dataset(ds) def differentiate( - self: T_DataArray, + self, coord: Hashable, edge_order: Literal[1, 2] = 1, datetime_unit: DatetimeUnitOptions = None, - ) -> T_DataArray: + ) -> Self: """ Differentiate the array with the second order accurate central differences. @@ -4898,9 +5239,10 @@ def differentiate( The coordinate to be used to compute the gradient. edge_order : {1, 2}, default: 1 N-th order accurate differences at the boundaries. - datetime_unit : {"Y", "M", "W", "D", "h", "m", "s", "ms", \ + datetime_unit : {"W", "D", "h", "m", "s", "ms", \ "us", "ns", "ps", "fs", "as", None}, optional - Unit to compute gradient. Only valid for datetime coordinate. + Unit to compute gradient. Only valid for datetime coordinate. "Y" and "M" are not available as + datetime_unit. Returns ------- @@ -4919,35 +5261,33 @@ def differentiate( ... coords={"x": [0, 0.1, 1.1, 1.2]}, ... ) >>> da - + Size: 96B array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) Coordinates: - * x (x) float64 0.0 0.1 1.1 1.2 + * x (x) float64 32B 0.0 0.1 1.1 1.2 Dimensions without coordinates: y >>> >>> da.differentiate("x") - + Size: 96B array([[30. , 30. , 30. ], [27.54545455, 27.54545455, 27.54545455], [27.54545455, 27.54545455, 27.54545455], [30. , 30. , 30. ]]) Coordinates: - * x (x) float64 0.0 0.1 1.1 1.2 + * x (x) float64 32B 0.0 0.1 1.1 1.2 Dimensions without coordinates: y """ ds = self._to_temp_dataset().differentiate(coord, edge_order, datetime_unit) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def integrate( self, coord: Hashable | Sequence[Hashable] = None, datetime_unit: DatetimeUnitOptions = None, - ) -> DataArray: + ) -> Self: """Integrate along the given coordinate using the trapezoidal rule. .. note:: @@ -4980,30 +5320,28 @@ def integrate( ... coords={"x": [0, 0.1, 1.1, 1.2]}, ... ) >>> da - + Size: 96B array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) Coordinates: - * x (x) float64 0.0 0.1 1.1 1.2 + * x (x) float64 32B 0.0 0.1 1.1 1.2 Dimensions without coordinates: y >>> >>> da.integrate("x") - + Size: 24B array([5.4, 6.6, 7.8]) Dimensions without coordinates: y """ ds = self._to_temp_dataset().integrate(coord, datetime_unit) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def cumulative_integrate( self, coord: Hashable | Sequence[Hashable] = None, datetime_unit: DatetimeUnitOptions = None, - ) -> DataArray: + ) -> Self: """Integrate cumulatively along the given coordinate using the trapezoidal rule. .. note:: @@ -5039,29 +5377,29 @@ def cumulative_integrate( ... coords={"x": [0, 0.1, 1.1, 1.2]}, ... ) >>> da - + Size: 96B array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) Coordinates: - * x (x) float64 0.0 0.1 1.1 1.2 + * x (x) float64 32B 0.0 0.1 1.1 1.2 Dimensions without coordinates: y >>> >>> da.cumulative_integrate("x") - + Size: 96B array([[0. , 0. , 0. ], [0.15, 0.25, 0.35], [4.65, 5.75, 6.85], [5.4 , 6.6 , 7.8 ]]) Coordinates: - * x (x) float64 0.0 0.1 1.1 1.2 + * x (x) float64 32B 0.0 0.1 1.1 1.2 Dimensions without coordinates: y """ ds = self._to_temp_dataset().cumulative_integrate(coord, datetime_unit) return self._from_temp_dataset(ds) - def unify_chunks(self) -> DataArray: + def unify_chunks(self) -> Self: """Unify chunk size along all chunked dimensions of this DataArray. Returns @@ -5133,6 +5471,9 @@ def map_blocks( dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks xarray.DataArray.map_blocks + :doc:`xarray-tutorial:advanced/map_blocks/map_blocks` + Advanced Tutorial on map_blocks with dask + Examples -------- Calculate an anomaly from climatology using ``.groupby()``. Using @@ -5144,7 +5485,7 @@ def map_blocks( ... clim = gb.mean(dim="time") ... return gb - clim ... - >>> time = xr.cftime_range("1990-01", "1992-01", freq="M") + >>> time = xr.cftime_range("1990-01", "1992-01", freq="ME") >>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"]) >>> np.random.seed(123) >>> array = xr.DataArray( @@ -5153,15 +5494,15 @@ def map_blocks( ... coords={"time": time, "month": month}, ... ).chunk() >>> array.map_blocks(calculate_anomaly, template=array).compute() - + Size: 192B array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, -0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 , 0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108, 0.07673453, 0.22865714, 0.19063865, -0.0590131 ]) Coordinates: - * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 - month (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12 + * time (time) object 192B 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 192B 1 2 3 4 5 6 7 8 9 10 ... 3 4 5 6 7 8 9 10 11 12 Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments to the function being applied in ``xr.map_blocks()``: @@ -5169,11 +5510,11 @@ def map_blocks( >>> array.map_blocks( ... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=array ... ) # doctest: +ELLIPSIS - + Size: 192B dask.array<-calculate_anomaly, shape=(24,), dtype=float64, chunksize=(24,), chunktype=numpy.ndarray> Coordinates: - * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 - month (time) int64 dask.array + * time (time) object 192B 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 192B dask.array """ from xarray.core.parallel import map_blocks @@ -5239,28 +5580,27 @@ def polyfit( numpy.polyfit numpy.polyval xarray.polyval + DataArray.curvefit """ return self._to_temp_dataset().polyfit( dim, deg, skipna=skipna, rcond=rcond, w=w, full=full, cov=cov ) def pad( - self: T_DataArray, + self, pad_width: Mapping[Any, int | tuple[int, int]] | None = None, mode: PadModeOptions = "constant", - stat_length: int - | tuple[int, int] - | Mapping[Any, tuple[int, int]] - | None = None, - constant_values: float - | tuple[float, float] - | Mapping[Any, tuple[float, float]] - | None = None, + stat_length: ( + int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None + ) = None, + constant_values: ( + float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None + ) = None, end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None, reflect_type: PadReflectOptions = None, keep_attrs: bool | None = None, **pad_width_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Pad this array along one or more dimensions. .. warning:: @@ -5365,10 +5705,10 @@ def pad( -------- >>> arr = xr.DataArray([5, 6, 7], coords=[("x", [0, 1, 2])]) >>> arr.pad(x=(1, 2), constant_values=0) - + Size: 48B array([0, 5, 6, 7, 0, 0]) Coordinates: - * x (x) float64 nan 0.0 1.0 2.0 nan nan + * x (x) float64 48B nan 0.0 1.0 2.0 nan nan >>> da = xr.DataArray( ... [[0, 1, 2, 3], [10, 11, 12, 13]], @@ -5376,29 +5716,29 @@ def pad( ... coords={"x": [0, 1], "y": [10, 20, 30, 40], "z": ("x", [100, 200])}, ... ) >>> da.pad(x=1) - + Size: 128B array([[nan, nan, nan, nan], [ 0., 1., 2., 3.], [10., 11., 12., 13.], [nan, nan, nan, nan]]) Coordinates: - * x (x) float64 nan 0.0 1.0 nan - * y (y) int64 10 20 30 40 - z (x) float64 nan 100.0 200.0 nan + * x (x) float64 32B nan 0.0 1.0 nan + * y (y) int64 32B 10 20 30 40 + z (x) float64 32B nan 100.0 200.0 nan Careful, ``constant_values`` are coerced to the data type of the array which may lead to a loss of precision: >>> da.pad(x=1, constant_values=1.23456789) - + Size: 128B array([[ 1, 1, 1, 1], [ 0, 1, 2, 3], [10, 11, 12, 13], [ 1, 1, 1, 1]]) Coordinates: - * x (x) float64 nan 0.0 1.0 nan - * y (y) int64 10 20 30 40 - z (x) float64 nan 100.0 200.0 nan + * x (x) float64 32B nan 0.0 1.0 nan + * y (y) int64 32B 10 20 30 40 + z (x) float64 32B nan 100.0 200.0 nan """ ds = self._to_temp_dataset().pad( pad_width=pad_width, @@ -5412,13 +5752,15 @@ def pad( ) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def idxmin( self, dim: Hashable | None = None, + *, skipna: bool | None = None, fill_value: Any = dtypes.NA, keep_attrs: bool | None = None, - ) -> DataArray: + ) -> Self: """Return the coordinate label of the minimum value along a dimension. Returns a new `DataArray` named after the dimension with the values of @@ -5465,39 +5807,39 @@ def idxmin( ... [0, 2, 1, 0, -2], dims="x", coords={"x": ["a", "b", "c", "d", "e"]} ... ) >>> array.min() - + Size: 8B array(-2) >>> array.argmin(...) - {'x': + {'x': Size: 8B array(4)} >>> array.idxmin() - + Size: 4B array('e', dtype='>> array = xr.DataArray( ... [ ... [2.0, 1.0, 2.0, 0.0, -2.0], - ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], - ... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN], + ... [-4.0, np.nan, 2.0, np.nan, -2.0], + ... [np.nan, np.nan, 1.0, np.nan, np.nan], ... ], ... dims=["y", "x"], ... coords={"y": [-1, 0, 1], "x": np.arange(5.0) ** 2}, ... ) >>> array.min(dim="x") - + Size: 24B array([-2., -4., 1.]) Coordinates: - * y (y) int64 -1 0 1 + * y (y) int64 24B -1 0 1 >>> array.argmin(dim="x") - + Size: 24B array([4, 0, 2]) Coordinates: - * y (y) int64 -1 0 1 + * y (y) int64 24B -1 0 1 >>> array.idxmin(dim="x") - + Size: 24B array([16., 0., 4.]) Coordinates: - * y (y) int64 -1 0 1 + * y (y) int64 24B -1 0 1 """ return computation._calc_idxminmax( array=self, @@ -5508,13 +5850,15 @@ def idxmin( keep_attrs=keep_attrs, ) + @_deprecate_positional_args("v2023.10.0") def idxmax( self, dim: Hashable = None, + *, skipna: bool | None = None, fill_value: Any = dtypes.NA, keep_attrs: bool | None = None, - ) -> DataArray: + ) -> Self: """Return the coordinate label of the maximum value along a dimension. Returns a new `DataArray` named after the dimension with the values of @@ -5561,39 +5905,39 @@ def idxmax( ... [0, 2, 1, 0, -2], dims="x", coords={"x": ["a", "b", "c", "d", "e"]} ... ) >>> array.max() - + Size: 8B array(2) >>> array.argmax(...) - {'x': + {'x': Size: 8B array(1)} >>> array.idxmax() - + Size: 4B array('b', dtype='>> array = xr.DataArray( ... [ ... [2.0, 1.0, 2.0, 0.0, -2.0], - ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], - ... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN], + ... [-4.0, np.nan, 2.0, np.nan, -2.0], + ... [np.nan, np.nan, 1.0, np.nan, np.nan], ... ], ... dims=["y", "x"], ... coords={"y": [-1, 0, 1], "x": np.arange(5.0) ** 2}, ... ) >>> array.max(dim="x") - + Size: 24B array([2., 2., 1.]) Coordinates: - * y (y) int64 -1 0 1 + * y (y) int64 24B -1 0 1 >>> array.argmax(dim="x") - + Size: 24B array([0, 2, 2]) Coordinates: - * y (y) int64 -1 0 1 + * y (y) int64 24B -1 0 1 >>> array.idxmax(dim="x") - + Size: 24B array([0., 4., 4.]) Coordinates: - * y (y) int64 -1 0 1 + * y (y) int64 24B -1 0 1 """ return computation._calc_idxminmax( array=self, @@ -5604,15 +5948,15 @@ def idxmax( keep_attrs=keep_attrs, ) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved + @_deprecate_positional_args("v2023.10.0") def argmin( self, dim: Dims = None, + *, axis: int | None = None, keep_attrs: bool | None = None, skipna: bool | None = None, - ) -> DataArray | dict[Hashable, DataArray]: + ) -> Self | dict[Hashable, Self]: """Index or indices of the minimum of the DataArray over one or more dimensions. If a sequence is passed to 'dim', then result returned as dict of DataArrays, @@ -5654,13 +5998,13 @@ def argmin( -------- >>> array = xr.DataArray([0, 2, -1, 3], dims="x") >>> array.min() - + Size: 8B array(-1) >>> array.argmin(...) - {'x': + {'x': Size: 8B array(2)} >>> array.isel(array.argmin(...)) - + Size: 8B array(-1) >>> array = xr.DataArray( @@ -5668,35 +6012,35 @@ def argmin( ... dims=("x", "y", "z"), ... ) >>> array.min(dim="x") - + Size: 72B array([[ 1, 2, 1], [ 2, -5, 1], [ 2, 1, 1]]) Dimensions without coordinates: y, z >>> array.argmin(dim="x") - + Size: 72B array([[1, 0, 0], [1, 1, 1], [0, 0, 1]]) Dimensions without coordinates: y, z >>> array.argmin(dim=["x"]) - {'x': + {'x': Size: 72B array([[1, 0, 0], [1, 1, 1], [0, 0, 1]]) Dimensions without coordinates: y, z} >>> array.min(dim=("x", "z")) - + Size: 24B array([ 1, -5, 1]) Dimensions without coordinates: y >>> array.argmin(dim=["x", "z"]) - {'x': + {'x': Size: 24B array([0, 1, 0]) - Dimensions without coordinates: y, 'z': + Dimensions without coordinates: y, 'z': Size: 24B array([2, 1, 1]) Dimensions without coordinates: y} >>> array.isel(array.argmin(dim=["x", "z"])) - + Size: 24B array([ 1, -5, 1]) Dimensions without coordinates: y """ @@ -5706,15 +6050,15 @@ def argmin( else: return self._replace_maybe_drop_dims(result) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved + @_deprecate_positional_args("v2023.10.0") def argmax( self, dim: Dims = None, + *, axis: int | None = None, keep_attrs: bool | None = None, skipna: bool | None = None, - ) -> DataArray | dict[Hashable, DataArray]: + ) -> Self | dict[Hashable, Self]: """Index or indices of the maximum of the DataArray over one or more dimensions. If a sequence is passed to 'dim', then result returned as dict of DataArrays, @@ -5756,13 +6100,13 @@ def argmax( -------- >>> array = xr.DataArray([0, 2, -1, 3], dims="x") >>> array.max() - + Size: 8B array(3) >>> array.argmax(...) - {'x': + {'x': Size: 8B array(3)} >>> array.isel(array.argmax(...)) - + Size: 8B array(3) >>> array = xr.DataArray( @@ -5770,35 +6114,35 @@ def argmax( ... dims=("x", "y", "z"), ... ) >>> array.max(dim="x") - + Size: 72B array([[3, 3, 2], [3, 5, 2], [2, 3, 3]]) Dimensions without coordinates: y, z >>> array.argmax(dim="x") - + Size: 72B array([[0, 1, 1], [0, 1, 0], [0, 1, 0]]) Dimensions without coordinates: y, z >>> array.argmax(dim=["x"]) - {'x': + {'x': Size: 72B array([[0, 1, 1], [0, 1, 0], [0, 1, 0]]) Dimensions without coordinates: y, z} >>> array.max(dim=("x", "z")) - + Size: 24B array([3, 5, 3]) Dimensions without coordinates: y >>> array.argmax(dim=["x", "z"]) - {'x': + {'x': Size: 24B array([0, 1, 0]) - Dimensions without coordinates: y, 'z': + Dimensions without coordinates: y, 'z': Size: 24B array([0, 1, 2]) Dimensions without coordinates: y} >>> array.isel(array.argmax(dim=["x", "z"])) - + Size: 24B array([3, 5, 3]) Dimensions without coordinates: y """ @@ -5868,11 +6212,11 @@ def query( -------- >>> da = xr.DataArray(np.arange(0, 5, 1), dims="x", name="a") >>> da - + Size: 40B array([0, 1, 2, 3, 4]) Dimensions without coordinates: x >>> da.query(x="a > 2") - + Size: 16B array([3, 4]) Dimensions without coordinates: x """ @@ -5893,9 +6237,10 @@ def curvefit( func: Callable[..., Any], reduce_dims: Dims = None, skipna: bool = True, - p0: dict[str, Any] | None = None, - bounds: dict[str, Any] | None = None, + p0: Mapping[str, float | DataArray] | None = None, + bounds: Mapping[str, tuple[float | DataArray, float | DataArray]] | None = None, param_names: Sequence[str] | None = None, + errors: ErrorOptions = "raise", kwargs: dict[str, Any] | None = None, ) -> Dataset: """ @@ -5925,17 +6270,25 @@ def curvefit( Whether to skip missing values when fitting. Default is True. p0 : dict-like or None, optional Optional dictionary of parameter names to initial guesses passed to the - `curve_fit` `p0` arg. If none or only some parameters are passed, the rest will - be assigned initial values following the default scipy behavior. - bounds : dict-like or None, optional - Optional dictionary of parameter names to bounding values passed to the - `curve_fit` `bounds` arg. If none or only some parameters are passed, the rest - will be unbounded following the default scipy behavior. + `curve_fit` `p0` arg. If the values are DataArrays, they will be appropriately + broadcast to the coordinates of the array. If none or only some parameters are + passed, the rest will be assigned initial values following the default scipy + behavior. + bounds : dict-like, optional + Optional dictionary of parameter names to tuples of bounding values passed to the + `curve_fit` `bounds` arg. If any of the bounds are DataArrays, they will be + appropriately broadcast to the coordinates of the array. If none or only some + parameters are passed, the rest will be unbounded following the default scipy + behavior. param_names : sequence of Hashable or None, optional Sequence of names for the fittable parameters of `func`. If not supplied, this will be automatically determined by arguments of `func`. `param_names` should be manually supplied when fitting a function that takes a variable number of parameters. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will + raise an exception. If 'ignore', the coefficients and covariances for the + coordinates where the fitting failed will be NaN. **kwargs : optional Additional keyword arguments to passed to scipy curve_fit. @@ -5949,6 +6302,86 @@ def curvefit( [var]_curvefit_covariance The covariance matrix of the coefficient estimates. + Examples + -------- + Generate some exponentially decaying data, where the decay constant and amplitude are + different for different values of the coordinate ``x``: + + >>> rng = np.random.default_rng(seed=0) + >>> def exp_decay(t, time_constant, amplitude): + ... return np.exp(-t / time_constant) * amplitude + ... + >>> t = np.arange(11) + >>> da = xr.DataArray( + ... np.stack( + ... [ + ... exp_decay(t, 1, 0.1), + ... exp_decay(t, 2, 0.2), + ... exp_decay(t, 3, 0.3), + ... ] + ... ) + ... + rng.normal(size=(3, t.size)) * 0.01, + ... coords={"x": [0, 1, 2], "time": t}, + ... ) + >>> da + Size: 264B + array([[ 0.1012573 , 0.0354669 , 0.01993775, 0.00602771, -0.00352513, + 0.00428975, 0.01328788, 0.009562 , -0.00700381, -0.01264187, + -0.0062282 ], + [ 0.20041326, 0.09805582, 0.07138797, 0.03216692, 0.01974438, + 0.01097441, 0.00679441, 0.01015578, 0.01408826, 0.00093645, + 0.01501222], + [ 0.29334805, 0.21847449, 0.16305984, 0.11130396, 0.07164415, + 0.04744543, 0.03602333, 0.03129354, 0.01074885, 0.01284436, + 0.00910995]]) + Coordinates: + * x (x) int64 24B 0 1 2 + * time (time) int64 88B 0 1 2 3 4 5 6 7 8 9 10 + + Fit the exponential decay function to the data along the ``time`` dimension: + + >>> fit_result = da.curvefit("time", exp_decay) + >>> fit_result["curvefit_coefficients"].sel( + ... param="time_constant" + ... ) # doctest: +NUMBER + Size: 24B + array([1.05692036, 1.73549638, 2.94215771]) + Coordinates: + * x (x) int64 24B 0 1 2 + param >> fit_result["curvefit_coefficients"].sel(param="amplitude") + Size: 24B + array([0.1005489 , 0.19631423, 0.30003579]) + Coordinates: + * x (x) int64 24B 0 1 2 + param >> fit_result = da.curvefit( + ... "time", + ... exp_decay, + ... p0={ + ... "amplitude": 0.2, + ... "time_constant": xr.DataArray([1, 2, 3], coords=[da.x]), + ... }, + ... ) + >>> fit_result["curvefit_coefficients"].sel(param="time_constant") + Size: 24B + array([1.0569213 , 1.73550052, 2.94215733]) + Coordinates: + * x (x) int64 24B 0 1 2 + param >> fit_result["curvefit_coefficients"].sel(param="amplitude") + Size: 24B + array([0.10054889, 0.1963141 , 0.3000358 ]) + Coordinates: + * x (x) int64 24B 0 1 2 + param T_DataArray: + ) -> Self: """Returns a new DataArray with duplicate dimension values removed. Parameters @@ -5999,47 +6435,47 @@ def drop_duplicates( ... coords={"x": np.array([0, 0, 1, 2, 3]), "y": np.array([0, 1, 2, 3, 3])}, ... ) >>> da - + Size: 200B array([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]) Coordinates: - * x (x) int64 0 0 1 2 3 - * y (y) int64 0 1 2 3 3 + * x (x) int64 40B 0 0 1 2 3 + * y (y) int64 40B 0 1 2 3 3 >>> da.drop_duplicates(dim="x") - + Size: 160B array([[ 0, 1, 2, 3, 4], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]) Coordinates: - * x (x) int64 0 1 2 3 - * y (y) int64 0 1 2 3 3 + * x (x) int64 32B 0 1 2 3 + * y (y) int64 40B 0 1 2 3 3 >>> da.drop_duplicates(dim="x", keep="last") - + Size: 160B array([[ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]) Coordinates: - * x (x) int64 0 1 2 3 - * y (y) int64 0 1 2 3 3 + * x (x) int64 32B 0 1 2 3 + * y (y) int64 40B 0 1 2 3 3 Drop all duplicate dimension values: >>> da.drop_duplicates(dim=...) - + Size: 128B array([[ 0, 1, 2, 3], [10, 11, 12, 13], [15, 16, 17, 18], [20, 21, 22, 23]]) Coordinates: - * x (x) int64 0 1 2 3 - * y (y) int64 0 1 2 3 + * x (x) int64 32B 0 1 2 3 + * y (y) int64 32B 0 1 2 3 """ deduplicated = self._to_temp_dataset().drop_duplicates(dim, keep=keep) return self._from_temp_dataset(deduplicated) @@ -6051,7 +6487,7 @@ def convert_calendar( align_on: str | None = None, missing: Any | None = None, use_cftime: bool | None = None, - ) -> DataArray: + ) -> Self: """Convert the DataArray to another calendar. Only converts the individual timestamps, does not modify any data except @@ -6171,7 +6607,7 @@ def interp_calendar( self, target: pd.DatetimeIndex | CFTimeIndex | DataArray, dim: str = "time", - ) -> DataArray: + ) -> Self: """Interpolates the DataArray to another calendar based on decimal year measure. Each timestamp in `source` and `target` are first converted to their decimal @@ -6201,7 +6637,7 @@ def interp_calendar( def groupby( self, group: Hashable | DataArray | IndexVariable, - squeeze: bool = True, + squeeze: bool | None = None, restore_coord_dims: bool = False, ) -> DataArrayGroupBy: """Returns a DataArrayGroupBy object for performing grouped operations. @@ -6235,42 +6671,51 @@ def groupby( ... dims="time", ... ) >>> da - + Size: 15kB array([0.000e+00, 1.000e+00, 2.000e+00, ..., 1.824e+03, 1.825e+03, 1.826e+03]) Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2004-12-31 + * time (time) datetime64[ns] 15kB 2000-01-01 2000-01-02 ... 2004-12-31 >>> da.groupby("time.dayofyear") - da.groupby("time.dayofyear").mean("time") - + Size: 15kB array([-730.8, -730.8, -730.8, ..., 730.2, 730.2, 730.5]) Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2004-12-31 - dayofyear (time) int64 1 2 3 4 5 6 7 8 ... 359 360 361 362 363 364 365 366 + * time (time) datetime64[ns] 15kB 2000-01-01 2000-01-02 ... 2004-12-31 + dayofyear (time) int64 15kB 1 2 3 4 5 6 7 8 ... 360 361 362 363 364 365 366 See Also -------- :ref:`groupby` Users guide explanation of how to group and bin data. + + :doc:`xarray-tutorial:intermediate/01-high-level-computation-patterns` + Tutorial on :py:func:`~xarray.DataArray.Groupby` for windowed computation + + :doc:`xarray-tutorial:fundamentals/03.2_groupby_with_xarray` + Tutorial on :py:func:`~xarray.DataArray.Groupby` demonstrating reductions, transformation and comparison with :py:func:`~xarray.DataArray.resample` + DataArray.groupby_bins Dataset.groupby core.groupby.DataArrayGroupBy + DataArray.coarsen pandas.DataFrame.groupby + Dataset.resample + DataArray.resample """ - from xarray.core.groupby import DataArrayGroupBy - - # While we don't generally check the type of every arg, passing - # multiple dimensions as multiple arguments is common enough, and the - # consequences hidden enough (strings evaluate as true) to warrant - # checking here. - # A future version could make squeeze kwarg only, but would face - # backward-compat issues. - if not isinstance(squeeze, bool): - raise TypeError( - f"`squeeze` must be True or False, but {squeeze} was supplied" - ) + from xarray.core.groupby import ( + DataArrayGroupBy, + ResolvedGrouper, + UniqueGrouper, + _validate_groupby_squeeze, + ) + _validate_groupby_squeeze(squeeze) + rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) return DataArrayGroupBy( - self, group, squeeze=squeeze, restore_coord_dims=restore_coord_dims + self, + (rgrouper,), + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, ) def groupby_bins( @@ -6281,7 +6726,7 @@ def groupby_bins( labels: ArrayLike | Literal[False] | None = None, precision: int = 3, include_lowest: bool = False, - squeeze: bool = True, + squeeze: bool | None = None, restore_coord_dims: bool = False, ) -> DataArrayGroupBy: """Returns a DataArrayGroupBy object for performing grouped operations. @@ -6341,14 +6786,16 @@ def groupby_bins( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ - from xarray.core.groupby import DataArrayGroupBy + from xarray.core.groupby import ( + BinGrouper, + DataArrayGroupBy, + ResolvedGrouper, + _validate_groupby_squeeze, + ) - return DataArrayGroupBy( - self, - group, - squeeze=squeeze, + _validate_groupby_squeeze(squeeze) + grouper = BinGrouper( bins=bins, - restore_coord_dims=restore_coord_dims, cut_kwargs={ "right": right, "labels": labels, @@ -6356,6 +6803,14 @@ def groupby_bins( "include_lowest": include_lowest, }, ) + rgrouper = ResolvedGrouper(grouper, group, self) + + return DataArrayGroupBy( + self, + (rgrouper,), + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) def weighted(self, weights: DataArray) -> DataArrayWeighted: """ @@ -6380,6 +6835,13 @@ def weighted(self, weights: DataArray) -> DataArrayWeighted: See Also -------- Dataset.weighted + + :ref:`comput.weighted` + User guide on weighted array reduction using :py:func:`~xarray.DataArray.weighted` + + :doc:`xarray-tutorial:fundamentals/03.4_weighted` + Tutorial on Weighted Reduction using :py:func:`~xarray.DataArray.weighted` + """ from xarray.core.weighted import DataArrayWeighted @@ -6431,28 +6893,29 @@ def rolling( ... dims="time", ... ) >>> da - + Size: 96B array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) Coordinates: - * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15 + * time (time) datetime64[ns] 96B 1999-12-15 2000-01-15 ... 2000-11-15 >>> da.rolling(time=3, center=True).mean() - + Size: 96B array([nan, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., nan]) Coordinates: - * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15 + * time (time) datetime64[ns] 96B 1999-12-15 2000-01-15 ... 2000-11-15 Remove the NaNs using ``dropna()``: >>> da.rolling(time=3, center=True).mean().dropna("time") - + Size: 80B array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) Coordinates: - * time (time) datetime64[ns] 2000-01-15 2000-02-15 ... 2000-10-15 + * time (time) datetime64[ns] 80B 2000-01-15 2000-02-15 ... 2000-10-15 See Also -------- - core.rolling.DataArrayRolling + DataArray.cumulative Dataset.rolling + core.rolling.DataArrayRolling """ from xarray.core.rolling import DataArrayRolling @@ -6461,6 +6924,81 @@ def rolling( self, dim, min_periods=min_periods, center=center, pad=pad ) + def cumulative( + self, + dim: str | Iterable[Hashable], + min_periods: int = 1, + ) -> DataArrayRolling: + """ + Accumulating object for DataArrays. + + Parameters + ---------- + dims : iterable of hashable + The name(s) of the dimensions to create the cumulative window along + min_periods : int, default: 1 + Minimum number of observations in window required to have a value + (otherwise result is NA). The default is 1 (note this is different + from ``Rolling``, whose default is the size of the window). + + Returns + ------- + core.rolling.DataArrayRolling + + Examples + -------- + Create rolling seasonal average of monthly data e.g. DJF, JFM, ..., SON: + + >>> da = xr.DataArray( + ... np.linspace(0, 11, num=12), + ... coords=[ + ... pd.date_range( + ... "1999-12-15", + ... periods=12, + ... freq=pd.DateOffset(months=1), + ... ) + ... ], + ... dims="time", + ... ) + + >>> da + Size: 96B + array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) + Coordinates: + * time (time) datetime64[ns] 96B 1999-12-15 2000-01-15 ... 2000-11-15 + + >>> da.cumulative("time").sum() + Size: 96B + array([ 0., 1., 3., 6., 10., 15., 21., 28., 36., 45., 55., 66.]) + Coordinates: + * time (time) datetime64[ns] 96B 1999-12-15 2000-01-15 ... 2000-11-15 + + See Also + -------- + DataArray.rolling + Dataset.cumulative + core.rolling.DataArrayRolling + """ + from xarray.core.rolling import DataArrayRolling + + # Could we abstract this "normalize and check 'dim'" logic? It's currently shared + # with the same method in Dataset. + if isinstance(dim, str): + if dim not in self.dims: + raise ValueError( + f"Dimension {dim} not found in data dimensions: {self.dims}" + ) + dim = {dim: self.sizes[dim]} + else: + missing_dims = set(dim) - set(self.dims) + if missing_dims: + raise ValueError( + f"Dimensions {missing_dims} not found in data dimensions: {self.dims}" + ) + dim = {d: self.sizes[d] for d in dim} + + return DataArrayRolling(self, dim, min_periods=min_periods, center=False) + def coarsen( self, dim: Mapping[Any, int] | None = None, @@ -6499,31 +7037,101 @@ def coarsen( ... coords={"time": pd.date_range("1999-12-15", periods=364)}, ... ) >>> da # +doctest: ELLIPSIS - + Size: 3kB array([ 0. , 1.00275482, 2.00550964, 3.00826446, 4.01101928, 5.0137741 , 6.01652893, 7.01928375, 8.02203857, 9.02479339, 10.02754821, 11.03030303, + 12.03305785, 13.03581267, 14.03856749, 15.04132231, + 16.04407713, 17.04683196, 18.04958678, 19.0523416 , + 20.05509642, 21.05785124, 22.06060606, 23.06336088, + 24.0661157 , 25.06887052, 26.07162534, 27.07438017, + 28.07713499, 29.07988981, 30.08264463, 31.08539945, + 32.08815427, 33.09090909, 34.09366391, 35.09641873, + 36.09917355, 37.10192837, 38.1046832 , 39.10743802, + 40.11019284, 41.11294766, 42.11570248, 43.1184573 , + 44.12121212, 45.12396694, 46.12672176, 47.12947658, + 48.1322314 , 49.13498623, 50.13774105, 51.14049587, + 52.14325069, 53.14600551, 54.14876033, 55.15151515, + 56.15426997, 57.15702479, 58.15977961, 59.16253444, + 60.16528926, 61.16804408, 62.1707989 , 63.17355372, + 64.17630854, 65.17906336, 66.18181818, 67.184573 , + 68.18732782, 69.19008264, 70.19283747, 71.19559229, + 72.19834711, 73.20110193, 74.20385675, 75.20661157, + 76.20936639, 77.21212121, 78.21487603, 79.21763085, ... + 284.78236915, 285.78512397, 286.78787879, 287.79063361, + 288.79338843, 289.79614325, 290.79889807, 291.80165289, + 292.80440771, 293.80716253, 294.80991736, 295.81267218, + 296.815427 , 297.81818182, 298.82093664, 299.82369146, + 300.82644628, 301.8292011 , 302.83195592, 303.83471074, + 304.83746556, 305.84022039, 306.84297521, 307.84573003, + 308.84848485, 309.85123967, 310.85399449, 311.85674931, + 312.85950413, 313.86225895, 314.86501377, 315.8677686 , + 316.87052342, 317.87327824, 318.87603306, 319.87878788, + 320.8815427 , 321.88429752, 322.88705234, 323.88980716, + 324.89256198, 325.8953168 , 326.89807163, 327.90082645, + 328.90358127, 329.90633609, 330.90909091, 331.91184573, + 332.91460055, 333.91735537, 334.92011019, 335.92286501, + 336.92561983, 337.92837466, 338.93112948, 339.9338843 , + 340.93663912, 341.93939394, 342.94214876, 343.94490358, + 344.9476584 , 345.95041322, 346.95316804, 347.95592287, + 348.95867769, 349.96143251, 350.96418733, 351.96694215, + 352.96969697, 353.97245179, 354.97520661, 355.97796143, 356.98071625, 357.98347107, 358.9862259 , 359.98898072, 360.99173554, 361.99449036, 362.99724518, 364. ]) Coordinates: - * time (time) datetime64[ns] 1999-12-15 1999-12-16 ... 2000-12-12 + * time (time) datetime64[ns] 3kB 1999-12-15 1999-12-16 ... 2000-12-12 >>> da.coarsen(time=3, boundary="trim").mean() # +doctest: ELLIPSIS - + Size: 968B array([ 1.00275482, 4.01101928, 7.01928375, 10.02754821, 13.03581267, 16.04407713, 19.0523416 , 22.06060606, 25.06887052, 28.07713499, 31.08539945, 34.09366391, - ... + 37.10192837, 40.11019284, 43.1184573 , 46.12672176, + 49.13498623, 52.14325069, 55.15151515, 58.15977961, + 61.16804408, 64.17630854, 67.184573 , 70.19283747, + 73.20110193, 76.20936639, 79.21763085, 82.22589532, + 85.23415978, 88.24242424, 91.25068871, 94.25895317, + 97.26721763, 100.27548209, 103.28374656, 106.29201102, + 109.30027548, 112.30853994, 115.31680441, 118.32506887, + 121.33333333, 124.3415978 , 127.34986226, 130.35812672, + 133.36639118, 136.37465565, 139.38292011, 142.39118457, + 145.39944904, 148.4077135 , 151.41597796, 154.42424242, + 157.43250689, 160.44077135, 163.44903581, 166.45730028, + 169.46556474, 172.4738292 , 175.48209366, 178.49035813, + 181.49862259, 184.50688705, 187.51515152, 190.52341598, + 193.53168044, 196.5399449 , 199.54820937, 202.55647383, + 205.56473829, 208.57300275, 211.58126722, 214.58953168, + 217.59779614, 220.60606061, 223.61432507, 226.62258953, + 229.63085399, 232.63911846, 235.64738292, 238.65564738, + 241.66391185, 244.67217631, 247.68044077, 250.68870523, + 253.6969697 , 256.70523416, 259.71349862, 262.72176309, + 265.73002755, 268.73829201, 271.74655647, 274.75482094, + 277.7630854 , 280.77134986, 283.77961433, 286.78787879, + 289.79614325, 292.80440771, 295.81267218, 298.82093664, + 301.8292011 , 304.83746556, 307.84573003, 310.85399449, + 313.86225895, 316.87052342, 319.87878788, 322.88705234, + 325.8953168 , 328.90358127, 331.91184573, 334.92011019, + 337.92837466, 340.93663912, 343.94490358, 346.95316804, 349.96143251, 352.96969697, 355.97796143, 358.9862259 , 361.99449036]) Coordinates: - * time (time) datetime64[ns] 1999-12-16 1999-12-19 ... 2000-12-10 + * time (time) datetime64[ns] 968B 1999-12-16 1999-12-19 ... 2000-12-10 >>> See Also -------- core.rolling.DataArrayCoarsen Dataset.coarsen + + :ref:`reshape.coarsen` + User guide describing :py:func:`~xarray.DataArray.coarsen` + + :ref:`compute.coarsen` + User guide on block arrgragation :py:func:`~xarray.DataArray.coarsen` + + :doc:`xarray-tutorial:fundamentals/03.3_windowed` + Tutorial on windowed computation using :py:func:`~xarray.DataArray.coarsen` + """ from xarray.core.rolling import DataArrayCoarsen @@ -6545,7 +7153,6 @@ def resample( base: int | None = None, offset: pd.Timedelta | datetime.timedelta | str | None = None, origin: str | DatetimeLike = "start_day", - keep_attrs: bool | None = None, loffset: datetime.timedelta | str | None = None, restore_coord_dims: bool | None = None, **indexer_kwargs: str, @@ -6587,6 +7194,12 @@ def resample( loffset : timedelta or str, optional Offset used to adjust the resampled time labels. Some pandas date offset strings are supported. + + .. deprecated:: 2023.03.0 + Following pandas, the ``loffset`` parameter is deprecated in favor + of using time offset arithmetic, and will be removed in a future + version of xarray. + restore_coord_dims : bool, optional If True, also restore the dimension order of multi-dimensional coordinates. @@ -6615,36 +7228,96 @@ def resample( ... dims="time", ... ) >>> da - + Size: 96B array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) Coordinates: - * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15 + * time (time) datetime64[ns] 96B 1999-12-15 2000-01-15 ... 2000-11-15 >>> da.resample(time="QS-DEC").mean() - + Size: 32B array([ 1., 4., 7., 10.]) Coordinates: - * time (time) datetime64[ns] 1999-12-01 2000-03-01 2000-06-01 2000-09-01 + * time (time) datetime64[ns] 32B 1999-12-01 2000-03-01 ... 2000-09-01 Upsample monthly time-series data to daily data: >>> da.resample(time="1D").interpolate("linear") # +doctest: ELLIPSIS - + Size: 3kB array([ 0. , 0.03225806, 0.06451613, 0.09677419, 0.12903226, 0.16129032, 0.19354839, 0.22580645, 0.25806452, 0.29032258, 0.32258065, 0.35483871, 0.38709677, 0.41935484, 0.4516129 , + 0.48387097, 0.51612903, 0.5483871 , 0.58064516, 0.61290323, + 0.64516129, 0.67741935, 0.70967742, 0.74193548, 0.77419355, + 0.80645161, 0.83870968, 0.87096774, 0.90322581, 0.93548387, + 0.96774194, 1. , 1.03225806, 1.06451613, 1.09677419, + 1.12903226, 1.16129032, 1.19354839, 1.22580645, 1.25806452, + 1.29032258, 1.32258065, 1.35483871, 1.38709677, 1.41935484, + 1.4516129 , 1.48387097, 1.51612903, 1.5483871 , 1.58064516, + 1.61290323, 1.64516129, 1.67741935, 1.70967742, 1.74193548, + 1.77419355, 1.80645161, 1.83870968, 1.87096774, 1.90322581, + 1.93548387, 1.96774194, 2. , 2.03448276, 2.06896552, + 2.10344828, 2.13793103, 2.17241379, 2.20689655, 2.24137931, + 2.27586207, 2.31034483, 2.34482759, 2.37931034, 2.4137931 , + 2.44827586, 2.48275862, 2.51724138, 2.55172414, 2.5862069 , + 2.62068966, 2.65517241, 2.68965517, 2.72413793, 2.75862069, + 2.79310345, 2.82758621, 2.86206897, 2.89655172, 2.93103448, + 2.96551724, 3. , 3.03225806, 3.06451613, 3.09677419, + 3.12903226, 3.16129032, 3.19354839, 3.22580645, 3.25806452, ... + 7.87096774, 7.90322581, 7.93548387, 7.96774194, 8. , + 8.03225806, 8.06451613, 8.09677419, 8.12903226, 8.16129032, + 8.19354839, 8.22580645, 8.25806452, 8.29032258, 8.32258065, + 8.35483871, 8.38709677, 8.41935484, 8.4516129 , 8.48387097, + 8.51612903, 8.5483871 , 8.58064516, 8.61290323, 8.64516129, + 8.67741935, 8.70967742, 8.74193548, 8.77419355, 8.80645161, + 8.83870968, 8.87096774, 8.90322581, 8.93548387, 8.96774194, + 9. , 9.03333333, 9.06666667, 9.1 , 9.13333333, + 9.16666667, 9.2 , 9.23333333, 9.26666667, 9.3 , + 9.33333333, 9.36666667, 9.4 , 9.43333333, 9.46666667, + 9.5 , 9.53333333, 9.56666667, 9.6 , 9.63333333, + 9.66666667, 9.7 , 9.73333333, 9.76666667, 9.8 , + 9.83333333, 9.86666667, 9.9 , 9.93333333, 9.96666667, + 10. , 10.03225806, 10.06451613, 10.09677419, 10.12903226, + 10.16129032, 10.19354839, 10.22580645, 10.25806452, 10.29032258, + 10.32258065, 10.35483871, 10.38709677, 10.41935484, 10.4516129 , + 10.48387097, 10.51612903, 10.5483871 , 10.58064516, 10.61290323, + 10.64516129, 10.67741935, 10.70967742, 10.74193548, 10.77419355, 10.80645161, 10.83870968, 10.87096774, 10.90322581, 10.93548387, 10.96774194, 11. ]) Coordinates: - * time (time) datetime64[ns] 1999-12-15 1999-12-16 ... 2000-11-15 + * time (time) datetime64[ns] 3kB 1999-12-15 1999-12-16 ... 2000-11-15 Limit scope of upsampling method >>> da.resample(time="1D").nearest(tolerance="1D") - - array([ 0., 0., nan, ..., nan, 11., 11.]) + Size: 3kB + array([ 0., 0., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, 1., 1., 1., nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, 2., 2., 2., nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 3., + 3., 3., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, 4., 4., 4., nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, 5., 5., 5., nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + 6., 6., 6., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, 7., 7., 7., nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, 8., 8., 8., nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, 9., 9., 9., nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, 10., 10., 10., nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, + nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 11., 11.]) Coordinates: - * time (time) datetime64[ns] 1999-12-15 1999-12-16 ... 2000-11-15 + * time (time) datetime64[ns] 3kB 1999-12-15 1999-12-16 ... 2000-11-15 See Also -------- @@ -6667,12 +7340,76 @@ def resample( base=base, offset=offset, origin=origin, - keep_attrs=keep_attrs, loffset=loffset, restore_coord_dims=restore_coord_dims, **indexer_kwargs, ) + def to_dask_dataframe( + self, + dim_order: Sequence[Hashable] | None = None, + set_index: bool = False, + ) -> DaskDataFrame: + """Convert this array into a dask.dataframe.DataFrame. + + Parameters + ---------- + dim_order : Sequence of Hashable or None , optional + Hierarchical dimension order for the resulting dataframe. + Array content is transposed to this order and then written out as flat + vectors in contiguous order, so the last dimension in this list + will be contiguous in the resulting DataFrame. This has a major influence + on which operations are efficient on the resulting dask dataframe. + set_index : bool, default: False + If set_index=True, the dask DataFrame is indexed by this dataset's + coordinate. Since dask DataFrames do not support multi-indexes, + set_index only works if the dataset only contains one dimension. + + Returns + ------- + dask.dataframe.DataFrame + + Examples + -------- + >>> da = xr.DataArray( + ... np.arange(4 * 2 * 2).reshape(4, 2, 2), + ... dims=("time", "lat", "lon"), + ... coords={ + ... "time": np.arange(4), + ... "lat": [-30, -20], + ... "lon": [120, 130], + ... }, + ... name="eg_dataarray", + ... attrs={"units": "Celsius", "description": "Random temperature data"}, + ... ) + >>> da.to_dask_dataframe(["lat", "lon", "time"]).compute() + lat lon time eg_dataarray + 0 -30 120 0 0 + 1 -30 120 1 4 + 2 -30 120 2 8 + 3 -30 120 3 12 + 4 -30 130 0 1 + 5 -30 130 1 5 + 6 -30 130 2 9 + 7 -30 130 3 13 + 8 -20 120 0 2 + 9 -20 120 1 6 + 10 -20 120 2 10 + 11 -20 120 3 14 + 12 -20 130 0 3 + 13 -20 130 1 7 + 14 -20 130 2 11 + 15 -20 130 3 15 + """ + if self.name is None: + raise ValueError( + "Cannot convert an unnamed DataArray to a " + "dask dataframe : use the ``.rename`` method to assign a name." + ) + name = self.name + ds = self._to_dataset_whole(name, shallow_copy=False) + return ds.to_dask_dataframe(dim_order, set_index) + # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names str = utils.UncachedAccessor(StringAccessor["DataArray"]) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 569267bdc12..54386e42206 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -24,6 +24,13 @@ from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload import numpy as np + +# remove once numpy 2.0 is the oldest supported version +try: + from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] +except ImportError: + from numpy import RankWarning + import pandas as pd from xarray.coding.calendar_ops import convert_calendar, interp_calendar @@ -50,7 +57,12 @@ get_chunksizes, ) from xarray.core.computation import unify_chunks -from xarray.core.coordinates import DatasetCoordinates, assert_coordinate_consistent +from xarray.core.coordinates import ( + Coordinates, + DatasetCoordinates, + assert_coordinate_consistent, + create_coords_with_default_indexes, +) from xarray.core.duck_array_ops import datetime_to_numeric from xarray.core.indexes import ( Index, @@ -69,23 +81,35 @@ dataset_merge_method, dataset_update_method, merge_coordinates_without_align, - merge_data_and_coords, + merge_core, ) from xarray.core.missing import get_clean_interp_index from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.pycompat import array_type, is_duck_array, is_duck_dask_array -from xarray.core.types import QuantileMethods, T_Dataset +from xarray.core.types import ( + QuantileMethods, + Self, + T_ChunkDim, + T_Chunks, + T_DataArray, + T_DataArrayOrSet, + T_Dataset, + ZarrWriteModes, +) from xarray.core.utils import ( Default, Frozen, + FrozenMappingWarningOnValuesAccess, HybridMappingProxy, OrderedSet, _default, decode_numpy_dict_values, drop_dims_from_indexers, either_dict_or_kwargs, + emit_user_level_warning, infix_dims, is_dict_like, + is_duck_array, + is_duck_dask_array, is_scalar, maybe_wrap_array, ) @@ -96,17 +120,21 @@ broadcast_variables, calculate_dimensions, ) +from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager +from xarray.namedarray.pycompat import array_type, is_chunked_array from xarray.plot.accessor import DatasetPlotAccessor +from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: + from dask.dataframe import DataFrame as DaskDataFrame + from dask.delayed import Delayed from numpy.typing import ArrayLike from xarray.backends import AbstractDataStore, ZarrStore from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes - from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.groupby import DatasetGroupBy - from xarray.core.merge import CoercibleMapping + from xarray.core.merge import CoercibleMapping, CoercibleValue, _MergeResult from xarray.core.resample import DatasetResample from xarray.core.rolling import DatasetCoarsen, DatasetRolling from xarray.core.types import ( @@ -114,9 +142,11 @@ CoarsenBoundaryOptions, CombineAttrsOptions, CompatOptions, + DataVars, DatetimeLike, DatetimeUnitOptions, Dims, + DsCompatible, ErrorOptions, ErrorOptionsWithWarn, InterpOptions, @@ -130,15 +160,7 @@ T_Xarray, ) from xarray.core.weighted import DatasetWeighted - - try: - from dask.delayed import Delayed - except ImportError: - Delayed = None # type: ignore - try: - from dask.dataframe import DataFrame as DaskDataFrame - except ImportError: - DaskDataFrame = None # type: ignore + from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint # list of attributes of pd.DatetimeIndex that are ndarrays of time info @@ -197,18 +219,11 @@ def _get_virtual_variable( return ref_name, var_name, virtual_var -def _assert_empty(args: tuple, msg: str = "%s") -> None: - if args: - raise ValueError(msg % args) - - -def _get_chunk(var, chunks): +def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint): """ Return map from each dim to chunk sizes, accounting for backend's preferred chunks. """ - import dask.array as da - if isinstance(var, IndexVariable): return {} dims = var.dims @@ -225,7 +240,8 @@ def _get_chunk(var, chunks): chunks.get(dim, None) or preferred_chunk_sizes for dim, preferred_chunk_sizes in zip(dims, preferred_chunk_shape) ) - chunk_shape = da.core.normalize_chunks( + + chunk_shape = chunkmanager.normalize_chunks( chunk_shape, shape=shape, dtype=var.dtype, previous_chunks=preferred_chunk_shape ) @@ -242,7 +258,7 @@ def _get_chunk(var, chunks): # expresses the preferred chunks, the sequence sums to the size. preferred_stops = ( range(preferred_chunk_sizes, size, preferred_chunk_sizes) - if isinstance(preferred_chunk_sizes, Number) + if isinstance(preferred_chunk_sizes, int) else itertools.accumulate(preferred_chunk_sizes[:-1]) ) # Gather any stop indices of the specified chunks that are not a stop index @@ -253,7 +269,7 @@ def _get_chunk(var, chunks): ) if breaks: warnings.warn( - "The specified Dask chunks separate the stored chunks along " + "The specified chunks separate the stored chunks along " f'dimension "{dim}" starting at index {min(breaks)}. This could ' "degrade performance. Instead, consider rechunking after loading." ) @@ -270,18 +286,41 @@ def _maybe_chunk( name_prefix="xarray-", overwrite_encoded_chunks=False, inline_array=False, + chunked_array_type: str | ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, ): - from dask.base import tokenize + + from xarray.namedarray.daskmanager import DaskManager if chunks is not None: chunks = {dim: chunks[dim] for dim in var.dims if dim in chunks} + if var.ndim: - # when rechunking by different amounts, make sure dask names change - # by provinding chunks as an input to tokenize. - # subtle bugs result otherwise. see GH3350 - token2 = tokenize(name, token if token else var._data, chunks) - name2 = f"{name_prefix}{name}-{token2}" - var = var.chunk(chunks, name=name2, lock=lock, inline_array=inline_array) + chunked_array_type = guess_chunkmanager( + chunked_array_type + ) # coerce string to ChunkManagerEntrypoint type + if isinstance(chunked_array_type, DaskManager): + from dask.base import tokenize + + # when rechunking by different amounts, make sure dask names change + # by providing chunks as an input to tokenize. + # subtle bugs result otherwise. see GH3350 + # we use str() for speed, and use the name for the final array name on the next line + token2 = tokenize(token if token else var._data, str(chunks)) + name2 = f"{name_prefix}{name}-{token2}" + + from_array_kwargs = utils.consolidate_dask_from_array_kwargs( + from_array_kwargs, + name=name2, + lock=lock, + inline_array=inline_array, + ) + + var = var.chunk( + chunks, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, + ) if overwrite_encoded_chunks and var.chunks is not None: var.encoding["chunks"] = tuple(x[0] for x in var.chunks) @@ -332,17 +371,24 @@ def _initialize_curvefit_params(params, p0, bounds, func_args): """Set initial guess and bounds for curvefit. Priority: 1) passed args 2) func signature 3) scipy defaults """ + from xarray.core.computation import where def _initialize_feasible(lb, ub): # Mimics functionality of scipy.optimize.minpack._initialize_feasible lb_finite = np.isfinite(lb) ub_finite = np.isfinite(ub) - p0 = np.nansum( - [ - 0.5 * (lb + ub) * int(lb_finite & ub_finite), - (lb + 1) * int(lb_finite & ~ub_finite), - (ub - 1) * int(~lb_finite & ub_finite), - ] + p0 = where( + lb_finite, + where( + ub_finite, + 0.5 * (lb + ub), # both bounds finite + lb + 1, # lower bound finite, upper infinite + ), + where( + ub_finite, + ub - 1, # lower bound infinite, upper finite + 0, # both bounds infinite + ), ) return p0 @@ -352,14 +398,38 @@ def _initialize_feasible(lb, ub): if p in func_args and func_args[p].default is not func_args[p].empty: param_defaults[p] = func_args[p].default if p in bounds: - bounds_defaults[p] = tuple(bounds[p]) - if param_defaults[p] < bounds[p][0] or param_defaults[p] > bounds[p][1]: - param_defaults[p] = _initialize_feasible(bounds[p][0], bounds[p][1]) + lb, ub = bounds[p] + bounds_defaults[p] = (lb, ub) + param_defaults[p] = where( + (param_defaults[p] < lb) | (param_defaults[p] > ub), + _initialize_feasible(lb, ub), + param_defaults[p], + ) if p in p0: param_defaults[p] = p0[p] return param_defaults, bounds_defaults +def merge_data_and_coords(data_vars: DataVars, coords) -> _MergeResult: + """Used in Dataset.__init__.""" + if isinstance(coords, Coordinates): + coords = coords.copy() + else: + coords = create_coords_with_default_indexes(coords, data_vars) + + # exclude coords from alignment (all variables in a Coordinates object should + # already be aligned together) and use coordinates' indexes to align data_vars + return merge_core( + [data_vars, coords], + compat="broadcast_equals", + join="outer", + explicit_coords=tuple(coords), + indexes=coords.xindexes, + priority_arg=1, + skip_align_args=[1], + ) + + class DataVariables(Mapping[Any, "DataArray"]): __slots__ = ("_dataset",) @@ -374,14 +444,16 @@ def __iter__(self) -> Iterator[Hashable]: ) def __len__(self) -> int: - return len(self._dataset._variables) - len(self._dataset._coord_names) + length = len(self._dataset._variables) - len(self._dataset._coord_names) + assert length >= 0, "something is wrong with Dataset._coord_names" + return length def __contains__(self, key: Hashable) -> bool: return key in self._dataset._variables and key not in self._dataset._coord_names def __getitem__(self, key: Hashable) -> DataArray: if key not in self._dataset._coord_names: - return cast("DataArray", self._dataset[key]) + return self._dataset[key] raise KeyError(key) def __repr__(self) -> str: @@ -451,8 +523,11 @@ class Dataset( Dataset implements the mapping interface with keys given by variable names and values given by DataArray objects for each variable name. - One dimensional variables with name equal to their dimension are - index coordinates used for label based indexing. + By default, pandas indexes are created for one dimensional variables with + name equal to their dimension (i.e., :term:`Dimension coordinate`) so those + variables can be readily used as coordinates for label based indexing. When a + :py:class:`~xarray.Coordinates` object is passed to ``coords``, any existing + index(es) built from those coordinates will be added to the Dataset. To load data from a file or file-like object, use the `open_dataset` function. @@ -473,22 +548,21 @@ class Dataset( - mapping {var name: (dimension name, array-like)} - mapping {var name: (tuple of dimension names, array-like)} - mapping {dimension name: array-like} - (it will be automatically moved to coords, see below) + (if array-like is not a scalar it will be automatically moved to coords, + see below) Each dimension must have the same length in all variables in which it appears. - coords : dict-like, optional - Another mapping in similar form as the `data_vars` argument, - except the each item is saved on the dataset as a "coordinate". + coords : :py:class:`~xarray.Coordinates` or dict-like, optional + A :py:class:`~xarray.Coordinates` object or another mapping in + similar form as the `data_vars` argument, except that each item + is saved on the dataset as a "coordinate". These variables have an associated meaning: they describe constant/fixed/independent quantities, unlike the varying/measured/dependent quantities that belong in - `variables`. Coordinates values may be given by 1-dimensional - arrays or scalars, in which case `dims` do not need to be - supplied: 1D arrays will be assumed to give index values along - the dimension with the same name. + `variables`. - The following notations are accepted: + The following notations are accepted for arbitrary mappings: - mapping {coord name: DataArray} - mapping {coord name: Variable} @@ -498,8 +572,16 @@ class Dataset( (the dimension name is implicitly set to be the same as the coord name) - The last notation implies that the coord name is the same as - the dimension name. + The last notation implies either that the coordinate value is a scalar + or that it is a 1-dimensional array and the coord name is the same as + the dimension name (i.e., a :term:`Dimension coordinate`). In the latter + case, the 1-dimensional array will be assumed to give index values + along the dimension with the same name. + + Alternatively, a :py:class:`~xarray.Coordinates` object may be used in + order to explicitly pass indexes (e.g., a multi-index or any custom + Xarray index) or to bypass the creation of a default index for any + :term:`Dimension coordinate` included in that object. attrs : dict-like, optional Global attributes to save on this dataset. @@ -532,17 +614,17 @@ class Dataset( ... attrs=dict(description="Weather related data."), ... ) >>> ds - + Size: 288B Dimensions: (x: 2, y: 2, time: 3) Coordinates: - lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 - lat (x, y) float64 42.25 42.21 42.63 42.59 - * time (time) datetime64[ns] 2014-09-06 2014-09-07 2014-09-08 - reference_time datetime64[ns] 2014-09-05 + lon (x, y) float64 32B -99.83 -99.32 -99.79 -99.23 + lat (x, y) float64 32B 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 24B 2014-09-06 2014-09-07 2014-09-08 + reference_time datetime64[ns] 8B 2014-09-05 Dimensions without coordinates: x, y Data variables: - temperature (x, y, time) float64 29.11 18.2 22.83 ... 18.28 16.15 26.63 - precipitation (x, y, time) float64 5.68 9.256 0.7104 ... 7.992 4.615 7.805 + temperature (x, y, time) float64 96B 29.11 18.2 22.83 ... 16.15 26.63 + precipitation (x, y, time) float64 96B 5.68 9.256 0.7104 ... 4.615 7.805 Attributes: description: Weather related data. @@ -550,18 +632,19 @@ class Dataset( other variables had: >>> ds.isel(ds.temperature.argmin(...)) - + Size: 48B Dimensions: () Coordinates: - lon float64 -99.32 - lat float64 42.21 - time datetime64[ns] 2014-09-08 - reference_time datetime64[ns] 2014-09-05 + lon float64 8B -99.32 + lat float64 8B 42.21 + time datetime64[ns] 8B 2014-09-08 + reference_time datetime64[ns] 8B 2014-09-05 Data variables: - temperature float64 7.182 - precipitation float64 8.326 + temperature float64 8B 7.182 + precipitation float64 8B 8.326 Attributes: description: Weather related data. + """ _attrs: dict[Hashable, Any] | None @@ -589,12 +672,10 @@ def __init__( self, # could make a VariableArgs to use more generally, and refine these # categories - data_vars: Mapping[Any, Any] | None = None, + data_vars: DataVars | None = None, coords: Mapping[Any, Any] | None = None, attrs: Mapping[Any, Any] | None = None, ) -> None: - # TODO(shoyer): expose indexes as a public argument in __init__ - if data_vars is None: data_vars = {} if coords is None: @@ -607,13 +688,13 @@ def __init__( ) if isinstance(coords, Dataset): - coords = coords.variables + coords = coords._variables variables, coord_names, dims, indexes, _ = merge_data_and_coords( - data_vars, coords, compat="broadcast_equals" + data_vars, coords ) - self._attrs = dict(attrs) if attrs is not None else None + self._attrs = dict(attrs) if attrs else None self._close = None self._encoding = None self._variables = variables @@ -621,8 +702,13 @@ def __init__( self._dims = dims self._indexes = indexes + # TODO: dirty workaround for mypy 1.5 error with inherited DatasetOpsMixin vs. Mapping + # related to https://github.com/python/mypy/issues/9319? + def __eq__(self, other: DsCompatible) -> Self: # type: ignore[override] + return super().__eq__(other) + @classmethod - def load_store(cls: type[T_Dataset], store, decoder=None) -> T_Dataset: + def load_store(cls, store, decoder=None) -> Self: """Create a new dataset from the contents of a backends.*DataStore object """ @@ -653,7 +739,7 @@ def attrs(self) -> dict[Any, Any]: @attrs.setter def attrs(self, value: Mapping[Any, Any]) -> None: - self._attrs = dict(value) + self._attrs = dict(value) if value else None @property def encoding(self) -> dict[Any, Any]: @@ -666,6 +752,18 @@ def encoding(self) -> dict[Any, Any]: def encoding(self, value: Mapping[Any, Any]) -> None: self._encoding = dict(value) + def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() + + def drop_encoding(self) -> Self: + """Return a new Dataset without encoding on the dataset or any of its + variables/coords.""" + variables = {k: v.drop_encoding() for k, v in self.variables.items()} + return self._replace(variables=variables, encoding={}) + @property def dims(self) -> Frozen[Hashable, int]: """Mapping from dimension names to lengths. @@ -674,14 +772,15 @@ def dims(self) -> Frozen[Hashable, int]: Note that type of this object differs from `DataArray.dims`. See `Dataset.sizes` and `DataArray.sizes` for consistently named - properties. + properties. This property will be changed to return a type more consistent with + `DataArray.dims` in the future, i.e. a set of dimension names. See Also -------- Dataset.sizes DataArray.dims """ - return Frozen(self._dims) + return FrozenMappingWarningOnValuesAccess(self._dims) @property def sizes(self) -> Frozen[Hashable, int]: @@ -696,7 +795,7 @@ def sizes(self) -> Frozen[Hashable, int]: -------- DataArray.sizes """ - return self.dims + return Frozen(self._dims) @property def dtypes(self) -> Frozen[Hashable, np.dtype]: @@ -716,7 +815,7 @@ def dtypes(self) -> Frozen[Hashable, np.dtype]: } ) - def load(self: T_Dataset, **kwargs) -> T_Dataset: + def load(self, **kwargs) -> Self: """Manually trigger loading and/or computation of this dataset's data from disk or a remote source into memory and return this dataset. Unlike compute, the original dataset is modified and returned. @@ -737,13 +836,15 @@ def load(self: T_Dataset, **kwargs) -> T_Dataset: """ # access .data to coerce everything to numpy or dask arrays lazy_data = { - k: v._data for k, v in self.variables.items() if is_duck_dask_array(v._data) + k: v._data for k, v in self.variables.items() if is_chunked_array(v._data) } if lazy_data: - import dask.array as da + chunkmanager = get_chunked_array_type(*lazy_data.values()) - # evaluate all the dask arrays simultaneously - evaluated_data = da.compute(*lazy_data.values(), **kwargs) + # evaluate all the chunked arrays simultaneously + evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute( + *lazy_data.values(), **kwargs + ) for k, data in zip(lazy_data, evaluated_data): self.variables[k].data = data @@ -755,11 +856,11 @@ def load(self: T_Dataset, **kwargs) -> T_Dataset: return self - def __dask_tokenize__(self): + def __dask_tokenize__(self) -> object: from dask.base import normalize_token return normalize_token( - (type(self), self._variables, self._coord_names, self._attrs) + (type(self), self._variables, self._coord_names, self._attrs or None) ) def __dask_graph__(self): @@ -816,7 +917,7 @@ def __dask_postcompute__(self): def __dask_postpersist__(self): return self._dask_postpersist, () - def _dask_postcompute(self: T_Dataset, results: Iterable[Variable]) -> T_Dataset: + def _dask_postcompute(self, results: Iterable[Variable]) -> Self: import dask variables = {} @@ -839,8 +940,8 @@ def _dask_postcompute(self: T_Dataset, results: Iterable[Variable]) -> T_Dataset ) def _dask_postpersist( - self: T_Dataset, dsk: Mapping, *, rename: Mapping[str, str] | None = None - ) -> T_Dataset: + self, dsk: Mapping, *, rename: Mapping[str, str] | None = None + ) -> Self: from dask import is_dask_collection from dask.highlevelgraph import HighLevelGraph from dask.optimization import cull @@ -889,7 +990,7 @@ def _dask_postpersist( self._close, ) - def compute(self: T_Dataset, **kwargs) -> T_Dataset: + def compute(self, **kwargs) -> Self: """Manually trigger loading and/or computation of this dataset's data from disk or a remote source into memory and return a new dataset. Unlike load, the original dataset is left unaltered. @@ -911,7 +1012,7 @@ def compute(self: T_Dataset, **kwargs) -> T_Dataset: new = self.copy(deep=False) return new.load(**kwargs) - def _persist_inplace(self: T_Dataset, **kwargs) -> T_Dataset: + def _persist_inplace(self, **kwargs) -> Self: """Persist all Dask arrays in memory""" # access .data to coerce everything to numpy or dask arrays lazy_data = { @@ -928,7 +1029,7 @@ def _persist_inplace(self: T_Dataset, **kwargs) -> T_Dataset: return self - def persist(self: T_Dataset, **kwargs) -> T_Dataset: + def persist(self, **kwargs) -> Self: """Trigger computation, keeping data as dask arrays This operation can be used to trigger computation on underlying dask @@ -951,7 +1052,7 @@ def persist(self: T_Dataset, **kwargs) -> T_Dataset: @classmethod def _construct_direct( - cls: type[T_Dataset], + cls, variables: dict[Any, Variable], coord_names: set[Hashable], dims: dict[Any, int] | None = None, @@ -959,7 +1060,7 @@ def _construct_direct( indexes: dict[Any, Index] | None = None, encoding: dict | None = None, close: Callable[[], None] | None = None, - ) -> T_Dataset: + ) -> Self: """Shortcut around __init__ for internal use when we want to skip costly validation """ @@ -978,7 +1079,7 @@ def _construct_direct( return obj def _replace( - self: T_Dataset, + self, variables: dict[Hashable, Variable] | None = None, coord_names: set[Hashable] | None = None, dims: dict[Any, int] | None = None, @@ -986,7 +1087,7 @@ def _replace( indexes: dict[Hashable, Index] | None = None, encoding: dict | None | Default = _default, inplace: bool = False, - ) -> T_Dataset: + ) -> Self: """Fastpath constructor for internal use. Returns an object with optionally with replaced attributes. @@ -1028,13 +1129,13 @@ def _replace( return obj def _replace_with_new_dims( - self: T_Dataset, + self, variables: dict[Hashable, Variable], coord_names: set | None = None, attrs: dict[Hashable, Any] | None | Default = _default, indexes: dict[Hashable, Index] | None = None, inplace: bool = False, - ) -> T_Dataset: + ) -> Self: """Replace variables with recalculated dimensions.""" dims = calculate_dimensions(variables) return self._replace( @@ -1042,13 +1143,13 @@ def _replace_with_new_dims( ) def _replace_vars_and_dims( - self: T_Dataset, + self, variables: dict[Hashable, Variable], coord_names: set | None = None, dims: dict[Hashable, int] | None = None, attrs: dict[Hashable, Any] | None | Default = _default, inplace: bool = False, - ) -> T_Dataset: + ) -> Self: """Deprecated version of _replace_with_new_dims(). Unlike _replace_with_new_dims(), this method always recalculates @@ -1061,13 +1162,13 @@ def _replace_vars_and_dims( ) def _overwrite_indexes( - self: T_Dataset, + self, indexes: Mapping[Hashable, Index], variables: Mapping[Hashable, Variable] | None = None, drop_variables: list[Hashable] | None = None, drop_indexes: list[Hashable] | None = None, rename_dims: Mapping[Hashable, Hashable] | None = None, - ) -> T_Dataset: + ) -> Self: """Maybe replace indexes. This function may do a lot more depending on index query @@ -1134,9 +1235,7 @@ def _overwrite_indexes( else: return replaced - def copy( - self: T_Dataset, deep: bool = False, data: Mapping[Any, ArrayLike] | None = None - ) -> T_Dataset: + def copy(self, deep: bool = False, data: DataVars | None = None) -> Self: """Returns a copy of this dataset. If `deep=True`, a deep copy is made of each of the component variables. @@ -1174,60 +1273,60 @@ def copy( ... coords={"x": ["one", "two"]}, ... ) >>> ds.copy() - + Size: 88B Dimensions: (dim_0: 2, dim_1: 3, x: 2) Coordinates: - * x (x) >> ds_0 = ds.copy(deep=False) >>> ds_0["foo"][0, 0] = 7 >>> ds_0 - + Size: 88B Dimensions: (dim_0: 2, dim_1: 3, x: 2) Coordinates: - * x (x) >> ds - + Size: 88B Dimensions: (dim_0: 2, dim_1: 3, x: 2) Coordinates: - * x (x) >> ds.copy(data={"foo": np.arange(6).reshape(2, 3), "bar": ["a", "b"]}) - + Size: 80B Dimensions: (dim_0: 2, dim_1: 3, x: 2) Coordinates: - * x (x) >> ds - + Size: 88B Dimensions: (dim_0: 2, dim_1: 3, x: 2) Coordinates: - * x (x) T_Dataset: + ) -> Self: if data is None: data = {} elif not utils.is_dict_like(data): @@ -1253,13 +1352,13 @@ def _copy( if keys_not_in_vars: raise ValueError( "Data must only contain variables in original " - "dataset. Extra variables: {}".format(keys_not_in_vars) + f"dataset. Extra variables: {keys_not_in_vars}" ) keys_missing_from_data = var_keys - data_keys if keys_missing_from_data: raise ValueError( "Data must contain all variables in original " - "dataset. Data is missing {}".format(keys_missing_from_data) + f"dataset. Data is missing {keys_missing_from_data}" ) indexes, index_vars = self.xindexes.copy_indexes(deep=deep) @@ -1278,13 +1377,13 @@ def _copy( return self._replace(variables, indexes=indexes, attrs=attrs, encoding=encoding) - def __copy__(self: T_Dataset) -> T_Dataset: + def __copy__(self) -> Self: return self._copy(deep=False) - def __deepcopy__(self: T_Dataset, memo: dict[int, Any] | None = None) -> T_Dataset: + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: return self._copy(deep=True, memo=memo) - def as_numpy(self: T_Dataset) -> T_Dataset: + def as_numpy(self) -> Self: """ Coerces wrapped data and coordinates into numpy arrays, returning a Dataset. @@ -1296,7 +1395,7 @@ def as_numpy(self: T_Dataset) -> T_Dataset: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) - def _copy_listed(self: T_Dataset, names: Iterable[Hashable]) -> T_Dataset: + def _copy_listed(self, names: Iterable[Hashable]) -> Self: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. """ @@ -1309,7 +1408,7 @@ def _copy_listed(self: T_Dataset, names: Iterable[Hashable]) -> T_Dataset: variables[name] = self._variables[name] except KeyError: ref_name, var_name, var = _get_virtual_variable( - self._variables, name, self.dims + self._variables, name, self.sizes ) variables[var_name] = var if ref_name in self._coord_names or ref_name in self.dims: @@ -1324,7 +1423,7 @@ def _copy_listed(self: T_Dataset, names: Iterable[Hashable]) -> T_Dataset: for v in variables.values(): needed_dims.update(v.dims) - dims = {k: self.dims[k] for k in needed_dims} + dims = {k: self.sizes[k] for k in needed_dims} # preserves ordering of coordinates for k in self._variables: @@ -1346,15 +1445,15 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: try: variable = self._variables[name] except KeyError: - _, name, variable = _get_virtual_variable(self._variables, name, self.dims) + _, name, variable = _get_virtual_variable(self._variables, name, self.sizes) needed_dims = set(variable.dims) coords: dict[Hashable, Variable] = {} # preserve ordering for k in self._variables: - if k in self._coord_names and set(self.variables[k].dims) <= needed_dims: - coords[k] = self.variables[k] + if k in self._coord_names and set(self._variables[k].dims) <= needed_dims: + coords[k] = self._variables[k] indexes = filter_indexes_from_coords(self._indexes, set(coords)) @@ -1373,7 +1472,7 @@ def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords) # virtual coordinates - yield HybridMappingProxy(keys=self.dims, mapping=self) + yield HybridMappingProxy(keys=self.sizes, mapping=self) def __contains__(self, key: object) -> bool: """The 'in' operator will return true or false depending on whether @@ -1390,13 +1489,20 @@ def __bool__(self) -> bool: def __iter__(self) -> Iterator[Hashable]: return iter(self.data_vars) - def __array__(self, dtype=None): - raise TypeError( - "cannot directly convert an xarray.Dataset into a " - "numpy array. Instead, create an xarray.DataArray " - "first, either with indexing on the Dataset or by " - "invoking the `to_array()` method." - ) + if TYPE_CHECKING: + # needed because __getattr__ is returning Any and otherwise + # this class counts as part of the SupportsArray Protocol + __array__ = None # type: ignore[var-annotated,unused-ignore] + + else: + + def __array__(self, dtype=None): + raise TypeError( + "cannot directly convert an xarray.Dataset into a " + "numpy array. Instead, create an xarray.DataArray " + "first, either with indexing on the Dataset or by " + "invoking the `to_dataarray()` method." + ) @property def nbytes(self) -> int: @@ -1409,33 +1515,39 @@ def nbytes(self) -> int: return sum(v.nbytes for v in self.variables.values()) @property - def loc(self: T_Dataset) -> _LocIndexer[T_Dataset]: + def loc(self) -> _LocIndexer[Self]: """Attribute for location based indexing. Only supports __getitem__, and only when the key is a dict of the form {dim: labels}. """ return _LocIndexer(self) @overload - def __getitem__(self, key: Hashable) -> DataArray: - ... + def __getitem__(self, key: Hashable) -> DataArray: ... # Mapping is Iterable @overload - def __getitem__(self: T_Dataset, key: Iterable[Hashable]) -> T_Dataset: - ... + def __getitem__(self, key: Iterable[Hashable]) -> Self: ... def __getitem__( - self: T_Dataset, key: Mapping[Any, Any] | Hashable | Iterable[Hashable] - ) -> T_Dataset | DataArray: + self, key: Mapping[Any, Any] | Hashable | Iterable[Hashable] + ) -> Self | DataArray: """Access variables or coordinates of this dataset as a :py:class:`~xarray.DataArray` or a subset of variables or a indexed dataset. Indexing with a list of names will return a new ``Dataset`` object. """ + from xarray.core.formatting import shorten_list_repr + if utils.is_dict_like(key): return self.isel(**key) if utils.hashable(key): - return self._construct_dataarray(key) + try: + return self._construct_dataarray(key) + except KeyError as e: + raise KeyError( + f"No variable named {key!r}. Variables on the dataset include {shorten_list_repr(list(self.variables.keys()), max_items=10)}" + ) from e + if utils.iterable_of_hashable(key): return self._copy_listed(key) raise ValueError(f"Unsupported key-type {type(key)}") @@ -1569,7 +1681,7 @@ def _setitem_check(self, key, value): val = np.array(val) # type conversion - new_value[name] = val.astype(var_k.dtype, copy=False) + new_value[name] = duck_array_ops.astype(val, dtype=var_k.dtype, copy=False) # check consistency of dimension sizes and dimension coordinates if isinstance(value, DataArray) or isinstance(value, Dataset): @@ -1591,7 +1703,7 @@ def __delitem__(self, key: Hashable) -> None: # https://github.com/python/mypy/issues/4266 __hash__ = None # type: ignore[assignment] - def _all_compat(self, other: Dataset, compat_str: str) -> bool: + def _all_compat(self, other: Self, compat_str: str) -> bool: """Helper function for equals and identical""" # some stores (e.g., scipy) do not seem to preserve order, so don't @@ -1603,7 +1715,7 @@ def compat(x: Variable, y: Variable) -> bool: self._variables, other._variables, compat=compat ) - def broadcast_equals(self, other: Dataset) -> bool: + def broadcast_equals(self, other: Self) -> bool: """Two Datasets are broadcast equal if they are equal after broadcasting all variables against each other. @@ -1611,17 +1723,66 @@ def broadcast_equals(self, other: Dataset) -> bool: the other dataset can still be broadcast equal if the the non-scalar variable is a constant. + Examples + -------- + + # 2D array with shape (1, 3) + + >>> data = np.array([[1, 2, 3]]) + >>> a = xr.Dataset( + ... {"variable_name": (("space", "time"), data)}, + ... coords={"space": [0], "time": [0, 1, 2]}, + ... ) + >>> a + Size: 56B + Dimensions: (space: 1, time: 3) + Coordinates: + * space (space) int64 8B 0 + * time (time) int64 24B 0 1 2 + Data variables: + variable_name (space, time) int64 24B 1 2 3 + + # 2D array with shape (3, 1) + + >>> data = np.array([[1], [2], [3]]) + >>> b = xr.Dataset( + ... {"variable_name": (("time", "space"), data)}, + ... coords={"time": [0, 1, 2], "space": [0]}, + ... ) + >>> b + Size: 56B + Dimensions: (time: 3, space: 1) + Coordinates: + * time (time) int64 24B 0 1 2 + * space (space) int64 8B 0 + Data variables: + variable_name (time, space) int64 24B 1 2 3 + + .equals returns True if two Datasets have the same values, dimensions, and coordinates. .broadcast_equals returns True if the + results of broadcasting two Datasets against each other have the same values, dimensions, and coordinates. + + >>> a.equals(b) + False + + >>> a.broadcast_equals(b) + True + + >>> a2, b2 = xr.broadcast(a, b) + >>> a2.equals(b2) + True + See Also -------- Dataset.equals Dataset.identical + Dataset.broadcast """ try: return self._all_compat(other, "broadcast_equals") except (TypeError, AttributeError): return False - def equals(self, other: Dataset) -> bool: + def equals(self, other: Self) -> bool: """Two Datasets are equal if they have matching variables and coordinates, all of which are equal. @@ -1631,6 +1792,67 @@ def equals(self, other: Dataset) -> bool: This method is necessary because `v1 == v2` for ``Dataset`` does element-wise comparisons (like numpy.ndarrays). + Examples + -------- + + # 2D array with shape (1, 3) + + >>> data = np.array([[1, 2, 3]]) + >>> dataset1 = xr.Dataset( + ... {"variable_name": (("space", "time"), data)}, + ... coords={"space": [0], "time": [0, 1, 2]}, + ... ) + >>> dataset1 + Size: 56B + Dimensions: (space: 1, time: 3) + Coordinates: + * space (space) int64 8B 0 + * time (time) int64 24B 0 1 2 + Data variables: + variable_name (space, time) int64 24B 1 2 3 + + # 2D array with shape (3, 1) + + >>> data = np.array([[1], [2], [3]]) + >>> dataset2 = xr.Dataset( + ... {"variable_name": (("time", "space"), data)}, + ... coords={"time": [0, 1, 2], "space": [0]}, + ... ) + >>> dataset2 + Size: 56B + Dimensions: (time: 3, space: 1) + Coordinates: + * time (time) int64 24B 0 1 2 + * space (space) int64 8B 0 + Data variables: + variable_name (time, space) int64 24B 1 2 3 + >>> dataset1.equals(dataset2) + False + + >>> dataset1.broadcast_equals(dataset2) + True + + .equals returns True if two Datasets have the same values, dimensions, and coordinates. .broadcast_equals returns True if the + results of broadcasting two Datasets against each other have the same values, dimensions, and coordinates. + + Similar for missing values too: + + >>> ds1 = xr.Dataset( + ... { + ... "temperature": (["x", "y"], [[1, np.nan], [3, 4]]), + ... }, + ... coords={"x": [0, 1], "y": [0, 1]}, + ... ) + + >>> ds2 = xr.Dataset( + ... { + ... "temperature": (["x", "y"], [[1, np.nan], [3, 4]]), + ... }, + ... coords={"x": [0, 1], "y": [0, 1]}, + ... ) + >>> ds1.equals(ds2) + True + See Also -------- Dataset.broadcast_equals @@ -1641,10 +1863,70 @@ def equals(self, other: Dataset) -> bool: except (TypeError, AttributeError): return False - def identical(self, other: Dataset) -> bool: + def identical(self, other: Self) -> bool: """Like equals, but also checks all dataset attributes and the attributes on all variables and coordinates. + Example + ------- + + >>> a = xr.Dataset( + ... {"Width": ("X", [1, 2, 3])}, + ... coords={"X": [1, 2, 3]}, + ... attrs={"units": "m"}, + ... ) + >>> b = xr.Dataset( + ... {"Width": ("X", [1, 2, 3])}, + ... coords={"X": [1, 2, 3]}, + ... attrs={"units": "m"}, + ... ) + >>> c = xr.Dataset( + ... {"Width": ("X", [1, 2, 3])}, + ... coords={"X": [1, 2, 3]}, + ... attrs={"units": "ft"}, + ... ) + >>> a + Size: 48B + Dimensions: (X: 3) + Coordinates: + * X (X) int64 24B 1 2 3 + Data variables: + Width (X) int64 24B 1 2 3 + Attributes: + units: m + + >>> b + Size: 48B + Dimensions: (X: 3) + Coordinates: + * X (X) int64 24B 1 2 3 + Data variables: + Width (X) int64 24B 1 2 3 + Attributes: + units: m + + >>> c + Size: 48B + Dimensions: (X: 3) + Coordinates: + * X (X) int64 24B 1 2 3 + Data variables: + Width (X) int64 24B 1 2 3 + Attributes: + units: ft + + >>> a.equals(b) + True + + >>> a.identical(b) + True + + >>> a.equals(c) + True + + >>> a.identical(c) + False + See Also -------- Dataset.broadcast_equals @@ -1673,13 +1955,19 @@ def indexes(self) -> Indexes[pd.Index]: @property def xindexes(self) -> Indexes[Index]: - """Mapping of xarray Index objects used for label based indexing.""" + """Mapping of :py:class:`~xarray.indexes.Index` objects + used for label based indexing. + """ return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) @property def coords(self) -> DatasetCoordinates: - """Dictionary of xarray.DataArray objects corresponding to coordinate - variables + """Mapping of :py:class:`~xarray.DataArray` objects corresponding to + coordinate variables. + + See Also + -------- + Coordinates """ return DatasetCoordinates(self) @@ -1688,7 +1976,7 @@ def data_vars(self) -> DataVariables: """Dictionary of DataArray objects corresponding to data variables""" return DataVariables(self) - def set_coords(self: T_Dataset, names: Hashable | Iterable[Hashable]) -> T_Dataset: + def set_coords(self, names: Hashable | Iterable[Hashable]) -> Self: """Given names of one or more variables, set them as coordinates Parameters @@ -1696,6 +1984,33 @@ def set_coords(self: T_Dataset, names: Hashable | Iterable[Hashable]) -> T_Datas names : hashable or iterable of hashable Name(s) of variables in this dataset to convert into coordinates. + Examples + -------- + >>> dataset = xr.Dataset( + ... { + ... "pressure": ("time", [1.013, 1.2, 3.5]), + ... "time": pd.date_range("2023-01-01", periods=3), + ... } + ... ) + >>> dataset + Size: 48B + Dimensions: (time: 3) + Coordinates: + * time (time) datetime64[ns] 24B 2023-01-01 2023-01-02 2023-01-03 + Data variables: + pressure (time) float64 24B 1.013 1.2 3.5 + + >>> dataset.set_coords("pressure") + Size: 48B + Dimensions: (time: 3) + Coordinates: + pressure (time) float64 24B 1.013 1.2 3.5 + * time (time) datetime64[ns] 24B 2023-01-01 2023-01-02 2023-01-03 + Data variables: + *empty* + + On calling ``set_coords`` , these data variables are converted to coordinates, as shown in the final dataset. + Returns ------- Dataset @@ -1719,10 +2034,10 @@ def set_coords(self: T_Dataset, names: Hashable | Iterable[Hashable]) -> T_Datas return obj def reset_coords( - self: T_Dataset, + self, names: Dims = None, drop: bool = False, - ) -> T_Dataset: + ) -> Self: """Given names of coordinates, reset them to become variables Parameters @@ -1734,9 +2049,66 @@ def reset_coords( If True, remove coordinates instead of converting them into variables. + Examples + -------- + >>> dataset = xr.Dataset( + ... { + ... "temperature": ( + ... ["time", "lat", "lon"], + ... [[[25, 26], [27, 28]], [[29, 30], [31, 32]]], + ... ), + ... "precipitation": ( + ... ["time", "lat", "lon"], + ... [[[0.5, 0.8], [0.2, 0.4]], [[0.3, 0.6], [0.7, 0.9]]], + ... ), + ... }, + ... coords={ + ... "time": pd.date_range(start="2023-01-01", periods=2), + ... "lat": [40, 41], + ... "lon": [-80, -79], + ... "altitude": 1000, + ... }, + ... ) + + # Dataset before resetting coordinates + + >>> dataset + Size: 184B + Dimensions: (time: 2, lat: 2, lon: 2) + Coordinates: + * time (time) datetime64[ns] 16B 2023-01-01 2023-01-02 + * lat (lat) int64 16B 40 41 + * lon (lon) int64 16B -80 -79 + altitude int64 8B 1000 + Data variables: + temperature (time, lat, lon) int64 64B 25 26 27 28 29 30 31 32 + precipitation (time, lat, lon) float64 64B 0.5 0.8 0.2 0.4 0.3 0.6 0.7 0.9 + + # Reset the 'altitude' coordinate + + >>> dataset_reset = dataset.reset_coords("altitude") + + # Dataset after resetting coordinates + + >>> dataset_reset + Size: 184B + Dimensions: (time: 2, lat: 2, lon: 2) + Coordinates: + * time (time) datetime64[ns] 16B 2023-01-01 2023-01-02 + * lat (lat) int64 16B 40 41 + * lon (lon) int64 16B -80 -79 + Data variables: + temperature (time, lat, lon) int64 64B 25 26 27 28 29 30 31 32 + precipitation (time, lat, lon) float64 64B 0.5 0.8 0.2 0.4 0.3 0.6 0.7 0.9 + altitude int64 8B 1000 + Returns ------- Dataset + + See Also + -------- + Dataset.set_coords """ if names is None: names = self._coord_names - set(self._indexes) @@ -1779,8 +2151,23 @@ def to_netcdf( unlimited_dims: Iterable[Hashable] | None = None, compute: bool = True, invalid_netcdf: bool = False, - ) -> bytes: - ... + ) -> bytes: ... + + # compute=False returns dask.Delayed + @overload + def to_netcdf( + self, + path: str | PathLike, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Any, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + *, + compute: Literal[False], + invalid_netcdf: bool = False, + ) -> Delayed: ... # default return None @overload @@ -1795,10 +2182,10 @@ def to_netcdf( unlimited_dims: Iterable[Hashable] | None = None, compute: Literal[True] = True, invalid_netcdf: bool = False, - ) -> None: - ... + ) -> None: ... - # compute=False returns dask.Delayed + # if compute cannot be evaluated at type check time + # we may get back either Delayed or None @overload def to_netcdf( self, @@ -1809,11 +2196,9 @@ def to_netcdf( engine: T_NetcdfEngine | None = None, encoding: Mapping[Any, Mapping[str, Any]] | None = None, unlimited_dims: Iterable[Hashable] | None = None, - *, - compute: Literal[False], + compute: bool = True, invalid_netcdf: bool = False, - ) -> Delayed: - ... + ) -> Delayed | None: ... def to_netcdf( self, @@ -1873,7 +2258,9 @@ def to_netcdf( Nested dictionary with variable names as keys and dictionaries of variable specific encodings as values, e.g., ``{"my_variable": {"dtype": "int16", "scale_factor": 0.1, - "zlib": True}, ...}`` + "zlib": True}, ...}``. + If ``encoding`` is specified the original encoding of the variables of + the dataset is ignored. The `h5netcdf` engine supports both the NetCDF4-style compression encoding parameters ``{"zlib": True, "complevel": 9}`` and the h5py @@ -1928,19 +2315,21 @@ def to_zarr( self, store: MutableMapping | str | PathLike[str] | None = None, chunk_store: MutableMapping | str | PathLike | None = None, - mode: Literal["w", "w-", "a", "r+", None] = None, + mode: ZarrWriteModes | None = None, synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: Literal[True] = True, consolidated: bool | None = None, append_dim: Hashable | None = None, - region: Mapping[str, slice] | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, - ) -> ZarrStore: - ... + write_empty_chunks: bool | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, + ) -> ZarrStore: ... # compute=False returns dask.Delayed @overload @@ -1948,7 +2337,7 @@ def to_zarr( self, store: MutableMapping | str | PathLike[str] | None = None, chunk_store: MutableMapping | str | PathLike | None = None, - mode: Literal["w", "w-", "a", "r+", None] = None, + mode: ZarrWriteModes | None = None, synchronizer=None, group: str | None = None, encoding: Mapping | None = None, @@ -1956,27 +2345,32 @@ def to_zarr( compute: Literal[False], consolidated: bool | None = None, append_dim: Hashable | None = None, - region: Mapping[str, slice] | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, - ) -> Delayed: - ... + zarr_version: int | None = None, + write_empty_chunks: bool | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, + ) -> Delayed: ... def to_zarr( self, store: MutableMapping | str | PathLike[str] | None = None, chunk_store: MutableMapping | str | PathLike | None = None, - mode: Literal["w", "w-", "a", "r+", None] = None, + mode: ZarrWriteModes | None = None, synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: bool = True, consolidated: bool | None = None, append_dim: Hashable | None = None, - region: Mapping[str, slice] | None = None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, + write_empty_chunks: bool | None = None, + chunkmanager_store_kwargs: dict[str, Any] | None = None, ) -> ZarrStore | Delayed: """Write dataset contents to a zarr group. @@ -2001,10 +2395,11 @@ def to_zarr( chunk_store : MutableMapping, str or path-like, optional Store or path to directory in local or remote file system only for Zarr array chunks. Requires zarr-python v2.4.0 or later. - mode : {"w", "w-", "a", "r+", None}, optional + mode : {"w", "w-", "a", "a-", r+", None}, optional Persistence mode: "w" means create (overwrite if exists); "w-" means create (fail if exists); - "a" means override existing variables (create if does not exist); + "a" means override all existing variables including dimension coordinates (create if does not exist); + "a-" means only append those variables that have ``append_dim``. "r+" means modify existing array *values* only (raise an error if any metadata or shapes would change). The default mode is "a" if ``append_dim`` is set. Otherwise, it is @@ -2017,12 +2412,12 @@ def to_zarr( Nested dictionary with variable names as keys and dictionaries of variable specific encodings as values, e.g., ``{"my_variable": {"dtype": "int16", "scale_factor": 0.1,}, ...}`` - compute : bool, optional + compute : bool, default: True If True write array data immediately, otherwise return a ``dask.delayed.Delayed`` object that can be computed to write array data later. Metadata is always updated eagerly. consolidated : bool, optional - If True, apply zarr's `consolidate_metadata` function to the store + If True, apply :func:`zarr.convenience.consolidate_metadata` after writing metadata and read existing stores with consolidated metadata; if False, do not. The default (`consolidated=None`) means write consolidated metadata and attempt to read consolidated @@ -2033,7 +2428,7 @@ def to_zarr( append_dim : hashable, optional If set, the dimension along which the data will be appended. All other dimensions on overridden variables must remain the same size. - region : dict, optional + region : dict or "auto", optional Optional mapping from dimension names to integer slices along dataset dimensions to indicate the region of existing zarr array(s) in which to write this dataset's data. For example, @@ -2041,6 +2436,12 @@ def to_zarr( that values should be written to the region ``0:1000`` along ``x`` and ``10000:11000`` along ``y``. + Can also specify ``"auto"``, in which case the existing store will be + opened and the region inferred by matching the new data's coordinates. + ``"auto"`` can be used as a single string, which will automatically infer + the region for all dimensions, or as dictionary values for specific + dimensions mixed together with explicit slices for other dimensions. + Two restrictions apply to the use of ``region``: - If ``region`` is set, _all_ variables in a dataset must have at @@ -2051,7 +2452,7 @@ def to_zarr( in with ``region``, use a separate call to ``to_zarr()`` with ``compute=False``. See "Appending to existing Zarr stores" in the reference documentation for full details. - safe_chunks : bool, optional + safe_chunks : bool, default: True If True, only allow writes to when there is a many-to-one relationship between Zarr chunks (specified in encoding) and Dask chunks. Set False to override this restriction; however, data may become corrupted @@ -2065,6 +2466,21 @@ def to_zarr( The desired zarr spec version to target (currently 2 or 3). The default of None will attempt to determine the zarr version from ``store`` when possible, otherwise defaulting to 2. + write_empty_chunks : bool or None, optional + If True, all chunks will be stored regardless of their + contents. If False, each chunk is compared to the array's fill value + prior to storing. If a chunk is uniformly equal to the fill value, then + that chunk is not be stored, and the store entry for that chunk's key + is deleted. This setting enables sparser storage, as only chunks with + non-fill-value data are stored, at the expense of overhead associated + with checking the data of each chunk. If None (default) fall back to + specification(s) in ``encoding`` or Zarr defaults. A ``ValueError`` + will be raised if the value of this (if not None) differs with + ``encoding``. + chunkmanager_store_kwargs : dict, optional + Additional keyword arguments passed on to the `ChunkManager.store` method used to store + chunked arrays. For example for a dask array additional kwargs will be passed eventually to + :py:func:`dask.array.store()`. Experimental API that should not be relied upon. Returns ------- @@ -2095,7 +2511,7 @@ def to_zarr( """ from xarray.backends.api import to_zarr - return to_zarr( # type: ignore + return to_zarr( # type: ignore[call-overload,misc] self, store=store, chunk_store=chunk_store, @@ -2110,6 +2526,8 @@ def to_zarr( region=region, safe_chunks=safe_chunks, zarr_version=zarr_version, + write_empty_chunks=write_empty_chunks, + chunkmanager_store_kwargs=chunkmanager_store_kwargs, ) def __repr__(self) -> str: @@ -2140,7 +2558,7 @@ def info(self, buf: IO | None = None) -> None: lines = [] lines.append("xarray.Dataset {") lines.append("dimensions:") - for name, size in self.dims.items(): + for name, size in self.sizes.items(): lines.append(f"\t{name} = {size} ;") lines.append("\nvariables:") for name, da in self.variables.items(): @@ -2190,16 +2608,16 @@ def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]: return get_chunksizes(self.variables.values()) def chunk( - self: T_Dataset, - chunks: ( - int | Literal["auto"] | Mapping[Any, None | int | str | tuple[int, ...]] - ) = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + self, + chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) name_prefix: str = "xarray-", token: str | None = None, lock: bool = False, inline_array: bool = False, - **chunks_kwargs: None | int | str | tuple[int, ...], - ) -> T_Dataset: + chunked_array_type: str | ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, + **chunks_kwargs: T_ChunkDim, + ) -> Self: """Coerce all arrays in this dataset into dask arrays with the given chunks. @@ -2225,6 +2643,15 @@ def chunk( inline_array: bool, default: False Passed on to :py:func:`dask.array.from_array`, if the array is not already as dask array. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. **chunks_kwargs : {dim: chunks, ...}, optional The keyword arguments form of ``chunks``. One of chunks or chunks_kwargs must be provided @@ -2240,27 +2667,47 @@ def chunk( xarray.unify_chunks dask.array.from_array """ - if chunks is None and chunks_kwargs is None: + if chunks is None and not chunks_kwargs: warnings.warn( "None value for 'chunks' is deprecated. " "It will raise an error in the future. Use instead '{}'", - category=FutureWarning, + category=DeprecationWarning, ) chunks = {} - - if isinstance(chunks, (Number, str, int)): - chunks = dict.fromkeys(self.dims, chunks) + chunks_mapping: Mapping[Any, Any] + if not isinstance(chunks, Mapping) and chunks is not None: + if isinstance(chunks, (tuple, list)): + utils.emit_user_level_warning( + "Supplying chunks as dimension-order tuples is deprecated. " + "It will raise an error in the future. Instead use a dict with dimensions as keys.", + category=DeprecationWarning, + ) + chunks_mapping = dict.fromkeys(self.dims, chunks) else: - chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") + chunks_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") - bad_dims = chunks.keys() - self.dims.keys() + bad_dims = chunks_mapping.keys() - self.sizes.keys() if bad_dims: raise ValueError( - f"some chunks keys are not dimensions on this object: {bad_dims}" + f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}" ) + chunkmanager = guess_chunkmanager(chunked_array_type) + if from_array_kwargs is None: + from_array_kwargs = {} + variables = { - k: _maybe_chunk(k, v, chunks, token, lock, name_prefix) + k: _maybe_chunk( + k, + v, + chunks_mapping, + token, + lock, + name_prefix, + inline_array=inline_array, + chunked_array_type=chunkmanager, + from_array_kwargs=from_array_kwargs.copy(), + ) for k, v in self.variables.items() } return self._replace(variables) @@ -2298,14 +2745,14 @@ def _validate_indexers( if v.dtype.kind in "US": index = self._indexes[k].to_pandas_index() if isinstance(index, pd.DatetimeIndex): - v = v.astype("datetime64[ns]") + v = duck_array_ops.astype(v, dtype="datetime64[ns]") elif isinstance(index, CFTimeIndex): v = _parse_array_of_cftime_strings(v, index.date_type) if v.ndim > 1: raise IndexError( "Unlabeled multi-dimensional array cannot be " - "used for indexing: {}".format(k) + f"used for indexing: {k}" ) yield k, v @@ -2345,9 +2792,9 @@ def _get_indexers_coords_and_indexes(self, indexers): if v.dtype.kind == "b": if v.ndim != 1: # we only support 1-d boolean array raise ValueError( - "{:d}d-boolean array is used for indexing along " - "dimension {!r}, but only 1d boolean arrays are " - "supported.".format(v.ndim, k) + f"{v.ndim:d}d-boolean array is used for indexing along " + f"dimension {k!r}, but only 1d boolean arrays are " + "supported." ) # Make sure in case of boolean DataArray, its # coordinate also should be indexed. @@ -2370,12 +2817,12 @@ def _get_indexers_coords_and_indexes(self, indexers): return attached_coords, attached_indexes def isel( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, drop: bool = False, missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with each array indexed along the specified dimension(s). @@ -2417,10 +2864,74 @@ def isel( in this dataset, unless vectorized indexing was triggered by using an array indexer, in which case the data will be a copy. + Examples + -------- + + >>> dataset = xr.Dataset( + ... { + ... "math_scores": ( + ... ["student", "test"], + ... [[90, 85, 92], [78, 80, 85], [95, 92, 98]], + ... ), + ... "english_scores": ( + ... ["student", "test"], + ... [[88, 90, 92], [75, 82, 79], [93, 96, 91]], + ... ), + ... }, + ... coords={ + ... "student": ["Alice", "Bob", "Charlie"], + ... "test": ["Test 1", "Test 2", "Test 3"], + ... }, + ... ) + + # A specific element from the dataset is selected + + >>> dataset.isel(student=1, test=0) + Size: 68B + Dimensions: () + Coordinates: + student >> slice_of_data = dataset.isel(student=slice(0, 2), test=slice(0, 2)) + >>> slice_of_data + Size: 168B + Dimensions: (student: 2, test: 2) + Coordinates: + * student (student) >> index_array = xr.DataArray([0, 2], dims="student") + >>> indexed_data = dataset.isel(student=index_array) + >>> indexed_data + Size: 224B + Dimensions: (student: 2, test: 3) + Coordinates: + * student (student) T_Dataset: + ) -> Self: valid_indexers = dict(self._validate_indexers(indexers, missing_dims)) variables: dict[Hashable, Variable] = {} @@ -2502,13 +3013,13 @@ def _isel_fancy( return self._replace_with_new_dims(variables, coord_names, indexes=indexes) def sel( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, drop: bool = False, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with each array indexed by tick labels along the specified dimension(s). @@ -2568,6 +3079,13 @@ def sel( -------- Dataset.isel DataArray.sel + + :doc:`xarray-tutorial:intermediate/indexing/indexing` + Tutorial material on indexing with Xarray objects + + :doc:`xarray-tutorial:fundamentals/02.1_indexing_Basic` + Tutorial material on basics of indexing + """ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") query_results = map_index_queries( @@ -2588,10 +3106,10 @@ def sel( return result._overwrite_indexes(*query_results.as_tuple()[1:]) def head( - self: T_Dataset, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with the first `n` values of each array for the specified dimension(s). @@ -2605,6 +3123,50 @@ def head( The keyword arguments form of ``indexers``. One of indexers or indexers_kwargs must be provided. + Examples + -------- + >>> dates = pd.date_range(start="2023-01-01", periods=5) + >>> pageviews = [1200, 1500, 900, 1800, 2000] + >>> visitors = [800, 1000, 600, 1200, 1500] + >>> dataset = xr.Dataset( + ... { + ... "pageviews": (("date"), pageviews), + ... "visitors": (("date"), visitors), + ... }, + ... coords={"date": dates}, + ... ) + >>> busiest_days = dataset.sortby("pageviews", ascending=False) + >>> busiest_days.head() + Size: 120B + Dimensions: (date: 5) + Coordinates: + * date (date) datetime64[ns] 40B 2023-01-05 2023-01-04 ... 2023-01-03 + Data variables: + pageviews (date) int64 40B 2000 1800 1500 1200 900 + visitors (date) int64 40B 1500 1200 1000 800 600 + + # Retrieve the 3 most busiest days in terms of pageviews + + >>> busiest_days.head(3) + Size: 72B + Dimensions: (date: 3) + Coordinates: + * date (date) datetime64[ns] 24B 2023-01-05 2023-01-04 2023-01-02 + Data variables: + pageviews (date) int64 24B 2000 1800 1500 + visitors (date) int64 24B 1500 1200 1000 + + # Using a dictionary to specify the number of elements for specific dimensions + + >>> busiest_days.head({"date": 3}) + Size: 72B + Dimensions: (date: 3) + Coordinates: + * date (date) datetime64[ns] 24B 2023-01-05 2023-01-04 2023-01-02 + Data variables: + pageviews (date) int64 24B 2000 1800 1500 + visitors (date) int64 24B 1500 1200 1000 + See Also -------- Dataset.tail @@ -2634,10 +3196,10 @@ def head( return self.isel(indexers_slices) def tail( - self: T_Dataset, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with the last `n` values of each array for the specified dimension(s). @@ -2651,6 +3213,48 @@ def tail( The keyword arguments form of ``indexers``. One of indexers or indexers_kwargs must be provided. + Examples + -------- + >>> activity_names = ["Walking", "Running", "Cycling", "Swimming", "Yoga"] + >>> durations = [30, 45, 60, 45, 60] # in minutes + >>> energies = [150, 300, 250, 400, 100] # in calories + >>> dataset = xr.Dataset( + ... { + ... "duration": (["activity"], durations), + ... "energy_expenditure": (["activity"], energies), + ... }, + ... coords={"activity": activity_names}, + ... ) + >>> sorted_dataset = dataset.sortby("energy_expenditure", ascending=False) + >>> sorted_dataset + Size: 240B + Dimensions: (activity: 5) + Coordinates: + * activity (activity) >> sorted_dataset.tail(3) + Size: 144B + Dimensions: (activity: 3) + Coordinates: + * activity (activity) >> sorted_dataset.tail({"activity": 3}) + Size: 144B + Dimensions: (activity: 3) + Coordinates: + * activity (activity) T_Dataset: + ) -> Self: """Returns a new dataset with each array indexed along every `n`-th value for the specified dimension(s) @@ -2713,28 +3317,28 @@ def thin( ... ) >>> x_ds = xr.Dataset({"foo": x}) >>> x_ds - + Size: 328B Dimensions: (x: 2, y: 13) Coordinates: - * x (x) int64 0 1 - * y (y) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 + * x (x) int64 16B 0 1 + * y (y) int64 104B 0 1 2 3 4 5 6 7 8 9 10 11 12 Data variables: - foo (x, y) int64 0 1 2 3 4 5 6 7 8 9 ... 16 17 18 19 20 21 22 23 24 25 + foo (x, y) int64 208B 0 1 2 3 4 5 6 7 8 ... 17 18 19 20 21 22 23 24 25 >>> x_ds.thin(3) - + Size: 88B Dimensions: (x: 1, y: 5) Coordinates: - * x (x) int64 0 - * y (y) int64 0 3 6 9 12 + * x (x) int64 8B 0 + * y (y) int64 40B 0 3 6 9 12 Data variables: - foo (x, y) int64 0 3 6 9 12 + foo (x, y) int64 40B 0 3 6 9 12 >>> x.thin({"x": 2, "y": 5}) - + Size: 24B array([[ 0, 5, 10]]) Coordinates: - * x (x) int64 0 - * y (y) int64 0 5 10 + * x (x) int64 8B 0 + * y (y) int64 24B 0 5 10 See Also -------- @@ -2768,10 +3372,10 @@ def thin( return self.isel(indexers_slices) def broadcast_like( - self: T_Dataset, - other: Dataset | DataArray, + self, + other: T_DataArrayOrSet, exclude: Iterable[Hashable] | None = None, - ) -> T_Dataset: + ) -> Self: """Broadcast this DataArray against another Dataset or DataArray. This is equivalent to xr.broadcast(other, self)[1] @@ -2791,9 +3395,7 @@ def broadcast_like( dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude) - return _broadcast_helper( - cast("T_Dataset", args[1]), exclude, dims_map, common_coords - ) + return _broadcast_helper(args[1], exclude, dims_map, common_coords) def _reindex_callback( self, @@ -2804,7 +3406,7 @@ def _reindex_callback( fill_value: Any, exclude_dims: frozenset[Hashable], exclude_vars: frozenset[Hashable], - ) -> Dataset: + ) -> Self: """Callback called from ``Aligner`` to create a new reindexed Dataset.""" new_variables = variables.copy() @@ -2852,18 +3454,22 @@ def _reindex_callback( new_variables, new_coord_names, indexes=new_indexes ) + reindexed.encoding = self.encoding + return reindexed def reindex_like( - self: T_Dataset, - other: Dataset | DataArray, + self, + other: T_Xarray, method: ReindexMethodOptions = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value: Any = xrdtypes.NA, - ) -> T_Dataset: - """Conform this object onto the indexes of another object, filling in - missing values with ``fill_value``. The default fill value is NaN. + ) -> Self: + """ + Conform this object onto the indexes of another object, for indexes which the + objects share. Missing values are filled with ``fill_value``. The default fill + value is NaN. Parameters ---------- @@ -2909,7 +3515,9 @@ def reindex_like( See Also -------- Dataset.reindex + DataArray.reindex_like align + """ return alignment.reindex_like( self, @@ -2921,14 +3529,14 @@ def reindex_like( ) def reindex( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, method: ReindexMethodOptions = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value: Any = xrdtypes.NA, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Conform this object onto a new set of indexes, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -2994,13 +3602,13 @@ def reindex( ... coords={"station": ["boston", "nyc", "seattle", "denver"]}, ... ) >>> x - + Size: 176B Dimensions: (station: 4) Coordinates: - * station (station) >> x.indexes Indexes: station Index(['boston', 'nyc', 'seattle', 'denver'], dtype='object', name='station') @@ -3010,37 +3618,37 @@ def reindex( >>> new_index = ["boston", "austin", "seattle", "lincoln"] >>> x.reindex({"station": new_index}) - + Size: 176B Dimensions: (station: 4) Coordinates: - * station (station) >> x.reindex({"station": new_index}, fill_value=0) - + Size: 176B Dimensions: (station: 4) Coordinates: - * station (station) >> x.reindex( ... {"station": new_index}, fill_value={"temperature": 0, "pressure": 100} ... ) - + Size: 176B Dimensions: (station: 4) Coordinates: - * station (station) >> x2 - + Size: 144B Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2019-01-01 2019-01-02 ... 2019-01-06 + * time (time) datetime64[ns] 48B 2019-01-01 2019-01-02 ... 2019-01-06 Data variables: - temperature (time) float64 15.57 12.77 nan 0.3081 16.59 15.12 - pressure (time) float64 481.8 191.7 395.9 264.4 284.0 462.8 + temperature (time) float64 48B 15.57 12.77 nan 0.3081 16.59 15.12 + pressure (time) float64 48B 481.8 191.7 395.9 264.4 284.0 462.8 Suppose we decide to expand the dataset to cover a wider date range. >>> time_index2 = pd.date_range("12/29/2018", periods=10, freq="D") >>> x2.reindex({"time": time_index2}) - + Size: 240B Dimensions: (time: 10) Coordinates: - * time (time) datetime64[ns] 2018-12-29 2018-12-30 ... 2019-01-07 + * time (time) datetime64[ns] 80B 2018-12-29 2018-12-30 ... 2019-01-07 Data variables: - temperature (time) float64 nan nan nan 15.57 ... 0.3081 16.59 15.12 nan - pressure (time) float64 nan nan nan 481.8 ... 264.4 284.0 462.8 nan + temperature (time) float64 80B nan nan nan 15.57 ... 0.3081 16.59 15.12 nan + pressure (time) float64 80B nan nan nan 481.8 ... 264.4 284.0 462.8 nan The index entries that did not have a value in the original data frame (for example, `2018-12-29`) are by default filled with NaN. If desired, we can fill in the missing values using one of several options. @@ -3093,33 +3701,33 @@ def reindex( >>> x3 = x2.reindex({"time": time_index2}, method="bfill") >>> x3 - + Size: 240B Dimensions: (time: 10) Coordinates: - * time (time) datetime64[ns] 2018-12-29 2018-12-30 ... 2019-01-07 + * time (time) datetime64[ns] 80B 2018-12-29 2018-12-30 ... 2019-01-07 Data variables: - temperature (time) float64 15.57 15.57 15.57 15.57 ... 16.59 15.12 nan - pressure (time) float64 481.8 481.8 481.8 481.8 ... 284.0 462.8 nan + temperature (time) float64 80B 15.57 15.57 15.57 15.57 ... 16.59 15.12 nan + pressure (time) float64 80B 481.8 481.8 481.8 481.8 ... 284.0 462.8 nan Please note that the `NaN` value present in the original dataset (at index value `2019-01-03`) will not be filled by any of the value propagation schemes. >>> x2.where(x2.temperature.isnull(), drop=True) - + Size: 24B Dimensions: (time: 1) Coordinates: - * time (time) datetime64[ns] 2019-01-03 + * time (time) datetime64[ns] 8B 2019-01-03 Data variables: - temperature (time) float64 nan - pressure (time) float64 395.9 + temperature (time) float64 8B nan + pressure (time) float64 8B 395.9 >>> x3.where(x3.temperature.isnull(), drop=True) - + Size: 48B Dimensions: (time: 2) Coordinates: - * time (time) datetime64[ns] 2019-01-03 2019-01-07 + * time (time) datetime64[ns] 16B 2019-01-03 2019-01-07 Data variables: - temperature (time) float64 nan nan - pressure (time) float64 395.9 nan + temperature (time) float64 16B nan nan + pressure (time) float64 16B 395.9 nan This is because filling while reindexing does not look at dataset values, but only compares the original and desired indexes. If you do want to fill in the `NaN` values present in the @@ -3137,7 +3745,7 @@ def reindex( ) def _reindex( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, @@ -3145,7 +3753,7 @@ def _reindex( fill_value: Any = xrdtypes.NA, sparse: bool = False, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """ Same as reindex but supports sparse option. """ @@ -3161,14 +3769,14 @@ def _reindex( ) def interp( - self: T_Dataset, + self, coords: Mapping[Any, Any] | None = None, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, method_non_numeric: str = "nearest", **coords_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Interpolate a Dataset onto new coordinates Performs univariate or multivariate interpolation of a Dataset onto @@ -3188,7 +3796,7 @@ def interp( If DataArrays are passed as new coordinates, their dimensions are used for the broadcasting. Missing values are skipped. method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \ - "barycentric", "krog", "pchip", "spline", "akima"}, default: "linear" + "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear" String indicating which method to use for interpolation: - 'linear': linear interpolation. Additional keyword @@ -3197,7 +3805,7 @@ def interp( are passed to :py:func:`scipy.interpolate.interp1d`. If ``method='polynomial'``, the ``order`` keyword argument must also be provided. - - 'barycentric', 'krog', 'pchip', 'spline', 'akima': use their + - 'barycentric', 'krogh', 'pchip', 'spline', 'akima': use their respective :py:class:`scipy.interpolate` classes. assume_sorted : bool, default: False @@ -3230,6 +3838,9 @@ def interp( scipy.interpolate.interp1d scipy.interpolate.interpn + :doc:`xarray-tutorial:fundamentals/02.2_manipulating_dimensions` + Tutorial material on manipulating data resolution using :py:func:`~xarray.Dataset.interp` + Examples -------- >>> ds = xr.Dataset( @@ -3243,38 +3854,38 @@ def interp( ... coords={"x": [0, 1, 2], "y": [10, 12, 14, 16]}, ... ) >>> ds - + Size: 176B Dimensions: (x: 3, y: 4) Coordinates: - * x (x) int64 0 1 2 - * y (y) int64 10 12 14 16 + * x (x) int64 24B 0 1 2 + * y (y) int64 32B 10 12 14 16 Data variables: - a (x) int64 5 7 4 - b (x, y) float64 1.0 4.0 2.0 9.0 2.0 7.0 6.0 nan 6.0 nan 5.0 8.0 + a (x) int64 24B 5 7 4 + b (x, y) float64 96B 1.0 4.0 2.0 9.0 2.0 7.0 6.0 nan 6.0 nan 5.0 8.0 1D interpolation with the default method (linear): >>> ds.interp(x=[0, 0.75, 1.25, 1.75]) - + Size: 224B Dimensions: (x: 4, y: 4) Coordinates: - * y (y) int64 10 12 14 16 - * x (x) float64 0.0 0.75 1.25 1.75 + * y (y) int64 32B 10 12 14 16 + * x (x) float64 32B 0.0 0.75 1.25 1.75 Data variables: - a (x) float64 5.0 6.5 6.25 4.75 - b (x, y) float64 1.0 4.0 2.0 nan 1.75 6.25 ... nan 5.0 nan 5.25 nan + a (x) float64 32B 5.0 6.5 6.25 4.75 + b (x, y) float64 128B 1.0 4.0 2.0 nan 1.75 ... nan 5.0 nan 5.25 nan 1D interpolation with a different method: >>> ds.interp(x=[0, 0.75, 1.25, 1.75], method="nearest") - + Size: 224B Dimensions: (x: 4, y: 4) Coordinates: - * y (y) int64 10 12 14 16 - * x (x) float64 0.0 0.75 1.25 1.75 + * y (y) int64 32B 10 12 14 16 + * x (x) float64 32B 0.0 0.75 1.25 1.75 Data variables: - a (x) float64 5.0 7.0 7.0 4.0 - b (x, y) float64 1.0 4.0 2.0 9.0 2.0 7.0 ... 6.0 nan 6.0 nan 5.0 8.0 + a (x) float64 32B 5.0 7.0 7.0 4.0 + b (x, y) float64 128B 1.0 4.0 2.0 9.0 2.0 7.0 ... nan 6.0 nan 5.0 8.0 1D extrapolation: @@ -3283,26 +3894,26 @@ def interp( ... method="linear", ... kwargs={"fill_value": "extrapolate"}, ... ) - + Size: 224B Dimensions: (x: 4, y: 4) Coordinates: - * y (y) int64 10 12 14 16 - * x (x) float64 1.0 1.5 2.5 3.5 + * y (y) int64 32B 10 12 14 16 + * x (x) float64 32B 1.0 1.5 2.5 3.5 Data variables: - a (x) float64 7.0 5.5 2.5 -0.5 - b (x, y) float64 2.0 7.0 6.0 nan 4.0 nan ... 4.5 nan 12.0 nan 3.5 nan + a (x) float64 32B 7.0 5.5 2.5 -0.5 + b (x, y) float64 128B 2.0 7.0 6.0 nan 4.0 ... nan 12.0 nan 3.5 nan 2D interpolation: >>> ds.interp(x=[0, 0.75, 1.25, 1.75], y=[11, 13, 15], method="linear") - + Size: 184B Dimensions: (x: 4, y: 3) Coordinates: - * x (x) float64 0.0 0.75 1.25 1.75 - * y (y) int64 11 13 15 + * x (x) float64 32B 0.0 0.75 1.25 1.75 + * y (y) int64 24B 11 13 15 Data variables: - a (x) float64 5.0 6.5 6.25 4.75 - b (x, y) float64 2.5 3.0 nan 4.0 5.625 nan nan nan nan nan nan nan + a (x) float64 32B 5.0 6.5 6.25 4.75 + b (x, y) float64 96B 2.5 3.0 nan 4.0 5.625 ... nan nan nan nan nan """ from xarray.core import missing @@ -3330,7 +3941,7 @@ def maybe_variable(obj, k): try: return obj._variables[k] except KeyError: - return as_variable((k, range(obj.dims[k]))) + return as_variable((k, range(obj.sizes[k]))) def _validate_interp_indexer(x, new_x): # In the case of datetimes, the restrictions placed on indexers @@ -3344,7 +3955,7 @@ def _validate_interp_indexer(x, new_x): "coordinate, the coordinates to " "interpolate to must be either datetime " "strings or datetimes. " - "Instead got\n{}".format(new_x) + f"Instead got\n{new_x}" ) return x, new_x @@ -3441,12 +4052,12 @@ def _validate_interp_indexer(x, new_x): def interp_like( self, - other: Dataset | DataArray, + other: T_Xarray, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, method_non_numeric: str = "nearest", - ) -> Dataset: + ) -> Self: """Interpolate this object onto the coordinates of another object, filling the out of range values with NaN. @@ -3464,7 +4075,7 @@ def interp_like( names to an 1d array-like, which provides coordinates upon which to index the variables in this dataset. Missing values are skipped. method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \ - "barycentric", "krog", "pchip", "spline", "akima"}, default: "linear" + "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear" String indicating which method to use for interpolation: - 'linear': linear interpolation. Additional keyword @@ -3473,7 +4084,7 @@ def interp_like( are passed to :py:func:`scipy.interpolate.interp1d`. If ``method='polynomial'``, the ``order`` keyword argument must also be provided. - - 'barycentric', 'krog', 'pchip', 'spline', 'akima': use their + - 'barycentric', 'krogh', 'pchip', 'spline', 'akima': use their respective :py:class:`scipy.interpolate` classes. assume_sorted : bool, default: False @@ -3554,7 +4165,7 @@ def _rename_vars( return variables, coord_names def _rename_dims(self, name_dict: Mapping[Any, Hashable]) -> dict[Hashable, int]: - return {name_dict.get(k, k): v for k, v in self.dims.items()} + return {name_dict.get(k, k): v for k, v in self.sizes.items()} def _rename_indexes( self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable] @@ -3596,10 +4207,10 @@ def _rename_all( return variables, coord_names, dims, indexes def _rename( - self: T_Dataset, + self, name_dict: Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> T_Dataset: + ) -> Self: """Also used internally by DataArray so that the warning (if any) is raised at the right stack level. """ @@ -3614,6 +4225,9 @@ def _rename( create_dim_coord = False new_k = name_dict[k] + if k == new_k: + continue # Same name, nothing to do + if k in self.dims and new_k in self._coord_names: coord_dims = self._variables[name_dict[k]].dims if coord_dims == (k,): @@ -3638,10 +4252,10 @@ def _rename( return self._replace(variables, coord_names, dims=dims, indexes=indexes) def rename( - self: T_Dataset, + self, name_dict: Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> T_Dataset: + ) -> Self: """Returns a new object with renamed variables, coordinates and dimensions. Parameters @@ -3668,10 +4282,10 @@ def rename( return self._rename(name_dict=name_dict, **names) def rename_dims( - self: T_Dataset, + self, dims_dict: Mapping[Any, Hashable] | None = None, **dims: Hashable, - ) -> T_Dataset: + ) -> Self: """Returns a new object with renamed dimensions only. Parameters @@ -3700,8 +4314,8 @@ def rename_dims( for k, v in dims_dict.items(): if k not in self.dims: raise ValueError( - f"cannot rename {k!r} because it is not a " - "dimension in this dataset" + f"cannot rename {k!r} because it is not found " + f"in the dimensions of this dataset {tuple(self.dims)}" ) if v in self.dims or v in self: raise ValueError( @@ -3715,10 +4329,10 @@ def rename_dims( return self._replace(variables, coord_names, dims=sizes, indexes=indexes) def rename_vars( - self: T_Dataset, + self, name_dict: Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> T_Dataset: + ) -> Self: """Returns a new object with renamed variables including coordinates Parameters @@ -3755,8 +4369,8 @@ def rename_vars( return self._replace(variables, coord_names, dims=dims, indexes=indexes) def swap_dims( - self: T_Dataset, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs - ) -> T_Dataset: + self, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs + ) -> Self: """Returns a new object with swapped dimensions. Parameters @@ -3780,35 +4394,35 @@ def swap_dims( ... coords={"x": ["a", "b"], "y": ("x", [0, 1])}, ... ) >>> ds - + Size: 56B Dimensions: (x: 2) Coordinates: - * x (x) >> ds.swap_dims({"x": "y"}) - + Size: 56B Dimensions: (y: 2) Coordinates: - x (y) >> ds.swap_dims({"x": "z"}) - + Size: 56B Dimensions: (z: 2) Coordinates: - x (z) Dataset: + ) -> Self: """Return a new object with an additional axis (or axes) inserted at the corresponding position in the array shape. The new object is a view into the underlying array, not a copy. @@ -3900,6 +4514,64 @@ def expand_dims( expanded : Dataset This object, but with additional dimension(s). + Examples + -------- + >>> dataset = xr.Dataset({"temperature": ([], 25.0)}) + >>> dataset + Size: 8B + Dimensions: () + Data variables: + temperature float64 8B 25.0 + + # Expand the dataset with a new dimension called "time" + + >>> dataset.expand_dims(dim="time") + Size: 8B + Dimensions: (time: 1) + Dimensions without coordinates: time + Data variables: + temperature (time) float64 8B 25.0 + + # 1D data + + >>> temperature_1d = xr.DataArray([25.0, 26.5, 24.8], dims="x") + >>> dataset_1d = xr.Dataset({"temperature": temperature_1d}) + >>> dataset_1d + Size: 24B + Dimensions: (x: 3) + Dimensions without coordinates: x + Data variables: + temperature (x) float64 24B 25.0 26.5 24.8 + + # Expand the dataset with a new dimension called "time" using axis argument + + >>> dataset_1d.expand_dims(dim="time", axis=0) + Size: 24B + Dimensions: (time: 1, x: 3) + Dimensions without coordinates: time, x + Data variables: + temperature (time, x) float64 24B 25.0 26.5 24.8 + + # 2D data + + >>> temperature_2d = xr.DataArray(np.random.rand(3, 4), dims=("y", "x")) + >>> dataset_2d = xr.Dataset({"temperature": temperature_2d}) + >>> dataset_2d + Size: 96B + Dimensions: (y: 3, x: 4) + Dimensions without coordinates: y, x + Data variables: + temperature (y, x) float64 96B 0.5488 0.7152 0.6028 ... 0.7917 0.5289 + + # Expand the dataset with a new dimension called "time" using axis argument + + >>> dataset_2d.expand_dims(dim="time", axis=2) + Size: 96B + Dimensions: (y: 3, x: 4, time: 1) + Dimensions without coordinates: y, x, time + Data variables: + temperature (y, x, time) float64 96B 0.5488 0.7152 0.6028 ... 0.7917 0.5289 + See Also -------- DataArray.expand_dims @@ -3936,8 +4608,7 @@ def expand_dims( raise ValueError(f"Dimension {d} already exists.") if d in self._variables and not utils.is_scalar(self._variables[d]): raise ValueError( - "{dim} already exists as coordinate or" - " variable name.".format(dim=d) + f"{d} already exists as coordinate or" " variable name." ) variables: dict[Hashable, Variable] = {} @@ -3960,8 +4631,7 @@ def expand_dims( pass # Do nothing if the dimensions value is just an int else: raise TypeError( - "The value of new dimension {k} must be " - "an iterable or an int".format(k=k) + f"The value of new dimension {k} must be " "an iterable or an int" ) for k, v in self._variables.items(): @@ -4000,14 +4670,12 @@ def expand_dims( variables, coord_names=coord_names, indexes=indexes ) - # change type of self and return to T_Dataset once - # https://github.com/python/mypy/issues/12846 is resolved def set_index( self, indexes: Mapping[Any, Hashable | Sequence[Hashable]] | None = None, append: bool = False, **indexes_kwargs: Hashable | Sequence[Hashable], - ) -> Dataset: + ) -> Self: """Set Dataset (multi-)indexes using one or more existing coordinates or variables. @@ -4043,22 +4711,22 @@ def set_index( ... ) >>> ds = xr.Dataset({"v": arr}) >>> ds - + Size: 104B Dimensions: (x: 2, y: 3) Coordinates: - * x (x) int64 0 1 - * y (y) int64 0 1 2 - a (x) int64 3 4 + * x (x) int64 16B 0 1 + * y (y) int64 24B 0 1 2 + a (x) int64 16B 3 4 Data variables: - v (x, y) float64 1.0 1.0 1.0 1.0 1.0 1.0 + v (x, y) float64 48B 1.0 1.0 1.0 1.0 1.0 1.0 >>> ds.set_index(x="a") - + Size: 88B Dimensions: (x: 2, y: 3) Coordinates: - * x (x) int64 3 4 - * y (y) int64 0 1 2 + * x (x) int64 16B 3 4 + * y (y) int64 24B 0 1 2 Data variables: - v (x, y) float64 1.0 1.0 1.0 1.0 1.0 1.0 + v (x, y) float64 48B 1.0 1.0 1.0 1.0 1.0 1.0 See Also -------- @@ -4105,7 +4773,9 @@ def set_index( if len(var_names) == 1 and (not append or dim not in self._indexes): var_name = var_names[0] var = self._variables[var_name] - if var.dims != (dim,): + # an error with a better message will be raised for scalar variables + # when creating the PandasIndex + if var.ndim > 0 and var.dims != (dim,): raise ValueError( f"dimension mismatch: try setting an index for dimension {dim!r} with " f"variable {var_name!r} that has dimensions {var.dims}" @@ -4165,11 +4835,13 @@ def set_index( variables, coord_names=coord_names, indexes=indexes_ ) + @_deprecate_positional_args("v2023.10.0") def reset_index( - self: T_Dataset, + self, dims_or_levels: Hashable | Sequence[Hashable], + *, drop: bool = False, - ) -> T_Dataset: + ) -> Self: """Reset the specified index(es) or multi-index level(s). This legacy method is specific to pandas (multi-)indexes and @@ -4277,11 +4949,11 @@ def drop_or_convert(var_names): ) def set_xindex( - self: T_Dataset, + self, coord_names: str | Sequence[Hashable], index_cls: type[Index] | None = None, **options, - ) -> T_Dataset: + ) -> Self: """Set a new, Xarray-compatible index from one or more existing coordinate(s). @@ -4389,10 +5061,10 @@ def set_xindex( ) def reorder_levels( - self: T_Dataset, + self, dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None, **dim_order_kwargs: Sequence[int | Hashable], - ) -> T_Dataset: + ) -> Self: """Rearrange index levels using input order. Parameters @@ -4485,7 +5157,7 @@ def _get_stack_index( if dim in self._variables: var = self._variables[dim] else: - _, _, var = _get_virtual_variable(self._variables, dim, self.dims) + _, _, var = _get_virtual_variable(self._variables, dim, self.sizes) # dummy index (only `stack_coords` will be used to construct the multi-index) stack_index = PandasIndex([0], dim) stack_coords = {dim: var} @@ -4493,12 +5165,12 @@ def _get_stack_index( return stack_index, stack_coords def _stack_once( - self: T_Dataset, + self, dims: Sequence[Hashable | ellipsis], new_dim: Hashable, index_cls: type[Index], create_index: bool | None = True, - ) -> T_Dataset: + ) -> Self: if dims == ...: raise ValueError("Please use [...] for dims, rather than just ...") if ... in dims: @@ -4512,7 +5184,7 @@ def _stack_once( if any(d in var.dims for d in dims): add_dims = [d for d in dims if d not in var.dims] vdims = list(var.dims) + add_dims - shape = [self.dims[d] for d in vdims] + shape = [self.sizes[d] for d in vdims] exp_var = var.set_dims(vdims, shape) stacked_var = exp_var.stack(**{new_dim: dims}) new_variables[name] = stacked_var @@ -4552,12 +5224,12 @@ def _stack_once( ) def stack( - self: T_Dataset, + self, dimensions: Mapping[Any, Sequence[Hashable | ellipsis]] | None = None, create_index: bool | None = True, index_cls: type[Index] = PandasMultiIndex, **dimensions_kwargs: Sequence[Hashable | ellipsis], - ) -> T_Dataset: + ) -> Self: """ Stack any number of existing dimensions into a single new dimension. @@ -4611,7 +5283,7 @@ def to_stacked_array( """Combine variables of differing dimensionality into a DataArray without broadcasting. - This method is similar to Dataset.to_array but does not broadcast the + This method is similar to Dataset.to_dataarray but does not broadcast the variables. Parameters @@ -4640,7 +5312,7 @@ def to_stacked_array( See Also -------- - Dataset.to_array + Dataset.to_dataarray Dataset.stack DataArray.to_unstacked_dataset @@ -4655,23 +5327,23 @@ def to_stacked_array( ... ) >>> data - + Size: 76B Dimensions: (x: 2, y: 3) Coordinates: - * y (y) >> data.to_stacked_array("z", sample_dims=["x"]) - + Size: 64B array([[0, 1, 2, 6], [3, 4, 5, 7]]) Coordinates: - * z (z) object MultiIndex - * variable (z) object 'a' 'a' 'a' 'b' - * y (z) object 'u' 'v' 'w' nan + * z (z) object 32B MultiIndex + * variable (z) object 32B 'a' 'a' 'a' 'b' + * y (z) object 32B 'u' 'v' 'w' nan Dimensions without coordinates: x """ @@ -4679,34 +5351,31 @@ def to_stacked_array( stacking_dims = tuple(dim for dim in self.dims if dim not in sample_dims) - for variable in self: - dims = self[variable].dims - dims_include_sample_dims = set(sample_dims) <= set(dims) - if not dims_include_sample_dims: + for key, da in self.data_vars.items(): + missing_sample_dims = set(sample_dims) - set(da.dims) + if missing_sample_dims: raise ValueError( - "All variables in the dataset must contain the " - "dimensions {}.".format(dims) + "Variables in the dataset must contain all ``sample_dims`` " + f"({sample_dims!r}) but '{key}' misses {sorted(map(str, missing_sample_dims))}" ) - def ensure_stackable(val): - assign_coords = {variable_dim: val.name} - for dim in stacking_dims: - if dim not in val.dims: - assign_coords[dim] = None + def stack_dataarray(da): + # add missing dims/ coords and the name of the variable - expand_dims = set(stacking_dims).difference(set(val.dims)) - expand_dims.add(variable_dim) - # must be list for .expand_dims - expand_dims = list(expand_dims) + missing_stack_coords = {variable_dim: da.name} + for dim in set(stacking_dims) - set(da.dims): + missing_stack_coords[dim] = None + + missing_stack_dims = list(missing_stack_coords) return ( - val.assign_coords(**assign_coords) - .expand_dims(expand_dims) + da.assign_coords(**missing_stack_coords) + .expand_dims(missing_stack_dims) .stack({new_dim: (variable_dim,) + stacking_dims}) ) # concatenate the arrays - stackable_vars = [ensure_stackable(self[key]) for key in self.data_vars] + stackable_vars = [stack_dataarray(da) for da in self.data_vars.values()] data_array = concat(stackable_vars, dim=new_dim) if name is not None: @@ -4715,12 +5384,12 @@ def ensure_stackable(val): return data_array def _unstack_once( - self: T_Dataset, + self, dim: Hashable, index_and_vars: tuple[Index, dict[Hashable, Variable]], fill_value, sparse: bool = False, - ) -> T_Dataset: + ) -> Self: index, index_vars = index_and_vars variables: dict[Hashable, Variable] = {} indexes = {k: v for k, v in self._indexes.items() if k != dim} @@ -4755,12 +5424,12 @@ def _unstack_once( ) def _unstack_full_reindex( - self: T_Dataset, + self, dim: Hashable, index_and_vars: tuple[Index, dict[Hashable, Variable]], fill_value, sparse: bool, - ) -> T_Dataset: + ) -> Self: index, index_vars = index_and_vars variables: dict[Hashable, Variable] = {} indexes = {k: v for k, v in self._indexes.items() if k != dim} @@ -4805,12 +5474,14 @@ def _unstack_full_reindex( variables, coord_names=coord_names, indexes=indexes ) + @_deprecate_positional_args("v2023.10.0") def unstack( - self: T_Dataset, + self, dim: Dims = None, + *, fill_value: Any = xrdtypes.NA, sparse: bool = False, - ) -> T_Dataset: + ) -> Self: """ Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions. @@ -4847,10 +5518,10 @@ def unstack( else: dims = list(dim) - missing_dims = [d for d in dims if d not in self.dims] + missing_dims = set(dims) - set(self.dims) if missing_dims: raise ValueError( - f"Dataset does not contain the dimensions: {missing_dims}" + f"Dimensions {tuple(missing_dims)} not found in data dimensions {tuple(self.dims)}" ) # each specified dimension must have exactly one multi-index @@ -4907,7 +5578,7 @@ def unstack( result = result._unstack_once(d, stacked_indexes[d], fill_value, sparse) return result - def update(self: T_Dataset, other: CoercibleMapping) -> T_Dataset: + def update(self, other: CoercibleMapping) -> Self: """Update this dataset's variables with those from another dataset. Just like :py:meth:`dict.update` this is a in-place operation. @@ -4947,14 +5618,14 @@ def update(self: T_Dataset, other: CoercibleMapping) -> T_Dataset: return self._replace(inplace=True, **merge_result._asdict()) def merge( - self: T_Dataset, + self, other: CoercibleMapping | DataArray, overwrite_vars: Hashable | Iterable[Hashable] = frozenset(), compat: CompatOptions = "no_conflicts", join: JoinOptions = "outer", fill_value: Any = xrdtypes.NA, combine_attrs: CombineAttrsOptions = "override", - ) -> T_Dataset: + ) -> Self: """Merge the arrays of two datasets into a single dataset. This method generally does not allow for overriding data, with the @@ -5058,45 +5729,149 @@ def _assert_all_in_dataset( ) def drop_vars( - self: T_Dataset, - names: Hashable | Iterable[Hashable], + self, + names: str | Iterable[Hashable] | Callable[[Self], str | Iterable[Hashable]], *, errors: ErrorOptions = "raise", - ) -> T_Dataset: + ) -> Self: """Drop variables from this dataset. Parameters ---------- - names : hashable or iterable of hashable - Name(s) of variables to drop. + names : Hashable or iterable of Hashable or Callable + Name(s) of variables to drop. If a Callable, this object is passed as its + only argument and its result is used. errors : {"raise", "ignore"}, default: "raise" If 'raise', raises a ValueError error if any of the variable passed are not in the dataset. If 'ignore', any given names that are in the dataset are dropped and no error is raised. + Examples + -------- + + >>> dataset = xr.Dataset( + ... { + ... "temperature": ( + ... ["time", "latitude", "longitude"], + ... [[[25.5, 26.3], [27.1, 28.0]]], + ... ), + ... "humidity": ( + ... ["time", "latitude", "longitude"], + ... [[[65.0, 63.8], [58.2, 59.6]]], + ... ), + ... "wind_speed": ( + ... ["time", "latitude", "longitude"], + ... [[[10.2, 8.5], [12.1, 9.8]]], + ... ), + ... }, + ... coords={ + ... "time": pd.date_range("2023-07-01", periods=1), + ... "latitude": [40.0, 40.2], + ... "longitude": [-75.0, -74.8], + ... }, + ... ) + >>> dataset + Size: 136B + Dimensions: (time: 1, latitude: 2, longitude: 2) + Coordinates: + * time (time) datetime64[ns] 8B 2023-07-01 + * latitude (latitude) float64 16B 40.0 40.2 + * longitude (longitude) float64 16B -75.0 -74.8 + Data variables: + temperature (time, latitude, longitude) float64 32B 25.5 26.3 27.1 28.0 + humidity (time, latitude, longitude) float64 32B 65.0 63.8 58.2 59.6 + wind_speed (time, latitude, longitude) float64 32B 10.2 8.5 12.1 9.8 + + Drop the 'humidity' variable + + >>> dataset.drop_vars(["humidity"]) + Size: 104B + Dimensions: (time: 1, latitude: 2, longitude: 2) + Coordinates: + * time (time) datetime64[ns] 8B 2023-07-01 + * latitude (latitude) float64 16B 40.0 40.2 + * longitude (longitude) float64 16B -75.0 -74.8 + Data variables: + temperature (time, latitude, longitude) float64 32B 25.5 26.3 27.1 28.0 + wind_speed (time, latitude, longitude) float64 32B 10.2 8.5 12.1 9.8 + + Drop the 'humidity', 'temperature' variables + + >>> dataset.drop_vars(["humidity", "temperature"]) + Size: 72B + Dimensions: (time: 1, latitude: 2, longitude: 2) + Coordinates: + * time (time) datetime64[ns] 8B 2023-07-01 + * latitude (latitude) float64 16B 40.0 40.2 + * longitude (longitude) float64 16B -75.0 -74.8 + Data variables: + wind_speed (time, latitude, longitude) float64 32B 10.2 8.5 12.1 9.8 + + Drop all indexes + + >>> dataset.drop_vars(lambda x: x.indexes) + Size: 96B + Dimensions: (time: 1, latitude: 2, longitude: 2) + Dimensions without coordinates: time, latitude, longitude + Data variables: + temperature (time, latitude, longitude) float64 32B 25.5 26.3 27.1 28.0 + humidity (time, latitude, longitude) float64 32B 65.0 63.8 58.2 59.6 + wind_speed (time, latitude, longitude) float64 32B 10.2 8.5 12.1 9.8 + + Attempt to drop non-existent variable with errors="ignore" + + >>> dataset.drop_vars(["pressure"], errors="ignore") + Size: 136B + Dimensions: (time: 1, latitude: 2, longitude: 2) + Coordinates: + * time (time) datetime64[ns] 8B 2023-07-01 + * latitude (latitude) float64 16B 40.0 40.2 + * longitude (longitude) float64 16B -75.0 -74.8 + Data variables: + temperature (time, latitude, longitude) float64 32B 25.5 26.3 27.1 28.0 + humidity (time, latitude, longitude) float64 32B 65.0 63.8 58.2 59.6 + wind_speed (time, latitude, longitude) float64 32B 10.2 8.5 12.1 9.8 + + Attempt to drop non-existent variable with errors="raise" + + >>> dataset.drop_vars(["pressure"], errors="raise") + Traceback (most recent call last): + ValueError: These variables cannot be found in this dataset: ['pressure'] + + Raises + ------ + ValueError + Raised if you attempt to drop a variable which is not present, and the kwarg ``errors='raise'``. + Returns ------- dropped : Dataset + See Also + -------- + DataArray.drop_vars + """ + if callable(names): + names = names(self) # the Iterable check is required for mypy if is_scalar(names) or not isinstance(names, Iterable): - names = {names} + names_set = {names} else: - names = set(names) + names_set = set(names) if errors == "raise": - self._assert_all_in_dataset(names) + self._assert_all_in_dataset(names_set) # GH6505 other_names = set() - for var in names: + for var in names_set: maybe_midx = self._indexes.get(var, None) if isinstance(maybe_midx, PandasMultiIndex): - idx_coord_names = set(maybe_midx.index.names + [maybe_midx.dim]) - idx_other_names = idx_coord_names - set(names) + idx_coord_names = set(list(maybe_midx.index.names) + [maybe_midx.dim]) + idx_other_names = idx_coord_names - set(names_set) other_names.update(idx_other_names) if other_names: - names |= set(other_names) + names_set |= set(other_names) warnings.warn( f"Deleting a single level of a MultiIndex is deprecated. Previously, this deleted all levels of a MultiIndex. " f"Please also drop the following variables: {other_names!r} to avoid an error in the future.", @@ -5104,21 +5879,21 @@ def drop_vars( stacklevel=2, ) - assert_no_index_corrupted(self.xindexes, names) + assert_no_index_corrupted(self.xindexes, names_set) - variables = {k: v for k, v in self._variables.items() if k not in names} + variables = {k: v for k, v in self._variables.items() if k not in names_set} coord_names = {k for k in self._coord_names if k in variables} - indexes = {k: v for k, v in self._indexes.items() if k not in names} + indexes = {k: v for k, v in self._indexes.items() if k not in names_set} return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes ) def drop_indexes( - self: T_Dataset, + self, coord_names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_Dataset: + ) -> Self: """Drop the indexes assigned to the given coordinates. Parameters @@ -5145,7 +5920,10 @@ def drop_indexes( if errors == "raise": invalid_coords = coord_names - self._coord_names if invalid_coords: - raise ValueError(f"those coordinates don't exist: {invalid_coords}") + raise ValueError( + f"The coordinates {tuple(invalid_coords)} are not found in the " + f"dataset coordinates {tuple(self.coords.keys())}" + ) unindexed_coords = set(coord_names) - set(self._indexes) if unindexed_coords: @@ -5167,13 +5945,13 @@ def drop_indexes( return self._replace(variables=variables, indexes=indexes) def drop( - self: T_Dataset, + self, labels=None, dim=None, *, errors: ErrorOptions = "raise", **labels_kwargs, - ) -> T_Dataset: + ) -> Self: """Backward compatible method based on `drop_vars` and `drop_sel` Using either `drop_vars` or `drop_sel` is encouraged @@ -5187,10 +5965,9 @@ def drop( raise ValueError('errors must be either "raise" or "ignore"') if is_dict_like(labels) and not isinstance(labels, dict): - warnings.warn( - "dropping coordinates using `drop` is be deprecated; use drop_vars.", - FutureWarning, - stacklevel=2, + emit_user_level_warning( + "dropping coordinates using `drop` is deprecated; use drop_vars.", + DeprecationWarning, ) return self.drop_vars(labels, errors=errors) @@ -5200,11 +5977,13 @@ def drop( labels = either_dict_or_kwargs(labels, labels_kwargs, "drop") if dim is None and (is_scalar(labels) or isinstance(labels, Iterable)): - warnings.warn( - "dropping variables using `drop` will be deprecated; using drop_vars is encouraged.", - PendingDeprecationWarning, - stacklevel=2, + emit_user_level_warning( + "dropping variables using `drop` is deprecated; use drop_vars.", + DeprecationWarning, ) + # for mypy + if is_scalar(labels): + labels = [labels] return self.drop_vars(labels, errors=errors) if dim is not None: warnings.warn( @@ -5215,16 +5994,15 @@ def drop( ) return self.drop_sel({dim: labels}, errors=errors, **labels_kwargs) - warnings.warn( - "dropping labels using `drop` will be deprecated; using drop_sel is encouraged.", - PendingDeprecationWarning, - stacklevel=2, + emit_user_level_warning( + "dropping labels using `drop` is deprecated; use `drop_sel` instead.", + DeprecationWarning, ) return self.drop_sel(labels, errors=errors) def drop_sel( - self: T_Dataset, labels=None, *, errors: ErrorOptions = "raise", **labels_kwargs - ) -> T_Dataset: + self, labels=None, *, errors: ErrorOptions = "raise", **labels_kwargs + ) -> Self: """Drop index labels from this dataset. Parameters @@ -5249,29 +6027,29 @@ def drop_sel( >>> labels = ["a", "b", "c"] >>> ds = xr.Dataset({"A": (["x", "y"], data), "y": labels}) >>> ds - + Size: 60B Dimensions: (x: 2, y: 3) Coordinates: - * y (y) >> ds.drop_sel(y=["a", "c"]) - + Size: 20B Dimensions: (x: 2, y: 1) Coordinates: - * y (y) >> ds.drop_sel(y="b") - + Size: 40B Dimensions: (x: 2, y: 2) Coordinates: - * y (y) T_Dataset: + def drop_isel(self, indexers=None, **indexers_kwargs) -> Self: """Drop index positions from this Dataset. Parameters @@ -5317,29 +6095,29 @@ def drop_isel(self: T_Dataset, indexers=None, **indexers_kwargs) -> T_Dataset: >>> labels = ["a", "b", "c"] >>> ds = xr.Dataset({"A": (["x", "y"], data), "y": labels}) >>> ds - + Size: 60B Dimensions: (x: 2, y: 3) Coordinates: - * y (y) >> ds.drop_isel(y=[0, 2]) - + Size: 20B Dimensions: (x: 2, y: 1) Coordinates: - * y (y) >> ds.drop_isel(y=1) - + Size: 40B Dimensions: (x: 2, y: 2) Coordinates: - * y (y) T_Dataset: return ds def drop_dims( - self: T_Dataset, + self, drop_dims: str | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_Dataset: + ) -> Self: """Drop dimensions and associated variables from this dataset. Parameters @@ -5393,17 +6171,17 @@ def drop_dims( missing_dims = drop_dims - set(self.dims) if missing_dims: raise ValueError( - f"Dataset does not contain the dimensions: {missing_dims}" + f"Dimensions {tuple(missing_dims)} not found in data dimensions {tuple(self.dims)}" ) drop_vars = {k for k, v in self._variables.items() if set(v.dims) & drop_dims} return self.drop_vars(drop_vars) def transpose( - self: T_Dataset, + self, *dims: Hashable, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> T_Dataset: + ) -> Self: """Return a new Dataset object with all array dimensions transposed. Although the order of dimensions on each array will change, the dataset @@ -5455,13 +6233,15 @@ def transpose( ds._variables[name] = var.transpose(*var_dims) return ds + @_deprecate_positional_args("v2023.10.0") def dropna( - self: T_Dataset, + self, dim: Hashable, + *, how: Literal["any", "all"] = "any", thresh: int | None = None, subset: Iterable[Hashable] | None = None, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with dropped labels for missing values along the provided dimension. @@ -5480,6 +6260,70 @@ def dropna( Which variables to check for missing values. By default, all variables in the dataset are checked. + Examples + -------- + >>> dataset = xr.Dataset( + ... { + ... "temperature": ( + ... ["time", "location"], + ... [[23.4, 24.1], [np.nan, 22.1], [21.8, 24.2], [20.5, 25.3]], + ... ) + ... }, + ... coords={"time": [1, 2, 3, 4], "location": ["A", "B"]}, + ... ) + >>> dataset + Size: 104B + Dimensions: (time: 4, location: 2) + Coordinates: + * time (time) int64 32B 1 2 3 4 + * location (location) >> dataset.dropna(dim="time") + Size: 80B + Dimensions: (time: 3, location: 2) + Coordinates: + * time (time) int64 24B 1 3 4 + * location (location) >> dataset.dropna(dim="time", how="any") + Size: 80B + Dimensions: (time: 3, location: 2) + Coordinates: + * time (time) int64 24B 1 3 4 + * location (location) >> dataset.dropna(dim="time", how="all") + Size: 104B + Dimensions: (time: 4, location: 2) + Coordinates: + * time (time) int64 32B 1 2 3 4 + * location (location) >> dataset.dropna(dim="time", thresh=2) + Size: 80B + Dimensions: (time: 3, location: 2) + Coordinates: + * time (time) int64 24B 1 3 4 + * location (location) = thresh @@ -5517,7 +6363,7 @@ def dropna( return self.isel({dim: mask}) - def fillna(self: T_Dataset, value: Any) -> T_Dataset: + def fillna(self, value: Any) -> Self: """Fill missing values in this object. This operation follows the normal broadcasting and alignment rules that @@ -5550,42 +6396,42 @@ def fillna(self: T_Dataset, value: Any) -> T_Dataset: ... coords={"x": [0, 1, 2, 3]}, ... ) >>> ds - + Size: 160B Dimensions: (x: 4) Coordinates: - * x (x) int64 0 1 2 3 + * x (x) int64 32B 0 1 2 3 Data variables: - A (x) float64 nan 2.0 nan 0.0 - B (x) float64 3.0 4.0 nan 1.0 - C (x) float64 nan nan nan 5.0 - D (x) float64 nan 3.0 nan 4.0 + A (x) float64 32B nan 2.0 nan 0.0 + B (x) float64 32B 3.0 4.0 nan 1.0 + C (x) float64 32B nan nan nan 5.0 + D (x) float64 32B nan 3.0 nan 4.0 Replace all `NaN` values with 0s. >>> ds.fillna(0) - + Size: 160B Dimensions: (x: 4) Coordinates: - * x (x) int64 0 1 2 3 + * x (x) int64 32B 0 1 2 3 Data variables: - A (x) float64 0.0 2.0 0.0 0.0 - B (x) float64 3.0 4.0 0.0 1.0 - C (x) float64 0.0 0.0 0.0 5.0 - D (x) float64 0.0 3.0 0.0 4.0 + A (x) float64 32B 0.0 2.0 0.0 0.0 + B (x) float64 32B 3.0 4.0 0.0 1.0 + C (x) float64 32B 0.0 0.0 0.0 5.0 + D (x) float64 32B 0.0 3.0 0.0 4.0 Replace all `NaN` elements in column ‘A’, ‘B’, ‘C’, and ‘D’, with 0, 1, 2, and 3 respectively. >>> values = {"A": 0, "B": 1, "C": 2, "D": 3} >>> ds.fillna(value=values) - + Size: 160B Dimensions: (x: 4) Coordinates: - * x (x) int64 0 1 2 3 + * x (x) int64 32B 0 1 2 3 Data variables: - A (x) float64 0.0 2.0 0.0 0.0 - B (x) float64 3.0 4.0 1.0 1.0 - C (x) float64 2.0 2.0 2.0 5.0 - D (x) float64 3.0 3.0 3.0 4.0 + A (x) float64 32B 0.0 2.0 0.0 0.0 + B (x) float64 32B 3.0 4.0 1.0 1.0 + C (x) float64 32B 2.0 2.0 2.0 5.0 + D (x) float64 32B 3.0 3.0 3.0 4.0 """ if utils.is_dict_like(value): value_keys = getattr(value, "data_vars", value).keys() @@ -5598,7 +6444,7 @@ def fillna(self: T_Dataset, value: Any) -> T_Dataset: return out def interpolate_na( - self: T_Dataset, + self, dim: Hashable | None = None, method: InterpOptions = "linear", limit: int | None = None, @@ -5607,7 +6453,7 @@ def interpolate_na( int | float | str | pd.Timedelta | np.timedelta64 | datetime.timedelta ) = None, **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Fill in NaNs by interpolating according to different methods. Parameters @@ -5615,7 +6461,7 @@ def interpolate_na( dim : Hashable or None, optional Specifies the dimension along which to interpolate. method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \ - "barycentric", "krog", "pchip", "spline", "akima"}, default: "linear" + "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear" String indicating which method to use for interpolation: - 'linear': linear interpolation. Additional keyword @@ -5624,15 +6470,15 @@ def interpolate_na( are passed to :py:func:`scipy.interpolate.interp1d`. If ``method='polynomial'``, the ``order`` keyword argument must also be provided. - - 'barycentric', 'krog', 'pchip', 'spline', 'akima': use their + - 'barycentric', 'krogh', 'pchip', 'spline', 'akima': use their respective :py:class:`scipy.interpolate` classes. use_coordinate : bool or Hashable, default: True Specifies which index to use as the x values in the interpolation formulated as `y = f(x)`. If False, values are treated as if - eqaully-spaced along ``dim``. If True, the IndexVariable `dim` is + equally-spaced along ``dim``. If True, the IndexVariable `dim` is used. If ``use_coordinate`` is a string, it specifies the name of a - coordinate variariable to use as the index. + coordinate variable to use as the index. limit : int, default: None Maximum number of consecutive NaNs to fill. Must be greater than 0 or None for no limit. This filling is done regardless of the size of @@ -5669,6 +6515,11 @@ def interpolate_na( interpolated: Dataset Filled in Dataset. + Warning + -------- + When passing fill_value as a keyword argument with method="linear", it does not use + ``numpy.interp`` but it uses ``scipy.interpolate.interp1d``, which provides the fill_value parameter. + See Also -------- numpy.interp @@ -5686,37 +6537,37 @@ def interpolate_na( ... coords={"x": [0, 1, 2, 3, 4]}, ... ) >>> ds - + Size: 200B Dimensions: (x: 5) Coordinates: - * x (x) int64 0 1 2 3 4 + * x (x) int64 40B 0 1 2 3 4 Data variables: - A (x) float64 nan 2.0 3.0 nan 0.0 - B (x) float64 3.0 4.0 nan 1.0 7.0 - C (x) float64 nan nan nan 5.0 0.0 - D (x) float64 nan 3.0 nan -1.0 4.0 + A (x) float64 40B nan 2.0 3.0 nan 0.0 + B (x) float64 40B 3.0 4.0 nan 1.0 7.0 + C (x) float64 40B nan nan nan 5.0 0.0 + D (x) float64 40B nan 3.0 nan -1.0 4.0 >>> ds.interpolate_na(dim="x", method="linear") - + Size: 200B Dimensions: (x: 5) Coordinates: - * x (x) int64 0 1 2 3 4 + * x (x) int64 40B 0 1 2 3 4 Data variables: - A (x) float64 nan 2.0 3.0 1.5 0.0 - B (x) float64 3.0 4.0 2.5 1.0 7.0 - C (x) float64 nan nan nan 5.0 0.0 - D (x) float64 nan 3.0 1.0 -1.0 4.0 + A (x) float64 40B nan 2.0 3.0 1.5 0.0 + B (x) float64 40B 3.0 4.0 2.5 1.0 7.0 + C (x) float64 40B nan nan nan 5.0 0.0 + D (x) float64 40B nan 3.0 1.0 -1.0 4.0 >>> ds.interpolate_na(dim="x", method="linear", fill_value="extrapolate") - + Size: 200B Dimensions: (x: 5) Coordinates: - * x (x) int64 0 1 2 3 4 + * x (x) int64 40B 0 1 2 3 4 Data variables: - A (x) float64 1.0 2.0 3.0 1.5 0.0 - B (x) float64 3.0 4.0 2.5 1.0 7.0 - C (x) float64 20.0 15.0 10.0 5.0 0.0 - D (x) float64 5.0 3.0 1.0 -1.0 4.0 + A (x) float64 40B 1.0 2.0 3.0 1.5 0.0 + B (x) float64 40B 3.0 4.0 2.5 1.0 7.0 + C (x) float64 40B 20.0 15.0 10.0 5.0 0.0 + D (x) float64 40B 5.0 3.0 1.0 -1.0 4.0 """ from xarray.core.missing import _apply_over_vars_with_dim, interp_na @@ -5732,7 +6583,7 @@ def interpolate_na( ) return new - def ffill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset: + def ffill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values forward *Requires bottleneck.* @@ -5740,8 +6591,7 @@ def ffill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset Parameters ---------- dim : Hashable - Specifies the dimension along which to propagate values when - filling. + Specifies the dimension along which to propagate values when filling. limit : int or None, optional The maximum number of consecutive NaN values to forward fill. In other words, if there is a gap with more than this number of @@ -5749,16 +6599,55 @@ def ffill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset than 0 or None for no limit. Must be None or greater than or equal to axis length if filling along chunked axes (dimensions). + Examples + -------- + >>> time = pd.date_range("2023-01-01", periods=10, freq="D") + >>> data = np.array( + ... [1, np.nan, np.nan, np.nan, 5, np.nan, np.nan, 8, np.nan, 10] + ... ) + >>> dataset = xr.Dataset({"data": (("time",), data)}, coords={"time": time}) + >>> dataset + Size: 160B + Dimensions: (time: 10) + Coordinates: + * time (time) datetime64[ns] 80B 2023-01-01 2023-01-02 ... 2023-01-10 + Data variables: + data (time) float64 80B 1.0 nan nan nan 5.0 nan nan 8.0 nan 10.0 + + # Perform forward fill (ffill) on the dataset + + >>> dataset.ffill(dim="time") + Size: 160B + Dimensions: (time: 10) + Coordinates: + * time (time) datetime64[ns] 80B 2023-01-01 2023-01-02 ... 2023-01-10 + Data variables: + data (time) float64 80B 1.0 1.0 1.0 1.0 5.0 5.0 5.0 8.0 8.0 10.0 + + # Limit the forward filling to a maximum of 2 consecutive NaN values + + >>> dataset.ffill(dim="time", limit=2) + Size: 160B + Dimensions: (time: 10) + Coordinates: + * time (time) datetime64[ns] 80B 2023-01-01 2023-01-02 ... 2023-01-10 + Data variables: + data (time) float64 80B 1.0 1.0 1.0 nan 5.0 5.0 5.0 8.0 8.0 10.0 + Returns ------- Dataset + + See Also + -------- + Dataset.bfill """ from xarray.core.missing import _apply_over_vars_with_dim, ffill new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit) return new - def bfill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset: + def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values backward *Requires bottleneck.* @@ -5775,16 +6664,55 @@ def bfill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset than 0 or None for no limit. Must be None or greater than or equal to axis length if filling along chunked axes (dimensions). + Examples + -------- + >>> time = pd.date_range("2023-01-01", periods=10, freq="D") + >>> data = np.array( + ... [1, np.nan, np.nan, np.nan, 5, np.nan, np.nan, 8, np.nan, 10] + ... ) + >>> dataset = xr.Dataset({"data": (("time",), data)}, coords={"time": time}) + >>> dataset + Size: 160B + Dimensions: (time: 10) + Coordinates: + * time (time) datetime64[ns] 80B 2023-01-01 2023-01-02 ... 2023-01-10 + Data variables: + data (time) float64 80B 1.0 nan nan nan 5.0 nan nan 8.0 nan 10.0 + + # filled dataset, fills NaN values by propagating values backward + + >>> dataset.bfill(dim="time") + Size: 160B + Dimensions: (time: 10) + Coordinates: + * time (time) datetime64[ns] 80B 2023-01-01 2023-01-02 ... 2023-01-10 + Data variables: + data (time) float64 80B 1.0 5.0 5.0 5.0 5.0 8.0 8.0 8.0 10.0 10.0 + + # Limit the backward filling to a maximum of 2 consecutive NaN values + + >>> dataset.bfill(dim="time", limit=2) + Size: 160B + Dimensions: (time: 10) + Coordinates: + * time (time) datetime64[ns] 80B 2023-01-01 2023-01-02 ... 2023-01-10 + Data variables: + data (time) float64 80B 1.0 nan 5.0 5.0 5.0 8.0 8.0 8.0 10.0 10.0 + Returns ------- Dataset + + See Also + -------- + Dataset.ffill """ from xarray.core.missing import _apply_over_vars_with_dim, bfill new = _apply_over_vars_with_dim(bfill, self, dim=dim, limit=limit) return new - def combine_first(self: T_Dataset, other: T_Dataset) -> T_Dataset: + def combine_first(self, other: Self) -> Self: """Combine two Datasets, default to data_vars of self. The new coordinates follow the normal broadcasting and alignment rules @@ -5804,7 +6732,7 @@ def combine_first(self: T_Dataset, other: T_Dataset) -> T_Dataset: return out def reduce( - self: T_Dataset, + self, func: Callable, dim: Dims = None, *, @@ -5812,7 +6740,7 @@ def reduce( keepdims: bool = False, numeric_only: bool = False, **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Reduce this dataset by applying `func` along some dimension(s). Parameters @@ -5842,6 +6770,38 @@ def reduce( reduced : Dataset Dataset with this object's DataArrays replaced with new DataArrays of summarized data and the indicated dimension(s) removed. + + Examples + -------- + + >>> dataset = xr.Dataset( + ... { + ... "math_scores": ( + ... ["student", "test"], + ... [[90, 85, 92], [78, 80, 85], [95, 92, 98]], + ... ), + ... "english_scores": ( + ... ["student", "test"], + ... [[88, 90, 92], [75, 82, 79], [93, 96, 91]], + ... ), + ... }, + ... coords={ + ... "student": ["Alice", "Bob", "Charlie"], + ... "test": ["Test 1", "Test 2", "Test 3"], + ... }, + ... ) + + # Calculate the 75th percentile of math scores for each student using np.percentile + + >>> percentile_scores = dataset.reduce(np.percentile, q=75, dim="test") + >>> percentile_scores + Size: 132B + Dimensions: (student: 3) + Coordinates: + * student (student) T_Dataset: + ) -> Self: """Apply a function to each data variable in this dataset Parameters @@ -5938,19 +6898,19 @@ def map( >>> da = xr.DataArray(np.random.randn(2, 3)) >>> ds = xr.Dataset({"foo": da, "bar": ("x", [-1, 2])}) >>> ds - + Size: 64B Dimensions: (dim_0: 2, dim_1: 3, x: 2) Dimensions without coordinates: dim_0, dim_1, x Data variables: - foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 -0.9773 - bar (x) int64 -1 2 + foo (dim_0, dim_1) float64 48B 1.764 0.4002 0.9787 2.241 1.868 -0.9773 + bar (x) int64 16B -1 2 >>> ds.map(np.fabs) - + Size: 64B Dimensions: (dim_0: 2, dim_1: 3, x: 2) Dimensions without coordinates: dim_0, dim_1, x Data variables: - foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 0.9773 - bar (x) float64 1.0 2.0 + foo (dim_0, dim_1) float64 48B 1.764 0.4002 0.9787 2.241 1.868 0.9773 + bar (x) float64 16B 1.0 2.0 """ if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) @@ -5965,12 +6925,12 @@ def map( return type(self)(variables, attrs=attrs) def apply( - self: T_Dataset, + self, func: Callable, keep_attrs: bool | None = None, args: Iterable[Any] = (), **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """ Backward compatible implementation of ``map`` @@ -5986,10 +6946,10 @@ def apply( return self.map(func, keep_attrs, args, **kwargs) def assign( - self: T_Dataset, + self, variables: Mapping[Any, Any] | None = None, **variables_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Assign new data variables to a Dataset, returning a new object with all the original variables in addition to the new ones. @@ -6018,6 +6978,10 @@ def assign( possible, but you cannot reference other variables created within the same ``assign`` call. + The new assigned variables that replace existing coordinates in the + original dataset are still listed as coordinates in the returned + Dataset. + See Also -------- pandas.DataFrame.assign @@ -6035,52 +6999,64 @@ def assign( ... coords={"lat": [10, 20], "lon": [150, 160]}, ... ) >>> x - + Size: 96B Dimensions: (lat: 2, lon: 2) Coordinates: - * lat (lat) int64 10 20 - * lon (lon) int64 150 160 + * lat (lat) int64 16B 10 20 + * lon (lon) int64 16B 150 160 Data variables: - temperature_c (lat, lon) float64 10.98 14.3 12.06 10.9 - precipitation (lat, lon) float64 0.4237 0.6459 0.4376 0.8918 + temperature_c (lat, lon) float64 32B 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 32B 0.4237 0.6459 0.4376 0.8918 Where the value is a callable, evaluated on dataset: >>> x.assign(temperature_f=lambda x: x.temperature_c * 9 / 5 + 32) - + Size: 128B Dimensions: (lat: 2, lon: 2) Coordinates: - * lat (lat) int64 10 20 - * lon (lon) int64 150 160 + * lat (lat) int64 16B 10 20 + * lon (lon) int64 16B 150 160 Data variables: - temperature_c (lat, lon) float64 10.98 14.3 12.06 10.9 - precipitation (lat, lon) float64 0.4237 0.6459 0.4376 0.8918 - temperature_f (lat, lon) float64 51.76 57.75 53.7 51.62 + temperature_c (lat, lon) float64 32B 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 32B 0.4237 0.6459 0.4376 0.8918 + temperature_f (lat, lon) float64 32B 51.76 57.75 53.7 51.62 Alternatively, the same behavior can be achieved by directly referencing an existing dataarray: >>> x.assign(temperature_f=x["temperature_c"] * 9 / 5 + 32) - + Size: 128B Dimensions: (lat: 2, lon: 2) Coordinates: - * lat (lat) int64 10 20 - * lon (lon) int64 150 160 + * lat (lat) int64 16B 10 20 + * lon (lon) int64 16B 150 160 Data variables: - temperature_c (lat, lon) float64 10.98 14.3 12.06 10.9 - precipitation (lat, lon) float64 0.4237 0.6459 0.4376 0.8918 - temperature_f (lat, lon) float64 51.76 57.75 53.7 51.62 + temperature_c (lat, lon) float64 32B 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 32B 0.4237 0.6459 0.4376 0.8918 + temperature_f (lat, lon) float64 32B 51.76 57.75 53.7 51.62 """ variables = either_dict_or_kwargs(variables, variables_kwargs, "assign") data = self.copy() + # do all calculations first... results: CoercibleMapping = data._calc_assign_results(variables) - data.coords._maybe_drop_multiindex_coords(set(results.keys())) + + # split data variables to add/replace vs. coordinates to replace + results_data_vars: dict[Hashable, CoercibleValue] = {} + results_coords: dict[Hashable, CoercibleValue] = {} + for k, v in results.items(): + if k in data._coord_names: + results_coords[k] = v + else: + results_data_vars[k] = v + # ... and then assign - data.update(results) + data.coords.update(results_coords) + data.update(results_data_vars) + return data - def to_array( + def to_dataarray( self, dim: Hashable = "variable", name: Hashable | None = None ) -> DataArray: """Convert this dataset into an xarray.DataArray @@ -6117,6 +7093,12 @@ def to_array( return DataArray._construct_direct(variable, coords, name, indexes) + def to_array( + self, dim: Hashable = "variable", name: Hashable | None = None + ) -> DataArray: + """Deprecated version of to_dataarray""" + return self.to_dataarray(dim=dim, name=name) + def _normalize_dim_order( self, dim_order: Sequence[Hashable] | None = None ) -> dict[Hashable, int]: @@ -6139,11 +7121,11 @@ def _normalize_dim_order( dim_order = list(self.dims) elif set(dim_order) != set(self.dims): raise ValueError( - "dim_order {} does not match the set of dimensions of this " - "Dataset: {}".format(dim_order, list(self.dims)) + f"dim_order {dim_order} does not match the set of dimensions of this " + f"Dataset: {list(self.dims)}" ) - ordered_dims = {k: self.dims[k] for k in dim_order} + ordered_dims = {k: self.sizes[k] for k in dim_order} return ordered_dims @@ -6163,9 +7145,9 @@ def to_pandas(self) -> pd.Series | pd.DataFrame: if len(self.dims) == 1: return self.to_dataframe() raise ValueError( - "cannot convert Datasets with %s dimensions into " + f"cannot convert Datasets with {len(self.dims)} dimensions into " "pandas objects without changing the number of dimensions. " - "Please use Dataset.to_dataframe() instead." % len(self.dims) + "Please use Dataset.to_dataframe() instead." ) def _to_dataframe(self, ordered_dims: Mapping[Any, int]): @@ -6278,18 +7260,17 @@ def _set_numpy_data_from_dataframe( self[name] = (dims, data) @classmethod - def from_dataframe( - cls: type[T_Dataset], dataframe: pd.DataFrame, sparse: bool = False - ) -> T_Dataset: + def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: """Convert a pandas.DataFrame into an xarray.Dataset Each column will be converted into an independent variable in the Dataset. If the dataframe's index is a MultiIndex, it will be expanded into a tensor product of one-dimensional indices (filling in missing - values with NaN). This method will produce a Dataset very similar to + values with NaN). If you rather preserve the MultiIndex use + `xr.Dataset(df)`. This method will produce a Dataset very similar to that on which the 'to_dataframe' method was called, except with possibly redundant dimensions (since all dataset variables will have - the same dimensionality) + the same dimensionality). Parameters ---------- @@ -6396,13 +7377,16 @@ def to_dask_dataframe( columns.extend(k for k in self.coords if k not in self.dims) columns.extend(self.data_vars) + ds_chunks = self.chunks + series_list = [] + df_meta = pd.DataFrame() for name in columns: try: var = self.variables[name] except KeyError: # dimension without a matching coordinate - size = self.dims[name] + size = self.sizes[name] data = da.arange(size, chunks=size, dtype=np.int64) var = Variable((name,), data) @@ -6415,8 +7399,11 @@ def to_dask_dataframe( if not is_duck_dask_array(var._data): var = var.chunk() - dask_array = var.set_dims(ordered_dims).chunk(self.chunks).data - series = dd.from_array(dask_array.reshape(-1), columns=[name]) + # Broadcast then flatten the array: + var_new_dims = var.set_dims(ordered_dims).chunk(ds_chunks) + dask_array = var_new_dims._data.reshape(-1) + + series = dd.from_dask_array(dask_array, columns=name, meta=df_meta) series_list.append(series) df = dd.concat(series_list, axis=1) @@ -6434,7 +7421,9 @@ def to_dask_dataframe( return df - def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]: + def to_dict( + self, data: bool | Literal["list", "array"] = "list", encoding: bool = False + ) -> dict[str, Any]: """ Convert this dataset to a dictionary following xarray naming conventions. @@ -6445,9 +7434,14 @@ def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]: Parameters ---------- - data : bool, default: True + data : bool or {"list", "array"}, default: "list" Whether to include the actual data in the dictionary. When set to - False, returns just the schema. + False, returns just the schema. If set to "array", returns data as + underlying array type. If set to "list" (or True for backwards + compatibility), returns data in lists of Python data types. Note + that for obtaining the "list" output efficiently, use + `ds.compute().to_dict(data="list")`. + encoding : bool, default: False Whether to include the Dataset's encoding in the dictionary. @@ -6465,7 +7459,7 @@ def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]: d: dict = { "coords": {}, "attrs": decode_numpy_dict_values(self.attrs), - "dims": dict(self.dims), + "dims": dict(self.sizes), "data_vars": {}, } for k in self.coords: @@ -6481,7 +7475,7 @@ def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]: return d @classmethod - def from_dict(cls: type[T_Dataset], d: Mapping[Any, Any]) -> T_Dataset: + def from_dict(cls, d: Mapping[Any, Any]) -> Self: """Convert a dictionary into an xarray.Dataset. Parameters @@ -6509,13 +7503,13 @@ def from_dict(cls: type[T_Dataset], d: Mapping[Any, Any]) -> T_Dataset: ... } >>> ds = xr.Dataset.from_dict(d) >>> ds - + Size: 60B Dimensions: (t: 3) Coordinates: - * t (t) int64 0 1 2 + * t (t) int64 24B 0 1 2 Data variables: - a (t) >> d = { ... "coords": { @@ -6530,13 +7524,13 @@ def from_dict(cls: type[T_Dataset], d: Mapping[Any, Any]) -> T_Dataset: ... } >>> ds = xr.Dataset.from_dict(d) >>> ds - + Size: 60B Dimensions: (t: 3) Coordinates: - * t (t) int64 0 1 2 + * t (t) int64 24B 0 1 2 Data variables: - a (t) int64 10 20 30 - b (t) T_Dataset: ) try: variable_dict = { - k: (v["dims"], v["data"], v.get("attrs")) for k, v in variables + k: (v["dims"], v["data"], v.get("attrs"), v.get("encoding")) + for k, v in variables } except KeyError as e: - raise ValueError( - "cannot convert dict without the key " - "'{dims_data}'".format(dims_data=str(e.args[0])) - ) + raise ValueError(f"cannot convert dict without the key '{str(e.args[0])}'") obj = cls(variable_dict) # what if coords aren't dims? @@ -6571,7 +7563,7 @@ def from_dict(cls: type[T_Dataset], d: Mapping[Any, Any]) -> T_Dataset: return obj - def _unary_op(self: T_Dataset, f, *args, **kwargs) -> T_Dataset: + def _unary_op(self, f, *args, **kwargs) -> Self: variables = {} keep_attrs = kwargs.pop("keep_attrs", None) if keep_attrs is None: @@ -6582,7 +7574,7 @@ def _unary_op(self: T_Dataset, f, *args, **kwargs) -> T_Dataset: else: variables[k] = f(v, *args, **kwargs) if keep_attrs: - variables[k].attrs = v._attrs + variables[k]._attrs = v._attrs attrs = self._attrs if keep_attrs else None return self._replace_with_new_dims(variables, attrs=attrs) @@ -6594,7 +7586,7 @@ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset: return NotImplemented align_type = OPTIONS["arithmetic_join"] if join is None else join if isinstance(other, (DataArray, Dataset)): - self, other = align(self, other, join=align_type, copy=False) # type: ignore[assignment] + self, other = align(self, other, join=align_type, copy=False) g = f if not reflexive else lambda x, y: f(y, x) ds = self._calculate_binary_op(g, other, join=align_type) keep_attrs = _get_keep_attrs(default=False) @@ -6602,7 +7594,7 @@ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset: ds.attrs = self.attrs return ds - def _inplace_binary_op(self: T_Dataset, other, f) -> T_Dataset: + def _inplace_binary_op(self, other, f) -> Self: from xarray.core.dataarray import DataArray from xarray.core.groupby import GroupBy @@ -6676,12 +7668,14 @@ def _copy_attrs_from(self, other): if v in self.variables: self.variables[v].attrs = other.variables[v].attrs + @_deprecate_positional_args("v2023.10.0") def diff( - self: T_Dataset, + self, dim: Hashable, n: int = 1, + *, label: Literal["upper", "lower"] = "upper", - ) -> T_Dataset: + ) -> Self: """Calculate the n-th order discrete difference along given axis. Parameters @@ -6709,17 +7703,17 @@ def diff( -------- >>> ds = xr.Dataset({"foo": ("x", [5, 5, 6, 6])}) >>> ds.diff("x") - + Size: 24B Dimensions: (x: 3) Dimensions without coordinates: x Data variables: - foo (x) int64 0 1 0 + foo (x) int64 24B 0 1 0 >>> ds.diff("x", 2) - + Size: 16B Dimensions: (x: 2) Dimensions without coordinates: x Data variables: - foo (x) int64 1 -1 + foo (x) int64 16B 1 -1 See Also -------- @@ -6764,11 +7758,11 @@ def diff( return difference def shift( - self: T_Dataset, + self, shifts: Mapping[Any, int] | None = None, fill_value: Any = xrdtypes.NA, **shifts_kwargs: int, - ) -> T_Dataset: + ) -> Self: """Shift this dataset by an offset along one or more dimensions. Only data variables are moved; coordinates stay in place. This is @@ -6805,16 +7799,18 @@ def shift( -------- >>> ds = xr.Dataset({"foo": ("x", list("abcde"))}) >>> ds.shift(x=2) - + Size: 40B Dimensions: (x: 5) Dimensions without coordinates: x Data variables: - foo (x) object nan nan 'a' 'b' 'c' + foo (x) object 40B nan nan 'a' 'b' 'c' """ shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "shift") - invalid = [k for k in shifts if k not in self.dims] + invalid = tuple(k for k in shifts if k not in self.dims) if invalid: - raise ValueError(f"dimensions {invalid!r} do not exist") + raise ValueError( + f"Dimensions {invalid} not found in data dimensions {tuple(self.dims)}" + ) variables = {} for name, var in self.variables.items(): @@ -6833,11 +7829,11 @@ def shift( return self._replace(variables) def roll( - self: T_Dataset, + self, shifts: Mapping[Any, int] | None = None, roll_coords: bool = False, **shifts_kwargs: int, - ) -> T_Dataset: + ) -> Self: """Roll this dataset by an offset along one or more dimensions. Unlike shift, roll treats the given dimensions as periodic, so will not @@ -6872,26 +7868,28 @@ def roll( -------- >>> ds = xr.Dataset({"foo": ("x", list("abcde"))}, coords={"x": np.arange(5)}) >>> ds.roll(x=2) - + Size: 60B Dimensions: (x: 5) Coordinates: - * x (x) int64 0 1 2 3 4 + * x (x) int64 40B 0 1 2 3 4 Data variables: - foo (x) >> ds.roll(x=2, roll_coords=True) - + Size: 60B Dimensions: (x: 5) Coordinates: - * x (x) int64 3 4 0 1 2 + * x (x) int64 40B 3 4 0 1 2 Data variables: - foo (x) T_Dataset: + ) -> Self: """ Sort object by labels or values (along an axis). @@ -6940,9 +7943,10 @@ def sortby( Parameters ---------- - variables : Hashable, DataArray, or list of hashable or DataArray - 1D DataArray objects or name(s) of 1D variable(s) in - coords/data_vars whose values are used to sort the dataset. + variables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable + 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are + used to sort this array. If a callable, the callable is passed this object, + and the result is used as the value for cond. ascending : bool, default: True Whether to sort by ascending or descending order. @@ -6968,27 +7972,37 @@ def sortby( ... }, ... coords={"x": ["b", "a"], "y": [1, 0]}, ... ) - >>> ds = ds.sortby("x") - >>> ds - + >>> ds.sortby("x") + Size: 88B + Dimensions: (x: 2, y: 2) + Coordinates: + * x (x) >> ds.sortby(lambda x: -x["y"]) + Size: 88B Dimensions: (x: 2, y: 2) Coordinates: - * x (x) T_Dataset: + ) -> Self: """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements for each variable @@ -7028,15 +8044,15 @@ def quantile( desired quantile lies between two data points. The options sorted by their R type as summarized in the H&F paper [1]_ are: - 1. "inverted_cdf" (*) - 2. "averaged_inverted_cdf" (*) - 3. "closest_observation" (*) - 4. "interpolated_inverted_cdf" (*) - 5. "hazen" (*) - 6. "weibull" (*) + 1. "inverted_cdf" + 2. "averaged_inverted_cdf" + 3. "closest_observation" + 4. "interpolated_inverted_cdf" + 5. "hazen" + 6. "weibull" 7. "linear" (default) - 8. "median_unbiased" (*) - 9. "normal_unbiased" (*) + 8. "median_unbiased" + 9. "normal_unbiased" The first three methods are discontiuous. The following discontinuous variations of the default "linear" (7.) option are also available: @@ -7050,8 +8066,6 @@ def quantile( was previously called "interpolation", renamed in accordance with numpy version 1.22.0. - (*) These methods require numpy version 1.22 or newer. - keep_attrs : bool, optional If True, the dataset's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new @@ -7084,35 +8098,35 @@ def quantile( ... coords={"x": [7, 9], "y": [1, 1.5, 2, 2.5]}, ... ) >>> ds.quantile(0) # or ds.quantile(0, dim=...) - + Size: 16B Dimensions: () Coordinates: - quantile float64 0.0 + quantile float64 8B 0.0 Data variables: - a float64 0.7 + a float64 8B 0.7 >>> ds.quantile(0, dim="x") - + Size: 72B Dimensions: (y: 4) Coordinates: - * y (y) float64 1.0 1.5 2.0 2.5 - quantile float64 0.0 + * y (y) float64 32B 1.0 1.5 2.0 2.5 + quantile float64 8B 0.0 Data variables: - a (y) float64 0.7 4.2 2.6 1.5 + a (y) float64 32B 0.7 4.2 2.6 1.5 >>> ds.quantile([0, 0.5, 1]) - + Size: 48B Dimensions: (quantile: 3) Coordinates: - * quantile (quantile) float64 0.0 0.5 1.0 + * quantile (quantile) float64 24B 0.0 0.5 1.0 Data variables: - a (quantile) float64 0.7 3.4 9.4 + a (quantile) float64 24B 0.7 3.4 9.4 >>> ds.quantile([0, 0.5, 1], dim="x") - + Size: 152B Dimensions: (quantile: 3, y: 4) Coordinates: - * y (y) float64 1.0 1.5 2.0 2.5 - * quantile (quantile) float64 0.0 0.5 1.0 + * y (y) float64 32B 1.0 1.5 2.0 2.5 + * quantile (quantile) float64 24B 0.0 0.5 1.0 Data variables: - a (quantile, y) float64 0.7 4.2 2.6 1.5 3.6 ... 1.7 6.5 7.3 9.4 1.9 + a (quantile, y) float64 96B 0.7 4.2 2.6 1.5 3.6 ... 6.5 7.3 9.4 1.9 References ---------- @@ -7142,10 +8156,11 @@ def quantile( else: dims = set(dim) - _assert_empty( - tuple(d for d in dims if d not in self.dims), - "Dataset does not contain the dimensions: %s", - ) + invalid_dims = set(dims) - set(self.dims) + if invalid_dims: + raise ValueError( + f"Dimensions {tuple(invalid_dims)} not found in data dimensions {tuple(self.dims)}" + ) q = np.asarray(q, dtype=np.float64) @@ -7181,12 +8196,14 @@ def quantile( ) return new.assign_coords(quantile=q) + @_deprecate_positional_args("v2023.10.0") def rank( - self: T_Dataset, + self, dim: Hashable, + *, pct: bool = False, keep_attrs: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Ranks the data. Equal values are assigned a rank that is the average of the ranks that @@ -7221,7 +8238,9 @@ def rank( ) if dim not in self.dims: - raise ValueError(f"Dataset does not contain the dimension: {dim}") + raise ValueError( + f"Dimension {dim!r} not found in data dimensions {tuple(self.dims)}" + ) variables = {} for name, var in self.variables.items(): @@ -7238,11 +8257,11 @@ def rank( return self._replace(variables, coord_names, attrs=attrs) def differentiate( - self: T_Dataset, + self, coord: Hashable, edge_order: Literal[1, 2] = 1, datetime_unit: DatetimeUnitOptions | None = None, - ) -> T_Dataset: + ) -> Self: """ Differentiate with the second order accurate central differences. @@ -7271,13 +8290,16 @@ def differentiate( from xarray.core.variable import Variable if coord not in self.variables and coord not in self.dims: - raise ValueError(f"Coordinate {coord} does not exist.") + variables_and_dims = tuple(set(self.variables.keys()).union(self.dims)) + raise ValueError( + f"Coordinate {coord!r} not found in variables or dimensions {variables_and_dims}." + ) coord_var = self[coord].variable if coord_var.ndim != 1: raise ValueError( - "Coordinate {} must be 1 dimensional but is {}" - " dimensional".format(coord, coord_var.ndim) + f"Coordinate {coord} must be 1 dimensional but is {coord_var.ndim}" + " dimensional" ) dim = coord_var.dims[0] @@ -7307,10 +8329,10 @@ def differentiate( return self._replace(variables) def integrate( - self: T_Dataset, + self, coord: Hashable | Sequence[Hashable], datetime_unit: DatetimeUnitOptions = None, - ) -> T_Dataset: + ) -> Self: """Integrate along the given coordinate using the trapezoidal rule. .. note:: @@ -7341,26 +8363,26 @@ def integrate( ... coords={"x": [0, 1, 2, 3], "y": ("x", [1, 7, 3, 5])}, ... ) >>> ds - + Size: 128B Dimensions: (x: 4) Coordinates: - * x (x) int64 0 1 2 3 - y (x) int64 1 7 3 5 + * x (x) int64 32B 0 1 2 3 + y (x) int64 32B 1 7 3 5 Data variables: - a (x) int64 5 5 6 6 - b (x) int64 1 2 1 0 + a (x) int64 32B 5 5 6 6 + b (x) int64 32B 1 2 1 0 >>> ds.integrate("x") - + Size: 16B Dimensions: () Data variables: - a float64 16.5 - b float64 3.5 + a float64 8B 16.5 + b float64 8B 3.5 >>> ds.integrate("y") - + Size: 16B Dimensions: () Data variables: - a float64 20.0 - b float64 4.0 + a float64 8B 20.0 + b float64 8B 4.0 """ if not isinstance(coord, (list, tuple)): coord = (coord,) @@ -7373,13 +8395,16 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): from xarray.core.variable import Variable if coord not in self.variables and coord not in self.dims: - raise ValueError(f"Coordinate {coord} does not exist.") + variables_and_dims = tuple(set(self.variables.keys()).union(self.dims)) + raise ValueError( + f"Coordinate {coord!r} not found in variables or dimensions {variables_and_dims}." + ) coord_var = self[coord].variable if coord_var.ndim != 1: raise ValueError( - "Coordinate {} must be 1 dimensional but is {}" - " dimensional".format(coord, coord_var.ndim) + f"Coordinate {coord} must be 1 dimensional but is {coord_var.ndim}" + " dimensional" ) dim = coord_var.dims[0] @@ -7423,10 +8448,10 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): ) def cumulative_integrate( - self: T_Dataset, + self, coord: Hashable | Sequence[Hashable], datetime_unit: DatetimeUnitOptions = None, - ) -> T_Dataset: + ) -> Self: """Integrate along the given coordinate using the trapezoidal rule. .. note:: @@ -7461,32 +8486,32 @@ def cumulative_integrate( ... coords={"x": [0, 1, 2, 3], "y": ("x", [1, 7, 3, 5])}, ... ) >>> ds - + Size: 128B Dimensions: (x: 4) Coordinates: - * x (x) int64 0 1 2 3 - y (x) int64 1 7 3 5 + * x (x) int64 32B 0 1 2 3 + y (x) int64 32B 1 7 3 5 Data variables: - a (x) int64 5 5 6 6 - b (x) int64 1 2 1 0 + a (x) int64 32B 5 5 6 6 + b (x) int64 32B 1 2 1 0 >>> ds.cumulative_integrate("x") - + Size: 128B Dimensions: (x: 4) Coordinates: - * x (x) int64 0 1 2 3 - y (x) int64 1 7 3 5 + * x (x) int64 32B 0 1 2 3 + y (x) int64 32B 1 7 3 5 Data variables: - a (x) float64 0.0 5.0 10.5 16.5 - b (x) float64 0.0 1.5 3.0 3.5 + a (x) float64 32B 0.0 5.0 10.5 16.5 + b (x) float64 32B 0.0 1.5 3.0 3.5 >>> ds.cumulative_integrate("y") - + Size: 128B Dimensions: (x: 4) Coordinates: - * x (x) int64 0 1 2 3 - y (x) int64 1 7 3 5 + * x (x) int64 32B 0 1 2 3 + y (x) int64 32B 1 7 3 5 Data variables: - a (x) float64 0.0 30.0 8.0 20.0 - b (x) float64 0.0 9.0 3.0 4.0 + a (x) float64 32B 0.0 30.0 8.0 20.0 + b (x) float64 32B 0.0 9.0 3.0 4.0 """ if not isinstance(coord, (list, tuple)): coord = (coord,) @@ -7498,7 +8523,7 @@ def cumulative_integrate( return result @property - def real(self: T_Dataset) -> T_Dataset: + def real(self) -> Self: """ The real part of each data variable. @@ -7509,7 +8534,7 @@ def real(self: T_Dataset) -> T_Dataset: return self.map(lambda x: x.real, keep_attrs=True) @property - def imag(self: T_Dataset) -> T_Dataset: + def imag(self) -> Self: """ The imaginary part of each data variable. @@ -7521,7 +8546,7 @@ def imag(self: T_Dataset) -> T_Dataset: plot = utils.UncachedAccessor(DatasetPlotAccessor) - def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset: + def filter_by_attrs(self, **kwargs) -> Self: """Returns a ``Dataset`` with variables that match specific conditions. Can pass in ``key=value`` or ``key=callable``. A Dataset is returned @@ -7574,32 +8599,32 @@ def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset: Get variables matching a specific standard_name: >>> ds.filter_by_attrs(standard_name="convective_precipitation_flux") - + Size: 192B Dimensions: (x: 2, y: 2, time: 3) Coordinates: - lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 - lat (x, y) float64 42.25 42.21 42.63 42.59 - * time (time) datetime64[ns] 2014-09-06 2014-09-07 2014-09-08 - reference_time datetime64[ns] 2014-09-05 + lon (x, y) float64 32B -99.83 -99.32 -99.79 -99.23 + lat (x, y) float64 32B 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 24B 2014-09-06 2014-09-07 2014-09-08 + reference_time datetime64[ns] 8B 2014-09-05 Dimensions without coordinates: x, y Data variables: - precipitation (x, y, time) float64 5.68 9.256 0.7104 ... 7.992 4.615 7.805 + precipitation (x, y, time) float64 96B 5.68 9.256 0.7104 ... 4.615 7.805 Get all variables that have a standard_name attribute: >>> standard_name = lambda v: v is not None >>> ds.filter_by_attrs(standard_name=standard_name) - + Size: 288B Dimensions: (x: 2, y: 2, time: 3) Coordinates: - lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 - lat (x, y) float64 42.25 42.21 42.63 42.59 - * time (time) datetime64[ns] 2014-09-06 2014-09-07 2014-09-08 - reference_time datetime64[ns] 2014-09-05 + lon (x, y) float64 32B -99.83 -99.32 -99.79 -99.23 + lat (x, y) float64 32B 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 24B 2014-09-06 2014-09-07 2014-09-08 + reference_time datetime64[ns] 8B 2014-09-05 Dimensions without coordinates: x, y Data variables: - temperature (x, y, time) float64 29.11 18.2 22.83 ... 18.28 16.15 26.63 - precipitation (x, y, time) float64 5.68 9.256 0.7104 ... 7.992 4.615 7.805 + temperature (x, y, time) float64 96B 29.11 18.2 22.83 ... 16.15 26.63 + precipitation (x, y, time) float64 96B 5.68 9.256 0.7104 ... 4.615 7.805 """ selection = [] @@ -7616,7 +8641,7 @@ def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset: selection.append(var_name) return self[selection] - def unify_chunks(self: T_Dataset) -> T_Dataset: + def unify_chunks(self) -> Self: """Unify chunk size along all chunked dimensions of this Dataset. Returns @@ -7688,6 +8713,10 @@ def map_blocks( dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks xarray.DataArray.map_blocks + :doc:`xarray-tutorial:advanced/map_blocks/map_blocks` + Advanced Tutorial on map_blocks with dask + + Examples -------- Calculate an anomaly from climatology using ``.groupby()``. Using @@ -7699,7 +8728,7 @@ def map_blocks( ... clim = gb.mean(dim="time") ... return gb - clim ... - >>> time = xr.cftime_range("1990-01", "1992-01", freq="M") + >>> time = xr.cftime_range("1990-01", "1992-01", freq="ME") >>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"]) >>> np.random.seed(123) >>> array = xr.DataArray( @@ -7709,13 +8738,13 @@ def map_blocks( ... ).chunk() >>> ds = xr.Dataset({"a": array}) >>> ds.map_blocks(calculate_anomaly, template=ds).compute() - + Size: 576B Dimensions: (time: 24) Coordinates: - * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 - month (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12 + * time (time) object 192B 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 192B 1 2 3 4 5 6 7 8 9 10 ... 3 4 5 6 7 8 9 10 11 12 Data variables: - a (time) float64 0.1289 0.1132 -0.0856 ... 0.2287 0.1906 -0.05901 + a (time) float64 192B 0.1289 0.1132 -0.0856 ... 0.1906 -0.05901 Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments to the function being applied in ``xr.map_blocks()``: @@ -7725,20 +8754,20 @@ def map_blocks( ... kwargs={"groupby_type": "time.year"}, ... template=ds, ... ) - + Size: 576B Dimensions: (time: 24) Coordinates: - * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 - month (time) int64 dask.array + * time (time) object 192B 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 192B dask.array Data variables: - a (time) float64 dask.array + a (time) float64 192B dask.array """ from xarray.core.parallel import map_blocks return map_blocks(func, self, args, kwargs, template) def polyfit( - self: T_Dataset, + self, dim: Hashable, deg: int, skipna: bool | None = None, @@ -7746,7 +8775,7 @@ def polyfit( w: Hashable | Any = None, full: bool = False, cov: bool | Literal["unscaled"] = False, - ) -> T_Dataset: + ) -> Self: """ Least squares polynomial fit. @@ -7878,13 +8907,13 @@ def polyfit( scale_da = scale if w is not None: - rhs *= w[:, np.newaxis] + rhs = rhs * w[:, np.newaxis] with warnings.catch_warnings(): if full: # Copy np.polyfit behavior - warnings.simplefilter("ignore", np.RankWarning) + warnings.simplefilter("ignore", RankWarning) else: # Raise only once per variable - warnings.simplefilter("once", np.RankWarning) + warnings.simplefilter("once", RankWarning) coeffs, residuals = duck_array_ops.least_squares( lhs, rhs.data, rcond=rcond, skipna=skipna_da @@ -7934,13 +8963,12 @@ def polyfit( return type(self)(data_vars=variables, attrs=self.attrs.copy()) def pad( - self: T_Dataset, + self, pad_width: Mapping[Any, int | tuple[int, int]] | None = None, mode: PadModeOptions = "constant", - stat_length: int - | tuple[int, int] - | Mapping[Any, tuple[int, int]] - | None = None, + stat_length: ( + int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None + ) = None, constant_values: ( float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None ) = None, @@ -7948,7 +8976,7 @@ def pad( reflect_type: PadReflectOptions = None, keep_attrs: bool | None = None, **pad_width_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Pad this dataset along one or more dimensions. .. warning:: @@ -8054,11 +9082,11 @@ def pad( -------- >>> ds = xr.Dataset({"foo": ("x", range(5))}) >>> ds.pad(x=(1, 2)) - + Size: 64B Dimensions: (x: 8) Dimensions without coordinates: x Data variables: - foo (x) float64 nan 0.0 1.0 2.0 3.0 4.0 nan nan + foo (x) float64 64B nan 0.0 1.0 2.0 3.0 4.0 nan nan """ pad_width = either_dict_or_kwargs(pad_width, pad_width_kwargs, "pad") @@ -8119,13 +9147,15 @@ def pad( attrs = self._attrs if keep_attrs else None return self._replace_with_new_dims(variables, indexes=indexes, attrs=attrs) + @_deprecate_positional_args("v2023.10.0") def idxmin( - self: T_Dataset, + self, dim: Hashable | None = None, + *, skipna: bool | None = None, fill_value: Any = xrdtypes.NA, keep_attrs: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Return the coordinate label of the minimum value along a dimension. Returns a new `Dataset` named after the dimension with the values of @@ -8174,37 +9204,37 @@ def idxmin( >>> array2 = xr.DataArray( ... [ ... [2.0, 1.0, 2.0, 0.0, -2.0], - ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], - ... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN], + ... [-4.0, np.nan, 2.0, np.nan, -2.0], + ... [np.nan, np.nan, 1.0, np.nan, np.nan], ... ], ... dims=["y", "x"], ... coords={"y": [-1, 0, 1], "x": ["a", "b", "c", "d", "e"]}, ... ) >>> ds = xr.Dataset({"int": array1, "float": array2}) >>> ds.min(dim="x") - + Size: 56B Dimensions: (y: 3) Coordinates: - * y (y) int64 -1 0 1 + * y (y) int64 24B -1 0 1 Data variables: - int int64 -2 - float (y) float64 -2.0 -4.0 1.0 + int int64 8B -2 + float (y) float64 24B -2.0 -4.0 1.0 >>> ds.argmin(dim="x") - + Size: 56B Dimensions: (y: 3) Coordinates: - * y (y) int64 -1 0 1 + * y (y) int64 24B -1 0 1 Data variables: - int int64 4 - float (y) int64 4 0 2 + int int64 8B 4 + float (y) int64 24B 4 0 2 >>> ds.idxmin(dim="x") - + Size: 52B Dimensions: (y: 3) Coordinates: - * y (y) int64 -1 0 1 + * y (y) int64 24B -1 0 1 Data variables: - int T_Dataset: + ) -> Self: """Return the coordinate label of the maximum value along a dimension. Returns a new `Dataset` named after the dimension with the values of @@ -8271,37 +9303,37 @@ def idxmax( >>> array2 = xr.DataArray( ... [ ... [2.0, 1.0, 2.0, 0.0, -2.0], - ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], - ... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN], + ... [-4.0, np.nan, 2.0, np.nan, -2.0], + ... [np.nan, np.nan, 1.0, np.nan, np.nan], ... ], ... dims=["y", "x"], ... coords={"y": [-1, 0, 1], "x": ["a", "b", "c", "d", "e"]}, ... ) >>> ds = xr.Dataset({"int": array1, "float": array2}) >>> ds.max(dim="x") - + Size: 56B Dimensions: (y: 3) Coordinates: - * y (y) int64 -1 0 1 + * y (y) int64 24B -1 0 1 Data variables: - int int64 2 - float (y) float64 2.0 2.0 1.0 + int int64 8B 2 + float (y) float64 24B 2.0 2.0 1.0 >>> ds.argmax(dim="x") - + Size: 56B Dimensions: (y: 3) Coordinates: - * y (y) int64 -1 0 1 + * y (y) int64 24B -1 0 1 Data variables: - int int64 1 - float (y) int64 0 2 2 + int int64 8B 1 + float (y) int64 24B 0 2 2 >>> ds.idxmax(dim="x") - + Size: 52B Dimensions: (y: 3) Coordinates: - * y (y) int64 -1 0 1 + * y (y) int64 24B -1 0 1 Data variables: - int T_Dataset: + def argmin(self, dim: Hashable | None = None, **kwargs) -> Self: """Indices of the minima of the member variables. If there are multiple minima, the indices of the first one found will be @@ -8341,8 +9373,52 @@ def argmin(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: ------- result : Dataset + Examples + -------- + >>> dataset = xr.Dataset( + ... { + ... "math_scores": ( + ... ["student", "test"], + ... [[90, 85, 79], [78, 80, 85], [95, 92, 98]], + ... ), + ... "english_scores": ( + ... ["student", "test"], + ... [[88, 90, 92], [75, 82, 79], [39, 96, 78]], + ... ), + ... }, + ... coords={ + ... "student": ["Alice", "Bob", "Charlie"], + ... "test": ["Test 1", "Test 2", "Test 3"], + ... }, + ... ) + + # Indices of the minimum values along the 'student' dimension are calculated + + >>> argmin_indices = dataset.argmin(dim="student") + + >>> min_score_in_math = dataset["student"].isel( + ... student=argmin_indices["math_scores"] + ... ) + >>> min_score_in_math + Size: 84B + array(['Bob', 'Bob', 'Alice'], dtype='>> min_score_in_english = dataset["student"].isel( + ... student=argmin_indices["english_scores"] + ... ) + >>> min_score_in_english + Size: 84B + array(['Charlie', 'Bob', 'Charlie'], dtype=' T_Dataset: "Dataset.argmin() with a sequence or ... for dim" ) - def argmax(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: + def argmax(self, dim: Hashable | None = None, **kwargs) -> Self: """Indices of the maxima of the member variables. If there are multiple maxima, the indices of the first one found will be @@ -8400,6 +9476,39 @@ def argmax(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: ------- result : Dataset + Examples + -------- + + >>> dataset = xr.Dataset( + ... { + ... "math_scores": ( + ... ["student", "test"], + ... [[90, 85, 92], [78, 80, 85], [95, 92, 98]], + ... ), + ... "english_scores": ( + ... ["student", "test"], + ... [[88, 90, 92], [75, 82, 79], [93, 96, 91]], + ... ), + ... }, + ... coords={ + ... "student": ["Alice", "Bob", "Charlie"], + ... "test": ["Test 1", "Test 2", "Test 3"], + ... }, + ... ) + + # Indices of the maximum values along the 'student' dimension are calculated + + >>> argmax_indices = dataset.argmax(dim="test") + + >>> argmax_indices + Size: 132B + Dimensions: (student: 3) + Coordinates: + * student (student) T_Dataset: "Dataset.argmin() with a sequence or ... for dim" ) + def eval( + self, + statement: str, + *, + parser: QueryParserOptions = "pandas", + ) -> Self | T_DataArray: + """ + Calculate an expression supplied as a string in the context of the dataset. + + This is currently experimental; the API may change particularly around + assignments, which currently returnn a ``Dataset`` with the additional variable. + Currently only the ``python`` engine is supported, which has the same + performance as executing in python. + + Parameters + ---------- + statement : str + String containing the Python-like expression to evaluate. + + Returns + ------- + result : Dataset or DataArray, depending on whether ``statement`` contains an + assignment. + + Examples + -------- + >>> ds = xr.Dataset( + ... {"a": ("x", np.arange(0, 5, 1)), "b": ("x", np.linspace(0, 1, 5))} + ... ) + >>> ds + Size: 80B + Dimensions: (x: 5) + Dimensions without coordinates: x + Data variables: + a (x) int64 40B 0 1 2 3 4 + b (x) float64 40B 0.0 0.25 0.5 0.75 1.0 + + >>> ds.eval("a + b") + Size: 40B + array([0. , 1.25, 2.5 , 3.75, 5. ]) + Dimensions without coordinates: x + + >>> ds.eval("c = a + b") + Size: 120B + Dimensions: (x: 5) + Dimensions without coordinates: x + Data variables: + a (x) int64 40B 0 1 2 3 4 + b (x) float64 40B 0.0 0.25 0.5 0.75 1.0 + c (x) float64 40B 0.0 1.25 2.5 3.75 5.0 + """ + + return pd.eval( + statement, + resolvers=[self], + target=self, + parser=parser, + # Because numexpr returns a numpy array, using that engine results in + # different behavior. We'd be very open to a contribution handling this. + engine="python", + ) + def query( - self: T_Dataset, + self, queries: Mapping[Any, Any] | None = None, parser: QueryParserOptions = "pandas", engine: QueryEngineOptions = None, missing_dims: ErrorOptionsWithWarn = "raise", **queries_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Return a new dataset with each array indexed along the specified dimension(s), where the indexers are given as strings containing Python expressions to be evaluated against the data variables in the @@ -8495,19 +9666,19 @@ def query( >>> b = np.linspace(0, 1, 5) >>> ds = xr.Dataset({"a": ("x", a), "b": ("x", b)}) >>> ds - + Size: 80B Dimensions: (x: 5) Dimensions without coordinates: x Data variables: - a (x) int64 0 1 2 3 4 - b (x) float64 0.0 0.25 0.5 0.75 1.0 + a (x) int64 40B 0 1 2 3 4 + b (x) float64 40B 0.0 0.25 0.5 0.75 1.0 >>> ds.query(x="a > 2") - + Size: 32B Dimensions: (x: 2) Dimensions without coordinates: x Data variables: - a (x) int64 3 4 - b (x) float64 0.75 1.0 + a (x) int64 16B 3 4 + b (x) float64 16B 0.75 1.0 """ # allow queries to be given either as a dict or as kwargs @@ -8529,16 +9700,17 @@ def query( return self.isel(indexers, missing_dims=missing_dims) def curvefit( - self: T_Dataset, + self, coords: str | DataArray | Iterable[str | DataArray], func: Callable[..., Any], reduce_dims: Dims = None, skipna: bool = True, - p0: dict[str, Any] | None = None, - bounds: dict[str, Any] | None = None, + p0: Mapping[str, float | DataArray] | None = None, + bounds: Mapping[str, tuple[float | DataArray, float | DataArray]] | None = None, param_names: Sequence[str] | None = None, + errors: ErrorOptions = "raise", kwargs: dict[str, Any] | None = None, - ) -> T_Dataset: + ) -> Self: """ Curve fitting optimization for arbitrary functions. @@ -8566,17 +9738,25 @@ def curvefit( Whether to skip missing values when fitting. Default is True. p0 : dict-like, optional Optional dictionary of parameter names to initial guesses passed to the - `curve_fit` `p0` arg. If none or only some parameters are passed, the rest will - be assigned initial values following the default scipy behavior. + `curve_fit` `p0` arg. If the values are DataArrays, they will be appropriately + broadcast to the coordinates of the array. If none or only some parameters are + passed, the rest will be assigned initial values following the default scipy + behavior. bounds : dict-like, optional - Optional dictionary of parameter names to bounding values passed to the - `curve_fit` `bounds` arg. If none or only some parameters are passed, the rest - will be unbounded following the default scipy behavior. + Optional dictionary of parameter names to tuples of bounding values passed to the + `curve_fit` `bounds` arg. If any of the bounds are DataArrays, they will be + appropriately broadcast to the coordinates of the array. If none or only some + parameters are passed, the rest will be unbounded following the default scipy + behavior. param_names : sequence of hashable, optional Sequence of names for the fittable parameters of `func`. If not supplied, this will be automatically determined by arguments of `func`. `param_names` should be manually supplied when fitting a function that takes a variable number of parameters. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will + raise an exception. If 'ignore', the coefficients and covariances for the + coordinates where the fitting failed will be NaN. **kwargs : optional Additional keyword arguments to passed to scipy curve_fit. @@ -8638,29 +9818,56 @@ def curvefit( "in fitting on scalar data." ) + # Check that initial guess and bounds only contain coordinates that are in preserved_dims + for param, guess in p0.items(): + if isinstance(guess, DataArray): + unexpected = set(guess.dims) - set(preserved_dims) + if unexpected: + raise ValueError( + f"Initial guess for '{param}' has unexpected dimensions " + f"{tuple(unexpected)}. It should only have dimensions that are in data " + f"dimensions {preserved_dims}." + ) + for param, (lb, ub) in bounds.items(): + for label, bound in zip(("Lower", "Upper"), (lb, ub)): + if isinstance(bound, DataArray): + unexpected = set(bound.dims) - set(preserved_dims) + if unexpected: + raise ValueError( + f"{label} bound for '{param}' has unexpected dimensions " + f"{tuple(unexpected)}. It should only have dimensions that are in data " + f"dimensions {preserved_dims}." + ) + + if errors not in ["raise", "ignore"]: + raise ValueError('errors must be either "raise" or "ignore"') + # Broadcast all coords with each other coords_ = broadcast(*coords_) coords_ = [ coord.broadcast_like(self, exclude=preserved_dims) for coord in coords_ ] + n_coords = len(coords_) params, func_args = _get_func_args(func, param_names) param_defaults, bounds_defaults = _initialize_curvefit_params( params, p0, bounds, func_args ) n_params = len(params) - kwargs.setdefault("p0", [param_defaults[p] for p in params]) - kwargs.setdefault( - "bounds", - [ - [bounds_defaults[p][0] for p in params], - [bounds_defaults[p][1] for p in params], - ], - ) - def _wrapper(Y, *coords_, **kwargs): + def _wrapper(Y, *args, **kwargs): # Wrap curve_fit with raveled coordinates and pointwise NaN handling - x = np.vstack([c.ravel() for c in coords_]) + # *args contains: + # - the coordinates + # - initial guess + # - lower bounds + # - upper bounds + coords__ = args[:n_coords] + p0_ = args[n_coords + 0 * n_params : n_coords + 1 * n_params] + lb = args[n_coords + 1 * n_params : n_coords + 2 * n_params] + ub = args[n_coords + 2 * n_params :] + + x = np.vstack([c.ravel() for c in coords__]) y = Y.ravel() if skipna: mask = np.all([np.any(~np.isnan(x), axis=0), ~np.isnan(y)], axis=0) @@ -8671,7 +9878,15 @@ def _wrapper(Y, *coords_, **kwargs): pcov = np.full([n_params, n_params], np.nan) return popt, pcov x = np.squeeze(x) - popt, pcov = curve_fit(func, x, y, **kwargs) + + try: + popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs) + except RuntimeError: + if errors == "raise": + raise + popt = np.full([n_params], np.nan) + pcov = np.full([n_params, n_params], np.nan) + return popt, pcov result = type(self)() @@ -8681,13 +9896,21 @@ def _wrapper(Y, *coords_, **kwargs): else: name = f"{str(name)}_" + input_core_dims = [reduce_dims_ for _ in range(n_coords + 1)] + input_core_dims.extend( + [[] for _ in range(3 * n_params)] + ) # core_dims for p0 and bounds + popt, pcov = apply_ufunc( _wrapper, da, *coords_, + *param_defaults.values(), + *[b[0] for b in bounds_defaults.values()], + *[b[1] for b in bounds_defaults.values()], vectorize=True, dask="parallelized", - input_core_dims=[reduce_dims_ for d in range(len(coords_) + 1)], + input_core_dims=input_core_dims, output_core_dims=[["param"], ["cov_i", "cov_j"]], dask_gufunc_kwargs={ "output_sizes": { @@ -8710,11 +9933,13 @@ def _wrapper(Y, *coords_, **kwargs): return result + @_deprecate_positional_args("v2023.10.0") def drop_duplicates( - self: T_Dataset, + self, dim: Hashable | Iterable[Hashable], + *, keep: Literal["first", "last", False] = "first", - ) -> T_Dataset: + ) -> Self: """Returns a new Dataset with duplicate dimension values removed. Parameters @@ -8746,19 +9971,21 @@ def drop_duplicates( missing_dims = set(dims) - set(self.dims) if missing_dims: - raise ValueError(f"'{missing_dims}' not found in dimensions") + raise ValueError( + f"Dimensions {tuple(missing_dims)} not found in data dimensions {tuple(self.dims)}" + ) indexes = {dim: ~self.get_index(dim).duplicated(keep=keep) for dim in dims} return self.isel(indexes) def convert_calendar( - self: T_Dataset, + self, calendar: CFCalendar, dim: Hashable = "time", align_on: Literal["date", "year", None] = None, missing: Any | None = None, use_cftime: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Convert the Dataset to another calendar. Only converts the individual timestamps, does not modify any data except @@ -8875,10 +10102,10 @@ def convert_calendar( ) def interp_calendar( - self: T_Dataset, + self, target: pd.DatetimeIndex | CFTimeIndex | DataArray, dim: Hashable = "time", - ) -> T_Dataset: + ) -> Self: """Interpolates the Dataset to another calendar based on decimal year measure. Each timestamp in `source` and `target` are first converted to their decimal @@ -8908,7 +10135,7 @@ def interp_calendar( def groupby( self, group: Hashable | DataArray | IndexVariable, - squeeze: bool = True, + squeeze: bool | None = None, restore_coord_dims: bool = False, ) -> DatasetGroupBy: """Returns a DatasetGroupBy object for performing grouped operations. @@ -8936,28 +10163,36 @@ def groupby( -------- :ref:`groupby` Users guide explanation of how to group and bin data. + + :doc:`xarray-tutorial:intermediate/01-high-level-computation-patterns` + Tutorial on :py:func:`~xarray.Dataset.Groupby` for windowed computation. + + :doc:`xarray-tutorial:fundamentals/03.2_groupby_with_xarray` + Tutorial on :py:func:`~xarray.Dataset.Groupby` demonstrating reductions, transformation and comparison with :py:func:`~xarray.Dataset.resample`. + Dataset.groupby_bins DataArray.groupby core.groupby.DatasetGroupBy pandas.DataFrame.groupby + Dataset.coarsen Dataset.resample DataArray.resample """ - from xarray.core.groupby import DatasetGroupBy + from xarray.core.groupby import ( + DatasetGroupBy, + ResolvedGrouper, + UniqueGrouper, + _validate_groupby_squeeze, + ) - # While we don't generally check the type of every arg, passing - # multiple dimensions as multiple arguments is common enough, and the - # consequences hidden enough (strings evaluate as true) to warrant - # checking here. - # A future version could make squeeze kwarg only, but would face - # backward-compat issues. - if not isinstance(squeeze, bool): - raise TypeError( - f"`squeeze` must be True or False, but {squeeze} was supplied" - ) + _validate_groupby_squeeze(squeeze) + rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) return DatasetGroupBy( - self, group, squeeze=squeeze, restore_coord_dims=restore_coord_dims + self, + (rgrouper,), + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, ) def groupby_bins( @@ -8968,7 +10203,7 @@ def groupby_bins( labels: ArrayLike | None = None, precision: int = 3, include_lowest: bool = False, - squeeze: bool = True, + squeeze: bool | None = None, restore_coord_dims: bool = False, ) -> DatasetGroupBy: """Returns a DatasetGroupBy object for performing grouped operations. @@ -9028,14 +10263,16 @@ def groupby_bins( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ - from xarray.core.groupby import DatasetGroupBy + from xarray.core.groupby import ( + BinGrouper, + DatasetGroupBy, + ResolvedGrouper, + _validate_groupby_squeeze, + ) - return DatasetGroupBy( - self, - group, - squeeze=squeeze, + _validate_groupby_squeeze(squeeze) + grouper = BinGrouper( bins=bins, - restore_coord_dims=restore_coord_dims, cut_kwargs={ "right": right, "labels": labels, @@ -9043,6 +10280,14 @@ def groupby_bins( "include_lowest": include_lowest, }, ) + rgrouper = ResolvedGrouper(grouper, group, self) + + return DatasetGroupBy( + self, + (rgrouper,), + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) def weighted(self, weights: DataArray) -> DatasetWeighted: """ @@ -9067,6 +10312,13 @@ def weighted(self, weights: DataArray) -> DatasetWeighted: See Also -------- DataArray.weighted + + :ref:`comput.weighted` + User guide on weighted array reduction using :py:func:`~xarray.Dataset.weighted` + + :doc:`xarray-tutorial:fundamentals/03.4_weighted` + Tutorial on Weighted Reduction using :py:func:`~xarray.Dataset.weighted` + """ from xarray.core.weighted import DatasetWeighted @@ -9104,8 +10356,9 @@ def rolling( See Also -------- - core.rolling.DatasetRolling + Dataset.cumulative DataArray.rolling + core.rolling.DatasetRolling """ from xarray.core.rolling import DatasetRolling @@ -9114,6 +10367,51 @@ def rolling( self, dim, min_periods=min_periods, center=center, pad=pad ) + def cumulative( + self, + dim: str | Iterable[Hashable], + min_periods: int = 1, + ) -> DatasetRolling: + """ + Accumulating object for Datasets + + Parameters + ---------- + dims : iterable of hashable + The name(s) of the dimensions to create the cumulative window along + min_periods : int, default: 1 + Minimum number of observations in window required to have a value + (otherwise result is NA). The default is 1 (note this is different + from ``Rolling``, whose default is the size of the window). + + Returns + ------- + core.rolling.DatasetRolling + + See Also + -------- + Dataset.rolling + DataArray.cumulative + core.rolling.DatasetRolling + """ + from xarray.core.rolling import DatasetRolling + + if isinstance(dim, str): + if dim not in self.dims: + raise ValueError( + f"Dimension {dim} not found in data dimensions: {self.dims}" + ) + dim = {dim: self.sizes[dim]} + else: + missing_dims = set(dim) - set(self.dims) + if missing_dims: + raise ValueError( + f"Dimensions {missing_dims} not found in data dimensions: {self.dims}" + ) + dim = {d: self.sizes[d] for d in dim} + + return DatasetRolling(self, dim, min_periods=min_periods, center=False) + def coarsen( self, dim: Mapping[Any, int] | None = None, @@ -9146,6 +10444,16 @@ def coarsen( -------- core.rolling.DatasetCoarsen DataArray.coarsen + + :ref:`reshape.coarsen` + User guide describing :py:func:`~xarray.Dataset.coarsen` + + :ref:`compute.coarsen` + User guide on block arrgragation :py:func:`~xarray.Dataset.coarsen` + + :doc:`xarray-tutorial:fundamentals/03.3_windowed` + Tutorial on windowed computation using :py:func:`~xarray.Dataset.coarsen` + """ from xarray.core.rolling import DatasetCoarsen @@ -9167,7 +10475,6 @@ def resample( base: int | None = None, offset: pd.Timedelta | datetime.timedelta | str | None = None, origin: str | DatetimeLike = "start_day", - keep_attrs: bool | None = None, loffset: datetime.timedelta | str | None = None, restore_coord_dims: bool | None = None, **indexer_kwargs: str, @@ -9209,6 +10516,12 @@ def resample( loffset : timedelta or str, optional Offset used to adjust the resampled time labels. Some pandas date offset strings are supported. + + .. deprecated:: 2023.03.0 + Following pandas, the ``loffset`` parameter is deprecated in favor + of using time offset arithmetic, and will be removed in a future + version of xarray. + restore_coord_dims : bool, optional If True, also restore the dimension order of multi-dimensional coordinates. @@ -9244,7 +10557,6 @@ def resample( base=base, offset=offset, origin=origin, - keep_attrs=keep_attrs, loffset=loffset, restore_coord_dims=restore_coord_dims, **indexer_kwargs, diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 4d8583cfe65..ccf84146819 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +from typing import Any import numpy as np @@ -37,14 +38,14 @@ def __eq__(self, other): # instead of following NumPy's own type-promotion rules. These type promotion # rules match pandas instead. For reference, see the NumPy type hierarchy: # https://numpy.org/doc/stable/reference/arrays.scalars.html -PROMOTE_TO_OBJECT = [ - {np.number, np.character}, # numpy promotes to character - {np.bool_, np.character}, # numpy promotes to character - {np.bytes_, np.unicode_}, # numpy promotes to unicode -] +PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ( + (np.number, np.character), # numpy promotes to character + (np.bool_, np.character), # numpy promotes to character + (np.bytes_, np.str_), # numpy promotes to unicode +) -def maybe_promote(dtype): +def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]: """Simpler equivalent of pandas.core.common._maybe_promote Parameters @@ -57,24 +58,33 @@ def maybe_promote(dtype): fill_value : Valid missing value for the promoted dtype. """ # N.B. these casting rules should match pandas + dtype_: np.typing.DTypeLike + fill_value: Any if np.issubdtype(dtype, np.floating): + dtype_ = dtype fill_value = np.nan elif np.issubdtype(dtype, np.timedelta64): # See https://github.com/numpy/numpy/issues/10685 # np.timedelta64 is a subclass of np.integer # Check np.timedelta64 before np.integer fill_value = np.timedelta64("NaT") + dtype_ = dtype elif np.issubdtype(dtype, np.integer): - dtype = np.float32 if dtype.itemsize <= 2 else np.float64 + dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64 fill_value = np.nan elif np.issubdtype(dtype, np.complexfloating): + dtype_ = dtype fill_value = np.nan + np.nan * 1j elif np.issubdtype(dtype, np.datetime64): + dtype_ = dtype fill_value = np.datetime64("NaT") else: - dtype = object + dtype_ = object fill_value = np.nan - return np.dtype(dtype), fill_value + + dtype_out = np.dtype(dtype_) + fill_value = dtype_out.type(fill_value) + return dtype_out, fill_value NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} @@ -156,7 +166,9 @@ def is_datetime_like(dtype): return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) -def result_type(*arrays_and_dtypes): +def result_type( + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, +) -> np.dtype: """Like np.result_type, but with type promotion rules matching pandas. Examples of changed behavior: diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 84e66803fe8..ef497e78ebf 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -3,12 +3,14 @@ Currently, this means Dask or NumPy arrays. None of these functions should accept or return xarray objects. """ + from __future__ import annotations import contextlib import datetime import inspect import warnings +from functools import partial from importlib import import_module import numpy as np @@ -17,7 +19,7 @@ from numpy import any as array_any # noqa from numpy import ( # noqa around, # noqa - einsum, + full_like, gradient, isclose, isin, @@ -26,15 +28,28 @@ tensordot, transpose, unravel_index, - zeros_like, # noqa ) from numpy import concatenate as _concatenate from numpy.lib.stride_tricks import sliding_window_view # noqa +from packaging.version import Version from xarray.core import dask_array_ops, dtypes, nputils -from xarray.core.nputils import nanfirst, nanlast -from xarray.core.pycompat import array_type, is_duck_dask_array -from xarray.core.utils import is_duck_array, module_available +from xarray.core.options import OPTIONS +from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available +from xarray.namedarray import pycompat +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import array_type, is_chunked_array + +# remove once numpy 2.0 is the oldest supported version +if module_available("numpy", minversion="2.0.0.dev0"): + from numpy.lib.array_utils import ( # type: ignore[import-not-found,unused-ignore] + normalize_axis_index, + ) +else: + from numpy.core.multiarray import ( # type: ignore[attr-defined,no-redef,unused-ignore] + normalize_axis_index, + ) + dask_available = module_available("dask") @@ -46,6 +61,17 @@ def get_array_namespace(x): return np +def einsum(*args, **kwargs): + from xarray.core.options import OPTIONS + + if OPTIONS["use_opt_einsum"] and module_available("opt_einsum"): + import opt_einsum + + return opt_einsum.contract(*args, **kwargs) + else: + return np.einsum(*args, **kwargs) + + def _dask_or_eager_func( name, eager_module=np, @@ -127,7 +153,7 @@ def isnull(data): return xp.isnan(data) elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)): # these types cannot represent missing values - return zeros_like(data, dtype=bool) + return full_like(data, dtype=bool, fill_value=False) else: # at this point, array should have dtype=object if isinstance(data, np.ndarray): @@ -182,6 +208,9 @@ def cumulative_trapezoid(y, x, axis): def astype(data, dtype, **kwargs): if hasattr(data, "__array_namespace__"): xp = get_array_namespace(data) + if xp == np: + # numpy currently doesn't have a astype: + return data.astype(dtype, **kwargs) return xp.astype(data, dtype, **kwargs) return data.astype(dtype, **kwargs) @@ -192,7 +221,10 @@ def asarray(data, xp=np): def as_shared_dtype(scalars_or_arrays, xp=np): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - if any(isinstance(x, array_type("cupy")) for x in scalars_or_arrays): + array_type_cupy = array_type("cupy") + if array_type_cupy and any( + isinstance(x, array_type_cupy) for x in scalars_or_arrays + ): import cupy as cp arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] @@ -315,7 +347,10 @@ def fillna(data, other): def concatenate(arrays, axis=0): """concatenate() with better dtype promotion rules.""" - if hasattr(arrays[0], "__array_namespace__"): + # TODO: remove the additional check once `numpy` adds `concat` to its array namespace + if hasattr(arrays[0], "__array_namespace__") and not isinstance( + arrays[0], np.ndarray + ): xp = get_array_namespace(arrays[0]) return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) return _concatenate(as_shared_dtype(arrays), axis=axis) @@ -332,6 +367,10 @@ def reshape(array, shape): return xp.reshape(array, shape) +def ravel(array): + return reshape(array, (-1,)) + + @contextlib.contextmanager def _ignore_warnings_if(condition): if condition: @@ -358,7 +397,7 @@ def f(values, axis=None, skipna=None, **kwargs): values = asarray(values) if coerce_strings and values.dtype.kind in "SU": - values = values.astype(object) + values = astype(values, object) func = None if skipna or (skipna is None and values.dtype.kind in "cfO"): @@ -640,10 +679,10 @@ def first(values, axis, skipna=None): """Return the first non-NA elements in this array along the given axis""" if (skipna or skipna is None) and values.dtype.kind not in "iSU": # only bother for dtypes that can hold NaN - if is_duck_dask_array(values): - return dask_array_ops.nanfirst(values, axis) + if is_chunked_array(values): + return chunked_nanfirst(values, axis) else: - return nanfirst(values, axis) + return nputils.nanfirst(values, axis) return take(values, 0, axis=axis) @@ -651,10 +690,10 @@ def last(values, axis, skipna=None): """Return the last non-NA elements in this array along the given axis""" if (skipna or skipna is None) and values.dtype.kind not in "iSU": # only bother for dtypes that can hold NaN - if is_duck_dask_array(values): - return dask_array_ops.nanlast(values, axis) + if is_chunked_array(values): + return chunked_nanlast(values, axis) else: - return nanlast(values, axis) + return nputils.nanlast(values, axis) return take(values, -1, axis=axis) @@ -666,10 +705,70 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna) -def push(array, n, axis): - from bottleneck import push +def _push(array, n: int | None = None, axis: int = -1): + """ + Use either bottleneck or numbagg depending on options & what's available + """ + + if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]: + raise RuntimeError( + "ffill & bfill requires bottleneck or numbagg to be enabled." + " Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one." + ) + if OPTIONS["use_numbagg"] and module_available("numbagg"): + import numbagg + + if pycompat.mod_version("numbagg") < Version("0.6.2"): + warnings.warn( + f"numbagg >= 0.6.2 is required for bfill & ffill; {pycompat.mod_version('numbagg')} is installed. We'll attempt with bottleneck instead." + ) + else: + return numbagg.ffill(array, limit=n, axis=axis) + # work around for bottleneck 178 + limit = n if n is not None else array.shape[axis] + + import bottleneck as bn + + return bn.push(array, limit, axis) + + +def push(array, n, axis): + if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]: + raise RuntimeError( + "ffill & bfill requires bottleneck or numbagg to be enabled." + " Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one." + ) if is_duck_dask_array(array): return dask_array_ops.push(array, n, axis) else: - return push(array, n, axis) + return _push(array, n, axis) + + +def _first_last_wrapper(array, *, axis, op, keepdims): + return op(array, axis, keepdims=keepdims) + + +def _chunked_first_or_last(darray, axis, op): + chunkmanager = get_chunked_array_type(darray) + + # This will raise the same error message seen for numpy + axis = normalize_axis_index(axis, darray.ndim) + + wrapped_op = partial(_first_last_wrapper, op=op) + return chunkmanager.reduction( + darray, + func=wrapped_op, + aggregate_func=wrapped_op, + axis=axis, + dtype=darray.dtype, + keepdims=False, # match numpy version + ) + + +def chunked_nanfirst(darray, axis): + return _chunked_first_or_last(darray, axis, op=nputils.nanfirst) + + +def chunked_nanlast(darray, axis): + return _chunked_first_or_last(darray, axis, op=nputils.nanlast) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index ed548771809..260dabd9d31 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -1,25 +1,32 @@ """String formatting routines for __repr__. """ + from __future__ import annotations import contextlib import functools import math from collections import defaultdict -from collections.abc import Collection, Hashable +from collections.abc import Collection, Hashable, Sequence from datetime import datetime, timedelta from itertools import chain, zip_longest from reprlib import recursive_repr +from typing import TYPE_CHECKING import numpy as np import pandas as pd from pandas.errors import OutOfBoundsDatetime -from xarray.core.duck_array_ops import array_equiv +from xarray.core.duck_array_ops import array_equiv, astype from xarray.core.indexing import MemoryCachedArray from xarray.core.options import OPTIONS, _get_boolean_with_default -from xarray.core.pycompat import array_type from xarray.core.utils import is_duck_array +from xarray.namedarray.pycompat import array_type, to_duck_array, to_numpy + +if TYPE_CHECKING: + from xarray.core.coordinates import AbstractCoordinates + +UNITS = ("B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") def pretty_print(x, numchars: int): @@ -64,6 +71,8 @@ def first_n_items(array, n_desired): # might not be a numpy.ndarray. Moreover, access to elements of the array # could be very expensive (e.g. if it's only available over DAP), so go out # of our way to get them in a single call to __getitem__ using only slices. + from xarray.core.variable import Variable + if n_desired < 1: raise ValueError("must request at least one item") @@ -74,7 +83,14 @@ def first_n_items(array, n_desired): if n_desired < array.size: indexer = _get_indexer_at_least_n_items(array.shape, n_desired, from_end=False) array = array[indexer] - return np.asarray(array).flat[:n_desired] + + # We pass variable objects in to handle indexing + # with indexer above. It would not work with our + # lazy indexing classes at the moment, so we cannot + # pass Variable._data + if isinstance(array, Variable): + array = array._data + return np.ravel(to_duck_array(array))[:n_desired] def last_n_items(array, n_desired): @@ -83,13 +99,22 @@ def last_n_items(array, n_desired): # might not be a numpy.ndarray. Moreover, access to elements of the array # could be very expensive (e.g. if it's only available over DAP), so go out # of our way to get them in a single call to __getitem__ using only slices. + from xarray.core.variable import Variable + if (n_desired == 0) or (array.size == 0): return [] if n_desired < array.size: indexer = _get_indexer_at_least_n_items(array.shape, n_desired, from_end=True) array = array[indexer] - return np.asarray(array).flat[-n_desired:] + + # We pass variable objects in to handle indexing + # with indexer above. It would not work with our + # lazy indexing classes at the moment, so we cannot + # pass Variable._data + if isinstance(array, Variable): + array = array._data + return np.ravel(to_duck_array(array))[-n_desired:] def last_item(array): @@ -99,7 +124,8 @@ def last_item(array): return [] indexer = (slice(-1, None),) * array.ndim - return np.ravel(np.asarray(array[indexer])).tolist() + # to_numpy since dask doesn't support tolist + return np.ravel(to_numpy(array[indexer])).tolist() def calc_max_rows_first(max_rows: int) -> int: @@ -114,9 +140,9 @@ def calc_max_rows_last(max_rows: int) -> int: def format_timestamp(t): """Cast given object to a Timestamp and return a nicely formatted string""" - # Timestamp is only valid for 1678 to 2262 try: - datetime_str = str(pd.Timestamp(t)) + timestamp = pd.Timestamp(t) + datetime_str = timestamp.isoformat(sep=" ") except OutOfBoundsDatetime: datetime_str = str(t) @@ -156,6 +182,8 @@ def format_item(x, timedelta_format=None, quote_strings=True): if isinstance(x, (np.timedelta64, timedelta)): return format_timedelta(x, timedelta_format=timedelta_format) elif isinstance(x, (str, bytes)): + if hasattr(x, "dtype"): + x = x.item() return repr(x) if quote_strings else x elif hasattr(x, "dtype") and np.issubdtype(x.dtype, np.floating): return f"{x.item():.4}" @@ -165,10 +193,10 @@ def format_item(x, timedelta_format=None, quote_strings=True): def format_items(x): """Returns a succinct summaries of all items in a sequence as strings""" - x = np.asarray(x) + x = to_duck_array(x) timedelta_format = "datetime" if np.issubdtype(x.dtype, np.timedelta64): - x = np.asarray(x, dtype="timedelta64[ns]") + x = astype(x, dtype="timedelta64[ns]") day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]") time_needed = x[~pd.isnull(x)] != day_part day_needed = day_part != np.timedelta64(0, "ns") @@ -308,7 +336,9 @@ def summarize_variable( dims_str = "({}) ".format(", ".join(map(str, variable.dims))) else: dims_str = "" - front_str = f"{first_col}{dims_str}{variable.dtype} " + + nbytes_str = f" {render_human_readable_nbytes(variable.nbytes)}" + front_str = f"{first_col}{dims_str}{variable.dtype}{nbytes_str} " values_width = max_width - len(front_str) values_str = inline_variable_array_repr(variable, values_width) @@ -332,7 +362,7 @@ def summarize_attr(key, value, col_width=None): def _calculate_col_width(col_items): - max_name_length = max(len(str(s)) for s in col_items) if col_items else 0 + max_name_length = max((len(str(s)) for s in col_items), default=0) col_width = max(max_name_length, 7) + 6 return col_width @@ -398,7 +428,7 @@ def _mapping_repr( ) -def coords_repr(coords, col_width=None, max_rows=None): +def coords_repr(coords: AbstractCoordinates, col_width=None, max_rows=None): if col_width is None: col_width = _calculate_col_width(coords) return _mapping_repr( @@ -412,7 +442,7 @@ def coords_repr(coords, col_width=None, max_rows=None): ) -def inline_index_repr(index, max_width=None): +def inline_index_repr(index: pd.Index, max_width=None): if hasattr(index, "_repr_inline_"): repr_ = index._repr_inline_(max_width=max_width) else: @@ -424,21 +454,37 @@ def inline_index_repr(index, max_width=None): def summarize_index( - name: Hashable, index, col_width: int, max_width: int | None = None -): + names: tuple[Hashable, ...], + index, + col_width: int, + max_width: int | None = None, +) -> str: if max_width is None: max_width = OPTIONS["display_width"] - preformatted = pretty_print(f" {name} ", col_width) + def prefixes(length: int) -> list[str]: + if length in (0, 1): + return [" "] + + return ["┌"] + ["│"] * max(length - 2, 0) + ["└"] + + preformatted = [ + pretty_print(f" {prefix} {name}", col_width) + for prefix, name in zip(prefixes(len(names)), names) + ] - index_width = max_width - len(preformatted) + head, *tail = preformatted + index_width = max_width - len(head) repr_ = inline_index_repr(index, max_width=index_width) - return preformatted + repr_ + return "\n".join([head + repr_] + [line.rstrip() for line in tail]) -def nondefault_indexes(indexes): +def filter_nondefault_indexes(indexes, filter_indexes: bool): from xarray.core.indexes import PandasIndex, PandasMultiIndex + if not filter_indexes: + return indexes + default_indexes = (PandasIndex, PandasMultiIndex) return { @@ -448,7 +494,9 @@ def nondefault_indexes(indexes): } -def indexes_repr(indexes, col_width=None, max_rows=None): +def indexes_repr(indexes, max_rows: int | None = None) -> str: + col_width = _calculate_col_width(chain.from_iterable(indexes)) + return _mapping_repr( indexes, "Indexes", @@ -557,8 +605,12 @@ def limit_lines(string: str, *, limit: int): return string -def short_numpy_repr(array): - array = np.asarray(array) +def short_array_repr(array): + from xarray.core.common import AbstractArray + + if isinstance(array, AbstractArray): + array = array.data + array = to_duck_array(array) # default to lower precision so a full (abbreviated) line can fit on # one line with the default display_width @@ -582,16 +634,22 @@ def short_data_repr(array): """Format "data" for DataArray and Variable.""" internal_data = getattr(array, "variable", array)._data if isinstance(array, np.ndarray): - return short_numpy_repr(array) + return short_array_repr(array) elif is_duck_array(internal_data): return limit_lines(repr(array.data), limit=40) - elif array._in_memory: - return short_numpy_repr(array) + elif getattr(array, "_in_memory", None): + return short_array_repr(array) else: # internal xarray array type return f"[{array.size} values with dtype={array.dtype}]" +def _get_indexes_dict(indexes): + return { + tuple(index_vars.keys()): idx for idx, index_vars in indexes.group_by_index() + } + + @recursive_repr("") def array_repr(arr): from xarray.core.variable import Variable @@ -615,11 +673,11 @@ def array_repr(arr): start = f"", + f"{start}({dims})> Size: {nbytes_str}", data_repr, ] - if hasattr(arr, "coords"): if arr.coords: col_width = _calculate_col_width(arr.coords) @@ -636,15 +694,13 @@ def array_repr(arr): display_default_indexes = _get_boolean_with_default( "display_default_indexes", False ) - if display_default_indexes: - xindexes = arr.xindexes - else: - xindexes = nondefault_indexes(arr.xindexes) + + xindexes = filter_nondefault_indexes( + _get_indexes_dict(arr.xindexes), not display_default_indexes + ) if xindexes: - summary.append( - indexes_repr(xindexes, col_width=col_width, max_rows=max_rows) - ) + summary.append(indexes_repr(xindexes, max_rows=max_rows)) if arr.attrs: summary.append(attrs_repr(arr.attrs, max_rows=max_rows)) @@ -654,7 +710,8 @@ def array_repr(arr): @recursive_repr("") def dataset_repr(ds): - summary = [f""] + nbytes_str = render_human_readable_nbytes(ds.nbytes) + summary = [f" Size: {nbytes_str}"] col_width = _calculate_col_width(ds.variables) max_rows = OPTIONS["display_max_rows"] @@ -675,12 +732,11 @@ def dataset_repr(ds): display_default_indexes = _get_boolean_with_default( "display_default_indexes", False ) - if display_default_indexes: - xindexes = ds.xindexes - else: - xindexes = nondefault_indexes(ds.xindexes) + xindexes = filter_nondefault_indexes( + _get_indexes_dict(ds.xindexes), not display_default_indexes + ) if xindexes: - summary.append(indexes_repr(xindexes, col_width=col_width, max_rows=max_rows)) + summary.append(indexes_repr(xindexes, max_rows=max_rows)) if ds.attrs: summary.append(attrs_repr(ds.attrs, max_rows=max_rows)) @@ -689,10 +745,8 @@ def dataset_repr(ds): def diff_dim_summary(a, b): - if a.dims != b.dims: - return "Differing dimensions:\n ({}) != ({})".format( - dim_summary(a), dim_summary(b) - ) + if a.sizes != b.sizes: + return f"Differing dimensions:\n ({dim_summary(a)}) != ({dim_summary(b)})" else: return "" @@ -735,9 +789,11 @@ def extra_items_repr(extra_keys, mapping, ab_side, kwargs): try: # compare xarray variable if not callable(compat): - compatible = getattr(a_mapping[k], compat)(b_mapping[k]) + compatible = getattr(a_mapping[k].variable, compat)( + b_mapping[k].variable + ) else: - compatible = compat(a_mapping[k], b_mapping[k]) + compatible = compat(a_mapping[k].variable, b_mapping[k].variable) is_variable = True except AttributeError: # compare attribute value @@ -756,11 +812,21 @@ def extra_items_repr(extra_keys, mapping, ab_side, kwargs): if compat == "identical" and is_variable: attrs_summary = [] + a_attrs = a_mapping[k].attrs + b_attrs = b_mapping[k].attrs + attrs_to_print = set(a_attrs) ^ set(b_attrs) + attrs_to_print.update( + {k for k in set(a_attrs) & set(b_attrs) if a_attrs[k] != b_attrs[k]} + ) for m in (a_mapping, b_mapping): attr_s = "\n".join( - summarize_attr(ak, av) for ak, av in m[k].attrs.items() + " " + summarize_attr(ak, av) + for ak, av in m[k].attrs.items() + if ak in attrs_to_print ) + if attr_s: + attr_s = " Differing variable attributes:\n" + attr_s attrs_summary.append(attr_s) temp = [ @@ -768,6 +834,18 @@ def extra_items_repr(extra_keys, mapping, ab_side, kwargs): for var_s, attr_s in zip(temp, attrs_summary) ] + # TODO: It should be possible recursively use _diff_mapping_repr + # instead of explicitly handling variable attrs specially. + # That would require some refactoring. + # newdiff = _diff_mapping_repr( + # {k: v for k,v in a_attrs.items() if k in attrs_to_print}, + # {k: v for k,v in b_attrs.items() if k in attrs_to_print}, + # compat=compat, + # summarizer=summarize_attr, + # title="Variable Attributes" + # ) + # temp += [newdiff] + diff_items += [ab_side + s[1:] for ab_side, s in zip(("L", "R"), temp)] if diff_items: @@ -819,9 +897,7 @@ def _compat_to_str(compat): def diff_array_repr(a, b, compat): # used for DataArray, Variable and IndexVariable summary = [ - "Left and right {} objects are not {}".format( - type(a).__name__, _compat_to_str(compat) - ) + f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}" ] summary.append(diff_dim_summary(a, b)) @@ -831,7 +907,7 @@ def diff_array_repr(a, b, compat): equiv = array_equiv if not equiv(a.data, b.data): - temp = [wrap_indent(short_numpy_repr(obj), start=" ") for obj in (a, b)] + temp = [wrap_indent(short_array_repr(obj), start=" ") for obj in (a, b)] diff_data_repr = [ ab_side + "\n" + ab_data_repr for ab_side, ab_data_repr in zip(("L", "R"), temp) @@ -852,9 +928,7 @@ def diff_array_repr(a, b, compat): def diff_dataset_repr(a, b, compat): summary = [ - "Left and right {} objects are not {}".format( - type(a).__name__, _compat_to_str(compat) - ) + f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}" ] col_width = _calculate_col_width(set(list(a.variables) + list(b.variables))) @@ -869,3 +943,59 @@ def diff_dataset_repr(a, b, compat): summary.append(diff_attrs_repr(a.attrs, b.attrs, compat)) return "\n".join(summary) + + +def shorten_list_repr(items: Sequence, max_items: int) -> str: + if len(items) <= max_items: + return repr(items) + else: + first_half = repr(items[: max_items // 2])[ + 1:-1 + ] # Convert to string and remove brackets + second_half = repr(items[-max_items // 2 :])[ + 1:-1 + ] # Convert to string and remove brackets + return f"[{first_half}, ..., {second_half}]" + + +def render_human_readable_nbytes( + nbytes: int, + /, + *, + attempt_constant_width: bool = False, +) -> str: + """Renders simple human-readable byte count representation + + This is only a quick representation that should not be relied upon for precise needs. + + To get the exact byte count, please use the ``nbytes`` attribute directly. + + Parameters + ---------- + nbytes + Byte count + attempt_constant_width + For reasonable nbytes sizes, tries to render a fixed-width representation. + + Returns + ------- + Human-readable representation of the byte count + """ + dividend = float(nbytes) + divisor = 1000.0 + last_unit_available = UNITS[-1] + + for unit in UNITS: + if dividend < divisor or unit == last_unit_available: + break + dividend /= divisor + + dividend_str = f"{dividend:.0f}" + unit_str = f"{unit}" + + if attempt_constant_width: + dividend_str = dividend_str.rjust(3) + unit_str = unit_str.ljust(2) + + string = f"{dividend_str}{unit_str}" + return string diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index d8d20a9e2c0..2c76b182207 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -4,7 +4,7 @@ from collections import OrderedDict from functools import lru_cache, partial from html import escape -from importlib.resources import read_binary +from importlib.resources import files from xarray.core.formatting import ( inline_index_repr, @@ -23,12 +23,12 @@ def _load_static_files(): """Lazily load the resource files into memory the first time they are needed""" return [ - read_binary(package, resource).decode("utf-8") + files(package).joinpath(resource).read_text(encoding="utf-8") for package, resource in STATIC_FILES ] -def short_data_repr_html(array): +def short_data_repr_html(array) -> str: """Format "data" for DataArray and Variable.""" internal_data = getattr(array, "variable", array)._data if hasattr(internal_data, "_repr_html_"): @@ -37,42 +37,43 @@ def short_data_repr_html(array): return f"
{text}
" -def format_dims(dims, dims_with_index): - if not dims: +def format_dims(dim_sizes, dims_with_index) -> str: + if not dim_sizes: return "" dim_css_map = { - dim: " class='xr-has-index'" if dim in dims_with_index else "" for dim in dims + dim: " class='xr-has-index'" if dim in dims_with_index else "" + for dim in dim_sizes } dims_li = "".join( - f"
  • " f"{escape(str(dim))}: {size}
  • " - for dim, size in dims.items() + f"
  • {escape(str(dim))}: {size}
  • " + for dim, size in dim_sizes.items() ) return f"
      {dims_li}
    " -def summarize_attrs(attrs): +def summarize_attrs(attrs) -> str: attrs_dl = "".join( - f"
    {escape(str(k))} :
    " f"
    {escape(str(v))}
    " + f"
    {escape(str(k))} :
    {escape(str(v))}
    " for k, v in attrs.items() ) return f"
    {attrs_dl}
    " -def _icon(icon_name): +def _icon(icon_name) -> str: # icon_name should be defined in xarray/static/html/icon-svg-inline.html return ( - "" - "" + f"" + f"" "" - "".format(icon_name) + "" ) -def summarize_variable(name, var, is_index=False, dtype=None): +def summarize_variable(name, var, is_index=False, dtype=None) -> str: variable = var.variable if hasattr(var, "variable") else var cssclass_idx = " class='xr-has-index'" if is_index else "" @@ -109,7 +110,7 @@ def summarize_variable(name, var, is_index=False, dtype=None): ) -def summarize_coords(variables): +def summarize_coords(variables) -> str: li_items = [] for k, v in variables.items(): li_content = summarize_variable(k, v, is_index=k in variables.xindexes) @@ -120,7 +121,7 @@ def summarize_coords(variables): return f"
      {vars_li}
    " -def summarize_vars(variables): +def summarize_vars(variables) -> str: vars_li = "".join( f"
  • {summarize_variable(k, v)}
  • " for k, v in variables.items() @@ -129,14 +130,14 @@ def summarize_vars(variables): return f"
      {vars_li}
    " -def short_index_repr_html(index): +def short_index_repr_html(index) -> str: if hasattr(index, "_repr_html_"): return index._repr_html_() return f"
    {escape(repr(index))}
    " -def summarize_index(coord_names, index): +def summarize_index(coord_names, index) -> str: name = "
    ".join([escape(str(n)) for n in coord_names]) index_id = f"index-{uuid.uuid4()}" @@ -155,7 +156,7 @@ def summarize_index(coord_names, index): ) -def summarize_indexes(indexes): +def summarize_indexes(indexes) -> str: indexes_li = "".join( f"
  • {summarize_index(v, i)}
  • " for v, i in indexes.items() @@ -165,7 +166,7 @@ def summarize_indexes(indexes): def collapsible_section( name, inline_details="", details="", n_items=None, enabled=True, collapsed=False -): +) -> str: # "unique" id to expand/collapse the section data_id = "section-" + str(uuid.uuid4()) @@ -187,7 +188,7 @@ def collapsible_section( def _mapping_section( mapping, name, details_func, max_items_collapse, expand_option_name, enabled=True -): +) -> str: n_items = len(mapping) expanded = _get_boolean_with_default( expand_option_name, n_items < max_items_collapse @@ -203,15 +204,15 @@ def _mapping_section( ) -def dim_section(obj): - dim_list = format_dims(obj.dims, obj.xindexes.dims) +def dim_section(obj) -> str: + dim_list = format_dims(obj.sizes, obj.xindexes.dims) return collapsible_section( "Dimensions", inline_details=dim_list, enabled=False, collapsed=True ) -def array_section(obj): +def array_section(obj) -> str: # "unique" id to expand/collapse the section data_id = "section-" + str(uuid.uuid4()) collapsed = ( @@ -296,7 +297,7 @@ def _obj_repr(obj, header_components, sections): ) -def array_repr(arr): +def array_repr(arr) -> str: dims = OrderedDict((k, v) for k, v in zip(arr.dims, arr.shape)) if hasattr(arr, "xindexes"): indexed_dims = arr.xindexes.dims @@ -326,7 +327,7 @@ def array_repr(arr): return _obj_repr(arr, header_components, sections) -def dataset_repr(ds): +def dataset_repr(ds) -> str: obj_type = f"xarray.{type(ds).__name__}" header_components = [f"
    {escape(obj_type)}
    "] diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 7de975c9c0a..3fbfb74d985 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1,13 +1,18 @@ from __future__ import annotations +import copy import datetime import warnings +from abc import ABC, abstractmethod from collections.abc import Hashable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union, cast +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union import numpy as np import pandas as pd +from packaging.version import Version +from xarray.coding.cftime_offsets import _new_to_legacy_freq from xarray.core import dtypes, duck_array_ops, nputils, ops from xarray.core._aggregations import ( DataArrayGroupByAggregations, @@ -24,29 +29,37 @@ safe_cast_to_index, ) from xarray.core.options import _get_keep_attrs -from xarray.core.pycompat import integer_types -from xarray.core.types import Dims, QuantileMethods, T_Xarray +from xarray.core.types import ( + Dims, + QuantileMethods, + T_DataArray, + T_DataWithCoords, + T_Xarray, +) from xarray.core.utils import ( + FrozenMappingWarningOnValuesAccess, either_dict_or_kwargs, + emit_user_level_warning, hashable, is_scalar, maybe_wrap_array, peek_at, ) from xarray.core.variable import IndexVariable, Variable +from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: from numpy.typing import ArrayLike from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + from xarray.core.resample_cftime import CFTimeGrouper from xarray.core.types import DatetimeLike, SideOptions from xarray.core.utils import Frozen GroupKey = Any - - T_GroupIndicesListInt = list[list[int]] - T_GroupIndices = Union[T_GroupIndicesListInt, list[slice], np.ndarray] + GroupIndex = Union[int, slice, list[int]] + T_GroupIndices = list[GroupIndex] def check_reduce_dims(reduce_dims, dimensions): @@ -57,12 +70,32 @@ def check_reduce_dims(reduce_dims, dimensions): raise ValueError( f"cannot reduce over dimensions {reduce_dims!r}. expected either '...' " f"to reduce over all dimensions or one or more of {dimensions!r}." + f" Try passing .groupby(..., squeeze=False)" ) +def _maybe_squeeze_indices( + indices, squeeze: bool | None, grouper: ResolvedGrouper, warn: bool +): + is_unique_grouper = isinstance(grouper.grouper, UniqueGrouper) + can_squeeze = is_unique_grouper and grouper.grouper.can_squeeze + if squeeze in [None, True] and can_squeeze: + if isinstance(indices, slice): + if indices.stop - indices.start == 1: + if (squeeze is None and warn) or squeeze is True: + emit_user_level_warning( + "The `squeeze` kwarg to GroupBy is being removed." + "Pass .groupby(..., squeeze=False) to disable squeezing," + " which is the new default, and to silence this warning." + ) + + indices = indices.start + return indices + + def unique_value_groups( ar, sort: bool = True -) -> tuple[np.ndarray | pd.Index, T_GroupIndices, np.ndarray]: +) -> tuple[np.ndarray | pd.Index, np.ndarray]: """Group an array by its unique values. Parameters @@ -83,12 +116,12 @@ def unique_value_groups( inverse, values = pd.factorize(ar, sort=sort) if isinstance(values, pd.MultiIndex): values.names = ar.names - groups = _codes_to_groups(inverse, len(values)) - return values, groups, inverse + return values, inverse -def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndicesListInt: - groups: T_GroupIndicesListInt = [[] for _ in range(N)] +def _codes_to_group_indices(inverse: np.ndarray, N: int) -> T_GroupIndices: + assert inverse.ndim == 1 + groups: T_GroupIndices = [[] for _ in range(N)] for n, g in enumerate(inverse): if g >= 0: groups[g].append(n) @@ -129,13 +162,13 @@ def _dummy_copy(xarray_obj): return res -def _is_one_or_none(obj): +def _is_one_or_none(obj) -> bool: return obj == 1 or obj is None -def _consolidate_slices(slices): +def _consolidate_slices(slices: list[slice]) -> list[slice]: """Consolidate adjacent slices in a list of slices.""" - result = [] + result: list[slice] = [] last_slice = slice(None) for slice_ in slices: if not isinstance(slice_, slice): @@ -179,13 +212,13 @@ def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray return newpositions[newpositions != -1] -class _DummyGroup: +class _DummyGroup(Generic[T_Xarray]): """Class for keeping track of grouped dimensions without coordinates. Should not be user visible. """ - __slots__ = ("name", "coords", "size") + __slots__ = ("name", "coords", "size", "dataarray") def __init__(self, obj: T_Xarray, name: Hashable, coords) -> None: self.name = name @@ -208,55 +241,71 @@ def values(self) -> range: def data(self) -> range: return range(self.size) + def __array__(self) -> np.ndarray: + return np.arange(self.size) + @property def shape(self) -> tuple[int]: return (self.size,) + @property + def attrs(self) -> dict: + return {} + def __getitem__(self, key): if isinstance(key, tuple): key = key[0] return self.values[key] + def to_index(self) -> pd.Index: + # could be pd.RangeIndex? + return pd.Index(np.arange(self.size)) + def copy(self, deep: bool = True, data: Any = None): raise NotImplementedError + def to_dataarray(self) -> DataArray: + from xarray.core.dataarray import DataArray -T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup]) + return DataArray( + data=self.data, dims=(self.name,), coords=self.coords, name=self.name + ) + def to_array(self) -> DataArray: + """Deprecated version of to_dataarray.""" + return self.to_dataarray() + + +T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup] -def _ensure_1d( - group: T_Group, obj: T_Xarray -) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable]]: - # 1D cases: do nothing - from xarray.core.dataarray import DataArray +def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ + T_Group, + T_DataWithCoords, + Hashable | None, + list[Hashable], +]: + # 1D cases: do nothing if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1: return group, obj, None, [] + from xarray.core.dataarray import DataArray + if isinstance(group, DataArray): # try to stack the dims of the group into a single dim orig_dims = group.dims stacked_dim = "stacked_" + "_".join(map(str, orig_dims)) # these dimensions get created by the stack operation inserted_dims = [dim for dim in group.dims if dim not in group.coords] - # the copy is necessary here, otherwise read only array raises error - # in pandas: https://github.com/pydata/pandas/issues/12813 - newgroup = group.stack({stacked_dim: orig_dims}).copy() + newgroup = group.stack({stacked_dim: orig_dims}) newobj = obj.stack({stacked_dim: orig_dims}) - return cast(T_Group, newgroup), newobj, stacked_dim, inserted_dims + return newgroup, newobj, stacked_dim, inserted_dims raise TypeError( f"group must be DataArray, IndexVariable or _DummyGroup, got {type(group)!r}." ) -def _unique_and_monotonic(group: T_Group) -> bool: - if isinstance(group, _DummyGroup): - return True - index = safe_cast_to_index(group) - return index.is_unique and index.is_monotonic_increasing - - def _apply_loffset( loffset: str | pd.DateOffset | datetime.timedelta | pd.Timedelta, result: pd.Series | pd.DataFrame, @@ -294,91 +343,446 @@ def _apply_loffset( result.index = result.index + loffset -def _get_index_and_items(index, grouper): - first_items, codes = grouper.first_items(index) - full_index = first_items.index - if first_items.isnull().any(): - first_items = first_items.dropna() - return full_index, first_items, codes +class Grouper(ABC): + """Base class for Grouper objects that allow specializing GroupBy instructions.""" + @property + def can_squeeze(self) -> bool: + """TODO: delete this when the `squeeze` kwarg is deprecated. Only `UniqueGrouper` + should override it.""" + return False -def _factorize_grouper( - group, grouper -) -> tuple[ - DataArray | IndexVariable | _DummyGroup, - T_GroupIndices, - np.ndarray, - pd.Index, -]: - index = safe_cast_to_index(group) - if not index.is_monotonic_increasing: - # TODO: sort instead of raising an error - raise ValueError("index must be monotonic for resampling") - full_index, first_items, codes = _get_index_and_items(index, grouper) - sbins = first_items.values.astype(np.int64) - group_indices: T_GroupIndices = [ - slice(i, j) for i, j in zip(sbins[:-1], sbins[1:]) - ] + [slice(sbins[-1], None)] - unique_coord = IndexVariable(group.name, first_items.index) - return unique_coord, group_indices, codes, full_index - - -def _factorize_bins( - group, bins, cut_kwargs: Mapping | None -) -> tuple[IndexVariable, T_GroupIndices, np.ndarray, pd.IntervalIndex, DataArray]: - from xarray.core.dataarray import DataArray + @abstractmethod + def factorize(self, group) -> EncodedGroups: + """ + Takes the group, and creates intermediates necessary for GroupBy. + These intermediates are + 1. codes - Same shape as `group` containing a unique integer code for each group. + 2. group_indices - Indexes that let us index out the members of each group. + 3. unique_coord - Unique groups present in the dataset. + 4. full_index - Unique groups in the output. This differs from `unique_coord` in the + case of resampling and binning, where certain groups in the output are not present in + the input. + """ + pass - if cut_kwargs is None: - cut_kwargs = {} - - if duck_array_ops.isnull(bins).all(): - raise ValueError("All bin edges are NaN.") - binned, bins = pd.cut(group.values, bins, **cut_kwargs, retbins=True) - codes = binned.codes - if (codes == -1).all(): - raise ValueError(f"None of the data falls within bins with edges {bins!r}") - full_index = binned.categories - unique_values = binned.unique().dropna() - group_indices = [g for g in _codes_to_groups(codes, len(full_index)) if g] - - if len(group_indices) == 0: - raise ValueError(f"None of the data falls within bins with edges {bins!r}") - - new_dim_name = str(group.name) + "_bins" - group_ = DataArray(binned, getattr(group, "coords", None), name=new_dim_name) - unique_coord = IndexVariable(new_dim_name, unique_values) - return unique_coord, group_indices, codes, full_index, group_ - - -def _factorize_rest( - group, -) -> tuple[IndexVariable, T_GroupIndices, np.ndarray]: - # look through group to find the unique values - group_as_index = safe_cast_to_index(group) - sort = not isinstance(group_as_index, pd.MultiIndex) - unique_values, group_indices, codes = unique_value_groups(group_as_index, sort=sort) - if len(group_indices) == 0: - raise ValueError( - "Failed to group data. Are you grouping by a variable that is all NaN?" + +class Resampler(Grouper): + """Base class for Grouper objects that allow specializing resampling-type GroupBy instructions. + Currently only used for TimeResampler, but could be used for SpaceResampler in the future. + """ + + pass + + +@dataclass +class EncodedGroups: + """ + Dataclass for storing intermediate values for GroupBy operation. + Returned by factorize method on Grouper objects. + + Parameters + ---------- + codes: integer codes for each group + full_index: pandas Index for the group coordinate + group_indices: optional, List of indices of array elements belonging + to each group. Inferred if not provided. + unique_coord: Unique group values present in dataset. Inferred if not provided + """ + + codes: DataArray + full_index: pd.Index + group_indices: T_GroupIndices | None = field(default=None) + unique_coord: IndexVariable | _DummyGroup | None = field(default=None) + + +@dataclass +class ResolvedGrouper(Generic[T_DataWithCoords]): + """ + Wrapper around a Grouper object. + + The Grouper object represents an abstract instruction to group an object. + The ResolvedGrouper object is a concrete version that contains all the common + logic necessary for a GroupBy problem including the intermediates necessary for + executing a GroupBy calculation. Specialization to the grouping problem at hand, + is accomplished by calling the `factorize` method on the encapsulated Grouper + object. + + This class is private API, while Groupers are public. + """ + + grouper: Grouper + group: T_Group + obj: T_DataWithCoords + + # returned by factorize: + codes: DataArray = field(init=False) + full_index: pd.Index = field(init=False) + group_indices: T_GroupIndices = field(init=False) + unique_coord: IndexVariable | _DummyGroup = field(init=False) + + # _ensure_1d: + group1d: T_Group = field(init=False) + stacked_obj: T_DataWithCoords = field(init=False) + stacked_dim: Hashable | None = field(init=False) + inserted_dims: list[Hashable] = field(init=False) + + def __post_init__(self) -> None: + # This copy allows the BinGrouper.factorize() method + # to update BinGrouper.bins when provided as int, using the output + # of pd.cut + # We do not want to modify the original object, since the same grouper + # might be used multiple times. + self.grouper = copy.deepcopy(self.grouper) + + self.group: T_Group = _resolve_group(self.obj, self.group) + + ( + self.group1d, + self.stacked_obj, + self.stacked_dim, + self.inserted_dims, + ) = _ensure_1d(group=self.group, obj=self.obj) + + self.factorize() + + @property + def name(self) -> Hashable: + # the name has to come from unique_coord because we need `_bins` suffix for BinGrouper + return self.unique_coord.name + + @property + def size(self) -> int: + return len(self) + + def __len__(self) -> int: + return len(self.full_index) + + @property + def dims(self): + return self.group1d.dims + + def factorize(self) -> None: + encoded = self.grouper.factorize(self.group1d) + + self.codes = encoded.codes + self.full_index = encoded.full_index + + if encoded.group_indices is not None: + self.group_indices = encoded.group_indices + else: + self.group_indices = [ + g + for g in _codes_to_group_indices(self.codes.data, len(self.full_index)) + if g + ] + if encoded.unique_coord is None: + unique_values = self.full_index[np.unique(encoded.codes)] + self.unique_coord = IndexVariable( + self.group.name, unique_values, attrs=self.group.attrs + ) + else: + self.unique_coord = encoded.unique_coord + + +@dataclass +class UniqueGrouper(Grouper): + """Grouper object for grouping by a categorical variable.""" + + _group_as_index: pd.Index | None = None + + @property + def is_unique_and_monotonic(self) -> bool: + if isinstance(self.group, _DummyGroup): + return True + index = self.group_as_index + return index.is_unique and index.is_monotonic_increasing + + @property + def group_as_index(self) -> pd.Index: + if self._group_as_index is None: + self._group_as_index = self.group.to_index() + return self._group_as_index + + @property + def can_squeeze(self) -> bool: + is_dimension = self.group.dims == (self.group.name,) + return is_dimension and self.is_unique_and_monotonic + + def factorize(self, group1d) -> EncodedGroups: + self.group = group1d + + if self.can_squeeze: + return self._factorize_dummy() + else: + return self._factorize_unique() + + def _factorize_unique(self) -> EncodedGroups: + # look through group to find the unique values + sort = not isinstance(self.group_as_index, pd.MultiIndex) + unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort) + if (codes_ == -1).all(): + raise ValueError( + "Failed to group data. Are you grouping by a variable that is all NaN?" + ) + codes = self.group.copy(data=codes_) + unique_coord = IndexVariable( + self.group.name, unique_values, attrs=self.group.attrs ) - unique_coord = IndexVariable(group.name, unique_values) - return unique_coord, group_indices, codes + full_index = unique_coord + return EncodedGroups( + codes=codes, full_index=full_index, unique_coord=unique_coord + ) -def _factorize_dummy( - group, squeeze: bool -) -> tuple[IndexVariable, T_GroupIndices, np.ndarray]: - # no need to factorize - group_indices: T_GroupIndices - if not squeeze: + def _factorize_dummy(self) -> EncodedGroups: + size = self.group.size + # no need to factorize # use slices to do views instead of fancy indexing # equivalent to: group_indices = group_indices.reshape(-1, 1) - group_indices = [slice(i, i + 1) for i in range(group.size)] + group_indices: T_GroupIndices = [slice(i, i + 1) for i in range(size)] + size_range = np.arange(size) + if isinstance(self.group, _DummyGroup): + codes = self.group.to_dataarray().copy(data=size_range) + else: + codes = self.group.copy(data=size_range) + unique_coord = self.group + full_index = IndexVariable( + self.group.name, unique_coord.values, self.group.attrs + ) + return EncodedGroups( + codes=codes, + group_indices=group_indices, + full_index=full_index, + unique_coord=unique_coord, + ) + + +@dataclass +class BinGrouper(Grouper): + """Grouper object for binning numeric data.""" + + bins: Any # TODO: What is the typing? + cut_kwargs: Mapping = field(default_factory=dict) + binned: Any = None + name: Any = None + + def __post_init__(self) -> None: + if duck_array_ops.isnull(self.bins).all(): + raise ValueError("All bin edges are NaN.") + + def factorize(self, group) -> EncodedGroups: + from xarray.core.dataarray import DataArray + + data = group.data + + binned, self.bins = pd.cut(data, self.bins, **self.cut_kwargs, retbins=True) + + binned_codes = binned.codes + if (binned_codes == -1).all(): + raise ValueError( + f"None of the data falls within bins with edges {self.bins!r}" + ) + + new_dim_name = f"{group.name}_bins" + + full_index = binned.categories + uniques = np.sort(pd.unique(binned_codes)) + unique_values = full_index[uniques[uniques != -1]] + + codes = DataArray( + binned_codes, getattr(group, "coords", None), name=new_dim_name + ) + unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs) + return EncodedGroups( + codes=codes, full_index=full_index, unique_coord=unique_coord + ) + + +@dataclass +class TimeResampler(Resampler): + """Grouper object specialized to resampling the time coordinate.""" + + freq: str + closed: SideOptions | None = field(default=None) + label: SideOptions | None = field(default=None) + origin: str | DatetimeLike = field(default="start_day") + offset: pd.Timedelta | datetime.timedelta | str | None = field(default=None) + loffset: datetime.timedelta | str | None = field(default=None) + base: int | None = field(default=None) + + index_grouper: CFTimeGrouper | pd.Grouper = field(init=False) + group_as_index: pd.Index = field(init=False) + + def __post_init__(self): + if self.loffset is not None: + emit_user_level_warning( + "Following pandas, the `loffset` parameter to resample is deprecated. " + "Switch to updating the resampled dataset time coordinate using " + "time offset arithmetic. For example:\n" + " >>> offset = pd.tseries.frequencies.to_offset(freq) / 2\n" + ' >>> resampled_ds["time"] = resampled_ds.get_index("time") + offset', + FutureWarning, + ) + + if self.base is not None: + emit_user_level_warning( + "Following pandas, the `base` parameter to resample will be deprecated in " + "a future version of xarray. Switch to using `origin` or `offset` instead.", + FutureWarning, + ) + + if self.base is not None and self.offset is not None: + raise ValueError("base and offset cannot be present at the same time") + + def _init_properties(self, group: T_Group) -> None: + from xarray import CFTimeIndex + from xarray.core.pdcompat import _convert_base_to_offset + + group_as_index = safe_cast_to_index(group) + + if self.base is not None: + # grouper constructor verifies that grouper.offset is None at this point + offset = _convert_base_to_offset(self.base, self.freq, group_as_index) + else: + offset = self.offset + + if not group_as_index.is_monotonic_increasing: + # TODO: sort instead of raising an error + raise ValueError("index must be monotonic for resampling") + + if isinstance(group_as_index, CFTimeIndex): + from xarray.core.resample_cftime import CFTimeGrouper + + index_grouper = CFTimeGrouper( + freq=self.freq, + closed=self.closed, + label=self.label, + origin=self.origin, + offset=offset, + loffset=self.loffset, + ) + else: + index_grouper = pd.Grouper( + # TODO remove once requiring pandas >= 2.2 + freq=_new_to_legacy_freq(self.freq), + closed=self.closed, + label=self.label, + origin=self.origin, + offset=offset, + ) + self.index_grouper = index_grouper + self.group_as_index = group_as_index + + def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: + first_items, codes = self.first_items() + full_index = first_items.index + if first_items.isnull().any(): + first_items = first_items.dropna() + + full_index = full_index.rename("__resample_dim__") + return full_index, first_items, codes + + def first_items(self) -> tuple[pd.Series, np.ndarray]: + from xarray import CFTimeIndex + + if isinstance(self.group_as_index, CFTimeIndex): + return self.index_grouper.first_items(self.group_as_index) + else: + s = pd.Series(np.arange(self.group_as_index.size), self.group_as_index) + grouped = s.groupby(self.index_grouper) + first_items = grouped.first() + counts = grouped.count() + # This way we generate codes for the final output index: full_index. + # So for _flox_reduce we avoid one reindex and copy by avoiding + # _maybe_restore_empty_groups + codes = np.repeat(np.arange(len(first_items)), counts) + if self.loffset is not None: + _apply_loffset(self.loffset, first_items) + return first_items, codes + + def factorize(self, group) -> EncodedGroups: + self._init_properties(group) + full_index, first_items, codes_ = self._get_index_and_items() + sbins = first_items.values.astype(np.int64) + group_indices: T_GroupIndices = [ + slice(i, j) for i, j in zip(sbins[:-1], sbins[1:]) + ] + group_indices += [slice(sbins[-1], None)] + + unique_coord = IndexVariable(group.name, first_items.index, group.attrs) + codes = group.copy(data=codes_) + + return EncodedGroups( + codes=codes, + group_indices=group_indices, + full_index=full_index, + unique_coord=unique_coord, + ) + + +def _validate_groupby_squeeze(squeeze: bool | None) -> None: + # While we don't generally check the type of every arg, passing + # multiple dimensions as multiple arguments is common enough, and the + # consequences hidden enough (strings evaluate as true) to warrant + # checking here. + # A future version could make squeeze kwarg only, but would face + # backward-compat issues. + if squeeze is not None and not isinstance(squeeze, bool): + raise TypeError( + f"`squeeze` must be None, True or False, but {squeeze} was supplied" + ) + + +def _resolve_group(obj: T_DataWithCoords, group: T_Group | Hashable) -> T_Group: + from xarray.core.dataarray import DataArray + + error_msg = ( + "the group variable's length does not " + "match the length of this variable along its " + "dimensions" + ) + + newgroup: T_Group + if isinstance(group, DataArray): + try: + align(obj, group, join="exact", copy=False) + except ValueError: + raise ValueError(error_msg) + + newgroup = group.copy(deep=False) + newgroup.name = group.name or "group" + + elif isinstance(group, IndexVariable): + # This assumption is built in to _ensure_1d. + if group.ndim != 1: + raise ValueError( + "Grouping by multi-dimensional IndexVariables is not allowed." + "Convert to and pass a DataArray instead." + ) + (group_dim,) = group.dims + if len(group) != obj.sizes[group_dim]: + raise ValueError(error_msg) + newgroup = DataArray(group) + else: - group_indices = np.arange(group.size) - codes = np.arange(group.size) - unique_coord = group - return unique_coord, group_indices, codes + if not hashable(group): + raise TypeError( + "`group` must be an xarray.DataArray or the " + "name of an xarray variable or dimension. " + f"Received {group!r} instead." + ) + group_da: DataArray = obj[group] + if group_da.name not in obj._indexes and group_da.name in obj.dims: + # DummyGroups should not appear on groupby results + newgroup = _DummyGroup(obj, group_da.name, group_da.coords) + else: + newgroup = group_da + + if newgroup.size == 0: + raise ValueError(f"{newgroup.name} must not be empty") + + return newgroup class GroupBy(Generic[T_Xarray]): @@ -405,6 +809,7 @@ class GroupBy(Generic[T_Xarray]): "_group_dim", "_group_indices", "_groups", + "groupers", "_obj", "_restore_coord_dims", "_stacked_dim", @@ -419,16 +824,26 @@ class GroupBy(Generic[T_Xarray]): "_codes", ) _obj: T_Xarray + groupers: tuple[ResolvedGrouper] + _squeeze: bool | None + _restore_coord_dims: bool + + _original_obj: T_Xarray + _original_group: T_Group + _group_indices: T_GroupIndices + _codes: DataArray + _group_dim: Hashable + + _groups: dict[GroupKey, GroupIndex] | None + _dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None + _sizes: Mapping[Hashable, int] | None def __init__( self, obj: T_Xarray, - group: Hashable | DataArray | IndexVariable, - squeeze: bool = False, - grouper: pd.Grouper | None = None, - bins: ArrayLike | None = None, + groupers: tuple[ResolvedGrouper], + squeeze: bool | None = False, restore_coord_dims: bool = True, - cut_kwargs: Mapping[Any, Any] | None = None, ) -> None: """Create a GroupBy object @@ -436,103 +851,36 @@ def __init__( ---------- obj : Dataset or DataArray Object to group. - group : Hashable, DataArray or Index - Array with the group values or name of the variable. - squeeze : bool, default: False - If "group" is a coordinate of object, `squeeze` controls whether - the subarrays have a dimension of length 1 along that coordinate or - if the dimension is squeezed out. - grouper : pandas.Grouper, optional - Used for grouping values along the `group` array. - bins : array-like, optional - If `bins` is specified, the groups will be discretized into the - specified bins by `pandas.cut`. + grouper : Grouper + Grouper object restore_coord_dims : bool, default: True If True, also restore the dimension order of multi-dimensional coordinates. - cut_kwargs : dict-like, optional - Extra keyword arguments to pass to `pandas.cut` - """ - from xarray.core.dataarray import DataArray - - if grouper is not None and bins is not None: - raise TypeError("can't specify both `grouper` and `bins`") - - if not isinstance(group, (DataArray, IndexVariable)): - if not hashable(group): - raise TypeError( - "`group` must be an xarray.DataArray or the " - "name of an xarray variable or dimension. " - f"Received {group!r} instead." - ) - group = obj[group] - if len(group) == 0: - raise ValueError(f"{group.name} must not be empty") - - if group.name not in obj.coords and group.name in obj.dims: - # DummyGroups should not appear on groupby results - group = _DummyGroup(obj, group.name, group.coords) - - if getattr(group, "name", None) is None: - group.name = "group" + self.groupers = groupers - self._original_obj: T_Xarray = obj - self._original_group = group - self._bins = bins + self._original_obj = obj - group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj) - (group_dim,) = group.dims - - expected_size = obj.sizes[group_dim] - if group.size != expected_size: - raise ValueError( - "the group variable's length does not " - "match the length of this variable along its " - "dimension" - ) - - self._codes: DataArray - if grouper is not None: - unique_coord, group_indices, codes, full_index = _factorize_grouper( - group, grouper - ) - self._codes = group.copy(data=codes) - elif bins is not None: - unique_coord, group_indices, codes, full_index, group = _factorize_bins( - group, bins, cut_kwargs - ) - self._codes = group.copy(data=codes) - elif group.dims == (group.name,) and _unique_and_monotonic(group): - unique_coord, group_indices, codes = _factorize_dummy(group, squeeze) - full_index = None - self._codes = obj[group.name].copy(data=codes) - else: - unique_coord, group_indices, codes = _factorize_rest(group) - full_index = None - self._codes = group.copy(data=codes) + (grouper,) = self.groupers + self._original_group = grouper.group # specification for the groupby operation - self._obj: T_Xarray = obj - self._group = group - self._group_dim = group_dim - self._group_indices = group_indices - self._unique_coord = unique_coord - self._stacked_dim = stacked_dim - self._inserted_dims = inserted_dims - self._full_index = full_index + self._obj = grouper.stacked_obj self._restore_coord_dims = restore_coord_dims - self._bins = bins self._squeeze = squeeze - self._codes = self._maybe_unstack(self._codes) + # These should generalize to multiple groupers + self._group_indices = grouper.group_indices + self._codes = self._maybe_unstack(grouper.codes) + + (self._group_dim,) = grouper.group1d.dims # cached attributes - self._groups: dict[GroupKey, slice | int | list[int]] | None = None - self._dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None = None - self._sizes: Frozen[Hashable, int] | None = None + self._groups = None + self._dims = None + self._sizes = None @property - def sizes(self) -> Frozen[Hashable, int]: + def sizes(self) -> Mapping[Hashable, int]: """Ordered mapping from dimension names to lengths. Immutable. @@ -543,9 +891,14 @@ def sizes(self) -> Frozen[Hashable, int]: Dataset.sizes """ if self._sizes is None: - self._sizes = self._obj.isel( - {self._group_dim: self._group_indices[0]} - ).sizes + (grouper,) = self.groupers + index = _maybe_squeeze_indices( + self._group_indices[0], + self._squeeze, + grouper, + warn=True, + ) + self._sizes = self._obj.isel({self._group_dim: index}).sizes return self._sizes @@ -572,46 +925,63 @@ def reduce( raise NotImplementedError() @property - def groups(self) -> dict[GroupKey, slice | int | list[int]]: + def groups(self) -> dict[GroupKey, GroupIndex]: """ Mapping from group labels to indices. The indices can be used to index the underlying object. """ # provided to mimic pandas.groupby if self._groups is None: - self._groups = dict(zip(self._unique_coord.values, self._group_indices)) + (grouper,) = self.groupers + squeezed_indices = ( + _maybe_squeeze_indices(ind, self._squeeze, grouper, warn=idx > 0) + for idx, ind in enumerate(self._group_indices) + ) + self._groups = dict(zip(grouper.unique_coord.values, squeezed_indices)) return self._groups def __getitem__(self, key: GroupKey) -> T_Xarray: """ Get DataArray or Dataset corresponding to a particular group label. """ - return self._obj.isel({self._group_dim: self.groups[key]}) + (grouper,) = self.groupers + index = _maybe_squeeze_indices( + self.groups[key], self._squeeze, grouper, warn=True + ) + return self._obj.isel({self._group_dim: index}) def __len__(self) -> int: - return self._unique_coord.size + (grouper,) = self.groupers + return grouper.size def __iter__(self) -> Iterator[tuple[GroupKey, T_Xarray]]: - return zip(self._unique_coord.values, self._iter_grouped()) + (grouper,) = self.groupers + return zip(grouper.unique_coord.data, self._iter_grouped()) def __repr__(self) -> str: + (grouper,) = self.groupers return "{}, grouped over {!r}\n{!r} groups with labels {}.".format( self.__class__.__name__, - self._unique_coord.name, - self._unique_coord.size, - ", ".join(format_array_flat(self._unique_coord, 30).split()), + grouper.name, + grouper.full_index.size, + ", ".join(format_array_flat(grouper.full_index, 30).split()), ) - def _iter_grouped(self) -> Iterator[T_Xarray]: + def _iter_grouped(self, warn_squeeze=True) -> Iterator[T_Xarray]: """Iterate over each element in this group""" - for indices in self._group_indices: + (grouper,) = self.groupers + for idx, indices in enumerate(self._group_indices): + indices = _maybe_squeeze_indices( + indices, self._squeeze, grouper, warn=warn_squeeze and idx == 0 + ) yield self._obj.isel({self._group_dim: indices}) def _infer_concat_args(self, applied_example): + (grouper,) = self.groupers if self._group_dim in applied_example.dims: - coord = self._group + coord = grouper.group1d positions = self._group_indices else: - coord = self._unique_coord + coord = grouper.unique_coord positions = None (dim,) = coord.dims if isinstance(coord, _DummyGroup): @@ -625,19 +995,19 @@ def _binary_op(self, other, f, reflexive=False): g = f if not reflexive else lambda x, y: f(y, x) + (grouper,) = self.groupers obj = self._original_obj - group = self._original_group + group = grouper.group codes = self._codes dims = group.dims if isinstance(group, _DummyGroup): - group = obj[group.name] - coord = group + group = coord = group.to_dataarray() else: - coord = self._unique_coord + coord = grouper.unique_coord if not isinstance(coord, DataArray): - coord = DataArray(self._unique_coord) - name = self._group.name + coord = DataArray(grouper.unique_coord) + name = grouper.name if not isinstance(other, (Dataset, DataArray)): raise TypeError( @@ -666,9 +1036,27 @@ def _binary_op(self, other, f, reflexive=False): mask = codes == -1 if mask.any(): obj = obj.where(~mask, drop=True) + group = group.where(~mask, drop=True) codes = codes.where(~mask, drop=True).astype(int) - other, _ = align(other, coord, join="outer") + # if other is dask-backed, that's a hint that the + # "expanded" dataset is too big to hold in memory. + # this can be the case when `other` was read from disk + # and contains our lazy indexing classes + # We need to check for dask-backed Datasets + # so utils.is_duck_dask_array does not work for this check + if obj.chunks and not other.chunks: + # TODO: What about datasets with some dask vars, and others not? + # This handles dims other than `name`` + chunks = {k: v for k, v in obj.chunksizes.items() if k in other.dims} + # a chunk size of 1 seems reasonable since we expect individual elements of + # other to be repeated multiple times across the reduced dimension(s) + chunks[name] = 1 + other = other.chunk(chunks) + + # codes are defined for coord, so we align `other` with `coord` + # before indexing + other, _ = align(other, coord, join="right", copy=False) expanded = other.isel({name: codes}) result = g(obj, expanded) @@ -688,20 +1076,27 @@ def _binary_op(self, other, f, reflexive=False): return result def _maybe_restore_empty_groups(self, combined): - """Our index contained empty groups (e.g., from a resampling). If we + """Our index contained empty groups (e.g., from a resampling or binning). If we reduced on that dimension, we want to restore the full index. """ - if self._full_index is not None and self._group.name in combined.dims: - indexers = {self._group.name: self._full_index} + (grouper,) = self.groupers + if ( + isinstance(grouper.grouper, (BinGrouper, TimeResampler)) + and grouper.name in combined.dims + ): + indexers = {grouper.name: grouper.full_index} combined = combined.reindex(**indexers) return combined def _maybe_unstack(self, obj): """This gets called if we are applying on an array with a multidimensional group.""" - if self._stacked_dim is not None and self._stacked_dim in obj.dims: - obj = obj.unstack(self._stacked_dim) - for dim in self._inserted_dims: + (grouper,) = self.groupers + stacked_dim = grouper.stacked_dim + inserted_dims = grouper.inserted_dims + if stacked_dim is not None and stacked_dim in obj.dims: + obj = obj.unstack(stacked_dim) + for dim in inserted_dims: if dim in obj.coords: del obj.coords[dim] obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords)) @@ -714,20 +1109,23 @@ def _flox_reduce( **kwargs: Any, ): """Adaptor function that translates our groupby API to that of flox.""" + import flox from flox.xarray import xarray_reduce from xarray.core.dataset import Dataset obj = self._original_obj - group = self._original_group + (grouper,) = self.groupers + isbin = isinstance(grouper.grouper, BinGrouper) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) - # preserve current strategy (approximately) for dask groupby. - # We want to control the default anyway to prevent surprises - # if flox decides to change its default - kwargs.setdefault("method", "cohorts") + if Version(flox.__version__) < Version("0.9"): + # preserve current strategy (approximately) for dask groupby + # on older flox versions to prevent surprises. + # flox >=0.9 will choose this on its own. + kwargs.setdefault("method", "cohorts") numeric_only = kwargs.pop("numeric_only", None) if numeric_only: @@ -739,25 +1137,30 @@ def _flox_reduce( else: non_numeric = {} + if "min_count" in kwargs: + if kwargs["func"] not in ["sum", "prod"]: + raise TypeError("Received an unexpected keyword argument 'min_count'") + elif kwargs["min_count"] is None: + # set explicitly to avoid unnecessarily accumulating count + kwargs["min_count"] = 0 + # weird backcompat # reducing along a unique indexed dimension with squeeze=True # should raise an error - if ( - dim is None or dim == self._group.name - ) and self._group.name in obj.xindexes: - index = obj.indexes[self._group.name] + if (dim is None or dim == grouper.name) and grouper.name in obj.xindexes: + index = obj.indexes[grouper.name] if index.is_unique and self._squeeze: - raise ValueError(f"cannot reduce over dimensions {self._group.name!r}") + raise ValueError(f"cannot reduce over dimensions {grouper.name!r}") unindexed_dims: tuple[Hashable, ...] = tuple() - if isinstance(group, _DummyGroup) and self._bins is None: - unindexed_dims = (group.name,) + if isinstance(grouper.group, _DummyGroup) and not isbin: + unindexed_dims = (grouper.name,) parsed_dim: tuple[Hashable, ...] if isinstance(dim, str): parsed_dim = (dim,) elif dim is None: - parsed_dim = group.dims + parsed_dim = grouper.group.dims elif dim is ...: parsed_dim = tuple(obj.dims) else: @@ -765,12 +1168,12 @@ def _flox_reduce( # Do this so we raise the same error message whether flox is present or not. # Better to control it here than in flox. - if any(d not in group.dims and d not in obj.dims for d in parsed_dim): + if any(d not in grouper.group.dims and d not in obj.dims for d in parsed_dim): raise ValueError(f"cannot reduce over dimensions {dim}.") if kwargs["func"] not in ["all", "any", "count"]: kwargs.setdefault("fill_value", np.nan) - if self._bins is not None and kwargs["func"] == "count": + if isbin and kwargs["func"] == "count": # This is an annoying hack. Xarray returns np.nan # when there are no observations in a bin, instead of 0. # We can fake that here by forcing min_count=1. @@ -779,7 +1182,7 @@ def _flox_reduce( kwargs.setdefault("fill_value", np.nan) kwargs.setdefault("min_count", 1) - output_index = self._get_output_index() + output_index = grouper.full_index result = xarray_reduce( obj.drop_vars(non_numeric.keys()), self._codes, @@ -793,35 +1196,29 @@ def _flox_reduce( # we did end up reducing over dimension(s) that are # in the grouped variable - if set(self._codes.dims).issubset(set(parsed_dim)): - result[self._unique_coord.name] = output_index + group_dims = grouper.group.dims + if set(group_dims).issubset(set(parsed_dim)): + result[grouper.name] = output_index result = result.drop_vars(unindexed_dims) # broadcast and restore non-numeric data variables (backcompat) for name, var in non_numeric.items(): if all(d not in var.dims for d in parsed_dim): result[name] = var.variable.set_dims( - (group.name,) + var.dims, (result.sizes[group.name],) + var.shape + (grouper.name,) + var.dims, + (result.sizes[grouper.name],) + var.shape, ) - if self._bins is not None: + if isbin: # Fix dimension order when binning a dimension coordinate # Needed as long as we do a separate code path for pint; # For some reason Datasets and DataArrays behave differently! - if isinstance(self._obj, Dataset) and self._group_dim in self._obj.dims: - result = result.transpose(self._group.name, ...) + (group_dim,) = grouper.dims + if isinstance(self._obj, Dataset) and group_dim in self._obj.dims: + result = result.transpose(grouper.name, ...) return result - def _get_output_index(self) -> pd.Index: - """Return pandas.Index object for the output array.""" - if self._full_index is not None: - # binning and resample - return self._full_index.rename(self._unique_coord.name) - if isinstance(self._unique_coord, _DummyGroup): - return IndexVariable(self._group.name, self._unique_coord.values) - return self._unique_coord - def fillna(self, value: Any) -> T_Xarray: """Fill missing values in this object by group. @@ -848,10 +1245,12 @@ def fillna(self, value: Any) -> T_Xarray: """ return ops.fillna(self, value) + @_deprecate_positional_args("v2023.10.0") def quantile( self, q: ArrayLike, dim: Dims = None, + *, method: QuantileMethods = "linear", keep_attrs: bool | None = None, skipna: bool | None = None, @@ -873,15 +1272,15 @@ def quantile( desired quantile lies between two data points. The options sorted by their R type as summarized in the H&F paper [1]_ are: - 1. "inverted_cdf" (*) - 2. "averaged_inverted_cdf" (*) - 3. "closest_observation" (*) - 4. "interpolated_inverted_cdf" (*) - 5. "hazen" (*) - 6. "weibull" (*) + 1. "inverted_cdf" + 2. "averaged_inverted_cdf" + 3. "closest_observation" + 4. "interpolated_inverted_cdf" + 5. "hazen" + 6. "weibull" 7. "linear" (default) - 8. "median_unbiased" (*) - 9. "normal_unbiased" (*) + 8. "median_unbiased" + 9. "normal_unbiased" The first three methods are discontiuous. The following discontinuous variations of the default "linear" (7.) option are also available: @@ -891,9 +1290,8 @@ def quantile( * "midpoint" * "nearest" - See :py:func:`numpy.quantile` or [1]_ for details. Methods marked with - an asterix require numpy version 1.22 or newer. The "method" argument was - previously called "interpolation", renamed in accordance with numpy + See :py:func:`numpy.quantile` or [1]_ for details. The "method" argument + was previously called "interpolation", renamed in accordance with numpy version 1.22.0. keep_attrs : bool or None, default: None If True, the dataarray's attributes (`attrs`) will be copied from @@ -929,23 +1327,23 @@ def quantile( ... ) >>> ds = xr.Dataset({"a": da}) >>> da.groupby("x").quantile(0) - + Size: 64B array([[0.7, 4.2, 0.7, 1.5], [6.5, 7.3, 2.6, 1.9]]) Coordinates: - * y (y) int64 1 1 2 2 - quantile float64 0.0 - * x (x) int64 0 1 + * y (y) int64 32B 1 1 2 2 + quantile float64 8B 0.0 + * x (x) int64 16B 0 1 >>> ds.groupby("y").quantile(0, dim=...) - + Size: 40B Dimensions: (y: 2) Coordinates: - quantile float64 0.0 - * y (y) int64 1 2 + quantile float64 8B 0.0 + * y (y) int64 16B 1 2 Data variables: - a (y) float64 0.7 0.7 + a (y) float64 16B 0.7 0.7 >>> da.groupby("x").quantile([0, 0.5, 1]) - + Size: 192B array([[[0.7 , 1. , 1.3 ], [4.2 , 6.3 , 8.4 ], [0.7 , 5.05, 9.4 ], @@ -956,17 +1354,17 @@ def quantile( [2.6 , 2.6 , 2.6 ], [1.9 , 1.9 , 1.9 ]]]) Coordinates: - * y (y) int64 1 1 2 2 - * quantile (quantile) float64 0.0 0.5 1.0 - * x (x) int64 0 1 + * y (y) int64 32B 1 1 2 2 + * quantile (quantile) float64 24B 0.0 0.5 1.0 + * x (x) int64 16B 0 1 >>> ds.groupby("y").quantile([0, 0.5, 1], dim=...) - + Size: 88B Dimensions: (y: 2, quantile: 3) Coordinates: - * quantile (quantile) float64 0.0 0.5 1.0 - * y (y) int64 1 2 + * quantile (quantile) float64 24B 0.0 0.5 1.0 + * y (y) int64 16B 1 2 Data variables: - a (y, quantile) float64 0.7 5.35 8.4 0.7 2.25 9.4 + a (y, quantile) float64 48B 0.7 5.35 8.4 0.7 2.25 9.4 References ---------- @@ -975,7 +1373,8 @@ def quantile( The American Statistician, 50(4), pp. 361-365, 1996 """ if dim is None: - dim = (self._group_dim,) + (grouper,) = self.groupers + dim = grouper.group1d.dims return self.map( self._obj.__class__.quantile, @@ -1010,7 +1409,11 @@ def where(self, cond, other=dtypes.NA) -> T_Xarray: return ops.where_method(self, cond, other) def _first_or_last(self, op, skipna, keep_attrs): - if isinstance(self._group_indices[0], integer_types): + if all( + isinstance(maybe_slice, slice) + and (maybe_slice.stop == maybe_slice.start + 1) + for maybe_slice in self._group_indices + ): # NB. this is currently only used for reductions along an existing # dimension return self._obj @@ -1058,16 +1461,24 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic): @property def dims(self) -> tuple[Hashable, ...]: if self._dims is None: - self._dims = self._obj.isel({self._group_dim: self._group_indices[0]}).dims + (grouper,) = self.groupers + index = _maybe_squeeze_indices( + self._group_indices[0], self._squeeze, grouper, warn=True + ) + self._dims = self._obj.isel({self._group_dim: index}).dims return self._dims - def _iter_grouped_shortcut(self): + def _iter_grouped_shortcut(self, warn_squeeze=True): """Fast version of `_iter_grouped` that yields Variables without metadata """ var = self._obj.variable - for indices in self._group_indices: + (grouper,) = self.groupers + for idx, indices in enumerate(self._group_indices): + indices = _maybe_squeeze_indices( + indices, self._squeeze, grouper, warn=warn_squeeze and idx == 0 + ) yield var[{self._group_dim: indices}] def _concat_shortcut(self, applied, dim, positions=None): @@ -1078,13 +1489,24 @@ def _concat_shortcut(self, applied, dim, positions=None): # TODO: benbovy - explicit indexes: this fast implementation doesn't # create an explicit index for the stacked dim coordinate stacked = Variable.concat(applied, dim, shortcut=True) - reordered = _maybe_reorder(stacked, dim, positions, N=self._group.size) + + (grouper,) = self.groupers + reordered = _maybe_reorder(stacked, dim, positions, N=grouper.group.size) return self._obj._replace_maybe_drop_dims(reordered) def _restore_dim_order(self, stacked: DataArray) -> DataArray: + (grouper,) = self.groupers + group = grouper.group1d + + groupby_coord = ( + f"{group.name}_bins" + if isinstance(grouper.grouper, BinGrouper) + else group.name + ) + def lookup_order(dimension): - if dimension == self._group.name: - (dimension,) = self._group.dims + if dimension == groupby_coord: + (dimension,) = group.dims if dimension in self._obj.dims: axis = self._obj.get_axis_num(dimension) else: @@ -1142,7 +1564,24 @@ def map( applied : DataArray The result of splitting, applying and combining this array. """ - grouped = self._iter_grouped_shortcut() if shortcut else self._iter_grouped() + return self._map_maybe_warn( + func, args, warn_squeeze=True, shortcut=shortcut, **kwargs + ) + + def _map_maybe_warn( + self, + func: Callable[..., DataArray], + args: tuple[Any, ...] = (), + *, + warn_squeeze: bool = True, + shortcut: bool | None = None, + **kwargs: Any, + ) -> DataArray: + grouped = ( + self._iter_grouped_shortcut(warn_squeeze) + if shortcut + else self._iter_grouped(warn_squeeze) + ) applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs)) for arr in grouped) return self._combine(applied, shortcut=shortcut) @@ -1169,7 +1608,8 @@ def _combine(self, applied, shortcut=False): combined = self._concat_shortcut(applied, dim, positions) else: combined = concat(applied, dim) - combined = _maybe_reorder(combined, dim, positions, N=self._group.size) + (grouper,) = self.groupers + combined = _maybe_reorder(combined, dim, positions, N=grouper.group.size) if isinstance(combined, type(self._obj)): # only restore dimension order for arrays @@ -1243,6 +1683,68 @@ def reduce_array(ar: DataArray) -> DataArray: return self.map(reduce_array, shortcut=shortcut) + def _reduce_without_squeeze_warn( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> DataArray: + """Reduce the items in this group by applying `func` along some + dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, axis=axis, **kwargs)` to return the result of collapsing + an np.ndarray over an integer valued axis. + dim : "...", str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. If None, apply over the + groupby dimension, if "..." apply over all dimensions. + axis : int or sequence of int, optional + Axis(es) over which to apply `func`. Only one of the 'dimension' + and 'axis' arguments can be supplied. If neither are supplied, then + `func` is calculated over all dimension for each group item. + keep_attrs : bool, optional + If True, the datasets's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : Array + Array with summarized data and the indicated dimension(s) + removed. + """ + if dim is None: + dim = [self._group_dim] + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + def reduce_array(ar: DataArray) -> DataArray: + return ar.reduce( + func=func, + dim=dim, + axis=axis, + keep_attrs=keep_attrs, + keepdims=keepdims, + **kwargs, + ) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="The `squeeze` kwarg") + check_reduce_dims(dim, self.dims) + + return self._map_maybe_warn(reduce_array, shortcut=shortcut, warn_squeeze=False) + # https://github.com/python/mypy/issues/9031 class DataArrayGroupBy( # type: ignore[misc] @@ -1260,9 +1762,16 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic): @property def dims(self) -> Frozen[Hashable, int]: if self._dims is None: - self._dims = self._obj.isel({self._group_dim: self._group_indices[0]}).dims + (grouper,) = self.groupers + index = _maybe_squeeze_indices( + self._group_indices[0], + self._squeeze, + grouper, + warn=True, + ) + self._dims = self._obj.isel({self._group_dim: index}).dims - return self._dims + return FrozenMappingWarningOnValuesAccess(self._dims) def map( self, @@ -1300,8 +1809,18 @@ def map( applied : Dataset The result of splitting, applying and combining this dataset. """ + return self._map_maybe_warn(func, args, shortcut, warn_squeeze=True, **kwargs) + + def _map_maybe_warn( + self, + func: Callable[..., Dataset], + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + warn_squeeze: bool = False, + **kwargs: Any, + ) -> Dataset: # ignore shortcut if set (for now) - applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped()) + applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped(warn_squeeze)) return self._combine(applied) def apply(self, func, args=(), shortcut=None, **kwargs): @@ -1325,7 +1844,8 @@ def _combine(self, applied): applied_example, applied = peek_at(applied) coord, dim, positions = self._infer_concat_args(applied_example) combined = concat(applied, dim) - combined = _maybe_reorder(combined, dim, positions, N=self._group.size) + (grouper,) = self.groupers + combined = _maybe_reorder(combined, dim, positions, N=grouper.group.size) # assign coord when the applied function does not return that coord if coord is not None and dim not in applied_example.dims: index, index_vars = create_default_index_implicit(coord) @@ -1395,6 +1915,68 @@ def reduce_dataset(ds: Dataset) -> Dataset: return self.map(reduce_dataset) + def _reduce_without_squeeze_warn( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> Dataset: + """Reduce the items in this group by applying `func` along some + dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, axis=axis, **kwargs)` to return the result of collapsing + an np.ndarray over an integer valued axis. + dim : ..., str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. By default apply over the + groupby dimension, with "..." apply over all dimensions. + axis : int or sequence of int, optional + Axis(es) over which to apply `func`. Only one of the 'dimension' + and 'axis' arguments can be supplied. If neither are supplied, then + `func` is calculated over all dimension for each group item. + keep_attrs : bool, optional + If True, the datasets's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : Dataset + Array with summarized data and the indicated dimension(s) + removed. + """ + if dim is None: + dim = [self._group_dim] + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + def reduce_dataset(ds: Dataset) -> Dataset: + return ds.reduce( + func=func, + dim=dim, + axis=axis, + keep_attrs=keep_attrs, + keepdims=keepdims, + **kwargs, + ) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="The `squeeze` kwarg") + check_reduce_dims(dim, self.dims) + + return self._map_maybe_warn(reduce_dataset, warn_squeeze=False) + def assign(self, **kwargs: Any) -> Dataset: """Assign data variables by group. @@ -1412,56 +1994,3 @@ class DatasetGroupBy( # type: ignore[misc] ImplementsDatasetReduce, ): __slots__ = () - - -class TimeResampleGrouper: - def __init__( - self, - freq: str, - closed: SideOptions | None, - label: SideOptions | None, - origin: str | DatetimeLike, - offset: pd.Timedelta | datetime.timedelta | str | None, - loffset: datetime.timedelta | str | None, - ): - self.freq = freq - self.closed = closed - self.label = label - self.origin = origin - self.offset = offset - self.loffset = loffset - - def first_items(self, index): - from xarray import CFTimeIndex - from xarray.core.resample_cftime import CFTimeGrouper - - if isinstance(index, CFTimeIndex): - grouper = CFTimeGrouper( - freq=self.freq, - closed=self.closed, - label=self.label, - origin=self.origin, - offset=self.offset, - loffset=self.loffset, - ) - return grouper.first_items(index) - else: - s = pd.Series(np.arange(index.size), index, copy=False) - grouper = pd.Grouper( - freq=self.freq, - closed=self.closed, - label=self.label, - origin=self.origin, - offset=self.offset, - ) - - grouped = s.groupby(grouper) - first_items = grouped.first() - counts = grouped.count() - # This way we generate codes for the final output index: full_index. - # So for _flox_reduce we avoid one reindex and copy by avoiding - # _maybe_restore_empty_groups - codes = np.repeat(np.arange(len(first_items)), counts) - if self.loffset is not None: - _apply_loffset(self.loffset, first_items) - return first_items, codes diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 5f42c50e26f..e71c4a6f073 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -15,20 +15,47 @@ PandasIndexingAdapter, PandasMultiIndexingAdapter, ) -from xarray.core.utils import Frozen, get_valid_numpy_dtype, is_dict_like, is_scalar +from xarray.core.utils import ( + Frozen, + emit_user_level_warning, + get_valid_numpy_dtype, + is_dict_like, + is_scalar, +) if TYPE_CHECKING: - from xarray.core.types import ErrorOptions, T_Index + from xarray.core.types import ErrorOptions, JoinOptions, Self from xarray.core.variable import Variable + IndexVars = dict[Any, "Variable"] class Index: - """Base class inherited by all xarray-compatible indexes. + """ + Base class inherited by all xarray-compatible indexes. - Do not use this class directly for creating index objects. + Do not use this class directly for creating index objects. Xarray indexes + are created exclusively from subclasses of ``Index``, mostly via Xarray's + public API like ``Dataset.set_xindex``. + + Every subclass must at least implement :py:meth:`Index.from_variables`. The + (re)implementation of the other methods of this base class is optional but + mostly required in order to support operations relying on indexes such as + label-based selection or alignment. + The ``Index`` API closely follows the :py:meth:`Dataset` and + :py:meth:`DataArray` API, e.g., for an index to support ``.sel()`` it needs + to implement :py:meth:`Index.sel`, to support ``.stack()`` and + ``.unstack()`` it needs to implement :py:meth:`Index.stack` and + :py:meth:`Index.unstack`, etc. + + When a method is not (re)implemented, depending on the case the + corresponding operation on a :py:meth:`Dataset` or :py:meth:`DataArray` + either will raise a ``NotImplementedError`` or will simply drop/pass/copy + the index from/to the result. + + Do not use this class directly for creating index objects. """ @classmethod @@ -37,30 +64,129 @@ def from_variables( variables: Mapping[Any, Variable], *, options: Mapping[str, Any], - ) -> Index: + ) -> Self: + """Create a new index object from one or more coordinate variables. + + This factory method must be implemented in all subclasses of Index. + + The coordinate variables may be passed here in an arbitrary number and + order and each with arbitrary dimensions. It is the responsibility of + the index to check the consistency and validity of these coordinates. + + Parameters + ---------- + variables : dict-like + Mapping of :py:class:`Variable` objects holding the coordinate labels + to index. + + Returns + ------- + index : Index + A new Index object. + """ raise NotImplementedError() @classmethod def concat( - cls: type[T_Index], - indexes: Sequence[T_Index], + cls, + indexes: Sequence[Self], dim: Hashable, positions: Iterable[Iterable[int]] | None = None, - ) -> T_Index: + ) -> Self: + """Create a new index by concatenating one or more indexes of the same + type. + + Implementation is optional but required in order to support + ``concat``. Otherwise it will raise an error if the index needs to be + updated during the operation. + + Parameters + ---------- + indexes : sequence of Index objects + Indexes objects to concatenate together. All objects must be of the + same type. + dim : Hashable + Name of the dimension to concatenate along. + positions : None or list of integer arrays, optional + List of integer arrays which specifies the integer positions to which + to assign each dataset along the concatenated dimension. If not + supplied, objects are concatenated in the provided order. + + Returns + ------- + index : Index + A new Index object. + """ raise NotImplementedError() @classmethod - def stack(cls, variables: Mapping[Any, Variable], dim: Hashable) -> Index: + def stack(cls, variables: Mapping[Any, Variable], dim: Hashable) -> Self: + """Create a new index by stacking coordinate variables into a single new + dimension. + + Implementation is optional but required in order to support ``stack``. + Otherwise it will raise an error when trying to pass the Index subclass + as argument to :py:meth:`Dataset.stack`. + + Parameters + ---------- + variables : dict-like + Mapping of :py:class:`Variable` objects to stack together. + dim : Hashable + Name of the new, stacked dimension. + + Returns + ------- + index + A new Index object. + """ raise NotImplementedError( f"{cls!r} cannot be used for creating an index of stacked coordinates" ) def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]: + """Unstack a (multi-)index into multiple (single) indexes. + + Implementation is optional but required in order to support unstacking + the coordinates from which this index has been built. + + Returns + ------- + indexes : tuple + A 2-length tuple where the 1st item is a dictionary of unstacked + Index objects and the 2nd item is a :py:class:`pandas.MultiIndex` + object used to unstack unindexed coordinate variables or data + variables. + """ raise NotImplementedError() def create_variables( self, variables: Mapping[Any, Variable] | None = None ) -> IndexVars: + """Maybe create new coordinate variables from this index. + + This method is useful if the index data can be reused as coordinate + variable data. It is often the case when the underlying index structure + has an array-like interface, like :py:class:`pandas.Index` objects. + + The variables given as argument (if any) are either returned as-is + (default behavior) or can be used to copy their metadata (attributes and + encoding) into the new returned coordinate variables. + + Note: the input variables may or may not have been filtered for this + index. + + Parameters + ---------- + variables : dict-like, optional + Mapping of :py:class:`Variable` objects. + + Returns + ------- + index_variables : dict-like + Dictionary of :py:class:`Variable` or :py:class:`IndexVariable` + objects. + """ if variables is not None: # pass through return dict(**variables) @@ -68,54 +194,213 @@ def create_variables( return {} def to_pandas_index(self) -> pd.Index: - """Cast this xarray index to a pandas.Index object or raise a TypeError - if this is not supported. + """Cast this xarray index to a pandas.Index object or raise a + ``TypeError`` if this is not supported. - This method is used by all xarray operations that expect/require a - pandas.Index object. + This method is used by all xarray operations that still rely on + pandas.Index objects. + By default it raises a ``TypeError``, unless it is re-implemented in + subclasses of Index. """ raise TypeError(f"{self!r} cannot be cast to a pandas.Index object") def isel( self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] - ) -> Index | None: + ) -> Self | None: + """Maybe returns a new index from the current index itself indexed by + positional indexers. + + This method should be re-implemented in subclasses of Index if the + wrapped index structure supports indexing operations. For example, + indexing a ``pandas.Index`` is pretty straightforward as it behaves very + much like an array. By contrast, it may be harder doing so for a + structure like a kd-tree that differs much from a simple array. + + If not re-implemented in subclasses of Index, this method returns + ``None``, i.e., calling :py:meth:`Dataset.isel` will either drop the + index in the resulting dataset or pass it unchanged if its corresponding + coordinate(s) are not indexed. + + Parameters + ---------- + indexers : dict + A dictionary of positional indexers as passed from + :py:meth:`Dataset.isel` and where the entries have been filtered + for the current index. + + Returns + ------- + maybe_index : Index + A new Index object or ``None``. + """ return None def sel(self, labels: dict[Any, Any]) -> IndexSelResult: + """Query the index with arbitrary coordinate label indexers. + + Implementation is optional but required in order to support label-based + selection. Otherwise it will raise an error when trying to call + :py:meth:`Dataset.sel` with labels for this index coordinates. + + Coordinate label indexers can be of many kinds, e.g., scalar, list, + tuple, array-like, slice, :py:class:`Variable`, :py:class:`DataArray`, etc. + It is the responsibility of the index to handle those indexers properly. + + Parameters + ---------- + labels : dict + A dictionary of coordinate label indexers passed from + :py:meth:`Dataset.sel` and where the entries have been filtered + for the current index. + + Returns + ------- + sel_results : :py:class:`IndexSelResult` + An index query result object that contains dimension positional indexers. + It may also contain new indexes, coordinate variables, etc. + """ raise NotImplementedError(f"{self!r} doesn't support label-based selection") - def join(self: T_Index, other: T_Index, how: str = "inner") -> T_Index: + def join(self, other: Self, how: JoinOptions = "inner") -> Self: + """Return a new index from the combination of this index with another + index of the same type. + + Implementation is optional but required in order to support alignment. + + Parameters + ---------- + other : Index + The other Index object to combine with this index. + join : str, optional + Method for joining the two indexes (see :py:func:`~xarray.align`). + + Returns + ------- + joined : Index + A new Index object. + """ raise NotImplementedError( f"{self!r} doesn't support alignment with inner/outer join method" ) - def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]: + def reindex_like(self, other: Self) -> dict[Hashable, Any]: + """Query the index with another index of the same type. + + Implementation is optional but required in order to support alignment. + + Parameters + ---------- + other : Index + The other Index object used to query this index. + + Returns + ------- + dim_positional_indexers : dict + A dictionary where keys are dimension names and values are positional + indexers. + """ raise NotImplementedError(f"{self!r} doesn't support re-indexing labels") - def equals(self, other): # pragma: no cover + def equals(self, other: Self) -> bool: + """Compare this index with another index of the same type. + + Implementation is optional but required in order to support alignment. + + Parameters + ---------- + other : Index + The other Index object to compare with this object. + + Returns + ------- + is_equal : bool + ``True`` if the indexes are equal, ``False`` otherwise. + """ raise NotImplementedError() - def roll(self, shifts: Mapping[Any, int]) -> Index | None: + def roll(self, shifts: Mapping[Any, int]) -> Self | None: + """Roll this index by an offset along one or more dimensions. + + This method can be re-implemented in subclasses of Index, e.g., when the + index can be itself indexed. + + If not re-implemented, this method returns ``None``, i.e., calling + :py:meth:`Dataset.roll` will either drop the index in the resulting + dataset or pass it unchanged if its corresponding coordinate(s) are not + rolled. + + Parameters + ---------- + shifts : mapping of hashable to int, optional + A dict with keys matching dimensions and values given + by integers to rotate each of the given dimensions, as passed + :py:meth:`Dataset.roll`. + + Returns + ------- + rolled : Index + A new index with rolled data. + """ return None def rename( - self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable] - ) -> Index: + self, + name_dict: Mapping[Any, Hashable], + dims_dict: Mapping[Any, Hashable], + ) -> Self: + """Maybe update the index with new coordinate and dimension names. + + This method should be re-implemented in subclasses of Index if it has + attributes that depend on coordinate or dimension names. + + By default (if not re-implemented), it returns the index itself. + + Warning: the input names are not filtered for this method, they may + correspond to any variable or dimension of a Dataset or a DataArray. + + Parameters + ---------- + name_dict : dict-like + Mapping of current variable or coordinate names to the desired names, + as passed from :py:meth:`Dataset.rename_vars`. + dims_dict : dict-like + Mapping of current dimension names to the desired names, as passed + from :py:meth:`Dataset.rename_dims`. + + Returns + ------- + renamed : Index + Index with renamed attributes. + """ return self - def __copy__(self) -> Index: - return self._copy(deep=False) + def copy(self, deep: bool = True) -> Self: + """Return a (deep) copy of this index. - def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Index: - return self._copy(deep=True, memo=memo) + Implementation in subclasses of Index is optional. The base class + implements the default (deep) copy semantics. - def copy(self: T_Index, deep: bool = True) -> T_Index: + Parameters + ---------- + deep : bool, optional + If true (default), a copy of the internal structures + (e.g., wrapped index) is returned with the new object. + + Returns + ------- + index : Index + A new Index object. + """ return self._copy(deep=deep) - def _copy( - self: T_Index, deep: bool = True, memo: dict[int, Any] | None = None - ) -> T_Index: + def __copy__(self) -> Self: + return self.copy(deep=False) + + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Index: + return self._copy(deep=True, memo=memo) + + def _copy(self, deep: bool = True, memo: dict[int, Any] | None = None) -> Self: cls = self.__class__ copied = cls.__new__(cls) if deep: @@ -125,7 +410,7 @@ def _copy( copied.__dict__.update(self.__dict__) return copied - def __getitem__(self, indexer: Any): + def __getitem__(self, indexer: Any) -> Self: raise NotImplementedError() def _repr_inline_(self, max_width): @@ -135,7 +420,7 @@ def _repr_inline_(self, max_width): def _maybe_cast_to_cftimeindex(index: pd.Index) -> pd.Index: from xarray.coding.cftimeindex import CFTimeIndex - if len(index) > 0 and index.dtype == "O": + if len(index) > 0 and index.dtype == "O" and not isinstance(index, CFTimeIndex): try: return CFTimeIndex(index) except (ImportError, TypeError): @@ -166,9 +451,21 @@ def safe_cast_to_index(array: Any) -> pd.Index: elif isinstance(array, PandasIndexingAdapter): index = array.array else: - kwargs = {} - if hasattr(array, "dtype") and array.dtype.kind == "O": - kwargs["dtype"] = object + kwargs: dict[str, str] = {} + if hasattr(array, "dtype"): + if array.dtype.kind == "O": + kwargs["dtype"] = "object" + elif array.dtype == "float16": + emit_user_level_warning( + ( + "`pandas.Index` does not support the `float16` dtype." + " Casting to `float64` for you, but in the future please" + " manually cast to either `float32` and `float64`." + ), + category=DeprecationWarning, + ) + kwargs["dtype"] = "float64" + index = pd.Index(np.asarray(array), **kwargs) return _maybe_cast_to_cftimeindex(index) @@ -259,6 +556,8 @@ def get_indexer_nd(index, labels, method=None, tolerance=None): labels """ flat_labels = np.ravel(labels) + if flat_labels.dtype == "float16": + flat_labels = flat_labels.astype("float64") flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance) indexer = flat_indexer.reshape(labels.shape) return indexer @@ -313,7 +612,14 @@ def from_variables( name, var = next(iter(variables.items())) - if var.ndim != 1: + if var.ndim == 0: + raise ValueError( + f"cannot set a PandasIndex from the scalar variable {name!r}, " + "only 1-dimensional variables are supported. " + f"Note: you might want to use `obj.expand_dims({name!r})` to create a " + f"new dimension and turn {name!r} as an indexed dimension coordinate." + ) + elif var.ndim != 1: raise ValueError( "PandasIndex only accepts a 1-dimensional variable, " f"variable {name!r} has {var.ndim} dimensions" @@ -364,10 +670,10 @@ def _concat_indexes(indexes, dim, positions=None) -> pd.Index: @classmethod def concat( cls, - indexes: Sequence[PandasIndex], + indexes: Sequence[Self], dim: Hashable, positions: Iterable[Iterable[int]] | None = None, - ) -> PandasIndex: + ) -> Self: new_pd_index = cls._concat_indexes(indexes, dim, positions) if not indexes: @@ -490,7 +796,11 @@ def equals(self, other: Index): return False return self.index.equals(other.index) and self.dim == other.dim - def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasIndex: + def join( + self, + other: Self, + how: str = "inner", + ) -> Self: if how == "outer": index = self.index.union(other.index) else: @@ -501,7 +811,7 @@ def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasInd return type(self)(index, self.dim, coord_dtype=coord_dtype) def reindex_like( - self, other: PandasIndex, method=None, tolerance=None + self, other: Self, method=None, tolerance=None ) -> dict[Hashable, Any]: if not self.index.is_unique: raise ValueError( @@ -653,12 +963,12 @@ def from_variables( return obj @classmethod - def concat( # type: ignore[override] + def concat( cls, - indexes: Sequence[PandasMultiIndex], + indexes: Sequence[Self], dim: Hashable, positions: Iterable[Iterable[int]] | None = None, - ) -> PandasMultiIndex: + ) -> Self: new_pd_index = cls._concat_indexes(indexes, dim, positions) if not indexes: @@ -707,6 +1017,13 @@ def stack( def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]: clean_index = remove_unused_levels_categories(self.index) + if not clean_index.is_unique: + raise ValueError( + "Cannot unstack MultiIndex containing duplicates. Make sure entries " + f"are unique, e.g., by calling ``.drop_duplicates('{self.dim}')``, " + "before unstacking." + ) + new_indexes: dict[Hashable, Index] = {} for name, lev in zip(clean_index.names, clean_index.levels): idx = PandasIndex( @@ -893,12 +1210,12 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: coord_name, label = next(iter(labels.items())) if is_dict_like(label): - invalid_levels = [ + invalid_levels = tuple( name for name in label if name not in self.index.names - ] + ) if invalid_levels: raise ValueError( - f"invalid multi-index level names {invalid_levels}" + f"multi-index level names {invalid_levels} not found in indexes {tuple(self.index.names)}" ) return self.sel(label) @@ -1078,19 +1395,22 @@ def create_default_index_implicit( class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): - """Immutable proxy for Dataset or DataArrary indexes. + """Immutable proxy for Dataset or DataArray indexes. - Keys are coordinate names and values may correspond to either pandas or - xarray indexes. + It is a mapping where keys are coordinate names and values are either pandas + or xarray indexes. - Also provides some utility methods. + It also contains the indexed coordinate variables and provides some utility + methods. """ + _index_type: type[Index] | type[pd.Index] _indexes: dict[Any, T_PandasOrXarrayIndex] _variables: dict[Any, Variable] __slots__ = ( + "_index_type", "_indexes", "_variables", "_dims", @@ -1101,8 +1421,9 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): def __init__( self, - indexes: dict[Any, T_PandasOrXarrayIndex], - variables: dict[Any, Variable], + indexes: Mapping[Any, T_PandasOrXarrayIndex] | None = None, + variables: Mapping[Any, Variable] | None = None, + index_type: type[Index] | type[pd.Index] = Index, ): """Constructor not for public consumption. @@ -1111,11 +1432,33 @@ def __init__( indexes : dict Indexes held by this object. variables : dict - Indexed coordinate variables in this object. + Indexed coordinate variables in this object. Entries must + match those of `indexes`. + index_type : type + The type of all indexes, i.e., either :py:class:`xarray.indexes.Index` + or :py:class:`pandas.Index`. """ - self._indexes = indexes - self._variables = variables + if indexes is None: + indexes = {} + if variables is None: + variables = {} + + unmatched_keys = set(indexes) ^ set(variables) + if unmatched_keys: + raise ValueError( + f"unmatched keys found in indexes and variables: {unmatched_keys}" + ) + + if any(not isinstance(idx, index_type) for idx in indexes.values()): + index_type_str = f"{index_type.__module__}.{index_type.__name__}" + raise TypeError( + f"values of indexes must all be instances of {index_type_str}" + ) + + self._index_type = index_type + self._indexes = dict(**indexes) + self._variables = dict(**variables) self._dims: Mapping[Hashable, int] | None = None self.__coord_name_id: dict[Any, int] | None = None @@ -1263,10 +1606,10 @@ def to_pandas_indexes(self) -> Indexes[pd.Index]: elif isinstance(idx, Index): indexes[k] = idx.to_pandas_index() - return Indexes(indexes, self._variables) + return Indexes(indexes, self._variables, index_type=pd.Index) def copy_indexes( - self, deep: bool = True, memo: dict[int, Any] | None = None + self, deep: bool = True, memo: dict[int, T_PandasOrXarrayIndex] | None = None ) -> tuple[dict[Hashable, T_PandasOrXarrayIndex], dict[Hashable, Variable]]: """Return a new dictionary with copies of indexes, preserving unique indexes. @@ -1283,6 +1626,7 @@ def copy_indexes( new_indexes = {} new_index_vars = {} + idx: T_PandasOrXarrayIndex for idx, coords in self.group_by_index(): if isinstance(idx, pd.Index): convert_new_idx = True @@ -1318,7 +1662,8 @@ def __getitem__(self, key) -> T_PandasOrXarrayIndex: return self._indexes[key] def __repr__(self): - return formatting.indexes_repr(self) + indexes = formatting._get_indexes_dict(self) + return formatting.indexes_repr(indexes) def default_indexes( @@ -1342,7 +1687,7 @@ def default_indexes( coord_names = set(coords) for name, var in coords.items(): - if name in dims: + if name in dims and var.ndim == 1: index, index_vars = create_default_index_implicit(var, coords) if set(index_vars) <= coord_names: indexes.update({k: index for k in index_vars}) @@ -1475,7 +1820,7 @@ def filter_indexes_from_coords( of coordinate names. """ - filtered_indexes: dict[Any, Index] = dict(**indexes) + filtered_indexes: dict[Any, Index] = dict(indexes) index_coord_names: dict[Hashable, set[Hashable]] = defaultdict(set) for name, idx in indexes.items(): diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 7109d4fdd2c..82ee4ccb0e4 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -17,26 +17,26 @@ from xarray.core import duck_array_ops from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS -from xarray.core.pycompat import ( - array_type, - integer_types, - is_duck_array, - is_duck_dask_array, -) from xarray.core.types import T_Xarray from xarray.core.utils import ( NDArrayMixin, either_dict_or_kwargs, get_valid_numpy_dtype, + is_duck_array, + is_duck_dask_array, is_scalar, to_0d_array, ) +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import array_type, integer_types, is_chunked_array if TYPE_CHECKING: from numpy.typing import DTypeLike from xarray.core.indexes import Index from xarray.core.variable import Variable + from xarray.namedarray._typing import _Shape, duckarray + from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @dataclass @@ -142,7 +142,10 @@ def group_indexers_by_index( elif key in obj.coords: raise KeyError(f"no index found for coordinate {key!r}") elif key not in obj.dims: - raise KeyError(f"{key!r} is not a valid dimension or coordinate") + raise KeyError( + f"{key!r} is not a valid dimension or coordinate for " + f"{obj.__class__.__name__} with dimensions {obj.dims!r}" + ) elif len(options): raise ValueError( f"cannot supply selection options {options!r} for dimension {key!r}" @@ -162,7 +165,7 @@ def map_index_queries( obj: T_Xarray, indexers: Mapping[Any, Any], method=None, - tolerance=None, + tolerance: int | float | Iterable[int | float] | None = None, **indexers_kwargs: Any, ) -> IndexSelResult: """Execute index queries from a DataArray / Dataset and label-based indexers @@ -233,17 +236,17 @@ def expanded_indexer(key, ndim): return tuple(new_key) -def _expand_slice(slice_, size): +def _expand_slice(slice_, size: int) -> np.ndarray: return np.arange(*slice_.indices(size)) -def _normalize_slice(sl, size): +def _normalize_slice(sl: slice, size) -> slice: """Ensure that given slice only contains positive start and stop values (stop can be -1 for full-size slices with negative steps, e.g. [-10::-1])""" return slice(*sl.indices(size)) -def slice_slice(old_slice, applied_slice, size): +def slice_slice(old_slice: slice, applied_slice: slice, size: int) -> slice: """Given a slice and the size of the dimension to which it will be applied, index it with another slice to return a new slice equivalent to applying the slices sequentially @@ -272,7 +275,7 @@ def slice_slice(old_slice, applied_slice, size): return slice(start, stop, step) -def _index_indexer_1d(old_indexer, applied_indexer, size): +def _index_indexer_1d(old_indexer, applied_indexer, size: int): assert isinstance(applied_indexer, integer_types + (slice, np.ndarray)) if isinstance(applied_indexer, slice) and applied_indexer == slice(None): # shortcut for the usual case @@ -281,7 +284,7 @@ def _index_indexer_1d(old_indexer, applied_indexer, size): if isinstance(applied_indexer, slice): indexer = slice_slice(old_indexer, applied_indexer, size) else: - indexer = _expand_slice(old_indexer, size)[applied_indexer] + indexer = _expand_slice(old_indexer, size)[applied_indexer] # type: ignore[assignment] else: indexer = old_indexer[applied_indexer] return indexer @@ -300,16 +303,16 @@ class ExplicitIndexer: __slots__ = ("_key",) - def __init__(self, key): + def __init__(self, key: tuple[Any, ...]): if type(self) is ExplicitIndexer: raise TypeError("cannot instantiate base ExplicitIndexer objects") self._key = tuple(key) @property - def tuple(self): + def tuple(self) -> tuple[Any, ...]: return self._key - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}({self.tuple})" @@ -324,6 +327,28 @@ def as_integer_slice(value): return slice(start, stop, step) +class IndexCallable: + """Provide getitem and setitem syntax for callable objects.""" + + __slots__ = ("getter", "setter") + + def __init__( + self, getter: Callable[..., Any], setter: Callable[..., Any] | None = None + ): + self.getter = getter + self.setter = setter + + def __getitem__(self, key: Any) -> Any: + return self.getter(key) + + def __setitem__(self, key: Any, value: Any) -> None: + if self.setter is None: + raise NotImplementedError( + "Setting values is not supported for this indexer." + ) + self.setter(key, value) + + class BasicIndexer(ExplicitIndexer): """Tuple for basic indexing. @@ -334,7 +359,7 @@ class BasicIndexer(ExplicitIndexer): __slots__ = () - def __init__(self, key): + def __init__(self, key: tuple[int | np.integer | slice, ...]): if not isinstance(key, tuple): raise TypeError(f"key must be a tuple: {key!r}") @@ -350,7 +375,7 @@ def __init__(self, key): ) new_key.append(k) - super().__init__(new_key) + super().__init__(tuple(new_key)) class OuterIndexer(ExplicitIndexer): @@ -364,7 +389,12 @@ class OuterIndexer(ExplicitIndexer): __slots__ = () - def __init__(self, key): + def __init__( + self, + key: tuple[ + int | np.integer | slice | np.ndarray[Any, np.dtype[np.generic]], ... + ], + ): if not isinstance(key, tuple): raise TypeError(f"key must be a tuple: {key!r}") @@ -379,19 +409,19 @@ def __init__(self, key): raise TypeError( f"invalid indexer array, does not have integer dtype: {k!r}" ) - if k.ndim > 1: + if k.ndim > 1: # type: ignore[union-attr] raise TypeError( f"invalid indexer array for {type(self).__name__}; must be scalar " f"or have 1 dimension: {k!r}" ) - k = k.astype(np.int64) + k = k.astype(np.int64) # type: ignore[union-attr] else: raise TypeError( f"unexpected indexer type for {type(self).__name__}: {k!r}" ) new_key.append(k) - super().__init__(new_key) + super().__init__(tuple(new_key)) class VectorizedIndexer(ExplicitIndexer): @@ -406,7 +436,7 @@ class VectorizedIndexer(ExplicitIndexer): __slots__ = () - def __init__(self, key): + def __init__(self, key: tuple[slice | np.ndarray[Any, np.dtype[np.generic]], ...]): if not isinstance(key, tuple): raise TypeError(f"key must be a tuple: {key!r}") @@ -427,21 +457,21 @@ def __init__(self, key): f"invalid indexer array, does not have integer dtype: {k!r}" ) if ndim is None: - ndim = k.ndim + ndim = k.ndim # type: ignore[union-attr] elif ndim != k.ndim: ndims = [k.ndim for k in key if isinstance(k, np.ndarray)] raise ValueError( "invalid indexer key: ndarray arguments " f"have different numbers of dimensions: {ndims}" ) - k = k.astype(np.int64) + k = k.astype(np.int64) # type: ignore[union-attr] else: raise TypeError( f"unexpected indexer type for {type(self).__name__}: {k!r}" ) new_key.append(k) - super().__init__(new_key) + super().__init__(tuple(new_key)) class ExplicitlyIndexed: @@ -449,13 +479,60 @@ class ExplicitlyIndexed: __slots__ = () + def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + # Leave casting to an array up to the underlying array type. + return np.asarray(self.get_duck_array(), dtype=dtype) + + def get_duck_array(self): + return self.array + class ExplicitlyIndexedNDArrayMixin(NDArrayMixin, ExplicitlyIndexed): __slots__ = () - def __array__(self, dtype=None): + def get_duck_array(self): key = BasicIndexer((slice(None),) * self.ndim) - return np.asarray(self[key], dtype=dtype) + return self[key] + + def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + # This is necessary because we apply the indexing key in self.get_duck_array() + # Note this is the base class for all lazy indexing classes + return np.asarray(self.get_duck_array(), dtype=dtype) + + def _oindex_get(self, indexer: OuterIndexer): + raise NotImplementedError( + f"{self.__class__.__name__}._oindex_get method should be overridden" + ) + + def _vindex_get(self, indexer: VectorizedIndexer): + raise NotImplementedError( + f"{self.__class__.__name__}._vindex_get method should be overridden" + ) + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + raise NotImplementedError( + f"{self.__class__.__name__}._oindex_set method should be overridden" + ) + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + raise NotImplementedError( + f"{self.__class__.__name__}._vindex_set method should be overridden" + ) + + def _check_and_raise_if_non_basic_indexer(self, indexer: ExplicitIndexer) -> None: + if isinstance(indexer, (VectorizedIndexer, OuterIndexer)): + raise TypeError( + "Vectorized indexing with vectorized or outer indexers is not supported. " + "Please use .vindex and .oindex properties to index the array." + ) + + @property + def oindex(self) -> IndexCallable: + return IndexCallable(self._oindex_get, self._oindex_set) + + @property + def vindex(self) -> IndexCallable: + return IndexCallable(self._vindex_get, self._vindex_set) class ImplicitToExplicitIndexingAdapter(NDArrayMixin): @@ -463,16 +540,22 @@ class ImplicitToExplicitIndexingAdapter(NDArrayMixin): __slots__ = ("array", "indexer_cls") - def __init__(self, array, indexer_cls=BasicIndexer): + def __init__(self, array, indexer_cls: type[ExplicitIndexer] = BasicIndexer): self.array = as_indexable(array) self.indexer_cls = indexer_cls - def __array__(self, dtype=None): - return np.asarray(self.array, dtype=dtype) + def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + return np.asarray(self.get_duck_array(), dtype=dtype) - def __getitem__(self, key): + def get_duck_array(self): + return self.array.get_duck_array() + + def __getitem__(self, key: Any): key = expanded_indexer(key, self.ndim) - result = self.array[self.indexer_cls(key)] + indexer = self.indexer_cls(key) + + result = apply_indexer(self.array, indexer) + if isinstance(result, ExplicitlyIndexed): return type(self)(result, self.indexer_cls) else: @@ -486,7 +569,7 @@ class LazilyIndexedArray(ExplicitlyIndexedNDArrayMixin): __slots__ = ("array", "key") - def __init__(self, array, key=None): + def __init__(self, array: Any, key: ExplicitIndexer | None = None): """ Parameters ---------- @@ -498,8 +581,8 @@ def __init__(self, array, key=None): """ if isinstance(array, type(self)) and key is None: # unwrap - key = array.key - array = array.array + key = array.key # type: ignore[has-type] + array = array.array # type: ignore[has-type] if key is None: key = BasicIndexer((slice(None),) * array.ndim) @@ -507,7 +590,7 @@ def __init__(self, array, key=None): self.array = as_indexable(array) self.key = key - def _updated_key(self, new_key): + def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer: iter_new_key = iter(expanded_indexer(new_key.tuple, self.ndim)) full_key = [] for size, k in zip(self.array.shape, self.key.tuple): @@ -515,14 +598,14 @@ def _updated_key(self, new_key): full_key.append(k) else: full_key.append(_index_indexer_1d(k, next(iter_new_key), size)) - full_key = tuple(full_key) + full_key_tuple = tuple(full_key) - if all(isinstance(k, integer_types + (slice,)) for k in full_key): - return BasicIndexer(full_key) - return OuterIndexer(full_key) + if all(isinstance(k, integer_types + (slice,)) for k in full_key_tuple): + return BasicIndexer(full_key_tuple) + return OuterIndexer(full_key_tuple) @property - def shape(self) -> tuple[int, ...]: + def shape(self) -> _Shape: shape = [] for size, k in zip(self.array.shape, self.key.tuple): if isinstance(k, slice): @@ -531,29 +614,52 @@ def shape(self) -> tuple[int, ...]: shape.append(k.size) return tuple(shape) - def __array__(self, dtype=None): - array = as_indexable(self.array) - return np.asarray(array[self.key], dtype=None) + def get_duck_array(self): + if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): + array = apply_indexer(self.array, self.key) + else: + # If the array is not an ExplicitlyIndexedNDArrayMixin, + # it may wrap a BackendArray so use its __getitem__ + array = self.array[self.key] + + # self.array[self.key] is now a numpy array when + # self.array is a BackendArray subclass + # and self.key is BasicIndexer((slice(None, None, None),)) + # so we need the explicit check for ExplicitlyIndexed + if isinstance(array, ExplicitlyIndexed): + array = array.get_duck_array() + return _wrap_numpy_scalars(array) def transpose(self, order): return LazilyVectorizedIndexedArray(self.array, self.key).transpose(order) - def __getitem__(self, indexer): - if isinstance(indexer, VectorizedIndexer): - array = LazilyVectorizedIndexedArray(self.array, self.key) - return array[indexer] + def _oindex_get(self, indexer: OuterIndexer): return type(self)(self.array, self._updated_key(indexer)) - def __setitem__(self, key, value): - if isinstance(key, VectorizedIndexer): - raise NotImplementedError( - "Lazy item assignment with the vectorized indexer is not yet " - "implemented. Load your data first by .load() or compute()." - ) + def _vindex_get(self, indexer: VectorizedIndexer): + array = LazilyVectorizedIndexedArray(self.array, self.key) + return array.vindex[indexer] + + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + return type(self)(self.array, self._updated_key(indexer)) + + def _vindex_set(self, key: VectorizedIndexer, value: Any) -> None: + raise NotImplementedError( + "Lazy item assignment with the vectorized indexer is not yet " + "implemented. Load your data first by .load() or compute()." + ) + + def _oindex_set(self, key: OuterIndexer, value: Any) -> None: + full_key = self._updated_key(key) + self.array.oindex[full_key] = value + + def __setitem__(self, key: BasicIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(key) full_key = self._updated_key(key) self.array[full_key] = value - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}(array={self.array!r}, key={self.key!r})" @@ -566,7 +672,7 @@ class LazilyVectorizedIndexedArray(ExplicitlyIndexedNDArrayMixin): __slots__ = ("array", "key") - def __init__(self, array, key): + def __init__(self, array: duckarray[Any, Any], key: ExplicitIndexer): """ Parameters ---------- @@ -576,21 +682,40 @@ def __init__(self, array, key): """ if isinstance(key, (BasicIndexer, OuterIndexer)): self.key = _outer_to_vectorized_indexer(key, array.shape) - else: + elif isinstance(key, VectorizedIndexer): self.key = _arrayize_vectorized_indexer(key, array.shape) self.array = as_indexable(array) @property - def shape(self) -> tuple[int, ...]: + def shape(self) -> _Shape: return np.broadcast(*self.key.tuple).shape - def __array__(self, dtype=None): - return np.asarray(self.array[self.key], dtype=None) - - def _updated_key(self, new_key): + def get_duck_array(self): + if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): + array = apply_indexer(self.array, self.key) + else: + # If the array is not an ExplicitlyIndexedNDArrayMixin, + # it may wrap a BackendArray so use its __getitem__ + array = self.array[self.key] + # self.array[self.key] is now a numpy array when + # self.array is a BackendArray subclass + # and self.key is BasicIndexer((slice(None, None, None),)) + # so we need the explicit check for ExplicitlyIndexed + if isinstance(array, ExplicitlyIndexed): + array = array.get_duck_array() + return _wrap_numpy_scalars(array) + + def _updated_key(self, new_key: ExplicitIndexer): return _combine_indexers(self.key, self.shape, new_key) - def __getitem__(self, indexer): + def _oindex_get(self, indexer: OuterIndexer): + return type(self)(self.array, self._updated_key(indexer)) + + def _vindex_get(self, indexer: VectorizedIndexer): + return type(self)(self.array, self._updated_key(indexer)) + + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) # If the indexed array becomes a scalar, return LazilyIndexedArray if all(isinstance(ind, integer_types) for ind in indexer.tuple): key = BasicIndexer(tuple(k[indexer.tuple] for k in self.key.tuple)) @@ -601,13 +726,13 @@ def transpose(self, order): key = VectorizedIndexer(tuple(k.transpose(order) for k in self.key.tuple)) return type(self)(self.array, key) - def __setitem__(self, key, value): + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: raise NotImplementedError( "Lazy item assignment with the vectorized indexer is not yet " "implemented. Load your data first by .load() or compute()." ) - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}(array={self.array!r}, key={self.key!r})" @@ -622,7 +747,7 @@ def _wrap_numpy_scalars(array): class CopyOnWriteArray(ExplicitlyIndexedNDArrayMixin): __slots__ = ("array", "_copied") - def __init__(self, array): + def __init__(self, array: duckarray[Any, Any]): self.array = as_indexable(array) self._copied = False @@ -631,18 +756,35 @@ def _ensure_copied(self): self.array = as_indexable(np.array(self.array)) self._copied = True - def __array__(self, dtype=None): - return np.asarray(self.array, dtype=dtype) + def get_duck_array(self): + return self.array.get_duck_array() + + def _oindex_get(self, indexer: OuterIndexer): + return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) - def __getitem__(self, key): - return type(self)(_wrap_numpy_scalars(self.array[key])) + def _vindex_get(self, indexer: VectorizedIndexer): + return type(self)(_wrap_numpy_scalars(self.array.vindex[indexer])) + + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + return type(self)(_wrap_numpy_scalars(self.array[indexer])) def transpose(self, order): return self.array.transpose(order) - def __setitem__(self, key, value): + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + self._ensure_copied() + self.array.vindex[indexer] = value + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + self._ensure_copied() + self.array.oindex[indexer] = value + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) self._ensure_copied() - self.array[key] = value + + self.array[indexer] = value def __deepcopy__(self, memo): # CopyOnWriteArray is used to wrap backend array objects, which might @@ -658,21 +800,37 @@ def __init__(self, array): self.array = _wrap_numpy_scalars(as_indexable(array)) def _ensure_cached(self): - if not isinstance(self.array, NumpyIndexingAdapter): - self.array = NumpyIndexingAdapter(np.asarray(self.array)) + self.array = as_indexable(self.array.get_duck_array()) + + def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + return np.asarray(self.get_duck_array(), dtype=dtype) - def __array__(self, dtype=None): + def get_duck_array(self): self._ensure_cached() - return np.asarray(self.array, dtype=dtype) + return self.array.get_duck_array() + + def _oindex_get(self, indexer: OuterIndexer): + return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) + + def _vindex_get(self, indexer: VectorizedIndexer): + return type(self)(_wrap_numpy_scalars(self.array.vindex[indexer])) - def __getitem__(self, key): - return type(self)(_wrap_numpy_scalars(self.array[key])) + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + return type(self)(_wrap_numpy_scalars(self.array[indexer])) def transpose(self, order): return self.array.transpose(order) - def __setitem__(self, key, value): - self.array[key] = value + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + self.array.vindex[indexer] = value + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + self.array.oindex[indexer] = value + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) + self.array[indexer] = value def as_indexable(array): @@ -697,12 +855,14 @@ def as_indexable(array): raise TypeError(f"Invalid array type: {type(array)}") -def _outer_to_vectorized_indexer(key, shape): +def _outer_to_vectorized_indexer( + indexer: BasicIndexer | OuterIndexer, shape: _Shape +) -> VectorizedIndexer: """Convert an OuterIndexer into an vectorized indexer. Parameters ---------- - key : Outer/Basic Indexer + indexer : Outer/Basic Indexer An indexer to convert. shape : tuple Shape of the array subject to the indexing. @@ -714,7 +874,7 @@ def _outer_to_vectorized_indexer(key, shape): Each element is an array: broadcasting them together gives the shape of the result. """ - key = key.tuple + key = indexer.tuple n_dim = len([k for k in key if not isinstance(k, integer_types)]) i_dim = 0 @@ -726,18 +886,18 @@ def _outer_to_vectorized_indexer(key, shape): if isinstance(k, slice): k = np.arange(*k.indices(size)) assert k.dtype.kind in {"i", "u"} - shape = [(1,) * i_dim + (k.size,) + (1,) * (n_dim - i_dim - 1)] - new_key.append(k.reshape(*shape)) + new_shape = [(1,) * i_dim + (k.size,) + (1,) * (n_dim - i_dim - 1)] + new_key.append(k.reshape(*new_shape)) i_dim += 1 return VectorizedIndexer(tuple(new_key)) -def _outer_to_numpy_indexer(key, shape): +def _outer_to_numpy_indexer(indexer: BasicIndexer | OuterIndexer, shape: _Shape): """Convert an OuterIndexer into an indexer for NumPy. Parameters ---------- - key : Basic/OuterIndexer + indexer : Basic/OuterIndexer An indexer to convert. shape : tuple Shape of the array subject to the indexing. @@ -747,16 +907,16 @@ def _outer_to_numpy_indexer(key, shape): tuple Tuple suitable for use to index a NumPy array. """ - if len([k for k in key.tuple if not isinstance(k, slice)]) <= 1: + if len([k for k in indexer.tuple if not isinstance(k, slice)]) <= 1: # If there is only one vector and all others are slice, # it can be safely used in mixed basic/advanced indexing. # Boolean index should already be converted to integer array. - return key.tuple + return indexer.tuple else: - return _outer_to_vectorized_indexer(key, shape).tuple + return _outer_to_vectorized_indexer(indexer, shape).tuple -def _combine_indexers(old_key, shape, new_key): +def _combine_indexers(old_key, shape: _Shape, new_key) -> VectorizedIndexer: """Combine two indexers. Parameters @@ -798,9 +958,9 @@ class IndexingSupport(enum.Enum): def explicit_indexing_adapter( key: ExplicitIndexer, - shape: tuple[int, ...], + shape: _Shape, indexing_support: IndexingSupport, - raw_indexing_method: Callable, + raw_indexing_method: Callable[..., Any], ) -> Any: """Support explicit indexing by delegating to a raw indexing method. @@ -827,12 +987,33 @@ def explicit_indexing_adapter( result = raw_indexing_method(raw_key.tuple) if numpy_indices.tuple: # index the loaded np.ndarray - result = NumpyIndexingAdapter(np.asarray(result))[numpy_indices] + indexable = NumpyIndexingAdapter(result) + result = apply_indexer(indexable, numpy_indices) return result +def apply_indexer(indexable, indexer: ExplicitIndexer): + """Apply an indexer to an indexable object.""" + if isinstance(indexer, VectorizedIndexer): + return indexable.vindex[indexer] + elif isinstance(indexer, OuterIndexer): + return indexable.oindex[indexer] + else: + return indexable[indexer] + + +def set_with_indexer(indexable, indexer: ExplicitIndexer, value: Any) -> None: + """Set values in an indexable object using an indexer.""" + if isinstance(indexer, VectorizedIndexer): + indexable.vindex[indexer] = value + elif isinstance(indexer, OuterIndexer): + indexable.oindex[indexer] = value + else: + indexable[indexer] = value + + def decompose_indexer( - indexer: ExplicitIndexer, shape: tuple[int, ...], indexing_support: IndexingSupport + indexer: ExplicitIndexer, shape: _Shape, indexing_support: IndexingSupport ) -> tuple[ExplicitIndexer, ExplicitIndexer]: if isinstance(indexer, VectorizedIndexer): return _decompose_vectorized_indexer(indexer, shape, indexing_support) @@ -871,7 +1052,7 @@ def _decompose_slice(key: slice, size: int) -> tuple[slice, slice]: def _decompose_vectorized_indexer( indexer: VectorizedIndexer, - shape: tuple[int, ...], + shape: _Shape, indexing_support: IndexingSupport, ) -> tuple[ExplicitIndexer, ExplicitIndexer]: """ @@ -902,10 +1083,10 @@ def _decompose_vectorized_indexer( >>> array = np.arange(36).reshape(6, 6) >>> backend_indexer = OuterIndexer((np.array([0, 1, 3]), np.array([2, 3]))) >>> # load subslice of the array - ... array = NumpyIndexingAdapter(array)[backend_indexer] + ... array = NumpyIndexingAdapter(array).oindex[backend_indexer] >>> np_indexer = VectorizedIndexer((np.array([0, 2, 1]), np.array([0, 1, 0]))) >>> # vectorized indexing for on-memory np.ndarray. - ... NumpyIndexingAdapter(array)[np_indexer] + ... NumpyIndexingAdapter(array).vindex[np_indexer] array([ 2, 21, 8]) """ assert isinstance(indexer, VectorizedIndexer) @@ -953,7 +1134,7 @@ def _decompose_vectorized_indexer( def _decompose_outer_indexer( indexer: BasicIndexer | OuterIndexer, - shape: tuple[int, ...], + shape: _Shape, indexing_support: IndexingSupport, ) -> tuple[ExplicitIndexer, ExplicitIndexer]: """ @@ -987,7 +1168,7 @@ def _decompose_outer_indexer( ... array = NumpyIndexingAdapter(array)[backend_indexer] >>> np_indexer = OuterIndexer((np.array([0, 2, 1]), np.array([0, 1, 0]))) >>> # outer indexing for on-memory np.ndarray. - ... NumpyIndexingAdapter(array)[np_indexer] + ... NumpyIndexingAdapter(array).oindex[np_indexer] array([[ 2, 3, 2], [14, 15, 14], [ 8, 9, 8]]) @@ -1026,9 +1207,11 @@ def _decompose_outer_indexer( # some backends such as h5py supports only 1 vector in indexers # We choose the most efficient axis gains = [ - (np.max(k) - np.min(k) + 1.0) / len(np.unique(k)) - if isinstance(k, np.ndarray) - else 0 + ( + (np.max(k) - np.min(k) + 1.0) / len(np.unique(k)) + if isinstance(k, np.ndarray) + else 0 + ) for k in indexer_elems ] array_index = np.argmax(np.array(gains)) if len(gains) > 0 else None @@ -1092,7 +1275,9 @@ def _decompose_outer_indexer( return (BasicIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) -def _arrayize_vectorized_indexer(indexer, shape): +def _arrayize_vectorized_indexer( + indexer: VectorizedIndexer, shape: _Shape +) -> VectorizedIndexer: """Return an identical vindex but slices are replaced by arrays""" slices = [v for v in indexer.tuple if isinstance(v, slice)] if len(slices) == 0: @@ -1112,31 +1297,35 @@ def _arrayize_vectorized_indexer(indexer, shape): return VectorizedIndexer(tuple(new_key)) -def _dask_array_with_chunks_hint(array, chunks): - """Create a dask array using the chunks hint for dimensions of size > 1.""" - import dask.array as da +def _chunked_array_with_chunks_hint( + array, chunks, chunkmanager: ChunkManagerEntrypoint[Any] +): + """Create a chunked array using the chunks hint for dimensions of size > 1.""" if len(chunks) < array.ndim: raise ValueError("not enough chunks in hint") new_chunks = [] for chunk, size in zip(chunks, array.shape): new_chunks.append(chunk if size > 1 else (1,)) - return da.from_array(array, new_chunks) + return chunkmanager.from_array(array, new_chunks) # type: ignore[arg-type] def _logical_any(args): return functools.reduce(operator.or_, args) -def _masked_result_drop_slice(key, data=None): +def _masked_result_drop_slice(key, data: duckarray[Any, Any] | None = None): key = (k for k in key if not isinstance(k, slice)) chunks_hint = getattr(data, "chunks", None) new_keys = [] for k in key: if isinstance(k, np.ndarray): - if is_duck_dask_array(data): - new_keys.append(_dask_array_with_chunks_hint(k, chunks_hint)) + if is_chunked_array(data): # type: ignore[arg-type] + chunkmanager = get_chunked_array_type(data) + new_keys.append( + _chunked_array_with_chunks_hint(k, chunks_hint, chunkmanager) + ) elif isinstance(data, array_type("sparse")): import sparse @@ -1150,7 +1339,9 @@ def _masked_result_drop_slice(key, data=None): return mask -def create_mask(indexer, shape, data=None): +def create_mask( + indexer: ExplicitIndexer, shape: _Shape, data: duckarray[Any, Any] | None = None +): """Create a mask for indexing with a fill-value. Parameters @@ -1195,7 +1386,9 @@ def create_mask(indexer, shape, data=None): return mask -def _posify_mask_subindexer(index): +def _posify_mask_subindexer( + index: np.ndarray[Any, np.dtype[np.generic]], +) -> np.ndarray[Any, np.dtype[np.generic]]: """Convert masked indices in a flat array to the nearest unmasked index. Parameters @@ -1221,7 +1414,7 @@ def _posify_mask_subindexer(index): return new_index -def posify_mask_indexer(indexer): +def posify_mask_indexer(indexer: ExplicitIndexer) -> ExplicitIndexer: """Convert masked values (-1) in an indexer to nearest unmasked values. This routine is useful for dask, where it can be much faster to index @@ -1239,9 +1432,11 @@ def posify_mask_indexer(indexer): replaced by an adjacent non-masked element. """ key = tuple( - _posify_mask_subindexer(k.ravel()).reshape(k.shape) - if isinstance(k, np.ndarray) - else k + ( + _posify_mask_subindexer(k.ravel()).reshape(k.shape) + if isinstance(k, np.ndarray) + else k + ) for k in indexer.tuple ) return type(indexer)(key) @@ -1270,40 +1465,35 @@ def __init__(self, array): if not isinstance(array, np.ndarray): raise TypeError( "NumpyIndexingAdapter only wraps np.ndarray. " - "Trying to wrap {}".format(type(array)) + f"Trying to wrap {type(array)}" ) self.array = array - def _indexing_array_and_key(self, key): - if isinstance(key, OuterIndexer): - array = self.array - key = _outer_to_numpy_indexer(key, self.array.shape) - elif isinstance(key, VectorizedIndexer): - array = NumpyVIndexAdapter(self.array) - key = key.tuple - elif isinstance(key, BasicIndexer): - array = self.array - # We want 0d slices rather than scalars. This is achieved by - # appending an ellipsis (see - # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). - key = key.tuple + (Ellipsis,) - else: - raise TypeError(f"unexpected key type: {type(key)}") - - return array, key - def transpose(self, order): return self.array.transpose(order) - def __getitem__(self, key): - array, key = self._indexing_array_and_key(key) + def _oindex_get(self, indexer: OuterIndexer): + key = _outer_to_numpy_indexer(indexer, self.array.shape) + return self.array[key] + + def _vindex_get(self, indexer: VectorizedIndexer): + array = NumpyVIndexAdapter(self.array) + return array[indexer.tuple] + + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + + array = self.array + # We want 0d slices rather than scalars. This is achieved by + # appending an ellipsis (see + # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). + key = indexer.tuple + (Ellipsis,) return array[key] - def __setitem__(self, key, value): - array, key = self._indexing_array_and_key(key) + def _safe_setitem(self, array, key: tuple[Any, ...], value: Any) -> None: try: array[key] = value - except ValueError: + except ValueError as exc: # More informative exception if read-only view if not array.flags.writeable and not array.flags.owndata: raise ValueError( @@ -1311,7 +1501,24 @@ def __setitem__(self, key, value): "Do you want to .copy() array first?" ) else: - raise + raise exc + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + key = _outer_to_numpy_indexer(indexer, self.array.shape) + self._safe_setitem(self.array, key, value) + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + array = NumpyVIndexAdapter(self.array) + self._safe_setitem(array, indexer.tuple, value) + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) + array = self.array + # We want 0d slices rather than scalars. This is achieved by + # appending an ellipsis (see + # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes). + key = indexer.tuple + (Ellipsis,) + self._safe_setitem(array, key, value) class NdArrayLikeIndexingAdapter(NumpyIndexingAdapter): @@ -1339,30 +1546,30 @@ def __init__(self, array): ) self.array = array - def __getitem__(self, key): - if isinstance(key, BasicIndexer): - return self.array[key.tuple] - elif isinstance(key, OuterIndexer): - # manual orthogonal indexing (implemented like DaskIndexingAdapter) - key = key.tuple - value = self.array - for axis, subkey in reversed(list(enumerate(key))): - value = value[(slice(None),) * axis + (subkey, Ellipsis)] - return value - else: - if isinstance(key, VectorizedIndexer): - raise TypeError("Vectorized indexing is not supported") - else: - raise TypeError(f"Unrecognized indexer: {key}") + def _oindex_get(self, indexer: OuterIndexer): + # manual orthogonal indexing (implemented like DaskIndexingAdapter) + key = indexer.tuple + value = self.array + for axis, subkey in reversed(list(enumerate(key))): + value = value[(slice(None),) * axis + (subkey, Ellipsis)] + return value - def __setitem__(self, key, value): - if isinstance(key, (BasicIndexer, OuterIndexer)): - self.array[key.tuple] = value - else: - if isinstance(key, VectorizedIndexer): - raise TypeError("Vectorized indexing is not supported") - else: - raise TypeError(f"Unrecognized indexer: {key}") + def _vindex_get(self, indexer: VectorizedIndexer): + raise TypeError("Vectorized indexing is not supported") + + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + return self.array[indexer.tuple] + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + self.array[indexer.tuple] = value + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + raise TypeError("Vectorized indexing is not supported") + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) + self.array[indexer.tuple] = value def transpose(self, order): xp = self.array.__array_namespace__() @@ -1380,55 +1587,38 @@ def __init__(self, array): """ self.array = array - def __getitem__(self, key): - if not isinstance(key, VectorizedIndexer): - # if possible, short-circuit when keys are effectively slice(None) - # This preserves dask name and passes lazy array equivalence checks - # (see duck_array_ops.lazy_array_equiv) - rewritten_indexer = False - new_indexer = [] - for idim, k in enumerate(key.tuple): - if isinstance(k, Iterable) and ( - not is_duck_dask_array(k) - and duck_array_ops.array_equiv(k, np.arange(self.array.shape[idim])) - ): - new_indexer.append(slice(None)) - rewritten_indexer = True - else: - new_indexer.append(k) - if rewritten_indexer: - key = type(key)(tuple(new_indexer)) - - if isinstance(key, BasicIndexer): - return self.array[key.tuple] - elif isinstance(key, VectorizedIndexer): - return self.array.vindex[key.tuple] - else: - assert isinstance(key, OuterIndexer) - key = key.tuple - try: - return self.array[key] - except NotImplementedError: - # manual orthogonal indexing. - # TODO: port this upstream into dask in a saner way. - value = self.array - for axis, subkey in reversed(list(enumerate(key))): - value = value[(slice(None),) * axis + (subkey,)] - return value - - def __setitem__(self, key, value): - if isinstance(key, BasicIndexer): - self.array[key.tuple] = value - elif isinstance(key, VectorizedIndexer): - self.array.vindex[key.tuple] = value - elif isinstance(key, OuterIndexer): - num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in key.tuple) - if num_non_slices > 1: - raise NotImplementedError( - "xarray can't set arrays with multiple " - "array indices to dask yet." - ) - self.array[key.tuple] = value + def _oindex_get(self, indexer: OuterIndexer): + key = indexer.tuple + try: + return self.array[key] + except NotImplementedError: + # manual orthogonal indexing + value = self.array + for axis, subkey in reversed(list(enumerate(key))): + value = value[(slice(None),) * axis + (subkey,)] + return value + + def _vindex_get(self, indexer: VectorizedIndexer): + return self.array.vindex[indexer.tuple] + + def __getitem__(self, indexer: ExplicitIndexer): + self._check_and_raise_if_non_basic_indexer(indexer) + return self.array[indexer.tuple] + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in indexer.tuple) + if num_non_slices > 1: + raise NotImplementedError( + "xarray can't set arrays with multiple " "array indices to dask yet." + ) + self.array[indexer.tuple] = value + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + self.array.vindex[indexer.tuple] = value + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + self._check_and_raise_if_non_basic_indexer(indexer) + self.array[indexer.tuple] = value def transpose(self, order): return self.array.transpose(order) @@ -1463,8 +1653,11 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray: array = array.astype("object") return np.asarray(array.values, dtype=dtype) + def get_duck_array(self) -> np.ndarray: + return np.asarray(self) + @property - def shape(self) -> tuple[int, ...]: + def shape(self) -> _Shape: return (len(self.array),) def _convert_scalar(self, item): @@ -1487,8 +1680,14 @@ def _convert_scalar(self, item): # a NumPy array. return to_0d_array(item) + def _oindex_get(self, indexer: OuterIndexer): + return self.__getitem__(indexer) + + def _vindex_get(self, indexer: VectorizedIndexer): + return self.__getitem__(indexer) + def __getitem__( - self, indexer + self, indexer: ExplicitIndexer ) -> ( PandasIndexingAdapter | NumpyIndexingAdapter @@ -1503,7 +1702,8 @@ def __getitem__( (key,) = key if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional - return NumpyIndexingAdapter(np.asarray(self))[indexer] + indexable = NumpyIndexingAdapter(np.asarray(self)) + return apply_indexer(indexable, indexer) result = self.array[key] @@ -1566,7 +1766,7 @@ def _convert_scalar(self, item): item = item[idx] return super()._convert_scalar(item) - def __getitem__(self, indexer): + def __getitem__(self, indexer: ExplicitIndexer): result = super().__getitem__(indexer) if isinstance(result, type(self)): result.level = self.level @@ -1603,9 +1803,9 @@ def _repr_inline_(self, max_width: int) -> str: return format_array_flat(self._get_array_subset(), max_width) def _repr_html_(self) -> str: - from xarray.core.formatting import short_numpy_repr + from xarray.core.formatting import short_array_repr - array_repr = short_numpy_repr(self._get_array_subset()) + array_repr = short_array_repr(self._get_array_subset()) return f"
    {escape(array_repr)}
    " def copy(self, deep: bool = True) -> PandasMultiIndexingAdapter: diff --git a/xarray/core/merge.py b/xarray/core/merge.py index bf7288ad7ed..a689620e524 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -11,7 +11,6 @@ from xarray.core.duck_array_ops import lazy_array_equiv from xarray.core.indexes import ( Index, - Indexes, create_default_index_implicit, filter_indexes_from_coords, indexes_equal, @@ -34,7 +33,7 @@ tuple[DimsLike, ArrayLike, Mapping, Mapping], ] XarrayValue = Union[DataArray, Variable, VariableLike] - DatasetLike = Union[Dataset, Mapping[Any, XarrayValue]] + DatasetLike = Union[Dataset, Coordinates, Mapping[Any, XarrayValue]] CoercibleValue = Union[XarrayValue, pd.Series, pd.DataFrame] CoercibleMapping = Union[Dataset, Mapping[Any, CoercibleValue]] @@ -195,11 +194,11 @@ def _assert_prioritized_valid( def merge_collected( - grouped: dict[Hashable, list[MergeElement]], + grouped: dict[Any, list[MergeElement]], prioritized: Mapping[Any, MergeElement] | None = None, compat: CompatOptions = "minimal", combine_attrs: CombineAttrsOptions = "override", - equals: dict[Hashable, bool] | None = None, + equals: dict[Any, bool] | None = None, ) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: """Merge dicts of variables, while resolving conflicts appropriately. @@ -306,22 +305,27 @@ def merge_collected( def collect_variables_and_indexes( - list_of_mappings: list[DatasetLike], + list_of_mappings: Iterable[DatasetLike], indexes: Mapping[Any, Any] | None = None, ) -> dict[Hashable, list[MergeElement]]: """Collect variables and indexes from list of mappings of xarray objects. - Mappings must either be Dataset objects, or have values of one of the - following types: + Mappings can be Dataset or Coordinates objects, in which case both + variables and indexes are extracted from it. + + It can also have values of one of the following types: - an xarray.Variable - a tuple `(dims, data[, attrs[, encoding]])` that can be converted in an xarray.Variable - or an xarray.DataArray If a mapping of indexes is given, those indexes are assigned to all variables - with a matching key/name. + with a matching key/name. For dimension variables with no matching index, a + default (pandas) index is assigned. DataArray indexes that don't match mapping + keys are also extracted. """ + from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -338,8 +342,8 @@ def append_all(variables, indexes): append(name, variable, indexes.get(name)) for mapping in list_of_mappings: - if isinstance(mapping, Dataset): - append_all(mapping.variables, mapping._indexes) + if isinstance(mapping, (Coordinates, Dataset)): + append_all(mapping.variables, mapping.xindexes) continue for name, variable in mapping.items(): @@ -466,13 +470,15 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik List of Dataset or dictionary objects. Any inputs or values in the inputs that were pandas objects have been converted into native xarray objects. """ + from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - out = [] + out: list[DatasetLike] = [] for obj in objects: - if isinstance(obj, Dataset): - variables: DatasetLike = obj + variables: DatasetLike + if isinstance(obj, (Dataset, Coordinates)): + variables = obj else: variables = {} if isinstance(obj, PANDAS_TYPES): @@ -486,7 +492,7 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik def _get_priority_vars_and_indexes( - objects: list[DatasetLike], + objects: Sequence[DatasetLike], priority_arg: int | None, compat: CompatOptions = "equals", ) -> dict[Hashable, MergeElement]: @@ -498,7 +504,7 @@ def _get_priority_vars_and_indexes( Parameters ---------- - objects : list of dict-like of Variable + objects : sequence of dict-like of Variable Dictionaries in which to find the priority variables. priority_arg : int or None Integer object whose variable should take priority. @@ -556,56 +562,11 @@ def merge_coords( return variables, out_indexes -def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="outer"): - """Used in Dataset.__init__.""" - indexes, coords = _create_indexes_from_coords(coords, data_vars) - objects = [data_vars, coords] - explicit_coords = coords.keys() - return merge_core( - objects, - compat, - join, - explicit_coords=explicit_coords, - indexes=Indexes(indexes, coords), - ) - - -def _create_indexes_from_coords(coords, data_vars=None): - """Maybe create default indexes from a mapping of coordinates. - - Return those indexes and updated coordinates. - """ - all_variables = dict(coords) - if data_vars is not None: - all_variables.update(data_vars) - - indexes = {} - updated_coords = {} - - # this is needed for backward compatibility: when a pandas multi-index - # is given as data variable, it is promoted as index / level coordinates - # TODO: depreciate this implicit behavior - index_vars = { - k: v - for k, v in all_variables.items() - if k in coords or isinstance(v, pd.MultiIndex) - } - - for name, obj in index_vars.items(): - variable = as_variable(obj, name=name) - - if variable.dims == (name,): - idx, idx_vars = create_default_index_implicit(variable, all_variables) - indexes.update({k: idx for k in idx_vars}) - updated_coords.update(idx_vars) - all_variables.update(idx_vars) - else: - updated_coords[name] = obj - - return indexes, updated_coords - - -def assert_valid_explicit_coords(variables, dims, explicit_coords): +def assert_valid_explicit_coords( + variables: Mapping[Any, Any], + dims: Mapping[Any, int], + explicit_coords: Iterable[Hashable], +) -> None: """Validate explicit coordinate names/dims. Raise a MergeError if an explicit coord shares a name with a dimension @@ -688,9 +649,10 @@ def merge_core( join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", priority_arg: int | None = None, - explicit_coords: Sequence | None = None, + explicit_coords: Iterable[Hashable] | None = None, indexes: Mapping[Any, Any] | None = None, fill_value: object = dtypes.NA, + skip_align_args: list[int] | None = None, ) -> _MergeResult: """Core logic for merging labeled objects. @@ -716,6 +678,8 @@ def merge_core( may be cast to pandas.Index objects. fill_value : scalar, optional Value to use for newly missing values + skip_align_args : list of int, optional + Optional arguments in `objects` that are not included in alignment. Returns ------- @@ -737,10 +701,20 @@ def merge_core( _assert_compat_valid(compat) + objects = list(objects) + if skip_align_args is None: + skip_align_args = [] + + skip_align_objs = [(pos, objects.pop(pos)) for pos in skip_align_args] + coerced = coerce_pandas_values(objects) aligned = deep_align( coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value ) + + for pos, obj in skip_align_objs: + aligned.insert(pos, obj) + collected = collect_variables_and_indexes(aligned, indexes=indexes) prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) variables, out_indexes = merge_collected( @@ -750,6 +724,9 @@ def merge_core( dims = calculate_dimensions(variables) coord_names, noncoord_names = determine_coords(coerced) + if compat == "minimal": + # coordinates may be dropped in merged results + coord_names.intersection_update(variables) if explicit_coords is not None: assert_valid_explicit_coords(variables, dims, explicit_coords) coord_names.update(explicit_coords) @@ -862,124 +839,124 @@ def merge( ... ) >>> x - + Size: 32B array([[1., 2.], [3., 5.]]) Coordinates: - * lat (lat) float64 35.0 40.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 16B 35.0 40.0 + * lon (lon) float64 16B 100.0 120.0 >>> y - + Size: 32B array([[5., 6.], [7., 8.]]) Coordinates: - * lat (lat) float64 35.0 42.0 - * lon (lon) float64 100.0 150.0 + * lat (lat) float64 16B 35.0 42.0 + * lon (lon) float64 16B 100.0 150.0 >>> z - + Size: 32B array([[0., 3.], [4., 9.]]) Coordinates: - * time (time) float64 30.0 60.0 - * lon (lon) float64 100.0 150.0 + * time (time) float64 16B 30.0 60.0 + * lon (lon) float64 16B 100.0 150.0 >>> xr.merge([x, y, z]) - + Size: 256B Dimensions: (lat: 3, lon: 3, time: 2) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 150.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 24B 35.0 40.0 42.0 + * lon (lon) float64 24B 100.0 120.0 150.0 + * time (time) float64 16B 30.0 60.0 Data variables: - var1 (lat, lon) float64 1.0 2.0 nan 3.0 5.0 nan nan nan nan - var2 (lat, lon) float64 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 - var3 (time, lon) float64 0.0 nan 3.0 4.0 nan 9.0 + var1 (lat, lon) float64 72B 1.0 2.0 nan 3.0 5.0 nan nan nan nan + var2 (lat, lon) float64 72B 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 + var3 (time, lon) float64 48B 0.0 nan 3.0 4.0 nan 9.0 >>> xr.merge([x, y, z], compat="identical") - + Size: 256B Dimensions: (lat: 3, lon: 3, time: 2) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 150.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 24B 35.0 40.0 42.0 + * lon (lon) float64 24B 100.0 120.0 150.0 + * time (time) float64 16B 30.0 60.0 Data variables: - var1 (lat, lon) float64 1.0 2.0 nan 3.0 5.0 nan nan nan nan - var2 (lat, lon) float64 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 - var3 (time, lon) float64 0.0 nan 3.0 4.0 nan 9.0 + var1 (lat, lon) float64 72B 1.0 2.0 nan 3.0 5.0 nan nan nan nan + var2 (lat, lon) float64 72B 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 + var3 (time, lon) float64 48B 0.0 nan 3.0 4.0 nan 9.0 >>> xr.merge([x, y, z], compat="equals") - + Size: 256B Dimensions: (lat: 3, lon: 3, time: 2) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 150.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 24B 35.0 40.0 42.0 + * lon (lon) float64 24B 100.0 120.0 150.0 + * time (time) float64 16B 30.0 60.0 Data variables: - var1 (lat, lon) float64 1.0 2.0 nan 3.0 5.0 nan nan nan nan - var2 (lat, lon) float64 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 - var3 (time, lon) float64 0.0 nan 3.0 4.0 nan 9.0 + var1 (lat, lon) float64 72B 1.0 2.0 nan 3.0 5.0 nan nan nan nan + var2 (lat, lon) float64 72B 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 + var3 (time, lon) float64 48B 0.0 nan 3.0 4.0 nan 9.0 >>> xr.merge([x, y, z], compat="equals", fill_value=-999.0) - + Size: 256B Dimensions: (lat: 3, lon: 3, time: 2) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 150.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 24B 35.0 40.0 42.0 + * lon (lon) float64 24B 100.0 120.0 150.0 + * time (time) float64 16B 30.0 60.0 Data variables: - var1 (lat, lon) float64 1.0 2.0 -999.0 3.0 ... -999.0 -999.0 -999.0 - var2 (lat, lon) float64 5.0 -999.0 6.0 -999.0 ... -999.0 7.0 -999.0 8.0 - var3 (time, lon) float64 0.0 -999.0 3.0 4.0 -999.0 9.0 + var1 (lat, lon) float64 72B 1.0 2.0 -999.0 3.0 ... -999.0 -999.0 -999.0 + var2 (lat, lon) float64 72B 5.0 -999.0 6.0 -999.0 ... 7.0 -999.0 8.0 + var3 (time, lon) float64 48B 0.0 -999.0 3.0 4.0 -999.0 9.0 >>> xr.merge([x, y, z], join="override") - + Size: 144B Dimensions: (lat: 2, lon: 2, time: 2) Coordinates: - * lat (lat) float64 35.0 40.0 - * lon (lon) float64 100.0 120.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 16B 35.0 40.0 + * lon (lon) float64 16B 100.0 120.0 + * time (time) float64 16B 30.0 60.0 Data variables: - var1 (lat, lon) float64 1.0 2.0 3.0 5.0 - var2 (lat, lon) float64 5.0 6.0 7.0 8.0 - var3 (time, lon) float64 0.0 3.0 4.0 9.0 + var1 (lat, lon) float64 32B 1.0 2.0 3.0 5.0 + var2 (lat, lon) float64 32B 5.0 6.0 7.0 8.0 + var3 (time, lon) float64 32B 0.0 3.0 4.0 9.0 >>> xr.merge([x, y, z], join="inner") - + Size: 64B Dimensions: (lat: 1, lon: 1, time: 2) Coordinates: - * lat (lat) float64 35.0 - * lon (lon) float64 100.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 8B 35.0 + * lon (lon) float64 8B 100.0 + * time (time) float64 16B 30.0 60.0 Data variables: - var1 (lat, lon) float64 1.0 - var2 (lat, lon) float64 5.0 - var3 (time, lon) float64 0.0 4.0 + var1 (lat, lon) float64 8B 1.0 + var2 (lat, lon) float64 8B 5.0 + var3 (time, lon) float64 16B 0.0 4.0 >>> xr.merge([x, y, z], compat="identical", join="inner") - + Size: 64B Dimensions: (lat: 1, lon: 1, time: 2) Coordinates: - * lat (lat) float64 35.0 - * lon (lon) float64 100.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 8B 35.0 + * lon (lon) float64 8B 100.0 + * time (time) float64 16B 30.0 60.0 Data variables: - var1 (lat, lon) float64 1.0 - var2 (lat, lon) float64 5.0 - var3 (time, lon) float64 0.0 4.0 + var1 (lat, lon) float64 8B 1.0 + var2 (lat, lon) float64 8B 5.0 + var3 (time, lon) float64 16B 0.0 4.0 >>> xr.merge([x, y, z], compat="broadcast_equals", join="outer") - + Size: 256B Dimensions: (lat: 3, lon: 3, time: 2) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 150.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 24B 35.0 40.0 42.0 + * lon (lon) float64 24B 100.0 120.0 150.0 + * time (time) float64 16B 30.0 60.0 Data variables: - var1 (lat, lon) float64 1.0 2.0 nan 3.0 5.0 nan nan nan nan - var2 (lat, lon) float64 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 - var3 (time, lon) float64 0.0 nan 3.0 4.0 nan 9.0 + var1 (lat, lon) float64 72B 1.0 2.0 nan 3.0 5.0 nan nan nan nan + var2 (lat, lon) float64 72B 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 + var3 (time, lon) float64 48B 0.0 nan 3.0 4.0 nan 9.0 >>> xr.merge([x, y, z], join="exact") Traceback (most recent call last): @@ -997,18 +974,23 @@ def merge( combine_nested combine_by_coords """ + + from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset dict_like_objects = [] for obj in objects: - if not isinstance(obj, (DataArray, Dataset, dict)): + if not isinstance(obj, (DataArray, Dataset, Coordinates, dict)): raise TypeError( "objects must be an iterable containing only " "Dataset(s), DataArray(s), and dictionaries." ) - obj = obj.to_dataset(promote_attrs=True) if isinstance(obj, DataArray) else obj + if isinstance(obj, DataArray): + obj = obj.to_dataset(promote_attrs=True) + elif isinstance(obj, Coordinates): + obj = obj.to_dataset() dict_like_objects.append(obj) merge_result = merge_core( @@ -1035,7 +1017,7 @@ def dataset_merge_method( # method due for backwards compatibility # TODO: consider deprecating it? - if isinstance(overwrite_vars, Iterable) and not isinstance(overwrite_vars, str): + if not isinstance(overwrite_vars, str) and isinstance(overwrite_vars, Iterable): overwrite_vars = set(overwrite_vars) else: overwrite_vars = {overwrite_vars} diff --git a/xarray/core/missing.py b/xarray/core/missing.py index d7f0be5fa08..8aa2ff2f042 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -13,12 +13,18 @@ from xarray.core import utils from xarray.core.common import _contains_datetime_like_objects, ones_like from xarray.core.computation import apply_ufunc -from xarray.core.duck_array_ops import datetime_to_numeric, push, timedelta_to_numeric -from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.pycompat import is_duck_dask_array +from xarray.core.duck_array_ops import ( + datetime_to_numeric, + push, + reshape, + timedelta_to_numeric, +) +from xarray.core.options import _get_keep_attrs from xarray.core.types import Interp1dOptions, InterpOptions from xarray.core.utils import OrderedSet, is_scalar from xarray.core.variable import Variable, broadcast_variables +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import is_chunked_array if TYPE_CHECKING: from xarray.core.dataarray import DataArray @@ -66,9 +72,7 @@ def __call__(self, x): return self.f(x, **self.call_kwargs) def __repr__(self): - return "{type}: method={method}".format( - type=self.__class__.__name__, method=self.method - ) + return f"{self.__class__.__name__}: method={self.method}" class NumpyInterpolator(BaseInterpolator): @@ -415,11 +419,6 @@ def _bfill(arr, n=None, axis=-1): def ffill(arr, dim=None, limit=None): """forward fill missing values""" - if not OPTIONS["use_bottleneck"]: - raise RuntimeError( - "ffill requires bottleneck to be enabled." - " Call `xr.set_options(use_bottleneck=True)` to enable it." - ) axis = arr.get_axis_num(dim) @@ -438,11 +437,6 @@ def ffill(arr, dim=None, limit=None): def bfill(arr, dim=None, limit=None): """backfill missing values""" - if not OPTIONS["use_bottleneck"]: - raise RuntimeError( - "bfill requires bottleneck to be enabled." - " Call `xr.set_options(use_bottleneck=True)` to enable it." - ) axis = arr.get_axis_num(dim) @@ -476,9 +470,9 @@ def _get_interpolator( returns interpolator class and keyword arguments for the class """ - interp_class: type[NumpyInterpolator] | type[ScipyInterpolator] | type[ - SplineInterpolator - ] + interp_class: ( + type[NumpyInterpolator] | type[ScipyInterpolator] | type[SplineInterpolator] + ) interp1d_methods = get_args(Interp1dOptions) valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v)) @@ -503,7 +497,7 @@ def _get_interpolator( ) elif method == "barycentric": interp_class = _import_interpolant("BarycentricInterpolator", method) - elif method == "krog": + elif method in ["krogh", "krog"]: interp_class = _import_interpolant("KroghInterpolator", method) elif method == "pchip": interp_class = _import_interpolant("PchipInterpolator", method) @@ -639,7 +633,7 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs): var.transpose(*original_dims).data, x, destination, method, kwargs ) - result = Variable(new_dims, interped, attrs=var.attrs) + result = Variable(new_dims, interped, attrs=var.attrs, fastpath=True) # dimension of the output array out_dims: OrderedSet = OrderedSet() @@ -648,7 +642,8 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs): out_dims.update(indexes_coords[d][1].dims) else: out_dims.add(d) - result = result.transpose(*out_dims) + if len(out_dims) > 1: + result = result.transpose(*out_dims) return result @@ -679,7 +674,7 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): Notes ----- - This requiers scipy installed. + This requires scipy installed. See Also -------- @@ -693,8 +688,8 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): else: func, kwargs = _get_interpolator_nd(method, **kwargs) - if is_duck_dask_array(var): - import dask.array as da + if is_chunked_array(var): + chunkmanager = get_chunked_array_type(var) ndim = var.ndim nconst = ndim - len(x) @@ -709,40 +704,36 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): ] new_x_arginds = [item for pair in new_x_arginds for item in pair] - args = ( - var, - range(ndim), - *x_arginds, - *new_x_arginds, - ) + args = (var, range(ndim), *x_arginds, *new_x_arginds) - _, rechunked = da.unify_chunks(*args) + _, rechunked = chunkmanager.unify_chunks(*args) args = tuple(elem for pair in zip(rechunked, args[1::2]) for elem in pair) new_x = rechunked[1 + (len(rechunked) - 1) // 2 :] + new_x0_chunks = new_x[0].chunks + new_x0_shape = new_x[0].shape + new_x0_chunks_is_not_none = new_x0_chunks is not None new_axes = { - ndim + i: new_x[0].chunks[i] - if new_x[0].chunks is not None - else new_x[0].shape[i] + ndim + i: new_x0_chunks[i] if new_x0_chunks_is_not_none else new_x0_shape[i] for i in range(new_x[0].ndim) } - # if useful, re-use localize for each chunk of new_x - localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None) + # if useful, reuse localize for each chunk of new_x + localize = (method in ["linear", "nearest"]) and new_x0_chunks_is_not_none # scipy.interpolate.interp1d always forces to float. # Use the same check for blockwise as well: if not issubclass(var.dtype.type, np.inexact): - dtype = np.float_ + dtype = float else: dtype = var.dtype meta = var._meta - return da.blockwise( - _dask_aware_interpnd, + return chunkmanager.blockwise( + _chunked_aware_interpnd, out_ind, *args, interp_func=func, @@ -763,7 +754,7 @@ def _interp1d(var, x, new_x, func, kwargs): x, new_x = x[0], new_x[0] rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x)) if new_x.ndim > 1: - return rslt.reshape(var.shape[:-1] + new_x.shape) + return reshape(rslt, (var.shape[:-1] + new_x.shape)) if new_x.ndim == 0: return rslt[..., -1] return rslt @@ -782,11 +773,11 @@ def _interpnd(var, x, new_x, func, kwargs): rslt = func(x, var, xi, **kwargs) # move back the interpolation axes to the last position rslt = rslt.transpose(range(-rslt.ndim + 1, 1)) - return rslt.reshape(rslt.shape[:-1] + new_x[0].shape) + return reshape(rslt, rslt.shape[:-1] + new_x[0].shape) -def _dask_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True): - """Wrapper for `_interpnd` through `blockwise` +def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True): + """Wrapper for `_interpnd` through `blockwise` for chunked arrays. The first half arrays in `coords` are original coordinates, the other half are destination coordinates diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 022de845c4c..fc7240139aa 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -4,8 +4,9 @@ import numpy as np -from xarray.core import dtypes, nputils, utils +from xarray.core import dtypes, duck_array_ops, nputils, utils from xarray.core.duck_array_ops import ( + astype, count, fillna, isnull, @@ -20,12 +21,16 @@ def _maybe_null_out(result, axis, mask, min_count=1): xarray version of pandas.core.nanops._maybe_null_out """ if axis is not None and getattr(result, "ndim", False): - null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0 + null_mask = ( + np.take(mask.shape, axis).prod() + - duck_array_ops.sum(mask, axis) + - min_count + ) < 0 dtype, fill_value = dtypes.maybe_promote(result.dtype) - result = where(null_mask, fill_value, result.astype(dtype)) + result = where(null_mask, fill_value, astype(result, dtype)) elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES: - null_mask = mask.size - mask.sum() + null_mask = mask.size - duck_array_ops.sum(mask) result = where(null_mask < min_count, np.nan, result) return result @@ -140,7 +145,7 @@ def _nanvar_object(value, axis=None, ddof=0, keepdims=False, **kwargs): value_mean = _nanmean_ddof_object( ddof=0, value=value, axis=axis, keepdims=True, **kwargs ) - squared = (value.astype(value_mean.dtype) - value_mean) ** 2 + squared = (astype(value, value_mean.dtype) - value_mean) ** 2 return _nanmean_ddof_object(ddof, squared, axis=axis, keepdims=keepdims, **kwargs) diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 1c5b0d3d972..6970d37402f 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -1,22 +1,41 @@ from __future__ import annotations import warnings +from typing import Callable import numpy as np import pandas as pd -from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined] +from packaging.version import Version + +from xarray.core.utils import is_duck_array, module_available +from xarray.namedarray import pycompat + +# remove once numpy 2.0 is the oldest supported version +if module_available("numpy", minversion="2.0.0.dev0"): + from numpy.lib.array_utils import ( # type: ignore[import-not-found,unused-ignore] + normalize_axis_index, + ) +else: + from numpy.core.multiarray import ( # type: ignore[attr-defined,no-redef,unused-ignore] + normalize_axis_index, + ) + +# remove once numpy 2.0 is the oldest supported version +try: + from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] +except ImportError: + from numpy import RankWarning # type: ignore[attr-defined,no-redef,unused-ignore] from xarray.core.options import OPTIONS -from xarray.core.pycompat import is_duck_array try: import bottleneck as bn - _USE_BOTTLENECK = True + _BOTTLENECK_AVAILABLE = True except ImportError: # use numpy methods instead bn = np - _USE_BOTTLENECK = False + _BOTTLENECK_AVAILABLE = False def _select_along_axis(values, idx, axis): @@ -155,13 +174,51 @@ def __setitem__(self, key, value): self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions) -def _create_bottleneck_method(name, npmodule=np): +def _create_method(name, npmodule=np) -> Callable: def f(values, axis=None, **kwargs): dtype = kwargs.get("dtype", None) bn_func = getattr(bn, name, None) if ( - _USE_BOTTLENECK + module_available("numbagg") + and pycompat.mod_version("numbagg") >= Version("0.5.0") + and OPTIONS["use_numbagg"] + and isinstance(values, np.ndarray) + # numbagg<0.7.0 uses ddof=1 only, but numpy uses ddof=0 by default + and ( + pycompat.mod_version("numbagg") >= Version("0.7.0") + or ("var" not in name and "std" not in name) + or kwargs.get("ddof", 0) == 1 + ) + # TODO: bool? + and values.dtype.kind in "uifc" + # and values.dtype.isnative + and (dtype is None or np.dtype(dtype) == values.dtype) + # numbagg.nanquantile only available after 0.8.0 and with linear method + and ( + name != "nanquantile" + or ( + pycompat.mod_version("numbagg") >= Version("0.8.0") + and kwargs.get("method", "linear") == "linear" + ) + ) + ): + import numbagg + + nba_func = getattr(numbagg, name, None) + if nba_func is not None: + # numbagg does not use dtype + kwargs.pop("dtype", None) + # prior to 0.7.0, numbagg did not support ddof; we ensure it's limited + # to ddof=1 above. + if pycompat.mod_version("numbagg") < Version("0.7.0"): + kwargs.pop("ddof", None) + if name == "nanquantile": + kwargs["quantiles"] = kwargs.pop("q") + kwargs.pop("method", None) + return nba_func(values, axis=axis, **kwargs) + if ( + _BOTTLENECK_AVAILABLE and OPTIONS["use_bottleneck"] and isinstance(values, np.ndarray) and bn_func is not None @@ -187,14 +244,14 @@ def _nanpolyfit_1d(arr, x, rcond=None): mask = np.isnan(arr) if not np.all(mask): out[:-1], resid, rank, _ = np.linalg.lstsq(x[~mask, :], arr[~mask], rcond=rcond) - out[-1] = resid if resid.size > 0 else np.nan + out[-1] = resid[0] if resid.size > 0 else np.nan warn_on_deficient_rank(rank, x.shape[1]) return out def warn_on_deficient_rank(rank, order): if rank != order: - warnings.warn("Polyfit may be poorly conditioned", np.RankWarning, stacklevel=2) + warnings.warn("Polyfit may be poorly conditioned", RankWarning, stacklevel=2) def least_squares(lhs, rhs, rcond=None, skipna=False): @@ -227,14 +284,15 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): return coeffs, residuals -nanmin = _create_bottleneck_method("nanmin") -nanmax = _create_bottleneck_method("nanmax") -nanmean = _create_bottleneck_method("nanmean") -nanmedian = _create_bottleneck_method("nanmedian") -nanvar = _create_bottleneck_method("nanvar") -nanstd = _create_bottleneck_method("nanstd") -nanprod = _create_bottleneck_method("nanprod") -nancumsum = _create_bottleneck_method("nancumsum") -nancumprod = _create_bottleneck_method("nancumprod") -nanargmin = _create_bottleneck_method("nanargmin") -nanargmax = _create_bottleneck_method("nanargmax") +nanmin = _create_method("nanmin") +nanmax = _create_method("nanmax") +nanmean = _create_method("nanmean") +nanmedian = _create_method("nanmedian") +nanvar = _create_method("nanvar") +nanstd = _create_method("nanstd") +nanprod = _create_method("nanprod") +nancumsum = _create_method("nancumsum") +nancumprod = _create_method("nancumprod") +nanargmin = _create_method("nanargmin") +nanargmax = _create_method("nanargmax") +nanquantile = _create_method("nanquantile") diff --git a/xarray/core/ops.py b/xarray/core/ops.py index 009616c5f12..c67b46692b8 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -4,6 +4,7 @@ NumPy's __array_ufunc__ and mixin classes instead of the unintuitive "inject" functions. """ + from __future__ import annotations import operator @@ -33,6 +34,8 @@ "and", "xor", "or", + "lshift", + "rshift", ] # methods which pass on the numpy return value unchanged @@ -51,7 +54,6 @@ "var", "median", ] -NAN_CUM_METHODS = ["cumsum", "cumprod"] # TODO: wrap take, dot, sort @@ -261,20 +263,6 @@ def inject_reduce_methods(cls): setattr(cls, name, func) -def inject_cum_methods(cls): - methods = [(name, getattr(duck_array_ops, name), True) for name in NAN_CUM_METHODS] - for name, f, include_skipna in methods: - numeric_only = getattr(f, "numeric_only", False) - func = cls._reduce_method(f, include_skipna, numeric_only) - func.__name__ = name - func.__doc__ = _CUM_DOCSTRING_TEMPLATE.format( - name=name, - cls=cls.__name__, - extra_args=cls._cum_extra_args_docstring.format(name=name), - ) - setattr(cls, name, func) - - def op_str(name): return f"__{name}__" @@ -314,16 +302,6 @@ def __init_subclass__(cls, **kwargs): inject_reduce_methods(cls) -class IncludeCumMethods: - __slots__ = () - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - - if getattr(cls, "_reduce_method", None): - inject_cum_methods(cls) - - class IncludeNumpySameMethods: __slots__ = () diff --git a/xarray/core/options.py b/xarray/core/options.py index eb0c56c7ee0..f5614104357 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -6,10 +6,8 @@ from xarray.core.utils import FrozenDict if TYPE_CHECKING: - try: - from matplotlib.colors import Colormap - except ImportError: - Colormap = str + from matplotlib.colors import Colormap + Options = Literal[ "arithmetic_join", "cmap_divergent", @@ -22,6 +20,7 @@ "display_expand_coords", "display_expand_data_vars", "display_expand_data", + "display_expand_groups", "display_expand_indexes", "display_default_indexes", "enable_cftimeindex", @@ -29,10 +28,13 @@ "keep_attrs", "warn_for_unclosed_files", "use_bottleneck", + "use_numbagg", + "use_opt_einsum", "use_flox", ] class T_Options(TypedDict): + arithmetic_broadcast: bool arithmetic_join: Literal["inner", "outer", "left", "right", "exact"] cmap_divergent: str | Colormap cmap_sequential: str | Colormap @@ -44,6 +46,7 @@ class T_Options(TypedDict): display_expand_coords: Literal["default", True, False] display_expand_data_vars: Literal["default", True, False] display_expand_data: Literal["default", True, False] + display_expand_groups: Literal["default", True, False] display_expand_indexes: Literal["default", True, False] display_default_indexes: Literal["default", True, False] enable_cftimeindex: bool @@ -52,9 +55,12 @@ class T_Options(TypedDict): warn_for_unclosed_files: bool use_bottleneck: bool use_flox: bool + use_numbagg: bool + use_opt_einsum: bool OPTIONS: T_Options = { + "arithmetic_broadcast": True, "arithmetic_join": "inner", "cmap_divergent": "RdBu_r", "cmap_sequential": "viridis", @@ -66,6 +72,7 @@ class T_Options(TypedDict): "display_expand_coords": "default", "display_expand_data_vars": "default", "display_expand_data": "default", + "display_expand_groups": "default", "display_expand_indexes": "default", "display_default_indexes": False, "enable_cftimeindex": True, @@ -74,6 +81,8 @@ class T_Options(TypedDict): "warn_for_unclosed_files": False, "use_bottleneck": True, "use_flox": True, + "use_numbagg": True, + "use_opt_einsum": True, } _JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"]) @@ -85,6 +94,7 @@ def _positive_integer(value: int) -> bool: _VALIDATORS = { + "arithmetic_broadcast": lambda value: isinstance(value, bool), "arithmetic_join": _JOIN_OPTIONS.__contains__, "display_max_rows": _positive_integer, "display_values_threshold": _positive_integer, @@ -100,6 +110,8 @@ def _positive_integer(value: int) -> bool: "file_cache_maxsize": _positive_integer, "keep_attrs": lambda choice: choice in [True, False, "default"], "use_bottleneck": lambda value: isinstance(value, bool), + "use_numbagg": lambda value: isinstance(value, bool), + "use_opt_einsum": lambda value: isinstance(value, bool), "use_flox": lambda value: isinstance(value, bool), "warn_for_unclosed_files": lambda value: isinstance(value, bool), } @@ -164,11 +176,11 @@ class set_options: cmap_divergent : str or matplotlib.colors.Colormap, default: "RdBu_r" Colormap to use for divergent data plots. If string, must be matplotlib built-in colormap. Can also be a Colormap object - (e.g. mpl.cm.magma) + (e.g. mpl.colormaps["magma"]) cmap_sequential : str or matplotlib.colors.Colormap, default: "viridis" Colormap to use for nondivergent data plots. If string, must be matplotlib built-in colormap. Can also be a Colormap object - (e.g. mpl.cm.magma) + (e.g. mpl.colormaps["magma"]) display_expand_attrs : {"default", True, False} Whether to expand the attributes section for display of ``DataArray`` or ``Dataset`` objects. Can be @@ -232,6 +244,11 @@ class set_options: use_flox : bool, default: True Whether to use ``numpy_groupies`` and `flox`` to accelerate groupby and resampling reductions. + use_numbagg : bool, default: True + Whether to use ``numbagg`` to accelerate reductions. + Takes precedence over ``use_bottleneck`` when both are True. + use_opt_einsum : bool, default: True + Whether to use ``opt_einsum`` to accelerate dot products. warn_for_unclosed_files : bool, default: False Whether or not to issue a warning when unclosed files are deallocated. This is mostly useful for debugging. @@ -244,10 +261,10 @@ class set_options: >>> with xr.set_options(display_width=40): ... print(ds) ... - + Size: 8kB Dimensions: (x: 1000) Coordinates: - * x (x) int64 0 1 2 ... 998 999 + * x (x) int64 8kB 0 1 ... 999 Data variables: *empty* diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 2f8612c5a9b..41311497f8b 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -4,19 +4,30 @@ import itertools import operator from collections.abc import Hashable, Iterable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict import numpy as np from xarray.core.alignment import align +from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset -from xarray.core.pycompat import is_dask_collection +from xarray.core.indexes import Index +from xarray.core.merge import merge +from xarray.core.utils import is_dask_collection +from xarray.core.variable import Variable if TYPE_CHECKING: from xarray.core.types import T_Xarray +class ExpectedDict(TypedDict): + shapes: dict[Hashable, int] + coords: set[Hashable] + data_vars: set[Hashable] + indexes: dict[Hashable, Index] + + def unzip(iterable): return zip(*iterable) @@ -31,7 +42,9 @@ def assert_chunks_compatible(a: Dataset, b: Dataset): def check_result_variables( - result: DataArray | Dataset, expected: Mapping[str, Any], kind: str + result: DataArray | Dataset, + expected: ExpectedDict, + kind: Literal["coords", "data_vars"], ): if kind == "coords": nice_str = "coordinate" @@ -144,6 +157,75 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping return slice(None) +def subset_dataset_to_block( + graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index +): + """ + Creates a task that subsets an xarray dataset to a block determined by chunk_index. + Block extents are determined by input_chunk_bounds. + Also subtasks that subset the constituent variables of a dataset. + """ + import dask + + # this will become [[name1, variable1], + # [name2, variable2], + # ...] + # which is passed to dict and then to Dataset + data_vars = [] + coords = [] + + chunk_tuple = tuple(chunk_index.values()) + chunk_dims_set = set(chunk_index) + variable: Variable + for name, variable in dataset.variables.items(): + # make a task that creates tuple of (dims, chunk) + if dask.is_dask_collection(variable.data): + # get task name for chunk + chunk = ( + variable.data.name, + *tuple(chunk_index[dim] for dim in variable.dims), + ) + + chunk_variable_task = (f"{name}-{gname}-{chunk[0]!r}",) + chunk_tuple + graph[chunk_variable_task] = ( + tuple, + [variable.dims, chunk, variable.attrs], + ) + else: + assert name in dataset.dims or variable.ndim == 0 + + # non-dask array possibly with dimensions chunked on other variables + # index into variable appropriately + subsetter = { + dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) + for dim in variable.dims + } + if set(variable.dims) < chunk_dims_set: + this_var_chunk_tuple = tuple(chunk_index[dim] for dim in variable.dims) + else: + this_var_chunk_tuple = chunk_tuple + + chunk_variable_task = ( + f"{name}-{gname}-{dask.base.tokenize(subsetter)}", + ) + this_var_chunk_tuple + # We are including a dimension coordinate, + # minimize duplication by not copying it in the graph for every chunk. + if variable.ndim == 0 or chunk_variable_task not in graph: + subset = variable.isel(subsetter) + graph[chunk_variable_task] = ( + tuple, + [subset.dims, subset._data, subset.attrs], + ) + + # this task creates dict mapping variable name to above tuple + if name in dataset._coord_names: + coords.append([name, chunk_variable_task]) + else: + data_vars.append([name, chunk_variable_task]) + + return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) + + def map_blocks( func: Callable[..., T_Xarray], obj: DataArray | Dataset, @@ -186,8 +268,9 @@ def map_blocks( Returns ------- - A single DataArray or Dataset with dask backend, reassembled from the outputs of the - function. + obj : same as obj + A single DataArray or Dataset with dask backend, reassembled from the outputs of the + function. Notes ----- @@ -214,7 +297,7 @@ def map_blocks( ... clim = gb.mean(dim="time") ... return gb - clim ... - >>> time = xr.cftime_range("1990-01", "1992-01", freq="M") + >>> time = xr.cftime_range("1990-01", "1992-01", freq="ME") >>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"]) >>> np.random.seed(123) >>> array = xr.DataArray( @@ -223,15 +306,15 @@ def map_blocks( ... coords={"time": time, "month": month}, ... ).chunk() >>> array.map_blocks(calculate_anomaly, template=array).compute() - + Size: 192B array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, -0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 , 0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108, 0.07673453, 0.22865714, 0.19063865, -0.0590131 ]) Coordinates: - * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 - month (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12 + * time (time) object 192B 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 192B 1 2 3 4 5 6 7 8 9 10 ... 3 4 5 6 7 8 9 10 11 12 Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments to the function being applied in ``xr.map_blocks()``: @@ -241,11 +324,11 @@ def map_blocks( ... kwargs={"groupby_type": "time.year"}, ... template=array, ... ) # doctest: +ELLIPSIS - + Size: 192B dask.array<-calculate_anomaly, shape=(24,), dtype=float64, chunksize=(24,), chunktype=numpy.ndarray> Coordinates: - * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 - month (time) int64 dask.array + * time (time) object 192B 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 192B dask.array """ def _wrapper( @@ -253,7 +336,7 @@ def _wrapper( args: list, kwargs: dict, arg_is_array: Iterable[bool], - expected: dict, + expected: ExpectedDict, ): """ Wrapper function that receives datasets in args; converts to dataarrays when necessary; @@ -267,6 +350,10 @@ def _wrapper( result = func(*converted_args, **kwargs) + merged_coordinates = merge( + [arg.coords for arg in args if isinstance(arg, (Dataset, DataArray))] + ).coords + # check all dims are present missing_dimensions = set(expected["shapes"]) - set(result.sizes) if missing_dimensions: @@ -282,12 +369,16 @@ def _wrapper( f"Received dimension {name!r} of length {result.sizes[name]}. " f"Expected length {expected['shapes'][name]}." ) - if name in expected["indexes"]: - expected_index = expected["indexes"][name] - if not index.equals(expected_index): - raise ValueError( - f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." - ) + + # ChainMap wants MutableMapping, but xindexes is Mapping + merged_indexes = collections.ChainMap( + expected["indexes"], merged_coordinates.xindexes # type: ignore[arg-type] + ) + expected_index = merged_indexes.get(name, None) + if expected_index is not None and not index.equals(expected_index): + raise ValueError( + f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." + ) # check that all expected variables were returned check_result_variables(result, expected, "coords") @@ -343,6 +434,10 @@ def _wrapper( dataarray_to_dataset(arg) if isinstance(arg, DataArray) else arg for arg in aligned ) + # rechunk any numpy variables appropriately + xarray_objs = tuple(arg.chunk(arg.chunksizes) for arg in xarray_objs) + + merged_coordinates = merge([arg.coords for arg in aligned]).coords _, npargs = unzip( sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0]) @@ -350,27 +445,37 @@ def _wrapper( # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) - input_indexes = dict(npargs[0]._indexes) for arg in xarray_objs[1:]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) - input_indexes.update(arg._indexes) + coordinates: Coordinates if template is None: # infer template by providing zero-shaped arrays template = infer_template(func, aligned[0], *args, **kwargs) - template_indexes = set(template._indexes) - preserved_indexes = template_indexes & set(input_indexes) - new_indexes = template_indexes - set(input_indexes) - indexes = {dim: input_indexes[dim] for dim in preserved_indexes} - indexes.update({k: template._indexes[k] for k in new_indexes}) + template_coords = set(template.coords) + preserved_coord_vars = template_coords & set(merged_coordinates) + new_coord_vars = template_coords - set(merged_coordinates) + + preserved_coords = merged_coordinates.to_dataset()[preserved_coord_vars] + # preserved_coords contains all coordinates variables that share a dimension + # with any index variable in preserved_indexes + # Drop any unneeded vars in a second pass, this is required for e.g. + # if the mapped function were to drop a non-dimension coordinate variable. + preserved_coords = preserved_coords.drop_vars( + tuple(k for k in preserved_coords.variables if k not in template_coords) + ) + + coordinates = merge( + (preserved_coords, template.coords.to_dataset()[new_coord_vars]) + ).coords output_chunks: Mapping[Hashable, tuple[int, ...]] = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks } else: # template xarray object has been provided with proper sizes and chunk shapes - indexes = dict(template._indexes) + coordinates = template.coords output_chunks = template.chunksizes if not output_chunks: raise ValueError( @@ -378,6 +483,13 @@ def _wrapper( " Please construct a template with appropriately chunked dask arrays." ) + new_indexes = set(template.xindexes) - set(merged_coordinates) + modified_indexes = set( + name + for name, xindex in coordinates.xindexes.items() + if not xindex.equals(merged_coordinates.xindexes.get(name, None)) + ) + for dim in output_chunks: if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]): raise ValueError( @@ -406,9 +518,7 @@ def _wrapper( new_layers: collections.defaultdict[str, dict[Any, Any]] = collections.defaultdict( dict ) - gname = "{}-{}".format( - dask.utils.funcname(func), dask.base.tokenize(npargs[0], args, kwargs) - ) + gname = f"{dask.utils.funcname(func)}-{dask.base.tokenize(npargs[0], args, kwargs)}" # map dims to list of chunk indexes ichunk = {dim: range(len(chunks_v)) for dim, chunks_v in input_chunks.items()} @@ -420,85 +530,41 @@ def _wrapper( dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items() } - def subset_dataset_to_block( - graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index - ): - """ - Creates a task that subsets an xarray dataset to a block determined by chunk_index. - Block extents are determined by input_chunk_bounds. - Also subtasks that subset the constituent variables of a dataset. - """ - - # this will become [[name1, variable1], - # [name2, variable2], - # ...] - # which is passed to dict and then to Dataset - data_vars = [] - coords = [] - - chunk_tuple = tuple(chunk_index.values()) - for name, variable in dataset.variables.items(): - # make a task that creates tuple of (dims, chunk) - if dask.is_dask_collection(variable.data): - # recursively index into dask_keys nested list to get chunk - chunk = variable.__dask_keys__() - for dim in variable.dims: - chunk = chunk[chunk_index[dim]] - - chunk_variable_task = (f"{name}-{gname}-{chunk[0]}",) + chunk_tuple - graph[chunk_variable_task] = ( - tuple, - [variable.dims, chunk, variable.attrs], - ) - else: - # non-dask array possibly with dimensions chunked on other variables - # index into variable appropriately - subsetter = { - dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) - for dim in variable.dims - } - subset = variable.isel(subsetter) - chunk_variable_task = ( - f"{name}-{gname}-{dask.base.tokenize(subset)}", - ) + chunk_tuple - graph[chunk_variable_task] = ( - tuple, - [subset.dims, subset, subset.attrs], - ) - - # this task creates dict mapping variable name to above tuple - if name in dataset._coord_names: - coords.append([name, chunk_variable_task]) - else: - data_vars.append([name, chunk_variable_task]) - - return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) - + computed_variables = set(template.variables) - set(coordinates.indexes) # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()): # mapping from dimension name to chunk index chunk_index = dict(zip(ichunk.keys(), chunk_tuple)) blocked_args = [ - subset_dataset_to_block(graph, gname, arg, input_chunk_bounds, chunk_index) - if isxr - else arg + ( + subset_dataset_to_block( + graph, gname, arg, input_chunk_bounds, chunk_index + ) + if isxr + else arg + ) for isxr, arg in zip(is_xarray, npargs) ] - # expected["shapes", "coords", "data_vars", "indexes"] are used to # raise nice error messages in _wrapper - expected = {} - # input chunk 0 along a dimension maps to output chunk 0 along the same dimension - # even if length of dimension is changed by the applied function - expected["shapes"] = { - k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks - } - expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] - expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] - expected["indexes"] = { - dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] - for dim in indexes + expected: ExpectedDict = { + # input chunk 0 along a dimension maps to output chunk 0 along the same dimension + # even if length of dimension is changed by the applied function + "shapes": { + k: output_chunks[k][v] + for k, v in chunk_index.items() + if k in output_chunks + }, + "data_vars": set(template.data_vars.keys()), + "coords": set(template.coords.keys()), + # only include new or modified indexes to minimize duplication of data, and graph size. + "indexes": { + dim: coordinates.xindexes[dim][ + _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) + ] + for dim in (new_indexes | modified_indexes) + }, } from_wrapper = (gname,) + chunk_tuple @@ -506,20 +572,16 @@ def subset_dataset_to_block( # mapping from variable name to dask graph key var_key_map: dict[Hashable, str] = {} - for name, variable in template.variables.items(): - if name in indexes: - continue + for name in computed_variables: + variable = template.variables[name] gname_l = f"{name}-{gname}" var_key_map[name] = gname_l - key: tuple[Any, ...] = (gname_l,) - for dim in variable.dims: - if dim in chunk_index: - key += (chunk_index[dim],) - else: - # unchunked dimensions in the input have one chunk in the result - # output can have new dimensions with exactly one chunk - key += (0,) + # unchunked dimensions in the input have one chunk in the result + # output can have new dimensions with exactly one chunk + key: tuple[Any, ...] = (gname_l,) + tuple( + chunk_index[dim] if dim in chunk_index else 0 for dim in variable.dims + ) # We're adding multiple new layers to the graph: # The first new layer is the result of the computation on @@ -544,12 +606,7 @@ def subset_dataset_to_block( }, ) - # TODO: benbovy - flexible indexes: make it work with custom indexes - # this will need to pass both indexes and coords to the Dataset constructor - result = Dataset( - coords={k: idx.to_pandas_index() for k, idx in indexes.items()}, - attrs=template.attrs, - ) + result = Dataset(coords=coordinates, attrs=template.attrs) for index in result._indexes: result[index].attrs = template[index].attrs diff --git a/xarray/core/pdcompat.py b/xarray/core/pdcompat.py index b20a96bb8d6..c09dd82b074 100644 --- a/xarray/core/pdcompat.py +++ b/xarray/core/pdcompat.py @@ -39,6 +39,7 @@ from typing import Literal import pandas as pd +from packaging.version import Version from xarray.coding import cftime_offsets @@ -82,6 +83,7 @@ def _convert_base_to_offset(base, freq, index): from xarray.coding.cftimeindex import CFTimeIndex if isinstance(index, pd.DatetimeIndex): + freq = cftime_offsets._new_to_legacy_freq(freq) freq = pd.tseries.frequencies.to_offset(freq) if isinstance(freq, pd.offsets.Tick): return pd.Timedelta(base * freq.nanos // freq.n) @@ -91,3 +93,15 @@ def _convert_base_to_offset(base, freq, index): return base * freq.as_timedelta() // freq.n else: raise ValueError("Can only resample using a DatetimeIndex or CFTimeIndex.") + + +def nanosecond_precision_timestamp(*args, **kwargs) -> pd.Timestamp: + """Return a nanosecond-precision Timestamp object. + + Note this function should no longer be needed after addressing GitHub issue + #7493. + """ + if Version(pd.__version__) >= Version("2.0.0"): + return pd.Timestamp(*args, **kwargs).as_unit("ns") + else: + return pd.Timestamp(*args, **kwargs) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py deleted file mode 100644 index 4a3f3638d14..00000000000 --- a/xarray/core/pycompat.py +++ /dev/null @@ -1,85 +0,0 @@ -from __future__ import annotations - -from importlib import import_module -from types import ModuleType -from typing import TYPE_CHECKING, Any, Literal - -import numpy as np -from packaging.version import Version - -from xarray.core.utils import is_duck_array, is_scalar, module_available - -integer_types = (int, np.integer) - -if TYPE_CHECKING: - ModType = Literal["dask", "pint", "cupy", "sparse"] - DuckArrayTypes = tuple[type[Any], ...] # TODO: improve this? maybe Generic - - -class DuckArrayModule: - """ - Solely for internal isinstance and version checks. - - Motivated by having to only import pint when required (as pint currently imports xarray) - https://github.com/pydata/xarray/pull/5561#discussion_r664815718 - """ - - module: ModuleType | None - version: Version - type: DuckArrayTypes - available: bool - - def __init__(self, mod: ModType) -> None: - duck_array_module: ModuleType | None = None - duck_array_version: Version - duck_array_type: DuckArrayTypes - try: - duck_array_module = import_module(mod) - duck_array_version = Version(duck_array_module.__version__) - - if mod == "dask": - duck_array_type = (import_module("dask.array").Array,) - elif mod == "pint": - duck_array_type = (duck_array_module.Quantity,) - elif mod == "cupy": - duck_array_type = (duck_array_module.ndarray,) - elif mod == "sparse": - duck_array_type = (duck_array_module.SparseArray,) - else: - raise NotImplementedError - - except ImportError: # pragma: no cover - duck_array_module = None - duck_array_version = Version("0.0.0") - duck_array_type = () - - self.module = duck_array_module - self.version = duck_array_version - self.type = duck_array_type - self.available = duck_array_module is not None - - -def array_type(mod: ModType) -> DuckArrayTypes: - """Quick wrapper to get the array class of the module.""" - return DuckArrayModule(mod).type - - -def mod_version(mod: ModType) -> Version: - """Quick wrapper to get the version of the module.""" - return DuckArrayModule(mod).version - - -def is_dask_collection(x): - if module_available("dask"): - from dask.base import is_dask_collection - - return is_dask_collection(x) - return False - - -def is_duck_dask_array(x): - return is_duck_array(x) and is_dask_collection(x) - - -def is_0d_dask_array(x): - return is_duck_dask_array(x) and is_scalar(x) diff --git a/xarray/core/resample.py b/xarray/core/resample.py index ad9b8379322..3bb158acfdb 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -63,7 +63,7 @@ def _drop_coords(self) -> T_Xarray: obj = self._obj for k, v in obj.coords.items(): if k != self._dim and self._dim in v.dims: - obj = obj.drop_vars(k) + obj = obj.drop_vars([k]) return obj def pad(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: @@ -84,8 +84,9 @@ def pad(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: padded : DataArray or Dataset """ obj = self._drop_coords() + (grouper,) = self.groupers return obj.reindex( - {self._dim: self._full_index}, method="pad", tolerance=tolerance + {self._dim: grouper.full_index}, method="pad", tolerance=tolerance ) ffill = pad @@ -108,8 +109,9 @@ def backfill(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray backfilled : DataArray or Dataset """ obj = self._drop_coords() + (grouper,) = self.groupers return obj.reindex( - {self._dim: self._full_index}, method="backfill", tolerance=tolerance + {self._dim: grouper.full_index}, method="backfill", tolerance=tolerance ) bfill = backfill @@ -133,8 +135,9 @@ def nearest(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: upsampled : DataArray or Dataset """ obj = self._drop_coords() + (grouper,) = self.groupers return obj.reindex( - {self._dim: self._full_index}, method="nearest", tolerance=tolerance + {self._dim: grouper.full_index}, method="nearest", tolerance=tolerance ) def interpolate(self, kind: InterpOptions = "linear") -> T_Xarray: @@ -170,8 +173,9 @@ def interpolate(self, kind: InterpOptions = "linear") -> T_Xarray: def _interpolate(self, kind="linear") -> T_Xarray: """Apply scipy.interpolate.interp1d along resampling dimension.""" obj = self._drop_coords() + (grouper,) = self.groupers return obj.interp( - coords={self._dim: self._full_index}, + coords={self._dim: grouper.full_index}, assume_sorted=True, method=kind, kwargs={"bounds_error": False}, @@ -184,6 +188,51 @@ class DataArrayResample(Resample["DataArray"], DataArrayGroupByBase, DataArrayRe specified dimension """ + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> DataArray: + """Reduce the items in this group by applying `func` along the + pre-defined resampling dimension. + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, axis=axis, **kwargs)` to return the result of collapsing + an np.ndarray over an integer valued axis. + dim : "...", str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. + keep_attrs : bool, optional + If True, the datasets's attributes (`attrs`) will be copied from + the original object to the new one. If False (default), the new + object will be returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : DataArray + Array with summarized data and the indicated dimension(s) + removed. + """ + return super().reduce( + func=func, + dim=dim, + axis=axis, + keep_attrs=keep_attrs, + keepdims=keepdims, + shortcut=shortcut, + **kwargs, + ) + def map( self, func: Callable[..., Any], @@ -232,15 +281,27 @@ def map( applied : DataArray The result of splitting, applying and combining this array. """ + return self._map_maybe_warn(func, args, shortcut, warn_squeeze=True, **kwargs) + + def _map_maybe_warn( + self, + func: Callable[..., Any], + args: tuple[Any, ...] = (), + shortcut: bool | None = False, + warn_squeeze: bool = True, + **kwargs: Any, + ) -> DataArray: # TODO: the argument order for Resample doesn't match that for its parent, # GroupBy - combined = super().map(func, shortcut=shortcut, args=args, **kwargs) + combined = super()._map_maybe_warn( + func, shortcut=shortcut, args=args, warn_squeeze=warn_squeeze, **kwargs + ) # If the aggregation function didn't drop the original resampling # dimension, then we need to do so before we can rename the proxy # dimension we used. if self._dim in combined.coords: - combined = combined.drop_vars(self._dim) + combined = combined.drop_vars([self._dim]) if RESAMPLE_DIM in combined.dims: combined = combined.rename({RESAMPLE_DIM: self._dim}) @@ -314,8 +375,18 @@ def map( applied : Dataset The result of splitting, applying and combining this dataset. """ + return self._map_maybe_warn(func, args, shortcut, warn_squeeze=True, **kwargs) + + def _map_maybe_warn( + self, + func: Callable[..., Any], + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + warn_squeeze: bool = True, + **kwargs: Any, + ) -> Dataset: # ignore shortcut if set (for now) - applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped()) + applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped(warn_squeeze)) combined = self._combine(applied) # If the aggregation function didn't drop the original resampling @@ -390,6 +461,27 @@ def reduce( **kwargs, ) + def _reduce_without_squeeze_warn( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> Dataset: + return super()._reduce_without_squeeze_warn( + func=func, + dim=dim, + axis=axis, + keep_attrs=keep_attrs, + keepdims=keepdims, + shortcut=shortcut, + **kwargs, + ) + def asfreq(self) -> Dataset: """Return values of original object at the new up-sampling frequency; essentially a re-index with new times set to NaN. diff --git a/xarray/core/resample_cftime.py b/xarray/core/resample_cftime.py index 43edbc08456..216bd8fca6b 100644 --- a/xarray/core/resample_cftime.py +++ b/xarray/core/resample_cftime.py @@ -1,4 +1,5 @@ """Resampling for CFTimeIndex. Does not support non-integer freq.""" + # The mechanisms for resampling CFTimeIndex was copied and adapted from # the source code defined in pandas.core.resample # @@ -45,7 +46,6 @@ from xarray.coding.cftime_offsets import ( BaseCFTimeOffset, - Day, MonthEnd, QuarterEnd, Tick, @@ -152,7 +152,10 @@ def first_items(self, index: CFTimeIndex): f"Got {self.loffset}." ) - labels = labels + pd.to_timedelta(self.loffset) + if isinstance(self.loffset, datetime.timedelta): + labels = labels + self.loffset + else: + labels = labels + to_offset(self.loffset) # check binner fits data if index[0] < datetime_bins[0]: @@ -254,8 +257,7 @@ def _adjust_bin_edges( labels: np.ndarray, ): """This is required for determining the bin edges resampling with - daily frequencies greater than one day, month end, and year end - frequencies. + month end, quarter end, and year end frequencies. Consider the following example. Let's say you want to downsample the time series with the following coordinates to month end frequency: @@ -283,14 +285,8 @@ def _adjust_bin_edges( The labels are still: CFTimeIndex([2000-01-31 00:00:00, 2000-02-29 00:00:00], dtype='object') - - This is also required for daily frequencies longer than one day and - year-end frequencies. """ - is_super_daily = isinstance(freq, (MonthEnd, QuarterEnd, YearEnd)) or ( - isinstance(freq, Day) and freq.n > 1 - ) - if is_super_daily: + if isinstance(freq, (MonthEnd, QuarterEnd, YearEnd)): if closed == "right": datetime_bins = datetime_bins + datetime.timedelta(days=1, microseconds=-1) if datetime_bins[-2] > index.max(): diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index b9375987f94..7bc492d4296 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -8,13 +8,18 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar import numpy as np +from packaging.version import Version from xarray.core import dtypes, duck_array_ops, utils from xarray.core.arithmetic import CoarsenArithmetic from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.pycompat import is_duck_dask_array from xarray.core.types import CoarsenBoundaryOptions, SideOptions, T_Xarray -from xarray.core.utils import either_dict_or_kwargs +from xarray.core.utils import ( + either_dict_or_kwargs, + is_duck_dask_array, + module_available, +) +from xarray.namedarray import pycompat try: import bottleneck @@ -100,6 +105,12 @@ class Rolling(Generic[T_Xarray]): __slots__ = ("obj", "window", "min_periods", "center", "pad", "dim") _attributes = ("window", "min_periods", "center", "pad", "dim") + dim: list[Hashable] + window: list[int] + center: list[bool] + obj: T_Xarray + min_periods: int + pad: bool def __init__( self, @@ -134,8 +145,8 @@ def __init__( ------- rolling : type of input argument """ - self.dim: list[Hashable] = [] - self.window: list[int] = [] + self.dim = [] + self.window = [] for d, w in windows.items(): self.dim.append(d) if w <= 0: @@ -146,6 +157,14 @@ def __init__( self.pad = self._mapping_to_list(pad, default=True) self.obj: T_Xarray = obj + missing_dims = tuple(dim for dim in self.dim if dim not in self.obj.dims) + if missing_dims: + # NOTE: we raise KeyError here but ValueError in Coarsen. + raise KeyError( + f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} " + f"dimensions {tuple(self.obj.dims)}" + ) + # attributes if min_periods is not None and min_periods <= 0: raise ValueError("min_periods must be greater than zero or None") @@ -178,7 +197,13 @@ def _reduce_method( # type: ignore[misc] name: str, fillna: Any, rolling_agg_func: Callable | None = None ) -> Callable[..., T_Xarray]: """Constructs reduction methods built on a numpy reduction function (e.g. sum), - a bottleneck reduction function (e.g. move_sum), or a Rolling reduction (_mean). + a numbagg reduction function (e.g. move_sum), a bottleneck reduction function + (e.g. move_sum), or a Rolling reduction (_mean). + + The logic here for which function to run is quite diffuse, across this method & + _array_reduce. Arguably we could refactor this. But one constraint is that we + need context of xarray options, of the functions each library offers, of + the array (e.g. dtype). """ if rolling_agg_func: array_agg_func = None @@ -186,14 +211,21 @@ def _reduce_method( # type: ignore[misc] array_agg_func = getattr(duck_array_ops, name) bottleneck_move_func = getattr(bottleneck, "move_" + name, None) + if module_available("numbagg"): + import numbagg + + numbagg_move_func = getattr(numbagg, "move_" + name, None) + else: + numbagg_move_func = None def method(self, keep_attrs=None, **kwargs): keep_attrs = self._get_keep_attrs(keep_attrs) - return self._numpy_or_bottleneck_reduce( - array_agg_func, - bottleneck_move_func, - rolling_agg_func, + return self._array_reduce( + array_agg_func=array_agg_func, + bottleneck_move_func=bottleneck_move_func, + numbagg_move_func=numbagg_move_func, + rolling_agg_func=rolling_agg_func, keep_attrs=keep_attrs, fillna=fillna, **kwargs, @@ -204,9 +236,9 @@ def method(self, keep_attrs=None, **kwargs): return method def _mean(self, keep_attrs, **kwargs): - result = self.sum(keep_attrs=False, **kwargs) / self.count( - keep_attrs=False - ).astype(self.obj.dtype, copy=False) + result = self.sum(keep_attrs=False, **kwargs) / duck_array_ops.astype( + self.count(keep_attrs=False), dtype=self.obj.dtype, copy=False + ) if keep_attrs: result.attrs = self.obj.attrs return result @@ -453,7 +485,7 @@ def construct( >>> rolling = da.rolling(b=3) >>> rolling.construct("window_dim") - + Size: 192B array([[[nan, nan, 0.], [nan, 0., 1.], [ 0., 1., 2.], @@ -467,7 +499,7 @@ def construct( >>> rolling = da.rolling(b=3, center=True) >>> rolling.construct("window_dim") - + Size: 192B array([[[nan, 0., 1.], [ 0., 1., 2.], [ 1., 2., 3.], @@ -511,7 +543,7 @@ def _construct( window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim} window_dims = self._mapping_to_list( - window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # https://github.com/python/mypy/issues/12506 + window_dim, allow_default=False, allow_allsame=False ) strides = self._mapping_to_list(stride, default=1) @@ -565,7 +597,7 @@ def reduce( >>> da = xr.DataArray(np.arange(8).reshape(2, 4), dims=("a", "b")) >>> rolling = da.rolling(b=3) >>> rolling.construct("window_dim") - + Size: 192B array([[[nan, nan, 0.], [nan, 0., 1.], [ 0., 1., 2.], @@ -578,14 +610,14 @@ def reduce( Dimensions without coordinates: a, b, window_dim >>> rolling.reduce(np.sum) - + Size: 64B array([[nan, nan, 3., 6.], [nan, nan, 15., 18.]]) Dimensions without coordinates: a, b >>> rolling = da.rolling(b=3, min_periods=1) >>> rolling.reduce(np.nansum) - + Size: 64B array([[ 0., 1., 3., 6.], [ 4., 9., 15., 18.]]) Dimensions without coordinates: a, b @@ -608,9 +640,8 @@ def reduce( obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna ) - result = windows.reduce( - func, dim=list(rolling_dim.values()), keep_attrs=keep_attrs, **kwargs - ) + dim = list(rolling_dim.values()) + result = windows.reduce(func, dim=dim, keep_attrs=keep_attrs, **kwargs) # Find valid windows based on count. counts = self._counts(keep_attrs=False) @@ -627,6 +658,7 @@ def _counts(self, keep_attrs: bool | None) -> DataArray: # array is faster to be reduced than object array. # The use of skipna==False is also faster since it does not need to # copy the strided array. + dim = list(rolling_dim.values()) counts = ( self.obj.notnull(keep_attrs=keep_attrs) .rolling( @@ -635,13 +667,51 @@ def _counts(self, keep_attrs: bool | None) -> DataArray: pad={d: self.pad[i] for i, d in enumerate(self.dim)}, ) .construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs) - .sum(dim=list(rolling_dim.values()), skipna=False, keep_attrs=keep_attrs) + .sum(dim=dim, skipna=False, keep_attrs=keep_attrs) ) return counts - def _bottleneck_reduce(self, func, keep_attrs, **kwargs): - from xarray.core.dataarray import DataArray + def _numbagg_reduce(self, func, keep_attrs, **kwargs): + # Some of this is copied from `_bottleneck_reduce`, we could reduce this as part + # of a wider refactor. + + axis = self.obj.get_axis_num(self.dim[0]) + + padded = self.obj.variable + if self.center[0]: + if is_duck_dask_array(padded.data): + # workaround to make the padded chunk size larger than + # self.window - 1 + shift = -(self.window[0] + 1) // 2 + offset = (self.window[0] - 1) // 2 + valid = (slice(None),) * axis + ( + slice(offset, offset + self.obj.shape[axis]), + ) + else: + shift = (-self.window[0] // 2) + 1 + valid = (slice(None),) * axis + (slice(-shift, None),) + padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant") + if is_duck_dask_array(padded.data) and False: + raise AssertionError("should not be reachable") + else: + values = func( + padded.data, + window=self.window[0], + min_count=self.min_periods, + axis=axis, + ) + + if self.center[0]: + values = values[valid] + + attrs = self.obj.attrs if keep_attrs else {} + + return self.obj.__class__( + values, self.obj.coords, attrs=attrs, name=self.obj.name + ) + + def _bottleneck_reduce(self, func, keep_attrs, **kwargs): # bottleneck doesn't allow min_count to be 0, although it should # work the same as if min_count = 1 # Note bottleneck only works with 1d-rolling. @@ -673,6 +743,11 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs): values = func( padded.data, window=self.window[0], min_count=min_count, axis=axis ) + # index 0 is at the rightmost edge of the window + # need to reverse index here + # see GH #8541 + if func in [bottleneck.move_argmin, bottleneck.move_argmax]: + values = self.window[0] - 1 - values if self.center[0]: values = values[valid] @@ -680,14 +755,15 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs): attrs = self.obj.attrs if keep_attrs else {} output_selector = self._get_output_dim_selector() - return DataArray(values, self.obj.coords, attrs=attrs, name=self.obj.name).isel( - output_selector - ) + return type(self.obj)( + values, self.obj.coords, attrs=attrs, name=self.obj.name + ).isel(output_selector) - def _numpy_or_bottleneck_reduce( + def _array_reduce( self, array_agg_func, bottleneck_move_func, + numbagg_move_func, rolling_agg_func, keep_attrs, fillna, @@ -703,20 +779,51 @@ def _numpy_or_bottleneck_reduce( ) del kwargs["dim"] + if ( + OPTIONS["use_numbagg"] + and module_available("numbagg") + and pycompat.mod_version("numbagg") >= Version("0.6.3") + and numbagg_move_func is not None + # TODO: we could at least allow this for the equivalent of `apply_ufunc`'s + # "parallelized". `rolling_exp` does this, as an example (but rolling_exp is + # much simpler) + and not is_duck_dask_array(self.obj.data) + # Numbagg doesn't handle object arrays and generally has dtype consistency, + # so doesn't deal well with bool arrays which are expected to change type. + and self.obj.data.dtype.kind not in "ObMm" + # TODO: we could also allow this, probably as part of a refactoring of this + # module, so we can use the machinery in `self.reduce`. + and self.ndim == 1 + ): + import numbagg + + # Numbagg has a default ddof of 1. I (@max-sixty) think we should make + # this the default in xarray too, but until we do, don't use numbagg for + # std and var unless ddof is set to 1. + if ( + numbagg_move_func not in [numbagg.move_std, numbagg.move_var] + or kwargs.get("ddof") == 1 + ): + return self._numbagg_reduce( + numbagg_move_func, keep_attrs=keep_attrs, **kwargs + ) + if ( OPTIONS["use_bottleneck"] and bottleneck_move_func is not None and not is_duck_dask_array(self.obj.data) and self.ndim == 1 ): - # TODO: renable bottleneck with dask after the issues + # TODO: re-enable bottleneck with dask after the issues # underlying https://github.com/pydata/xarray/issues/2940 are # fixed. return self._bottleneck_reduce( bottleneck_move_func, keep_attrs=keep_attrs, **kwargs ) + if rolling_agg_func: return rolling_agg_func(self, keep_attrs=self._get_keep_attrs(keep_attrs)) + if fillna is not None: if fillna is dtypes.INF: fillna = dtypes.get_pos_infinity(self.obj.dtype, max_for_int=True) @@ -845,7 +952,7 @@ def _counts(self, keep_attrs: bool | None) -> Dataset: DataArrayRolling._counts, keep_attrs=keep_attrs ) - def _numpy_or_bottleneck_reduce( + def _array_reduce( self, array_agg_func, bottleneck_move_func, @@ -855,7 +962,7 @@ def _numpy_or_bottleneck_reduce( ): return self._dataset_implementation( functools.partial( - DataArrayRolling._numpy_or_bottleneck_reduce, + DataArrayRolling._array_reduce, array_agg_func=array_agg_func, bottleneck_move_func=bottleneck_move_func, rolling_agg_func=rolling_agg_func, @@ -905,7 +1012,7 @@ def construct( window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim} window_dims = self._mapping_to_list( - window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # https://github.com/python/mypy/issues/12506 + window_dim, allow_default=False, allow_allsame=False ) strides = self._mapping_to_list(stride, default=1) @@ -930,11 +1037,14 @@ def construct( if not keep_attrs: dataset[key].attrs = {} + # Need to stride coords as well. TODO: is there a better way? + coords = self.obj.isel( + {d: slice(None, None, s) for d, s in zip(self.dim, strides)} + ).coords + attrs = self.obj.attrs if keep_attrs else {} - return Dataset(dataset, coords=self.obj.coords, attrs=attrs).isel( - {d: slice(None, None, s) for d, s in zip(self.dim, strides)} - ) + return Dataset(dataset, coords=coords, attrs=attrs) class Coarsen(CoarsenArithmetic, Generic[T_Xarray]): @@ -956,6 +1066,10 @@ class Coarsen(CoarsenArithmetic, Generic[T_Xarray]): ) _attributes = ("windows", "side", "trim_excess") obj: T_Xarray + windows: Mapping[Hashable, int] + side: SideOptions | Mapping[Hashable, SideOptions] + boundary: CoarsenBoundaryOptions + coord_func: Mapping[Hashable, str | Callable] def __init__( self, @@ -985,23 +1099,28 @@ def __init__( Returns ------- coarsen + """ self.obj = obj self.windows = windows self.side = side self.boundary = boundary - absent_dims = [dim for dim in windows.keys() if dim not in self.obj.dims] - if absent_dims: + missing_dims = tuple(dim for dim in windows.keys() if dim not in self.obj.dims) + if missing_dims: raise ValueError( - f"Dimensions {absent_dims!r} not found in {self.obj.__class__.__name__}." + f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} " + f"dimensions {tuple(self.obj.dims)}" ) - if not utils.is_dict_like(coord_func): - coord_func = {d: coord_func for d in self.obj.dims} # type: ignore[misc] + + if utils.is_dict_like(coord_func): + coord_func_map = coord_func + else: + coord_func_map = {d: coord_func for d in self.obj.dims} for c in self.obj.coords: - if c not in coord_func: - coord_func[c] = duck_array_ops.mean # type: ignore[index] - self.coord_func: Mapping[Hashable, str | Callable] = coord_func + if c not in coord_func_map: + coord_func_map[c] = duck_array_ops.mean # type: ignore[index] + self.coord_func = coord_func_map def _get_keep_attrs(self, keep_attrs): if keep_attrs is None: @@ -1051,7 +1170,7 @@ def construct( -------- >>> da = xr.DataArray(np.arange(24), dims="time") >>> da.coarsen(time=12).construct(time=("year", "month")) - + Size: 192B array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]) Dimensions without coordinates: year, month @@ -1207,7 +1326,7 @@ def reduce( >>> da = xr.DataArray(np.arange(8).reshape(2, 4), dims=("a", "b")) >>> coarsen = da.coarsen(b=2) >>> coarsen.reduce(np.sum) - + Size: 32B array([[ 1, 5], [ 9, 13]]) Dimensions without coordinates: a, b diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index 91edd3acb7c..4e085a0a7eb 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -6,76 +6,47 @@ import numpy as np from packaging.version import Version +from xarray.core.computation import apply_ufunc from xarray.core.options import _get_keep_attrs from xarray.core.pdcompat import count_not_none -from xarray.core.pycompat import is_duck_dask_array from xarray.core.types import T_DataWithCoords +from xarray.core.utils import module_available +from xarray.namedarray import pycompat -def _get_alpha(com=None, span=None, halflife=None, alpha=None): - # pandas defines in terms of com (converting to alpha in the algo) - # so use its function to get a com and then convert to alpha - - com = _get_center_of_mass(com, span, halflife, alpha) - return 1 / (1 + com) - - -def move_exp_nanmean(array, *, axis, alpha): - if is_duck_dask_array(array): - raise TypeError("rolling_exp is not currently support for dask-like arrays") - import numbagg - - # No longer needed in numbag > 0.2.0; remove in time - if axis == (): - return array.astype(np.float64) - else: - return numbagg.move_exp_nanmean(array, axis=axis, alpha=alpha) - - -def move_exp_nansum(array, *, axis, alpha): - if is_duck_dask_array(array): - raise TypeError("rolling_exp is not currently supported for dask-like arrays") - import numbagg - - # numbagg <= 0.2.0 did not have a __version__ attribute - if Version(getattr(numbagg, "__version__", "0.1.0")) < Version("0.2.0"): - raise ValueError("`rolling_exp(...).sum() requires numbagg>=0.2.1.") - - return numbagg.move_exp_nansum(array, axis=axis, alpha=alpha) - - -def _get_center_of_mass(comass, span, halflife, alpha): +def _get_alpha( + com: float | None = None, + span: float | None = None, + halflife: float | None = None, + alpha: float | None = None, +) -> float: """ - Vendored from pandas.core.window.common._get_center_of_mass - - See licenses/PANDAS_LICENSE for the function's license + Convert com, span, halflife to alpha. """ - valid_count = count_not_none(comass, span, halflife, alpha) + valid_count = count_not_none(com, span, halflife, alpha) if valid_count > 1: - raise ValueError("comass, span, halflife, and alpha are mutually exclusive") + raise ValueError("com, span, halflife, and alpha are mutually exclusive") - # Convert to center of mass; domain checks ensure 0 < alpha <= 1 - if comass is not None: - if comass < 0: - raise ValueError("comass must satisfy: comass >= 0") + # Convert to alpha + if com is not None: + if com < 0: + raise ValueError("commust satisfy: com>= 0") + return 1 / (com + 1) elif span is not None: if span < 1: raise ValueError("span must satisfy: span >= 1") - comass = (span - 1) / 2.0 + return 2 / (span + 1) elif halflife is not None: if halflife <= 0: raise ValueError("halflife must satisfy: halflife > 0") - decay = 1 - np.exp(np.log(0.5) / halflife) - comass = 1 / decay - 1 + return 1 - np.exp(np.log(0.5) / halflife) elif alpha is not None: - if alpha <= 0 or alpha > 1: + if not 0 < alpha <= 1: raise ValueError("alpha must satisfy: 0 < alpha <= 1") - comass = (1.0 - alpha) / alpha + return alpha else: raise ValueError("Must pass one of comass, span, halflife, or alpha") - return float(comass) - class RollingExp(Generic[T_DataWithCoords]): """ @@ -104,11 +75,31 @@ def __init__( obj: T_DataWithCoords, windows: Mapping[Any, int | float], window_type: str = "span", + min_weight: float = 0.0, ): + if not module_available("numbagg"): + raise ImportError( + "numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed" + ) + elif pycompat.mod_version("numbagg") < Version("0.2.1"): + raise ImportError( + f"numbagg >= 0.2.1 is required for rolling_exp but currently version {pycompat.mod_version('numbagg')} is installed" + ) + elif pycompat.mod_version("numbagg") < Version("0.3.1") and min_weight > 0: + raise ImportError( + f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {pycompat.mod_version('numbagg')} is installed" + ) + self.obj: T_DataWithCoords = obj dim, window = next(iter(windows.items())) self.dim = dim self.alpha = _get_alpha(**{window_type: window}) + self.min_weight = min_weight + # Don't pass min_weight=0 so we can support older versions of numbagg + kwargs = dict(alpha=self.alpha, axis=-1) + if min_weight > 0: + kwargs["min_weight"] = min_weight + self.kwargs = kwargs def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords: """ @@ -125,17 +116,28 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords: -------- >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") >>> da.rolling_exp(x=2, window_type="span").mean() - + Size: 40B array([1. , 1. , 1.69230769, 1.9 , 1.96694215]) Dimensions without coordinates: x """ + import numbagg + if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) - return self.obj.reduce( - move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs - ) + dim_order = self.obj.dims + + return apply_ufunc( + numbagg.move_exp_nanmean, + self.obj, + input_core_dims=[[self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=keep_attrs, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords: """ @@ -152,14 +154,159 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords: -------- >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") >>> da.rolling_exp(x=2, window_type="span").sum() - + Size: 40B array([1. , 1.33333333, 2.44444444, 2.81481481, 2.9382716 ]) Dimensions without coordinates: x """ + import numbagg + if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) - return self.obj.reduce( - move_exp_nansum, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs - ) + dim_order = self.obj.dims + + return apply_ufunc( + numbagg.move_exp_nansum, + self.obj, + input_core_dims=[[self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=keep_attrs, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) + + def std(self) -> T_DataWithCoords: + """ + Exponentially weighted moving standard deviation. + + `keep_attrs` is always True for this method. Drop attrs separately to remove attrs. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").std() + Size: 40B + array([ nan, 0. , 0.67936622, 0.42966892, 0.25389527]) + Dimensions without coordinates: x + """ + + if pycompat.mod_version("numbagg") < Version("0.4.0"): + raise ImportError( + f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed" + ) + import numbagg + + dim_order = self.obj.dims + + return apply_ufunc( + numbagg.move_exp_nanstd, + self.obj, + input_core_dims=[[self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=True, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) + + def var(self) -> T_DataWithCoords: + """ + Exponentially weighted moving variance. + + `keep_attrs` is always True for this method. Drop attrs separately to remove attrs. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").var() + Size: 40B + array([ nan, 0. , 0.46153846, 0.18461538, 0.06446281]) + Dimensions without coordinates: x + """ + if pycompat.mod_version("numbagg") < Version("0.4.0"): + raise ImportError( + f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {pycompat.mod_version('numbagg')} is installed" + ) + dim_order = self.obj.dims + import numbagg + + return apply_ufunc( + numbagg.move_exp_nanvar, + self.obj, + input_core_dims=[[self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=True, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) + + def cov(self, other: T_DataWithCoords) -> T_DataWithCoords: + """ + Exponentially weighted moving covariance. + + `keep_attrs` is always True for this method. Drop attrs separately to remove attrs. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").cov(da**2) + Size: 40B + array([ nan, 0. , 1.38461538, 0.55384615, 0.19338843]) + Dimensions without coordinates: x + """ + + if pycompat.mod_version("numbagg") < Version("0.4.0"): + raise ImportError( + f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {pycompat.mod_version('numbagg')} is installed" + ) + dim_order = self.obj.dims + import numbagg + + return apply_ufunc( + numbagg.move_exp_nancov, + self.obj, + other, + input_core_dims=[[self.dim], [self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=True, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) + + def corr(self, other: T_DataWithCoords) -> T_DataWithCoords: + """ + Exponentially weighted moving correlation. + + `keep_attrs` is always True for this method. Drop attrs separately to remove attrs. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").corr(da.shift(x=1)) + Size: 40B + array([ nan, nan, nan, 0.4330127 , 0.48038446]) + Dimensions without coordinates: x + """ + + if pycompat.mod_version("numbagg") < Version("0.4.0"): + raise ImportError( + f"numbagg >= 0.4.0 is required for rolling_exp().corr(), currently {pycompat.mod_version('numbagg')} is installed" + ) + dim_order = self.obj.dims + import numbagg + + return apply_ufunc( + numbagg.move_exp_nancorr, + self.obj, + other, + input_core_dims=[[self.dim], [self.dim]], + kwargs=self.kwargs, + output_core_dims=[[self.dim]], + keep_attrs=True, + on_missing_core_dim="copy", + dask="parallelized", + ).transpose(*dim_order) diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py new file mode 100644 index 00000000000..b3e6e43f306 --- /dev/null +++ b/xarray/core/treenode.py @@ -0,0 +1,679 @@ +from __future__ import annotations + +import sys +from collections.abc import Iterator, Mapping +from pathlib import PurePosixPath +from typing import ( + TYPE_CHECKING, + Generic, + TypeVar, +) + +from xarray.core.utils import Frozen, is_dict_like + +if TYPE_CHECKING: + from xarray.core.types import T_DataArray + + +class InvalidTreeError(Exception): + """Raised when user attempts to create an invalid tree in some way.""" + + +class NotFoundInTreeError(ValueError): + """Raised when operation can't be completed because one node is not part of the expected tree.""" + + +class NodePath(PurePosixPath): + """Represents a path from one node to another within a tree.""" + + def __init__(self, *pathsegments): + if sys.version_info >= (3, 12): + super().__init__(*pathsegments) + else: + super().__new__(PurePosixPath, *pathsegments) + if self.drive: + raise ValueError("NodePaths cannot have drives") + + if self.root not in ["/", ""]: + raise ValueError( + 'Root of NodePath can only be either "/" or "", with "" meaning the path is relative.' + ) + # TODO should we also forbid suffixes to avoid node names with dots in them? + + +Tree = TypeVar("Tree", bound="TreeNode") + + +class TreeNode(Generic[Tree]): + """ + Base class representing a node of a tree, with methods for traversing and altering the tree. + + This class stores no data, it has only parents and children attributes, and various methods. + + Stores child nodes in an dict, ensuring that equality checks between trees + and order of child nodes is preserved (since python 3.7). + + Nodes themselves are intrinsically unnamed (do not possess a ._name attribute), but if the node has a parent you can + find the key it is stored under via the .name property. + + The .parent attribute is read-only: to replace the parent using public API you must set this node as the child of a + new parent using `new_parent.children[name] = child_node`, or to instead detach from the current parent use + `child_node.orphan()`. + + This class is intended to be subclassed by DataTree, which will overwrite some of the inherited behaviour, + in particular to make names an inherent attribute, and allow setting parents directly. The intention is to mirror + the class structure of xarray.Variable & xarray.DataArray, where Variable is unnamed but DataArray is (optionally) + named. + + Also allows access to any other node in the tree via unix-like paths, including upwards referencing via '../'. + + (This class is heavily inspired by the anytree library's NodeMixin class.) + + """ + + _parent: Tree | None + _children: dict[str, Tree] + + def __init__(self, children: Mapping[str, Tree] | None = None): + """Create a parentless node.""" + self._parent = None + self._children = {} + if children is not None: + self.children = children + + @property + def parent(self) -> Tree | None: + """Parent of this node.""" + return self._parent + + def _set_parent( + self, new_parent: Tree | None, child_name: str | None = None + ) -> None: + # TODO is it possible to refactor in a way that removes this private method? + + if new_parent is not None and not isinstance(new_parent, TreeNode): + raise TypeError( + "Parent nodes must be of type DataTree or None, " + f"not type {type(new_parent)}" + ) + + old_parent = self._parent + if new_parent is not old_parent: + self._check_loop(new_parent) + self._detach(old_parent) + self._attach(new_parent, child_name) + + def _check_loop(self, new_parent: Tree | None) -> None: + """Checks that assignment of this new parent will not create a cycle.""" + if new_parent is not None: + if new_parent is self: + raise InvalidTreeError( + f"Cannot set parent, as node {self} cannot be a parent of itself." + ) + + if self._is_descendant_of(new_parent): + raise InvalidTreeError( + "Cannot set parent, as intended parent is already a descendant of this node." + ) + + def _is_descendant_of(self, node: Tree) -> bool: + return any(n is self for n in node.parents) + + def _detach(self, parent: Tree | None) -> None: + if parent is not None: + self._pre_detach(parent) + parents_children = parent.children + parent._children = { + name: child + for name, child in parents_children.items() + if child is not self + } + self._parent = None + self._post_detach(parent) + + def _attach(self, parent: Tree | None, child_name: str | None = None) -> None: + if parent is not None: + if child_name is None: + raise ValueError( + "To directly set parent, child needs a name, but child is unnamed" + ) + + self._pre_attach(parent) + parentchildren = parent._children + assert not any( + child is self for child in parentchildren + ), "Tree is corrupt." + parentchildren[child_name] = self + self._parent = parent + self._post_attach(parent) + else: + self._parent = None + + def orphan(self) -> None: + """Detach this node from its parent.""" + self._set_parent(new_parent=None) + + @property + def children(self: Tree) -> Mapping[str, Tree]: + """Child nodes of this node, stored under a mapping via their names.""" + return Frozen(self._children) + + @children.setter + def children(self: Tree, children: Mapping[str, Tree]) -> None: + self._check_children(children) + children = {**children} + + old_children = self.children + del self.children + try: + self._pre_attach_children(children) + for name, child in children.items(): + child._set_parent(new_parent=self, child_name=name) + self._post_attach_children(children) + assert len(self.children) == len(children) + except Exception: + # if something goes wrong then revert to previous children + self.children = old_children + raise + + @children.deleter + def children(self) -> None: + # TODO this just detaches all the children, it doesn't actually delete them... + children = self.children + self._pre_detach_children(children) + for child in self.children.values(): + child.orphan() + assert len(self.children) == 0 + self._post_detach_children(children) + + @staticmethod + def _check_children(children: Mapping[str, Tree]) -> None: + """Check children for correct types and for any duplicates.""" + if not is_dict_like(children): + raise TypeError( + "children must be a dict-like mapping from names to node objects" + ) + + seen = set() + for name, child in children.items(): + if not isinstance(child, TreeNode): + raise TypeError( + f"Cannot add object {name}. It is of type {type(child)}, " + "but can only add children of type DataTree" + ) + + childid = id(child) + if childid not in seen: + seen.add(childid) + else: + raise InvalidTreeError( + f"Cannot add same node {name} multiple times as different children." + ) + + def __repr__(self) -> str: + return f"TreeNode(children={dict(self._children)})" + + def _pre_detach_children(self: Tree, children: Mapping[str, Tree]) -> None: + """Method call before detaching `children`.""" + pass + + def _post_detach_children(self: Tree, children: Mapping[str, Tree]) -> None: + """Method call after detaching `children`.""" + pass + + def _pre_attach_children(self: Tree, children: Mapping[str, Tree]) -> None: + """Method call before attaching `children`.""" + pass + + def _post_attach_children(self: Tree, children: Mapping[str, Tree]) -> None: + """Method call after attaching `children`.""" + pass + + def _iter_parents(self: Tree) -> Iterator[Tree]: + """Iterate up the tree, starting from the current node.""" + node: Tree | None = self.parent + while node is not None: + yield node + node = node.parent + + def iter_lineage(self: Tree) -> tuple[Tree, ...]: + """Iterate up the tree, starting from the current node.""" + from warnings import warn + + warn( + "`iter_lineage` has been deprecated, and in the future will raise an error." + "Please use `parents` from now on.", + DeprecationWarning, + ) + return tuple((self, *self.parents)) + + @property + def lineage(self: Tree) -> tuple[Tree, ...]: + """All parent nodes and their parent nodes, starting with the closest.""" + from warnings import warn + + warn( + "`lineage` has been deprecated, and in the future will raise an error." + "Please use `parents` from now on.", + DeprecationWarning, + ) + return self.iter_lineage() + + @property + def parents(self: Tree) -> tuple[Tree, ...]: + """All parent nodes and their parent nodes, starting with the closest.""" + return tuple(self._iter_parents()) + + @property + def ancestors(self: Tree) -> tuple[Tree, ...]: + """All parent nodes and their parent nodes, starting with the most distant.""" + + from warnings import warn + + warn( + "`ancestors` has been deprecated, and in the future will raise an error." + "Please use `parents`. Example: `tuple(reversed(node.parents))`", + DeprecationWarning, + ) + return tuple((*reversed(self.parents), self)) + + @property + def root(self: Tree) -> Tree: + """Root node of the tree""" + node = self + while node.parent is not None: + node = node.parent + return node + + @property + def is_root(self) -> bool: + """Whether this node is the tree root.""" + return self.parent is None + + @property + def is_leaf(self) -> bool: + """ + Whether this node is a leaf node. + + Leaf nodes are defined as nodes which have no children. + """ + return self.children == {} + + @property + def leaves(self: Tree) -> tuple[Tree, ...]: + """ + All leaf nodes. + + Leaf nodes are defined as nodes which have no children. + """ + return tuple([node for node in self.subtree if node.is_leaf]) + + @property + def siblings(self: Tree) -> dict[str, Tree]: + """ + Nodes with the same parent as this node. + """ + if self.parent: + return { + name: child + for name, child in self.parent.children.items() + if child is not self + } + else: + return {} + + @property + def subtree(self: Tree) -> Iterator[Tree]: + """ + An iterator over all nodes in this tree, including both self and all descendants. + + Iterates depth-first. + + See Also + -------- + DataTree.descendants + """ + from xarray.datatree_.datatree import iterators + + return iterators.PreOrderIter(self) + + @property + def descendants(self: Tree) -> tuple[Tree, ...]: + """ + Child nodes and all their child nodes. + + Returned in depth-first order. + + See Also + -------- + DataTree.subtree + """ + all_nodes = tuple(self.subtree) + this_node, *descendants = all_nodes + return tuple(descendants) + + @property + def level(self: Tree) -> int: + """ + Level of this node. + + Level means number of parent nodes above this node before reaching the root. + The root node is at level 0. + + Returns + ------- + level : int + + See Also + -------- + depth + width + """ + return len(self.parents) + + @property + def depth(self: Tree) -> int: + """ + Maximum level of this tree. + + Measured from the root, which has a depth of 0. + + Returns + ------- + depth : int + + See Also + -------- + level + width + """ + return max(node.level for node in self.root.subtree) + + @property + def width(self: Tree) -> int: + """ + Number of nodes at this level in the tree. + + Includes number of immediate siblings, but also "cousins" in other branches and so-on. + + Returns + ------- + depth : int + + See Also + -------- + level + depth + """ + return len([node for node in self.root.subtree if node.level == self.level]) + + def _pre_detach(self: Tree, parent: Tree) -> None: + """Method call before detaching from `parent`.""" + pass + + def _post_detach(self: Tree, parent: Tree) -> None: + """Method call after detaching from `parent`.""" + pass + + def _pre_attach(self: Tree, parent: Tree) -> None: + """Method call before attaching to `parent`.""" + pass + + def _post_attach(self: Tree, parent: Tree) -> None: + """Method call after attaching to `parent`.""" + pass + + def get(self: Tree, key: str, default: Tree | None = None) -> Tree | None: + """ + Return the child node with the specified key. + + Only looks for the node within the immediate children of this node, + not in other nodes of the tree. + """ + if key in self.children: + return self.children[key] + else: + return default + + # TODO `._walk` method to be called by both `_get_item` and `_set_item` + + def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray: + """ + Returns the object lying at the given path. + + Raises a KeyError if there is no object at the given path. + """ + if isinstance(path, str): + path = NodePath(path) + + if path.root: + current_node = self.root + root, *parts = list(path.parts) + else: + current_node = self + parts = list(path.parts) + + for part in parts: + if part == "..": + if current_node.parent is None: + raise KeyError(f"Could not find node at {path}") + else: + current_node = current_node.parent + elif part in ("", "."): + pass + else: + if current_node.get(part) is None: + raise KeyError(f"Could not find node at {path}") + else: + current_node = current_node.get(part) + return current_node + + def _set(self: Tree, key: str, val: Tree) -> None: + """ + Set the child node with the specified key to value. + + Counterpart to the public .get method, and also only works on the immediate node, not other nodes in the tree. + """ + new_children = {**self.children, key: val} + self.children = new_children + + def _set_item( + self: Tree, + path: str | NodePath, + item: Tree | T_DataArray, + new_nodes_along_path: bool = False, + allow_overwrite: bool = True, + ) -> None: + """ + Set a new item in the tree, overwriting anything already present at that path. + + The given value either forms a new node of the tree or overwrites an + existing item at that location. + + Parameters + ---------- + path + item + new_nodes_along_path : bool + If true, then if necessary new nodes will be created along the + given path, until the tree can reach the specified location. + allow_overwrite : bool + Whether or not to overwrite any existing node at the location given + by path. + + Raises + ------ + KeyError + If node cannot be reached, and new_nodes_along_path=False. + Or if a node already exists at the specified path, and allow_overwrite=False. + """ + if isinstance(path, str): + path = NodePath(path) + + if not path.name: + raise ValueError("Can't set an item under a path which has no name") + + if path.root: + # absolute path + current_node = self.root + root, *parts, name = path.parts + else: + # relative path + current_node = self + *parts, name = path.parts + + if parts: + # Walk to location of new node, creating intermediate node objects as we go if necessary + for part in parts: + if part == "..": + if current_node.parent is None: + # We can't create a parent if `new_nodes_along_path=True` as we wouldn't know what to name it + raise KeyError(f"Could not reach node at path {path}") + else: + current_node = current_node.parent + elif part in ("", "."): + pass + else: + if part in current_node.children: + current_node = current_node.children[part] + elif new_nodes_along_path: + # Want child classes (i.e. DataTree) to populate tree with their own types + new_node = type(self)() + current_node._set(part, new_node) + current_node = current_node.children[part] + else: + raise KeyError(f"Could not reach node at path {path}") + + if name in current_node.children: + # Deal with anything already existing at this location + if allow_overwrite: + current_node._set(name, item) + else: + raise KeyError(f"Already a node object at path {path}") + else: + current_node._set(name, item) + + def __delitem__(self: Tree, key: str): + """Remove a child node from this tree object.""" + if key in self.children: + child = self._children[key] + del self._children[key] + child.orphan() + else: + raise KeyError("Cannot delete") + + def same_tree(self, other: Tree) -> bool: + """True if other node is in the same tree as this node.""" + return self.root is other.root + + +class NamedNode(TreeNode, Generic[Tree]): + """ + A TreeNode which knows its own name. + + Implements path-like relationships to other nodes in its tree. + """ + + _name: str | None + _parent: Tree | None + _children: dict[str, Tree] + + def __init__(self, name=None, children=None): + super().__init__(children=children) + self._name = None + self.name = name + + @property + def name(self) -> str | None: + """The name of this node.""" + return self._name + + @name.setter + def name(self, name: str | None) -> None: + if name is not None: + if not isinstance(name, str): + raise TypeError("node name must be a string or None") + if "/" in name: + raise ValueError("node names cannot contain forward slashes") + self._name = name + + def __repr__(self, level=0): + repr_value = "\t" * level + self.__str__() + "\n" + for child in self.children: + repr_value += self.get(child).__repr__(level + 1) + return repr_value + + def __str__(self) -> str: + return f"NamedNode('{self.name}')" if self.name else "NamedNode()" + + def _post_attach(self: NamedNode, parent: NamedNode) -> None: + """Ensures child has name attribute corresponding to key under which it has been stored.""" + key = next(k for k, v in parent.children.items() if v is self) + self.name = key + + @property + def path(self) -> str: + """Return the file-like path from the root to this node.""" + if self.is_root: + return "/" + else: + root, *ancestors = tuple(reversed(self.parents)) + # don't include name of root because (a) root might not have a name & (b) we want path relative to root. + names = [*(node.name for node in ancestors), self.name] + return "/" + "/".join(names) + + def relative_to(self: NamedNode, other: NamedNode) -> str: + """ + Compute the relative path from this node to node `other`. + + If other is not in this tree, or it's otherwise impossible, raise a ValueError. + """ + if not self.same_tree(other): + raise NotFoundInTreeError( + "Cannot find relative path because nodes do not lie within the same tree" + ) + + this_path = NodePath(self.path) + if other.path in list(parent.path for parent in (self, *self.parents)): + return str(this_path.relative_to(other.path)) + else: + common_ancestor = self.find_common_ancestor(other) + path_to_common_ancestor = other._path_to_ancestor(common_ancestor) + return str( + path_to_common_ancestor / this_path.relative_to(common_ancestor.path) + ) + + def find_common_ancestor(self, other: NamedNode) -> NamedNode: + """ + Find the first common ancestor of two nodes in the same tree. + + Raise ValueError if they are not in the same tree. + """ + if self is other: + return self + + other_paths = [op.path for op in other.parents] + for parent in (self, *self.parents): + if parent.path in other_paths: + return parent + + raise NotFoundInTreeError( + "Cannot find common ancestor because nodes do not lie within the same tree" + ) + + def _path_to_ancestor(self, ancestor: NamedNode) -> NodePath: + """Return the relative path from this node to the given ancestor node""" + + if not self.same_tree(ancestor): + raise NotFoundInTreeError( + "Cannot find relative path to ancestor because nodes do not lie within the same tree" + ) + if ancestor.path not in list(a.path for a in (self, *self.parents)): + raise NotFoundInTreeError( + "Cannot find relative path to ancestor because given node is not an ancestor of this node" + ) + + parents_paths = list(parent.path for parent in (self, *self.parents)) + generation_gap = list(parents_paths).index(ancestor.path) + path_upwards = "../" * generation_gap if generation_gap > 0 else "." + return NodePath(path_upwards) diff --git a/xarray/core/types.py b/xarray/core/types.py index 0f11b16b003..410cf3de00b 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -1,12 +1,14 @@ from __future__ import annotations import datetime -from collections.abc import Hashable, Iterable, Sequence +import sys +from collections.abc import Collection, Hashable, Iterator, Mapping, Sequence from typing import ( TYPE_CHECKING, Any, Callable, Literal, + Protocol, SupportsIndex, TypeVar, Union, @@ -14,18 +16,30 @@ import numpy as np import pandas as pd -from packaging.version import Version + +try: + if sys.version_info >= (3, 11): + from typing import Self, TypeAlias + else: + from typing_extensions import Self, TypeAlias +except ImportError: + if TYPE_CHECKING: + raise + else: + Self: Any = None if TYPE_CHECKING: from numpy._typing import _SupportsDType from numpy.typing import ArrayLike from xarray.backends.common import BackendEntrypoint + from xarray.core.alignment import Aligner from xarray.core.common import AbstractArray, DataWithCoords + from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.groupby import DataArrayGroupBy, GroupBy - from xarray.core.indexes import Index + from xarray.core.indexes import Index, Indexes + from xarray.core.utils import Frozen from xarray.core.variable import Variable try: @@ -33,25 +47,22 @@ except ImportError: DaskArray = np.ndarray # type: ignore - # TODO: Turn on when https://github.com/python/mypy/issues/11871 is fixed. - # Can be uncommented if using pyright though. - # import sys + try: + from cubed import Array as CubedArray + except ImportError: + CubedArray = np.ndarray - # try: - # if sys.version_info >= (3, 11): - # from typing import Self - # else: - # from typing_extensions import Self - # except ImportError: - # Self: Any = None - Self: Any = None + try: + from zarr.core import Array as ZarrArray + except ImportError: + ZarrArray = np.ndarray # Anything that can be coerced to a shape tuple _ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]] _DTypeLikeNested = Any # TODO: wait for support for recursive types # Xarray requires a Mapping[Hashable, dtype] in many places which - # conflics with numpys own DTypeLike (with dtypes for fields). + # conflicts with numpys own DTypeLike (with dtypes for fields). # https://numpy.org/devdocs/reference/typing.html#numpy.typing.DTypeLike # This is a copy of this DTypeLike that allows only non-Mapping dtypes. DTypeLikeSave = Union[ @@ -79,31 +90,102 @@ CFTimeDatetime = Any DatetimeLike = Union[pd.Timestamp, datetime.datetime, np.datetime64, CFTimeDatetime] else: - Self: Any = None DTypeLikeSave: Any = None +class Alignable(Protocol): + """Represents any Xarray type that supports alignment. + + It may be ``Dataset``, ``DataArray`` or ``Coordinates``. This protocol class + is needed since those types do not all have a common base class. + + """ + + @property + def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]: ... + + @property + def sizes(self) -> Mapping[Hashable, int]: ... + + @property + def xindexes(self) -> Indexes[Index]: ... + + def _reindex_callback( + self, + aligner: Aligner, + dim_pos_indexers: dict[Hashable, Any], + variables: dict[Hashable, Variable], + indexes: dict[Hashable, Index], + fill_value: Any, + exclude_dims: frozenset[Hashable], + exclude_vars: frozenset[Hashable], + ) -> Self: ... + + def _overwrite_indexes( + self, + indexes: Mapping[Any, Index], + variables: Mapping[Any, Variable] | None = None, + ) -> Self: ... + + def __len__(self) -> int: ... + + def __iter__(self) -> Iterator[Hashable]: ... + + def copy( + self, + deep: bool = False, + ) -> Self: ... + + +T_Alignable = TypeVar("T_Alignable", bound="Alignable") + T_Backend = TypeVar("T_Backend", bound="BackendEntrypoint") T_Dataset = TypeVar("T_Dataset", bound="Dataset") T_DataArray = TypeVar("T_DataArray", bound="DataArray") T_Variable = TypeVar("T_Variable", bound="Variable") +T_Coordinates = TypeVar("T_Coordinates", bound="Coordinates") T_Array = TypeVar("T_Array", bound="AbstractArray") T_Index = TypeVar("T_Index", bound="Index") +# `T_Xarray` is a type variable that can be either "DataArray" or "Dataset". When used +# in a function definition, all inputs and outputs annotated with `T_Xarray` must be of +# the same concrete type, either "DataArray" or "Dataset". This is generally preferred +# over `T_DataArrayOrSet`, given the type system can determine the exact type. +T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") + +# `T_DataArrayOrSet` is a type variable that is bounded to either "DataArray" or +# "Dataset". Use it for functions that might return either type, but where the exact +# type cannot be determined statically using the type system. T_DataArrayOrSet = TypeVar("T_DataArrayOrSet", bound=Union["Dataset", "DataArray"]) -# Maybe we rename this to T_Data or something less Fortran-y? -T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") +# For working directly with `DataWithCoords`. It will only allow using methods defined +# on `DataWithCoords`. T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") + +# Temporary placeholder for indicating an array api compliant type. +# hopefully in the future we can narrow this down more: +T_DuckArray = TypeVar("T_DuckArray", bound=Any, covariant=True) + + ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] -DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"] -DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"] VarCompatible = Union["Variable", "ScalarOrArray"] -GroupByIncompatible = Union["Variable", "GroupBy"] +DaCompatible = Union["DataArray", "VarCompatible"] +DsCompatible = Union["Dataset", "DaCompatible"] +GroupByCompatible = Union["Dataset", "DataArray"] + +# Don't change to Hashable | Collection[Hashable] +# Read: https://github.com/pydata/xarray/issues/6142 +Dims = Union[str, Collection[Hashable], "ellipsis", None] + +# FYI in some cases we don't allow `None`, which this doesn't take account of. +T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]] +# We allow the tuple form of this (though arguably we could transition to named dims only) +T_Chunks: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDim]] +T_NormalizedChunks = tuple[tuple[int, ...], ...] + +DataVars = Mapping[Any, Any] -Dims = Union[str, Iterable[Hashable], "ellipsis", None] -OrderedDims = Union[str, Sequence[Union[Hashable, "ellipsis"]], "ellipsis", None] ErrorOptions = Literal["raise", "ignore"] ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"] @@ -121,7 +203,7 @@ Interp1dOptions = Literal[ "linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial" ] -InterpolantOptions = Literal["barycentric", "krog", "pchip", "spline", "akima"] +InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"] InterpOptions = Union[Interp1dOptions, InterpolantOptions] DatetimeUnitOptions = Literal[ @@ -179,27 +261,21 @@ ] -if Version(np.__version__) >= Version("1.22.0"): - QuantileMethods = Literal[ - "inverted_cdf", - "averaged_inverted_cdf", - "closest_observation", - "interpolated_inverted_cdf", - "hazen", - "weibull", - "linear", - "median_unbiased", - "normal_unbiased", - "lower", - "higher", - "midpoint", - "nearest", - ] -else: - QuantileMethods = Literal[ # type: ignore[misc] - "linear", - "lower", - "higher", - "midpoint", - "nearest", - ] +QuantileMethods = Literal[ + "inverted_cdf", + "averaged_inverted_cdf", + "closest_observation", + "interpolated_inverted_cdf", + "hazen", + "weibull", + "linear", + "median_unbiased", + "normal_unbiased", + "lower", + "higher", + "midpoint", + "nearest", +] + + +ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"] diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 68474c2bb0c..1d109d304b3 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -1,4 +1,5 @@ """Internal utilities; not for external use""" + # Some functions in this module are derived from functions in pandas. For # reference, here is a copy of the pandas copyright notice: @@ -37,7 +38,6 @@ import contextlib import functools -import importlib import inspect import io import itertools @@ -50,14 +50,18 @@ Collection, Container, Hashable, + ItemsView, Iterable, Iterator, + KeysView, Mapping, MutableMapping, MutableSet, Sequence, + ValuesView, ) from enum import Enum +from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -66,15 +70,27 @@ Literal, TypeVar, Union, - cast, overload, ) import numpy as np import pandas as pd +from xarray.namedarray.utils import ( # noqa: F401 + ReprObject, + drop_missing_dims, + either_dict_or_kwargs, + infix_dims, + is_dask_collection, + is_dict_like, + is_duck_array, + is_duck_dask_array, + module_available, + to_0d_object_array, +) + if TYPE_CHECKING: - from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims + from xarray.core.types import Dims, ErrorOptionsWithWarn K = TypeVar("K") V = TypeVar("V") @@ -114,7 +130,9 @@ def get_valid_numpy_dtype(array: np.ndarray | pd.Index): dtype = np.dtype("O") elif hasattr(array, "categories"): # category isn't a real numpy dtype - dtype = array.categories.dtype # type: ignore[union-attr] + dtype = array.categories.dtype + if not is_valid_numpy_dtype(dtype): + dtype = np.dtype("O") elif not is_valid_numpy_dtype(array.dtype): dtype = np.dtype("O") else: @@ -241,11 +259,6 @@ def remove_incompatible_items( del first_dict[k] -# It's probably OK to give this as a TypeGuard; though it's not perfectly robust. -def is_dict_like(value: Any) -> TypeGuard[Mapping]: - return hasattr(value, "keys") and hasattr(value, "__getitem__") - - def is_full_slice(value: Any) -> bool: return isinstance(value, slice) and value == slice(None) @@ -254,39 +267,6 @@ def is_list_like(value: Any) -> TypeGuard[list | tuple]: return isinstance(value, (list, tuple)) -def is_duck_array(value: Any) -> bool: - if isinstance(value, np.ndarray): - return True - return ( - hasattr(value, "ndim") - and hasattr(value, "shape") - and hasattr(value, "dtype") - and ( - (hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__")) - or hasattr(value, "__array_namespace__") - ) - ) - - -def either_dict_or_kwargs( - pos_kwargs: Mapping[Any, T] | None, - kw_kwargs: Mapping[str, T], - func_name: str, -) -> Mapping[Hashable, T]: - if pos_kwargs is None or pos_kwargs == {}: - # Need an explicit cast to appease mypy due to invariance; see - # https://github.com/python/mypy/issues/6228 - return cast(Mapping[Hashable, T], kw_kwargs) - - if not is_dict_like(pos_kwargs): - raise ValueError(f"the first argument to .{func_name} must be a dictionary") - if kw_kwargs: - raise ValueError( - f"cannot specify both keyword and positional arguments to .{func_name}" - ) - return pos_kwargs - - def _is_scalar(value, include_0d): from xarray.core.variable import NON_NUMPY_SUPPORTED_ARRAY_TYPES @@ -342,13 +322,6 @@ def is_valid_numpy_dtype(dtype: Any) -> bool: return True -def to_0d_object_array(value: Any) -> np.ndarray: - """Given a value, wrap it in a 0-D numpy.ndarray with dtype=object.""" - result = np.empty((), dtype=object) - result[()] = value - return result - - def to_0d_array(value: Any) -> np.ndarray: """Given a value, wrap it in a 0-D numpy.ndarray.""" if np.isscalar(value) or (isinstance(value, np.ndarray) and value.ndim == 0): @@ -472,6 +445,55 @@ def FrozenDict(*args, **kwargs) -> Frozen: return Frozen(dict(*args, **kwargs)) +class FrozenMappingWarningOnValuesAccess(Frozen[K, V]): + """ + Class which behaves like a Mapping but warns if the values are accessed. + + Temporary object to aid in deprecation cycle of `Dataset.dims` (see GH issue #8496). + `Dataset.dims` is being changed from returning a mapping of dimension names to lengths to just + returning a frozen set of dimension names (to increase consistency with `DataArray.dims`). + This class retains backwards compatibility but raises a warning only if the return value + of ds.dims is used like a dictionary (i.e. it doesn't raise a warning if used in a way that + would also be valid for a FrozenSet, e.g. iteration). + """ + + __slots__ = ("mapping",) + + def _warn(self) -> None: + emit_user_level_warning( + "The return type of `Dataset.dims` will be changed to return a set of dimension names in future, " + "in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, " + "please use `Dataset.sizes`.", + FutureWarning, + ) + + def __getitem__(self, key: K) -> V: + self._warn() + return super().__getitem__(key) + + @overload + def get(self, key: K, /) -> V | None: ... + + @overload + def get(self, key: K, /, default: V | T) -> V | T: ... + + def get(self, key: K, default: T | None = None) -> V | T | None: + self._warn() + return super().get(key, default) + + def keys(self) -> KeysView[K]: + self._warn() + return super().keys() + + def items(self) -> ItemsView[K, V]: + self._warn() + return super().items() + + def values(self) -> ValuesView[V]: + self._warn() + return super().values() + + class HybridMappingProxy(Mapping[K, V]): """Implements the Mapping interface. Uses the wrapped mapping for item lookup and a separate wrapped keys collection for iteration. @@ -537,8 +559,7 @@ def discard(self, value: T) -> None: # Additional methods def update(self, values: Iterable[T]) -> None: - for v in values: - self._d[v] = None + self._d.update(dict.fromkeys(values)) def __repr__(self) -> str: return f"{type(self).__name__}({list(self)!r})" @@ -607,31 +628,6 @@ def __repr__(self: Any) -> str: return f"{type(self).__name__}(array={self.array!r})" -class ReprObject: - """Object that prints as the given value, for use with sentinel values.""" - - __slots__ = ("_value",) - - def __init__(self, value: str): - self._value = value - - def __repr__(self) -> str: - return self._value - - def __eq__(self, other) -> bool: - if isinstance(other, ReprObject): - return self._value == other._value - return False - - def __hash__(self) -> int: - return hash((type(self), self._value)) - - def __dask_tokenize__(self): - from dask.base import normalize_token - - return normalize_token((type(self), self._value)) - - @contextlib.contextmanager def close_on_error(f): """Context manager to ensure that a file opened by xarray is closed if an @@ -791,36 +787,6 @@ def __len__(self) -> int: return len(self._data) - num_hidden -def infix_dims( - dims_supplied: Collection, - dims_all: Collection, - missing_dims: ErrorOptionsWithWarn = "raise", -) -> Iterator: - """ - Resolves a supplied list containing an ellipsis representing other items, to - a generator with the 'realized' list of all items - """ - if ... in dims_supplied: - if len(set(dims_all)) != len(dims_all): - raise ValueError("Cannot use ellipsis with repeated dims") - if list(dims_supplied).count(...) > 1: - raise ValueError("More than one ellipsis supplied") - other_dims = [d for d in dims_all if d not in dims_supplied] - existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims) - for d in existing_dims: - if d is ...: - yield from other_dims - else: - yield d - else: - existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims) - if set(existing_dims) ^ set(dims_all): - raise ValueError( - f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included" - ) - yield from existing_dims - - def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: """Get an new dimension name based on new_dim, that is not used in dims. If the same name exists, we add an underscore(s) in the head. @@ -886,72 +852,24 @@ def drop_dims_from_indexers( ) -def drop_missing_dims( - supplied_dims: Iterable[Hashable], - dims: Iterable[Hashable], - missing_dims: ErrorOptionsWithWarn, -) -> Iterable[Hashable]: - """Depending on the setting of missing_dims, drop any dimensions from supplied_dims that - are not present in dims. - - Parameters - ---------- - supplied_dims : Iterable of Hashable - dims : Iterable of Hashable - missing_dims : {"raise", "warn", "ignore"} - """ - - if missing_dims == "raise": - supplied_dims_set = {val for val in supplied_dims if val is not ...} - invalid = supplied_dims_set - set(dims) - if invalid: - raise ValueError( - f"Dimensions {invalid} do not exist. Expected one or more of {dims}" - ) - - return supplied_dims - - elif missing_dims == "warn": - invalid = set(supplied_dims) - set(dims) - if invalid: - warnings.warn( - f"Dimensions {invalid} do not exist. Expected one or more of {dims}" - ) - - return [val for val in supplied_dims if val in dims or val is ...] - - elif missing_dims == "ignore": - return [val for val in supplied_dims if val in dims or val is ...] - - else: - raise ValueError( - f"Unrecognised option {missing_dims} for missing_dims argument" - ) - - -T_None = TypeVar("T_None", None, "ellipsis") - - @overload def parse_dims( - dim: str | Iterable[Hashable] | T_None, + dim: Dims, all_dims: tuple[Hashable, ...], *, check_exists: bool = True, replace_none: Literal[True] = True, -) -> tuple[Hashable, ...]: - ... +) -> tuple[Hashable, ...]: ... @overload def parse_dims( - dim: str | Iterable[Hashable] | T_None, + dim: Dims, all_dims: tuple[Hashable, ...], *, check_exists: bool = True, replace_none: Literal[False], -) -> tuple[Hashable, ...] | T_None: - ... +) -> tuple[Hashable, ...] | None | ellipsis: ... def parse_dims( @@ -997,28 +915,26 @@ def parse_dims( @overload def parse_ordered_dims( - dim: str | Sequence[Hashable | ellipsis] | T_None, + dim: Dims, all_dims: tuple[Hashable, ...], *, check_exists: bool = True, replace_none: Literal[True] = True, -) -> tuple[Hashable, ...]: - ... +) -> tuple[Hashable, ...]: ... @overload def parse_ordered_dims( - dim: str | Sequence[Hashable | ellipsis] | T_None, + dim: Dims, all_dims: tuple[Hashable, ...], *, check_exists: bool = True, replace_none: Literal[False], -) -> tuple[Hashable, ...] | T_None: - ... +) -> tuple[Hashable, ...] | None | ellipsis: ... def parse_ordered_dims( - dim: OrderedDims, + dim: Dims, all_dims: tuple[Hashable, ...], *, check_exists: bool = True, @@ -1072,9 +988,9 @@ def parse_ordered_dims( ) -def _check_dims(dim: set[Hashable | ellipsis], all_dims: set[Hashable]) -> None: - wrong_dims = dim - all_dims - if wrong_dims and wrong_dims != {...}: +def _check_dims(dim: set[Hashable], all_dims: set[Hashable]) -> None: + wrong_dims = (dim - all_dims) - {...} + if wrong_dims: wrong_dims_str = ", ".join(f"'{d!s}'" for d in wrong_dims) raise ValueError( f"Dimension(s) {wrong_dims_str} do not exist. Expected one or more of {all_dims}" @@ -1096,12 +1012,10 @@ def __init__(self, accessor: type[_Accessor]) -> None: self._accessor = accessor @overload - def __get__(self, obj: None, cls) -> type[_Accessor]: - ... + def __get__(self, obj: None, cls) -> type[_Accessor]: ... @overload - def __get__(self, obj: object, cls) -> _Accessor: - ... + def __get__(self, obj: object, cls) -> _Accessor: ... def __get__(self, obj: None | object, cls) -> type[_Accessor] | _Accessor: if obj is None: @@ -1185,49 +1099,31 @@ def to_list(arg): return dim_list, arr_args -def contains_only_dask_or_numpy(obj) -> bool: - """Returns True if xarray object contains only numpy or dask arrays. +def contains_only_chunked_or_numpy(obj) -> bool: + """Returns True if xarray object contains only numpy arrays or chunked arrays (i.e. pure dask or cubed). Expects obj to be Dataset or DataArray""" from xarray.core.dataarray import DataArray - from xarray.core.pycompat import is_duck_dask_array + from xarray.namedarray.pycompat import is_chunked_array if isinstance(obj, DataArray): obj = obj._to_temp_dataset() return all( [ - isinstance(var.data, np.ndarray) or is_duck_dask_array(var.data) + isinstance(var.data, np.ndarray) or is_chunked_array(var.data) for var in obj.variables.values() ] ) -def module_available(module: str) -> bool: - """Checks whether a module is installed without importing it. - - Use this for a lightweight check and lazy imports. - - Parameters - ---------- - module : str - Name of the module. - - Returns - ------- - available : bool - Whether the module is installed. - """ - return importlib.util.find_spec(module) is not None - - def find_stack_level(test_mode=False) -> int: - """Find the first place in the stack that is not inside xarray. + """Find the first place in the stack that is not inside xarray or the Python standard library. This is unless the code emanates from a test, in which case we would prefer to see the xarray source. - This function is taken from pandas. + This function is taken from pandas and modified to exclude standard library paths. Parameters ---------- @@ -1238,19 +1134,32 @@ def find_stack_level(test_mode=False) -> int: Returns ------- stacklevel : int - First level in the stack that is not part of xarray. + First level in the stack that is not part of xarray or the Python standard library. """ import xarray as xr - pkg_dir = os.path.dirname(xr.__file__) - test_dir = os.path.join(pkg_dir, "tests") + pkg_dir = Path(xr.__file__).parent + test_dir = pkg_dir / "tests" + + std_lib_init = sys.modules["os"].__file__ + # Mostly to appease mypy; I don't think this can happen... + if std_lib_init is None: + return 0 + + std_lib_dir = Path(std_lib_init).parent - # https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow frame = inspect.currentframe() n = 0 while frame: fname = inspect.getfile(frame) - if fname.startswith(pkg_dir) and (not fname.startswith(test_dir) or test_mode): + if ( + fname.startswith(str(pkg_dir)) + and (not fname.startswith(str(test_dir)) or test_mode) + ) or ( + fname.startswith(str(std_lib_dir)) + and "site-packages" not in fname + and "dist-packages" not in fname + ): frame = frame.f_back n += 1 else: @@ -1258,7 +1167,70 @@ def find_stack_level(test_mode=False) -> int: return n -def emit_user_level_warning(message, category=None): +def emit_user_level_warning(message, category=None) -> None: """Emit a warning at the user level by inspecting the stack trace.""" stacklevel = find_stack_level() - warnings.warn(message, category=category, stacklevel=stacklevel) + return warnings.warn(message, category=category, stacklevel=stacklevel) + + +def consolidate_dask_from_array_kwargs( + from_array_kwargs: dict[Any, Any], + name: str | None = None, + lock: bool | None = None, + inline_array: bool | None = None, +) -> dict[Any, Any]: + """ + Merge dask-specific kwargs with arbitrary from_array_kwargs dict. + + Temporary function, to be deleted once explicitly passing dask-specific kwargs to .chunk() is deprecated. + """ + + from_array_kwargs = _resolve_doubly_passed_kwarg( + from_array_kwargs, + kwarg_name="name", + passed_kwarg_value=name, + default=None, + err_msg_dict_name="from_array_kwargs", + ) + from_array_kwargs = _resolve_doubly_passed_kwarg( + from_array_kwargs, + kwarg_name="lock", + passed_kwarg_value=lock, + default=False, + err_msg_dict_name="from_array_kwargs", + ) + from_array_kwargs = _resolve_doubly_passed_kwarg( + from_array_kwargs, + kwarg_name="inline_array", + passed_kwarg_value=inline_array, + default=False, + err_msg_dict_name="from_array_kwargs", + ) + + return from_array_kwargs + + +def _resolve_doubly_passed_kwarg( + kwargs_dict: dict[Any, Any], + kwarg_name: str, + passed_kwarg_value: str | bool | None, + default: bool | None, + err_msg_dict_name: str, +) -> dict[Any, Any]: + # if in kwargs_dict but not passed explicitly then just pass kwargs_dict through unaltered + if kwarg_name in kwargs_dict and passed_kwarg_value is None: + pass + # if passed explicitly but not in kwargs_dict then use that + elif kwarg_name not in kwargs_dict and passed_kwarg_value is not None: + kwargs_dict[kwarg_name] = passed_kwarg_value + # if in neither then use default + elif kwarg_name not in kwargs_dict and passed_kwarg_value is None: + kwargs_dict[kwarg_name] = default + # if in both then raise + else: + raise ValueError( + f"argument {kwarg_name} cannot be passed both as a keyword argument and within " + f"the {err_msg_dict_name} dictionary" + ) + + return kwargs_dict diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 555bff47913..a5cf594c386 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -5,14 +5,14 @@ import math import numbers import warnings -from collections.abc import Hashable, Iterable, Mapping, Sequence +from collections.abc import Hashable, Mapping, Sequence from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast import numpy as np import pandas as pd from numpy.typing import ArrayLike -from packaging.version import Version import xarray as xr # only for Dataset and DataArray from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils @@ -26,27 +26,28 @@ as_indexable, ) from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.pycompat import ( - array_type, - integer_types, - is_0d_dask_array, - is_duck_dask_array, -) from xarray.core.rolling import get_pads from xarray.core.utils import ( - Frozen, - NdimSizeLenMixin, OrderedSet, _default, + consolidate_dask_from_array_kwargs, decode_numpy_dict_values, drop_dims_from_indexers, either_dict_or_kwargs, ensure_us_time_resolution, expand_args_to_dims, infix_dims, + is_dict_like, is_duck_array, maybe_coerce_to_str, ) +from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions +from xarray.namedarray.pycompat import ( + integer_types, + is_0d_dask_array, + is_duck_dask_array, + to_duck_array, +) NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( indexing.ExplicitlyIndexed, @@ -62,8 +63,11 @@ PadModeOptions, PadReflectOptions, QuantileMethods, - T_Variable, + Self, + T_DuckArray, ) + from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint + NON_NANOSECOND_WARNING = ( "Converting non-nanosecond precision {case} values to nanosecond precision. " @@ -82,7 +86,7 @@ class MissingDimensionsError(ValueError): # TODO: move this to an xarray.exceptions module? -def as_variable(obj, name=None) -> Variable | IndexVariable: +def as_variable(obj: T_DuckArray | Any, name=None) -> Variable | IndexVariable: """Convert an object into a Variable. Parameters @@ -121,17 +125,15 @@ def as_variable(obj, name=None) -> Variable | IndexVariable: elif isinstance(obj, tuple): if isinstance(obj[1], DataArray): raise TypeError( - "Using a DataArray object to construct a variable is" + f"Variable {name!r}: Using a DataArray object to construct a variable is" " ambiguous, please extract the data using the .data property." ) try: obj = Variable(*obj) except (TypeError, ValueError) as error: - # use .format() instead of % because it handles tuples consistently raise error.__class__( - "Could not convert tuple of form " - "(dims, data[, attrs, encoding]): " - "{} to Variable.".format(obj) + f"Variable {name!r}: Could not convert tuple of form " + f"(dims, data[, attrs, encoding]): {obj} to Variable." ) elif utils.is_scalar(obj): obj = Variable([], obj) @@ -140,7 +142,7 @@ def as_variable(obj, name=None) -> Variable | IndexVariable: elif isinstance(obj, (set, dict)): raise TypeError(f"variable {name!r} has invalid type {type(obj)!r}") elif name is not None: - data = as_compatible_data(obj) + data: T_DuckArray = as_compatible_data(obj) if data.ndim != 1: raise MissingDimensionsError( f"cannot set variable {name!r} with {data.ndim!r}-dimensional data " @@ -150,18 +152,12 @@ def as_variable(obj, name=None) -> Variable | IndexVariable: obj = Variable(name, data, fastpath=True) else: raise TypeError( - "unable to convert object into a variable without an " + f"Variable {name!r}: unable to convert object into a variable without an " f"explicit list of dimensions: {obj!r}" ) - if name is not None and name in obj.dims: - # convert the Variable into an Index - if obj.ndim != 1: - raise MissingDimensionsError( - f"{name!r} has more than 1-dimension and the same name as one of its " - f"dimensions {obj.dims!r}. xarray disallows such variables because they " - "conflict with the coordinates used to label dimensions." - ) + if name is not None and name in obj.dims and obj.ndim == 1: + # automatically convert the Variable into an Index obj = obj.to_index_variable() return obj @@ -196,10 +192,10 @@ def _as_nanosecond_precision(data): nanosecond_precision_dtype = pd.DatetimeTZDtype("ns", dtype.tz) else: nanosecond_precision_dtype = "datetime64[ns]" - return data.astype(nanosecond_precision_dtype) + return duck_array_ops.astype(data, nanosecond_precision_dtype) elif dtype.kind == "m" and dtype != np.dtype("timedelta64[ns]"): utils.emit_user_level_warning(NON_NANOSECOND_WARNING.format(case="timedelta")) - return data.astype("timedelta64[ns]") + return duck_array_ops.astype(data, "timedelta64[ns]") else: return data @@ -219,7 +215,14 @@ def _possibly_convert_objects(values): as_series = pd.Series(values.ravel(), copy=False) if as_series.dtype.kind in "mM": as_series = _as_nanosecond_precision(as_series) - return np.asarray(as_series).reshape(values.shape) + result = np.asarray(as_series).reshape(values.shape) + if not result.flags.writeable: + # GH8843, pandas copy-on-write mode creates read-only arrays by default + try: + result.flags.writeable = True + except ValueError: + result = result.copy() + return result def _possibly_convert_datetime_or_timedelta_index(data): @@ -228,13 +231,17 @@ def _possibly_convert_datetime_or_timedelta_index(data): this in version 2.0.0, in xarray we will need to make sure we are ready to handle non-nanosecond precision datetimes or timedeltas in our code before allowing such values to pass through unchanged.""" - if isinstance(data, (pd.DatetimeIndex, pd.TimedeltaIndex)): - return _as_nanosecond_precision(data) - else: - return data + if isinstance(data, PandasIndexingAdapter): + if isinstance(data.array, (pd.DatetimeIndex, pd.TimedeltaIndex)): + data = PandasIndexingAdapter(_as_nanosecond_precision(data.array)) + elif isinstance(data, (pd.DatetimeIndex, pd.TimedeltaIndex)): + data = _as_nanosecond_precision(data) + return data -def as_compatible_data(data, fastpath=False): +def as_compatible_data( + data: T_DuckArray | ArrayLike, fastpath: bool = False +) -> T_DuckArray: """Prepare and wrap data to put in a Variable. - If data does not have the necessary attributes, convert it to ndarray. @@ -247,7 +254,7 @@ def as_compatible_data(data, fastpath=False): """ if fastpath and getattr(data, "ndim", 0) > 0: # can't use fastpath (yet) for scalars - return _maybe_wrap_data(data) + return cast("T_DuckArray", _maybe_wrap_data(data)) from xarray.core.dataarray import DataArray @@ -256,7 +263,7 @@ def as_compatible_data(data, fastpath=False): if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): data = _possibly_convert_datetime_or_timedelta_index(data) - return _maybe_wrap_data(data) + return cast("T_DuckArray", _maybe_wrap_data(data)) if isinstance(data, tuple): data = utils.to_0d_object_array(data) @@ -269,22 +276,21 @@ def as_compatible_data(data, fastpath=False): data = np.timedelta64(getattr(data, "value", data), "ns") # we don't want nested self-described arrays - if isinstance(data, (pd.Series, pd.Index, pd.DataFrame)): + if isinstance(data, (pd.Series, pd.DataFrame)): data = data.values if isinstance(data, np.ma.MaskedArray): mask = np.ma.getmaskarray(data) if mask.any(): dtype, fill_value = dtypes.maybe_promote(data.dtype) - data = np.asarray(data, dtype=dtype) - data[mask] = fill_value + data = duck_array_ops.where_method(data, ~mask, fill_value) else: data = np.asarray(data) if not isinstance(data, np.ndarray) and ( hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__") ): - return data + return cast("T_DuckArray", data) # validate whether the data is valid data types. data = np.asarray(data) @@ -317,7 +323,7 @@ def _as_array_or_item(data): return data -class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic): +class Variable(NamedArray, AbstractArray, VariableArithmetic): """A netcdf-like variable consisting of dimensions, data and attributes which describe a single Array. A single Variable object is not fully described outside the context of its parent Dataset (if you want such a @@ -340,7 +346,14 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic): __slots__ = ("_dims", "_data", "_attrs", "_encoding") - def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): + def __init__( + self, + dims, + data: T_DuckArray | ArrayLike, + attrs=None, + encoding=None, + fastpath=False, + ): """ Parameters ---------- @@ -360,50 +373,32 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): Well-behaved code to serialize a Variable should ignore unrecognized encoding items. """ - self._data = as_compatible_data(data, fastpath=fastpath) - self._dims = self._parse_dimensions(dims) - self._attrs = None + super().__init__( + dims=dims, data=as_compatible_data(data, fastpath=fastpath), attrs=attrs + ) + self._encoding = None - if attrs is not None: - self.attrs = attrs if encoding is not None: self.encoding = encoding - @property - def dtype(self): - """ - Data-type of the array’s elements. - - See Also - -------- - ndarray.dtype - numpy.dtype - """ - return self._data.dtype - - @property - def shape(self): - """ - Tuple of array dimensions. - - See Also - -------- - numpy.ndarray.shape - """ - return self._data.shape + def _new( + self, + dims=_default, + data=_default, + attrs=_default, + ): + dims_ = copy.copy(self._dims) if dims is _default else dims - @property - def nbytes(self) -> int: - """ - Total bytes consumed by the elements of the data array. + if attrs is _default: + attrs_ = None if self._attrs is None else self._attrs.copy() + else: + attrs_ = attrs - If the underlying data array does not include ``nbytes``, estimates - the bytes consumed based on the ``size`` and ``dtype``. - """ - if hasattr(self._data, "nbytes"): - return self._data.nbytes + if data is _default: + return type(self)(dims_, copy.copy(self._data), attrs_) else: - return self.size * self.dtype.itemsize + cls_ = type(self) + return cls_(dims_, data, attrs_) @property def _in_memory(self): @@ -415,7 +410,7 @@ def _in_memory(self): ) @property - def data(self) -> Any: + def data(self): """ The Variable's data as an array. The underlying array type (e.g. dask, sparse, pint) is preserved. @@ -428,21 +423,19 @@ def data(self) -> Any: """ if is_duck_array(self._data): return self._data + elif isinstance(self._data, indexing.ExplicitlyIndexed): + return self._data.get_duck_array() else: return self.values @data.setter - def data(self, data): + def data(self, data: T_DuckArray | ArrayLike) -> None: data = as_compatible_data(data) - if data.shape != self.shape: - raise ValueError( - f"replacement data must match the Variable's shape. " - f"replacement data has shape {data.shape}; Variable has shape {self.shape}" - ) + self._check_shape(data) self._data = data def astype( - self: T_Variable, + self, dtype, *, order=None, @@ -450,7 +443,7 @@ def astype( subok=None, copy=None, keep_attrs=True, - ) -> T_Variable: + ) -> Self: """ Copy of the Variable object, with data cast to a specified type. @@ -516,85 +509,6 @@ def astype( dask="allowed", ) - def load(self, **kwargs): - """Manually trigger loading of this variable's data from disk or a - remote source into memory and return this variable. - - Normally, it should not be necessary to call this method in user code, - because all xarray functions should either work on deferred data or - load data automatically. - - Parameters - ---------- - **kwargs : dict - Additional keyword arguments passed on to ``dask.array.compute``. - - See Also - -------- - dask.array.compute - """ - if is_duck_dask_array(self._data): - self._data = as_compatible_data(self._data.compute(**kwargs)) - elif not is_duck_array(self._data): - self._data = np.asarray(self._data) - return self - - def compute(self, **kwargs): - """Manually trigger loading of this variable's data from disk or a - remote source into memory and return a new variable. The original is - left unaltered. - - Normally, it should not be necessary to call this method in user code, - because all xarray functions should either work on deferred data or - load data automatically. - - Parameters - ---------- - **kwargs : dict - Additional keyword arguments passed on to ``dask.array.compute``. - - See Also - -------- - dask.array.compute - """ - new = self.copy(deep=False) - return new.load(**kwargs) - - def __dask_tokenize__(self): - # Use v.data, instead of v._data, in order to cope with the wrappers - # around NetCDF and the like - from dask.base import normalize_token - - return normalize_token((type(self), self._dims, self.data, self._attrs)) - - def __dask_graph__(self): - if is_duck_dask_array(self._data): - return self._data.__dask_graph__() - else: - return None - - def __dask_keys__(self): - return self._data.__dask_keys__() - - def __dask_layers__(self): - return self._data.__dask_layers__() - - @property - def __dask_optimize__(self): - return self._data.__dask_optimize__ - - @property - def __dask_scheduler__(self): - return self._data.__dask_scheduler__ - - def __dask_postcompute__(self): - array_func, array_args = self._data.__dask_postcompute__() - return self._dask_finalize, (array_func,) + array_args - - def __dask_postpersist__(self): - array_func, array_args = self._data.__dask_postpersist__() - return self._dask_finalize, (array_func,) + array_args - def _dask_finalize(self, results, array_func, *args, **kwargs): data = array_func(results, *args, **kwargs) return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding) @@ -631,11 +545,23 @@ def to_index(self) -> pd.Index: """Convert this variable to a pandas.Index""" return self.to_index_variable().to_index() - def to_dict(self, data: bool = True, encoding: bool = False) -> dict: + def to_dict( + self, data: bool | str = "list", encoding: bool = False + ) -> dict[str, Any]: """Dictionary representation of variable.""" - item = {"dims": self.dims, "attrs": decode_numpy_dict_values(self.attrs)} - if data: - item["data"] = ensure_us_time_resolution(self.values).tolist() + item: dict[str, Any] = { + "dims": self.dims, + "attrs": decode_numpy_dict_values(self.attrs), + } + if data is not False: + if data in [True, "list"]: + item["data"] = ensure_us_time_resolution(self.to_numpy()).tolist() + elif data == "array": + item["data"] = ensure_us_time_resolution(self.data) + else: + msg = 'data argument must be bool, "list", or "array"' + raise ValueError(msg) + else: item.update({"dtype": str(self.dtype), "shape": self.shape}) @@ -644,28 +570,8 @@ def to_dict(self, data: bool = True, encoding: bool = False) -> dict: return item - @property - def dims(self) -> tuple[Hashable, ...]: - """Tuple of dimension names with which this variable is associated.""" - return self._dims - - @dims.setter - def dims(self, value: str | Iterable[Hashable]) -> None: - self._dims = self._parse_dimensions(value) - - def _parse_dimensions(self, dims: str | Iterable[Hashable]) -> tuple[Hashable, ...]: - if isinstance(dims, str): - dims = (dims,) - dims = tuple(dims) - if len(dims) != self.ndim: - raise ValueError( - f"dimensions {dims} must have the same length as the " - f"number of data dimensions, ndim={self.ndim}" - ) - return dims - def _item_key_to_tuple(self, key): - if utils.is_dict_like(key): + if is_dict_like(key): return tuple(key.get(dim, slice(None)) for dim in self.dims) else: return key @@ -745,18 +651,18 @@ def _validate_indexers(self, key): if k.ndim > 1: raise IndexError( "Unlabeled multi-dimensional array cannot be " - "used for indexing: {}".format(k) + f"used for indexing: {k}" ) if k.dtype.kind == "b": if self.shape[self.get_axis_num(dim)] != len(k): raise IndexError( - "Boolean array size {:d} is used to index array " - "with shape {:s}.".format(len(k), str(self.shape)) + f"Boolean array size {len(k):d} is used to index array " + f"with shape {str(self.shape):s}." ) if k.ndim > 1: raise IndexError( - "{}-dimensional boolean indexing is " - "not supported. ".format(k.ndim) + f"{k.ndim}-dimensional boolean indexing is " + "not supported. " ) if is_duck_dask_array(k.data): raise KeyError( @@ -769,9 +675,7 @@ def _validate_indexers(self, key): raise IndexError( "Boolean indexer should be unlabeled or on the " "same dimension to the indexed array. Indexer is " - "on {:s} but the target dimension is {:s}.".format( - str(k.dims), dim - ) + f"on {str(k.dims):s} but the target dimension is {dim:s}." ) def _broadcast_indexes_outer(self, key): @@ -798,13 +702,6 @@ def _broadcast_indexes_outer(self, key): return dims, OuterIndexer(tuple(new_key)), None - def _nonzero(self): - """Equivalent numpy's nonzero but returns a tuple of Variables.""" - # TODO we should replace dask's native nonzero - # after https://github.com/dask/dask/issues/1076 is implemented. - nonzeros = np.nonzero(self.data) - return tuple(Variable((dim), nz) for nz, dim in zip(nonzeros, self.dims)) - def _broadcast_indexes_vectorized(self, key): variables = [] out_dims_set = OrderedSet() @@ -861,7 +758,7 @@ def _broadcast_indexes_vectorized(self, key): return out_dims, VectorizedIndexer(tuple(out_key)), new_order - def __getitem__(self: T_Variable, key) -> T_Variable: + def __getitem__(self, key) -> Self: """Return a new Variable object whose contents are consistent with getting the provided key from the underlying data. @@ -875,12 +772,15 @@ def __getitem__(self: T_Variable, key) -> T_Variable: array `x.values` directly. """ dims, indexer, new_order = self._broadcast_indexes(key) - data = as_indexable(self._data)[indexer] + indexable = as_indexable(self._data) + + data = indexing.apply_indexer(indexable, indexer) + if new_order: data = np.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) - def _finalize_indexing_result(self: T_Variable, dims, data) -> T_Variable: + def _finalize_indexing_result(self, dims, data) -> Self: """Used by IndexVariable to return IndexVariable objects when possible.""" return self._replace(dims=dims, data=data) @@ -900,6 +800,7 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): dims, indexer, new_order = self._broadcast_indexes(key) if self.size: + if is_duck_dask_array(self._data): # dask's indexing is faster this way; also vindex does not # support negative indices yet: @@ -908,7 +809,9 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): else: actual_indexer = indexer - data = as_indexable(self._data)[actual_indexer] + indexable = as_indexable(self._data) + data = indexing.apply_indexer(indexable, actual_indexer) + mask = indexing.create_mask(indexer, self.shape, data) # we need to invert the mask in order to pass data first. This helps # pint to choose the correct unit @@ -952,18 +855,7 @@ def __setitem__(self, key, value): value = np.moveaxis(value, new_order, range(len(new_order))) indexable = as_indexable(self._data) - indexable[index_tuple] = value - - @property - def attrs(self) -> dict[Any, Any]: - """Dictionary of local attributes on this variable.""" - if self._attrs is None: - self._attrs = {} - return self._attrs - - @attrs.setter - def attrs(self, value: Mapping[Any, Any]) -> None: - self._attrs = dict(value) + indexing.set_with_indexer(indexable, index_tuple, value) @property def encoding(self) -> dict[Any, Any]: @@ -979,89 +871,40 @@ def encoding(self, value): except ValueError: raise ValueError("encoding must be castable to a dictionary") - def copy( - self: T_Variable, deep: bool = True, data: ArrayLike | None = None - ) -> T_Variable: - """Returns a copy of this object. - - If `deep=True`, the data array is loaded into memory and copied onto - the new object. Dimensions, attributes and encodings are always copied. - - Use `data` to create a new object with the same structure as - original but entirely new data. - - Parameters - ---------- - deep : bool, default: True - Whether the data array is loaded into memory and copied onto - the new object. Default is True. - data : array_like, optional - Data to use in the new object. Must have same shape as original. - When `data` is used, `deep` is ignored. - - Returns - ------- - object : Variable - New object with dimensions, attributes, encodings, and optionally - data copied from original. - - Examples - -------- - Shallow copy versus deep copy - - >>> var = xr.Variable(data=[1, 2, 3], dims="x") - >>> var.copy() - - array([1, 2, 3]) - >>> var_0 = var.copy(deep=False) - >>> var_0[0] = 7 - >>> var_0 - - array([7, 2, 3]) - >>> var - - array([7, 2, 3]) - - Changing the data using the ``data`` argument maintains the - structure of the original object, but with the new data. Original - object is unaffected. + def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() - >>> var.copy(data=[0.1, 0.2, 0.3]) - - array([0.1, 0.2, 0.3]) - >>> var - - array([7, 2, 3]) - - See Also - -------- - pandas.DataFrame.copy - """ - return self._copy(deep=deep, data=data) + def drop_encoding(self) -> Self: + """Return a new Variable without encoding.""" + return self._replace(encoding={}) def _copy( - self: T_Variable, + self, deep: bool = True, - data: ArrayLike | None = None, + data: T_DuckArray | ArrayLike | None = None, memo: dict[int, Any] | None = None, - ) -> T_Variable: + ) -> Self: if data is None: - ndata = self._data + data_old = self._data - if isinstance(ndata, indexing.MemoryCachedArray): + if not isinstance(data_old, indexing.MemoryCachedArray): + ndata = data_old + else: # don't share caching between copies - ndata = indexing.MemoryCachedArray(ndata.array) + # TODO: MemoryCachedArray doesn't match the array api: + ndata = indexing.MemoryCachedArray(data_old.array) # type: ignore[assignment] if deep: ndata = copy.deepcopy(ndata, memo) else: ndata = as_compatible_data(data) - if self.shape != ndata.shape: + if self.shape != ndata.shape: # type: ignore[attr-defined] raise ValueError( - "Data shape {} must match shape of object {}".format( - ndata.shape, self.shape - ) + f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined] ) attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs) @@ -1073,233 +916,70 @@ def _copy( return self._replace(data=ndata, attrs=attrs, encoding=encoding) def _replace( - self: T_Variable, + self, dims=_default, data=_default, attrs=_default, encoding=_default, - ) -> T_Variable: + ) -> Self: if dims is _default: dims = copy.copy(self._dims) if data is _default: data = copy.copy(self.data) if attrs is _default: attrs = copy.copy(self._attrs) + if encoding is _default: encoding = copy.copy(self._encoding) return type(self)(dims, data, attrs, encoding, fastpath=True) - def __copy__(self: T_Variable) -> T_Variable: - return self._copy(deep=False) - - def __deepcopy__( - self: T_Variable, memo: dict[int, Any] | None = None - ) -> T_Variable: - return self._copy(deep=True, memo=memo) - - # mutable objects should not be hashable - # https://github.com/python/mypy/issues/4266 - __hash__ = None # type: ignore[assignment] - - @property - def chunks(self) -> tuple[tuple[int, ...], ...] | None: - """ - Tuple of block lengths for this dataarray's data, in order of dimensions, or None if - the underlying data is not a dask array. - - See Also - -------- - Variable.chunk - Variable.chunksizes - xarray.unify_chunks - """ - return getattr(self._data, "chunks", None) + def load(self, **kwargs): + """Manually trigger loading of this variable's data from disk or a + remote source into memory and return this variable. - @property - def chunksizes(self) -> Mapping[Any, tuple[int, ...]]: - """ - Mapping from dimension names to block lengths for this variable's data, or None if - the underlying data is not a dask array. - Cannot be modified directly, but can be modified by calling .chunk(). + Normally, it should not be necessary to call this method in user code, + because all xarray functions should either work on deferred data or + load data automatically. - Differs from variable.chunks because it returns a mapping of dimensions to chunk shapes - instead of a tuple of chunk shapes. + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.array.compute``. See Also -------- - Variable.chunk - Variable.chunks - xarray.unify_chunks + dask.array.compute """ - if hasattr(self._data, "chunks"): - return Frozen({dim: c for dim, c in zip(self.dims, self.data.chunks)}) - else: - return {} - - _array_counter = itertools.count() - - def chunk( - self, - chunks: ( - int - | Literal["auto"] - | tuple[int, ...] - | tuple[tuple[int, ...], ...] - | Mapping[Any, None | int | tuple[int, ...]] - ) = {}, - name: str | None = None, - lock: bool = False, - inline_array: bool = False, - **chunks_kwargs: Any, - ) -> Variable: - """Coerce this array's data into a dask array with the given chunks. + self._data = to_duck_array(self._data, **kwargs) + return self - If this variable is a non-dask array, it will be converted to dask - array. If it's a dask array, it will be rechunked to the given chunk - sizes. + def compute(self, **kwargs): + """Manually trigger loading of this variable's data from disk or a + remote source into memory and return a new variable. The original is + left unaltered. - If neither chunks is not provided for one or more dimensions, chunk - sizes along that dimension will not be updated; non-dask arrays will be - converted into dask arrays with a single block. + Normally, it should not be necessary to call this method in user code, + because all xarray functions should either work on deferred data or + load data automatically. Parameters ---------- - chunks : int, tuple or dict, optional - Chunk sizes along each dimension, e.g., ``5``, ``(5, 5)`` or - ``{'x': 5, 'y': 5}``. - name : str, optional - Used to generate the name for this array in the internal dask - graph. Does not need not be unique. - lock : optional - Passed on to :py:func:`dask.array.from_array`, if the array is not - already as dask array. - inline_array: optional - Passed on to :py:func:`dask.array.from_array`, if the array is not - already as dask array. - **chunks_kwargs : {dim: chunks, ...}, optional - The keyword arguments form of ``chunks``. - One of chunks or chunks_kwargs must be provided. - - Returns - ------- - chunked : xarray.Variable + **kwargs : dict + Additional keyword arguments passed on to ``dask.array.compute``. See Also -------- - Variable.chunks - Variable.chunksizes - xarray.unify_chunks - dask.array.from_array - """ - import dask.array as da - - if chunks is None: - warnings.warn( - "None value for 'chunks' is deprecated. " - "It will raise an error in the future. Use instead '{}'", - category=FutureWarning, - ) - chunks = {} - - if isinstance(chunks, (float, str, int, tuple, list)): - pass # dask.array.from_array can handle these directly - else: - chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") - - if utils.is_dict_like(chunks): - chunks = {self.get_axis_num(dim): chunk for dim, chunk in chunks.items()} - - data = self._data - if is_duck_dask_array(data): - data = data.rechunk(chunks) - else: - if isinstance(data, indexing.ExplicitlyIndexed): - # Unambiguously handle array storage backends (like NetCDF4 and h5py) - # that can't handle general array indexing. For example, in netCDF4 you - # can do "outer" indexing along two dimensions independent, which works - # differently from how NumPy handles it. - # da.from_array works by using lazy indexing with a tuple of slices. - # Using OuterIndexer is a pragmatic choice: dask does not yet handle - # different indexing types in an explicit way: - # https://github.com/dask/dask/issues/2883 - data = indexing.ImplicitToExplicitIndexingAdapter( - data, indexing.OuterIndexer - ) - - # All of our lazily loaded backend array classes should use NumPy - # array operations. - kwargs = {"meta": np.ndarray} - else: - kwargs = {} - - if utils.is_dict_like(chunks): - chunks = tuple(chunks.get(n, s) for n, s in enumerate(self.shape)) - - data = da.from_array( - data, chunks, name=name, lock=lock, inline_array=inline_array, **kwargs - ) - - return self._replace(data=data) - - def to_numpy(self) -> np.ndarray: - """Coerces wrapped data to numpy and returns a numpy.ndarray""" - # TODO an entrypoint so array libraries can choose coercion method? - data = self.data - - # TODO first attempt to call .to_numpy() once some libraries implement it - if hasattr(data, "chunks"): - data = data.compute() - if isinstance(data, array_type("cupy")): - data = data.get() - # pint has to be imported dynamically as pint imports xarray - if isinstance(data, array_type("pint")): - data = data.magnitude - if isinstance(data, array_type("sparse")): - data = data.todense() - data = np.asarray(data) - - return data - - def as_numpy(self: T_Variable) -> T_Variable: - """Coerces wrapped data into a numpy array, returning a Variable.""" - return self._replace(data=self.to_numpy()) - - def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA): - """ - use sparse-array as backend. - """ - import sparse - - # TODO: what to do if dask-backended? - if fill_value is dtypes.NA: - dtype, fill_value = dtypes.maybe_promote(self.dtype) - else: - dtype = dtypes.result_type(self.dtype, fill_value) - - if sparse_format is _default: - sparse_format = "coo" - try: - as_sparse = getattr(sparse, f"as_{sparse_format.lower()}") - except AttributeError: - raise ValueError(f"{sparse_format} is not a valid sparse format") - - data = as_sparse(self.data.astype(dtype), fill_value=fill_value) - return self._replace(data=data) - - def _to_dense(self): - """ - Change backend from sparse to np.array + dask.array.compute """ - if hasattr(self._data, "todense"): - return self._replace(data=self._data.todense()) - return self.copy(deep=False) + new = self.copy(deep=False) + return new.load(**kwargs) def isel( - self: T_Variable, + self, indexers: Mapping[Any, Any] | None = None, missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, - ) -> T_Variable: + ) -> Self: """Return a new array indexed along the specified dimension(s). Parameters @@ -1374,7 +1054,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA): pads = [(0, 0) if d != dim else dim_pad for d in self.dims] data = np.pad( - trimmed_data.astype(dtype), + duck_array_ops.astype(trimmed_data, dtype), pads, mode="constant", constant_values=fill_value, @@ -1431,14 +1111,12 @@ def pad( self, pad_width: Mapping[Any, int | tuple[int, int]] | None = None, mode: PadModeOptions = "constant", - stat_length: int - | tuple[int, int] - | Mapping[Any, tuple[int, int]] - | None = None, - constant_values: float - | tuple[float, float] - | Mapping[Any, tuple[float, float]] - | None = None, + stat_length: ( + int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None + ) = None, + constant_values: ( + float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None + ) = None, end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None, reflect_type: PadReflectOptions = None, keep_attrs: bool | None = None, @@ -1523,7 +1201,7 @@ def pad( pad_option_kwargs["reflect_type"] = reflect_type array = np.pad( - self.data.astype(dtype, copy=False), + duck_array_ops.astype(self.data, dtype, copy=False), pad_width_by_index, mode=mode, **pad_option_kwargs, @@ -1586,7 +1264,7 @@ def transpose( self, *dims: Hashable | ellipsis, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> Variable: + ) -> Self: """Return a new Variable object with transposed dimensions. Parameters @@ -1631,7 +1309,7 @@ def transpose( return self._replace(dims=dims, data=data) @property - def T(self) -> Variable: + def T(self) -> Self: return self.transpose() def set_dims(self, dims, shape=None): @@ -1654,7 +1332,7 @@ def set_dims(self, dims, shape=None): if isinstance(dims, str): dims = [dims] - if shape is None and utils.is_dict_like(dims): + if shape is None and is_dict_like(dims): shape = dims.values() missing_dims = set(self.dims) - set(dims) @@ -1676,7 +1354,8 @@ def set_dims(self, dims, shape=None): tmp_shape = tuple(dims_map[d] for d in expanded_dims) expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape) else: - expanded_data = self.data[(None,) * (len(expanded_dims) - self.ndim)] + indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,) + expanded_data = self.data[indexer] expanded_var = Variable( expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True @@ -1705,7 +1384,9 @@ def _stack_once(self, dims: list[Hashable], new_dim: Hashable): new_data = duck_array_ops.reshape(reordered.data, new_shape) new_dims = reordered.dims[: len(other_dims)] + (new_dim,) - return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) + return type(self)( + new_dims, new_data, self._attrs, self._encoding, fastpath=True + ) def stack(self, dimensions=None, **dimensions_kwargs): """ @@ -1739,17 +1420,15 @@ def stack(self, dimensions=None, **dimensions_kwargs): result = result._stack_once(dims, new_dim) return result - def _unstack_once_full( - self, dims: Mapping[Any, int], old_dim: Hashable - ) -> Variable: + def _unstack_once_full(self, dim: Mapping[Any, int], old_dim: Hashable) -> Self: """ Unstacks the variable without needing an index. Unlike `_unstack_once`, this function requires the existing dimension to contain the full product of the new dimensions. """ - new_dim_names = tuple(dims.keys()) - new_dim_sizes = tuple(dims.values()) + new_dim_names = tuple(dim.keys()) + new_dim_sizes = tuple(dim.values()) if old_dim not in self.dims: raise ValueError(f"invalid existing dimension: {old_dim}") @@ -1771,10 +1450,12 @@ def _unstack_once_full( reordered = self.transpose(*dim_order) new_shape = reordered.shape[: len(other_dims)] + new_dim_sizes - new_data = reordered.data.reshape(new_shape) + new_data = duck_array_ops.reshape(reordered.data, new_shape) new_dims = reordered.dims[: len(other_dims)] + new_dim_names - return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) + return type(self)( + new_dims, new_data, self._attrs, self._encoding, fastpath=True + ) def _unstack_once( self, @@ -1782,7 +1463,7 @@ def _unstack_once( dim: Hashable, fill_value=dtypes.NA, sparse: bool = False, - ) -> Variable: + ) -> Self: """ Unstacks this variable given an index to unstack and the name of the dimension to which the index refers. @@ -1799,15 +1480,20 @@ def _unstack_once( new_shape = tuple(list(reordered.shape[: len(other_dims)]) + new_dim_sizes) new_dims = reordered.dims[: len(other_dims)] + new_dim_names + create_template: Callable if fill_value is dtypes.NA: is_missing_values = math.prod(new_shape) > math.prod(self.shape) if is_missing_values: dtype, fill_value = dtypes.maybe_promote(self.dtype) + + create_template = partial(np.full_like, fill_value=fill_value) else: dtype = self.dtype fill_value = dtypes.get_fill_value(dtype) + create_template = np.empty_like else: dtype = self.dtype + create_template = partial(np.full_like, fill_value=fill_value) if sparse: # unstacking a dense multitindexed array to a sparse array @@ -1830,12 +1516,7 @@ def _unstack_once( ) else: - data = np.full_like( - self.data, - fill_value=fill_value, - shape=new_shape, - dtype=dtype, - ) + data = create_template(self.data, shape=new_shape, dtype=dtype) # Indexer is a list of lists of locations. Each list is the locations # on the new dimension. This is robust to the data being sparse; in that @@ -1903,7 +1584,7 @@ def clip(self, min=None, max=None): return apply_ufunc(np.clip, self, min, max, dask="allowed") - def reduce( + def reduce( # type: ignore[override] self, func: Callable[..., Any], dim: Dims = None, @@ -1944,57 +1625,21 @@ def reduce( Array with summarized data and the indicated dimension(s) removed. """ - if dim == ...: - dim = None - if dim is not None and axis is not None: - raise ValueError("cannot supply both 'axis' and 'dim' arguments") - - if dim is not None: - axis = self.get_axis_num(dim) - - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", r"Mean of empty slice", category=RuntimeWarning - ) - if axis is not None: - if isinstance(axis, tuple) and len(axis) == 1: - # unpack axis for the benefit of functions - # like np.argmin which can't handle tuple arguments - axis = axis[0] - data = func(self.data, axis=axis, **kwargs) - else: - data = func(self.data, **kwargs) - - if getattr(data, "shape", ()) == self.shape: - dims = self.dims - else: - removed_axes: Iterable[int] - if axis is None: - removed_axes = range(self.ndim) - else: - removed_axes = np.atleast_1d(axis) % self.ndim - if keepdims: - # Insert np.newaxis for removed dims - slices = tuple( - np.newaxis if i in removed_axes else slice(None, None) - for i in range(self.ndim) - ) - if getattr(data, "shape", None) is None: - # Reduce has produced a scalar value, not an array-like - data = np.asanyarray(data)[slices] - else: - data = data[slices] - dims = self.dims - else: - dims = tuple( - adim for n, adim in enumerate(self.dims) if n not in removed_axes - ) + keep_attrs_ = ( + _get_keep_attrs(default=False) if keep_attrs is None else keep_attrs + ) - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) - attrs = self._attrs if keep_attrs else None + # Noe that the call order for Variable.mean is + # Variable.mean -> NamedArray.mean -> Variable.reduce + # -> NamedArray.reduce + result = super().reduce( + func=func, dim=dim, axis=axis, keepdims=keepdims, **kwargs + ) - return Variable(dims, data, attrs=attrs) + # return Variable always to support IndexVariable + return Variable( + result.dims, result._data, attrs=result._attrs if keep_attrs_ else None + ) @classmethod def concat( @@ -2055,12 +1700,13 @@ def concat( # twice variables = list(variables) first_var = variables[0] + first_var_dims = first_var.dims - arrays = [v.data for v in variables] + arrays = [v._data for v in variables] - if dim in first_var.dims: + if dim in first_var_dims: axis = first_var.get_axis_num(dim) - dims = first_var.dims + dims = first_var_dims data = duck_array_ops.concatenate(arrays, axis=axis) if positions is not None: # TODO: deprecate this option -- we don't need it for groupby @@ -2069,7 +1715,7 @@ def concat( data = duck_array_ops.take(data, indices, axis=axis) else: axis = 0 - dims = (dim,) + first_var.dims + dims = (dim,) + first_var_dims data = duck_array_ops.stack(arrays, axis=axis) attrs = merge_attrs( @@ -2078,12 +1724,12 @@ def concat( encoding = dict(first_var.encoding) if not shortcut: for var in variables: - if var.dims != first_var.dims: + if var.dims != first_var_dims: raise ValueError( - f"Variable has dimensions {list(var.dims)} but first Variable has dimensions {list(first_var.dims)}" + f"Variable has dimensions {tuple(var.dims)} but first Variable has dimensions {tuple(first_var_dims)}" ) - return cls(dims, data, attrs, encoding) + return cls(dims, data, attrs, encoding, fastpath=True) def equals(self, other, equiv=duck_array_ops.array_equiv): """True if two Variables have the same dimensions and values; @@ -2142,7 +1788,7 @@ def quantile( keep_attrs: bool | None = None, skipna: bool | None = None, interpolation: QuantileMethods | None = None, - ) -> Variable: + ) -> Self: """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements. @@ -2159,15 +1805,15 @@ def quantile( desired quantile lies between two data points. The options sorted by their R type as summarized in the H&F paper [1]_ are: - 1. "inverted_cdf" (*) - 2. "averaged_inverted_cdf" (*) - 3. "closest_observation" (*) - 4. "interpolated_inverted_cdf" (*) - 5. "hazen" (*) - 6. "weibull" (*) + 1. "inverted_cdf" + 2. "averaged_inverted_cdf" + 3. "closest_observation" + 4. "interpolated_inverted_cdf" + 5. "hazen" + 6. "weibull" 7. "linear" (default) - 8. "median_unbiased" (*) - 9. "normal_unbiased" (*) + 8. "median_unbiased" + 9. "normal_unbiased" The first three methods are discontiuous. The following discontinuous variations of the default "linear" (7.) option are also available: @@ -2181,8 +1827,6 @@ def quantile( was previously called "interpolation", renamed in accordance with numpy version 1.22.0. - (*) These methods require numpy version 1.22 or newer. - keep_attrs : bool, optional If True, the variable's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new @@ -2228,7 +1872,7 @@ def quantile( method = interpolation if skipna or (skipna is None and self.dtype.kind in "cfO"): - _quantile_func = np.nanquantile + _quantile_func = nputils.nanquantile else: _quantile_func = np.quantile @@ -2250,14 +1894,7 @@ def _wrapper(npa, **kwargs): axis = np.arange(-1, -1 * len(dim) - 1, -1) - if Version(np.__version__) >= Version("1.22.0"): - kwargs = {"q": q, "axis": axis, "method": method} - else: - if method not in ("linear", "lower", "higher", "midpoint", "nearest"): - raise ValueError( - f"Interpolation method '{method}' requires numpy >= 1.22 or is not supported." - ) - kwargs = {"q": q, "axis": axis, "interpolation": method} + kwargs = {"q": q, "axis": axis, "method": method} result = apply_ufunc( _wrapper, @@ -2305,6 +1942,7 @@ def rank(self, dim, pct=False): -------- Dataset.rank, DataArray.rank """ + # This could / should arguably be implemented at the DataArray & Dataset level if not OPTIONS["use_bottleneck"]: raise RuntimeError( "rank requires bottleneck to be enabled." @@ -2313,24 +1951,20 @@ def rank(self, dim, pct=False): import bottleneck as bn - data = self.data - - if is_duck_dask_array(data): - raise TypeError( - "rank does not work for arrays stored as dask " - "arrays. Load the data via .compute() or .load() " - "prior to calling this method." - ) - elif not isinstance(data, np.ndarray): - raise TypeError(f"rank is not implemented for {type(data)} objects.") - - axis = self.get_axis_num(dim) func = bn.nanrankdata if self.dtype.kind == "f" else bn.rankdata - ranked = func(data, axis=axis) + ranked = xr.apply_ufunc( + func, + self, + input_core_dims=[[dim]], + output_core_dims=[[dim]], + dask="parallelized", + kwargs=dict(axis=-1), + ).transpose(*self.dims) + if pct: - count = np.sum(~np.isnan(data), axis=axis, keepdims=True) + count = self.notnull().sum(dim) ranked /= count - return Variable(self.dims, ranked) + return ranked def rolling_window( self, dim, window, window_dim, center=False, pad=True, fill_value=dtypes.NA @@ -2370,7 +2004,7 @@ def rolling_window( -------- >>> v = Variable(("a", "b"), np.arange(8).reshape((2, 4))) >>> v.rolling_window("b", 3, "window_dim") - + Size: 192B array([[[nan, nan, 0.], [nan, 0., 1.], [ 0., 1., 2.], @@ -2382,7 +2016,7 @@ def rolling_window( [ 5., 6., 7.]]]) >>> v.rolling_window("b", 3, "window_dim", center=True) - + Size: 192B array([[[nan, 0., 1.], [ 0., 1., 2.], [ 1., 2., 3.], @@ -2402,7 +2036,7 @@ def rolling_window( """ if fill_value is dtypes.NA: # np.nan is passed dtype, fill_value = dtypes.maybe_promote(self.dtype) - var = self.astype(dtype, copy=False) + var = duck_array_ops.astype(self, dtype, copy=False) else: dtype = self.dtype var = self @@ -2460,10 +2094,10 @@ def coarsen_reshape(self, windows, boundary, side): """ Construct a reshaped-array for coarsen """ - if not utils.is_dict_like(boundary): + if not is_dict_like(boundary): boundary = {d: boundary for d in windows.keys()} - if not utils.is_dict_like(side): + if not is_dict_like(side): side = {d: side for d in windows.keys()} # remove unrelated dimensions @@ -2504,8 +2138,8 @@ def coarsen_reshape(self, windows, boundary, side): variable = variable.pad(pad_width, mode="constant") else: raise TypeError( - "{} is invalid for boundary. Valid option is 'exact', " - "'trim' and 'pad'".format(boundary[d]) + f"{boundary[d]} is invalid for boundary. Valid option is 'exact', " + "'trim' and 'pad'" ) shape = [] @@ -2521,7 +2155,7 @@ def coarsen_reshape(self, windows, boundary, side): else: shape.append(variable.shape[i]) - return variable.data.reshape(shape), tuple(axes) + return duck_array_ops.reshape(variable.data, shape), tuple(axes) def isnull(self, keep_attrs: bool | None = None): """Test each value in the array for whether it is a missing value. @@ -2539,10 +2173,10 @@ def isnull(self, keep_attrs: bool | None = None): -------- >>> var = xr.Variable("x", [1, np.nan, 3]) >>> var - + Size: 24B array([ 1., nan, 3.]) >>> var.isnull() - + Size: 3B array([False, True, False]) """ from xarray.core.computation import apply_ufunc @@ -2573,10 +2207,10 @@ def notnull(self, keep_attrs: bool | None = None): -------- >>> var = xr.Variable("x", [1, np.nan, 3]) >>> var - + Size: 24B array([ 1., nan, 3.]) >>> var.notnull() - + Size: 3B array([ True, False, True]) """ from xarray.core.computation import apply_ufunc @@ -2592,26 +2226,26 @@ def notnull(self, keep_attrs: bool | None = None): ) @property - def real(self): + def imag(self) -> Variable: """ - The real part of the variable. + The imaginary part of the variable. See Also -------- - numpy.ndarray.real + numpy.ndarray.imag """ - return self._replace(data=self.data.real) + return self._new(data=self.data.imag) @property - def imag(self): + def real(self) -> Variable: """ - The imaginary part of the variable. + The real part of the variable. See Also -------- - numpy.ndarray.imag + numpy.ndarray.real """ - return self._replace(data=self.data.imag) + return self._new(data=self.data.real) def __array_wrap__(self, obj, context=None): return Variable(self.dims, obj) @@ -2821,6 +2455,105 @@ def argmax( """ return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) + def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable: + """ + Use sparse-array as backend. + """ + from xarray.namedarray._typing import _default as _default_named + + if sparse_format is _default: + sparse_format = _default_named + + if fill_value is _default: + fill_value = _default_named + + out = super()._as_sparse(sparse_format, fill_value) + return cast("Variable", out) + + def _to_dense(self) -> Variable: + """ + Change backend from sparse to np.array. + """ + out = super()._to_dense() + return cast("Variable", out) + + def chunk( # type: ignore[override] + self, + chunks: int | Literal["auto"] | Mapping[Any, None | int | tuple[int, ...]] = {}, + name: str | None = None, + lock: bool | None = None, + inline_array: bool | None = None, + chunked_array_type: str | ChunkManagerEntrypoint[Any] | None = None, + from_array_kwargs: Any = None, + **chunks_kwargs: Any, + ) -> Self: + """Coerce this array's data into a dask array with the given chunks. + + If this variable is a non-dask array, it will be converted to dask + array. If it's a dask array, it will be rechunked to the given chunk + sizes. + + If neither chunks is not provided for one or more dimensions, chunk + sizes along that dimension will not be updated; non-dask arrays will be + converted into dask arrays with a single block. + + Parameters + ---------- + chunks : int, tuple or dict, optional + Chunk sizes along each dimension, e.g., ``5``, ``(5, 5)`` or + ``{'x': 5, 'y': 5}``. + name : str, optional + Used to generate the name for this array in the internal dask + graph. Does not need not be unique. + lock : bool, default: False + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. + inline_array : bool, default: False + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntrypoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + **chunks_kwargs : {dim: chunks, ...}, optional + The keyword arguments form of ``chunks``. + One of chunks or chunks_kwargs must be provided. + + Returns + ------- + chunked : xarray.Variable + + See Also + -------- + Variable.chunks + Variable.chunksizes + xarray.unify_chunks + dask.array.from_array + """ + + if from_array_kwargs is None: + from_array_kwargs = {} + + # TODO deprecate passing these dask-specific arguments explicitly. In future just pass everything via from_array_kwargs + _from_array_kwargs = consolidate_dask_from_array_kwargs( + from_array_kwargs, + name=name, + lock=lock, + inline_array=inline_array, + ) + + return super().chunk( + chunks=chunks, + chunked_array_type=chunked_array_type, + from_array_kwargs=_from_array_kwargs, + **chunks_kwargs, + ) + class IndexVariable(Variable): """Wrapper for accommodating a pandas.Index in an xarray.Variable. @@ -2835,6 +2568,9 @@ class IndexVariable(Variable): __slots__ = () + # TODO: PandasIndexingAdapter doesn't match the array api: + _data: PandasIndexingAdapter # type: ignore[assignment] + def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): super().__init__(dims, data, attrs, encoding, fastpath) if self.ndim != 1: @@ -2844,11 +2580,13 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): if not isinstance(self._data, PandasIndexingAdapter): self._data = PandasIndexingAdapter(self._data) - def __dask_tokenize__(self): + def __dask_tokenize__(self) -> object: from dask.base import normalize_token # Don't waste time converting pd.Index to np.ndarray - return normalize_token((type(self), self._dims, self._data.array, self._attrs)) + return normalize_token( + (type(self), self._dims, self._data.array, self._attrs or None) + ) def load(self): # data is already loaded into memory for IndexVariable @@ -2869,7 +2607,15 @@ def values(self, values): f"Please use DataArray.assign_coords, Dataset.assign_coords or Dataset.assign as appropriate." ) - def chunk(self, chunks={}, name=None, lock=False, inline_array=False): + def chunk( + self, + chunks={}, + name=None, + lock=False, + inline_array=False, + chunked_array_type=None, + from_array_kwargs=None, + ): # Dummy - do not chunk. This method is invoked e.g. by Dataset.chunk() return self.copy(deep=False) @@ -2943,7 +2689,7 @@ def concat( return cls(first_var.dims, data, attrs) - def copy(self, deep: bool = True, data: ArrayLike | None = None): + def copy(self, deep: bool = True, data: T_DuckArray | ArrayLike | None = None): """Returns a copy of this object. `deep` is ignored since data is stored in the form of @@ -2968,14 +2714,16 @@ def copy(self, deep: bool = True, data: ArrayLike | None = None): data copied from original. """ if data is None: - ndata = self._data.copy(deep=deep) + ndata = self._data + + if deep: + ndata = copy.deepcopy(ndata, None) + else: ndata = as_compatible_data(data) - if self.shape != ndata.shape: + if self.shape != ndata.shape: # type: ignore[attr-defined] raise ValueError( - "Data shape {} must match shape of object {}".format( - ndata.shape, self.shape - ) + f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined] ) attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs) @@ -3065,20 +2813,13 @@ def _inplace_binary_op(self, other, f): ) -# for backwards compatibility -Coordinate = utils.alias(IndexVariable, "Coordinate") - - def _unified_dims(variables): # validate dimensions all_dims = {} for var in variables: var_dims = var.dims - if len(set(var_dims)) < len(var_dims): - raise ValueError( - "broadcasting cannot handle duplicate " - f"dimensions: {list(var_dims)!r}" - ) + _raise_if_any_duplicate_dimensions(var_dims, err_context="Broadcasting") + for d, s in zip(var_dims, var.shape): if d not in all_dims: all_dims[d] = s @@ -3118,6 +2859,16 @@ def broadcast_variables(*variables: Variable) -> tuple[Variable, ...]: def _broadcast_compat_data(self, other): + if not OPTIONS["arithmetic_broadcast"]: + if (isinstance(other, Variable) and self.dims != other.dims) or ( + is_duck_array(other) and self.ndim != other.ndim + ): + raise ValueError( + "Broadcasting is necessary but automatic broadcasting is disabled via " + "global option `'arithmetic_broadcast'`. " + "Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting." + ) + if all(hasattr(other, attr) for attr in ["dims", "data", "shape", "encoding"]): # `other` satisfies the necessary Variable API for broadcast_variables new_self, new_other = _broadcast_compat_variables(self, other) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 904c6a4d980..ae9521309e0 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -9,8 +9,9 @@ from xarray.core import duck_array_ops, utils from xarray.core.alignment import align, broadcast from xarray.core.computation import apply_ufunc, dot -from xarray.core.pycompat import is_duck_dask_array -from xarray.core.types import Dims, T_Xarray +from xarray.core.types import Dims, T_DataArray, T_Xarray +from xarray.namedarray.utils import is_duck_dask_array +from xarray.util.deprecation_helpers import _deprecate_positional_args # Weighted quantile methods are a subset of the numpy supported quantile methods. QUANTILE_METHODS = Literal[ @@ -144,7 +145,7 @@ class Weighted(Generic[T_Xarray]): __slots__ = ("obj", "weights") - def __init__(self, obj: T_Xarray, weights: DataArray) -> None: + def __init__(self, obj: T_Xarray, weights: T_DataArray) -> None: """ Create a Weighted object @@ -188,7 +189,7 @@ def _weight_check(w): _weight_check(weights.data) self.obj: T_Xarray = obj - self.weights: DataArray = weights + self.weights: T_DataArray = weights def _check_dim(self, dim: Dims): """raise an error if any dimension is missing""" @@ -198,19 +199,20 @@ def _check_dim(self, dim: Dims): dims = [dim] if dim else [] else: dims = list(dim) - missing_dims = set(dims) - set(self.obj.dims) - set(self.weights.dims) + all_dims = set(self.obj.dims).union(set(self.weights.dims)) + missing_dims = set(dims) - all_dims if missing_dims: raise ValueError( - f"{self.__class__.__name__} does not contain the dimensions: {missing_dims}" + f"Dimensions {tuple(missing_dims)} not found in {self.__class__.__name__} dimensions {tuple(all_dims)}" ) @staticmethod def _reduce( - da: DataArray, - weights: DataArray, + da: T_DataArray, + weights: T_DataArray, dim: Dims = None, skipna: bool | None = None, - ) -> DataArray: + ) -> T_DataArray: """reduce using dot; equivalent to (da * weights).sum(dim, skipna) for internal use only @@ -226,9 +228,9 @@ def _reduce( # `dot` does not broadcast arrays, so this avoids creating a large # DataArray (if `weights` has additional dimensions) - return dot(da, weights, dims=dim) + return dot(da, weights, dim=dim) - def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray: + def _sum_of_weights(self, da: T_DataArray, dim: Dims = None) -> T_DataArray: """Calculate the sum of weights, accounting for missing values""" # we need to mask data values that are nan; else the weights are wrong @@ -238,7 +240,10 @@ def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray: # (and not 2); GH4074 if self.weights.dtype == bool: sum_of_weights = self._reduce( - mask, self.weights.astype(int), dim=dim, skipna=False + mask, + duck_array_ops.astype(self.weights, dtype=int), + dim=dim, + skipna=False, ) else: sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False) @@ -250,10 +255,10 @@ def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray: def _sum_of_squares( self, - da: DataArray, + da: T_DataArray, dim: Dims = None, skipna: bool | None = None, - ) -> DataArray: + ) -> T_DataArray: """Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s).""" demeaned = da - da.weighted(self.weights).mean(dim=dim) @@ -262,20 +267,20 @@ def _sum_of_squares( def _weighted_sum( self, - da: DataArray, + da: T_DataArray, dim: Dims = None, skipna: bool | None = None, - ) -> DataArray: + ) -> T_DataArray: """Reduce a DataArray by a weighted ``sum`` along some dimension(s).""" return self._reduce(da, self.weights, dim=dim, skipna=skipna) def _weighted_mean( self, - da: DataArray, + da: T_DataArray, dim: Dims = None, skipna: bool | None = None, - ) -> DataArray: + ) -> T_DataArray: """Reduce a DataArray by a weighted ``mean`` along some dimension(s).""" weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna) @@ -286,10 +291,10 @@ def _weighted_mean( def _weighted_var( self, - da: DataArray, + da: T_DataArray, dim: Dims = None, skipna: bool | None = None, - ) -> DataArray: + ) -> T_DataArray: """Reduce a DataArray by a weighted ``var`` along some dimension(s).""" sum_of_squares = self._sum_of_squares(da, dim=dim, skipna=skipna) @@ -300,26 +305,27 @@ def _weighted_var( def _weighted_std( self, - da: DataArray, + da: T_DataArray, dim: Dims = None, skipna: bool | None = None, - ) -> DataArray: + ) -> T_DataArray: """Reduce a DataArray by a weighted ``std`` along some dimension(s).""" - return cast("DataArray", np.sqrt(self._weighted_var(da, dim, skipna))) + return cast("T_DataArray", np.sqrt(self._weighted_var(da, dim, skipna))) def _weighted_quantile( self, - da: DataArray, + da: T_DataArray, q: ArrayLike, dim: Dims = None, skipna: bool | None = None, - ) -> DataArray: + ) -> T_DataArray: """Apply a weighted ``quantile`` to a DataArray along some dimension(s).""" def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray: """Return the interpolation parameter.""" # Note that options are not yet exposed in the public API. + h: np.ndarray if method == "linear": h = (n - 1) * q + 1 elif method == "interpolated_inverted_cdf": @@ -445,18 +451,22 @@ def _weighted_quantile_1d( def _implementation(self, func, dim, **kwargs): raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`") + @_deprecate_positional_args("v2023.10.0") def sum_of_weights( self, dim: Dims = None, + *, keep_attrs: bool | None = None, ) -> T_Xarray: return self._implementation( self._sum_of_weights, dim=dim, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def sum_of_squares( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -464,9 +474,11 @@ def sum_of_squares( self._sum_of_squares, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def sum( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -474,9 +486,11 @@ def sum( self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def mean( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -484,9 +498,11 @@ def mean( self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def var( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -494,9 +510,11 @@ def var( self._weighted_var, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def std( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: diff --git a/xarray/datatree_/.flake8 b/xarray/datatree_/.flake8 new file mode 100644 index 00000000000..f1e3f9271e1 --- /dev/null +++ b/xarray/datatree_/.flake8 @@ -0,0 +1,15 @@ +[flake8] +ignore = + # whitespace before ':' - doesn't work well with black + E203 + # module level import not at top of file + E402 + # line too long - let black worry about that + E501 + # do not assign a lambda expression, use a def + E731 + # line break before binary operator + W503 +exclude= + .eggs + doc diff --git a/xarray/datatree_/.git_archival.txt b/xarray/datatree_/.git_archival.txt new file mode 100644 index 00000000000..3994ec0a83e --- /dev/null +++ b/xarray/datatree_/.git_archival.txt @@ -0,0 +1,4 @@ +node: $Format:%H$ +node-date: $Format:%cI$ +describe-name: $Format:%(describe:tags=true)$ +ref-names: $Format:%D$ diff --git a/xarray/datatree_/.github/dependabot.yml b/xarray/datatree_/.github/dependabot.yml new file mode 100644 index 00000000000..d1d1190be70 --- /dev/null +++ b/xarray/datatree_/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: + - package-ecosystem: pip + directory: "/" + schedule: + interval: daily + - package-ecosystem: "github-actions" + directory: "/" + schedule: + # Check for updates to GitHub Actions every weekday + interval: "daily" diff --git a/xarray/datatree_/.github/pull_request_template.md b/xarray/datatree_/.github/pull_request_template.md new file mode 100644 index 00000000000..8270498108a --- /dev/null +++ b/xarray/datatree_/.github/pull_request_template.md @@ -0,0 +1,7 @@ + + +- [ ] Closes #xxxx +- [ ] Tests added +- [ ] Passes `pre-commit run --all-files` +- [ ] New functions/methods are listed in `api.rst` +- [ ] Changes are summarized in `docs/source/whats-new.rst` diff --git a/xarray/datatree_/.github/workflows/main.yaml b/xarray/datatree_/.github/workflows/main.yaml new file mode 100644 index 00000000000..37034fc5900 --- /dev/null +++ b/xarray/datatree_/.github/workflows/main.yaml @@ -0,0 +1,97 @@ +name: CI + +on: + push: + branches: + - main + pull_request: + branches: + - main + schedule: + - cron: "0 0 * * *" + +jobs: + + test: + name: ${{ matrix.python-version }}-build + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + + - name: Create conda environment + uses: mamba-org/provision-with-micromamba@main + with: + cache-downloads: true + micromamba-version: 'latest' + environment-file: ci/environment.yml + extra-specs: | + python=${{ matrix.python-version }} + + - name: Conda info + run: conda info + + - name: Install datatree + run: | + python -m pip install -e . --no-deps --force-reinstall + + - name: Conda list + run: conda list + + - name: Running Tests + run: | + python -m pytest --cov=./ --cov-report=xml --verbose + + - name: Upload code coverage to Codecov + uses: codecov/codecov-action@v3.1.4 + with: + file: ./coverage.xml + flags: unittests + env_vars: OS,PYTHON + name: codecov-umbrella + fail_ci_if_error: false + + + test-upstream: + name: ${{ matrix.python-version }}-dev-build + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + + - name: Create conda environment + uses: mamba-org/provision-with-micromamba@main + with: + cache-downloads: true + micromamba-version: 'latest' + environment-file: ci/environment.yml + extra-specs: | + python=${{ matrix.python-version }} + + - name: Conda info + run: conda info + + - name: Install dev reqs + run: | + python -m pip install --no-deps --upgrade \ + git+https://github.com/pydata/xarray \ + git+https://github.com/Unidata/netcdf4-python + + python -m pip install -e . --no-deps --force-reinstall + + - name: Conda list + run: conda list + + - name: Running Tests + run: | + python -m pytest --verbose diff --git a/xarray/datatree_/.github/workflows/pypipublish.yaml b/xarray/datatree_/.github/workflows/pypipublish.yaml new file mode 100644 index 00000000000..7dc36d87691 --- /dev/null +++ b/xarray/datatree_/.github/workflows/pypipublish.yaml @@ -0,0 +1,84 @@ +name: Build distribution +on: + release: + types: + - published + push: + branches: + - main + pull_request: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-artifacts: + runs-on: ubuntu-latest + if: github.repository == 'xarray-contrib/datatree' + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v5 + name: Install Python + with: + python-version: 3.9 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install build + + - name: Build tarball and wheels + run: | + git clean -xdf + git restore -SW . + python -m build --sdist --wheel . + + + - uses: actions/upload-artifact@v4 + with: + name: releases + path: dist + + test-built-dist: + needs: build-artifacts + runs-on: ubuntu-latest + steps: + - uses: actions/setup-python@v5 + name: Install Python + with: + python-version: '3.10' + - uses: actions/download-artifact@v4 + with: + name: releases + path: dist + - name: List contents of built dist + run: | + ls -ltrh + ls -ltrh dist + + - name: Verify the built dist/wheel is valid + run: | + python -m pip install --upgrade pip + python -m pip install dist/xarray_datatree*.whl + python -c "import datatree; print(datatree.__version__)" + + upload-to-pypi: + needs: test-built-dist + if: github.event_name == 'release' + runs-on: ubuntu-latest + steps: + - uses: actions/download-artifact@v4 + with: + name: releases + path: dist + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@v1.8.11 + with: + user: ${{ secrets.PYPI_USERNAME }} + password: ${{ secrets.PYPI_PASSWORD }} + verbose: true diff --git a/xarray/datatree_/.gitignore b/xarray/datatree_/.gitignore new file mode 100644 index 00000000000..88af9943a90 --- /dev/null +++ b/xarray/datatree_/.gitignore @@ -0,0 +1,136 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/source/generated + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# version +_version.py + +# Ignore vscode specific settings +.vscode/ diff --git a/xarray/datatree_/.pre-commit-config.yaml b/xarray/datatree_/.pre-commit-config.yaml new file mode 100644 index 00000000000..ea73c38d73e --- /dev/null +++ b/xarray/datatree_/.pre-commit-config.yaml @@ -0,0 +1,58 @@ +# https://pre-commit.com/ +ci: + autoupdate_schedule: monthly +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + # isort should run before black as black sometimes tweaks the isort output + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + # https://github.com/python/black#version-control-integration + - repo: https://github.com/psf/black + rev: 23.12.1 + hooks: + - id: black + - repo: https://github.com/keewis/blackdoc + rev: v0.3.9 + hooks: + - id: blackdoc + - repo: https://github.com/PyCQA/flake8 + rev: 6.1.0 + hooks: + - id: flake8 + # - repo: https://github.com/Carreau/velin + # rev: 0.0.8 + # hooks: + # - id: velin + # args: ["--write", "--compact"] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.8.0 + hooks: + - id: mypy + # Copied from setup.cfg + exclude: "properties|asv_bench|docs" + additional_dependencies: [ + # Type stubs + types-python-dateutil, + types-pkg_resources, + types-PyYAML, + types-pytz, + # Dependencies that are typed + numpy, + typing-extensions>=4.1.0, + ] + # run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194 + # - repo: https://github.com/asottile/pyupgrade + # rev: v1.22.1 + # hooks: + # - id: pyupgrade + # args: + # - "--py3-only" + # # remove on f-strings in Py3.7 + # - "--keep-percent-format" diff --git a/xarray/datatree_/LICENSE b/xarray/datatree_/LICENSE new file mode 100644 index 00000000000..d68e7230919 --- /dev/null +++ b/xarray/datatree_/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright (c) 2022 onwards, datatree developers + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/xarray/datatree_/README.md b/xarray/datatree_/README.md new file mode 100644 index 00000000000..e41a13b4cb6 --- /dev/null +++ b/xarray/datatree_/README.md @@ -0,0 +1,95 @@ +# datatree + +| CI | [![GitHub Workflow Status][github-ci-badge]][github-ci-link] [![Code Coverage Status][codecov-badge]][codecov-link] [![pre-commit.ci status][pre-commit.ci-badge]][pre-commit.ci-link] | +| :---------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| **Docs** | [![Documentation Status][rtd-badge]][rtd-link] | +| **Package** | [![Conda][conda-badge]][conda-link] [![PyPI][pypi-badge]][pypi-link] | +| **License** | [![License][license-badge]][repo-link] | + + +**Datatree is a prototype implementation of a tree-like hierarchical data structure for xarray.** + +Datatree was born after the xarray team recognised a [need for a new hierarchical data structure](https://github.com/pydata/xarray/issues/4118), +that was more flexible than a single `xarray.Dataset` object. +The initial motivation was to represent netCDF files / Zarr stores with multiple nested groups in a single in-memory object, +but `datatree.DataTree` objects have many other uses. + +### DEPRECATION NOTICE + +Datatree is in the process of being merged upstream into xarray (as of [v0.0.14](https://github.com/xarray-contrib/datatree/releases/tag/v0.0.14), see xarray issue [#8572](https://github.com/pydata/xarray/issues/8572)). We are aiming to preserve the record of contributions to this repository during the migration process. However whilst we will hapily accept new PRs to this repository, this repo will be deprecated and any PRs since [v0.0.14](https://github.com/xarray-contrib/datatree/releases/tag/v0.0.14) might be later copied across to xarray without full git attribution. + +Hopefully for users the disruption will be minimal - and just mean that in some future version of xarray you only need to do `from xarray import DataTree` rather than `from datatree import DataTree`. Once the migration is complete this repository will be archived. + +### Installation +You can install datatree via pip: +```shell +pip install xarray-datatree +``` + +or via conda-forge +```shell +conda install -c conda-forge xarray-datatree +``` + +### Why Datatree? + +You might want to use datatree for: + +- Organising many related datasets, e.g. results of the same experiment with different parameters, or simulations of the same system using different models, +- Analysing similar data at multiple resolutions simultaneously, such as when doing a convergence study, +- Comparing heterogenous but related data, such as experimental and theoretical data, +- I/O with nested data formats such as netCDF / Zarr groups. + +[**Talk slides on Datatree from AMS-python 2023**](https://speakerdeck.com/tomnicholas/xarray-datatree-hierarchical-data-structures-for-multi-model-science) + +### Features + +The approach used here is based on benbovy's [`DatasetNode` example](https://gist.github.com/benbovy/92e7c76220af1aaa4b3a0b65374e233a) - the basic idea is that each tree node wraps a up to a single `xarray.Dataset`. The differences are that this effort: +- Uses a node structure inspired by [anytree](https://github.com/xarray-contrib/datatree/issues/7) for the tree, +- Implements path-like getting and setting, +- Has functions for mapping user-supplied functions over every node in the tree, +- Automatically dispatches *some* of `xarray.Dataset`'s API over every node in the tree (such as `.isel`), +- Has a bunch of tests, +- Has a printable representation that currently looks like this: +drawing + +### Get Started + +You can create a `DataTree` object in 3 ways: +1) Load from a netCDF file (or Zarr store) that has groups via `open_datatree()`. +2) Using the init method of `DataTree`, which creates an individual node. + You can then specify the nodes' relationships to one other, either by setting `.parent` and `.children` attributes, + or through `__get/setitem__` access, e.g. `dt['path/to/node'] = DataTree()`. +3) Create a tree from a dictionary of paths to datasets using `DataTree.from_dict()`. + +### Development Roadmap + +Datatree currently lives in a separate repository to the main xarray package. +This allows the datatree developers to make changes to it, experiment, and improve it faster. + +Eventually we plan to fully integrate datatree upstream into xarray's main codebase, at which point the [github.com/xarray-contrib/datatree](https://github.com/xarray-contrib/datatree>) repository will be archived. +This should not cause much disruption to code that depends on datatree - you will likely only have to change the import line (i.e. from ``from datatree import DataTree`` to ``from xarray import DataTree``). + +However, until this full integration occurs, datatree's API should not be considered to have the same [level of stability as xarray's](https://docs.xarray.dev/en/stable/contributing.html#backwards-compatibility). + +### User Feedback + +We really really really want to hear your opinions on datatree! +At this point in development, user feedback is critical to help us create something that will suit everyone's needs. +Please raise any thoughts, issues, suggestions or bugs, no matter how small or large, on the [github issue tracker](https://github.com/xarray-contrib/datatree/issues). + + +[github-ci-badge]: https://img.shields.io/github/actions/workflow/status/xarray-contrib/datatree/main.yaml?branch=main&label=CI&logo=github +[github-ci-link]: https://github.com/xarray-contrib/datatree/actions?query=workflow%3ACI +[codecov-badge]: https://img.shields.io/codecov/c/github/xarray-contrib/datatree.svg?logo=codecov +[codecov-link]: https://codecov.io/gh/xarray-contrib/datatree +[rtd-badge]: https://img.shields.io/readthedocs/xarray-datatree/latest.svg +[rtd-link]: https://xarray-datatree.readthedocs.io/en/latest/?badge=latest +[pypi-badge]: https://img.shields.io/pypi/v/xarray-datatree?logo=pypi +[pypi-link]: https://pypi.org/project/xarray-datatree +[conda-badge]: https://img.shields.io/conda/vn/conda-forge/xarray-datatree?logo=anaconda +[conda-link]: https://anaconda.org/conda-forge/xarray-datatree +[license-badge]: https://img.shields.io/github/license/xarray-contrib/datatree +[repo-link]: https://github.com/xarray-contrib/datatree +[pre-commit.ci-badge]: https://results.pre-commit.ci/badge/github/xarray-contrib/datatree/main.svg +[pre-commit.ci-link]: https://results.pre-commit.ci/latest/github/xarray-contrib/datatree/main diff --git a/xarray/datatree_/ci/doc.yml b/xarray/datatree_/ci/doc.yml new file mode 100644 index 00000000000..f3b95f71bd4 --- /dev/null +++ b/xarray/datatree_/ci/doc.yml @@ -0,0 +1,25 @@ +name: datatree-doc +channels: + - conda-forge +dependencies: + - pip + - python>=3.9 + - netcdf4 + - scipy + - sphinx>=4.2.0 + - sphinx-copybutton + - sphinx-panels + - sphinx-autosummary-accessors + - sphinx-book-theme >= 0.0.38 + - nbsphinx + - sphinxcontrib-srclinks + - pickleshare + - pydata-sphinx-theme>=0.4.3 + - ipython + - h5netcdf + - zarr + - xarray + - pip: + - -e .. + - sphinxext-rediraffe + - sphinxext-opengraph diff --git a/xarray/datatree_/ci/environment.yml b/xarray/datatree_/ci/environment.yml new file mode 100644 index 00000000000..fc0c6d97e9f --- /dev/null +++ b/xarray/datatree_/ci/environment.yml @@ -0,0 +1,16 @@ +name: datatree-test +channels: + - conda-forge + - nodefaults +dependencies: + - python>=3.9 + - netcdf4 + - pytest + - flake8 + - black + - codecov + - pytest-cov + - h5netcdf + - zarr + - pip: + - xarray>=2022.05.0.dev0 diff --git a/xarray/datatree_/codecov.yml b/xarray/datatree_/codecov.yml new file mode 100644 index 00000000000..44fd739d417 --- /dev/null +++ b/xarray/datatree_/codecov.yml @@ -0,0 +1,21 @@ +codecov: + require_ci_to_pass: false + max_report_age: off + +comment: false + +ignore: + - 'datatree/tests/*' + - 'setup.py' + - 'conftest.py' + +coverage: + precision: 2 + round: down + status: + project: + default: + target: 95 + informational: true + patch: off + changes: false diff --git a/xarray/datatree_/conftest.py b/xarray/datatree_/conftest.py new file mode 100644 index 00000000000..7ef19174298 --- /dev/null +++ b/xarray/datatree_/conftest.py @@ -0,0 +1,3 @@ +import pytest + +pytest.register_assert_rewrite("datatree.testing") diff --git a/xarray/datatree_/datatree/__init__.py b/xarray/datatree_/datatree/__init__.py new file mode 100644 index 00000000000..071dcbecf8c --- /dev/null +++ b/xarray/datatree_/datatree/__init__.py @@ -0,0 +1,15 @@ +# import public API +from .datatree import DataTree +from .extensions import register_datatree_accessor +from .mapping import TreeIsomorphismError, map_over_subtree +from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError + + +__all__ = ( + "DataTree", + "TreeIsomorphismError", + "InvalidTreeError", + "NotFoundInTreeError", + "map_over_subtree", + "register_datatree_accessor", +) diff --git a/xarray/datatree_/datatree/common.py b/xarray/datatree_/datatree/common.py new file mode 100644 index 00000000000..e4d52925ede --- /dev/null +++ b/xarray/datatree_/datatree/common.py @@ -0,0 +1,105 @@ +""" +This file and class only exists because it was easier to copy the code for AttrAccessMixin from xarray.core.common +with some slight modifications than it was to change the behaviour of an inherited xarray internal here. + +The modifications are marked with # TODO comments. +""" + +import warnings +from contextlib import suppress +from typing import Any, Hashable, Iterable, List, Mapping + + +class TreeAttrAccessMixin: + """Mixin class that allows getting keys with attribute access""" + + __slots__ = () + + def __init_subclass__(cls, **kwargs): + """Verify that all subclasses explicitly define ``__slots__``. If they don't, + raise error in the core xarray module and a FutureWarning in third-party + extensions. + """ + if not hasattr(object.__new__(cls), "__dict__"): + pass + # TODO reinstate this once integrated upstream + # elif cls.__module__.startswith("datatree."): + # raise AttributeError(f"{cls.__name__} must explicitly define __slots__") + # else: + # cls.__setattr__ = cls._setattr_dict + # warnings.warn( + # f"xarray subclass {cls.__name__} should explicitly define __slots__", + # FutureWarning, + # stacklevel=2, + # ) + super().__init_subclass__(**kwargs) + + @property + def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for attribute-style access""" + yield from () + + @property + def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for key-autocompletion""" + yield from () + + def __getattr__(self, name: str) -> Any: + if name not in {"__dict__", "__setstate__"}: + # this avoids an infinite loop when pickle looks for the + # __setstate__ attribute before the xarray object is initialized + for source in self._attr_sources: + with suppress(KeyError): + return source[name] + raise AttributeError( + f"{type(self).__name__!r} object has no attribute {name!r}" + ) + + # This complicated two-method design boosts overall performance of simple operations + # - particularly DataArray methods that perform a _to_temp_dataset() round-trip - by + # a whopping 8% compared to a single method that checks hasattr(self, "__dict__") at + # runtime before every single assignment. All of this is just temporary until the + # FutureWarning can be changed into a hard crash. + def _setattr_dict(self, name: str, value: Any) -> None: + """Deprecated third party subclass (see ``__init_subclass__`` above)""" + object.__setattr__(self, name, value) + if name in self.__dict__: + # Custom, non-slotted attr, or improperly assigned variable? + warnings.warn( + f"Setting attribute {name!r} on a {type(self).__name__!r} object. Explicitly define __slots__ " + "to suppress this warning for legitimate custom attributes and " + "raise an error when attempting variables assignments.", + FutureWarning, + stacklevel=2, + ) + + def __setattr__(self, name: str, value: Any) -> None: + """Objects with ``__slots__`` raise AttributeError if you try setting an + undeclared attribute. This is desirable, but the error message could use some + improvement. + """ + try: + object.__setattr__(self, name, value) + except AttributeError as e: + # Don't accidentally shadow custom AttributeErrors, e.g. + # DataArray.dims.setter + if str(e) != "{!r} object has no attribute {!r}".format( + type(self).__name__, name + ): + raise + raise AttributeError( + f"cannot set attribute {name!r} on a {type(self).__name__!r} object. Use __setitem__ style" + "assignment (e.g., `ds['name'] = ...`) instead of assigning variables." + ) from e + + def __dir__(self) -> List[str]: + """Provide method name lookup and completion. Only provide 'public' + methods. + """ + extra_attrs = { + item + for source in self._attr_sources + for item in source + if isinstance(item, str) + } + return sorted(set(dir(type(self))) | extra_attrs) diff --git a/xarray/datatree_/datatree/datatree.py b/xarray/datatree_/datatree/datatree.py new file mode 100644 index 00000000000..10133052185 --- /dev/null +++ b/xarray/datatree_/datatree/datatree.py @@ -0,0 +1,1542 @@ +from __future__ import annotations + +import copy +import itertools +from collections import OrderedDict +from html import escape +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generic, + Hashable, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + NoReturn, + Optional, + Set, + Tuple, + Union, + overload, +) + +from xarray.core import utils +from xarray.core.coordinates import DatasetCoordinates +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset, DataVariables +from xarray.core.indexes import Index, Indexes +from xarray.core.merge import dataset_update_method +from xarray.core.options import OPTIONS as XR_OPTS +from xarray.core.utils import ( + Default, + Frozen, + HybridMappingProxy, + _default, + either_dict_or_kwargs, + maybe_wrap_array, +) +from xarray.core.variable import Variable + +from . import formatting, formatting_html +from .common import TreeAttrAccessMixin +from .mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree +from .ops import ( + DataTreeArithmeticMixin, + MappedDatasetMethodsMixin, + MappedDataWithCoords, +) +from .render import RenderTree +from xarray.core.treenode import NamedNode, NodePath, Tree + +try: + from xarray.core.variable import calculate_dimensions +except ImportError: + # for xarray versions 2022.03.0 and earlier + from xarray.core.dataset import calculate_dimensions + +if TYPE_CHECKING: + import pandas as pd + from xarray.core.merge import CoercibleValue + from xarray.core.types import ErrorOptions + +# """ +# DEVELOPERS' NOTE +# ---------------- +# The idea of this module is to create a `DataTree` class which inherits the tree structure from TreeNode, and also copies +# the entire API of `xarray.Dataset`, but with certain methods decorated to instead map the dataset function over every +# node in the tree. As this API is copied without directly subclassing `xarray.Dataset` we instead create various Mixin +# classes (in ops.py) which each define part of `xarray.Dataset`'s extensive API. +# +# Some of these methods must be wrapped to map over all nodes in the subtree. Others are fine to inherit unaltered +# (normally because they (a) only call dataset properties and (b) don't return a dataset that should be nested into a new +# tree) and some will get overridden by the class definition of DataTree. +# """ + + +T_Path = Union[str, NodePath] + + +def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: + if isinstance(data, DataArray): + ds = data.to_dataset() + elif isinstance(data, Dataset): + ds = data + elif data is None: + ds = Dataset() + else: + raise TypeError( + f"data object is not an xarray Dataset, DataArray, or None, it is of type {type(data)}" + ) + return ds + + +def _check_for_name_collisions( + children: Iterable[str], variables: Iterable[Hashable] +) -> None: + colliding_names = set(children).intersection(set(variables)) + if colliding_names: + raise KeyError( + f"Some names would collide between variables and children: {list(colliding_names)}" + ) + + +class DatasetView(Dataset): + """ + An immutable Dataset-like view onto the data in a single DataTree node. + + In-place operations modifying this object should raise an AttributeError. + This requires overriding all inherited constructors. + + Operations returning a new result will return a new xarray.Dataset object. + This includes all API on Dataset, which will be inherited. + """ + + # TODO what happens if user alters (in-place) a DataArray they extracted from this object? + + __slots__ = ( + "_attrs", + "_cache", + "_coord_names", + "_dims", + "_encoding", + "_close", + "_indexes", + "_variables", + ) + + def __init__( + self, + data_vars: Optional[Mapping[Any, Any]] = None, + coords: Optional[Mapping[Any, Any]] = None, + attrs: Optional[Mapping[Any, Any]] = None, + ): + raise AttributeError("DatasetView objects are not to be initialized directly") + + @classmethod + def _from_node( + cls, + wrapping_node: DataTree, + ) -> DatasetView: + """Constructor, using dataset attributes from wrapping node""" + + obj: DatasetView = object.__new__(cls) + obj._variables = wrapping_node._variables + obj._coord_names = wrapping_node._coord_names + obj._dims = wrapping_node._dims + obj._indexes = wrapping_node._indexes + obj._attrs = wrapping_node._attrs + obj._close = wrapping_node._close + obj._encoding = wrapping_node._encoding + + return obj + + def __setitem__(self, key, val) -> None: + raise AttributeError( + "Mutation of the DatasetView is not allowed, please use `.__setitem__` on the wrapping DataTree node, " + "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_subtree`," + "use `.copy()` first to get a mutable version of the input dataset." + ) + + def update(self, other) -> NoReturn: + raise AttributeError( + "Mutation of the DatasetView is not allowed, please use `.update` on the wrapping DataTree node, " + "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_subtree`," + "use `.copy()` first to get a mutable version of the input dataset." + ) + + # FIXME https://github.com/python/mypy/issues/7328 + @overload + def __getitem__(self, key: Mapping) -> Dataset: # type: ignore[misc] + ... + + @overload + def __getitem__(self, key: Hashable) -> DataArray: # type: ignore[misc] + ... + + @overload + def __getitem__(self, key: Any) -> Dataset: + ... + + def __getitem__(self, key) -> DataArray: + # TODO call the `_get_item` method of DataTree to allow path-like access to contents of other nodes + # For now just call Dataset.__getitem__ + return Dataset.__getitem__(self, key) + + @classmethod + def _construct_direct( + cls, + variables: dict[Any, Variable], + coord_names: set[Hashable], + dims: Optional[dict[Any, int]] = None, + attrs: Optional[dict] = None, + indexes: Optional[dict[Any, Index]] = None, + encoding: Optional[dict] = None, + close: Optional[Callable[[], None]] = None, + ) -> Dataset: + """ + Overriding this method (along with ._replace) and modifying it to return a Dataset object + should hopefully ensure that the return type of any method on this object is a Dataset. + """ + if dims is None: + dims = calculate_dimensions(variables) + if indexes is None: + indexes = {} + obj = object.__new__(Dataset) + obj._variables = variables + obj._coord_names = coord_names + obj._dims = dims + obj._indexes = indexes + obj._attrs = attrs + obj._close = close + obj._encoding = encoding + return obj + + def _replace( + self, + variables: Optional[dict[Hashable, Variable]] = None, + coord_names: Optional[set[Hashable]] = None, + dims: Optional[dict[Any, int]] = None, + attrs: dict[Hashable, Any] | None | Default = _default, + indexes: Optional[dict[Hashable, Index]] = None, + encoding: dict | None | Default = _default, + inplace: bool = False, + ) -> Dataset: + """ + Overriding this method (along with ._construct_direct) and modifying it to return a Dataset object + should hopefully ensure that the return type of any method on this object is a Dataset. + """ + + if inplace: + raise AttributeError("In-place mutation of the DatasetView is not allowed") + + return Dataset._replace( + self, + variables=variables, + coord_names=coord_names, + dims=dims, + attrs=attrs, + indexes=indexes, + encoding=encoding, + inplace=inplace, + ) + + def map( + self, + func: Callable, + keep_attrs: bool | None = None, + args: Iterable[Any] = (), + **kwargs: Any, + ) -> Dataset: + """Apply a function to each data variable in this dataset + + Parameters + ---------- + func : callable + Function which can be called in the form `func(x, *args, **kwargs)` + to transform each DataArray `x` in this dataset into another + DataArray. + keep_attrs : bool or None, optional + If True, both the dataset's and variables' attributes (`attrs`) will be + copied from the original objects to the new ones. If False, the new dataset + and variables will be returned without copying the attributes. + args : iterable, optional + Positional arguments passed on to `func`. + **kwargs : Any + Keyword arguments passed on to `func`. + + Returns + ------- + applied : Dataset + Resulting dataset from applying ``func`` to each data variable. + + Examples + -------- + >>> da = xr.DataArray(np.random.randn(2, 3)) + >>> ds = xr.Dataset({"foo": da, "bar": ("x", [-1, 2])}) + >>> ds + Size: 64B + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Dimensions without coordinates: dim_0, dim_1, x + Data variables: + foo (dim_0, dim_1) float64 48B 1.764 0.4002 0.9787 2.241 1.868 -0.9773 + bar (x) int64 16B -1 2 + >>> ds.map(np.fabs) + Size: 64B + Dimensions: (dim_0: 2, dim_1: 3, x: 2) + Dimensions without coordinates: dim_0, dim_1, x + Data variables: + foo (dim_0, dim_1) float64 48B 1.764 0.4002 0.9787 2.241 1.868 0.9773 + bar (x) float64 16B 1.0 2.0 + """ + + # Copied from xarray.Dataset so as not to call type(self), which causes problems (see datatree GH188). + # TODO Refactor xarray upstream to avoid needing to overwrite this. + # TODO This copied version will drop all attrs - the keep_attrs stuff should be re-instated + variables = { + k: maybe_wrap_array(v, func(v, *args, **kwargs)) + for k, v in self.data_vars.items() + } + # return type(self)(variables, attrs=attrs) + return Dataset(variables) + + +class DataTree( + NamedNode, + MappedDatasetMethodsMixin, + MappedDataWithCoords, + DataTreeArithmeticMixin, + TreeAttrAccessMixin, + Generic[Tree], + Mapping, +): + """ + A tree-like hierarchical collection of xarray objects. + + Attempts to present an API like that of xarray.Dataset, but methods are wrapped to also update all the tree's child nodes. + """ + + # TODO Some way of sorting children by depth + + # TODO do we need a watch out for if methods intended only for root nodes are called on non-root nodes? + + # TODO dataset methods which should not or cannot act over the whole tree, such as .to_array + + # TODO .loc method + + # TODO a lot of properties like .variables could be defined in a DataMapping class which both Dataset and DataTree inherit from + + # TODO all groupby classes + + # TODO a lot of properties like .variables could be defined in a DataMapping class which both Dataset and DataTree inherit from + + # TODO __slots__ + + # TODO all groupby classes + + _name: Optional[str] + _parent: Optional[DataTree] + _children: OrderedDict[str, DataTree] + _attrs: Optional[Dict[Hashable, Any]] + _cache: Dict[str, Any] + _coord_names: Set[Hashable] + _dims: Dict[Hashable, int] + _encoding: Optional[Dict[Hashable, Any]] + _close: Optional[Callable[[], None]] + _indexes: Dict[Hashable, Index] + _variables: Dict[Hashable, Variable] + + __slots__ = ( + "_name", + "_parent", + "_children", + "_attrs", + "_cache", + "_coord_names", + "_dims", + "_encoding", + "_close", + "_indexes", + "_variables", + ) + + def __init__( + self, + data: Optional[Dataset | DataArray] = None, + parent: Optional[DataTree] = None, + children: Optional[Mapping[str, DataTree]] = None, + name: Optional[str] = None, + ): + """ + Create a single node of a DataTree. + + The node may optionally contain data in the form of data and coordinate variables, stored in the same way as + data is stored in an xarray.Dataset. + + Parameters + ---------- + data : Dataset, DataArray, or None, optional + Data to store under the .ds attribute of this node. DataArrays will be promoted to Datasets. + Default is None. + parent : DataTree, optional + Parent node to this node. Default is None. + children : Mapping[str, DataTree], optional + Any child nodes of this node. Default is None. + name : str, optional + Name for this node of the tree. Default is None. + + Returns + ------- + DataTree + + See Also + -------- + DataTree.from_dict + """ + + # validate input + if children is None: + children = {} + ds = _coerce_to_dataset(data) + _check_for_name_collisions(children, ds.variables) + + super().__init__(name=name) + + # set data attributes + self._replace( + inplace=True, + variables=ds._variables, + coord_names=ds._coord_names, + dims=ds._dims, + indexes=ds._indexes, + attrs=ds._attrs, + encoding=ds._encoding, + ) + self._close = ds._close + + # set tree attributes (must happen after variables set to avoid initialization errors) + self.children = children + self.parent = parent + + @property + def parent(self: DataTree) -> DataTree | None: + """Parent of this node.""" + return self._parent + + @parent.setter + def parent(self: DataTree, new_parent: DataTree) -> None: + if new_parent and self.name is None: + raise ValueError("Cannot set an unnamed node as a child of another node") + self._set_parent(new_parent, self.name) + + @property + def ds(self) -> DatasetView: + """ + An immutable Dataset-like view onto the data in this node. + + For a mutable Dataset containing the same data as in this node, use `.to_dataset()` instead. + + See Also + -------- + DataTree.to_dataset + """ + return DatasetView._from_node(self) + + @ds.setter + def ds(self, data: Optional[Union[Dataset, DataArray]] = None) -> None: + ds = _coerce_to_dataset(data) + + _check_for_name_collisions(self.children, ds.variables) + + self._replace( + inplace=True, + variables=ds._variables, + coord_names=ds._coord_names, + dims=ds._dims, + indexes=ds._indexes, + attrs=ds._attrs, + encoding=ds._encoding, + ) + self._close = ds._close + + def _pre_attach(self: DataTree, parent: DataTree) -> None: + """ + Method which superclass calls before setting parent, here used to prevent having two + children with duplicate names (or a data variable with the same name as a child). + """ + super()._pre_attach(parent) + if self.name in list(parent.ds.variables): + raise KeyError( + f"parent {parent.name} already contains a data variable named {self.name}" + ) + + def to_dataset(self) -> Dataset: + """ + Return the data in this node as a new xarray.Dataset object. + + See Also + -------- + DataTree.ds + """ + return Dataset._construct_direct( + self._variables, + self._coord_names, + self._dims, + self._attrs, + self._indexes, + self._encoding, + self._close, + ) + + @property + def has_data(self): + """Whether or not there are any data variables in this node.""" + return len(self._variables) > 0 + + @property + def has_attrs(self) -> bool: + """Whether or not there are any metadata attributes in this node.""" + return len(self.attrs.keys()) > 0 + + @property + def is_empty(self) -> bool: + """False if node contains any data or attrs. Does not look at children.""" + return not (self.has_data or self.has_attrs) + + @property + def is_hollow(self) -> bool: + """True if only leaf nodes contain data.""" + return not any(node.has_data for node in self.subtree if not node.is_leaf) + + @property + def variables(self) -> Mapping[Hashable, Variable]: + """Low level interface to node contents as dict of Variable objects. + + This ordered dictionary is frozen to prevent mutation that could + violate Dataset invariants. It contains all variable objects + constituting this DataTree node, including both data variables and + coordinates. + """ + return Frozen(self._variables) + + @property + def attrs(self) -> Dict[Hashable, Any]: + """Dictionary of global attributes on this node object.""" + if self._attrs is None: + self._attrs = {} + return self._attrs + + @attrs.setter + def attrs(self, value: Mapping[Any, Any]) -> None: + self._attrs = dict(value) + + @property + def encoding(self) -> Dict: + """Dictionary of global encoding attributes on this node object.""" + if self._encoding is None: + self._encoding = {} + return self._encoding + + @encoding.setter + def encoding(self, value: Mapping) -> None: + self._encoding = dict(value) + + @property + def dims(self) -> Mapping[Hashable, int]: + """Mapping from dimension names to lengths. + + Cannot be modified directly, but is updated when adding new variables. + + Note that type of this object differs from `DataArray.dims`. + See `DataTree.sizes`, `Dataset.sizes`, and `DataArray.sizes` for consistently named + properties. + """ + return Frozen(self._dims) + + @property + def sizes(self) -> Mapping[Hashable, int]: + """Mapping from dimension names to lengths. + + Cannot be modified directly, but is updated when adding new variables. + + This is an alias for `DataTree.dims` provided for the benefit of + consistency with `DataArray.sizes`. + + See Also + -------- + DataArray.sizes + """ + return self.dims + + @property + def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for attribute-style access""" + yield from self._item_sources + yield self.attrs + + @property + def _item_sources(self) -> Iterable[Mapping[Any, Any]]: + """Places to look-up items for key-completion""" + yield self.data_vars + yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords) + + # virtual coordinates + yield HybridMappingProxy(keys=self.dims, mapping=self) + + # immediate child nodes + yield self.children + + def _ipython_key_completions_(self) -> List[str]: + """Provide method for the key-autocompletions in IPython. + See http://ipython.readthedocs.io/en/stable/config/integrating.html#tab-completion + For the details. + """ + + # TODO allow auto-completing relative string paths, e.g. `dt['path/to/../ node'` + # Would require changes to ipython's autocompleter, see https://github.com/ipython/ipython/issues/12420 + # Instead for now we only list direct paths to all node in subtree explicitly + + items_on_this_node = self._item_sources + full_file_like_paths_to_all_nodes_in_subtree = { + node.path[1:]: node for node in self.subtree + } + + all_item_sources = itertools.chain( + items_on_this_node, [full_file_like_paths_to_all_nodes_in_subtree] + ) + + items = { + item + for source in all_item_sources + for item in source + if isinstance(item, str) + } + return list(items) + + def __contains__(self, key: object) -> bool: + """The 'in' operator will return true or false depending on whether + 'key' is either an array stored in the datatree or a child node, or neither. + """ + return key in self.variables or key in self.children + + def __bool__(self) -> bool: + return bool(self.ds.data_vars) or bool(self.children) + + def __iter__(self) -> Iterator[Hashable]: + return itertools.chain(self.ds.data_vars, self.children) + + def __array__(self, dtype=None): + raise TypeError( + "cannot directly convert a DataTree into a " + "numpy array. Instead, create an xarray.DataArray " + "first, either with indexing on the DataTree or by " + "invoking the `to_array()` method." + ) + + def __repr__(self) -> str: + return formatting.datatree_repr(self) + + def __str__(self) -> str: + return formatting.datatree_repr(self) + + def _repr_html_(self): + """Make html representation of datatree object""" + if XR_OPTS["display_style"] == "text": + return f"
    {escape(repr(self))}
    " + return formatting_html.datatree_repr(self) + + @classmethod + def _construct_direct( + cls, + variables: dict[Any, Variable], + coord_names: set[Hashable], + dims: Optional[dict[Any, int]] = None, + attrs: Optional[dict] = None, + indexes: Optional[dict[Any, Index]] = None, + encoding: Optional[dict] = None, + name: str | None = None, + parent: DataTree | None = None, + children: Optional[OrderedDict[str, DataTree]] = None, + close: Optional[Callable[[], None]] = None, + ) -> DataTree: + """Shortcut around __init__ for internal use when we want to skip costly validation.""" + + # data attributes + if dims is None: + dims = calculate_dimensions(variables) + if indexes is None: + indexes = {} + if children is None: + children = OrderedDict() + + obj: DataTree = object.__new__(cls) + obj._variables = variables + obj._coord_names = coord_names + obj._dims = dims + obj._indexes = indexes + obj._attrs = attrs + obj._close = close + obj._encoding = encoding + + # tree attributes + obj._name = name + obj._children = children + obj._parent = parent + + return obj + + def _replace( + self: DataTree, + variables: Optional[dict[Hashable, Variable]] = None, + coord_names: Optional[set[Hashable]] = None, + dims: Optional[dict[Any, int]] = None, + attrs: dict[Hashable, Any] | None | Default = _default, + indexes: Optional[dict[Hashable, Index]] = None, + encoding: dict | None | Default = _default, + name: str | None | Default = _default, + parent: DataTree | None = _default, + children: Optional[OrderedDict[str, DataTree]] = None, + inplace: bool = False, + ) -> DataTree: + """ + Fastpath constructor for internal use. + + Returns an object with optionally replaced attributes. + + Explicitly passed arguments are *not* copied when placed on the new + datatree. It is up to the caller to ensure that they have the right type + and are not used elsewhere. + """ + # TODO Adding new children inplace using this method will cause bugs. + # You will end up with an inconsistency between the name of the child node and the key the child is stored under. + # Use ._set() instead for now + if inplace: + if variables is not None: + self._variables = variables + if coord_names is not None: + self._coord_names = coord_names + if dims is not None: + self._dims = dims + if attrs is not _default: + self._attrs = attrs + if indexes is not None: + self._indexes = indexes + if encoding is not _default: + self._encoding = encoding + if name is not _default: + self._name = name + if parent is not _default: + self._parent = parent + if children is not None: + self._children = children + obj = self + else: + if variables is None: + variables = self._variables.copy() + if coord_names is None: + coord_names = self._coord_names.copy() + if dims is None: + dims = self._dims.copy() + if attrs is _default: + attrs = copy.copy(self._attrs) + if indexes is None: + indexes = self._indexes.copy() + if encoding is _default: + encoding = copy.copy(self._encoding) + if name is _default: + name = self._name # no need to copy str objects or None + if parent is _default: + parent = copy.copy(self._parent) + if children is _default: + children = copy.copy(self._children) + obj = self._construct_direct( + variables, + coord_names, + dims, + attrs, + indexes, + encoding, + name, + parent, + children, + ) + return obj + + def copy( + self: DataTree, + deep: bool = False, + ) -> DataTree: + """ + Returns a copy of this subtree. + + Copies this node and all child nodes. + + If `deep=True`, a deep copy is made of each of the component variables. + Otherwise, a shallow copy of each of the component variable is made, so + that the underlying memory region of the new datatree is the same as in + the original datatree. + + Parameters + ---------- + deep : bool, default: False + Whether each component variable is loaded into memory and copied onto + the new object. Default is False. + + Returns + ------- + object : DataTree + New object with dimensions, attributes, coordinates, name, encoding, + and data of this node and all child nodes copied from original. + + See Also + -------- + xarray.Dataset.copy + pandas.DataFrame.copy + """ + return self._copy_subtree(deep=deep) + + def _copy_subtree( + self: DataTree, + deep: bool = False, + memo: dict[int, Any] | None = None, + ) -> DataTree: + """Copy entire subtree""" + new_tree = self._copy_node(deep=deep) + for node in self.descendants: + path = node.relative_to(self) + new_tree[path] = node._copy_node(deep=deep) + return new_tree + + def _copy_node( + self: DataTree, + deep: bool = False, + ) -> DataTree: + """Copy just one node of a tree""" + new_node: DataTree = DataTree() + new_node.name = self.name + new_node.ds = self.to_dataset().copy(deep=deep) + return new_node + + def __copy__(self: DataTree) -> DataTree: + return self._copy_subtree(deep=False) + + def __deepcopy__(self: DataTree, memo: dict[int, Any] | None = None) -> DataTree: + return self._copy_subtree(deep=True, memo=memo) + + def get( + self: DataTree, key: str, default: Optional[DataTree | DataArray] = None + ) -> Optional[DataTree | DataArray]: + """ + Access child nodes, variables, or coordinates stored in this node. + + Returned object will be either a DataTree or DataArray object depending on whether the key given points to a + child or variable. + + Parameters + ---------- + key : str + Name of variable / child within this node. Must lie in this immediate node (not elsewhere in the tree). + default : DataTree | DataArray, optional + A value to return if the specified key does not exist. Default return value is None. + """ + if key in self.children: + return self.children[key] + elif key in self.ds: + return self.ds[key] + else: + return default + + def __getitem__(self: DataTree, key: str) -> DataTree | DataArray: + """ + Access child nodes, variables, or coordinates stored anywhere in this tree. + + Returned object will be either a DataTree or DataArray object depending on whether the key given points to a + child or variable. + + Parameters + ---------- + key : str + Name of variable / child within this node, or unix-like path to variable / child within another node. + + Returns + ------- + Union[DataTree, DataArray] + """ + + # Either: + if utils.is_dict_like(key): + # dict-like indexing + raise NotImplementedError("Should this index over whole tree?") + elif isinstance(key, str): + # TODO should possibly deal with hashables in general? + # path-like: a name of a node/variable, or path to a node/variable + path = NodePath(key) + return self._get_item(path) + elif utils.is_list_like(key): + # iterable of variable names + raise NotImplementedError( + "Selecting via tags is deprecated, and selecting multiple items should be " + "implemented via .subset" + ) + else: + raise ValueError(f"Invalid format for key: {key}") + + def _set(self, key: str, val: DataTree | CoercibleValue) -> None: + """ + Set the child node or variable with the specified key to value. + + Counterpart to the public .get method, and also only works on the immediate node, not other nodes in the tree. + """ + if isinstance(val, DataTree): + # create and assign a shallow copy here so as not to alter original name of node in grafted tree + new_node = val.copy(deep=False) + new_node.name = key + new_node.parent = self + else: + if not isinstance(val, (DataArray, Variable)): + # accommodate other types that can be coerced into Variables + val = DataArray(val) + + self.update({key: val}) + + def __setitem__( + self, + key: str, + value: Any, + ) -> None: + """ + Add either a child node or an array to the tree, at any position. + + Data can be added anywhere, and new nodes will be created to cross the path to the new location if necessary. + + If there is already a node at the given location, then if value is a Node class or Dataset it will overwrite the + data already present at that node, and if value is a single array, it will be merged with it. + """ + # TODO xarray.Dataset accepts other possibilities, how do we exactly replicate all the behaviour? + if utils.is_dict_like(key): + raise NotImplementedError + elif isinstance(key, str): + # TODO should possibly deal with hashables in general? + # path-like: a name of a node/variable, or path to a node/variable + path = NodePath(key) + return self._set_item(path, value, new_nodes_along_path=True) + else: + raise ValueError("Invalid format for key") + + def update(self, other: Dataset | Mapping[str, DataTree | DataArray]) -> None: + """ + Update this node's children and / or variables. + + Just like `dict.update` this is an in-place operation. + """ + # TODO separate by type + new_children = {} + new_variables = {} + for k, v in other.items(): + if isinstance(v, DataTree): + # avoid named node being stored under inconsistent key + new_child = v.copy() + new_child.name = k + new_children[k] = new_child + elif isinstance(v, (DataArray, Variable)): + # TODO this should also accommodate other types that can be coerced into Variables + new_variables[k] = v + else: + raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree") + + vars_merge_result = dataset_update_method(self.to_dataset(), new_variables) + # TODO are there any subtleties with preserving order of children like this? + merged_children = OrderedDict({**self.children, **new_children}) + self._replace( + inplace=True, children=merged_children, **vars_merge_result._asdict() + ) + + def assign( + self, items: Mapping[Any, Any] | None = None, **items_kwargs: Any + ) -> DataTree: + """ + Assign new data variables or child nodes to a DataTree, returning a new object + with all the original items in addition to the new ones. + + Parameters + ---------- + items : mapping of hashable to Any + Mapping from variable or child node names to the new values. If the new values + are callable, they are computed on the Dataset and assigned to new + data variables. If the values are not callable, (e.g. a DataTree, DataArray, + scalar, or array), they are simply assigned. + **items_kwargs + The keyword arguments form of ``variables``. + One of variables or variables_kwargs must be provided. + + Returns + ------- + dt : DataTree + A new DataTree with the new variables or children in addition to all the + existing items. + + Notes + ----- + Since ``kwargs`` is a dictionary, the order of your arguments may not + be preserved, and so the order of the new variables is not well-defined. + Assigning multiple items within the same ``assign`` is + possible, but you cannot reference other variables created within the + same ``assign`` call. + + See Also + -------- + xarray.Dataset.assign + pandas.DataFrame.assign + """ + items = either_dict_or_kwargs(items, items_kwargs, "assign") + dt = self.copy() + dt.update(items) + return dt + + def drop_nodes( + self: DataTree, names: str | Iterable[str], *, errors: ErrorOptions = "raise" + ) -> DataTree: + """ + Drop child nodes from this node. + + Parameters + ---------- + names : str or iterable of str + Name(s) of nodes to drop. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a KeyError if any of the node names + passed are not present as children of this node. If 'ignore', + any given names that are present are dropped and no error is raised. + + Returns + ------- + dropped : DataTree + A copy of the node with the specified children dropped. + """ + # the Iterable check is required for mypy + if isinstance(names, str) or not isinstance(names, Iterable): + names = {names} + else: + names = set(names) + + if errors == "raise": + extra = names - set(self.children) + if extra: + raise KeyError(f"Cannot drop all nodes - nodes {extra} not present") + + children_to_keep = OrderedDict( + {name: child for name, child in self.children.items() if name not in names} + ) + return self._replace(children=children_to_keep) + + @classmethod + def from_dict( + cls, + d: MutableMapping[str, Dataset | DataArray | DataTree | None], + name: Optional[str] = None, + ) -> DataTree: + """ + Create a datatree from a dictionary of data objects, organised by paths into the tree. + + Parameters + ---------- + d : dict-like + A mapping from path names to xarray.Dataset, xarray.DataArray, or DataTree objects. + + Path names are to be given as unix-like path. If path names containing more than one part are given, new + tree nodes will be constructed as necessary. + + To assign data to the root node of the tree use "/" as the path. + name : Hashable, optional + Name for the root node of the tree. Default is None. + + Returns + ------- + DataTree + + Notes + ----- + If your dictionary is nested you will need to flatten it before using this method. + """ + + # First create the root node + root_data = d.pop("/", None) + obj = cls(name=name, data=root_data, parent=None, children=None) + + if d: + # Populate tree with children determined from data_objects mapping + for path, data in d.items(): + # Create and set new node + node_name = NodePath(path).name + if isinstance(data, cls): + new_node = data.copy() + new_node.orphan() + else: + new_node = cls(name=node_name, data=data) + obj._set_item( + path, + new_node, + allow_overwrite=False, + new_nodes_along_path=True, + ) + + return obj + + def to_dict(self) -> Dict[str, Dataset]: + """ + Create a dictionary mapping of absolute node paths to the data contained in those nodes. + + Returns + ------- + Dict[str, Dataset] + """ + return {node.path: node.to_dataset() for node in self.subtree} + + @property + def nbytes(self) -> int: + return sum(node.to_dataset().nbytes for node in self.subtree) + + def __len__(self) -> int: + return len(self.children) + len(self.data_vars) + + @property + def indexes(self) -> Indexes[pd.Index]: + """Mapping of pandas.Index objects used for label based indexing. + + Raises an error if this DataTree node has indexes that cannot be coerced + to pandas.Index objects. + + See Also + -------- + DataTree.xindexes + """ + return self.xindexes.to_pandas_indexes() + + @property + def xindexes(self) -> Indexes[Index]: + """Mapping of xarray Index objects used for label based indexing.""" + return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) + + @property + def coords(self) -> DatasetCoordinates: + """Dictionary of xarray.DataArray objects corresponding to coordinate + variables + """ + return DatasetCoordinates(self.to_dataset()) + + @property + def data_vars(self) -> DataVariables: + """Dictionary of DataArray objects corresponding to data variables""" + return DataVariables(self.to_dataset()) + + def isomorphic( + self, + other: DataTree, + from_root: bool = False, + strict_names: bool = False, + ) -> bool: + """ + Two DataTrees are considered isomorphic if every node has the same number of children. + + Nothing about the data in each node is checked. + + Isomorphism is a necessary condition for two trees to be used in a nodewise binary operation, + such as ``tree1 + tree2``. + + By default this method does not check any part of the tree above the given node. + Therefore this method can be used as default to check that two subtrees are isomorphic. + + Parameters + ---------- + other : DataTree + The other tree object to compare to. + from_root : bool, optional, default is False + Whether or not to first traverse to the root of the two trees before checking for isomorphism. + If neither tree has a parent then this has no effect. + strict_names : bool, optional, default is False + Whether or not to also check that every node in the tree has the same name as its counterpart in the other + tree. + + See Also + -------- + DataTree.equals + DataTree.identical + """ + try: + check_isomorphic( + self, + other, + require_names_equal=strict_names, + check_from_root=from_root, + ) + return True + except (TypeError, TreeIsomorphismError): + return False + + def equals(self, other: DataTree, from_root: bool = True) -> bool: + """ + Two DataTrees are equal if they have isomorphic node structures, with matching node names, + and if they have matching variables and coordinates, all of which are equal. + + By default this method will check the whole tree above the given node. + + Parameters + ---------- + other : DataTree + The other tree object to compare to. + from_root : bool, optional, default is True + Whether or not to first traverse to the root of the two trees before checking for isomorphism. + If neither tree has a parent then this has no effect. + + See Also + -------- + Dataset.equals + DataTree.isomorphic + DataTree.identical + """ + if not self.isomorphic(other, from_root=from_root, strict_names=True): + return False + + return all( + [ + node.ds.equals(other_node.ds) + for node, other_node in zip(self.subtree, other.subtree) + ] + ) + + def identical(self, other: DataTree, from_root=True) -> bool: + """ + Like equals, but will also check all dataset attributes and the attributes on + all variables and coordinates. + + By default this method will check the whole tree above the given node. + + Parameters + ---------- + other : DataTree + The other tree object to compare to. + from_root : bool, optional, default is True + Whether or not to first traverse to the root of the two trees before checking for isomorphism. + If neither tree has a parent then this has no effect. + + See Also + -------- + Dataset.identical + DataTree.isomorphic + DataTree.equals + """ + if not self.isomorphic(other, from_root=from_root, strict_names=True): + return False + + return all( + node.ds.identical(other_node.ds) + for node, other_node in zip(self.subtree, other.subtree) + ) + + def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree: + """ + Filter nodes according to a specified condition. + + Returns a new tree containing only the nodes in the original tree for which `fitlerfunc(node)` is True. + Will also contain empty nodes at intermediate positions if required to support leaves. + + Parameters + ---------- + filterfunc: function + A function which accepts only one DataTree - the node on which filterfunc will be called. + + Returns + ------- + DataTree + + See Also + -------- + match + pipe + map_over_subtree + """ + filtered_nodes = { + node.path: node.ds for node in self.subtree if filterfunc(node) + } + return DataTree.from_dict(filtered_nodes, name=self.root.name) + + def match(self, pattern: str) -> DataTree: + """ + Return nodes with paths matching pattern. + + Uses unix glob-like syntax for pattern-matching. + + Parameters + ---------- + pattern: str + A pattern to match each node path against. + + Returns + ------- + DataTree + + See Also + -------- + filter + pipe + map_over_subtree + + Examples + -------- + >>> dt = DataTree.from_dict( + ... { + ... "/a/A": None, + ... "/a/B": None, + ... "/b/A": None, + ... "/b/B": None, + ... } + ... ) + >>> dt.match("*/B") + DataTree('None', parent=None) + ├── DataTree('a') + │ └── DataTree('B') + └── DataTree('b') + └── DataTree('B') + """ + matching_nodes = { + node.path: node.ds + for node in self.subtree + if NodePath(node.path).match(pattern) + } + return DataTree.from_dict(matching_nodes, name=self.root.name) + + def map_over_subtree( + self, + func: Callable, + *args: Iterable[Any], + **kwargs: Any, + ) -> DataTree | Tuple[DataTree]: + """ + Apply a function to every dataset in this subtree, returning a new tree which stores the results. + + The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the + descendant nodes. The returned tree will have the same structure as the original subtree. + + func needs to return a Dataset in order to rebuild the subtree. + + Parameters + ---------- + func : callable + Function to apply to datasets with signature: + `func(node.ds, *args, **kwargs) -> Dataset`. + + Function will not be applied to any nodes without datasets. + *args : tuple, optional + Positional arguments passed on to `func`. + **kwargs : Any + Keyword arguments passed on to `func`. + + Returns + ------- + subtrees : DataTree, Tuple of DataTrees + One or more subtrees containing results from applying ``func`` to the data at each node. + """ + # TODO this signature means that func has no way to know which node it is being called upon - change? + + # TODO fix this typing error + return map_over_subtree(func)(self, *args, **kwargs) # type: ignore[operator] + + def map_over_subtree_inplace( + self, + func: Callable, + *args: Iterable[Any], + **kwargs: Any, + ) -> None: + """ + Apply a function to every dataset in this subtree, updating data in place. + + Parameters + ---------- + func : callable + Function to apply to datasets with signature: + `func(node.ds, *args, **kwargs) -> Dataset`. + + Function will not be applied to any nodes without datasets, + *args : tuple, optional + Positional arguments passed on to `func`. + **kwargs : Any + Keyword arguments passed on to `func`. + """ + + # TODO if func fails on some node then the previous nodes will still have been updated... + + for node in self.subtree: + if node.has_data: + node.ds = func(node.ds, *args, **kwargs) + + def pipe( + self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any + ) -> Any: + """Apply ``func(self, *args, **kwargs)`` + + This method replicates the pandas method of the same name. + + Parameters + ---------- + func : callable + function to apply to this xarray object (Dataset/DataArray). + ``args``, and ``kwargs`` are passed into ``func``. + Alternatively a ``(callable, data_keyword)`` tuple where + ``data_keyword`` is a string indicating the keyword of + ``callable`` that expects the xarray object. + *args + positional arguments passed into ``func``. + **kwargs + a dictionary of keyword arguments passed into ``func``. + + Returns + ------- + object : Any + the return type of ``func``. + + Notes + ----- + Use ``.pipe`` when chaining together functions that expect + xarray or pandas objects, e.g., instead of writing + + .. code:: python + + f(g(h(dt), arg1=a), arg2=b, arg3=c) + + You can write + + .. code:: python + + (dt.pipe(h).pipe(g, arg1=a).pipe(f, arg2=b, arg3=c)) + + If you have a function that takes the data as (say) the second + argument, pass a tuple indicating which keyword expects the + data. For example, suppose ``f`` takes its data as ``arg2``: + + .. code:: python + + (dt.pipe(h).pipe(g, arg1=a).pipe((f, "arg2"), arg1=a, arg3=c)) + + """ + if isinstance(func, tuple): + func, target = func + if target in kwargs: + raise ValueError( + f"{target} is both the pipe target and a keyword argument" + ) + kwargs[target] = self + else: + args = (self,) + args + return func(*args, **kwargs) + + def render(self): + """Print tree structure, including any data stored at each node.""" + for pre, fill, node in RenderTree(self): + print(f"{pre}DataTree('{self.name}')") + for ds_line in repr(node.ds)[1:]: + print(f"{fill}{ds_line}") + + def merge(self, datatree: DataTree) -> DataTree: + """Merge all the leaves of a second DataTree into this one.""" + raise NotImplementedError + + def merge_child_nodes(self, *paths, new_path: T_Path) -> DataTree: + """Merge a set of child nodes into a single new node.""" + raise NotImplementedError + + # TODO some kind of .collapse() or .flatten() method to merge a subtree + + def as_array(self) -> DataArray: + return self.ds.as_dataarray() + + @property + def groups(self): + """Return all netCDF4 groups in the tree, given as a tuple of path-like strings.""" + return tuple(node.path for node in self.subtree) + + def to_netcdf( + self, filepath, mode: str = "w", encoding=None, unlimited_dims=None, **kwargs + ): + """ + Write datatree contents to a netCDF file. + + Parameters + ---------- + filepath : str or Path + Path to which to save this datatree. + mode : {"w", "a"}, default: "w" + Write ('w') or append ('a') mode. If mode='w', any existing file at + this location will be overwritten. If mode='a', existing variables + will be overwritten. Only appies to the root group. + encoding : dict, optional + Nested dictionary with variable names as keys and dictionaries of + variable specific encodings as values, e.g., + ``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1, + "zlib": True}, ...}, ...}``. See ``xarray.Dataset.to_netcdf`` for available + options. + unlimited_dims : dict, optional + Mapping of unlimited dimensions per group that that should be serialized as unlimited dimensions. + By default, no dimensions are treated as unlimited dimensions. + Note that unlimited_dims may also be set via + ``dataset.encoding["unlimited_dims"]``. + kwargs : + Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf`` + """ + from .io import _datatree_to_netcdf + + _datatree_to_netcdf( + self, + filepath, + mode=mode, + encoding=encoding, + unlimited_dims=unlimited_dims, + **kwargs, + ) + + def to_zarr( + self, + store, + mode: str = "w-", + encoding=None, + consolidated: bool = True, + **kwargs, + ): + """ + Write datatree contents to a Zarr store. + + Parameters + ---------- + store : MutableMapping, str or Path, optional + Store or path to directory in file system + mode : {{"w", "w-", "a", "r+", None}, default: "w-" + Persistence mode: “w” means create (overwrite if exists); “w-” means create (fail if exists); + “a” means override existing variables (create if does not exist); “r+” means modify existing + array values only (raise an error if any metadata or shapes would change). The default mode + is “a” if append_dim is set. Otherwise, it is “r+” if region is set and w- otherwise. + encoding : dict, optional + Nested dictionary with variable names as keys and dictionaries of + variable specific encodings as values, e.g., + ``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1}, ...}, ...}``. + See ``xarray.Dataset.to_zarr`` for available options. + consolidated : bool + If True, apply zarr's `consolidate_metadata` function to the store + after writing metadata for all groups. + kwargs : + Additional keyword arguments to be passed to ``xarray.Dataset.to_zarr`` + """ + from .io import _datatree_to_zarr + + _datatree_to_zarr( + self, + store, + mode=mode, + encoding=encoding, + consolidated=consolidated, + **kwargs, + ) + + def plot(self): + raise NotImplementedError diff --git a/xarray/datatree_/datatree/extensions.py b/xarray/datatree_/datatree/extensions.py new file mode 100644 index 00000000000..f6f4e985a79 --- /dev/null +++ b/xarray/datatree_/datatree/extensions.py @@ -0,0 +1,20 @@ +from xarray.core.extensions import _register_accessor + +from .datatree import DataTree + + +def register_datatree_accessor(name): + """Register a custom accessor on DataTree objects. + + Parameters + ---------- + name : str + Name under which the accessor should be registered. A warning is issued + if this name conflicts with a preexisting attribute. + + See Also + -------- + xarray.register_dataarray_accessor + xarray.register_dataset_accessor + """ + return _register_accessor(name, DataTree) diff --git a/xarray/datatree_/datatree/formatting.py b/xarray/datatree_/datatree/formatting.py new file mode 100644 index 00000000000..deba57eb09d --- /dev/null +++ b/xarray/datatree_/datatree/formatting.py @@ -0,0 +1,91 @@ +from typing import TYPE_CHECKING + +from xarray.core.formatting import _compat_to_str, diff_dataset_repr + +from .mapping import diff_treestructure +from .render import RenderTree + +if TYPE_CHECKING: + from .datatree import DataTree + + +def diff_nodewise_summary(a, b, compat): + """Iterates over all corresponding nodes, recording differences between data at each location.""" + + compat_str = _compat_to_str(compat) + + summary = [] + for node_a, node_b in zip(a.subtree, b.subtree): + a_ds, b_ds = node_a.ds, node_b.ds + + if not a_ds._all_compat(b_ds, compat): + dataset_diff = diff_dataset_repr(a_ds, b_ds, compat_str) + data_diff = "\n".join(dataset_diff.split("\n", 1)[1:]) + + nodediff = ( + f"\nData in nodes at position '{node_a.path}' do not match:" + f"{data_diff}" + ) + summary.append(nodediff) + + return "\n".join(summary) + + +def diff_tree_repr(a, b, compat): + summary = [ + f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}" + ] + + # TODO check root parents? + + strict_names = True if compat in ["equals", "identical"] else False + treestructure_diff = diff_treestructure(a, b, strict_names) + + # If the trees structures are different there is no point comparing each node + # TODO we could show any differences in nodes up to the first place that structure differs? + if treestructure_diff or compat == "isomorphic": + summary.append("\n" + treestructure_diff) + else: + nodewise_diff = diff_nodewise_summary(a, b, compat) + summary.append("\n" + nodewise_diff) + + return "\n".join(summary) + + +def datatree_repr(dt): + """A printable representation of the structure of this entire tree.""" + renderer = RenderTree(dt) + + lines = [] + for pre, fill, node in renderer: + node_repr = _single_node_repr(node) + + node_line = f"{pre}{node_repr.splitlines()[0]}" + lines.append(node_line) + + if node.has_data or node.has_attrs: + ds_repr = node_repr.splitlines()[2:] + for line in ds_repr: + if len(node.children) > 0: + lines.append(f"{fill}{renderer.style.vertical}{line}") + else: + lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}") + + # Tack on info about whether or not root node has a parent at the start + first_line = lines[0] + parent = f'"{dt.parent.name}"' if dt.parent is not None else "None" + first_line_with_parent = first_line[:-1] + f", parent={parent})" + lines[0] = first_line_with_parent + + return "\n".join(lines) + + +def _single_node_repr(node: "DataTree") -> str: + """Information about this node, not including its relationships to other nodes.""" + node_info = f"DataTree('{node.name}')" + + if node.has_data or node.has_attrs: + ds_info = "\n" + repr(node.ds) + else: + ds_info = "" + return node_info + ds_info diff --git a/xarray/datatree_/datatree/formatting_html.py b/xarray/datatree_/datatree/formatting_html.py new file mode 100644 index 00000000000..547b567a396 --- /dev/null +++ b/xarray/datatree_/datatree/formatting_html.py @@ -0,0 +1,135 @@ +from functools import partial +from html import escape +from typing import Any, Mapping + +from xarray.core.formatting_html import ( + _mapping_section, + _obj_repr, + attr_section, + coord_section, + datavar_section, + dim_section, +) + + +def summarize_children(children: Mapping[str, Any]) -> str: + N_CHILDREN = len(children) - 1 + + # Get result from node_repr and wrap it + lines_callback = lambda n, c, end: _wrap_repr(node_repr(n, c), end=end) + + children_html = "".join( + lines_callback(n, c, end=False) # Long lines + if i < N_CHILDREN + else lines_callback(n, c, end=True) # Short lines + for i, (n, c) in enumerate(children.items()) + ) + + return "".join( + [ + "
    ", + children_html, + "
    ", + ] + ) + + +children_section = partial( + _mapping_section, + name="Groups", + details_func=summarize_children, + max_items_collapse=1, + expand_option_name="display_expand_groups", +) + + +def node_repr(group_title: str, dt: Any) -> str: + header_components = [f"
    {escape(group_title)}
    "] + + ds = dt.ds + + sections = [ + children_section(dt.children), + dim_section(ds), + coord_section(ds.coords), + datavar_section(ds.data_vars), + attr_section(ds.attrs), + ] + + return _obj_repr(ds, header_components, sections) + + +def _wrap_repr(r: str, end: bool = False) -> str: + """ + Wrap HTML representation with a tee to the left of it. + + Enclosing HTML tag is a
    with :code:`display: inline-grid` style. + + Turns: + [ title ] + | details | + |_____________| + + into (A): + |─ [ title ] + | | details | + | |_____________| + + or (B): + └─ [ title ] + | details | + |_____________| + + Parameters + ---------- + r: str + HTML representation to wrap. + end: bool + Specify if the line on the left should continue or end. + + Default is True. + + Returns + ------- + str + Wrapped HTML representation. + + Tee color is set to the variable :code:`--xr-border-color`. + """ + # height of line + end = bool(end) + height = "100%" if end is False else "1.2em" + return "".join( + [ + "
    ", + "
    ", + "
    ", + "
    ", + "
    ", + "
    ", + "
      ", + r, + "
    " "
    ", + "
    ", + ] + ) + + +def datatree_repr(dt: Any) -> str: + obj_type = f"datatree.{type(dt).__name__}" + return node_repr(obj_type, dt) diff --git a/xarray/datatree_/datatree/io.py b/xarray/datatree_/datatree/io.py new file mode 100644 index 00000000000..d3d533ee71e --- /dev/null +++ b/xarray/datatree_/datatree/io.py @@ -0,0 +1,134 @@ +from xarray.datatree_.datatree import DataTree + + +def _get_nc_dataset_class(engine): + if engine == "netcdf4": + from netCDF4 import Dataset # type: ignore + elif engine == "h5netcdf": + from h5netcdf.legacyapi import Dataset # type: ignore + elif engine is None: + try: + from netCDF4 import Dataset + except ImportError: + from h5netcdf.legacyapi import Dataset # type: ignore + else: + raise ValueError(f"unsupported engine: {engine}") + return Dataset + + +def _create_empty_netcdf_group(filename, group, mode, engine): + ncDataset = _get_nc_dataset_class(engine) + + with ncDataset(filename, mode=mode) as rootgrp: + rootgrp.createGroup(group) + + +def _datatree_to_netcdf( + dt: DataTree, + filepath, + mode: str = "w", + encoding=None, + unlimited_dims=None, + **kwargs, +): + if kwargs.get("format", None) not in [None, "NETCDF4"]: + raise ValueError("to_netcdf only supports the NETCDF4 format") + + engine = kwargs.get("engine", None) + if engine not in [None, "netcdf4", "h5netcdf"]: + raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines") + + if kwargs.get("group", None) is not None: + raise NotImplementedError( + "specifying a root group for the tree has not been implemented" + ) + + if not kwargs.get("compute", True): + raise NotImplementedError("compute=False has not been implemented yet") + + if encoding is None: + encoding = {} + + # In the future, we may want to expand this check to insure all the provided encoding + # options are valid. For now, this simply checks that all provided encoding keys are + # groups in the datatree. + if set(encoding) - set(dt.groups): + raise ValueError( + f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}" + ) + + if unlimited_dims is None: + unlimited_dims = {} + + for node in dt.subtree: + ds = node.ds + group_path = node.path + if ds is None: + _create_empty_netcdf_group(filepath, group_path, mode, engine) + else: + ds.to_netcdf( + filepath, + group=group_path, + mode=mode, + encoding=encoding.get(node.path), + unlimited_dims=unlimited_dims.get(node.path), + **kwargs, + ) + mode = "r+" + + +def _create_empty_zarr_group(store, group, mode): + import zarr # type: ignore + + root = zarr.open_group(store, mode=mode) + root.create_group(group, overwrite=True) + + +def _datatree_to_zarr( + dt: DataTree, + store, + mode: str = "w-", + encoding=None, + consolidated: bool = True, + **kwargs, +): + from zarr.convenience import consolidate_metadata # type: ignore + + if kwargs.get("group", None) is not None: + raise NotImplementedError( + "specifying a root group for the tree has not been implemented" + ) + + if not kwargs.get("compute", True): + raise NotImplementedError("compute=False has not been implemented yet") + + if encoding is None: + encoding = {} + + # In the future, we may want to expand this check to insure all the provided encoding + # options are valid. For now, this simply checks that all provided encoding keys are + # groups in the datatree. + if set(encoding) - set(dt.groups): + raise ValueError( + f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}" + ) + + for node in dt.subtree: + ds = node.ds + group_path = node.path + if ds is None: + _create_empty_zarr_group(store, group_path, mode) + else: + ds.to_zarr( + store, + group=group_path, + mode=mode, + encoding=encoding.get(node.path), + consolidated=False, + **kwargs, + ) + if "w" in mode: + mode = "a" + + if consolidated: + consolidate_metadata(store) diff --git a/xarray/datatree_/datatree/iterators.py b/xarray/datatree_/datatree/iterators.py new file mode 100644 index 00000000000..68e75c4f612 --- /dev/null +++ b/xarray/datatree_/datatree/iterators.py @@ -0,0 +1,116 @@ +from abc import abstractmethod +from collections import abc +from typing import Callable, Iterator, List, Optional + +from xarray.core.treenode import Tree + +"""These iterators are copied from anytree.iterators, with minor modifications.""" + + +class AbstractIter(abc.Iterator): + def __init__( + self, + node: Tree, + filter_: Optional[Callable] = None, + stop: Optional[Callable] = None, + maxlevel: Optional[int] = None, + ): + """ + Iterate over tree starting at `node`. + Base class for all iterators. + Keyword Args: + filter_: function called with every `node` as argument, `node` is returned if `True`. + stop: stop iteration at `node` if `stop` function returns `True` for `node`. + maxlevel (int): maximum descending in the node hierarchy. + """ + self.node = node + self.filter_ = filter_ + self.stop = stop + self.maxlevel = maxlevel + self.__iter = None + + def __init(self): + node = self.node + maxlevel = self.maxlevel + filter_ = self.filter_ or AbstractIter.__default_filter + stop = self.stop or AbstractIter.__default_stop + children = ( + [] + if AbstractIter._abort_at_level(1, maxlevel) + else AbstractIter._get_children([node], stop) + ) + return self._iter(children, filter_, stop, maxlevel) + + @staticmethod + def __default_filter(node): + return True + + @staticmethod + def __default_stop(node): + return False + + def __iter__(self) -> Iterator[Tree]: + return self + + def __next__(self) -> Iterator[Tree]: + if self.__iter is None: + self.__iter = self.__init() + item = next(self.__iter) # type: ignore[call-overload] + return item + + @staticmethod + @abstractmethod + def _iter(children: List[Tree], filter_, stop, maxlevel) -> Iterator[Tree]: + ... + + @staticmethod + def _abort_at_level(level, maxlevel): + return maxlevel is not None and level > maxlevel + + @staticmethod + def _get_children(children: List[Tree], stop) -> List[Tree]: + return [child for child in children if not stop(child)] + + +class PreOrderIter(AbstractIter): + """ + Iterate over tree applying pre-order strategy starting at `node`. + Start at root and go-down until reaching a leaf node. + Step upwards then, and search for the next leafs. + """ + + @staticmethod + def _iter(children, filter_, stop, maxlevel): + for child_ in children: + if stop(child_): + continue + if filter_(child_): + yield child_ + if not AbstractIter._abort_at_level(2, maxlevel): + descendantmaxlevel = maxlevel - 1 if maxlevel else None + for descendant_ in PreOrderIter._iter( + list(child_.children.values()), filter_, stop, descendantmaxlevel + ): + yield descendant_ + + +class LevelOrderIter(AbstractIter): + """ + Iterate over tree applying level-order strategy starting at `node`. + """ + + @staticmethod + def _iter(children, filter_, stop, maxlevel): + level = 1 + while children: + next_children = [] + for child in children: + if filter_(child): + yield child + next_children += AbstractIter._get_children( + list(child.children.values()), stop + ) + children = next_children + level += 1 + if AbstractIter._abort_at_level(level, maxlevel): + break diff --git a/xarray/datatree_/datatree/mapping.py b/xarray/datatree_/datatree/mapping.py new file mode 100644 index 00000000000..355149060a9 --- /dev/null +++ b/xarray/datatree_/datatree/mapping.py @@ -0,0 +1,346 @@ +from __future__ import annotations + +import functools +import sys +from itertools import repeat +from textwrap import dedent +from typing import TYPE_CHECKING, Callable, Tuple + +from xarray import DataArray, Dataset + +from .iterators import LevelOrderIter +from xarray.core.treenode import NodePath, TreeNode + +if TYPE_CHECKING: + from xarray.core.datatree import DataTree + + +class TreeIsomorphismError(ValueError): + """Error raised if two tree objects do not share the same node structure.""" + + pass + + +def check_isomorphic( + a: DataTree, + b: DataTree, + require_names_equal: bool = False, + check_from_root: bool = True, +): + """ + Check that two trees have the same structure, raising an error if not. + + Does not compare the actual data in the nodes. + + By default this function only checks that subtrees are isomorphic, not the entire tree above (if it exists). + Can instead optionally check the entire trees starting from the root, which will ensure all + + Can optionally check if corresponding nodes should have the same name. + + Parameters + ---------- + a : DataTree + b : DataTree + require_names_equal : Bool + Whether or not to also check that each node has the same name as its counterpart. + check_from_root : Bool + Whether or not to first traverse to the root of the trees before checking for isomorphism. + If a & b have no parents then this has no effect. + + Raises + ------ + TypeError + If either a or b are not tree objects. + TreeIsomorphismError + If a and b are tree objects, but are not isomorphic to one another. + Also optionally raised if their structure is isomorphic, but the names of any two + respective nodes are not equal. + """ + + if not isinstance(a, TreeNode): + raise TypeError(f"Argument `a` is not a tree, it is of type {type(a)}") + if not isinstance(b, TreeNode): + raise TypeError(f"Argument `b` is not a tree, it is of type {type(b)}") + + if check_from_root: + a = a.root + b = b.root + + diff = diff_treestructure(a, b, require_names_equal=require_names_equal) + + if diff: + raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff) + + +def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str: + """ + Return a summary of why two trees are not isomorphic. + If they are isomorphic return an empty string. + """ + + # Walking nodes in "level-order" fashion means walking down from the root breadth-first. + # Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree + # (which it is so long as children are stored in a tuple or list rather than in a set). + for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)): + path_a, path_b = node_a.path, node_b.path + + if require_names_equal: + if node_a.name != node_b.name: + diff = dedent( + f"""\ + Node '{path_a}' in the left object has name '{node_a.name}' + Node '{path_b}' in the right object has name '{node_b.name}'""" + ) + return diff + + if len(node_a.children) != len(node_b.children): + diff = dedent( + f"""\ + Number of children on node '{path_a}' of the left object: {len(node_a.children)} + Number of children on node '{path_b}' of the right object: {len(node_b.children)}""" + ) + return diff + + return "" + + +def map_over_subtree(func: Callable) -> Callable: + """ + Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees. + + Applies a function to every dataset in one or more subtrees, returning new trees which store the results. + + The function will be applied to any data-containing dataset stored in any of the nodes in the trees. The returned + trees will have the same structure as the supplied trees. + + `func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after + mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any + returned value that is one of these types will be stacked into a separate tree before returning all of them. + + The trees passed to the resulting function must all be isomorphic to one another. Their nodes need not be named + similarly, but all the output trees will have nodes named in the same way as the first tree passed. + + Parameters + ---------- + func : callable + Function to apply to datasets with signature: + + `func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`. + + (i.e. func must accept at least one Dataset and return at least one Dataset.) + Function will not be applied to any nodes without datasets. + *args : tuple, optional + Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets + via .ds . + **kwargs : Any + Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets + via .ds . + + Returns + ------- + mapped : callable + Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at + each node. + + See also + -------- + DataTree.map_over_subtree + DataTree.map_over_subtree_inplace + DataTree.subtree + """ + + # TODO examples in the docstring + + # TODO inspect function to work out immediately if the wrong number of arguments were passed for it? + + @functools.wraps(func) + def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]: + """Internal function which maps func over every node in tree, returning a tree of the results.""" + from .datatree import DataTree + + all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [ + a for a in kwargs.values() if isinstance(a, DataTree) + ] + + if len(all_tree_inputs) > 0: + first_tree, *other_trees = all_tree_inputs + else: + raise TypeError("Must pass at least one tree object") + + for other_tree in other_trees: + # isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic + check_isomorphic( + first_tree, other_tree, require_names_equal=False, check_from_root=False + ) + + # Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees + # We don't know which arguments are DataTrees so we zip all arguments together as iterables + # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return + out_data_objects = {} + args_as_tree_length_iterables = [ + a.subtree if isinstance(a, DataTree) else repeat(a) for a in args + ] + n_args = len(args_as_tree_length_iterables) + kwargs_as_tree_length_iterables = { + k: v.subtree if isinstance(v, DataTree) else repeat(v) + for k, v in kwargs.items() + } + for node_of_first_tree, *all_node_args in zip( + first_tree.subtree, + *args_as_tree_length_iterables, + *list(kwargs_as_tree_length_iterables.values()), + ): + node_args_as_datasetviews = [ + a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args] + ] + node_kwargs_as_datasetviews = dict( + zip( + [k for k in kwargs_as_tree_length_iterables.keys()], + [ + v.ds if isinstance(v, DataTree) else v + for v in all_node_args[n_args:] + ], + ) + ) + func_with_error_context = _handle_errors_with_path_context( + node_of_first_tree.path + )(func) + + if node_of_first_tree.has_data: + # call func on the data in this particular set of corresponding nodes + results = func_with_error_context( + *node_args_as_datasetviews, **node_kwargs_as_datasetviews + ) + elif node_of_first_tree.has_attrs: + # propagate attrs + results = node_of_first_tree.ds + else: + # nothing to propagate so use fastpath to create empty node in new tree + results = None + + # TODO implement mapping over multiple trees in-place using if conditions from here on? + out_data_objects[node_of_first_tree.path] = results + + # Find out how many return values we received + num_return_values = _check_all_return_values(out_data_objects) + + # Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees + original_root_path = first_tree.path + result_trees = [] + for i in range(num_return_values): + out_tree_contents = {} + for n in first_tree.subtree: + p = n.path + if p in out_data_objects.keys(): + if isinstance(out_data_objects[p], tuple): + output_node_data = out_data_objects[p][i] + else: + output_node_data = out_data_objects[p] + else: + output_node_data = None + + # Discard parentage so that new trees don't include parents of input nodes + relative_path = str(NodePath(p).relative_to(original_root_path)) + relative_path = "/" if relative_path == "." else relative_path + out_tree_contents[relative_path] = output_node_data + + new_tree = DataTree.from_dict( + out_tree_contents, + name=first_tree.name, + ) + result_trees.append(new_tree) + + # If only one result then don't wrap it in a tuple + if len(result_trees) == 1: + return result_trees[0] + else: + return tuple(result_trees) + + return _map_over_subtree + + +def _handle_errors_with_path_context(path): + """Wraps given function so that if it fails it also raises path to node on which it failed.""" + + def decorator(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + if sys.version_info >= (3, 11): + # Add the context information to the error message + e.add_note( + f"Raised whilst mapping function over node with path {path}" + ) + raise + + return wrapper + + return decorator + + +def add_note(err: BaseException, msg: str) -> None: + # TODO: remove once python 3.10 can be dropped + if sys.version_info < (3, 11): + err.__notes__ = getattr(err, "__notes__", []) + [msg] # type: ignore[attr-defined] + else: + err.add_note(msg) + + +def _check_single_set_return_values(path_to_node, obj): + """Check types returned from single evaluation of func, and return number of return values received from func.""" + if isinstance(obj, (Dataset, DataArray)): + return 1 + elif isinstance(obj, tuple): + for r in obj: + if not isinstance(r, (Dataset, DataArray)): + raise TypeError( + f"One of the results of calling func on datasets on the nodes at position {path_to_node} is " + f"of type {type(r)}, not Dataset or DataArray." + ) + return len(obj) + else: + raise TypeError( + f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not " + f"Dataset or DataArray, nor a tuple of such types." + ) + + +def _check_all_return_values(returned_objects): + """Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types.""" + + if all(r is None for r in returned_objects.values()): + raise TypeError( + "Called supplied function on all nodes but found a return value of None for" + "all of them." + ) + + result_data_objects = [ + (path_to_node, r) + for path_to_node, r in returned_objects.items() + if r is not None + ] + + if len(result_data_objects) == 1: + # Only one node in the tree: no need to check consistency of results between nodes + path_to_node, result = result_data_objects[0] + num_return_values = _check_single_set_return_values(path_to_node, result) + else: + prev_path, _ = result_data_objects[0] + prev_num_return_values, num_return_values = None, None + for path_to_node, obj in result_data_objects[1:]: + num_return_values = _check_single_set_return_values(path_to_node, obj) + + if ( + num_return_values != prev_num_return_values + and prev_num_return_values is not None + ): + raise TypeError( + f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return " + f"values, whereas calling func on the nodes at position {prev_path} instead returns " + f"{prev_num_return_values} separate return values." + ) + + prev_path, prev_num_return_values = path_to_node, num_return_values + + return num_return_values diff --git a/xarray/datatree_/datatree/ops.py b/xarray/datatree_/datatree/ops.py new file mode 100644 index 00000000000..d6ac4f83e7c --- /dev/null +++ b/xarray/datatree_/datatree/ops.py @@ -0,0 +1,262 @@ +import textwrap + +from xarray import Dataset + +from .mapping import map_over_subtree + +""" +Module which specifies the subset of xarray.Dataset's API which we wish to copy onto DataTree. + +Structured to mirror the way xarray defines Dataset's various operations internally, but does not actually import from +xarray's internals directly, only the public-facing xarray.Dataset class. +""" + + +_MAPPED_DOCSTRING_ADDENDUM = textwrap.fill( + "This method was copied from xarray.Dataset, but has been altered to " + "call the method on the Datasets stored in every node of the subtree. " + "See the `map_over_subtree` function for more details.", + width=117, +) + +# TODO equals, broadcast_equals etc. +# TODO do dask-related private methods need to be exposed? +_DATASET_DASK_METHODS_TO_MAP = [ + "load", + "compute", + "persist", + "unify_chunks", + "chunk", + "map_blocks", +] +_DATASET_METHODS_TO_MAP = [ + "as_numpy", + "set_coords", + "reset_coords", + "info", + "isel", + "sel", + "head", + "tail", + "thin", + "broadcast_like", + "reindex_like", + "reindex", + "interp", + "interp_like", + "rename", + "rename_dims", + "rename_vars", + "swap_dims", + "expand_dims", + "set_index", + "reset_index", + "reorder_levels", + "stack", + "unstack", + "merge", + "drop_vars", + "drop_sel", + "drop_isel", + "drop_dims", + "transpose", + "dropna", + "fillna", + "interpolate_na", + "ffill", + "bfill", + "combine_first", + "reduce", + "map", + "diff", + "shift", + "roll", + "sortby", + "quantile", + "rank", + "differentiate", + "integrate", + "cumulative_integrate", + "filter_by_attrs", + "polyfit", + "pad", + "idxmin", + "idxmax", + "argmin", + "argmax", + "query", + "curvefit", +] +_ALL_DATASET_METHODS_TO_MAP = _DATASET_DASK_METHODS_TO_MAP + _DATASET_METHODS_TO_MAP + +_DATA_WITH_COORDS_METHODS_TO_MAP = [ + "squeeze", + "clip", + "assign_coords", + "where", + "close", + "isnull", + "notnull", + "isin", + "astype", +] + +REDUCE_METHODS = ["all", "any"] +NAN_REDUCE_METHODS = [ + "max", + "min", + "mean", + "prod", + "sum", + "std", + "var", + "median", +] +NAN_CUM_METHODS = ["cumsum", "cumprod"] +_TYPED_DATASET_OPS_TO_MAP = [ + "__add__", + "__sub__", + "__mul__", + "__pow__", + "__truediv__", + "__floordiv__", + "__mod__", + "__and__", + "__xor__", + "__or__", + "__lt__", + "__le__", + "__gt__", + "__ge__", + "__eq__", + "__ne__", + "__radd__", + "__rsub__", + "__rmul__", + "__rpow__", + "__rtruediv__", + "__rfloordiv__", + "__rmod__", + "__rand__", + "__rxor__", + "__ror__", + "__iadd__", + "__isub__", + "__imul__", + "__ipow__", + "__itruediv__", + "__ifloordiv__", + "__imod__", + "__iand__", + "__ixor__", + "__ior__", + "__neg__", + "__pos__", + "__abs__", + "__invert__", + "round", + "argsort", + "conj", + "conjugate", +] +# TODO NUM_BINARY_OPS apparently aren't defined on DatasetArithmetic, and don't appear to be injected anywhere... +_ARITHMETIC_METHODS_TO_MAP = ( + REDUCE_METHODS + + NAN_REDUCE_METHODS + + NAN_CUM_METHODS + + _TYPED_DATASET_OPS_TO_MAP + + ["__array_ufunc__"] +) + + +def _wrap_then_attach_to_cls( + target_cls_dict, source_cls, methods_to_set, wrap_func=None +): + """ + Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_subtree) + + Result is like having written this in the classes' definition: + ``` + @wrap_func + def method_name(self, *args, **kwargs): + return self.method(*args, **kwargs) + ``` + + Every method attached here needs to have a return value of Dataset or DataArray in order to construct a new tree. + + Parameters + ---------- + target_cls_dict : MappingProxy + The __dict__ attribute of the class which we want the methods to be added to. (The __dict__ attribute can also + be accessed by calling vars() from within that classes' definition.) This will be updated by this function. + source_cls : class + Class object from which we want to copy methods (and optionally wrap them). Should be the actual class object + (or instance), not just the __dict__. + methods_to_set : Iterable[Tuple[str, callable]] + The method names and definitions supplied as a list of (method_name_string, method) pairs. + This format matches the output of inspect.getmembers(). + wrap_func : callable, optional + Function to decorate each method with. Must have the same return type as the method. + """ + for method_name in methods_to_set: + orig_method = getattr(source_cls, method_name) + wrapped_method = ( + wrap_func(orig_method) if wrap_func is not None else orig_method + ) + target_cls_dict[method_name] = wrapped_method + + if wrap_func is map_over_subtree: + # Add a paragraph to the method's docstring explaining how it's been mapped + orig_method_docstring = orig_method.__doc__ + # if orig_method_docstring is not None: + # if "\n" in orig_method_docstring: + # new_method_docstring = orig_method_docstring.replace( + # "\n", _MAPPED_DOCSTRING_ADDENDUM, 1 + # ) + # else: + # new_method_docstring = ( + # orig_method_docstring + f"\n\n{_MAPPED_DOCSTRING_ADDENDUM}" + # ) + setattr(target_cls_dict[method_name], "__doc__", orig_method_docstring) + + +class MappedDatasetMethodsMixin: + """ + Mixin to add methods defined specifically on the Dataset class such as .query(), but wrapped to map over all nodes + in the subtree. + """ + + _wrap_then_attach_to_cls( + target_cls_dict=vars(), + source_cls=Dataset, + methods_to_set=_ALL_DATASET_METHODS_TO_MAP, + wrap_func=map_over_subtree, + ) + + +class MappedDataWithCoords: + """ + Mixin to add coordinate-aware Dataset methods such as .where(), but wrapped to map over all nodes in the subtree. + """ + + # TODO add mapped versions of groupby, weighted, rolling, rolling_exp, coarsen, resample + _wrap_then_attach_to_cls( + target_cls_dict=vars(), + source_cls=Dataset, + methods_to_set=_DATA_WITH_COORDS_METHODS_TO_MAP, + wrap_func=map_over_subtree, + ) + + +class DataTreeArithmeticMixin: + """ + Mixin to add Dataset arithmetic operations such as __add__, reduction methods such as .mean(), and enable numpy + ufuncs such as np.sin(), but wrapped to map over all nodes in the subtree. + """ + + _wrap_then_attach_to_cls( + target_cls_dict=vars(), + source_cls=Dataset, + methods_to_set=_ARITHMETIC_METHODS_TO_MAP, + wrap_func=map_over_subtree, + ) diff --git a/xarray/datatree_/datatree/py.typed b/xarray/datatree_/datatree/py.typed new file mode 100644 index 00000000000..e69de29bb2d diff --git a/xarray/datatree_/datatree/render.py b/xarray/datatree_/datatree/render.py new file mode 100644 index 00000000000..aef327c5c47 --- /dev/null +++ b/xarray/datatree_/datatree/render.py @@ -0,0 +1,271 @@ +""" +String Tree Rendering. Copied from anytree. +""" + +import collections +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .datatree import DataTree + +Row = collections.namedtuple("Row", ("pre", "fill", "node")) + + +class AbstractStyle(object): + def __init__(self, vertical, cont, end): + """ + Tree Render Style. + Args: + vertical: Sign for vertical line. + cont: Chars for a continued branch. + end: Chars for the last branch. + """ + super(AbstractStyle, self).__init__() + self.vertical = vertical + self.cont = cont + self.end = end + assert ( + len(cont) == len(vertical) == len(end) + ), f"'{vertical}', '{cont}' and '{end}' need to have equal length" + + @property + def empty(self): + """Empty string as placeholder.""" + return " " * len(self.end) + + def __repr__(self): + return f"{self.__class__.__name__}()" + + +class ContStyle(AbstractStyle): + def __init__(self): + """ + Continued style, without gaps. + + >>> from anytree import Node, RenderTree + >>> root = Node("root") + >>> s0 = Node("sub0", parent=root) + >>> s0b = Node("sub0B", parent=s0) + >>> s0a = Node("sub0A", parent=s0) + >>> s1 = Node("sub1", parent=root) + >>> print(RenderTree(root, style=ContStyle())) + + Node('/root') + ├── Node('/root/sub0') + │ ├── Node('/root/sub0/sub0B') + │ └── Node('/root/sub0/sub0A') + └── Node('/root/sub1') + """ + super(ContStyle, self).__init__( + "\u2502 ", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 " + ) + + +class RenderTree(object): + def __init__( + self, node: "DataTree", style=ContStyle(), childiter=list, maxlevel=None + ): + """ + Render tree starting at `node`. + Keyword Args: + style (AbstractStyle): Render Style. + childiter: Child iterator. + maxlevel: Limit rendering to this depth. + :any:`RenderTree` is an iterator, returning a tuple with 3 items: + `pre` + tree prefix. + `fill` + filling for multiline entries. + `node` + :any:`NodeMixin` object. + It is up to the user to assemble these parts to a whole. + >>> from anytree import Node, RenderTree + >>> root = Node("root", lines=["c0fe", "c0de"]) + >>> s0 = Node("sub0", parent=root, lines=["ha", "ba"]) + >>> s0b = Node("sub0B", parent=s0, lines=["1", "2", "3"]) + >>> s0a = Node("sub0A", parent=s0, lines=["a", "b"]) + >>> s1 = Node("sub1", parent=root, lines=["Z"]) + Simple one line: + >>> for pre, _, node in RenderTree(root): + ... print("%s%s" % (pre, node.name)) + ... + root + ├── sub0 + │ ├── sub0B + │ └── sub0A + └── sub1 + Multiline: + >>> for pre, fill, node in RenderTree(root): + ... print("%s%s" % (pre, node.lines[0])) + ... for line in node.lines[1:]: + ... print("%s%s" % (fill, line)) + ... + c0fe + c0de + ├── ha + │ ba + │ ├── 1 + │ │ 2 + │ │ 3 + │ └── a + │ b + └── Z + `maxlevel` limits the depth of the tree: + >>> print(RenderTree(root, maxlevel=2)) + Node('/root', lines=['c0fe', 'c0de']) + ├── Node('/root/sub0', lines=['ha', 'ba']) + └── Node('/root/sub1', lines=['Z']) + The `childiter` is responsible for iterating over child nodes at the + same level. An reversed order can be achived by using `reversed`. + >>> for row in RenderTree(root, childiter=reversed): + ... print("%s%s" % (row.pre, row.node.name)) + ... + root + ├── sub1 + └── sub0 + ├── sub0A + └── sub0B + Or writing your own sort function: + >>> def mysort(items): + ... return sorted(items, key=lambda item: item.name) + ... + >>> for row in RenderTree(root, childiter=mysort): + ... print("%s%s" % (row.pre, row.node.name)) + ... + root + ├── sub0 + │ ├── sub0A + │ └── sub0B + └── sub1 + :any:`by_attr` simplifies attribute rendering and supports multiline: + >>> print(RenderTree(root).by_attr()) + root + ├── sub0 + │ ├── sub0B + │ └── sub0A + └── sub1 + >>> print(RenderTree(root).by_attr("lines")) + c0fe + c0de + ├── ha + │ ba + │ ├── 1 + │ │ 2 + │ │ 3 + │ └── a + │ b + └── Z + And can be a function: + >>> print(RenderTree(root).by_attr(lambda n: " ".join(n.lines))) + c0fe c0de + ├── ha ba + │ ├── 1 2 3 + │ └── a b + └── Z + """ + if not isinstance(style, AbstractStyle): + style = style() + self.node = node + self.style = style + self.childiter = childiter + self.maxlevel = maxlevel + + def __iter__(self): + return self.__next(self.node, tuple()) + + def __next(self, node, continues, level=0): + yield RenderTree.__item(node, continues, self.style) + children = node.children.values() + level += 1 + if children and (self.maxlevel is None or level < self.maxlevel): + children = self.childiter(children) + for child, is_last in _is_last(children): + for grandchild in self.__next( + child, continues + (not is_last,), level=level + ): + yield grandchild + + @staticmethod + def __item(node, continues, style): + if not continues: + return Row("", "", node) + else: + items = [style.vertical if cont else style.empty for cont in continues] + indent = "".join(items[:-1]) + branch = style.cont if continues[-1] else style.end + pre = indent + branch + fill = "".join(items) + return Row(pre, fill, node) + + def __str__(self): + lines = ["%s%r" % (pre, node) for pre, _, node in self] + return "\n".join(lines) + + def __repr__(self): + classname = self.__class__.__name__ + args = [ + repr(self.node), + "style=%s" % repr(self.style), + "childiter=%s" % repr(self.childiter), + ] + return "%s(%s)" % (classname, ", ".join(args)) + + def by_attr(self, attrname="name"): + """ + Return rendered tree with node attribute `attrname`. + >>> from anytree import AnyNode, RenderTree + >>> root = AnyNode(id="root") + >>> s0 = AnyNode(id="sub0", parent=root) + >>> s0b = AnyNode(id="sub0B", parent=s0, foo=4, bar=109) + >>> s0a = AnyNode(id="sub0A", parent=s0) + >>> s1 = AnyNode(id="sub1", parent=root) + >>> s1a = AnyNode(id="sub1A", parent=s1) + >>> s1b = AnyNode(id="sub1B", parent=s1, bar=8) + >>> s1c = AnyNode(id="sub1C", parent=s1) + >>> s1ca = AnyNode(id="sub1Ca", parent=s1c) + >>> print(RenderTree(root).by_attr("id")) + root + ├── sub0 + │ ├── sub0B + │ └── sub0A + └── sub1 + ├── sub1A + ├── sub1B + └── sub1C + └── sub1Ca + """ + + def get(): + for pre, fill, node in self: + attr = ( + attrname(node) + if callable(attrname) + else getattr(node, attrname, "") + ) + if isinstance(attr, (list, tuple)): + lines = attr + else: + lines = str(attr).split("\n") + yield "%s%s" % (pre, lines[0]) + for line in lines[1:]: + yield "%s%s" % (fill, line) + + return "\n".join(get()) + + +def _is_last(iterable): + iter_ = iter(iterable) + try: + nextitem = next(iter_) + except StopIteration: + pass + else: + item = nextitem + while True: + try: + nextitem = next(iter_) + yield item, False + except StopIteration: + yield nextitem, True + break + item = nextitem diff --git a/xarray/datatree_/datatree/testing.py b/xarray/datatree_/datatree/testing.py new file mode 100644 index 00000000000..1cbcdf2d4e3 --- /dev/null +++ b/xarray/datatree_/datatree/testing.py @@ -0,0 +1,120 @@ +from xarray.testing.assertions import ensure_warnings + +from .datatree import DataTree +from .formatting import diff_tree_repr + + +@ensure_warnings +def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool = False): + """ + Two DataTrees are considered isomorphic if every node has the same number of children. + + Nothing about the data in each node is checked. + + Isomorphism is a necessary condition for two trees to be used in a nodewise binary operation, + such as tree1 + tree2. + + By default this function does not check any part of the tree above the given node. + Therefore this function can be used as default to check that two subtrees are isomorphic. + + Parameters + ---------- + a : DataTree + The first object to compare. + b : DataTree + The second object to compare. + from_root : bool, optional, default is False + Whether or not to first traverse to the root of the trees before checking for isomorphism. + If a & b have no parents then this has no effect. + + See Also + -------- + DataTree.isomorphic + assert_equals + assert_identical + """ + __tracebackhide__ = True + assert isinstance(a, type(b)) + + if isinstance(a, DataTree): + if from_root: + a = a.root + b = b.root + + assert a.isomorphic(b, from_root=from_root), diff_tree_repr(a, b, "isomorphic") + else: + raise TypeError(f"{type(a)} not of type DataTree") + + +@ensure_warnings +def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): + """ + Two DataTrees are equal if they have isomorphic node structures, with matching node names, + and if they have matching variables and coordinates, all of which are equal. + + By default this method will check the whole tree above the given node. + + Parameters + ---------- + a : DataTree + The first object to compare. + b : DataTree + The second object to compare. + from_root : bool, optional, default is True + Whether or not to first traverse to the root of the trees before checking for isomorphism. + If a & b have no parents then this has no effect. + + See Also + -------- + DataTree.equals + assert_isomorphic + assert_identical + """ + __tracebackhide__ = True + assert isinstance(a, type(b)) + + if isinstance(a, DataTree): + if from_root: + a = a.root + b = b.root + + assert a.equals(b, from_root=from_root), diff_tree_repr(a, b, "equals") + else: + raise TypeError(f"{type(a)} not of type DataTree") + + +@ensure_warnings +def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): + """ + Like assert_equals, but will also check all dataset attributes and the attributes on + all variables and coordinates. + + By default this method will check the whole tree above the given node. + + Parameters + ---------- + a : xarray.DataTree + The first object to compare. + b : xarray.DataTree + The second object to compare. + from_root : bool, optional, default is True + Whether or not to first traverse to the root of the trees before checking for isomorphism. + If a & b have no parents then this has no effect. + + See Also + -------- + DataTree.identical + assert_isomorphic + assert_equal + """ + + __tracebackhide__ = True + assert isinstance(a, type(b)) + if isinstance(a, DataTree): + if from_root: + a = a.root + b = b.root + + assert a.identical(b, from_root=from_root), diff_tree_repr(a, b, "identical") + else: + raise TypeError(f"{type(a)} not of type DataTree") diff --git a/xarray/datatree_/datatree/tests/__init__.py b/xarray/datatree_/datatree/tests/__init__.py new file mode 100644 index 00000000000..64961158b13 --- /dev/null +++ b/xarray/datatree_/datatree/tests/__init__.py @@ -0,0 +1,29 @@ +import importlib + +import pytest +from packaging import version + + +def _importorskip(modname, minversion=None): + try: + mod = importlib.import_module(modname) + has = True + if minversion is not None: + if LooseVersion(mod.__version__) < LooseVersion(minversion): + raise ImportError("Minimum version not satisfied") + except ImportError: + has = False + func = pytest.mark.skipif(not has, reason=f"requires {modname}") + return has, func + + +def LooseVersion(vstring): + # Our development version is something like '0.10.9+aac7bfc' + # This function just ignores the git commit id. + vstring = vstring.split("+")[0] + return version.parse(vstring) + + +has_zarr, requires_zarr = _importorskip("zarr") +has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") +has_netCDF4, requires_netCDF4 = _importorskip("netCDF4") diff --git a/xarray/datatree_/datatree/tests/conftest.py b/xarray/datatree_/datatree/tests/conftest.py new file mode 100644 index 00000000000..bd2e7ba3247 --- /dev/null +++ b/xarray/datatree_/datatree/tests/conftest.py @@ -0,0 +1,65 @@ +import pytest +import xarray as xr + +from xarray.datatree_.datatree import DataTree + + +@pytest.fixture(scope="module") +def create_test_datatree(): + """ + Create a test datatree with this structure: + + + |-- set1 + | |-- + | | Dimensions: () + | | Data variables: + | | a int64 0 + | | b int64 1 + | |-- set1 + | |-- set2 + |-- set2 + | |-- + | | Dimensions: (x: 2) + | | Data variables: + | | a (x) int64 2, 3 + | | b (x) int64 0.1, 0.2 + | |-- set1 + |-- set3 + |-- + | Dimensions: (x: 2, y: 3) + | Data variables: + | a (y) int64 6, 7, 8 + | set0 (x) int64 9, 10 + + The structure has deliberately repeated names of tags, variables, and + dimensions in order to better check for bugs caused by name conflicts. + """ + + def _create_test_datatree(modify=lambda ds: ds): + set1_data = modify(xr.Dataset({"a": 0, "b": 1})) + set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})) + root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})) + + # Avoid using __init__ so we can independently test it + root = DataTree(data=root_data) + set1 = DataTree(name="set1", parent=root, data=set1_data) + DataTree(name="set1", parent=set1) + DataTree(name="set2", parent=set1) + set2 = DataTree(name="set2", parent=root, data=set2_data) + DataTree(name="set1", parent=set2) + DataTree(name="set3", parent=root) + + return root + + return _create_test_datatree + + +@pytest.fixture(scope="module") +def simple_datatree(create_test_datatree): + """ + Invoke create_test_datatree fixture (callback). + + Returns a DataTree. + """ + return create_test_datatree() diff --git a/xarray/datatree_/datatree/tests/test_dataset_api.py b/xarray/datatree_/datatree/tests/test_dataset_api.py new file mode 100644 index 00000000000..c3eb74451a6 --- /dev/null +++ b/xarray/datatree_/datatree/tests/test_dataset_api.py @@ -0,0 +1,98 @@ +import numpy as np +import xarray as xr + +from xarray.datatree_.datatree import DataTree +from xarray.datatree_.datatree.testing import assert_equal + + +class TestDSMethodInheritance: + def test_dataset_method(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected = DataTree(data=ds.isel(x=1)) + DataTree(name="results", parent=expected, data=ds.isel(x=1)) + + result = dt.isel(x=1) + assert_equal(result, expected) + + def test_reduce_method(self): + ds = xr.Dataset({"a": ("x", [False, True, False])}) + dt = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected = DataTree(data=ds.any()) + DataTree(name="results", parent=expected, data=ds.any()) + + result = dt.any() + assert_equal(result, expected) + + def test_nan_reduce_method(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected = DataTree(data=ds.mean()) + DataTree(name="results", parent=expected, data=ds.mean()) + + result = dt.mean() + assert_equal(result, expected) + + def test_cum_method(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt = DataTree(data=ds) + DataTree(name="results", parent=dt, data=ds) + + expected = DataTree(data=ds.cumsum()) + DataTree(name="results", parent=expected, data=ds.cumsum()) + + result = dt.cumsum() + assert_equal(result, expected) + + +class TestOps: + def test_binary_op_on_int(self): + ds1 = xr.Dataset({"a": [5], "b": [3]}) + ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) + dt = DataTree(data=ds1) + DataTree(name="subnode", data=ds2, parent=dt) + + expected = DataTree(data=ds1 * 5) + DataTree(name="subnode", data=ds2 * 5, parent=expected) + + result = dt * 5 + assert_equal(result, expected) + + def test_binary_op_on_dataset(self): + ds1 = xr.Dataset({"a": [5], "b": [3]}) + ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) + dt = DataTree(data=ds1) + DataTree(name="subnode", data=ds2, parent=dt) + other_ds = xr.Dataset({"z": ("z", [0.1, 0.2])}) + + expected = DataTree(data=ds1 * other_ds) + DataTree(name="subnode", data=ds2 * other_ds, parent=expected) + + result = dt * other_ds + assert_equal(result, expected) + + def test_binary_op_on_datatree(self): + ds1 = xr.Dataset({"a": [5], "b": [3]}) + ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) + dt = DataTree(data=ds1) + DataTree(name="subnode", data=ds2, parent=dt) + + expected = DataTree(data=ds1 * ds1) + DataTree(name="subnode", data=ds2 * ds2, parent=expected) + + result = dt * dt + assert_equal(result, expected) + + +class TestUFuncs: + def test_tree(self, create_test_datatree): + dt = create_test_datatree() + expected = create_test_datatree(modify=lambda ds: np.sin(ds)) + result_tree = np.sin(dt) + assert_equal(result_tree, expected) diff --git a/xarray/datatree_/datatree/tests/test_datatree.py b/xarray/datatree_/datatree/tests/test_datatree.py new file mode 100644 index 00000000000..cfb57470651 --- /dev/null +++ b/xarray/datatree_/datatree/tests/test_datatree.py @@ -0,0 +1,731 @@ +from copy import copy, deepcopy + +import numpy as np +import pytest +import xarray as xr +import xarray.testing as xrt +from xarray.tests import create_test_data, source_ndarray + +import xarray.datatree_.datatree.testing as dtt +from xarray.datatree_.datatree import DataTree, NotFoundInTreeError + + +class TestTreeCreation: + def test_empty(self): + dt = DataTree(name="root") + assert dt.name == "root" + assert dt.parent is None + assert dt.children == {} + xrt.assert_identical(dt.to_dataset(), xr.Dataset()) + + def test_unnamed(self): + dt = DataTree() + assert dt.name is None + + def test_bad_names(self): + with pytest.raises(TypeError): + DataTree(name=5) + + with pytest.raises(ValueError): + DataTree(name="folder/data") + + +class TestFamilyTree: + def test_setparent_unnamed_child_node_fails(self): + john = DataTree(name="john") + with pytest.raises(ValueError, match="unnamed"): + DataTree(parent=john) + + def test_create_two_children(self): + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = xr.Dataset({"a": 0, "b": 1}) + + root = DataTree(data=root_data) + set1 = DataTree(name="set1", parent=root, data=set1_data) + DataTree(name="set1", parent=root) + DataTree(name="set2", parent=set1) + + def test_create_full_tree(self, simple_datatree): + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = xr.Dataset({"a": 0, "b": 1}) + set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}) + + root = DataTree(data=root_data) + set1 = DataTree(name="set1", parent=root, data=set1_data) + DataTree(name="set1", parent=set1) + DataTree(name="set2", parent=set1) + set2 = DataTree(name="set2", parent=root, data=set2_data) + DataTree(name="set1", parent=set2) + DataTree(name="set3", parent=root) + + expected = simple_datatree + assert root.identical(expected) + + +class TestNames: + def test_child_gets_named_on_attach(self): + sue = DataTree() + mary = DataTree(children={"Sue": sue}) # noqa + assert sue.name == "Sue" + + +class TestPaths: + def test_path_property(self): + sue = DataTree() + mary = DataTree(children={"Sue": sue}) + john = DataTree(children={"Mary": mary}) # noqa + assert sue.path == "/Mary/Sue" + assert john.path == "/" + + def test_path_roundtrip(self): + sue = DataTree() + mary = DataTree(children={"Sue": sue}) + john = DataTree(children={"Mary": mary}) # noqa + assert john[sue.path] is sue + + def test_same_tree(self): + mary = DataTree() + kate = DataTree() + john = DataTree(children={"Mary": mary, "Kate": kate}) # noqa + assert mary.same_tree(kate) + + def test_relative_paths(self): + sue = DataTree() + mary = DataTree(children={"Sue": sue}) + annie = DataTree() + john = DataTree(children={"Mary": mary, "Annie": annie}) + + result = sue.relative_to(john) + assert result == "Mary/Sue" + assert john.relative_to(sue) == "../.." + assert annie.relative_to(sue) == "../../Annie" + assert sue.relative_to(annie) == "../Mary/Sue" + assert sue.relative_to(sue) == "." + + evil_kate = DataTree() + with pytest.raises( + NotFoundInTreeError, match="nodes do not lie within the same tree" + ): + sue.relative_to(evil_kate) + + +class TestStoreDatasets: + def test_create_with_data(self): + dat = xr.Dataset({"a": 0}) + john = DataTree(name="john", data=dat) + + xrt.assert_identical(john.to_dataset(), dat) + + with pytest.raises(TypeError): + DataTree(name="mary", parent=john, data="junk") # noqa + + def test_set_data(self): + john = DataTree(name="john") + dat = xr.Dataset({"a": 0}) + john.ds = dat + + xrt.assert_identical(john.to_dataset(), dat) + + with pytest.raises(TypeError): + john.ds = "junk" + + def test_has_data(self): + john = DataTree(name="john", data=xr.Dataset({"a": 0})) + assert john.has_data + + john = DataTree(name="john", data=None) + assert not john.has_data + + def test_is_hollow(self): + john = DataTree(data=xr.Dataset({"a": 0})) + assert john.is_hollow + + eve = DataTree(children={"john": john}) + assert eve.is_hollow + + eve.ds = xr.Dataset({"a": 1}) + assert not eve.is_hollow + + +class TestVariablesChildrenNameCollisions: + def test_parent_already_has_variable_with_childs_name(self): + dt = DataTree(data=xr.Dataset({"a": [0], "b": 1})) + with pytest.raises(KeyError, match="already contains a data variable named a"): + DataTree(name="a", data=None, parent=dt) + + def test_assign_when_already_child_with_variables_name(self): + dt = DataTree(data=None) + DataTree(name="a", data=None, parent=dt) + with pytest.raises(KeyError, match="names would collide"): + dt.ds = xr.Dataset({"a": 0}) + + dt.ds = xr.Dataset() + + new_ds = dt.to_dataset().assign(a=xr.DataArray(0)) + with pytest.raises(KeyError, match="names would collide"): + dt.ds = new_ds + + +class TestGet: + ... + + +class TestGetItem: + def test_getitem_node(self): + folder1 = DataTree(name="folder1") + results = DataTree(name="results", parent=folder1) + highres = DataTree(name="highres", parent=results) + assert folder1["results"] is results + assert folder1["results/highres"] is highres + + def test_getitem_self(self): + dt = DataTree() + assert dt["."] is dt + + def test_getitem_single_data_variable(self): + data = xr.Dataset({"temp": [0, 50]}) + results = DataTree(name="results", data=data) + xrt.assert_identical(results["temp"], data["temp"]) + + def test_getitem_single_data_variable_from_node(self): + data = xr.Dataset({"temp": [0, 50]}) + folder1 = DataTree(name="folder1") + results = DataTree(name="results", parent=folder1) + DataTree(name="highres", parent=results, data=data) + xrt.assert_identical(folder1["results/highres/temp"], data["temp"]) + + def test_getitem_nonexistent_node(self): + folder1 = DataTree(name="folder1") + DataTree(name="results", parent=folder1) + with pytest.raises(KeyError): + folder1["results/highres"] + + def test_getitem_nonexistent_variable(self): + data = xr.Dataset({"temp": [0, 50]}) + results = DataTree(name="results", data=data) + with pytest.raises(KeyError): + results["pressure"] + + @pytest.mark.xfail(reason="Should be deprecated in favour of .subset") + def test_getitem_multiple_data_variables(self): + data = xr.Dataset({"temp": [0, 50], "p": [5, 8, 7]}) + results = DataTree(name="results", data=data) + xrt.assert_identical(results[["temp", "p"]], data[["temp", "p"]]) + + @pytest.mark.xfail(reason="Indexing needs to return whole tree (GH https://github.com/xarray-contrib/datatree/issues/77)") + def test_getitem_dict_like_selection_access_to_dataset(self): + data = xr.Dataset({"temp": [0, 50]}) + results = DataTree(name="results", data=data) + xrt.assert_identical(results[{"temp": 1}], data[{"temp": 1}]) + + +class TestUpdate: + def test_update(self): + dt = DataTree() + dt.update({"foo": xr.DataArray(0), "a": DataTree()}) + expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "a": None}) + print(dt) + print(dt.children) + print(dt._children) + print(dt["a"]) + print(expected) + dtt.assert_equal(dt, expected) + + def test_update_new_named_dataarray(self): + da = xr.DataArray(name="temp", data=[0, 50]) + folder1 = DataTree(name="folder1") + folder1.update({"results": da}) + expected = da.rename("results") + xrt.assert_equal(folder1["results"], expected) + + def test_update_doesnt_alter_child_name(self): + dt = DataTree() + dt.update({"foo": xr.DataArray(0), "a": DataTree(name="b")}) + assert "a" in dt.children + child = dt["a"] + assert child.name == "a" + + def test_update_overwrite(self): + actual = DataTree.from_dict({"a": DataTree(xr.Dataset({"x": 1}))}) + actual.update({"a": DataTree(xr.Dataset({"x": 2}))}) + + expected = DataTree.from_dict({"a": DataTree(xr.Dataset({"x": 2}))}) + + print(actual) + print(expected) + + dtt.assert_equal(actual, expected) + + +class TestCopy: + def test_copy(self, create_test_datatree): + dt = create_test_datatree() + + for node in dt.root.subtree: + node.attrs["Test"] = [1, 2, 3] + + for copied in [dt.copy(deep=False), copy(dt)]: + dtt.assert_identical(dt, copied) + + for node, copied_node in zip(dt.root.subtree, copied.root.subtree): + assert node.encoding == copied_node.encoding + # Note: IndexVariable objects with string dtype are always + # copied because of xarray.core.util.safe_cast_to_index. + # Limiting the test to data variables. + for k in node.data_vars: + v0 = node.variables[k] + v1 = copied_node.variables[k] + assert source_ndarray(v0.data) is source_ndarray(v1.data) + copied_node["foo"] = xr.DataArray(data=np.arange(5), dims="z") + assert "foo" not in node + + copied_node.attrs["foo"] = "bar" + assert "foo" not in node.attrs + assert node.attrs["Test"] is copied_node.attrs["Test"] + + def test_copy_subtree(self): + dt = DataTree.from_dict({"/level1/level2/level3": xr.Dataset()}) + + actual = dt["/level1/level2"].copy() + expected = DataTree.from_dict({"/level3": xr.Dataset()}, name="level2") + + dtt.assert_identical(actual, expected) + + def test_deepcopy(self, create_test_datatree): + dt = create_test_datatree() + + for node in dt.root.subtree: + node.attrs["Test"] = [1, 2, 3] + + for copied in [dt.copy(deep=True), deepcopy(dt)]: + dtt.assert_identical(dt, copied) + + for node, copied_node in zip(dt.root.subtree, copied.root.subtree): + assert node.encoding == copied_node.encoding + # Note: IndexVariable objects with string dtype are always + # copied because of xarray.core.util.safe_cast_to_index. + # Limiting the test to data variables. + for k in node.data_vars: + v0 = node.variables[k] + v1 = copied_node.variables[k] + assert source_ndarray(v0.data) is not source_ndarray(v1.data) + copied_node["foo"] = xr.DataArray(data=np.arange(5), dims="z") + assert "foo" not in node + + copied_node.attrs["foo"] = "bar" + assert "foo" not in node.attrs + assert node.attrs["Test"] is not copied_node.attrs["Test"] + + @pytest.mark.xfail(reason="data argument not yet implemented") + def test_copy_with_data(self, create_test_datatree): + orig = create_test_datatree() + # TODO use .data_vars once that property is available + data_vars = { + k: v for k, v in orig.variables.items() if k not in orig._coord_names + } + new_data = {k: np.random.randn(*v.shape) for k, v in data_vars.items()} + actual = orig.copy(data=new_data) + + expected = orig.copy() + for k, v in new_data.items(): + expected[k].data = v + dtt.assert_identical(expected, actual) + + # TODO test parents and children? + + +class TestSetItem: + def test_setitem_new_child_node(self): + john = DataTree(name="john") + mary = DataTree(name="mary") + john["mary"] = mary + + grafted_mary = john["mary"] + assert grafted_mary.parent is john + assert grafted_mary.name == "mary" + + def test_setitem_unnamed_child_node_becomes_named(self): + john2 = DataTree(name="john2") + john2["sonny"] = DataTree() + assert john2["sonny"].name == "sonny" + + def test_setitem_new_grandchild_node(self): + john = DataTree(name="john") + mary = DataTree(name="mary", parent=john) + rose = DataTree(name="rose") + john["mary/rose"] = rose + + grafted_rose = john["mary/rose"] + assert grafted_rose.parent is mary + assert grafted_rose.name == "rose" + + def test_grafted_subtree_retains_name(self): + subtree = DataTree(name="original_subtree_name") + root = DataTree(name="root") + root["new_subtree_name"] = subtree # noqa + assert subtree.name == "original_subtree_name" + + def test_setitem_new_empty_node(self): + john = DataTree(name="john") + john["mary"] = DataTree() + mary = john["mary"] + assert isinstance(mary, DataTree) + xrt.assert_identical(mary.to_dataset(), xr.Dataset()) + + def test_setitem_overwrite_data_in_node_with_none(self): + john = DataTree(name="john") + mary = DataTree(name="mary", parent=john, data=xr.Dataset()) + john["mary"] = DataTree() + xrt.assert_identical(mary.to_dataset(), xr.Dataset()) + + john.ds = xr.Dataset() + with pytest.raises(ValueError, match="has no name"): + john["."] = DataTree() + + @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") + def test_setitem_dataset_on_this_node(self): + data = xr.Dataset({"temp": [0, 50]}) + results = DataTree(name="results") + results["."] = data + xrt.assert_identical(results.to_dataset(), data) + + @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") + def test_setitem_dataset_as_new_node(self): + data = xr.Dataset({"temp": [0, 50]}) + folder1 = DataTree(name="folder1") + folder1["results"] = data + xrt.assert_identical(folder1["results"].to_dataset(), data) + + @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") + def test_setitem_dataset_as_new_node_requiring_intermediate_nodes(self): + data = xr.Dataset({"temp": [0, 50]}) + folder1 = DataTree(name="folder1") + folder1["results/highres"] = data + xrt.assert_identical(folder1["results/highres"].to_dataset(), data) + + def test_setitem_named_dataarray(self): + da = xr.DataArray(name="temp", data=[0, 50]) + folder1 = DataTree(name="folder1") + folder1["results"] = da + expected = da.rename("results") + xrt.assert_equal(folder1["results"], expected) + + def test_setitem_unnamed_dataarray(self): + data = xr.DataArray([0, 50]) + folder1 = DataTree(name="folder1") + folder1["results"] = data + xrt.assert_equal(folder1["results"], data) + + def test_setitem_variable(self): + var = xr.Variable(data=[0, 50], dims="x") + folder1 = DataTree(name="folder1") + folder1["results"] = var + xrt.assert_equal(folder1["results"], xr.DataArray(var)) + + def test_setitem_coerce_to_dataarray(self): + folder1 = DataTree(name="folder1") + folder1["results"] = 0 + xrt.assert_equal(folder1["results"], xr.DataArray(0)) + + def test_setitem_add_new_variable_to_empty_node(self): + results = DataTree(name="results") + results["pressure"] = xr.DataArray(data=[2, 3]) + assert "pressure" in results.ds + results["temp"] = xr.Variable(data=[10, 11], dims=["x"]) + assert "temp" in results.ds + + # What if there is a path to traverse first? + results = DataTree(name="results") + results["highres/pressure"] = xr.DataArray(data=[2, 3]) + assert "pressure" in results["highres"].ds + results["highres/temp"] = xr.Variable(data=[10, 11], dims=["x"]) + assert "temp" in results["highres"].ds + + def test_setitem_dataarray_replace_existing_node(self): + t = xr.Dataset({"temp": [0, 50]}) + results = DataTree(name="results", data=t) + p = xr.DataArray(data=[2, 3]) + results["pressure"] = p + expected = t.assign(pressure=p) + xrt.assert_identical(results.to_dataset(), expected) + + +class TestDictionaryInterface: + ... + + +class TestTreeFromDict: + def test_data_in_root(self): + dat = xr.Dataset() + dt = DataTree.from_dict({"/": dat}) + assert dt.name is None + assert dt.parent is None + assert dt.children == {} + xrt.assert_identical(dt.to_dataset(), dat) + + def test_one_layer(self): + dat1, dat2 = xr.Dataset({"a": 1}), xr.Dataset({"b": 2}) + dt = DataTree.from_dict({"run1": dat1, "run2": dat2}) + xrt.assert_identical(dt.to_dataset(), xr.Dataset()) + assert dt.name is None + xrt.assert_identical(dt["run1"].to_dataset(), dat1) + assert dt["run1"].children == {} + xrt.assert_identical(dt["run2"].to_dataset(), dat2) + assert dt["run2"].children == {} + + def test_two_layers(self): + dat1, dat2 = xr.Dataset({"a": 1}), xr.Dataset({"a": [1, 2]}) + dt = DataTree.from_dict({"highres/run": dat1, "lowres/run": dat2}) + assert "highres" in dt.children + assert "lowres" in dt.children + highres_run = dt["highres/run"] + xrt.assert_identical(highres_run.to_dataset(), dat1) + + def test_nones(self): + dt = DataTree.from_dict({"d": None, "d/e": None}) + assert [node.name for node in dt.subtree] == [None, "d", "e"] + assert [node.path for node in dt.subtree] == ["/", "/d", "/d/e"] + xrt.assert_identical(dt["d/e"].to_dataset(), xr.Dataset()) + + def test_full(self, simple_datatree): + dt = simple_datatree + paths = list(node.path for node in dt.subtree) + assert paths == [ + "/", + "/set1", + "/set1/set1", + "/set1/set2", + "/set2", + "/set2/set1", + "/set3", + ] + + def test_datatree_values(self): + dat1 = DataTree(data=xr.Dataset({"a": 1})) + expected = DataTree() + expected["a"] = dat1 + + actual = DataTree.from_dict({"a": dat1}) + + dtt.assert_identical(actual, expected) + + def test_roundtrip(self, simple_datatree): + dt = simple_datatree + roundtrip = DataTree.from_dict(dt.to_dict()) + assert roundtrip.equals(dt) + + @pytest.mark.xfail + def test_roundtrip_unnamed_root(self, simple_datatree): + # See GH81 + + dt = simple_datatree + dt.name = "root" + roundtrip = DataTree.from_dict(dt.to_dict()) + assert roundtrip.equals(dt) + + +class TestDatasetView: + def test_view_contents(self): + ds = create_test_data() + dt = DataTree(data=ds) + assert ds.identical( + dt.ds + ) # this only works because Dataset.identical doesn't check types + assert isinstance(dt.ds, xr.Dataset) + + def test_immutability(self): + # See issue https://github.com/xarray-contrib/datatree/issues/38 + dt = DataTree(name="root", data=None) + DataTree(name="a", data=None, parent=dt) + + with pytest.raises( + AttributeError, match="Mutation of the DatasetView is not allowed" + ): + dt.ds["a"] = xr.DataArray(0) + + with pytest.raises( + AttributeError, match="Mutation of the DatasetView is not allowed" + ): + dt.ds.update({"a": 0}) + + # TODO are there any other ways you can normally modify state (in-place)? + # (not attribute-like assignment because that doesn't work on Dataset anyway) + + def test_methods(self): + ds = create_test_data() + dt = DataTree(data=ds) + assert ds.mean().identical(dt.ds.mean()) + assert type(dt.ds.mean()) == xr.Dataset + + def test_arithmetic(self, create_test_datatree): + dt = create_test_datatree() + expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"] + result = 10.0 * dt["set1"].ds + assert result.identical(expected) + + def test_init_via_type(self): + # from datatree GH issue https://github.com/xarray-contrib/datatree/issues/188 + # xarray's .weighted is unusual because it uses type() to create a Dataset/DataArray + + a = xr.DataArray( + np.random.rand(3, 4, 10), + dims=["x", "y", "time"], + coords={"area": (["x", "y"], np.random.rand(3, 4))}, + ).to_dataset(name="data") + dt = DataTree(data=a) + + def weighted_mean(ds): + return ds.weighted(ds.area).mean(["x", "y"]) + + weighted_mean(dt.ds) + + +class TestAccess: + def test_attribute_access(self, create_test_datatree): + dt = create_test_datatree() + + # vars / coords + for key in ["a", "set0"]: + xrt.assert_equal(dt[key], getattr(dt, key)) + assert key in dir(dt) + + # dims + xrt.assert_equal(dt["a"]["y"], getattr(dt.a, "y")) + assert "y" in dir(dt["a"]) + + # children + for key in ["set1", "set2", "set3"]: + dtt.assert_equal(dt[key], getattr(dt, key)) + assert key in dir(dt) + + # attrs + dt.attrs["meta"] = "NASA" + assert dt.attrs["meta"] == "NASA" + assert "meta" in dir(dt) + + def test_ipython_key_completions(self, create_test_datatree): + dt = create_test_datatree() + key_completions = dt._ipython_key_completions_() + + node_keys = [node.path[1:] for node in dt.subtree] + assert all(node_key in key_completions for node_key in node_keys) + + var_keys = list(dt.variables.keys()) + assert all(var_key in key_completions for var_key in var_keys) + + def test_operation_with_attrs_but_no_data(self): + # tests bug from xarray-datatree GH262 + xs = xr.Dataset({"testvar": xr.DataArray(np.ones((2, 3)))}) + dt = DataTree.from_dict({"node1": xs, "node2": xs}) + dt.attrs["test_key"] = 1 # sel works fine without this line + dt.sel(dim_0=0) + + +class TestRestructuring: + def test_drop_nodes(self): + sue = DataTree.from_dict({"Mary": None, "Kate": None, "Ashley": None}) + + # test drop just one node + dropped_one = sue.drop_nodes(names="Mary") + assert "Mary" not in dropped_one.children + + # test drop multiple nodes + dropped = sue.drop_nodes(names=["Mary", "Kate"]) + assert not set(["Mary", "Kate"]).intersection(set(dropped.children)) + assert "Ashley" in dropped.children + + # test raise + with pytest.raises(KeyError, match="nodes {'Mary'} not present"): + dropped.drop_nodes(names=["Mary", "Ashley"]) + + # test ignore + childless = dropped.drop_nodes(names=["Mary", "Ashley"], errors="ignore") + assert childless.children == {} + + def test_assign(self): + dt = DataTree() + expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "/a": None}) + + # kwargs form + result = dt.assign(foo=xr.DataArray(0), a=DataTree()) + dtt.assert_equal(result, expected) + + # dict form + result = dt.assign({"foo": xr.DataArray(0), "a": DataTree()}) + dtt.assert_equal(result, expected) + + +class TestPipe: + def test_noop(self, create_test_datatree): + dt = create_test_datatree() + + actual = dt.pipe(lambda tree: tree) + assert actual.identical(dt) + + def test_params(self, create_test_datatree): + dt = create_test_datatree() + + def f(tree, **attrs): + return tree.assign(arr_with_attrs=xr.Variable("dim0", [], attrs=attrs)) + + attrs = {"x": 1, "y": 2, "z": 3} + + actual = dt.pipe(f, **attrs) + assert actual["arr_with_attrs"].attrs == attrs + + def test_named_self(self, create_test_datatree): + dt = create_test_datatree() + + def f(x, tree, y): + tree.attrs.update({"x": x, "y": y}) + return tree + + attrs = {"x": 1, "y": 2} + + actual = dt.pipe((f, "tree"), **attrs) + + assert actual is dt and actual.attrs == attrs + + +class TestSubset: + def test_match(self): + # TODO is this example going to cause problems with case sensitivity? + dt = DataTree.from_dict( + { + "/a/A": None, + "/a/B": None, + "/b/A": None, + "/b/B": None, + } + ) + result = dt.match("*/B") + expected = DataTree.from_dict( + { + "/a/B": None, + "/b/B": None, + } + ) + dtt.assert_identical(result, expected) + + def test_filter(self): + simpsons = DataTree.from_dict( + d={ + "/": xr.Dataset({"age": 83}), + "/Herbert": xr.Dataset({"age": 40}), + "/Homer": xr.Dataset({"age": 39}), + "/Homer/Bart": xr.Dataset({"age": 10}), + "/Homer/Lisa": xr.Dataset({"age": 8}), + "/Homer/Maggie": xr.Dataset({"age": 1}), + }, + name="Abe", + ) + expected = DataTree.from_dict( + d={ + "/": xr.Dataset({"age": 83}), + "/Herbert": xr.Dataset({"age": 40}), + "/Homer": xr.Dataset({"age": 39}), + }, + name="Abe", + ) + elders = simpsons.filter(lambda node: node["age"] > 18) + dtt.assert_identical(elders, expected) diff --git a/xarray/datatree_/datatree/tests/test_extensions.py b/xarray/datatree_/datatree/tests/test_extensions.py new file mode 100644 index 00000000000..0241e496abf --- /dev/null +++ b/xarray/datatree_/datatree/tests/test_extensions.py @@ -0,0 +1,40 @@ +import pytest + +from xarray.datatree_.datatree import DataTree, register_datatree_accessor + + +class TestAccessor: + def test_register(self) -> None: + @register_datatree_accessor("demo") + class DemoAccessor: + """Demo accessor.""" + + def __init__(self, xarray_obj): + self._obj = xarray_obj + + @property + def foo(self): + return "bar" + + dt: DataTree = DataTree() + assert dt.demo.foo == "bar" # type: ignore + + # accessor is cached + assert dt.demo is dt.demo # type: ignore + + # check descriptor + assert dt.demo.__doc__ == "Demo accessor." # type: ignore + # TODO: typing doesn't seem to work with accessors + assert DataTree.demo.__doc__ == "Demo accessor." # type: ignore + assert isinstance(dt.demo, DemoAccessor) # type: ignore + assert DataTree.demo is DemoAccessor # type: ignore + + with pytest.warns(Warning, match="overriding a preexisting attribute"): + + @register_datatree_accessor("demo") + class Foo: + pass + + # ensure we can remove it + del DataTree.demo # type: ignore + assert not hasattr(DataTree, "demo") diff --git a/xarray/datatree_/datatree/tests/test_formatting.py b/xarray/datatree_/datatree/tests/test_formatting.py new file mode 100644 index 00000000000..b58c02282e7 --- /dev/null +++ b/xarray/datatree_/datatree/tests/test_formatting.py @@ -0,0 +1,120 @@ +from textwrap import dedent + +from xarray import Dataset + +from xarray.datatree_.datatree import DataTree +from xarray.datatree_.datatree.formatting import diff_tree_repr + + +class TestRepr: + def test_print_empty_node(self): + dt = DataTree(name="root") + printout = dt.__str__() + assert printout == "DataTree('root', parent=None)" + + def test_print_empty_node_with_attrs(self): + dat = Dataset(attrs={"note": "has attrs"}) + dt = DataTree(name="root", data=dat) + printout = dt.__str__() + assert printout == dedent( + """\ + DataTree('root', parent=None) + Dimensions: () + Data variables: + *empty* + Attributes: + note: has attrs""" + ) + + def test_print_node_with_data(self): + dat = Dataset({"a": [0, 2]}) + dt = DataTree(name="root", data=dat) + printout = dt.__str__() + expected = [ + "DataTree('root', parent=None)", + "Dimensions", + "Coordinates", + "a", + "Data variables", + "*empty*", + ] + for expected_line, printed_line in zip(expected, printout.splitlines()): + assert expected_line in printed_line + + def test_nested_node(self): + dat = Dataset({"a": [0, 2]}) + root = DataTree(name="root") + DataTree(name="results", data=dat, parent=root) + printout = root.__str__() + assert printout.splitlines()[2].startswith(" ") + + def test_print_datatree(self, simple_datatree): + dt = simple_datatree + print(dt) + + # TODO work out how to test something complex like this + + def test_repr_of_node_with_data(self): + dat = Dataset({"a": [0, 2]}) + dt = DataTree(name="root", data=dat) + assert "Coordinates" in repr(dt) + + +class TestDiffFormatting: + def test_diff_structure(self): + dt_1 = DataTree.from_dict({"a": None, "a/b": None, "a/c": None}) + dt_2 = DataTree.from_dict({"d": None, "d/e": None}) + + expected = dedent( + """\ + Left and right DataTree objects are not isomorphic + + Number of children on node '/a' of the left object: 2 + Number of children on node '/d' of the right object: 1""" + ) + actual = diff_tree_repr(dt_1, dt_2, "isomorphic") + assert actual == expected + + def test_diff_node_names(self): + dt_1 = DataTree.from_dict({"a": None}) + dt_2 = DataTree.from_dict({"b": None}) + + expected = dedent( + """\ + Left and right DataTree objects are not identical + + Node '/a' in the left object has name 'a' + Node '/b' in the right object has name 'b'""" + ) + actual = diff_tree_repr(dt_1, dt_2, "identical") + assert actual == expected + + def test_diff_node_data(self): + import numpy as np + + # casting to int64 explicitly ensures that int64s are created on all architectures + ds1 = Dataset({"u": np.int64(0), "v": np.int64(1)}) + ds3 = Dataset({"w": np.int64(5)}) + dt_1 = DataTree.from_dict({"a": ds1, "a/b": ds3}) + ds2 = Dataset({"u": np.int64(0)}) + ds4 = Dataset({"w": np.int64(6)}) + dt_2 = DataTree.from_dict({"a": ds2, "a/b": ds4}) + + expected = dedent( + """\ + Left and right DataTree objects are not equal + + + Data in nodes at position '/a' do not match: + + Data variables only on the left object: + v int64 8B 1 + + Data in nodes at position '/a/b' do not match: + + Differing data variables: + L w int64 8B 5 + R w int64 8B 6""" + ) + actual = diff_tree_repr(dt_1, dt_2, "equals") + assert actual == expected diff --git a/xarray/datatree_/datatree/tests/test_formatting_html.py b/xarray/datatree_/datatree/tests/test_formatting_html.py new file mode 100644 index 00000000000..943bbab4154 --- /dev/null +++ b/xarray/datatree_/datatree/tests/test_formatting_html.py @@ -0,0 +1,197 @@ +import pytest +import xarray as xr + +from xarray.datatree_.datatree import DataTree, formatting_html + + +@pytest.fixture(scope="module", params=["some html", "some other html"]) +def repr(request): + return request.param + + +class Test_summarize_children: + """ + Unit tests for summarize_children. + """ + + func = staticmethod(formatting_html.summarize_children) + + @pytest.fixture(scope="class") + def childfree_tree_factory(self): + """ + Fixture for a child-free DataTree factory. + """ + from random import randint + + def _childfree_tree_factory(): + return DataTree( + data=xr.Dataset({"z": ("y", [randint(1, 100) for _ in range(3)])}) + ) + + return _childfree_tree_factory + + @pytest.fixture(scope="class") + def childfree_tree(self, childfree_tree_factory): + """ + Fixture for a child-free DataTree. + """ + return childfree_tree_factory() + + @pytest.fixture(scope="function") + def mock_node_repr(self, monkeypatch): + """ + Apply mocking for node_repr. + """ + + def mock(group_title, dt): + """ + Mock with a simple result + """ + return group_title + " " + str(id(dt)) + + monkeypatch.setattr(formatting_html, "node_repr", mock) + + @pytest.fixture(scope="function") + def mock_wrap_repr(self, monkeypatch): + """ + Apply mocking for _wrap_repr. + """ + + def mock(r, *, end, **kwargs): + """ + Mock by appending "end" or "not end". + """ + return r + " " + ("end" if end else "not end") + "//" + + monkeypatch.setattr(formatting_html, "_wrap_repr", mock) + + def test_empty_mapping(self): + """ + Test with an empty mapping of children. + """ + children = {} + assert self.func(children) == ( + "
    " "
    " + ) + + def test_one_child(self, childfree_tree, mock_wrap_repr, mock_node_repr): + """ + Test with one child. + + Uses a mock of _wrap_repr and node_repr to essentially mock + the inline lambda function "lines_callback". + """ + # Create mapping of children + children = {"a": childfree_tree} + + # Expect first line to be produced from the first child, and + # wrapped as the last child + first_line = f"a {id(children['a'])} end//" + + assert self.func(children) == ( + "
    " + f"{first_line}" + "
    " + ) + + def test_two_children(self, childfree_tree_factory, mock_wrap_repr, mock_node_repr): + """ + Test with two level deep children. + + Uses a mock of _wrap_repr and node_repr to essentially mock + the inline lambda function "lines_callback". + """ + + # Create mapping of children + children = {"a": childfree_tree_factory(), "b": childfree_tree_factory()} + + # Expect first line to be produced from the first child, and + # wrapped as _not_ the last child + first_line = f"a {id(children['a'])} not end//" + + # Expect second line to be produced from the second child, and + # wrapped as the last child + second_line = f"b {id(children['b'])} end//" + + assert self.func(children) == ( + "
    " + f"{first_line}" + f"{second_line}" + "
    " + ) + + +class Test__wrap_repr: + """ + Unit tests for _wrap_repr. + """ + + func = staticmethod(formatting_html._wrap_repr) + + def test_end(self, repr): + """ + Test with end=True. + """ + r = self.func(repr, end=True) + assert r == ( + "
    " + "
    " + "
    " + "
    " + "
    " + "
    " + "
      " + f"{repr}" + "
    " + "
    " + "
    " + ) + + def test_not_end(self, repr): + """ + Test with end=False. + """ + r = self.func(repr, end=False) + assert r == ( + "
    " + "
    " + "
    " + "
    " + "
    " + "
    " + "
      " + f"{repr}" + "
    " + "
    " + "
    " + ) diff --git a/xarray/datatree_/datatree/tests/test_mapping.py b/xarray/datatree_/datatree/tests/test_mapping.py new file mode 100644 index 00000000000..53d6e085440 --- /dev/null +++ b/xarray/datatree_/datatree/tests/test_mapping.py @@ -0,0 +1,343 @@ +import numpy as np +import pytest +import xarray as xr + +from xarray.datatree_.datatree.datatree import DataTree +from xarray.datatree_.datatree.mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree +from xarray.datatree_.datatree.testing import assert_equal + +empty = xr.Dataset() + + +class TestCheckTreesIsomorphic: + def test_not_a_tree(self): + with pytest.raises(TypeError, match="not a tree"): + check_isomorphic("s", 1) + + def test_different_widths(self): + dt1 = DataTree.from_dict(d={"a": empty}) + dt2 = DataTree.from_dict(d={"b": empty, "c": empty}) + expected_err_str = ( + "Number of children on node '/' of the left object: 1\n" + "Number of children on node '/' of the right object: 2" + ) + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + check_isomorphic(dt1, dt2) + + def test_different_heights(self): + dt1 = DataTree.from_dict({"a": empty}) + dt2 = DataTree.from_dict({"b": empty, "b/c": empty}) + expected_err_str = ( + "Number of children on node '/a' of the left object: 0\n" + "Number of children on node '/b' of the right object: 1" + ) + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + check_isomorphic(dt1, dt2) + + def test_names_different(self): + dt1 = DataTree.from_dict({"a": xr.Dataset()}) + dt2 = DataTree.from_dict({"b": empty}) + expected_err_str = ( + "Node '/a' in the left object has name 'a'\n" + "Node '/b' in the right object has name 'b'" + ) + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + check_isomorphic(dt1, dt2, require_names_equal=True) + + def test_isomorphic_names_equal(self): + dt1 = DataTree.from_dict({"a": empty, "b": empty, "b/c": empty, "b/d": empty}) + dt2 = DataTree.from_dict({"a": empty, "b": empty, "b/c": empty, "b/d": empty}) + check_isomorphic(dt1, dt2, require_names_equal=True) + + def test_isomorphic_ordering(self): + dt1 = DataTree.from_dict({"a": empty, "b": empty, "b/d": empty, "b/c": empty}) + dt2 = DataTree.from_dict({"a": empty, "b": empty, "b/c": empty, "b/d": empty}) + check_isomorphic(dt1, dt2, require_names_equal=False) + + def test_isomorphic_names_not_equal(self): + dt1 = DataTree.from_dict({"a": empty, "b": empty, "b/c": empty, "b/d": empty}) + dt2 = DataTree.from_dict({"A": empty, "B": empty, "B/C": empty, "B/D": empty}) + check_isomorphic(dt1, dt2) + + def test_not_isomorphic_complex_tree(self, create_test_datatree): + dt1 = create_test_datatree() + dt2 = create_test_datatree() + dt2["set1/set2/extra"] = DataTree(name="extra") + with pytest.raises(TreeIsomorphismError, match="/set1/set2"): + check_isomorphic(dt1, dt2) + + def test_checking_from_root(self, create_test_datatree): + dt1 = create_test_datatree() + dt2 = create_test_datatree() + real_root = DataTree(name="real root") + dt2.name = "not_real_root" + dt2.parent = real_root + with pytest.raises(TreeIsomorphismError): + check_isomorphic(dt1, dt2, check_from_root=True) + + +class TestMapOverSubTree: + def test_no_trees_passed(self): + @map_over_subtree + def times_ten(ds): + return 10.0 * ds + + with pytest.raises(TypeError, match="Must pass at least one tree"): + times_ten("dt") + + def test_not_isomorphic(self, create_test_datatree): + dt1 = create_test_datatree() + dt2 = create_test_datatree() + dt2["set1/set2/extra"] = DataTree(name="extra") + + @map_over_subtree + def times_ten(ds1, ds2): + return ds1 * ds2 + + with pytest.raises(TreeIsomorphismError): + times_ten(dt1, dt2) + + def test_no_trees_returned(self, create_test_datatree): + dt1 = create_test_datatree() + dt2 = create_test_datatree() + + @map_over_subtree + def bad_func(ds1, ds2): + return None + + with pytest.raises(TypeError, match="return value of None"): + bad_func(dt1, dt2) + + def test_single_dt_arg(self, create_test_datatree): + dt = create_test_datatree() + + @map_over_subtree + def times_ten(ds): + return 10.0 * ds + + expected = create_test_datatree(modify=lambda ds: 10.0 * ds) + result_tree = times_ten(dt) + assert_equal(result_tree, expected) + + def test_single_dt_arg_plus_args_and_kwargs(self, create_test_datatree): + dt = create_test_datatree() + + @map_over_subtree + def multiply_then_add(ds, times, add=0.0): + return (times * ds) + add + + expected = create_test_datatree(modify=lambda ds: (10.0 * ds) + 2.0) + result_tree = multiply_then_add(dt, 10.0, add=2.0) + assert_equal(result_tree, expected) + + def test_multiple_dt_args(self, create_test_datatree): + dt1 = create_test_datatree() + dt2 = create_test_datatree() + + @map_over_subtree + def add(ds1, ds2): + return ds1 + ds2 + + expected = create_test_datatree(modify=lambda ds: 2.0 * ds) + result = add(dt1, dt2) + assert_equal(result, expected) + + def test_dt_as_kwarg(self, create_test_datatree): + dt1 = create_test_datatree() + dt2 = create_test_datatree() + + @map_over_subtree + def add(ds1, value=0.0): + return ds1 + value + + expected = create_test_datatree(modify=lambda ds: 2.0 * ds) + result = add(dt1, value=dt2) + assert_equal(result, expected) + + def test_return_multiple_dts(self, create_test_datatree): + dt = create_test_datatree() + + @map_over_subtree + def minmax(ds): + return ds.min(), ds.max() + + dt_min, dt_max = minmax(dt) + expected_min = create_test_datatree(modify=lambda ds: ds.min()) + assert_equal(dt_min, expected_min) + expected_max = create_test_datatree(modify=lambda ds: ds.max()) + assert_equal(dt_max, expected_max) + + def test_return_wrong_type(self, simple_datatree): + dt1 = simple_datatree + + @map_over_subtree + def bad_func(ds1): + return "string" + + with pytest.raises(TypeError, match="not Dataset or DataArray"): + bad_func(dt1) + + def test_return_tuple_of_wrong_types(self, simple_datatree): + dt1 = simple_datatree + + @map_over_subtree + def bad_func(ds1): + return xr.Dataset(), "string" + + with pytest.raises(TypeError, match="not Dataset or DataArray"): + bad_func(dt1) + + @pytest.mark.xfail + def test_return_inconsistent_number_of_results(self, simple_datatree): + dt1 = simple_datatree + + @map_over_subtree + def bad_func(ds): + # Datasets in simple_datatree have different numbers of dims + # TODO need to instead return different numbers of Dataset objects for this test to catch the intended error + return tuple(ds.dims) + + with pytest.raises(TypeError, match="instead returns"): + bad_func(dt1) + + def test_wrong_number_of_arguments_for_func(self, simple_datatree): + dt = simple_datatree + + @map_over_subtree + def times_ten(ds): + return 10.0 * ds + + with pytest.raises( + TypeError, match="takes 1 positional argument but 2 were given" + ): + times_ten(dt, dt) + + def test_map_single_dataset_against_whole_tree(self, create_test_datatree): + dt = create_test_datatree() + + @map_over_subtree + def nodewise_merge(node_ds, fixed_ds): + return xr.merge([node_ds, fixed_ds]) + + other_ds = xr.Dataset({"z": ("z", [0])}) + expected = create_test_datatree(modify=lambda ds: xr.merge([ds, other_ds])) + result_tree = nodewise_merge(dt, other_ds) + assert_equal(result_tree, expected) + + @pytest.mark.xfail + def test_trees_with_different_node_names(self): + # TODO test this after I've got good tests for renaming nodes + raise NotImplementedError + + def test_dt_method(self, create_test_datatree): + dt = create_test_datatree() + + def multiply_then_add(ds, times, add=0.0): + return times * ds + add + + expected = create_test_datatree(modify=lambda ds: (10.0 * ds) + 2.0) + result_tree = dt.map_over_subtree(multiply_then_add, 10.0, add=2.0) + assert_equal(result_tree, expected) + + def test_discard_ancestry(self, create_test_datatree): + # Check for datatree GH issue https://github.com/xarray-contrib/datatree/issues/48 + dt = create_test_datatree() + subtree = dt["set1"] + + @map_over_subtree + def times_ten(ds): + return 10.0 * ds + + expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"] + result_tree = times_ten(subtree) + assert_equal(result_tree, expected, from_root=False) + + def test_skip_empty_nodes_with_attrs(self, create_test_datatree): + # inspired by xarray-datatree GH262 + dt = create_test_datatree() + dt["set1/set2"].attrs["foo"] = "bar" + + def check_for_data(ds): + # fails if run on a node that has no data + assert len(ds.variables) != 0 + return ds + + dt.map_over_subtree(check_for_data) + + def test_keep_attrs_on_empty_nodes(self, create_test_datatree): + # GH278 + dt = create_test_datatree() + dt["set1/set2"].attrs["foo"] = "bar" + + def empty_func(ds): + return ds + + result = dt.map_over_subtree(empty_func) + assert result["set1/set2"].attrs == dt["set1/set2"].attrs + + @pytest.mark.xfail( + reason="probably some bug in pytests handling of exception notes" + ) + def test_error_contains_path_of_offending_node(self, create_test_datatree): + dt = create_test_datatree() + dt["set1"]["bad_var"] = 0 + print(dt) + + def fail_on_specific_node(ds): + if "bad_var" in ds: + raise ValueError("Failed because 'bar_var' present in dataset") + + with pytest.raises( + ValueError, match="Raised whilst mapping function over node /set1" + ): + dt.map_over_subtree(fail_on_specific_node) + + +class TestMutableOperations: + def test_construct_using_type(self): + # from datatree GH issue https://github.com/xarray-contrib/datatree/issues/188 + # xarray's .weighted is unusual because it uses type() to create a Dataset/DataArray + + a = xr.DataArray( + np.random.rand(3, 4, 10), + dims=["x", "y", "time"], + coords={"area": (["x", "y"], np.random.rand(3, 4))}, + ).to_dataset(name="data") + b = xr.DataArray( + np.random.rand(2, 6, 14), + dims=["x", "y", "time"], + coords={"area": (["x", "y"], np.random.rand(2, 6))}, + ).to_dataset(name="data") + dt = DataTree.from_dict({"a": a, "b": b}) + + def weighted_mean(ds): + return ds.weighted(ds.area).mean(["x", "y"]) + + dt.map_over_subtree(weighted_mean) + + def test_alter_inplace_forbidden(self): + simpsons = DataTree.from_dict( + d={ + "/": xr.Dataset({"age": 83}), + "/Herbert": xr.Dataset({"age": 40}), + "/Homer": xr.Dataset({"age": 39}), + "/Homer/Bart": xr.Dataset({"age": 10}), + "/Homer/Lisa": xr.Dataset({"age": 8}), + "/Homer/Maggie": xr.Dataset({"age": 1}), + }, + name="Abe", + ) + + def fast_forward(ds: xr.Dataset, years: float) -> xr.Dataset: + """Add some years to the age, but by altering the given dataset""" + ds["age"] = ds["age"] + years + return ds + + with pytest.raises(AttributeError): + simpsons.map_over_subtree(fast_forward, years=10) + + +@pytest.mark.xfail +class TestMapOverSubTreeInplace: + def test_map_over_subtree_inplace(self): + raise NotImplementedError diff --git a/xarray/datatree_/docs/Makefile b/xarray/datatree_/docs/Makefile new file mode 100644 index 00000000000..6e9b4058414 --- /dev/null +++ b/xarray/datatree_/docs/Makefile @@ -0,0 +1,183 @@ +# Makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +PAPER = +BUILDDIR = _build + +# User-friendly check for sphinx-build +ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) +$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif + +# Internal variables. +PAPEROPT_a4 = -D latex_paper_size=a4 +PAPEROPT_letter = -D latex_paper_size=letter +ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source +# the i18n builder cannot share the environment and doctrees with the others +I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source + +.PHONY: help clean html rtdhtml dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext + +help: + @echo "Please use \`make ' where is one of" + @echo " html to make standalone HTML files" + @echo " rtdhtml Build html using same settings used on ReadtheDocs" + @echo " dirhtml to make HTML files named index.html in directories" + @echo " singlehtml to make a single large HTML file" + @echo " pickle to make pickle files" + @echo " json to make JSON files" + @echo " htmlhelp to make HTML files and a HTML help project" + @echo " qthelp to make HTML files and a qthelp project" + @echo " devhelp to make HTML files and a Devhelp project" + @echo " epub to make an epub" + @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" + @echo " latexpdf to make LaTeX files and run them through pdflatex" + @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" + @echo " text to make text files" + @echo " man to make manual pages" + @echo " texinfo to make Texinfo files" + @echo " info to make Texinfo files and run them through makeinfo" + @echo " gettext to make PO message catalogs" + @echo " changes to make an overview of all changed/added/deprecated items" + @echo " xml to make Docutils-native XML files" + @echo " pseudoxml to make pseudoxml-XML files for display purposes" + @echo " linkcheck to check all external links for integrity" + @echo " doctest to run all doctests embedded in the documentation (if enabled)" + +clean: + rm -rf $(BUILDDIR)/* + +html: + $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +rtdhtml: + $(SPHINXBUILD) -T -j auto -E -W --keep-going -b html -d $(BUILDDIR)/doctrees -D language=en . $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +dirhtml: + $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." + +singlehtml: + $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml + @echo + @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." + +pickle: + $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle + @echo + @echo "Build finished; now you can process the pickle files." + +json: + $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json + @echo + @echo "Build finished; now you can process the JSON files." + +htmlhelp: + $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp + @echo + @echo "Build finished; now you can run HTML Help Workshop with the" \ + ".hhp project file in $(BUILDDIR)/htmlhelp." + +qthelp: + $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp + @echo + @echo "Build finished; now you can run "qcollectiongenerator" with the" \ + ".qhcp project file in $(BUILDDIR)/qthelp, like this:" + @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/complexity.qhcp" + @echo "To view the help file:" + @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/complexity.qhc" + +devhelp: + $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp + @echo + @echo "Build finished." + @echo "To view the help file:" + @echo "# mkdir -p $$HOME/.local/share/devhelp/complexity" + @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/complexity" + @echo "# devhelp" + +epub: + $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub + @echo + @echo "Build finished. The epub file is in $(BUILDDIR)/epub." + +latex: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo + @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." + @echo "Run \`make' in that directory to run these through (pdf)latex" \ + "(use \`make latexpdf' here to do that automatically)." + +latexpdf: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through pdflatex..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +latexpdfja: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through platex and dvipdfmx..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +text: + $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text + @echo + @echo "Build finished. The text files are in $(BUILDDIR)/text." + +man: + $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man + @echo + @echo "Build finished. The manual pages are in $(BUILDDIR)/man." + +texinfo: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo + @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." + @echo "Run \`make' in that directory to run these through makeinfo" \ + "(use \`make info' here to do that automatically)." + +info: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo "Running Texinfo files through makeinfo..." + make -C $(BUILDDIR)/texinfo info + @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." + +gettext: + $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale + @echo + @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." + +changes: + $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes + @echo + @echo "The overview file is in $(BUILDDIR)/changes." + +linkcheck: + $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck + @echo + @echo "Link check complete; look for any errors in the above output " \ + "or in $(BUILDDIR)/linkcheck/output.txt." + +doctest: + $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest + @echo "Testing of doctests in the sources finished, look at the " \ + "results in $(BUILDDIR)/doctest/output.txt." + +xml: + $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml + @echo + @echo "Build finished. The XML files are in $(BUILDDIR)/xml." + +pseudoxml: + $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml + @echo + @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." diff --git a/xarray/datatree_/docs/README.md b/xarray/datatree_/docs/README.md new file mode 100644 index 00000000000..ca2bf72952e --- /dev/null +++ b/xarray/datatree_/docs/README.md @@ -0,0 +1,14 @@ +# README - docs + +## Build the documentation locally + +```bash +cd docs # From project's root +make clean +rm -rf source/generated # remove autodoc artefacts, that are not removed by `make clean` +make html +``` + +## Access the documentation locally + +Open `docs/_build/html/index.html` in a web browser diff --git a/xarray/datatree_/docs/make.bat b/xarray/datatree_/docs/make.bat new file mode 100644 index 00000000000..2df9a8cbbb6 --- /dev/null +++ b/xarray/datatree_/docs/make.bat @@ -0,0 +1,242 @@ +@ECHO OFF + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +set I18NSPHINXOPTS=%SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% + set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. texinfo to make Texinfo files + echo. gettext to make PO message catalogs + echo. changes to make an overview over all changed/added/deprecated items + echo. xml to make Docutils-native XML files + echo. pseudoxml to make pseudoxml-XML files for display purposes + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + + +%SPHINXBUILD% 2> nul +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\complexity.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\complexity.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdf" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdfja" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf-ja + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "texinfo" ( + %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. + goto end +) + +if "%1" == "gettext" ( + %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The message catalogs are in %BUILDDIR%/locale. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +if "%1" == "xml" ( + %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The XML files are in %BUILDDIR%/xml. + goto end +) + +if "%1" == "pseudoxml" ( + %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. + goto end +) + +:end diff --git a/xarray/datatree_/docs/source/api.rst b/xarray/datatree_/docs/source/api.rst new file mode 100644 index 00000000000..d325d24f4a4 --- /dev/null +++ b/xarray/datatree_/docs/source/api.rst @@ -0,0 +1,362 @@ +.. currentmodule:: datatree + +############# +API reference +############# + +DataTree +======== + +Creating a DataTree +------------------- + +Methods of creating a datatree. + +.. autosummary:: + :toctree: generated/ + + DataTree + DataTree.from_dict + +Tree Attributes +--------------- + +Attributes relating to the recursive tree-like structure of a ``DataTree``. + +.. autosummary:: + :toctree: generated/ + + DataTree.parent + DataTree.children + DataTree.name + DataTree.path + DataTree.root + DataTree.is_root + DataTree.is_leaf + DataTree.leaves + DataTree.level + DataTree.depth + DataTree.width + DataTree.subtree + DataTree.descendants + DataTree.siblings + DataTree.lineage + DataTree.parents + DataTree.ancestors + DataTree.groups + +Data Contents +------------- + +Interface to the data objects (optionally) stored inside a single ``DataTree`` node. +This interface echoes that of ``xarray.Dataset``. + +.. autosummary:: + :toctree: generated/ + + DataTree.dims + DataTree.sizes + DataTree.data_vars + DataTree.coords + DataTree.attrs + DataTree.encoding + DataTree.indexes + DataTree.nbytes + DataTree.ds + DataTree.to_dataset + DataTree.has_data + DataTree.has_attrs + DataTree.is_empty + DataTree.is_hollow + +Dictionary Interface +-------------------- + +``DataTree`` objects also have a dict-like interface mapping keys to either ``xarray.DataArray``s or to child ``DataTree`` nodes. + +.. autosummary:: + :toctree: generated/ + + DataTree.__getitem__ + DataTree.__setitem__ + DataTree.__delitem__ + DataTree.update + DataTree.get + DataTree.items + DataTree.keys + DataTree.values + +Tree Manipulation +----------------- + +For manipulating, traversing, navigating, or mapping over the tree structure. + +.. autosummary:: + :toctree: generated/ + + DataTree.orphan + DataTree.same_tree + DataTree.relative_to + DataTree.iter_lineage + DataTree.find_common_ancestor + DataTree.map_over_subtree + map_over_subtree + DataTree.pipe + DataTree.match + DataTree.filter + +Pathlib-like Interface +---------------------- + +``DataTree`` objects deliberately echo some of the API of `pathlib.PurePath`. + +.. autosummary:: + :toctree: generated/ + + DataTree.name + DataTree.parent + DataTree.parents + DataTree.relative_to + +Missing: + +.. + + ``DataTree.glob`` + ``DataTree.joinpath`` + ``DataTree.with_name`` + ``DataTree.walk`` + ``DataTree.rename`` + ``DataTree.replace`` + +DataTree Contents +----------------- + +Manipulate the contents of all nodes in a tree simultaneously. + +.. autosummary:: + :toctree: generated/ + + DataTree.copy + DataTree.assign_coords + DataTree.merge + DataTree.rename + DataTree.rename_vars + DataTree.rename_dims + DataTree.swap_dims + DataTree.expand_dims + DataTree.drop_vars + DataTree.drop_dims + DataTree.set_coords + DataTree.reset_coords + +DataTree Node Contents +---------------------- + +Manipulate the contents of a single DataTree node. + +.. autosummary:: + :toctree: generated/ + + DataTree.assign + DataTree.drop_nodes + +Comparisons +=========== + +Compare one ``DataTree`` object to another. + +.. autosummary:: + :toctree: generated/ + + DataTree.isomorphic + DataTree.equals + DataTree.identical + +Indexing +======== + +Index into all nodes in the subtree simultaneously. + +.. autosummary:: + :toctree: generated/ + + DataTree.isel + DataTree.sel + DataTree.drop_sel + DataTree.drop_isel + DataTree.head + DataTree.tail + DataTree.thin + DataTree.squeeze + DataTree.interp + DataTree.interp_like + DataTree.reindex + DataTree.reindex_like + DataTree.set_index + DataTree.reset_index + DataTree.reorder_levels + DataTree.query + +.. + + Missing: + ``DataTree.loc`` + + +Missing Value Handling +====================== + +.. autosummary:: + :toctree: generated/ + + DataTree.isnull + DataTree.notnull + DataTree.combine_first + DataTree.dropna + DataTree.fillna + DataTree.ffill + DataTree.bfill + DataTree.interpolate_na + DataTree.where + DataTree.isin + +Computation +=========== + +Apply a computation to the data in all nodes in the subtree simultaneously. + +.. autosummary:: + :toctree: generated/ + + DataTree.map + DataTree.reduce + DataTree.diff + DataTree.quantile + DataTree.differentiate + DataTree.integrate + DataTree.map_blocks + DataTree.polyfit + DataTree.curvefit + +Aggregation +=========== + +Aggregate data in all nodes in the subtree simultaneously. + +.. autosummary:: + :toctree: generated/ + + DataTree.all + DataTree.any + DataTree.argmax + DataTree.argmin + DataTree.idxmax + DataTree.idxmin + DataTree.max + DataTree.min + DataTree.mean + DataTree.median + DataTree.prod + DataTree.sum + DataTree.std + DataTree.var + DataTree.cumsum + DataTree.cumprod + +ndarray methods +=============== + +Methods copied from :py:class:`numpy.ndarray` objects, here applying to the data in all nodes in the subtree. + +.. autosummary:: + :toctree: generated/ + + DataTree.argsort + DataTree.astype + DataTree.clip + DataTree.conj + DataTree.conjugate + DataTree.round + DataTree.rank + +Reshaping and reorganising +========================== + +Reshape or reorganise the data in all nodes in the subtree. + +.. autosummary:: + :toctree: generated/ + + DataTree.transpose + DataTree.stack + DataTree.unstack + DataTree.shift + DataTree.roll + DataTree.pad + DataTree.sortby + DataTree.broadcast_like + +Plotting +======== + +I/O +=== + +Open a datatree from an on-disk store or serialize the tree. + +.. autosummary:: + :toctree: generated/ + + open_datatree + DataTree.to_dict + DataTree.to_netcdf + DataTree.to_zarr + +.. + + Missing: + ``open_mfdatatree`` + +Tutorial +======== + +Testing +======= + +Test that two DataTree objects are similar. + +.. autosummary:: + :toctree: generated/ + + testing.assert_isomorphic + testing.assert_equal + testing.assert_identical + +Exceptions +========== + +Exceptions raised when manipulating trees. + +.. autosummary:: + :toctree: generated/ + + TreeIsomorphismError + InvalidTreeError + NotFoundInTreeError + +Advanced API +============ + +Relatively advanced API for users or developers looking to understand the internals, or extend functionality. + +.. autosummary:: + :toctree: generated/ + + DataTree.variables + register_datatree_accessor + +.. + + Missing: + ``DataTree.set_close`` diff --git a/xarray/datatree_/docs/source/conf.py b/xarray/datatree_/docs/source/conf.py new file mode 100644 index 00000000000..8a9224def5b --- /dev/null +++ b/xarray/datatree_/docs/source/conf.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +# flake8: noqa +# Ignoring F401: imported but unused + +# complexity documentation build configuration file, created by +# sphinx-quickstart on Tue Jul 9 22:26:36 2013. +# +# This file is execfile()d with the current directory set to its containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import inspect +import os +import sys + +import sphinx_autosummary_accessors + +import datatree + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# sys.path.insert(0, os.path.abspath('.')) + +cwd = os.getcwd() +parent = os.path.dirname(cwd) +sys.path.insert(0, parent) + + +# -- General configuration ----------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be extensions +# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.viewcode", + "sphinx.ext.linkcode", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.extlinks", + "sphinx.ext.napoleon", + "sphinx_copybutton", + "sphinxext.opengraph", + "sphinx_autosummary_accessors", + "IPython.sphinxext.ipython_console_highlighting", + "IPython.sphinxext.ipython_directive", + "nbsphinx", + "sphinxcontrib.srclinks", +] + +extlinks = { + "issue": ("https://github.com/xarray-contrib/datatree/issues/%s", "GH#%s"), + "pull": ("https://github.com/xarray-contrib/datatree/pull/%s", "GH#%s"), +} +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates", sphinx_autosummary_accessors.templates_path] + +# Generate the API documentation when building +autosummary_generate = True + + +# Napoleon configurations + +napoleon_google_docstring = False +napoleon_numpy_docstring = True +napoleon_use_param = False +napoleon_use_rtype = False +napoleon_preprocess_types = True +napoleon_type_aliases = { + # general terms + "sequence": ":term:`sequence`", + "iterable": ":term:`iterable`", + "callable": ":py:func:`callable`", + "dict_like": ":term:`dict-like `", + "dict-like": ":term:`dict-like `", + "path-like": ":term:`path-like `", + "mapping": ":term:`mapping`", + "file-like": ":term:`file-like `", + # special terms + # "same type as caller": "*same type as caller*", # does not work, yet + # "same type as values": "*same type as values*", # does not work, yet + # stdlib type aliases + "MutableMapping": "~collections.abc.MutableMapping", + "sys.stdout": ":obj:`sys.stdout`", + "timedelta": "~datetime.timedelta", + "string": ":class:`string `", + # numpy terms + "array_like": ":term:`array_like`", + "array-like": ":term:`array-like `", + "scalar": ":term:`scalar`", + "array": ":term:`array`", + "hashable": ":term:`hashable `", + # matplotlib terms + "color-like": ":py:func:`color-like `", + "matplotlib colormap name": ":doc:`matplotlib colormap name `", + "matplotlib axes object": ":py:class:`matplotlib axes object `", + "colormap": ":py:class:`colormap `", + # objects without namespace: xarray + "DataArray": "~xarray.DataArray", + "Dataset": "~xarray.Dataset", + "Variable": "~xarray.Variable", + "DatasetGroupBy": "~xarray.core.groupby.DatasetGroupBy", + "DataArrayGroupBy": "~xarray.core.groupby.DataArrayGroupBy", + # objects without namespace: numpy + "ndarray": "~numpy.ndarray", + "MaskedArray": "~numpy.ma.MaskedArray", + "dtype": "~numpy.dtype", + "ComplexWarning": "~numpy.ComplexWarning", + # objects without namespace: pandas + "Index": "~pandas.Index", + "MultiIndex": "~pandas.MultiIndex", + "CategoricalIndex": "~pandas.CategoricalIndex", + "TimedeltaIndex": "~pandas.TimedeltaIndex", + "DatetimeIndex": "~pandas.DatetimeIndex", + "Series": "~pandas.Series", + "DataFrame": "~pandas.DataFrame", + "Categorical": "~pandas.Categorical", + "Path": "~~pathlib.Path", + # objects with abbreviated namespace (from pandas) + "pd.Index": "~pandas.Index", + "pd.NaT": "~pandas.NaT", +} + +# The suffix of source filenames. +source_suffix = ".rst" + +# The encoding of source files. +# source_encoding = 'utf-8-sig' + +# The master toctree document. +master_doc = "index" + +# General information about the project. +project = "Datatree" +copyright = "2021 onwards, Tom Nicholas and its Contributors" +author = "Tom Nicholas" + +html_show_sourcelink = True +srclink_project = "https://github.com/xarray-contrib/datatree" +srclink_branch = "main" +srclink_src_path = "docs/source" + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = datatree.__version__ +# The full version, including alpha/beta/rc tags. +release = datatree.__version__ + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# language = None + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +# today = '' +# Else, today_fmt is used as the format for a strftime call. +# today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = ["_build"] + +# The reST default role (used for this markup: `text`) to use for all documents. +# default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +# add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +# add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +# show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# A list of ignored prefixes for module index sorting. +# modindex_common_prefix = [] + +# If true, keep warnings as "system message" paragraphs in the built documents. +# keep_warnings = False + + +# -- Intersphinx links --------------------------------------------------------- + +intersphinx_mapping = { + "python": ("https://docs.python.org/3.8/", None), + "numpy": ("https://numpy.org/doc/stable", None), + "xarray": ("https://xarray.pydata.org/en/stable/", None), +} + +# -- Options for HTML output --------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = "sphinx_book_theme" + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +html_theme_options = { + "repository_url": "https://github.com/xarray-contrib/datatree", + "repository_branch": "main", + "path_to_docs": "docs/source", + "use_repository_button": True, + "use_issues_button": True, + "use_edit_page_button": True, +} + +# Add any paths that contain custom themes here, relative to this directory. +# html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +# html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +# html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +# html_logo = None + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +# html_favicon = None + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +# html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +# html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +# html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +# html_additional_pages = {} + +# If false, no module index is generated. +# html_domain_indices = True + +# If false, no index is generated. +# html_use_index = True + +# If true, the index is split into individual pages for each letter. +# html_split_index = False + +# If true, links to the reST sources are added to the pages. +# html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +# html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +# html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +# html_use_opensearch = '' + +# This is the file name suffix for HTML files (e.g. ".xhtml"). +# html_file_suffix = None + +# Output file base name for HTML help builder. +htmlhelp_basename = "datatree_doc" + + +# -- Options for LaTeX output -------------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # 'preamble': '', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass [howto/manual]). +latex_documents = [ + ("index", "datatree.tex", "Datatree Documentation", author, "manual") +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +# latex_logo = None + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +# latex_use_parts = False + +# If true, show page references after internal links. +# latex_show_pagerefs = False + +# If true, show URL addresses after external links. +# latex_show_urls = False + +# Documents to append as an appendix to all manuals. +# latex_appendices = [] + +# If false, no module index is generated. +# latex_domain_indices = True + + +# -- Options for manual page output -------------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [("index", "datatree", "Datatree Documentation", [author], 1)] + +# If true, show URL addresses after external links. +# man_show_urls = False + + +# -- Options for Texinfo output ------------------------------------------------ + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ( + "index", + "datatree", + "Datatree Documentation", + author, + "datatree", + "Tree-like hierarchical data structure for xarray.", + "Miscellaneous", + ) +] + +# Documents to append as an appendix to all manuals. +# texinfo_appendices = [] + +# If false, no module index is generated. +# texinfo_domain_indices = True + +# How to display URL addresses: 'footnote', 'no', or 'inline'. +# texinfo_show_urls = 'footnote' + +# If true, do not generate a @detailmenu in the "Top" node's menu. +# texinfo_no_detailmenu = False + + +# based on numpy doc/source/conf.py +def linkcode_resolve(domain, info): + """ + Determine the URL corresponding to Python object + """ + if domain != "py": + return None + + modname = info["module"] + fullname = info["fullname"] + + submod = sys.modules.get(modname) + if submod is None: + return None + + obj = submod + for part in fullname.split("."): + try: + obj = getattr(obj, part) + except AttributeError: + return None + + try: + fn = inspect.getsourcefile(inspect.unwrap(obj)) + except TypeError: + fn = None + if not fn: + return None + + try: + source, lineno = inspect.getsourcelines(obj) + except OSError: + lineno = None + + if lineno: + linespec = f"#L{lineno}-L{lineno + len(source) - 1}" + else: + linespec = "" + + fn = os.path.relpath(fn, start=os.path.dirname(datatree.__file__)) + + if "+" in datatree.__version__: + return f"https://github.com/xarray-contrib/datatree/blob/main/datatree/{fn}{linespec}" + else: + return ( + f"https://github.com/xarray-contrib/datatree/blob/" + f"v{datatree.__version__}/datatree/{fn}{linespec}" + ) diff --git a/xarray/datatree_/docs/source/contributing.rst b/xarray/datatree_/docs/source/contributing.rst new file mode 100644 index 00000000000..b070c07c867 --- /dev/null +++ b/xarray/datatree_/docs/source/contributing.rst @@ -0,0 +1,136 @@ +======================== +Contributing to Datatree +======================== + +Contributions are highly welcomed and appreciated. Every little help counts, +so do not hesitate! + +.. contents:: Contribution links + :depth: 2 + +.. _submitfeedback: + +Feature requests and feedback +----------------------------- + +Do you like Datatree? Share some love on Twitter or in your blog posts! + +We'd also like to hear about your propositions and suggestions. Feel free to +`submit them as issues `_ and: + +* Explain in detail how they should work. +* Keep the scope as narrow as possible. This will make it easier to implement. + +.. _reportbugs: + +Report bugs +----------- + +Report bugs for Datatree in the `issue tracker `_. + +If you are reporting a bug, please include: + +* Your operating system name and version. +* Any details about your local setup that might be helpful in troubleshooting, + specifically the Python interpreter version, installed libraries, and Datatree + version. +* Detailed steps to reproduce the bug. + +If you can write a demonstration test that currently fails but should pass +(xfail), that is a very useful commit to make as well, even if you cannot +fix the bug itself. + +.. _fixbugs: + +Fix bugs +-------- + +Look through the `GitHub issues for bugs `_. + +Talk to developers to find out how you can fix specific bugs. + +Write documentation +------------------- + +Datatree could always use more documentation. What exactly is needed? + +* More complementary documentation. Have you perhaps found something unclear? +* Docstrings. There can never be too many of them. +* Blog posts, articles and such -- they're all very appreciated. + +You can also edit documentation files directly in the GitHub web interface, +without using a local copy. This can be convenient for small fixes. + +To build the documentation locally, you first need to install the following +tools: + +- `Sphinx `__ +- `sphinx_rtd_theme `__ +- `sphinx-autosummary-accessors `__ + +You can then build the documentation with the following commands:: + + $ cd docs + $ make html + +The built documentation should be available in the ``docs/_build/`` folder. + +.. _`pull requests`: +.. _pull-requests: + +Preparing Pull Requests +----------------------- + +#. Fork the + `Datatree GitHub repository `__. It's + fine to use ``Datatree`` as your fork repository name because it will live + under your user. + +#. Clone your fork locally using `git `_ and create a branch:: + + $ git clone git@github.com:{YOUR_GITHUB_USERNAME}/Datatree.git + $ cd Datatree + + # now, to fix a bug or add feature create your own branch off "master": + + $ git checkout -b your-bugfix-feature-branch-name master + +#. Install `pre-commit `_ and its hook on the Datatree repo:: + + $ pip install --user pre-commit + $ pre-commit install + + Afterwards ``pre-commit`` will run whenever you commit. + + https://pre-commit.com/ is a framework for managing and maintaining multi-language pre-commit hooks + to ensure code-style and code formatting is consistent. + +#. Install dependencies into a new conda environment:: + + $ conda env update -f ci/environment.yml + +#. Run all the tests + + Now running tests is as simple as issuing this command:: + + $ conda activate datatree-dev + $ pytest --junitxml=test-reports/junit.xml --cov=./ --verbose + + This command will run tests via the "pytest" tool. + +#. You can now edit your local working copy and run the tests again as necessary. Please follow PEP-8 for naming. + + When committing, ``pre-commit`` will re-format the files if necessary. + +#. Commit and push once your tests pass and you are happy with your change(s):: + + $ git commit -a -m "" + $ git push -u + +#. Finally, submit a pull request through the GitHub website using this data:: + + head-fork: YOUR_GITHUB_USERNAME/Datatree + compare: your-branch-name + + base-fork: TomNicholas/datatree + base: master diff --git a/xarray/datatree_/docs/source/data-structures.rst b/xarray/datatree_/docs/source/data-structures.rst new file mode 100644 index 00000000000..02e4a31f688 --- /dev/null +++ b/xarray/datatree_/docs/source/data-structures.rst @@ -0,0 +1,197 @@ +.. currentmodule:: datatree + +.. _data structures: + +Data Structures +=============== + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + import datatree + + np.random.seed(123456) + np.set_printoptions(threshold=10) + + %xmode minimal + +.. note:: + + This page builds on the information given in xarray's main page on + `data structures `_, so it is suggested that you + are familiar with those first. + +DataTree +-------- + +:py:class:`DataTree` is xarray's highest-level data structure, able to organise heterogeneous data which +could not be stored inside a single :py:class:`Dataset` object. This includes representing the recursive structure of multiple +`groups`_ within a netCDF file or `Zarr Store`_. + +.. _groups: https://www.unidata.ucar.edu/software/netcdf/workshops/2011/groups-types/GroupsIntro.html +.. _Zarr Store: https://zarr.readthedocs.io/en/stable/tutorial.html#groups + +Each ``DataTree`` object (or "node") contains the same data that a single ``xarray.Dataset`` would (i.e. ``DataArray`` objects +stored under hashable keys), and so has the same key properties: + +- ``dims``: a dictionary mapping of dimension names to lengths, for the variables in this node, +- ``data_vars``: a dict-like container of DataArrays corresponding to variables in this node, +- ``coords``: another dict-like container of DataArrays, corresponding to coordinate variables in this node, +- ``attrs``: dict to hold arbitary metadata relevant to data in this node. + +A single ``DataTree`` object acts much like a single ``Dataset`` object, and has a similar set of dict-like methods +defined upon it. However, ``DataTree``'s can also contain other ``DataTree`` objects, so they can be thought of as nested dict-like +containers of both ``xarray.DataArray``'s and ``DataTree``'s. + +A single datatree object is known as a "node", and its position relative to other nodes is defined by two more key +properties: + +- ``children``: An ordered dictionary mapping from names to other ``DataTree`` objects, known as its' "child nodes". +- ``parent``: The single ``DataTree`` object whose children this datatree is a member of, known as its' "parent node". + +Each child automatically knows about its parent node, and a node without a parent is known as a "root" node +(represented by the ``parent`` attribute pointing to ``None``). +Nodes can have multiple children, but as each child node has at most one parent, there can only ever be one root node in a given tree. + +The overall structure is technically a `connected acyclic undirected rooted graph`, otherwise known as a +`"Tree" `_. + +.. note:: + + Technically a ``DataTree`` with more than one child node forms an `"Ordered Tree" `_, + because the children are stored in an Ordered Dictionary. However, this distinction only really matters for a few + edge cases involving operations on multiple trees simultaneously, and can safely be ignored by most users. + + +``DataTree`` objects can also optionally have a ``name`` as well as ``attrs``, just like a ``DataArray``. +Again these are not normally used unless explicitly accessed by the user. + + +.. _creating a datatree: + +Creating a DataTree +~~~~~~~~~~~~~~~~~~~ + +One way to create a ``DataTree`` from scratch is to create each node individually, +specifying the nodes' relationship to one another as you create each one. + +The ``DataTree`` constructor takes: + +- ``data``: The data that will be stored in this node, represented by a single ``xarray.Dataset``, or a named ``xarray.DataArray``. +- ``parent``: The parent node (if there is one), given as a ``DataTree`` object. +- ``children``: The various child nodes (if there are any), given as a mapping from string keys to ``DataTree`` objects. +- ``name``: A string to use as the name of this node. + +Let's make a single datatree node with some example data in it: + +.. ipython:: python + + from datatree import DataTree + + ds1 = xr.Dataset({"foo": "orange"}) + dt = DataTree(name="root", data=ds1) # create root node + + dt + +At this point our node is also the root node, as every tree has a root node. + +We can add a second node to this tree either by referring to the first node in the constructor of the second: + +.. ipython:: python + + ds2 = xr.Dataset({"bar": 0}, coords={"y": ("y", [0, 1, 2])}) + # add a child by referring to the parent node + node2 = DataTree(name="a", parent=dt, data=ds2) + +or by dynamically updating the attributes of one node to refer to another: + +.. ipython:: python + + # add a second child by first creating a new node ... + ds3 = xr.Dataset({"zed": np.NaN}) + node3 = DataTree(name="b", data=ds3) + # ... then updating its .parent property + node3.parent = dt + +Our tree now has three nodes within it: + +.. ipython:: python + + dt + +It is at tree construction time that consistency checks are enforced. For instance, if we try to create a `cycle` the constructor will raise an error: + +.. ipython:: python + :okexcept: + + dt.parent = node3 + +Alternatively you can also create a ``DataTree`` object from + +- An ``xarray.Dataset`` using ``Dataset.to_node()`` (not yet implemented), +- A dictionary mapping directory-like paths to either ``DataTree`` nodes or data, using :py:meth:`DataTree.from_dict()`, +- A netCDF or Zarr file on disk with :py:func:`open_datatree()`. See :ref:`reading and writing files `. + + +DataTree Contents +~~~~~~~~~~~~~~~~~ + +Like ``xarray.Dataset``, ``DataTree`` implements the python mapping interface, but with values given by either ``xarray.DataArray`` objects or other ``DataTree`` objects. + +.. ipython:: python + + dt["a"] + dt["foo"] + +Iterating over keys will iterate over both the names of variables and child nodes. + +We can also access all the data in a single node through a dataset-like view + +.. ipython:: python + + dt["a"].ds + +This demonstrates the fact that the data in any one node is equivalent to the contents of a single ``xarray.Dataset`` object. +The ``DataTree.ds`` property returns an immutable view, but we can instead extract the node's data contents as a new (and mutable) +``xarray.Dataset`` object via :py:meth:`DataTree.to_dataset()`: + +.. ipython:: python + + dt["a"].to_dataset() + +Like with ``Dataset``, you can access the data and coordinate variables of a node separately via the ``data_vars`` and ``coords`` attributes: + +.. ipython:: python + + dt["a"].data_vars + dt["a"].coords + + +Dictionary-like methods +~~~~~~~~~~~~~~~~~~~~~~~ + +We can update a datatree in-place using Python's standard dictionary syntax, similar to how we can for Dataset objects. +For example, to create this example datatree from scratch, we could have written: + +# TODO update this example using ``.coords`` and ``.data_vars`` as setters, + +.. ipython:: python + + dt = DataTree(name="root") + dt["foo"] = "orange" + dt["a"] = DataTree(data=xr.Dataset({"bar": 0}, coords={"y": ("y", [0, 1, 2])})) + dt["a/b/zed"] = np.NaN + dt + +To change the variables in a node of a ``DataTree``, you can use all the standard dictionary +methods, including ``values``, ``items``, ``__delitem__``, ``get`` and +:py:meth:`DataTree.update`. +Note that assigning a ``DataArray`` object to a ``DataTree`` variable using ``__setitem__`` or ``update`` will +:ref:`automatically align ` the array(s) to the original node's indexes. + +If you copy a ``DataTree`` using the :py:func:`copy` function or the :py:meth:`DataTree.copy` method it will copy the subtree, +meaning that node and children below it, but no parents above it. +Like for ``Dataset``, this copy is shallow by default, but you can copy all the underlying data arrays by calling ``dt.copy(deep=True)``. diff --git a/xarray/datatree_/docs/source/hierarchical-data.rst b/xarray/datatree_/docs/source/hierarchical-data.rst new file mode 100644 index 00000000000..d4f58847718 --- /dev/null +++ b/xarray/datatree_/docs/source/hierarchical-data.rst @@ -0,0 +1,639 @@ +.. currentmodule:: datatree + +.. _hierarchical-data: + +Working With Hierarchical Data +============================== + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + from datatree import DataTree + + np.random.seed(123456) + np.set_printoptions(threshold=10) + + %xmode minimal + +Why Hierarchical Data? +---------------------- + +Many real-world datasets are composed of multiple differing components, +and it can often be be useful to think of these in terms of a hierarchy of related groups of data. +Examples of data which one might want organise in a grouped or hierarchical manner include: + +- Simulation data at multiple resolutions, +- Observational data about the same system but from multiple different types of sensors, +- Mixed experimental and theoretical data, +- A systematic study recording the same experiment but with different parameters, +- Heterogenous data, such as demographic and metereological data, + +or even any combination of the above. + +Often datasets like this cannot easily fit into a single :py:class:`xarray.Dataset` object, +or are more usefully thought of as groups of related ``xarray.Dataset`` objects. +For this purpose we provide the :py:class:`DataTree` class. + +This page explains in detail how to understand and use the different features of the :py:class:`DataTree` class for your own hierarchical data needs. + +.. _node relationships: + +Node Relationships +------------------ + +.. _creating a family tree: + +Creating a Family Tree +~~~~~~~~~~~~~~~~~~~~~~ + +The three main ways of creating a ``DataTree`` object are described briefly in :ref:`creating a datatree`. +Here we go into more detail about how to create a tree node-by-node, using a famous family tree from the Simpsons cartoon as an example. + +Let's start by defining nodes representing the two siblings, Bart and Lisa Simpson: + +.. ipython:: python + + bart = DataTree(name="Bart") + lisa = DataTree(name="Lisa") + +Each of these node objects knows their own :py:class:`~DataTree.name`, but they currently have no relationship to one another. +We can connect them by creating another node representing a common parent, Homer Simpson: + +.. ipython:: python + + homer = DataTree(name="Homer", children={"Bart": bart, "Lisa": lisa}) + +Here we set the children of Homer in the node's constructor. +We now have a small family tree + +.. ipython:: python + + homer + +where we can see how these individual Simpson family members are related to one another. +The nodes representing Bart and Lisa are now connected - we can confirm their sibling rivalry by examining the :py:class:`~DataTree.siblings` property: + +.. ipython:: python + + list(bart.siblings) + +But oops, we forgot Homer's third daughter, Maggie! Let's add her by updating Homer's :py:class:`~DataTree.children` property to include her: + +.. ipython:: python + + maggie = DataTree(name="Maggie") + homer.children = {"Bart": bart, "Lisa": lisa, "Maggie": maggie} + homer + +Let's check that Maggie knows who her Dad is: + +.. ipython:: python + + maggie.parent.name + +That's good - updating the properties of our nodes does not break the internal consistency of our tree, as changes of parentage are automatically reflected on both nodes. + + These children obviously have another parent, Marge Simpson, but ``DataTree`` nodes can only have a maximum of one parent. + Genealogical `family trees are not even technically trees `_ in the mathematical sense - + the fact that distant relatives can mate makes it a directed acyclic graph. + Trees of ``DataTree`` objects cannot represent this. + +Homer is currently listed as having no parent (the so-called "root node" of this tree), but we can update his :py:class:`~DataTree.parent` property: + +.. ipython:: python + + abe = DataTree(name="Abe") + homer.parent = abe + +Abe is now the "root" of this tree, which we can see by examining the :py:class:`~DataTree.root` property of any node in the tree + +.. ipython:: python + + maggie.root.name + +We can see the whole tree by printing Abe's node or just part of the tree by printing Homer's node: + +.. ipython:: python + + abe + homer + +We can see that Homer is aware of his parentage, and we say that Homer and his children form a "subtree" of the larger Simpson family tree. + +In episode 28, Abe Simpson reveals that he had another son, Herbert "Herb" Simpson. +We can add Herbert to the family tree without displacing Homer by :py:meth:`~DataTree.assign`-ing another child to Abe: + +.. ipython:: python + + herbert = DataTree(name="Herb") + abe.assign({"Herbert": herbert}) + +.. note:: + This example shows a minor subtlety - the returned tree has Homer's brother listed as ``"Herbert"``, + but the original node was named "Herbert". Not only are names overriden when stored as keys like this, + but the new node is a copy, so that the original node that was reference is unchanged (i.e. ``herbert.name == "Herb"`` still). + In other words, nodes are copied into trees, not inserted into them. + This is intentional, and mirrors the behaviour when storing named ``xarray.DataArray`` objects inside datasets. + +Certain manipulations of our tree are forbidden, if they would create an inconsistent result. +In episode 51 of the show Futurama, Philip J. Fry travels back in time and accidentally becomes his own Grandfather. +If we try similar time-travelling hijinks with Homer, we get a :py:class:`InvalidTreeError` raised: + +.. ipython:: python + :okexcept: + + abe.parent = homer + +.. _evolutionary tree: + +Ancestry in an Evolutionary Tree +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Let's use a different example of a tree to discuss more complex relationships between nodes - the phylogenetic tree, or tree of life. + +.. ipython:: python + + vertebrates = DataTree.from_dict( + name="Vertebrae", + d={ + "/Sharks": None, + "/Bony Skeleton/Ray-finned Fish": None, + "/Bony Skeleton/Four Limbs/Amphibians": None, + "/Bony Skeleton/Four Limbs/Amniotic Egg/Hair/Primates": None, + "/Bony Skeleton/Four Limbs/Amniotic Egg/Hair/Rodents & Rabbits": None, + "/Bony Skeleton/Four Limbs/Amniotic Egg/Two Fenestrae/Dinosaurs": None, + "/Bony Skeleton/Four Limbs/Amniotic Egg/Two Fenestrae/Birds": None, + }, + ) + + primates = vertebrates["/Bony Skeleton/Four Limbs/Amniotic Egg/Hair/Primates"] + dinosaurs = vertebrates[ + "/Bony Skeleton/Four Limbs/Amniotic Egg/Two Fenestrae/Dinosaurs" + ] + +We have used the :py:meth:`~DataTree.from_dict` constructor method as an alternate way to quickly create a whole tree, +and :ref:`filesystem paths` (to be explained shortly) to select two nodes of interest. + +.. ipython:: python + + vertebrates + +This tree shows various families of species, grouped by their common features (making it technically a `"Cladogram" `_, +rather than an evolutionary tree). + +Here both the species and the features used to group them are represented by ``DataTree`` node objects - there is no distinction in types of node. +We can however get a list of only the nodes we used to represent species by using the fact that all those nodes have no children - they are "leaf nodes". +We can check if a node is a leaf with :py:meth:`~DataTree.is_leaf`, and get a list of all leaves with the :py:class:`~DataTree.leaves` property: + +.. ipython:: python + + primates.is_leaf + [node.name for node in vertebrates.leaves] + +Pretending that this is a true evolutionary tree for a moment, we can find the features of the evolutionary ancestors (so-called "ancestor" nodes), +the distinguishing feature of the common ancestor of all vertebrate life (the root node), +and even the distinguishing feature of the common ancestor of any two species (the common ancestor of two nodes): + +.. ipython:: python + + [node.name for node in primates.ancestors] + primates.root.name + primates.find_common_ancestor(dinosaurs).name + +We can only find a common ancestor between two nodes that lie in the same tree. +If we try to find the common evolutionary ancestor between primates and an Alien species that has no relationship to Earth's evolutionary tree, +an error will be raised. + +.. ipython:: python + :okexcept: + + alien = DataTree(name="Xenomorph") + primates.find_common_ancestor(alien) + + +.. _navigating trees: + +Navigating Trees +---------------- + +There are various ways to access the different nodes in a tree. + +Properties +~~~~~~~~~~ + +We can navigate trees using the :py:class:`~DataTree.parent` and :py:class:`~DataTree.children` properties of each node, for example: + +.. ipython:: python + + lisa.parent.children["Bart"].name + +but there are also more convenient ways to access nodes. + +Dictionary-like interface +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Children are stored on each node as a key-value mapping from name to child node. +They can be accessed and altered via the :py:class:`~DataTree.__getitem__` and :py:class:`~DataTree.__setitem__` syntax. +In general :py:class:`~DataTree.DataTree` objects support almost the entire set of dict-like methods, +including :py:meth:`~DataTree.keys`, :py:class:`~DataTree.values`, :py:class:`~DataTree.items`, +:py:meth:`~DataTree.__delitem__` and :py:meth:`~DataTree.update`. + +.. ipython:: python + + vertebrates["Bony Skeleton"]["Ray-finned Fish"] + +Note that the dict-like interface combines access to child ``DataTree`` nodes and stored ``DataArrays``, +so if we have a node that contains both children and data, calling :py:meth:`~DataTree.keys` will list both names of child nodes and +names of data variables: + +.. ipython:: python + + dt = DataTree( + data=xr.Dataset({"foo": 0, "bar": 1}), + children={"a": DataTree(), "b": DataTree()}, + ) + print(dt) + list(dt.keys()) + +This also means that the names of variables and of child nodes must be different to one another. + +Attribute-like access +~~~~~~~~~~~~~~~~~~~~~ + +You can also select both variables and child nodes through dot indexing + +.. ipython:: python + + dt.foo + dt.a + +.. _filesystem paths: + +Filesystem-like Paths +~~~~~~~~~~~~~~~~~~~~~ + +Hierarchical trees can be thought of as analogous to file systems. +Each node is like a directory, and each directory can contain both more sub-directories and data. + +.. note:: + + You can even make the filesystem analogy concrete by using :py:func:`~DataTree.open_mfdatatree` or :py:func:`~DataTree.save_mfdatatree` # TODO not yet implemented - see GH issue 51 + +Datatree objects support a syntax inspired by unix-like filesystems, +where the "path" to a node is specified by the keys of each intermediate node in sequence, +separated by forward slashes. +This is an extension of the conventional dictionary ``__getitem__`` syntax to allow navigation across multiple levels of the tree. + +Like with filepaths, paths within the tree can either be relative to the current node, e.g. + +.. ipython:: python + + abe["Homer/Bart"].name + abe["./Homer/Bart"].name # alternative syntax + +or relative to the root node. +A path specified from the root (as opposed to being specified relative to an arbitrary node in the tree) is sometimes also referred to as a +`"fully qualified name" `_, +or as an "absolute path". +The root node is referred to by ``"/"``, so the path from the root node to its grand-child would be ``"/child/grandchild"``, e.g. + +.. ipython:: python + + # absolute path will start from root node + lisa["/Homer/Bart"].name + +Relative paths between nodes also support the ``"../"`` syntax to mean the parent of the current node. +We can use this with ``__setitem__`` to add a missing entry to our evolutionary tree, but add it relative to a more familiar node of interest: + +.. ipython:: python + + primates["../../Two Fenestrae/Crocodiles"] = DataTree() + print(vertebrates) + +Given two nodes in a tree, we can also find their relative path: + +.. ipython:: python + + bart.relative_to(lisa) + +You can use this filepath feature to build a nested tree from a dictionary of filesystem-like paths and corresponding ``xarray.Dataset`` objects in a single step. +If we have a dictionary where each key is a valid path, and each value is either valid data or ``None``, +we can construct a complex tree quickly using the alternative constructor :py:meth:`DataTree.from_dict()`: + +.. ipython:: python + + d = { + "/": xr.Dataset({"foo": "orange"}), + "/a": xr.Dataset({"bar": 0}, coords={"y": ("y", [0, 1, 2])}), + "/a/b": xr.Dataset({"zed": np.NaN}), + "a/c/d": None, + } + dt = DataTree.from_dict(d) + dt + +.. note:: + + Notice that using the path-like syntax will also create any intermediate empty nodes necessary to reach the end of the specified path + (i.e. the node labelled `"c"` in this case.) + This is to help avoid lots of redundant entries when creating deeply-nested trees using :py:meth:`DataTree.from_dict`. + +.. _iterating over trees: + +Iterating over trees +~~~~~~~~~~~~~~~~~~~~ + +You can iterate over every node in a tree using the subtree :py:class:`~DataTree.subtree` property. +This returns an iterable of nodes, which yields them in depth-first order. + +.. ipython:: python + + for node in vertebrates.subtree: + print(node.path) + +A very useful pattern is to use :py:class:`~DataTree.subtree` conjunction with the :py:class:`~DataTree.path` property to manipulate the nodes however you wish, +then rebuild a new tree using :py:meth:`DataTree.from_dict()`. + +For example, we could keep only the nodes containing data by looping over all nodes, +checking if they contain any data using :py:class:`~DataTree.has_data`, +then rebuilding a new tree using only the paths of those nodes: + +.. ipython:: python + + non_empty_nodes = {node.path: node.ds for node in dt.subtree if node.has_data} + DataTree.from_dict(non_empty_nodes) + +You can see this tree is similar to the ``dt`` object above, except that it is missing the empty nodes ``a/c`` and ``a/c/d``. + +(If you want to keep the name of the root node, you will need to add the ``name`` kwarg to :py:class:`from_dict`, i.e. ``DataTree.from_dict(non_empty_nodes, name=dt.root.name)``.) + +.. _manipulating trees: + +Manipulating Trees +------------------ + +Subsetting Tree Nodes +~~~~~~~~~~~~~~~~~~~~~ + +We can subset our tree to select only nodes of interest in various ways. + +Similarly to on a real filesystem, matching nodes by common patterns in their paths is often useful. +We can use :py:meth:`DataTree.match` for this: + +.. ipython:: python + + dt = DataTree.from_dict( + { + "/a/A": None, + "/a/B": None, + "/b/A": None, + "/b/B": None, + } + ) + result = dt.match("*/B") + result + +We can also subset trees by the contents of the nodes. +:py:meth:`DataTree.filter` retains only the nodes of a tree that meet a certain condition. +For example, we could recreate the Simpson's family tree with the ages of each individual, then filter for only the adults: +First lets recreate the tree but with an `age` data variable in every node: + +.. ipython:: python + + simpsons = DataTree.from_dict( + d={ + "/": xr.Dataset({"age": 83}), + "/Herbert": xr.Dataset({"age": 40}), + "/Homer": xr.Dataset({"age": 39}), + "/Homer/Bart": xr.Dataset({"age": 10}), + "/Homer/Lisa": xr.Dataset({"age": 8}), + "/Homer/Maggie": xr.Dataset({"age": 1}), + }, + name="Abe", + ) + simpsons + +Now let's filter out the minors: + +.. ipython:: python + + simpsons.filter(lambda node: node["age"] > 18) + +The result is a new tree, containing only the nodes matching the condition. + +(Yes, under the hood :py:meth:`~DataTree.filter` is just syntactic sugar for the pattern we showed you in :ref:`iterating over trees` !) + +.. _Tree Contents: + +Tree Contents +------------- + +Hollow Trees +~~~~~~~~~~~~ + +A concept that can sometimes be useful is that of a "Hollow Tree", which means a tree with data stored only at the leaf nodes. +This is useful because certain useful tree manipulation operations only make sense for hollow trees. + +You can check if a tree is a hollow tree by using the :py:class:`~DataTree.is_hollow` property. +We can see that the Simpson's family is not hollow because the data variable ``"age"`` is present at some nodes which +have children (i.e. Abe and Homer). + +.. ipython:: python + + simpsons.is_hollow + +.. _tree computation: + +Computation +----------- + +`DataTree` objects are also useful for performing computations, not just for organizing data. + +Operations and Methods on Trees +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To show how applying operations across a whole tree at once can be useful, +let's first create a example scientific dataset. + +.. ipython:: python + + def time_stamps(n_samples, T): + """Create an array of evenly-spaced time stamps""" + return xr.DataArray( + data=np.linspace(0, 2 * np.pi * T, n_samples), dims=["time"] + ) + + + def signal_generator(t, f, A, phase): + """Generate an example electrical-like waveform""" + return A * np.sin(f * t.data + phase) + + + time_stamps1 = time_stamps(n_samples=15, T=1.5) + time_stamps2 = time_stamps(n_samples=10, T=1.0) + + voltages = DataTree.from_dict( + { + "/oscilloscope1": xr.Dataset( + { + "potential": ( + "time", + signal_generator(time_stamps1, f=2, A=1.2, phase=0.5), + ), + "current": ( + "time", + signal_generator(time_stamps1, f=2, A=1.2, phase=1), + ), + }, + coords={"time": time_stamps1}, + ), + "/oscilloscope2": xr.Dataset( + { + "potential": ( + "time", + signal_generator(time_stamps2, f=1.6, A=1.6, phase=0.2), + ), + "current": ( + "time", + signal_generator(time_stamps2, f=1.6, A=1.6, phase=0.7), + ), + }, + coords={"time": time_stamps2}, + ), + } + ) + voltages + +Most xarray computation methods also exist as methods on datatree objects, +so you can for example take the mean value of these two timeseries at once: + +.. ipython:: python + + voltages.mean(dim="time") + +This works by mapping the standard :py:meth:`xarray.Dataset.mean()` method over the dataset stored in each node of the +tree one-by-one. + +The arguments passed to the method are used for every node, so the values of the arguments you pass might be valid for one node and invalid for another + +.. ipython:: python + :okexcept: + + voltages.isel(time=12) + +Notice that the error raised helpfully indicates which node of the tree the operation failed on. + +Arithmetic Methods on Trees +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Arithmetic methods are also implemented, so you can e.g. add a scalar to every dataset in the tree at once. +For example, we can advance the timeline of the Simpsons by a decade just by + +.. ipython:: python + + simpsons + 10 + +See that the same change (fast-forwarding by adding 10 years to the age of each character) has been applied to every node. + +Mapping Custom Functions Over Trees +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can map custom computation over each node in a tree using :py:meth:`DataTree.map_over_subtree`. +You can map any function, so long as it takes `xarray.Dataset` objects as one (or more) of the input arguments, +and returns one (or more) xarray datasets. + +.. note:: + + Functions passed to :py:func:`map_over_subtree` cannot alter nodes in-place. + Instead they must return new `xarray.Dataset` objects. + +For example, we can define a function to calculate the Root Mean Square of a timeseries + +.. ipython:: python + + def rms(signal): + return np.sqrt(np.mean(signal**2)) + +Then calculate the RMS value of these signals: + +.. ipython:: python + + voltages.map_over_subtree(rms) + +.. _multiple trees: + +We can also use the :py:func:`map_over_subtree` decorator to promote a function which accepts datasets into one which +accepts datatrees. + +Operating on Multiple Trees +--------------------------- + +The examples so far have involved mapping functions or methods over the nodes of a single tree, +but we can generalize this to mapping functions over multiple trees at once. + +Comparing Trees for Isomorphism +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For it to make sense to map a single non-unary function over the nodes of multiple trees at once, +each tree needs to have the same structure. Specifically two trees can only be considered similar, or "isomorphic", +if they have the same number of nodes, and each corresponding node has the same number of children. +We can check if any two trees are isomorphic using the :py:meth:`DataTree.isomorphic` method. + +.. ipython:: python + :okexcept: + + dt1 = DataTree.from_dict({"a": None, "a/b": None}) + dt2 = DataTree.from_dict({"a": None}) + dt1.isomorphic(dt2) + + dt3 = DataTree.from_dict({"a": None, "b": None}) + dt1.isomorphic(dt3) + + dt4 = DataTree.from_dict({"A": None, "A/B": xr.Dataset({"foo": 1})}) + dt1.isomorphic(dt4) + +If the trees are not isomorphic a :py:class:`~TreeIsomorphismError` will be raised. +Notice that corresponding tree nodes do not need to have the same name or contain the same data in order to be considered isomorphic. + +Arithmetic Between Multiple Trees +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Arithmetic operations like multiplication are binary operations, so as long as we have two isomorphic trees, +we can do arithmetic between them. + +.. ipython:: python + + currents = DataTree.from_dict( + { + "/oscilloscope1": xr.Dataset( + { + "current": ( + "time", + signal_generator(time_stamps1, f=2, A=1.2, phase=1), + ), + }, + coords={"time": time_stamps1}, + ), + "/oscilloscope2": xr.Dataset( + { + "current": ( + "time", + signal_generator(time_stamps2, f=1.6, A=1.6, phase=0.7), + ), + }, + coords={"time": time_stamps2}, + ), + } + ) + currents + + currents.isomorphic(voltages) + +We could use this feature to quickly calculate the electrical power in our signal, P=IV. + +.. ipython:: python + + power = currents * voltages + power diff --git a/xarray/datatree_/docs/source/index.rst b/xarray/datatree_/docs/source/index.rst new file mode 100644 index 00000000000..a88a5747ada --- /dev/null +++ b/xarray/datatree_/docs/source/index.rst @@ -0,0 +1,61 @@ +.. currentmodule:: datatree + +Datatree +======== + +**Datatree is a prototype implementation of a tree-like hierarchical data structure for xarray.** + +Why Datatree? +~~~~~~~~~~~~~ + +Datatree was born after the xarray team recognised a `need for a new hierarchical data structure `_, +that was more flexible than a single :py:class:`xarray.Dataset` object. +The initial motivation was to represent netCDF files / Zarr stores with multiple nested groups in a single in-memory object, +but :py:class:`~datatree.DataTree` objects have many other uses. + +You might want to use datatree for: + +- Organising many related datasets, e.g. results of the same experiment with different parameters, or simulations of the same system using different models, +- Analysing similar data at multiple resolutions simultaneously, such as when doing a convergence study, +- Comparing heterogenous but related data, such as experimental and theoretical data, +- I/O with nested data formats such as netCDF / Zarr groups. + +Development Roadmap +~~~~~~~~~~~~~~~~~~~ + +Datatree currently lives in a separate repository to the main xarray package. +This allows the datatree developers to make changes to it, experiment, and improve it faster. + +Eventually we plan to fully integrate datatree upstream into xarray's main codebase, at which point the `github.com/xarray-contrib/datatree `_ repository will be archived. +This should not cause much disruption to code that depends on datatree - you will likely only have to change the import line (i.e. from ``from datatree import DataTree`` to ``from xarray import DataTree``). + +However, until this full integration occurs, datatree's API should not be considered to have the same `level of stability as xarray's `_. + +User Feedback +~~~~~~~~~~~~~ + +We really really really want to hear your opinions on datatree! +At this point in development, user feedback is critical to help us create something that will suit everyone's needs. +Please raise any thoughts, issues, suggestions or bugs, no matter how small or large, on the `github issue tracker `_. + +.. toctree:: + :maxdepth: 2 + :caption: Documentation Contents + + Installation + Quick Overview + Tutorial + Data Model + Hierarchical Data + Reading and Writing Files + API Reference + Terminology + Contributing Guide + What's New + GitHub repository + +Feedback +-------- + +If you encounter any errors, problems with **Datatree**, or have any suggestions, please open an issue +on `GitHub `_. diff --git a/xarray/datatree_/docs/source/installation.rst b/xarray/datatree_/docs/source/installation.rst new file mode 100644 index 00000000000..b2682743ade --- /dev/null +++ b/xarray/datatree_/docs/source/installation.rst @@ -0,0 +1,38 @@ +.. currentmodule:: datatree + +============ +Installation +============ + +Datatree can be installed in three ways: + +Using the `conda `__ package manager that comes with the +Anaconda/Miniconda distribution: + +.. code:: bash + + $ conda install xarray-datatree --channel conda-forge + +Using the `pip `__ package manager: + +.. code:: bash + + $ python -m pip install xarray-datatree + +To install a development version from source: + +.. code:: bash + + $ git clone https://github.com/xarray-contrib/datatree + $ cd datatree + $ python -m pip install -e . + + +You will just need xarray as a required dependency, with netcdf4, zarr, and h5netcdf as optional dependencies to allow file I/O. + +.. note:: + + Datatree is very much still in the early stages of development. There may be functions that are present but whose + internals are not yet implemented, or significant changes to the API in future. + That said, if you try it out and find some behaviour that looks like a bug to you, please report it on the + `issue tracker `_! diff --git a/xarray/datatree_/docs/source/io.rst b/xarray/datatree_/docs/source/io.rst new file mode 100644 index 00000000000..2f2dabf9948 --- /dev/null +++ b/xarray/datatree_/docs/source/io.rst @@ -0,0 +1,54 @@ +.. currentmodule:: datatree + +.. _io: + +Reading and Writing Files +========================= + +.. note:: + + This page builds on the information given in xarray's main page on + `reading and writing files `_, + so it is suggested that you are familiar with those first. + + +netCDF +------ + +Groups +~~~~~~ + +Whilst netCDF groups can only be loaded individually as Dataset objects, a whole file of many nested groups can be loaded +as a single :py:class:`DataTree` object. +To open a whole netCDF file as a tree of groups use the :py:func:`open_datatree` function. +To save a DataTree object as a netCDF file containing many groups, use the :py:meth:`DataTree.to_netcdf` method. + + +.. _netcdf.group.warning: + +.. warning:: + ``DataTree`` objects do not follow the exact same data model as netCDF files, which means that perfect round-tripping + is not always possible. + + In particular in the netCDF data model dimensions are entities that can exist regardless of whether any variable possesses them. + This is in contrast to `xarray's data model `_ + (and hence :ref:`datatree's data model `) in which the dimensions of a (Dataset/Tree) + object are simply the set of dimensions present across all variables in that dataset. + + This means that if a netCDF file contains dimensions but no variables which possess those dimensions, + these dimensions will not be present when that file is opened as a DataTree object. + Saving this DataTree object to file will therefore not preserve these "unused" dimensions. + +Zarr +---- + +Groups +~~~~~~ + +Nested groups in zarr stores can be represented by loading the store as a :py:class:`DataTree` object, similarly to netCDF. +To open a whole zarr store as a tree of groups use the :py:func:`open_datatree` function. +To save a DataTree object as a zarr store containing many groups, use the :py:meth:`DataTree.to_zarr()` method. + +.. note:: + Note that perfect round-tripping should always be possible with a zarr store (:ref:`unlike for netCDF files `), + as zarr does not support "unused" dimensions. diff --git a/xarray/datatree_/docs/source/quick-overview.rst b/xarray/datatree_/docs/source/quick-overview.rst new file mode 100644 index 00000000000..4743b0899fa --- /dev/null +++ b/xarray/datatree_/docs/source/quick-overview.rst @@ -0,0 +1,84 @@ +.. currentmodule:: datatree + +############## +Quick overview +############## + +DataTrees +--------- + +:py:class:`DataTree` is a tree-like container of :py:class:`xarray.DataArray` objects, organised into multiple mutually alignable groups. +You can think of it like a (recursive) ``dict`` of :py:class:`xarray.Dataset` objects. + +Let's first make some example xarray datasets (following on from xarray's +`quick overview `_ page): + +.. ipython:: python + + import numpy as np + import xarray as xr + + data = xr.DataArray(np.random.randn(2, 3), dims=("x", "y"), coords={"x": [10, 20]}) + ds = xr.Dataset(dict(foo=data, bar=("x", [1, 2]), baz=np.pi)) + ds + + ds2 = ds.interp(coords={"x": [10, 12, 14, 16, 18, 20]}) + ds2 + + ds3 = xr.Dataset( + dict(people=["alice", "bob"], heights=("people", [1.57, 1.82])), + coords={"species": "human"}, + ) + ds3 + +Now we'll put this data into a multi-group tree: + +.. ipython:: python + + from datatree import DataTree + + dt = DataTree.from_dict({"simulation/coarse": ds, "simulation/fine": ds2, "/": ds3}) + dt + +This creates a datatree with various groups. We have one root group, containing information about individual people. +(This root group can be named, but here is unnamed, so is referred to with ``"/"``, same as the root of a unix-like filesystem.) +The root group then has one subgroup ``simulation``, which contains no data itself but does contain another two subgroups, +named ``fine`` and ``coarse``. + +The (sub-)sub-groups ``fine`` and ``coarse`` contain two very similar datasets. +They both have an ``"x"`` dimension, but the dimension is of different lengths in each group, which makes the data in each group unalignable. +In the root group we placed some completely unrelated information, showing how we can use a tree to store heterogenous data. + +The constraints on each group are therefore the same as the constraint on dataarrays within a single dataset. + +We created the sub-groups using a filesystem-like syntax, and accessing groups works the same way. +We can access individual dataarrays in a similar fashion + +.. ipython:: python + + dt["simulation/coarse/foo"] + +and we can also pull out the data in a particular group as a ``Dataset`` object using ``.ds``: + +.. ipython:: python + + dt["simulation/coarse"].ds + +Operations map over subtrees, so we can take a mean over the ``x`` dimension of both the ``fine`` and ``coarse`` groups just by + +.. ipython:: python + + avg = dt["simulation"].mean(dim="x") + avg + +Here the ``"x"`` dimension used is always the one local to that sub-group. + +You can do almost everything you can do with ``Dataset`` objects with ``DataTree`` objects +(including indexing and arithmetic), as operations will be mapped over every sub-group in the tree. +This allows you to work with multiple groups of non-alignable variables at once. + +.. note:: + + If all of your variables are mutually alignable + (i.e. they live on the same grid, such that every common dimension name maps to the same length), + then you probably don't need :py:class:`DataTree`, and should consider just sticking with ``xarray.Dataset``. diff --git a/xarray/datatree_/docs/source/terminology.rst b/xarray/datatree_/docs/source/terminology.rst new file mode 100644 index 00000000000..e481a01a6b2 --- /dev/null +++ b/xarray/datatree_/docs/source/terminology.rst @@ -0,0 +1,34 @@ +.. currentmodule:: datatree + +.. _terminology: + +This page extends `xarray's page on terminology `_. + +Terminology +=========== + +.. glossary:: + + DataTree + A tree-like collection of ``Dataset`` objects. A *tree* is made up of one or more *nodes*, + each of which can store the same information as a single ``Dataset`` (accessed via `.ds`). + This data is stored in the same way as in a ``Dataset``, i.e. in the form of data variables + (see **Variable** in the `corresponding xarray terminology page `_), + dimensions, coordinates, and attributes. + + The nodes in a tree are linked to one another, and each node is it's own instance of ``DataTree`` object. + Each node can have zero or more *children* (stored in a dictionary-like manner under their corresponding *names*), + and those child nodes can themselves have children. + If a node is a child of another node that other node is said to be its *parent*. Nodes can have a maximum of one parent, + and if a node has no parent it is said to be the *root* node of that *tree*. + + Subtree + A section of a *tree*, consisting of a *node* along with all the child nodes below it + (and the child nodes below them, i.e. all so-called *descendant* nodes). + Excludes the parent node and all nodes above. + + Group + Another word for a subtree, reflecting how the hierarchical structure of a ``DataTree`` allows for grouping related data together. + Analogous to a single + `netCDF group `_ or + `Zarr group `_. diff --git a/xarray/datatree_/docs/source/tutorial.rst b/xarray/datatree_/docs/source/tutorial.rst new file mode 100644 index 00000000000..6e33bd36f91 --- /dev/null +++ b/xarray/datatree_/docs/source/tutorial.rst @@ -0,0 +1,7 @@ +.. currentmodule:: datatree + +======== +Tutorial +======== + +Coming soon! diff --git a/xarray/datatree_/docs/source/whats-new.rst b/xarray/datatree_/docs/source/whats-new.rst new file mode 100644 index 00000000000..2f6e4f88fe5 --- /dev/null +++ b/xarray/datatree_/docs/source/whats-new.rst @@ -0,0 +1,426 @@ +.. currentmodule:: datatree + +What's New +========== + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xray + import xarray + import xarray as xr + import datatree + + np.random.seed(123456) + +.. _whats-new.v0.0.14: + +v0.0.14 (unreleased) +-------------------- + +New Features +~~~~~~~~~~~~ + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Renamed `DataTree.lineage` to `DataTree.parents` to match `pathlib` vocabulary + (:issue:`283`, :pull:`286`) +- Minimum required version of xarray is now 2023.12.0, i.e. the latest version. + This is required to prevent recent changes to xarray's internals from breaking datatree. + (:issue:`293`, :pull:`294`) + By `Tom Nicholas `_. +- Change default write mode of :py:meth:`DataTree.to_zarr` to ``'w-'`` to match ``xarray`` + default and prevent accidental directory overwrites. (:issue:`274`, :pull:`275`) + By `Sam Levang `_. + +Deprecations +~~~~~~~~~~~~ + +- Renamed `DataTree.lineage` to `DataTree.parents` to match `pathlib` vocabulary + (:issue:`283`, :pull:`286`). `lineage` is now deprecated and use of `parents` is encouraged. + By `Etienne Schalk `_. + +Bug fixes +~~~~~~~~~ +- Keep attributes on nodes containing no data in :py:func:`map_over_subtree`. (:issue:`278`, :pull:`279`) + By `Sam Levang `_. + +Documentation +~~~~~~~~~~~~~ +- Use ``napoleon`` instead of ``numpydoc`` to align with xarray documentation + (:issue:`284`, :pull:`298`). + By `Etienne Schalk `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +.. _whats-new.v0.0.13: + +v0.0.13 (27/10/2023) +-------------------- + +New Features +~~~~~~~~~~~~ + +- New :py:meth:`DataTree.match` method for glob-like pattern matching of node paths. (:pull:`267`) + By `Tom Nicholas `_. +- New :py:meth:`DataTree.is_hollow` property for checking if data is only contained at the leaf nodes. (:pull:`272`) + By `Tom Nicholas `_. +- Indicate which node caused the problem if error encountered while applying user function using :py:func:`map_over_subtree` + (:issue:`190`, :pull:`264`). Only works when using python 3.11 or later. + By `Tom Nicholas `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Nodes containing only attributes but no data are now ignored by :py:func:`map_over_subtree` (:issue:`262`, :pull:`263`) + By `Tom Nicholas `_. +- Disallow altering of given dataset inside function called by :py:func:`map_over_subtree` (:pull:`269`, reverts part of :pull:`194`). + By `Tom Nicholas `_. + +Bug fixes +~~~~~~~~~ + +- Fix unittests on i386. (:pull:`249`) + By `Antonio Valentino `_. +- Ensure nodepath class is compatible with python 3.12 (:pull:`260`) + By `Max Grover `_. + +Documentation +~~~~~~~~~~~~~ + +- Added new sections to page on ``Working with Hierarchical Data`` (:pull:`180`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +* No longer use the deprecated `distutils` package. + +.. _whats-new.v0.0.12: + +v0.0.12 (03/07/2023) +-------------------- + +New Features +~~~~~~~~~~~~ + +- Added a :py:func:`DataTree.level`, :py:func:`DataTree.depth`, and :py:func:`DataTree.width` property (:pull:`208`). + By `Tom Nicholas `_. +- Allow dot-style (or "attribute-like") access to child nodes and variables, with ipython autocomplete. (:issue:`189`, :pull:`98`) + By `Tom Nicholas `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +Deprecations +~~~~~~~~~~~~ + +- Dropped support for python 3.8 (:issue:`212`, :pull:`214`) + By `Tom Nicholas `_. + +Bug fixes +~~~~~~~~~ + +- Allow for altering of given dataset inside function called by :py:func:`map_over_subtree` (:issue:`188`, :pull:`194`). + By `Tom Nicholas `_. +- copy subtrees without creating ancestor nodes (:pull:`201`) + By `Justus Magin `_. + +Documentation +~~~~~~~~~~~~~ + +Internal Changes +~~~~~~~~~~~~~~~~ + +.. _whats-new.v0.0.11: + +v0.0.11 (01/09/2023) +-------------------- + +Big update with entirely new pages in the docs, +new methods (``.drop_nodes``, ``.filter``, ``.leaves``, ``.descendants``), and bug fixes! + +New Features +~~~~~~~~~~~~ + +- Added a :py:meth:`DataTree.drop_nodes` method (:issue:`161`, :pull:`175`). + By `Tom Nicholas `_. +- New, more specific exception types for tree-related errors (:pull:`169`). + By `Tom Nicholas `_. +- Added a new :py:meth:`DataTree.descendants` property (:pull:`170`). + By `Tom Nicholas `_. +- Added a :py:meth:`DataTree.leaves` property (:pull:`177`). + By `Tom Nicholas `_. +- Added a :py:meth:`DataTree.filter` method (:pull:`184`). + By `Tom Nicholas `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- :py:meth:`DataTree.copy` copy method now only copies the subtree, not the parent nodes (:pull:`171`). + By `Tom Nicholas `_. +- Grafting a subtree onto another tree now leaves name of original subtree object unchanged (:issue:`116`, :pull:`172`, :pull:`178`). + By `Tom Nicholas `_. +- Changed the :py:meth:`DataTree.assign` method to just work on the local node (:pull:`181`). + By `Tom Nicholas `_. + +Deprecations +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + +- Fix bug with :py:meth:`DataTree.relative_to` method (:issue:`133`, :pull:`160`). + By `Tom Nicholas `_. +- Fix links to API docs in all documentation (:pull:`183`). + By `Tom Nicholas `_. + +Documentation +~~~~~~~~~~~~~ + +- Changed docs theme to match xarray's main documentation. (:pull:`173`) + By `Tom Nicholas `_. +- Added ``Terminology`` page. (:pull:`174`) + By `Tom Nicholas `_. +- Added page on ``Working with Hierarchical Data`` (:pull:`179`) + By `Tom Nicholas `_. +- Added context content to ``Index`` page (:pull:`182`) + By `Tom Nicholas `_. +- Updated the README (:pull:`187`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + + +.. _whats-new.v0.0.10: + +v0.0.10 (12/07/2022) +-------------------- + +Adds accessors and a `.pipe()` method. + +New Features +~~~~~~~~~~~~ + +- Add the ability to register accessors on ``DataTree`` objects, by using ``register_datatree_accessor``. (:pull:`144`) + By `Tom Nicholas `_. +- Allow method chaining with a new :py:meth:`DataTree.pipe` method (:issue:`151`, :pull:`156`). + By `Justus Magin `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +Deprecations +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + +- Allow ``Datatree`` objects as values in :py:meth:`DataTree.from_dict` (:pull:`159`). + By `Justus Magin `_. + +Documentation +~~~~~~~~~~~~~ + +- Added ``Reading and Writing Files`` page. (:pull:`158`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Avoid reading from same file twice with fsspec3 (:pull:`130`) + By `William Roberts `_. + + +.. _whats-new.v0.0.9: + +v0.0.9 (07/14/2022) +------------------- + +New Features +~~~~~~~~~~~~ + +Breaking changes +~~~~~~~~~~~~~~~~ + +Deprecations +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + +Documentation +~~~~~~~~~~~~~ +- Switch docs theme (:pull:`123`). + By `JuliusBusecke `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + + +.. _whats-new.v0.0.7: + +v0.0.7 (07/11/2022) +------------------- + +New Features +~~~~~~~~~~~~ + +- Improve the HTML repr by adding tree-style lines connecting groups and sub-groups (:pull:`109`). + By `Benjamin Woods `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The ``DataTree.ds`` attribute now returns a view onto an immutable Dataset-like object, instead of an actual instance + of ``xarray.Dataset``. This make break existing ``isinstance`` checks or ``assert`` comparisons. (:pull:`99`) + By `Tom Nicholas `_. + +Deprecations +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + +- Modifying the contents of a ``DataTree`` object via the ``DataTree.ds`` attribute is now forbidden, which prevents + any possibility of the contents of a ``DataTree`` object and its ``.ds`` attribute diverging. (:issue:`38`, :pull:`99`) + By `Tom Nicholas `_. +- Fixed a bug so that names of children now always match keys under which parents store them (:pull:`99`). + By `Tom Nicholas `_. + +Documentation +~~~~~~~~~~~~~ + +- Added ``Data Structures`` page describing the internal structure of a ``DataTree`` object, and its relation to + ``xarray.Dataset`` objects. (:pull:`103`) + By `Tom Nicholas `_. +- API page updated with all the methods that are copied from ``xarray.Dataset``. (:pull:`41`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Refactored ``DataTree`` class to store a set of ``xarray.Variable`` objects instead of a single ``xarray.Dataset``. + This approach means that the ``DataTree`` class now effectively copies and extends the internal structure of + ``xarray.Dataset``. (:pull:`41`) + By `Tom Nicholas `_. +- Refactored to use intermediate ``NamedNode`` class, separating implementation of methods requiring a ``name`` + attribute from those not requiring it. + By `Tom Nicholas `_. +- Made ``testing.test_datatree.create_test_datatree`` into a pytest fixture (:pull:`107`). + By `Benjamin Woods `_. + + + +.. _whats-new.v0.0.6: + +v0.0.6 (06/03/2022) +------------------- + +Various small bug fixes, in preparation for more significant changes in the next version. + +Bug fixes +~~~~~~~~~ + +- Fixed bug with checking that assigning parent or new children did not create a loop in the tree (:pull:`105`) + By `Tom Nicholas `_. +- Do not call ``__exit__`` on Zarr store when opening (:pull:`90`) + By `Matt McCormick `_. +- Fix netCDF encoding for compression (:pull:`95`) + By `Joe Hamman `_. +- Added validity checking for node names (:pull:`106`) + By `Tom Nicholas `_. + +.. _whats-new.v0.0.5: + +v0.0.5 (05/05/2022) +------------------- + +- Major refactor of internals, moving from the ``DataTree.children`` attribute being a ``Tuple[DataTree]`` to being a + ``OrderedDict[str, DataTree]``. This was necessary in order to integrate better with xarray's dictionary-like API, + solve several issues, simplify the code internally, remove dependencies, and enable new features. (:pull:`76`) + By `Tom Nicholas `_. + +New Features +~~~~~~~~~~~~ + +- Syntax for accessing nodes now supports file-like paths, including parent nodes via ``"../"``, relative paths, the + root node via ``"/"``, and the current node via ``"."``. (Internally it actually uses ``pathlib`` now.) + By `Tom Nicholas `_. +- New path-like API methods, such as ``.relative_to``, ``.find_common_ancestor``, and ``.same_tree``. +- Some new dictionary-like methods, such as ``DataTree.get`` and ``DataTree.update``. (:pull:`76`) + By `Tom Nicholas `_. +- New HTML repr, which will automatically display in a jupyter notebook. (:pull:`78`) + By `Tom Nicholas `_. +- New delitem method so you can delete nodes. (:pull:`88`) + By `Tom Nicholas `_. +- New ``to_dict`` method. (:pull:`82`) + By `Tom Nicholas `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- Node names are now optional, which means that the root of the tree can be unnamed. This has knock-on effects for + a lot of the API. +- The ``__init__`` signature for ``DataTree`` has changed, so that ``name`` is now an optional kwarg. +- Files will now be loaded as a slightly different tree, because the root group no longer needs to be given a default + name. +- Removed tag-like access to nodes. +- Removes the option to delete all data in a node by assigning None to the node (in favour of deleting data by replacing + the node's ``.ds`` attribute with an empty Dataset), or to create a new empty node in the same way (in favour of + assigning an empty DataTree object instead). +- Removes the ability to create a new node by assigning a ``Dataset`` object to ``DataTree.__setitem__``. +- Several other minor API changes such as ``.pathstr`` -> ``.path``, and ``from_dict``'s dictionary argument now being + required. (:pull:`76`) + By `Tom Nicholas `_. + +Deprecations +~~~~~~~~~~~~ + +- No longer depends on the anytree library (:pull:`76`) + By `Tom Nicholas `_. + +Bug fixes +~~~~~~~~~ + +- Fixed indentation issue with the string repr (:pull:`86`) + By `Tom Nicholas `_. + +Documentation +~~~~~~~~~~~~~ + +- Quick-overview page updated to match change in path syntax (:pull:`76`) + By `Tom Nicholas `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Basically every file was changed in some way to accommodate (:pull:`76`). +- No longer need the utility functions for string manipulation that were defined in ``utils.py``. +- A considerable amount of code copied over from the internals of anytree (e.g. in ``render.py`` and ``iterators.py``). + The Apache license for anytree has now been bundled with datatree. (:pull:`76`). + By `Tom Nicholas `_. + +.. _whats-new.v0.0.4: + +v0.0.4 (31/03/2022) +------------------- + +- Ensure you get the pretty tree-like string representation by default in ipython (:pull:`73`). + By `Tom Nicholas `_. +- Now available on conda-forge (as xarray-datatree)! (:pull:`71`) + By `Anderson Banihirwe `_. +- Allow for python 3.8 (:pull:`70`). + By `Don Setiawan `_. + +.. _whats-new.v0.0.3: + +v0.0.3 (30/03/2022) +------------------- + +- First released version available on both pypi (as xarray-datatree)! diff --git a/xarray/datatree_/readthedocs.yml b/xarray/datatree_/readthedocs.yml new file mode 100644 index 00000000000..9b04939c898 --- /dev/null +++ b/xarray/datatree_/readthedocs.yml @@ -0,0 +1,7 @@ +version: 2 +conda: + environment: ci/doc.yml +build: + os: 'ubuntu-20.04' + tools: + python: 'mambaforge-4.10' diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py index 143d7a58fda..b1bf7a1af11 100644 --- a/xarray/indexes/__init__.py +++ b/xarray/indexes/__init__.py @@ -2,6 +2,7 @@ DataArray objects. """ + from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex __all__ = ["Index", "PandasIndex", "PandasMultiIndex"] diff --git a/xarray/namedarray/__init__.py b/xarray/namedarray/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/xarray/namedarray/_aggregations.py b/xarray/namedarray/_aggregations.py new file mode 100644 index 00000000000..9f58aeb791d --- /dev/null +++ b/xarray/namedarray/_aggregations.py @@ -0,0 +1,950 @@ +"""Mixin classes with reduction operations.""" + +# This file was generated using xarray.util.generate_aggregations. Do not edit manually. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, Callable + +from xarray.core import duck_array_ops +from xarray.core.types import Dims, Self + + +class NamedArrayAggregations: + __slots__ = () + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> Self: + raise NotImplementedError() + + def count( + self, + dim: Dims = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``count`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``count``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``count`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``count`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + pandas.DataFrame.count + dask.dataframe.DataFrame.count + Dataset.count + DataArray.count + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.count() + Size: 8B + array(5) + """ + return self.reduce( + duck_array_ops.count, + dim=dim, + **kwargs, + ) + + def all( + self, + dim: Dims = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``all`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``all``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``all`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``all`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.all + dask.array.all + Dataset.all + DataArray.all + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([True, True, True, True, True, False], dtype=bool), + ... ) + >>> na + Size: 6B + array([ True, True, True, True, True, False]) + + >>> na.all() + Size: 1B + array(False) + """ + return self.reduce( + duck_array_ops.array_all, + dim=dim, + **kwargs, + ) + + def any( + self, + dim: Dims = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``any`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``any``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``any`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``any`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.any + dask.array.any + Dataset.any + DataArray.any + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([True, True, True, True, True, False], dtype=bool), + ... ) + >>> na + Size: 6B + array([ True, True, True, True, True, False]) + + >>> na.any() + Size: 1B + array(True) + """ + return self.reduce( + duck_array_ops.array_any, + dim=dim, + **kwargs, + ) + + def max( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``max`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``max``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``max`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``max`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.max + dask.array.max + Dataset.max + DataArray.max + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.max() + Size: 8B + array(3.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.max(skipna=False) + Size: 8B + array(nan) + """ + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def min( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``min`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``min``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``min`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``min`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.min + dask.array.min + Dataset.min + DataArray.min + :ref:`agg` + User guide on reduction or aggregation operations. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.min() + Size: 8B + array(0.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.min(skipna=False) + Size: 8B + array(nan) + """ + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def mean( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``mean`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``mean``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``mean`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``mean`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.mean + dask.array.mean + Dataset.mean + DataArray.mean + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.mean() + Size: 8B + array(1.6) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.mean(skipna=False) + Size: 8B + array(nan) + """ + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def prod( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + min_count: int | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``prod`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``prod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``prod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``prod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.prod + dask.array.prod + Dataset.prod + DataArray.prod + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.prod() + Size: 8B + array(0.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.prod(skipna=False) + Size: 8B + array(nan) + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> na.prod(skipna=True, min_count=2) + Size: 8B + array(0.) + """ + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + **kwargs, + ) + + def sum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + min_count: int | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``sum`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``sum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int or None, optional + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. Changed in version 0.17.0: if specified on an integer + array and skipna=True, the result will be a float array. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``sum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``sum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.sum + dask.array.sum + Dataset.sum + DataArray.sum + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.sum() + Size: 8B + array(8.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.sum(skipna=False) + Size: 8B + array(nan) + + Specify ``min_count`` for finer control over when NaNs are ignored. + + >>> na.sum(skipna=True, min_count=2) + Size: 8B + array(8.) + """ + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + **kwargs, + ) + + def std( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + ddof: int = 0, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``std`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``std``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``std`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``std`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.std + dask.array.std + Dataset.std + DataArray.std + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.std() + Size: 8B + array(1.0198039) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.std(skipna=False) + Size: 8B + array(nan) + + Specify ``ddof=1`` for an unbiased estimate. + + >>> na.std(skipna=True, ddof=1) + Size: 8B + array(1.14017543) + """ + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + **kwargs, + ) + + def var( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + ddof: int = 0, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``var`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``var``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + ddof : int, default: 0 + “Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``, + where ``N`` represents the number of elements. + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``var`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``var`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.var + dask.array.var + Dataset.var + DataArray.var + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.var() + Size: 8B + array(1.04) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.var(skipna=False) + Size: 8B + array(nan) + + Specify ``ddof=1`` for an unbiased estimate. + + >>> na.var(skipna=True, ddof=1) + Size: 8B + array(1.3) + """ + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + **kwargs, + ) + + def median( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``median`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``median``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``median`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``median`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.median + dask.array.median + Dataset.median + DataArray.median + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.median() + Size: 8B + array(2.) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.median(skipna=False) + Size: 8B + array(nan) + """ + return self.reduce( + duck_array_ops.median, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def cumsum( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``cumsum`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``cumsum``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``cumsum`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``cumsum`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.cumsum + dask.array.cumsum + Dataset.cumsum + DataArray.cumsum + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.cumsum() + Size: 48B + array([1., 3., 6., 6., 8., 8.]) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.cumsum(skipna=False) + Size: 48B + array([ 1., 3., 6., 6., 8., nan]) + """ + return self.reduce( + duck_array_ops.cumsum, + dim=dim, + skipna=skipna, + **kwargs, + ) + + def cumprod( + self, + dim: Dims = None, + *, + skipna: bool | None = None, + **kwargs: Any, + ) -> Self: + """ + Reduce this NamedArray's data by applying ``cumprod`` along some dimension(s). + + Parameters + ---------- + dim : str, Iterable of Hashable, "..." or None, default: None + Name of dimension[s] along which to apply ``cumprod``. For e.g. ``dim="x"`` + or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions. + skipna : bool or None, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + **kwargs : Any + Additional keyword arguments passed on to the appropriate array + function for calculating ``cumprod`` on this object's data. + These could include dask-specific kwargs like ``split_every``. + + Returns + ------- + reduced : NamedArray + New NamedArray with ``cumprod`` applied to its data and the + indicated dimension(s) removed + + See Also + -------- + numpy.cumprod + dask.array.cumprod + Dataset.cumprod + DataArray.cumprod + :ref:`agg` + User guide on reduction or aggregation operations. + + Notes + ----- + Non-numeric variables will be removed prior to reducing. + + Examples + -------- + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x", + ... np.array([1, 2, 3, 0, 2, np.nan]), + ... ) + >>> na + Size: 48B + array([ 1., 2., 3., 0., 2., nan]) + + >>> na.cumprod() + Size: 48B + array([1., 2., 6., 0., 0., 0.]) + + Use ``skipna`` to control whether NaNs are ignored. + + >>> na.cumprod(skipna=False) + Size: 48B + array([ 1., 2., 6., 0., 0., nan]) + """ + return self.reduce( + duck_array_ops.cumprod, + dim=dim, + skipna=skipna, + **kwargs, + ) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py new file mode 100644 index 00000000000..977d011c685 --- /dev/null +++ b/xarray/namedarray/_array_api.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import warnings +from types import ModuleType +from typing import Any + +import numpy as np + +from xarray.namedarray._typing import ( + Default, + _arrayapi, + _Axes, + _Axis, + _default, + _Dim, + _DType, + _ScalarType, + _ShapeType, + _SupportsImag, + _SupportsReal, +) +from xarray.namedarray.core import NamedArray + +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + r"The numpy.array_api submodule is still experimental", + category=UserWarning, + ) + import numpy.array_api as nxp # noqa: F401 + + +def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: + if isinstance(x._data, _arrayapi): + return x._data.__array_namespace__() + + return np + + +# %% Creation Functions + + +def astype( + x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True +) -> NamedArray[_ShapeType, _DType]: + """ + Copies an array to a specified data type irrespective of Type Promotion Rules rules. + + Parameters + ---------- + x : NamedArray + Array to cast. + dtype : _DType + Desired data type. + copy : bool, optional + Specifies whether to copy an array when the specified dtype matches the data + type of the input array x. + If True, a newly allocated array must always be returned. + If False and the specified dtype matches the data type of the input array, + the input array must be returned; otherwise, a newly allocated array must be + returned. Default: True. + + Returns + ------- + out : NamedArray + An array having the specified data type. The returned array must have the + same shape as x. + + Examples + -------- + >>> narr = NamedArray(("x",), nxp.asarray([1.5, 2.5])) + >>> narr + Size: 16B + Array([1.5, 2.5], dtype=float64) + >>> astype(narr, np.dtype(np.int32)) + Size: 8B + Array([1, 2], dtype=int32) + """ + if isinstance(x._data, _arrayapi): + xp = x._data.__array_namespace__() + return x._new(data=xp.astype(x._data, dtype, copy=copy)) + + # np.astype doesn't exist yet: + return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined] + + +# %% Elementwise Functions + + +def imag( + x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], / # type: ignore[type-var] +) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: + """ + Returns the imaginary component of a complex number for each element x_i of the + input array x. + + Parameters + ---------- + x : NamedArray + Input array. Should have a complex floating-point data type. + + Returns + ------- + out : NamedArray + An array containing the element-wise results. The returned array must have a + floating-point data type with the same floating-point precision as x + (e.g., if x is complex64, the returned array must have the floating-point + data type float32). + + Examples + -------- + >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp + >>> imag(narr) + Size: 16B + array([2., 4.]) + """ + xp = _get_data_namespace(x) + out = x._new(data=xp.imag(x._data)) + return out + + +def real( + x: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], / # type: ignore[type-var] +) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: + """ + Returns the real component of a complex number for each element x_i of the + input array x. + + Parameters + ---------- + x : NamedArray + Input array. Should have a complex floating-point data type. + + Returns + ------- + out : NamedArray + An array containing the element-wise results. The returned array must have a + floating-point data type with the same floating-point precision as x + (e.g., if x is complex64, the returned array must have the floating-point + data type float32). + + Examples + -------- + >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp + >>> real(narr) + Size: 16B + array([1., 2.]) + """ + xp = _get_data_namespace(x) + out = x._new(data=xp.real(x._data)) + return out + + +# %% Manipulation functions +def expand_dims( + x: NamedArray[Any, _DType], + /, + *, + dim: _Dim | Default = _default, + axis: _Axis = 0, +) -> NamedArray[Any, _DType]: + """ + Expands the shape of an array by inserting a new dimension of size one at the + position specified by dims. + + Parameters + ---------- + x : + Array to expand. + dim : + Dimension name. New dimension will be stored in the axis position. + axis : + (Not recommended) Axis position (zero-based). Default is 0. + + Returns + ------- + out : + An expanded output array having the same data type as x. + + Examples + -------- + >>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]])) + >>> expand_dims(x) + Size: 32B + Array([[[1., 2.], + [3., 4.]]], dtype=float64) + >>> expand_dims(x, dim="z") + Size: 32B + Array([[[1., 2.], + [3., 4.]]], dtype=float64) + """ + xp = _get_data_namespace(x) + dims = x.dims + if dim is _default: + dim = f"dim_{len(dims)}" + d = list(dims) + d.insert(axis, dim) + out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis)) + return out + + +def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DType]: + """ + Permutes the dimensions of an array. + + Parameters + ---------- + x : + Array to permute. + axes : + Permutation of the dimensions of x. + + Returns + ------- + out : + An array with permuted dimensions. The returned array must have the same + data type as x. + + """ + + dims = x.dims + new_dims = tuple(dims[i] for i in axes) + if isinstance(x._data, _arrayapi): + xp = _get_data_namespace(x) + out = x._new(dims=new_dims, data=xp.permute_dims(x._data, axes)) + else: + out = x._new(dims=new_dims, data=x._data.transpose(axes)) # type: ignore[attr-defined] + return out diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py new file mode 100644 index 00000000000..b715973814f --- /dev/null +++ b/xarray/namedarray/_typing.py @@ -0,0 +1,319 @@ +from __future__ import annotations + +import sys +from collections.abc import Hashable, Iterable, Mapping, Sequence +from enum import Enum +from types import ModuleType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Final, + Literal, + Protocol, + SupportsIndex, + TypeVar, + Union, + overload, + runtime_checkable, +) + +import numpy as np + +try: + if sys.version_info >= (3, 11): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias +except ImportError: + if TYPE_CHECKING: + raise + else: + Self: Any = None + + +# Singleton type, as per https://github.com/python/typing/pull/240 +class Default(Enum): + token: Final = 0 + + +_default = Default.token + +# https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + +_dtype = np.dtype +_DType = TypeVar("_DType", bound=np.dtype[Any]) +_DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype[Any]) +# A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic` + +_ScalarType = TypeVar("_ScalarType", bound=np.generic) +_ScalarType_co = TypeVar("_ScalarType_co", bound=np.generic, covariant=True) + + +# A protocol for anything with the dtype attribute +@runtime_checkable +class _SupportsDType(Protocol[_DType_co]): + @property + def dtype(self) -> _DType_co: ... + + +_DTypeLike = Union[ + np.dtype[_ScalarType], + type[_ScalarType], + _SupportsDType[np.dtype[_ScalarType]], +] + +# For unknown shapes Dask uses np.nan, array_api uses None: +_IntOrUnknown = int +_Shape = tuple[_IntOrUnknown, ...] +_ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]] +_ShapeType = TypeVar("_ShapeType", bound=Any) +_ShapeType_co = TypeVar("_ShapeType_co", bound=Any, covariant=True) + +_Axis = int +_Axes = tuple[_Axis, ...] +_AxisLike = Union[_Axis, _Axes] + +_Chunks = tuple[_Shape, ...] +_NormalizedChunks = tuple[tuple[int, ...], ...] +# FYI in some cases we don't allow `None`, which this doesn't take account of. +T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]] +# We allow the tuple form of this (though arguably we could transition to named dims only) +T_Chunks: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDim]] + +_Dim = Hashable +_Dims = tuple[_Dim, ...] + +_DimsLike = Union[str, Iterable[_Dim]] + +# https://data-apis.org/array-api/latest/API_specification/indexing.html +# TODO: np.array_api was bugged and didn't allow (None,), but should! +# https://github.com/numpy/numpy/pull/25022 +# https://github.com/data-apis/array-api/pull/674 +_IndexKey = Union[int, slice, "ellipsis"] +_IndexKeys = tuple[Union[_IndexKey], ...] # tuple[Union[_IndexKey, None], ...] +_IndexKeyLike = Union[_IndexKey, _IndexKeys] + +_AttrsLike = Union[Mapping[Any, Any], None] + + +class _SupportsReal(Protocol[_T_co]): + @property + def real(self) -> _T_co: ... + + +class _SupportsImag(Protocol[_T_co]): + @property + def imag(self) -> _T_co: ... + + +@runtime_checkable +class _array(Protocol[_ShapeType_co, _DType_co]): + """ + Minimal duck array named array uses. + + Corresponds to np.ndarray. + """ + + @property + def shape(self) -> _Shape: ... + + @property + def dtype(self) -> _DType_co: ... + + +@runtime_checkable +class _arrayfunction( + _array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Duck array supporting NEP 18. + + Corresponds to np.ndarray. + """ + + @overload + def __getitem__( + self, key: _arrayfunction[Any, Any] | tuple[_arrayfunction[Any, Any], ...], / + ) -> _arrayfunction[Any, _DType_co]: ... + + @overload + def __getitem__(self, key: _IndexKeyLike, /) -> Any: ... + + def __getitem__( + self, + key: ( + _IndexKeyLike + | _arrayfunction[Any, Any] + | tuple[_arrayfunction[Any, Any], ...] + ), + /, + ) -> _arrayfunction[Any, _DType_co] | Any: ... + + @overload + def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]: ... + + @overload + def __array__(self, dtype: _DType, /) -> np.ndarray[Any, _DType]: ... + + def __array__( + self, dtype: _DType | None = ..., / + ) -> np.ndarray[Any, _DType] | np.ndarray[Any, _DType_co]: ... + + # TODO: Should return the same subclass but with a new dtype generic. + # https://github.com/python/typing/issues/548 + def __array_ufunc__( + self, + ufunc: Any, + method: Any, + *inputs: Any, + **kwargs: Any, + ) -> Any: ... + + # TODO: Should return the same subclass but with a new dtype generic. + # https://github.com/python/typing/issues/548 + def __array_function__( + self, + func: Callable[..., Any], + types: Iterable[type], + args: Iterable[Any], + kwargs: Mapping[str, Any], + ) -> Any: ... + + @property + def imag(self) -> _arrayfunction[_ShapeType_co, Any]: ... + + @property + def real(self) -> _arrayfunction[_ShapeType_co, Any]: ... + + +@runtime_checkable +class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co]): + """ + Duck array supporting NEP 47. + + Corresponds to np.ndarray. + """ + + def __getitem__( + self, + key: ( + _IndexKeyLike | Any + ), # TODO: Any should be _arrayapi[Any, _dtype[np.integer]] + /, + ) -> _arrayapi[Any, Any]: ... + + def __array_namespace__(self) -> ModuleType: ... + + +# NamedArray can most likely use both __array_function__ and __array_namespace__: +_arrayfunction_or_api = (_arrayfunction, _arrayapi) + +duckarray = Union[ + _arrayfunction[_ShapeType_co, _DType_co], _arrayapi[_ShapeType_co, _DType_co] +] + +# Corresponds to np.typing.NDArray: +DuckArray = _arrayfunction[Any, np.dtype[_ScalarType_co]] + + +@runtime_checkable +class _chunkedarray( + _array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Minimal chunked duck array. + + Corresponds to np.ndarray. + """ + + @property + def chunks(self) -> _Chunks: ... + + +@runtime_checkable +class _chunkedarrayfunction( + _arrayfunction[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Chunked duck array supporting NEP 18. + + Corresponds to np.ndarray. + """ + + @property + def chunks(self) -> _Chunks: ... + + +@runtime_checkable +class _chunkedarrayapi( + _arrayapi[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Chunked duck array supporting NEP 47. + + Corresponds to np.ndarray. + """ + + @property + def chunks(self) -> _Chunks: ... + + +# NamedArray can most likely use both __array_function__ and __array_namespace__: +_chunkedarrayfunction_or_api = (_chunkedarrayfunction, _chunkedarrayapi) +chunkedduckarray = Union[ + _chunkedarrayfunction[_ShapeType_co, _DType_co], + _chunkedarrayapi[_ShapeType_co, _DType_co], +] + + +@runtime_checkable +class _sparsearray( + _array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Minimal sparse duck array. + + Corresponds to np.ndarray. + """ + + def todense(self) -> np.ndarray[Any, _DType_co]: ... + + +@runtime_checkable +class _sparsearrayfunction( + _arrayfunction[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Sparse duck array supporting NEP 18. + + Corresponds to np.ndarray. + """ + + def todense(self) -> np.ndarray[Any, _DType_co]: ... + + +@runtime_checkable +class _sparsearrayapi( + _arrayapi[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co] +): + """ + Sparse duck array supporting NEP 47. + + Corresponds to np.ndarray. + """ + + def todense(self) -> np.ndarray[Any, _DType_co]: ... + + +# NamedArray can most likely use both __array_function__ and __array_namespace__: +_sparsearrayfunction_or_api = (_sparsearrayfunction, _sparsearrayapi) +sparseduckarray = Union[ + _sparsearrayfunction[_ShapeType_co, _DType_co], + _sparsearrayapi[_ShapeType_co, _DType_co], +] + +ErrorOptions = Literal["raise", "ignore"] +ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"] diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py new file mode 100644 index 00000000000..135dabc0656 --- /dev/null +++ b/xarray/namedarray/core.py @@ -0,0 +1,1142 @@ +from __future__ import annotations + +import copy +import math +import sys +import warnings +from collections.abc import Hashable, Iterable, Mapping, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + TypeVar, + cast, + overload, +) + +import numpy as np + +# TODO: get rid of this after migrating this class to array API +from xarray.core import dtypes, formatting, formatting_html +from xarray.core.indexing import ( + ExplicitlyIndexed, + ImplicitToExplicitIndexingAdapter, + OuterIndexer, +) +from xarray.namedarray._aggregations import NamedArrayAggregations +from xarray.namedarray._typing import ( + ErrorOptionsWithWarn, + _arrayapi, + _arrayfunction_or_api, + _chunkedarray, + _default, + _dtype, + _DType_co, + _ScalarType_co, + _ShapeType_co, + _sparsearrayfunction_or_api, + _SupportsImag, + _SupportsReal, +) +from xarray.namedarray.parallelcompat import guess_chunkmanager +from xarray.namedarray.pycompat import to_numpy +from xarray.namedarray.utils import ( + either_dict_or_kwargs, + infix_dims, + is_dict_like, + is_duck_dask_array, + to_0d_object_array, +) + +if TYPE_CHECKING: + from numpy.typing import ArrayLike, NDArray + + from xarray.core.types import Dims + from xarray.namedarray._typing import ( + Default, + _AttrsLike, + _Chunks, + _Dim, + _Dims, + _DimsLike, + _DType, + _IntOrUnknown, + _ScalarType, + _Shape, + _ShapeType, + duckarray, + ) + from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint + + try: + from dask.typing import ( + Graph, + NestedKeys, + PostComputeCallable, + PostPersistCallable, + SchedulerGetCallable, + ) + except ImportError: + Graph: Any # type: ignore[no-redef] + NestedKeys: Any # type: ignore[no-redef] + SchedulerGetCallable: Any # type: ignore[no-redef] + PostComputeCallable: Any # type: ignore[no-redef] + PostPersistCallable: Any # type: ignore[no-redef] + + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + + T_NamedArray = TypeVar("T_NamedArray", bound="_NamedArray[Any]") + T_NamedArrayInteger = TypeVar( + "T_NamedArrayInteger", bound="_NamedArray[np.integer[Any]]" + ) + + +@overload +def _new( + x: NamedArray[Any, _DType_co], + dims: _DimsLike | Default = ..., + data: duckarray[_ShapeType, _DType] = ..., + attrs: _AttrsLike | Default = ..., +) -> NamedArray[_ShapeType, _DType]: ... + + +@overload +def _new( + x: NamedArray[_ShapeType_co, _DType_co], + dims: _DimsLike | Default = ..., + data: Default = ..., + attrs: _AttrsLike | Default = ..., +) -> NamedArray[_ShapeType_co, _DType_co]: ... + + +def _new( + x: NamedArray[Any, _DType_co], + dims: _DimsLike | Default = _default, + data: duckarray[_ShapeType, _DType] | Default = _default, + attrs: _AttrsLike | Default = _default, +) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, _DType_co]: + """ + Create a new array with new typing information. + + Parameters + ---------- + x : NamedArray + Array to create a new array from + dims : Iterable of Hashable, optional + Name(s) of the dimension(s). + Will copy the dims from x by default. + data : duckarray, optional + The actual data that populates the array. Should match the + shape specified by `dims`. + Will copy the data from x by default. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Will copy the attrs from x by default. + """ + dims_ = copy.copy(x._dims) if dims is _default else dims + + attrs_: Mapping[Any, Any] | None + if attrs is _default: + attrs_ = None if x._attrs is None else x._attrs.copy() + else: + attrs_ = attrs + + if data is _default: + return type(x)(dims_, copy.copy(x._data), attrs_) + else: + cls_ = cast("type[NamedArray[_ShapeType, _DType]]", type(x)) + return cls_(dims_, data, attrs_) + + +@overload +def from_array( + dims: _DimsLike, + data: duckarray[_ShapeType, _DType], + attrs: _AttrsLike = ..., +) -> NamedArray[_ShapeType, _DType]: ... + + +@overload +def from_array( + dims: _DimsLike, + data: ArrayLike, + attrs: _AttrsLike = ..., +) -> NamedArray[Any, Any]: ... + + +def from_array( + dims: _DimsLike, + data: duckarray[_ShapeType, _DType] | ArrayLike, + attrs: _AttrsLike = None, +) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, Any]: + """ + Create a Named array from an array-like object. + + Parameters + ---------- + dims : str or iterable of str + Name(s) of the dimension(s). + data : T_DuckArray or ArrayLike + The actual data that populates the array. Should match the + shape specified by `dims`. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Default is None, meaning no attributes will be stored. + """ + if isinstance(data, NamedArray): + raise TypeError( + "Array is already a Named array. Use 'data.data' to retrieve the data array" + ) + + # TODO: dask.array.ma.MaskedArray also exists, better way? + if isinstance(data, np.ma.MaskedArray): + mask = np.ma.getmaskarray(data) # type: ignore[no-untyped-call] + if mask.any(): + # TODO: requires refactoring/vendoring xarray.core.dtypes and + # xarray.core.duck_array_ops + raise NotImplementedError("MaskedArray is not supported yet") + + return NamedArray(dims, data, attrs) + + if isinstance(data, _arrayfunction_or_api): + return NamedArray(dims, data, attrs) + + if isinstance(data, tuple): + return NamedArray(dims, to_0d_object_array(data), attrs) + + # validate whether the data is valid data types. + return NamedArray(dims, np.asarray(data), attrs) + + +class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]): + """ + A wrapper around duck arrays with named dimensions + and attributes which describe a single Array. + Numeric operations on this object implement array broadcasting and + dimension alignment based on dimension names, + rather than axis order. + + + Parameters + ---------- + dims : str or iterable of hashable + Name(s) of the dimension(s). + data : array-like or duck-array + The actual data that populates the array. Should match the + shape specified by `dims`. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Default is None, meaning no attributes will be stored. + + Raises + ------ + ValueError + If the `dims` length does not match the number of data dimensions (ndim). + + + Examples + -------- + >>> data = np.array([1.5, 2, 3], dtype=float) + >>> narr = NamedArray(("x",), data, {"units": "m"}) # TODO: Better name than narr? + """ + + __slots__ = ("_data", "_dims", "_attrs") + + _data: duckarray[Any, _DType_co] + _dims: _Dims + _attrs: dict[Any, Any] | None + + def __init__( + self, + dims: _DimsLike, + data: duckarray[Any, _DType_co], + attrs: _AttrsLike = None, + ): + self._data = data + self._dims = self._parse_dimensions(dims) + self._attrs = dict(attrs) if attrs else None + + def __init_subclass__(cls, **kwargs: Any) -> None: + if NamedArray in cls.__bases__ and (cls._new == NamedArray._new): + # Type hinting does not work for subclasses unless _new is + # overridden with the correct class. + raise TypeError( + "Subclasses of `NamedArray` must override the `_new` method." + ) + super().__init_subclass__(**kwargs) + + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: duckarray[_ShapeType, _DType] = ..., + attrs: _AttrsLike | Default = ..., + ) -> NamedArray[_ShapeType, _DType]: ... + + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: Default = ..., + attrs: _AttrsLike | Default = ..., + ) -> NamedArray[_ShapeType_co, _DType_co]: ... + + def _new( + self, + dims: _DimsLike | Default = _default, + data: duckarray[Any, _DType] | Default = _default, + attrs: _AttrsLike | Default = _default, + ) -> NamedArray[_ShapeType, _DType] | NamedArray[_ShapeType_co, _DType_co]: + """ + Create a new array with new typing information. + + _new has to be reimplemented each time NamedArray is subclassed, + otherwise type hints will not be correct. The same is likely true + for methods that relied on _new. + + Parameters + ---------- + dims : Iterable of Hashable, optional + Name(s) of the dimension(s). + Will copy the dims from x by default. + data : duckarray, optional + The actual data that populates the array. Should match the + shape specified by `dims`. + Will copy the data from x by default. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Will copy the attrs from x by default. + """ + return _new(self, dims, data, attrs) + + def _replace( + self, + dims: _DimsLike | Default = _default, + data: duckarray[_ShapeType_co, _DType_co] | Default = _default, + attrs: _AttrsLike | Default = _default, + ) -> Self: + """ + Create a new array with the same typing information. + + The types for each argument cannot change, + use self._new if that is a risk. + + Parameters + ---------- + dims : Iterable of Hashable, optional + Name(s) of the dimension(s). + Will copy the dims from x by default. + data : duckarray, optional + The actual data that populates the array. Should match the + shape specified by `dims`. + Will copy the data from x by default. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Will copy the attrs from x by default. + """ + return cast("Self", self._new(dims, data, attrs)) + + def _copy( + self, + deep: bool = True, + data: duckarray[_ShapeType_co, _DType_co] | None = None, + memo: dict[int, Any] | None = None, + ) -> Self: + if data is None: + ndata = self._data + if deep: + ndata = copy.deepcopy(ndata, memo=memo) + else: + ndata = data + self._check_shape(ndata) + + attrs = ( + copy.deepcopy(self._attrs, memo=memo) if deep else copy.copy(self._attrs) + ) + + return self._replace(data=ndata, attrs=attrs) + + def __copy__(self) -> Self: + return self._copy(deep=False) + + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: + return self._copy(deep=True, memo=memo) + + def copy( + self, + deep: bool = True, + data: duckarray[_ShapeType_co, _DType_co] | None = None, + ) -> Self: + """Returns a copy of this object. + + If `deep=True`, the data array is loaded into memory and copied onto + the new object. Dimensions, attributes and encodings are always copied. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, default: True + Whether the data array is loaded into memory and copied onto + the new object. Default is True. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + When `data` is used, `deep` is ignored. + + Returns + ------- + object : NamedArray + New object with dimensions, attributes, and optionally + data copied from original. + + + """ + return self._copy(deep=deep, data=data) + + @property + def ndim(self) -> int: + """ + Number of array dimensions. + + See Also + -------- + numpy.ndarray.ndim + """ + return len(self.shape) + + @property + def size(self) -> _IntOrUnknown: + """ + Number of elements in the array. + + Equal to ``np.prod(a.shape)``, i.e., the product of the array’s dimensions. + + See Also + -------- + numpy.ndarray.size + """ + return math.prod(self.shape) + + def __len__(self) -> _IntOrUnknown: + try: + return self.shape[0] + except Exception as exc: + raise TypeError("len() of unsized object") from exc + + @property + def dtype(self) -> _DType_co: + """ + Data-type of the array’s elements. + + See Also + -------- + ndarray.dtype + numpy.dtype + """ + return self._data.dtype + + @property + def shape(self) -> _Shape: + """ + Get the shape of the array. + + Returns + ------- + shape : tuple of ints + Tuple of array dimensions. + + See Also + -------- + numpy.ndarray.shape + """ + return self._data.shape + + @property + def nbytes(self) -> _IntOrUnknown: + """ + Total bytes consumed by the elements of the data array. + + If the underlying data array does not include ``nbytes``, estimates + the bytes consumed based on the ``size`` and ``dtype``. + """ + if hasattr(self._data, "nbytes"): + return self._data.nbytes # type: ignore[no-any-return] + else: + return self.size * self.dtype.itemsize + + @property + def dims(self) -> _Dims: + """Tuple of dimension names with which this NamedArray is associated.""" + return self._dims + + @dims.setter + def dims(self, value: _DimsLike) -> None: + self._dims = self._parse_dimensions(value) + + def _parse_dimensions(self, dims: _DimsLike) -> _Dims: + dims = (dims,) if isinstance(dims, str) else tuple(dims) + if len(dims) != self.ndim: + raise ValueError( + f"dimensions {dims} must have the same length as the " + f"number of data dimensions, ndim={self.ndim}" + ) + if len(set(dims)) < len(dims): + repeated_dims = {d for d in dims if dims.count(d) > 1} + warnings.warn( + f"Duplicate dimension names present: dimensions {repeated_dims} appear more than once in dims={dims}. " + "We do not yet support duplicate dimension names, but we do allow initial construction of the object. " + "We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. " + "To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``.", + UserWarning, + ) + return dims + + @property + def attrs(self) -> dict[Any, Any]: + """Dictionary of local attributes on this NamedArray.""" + if self._attrs is None: + self._attrs = {} + return self._attrs + + @attrs.setter + def attrs(self, value: Mapping[Any, Any]) -> None: + self._attrs = dict(value) if value else None + + def _check_shape(self, new_data: duckarray[Any, _DType_co]) -> None: + if new_data.shape != self.shape: + raise ValueError( + f"replacement data must match the {self.__class__.__name__}'s shape. " + f"replacement data has shape {new_data.shape}; {self.__class__.__name__} has shape {self.shape}" + ) + + @property + def data(self) -> duckarray[Any, _DType_co]: + """ + The NamedArray's data as an array. The underlying array type + (e.g. dask, sparse, pint) is preserved. + + """ + + return self._data + + @data.setter + def data(self, data: duckarray[Any, _DType_co]) -> None: + self._check_shape(data) + self._data = data + + @property + def imag( + self: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var] + ) -> NamedArray[_ShapeType, _dtype[_ScalarType]]: + """ + The imaginary part of the array. + + See Also + -------- + numpy.ndarray.imag + """ + if isinstance(self._data, _arrayapi): + from xarray.namedarray._array_api import imag + + return imag(self) + + return self._new(data=self._data.imag) + + @property + def real( + self: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var] + ) -> NamedArray[_ShapeType, _dtype[_ScalarType]]: + """ + The real part of the array. + + See Also + -------- + numpy.ndarray.real + """ + if isinstance(self._data, _arrayapi): + from xarray.namedarray._array_api import real + + return real(self) + return self._new(data=self._data.real) + + def __dask_tokenize__(self) -> object: + # Use v.data, instead of v._data, in order to cope with the wrappers + # around NetCDF and the like + from dask.base import normalize_token + + return normalize_token((type(self), self._dims, self.data, self._attrs or None)) + + def __dask_graph__(self) -> Graph | None: + if is_duck_dask_array(self._data): + return self._data.__dask_graph__() + else: + # TODO: Should this method just raise instead? + # raise NotImplementedError("Method requires self.data to be a dask array") + return None + + def __dask_keys__(self) -> NestedKeys: + if is_duck_dask_array(self._data): + return self._data.__dask_keys__() + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def __dask_layers__(self) -> Sequence[str]: + if is_duck_dask_array(self._data): + return self._data.__dask_layers__() + else: + raise AttributeError("Method requires self.data to be a dask array.") + + @property + def __dask_optimize__( + self, + ) -> Callable[..., dict[Any, Any]]: + if is_duck_dask_array(self._data): + return self._data.__dask_optimize__ # type: ignore[no-any-return] + else: + raise AttributeError("Method requires self.data to be a dask array.") + + @property + def __dask_scheduler__(self) -> SchedulerGetCallable: + if is_duck_dask_array(self._data): + return self._data.__dask_scheduler__ + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def __dask_postcompute__( + self, + ) -> tuple[PostComputeCallable, tuple[Any, ...]]: + if is_duck_dask_array(self._data): + array_func, array_args = self._data.__dask_postcompute__() # type: ignore[no-untyped-call] + return self._dask_finalize, (array_func,) + array_args + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def __dask_postpersist__( + self, + ) -> tuple[ + Callable[ + [Graph, PostPersistCallable[Any], Any, Any], + Self, + ], + tuple[Any, ...], + ]: + if is_duck_dask_array(self._data): + a: tuple[PostPersistCallable[Any], tuple[Any, ...]] + a = self._data.__dask_postpersist__() # type: ignore[no-untyped-call] + array_func, array_args = a + + return self._dask_finalize, (array_func,) + array_args + else: + raise AttributeError("Method requires self.data to be a dask array.") + + def _dask_finalize( + self, + results: Graph, + array_func: PostPersistCallable[Any], + *args: Any, + **kwargs: Any, + ) -> Self: + data = array_func(results, *args, **kwargs) + return type(self)(self._dims, data, attrs=self._attrs) + + @overload + def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]: ... + + @overload + def get_axis_num(self, dim: Hashable) -> int: ... + + def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]: + """Return axis number(s) corresponding to dimension(s) in this array. + + Parameters + ---------- + dim : str or iterable of str + Dimension name(s) for which to lookup axes. + + Returns + ------- + int or tuple of int + Axis number or numbers corresponding to the given dimensions. + """ + if not isinstance(dim, str) and isinstance(dim, Iterable): + return tuple(self._get_axis_num(d) for d in dim) + else: + return self._get_axis_num(dim) + + def _get_axis_num(self: Any, dim: Hashable) -> int: + _raise_if_any_duplicate_dimensions(self.dims) + try: + return self.dims.index(dim) # type: ignore[no-any-return] + except ValueError: + raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") + + @property + def chunks(self) -> _Chunks | None: + """ + Tuple of block lengths for this NamedArray's data, in order of dimensions, or None if + the underlying data is not a dask array. + + See Also + -------- + NamedArray.chunk + NamedArray.chunksizes + xarray.unify_chunks + """ + data = self._data + if isinstance(data, _chunkedarray): + return data.chunks + else: + return None + + @property + def chunksizes( + self, + ) -> Mapping[_Dim, _Shape]: + """ + Mapping from dimension names to block lengths for this namedArray's data, or None if + the underlying data is not a dask array. + Cannot be modified directly, but can be modified by calling .chunk(). + + Differs from NamedArray.chunks because it returns a mapping of dimensions to chunk shapes + instead of a tuple of chunk shapes. + + See Also + -------- + NamedArray.chunk + NamedArray.chunks + xarray.unify_chunks + """ + data = self._data + if isinstance(data, _chunkedarray): + return dict(zip(self.dims, data.chunks)) + else: + return {} + + @property + def sizes(self) -> dict[_Dim, _IntOrUnknown]: + """Ordered mapping from dimension names to lengths.""" + return dict(zip(self.dims, self.shape)) + + def chunk( + self, + chunks: int | Literal["auto"] | Mapping[Any, None | int | tuple[int, ...]] = {}, + chunked_array_type: str | ChunkManagerEntrypoint[Any] | None = None, + from_array_kwargs: Any = None, + **chunks_kwargs: Any, + ) -> Self: + """Coerce this array's data into a dask array with the given chunks. + + If this variable is a non-dask array, it will be converted to dask + array. If it's a dask array, it will be rechunked to the given chunk + sizes. + + If neither chunks is not provided for one or more dimensions, chunk + sizes along that dimension will not be updated; non-dask arrays will be + converted into dask arrays with a single block. + + Parameters + ---------- + chunks : int, tuple or dict, optional + Chunk sizes along each dimension, e.g., ``5``, ``(5, 5)`` or + ``{'x': 5, 'y': 5}``. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntrypoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + **chunks_kwargs : {dim: chunks, ...}, optional + The keyword arguments form of ``chunks``. + One of chunks or chunks_kwargs must be provided. + + Returns + ------- + chunked : xarray.Variable + + See Also + -------- + Variable.chunks + Variable.chunksizes + xarray.unify_chunks + dask.array.from_array + """ + + if from_array_kwargs is None: + from_array_kwargs = {} + + if chunks is None: + warnings.warn( + "None value for 'chunks' is deprecated. " + "It will raise an error in the future. Use instead '{}'", + category=FutureWarning, + ) + chunks = {} + + if isinstance(chunks, (float, str, int, tuple, list)): + # TODO we shouldn't assume here that other chunkmanagers can handle these types + # TODO should we call normalize_chunks here? + pass # dask.array.from_array can handle these directly + else: + chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") + + if is_dict_like(chunks): + chunks = {self.get_axis_num(dim): chunk for dim, chunk in chunks.items()} + + chunkmanager = guess_chunkmanager(chunked_array_type) + + data_old = self._data + if chunkmanager.is_chunked_array(data_old): + data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type] + else: + if not isinstance(data_old, ExplicitlyIndexed): + ndata = data_old + else: + # Unambiguously handle array storage backends (like NetCDF4 and h5py) + # that can't handle general array indexing. For example, in netCDF4 you + # can do "outer" indexing along two dimensions independent, which works + # differently from how NumPy handles it. + # da.from_array works by using lazy indexing with a tuple of slices. + # Using OuterIndexer is a pragmatic choice: dask does not yet handle + # different indexing types in an explicit way: + # https://github.com/dask/dask/issues/2883 + ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer) # type: ignore[assignment] + + if is_dict_like(chunks): + chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape)) # type: ignore[assignment] + + data_chunked = chunkmanager.from_array(ndata, chunks, **from_array_kwargs) # type: ignore[arg-type] + + return self._replace(data=data_chunked) + + def to_numpy(self) -> np.ndarray[Any, Any]: + """Coerces wrapped data to numpy and returns a numpy.ndarray""" + # TODO an entrypoint so array libraries can choose coercion method? + return to_numpy(self._data) + + def as_numpy(self) -> Self: + """Coerces wrapped data into a numpy array, returning a Variable.""" + return self._replace(data=self.to_numpy()) + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> NamedArray[Any, Any]: + """Reduce this array by applying `func` along some dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form + `func(x, axis=axis, **kwargs)` to return the result of reducing an + np.ndarray over an integer valued axis. + dim : "...", str, Iterable of Hashable or None, optional + Dimension(s) over which to apply `func`. By default `func` is + applied over all dimensions. + axis : int or Sequence of int, optional + Axis(es) over which to apply `func`. Only one of the 'dim' + and 'axis' arguments can be supplied. If neither are supplied, then + the reduction is calculated over the flattened array (by calling + `func(x)` without an axis argument). + keepdims : bool, default: False + If True, the dimensions which are reduced are left in the result + as dimensions of size one + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : Array + Array with summarized data and the indicated dimension(s) + removed. + """ + if dim == ...: + dim = None + if dim is not None and axis is not None: + raise ValueError("cannot supply both 'axis' and 'dim' arguments") + + if dim is not None: + axis = self.get_axis_num(dim) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", r"Mean of empty slice", category=RuntimeWarning + ) + if axis is not None: + if isinstance(axis, tuple) and len(axis) == 1: + # unpack axis for the benefit of functions + # like np.argmin which can't handle tuple arguments + axis = axis[0] + data = func(self.data, axis=axis, **kwargs) + else: + data = func(self.data, **kwargs) + + if getattr(data, "shape", ()) == self.shape: + dims = self.dims + else: + removed_axes: Iterable[int] + if axis is None: + removed_axes = range(self.ndim) + else: + removed_axes = np.atleast_1d(axis) % self.ndim + if keepdims: + # Insert np.newaxis for removed dims + slices = tuple( + np.newaxis if i in removed_axes else slice(None, None) + for i in range(self.ndim) + ) + if getattr(data, "shape", None) is None: + # Reduce has produced a scalar value, not an array-like + data = np.asanyarray(data)[slices] + else: + data = data[slices] + dims = self.dims + else: + dims = tuple( + adim for n, adim in enumerate(self.dims) if n not in removed_axes + ) + + # Return NamedArray to handle IndexVariable when data is nD + return from_array(dims, data, attrs=self._attrs) + + def _nonzero(self: T_NamedArrayInteger) -> tuple[T_NamedArrayInteger, ...]: + """Equivalent numpy's nonzero but returns a tuple of NamedArrays.""" + # TODO: we should replace dask's native nonzero + # after https://github.com/dask/dask/issues/1076 is implemented. + # TODO: cast to ndarray and back to T_DuckArray is a workaround + nonzeros = np.nonzero(cast("NDArray[np.integer[Any]]", self.data)) + _attrs = self.attrs + return tuple( + cast("T_NamedArrayInteger", self._new((dim,), nz, _attrs)) + for nz, dim in zip(nonzeros, self.dims) + ) + + def __repr__(self) -> str: + return formatting.array_repr(self) + + def _repr_html_(self) -> str: + return formatting_html.array_repr(self) + + def _as_sparse( + self, + sparse_format: Literal["coo"] | Default = _default, + fill_value: ArrayLike | Default = _default, + ) -> NamedArray[Any, _DType_co]: + """ + Use sparse-array as backend. + """ + import sparse + + from xarray.namedarray._array_api import astype + + # TODO: what to do if dask-backended? + if fill_value is _default: + dtype, fill_value = dtypes.maybe_promote(self.dtype) + else: + dtype = dtypes.result_type(self.dtype, fill_value) + + if sparse_format is _default: + sparse_format = "coo" + try: + as_sparse = getattr(sparse, f"as_{sparse_format.lower()}") + except AttributeError as exc: + raise ValueError(f"{sparse_format} is not a valid sparse format") from exc + + data = as_sparse(astype(self, dtype).data, fill_value=fill_value) + return self._new(data=data) + + def _to_dense(self) -> NamedArray[Any, _DType_co]: + """ + Change backend from sparse to np.array. + """ + if isinstance(self._data, _sparsearrayfunction_or_api): + data_dense: np.ndarray[Any, _DType_co] = self._data.todense() + return self._new(data=data_dense) + else: + raise TypeError("self.data is not a sparse array") + + def permute_dims( + self, + *dim: Iterable[_Dim] | ellipsis, + missing_dims: ErrorOptionsWithWarn = "raise", + ) -> NamedArray[Any, _DType_co]: + """Return a new object with transposed dimensions. + + Parameters + ---------- + *dim : Hashable, optional + By default, reverse the order of the dimensions. Otherwise, reorder the + dimensions to this order. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + NamedArray: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + + Returns + ------- + NamedArray + The returned NamedArray has permuted dimensions and data with the + same attributes as the original. + + + See Also + -------- + numpy.transpose + """ + + from xarray.namedarray._array_api import permute_dims + + if not dim: + dims = self.dims[::-1] + else: + dims = tuple(infix_dims(dim, self.dims, missing_dims)) # type: ignore[arg-type] + + if len(dims) < 2 or dims == self.dims: + # no need to transpose if only one dimension + # or dims are in same order + return self.copy(deep=False) + + axes_result = self.get_axis_num(dims) + axes = (axes_result,) if isinstance(axes_result, int) else axes_result + + return permute_dims(self, axes) + + @property + def T(self) -> NamedArray[Any, _DType_co]: + """Return a new object with transposed dimensions.""" + if self.ndim != 2: + raise ValueError( + f"x.T requires x to have 2 dimensions, got {self.ndim}. Use x.permute_dims() to permute dimensions." + ) + + return self.permute_dims() + + def broadcast_to( + self, dim: Mapping[_Dim, int] | None = None, **dim_kwargs: Any + ) -> NamedArray[Any, _DType_co]: + """ + Broadcast the NamedArray to a new shape. New dimensions are not allowed. + + This method allows for the expansion of the array's dimensions to a specified shape. + It handles both positional and keyword arguments for specifying the dimensions to broadcast. + An error is raised if new dimensions are attempted to be added. + + Parameters + ---------- + dim : dict, str, sequence of str, optional + Dimensions to broadcast the array to. If a dict, keys are dimension names and values are the new sizes. + If a string or sequence of strings, existing dimensions are matched with a size of 1. + + **dim_kwargs : Any + Additional dimensions specified as keyword arguments. Each keyword argument specifies the name of an existing dimension and its size. + + Returns + ------- + NamedArray + A new NamedArray with the broadcasted dimensions. + + Examples + -------- + >>> data = np.asarray([[1.0, 2.0], [3.0, 4.0]]) + >>> array = xr.NamedArray(("x", "y"), data) + >>> array.sizes + {'x': 2, 'y': 2} + + >>> broadcasted = array.broadcast_to(x=2, y=2) + >>> broadcasted.sizes + {'x': 2, 'y': 2} + """ + + from xarray.core import duck_array_ops + + combined_dims = either_dict_or_kwargs(dim, dim_kwargs, "broadcast_to") + + # Check that no new dimensions are added + if new_dims := set(combined_dims) - set(self.dims): + raise ValueError( + f"Cannot add new dimensions: {new_dims}. Only existing dimensions are allowed. " + "Use `expand_dims` method to add new dimensions." + ) + + # Create a dictionary of the current dimensions and their sizes + current_shape = self.sizes + + # Update the current shape with the new dimensions, keeping the order of the original dimensions + broadcast_shape = {d: current_shape.get(d, 1) for d in self.dims} + broadcast_shape |= combined_dims + + # Ensure the dimensions are in the correct order + ordered_dims = list(broadcast_shape.keys()) + ordered_shape = tuple(broadcast_shape[d] for d in ordered_dims) + data = duck_array_ops.broadcast_to(self._data, ordered_shape) # type: ignore[no-untyped-call] # TODO: use array-api-compat function + return self._new(data=data, dims=ordered_dims) + + def expand_dims( + self, + dim: _Dim | Default = _default, + ) -> NamedArray[Any, _DType_co]: + """ + Expand the dimensions of the NamedArray. + + This method adds new dimensions to the object. The new dimensions are added at the beginning of the array. + + Parameters + ---------- + dim : Hashable, optional + Dimension name to expand the array to. This dimension will be added at the beginning of the array. + + Returns + ------- + NamedArray + A new NamedArray with expanded dimensions. + + + Examples + -------- + + >>> data = np.asarray([[1.0, 2.0], [3.0, 4.0]]) + >>> array = xr.NamedArray(("x", "y"), data) + + + # expand dimensions by specifying a new dimension name + >>> expanded = array.expand_dims(dim="z") + >>> expanded.dims + ('z', 'x', 'y') + + """ + + from xarray.namedarray._array_api import expand_dims + + return expand_dims(self, dim=dim) + + +_NamedArray = NamedArray[Any, np.dtype[_ScalarType_co]] + + +def _raise_if_any_duplicate_dimensions( + dims: _Dims, err_context: str = "This function" +) -> None: + if len(set(dims)) < len(dims): + repeated_dims = {d for d in dims if dims.count(d) > 1} + raise ValueError( + f"{err_context} cannot handle duplicate dimensions, but dimensions {repeated_dims} appear more than once on this object's dims: {dims}" + ) diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py new file mode 100644 index 00000000000..14744d2de6b --- /dev/null +++ b/xarray/namedarray/daskmanager.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable + +import numpy as np +from packaging.version import Version + +from xarray.core.indexing import ImplicitToExplicitIndexingAdapter +from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint, T_ChunkedArray +from xarray.namedarray.utils import is_duck_dask_array, module_available + +if TYPE_CHECKING: + from xarray.namedarray._typing import ( + T_Chunks, + _DType_co, + _NormalizedChunks, + duckarray, + ) + + try: + from dask.array import Array as DaskArray + except ImportError: + DaskArray = np.ndarray[Any, Any] # type: ignore[assignment, misc] + + +dask_available = module_available("dask") + + +class DaskManager(ChunkManagerEntrypoint["DaskArray"]): # type: ignore[type-var] + array_cls: type[DaskArray] + available: bool = dask_available + + def __init__(self) -> None: + # TODO can we replace this with a class attribute instead? + + from dask.array import Array + + self.array_cls = Array + + def is_chunked_array(self, data: duckarray[Any, Any]) -> bool: + return is_duck_dask_array(data) + + def chunks(self, data: Any) -> _NormalizedChunks: + return data.chunks # type: ignore[no-any-return] + + def normalize_chunks( + self, + chunks: T_Chunks | _NormalizedChunks, + shape: tuple[int, ...] | None = None, + limit: int | None = None, + dtype: _DType_co | None = None, + previous_chunks: _NormalizedChunks | None = None, + ) -> Any: + """Called by open_dataset""" + from dask.array.core import normalize_chunks + + return normalize_chunks( + chunks, + shape=shape, + limit=limit, + dtype=dtype, + previous_chunks=previous_chunks, + ) # type: ignore[no-untyped-call] + + def from_array( + self, data: Any, chunks: T_Chunks | _NormalizedChunks, **kwargs: Any + ) -> DaskArray | Any: + import dask.array as da + + if isinstance(data, ImplicitToExplicitIndexingAdapter): + # lazily loaded backend array classes should use NumPy array operations. + kwargs["meta"] = np.ndarray + + return da.from_array( + data, + chunks, + **kwargs, + ) # type: ignore[no-untyped-call] + + def compute( + self, *data: Any, **kwargs: Any + ) -> tuple[np.ndarray[Any, _DType_co], ...]: + from dask.array import compute + + return compute(*data, **kwargs) # type: ignore[no-untyped-call, no-any-return] + + @property + def array_api(self) -> Any: + from dask import array as da + + return da + + def reduction( # type: ignore[override] + self, + arr: T_ChunkedArray, + func: Callable[..., Any], + combine_func: Callable[..., Any] | None = None, + aggregate_func: Callable[..., Any] | None = None, + axis: int | Sequence[int] | None = None, + dtype: _DType_co | None = None, + keepdims: bool = False, + ) -> DaskArray | Any: + from dask.array import reduction + + return reduction( + arr, + chunk=func, + combine=combine_func, + aggregate=aggregate_func, + axis=axis, + dtype=dtype, + keepdims=keepdims, + ) # type: ignore[no-untyped-call] + + def scan( # type: ignore[override] + self, + func: Callable[..., Any], + binop: Callable[..., Any], + ident: float, + arr: T_ChunkedArray, + axis: int | None = None, + dtype: _DType_co | None = None, + **kwargs: Any, + ) -> DaskArray | Any: + from dask.array.reductions import cumreduction + + return cumreduction( + func, + binop, + ident, + arr, + axis=axis, + dtype=dtype, + **kwargs, + ) # type: ignore[no-untyped-call] + + def apply_gufunc( + self, + func: Callable[..., Any], + signature: str, + *args: Any, + axes: Sequence[tuple[int, ...]] | None = None, + axis: int | None = None, + keepdims: bool = False, + output_dtypes: Sequence[_DType_co] | None = None, + output_sizes: dict[str, int] | None = None, + vectorize: bool | None = None, + allow_rechunk: bool = False, + meta: tuple[np.ndarray[Any, _DType_co], ...] | None = None, + **kwargs: Any, + ) -> Any: + from dask.array.gufunc import apply_gufunc + + return apply_gufunc( + func, + signature, + *args, + axes=axes, + axis=axis, + keepdims=keepdims, + output_dtypes=output_dtypes, + output_sizes=output_sizes, + vectorize=vectorize, + allow_rechunk=allow_rechunk, + meta=meta, + **kwargs, + ) # type: ignore[no-untyped-call] + + def map_blocks( + self, + func: Callable[..., Any], + *args: Any, + dtype: _DType_co | None = None, + chunks: tuple[int, ...] | None = None, + drop_axis: int | Sequence[int] | None = None, + new_axis: int | Sequence[int] | None = None, + **kwargs: Any, + ) -> Any: + import dask + from dask.array import map_blocks + + if drop_axis is None and Version(dask.__version__) < Version("2022.9.1"): + # See https://github.com/pydata/xarray/pull/7019#discussion_r1196729489 + # TODO remove once dask minimum version >= 2022.9.1 + drop_axis = [] + + # pass through name, meta, token as kwargs + return map_blocks( + func, + *args, + dtype=dtype, + chunks=chunks, + drop_axis=drop_axis, + new_axis=new_axis, + **kwargs, + ) # type: ignore[no-untyped-call] + + def blockwise( + self, + func: Callable[..., Any], + out_ind: Iterable[Any], + *args: Any, + # can't type this as mypy assumes args are all same type, but dask blockwise args alternate types + name: str | None = None, + token: Any | None = None, + dtype: _DType_co | None = None, + adjust_chunks: dict[Any, Callable[..., Any]] | None = None, + new_axes: dict[Any, int] | None = None, + align_arrays: bool = True, + concatenate: bool | None = None, + meta: tuple[np.ndarray[Any, _DType_co], ...] | None = None, + **kwargs: Any, + ) -> DaskArray | Any: + from dask.array import blockwise + + return blockwise( + func, + out_ind, + *args, + name=name, + token=token, + dtype=dtype, + adjust_chunks=adjust_chunks, + new_axes=new_axes, + align_arrays=align_arrays, + concatenate=concatenate, + meta=meta, + **kwargs, + ) # type: ignore[no-untyped-call] + + def unify_chunks( + self, + *args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types + **kwargs: Any, + ) -> tuple[dict[str, _NormalizedChunks], list[DaskArray]]: + from dask.array.core import unify_chunks + + return unify_chunks(*args, **kwargs) # type: ignore[no-any-return, no-untyped-call] + + def store( + self, + sources: Any | Sequence[Any], + targets: Any, + **kwargs: Any, + ) -> Any: + from dask.array import store + + return store( + sources=sources, + targets=targets, + **kwargs, + ) diff --git a/xarray/namedarray/dtypes.py b/xarray/namedarray/dtypes.py new file mode 100644 index 00000000000..7a83bd17064 --- /dev/null +++ b/xarray/namedarray/dtypes.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import functools +import sys +from typing import Any, Literal + +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + +import numpy as np + +from xarray.namedarray import utils + +# Use as a sentinel value to indicate a dtype appropriate NA value. +NA = utils.ReprObject("") + + +@functools.total_ordering +class AlwaysGreaterThan: + def __gt__(self, other: Any) -> Literal[True]: + return True + + def __eq__(self, other: Any) -> bool: + return isinstance(other, type(self)) + + +@functools.total_ordering +class AlwaysLessThan: + def __lt__(self, other: Any) -> Literal[True]: + return True + + def __eq__(self, other: Any) -> bool: + return isinstance(other, type(self)) + + +# Equivalence to np.inf (-np.inf) for object-type +INF = AlwaysGreaterThan() +NINF = AlwaysLessThan() + + +# Pairs of types that, if both found, should be promoted to object dtype +# instead of following NumPy's own type-promotion rules. These type promotion +# rules match pandas instead. For reference, see the NumPy type hierarchy: +# https://numpy.org/doc/stable/reference/arrays.scalars.html +PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ( + (np.number, np.character), # numpy promotes to character + (np.bool_, np.character), # numpy promotes to character + (np.bytes_, np.str_), # numpy promotes to unicode +) + + +def maybe_promote(dtype: np.dtype[np.generic]) -> tuple[np.dtype[np.generic], Any]: + """Simpler equivalent of pandas.core.common._maybe_promote + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + dtype : Promoted dtype that can hold missing values. + fill_value : Valid missing value for the promoted dtype. + """ + # N.B. these casting rules should match pandas + dtype_: np.typing.DTypeLike + fill_value: Any + if np.issubdtype(dtype, np.floating): + dtype_ = dtype + fill_value = np.nan + elif np.issubdtype(dtype, np.timedelta64): + # See https://github.com/numpy/numpy/issues/10685 + # np.timedelta64 is a subclass of np.integer + # Check np.timedelta64 before np.integer + fill_value = np.timedelta64("NaT") + dtype_ = dtype + elif np.issubdtype(dtype, np.integer): + dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64 + fill_value = np.nan + elif np.issubdtype(dtype, np.complexfloating): + dtype_ = dtype + fill_value = np.nan + np.nan * 1j + elif np.issubdtype(dtype, np.datetime64): + dtype_ = dtype + fill_value = np.datetime64("NaT") + else: + dtype_ = object + fill_value = np.nan + + dtype_out = np.dtype(dtype_) + fill_value = dtype_out.type(fill_value) + return dtype_out, fill_value + + +NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} + + +def get_fill_value(dtype: np.dtype[np.generic]) -> Any: + """Return an appropriate fill value for this dtype. + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + fill_value : Missing value corresponding to this dtype. + """ + _, fill_value = maybe_promote(dtype) + return fill_value + + +def get_pos_infinity( + dtype: np.dtype[np.generic], max_for_int: bool = False +) -> float | complex | AlwaysGreaterThan: + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + max_for_int : bool + Return np.iinfo(dtype).max instead of np.inf + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if issubclass(dtype.type, np.floating): + return np.inf + + if issubclass(dtype.type, np.integer): + return np.iinfo(dtype.type).max if max_for_int else np.inf + if issubclass(dtype.type, np.complexfloating): + return np.inf + 1j * np.inf + + return INF + + +def get_neg_infinity( + dtype: np.dtype[np.generic], min_for_int: bool = False +) -> float | complex | AlwaysLessThan: + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + min_for_int : bool + Return np.iinfo(dtype).min instead of -np.inf + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if issubclass(dtype.type, np.floating): + return -np.inf + + if issubclass(dtype.type, np.integer): + return np.iinfo(dtype.type).min if min_for_int else -np.inf + if issubclass(dtype.type, np.complexfloating): + return -np.inf - 1j * np.inf + + return NINF + + +def is_datetime_like( + dtype: np.dtype[np.generic], +) -> TypeGuard[np.datetime64 | np.timedelta64]: + """Check if a dtype is a subclass of the numpy datetime types""" + return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) + + +def result_type( + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, +) -> np.dtype[np.generic]: + """Like np.result_type, but with type promotion rules matching pandas. + + Examples of changed behavior: + number + string -> object (not string) + bytes + unicode -> object (not unicode) + + Parameters + ---------- + *arrays_and_dtypes : list of arrays and dtypes + The dtype is extracted from both numpy and dask arrays. + + Returns + ------- + numpy.dtype for the result. + """ + types = {np.result_type(t).type for t in arrays_and_dtypes} + + for left, right in PROMOTE_TO_OBJECT: + if any(issubclass(t, left) for t in types) and any( + issubclass(t, right) for t in types + ): + return np.dtype(object) + + return np.result_type(*arrays_and_dtypes) diff --git a/xarray/namedarray/parallelcompat.py b/xarray/namedarray/parallelcompat.py new file mode 100644 index 00000000000..dd555fe200a --- /dev/null +++ b/xarray/namedarray/parallelcompat.py @@ -0,0 +1,708 @@ +""" +The code in this module is an experiment in going from N=1 to N=2 parallel computing frameworks in xarray. +It could later be used as the basis for a public interface allowing any N frameworks to interoperate with xarray, +but for now it is just a private experiment. +""" + +from __future__ import annotations + +import functools +import sys +from abc import ABC, abstractmethod +from collections.abc import Iterable, Sequence +from importlib.metadata import EntryPoint, entry_points +from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, TypeVar + +import numpy as np + +from xarray.core.utils import emit_user_level_warning +from xarray.namedarray.pycompat import is_chunked_array + +if TYPE_CHECKING: + from xarray.namedarray._typing import ( + _Chunks, + _DType, + _DType_co, + _NormalizedChunks, + _ShapeType, + duckarray, + ) + + +class ChunkedArrayMixinProtocol(Protocol): + def rechunk(self, chunks: Any, **kwargs: Any) -> Any: ... + + @property + def dtype(self) -> np.dtype[Any]: ... + + @property + def chunks(self) -> _NormalizedChunks: ... + + def compute( + self, *data: Any, **kwargs: Any + ) -> tuple[np.ndarray[Any, _DType_co], ...]: ... + + +T_ChunkedArray = TypeVar("T_ChunkedArray", bound=ChunkedArrayMixinProtocol) + + +@functools.lru_cache(maxsize=1) +def list_chunkmanagers() -> dict[str, ChunkManagerEntrypoint[Any]]: + """ + Return a dictionary of available chunk managers and their ChunkManagerEntrypoint subclass objects. + + Returns + ------- + chunkmanagers : dict + Dictionary whose values are registered ChunkManagerEntrypoint subclass instances, and whose values + are the strings under which they are registered. + + Notes + ----- + # New selection mechanism introduced with Python 3.10. See GH6514. + """ + if sys.version_info >= (3, 10): + entrypoints = entry_points(group="xarray.chunkmanagers") + else: + entrypoints = entry_points().get("xarray.chunkmanagers", ()) + + return load_chunkmanagers(entrypoints) + + +def load_chunkmanagers( + entrypoints: Sequence[EntryPoint], +) -> dict[str, ChunkManagerEntrypoint[Any]]: + """Load entrypoints and instantiate chunkmanagers only once.""" + + loaded_entrypoints = {} + for entrypoint in entrypoints: + try: + loaded_entrypoints[entrypoint.name] = entrypoint.load() + except ModuleNotFoundError as e: + emit_user_level_warning( + f"Failed to load chunk manager entrypoint {entrypoint.name} due to {e}. Skipping.", + ) + pass + + available_chunkmanagers = { + name: chunkmanager() + for name, chunkmanager in loaded_entrypoints.items() + if chunkmanager.available + } + return available_chunkmanagers + + +def guess_chunkmanager( + manager: str | ChunkManagerEntrypoint[Any] | None, +) -> ChunkManagerEntrypoint[Any]: + """ + Get namespace of chunk-handling methods, guessing from what's available. + + If the name of a specific ChunkManager is given (e.g. "dask"), then use that. + Else use whatever is installed, defaulting to dask if there are multiple options. + """ + + chunkmanagers = list_chunkmanagers() + + if manager is None: + if len(chunkmanagers) == 1: + # use the only option available + manager = next(iter(chunkmanagers.keys())) + else: + # default to trying to use dask + manager = "dask" + + if isinstance(manager, str): + if manager not in chunkmanagers: + raise ValueError( + f"unrecognized chunk manager {manager} - must be one of: {list(chunkmanagers)}" + ) + + return chunkmanagers[manager] + elif isinstance(manager, ChunkManagerEntrypoint): + # already a valid ChunkManager so just pass through + return manager + else: + raise TypeError( + f"manager must be a string or instance of ChunkManagerEntrypoint, but received type {type(manager)}" + ) + + +def get_chunked_array_type(*args: Any) -> ChunkManagerEntrypoint[Any]: + """ + Detects which parallel backend should be used for given set of arrays. + + Also checks that all arrays are of same chunking type (i.e. not a mix of cubed and dask). + """ + + # TODO this list is probably redundant with something inside xarray.apply_ufunc + ALLOWED_NON_CHUNKED_TYPES = {int, float, np.ndarray} + + chunked_arrays = [ + a + for a in args + if is_chunked_array(a) and type(a) not in ALLOWED_NON_CHUNKED_TYPES + ] + + # Asserts all arrays are the same type (or numpy etc.) + chunked_array_types = {type(a) for a in chunked_arrays} + if len(chunked_array_types) > 1: + raise TypeError( + f"Mixing chunked array types is not supported, but received multiple types: {chunked_array_types}" + ) + elif len(chunked_array_types) == 0: + raise TypeError("Expected a chunked array but none were found") + + # iterate over defined chunk managers, seeing if each recognises this array type + chunked_arr = chunked_arrays[0] + chunkmanagers = list_chunkmanagers() + selected = [ + chunkmanager + for chunkmanager in chunkmanagers.values() + if chunkmanager.is_chunked_array(chunked_arr) + ] + if not selected: + raise TypeError( + f"Could not find a Chunk Manager which recognises type {type(chunked_arr)}" + ) + elif len(selected) >= 2: + raise TypeError(f"Multiple ChunkManagers recognise type {type(chunked_arr)}") + else: + return selected[0] + + +class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]): + """ + Interface between a particular parallel computing framework and xarray. + + This abstract base class must be subclassed by libraries implementing chunked array types, and + registered via the ``chunkmanagers`` entrypoint. + + Abstract methods on this class must be implemented, whereas non-abstract methods are only required in order to + enable a subset of xarray functionality, and by default will raise a ``NotImplementedError`` if called. + + Attributes + ---------- + array_cls + Type of the array class this parallel computing framework provides. + + Parallel frameworks need to provide an array class that supports the array API standard. + This attribute is used for array instance type checking at runtime. + """ + + array_cls: type[T_ChunkedArray] + available: bool = True + + @abstractmethod + def __init__(self) -> None: + """Used to set the array_cls attribute at import time.""" + raise NotImplementedError() + + def is_chunked_array(self, data: duckarray[Any, Any]) -> bool: + """ + Check if the given object is an instance of this type of chunked array. + + Compares against the type stored in the array_cls attribute by default. + + Parameters + ---------- + data : Any + + Returns + ------- + is_chunked : bool + + See Also + -------- + dask.is_dask_collection + """ + return isinstance(data, self.array_cls) + + @abstractmethod + def chunks(self, data: T_ChunkedArray) -> _NormalizedChunks: + """ + Return the current chunks of the given array. + + Returns chunks explicitly as a tuple of tuple of ints. + + Used internally by xarray objects' .chunks and .chunksizes properties. + + Parameters + ---------- + data : chunked array + + Returns + ------- + chunks : tuple[tuple[int, ...], ...] + + See Also + -------- + dask.array.Array.chunks + cubed.Array.chunks + """ + raise NotImplementedError() + + @abstractmethod + def normalize_chunks( + self, + chunks: _Chunks | _NormalizedChunks, + shape: _ShapeType | None = None, + limit: int | None = None, + dtype: _DType | None = None, + previous_chunks: _NormalizedChunks | None = None, + ) -> _NormalizedChunks: + """ + Normalize given chunking pattern into an explicit tuple of tuples representation. + + Exposed primarily because different chunking backends may want to make different decisions about how to + automatically chunk along dimensions not given explicitly in the input chunks. + + Called internally by xarray.open_dataset. + + Parameters + ---------- + chunks : tuple, int, dict, or string + The chunks to be normalized. + shape : Tuple[int] + The shape of the array + limit : int (optional) + The maximum block size to target in bytes, + if freedom is given to choose + dtype : np.dtype + previous_chunks : Tuple[Tuple[int]], optional + Chunks from a previous array that we should use for inspiration when + rechunking dimensions automatically. + + See Also + -------- + dask.array.core.normalize_chunks + """ + raise NotImplementedError() + + @abstractmethod + def from_array( + self, data: duckarray[Any, Any], chunks: _Chunks, **kwargs: Any + ) -> T_ChunkedArray: + """ + Create a chunked array from a non-chunked numpy-like array. + + Generally input should have a ``.shape``, ``.ndim``, ``.dtype`` and support numpy-style slicing. + + Called when the .chunk method is called on an xarray object that is not already chunked. + Also called within open_dataset (when chunks is not None) to create a chunked array from + an xarray lazily indexed array. + + Parameters + ---------- + data : array_like + chunks : int, tuple + How to chunk the array. + + See Also + -------- + dask.array.from_array + cubed.from_array + """ + raise NotImplementedError() + + def rechunk( + self, + data: T_ChunkedArray, + chunks: _NormalizedChunks | tuple[int, ...] | _Chunks, + **kwargs: Any, + ) -> Any: + """ + Changes the chunking pattern of the given array. + + Called when the .chunk method is called on an xarray object that is already chunked. + + Parameters + ---------- + data : dask array + Array to be rechunked. + chunks : int, tuple, dict or str, optional + The new block dimensions to create. -1 indicates the full size of the + corresponding dimension. Default is "auto" which automatically + determines chunk sizes. + + Returns + ------- + chunked array + + See Also + -------- + dask.array.Array.rechunk + cubed.Array.rechunk + """ + return data.rechunk(chunks, **kwargs) + + @abstractmethod + def compute( + self, *data: T_ChunkedArray | Any, **kwargs: Any + ) -> tuple[np.ndarray[Any, _DType_co], ...]: + """ + Computes one or more chunked arrays, returning them as eager numpy arrays. + + Called anytime something needs to computed, including multiple arrays at once. + Used by `.compute`, `.persist`, `.values`. + + Parameters + ---------- + *data : object + Any number of objects. If an object is an instance of the chunked array type, it is computed + and the in-memory result returned as a numpy array. All other types should be passed through unchanged. + + Returns + ------- + objs + The input, but with all chunked arrays now computed. + + See Also + -------- + dask.compute + cubed.compute + """ + raise NotImplementedError() + + @property + def array_api(self) -> Any: + """ + Return the array_api namespace following the python array API standard. + + See https://data-apis.org/array-api/latest/ . Currently used to access the array API function + ``full_like``, which is called within the xarray constructors ``xarray.full_like``, ``xarray.ones_like``, + ``xarray.zeros_like``, etc. + + See Also + -------- + dask.array + cubed.array_api + """ + raise NotImplementedError() + + def reduction( + self, + arr: T_ChunkedArray, + func: Callable[..., Any], + combine_func: Callable[..., Any] | None = None, + aggregate_func: Callable[..., Any] | None = None, + axis: int | Sequence[int] | None = None, + dtype: _DType_co | None = None, + keepdims: bool = False, + ) -> T_ChunkedArray: + """ + A general version of array reductions along one or more axes. + + Used inside some reductions like nanfirst, which is used by ``groupby.first``. + + Parameters + ---------- + arr : chunked array + Data to be reduced along one or more axes. + func : Callable(x_chunk, axis, keepdims) + First function to be executed when resolving the dask graph. + This function is applied in parallel to all original chunks of x. + See below for function parameters. + combine_func : Callable(x_chunk, axis, keepdims), optional + Function used for intermediate recursive aggregation (see + split_every below). If omitted, it defaults to aggregate_func. + aggregate_func : Callable(x_chunk, axis, keepdims) + Last function to be executed, producing the final output. It is always invoked, even when the reduced + Array counts a single chunk along the reduced axes. + axis : int or sequence of ints, optional + Axis or axes to aggregate upon. If omitted, aggregate along all axes. + dtype : np.dtype + data type of output. This argument was previously optional, but + leaving as ``None`` will now raise an exception. + keepdims : boolean, optional + Whether the reduction function should preserve the reduced axes, + leaving them at size ``output_size``, or remove them. + + Returns + ------- + chunked array + + See Also + -------- + dask.array.reduction + cubed.core.reduction + """ + raise NotImplementedError() + + def scan( + self, + func: Callable[..., Any], + binop: Callable[..., Any], + ident: float, + arr: T_ChunkedArray, + axis: int | None = None, + dtype: _DType_co | None = None, + **kwargs: Any, + ) -> T_ChunkedArray: + """ + General version of a 1D scan, also known as a cumulative array reduction. + + Used in ``ffill`` and ``bfill`` in xarray. + + Parameters + ---------- + func: callable + Cumulative function like np.cumsum or np.cumprod + binop: callable + Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul`` + ident: Number + Associated identity like ``np.cumsum->0`` or ``np.cumprod->1`` + arr: dask Array + axis: int, optional + dtype: dtype + + Returns + ------- + Chunked array + + See also + -------- + dask.array.cumreduction + """ + raise NotImplementedError() + + @abstractmethod + def apply_gufunc( + self, + func: Callable[..., Any], + signature: str, + *args: Any, + axes: Sequence[tuple[int, ...]] | None = None, + keepdims: bool = False, + output_dtypes: Sequence[_DType_co] | None = None, + vectorize: bool | None = None, + **kwargs: Any, + ) -> Any: + """ + Apply a generalized ufunc or similar python function to arrays. + + ``signature`` determines if the function consumes or produces core + dimensions. The remaining dimensions in given input arrays (``*args``) + are considered loop dimensions and are required to broadcast + naturally against each other. + + In other terms, this function is like ``np.vectorize``, but for + the blocks of chunked arrays. If the function itself shall also + be vectorized use ``vectorize=True`` for convenience. + + Called inside ``xarray.apply_ufunc``, which is called internally for most xarray operations. + Therefore this method must be implemented for the vast majority of xarray computations to be supported. + + Parameters + ---------- + func : callable + Function to call like ``func(*args, **kwargs)`` on input arrays + (``*args``) that returns an array or tuple of arrays. If multiple + arguments with non-matching dimensions are supplied, this function is + expected to vectorize (broadcast) over axes of positional arguments in + the style of NumPy universal functions [1]_ (if this is not the case, + set ``vectorize=True``). If this function returns multiple outputs, + ``output_core_dims`` has to be set as well. + signature: string + Specifies what core dimensions are consumed and produced by ``func``. + According to the specification of numpy.gufunc signature [2]_ + *args : numeric + Input arrays or scalars to the callable function. + axes: List of tuples, optional, keyword only + A list of tuples with indices of axes a generalized ufunc should operate on. + For instance, for a signature of ``"(i,j),(j,k)->(i,k)"`` appropriate for + matrix multiplication, the base elements are two-dimensional matrices + and these are taken to be stored in the two last axes of each argument. The + corresponding axes keyword would be ``[(-2, -1), (-2, -1), (-2, -1)]``. + For simplicity, for generalized ufuncs that operate on 1-dimensional arrays + (vectors), a single integer is accepted instead of a single-element tuple, + and for generalized ufuncs for which all outputs are scalars, the output + tuples can be omitted. + keepdims: bool, optional, keyword only + If this is set to True, axes which are reduced over will be left in the result as + a dimension with size one, so that the result will broadcast correctly against the + inputs. This option can only be used for generalized ufuncs that operate on inputs + that all have the same number of core dimensions and with outputs that have no core + dimensions , i.e., with signatures like ``"(i),(i)->()"`` or ``"(m,m)->()"``. + If used, the location of the dimensions in the output can be controlled with axes + and axis. + output_dtypes : Optional, dtype or list of dtypes, keyword only + Valid numpy dtype specification or list thereof. + If not given, a call of ``func`` with a small set of data + is performed in order to try to automatically determine the + output dtypes. + vectorize: bool, keyword only + If set to ``True``, ``np.vectorize`` is applied to ``func`` for + convenience. Defaults to ``False``. + **kwargs : dict + Extra keyword arguments to pass to `func` + + Returns + ------- + Single chunked array or tuple of chunked arrays + + See Also + -------- + dask.array.gufunc.apply_gufunc + cubed.apply_gufunc + + References + ---------- + .. [1] https://docs.scipy.org/doc/numpy/reference/ufuncs.html + .. [2] https://docs.scipy.org/doc/numpy/reference/c-api/generalized-ufuncs.html + """ + raise NotImplementedError() + + def map_blocks( + self, + func: Callable[..., Any], + *args: Any, + dtype: _DType_co | None = None, + chunks: tuple[int, ...] | None = None, + drop_axis: int | Sequence[int] | None = None, + new_axis: int | Sequence[int] | None = None, + **kwargs: Any, + ) -> Any: + """ + Map a function across all blocks of a chunked array. + + Called in elementwise operations, but notably not (currently) called within xarray.map_blocks. + + Parameters + ---------- + func : callable + Function to apply to every block in the array. + If ``func`` accepts ``block_info=`` or ``block_id=`` + as keyword arguments, these will be passed dictionaries + containing information about input and output chunks/arrays + during computation. See examples for details. + args : dask arrays or other objects + dtype : np.dtype, optional + The ``dtype`` of the output array. It is recommended to provide this. + If not provided, will be inferred by applying the function to a small + set of fake data. + chunks : tuple, optional + Chunk shape of resulting blocks if the function does not preserve + shape. If not provided, the resulting array is assumed to have the same + block structure as the first input array. + drop_axis : number or iterable, optional + Dimensions lost by the function. + new_axis : number or iterable, optional + New dimensions created by the function. Note that these are applied + after ``drop_axis`` (if present). + **kwargs : + Other keyword arguments to pass to function. Values must be constants + (not dask.arrays) + + See Also + -------- + dask.array.map_blocks + cubed.map_blocks + """ + raise NotImplementedError() + + def blockwise( + self, + func: Callable[..., Any], + out_ind: Iterable[Any], + *args: Any, # can't type this as mypy assumes args are all same type, but dask blockwise args alternate types + adjust_chunks: dict[Any, Callable[..., Any]] | None = None, + new_axes: dict[Any, int] | None = None, + align_arrays: bool = True, + **kwargs: Any, + ) -> Any: + """ + Tensor operation: Generalized inner and outer products. + + A broad class of blocked algorithms and patterns can be specified with a + concise multi-index notation. The ``blockwise`` function applies an in-memory + function across multiple blocks of multiple inputs in a variety of ways. + Many chunked array operations are special cases of blockwise including + elementwise, broadcasting, reductions, tensordot, and transpose. + + Currently only called explicitly in xarray when performing multidimensional interpolation. + + Parameters + ---------- + func : callable + Function to apply to individual tuples of blocks + out_ind : iterable + Block pattern of the output, something like 'ijk' or (1, 2, 3) + *args : sequence of Array, index pairs + You may also pass literal arguments, accompanied by None index + e.g. (x, 'ij', y, 'jk', z, 'i', some_literal, None) + **kwargs : dict + Extra keyword arguments to pass to function + adjust_chunks : dict + Dictionary mapping index to function to be applied to chunk sizes + new_axes : dict, keyword only + New indexes and their dimension lengths + align_arrays: bool + Whether or not to align chunks along equally sized dimensions when + multiple arrays are provided. This allows for larger chunks in some + arrays to be broken into smaller ones that match chunk sizes in other + arrays such that they are compatible for block function mapping. If + this is false, then an error will be thrown if arrays do not already + have the same number of blocks in each dimension. + + See Also + -------- + dask.array.blockwise + cubed.core.blockwise + """ + raise NotImplementedError() + + def unify_chunks( + self, + *args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types + **kwargs: Any, + ) -> tuple[dict[str, _NormalizedChunks], list[T_ChunkedArray]]: + """ + Unify chunks across a sequence of arrays. + + Called by xarray.unify_chunks. + + Parameters + ---------- + *args: sequence of Array, index pairs + Sequence like (x, 'ij', y, 'jk', z, 'i') + + See Also + -------- + dask.array.core.unify_chunks + cubed.core.unify_chunks + """ + raise NotImplementedError() + + def store( + self, + sources: T_ChunkedArray | Sequence[T_ChunkedArray], + targets: Any, + **kwargs: dict[str, Any], + ) -> Any: + """ + Store chunked arrays in array-like objects, overwriting data in target. + + This stores chunked arrays into object that supports numpy-style setitem + indexing (e.g. a Zarr Store). Allows storing values chunk by chunk so that it does not have to + fill up memory. For best performance you likely want to align the block size of + the storage target with the block size of your array. + + Used when writing to any registered xarray I/O backend. + + Parameters + ---------- + sources: Array or collection of Arrays + targets: array-like or collection of array-likes + These should support setitem syntax ``target[10:20] = ...``. + If sources is a single item, targets must be a single item; if sources is a + collection of arrays, targets must be a matching collection. + kwargs: + Parameters passed to compute/persist (only used if compute=True) + + See Also + -------- + dask.array.store + cubed.store + """ + raise NotImplementedError() diff --git a/xarray/namedarray/pycompat.py b/xarray/namedarray/pycompat.py new file mode 100644 index 00000000000..3ce33d4d8ea --- /dev/null +++ b/xarray/namedarray/pycompat.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from importlib import import_module +from types import ModuleType +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np +from packaging.version import Version + +from xarray.core.utils import is_scalar +from xarray.namedarray.utils import is_duck_array, is_duck_dask_array + +integer_types = (int, np.integer) + +if TYPE_CHECKING: + ModType = Literal["dask", "pint", "cupy", "sparse", "cubed", "numbagg"] + DuckArrayTypes = tuple[type[Any], ...] # TODO: improve this? maybe Generic + from xarray.namedarray._typing import _DType, _ShapeType, duckarray + + +class DuckArrayModule: + """ + Solely for internal isinstance and version checks. + + Motivated by having to only import pint when required (as pint currently imports xarray) + https://github.com/pydata/xarray/pull/5561#discussion_r664815718 + """ + + module: ModuleType | None + version: Version + type: DuckArrayTypes + available: bool + + def __init__(self, mod: ModType) -> None: + duck_array_module: ModuleType | None + duck_array_version: Version + duck_array_type: DuckArrayTypes + try: + duck_array_module = import_module(mod) + duck_array_version = Version(duck_array_module.__version__) + + if mod == "dask": + duck_array_type = (import_module("dask.array").Array,) + elif mod == "pint": + duck_array_type = (duck_array_module.Quantity,) + elif mod == "cupy": + duck_array_type = (duck_array_module.ndarray,) + elif mod == "sparse": + duck_array_type = (duck_array_module.SparseArray,) + elif mod == "cubed": + duck_array_type = (duck_array_module.Array,) + # Not a duck array module, but using this system regardless, to get lazy imports + elif mod == "numbagg": + duck_array_type = () + else: + raise NotImplementedError + + except (ImportError, AttributeError): # pragma: no cover + duck_array_module = None + duck_array_version = Version("0.0.0") + duck_array_type = () + + self.module = duck_array_module + self.version = duck_array_version + self.type = duck_array_type + self.available = duck_array_module is not None + + +_cached_duck_array_modules: dict[ModType, DuckArrayModule] = {} + + +def _get_cached_duck_array_module(mod: ModType) -> DuckArrayModule: + if mod not in _cached_duck_array_modules: + duckmod = DuckArrayModule(mod) + _cached_duck_array_modules[mod] = duckmod + return duckmod + else: + return _cached_duck_array_modules[mod] + + +def array_type(mod: ModType) -> DuckArrayTypes: + """Quick wrapper to get the array class of the module.""" + return _get_cached_duck_array_module(mod).type + + +def mod_version(mod: ModType) -> Version: + """Quick wrapper to get the version of the module.""" + return _get_cached_duck_array_module(mod).version + + +def is_chunked_array(x: duckarray[Any, Any]) -> bool: + return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks")) + + +def is_0d_dask_array(x: duckarray[Any, Any]) -> bool: + return is_duck_dask_array(x) and is_scalar(x) + + +def to_numpy( + data: duckarray[Any, Any], **kwargs: dict[str, Any] +) -> np.ndarray[Any, np.dtype[Any]]: + from xarray.core.indexing import ExplicitlyIndexed + from xarray.namedarray.parallelcompat import get_chunked_array_type + + if isinstance(data, ExplicitlyIndexed): + data = data.get_duck_array() # type: ignore[no-untyped-call] + + # TODO first attempt to call .to_numpy() once some libraries implement it + if is_chunked_array(data): + chunkmanager = get_chunked_array_type(data) + data, *_ = chunkmanager.compute(data, **kwargs) + if isinstance(data, array_type("cupy")): + data = data.get() + # pint has to be imported dynamically as pint imports xarray + if isinstance(data, array_type("pint")): + data = data.magnitude + if isinstance(data, array_type("sparse")): + data = data.todense() + data = np.asarray(data) + + return data + + +def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType, _DType]: + from xarray.core.indexing import ExplicitlyIndexed + from xarray.namedarray.parallelcompat import get_chunked_array_type + + if is_chunked_array(data): + chunkmanager = get_chunked_array_type(data) + loaded_data, *_ = chunkmanager.compute(data, **kwargs) # type: ignore[var-annotated] + return loaded_data + + if isinstance(data, ExplicitlyIndexed): + return data.get_duck_array() # type: ignore[no-untyped-call, no-any-return] + elif is_duck_array(data): + return data + else: + return np.asarray(data) # type: ignore[return-value] diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py new file mode 100644 index 00000000000..b82a80b546a --- /dev/null +++ b/xarray/namedarray/utils.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import importlib +import sys +import warnings +from collections.abc import Hashable, Iterable, Iterator, Mapping +from functools import lru_cache +from typing import TYPE_CHECKING, Any, TypeVar, cast + +import numpy as np +from packaging.version import Version + +from xarray.namedarray._typing import ErrorOptionsWithWarn, _DimsLike + +if TYPE_CHECKING: + if sys.version_info >= (3, 10): + from typing import TypeGuard + else: + from typing_extensions import TypeGuard + + from numpy.typing import NDArray + + try: + from dask.array.core import Array as DaskArray + from dask.typing import DaskCollection + except ImportError: + DaskArray = NDArray # type: ignore + DaskCollection: Any = NDArray # type: ignore + + from xarray.namedarray._typing import _Dim, duckarray + + +K = TypeVar("K") +V = TypeVar("V") +T = TypeVar("T") + + +@lru_cache +def module_available(module: str, minversion: str | None = None) -> bool: + """Checks whether a module is installed without importing it. + + Use this for a lightweight check and lazy imports. + + Parameters + ---------- + module : str + Name of the module. + minversion : str, optional + Minimum version of the module + + Returns + ------- + available : bool + Whether the module is installed. + """ + if importlib.util.find_spec(module) is None: + return False + + if minversion is not None: + version = importlib.metadata.version(module) + + return Version(version) >= Version(minversion) + + return True + + +def is_dask_collection(x: object) -> TypeGuard[DaskCollection]: + if module_available("dask"): + from dask.base import is_dask_collection + + # use is_dask_collection function instead of dask.typing.DaskCollection + # see https://github.com/pydata/xarray/pull/8241#discussion_r1476276023 + return is_dask_collection(x) + return False + + +def is_duck_array(value: Any) -> TypeGuard[duckarray[Any, Any]]: + # TODO: replace is_duck_array with runtime checks via _arrayfunction_or_api protocol on + # python 3.12 and higher (see https://github.com/pydata/xarray/issues/8696#issuecomment-1924588981) + if isinstance(value, np.ndarray): + return True + return ( + hasattr(value, "ndim") + and hasattr(value, "shape") + and hasattr(value, "dtype") + and ( + (hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__")) + or hasattr(value, "__array_namespace__") + ) + ) + + +def is_duck_dask_array(x: duckarray[Any, Any]) -> TypeGuard[DaskArray]: + return is_duck_array(x) and is_dask_collection(x) + + +def to_0d_object_array( + value: object, +) -> NDArray[np.object_]: + """Given a value, wrap it in a 0-D numpy.ndarray with dtype=object.""" + result = np.empty((), dtype=object) + result[()] = value + return result + + +def is_dict_like(value: Any) -> TypeGuard[Mapping[Any, Any]]: + return hasattr(value, "keys") and hasattr(value, "__getitem__") + + +def drop_missing_dims( + supplied_dims: Iterable[_Dim], + dims: Iterable[_Dim], + missing_dims: ErrorOptionsWithWarn, +) -> _DimsLike: + """Depending on the setting of missing_dims, drop any dimensions from supplied_dims that + are not present in dims. + + Parameters + ---------- + supplied_dims : Iterable of Hashable + dims : Iterable of Hashable + missing_dims : {"raise", "warn", "ignore"} + """ + + if missing_dims == "raise": + supplied_dims_set = {val for val in supplied_dims if val is not ...} + if invalid := supplied_dims_set - set(dims): + raise ValueError( + f"Dimensions {invalid} do not exist. Expected one or more of {dims}" + ) + + return supplied_dims + + elif missing_dims == "warn": + if invalid := set(supplied_dims) - set(dims): + warnings.warn( + f"Dimensions {invalid} do not exist. Expected one or more of {dims}" + ) + + return [val for val in supplied_dims if val in dims or val is ...] + + elif missing_dims == "ignore": + return [val for val in supplied_dims if val in dims or val is ...] + + else: + raise ValueError( + f"Unrecognised option {missing_dims} for missing_dims argument" + ) + + +def infix_dims( + dims_supplied: Iterable[_Dim], + dims_all: Iterable[_Dim], + missing_dims: ErrorOptionsWithWarn = "raise", +) -> Iterator[_Dim]: + """ + Resolves a supplied list containing an ellipsis representing other items, to + a generator with the 'realized' list of all items + """ + if ... in dims_supplied: + dims_all_list = list(dims_all) + if len(set(dims_all)) != len(dims_all_list): + raise ValueError("Cannot use ellipsis with repeated dims") + if list(dims_supplied).count(...) > 1: + raise ValueError("More than one ellipsis supplied") + other_dims = [d for d in dims_all if d not in dims_supplied] + existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims) + for d in existing_dims: + if d is ...: + yield from other_dims + else: + yield d + else: + existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims) + if set(existing_dims) ^ set(dims_all): + raise ValueError( + f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included" + ) + yield from existing_dims + + +def either_dict_or_kwargs( + pos_kwargs: Mapping[Any, T] | None, + kw_kwargs: Mapping[str, T], + func_name: str, +) -> Mapping[Hashable, T]: + if pos_kwargs is None or pos_kwargs == {}: + # Need an explicit cast to appease mypy due to invariance; see + # https://github.com/python/mypy/issues/6228 + return cast(Mapping[Hashable, T], kw_kwargs) + + if not is_dict_like(pos_kwargs): + raise ValueError(f"the first argument to .{func_name} must be a dictionary") + if kw_kwargs: + raise ValueError( + f"cannot specify both keyword and positional arguments to .{func_name}" + ) + return pos_kwargs + + +class ReprObject: + """Object that prints as the given value, for use with sentinel values.""" + + __slots__ = ("_value",) + + _value: str + + def __init__(self, value: str): + self._value = value + + def __repr__(self) -> str: + return self._value + + def __eq__(self, other: ReprObject | Any) -> bool: + # TODO: What type can other be? ArrayLike? + return self._value == other._value if isinstance(other, ReprObject) else False + + def __hash__(self) -> int: + return hash((type(self), self._value)) + + def __dask_tokenize__(self) -> object: + from dask.base import normalize_token + + return normalize_token((type(self), self._value)) diff --git a/xarray/plot/__init__.py b/xarray/plot/__init__.py index 28aac6edd9e..ae7a0012b32 100644 --- a/xarray/plot/__init__.py +++ b/xarray/plot/__init__.py @@ -6,6 +6,7 @@ DataArray.plot._____ Dataset.plot._____ """ + from xarray.plot.dataarray_plot import ( contour, contourf, diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py index ff707602545..9db4ae4e3f7 100644 --- a/xarray/plot/accessor.py +++ b/xarray/plot/accessor.py @@ -16,6 +16,7 @@ from matplotlib.container import BarContainer from matplotlib.contour import QuadContourSet from matplotlib.image import AxesImage + from matplotlib.patches import Polygon from matplotlib.quiver import Quiver from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection from numpy.typing import ArrayLike @@ -47,11 +48,13 @@ def __call__(self, **kwargs) -> Any: return dataarray_plot.plot(self._da, **kwargs) @functools.wraps(dataarray_plot.hist) - def hist(self, *args, **kwargs) -> tuple[np.ndarray, np.ndarray, BarContainer]: + def hist( + self, *args, **kwargs + ) -> tuple[np.ndarray, np.ndarray, BarContainer | Polygon]: return dataarray_plot.hist(self._da, *args, **kwargs) @overload - def line( # type: ignore[misc] # None is hashable :( + def line( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, row: None = None, # no wrap -> primitive @@ -69,13 +72,12 @@ def line( # type: ignore[misc] # None is hashable :( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, - ) -> list[Line3D]: - ... + ) -> list[Line3D]: ... @overload def line( @@ -96,13 +98,12 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, - ) -> FacetGrid[DataArray]: - ... + ) -> FacetGrid[DataArray]: ... @overload def line( @@ -123,20 +124,19 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, - ) -> FacetGrid[DataArray]: - ... + ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.line) + @functools.wraps(dataarray_plot.line, assigned=("__doc__",)) def line(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return dataarray_plot.line(self._da, *args, **kwargs) @overload - def step( # type: ignore[misc] # None is hashable :( + def step( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, where: Literal["pre", "post", "mid"] = "pre", @@ -145,8 +145,7 @@ def step( # type: ignore[misc] # None is hashable :( row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive **kwargs: Any, - ) -> list[Line3D]: - ... + ) -> list[Line3D]: ... @overload def step( @@ -158,8 +157,7 @@ def step( row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, **kwargs: Any, - ) -> FacetGrid[DataArray]: - ... + ) -> FacetGrid[DataArray]: ... @overload def step( @@ -171,15 +169,14 @@ def step( row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid **kwargs: Any, - ) -> FacetGrid[DataArray]: - ... + ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.step) + @functools.wraps(dataarray_plot.step, assigned=("__doc__",)) def step(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return dataarray_plot.step(self._da, *args, **kwargs) @overload - def scatter( + def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -207,8 +204,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -216,8 +213,7 @@ def scatter( extend=None, levels=None, **kwargs, - ) -> PathCollection: - ... + ) -> PathCollection: ... @overload def scatter( @@ -248,8 +244,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -257,8 +253,7 @@ def scatter( extend=None, levels=None, **kwargs, - ) -> FacetGrid[DataArray]: - ... + ) -> FacetGrid[DataArray]: ... @overload def scatter( @@ -289,8 +284,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -298,15 +293,14 @@ def scatter( extend=None, levels=None, **kwargs, - ) -> FacetGrid[DataArray]: - ... + ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.scatter) - def scatter(self, *args, **kwargs): + @functools.wraps(dataarray_plot.scatter, assigned=("__doc__",)) + def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[DataArray]: return dataarray_plot.scatter(self._da, *args, **kwargs) @overload - def imshow( + def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -338,12 +332,11 @@ def imshow( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> AxesImage: - ... + ) -> AxesImage: ... @overload def imshow( @@ -378,12 +371,11 @@ def imshow( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid[DataArray]: - ... + ) -> FacetGrid[DataArray]: ... @overload def imshow( @@ -418,19 +410,18 @@ def imshow( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid[DataArray]: - ... + ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.imshow) - def imshow(self, *args, **kwargs) -> AxesImage: + @functools.wraps(dataarray_plot.imshow, assigned=("__doc__",)) + def imshow(self, *args, **kwargs) -> AxesImage | FacetGrid[DataArray]: return dataarray_plot.imshow(self._da, *args, **kwargs) @overload - def contour( + def contour( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -462,12 +453,11 @@ def contour( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> QuadContourSet: - ... + ) -> QuadContourSet: ... @overload def contour( @@ -502,12 +492,11 @@ def contour( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid[DataArray]: - ... + ) -> FacetGrid[DataArray]: ... @overload def contour( @@ -542,19 +531,18 @@ def contour( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid[DataArray]: - ... + ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.contour) - def contour(self, *args, **kwargs) -> QuadContourSet: + @functools.wraps(dataarray_plot.contour, assigned=("__doc__",)) + def contour(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]: return dataarray_plot.contour(self._da, *args, **kwargs) @overload - def contourf( + def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -586,12 +574,11 @@ def contourf( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> QuadContourSet: - ... + ) -> QuadContourSet: ... @overload def contourf( @@ -626,12 +613,11 @@ def contourf( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid[DataArray]: - ... + ) -> FacetGrid[DataArray]: ... @overload def contourf( @@ -666,19 +652,18 @@ def contourf( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid: - ... + ) -> FacetGrid: ... - @functools.wraps(dataarray_plot.contourf) - def contourf(self, *args, **kwargs) -> QuadContourSet: + @functools.wraps(dataarray_plot.contourf, assigned=("__doc__",)) + def contourf(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]: return dataarray_plot.contourf(self._da, *args, **kwargs) @overload - def pcolormesh( + def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -710,12 +695,11 @@ def pcolormesh( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> QuadMesh: - ... + ) -> QuadMesh: ... @overload def pcolormesh( @@ -750,12 +734,11 @@ def pcolormesh( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid: - ... + ) -> FacetGrid[DataArray]: ... @overload def pcolormesh( @@ -790,15 +773,14 @@ def pcolormesh( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid: - ... + ) -> FacetGrid[DataArray]: ... - @functools.wraps(dataarray_plot.pcolormesh) - def pcolormesh(self, *args, **kwargs) -> QuadMesh: + @functools.wraps(dataarray_plot.pcolormesh, assigned=("__doc__",)) + def pcolormesh(self, *args, **kwargs) -> QuadMesh | FacetGrid[DataArray]: return dataarray_plot.pcolormesh(self._da, *args, **kwargs) @overload @@ -834,12 +816,11 @@ def surface( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> Poly3DCollection: - ... + ) -> Poly3DCollection: ... @overload def surface( @@ -874,12 +855,11 @@ def surface( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid: - ... + ) -> FacetGrid: ... @overload def surface( @@ -914,14 +894,13 @@ def surface( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, - ) -> FacetGrid: - ... + ) -> FacetGrid: ... - @functools.wraps(dataarray_plot.surface) + @functools.wraps(dataarray_plot.surface, assigned=("__doc__",)) def surface(self, *args, **kwargs) -> Poly3DCollection: return dataarray_plot.surface(self._da, *args, **kwargs) @@ -945,7 +924,7 @@ def __call__(self, *args, **kwargs) -> NoReturn: ) @overload - def scatter( + def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -973,8 +952,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -982,8 +961,7 @@ def scatter( extend=None, levels=None, **kwargs: Any, - ) -> PathCollection: - ... + ) -> PathCollection: ... @overload def scatter( @@ -1014,8 +992,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -1023,8 +1001,7 @@ def scatter( extend=None, levels=None, **kwargs: Any, - ) -> FacetGrid[DataArray]: - ... + ) -> FacetGrid[Dataset]: ... @overload def scatter( @@ -1055,8 +1032,8 @@ def scatter( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap=None, vmin: float | None = None, vmax: float | None = None, @@ -1064,15 +1041,14 @@ def scatter( extend=None, levels=None, **kwargs: Any, - ) -> FacetGrid[DataArray]: - ... + ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.scatter) - def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[DataArray]: + @functools.wraps(dataset_plot.scatter, assigned=("__doc__",)) + def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[Dataset]: return dataset_plot.scatter(self._ds, *args, **kwargs) @overload - def quiver( + def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -1105,8 +1081,7 @@ def quiver( extend=None, cmap=None, **kwargs: Any, - ) -> Quiver: - ... + ) -> Quiver: ... @overload def quiver( @@ -1142,8 +1117,7 @@ def quiver( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: - ... + ) -> FacetGrid[Dataset]: ... @overload def quiver( @@ -1179,15 +1153,14 @@ def quiver( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: - ... + ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.quiver) - def quiver(self, *args, **kwargs) -> Quiver | FacetGrid: + @functools.wraps(dataset_plot.quiver, assigned=("__doc__",)) + def quiver(self, *args, **kwargs) -> Quiver | FacetGrid[Dataset]: return dataset_plot.quiver(self._ds, *args, **kwargs) @overload - def streamplot( + def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( self, *args: Any, x: Hashable | None = None, @@ -1220,8 +1193,7 @@ def streamplot( extend=None, cmap=None, **kwargs: Any, - ) -> LineCollection: - ... + ) -> LineCollection: ... @overload def streamplot( @@ -1257,8 +1229,7 @@ def streamplot( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: - ... + ) -> FacetGrid[Dataset]: ... @overload def streamplot( @@ -1294,9 +1265,8 @@ def streamplot( extend=None, cmap=None, **kwargs: Any, - ) -> FacetGrid: - ... + ) -> FacetGrid[Dataset]: ... - @functools.wraps(dataset_plot.streamplot) - def streamplot(self, *args, **kwargs) -> LineCollection | FacetGrid: + @functools.wraps(dataset_plot.streamplot, assigned=("__doc__",)) + def streamplot(self, *args, **kwargs) -> LineCollection | FacetGrid[Dataset]: return dataset_plot.streamplot(self._ds, *args, **kwargs) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index a80db91562c..8386161bf29 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -3,7 +3,7 @@ import functools import warnings from collections.abc import Hashable, Iterable, MutableMapping -from typing import TYPE_CHECKING, Any, Callable, Literal, cast, overload +from typing import TYPE_CHECKING, Any, Callable, Literal, Union, cast, overload import numpy as np import pandas as pd @@ -27,9 +27,9 @@ _rescale_imshow_rgb, _resolve_intervals_1dplot, _resolve_intervals_2dplot, + _set_concise_date, _update_axes, get_axis, - import_matplotlib_pyplot, label_from_attrs, ) @@ -40,6 +40,7 @@ from matplotlib.container import BarContainer from matplotlib.contour import QuadContourSet from matplotlib.image import AxesImage + from matplotlib.patches import Polygon from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection from numpy.typing import ArrayLike @@ -53,7 +54,7 @@ ) from xarray.plot.facetgrid import FacetGrid -_styles: MutableMapping[str, Any] = { +_styles: dict[str, Any] = { # Add a white border to make it easier seeing overlapping markers: "scatter.edgecolors": "w", } @@ -186,7 +187,7 @@ def _prepare_plot1d_data( # dimensions so the plotter can plot anything: if darray.ndim > 1: # When stacking dims the lines will continue connecting. For floats - # this can be solved by adding a nan element inbetween the flattening + # this can be solved by adding a nan element in between the flattening # points: dims_T = [] if np.issubdtype(darray.dtype, np.floating): @@ -264,7 +265,9 @@ def plot( -------- xarray.DataArray.squeeze """ - darray = darray.squeeze().compute() + darray = darray.squeeze( + d for d, s in darray.sizes.items() if s == 1 and d not in (row, col, hue) + ).compute() plot_dims = set(darray.dims) plot_dims.discard(row) @@ -307,7 +310,7 @@ def plot( @overload -def line( # type: ignore[misc] # None is hashable :( +def line( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, *args: Any, row: None = None, # no wrap -> primitive @@ -325,18 +328,17 @@ def line( # type: ignore[misc] # None is hashable :( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, -) -> list[Line3D]: - ... +) -> list[Line3D]: ... @overload def line( - darray, + darray: T_DataArray, *args: Any, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, @@ -353,18 +355,17 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[T_DataArray]: ... @overload def line( - darray, + darray: T_DataArray, *args: Any, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid @@ -381,19 +382,18 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[T_DataArray]: ... # This function signature should not change so that it can use # matplotlib format strings def line( - darray: DataArray, + darray: T_DataArray, *args: Any, row: Hashable | None = None, col: Hashable | None = None, @@ -410,12 +410,12 @@ def line( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, add_legend: bool = True, _labels: bool = True, **kwargs: Any, -) -> list[Line3D] | FacetGrid[DataArray]: +) -> list[Line3D] | FacetGrid[T_DataArray]: """ Line plot of DataArray values. @@ -459,7 +459,7 @@ def line( Specifies scaling for the *x*- and *y*-axis, respectively. xticks, yticks : array-like, optional Specify tick locations for *x*- and *y*-axis. - xlim, ylim : array-like, optional + xlim, ylim : tuple[float, float], optional Specify *x*- and *y*-axis limits. add_legend : bool, default: True Add legend with *y* axis coordinates (2D inputs only). @@ -486,8 +486,8 @@ def line( if ndims > 2: raise ValueError( "Line plots are for 1- or 2-dimensional DataArrays. " - "Passed DataArray has {ndims} " - "dimensions".format(ndims=ndims) + f"Passed DataArray has {ndims} " + "dimensions" ) # The allargs dict passed to _easy_facetgrid above contains args @@ -523,14 +523,8 @@ def line( assert hueplt is not None ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) - # Rotate dates on xlabels - # Do this without calling autofmt_xdate so that x-axes ticks - # on other subplots (if any) are not deleted. - # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots if np.issubdtype(xplt.dtype, np.datetime64): - for xlabels in ax.get_xticklabels(): - xlabels.set_rotation(30) - xlabels.set_ha("right") + _set_concise_date(ax, axis="x") _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) @@ -538,7 +532,7 @@ def line( @overload -def step( # type: ignore[misc] # None is hashable :( +def step( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, *args: Any, where: Literal["pre", "post", "mid"] = "pre", @@ -547,8 +541,7 @@ def step( # type: ignore[misc] # None is hashable :( row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive **kwargs: Any, -) -> list[Line3D]: - ... +) -> list[Line3D]: ... @overload @@ -561,8 +554,7 @@ def step( row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[DataArray]: ... @overload @@ -575,8 +567,7 @@ def step( row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[DataArray]: ... def step( @@ -654,10 +645,10 @@ def hist( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, **kwargs: Any, -) -> tuple[np.ndarray, np.ndarray, BarContainer]: +) -> tuple[np.ndarray, np.ndarray, BarContainer | Polygon]: """ Histogram of DataArray. @@ -691,7 +682,7 @@ def hist( Specifies scaling for the *x*- and *y*-axis, respectively. xticks, yticks : array-like, optional Specify tick locations for *x*- and *y*-axis. - xlim, ylim : array-like, optional + xlim, ylim : tuple[float, float], optional Specify *x*- and *y*-axis limits. **kwargs : optional Additional keyword arguments to :py:func:`matplotlib:matplotlib.pyplot.hist`. @@ -708,14 +699,17 @@ def hist( no_nan = np.ravel(darray.to_numpy()) no_nan = no_nan[pd.notnull(no_nan)] - primitive = ax.hist(no_nan, **kwargs) + n, bins, patches = cast( + tuple[np.ndarray, np.ndarray, Union["BarContainer", "Polygon"]], + ax.hist(no_nan, **kwargs), + ) ax.set_title(darray._title_for_slice()) ax.set_xlabel(label_from_attrs(darray)) _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) - return primitive + return n, bins, patches def _plot1d(plotfunc): @@ -733,15 +727,6 @@ def _plot1d(plotfunc): If specified plot 3D and use this coordinate for *z* axis. hue : Hashable or None, optional Dimension or coordinate for which you want multiple lines plotted. - hue_style: {'discrete', 'continuous'} or None, optional - How to use the ``hue`` variable: - - - ``'continuous'`` -- continuous color scale - (default for numeric ``hue`` variables) - - ``'discrete'`` -- a color for each unique value, - using the default color cycle - (default for non-numeric ``hue`` variables) - markersize: Hashable or None, optional scatter only. Variable by which to vary size of scattered points. linewidth: Hashable or None, optional @@ -788,9 +773,9 @@ def _plot1d(plotfunc): Specify tick locations for x-axes. yticks : ArrayLike or None, optional Specify tick locations for y-axes. - xlim : ArrayLike or None, optional + xlim : tuple[float, float] or None, optional Specify x-axes limits. - ylim : ArrayLike or None, optional + ylim : tuple[float, float] or None, optional Specify y-axes limits. cmap : matplotlib colormap name or colormap, optional The mapping from data values to color space. Either a @@ -798,7 +783,7 @@ def _plot1d(plotfunc): be either ``'viridis'`` (if the function infers a sequential dataset) or ``'RdBu_r'`` (if the function infers a diverging dataset). - See :doc:`Choosing Colormaps in Matplotlib ` + See :doc:`Choosing Colormaps in Matplotlib ` for more information. If *seaborn* is installed, ``cmap`` may also be a @@ -875,8 +860,8 @@ def newplotfunc( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, cmap: str | Colormap | None = None, vmin: float | None = None, vmax: float | None = None, @@ -888,7 +873,7 @@ def newplotfunc( # All 1d plots in xarray share this function signature. # Method signature below should be consistent. - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt if subplot_kws is None: subplot_kws = dict() @@ -935,11 +920,30 @@ def newplotfunc( warnings.warn(msg, DeprecationWarning, stacklevel=2) del args + if hue_style is not None: + # TODO: Not used since 2022.10. Deprecated since 2023.07. + warnings.warn( + ( + "hue_style is no longer used for plot1d plots " + "and the argument will eventually be removed. " + "Convert numbers to string for a discrete hue " + "and use add_legend or add_colorbar to control which guide to display." + ), + DeprecationWarning, + stacklevel=2, + ) + _is_facetgrid = kwargs.pop("_is_facetgrid", False) if plotfunc.__name__ == "scatter": size_ = kwargs.pop("_size", markersize) size_r = _MARKERSIZE_RANGE + + # Remove any nulls, .where(m, drop=True) doesn't work when m is + # a dask array, so load the array to memory. + # It will have to be loaded to memory at some point anyway: + darray = darray.load() + darray = darray.where(darray.notnull(), drop=True) else: size_ = kwargs.pop("_size", linewidth) size_r = _LINEWIDTH_RANGE @@ -988,9 +992,13 @@ def newplotfunc( with plt.rc_context(_styles): if z is not None: + import mpl_toolkits + if ax is None: subplot_kws.update(projection="3d") ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) + # Using 30, 30 minimizes rotation of the plot. Making it easier to # build on your intuition from 2D plots: ax.view_init(azim=30, elev=30, vertical_axis="y") @@ -1028,9 +1036,11 @@ def newplotfunc( if add_legend_: if plotfunc.__name__ in ["scatter", "line"]: _add_legend( - hueplt_norm - if add_legend or not add_colorbar_ - else _Normalize(None), + ( + hueplt_norm + if add_legend or not add_colorbar_ + else _Normalize(None) + ), sizeplt_norm, primitive, legend_ax=ax, @@ -1039,9 +1049,7 @@ def newplotfunc( else: hueplt_norm_values: list[np.ndarray | None] if hueplt_norm.data is not None: - hueplt_norm_values = list( - cast("DataArray", hueplt_norm.data).to_numpy() - ) + hueplt_norm_values = list(hueplt_norm.data.to_numpy()) else: hueplt_norm_values = [hueplt_norm.data] @@ -1074,16 +1082,14 @@ def newplotfunc( def _add_labels( add_labels: bool | Iterable[bool], - darrays: Iterable[DataArray], + darrays: Iterable[DataArray | None], suffixes: Iterable[str], - rotate_labels: Iterable[bool], ax: Axes, ) -> None: - # Set x, y, z labels: + """Set x, y, z labels.""" add_labels = [add_labels] * 3 if isinstance(add_labels, bool) else add_labels - for axis, add_label, darray, suffix, rotate_label in zip( - ("x", "y", "z"), add_labels, darrays, suffixes, rotate_labels - ): + axes: tuple[Literal["x", "y", "z"], ...] = ("x", "y", "z") + for axis, add_label, darray, suffix in zip(axes, add_labels, darrays, suffixes): if darray is None: continue @@ -1092,18 +1098,12 @@ def _add_labels( if label is not None: getattr(ax, f"set_{axis}label")(label) - if rotate_label and np.issubdtype(darray.dtype, np.datetime64): - # Rotate dates on xlabels - # Do this without calling autofmt_xdate so that x-axes ticks - # on other subplots (if any) are not deleted. - # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots - for labels in getattr(ax, f"get_{axis}ticklabels")(): - labels.set_rotation(30) - labels.set_ha("right") + if np.issubdtype(darray.dtype, np.datetime64): + _set_concise_date(ax, axis=axis) @overload -def scatter( +def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, *args: Any, x: Hashable | None = None, @@ -1140,13 +1140,12 @@ def scatter( extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, -) -> PathCollection: - ... +) -> PathCollection: ... @overload def scatter( - darray: DataArray, + darray: T_DataArray, *args: Any, x: Hashable | None = None, y: Hashable | None = None, @@ -1182,13 +1181,12 @@ def scatter( extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[T_DataArray]: ... @overload def scatter( - darray: DataArray, + darray: T_DataArray, *args: Any, x: Hashable | None = None, y: Hashable | None = None, @@ -1224,8 +1222,7 @@ def scatter( extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[T_DataArray]: ... @_plot1d @@ -1253,15 +1250,24 @@ def scatter( if sizeplt is not None: kwargs.update(s=sizeplt.to_numpy().ravel()) - axis_order = ["x", "y", "z"] + plts_or_none = (xplt, yplt, zplt) + _add_labels(add_labels, plts_or_none, ("", "", ""), ax) - plts_dict: dict[str, DataArray | None] = dict(x=xplt, y=yplt, z=zplt) - plts_or_none = [plts_dict[v] for v in axis_order] - plts = [p for p in plts_or_none if p is not None] - primitive = ax.scatter(*[p.to_numpy().ravel() for p in plts], **kwargs) - _add_labels(add_labels, plts, ("", "", ""), (True, False, False), ax) + xplt_np = None if xplt is None else xplt.to_numpy().ravel() + yplt_np = None if yplt is None else yplt.to_numpy().ravel() + zplt_np = None if zplt is None else zplt.to_numpy().ravel() + plts_np = tuple(p for p in (xplt_np, yplt_np, zplt_np) if p is not None) - return primitive + if len(plts_np) == 3: + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) + return ax.scatter(xplt_np, yplt_np, zplt_np, **kwargs) + + if len(plts_np) == 2: + return ax.scatter(plts_np[0], plts_np[1], **kwargs) + + raise ValueError("At least two variables required for a scatter plot.") def _plot2d(plotfunc): @@ -1321,14 +1327,14 @@ def _plot2d(plotfunc): The mapping from data values to color space. If not provided, this will be either be ``'viridis'`` (if the function infers a sequential dataset) or ``'RdBu_r'`` (if the function infers a diverging dataset). - See :doc:`Choosing Colormaps in Matplotlib ` + See :doc:`Choosing Colormaps in Matplotlib ` for more information. If *seaborn* is installed, ``cmap`` may also be a `seaborn color palette `_. Note: if ``cmap`` is a seaborn color palette and the plot type is not ``'contour'`` or ``'contourf'``, ``levels`` must also be specified. - center : float, optional + center : float or False, optional The value at which to center the colormap. Passing this value implies use of a diverging colormap. Setting it to ``False`` prevents use of a diverging colormap. @@ -1370,9 +1376,9 @@ def _plot2d(plotfunc): Specify tick locations for x-axes. yticks : ArrayLike or None, optional Specify tick locations for y-axes. - xlim : ArrayLike or None, optional + xlim : tuple[float, float] or None, optional Specify x-axes limits. - ylim : ArrayLike or None, optional + ylim : tuple[float, float] or None, optional Specify y-axes limits. norm : matplotlib.colors.Normalize, optional If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding @@ -1412,7 +1418,7 @@ def newplotfunc( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1425,8 +1431,8 @@ def newplotfunc( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, norm: Normalize | None = None, **kwargs: Any, ) -> Any: @@ -1498,8 +1504,6 @@ def newplotfunc( # TypeError to be consistent with pandas raise TypeError("No numeric data to plot.") - plt = import_matplotlib_pyplot() - if ( plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False) @@ -1612,6 +1616,9 @@ def newplotfunc( ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) ax.set_title(darray._title_for_slice()) if plotfunc.__name__ == "surface": + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) ax.set_zlabel(label_from_attrs(darray)) if add_colorbar: @@ -1632,14 +1639,8 @@ def newplotfunc( ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim ) - # Rotate dates on xlabels - # Do this without calling autofmt_xdate so that x-axes ticks - # on other subplots (if any) are not deleted. - # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots if np.issubdtype(xplt.dtype, np.datetime64): - for xlabels in ax.get_xticklabels(): - xlabels.set_rotation(30) - xlabels.set_ha("right") + _set_concise_date(ax, "x") return primitive @@ -1652,7 +1653,7 @@ def newplotfunc( @overload -def imshow( +def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -1671,7 +1672,7 @@ def imshow( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1688,13 +1689,12 @@ def imshow( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> AxesImage: - ... +) -> AxesImage: ... @overload def imshow( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1712,7 +1712,7 @@ def imshow( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1729,13 +1729,12 @@ def imshow( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[T_DataArray]: ... @overload def imshow( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1753,7 +1752,7 @@ def imshow( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1770,8 +1769,7 @@ def imshow( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[T_DataArray]: ... @_plot2d @@ -1871,7 +1869,7 @@ def _center_pixels(x): @overload -def contour( +def contour( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -1890,7 +1888,7 @@ def contour( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1907,13 +1905,12 @@ def contour( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> QuadContourSet: - ... +) -> QuadContourSet: ... @overload def contour( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1931,7 +1928,7 @@ def contour( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1948,13 +1945,12 @@ def contour( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[T_DataArray]: ... @overload def contour( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -1972,7 +1968,7 @@ def contour( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1989,8 +1985,7 @@ def contour( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[T_DataArray]: ... @_plot2d @@ -2007,7 +2002,7 @@ def contour( @overload -def contourf( +def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -2026,7 +2021,7 @@ def contourf( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2043,13 +2038,12 @@ def contourf( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> QuadContourSet: - ... +) -> QuadContourSet: ... @overload def contourf( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2067,7 +2061,7 @@ def contourf( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2084,13 +2078,12 @@ def contourf( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[T_DataArray]: ... @overload def contourf( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2108,7 +2101,7 @@ def contourf( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2125,8 +2118,7 @@ def contourf( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[T_DataArray]: ... @_plot2d @@ -2143,7 +2135,7 @@ def contourf( @overload -def pcolormesh( +def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, x: Hashable | None = None, y: Hashable | None = None, @@ -2162,7 +2154,7 @@ def pcolormesh( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2179,13 +2171,12 @@ def pcolormesh( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> QuadMesh: - ... +) -> QuadMesh: ... @overload def pcolormesh( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2203,7 +2194,7 @@ def pcolormesh( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2220,13 +2211,12 @@ def pcolormesh( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[T_DataArray]: ... @overload def pcolormesh( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2244,7 +2234,7 @@ def pcolormesh( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2261,8 +2251,7 @@ def pcolormesh( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[T_DataArray]: ... @_plot2d @@ -2349,7 +2338,7 @@ def surface( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2366,13 +2355,12 @@ def surface( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> Poly3DCollection: - ... +) -> Poly3DCollection: ... @overload def surface( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2390,7 +2378,7 @@ def surface( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2407,13 +2395,12 @@ def surface( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[T_DataArray]: ... @overload def surface( - darray: DataArray, + darray: T_DataArray, x: Hashable | None = None, y: Hashable | None = None, *, @@ -2431,7 +2418,7 @@ def surface( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2448,8 +2435,7 @@ def surface( ylim: ArrayLike | None = None, norm: Normalize | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[T_DataArray]: ... @_plot2d @@ -2461,5 +2447,8 @@ def surface( Wraps :py:meth:`matplotlib:mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface`. """ + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) primitive = ax.plot_surface(x, y, z, **kwargs) return primitive diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 0d9898a6e9a..edc2bf43629 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -103,7 +103,7 @@ def _dsplot(plotfunc): be either ``'viridis'`` (if the function infers a sequential dataset) or ``'RdBu_r'`` (if the function infers a diverging dataset). - See :doc:`Choosing Colormaps in Matplotlib ` + See :doc:`Choosing Colormaps in Matplotlib ` for more information. If *seaborn* is installed, ``cmap`` may also be a @@ -128,7 +128,7 @@ def _dsplot(plotfunc): If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding kwarg must be ``None``. infer_intervals: bool | None - If True the intervals are infered. + If True the intervals are inferred. center : float, optional The value at which to center the colormap. Passing this value implies use of a diverging colormap. Setting it to ``False`` prevents use of a @@ -321,7 +321,7 @@ def newplotfunc( @overload -def quiver( +def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, @@ -354,8 +354,7 @@ def quiver( extend: ExtendOptions = None, cmap: str | Colormap | None = None, **kwargs: Any, -) -> Quiver: - ... +) -> Quiver: ... @overload @@ -392,8 +391,7 @@ def quiver( extend: ExtendOptions = None, cmap: str | Colormap | None = None, **kwargs: Any, -) -> FacetGrid[Dataset]: - ... +) -> FacetGrid[Dataset]: ... @overload @@ -430,8 +428,7 @@ def quiver( extend: ExtendOptions = None, cmap: str | Colormap | None = None, **kwargs: Any, -) -> FacetGrid[Dataset]: - ... +) -> FacetGrid[Dataset]: ... @_dsplot @@ -475,7 +472,7 @@ def quiver( @overload -def streamplot( +def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, @@ -508,8 +505,7 @@ def streamplot( extend: ExtendOptions = None, cmap: str | Colormap | None = None, **kwargs: Any, -) -> LineCollection: - ... +) -> LineCollection: ... @overload @@ -546,8 +542,7 @@ def streamplot( extend: ExtendOptions = None, cmap: str | Colormap | None = None, **kwargs: Any, -) -> FacetGrid[Dataset]: - ... +) -> FacetGrid[Dataset]: ... @overload @@ -584,8 +579,7 @@ def streamplot( extend: ExtendOptions = None, cmap: str | Colormap | None = None, **kwargs: Any, -) -> FacetGrid[Dataset]: - ... +) -> FacetGrid[Dataset]: ... @_dsplot @@ -632,7 +626,6 @@ def streamplot( du = du.transpose(ydim, xdim) dv = dv.transpose(ydim, xdim) - args = [dx.values, dy.values, du.values, dv.values] hue = kwargs.pop("hue") cmap_params = kwargs.pop("cmap_params") @@ -646,7 +639,9 @@ def streamplot( ) kwargs.pop("hue_style") - hdl = ax.streamplot(*args, **kwargs, **cmap_params) + hdl = ax.streamplot( + dx.values, dy.values, du.values, dv.values, **kwargs, **cmap_params + ) # Return .lines so colorbar creation works properly return hdl.lines @@ -730,7 +725,7 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr coords = dict(ds.coords) # Add extra coords to the DataArray from valid kwargs, if using all - # kwargs there is a risk that we add unneccessary dataarrays as + # kwargs there is a risk that we add unnecessary dataarrays as # coords straining RAM further for example: # ds.both and extend="both" would add ds.both to the coords: valid_coord_kwargs = {"x", "z", "markersize", "hue", "row", "col", "u", "v"} @@ -748,7 +743,7 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr @overload -def scatter( +def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( ds: Dataset, *args: Any, x: Hashable | None = None, @@ -785,8 +780,7 @@ def scatter( extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs: Any, -) -> PathCollection: - ... +) -> PathCollection: ... @overload @@ -827,8 +821,7 @@ def scatter( extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[DataArray]: ... @overload @@ -869,8 +862,7 @@ def scatter( extend: ExtendOptions = None, levels: ArrayLike | None = None, **kwargs: Any, -) -> FacetGrid[DataArray]: - ... +) -> FacetGrid[DataArray]: ... @_update_doc_to_dataset(dataarray_plot.scatter) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 93a328836d0..faf809a8a74 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -9,7 +9,7 @@ import numpy as np from xarray.core.formatting import format_item -from xarray.core.types import HueStyleOptions, T_Xarray +from xarray.core.types import HueStyleOptions, T_DataArrayOrSet from xarray.plot.utils import ( _LINEWIDTH_RANGE, _MARKERSIZE_RANGE, @@ -21,7 +21,6 @@ _Normalize, _parse_size, _process_cmap_cbar_kwargs, - import_matplotlib_pyplot, label_from_attrs, ) @@ -60,7 +59,7 @@ def _nicetitle(coord, value, maxchar, template): T_FacetGrid = TypeVar("T_FacetGrid", bound="FacetGrid") -class FacetGrid(Generic[T_Xarray]): +class FacetGrid(Generic[T_DataArrayOrSet]): """ Initialize the Matplotlib figure and FacetGrid object. @@ -76,7 +75,7 @@ class FacetGrid(Generic[T_Xarray]): The general approach to plotting here is called "small multiples", where the same kind of plot is repeated multiple times, and the specific use of small multiples to display the same relationship - conditioned on one ore more other variables is often called a "trellis + conditioned on one or more other variables is often called a "trellis plot". The basic workflow is to initialize the :class:`FacetGrid` object with @@ -101,7 +100,7 @@ class FacetGrid(Generic[T_Xarray]): sometimes the rightmost grid positions in the bottom row. """ - data: T_Xarray + data: T_DataArrayOrSet name_dicts: np.ndarray fig: Figure axs: np.ndarray @@ -126,7 +125,7 @@ class FacetGrid(Generic[T_Xarray]): def __init__( self, - data: T_Xarray, + data: T_DataArrayOrSet, col: Hashable | None = None, row: Hashable | None = None, col_wrap: int | None = None, @@ -166,7 +165,7 @@ def __init__( """ - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt # Handle corner case of nonunique coordinates rep_col = col is not None and not data[col].to_index().is_unique @@ -681,7 +680,10 @@ def _finalize_grid(self, *axlabels: Hashable) -> None: def _adjust_fig_for_guide(self, guide) -> None: # Draw the plot to set the bounding boxes correctly - renderer = self.fig.canvas.get_renderer() + if hasattr(self.fig.canvas, "get_renderer"): + renderer = self.fig.canvas.get_renderer() + else: + raise RuntimeError("MPL backend has no renderer") self.fig.draw(renderer) # Calculate and set the new width of the figure so the legend fits @@ -731,6 +733,9 @@ def add_colorbar(self, **kwargs: Any) -> None: if hasattr(self._mappables[-1], "extend"): kwargs.pop("extend", None) if "label" not in kwargs: + from xarray import DataArray + + assert isinstance(self.data, DataArray) kwargs.setdefault("label", label_from_attrs(self.data)) self.cbar = self.fig.colorbar( self._mappables[-1], ax=list(self.axs.flat), **kwargs @@ -985,7 +990,7 @@ def map( self : FacetGrid object """ - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt for ax, namedict in zip(self.axs.flat, self.name_dicts.flat): if namedict is not None: @@ -1004,7 +1009,7 @@ def map( def _easy_facetgrid( - data: T_Xarray, + data: T_DataArrayOrSet, plotfunc: Callable, kind: Literal["line", "dataarray", "dataset", "plot1d"], x: Hashable | None = None, @@ -1020,7 +1025,7 @@ def _easy_facetgrid( ax: Axes | None = None, figsize: Iterable[float] | None = None, **kwargs: Any, -) -> FacetGrid[T_Xarray]: +) -> FacetGrid[T_DataArrayOrSet]: """ Convenience method to call xarray.plot.FacetGrid from 2d plotting methods diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index e807081f838..804e1cfd795 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -6,15 +6,15 @@ from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence from datetime import datetime from inspect import getfullargspec -from typing import TYPE_CHECKING, Any, Callable, overload +from typing import TYPE_CHECKING, Any, Callable, Literal, overload import numpy as np import pandas as pd from xarray.core.indexes import PandasMultiIndex from xarray.core.options import OPTIONS -from xarray.core.pycompat import DuckArrayModule from xarray.core.utils import is_scalar, module_available +from xarray.namedarray.pycompat import DuckArrayModule nc_time_axis_available = module_available("nc_time_axis") @@ -47,14 +47,6 @@ _LINEWIDTH_RANGE = (1.5, 1.5, 6.0) -def import_matplotlib_pyplot(): - """import pyplot""" - # TODO: This function doesn't do anything (after #6109), remove it? - import matplotlib.pyplot as plt - - return plt - - def _determine_extend(calc_data, vmin, vmax): extend_min = calc_data.min() < vmin extend_max = calc_data.max() > vmax @@ -505,28 +497,29 @@ def _maybe_gca(**subplot_kws: Any) -> Axes: return plt.axes(**subplot_kws) -def _get_units_from_attrs(da) -> str: +def _get_units_from_attrs(da: DataArray) -> str: """Extracts and formats the unit/units from a attributes.""" pint_array_type = DuckArrayModule("pint").type units = " [{}]" if isinstance(da.data, pint_array_type): - units = units.format(str(da.data.units)) - elif da.attrs.get("units"): - units = units.format(da.attrs["units"]) - elif da.attrs.get("unit"): - units = units.format(da.attrs["unit"]) - else: - units = "" - return units + return units.format(str(da.data.units)) + if "units" in da.attrs: + return units.format(da.attrs["units"]) + if "unit" in da.attrs: + return units.format(da.attrs["unit"]) + return "" -def label_from_attrs(da, extra: str = "") -> str: +def label_from_attrs(da: DataArray | None, extra: str = "") -> str: """Makes informative labels if variable metadata (attrs) follows CF conventions.""" + if da is None: + return "" + name: str = "{}" - if da.attrs.get("long_name"): + if "long_name" in da.attrs: name = name.format(da.attrs["long_name"]) - elif da.attrs.get("standard_name"): + elif "standard_name" in da.attrs: name = name.format(da.attrs["standard_name"]) elif da.name is not None: name = name.format(da.name) @@ -774,8 +767,8 @@ def _update_axes( yscale: ScaleOptions = None, xticks: ArrayLike | None = None, yticks: ArrayLike | None = None, - xlim: ArrayLike | None = None, - ylim: ArrayLike | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, ) -> None: """ Update axes with provided parameters @@ -1131,7 +1124,7 @@ def _get_color_and_size(value): # Labels are not numerical so modifying label_values is not # possible, instead filter the array with nicely distributed # indexes: - if type(num) == int: + if type(num) == int: # noqa: E721 loc = mpl.ticker.LinearLocator(num) else: raise ValueError("`num` only supports integers for non-numeric labels.") @@ -1166,7 +1159,7 @@ def _get_color_and_size(value): def _legend_add_subtitle(handles, labels, text): """Add a subtitle to legend handles.""" - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt if text and len(handles) > 1: # Create a blank handle that's not visible, the @@ -1184,7 +1177,7 @@ def _legend_add_subtitle(handles, labels, text): def _adjust_legend_subtitles(legend): """Make invisible-handle "subtitles" entries look more like titles.""" - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt # Legend title not in rcParams until 3.0 font_size = plt.rcParams.get("legend.title_fontsize", None) @@ -1298,16 +1291,14 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): def _parse_size( data: None, norm: tuple[float | None, float | None, bool] | Normalize | None, -) -> None: - ... +) -> None: ... @overload def _parse_size( data: DataArray, norm: tuple[float | None, float | None, bool] | Normalize | None, -) -> pd.Series: - ... +) -> pd.Series: ... # copied from seaborn @@ -1438,20 +1429,28 @@ def data_is_numeric(self) -> bool: >>> a = xr.DataArray([0.5, 0, 0, 0.5, 2, 3]) >>> _Normalize(a).data_is_numeric True + + >>> # TODO: Datetime should be numeric right? + >>> a = xr.DataArray(pd.date_range("2000-1-1", periods=4)) + >>> _Normalize(a).data_is_numeric + False + + # TODO: Timedelta should be numeric right? + >>> a = xr.DataArray(pd.timedelta_range("-1D", periods=4, freq="D")) + >>> _Normalize(a).data_is_numeric + True """ return self._data_is_numeric @overload - def _calc_widths(self, y: np.ndarray) -> np.ndarray: - ... + def _calc_widths(self, y: np.ndarray) -> np.ndarray: ... @overload - def _calc_widths(self, y: DataArray) -> DataArray: - ... + def _calc_widths(self, y: DataArray) -> DataArray: ... def _calc_widths(self, y: np.ndarray | DataArray) -> np.ndarray | DataArray: """ - Normalize the values so they're inbetween self._width. + Normalize the values so they're in between self._width. """ if self._width is None: return y @@ -1463,18 +1462,16 @@ def _calc_widths(self, y: np.ndarray | DataArray) -> np.ndarray | DataArray: # Use default with if y is constant: widths = xdefault + 0 * y else: - # Normalize inbetween xmin and xmax: + # Normalize in between xmin and xmax: k = (y - np.min(y)) / diff_maxy_miny widths = xmin + k * (xmax - xmin) return widths @overload - def _indexes_centered(self, x: np.ndarray) -> np.ndarray: - ... + def _indexes_centered(self, x: np.ndarray) -> np.ndarray: ... @overload - def _indexes_centered(self, x: DataArray) -> DataArray: - ... + def _indexes_centered(self, x: DataArray) -> DataArray: ... def _indexes_centered(self, x: np.ndarray | DataArray) -> np.ndarray | DataArray: """ @@ -1492,28 +1489,28 @@ def values(self) -> DataArray | None: -------- >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) >>> _Normalize(a).values - + Size: 40B array([3, 1, 1, 3, 5]) Dimensions without coordinates: dim_0 >>> _Normalize(a, width=(18, 36, 72)).values - + Size: 40B array([45., 18., 18., 45., 72.]) Dimensions without coordinates: dim_0 >>> a = xr.DataArray([0.5, 0, 0, 0.5, 2, 3]) >>> _Normalize(a).values - + Size: 48B array([0.5, 0. , 0. , 0.5, 2. , 3. ]) Dimensions without coordinates: dim_0 >>> _Normalize(a, width=(18, 36, 72)).values - + Size: 48B array([27., 18., 18., 27., 54., 72.]) Dimensions without coordinates: dim_0 >>> _Normalize(a * 0, width=(18, 36, 72)).values - + Size: 48B array([36., 36., 36., 36., 36., 36.]) Dimensions without coordinates: dim_0 @@ -1630,7 +1627,7 @@ def format(self) -> FuncFormatter: >>> aa.format(1) '3.0' """ - plt = import_matplotlib_pyplot() + import matplotlib.pyplot as plt def _func(x: Any, pos: None | Any = None): return f"{self._lookup_arr([x])[0]}" @@ -1811,8 +1808,8 @@ def _guess_coords_to_plot( ) # If dims_plot[k] isn't defined then fill with one of the available dims, unless - # one of related mpl kwargs has been used. This should have similiar behaviour as - # * plt.plot(x, y) -> Multple lines with different colors if y is 2d. + # one of related mpl kwargs has been used. This should have similar behaviour as + # * plt.plot(x, y) -> Multiple lines with different colors if y is 2d. # * plt.plot(x, y, color="red") -> Multiple red lines if y is 2d. for k, dim, ign_kws in zip(default_guess, available_coords, ignore_guess_kwargs): if coords_to_plot.get(k, None) is None and all( @@ -1824,3 +1821,27 @@ def _guess_coords_to_plot( _assert_valid_xy(darray, dim, k) return coords_to_plot + + +def _set_concise_date(ax: Axes, axis: Literal["x", "y", "z"] = "x") -> None: + """ + Use ConciseDateFormatter which is meant to improve the + strings chosen for the ticklabels, and to minimize the + strings used in those tick labels as much as possible. + + https://matplotlib.org/stable/gallery/ticks/date_concise_formatter.html + + Parameters + ---------- + ax : Axes + Figure axes. + axis : Literal["x", "y", "z"], optional + Which axis to make concise. The default is "x". + """ + import matplotlib.dates as mdates + + locator = mdates.AutoDateLocator() + formatter = mdates.ConciseDateFormatter(locator) + _axis = getattr(ax, f"{axis}axis") + _axis.set_major_locator(locator) + _axis.set_major_formatter(formatter) diff --git a/xarray/testing/__init__.py b/xarray/testing/__init__.py new file mode 100644 index 00000000000..ab2f8ba4357 --- /dev/null +++ b/xarray/testing/__init__.py @@ -0,0 +1,23 @@ +from xarray.testing.assertions import ( # noqa: F401 + _assert_dataarray_invariants, + _assert_dataset_invariants, + _assert_indexes_invariants_checks, + _assert_internal_invariants, + _assert_variable_invariants, + _data_allclose_or_equiv, + assert_allclose, + assert_chunks_equal, + assert_duckarray_allclose, + assert_duckarray_equal, + assert_equal, + assert_identical, +) + +__all__ = [ + "assert_allclose", + "assert_chunks_equal", + "assert_duckarray_equal", + "assert_duckarray_allclose", + "assert_equal", + "assert_identical", +] diff --git a/xarray/testing.py b/xarray/testing/assertions.py similarity index 90% rename from xarray/testing.py rename to xarray/testing/assertions.py index b6a88135ee1..6418eb79b8b 100644 --- a/xarray/testing.py +++ b/xarray/testing/assertions.py @@ -1,4 +1,5 @@ """Testing functions exposed to the user API""" + import functools import warnings from collections.abc import Hashable @@ -8,20 +9,12 @@ import pandas as pd from xarray.core import duck_array_ops, formatting, utils +from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes from xarray.core.variable import IndexVariable, Variable -__all__ = ( - "assert_allclose", - "assert_chunks_equal", - "assert_duckarray_equal", - "assert_duckarray_allclose", - "assert_equal", - "assert_identical", -) - def ensure_warnings(func): # sometimes tests elevate warnings to errors @@ -68,9 +61,9 @@ def assert_equal(a, b): Parameters ---------- - a : xarray.Dataset, xarray.DataArray or xarray.Variable + a : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The first object to compare. - b : xarray.Dataset, xarray.DataArray or xarray.Variable + b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The second object to compare. See Also @@ -79,11 +72,15 @@ def assert_equal(a, b): numpy.testing.assert_array_equal """ __tracebackhide__ = True - assert type(a) == type(b) + assert ( + type(a) == type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates) + ) if isinstance(a, (Variable, DataArray)): assert a.equals(b), formatting.diff_array_repr(a, b, "equals") elif isinstance(a, Dataset): assert a.equals(b), formatting.diff_dataset_repr(a, b, "equals") + elif isinstance(a, Coordinates): + assert a.equals(b), formatting.diff_coords_repr(a, b, "equals") else: raise TypeError(f"{type(a)} not supported by assertion comparison") @@ -97,9 +94,9 @@ def assert_identical(a, b): Parameters ---------- - a : xarray.Dataset, xarray.DataArray or xarray.Variable + a : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The first object to compare. - b : xarray.Dataset, xarray.DataArray or xarray.Variable + b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates The second object to compare. See Also @@ -107,7 +104,9 @@ def assert_identical(a, b): assert_equal, assert_allclose, Dataset.equals, DataArray.equals """ __tracebackhide__ = True - assert type(a) == type(b) + assert ( + type(a) == type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates) + ) if isinstance(a, Variable): assert a.identical(b), formatting.diff_array_repr(a, b, "identical") elif isinstance(a, DataArray): @@ -115,6 +114,8 @@ def assert_identical(a, b): assert a.identical(b), formatting.diff_array_repr(a, b, "identical") elif isinstance(a, (Dataset, Variable)): assert a.identical(b), formatting.diff_dataset_repr(a, b, "identical") + elif isinstance(a, Coordinates): + assert a.identical(b), formatting.diff_coords_repr(a, b, "identical") else: raise TypeError(f"{type(a)} not supported by assertion comparison") @@ -355,7 +356,7 @@ def _assert_dataset_invariants(ds: Dataset, check_default_indexes: bool): set(ds._variables), ) - assert type(ds._dims) is dict, ds._dims + assert type(ds._dims) is dict, ds._dims # noqa: E721 assert all(isinstance(v, int) for v in ds._dims.values()), ds._dims var_dims: set[Hashable] = set() for v in ds._variables.values(): @@ -364,14 +365,13 @@ def _assert_dataset_invariants(ds: Dataset, check_default_indexes: bool): assert all( ds._dims[k] == v.sizes[k] for v in ds._variables.values() for k in v.sizes ), (ds._dims, {k: v.sizes for k, v in ds._variables.items()}) - assert all( - isinstance(v, IndexVariable) - for (k, v) in ds._variables.items() - if v.dims == (k,) - ), {k: type(v) for k, v in ds._variables.items() if v.dims == (k,)} - assert all(v.dims == (k,) for (k, v) in ds._variables.items() if k in ds._dims), { - k: v.dims for k, v in ds._variables.items() if k in ds._dims - } + + if check_default_indexes: + assert all( + isinstance(v, IndexVariable) + for (k, v) in ds._variables.items() + if v.dims == (k,) + ), {k: type(v) for k, v in ds._variables.items() if v.dims == (k,)} if ds._indexes is not None: _assert_indexes_invariants_checks( @@ -401,9 +401,11 @@ def _assert_internal_invariants( _assert_dataset_invariants( xarray_obj, check_default_indexes=check_default_indexes ) + elif isinstance(xarray_obj, Coordinates): + _assert_dataset_invariants( + xarray_obj.to_dataset(), check_default_indexes=check_default_indexes + ) else: raise TypeError( - "{} is not a supported type for xarray invariant checks".format( - type(xarray_obj) - ) + f"{type(xarray_obj)} is not a supported type for xarray invariant checks" ) diff --git a/xarray/testing/strategies.py b/xarray/testing/strategies.py new file mode 100644 index 00000000000..c5a7afdf54e --- /dev/null +++ b/xarray/testing/strategies.py @@ -0,0 +1,444 @@ +from collections.abc import Hashable, Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Protocol, Union, overload + +try: + import hypothesis.strategies as st +except ImportError as e: + raise ImportError( + "`xarray.testing.strategies` requires `hypothesis` to be installed." + ) from e + +import hypothesis.extra.numpy as npst +import numpy as np +from hypothesis.errors import InvalidArgument + +import xarray as xr +from xarray.core.types import T_DuckArray + +if TYPE_CHECKING: + from xarray.core.types import _DTypeLikeNested, _ShapeLike + + +__all__ = [ + "supported_dtypes", + "names", + "dimension_names", + "dimension_sizes", + "attrs", + "variables", + "unique_subset_of", +] + + +class ArrayStrategyFn(Protocol[T_DuckArray]): + def __call__( + self, + *, + shape: "_ShapeLike", + dtype: "_DTypeLikeNested", + ) -> st.SearchStrategy[T_DuckArray]: ... + + +def supported_dtypes() -> st.SearchStrategy[np.dtype]: + """ + Generates only those numpy dtypes which xarray can handle. + + Use instead of hypothesis.extra.numpy.scalar_dtypes in order to exclude weirder dtypes such as unicode, byte_string, array, or nested dtypes. + Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows. + + Requires the hypothesis package to be installed. + + See Also + -------- + :ref:`testing.hypothesis`_ + """ + # TODO should this be exposed publicly? + # We should at least decide what the set of numpy dtypes that xarray officially supports is. + return ( + npst.integer_dtypes() + | npst.unsigned_integer_dtypes() + | npst.floating_dtypes() + | npst.complex_number_dtypes() + ) + + +# TODO Generalize to all valid unicode characters once formatting bugs in xarray's reprs are fixed + docs can handle it. +_readable_characters = st.characters( + categories=["L", "N"], max_codepoint=0x017F +) # only use characters within the "Latin Extended-A" subset of unicode + + +def names() -> st.SearchStrategy[str]: + """ + Generates arbitrary string names for dimensions / variables. + + Requires the hypothesis package to be installed. + + See Also + -------- + :ref:`testing.hypothesis`_ + """ + return st.text( + _readable_characters, + min_size=1, + max_size=5, + ) + + +def dimension_names( + *, + min_dims: int = 0, + max_dims: int = 3, +) -> st.SearchStrategy[list[Hashable]]: + """ + Generates an arbitrary list of valid dimension names. + + Requires the hypothesis package to be installed. + + Parameters + ---------- + min_dims + Minimum number of dimensions in generated list. + max_dims + Maximum number of dimensions in generated list. + """ + + return st.lists( + elements=names(), + min_size=min_dims, + max_size=max_dims, + unique=True, + ) + + +def dimension_sizes( + *, + dim_names: st.SearchStrategy[Hashable] = names(), + min_dims: int = 0, + max_dims: int = 3, + min_side: int = 1, + max_side: Union[int, None] = None, +) -> st.SearchStrategy[Mapping[Hashable, int]]: + """ + Generates an arbitrary mapping from dimension names to lengths. + + Requires the hypothesis package to be installed. + + Parameters + ---------- + dim_names: strategy generating strings, optional + Strategy for generating dimension names. + Defaults to the `names` strategy. + min_dims: int, optional + Minimum number of dimensions in generated list. + Default is 1. + max_dims: int, optional + Maximum number of dimensions in generated list. + Default is 3. + min_side: int, optional + Minimum size of a dimension. + Default is 1. + max_side: int, optional + Minimum size of a dimension. + Default is `min_length` + 5. + + See Also + -------- + :ref:`testing.hypothesis`_ + """ + + if max_side is None: + max_side = min_side + 3 + + return st.dictionaries( + keys=dim_names, + values=st.integers(min_value=min_side, max_value=max_side), + min_size=min_dims, + max_size=max_dims, + ) + + +_readable_strings = st.text( + _readable_characters, + max_size=5, +) +_attr_keys = _readable_strings +_small_arrays = npst.arrays( + shape=npst.array_shapes( + max_side=2, + max_dims=2, + ), + dtype=npst.scalar_dtypes(), +) +_attr_values = st.none() | st.booleans() | _readable_strings | _small_arrays + + +def attrs() -> st.SearchStrategy[Mapping[Hashable, Any]]: + """ + Generates arbitrary valid attributes dictionaries for xarray objects. + + The generated dictionaries can potentially be recursive. + + Requires the hypothesis package to be installed. + + See Also + -------- + :ref:`testing.hypothesis`_ + """ + return st.recursive( + st.dictionaries(_attr_keys, _attr_values), + lambda children: st.dictionaries(_attr_keys, children), + max_leaves=3, + ) + + +@st.composite +def variables( + draw: st.DrawFn, + *, + array_strategy_fn: Union[ArrayStrategyFn, None] = None, + dims: Union[ + st.SearchStrategy[Union[Sequence[Hashable], Mapping[Hashable, int]]], + None, + ] = None, + dtype: st.SearchStrategy[np.dtype] = supported_dtypes(), + attrs: st.SearchStrategy[Mapping] = attrs(), +) -> xr.Variable: + """ + Generates arbitrary xarray.Variable objects. + + Follows the basic signature of the xarray.Variable constructor, but allows passing alternative strategies to + generate either numpy-like array data or dimensions. Also allows specifying the shape or dtype of the wrapped array + up front. + + Passing nothing will generate a completely arbitrary Variable (containing a numpy array). + + Requires the hypothesis package to be installed. + + Parameters + ---------- + array_strategy_fn: Callable which returns a strategy generating array-likes, optional + Callable must only accept shape and dtype kwargs, and must generate results consistent with its input. + If not passed the default is to generate a small numpy array with one of the supported_dtypes. + dims: Strategy for generating the dimensions, optional + Can either be a strategy for generating a sequence of string dimension names, + or a strategy for generating a mapping of string dimension names to integer lengths along each dimension. + If provided as a mapping the array shape will be passed to array_strategy_fn. + Default is to generate arbitrary dimension names for each axis in data. + dtype: Strategy which generates np.dtype objects, optional + Will be passed in to array_strategy_fn. + Default is to generate any scalar dtype using supported_dtypes. + Be aware that this default set of dtypes includes some not strictly allowed by the array API standard. + attrs: Strategy which generates dicts, optional + Default is to generate a nested attributes dictionary containing arbitrary strings, booleans, integers, Nones, + and numpy arrays. + + Returns + ------- + variable_strategy + Strategy for generating xarray.Variable objects. + + Raises + ------ + ValueError + If a custom array_strategy_fn returns a strategy which generates an example array inconsistent with the shape + & dtype input passed to it. + + Examples + -------- + Generate completely arbitrary Variable objects backed by a numpy array: + + >>> variables().example() # doctest: +SKIP + + array([43506, -16, -151], dtype=int32) + >>> variables().example() # doctest: +SKIP + + array([[[-10000000., -10000000.], + [-10000000., -10000000.]], + [[-10000000., -10000000.], + [ 0., -10000000.]], + [[ 0., -10000000.], + [-10000000., inf]], + [[ -0., -10000000.], + [-10000000., -0.]]], dtype=float32) + Attributes: + śřĴ: {'ĉ': {'iĥf': array([-30117, -1740], dtype=int16)}} + + Generate only Variable objects with certain dimension names: + + >>> variables(dims=st.just(["a", "b"])).example() # doctest: +SKIP + + array([[ 248, 4294967295, 4294967295], + [2412855555, 3514117556, 4294967295], + [ 111, 4294967295, 4294967295], + [4294967295, 1084434988, 51688], + [ 47714, 252, 11207]], dtype=uint32) + + Generate only Variable objects with certain dimension names and lengths: + + >>> variables(dims=st.just({"a": 2, "b": 1})).example() # doctest: +SKIP + + array([[-1.00000000e+007+3.40282347e+038j], + [-2.75034266e-225+2.22507386e-311j]]) + + See Also + -------- + :ref:`testing.hypothesis`_ + """ + + if not isinstance(dims, st.SearchStrategy) and dims is not None: + raise InvalidArgument( + f"dims must be provided as a hypothesis.strategies.SearchStrategy object (or None), but got type {type(dims)}. " + "To specify fixed contents, use hypothesis.strategies.just()." + ) + if not isinstance(dtype, st.SearchStrategy) and dtype is not None: + raise InvalidArgument( + f"dtype must be provided as a hypothesis.strategies.SearchStrategy object (or None), but got type {type(dtype)}. " + "To specify fixed contents, use hypothesis.strategies.just()." + ) + if not isinstance(attrs, st.SearchStrategy) and attrs is not None: + raise InvalidArgument( + f"attrs must be provided as a hypothesis.strategies.SearchStrategy object (or None), but got type {type(attrs)}. " + "To specify fixed contents, use hypothesis.strategies.just()." + ) + + _array_strategy_fn: ArrayStrategyFn + if array_strategy_fn is None: + # For some reason if I move the default value to the function signature definition mypy incorrectly says the ignore is no longer necessary, making it impossible to satisfy mypy + _array_strategy_fn = npst.arrays # type: ignore[assignment] # npst.arrays has extra kwargs that we aren't using later + elif not callable(array_strategy_fn): + raise InvalidArgument( + "array_strategy_fn must be a Callable that accepts the kwargs dtype and shape and returns a hypothesis " + "strategy which generates corresponding array-like objects." + ) + else: + _array_strategy_fn = ( + array_strategy_fn # satisfy mypy that this new variable cannot be None + ) + + _dtype = draw(dtype) + + if dims is not None: + # generate dims first then draw data to match + _dims = draw(dims) + if isinstance(_dims, Sequence): + dim_names = list(_dims) + valid_shapes = npst.array_shapes(min_dims=len(_dims), max_dims=len(_dims)) + _shape = draw(valid_shapes) + array_strategy = _array_strategy_fn(shape=_shape, dtype=_dtype) + elif isinstance(_dims, (Mapping, dict)): + # should be a mapping of form {dim_names: lengths} + dim_names, _shape = list(_dims.keys()), tuple(_dims.values()) + array_strategy = _array_strategy_fn(shape=_shape, dtype=_dtype) + else: + raise InvalidArgument( + f"Invalid type returned by dims strategy - drew an object of type {type(dims)}" + ) + else: + # nothing provided, so generate everything consistently + # We still generate the shape first here just so that we always pass shape to array_strategy_fn + _shape = draw(npst.array_shapes()) + array_strategy = _array_strategy_fn(shape=_shape, dtype=_dtype) + dim_names = draw(dimension_names(min_dims=len(_shape), max_dims=len(_shape))) + + _data = draw(array_strategy) + + if _data.shape != _shape: + raise ValueError( + "array_strategy_fn returned an array object with a different shape than it was passed." + f"Passed {_shape}, but returned {_data.shape}." + "Please either specify a consistent shape via the dims kwarg or ensure the array_strategy_fn callable " + "obeys the shape argument passed to it." + ) + if _data.dtype != _dtype: + raise ValueError( + "array_strategy_fn returned an array object with a different dtype than it was passed." + f"Passed {_dtype}, but returned {_data.dtype}" + "Please either specify a consistent dtype via the dtype kwarg or ensure the array_strategy_fn callable " + "obeys the dtype argument passed to it." + ) + + return xr.Variable(dims=dim_names, data=_data, attrs=draw(attrs)) + + +@overload +def unique_subset_of( + objs: Sequence[Hashable], + *, + min_size: int = 0, + max_size: Union[int, None] = None, +) -> st.SearchStrategy[Sequence[Hashable]]: ... + + +@overload +def unique_subset_of( + objs: Mapping[Hashable, Any], + *, + min_size: int = 0, + max_size: Union[int, None] = None, +) -> st.SearchStrategy[Mapping[Hashable, Any]]: ... + + +@st.composite +def unique_subset_of( + draw: st.DrawFn, + objs: Union[Sequence[Hashable], Mapping[Hashable, Any]], + *, + min_size: int = 0, + max_size: Union[int, None] = None, +) -> Union[Sequence[Hashable], Mapping[Hashable, Any]]: + """ + Return a strategy which generates a unique subset of the given objects. + + Each entry in the output subset will be unique (if input was a sequence) or have a unique key (if it was a mapping). + + Requires the hypothesis package to be installed. + + Parameters + ---------- + objs: Union[Sequence[Hashable], Mapping[Hashable, Any]] + Objects from which to sample to produce the subset. + min_size: int, optional + Minimum size of the returned subset. Default is 0. + max_size: int, optional + Maximum size of the returned subset. Default is the full length of the input. + If set to 0 the result will be an empty mapping. + + Returns + ------- + unique_subset_strategy + Strategy generating subset of the input. + + Examples + -------- + >>> unique_subset_of({"x": 2, "y": 3}).example() # doctest: +SKIP + {'y': 3} + >>> unique_subset_of(["x", "y"]).example() # doctest: +SKIP + ['x'] + + See Also + -------- + :ref:`testing.hypothesis`_ + """ + if not isinstance(objs, Iterable): + raise TypeError( + f"Object to sample from must be an Iterable or a Mapping, but received type {type(objs)}" + ) + + if len(objs) == 0: + raise ValueError("Can't sample from a length-zero object.") + + keys = list(objs.keys()) if isinstance(objs, Mapping) else objs + + subset_keys = draw( + st.lists( + st.sampled_from(keys), + unique=True, + min_size=min_size, + max_size=max_size, + ) + ) + + return ( + {k: objs[k] for k in subset_keys} if isinstance(objs, Mapping) else subset_keys + ) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 71cd175b99c..e99f9ec3a22 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -2,6 +2,7 @@ import importlib import platform +import string import warnings from contextlib import contextmanager, nullcontext from unittest import mock # noqa: F401 @@ -19,6 +20,7 @@ from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401 from xarray.core.indexing import ExplicitlyIndexed from xarray.core.options import set_options +from xarray.core.variable import IndexVariable from xarray.testing import ( # noqa: F401 assert_chunks_equal, assert_duckarray_allclose, @@ -46,6 +48,15 @@ ) +def assert_writeable(ds): + readonly = [ + name + for name, var in ds.variables.items() + if not isinstance(var, IndexVariable) and not var.data.flags.writeable + ] + assert not readonly, readonly + + def _importorskip( modname: str, minversion: str | None = None ) -> tuple[bool, pytest.MarkDecorator]: @@ -53,36 +64,69 @@ def _importorskip( mod = importlib.import_module(modname) has = True if minversion is not None: - if Version(mod.__version__) < Version(minversion): + v = getattr(mod, "__version__", "999") + if Version(v) < Version(minversion): raise ImportError("Minimum version not satisfied") except ImportError: has = False - func = pytest.mark.skipif(not has, reason=f"requires {modname}") + + reason = f"requires {modname}" + if minversion is not None: + reason += f">={minversion}" + func = pytest.mark.skipif(not has, reason=reason) return has, func has_matplotlib, requires_matplotlib = _importorskip("matplotlib") has_scipy, requires_scipy = _importorskip("scipy") -has_pydap, requires_pydap = _importorskip("pydap.client") +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="'cgi' is deprecated and slated for removal in Python 3.13", + category=DeprecationWarning, + ) + has_pydap, requires_pydap = _importorskip("pydap.client") has_netCDF4, requires_netCDF4 = _importorskip("netCDF4") -has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") +with warnings.catch_warnings(): + # see https://github.com/pydata/xarray/issues/8537 + warnings.filterwarnings( + "ignore", + message="h5py is running against HDF5 1.14.3", + category=UserWarning, + ) + + has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") has_pynio, requires_pynio = _importorskip("Nio") -has_pseudonetcdf, requires_pseudonetcdf = _importorskip("PseudoNetCDF") has_cftime, requires_cftime = _importorskip("cftime") has_dask, requires_dask = _importorskip("dask") +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="The current Dask DataFrame implementation is deprecated.", + category=DeprecationWarning, + ) + has_dask_expr, requires_dask_expr = _importorskip("dask_expr") has_bottleneck, requires_bottleneck = _importorskip("bottleneck") has_rasterio, requires_rasterio = _importorskip("rasterio") has_zarr, requires_zarr = _importorskip("zarr") has_fsspec, requires_fsspec = _importorskip("fsspec") has_iris, requires_iris = _importorskip("iris") -has_numbagg, requires_numbagg = _importorskip("numbagg") -has_seaborn, requires_seaborn = _importorskip("seaborn") +has_numbagg, requires_numbagg = _importorskip("numbagg", "0.4.0") +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="is_categorical_dtype is deprecated and will be removed in a future version.", + category=DeprecationWarning, + ) + # seaborn uses the deprecated `pandas.is_categorical_dtype` + has_seaborn, requires_seaborn = _importorskip("seaborn") has_sparse, requires_sparse = _importorskip("sparse") has_cupy, requires_cupy = _importorskip("cupy") has_cartopy, requires_cartopy = _importorskip("cartopy") has_pint, requires_pint = _importorskip("pint") has_numexpr, requires_numexpr = _importorskip("numexpr") has_flox, requires_flox = _importorskip("flox") +has_pandas_ge_2_2, __ = _importorskip("pandas", "2.2") # some special cases @@ -90,11 +134,21 @@ def _importorskip( requires_scipy_or_netCDF4 = pytest.mark.skipif( not has_scipy_or_netCDF4, reason="requires scipy or netCDF4" ) +has_numbagg_or_bottleneck = has_numbagg or has_bottleneck +requires_numbagg_or_bottleneck = pytest.mark.skipif( + not has_scipy_or_netCDF4, reason="requires scipy or netCDF4" +) # _importorskip does not work for development versions has_pandas_version_two = Version(pd.__version__).major >= 2 requires_pandas_version_two = pytest.mark.skipif( not has_pandas_version_two, reason="requires pandas 2.0.0" ) +has_numpy_array_api, requires_numpy_array_api = _importorskip("numpy", "1.26.0") +has_h5netcdf_ros3, requires_h5netcdf_ros3 = _importorskip("h5netcdf", "1.3.0") + +has_netCDF4_1_6_2_or_above, requires_netCDF4_1_6_2_or_above = _importorskip( + "netCDF4", "1.6.2" +) # change some global options for tests set_options(warn_for_unclosed_files=True) @@ -139,13 +193,18 @@ class UnexpectedDataAccess(Exception): class InaccessibleArray(utils.NDArrayMixin, ExplicitlyIndexed): + """Disallows any loading.""" + def __init__(self, array): self.array = array - def __getitem__(self, key): - raise UnexpectedDataAccess("Tried accessing data.") + def get_duck_array(self): + raise UnexpectedDataAccess("Tried accessing data") + + def __array__(self, dtype: np.typing.DTypeLike = None): + raise UnexpectedDataAccess("Tried accessing data") - def __array__(self): + def __getitem__(self, key): raise UnexpectedDataAccess("Tried accessing data.") @@ -157,6 +216,23 @@ def __getitem__(self, key): return self.array[tuple_idxr] +class DuckArrayWrapper(utils.NDArrayMixin): + """Array-like that prevents casting to array. + Modeled after cupy.""" + + def __init__(self, array: np.ndarray): + self.array = array + + def __getitem__(self, key): + return type(self)(self.array[key]) + + def __array__(self, dtype: np.typing.DTypeLike = None): + raise UnexpectedDataAccess("Tried accessing data") + + def __array_namespace__(self): + """Present to satisfy is_duck_array test.""" + + class ReturnItem: def __getitem__(self, key): return key @@ -185,11 +261,18 @@ def source_ndarray(array): return base +def format_record(record) -> str: + """Format warning record like `FutureWarning('Function will be deprecated...')`""" + return f"{str(record.category)[8:-2]}('{record.message}'))" + + @contextmanager def assert_no_warnings(): with warnings.catch_warnings(record=True) as record: yield record - assert len(record) == 0, "got unexpected warning(s)" + assert ( + len(record) == 0 + ), f"Got {len(record)} unexpected warning(s): {[format_record(r) for r in record]}" # Internal versions of xarray's test functions that validate additional @@ -217,30 +300,43 @@ def assert_allclose(a, b, check_default_indexes=True, **kwargs): xarray.testing._assert_internal_invariants(b, check_default_indexes) -def create_test_data(seed: int | None = None, add_attrs: bool = True) -> Dataset: +_DEFAULT_TEST_DIM_SIZES = (8, 9, 10) + + +def create_test_data( + seed: int | None = None, + add_attrs: bool = True, + dim_sizes: tuple[int, int, int] = _DEFAULT_TEST_DIM_SIZES, +) -> Dataset: rs = np.random.RandomState(seed) _vars = { "var1": ["dim1", "dim2"], "var2": ["dim1", "dim2"], "var3": ["dim3", "dim1"], } - _dims = {"dim1": 8, "dim2": 9, "dim3": 10} + _dims = {"dim1": dim_sizes[0], "dim2": dim_sizes[1], "dim3": dim_sizes[2]} obj = Dataset() obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"])) - obj["dim3"] = ("dim3", list("abcdefghij")) + if _dims["dim3"] > 26: + raise RuntimeError( + f'Not enough letters for filling this dimension size ({_dims["dim3"]})' + ) + obj["dim3"] = ("dim3", list(string.ascii_lowercase[0 : _dims["dim3"]])) obj["time"] = ("time", pd.date_range("2000-01-01", periods=20)) for v, dims in sorted(_vars.items()): data = rs.normal(size=tuple(_dims[d] for d in dims)) obj[v] = (dims, data) if add_attrs: obj[v].attrs = {"foo": "variable"} - obj.coords["numbers"] = ( - "dim3", - np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64"), - ) + + if dim_sizes == _DEFAULT_TEST_DIM_SIZES: + numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64") + else: + numbers_values = np.random.randint(0, 3, _dims["dim3"], dtype="int64") + obj.coords["numbers"] = ("dim3", numbers_values) obj.encoding = {"foo": "bar"} - assert all(obj.data.flags.writeable for obj in obj.variables.values()) + assert_writeable(obj) return obj diff --git a/xarray/tests/conftest.py b/xarray/tests/conftest.py index 6a8cf008f9f..8590c9fb4e7 100644 --- a/xarray/tests/conftest.py +++ b/xarray/tests/conftest.py @@ -1,8 +1,12 @@ +from __future__ import annotations + import numpy as np import pandas as pd import pytest +import xarray as xr from xarray import DataArray, Dataset +from xarray.datatree_.datatree import DataTree from xarray.tests import create_test_data, requires_dask @@ -11,6 +15,21 @@ def backend(request): return request.param +@pytest.fixture(params=["numbagg", "bottleneck", None]) +def compute_backend(request): + if request.param is None: + options = dict(use_bottleneck=False, use_numbagg=False) + elif request.param == "bottleneck": + options = dict(use_bottleneck=True, use_numbagg=False) + elif request.param == "numbagg": + options = dict(use_bottleneck=False, use_numbagg=True) + else: + raise ValueError + + with xr.set_options(**options): + yield request.param + + @pytest.fixture(params=[1]) def ds(request, backend): if request.param == 1: @@ -77,3 +96,105 @@ def da(request, backend): return da else: raise ValueError + + +@pytest.fixture(params=[Dataset, DataArray]) +def type(request): + return request.param + + +@pytest.fixture(params=[1]) +def d(request, backend, type) -> DataArray | Dataset: + """ + For tests which can test either a DataArray or a Dataset. + """ + result: DataArray | Dataset + if request.param == 1: + ds = Dataset( + dict( + a=(["x", "z"], np.arange(24).reshape(2, 12)), + b=(["y", "z"], np.arange(100, 136).reshape(3, 12).astype(np.float64)), + ), + dict( + x=("x", np.linspace(0, 1.0, 2)), + y=range(3), + z=("z", pd.date_range("2000-01-01", periods=12)), + w=("x", ["a", "b"]), + ), + ) + if type == DataArray: + result = ds["a"].assign_coords(w=ds.coords["w"]) + elif type == Dataset: + result = ds + else: + raise ValueError + else: + raise ValueError + + if backend == "dask": + return result.chunk() + elif backend == "numpy": + return result + else: + raise ValueError + + +@pytest.fixture(scope="module") +def create_test_datatree(): + """ + Create a test datatree with this structure: + + + |-- set1 + | |-- + | | Dimensions: () + | | Data variables: + | | a int64 0 + | | b int64 1 + | |-- set1 + | |-- set2 + |-- set2 + | |-- + | | Dimensions: (x: 2) + | | Data variables: + | | a (x) int64 2, 3 + | | b (x) int64 0.1, 0.2 + | |-- set1 + |-- set3 + |-- + | Dimensions: (x: 2, y: 3) + | Data variables: + | a (y) int64 6, 7, 8 + | set0 (x) int64 9, 10 + + The structure has deliberately repeated names of tags, variables, and + dimensions in order to better check for bugs caused by name conflicts. + """ + + def _create_test_datatree(modify=lambda ds: ds): + set1_data = modify(xr.Dataset({"a": 0, "b": 1})) + set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})) + root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})) + + # Avoid using __init__ so we can independently test it + root: DataTree = DataTree(data=root_data) + set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) + DataTree(name="set1", parent=set1) + DataTree(name="set2", parent=set1) + set2: DataTree = DataTree(name="set2", parent=root, data=set2_data) + DataTree(name="set1", parent=set2) + DataTree(name="set3", parent=root) + + return root + + return _create_test_datatree + + +@pytest.fixture(scope="module") +def simple_datatree(create_test_datatree): + """ + Invoke create_test_datatree fixture (callback). + + Returns a DataTree. + """ + return create_test_datatree() diff --git a/xarray/tests/test_accessor_dt.py b/xarray/tests/test_accessor_dt.py index ef91257c4d9..686bce943fa 100644 --- a/xarray/tests/test_accessor_dt.py +++ b/xarray/tests/test_accessor_dt.py @@ -6,6 +6,7 @@ import xarray as xr from xarray.tests import ( + assert_allclose, assert_array_equal, assert_chunks_equal, assert_equal, @@ -23,7 +24,7 @@ def setup(self): data = np.random.rand(10, 10, nt) lons = np.linspace(0, 11, 10) lats = np.linspace(0, 20, 10) - self.times = pd.date_range(start="2000/01/01", freq="H", periods=nt) + self.times = pd.date_range(start="2000/01/01", freq="h", periods=nt) self.data = xr.DataArray( data, @@ -59,6 +60,8 @@ def setup(self): "quarter", "date", "time", + "daysinmonth", + "days_in_month", "is_month_start", "is_month_end", "is_quarter_start", @@ -74,7 +77,18 @@ def test_field_access(self, field) -> None: else: data = getattr(self.times, field) - expected = xr.DataArray(data, name=field, coords=[self.times], dims=["time"]) + if data.dtype.kind != "b" and field not in ("date", "time"): + # pandas 2.0 returns int32 for integer fields now + data = data.astype("int64") + + translations = { + "weekday": "dayofweek", + "daysinmonth": "days_in_month", + "weekofyear": "week", + } + name = translations.get(field, field) + + expected = xr.DataArray(data, name=name, coords=[self.times], dims=["time"]) if field in ["week", "weekofyear"]: with pytest.warns( @@ -84,7 +98,21 @@ def test_field_access(self, field) -> None: else: actual = getattr(self.data.time.dt, field) - assert_equal(expected, actual) + assert expected.dtype == actual.dtype + assert_identical(expected, actual) + + def test_total_seconds(self) -> None: + # Subtract a value in the middle of the range to ensure that some values + # are negative + delta = self.data.time - np.datetime64("2000-01-03") + actual = delta.dt.total_seconds() + expected = xr.DataArray( + np.arange(-48, 52, dtype=np.float64) * 3600, + name="total_seconds", + coords=[self.data.time], + ) + # This works with assert_identical when pandas is >=1.5.0. + assert_allclose(expected, actual) @pytest.mark.parametrize( "field, pandas_field", @@ -117,7 +145,7 @@ def test_not_datetime_type(self) -> None: nontime_data = self.data.copy() int_data = np.arange(len(self.data.time)).astype("int8") nontime_data = nontime_data.assign_coords(time=int_data) - with pytest.raises(TypeError, match=r"dt"): + with pytest.raises(AttributeError, match=r"dt"): nontime_data.time.dt @pytest.mark.filterwarnings("ignore:dt.weekofyear and dt.week have been deprecated") @@ -220,7 +248,9 @@ def test_dask_accessor_method(self, method, parameters) -> None: assert_equal(actual.compute(), expected.compute()) def test_seasons(self) -> None: - dates = pd.date_range(start="2000/01/01", freq="M", periods=12) + dates = xr.date_range( + start="2000/01/01", freq="ME", periods=12, use_cftime=False + ) dates = dates.append(pd.Index([np.datetime64("NaT")])) dates = xr.DataArray(dates) seasons = xr.DataArray( @@ -247,7 +277,7 @@ def test_seasons(self) -> None: "method, parameters", [("floor", "D"), ("ceil", "D"), ("round", "D")] ) def test_accessor_method(self, method, parameters) -> None: - dates = pd.date_range("2014-01-01", "2014-05-01", freq="H") + dates = pd.date_range("2014-01-01", "2014-05-01", freq="h") xdates = xr.DataArray(dates, dims=["time"]) expected = getattr(dates, method)(parameters) actual = getattr(xdates.dt, method)(parameters) @@ -261,7 +291,7 @@ def setup(self): data = np.random.rand(10, 10, nt) lons = np.linspace(0, 11, 10) lats = np.linspace(0, 20, 10) - self.times = pd.timedelta_range(start="1 day", freq="6H", periods=nt) + self.times = pd.timedelta_range(start="1 day", freq="6h", periods=nt) self.data = xr.DataArray( data, @@ -282,7 +312,7 @@ def test_not_datetime_type(self) -> None: nontime_data = self.data.copy() int_data = np.arange(len(self.data.time)).astype("int8") nontime_data = nontime_data.assign_coords(time=int_data) - with pytest.raises(TypeError, match=r"dt"): + with pytest.raises(AttributeError, match=r"dt"): nontime_data.time.dt @pytest.mark.parametrize( @@ -299,7 +329,7 @@ def test_field_access(self, field) -> None: "method, parameters", [("floor", "D"), ("ceil", "D"), ("round", "D")] ) def test_accessor_methods(self, method, parameters) -> None: - dates = pd.timedelta_range(start="1 day", end="30 days", freq="6H") + dates = pd.timedelta_range(start="1 day", end="30 days", freq="6h") xdates = xr.DataArray(dates, dims=["time"]) expected = getattr(dates, method)(parameters) actual = getattr(xdates.dt, method)(parameters) diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index 168d3232f81..e0c9619b4e7 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -279,23 +279,19 @@ def test_case_bytes() -> None: def test_case_str() -> None: # This string includes some unicode characters # that are common case management corner cases - value = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) + value = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.str_) - exp_capitalized = xr.DataArray(["Some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.unicode_) - exp_lowered = xr.DataArray(["some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.unicode_) - exp_swapped = xr.DataArray(["soME WoRD dž SS ᾛ σς FFI⁵å ç ⅰ"]).astype(np.unicode_) - exp_titled = xr.DataArray(["Some Word Dž Ss ᾛ Σς Ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) - exp_uppered = xr.DataArray(["SOME WORD DŽ SS ἫΙ ΣΣ FFI⁵Å Ç Ⅰ"]).astype(np.unicode_) - exp_casefolded = xr.DataArray(["some word dž ss ἣι σσ ffi⁵å ç ⅰ"]).astype( - np.unicode_ - ) + exp_capitalized = xr.DataArray(["Some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.str_) + exp_lowered = xr.DataArray(["some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.str_) + exp_swapped = xr.DataArray(["soME WoRD dž SS ᾛ σς FFI⁵å ç ⅰ"]).astype(np.str_) + exp_titled = xr.DataArray(["Some Word Dž Ss ᾛ Σς Ffi⁵Å Ç Ⅰ"]).astype(np.str_) + exp_uppered = xr.DataArray(["SOME WORD DŽ SS ἫΙ ΣΣ FFI⁵Å Ç Ⅰ"]).astype(np.str_) + exp_casefolded = xr.DataArray(["some word dž ss ἣι σσ ffi⁵å ç ⅰ"]).astype(np.str_) - exp_norm_nfc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) - exp_norm_nfkc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(np.unicode_) - exp_norm_nfd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) - exp_norm_nfkd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype( - np.unicode_ - ) + exp_norm_nfc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.str_) + exp_norm_nfkc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(np.str_) + exp_norm_nfd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.str_) + exp_norm_nfkd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(np.str_) res_capitalized = value.str.capitalize() res_casefolded = value.str.casefold() @@ -680,7 +676,7 @@ def test_extract_extractall_name_collision_raises(dtype) -> None: def test_extract_single_case(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re) @@ -728,7 +724,7 @@ def test_extract_single_case(dtype) -> None: def test_extract_single_nocase(dtype) -> None: pat_str = r"(\w+)?_Xy_\d*" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re, flags=re.IGNORECASE) @@ -770,7 +766,7 @@ def test_extract_single_nocase(dtype) -> None: def test_extract_multi_case(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re) @@ -810,7 +806,7 @@ def test_extract_multi_case(dtype) -> None: def test_extract_multi_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re, flags=re.IGNORECASE) @@ -876,7 +872,7 @@ def test_extract_broadcast(dtype) -> None: def test_extractall_single_single_case(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re) @@ -908,7 +904,7 @@ def test_extractall_single_single_case(dtype) -> None: def test_extractall_single_single_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re, flags=re.I) @@ -937,7 +933,7 @@ def test_extractall_single_single_nocase(dtype) -> None: def test_extractall_single_multi_case(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re) @@ -983,7 +979,7 @@ def test_extractall_single_multi_case(dtype) -> None: def test_extractall_single_multi_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re, flags=re.I) @@ -1030,7 +1026,7 @@ def test_extractall_single_multi_nocase(dtype) -> None: def test_extractall_multi_single_case(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re) @@ -1065,7 +1061,7 @@ def test_extractall_multi_single_case(dtype) -> None: def test_extractall_multi_single_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re, flags=re.I) @@ -1097,7 +1093,7 @@ def test_extractall_multi_single_nocase(dtype) -> None: def test_extractall_multi_multi_case(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re) @@ -1147,7 +1143,7 @@ def test_extractall_multi_multi_case(dtype) -> None: def test_extractall_multi_multi_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re: str | bytes = ( - pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_str if dtype == np.str_ else bytes(pat_str, encoding="UTF-8") ) pat_compiled = re.compile(pat_re, flags=re.I) @@ -3419,12 +3415,12 @@ def test_cat_multi() -> None: values_4 = "" - values_5 = np.array("", dtype=np.unicode_) + values_5 = np.array("", dtype=np.str_) sep = xr.DataArray( [" ", ", "], dims=["ZZ"], - ).astype(np.unicode_) + ).astype(np.str_) expected = xr.DataArray( [ @@ -3440,7 +3436,7 @@ def test_cat_multi() -> None: ], ], dims=["X", "Y", "ZZ"], - ).astype(np.unicode_) + ).astype(np.str_) res = values_1.str.cat(values_2, values_3, values_4, values_5, sep=sep) @@ -3561,7 +3557,7 @@ def test_format_scalar() -> None: values = xr.DataArray( ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) pos0 = 1 pos1 = 1.2 @@ -3574,7 +3570,7 @@ def test_format_scalar() -> None: expected = xr.DataArray( ["1.X.None", "1,1.2,'test','test'", "'test'-X-None"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) res = values.str.format(pos0, pos1, pos2, X=X, Y=Y, ZZ=ZZ, W=W) @@ -3586,7 +3582,7 @@ def test_format_broadcast() -> None: values = xr.DataArray( ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) pos0 = 1 pos1 = 1.2 @@ -3608,7 +3604,7 @@ def test_format_broadcast() -> None: ["'test'-X-None", "'test'-X-None"], ], dims=["X", "YY"], - ).astype(np.unicode_) + ).astype(np.str_) res = values.str.format(pos0, pos1, pos2, X=X, Y=Y, ZZ=ZZ, W=W) @@ -3620,7 +3616,7 @@ def test_mod_scalar() -> None: values = xr.DataArray( ["%s.%s.%s", "%s,%s,%s", "%s-%s-%s"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) pos0 = 1 pos1 = 1.2 @@ -3629,7 +3625,7 @@ def test_mod_scalar() -> None: expected = xr.DataArray( ["1.1.2.2.3", "1,1.2,2.3", "1-1.2-2.3"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) res = values.str % (pos0, pos1, pos2) @@ -3641,7 +3637,7 @@ def test_mod_dict() -> None: values = xr.DataArray( ["%(a)s.%(a)s.%(b)s", "%(b)s,%(c)s,%(b)s", "%(c)s-%(b)s-%(a)s"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) a = 1 b = 1.2 @@ -3650,7 +3646,7 @@ def test_mod_dict() -> None: expected = xr.DataArray( ["1.1.1.2", "1.2,2.3,1.2", "2.3-1.2-1"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) res = values.str % {"a": a, "b": b, "c": c} @@ -3662,7 +3658,7 @@ def test_mod_broadcast_single() -> None: values = xr.DataArray( ["%s_1", "%s_2", "%s_3"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) pos = xr.DataArray( ["2.3", "3.44444"], @@ -3672,7 +3668,7 @@ def test_mod_broadcast_single() -> None: expected = xr.DataArray( [["2.3_1", "3.44444_1"], ["2.3_2", "3.44444_2"], ["2.3_3", "3.44444_3"]], dims=["X", "YY"], - ).astype(np.unicode_) + ).astype(np.str_) res = values.str % pos @@ -3684,7 +3680,7 @@ def test_mod_broadcast_multi() -> None: values = xr.DataArray( ["%s.%s.%s", "%s,%s,%s", "%s-%s-%s"], dims=["X"], - ).astype(np.unicode_) + ).astype(np.str_) pos0 = 1 pos1 = 1.2 @@ -3701,7 +3697,7 @@ def test_mod_broadcast_multi() -> None: ["1-1.2-2.3", "1-1.2-3.44444"], ], dims=["X", "YY"], - ).astype(np.unicode_) + ).astype(np.str_) res = values.str % (pos0, pos1, pos2) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index fddaa120970..a5ffb37a109 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -1,7 +1,5 @@ from __future__ import annotations -import warnings - import pytest import xarray as xr @@ -9,10 +7,19 @@ np = pytest.importorskip("numpy", minversion="1.22") -with warnings.catch_warnings(): - warnings.simplefilter("ignore") - import numpy.array_api as xp # isort:skip - from numpy.array_api._array_object import Array # isort:skip +try: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + import numpy.array_api as xp + from numpy.array_api._array_object import Array +except ImportError: + # for `numpy>=2.0` + xp = pytest.importorskip("array_api_strict") + + from array_api_strict._array_object import Array # type: ignore[no-redef] @pytest.fixture @@ -77,6 +84,22 @@ def test_broadcast(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: assert_equal(a, e) +def test_broadcast_during_arithmetic(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + np_arr2 = xr.DataArray(np.array([1.0, 2.0]), dims="x") + xp_arr2 = xr.DataArray(xp.asarray([1.0, 2.0]), dims="x") + + expected = np_arr * np_arr2 + actual = xp_arr * xp_arr2 + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + expected = np_arr2 * np_arr + actual = xp_arr2 * xp_arr + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + def test_concat(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: np_arr, xp_arr = arrays expected = xr.concat((np_arr, np_arr), dim="x") @@ -115,6 +138,14 @@ def test_stack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: assert_equal(actual, expected) +def test_unstack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + expected = np_arr.stack(z=("x", "y")).unstack() + actual = xp_arr.stack(z=("x", "y")).unstack() + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + def test_where() -> None: np_arr = xr.DataArray(np.array([1, 0]), dims="x") xp_arr = xr.DataArray(xp.asarray([1, 0]), dims="x") diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_assertions.py similarity index 100% rename from xarray/tests/test_testing.py rename to xarray/tests/test_assertions.py diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 1bf72d1243b..3fb137977e8 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -13,11 +13,13 @@ import tempfile import uuid import warnings -from collections.abc import Iterator +from collections.abc import Generator, Iterator from contextlib import ExitStack from io import BytesIO +from os import listdir from pathlib import Path from typing import TYPE_CHECKING, Any, Final, cast +from unittest.mock import patch import numpy as np import pandas as pd @@ -46,13 +48,14 @@ ) from xarray.backends.pydap_ import PydapDataStore from xarray.backends.scipy_ import ScipyBackendEntrypoint +from xarray.coding.cftime_offsets import cftime_range +from xarray.coding.strings import check_vlen_dtype, create_vlen_dtype from xarray.coding.variables import SerializationWarning from xarray.conventions import encode_dataset_coordinates from xarray.core import indexing from xarray.core.options import set_options -from xarray.core.pycompat import array_type +from xarray.namedarray.pycompat import array_type from xarray.tests import ( - arm_xfail, assert_allclose, assert_array_equal, assert_equal, @@ -67,12 +70,12 @@ requires_dask, requires_fsspec, requires_h5netcdf, + requires_h5netcdf_ros3, requires_iris, requires_netCDF4, - requires_pseudonetcdf, + requires_netCDF4_1_6_2_or_above, requires_pydap, requires_pynio, - requires_rasterio, requires_scipy, requires_scipy_or_netCDF4, requires_zarr, @@ -139,96 +142,100 @@ def open_example_mfdataset(names, *args, **kwargs) -> Dataset: ) -def create_masked_and_scaled_data() -> Dataset: - x = np.array([np.nan, np.nan, 10, 10.1, 10.2], dtype=np.float32) +def create_masked_and_scaled_data(dtype: np.dtype) -> Dataset: + x = np.array([np.nan, np.nan, 10, 10.1, 10.2], dtype=dtype) encoding = { "_FillValue": -1, - "add_offset": 10, - "scale_factor": np.float32(0.1), + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), "dtype": "i2", } return Dataset({"x": ("t", x, {}, encoding)}) -def create_encoded_masked_and_scaled_data() -> Dataset: - attributes = {"_FillValue": -1, "add_offset": 10, "scale_factor": np.float32(0.1)} +def create_encoded_masked_and_scaled_data(dtype: np.dtype) -> Dataset: + attributes = { + "_FillValue": -1, + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), + } return Dataset( {"x": ("t", np.array([-1, -1, 0, 1, 2], dtype=np.int16), attributes)} ) -def create_unsigned_masked_scaled_data() -> Dataset: +def create_unsigned_masked_scaled_data(dtype: np.dtype) -> Dataset: encoding = { "_FillValue": 255, "_Unsigned": "true", "dtype": "i1", - "add_offset": 10, - "scale_factor": np.float32(0.1), + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), } - x = np.array([10.0, 10.1, 22.7, 22.8, np.nan], dtype=np.float32) + x = np.array([10.0, 10.1, 22.7, 22.8, np.nan], dtype=dtype) return Dataset({"x": ("t", x, {}, encoding)}) -def create_encoded_unsigned_masked_scaled_data() -> Dataset: +def create_encoded_unsigned_masked_scaled_data(dtype: np.dtype) -> Dataset: # These are values as written to the file: the _FillValue will # be represented in the signed form. attributes = { "_FillValue": -1, "_Unsigned": "true", - "add_offset": 10, - "scale_factor": np.float32(0.1), + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), } # Create unsigned data corresponding to [0, 1, 127, 128, 255] unsigned sb = np.asarray([0, 1, 127, -128, -1], dtype="i1") return Dataset({"x": ("t", sb, attributes)}) -def create_bad_unsigned_masked_scaled_data() -> Dataset: +def create_bad_unsigned_masked_scaled_data(dtype: np.dtype) -> Dataset: encoding = { "_FillValue": 255, "_Unsigned": True, "dtype": "i1", - "add_offset": 10, - "scale_factor": np.float32(0.1), + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), } - x = np.array([10.0, 10.1, 22.7, 22.8, np.nan], dtype=np.float32) + x = np.array([10.0, 10.1, 22.7, 22.8, np.nan], dtype=dtype) return Dataset({"x": ("t", x, {}, encoding)}) -def create_bad_encoded_unsigned_masked_scaled_data() -> Dataset: +def create_bad_encoded_unsigned_masked_scaled_data(dtype: np.dtype) -> Dataset: # These are values as written to the file: the _FillValue will # be represented in the signed form. attributes = { "_FillValue": -1, "_Unsigned": True, - "add_offset": 10, - "scale_factor": np.float32(0.1), + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), } # Create signed data corresponding to [0, 1, 127, 128, 255] unsigned sb = np.asarray([0, 1, 127, -128, -1], dtype="i1") return Dataset({"x": ("t", sb, attributes)}) -def create_signed_masked_scaled_data() -> Dataset: +def create_signed_masked_scaled_data(dtype: np.dtype) -> Dataset: encoding = { "_FillValue": -127, "_Unsigned": "false", "dtype": "i1", - "add_offset": 10, - "scale_factor": np.float32(0.1), + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), } - x = np.array([-1.0, 10.1, 22.7, np.nan], dtype=np.float32) + x = np.array([-1.0, 10.1, 22.7, np.nan], dtype=dtype) return Dataset({"x": ("t", x, {}, encoding)}) -def create_encoded_signed_masked_scaled_data() -> Dataset: +def create_encoded_signed_masked_scaled_data(dtype: np.dtype) -> Dataset: # These are values as written to the file: the _FillValue will # be represented in the signed form. attributes = { "_FillValue": -127, "_Unsigned": "false", - "add_offset": 10, - "scale_factor": np.float32(0.1), + "add_offset": dtype.type(10), + "scale_factor": dtype.type(0.1), } # Create signed data corresponding to [0, 1, 127, 128, 255] unsigned sb = np.asarray([-110, 1, 127, -127], dtype="i1") @@ -430,8 +437,6 @@ def test_dataset_compute(self) -> None: assert_identical(expected, computed) def test_pickle(self) -> None: - if not has_dask: - pytest.xfail("pickling requires dask for SerializableLock") expected = Dataset({"foo": ("x", [42])}) with self.roundtrip(expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: with roundtripped: @@ -442,8 +447,6 @@ def test_pickle(self) -> None: @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") def test_pickle_dataarray(self) -> None: - if not has_dask: - pytest.xfail("pickling requires dask for SerializableLock") expected = Dataset({"foo": ("x", [42])}) with self.roundtrip(expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: with roundtripped: @@ -525,7 +528,6 @@ def test_roundtrip_string_encoded_characters(self) -> None: assert_identical(expected, actual) assert actual["x"].encoding["_Encoding"] == "ascii" - @arm_xfail def test_roundtrip_numpy_datetime_data(self) -> None: times = pd.to_datetime(["2000-01-01", "2000-01-02", "NaT"]) expected = Dataset({"t": ("t", times), "t0": times[0]}) @@ -631,6 +633,11 @@ def test_roundtrip_boolean_dtype(self) -> None: with self.roundtrip(original) as actual: assert_identical(original, actual) assert actual["x"].dtype == "bool" + # this checks for preserving dtype during second roundtrip + # see https://github.com/pydata/xarray/issues/7652#issuecomment-1476956975 + with self.roundtrip(actual) as actual2: + assert_identical(original, actual2) + assert actual2["x"].dtype == "bool" def test_orthogonal_indexing(self) -> None: in_memory = create_test_data() @@ -710,9 +717,6 @@ def multiple_indexing(indexers): ] multiple_indexing(indexers5) - @pytest.mark.xfail( - reason="zarr without dask handles negative steps in slices incorrectly", - ) def test_vectorized_indexing_negative_step(self) -> None: # use dask explicitly when present open_kwargs: dict[str, Any] | None @@ -808,7 +812,7 @@ def test_array_type_after_indexing(self) -> None: def test_dropna(self) -> None: # regression test for GH:issue:1694 a = np.random.randn(4, 3) - a[1, 1] = np.NaN + a[1, 1] = np.nan in_memory = xr.Dataset( {"a": (("y", "x"), a)}, coords={"y": np.arange(4), "x": np.arange(3)} ) @@ -855,6 +859,21 @@ def test_roundtrip_string_with_fill_value_nchar(self) -> None: with self.roundtrip(original) as actual: assert_identical(expected, actual) + def test_roundtrip_empty_vlen_string_array(self) -> None: + # checks preserving vlen dtype for empty arrays GH7862 + dtype = create_vlen_dtype(str) + original = Dataset({"a": np.array([], dtype=dtype)}) + assert check_vlen_dtype(original["a"].dtype) == str + with self.roundtrip(original) as actual: + assert_identical(original, actual) + if np.issubdtype(actual["a"].dtype, object): + # only check metadata for capable backends + # eg. NETCDF3 based backends do not roundtrip metadata + if actual["a"].dtype.metadata is not None: + assert check_vlen_dtype(actual["a"].dtype) == str + else: + assert actual["a"].dtype == np.dtype(" None: (create_masked_and_scaled_data, create_encoded_masked_and_scaled_data), ], ) - def test_roundtrip_mask_and_scale(self, decoded_fn, encoded_fn) -> None: - decoded = decoded_fn() - encoded = encoded_fn() - + @pytest.mark.parametrize("dtype", [np.dtype("float64"), np.dtype("float32")]) + def test_roundtrip_mask_and_scale(self, decoded_fn, encoded_fn, dtype) -> None: + if hasattr(self, "zarr_version") and dtype == np.float32: + pytest.skip("float32 will be treated as float64 in zarr") + decoded = decoded_fn(dtype) + encoded = encoded_fn(dtype) with self.roundtrip(decoded) as actual: for k in decoded.variables: assert decoded.variables[k].dtype == actual.variables[k].dtype @@ -897,7 +918,7 @@ def test_roundtrip_mask_and_scale(self, decoded_fn, encoded_fn) -> None: # make sure roundtrip encoding didn't change the # original dataset. - assert_allclose(encoded, encoded_fn(), decode_bytes=False) + assert_allclose(encoded, encoded_fn(dtype), decode_bytes=False) with self.roundtrip(encoded) as actual: for k in decoded.variables: @@ -1231,6 +1252,11 @@ def test_multiindex_not_implemented(self) -> None: with self.roundtrip(ds): pass + # regression GH8628 (can serialize reset multi-index level coordinates) + ds_reset = ds.reset_index("x") + with self.roundtrip(ds_reset) as actual: + assert_identical(actual, ds_reset) + class NetCDFBase(CFEncodedBase): """Tests for all netCDF3 and netCDF4 backends.""" @@ -1358,32 +1384,39 @@ def test_write_groups(self) -> None: with self.open(tmp_file, group="data/2") as actual2: assert_identical(data2, actual2) - def test_encoding_kwarg_vlen_string(self) -> None: - for input_strings in [[b"foo", b"bar", b"baz"], ["foo", "bar", "baz"]]: - original = Dataset({"x": input_strings}) - expected = Dataset({"x": ["foo", "bar", "baz"]}) - kwargs = dict(encoding={"x": {"dtype": str}}) - with self.roundtrip(original, save_kwargs=kwargs) as actual: - assert actual["x"].encoding["dtype"] is str - assert_identical(actual, expected) - - def test_roundtrip_string_with_fill_value_vlen(self) -> None: + @pytest.mark.parametrize( + "input_strings, is_bytes", + [ + ([b"foo", b"bar", b"baz"], True), + (["foo", "bar", "baz"], False), + (["foó", "bár", "baź"], False), + ], + ) + def test_encoding_kwarg_vlen_string( + self, input_strings: list[str], is_bytes: bool + ) -> None: + original = Dataset({"x": input_strings}) + + expected_string = ["foo", "bar", "baz"] if is_bytes else input_strings + expected = Dataset({"x": expected_string}) + kwargs = dict(encoding={"x": {"dtype": str}}) + with self.roundtrip(original, save_kwargs=kwargs) as actual: + assert actual["x"].encoding["dtype"] == " None: values = np.array(["ab", "cdef", np.nan], dtype=object) expected = Dataset({"x": ("t", values)}) - # netCDF4-based backends don't support an explicit fillvalue - # for variable length strings yet. - # https://github.com/Unidata/netcdf4-python/issues/730 - # https://github.com/h5netcdf/h5netcdf/issues/37 - original = Dataset({"x": ("t", values, {}, {"_FillValue": "XXX"})}) - with pytest.raises(NotImplementedError): - with self.roundtrip(original) as actual: - assert_identical(expected, actual) + original = Dataset({"x": ("t", values, {}, {"_FillValue": fill_value})}) + with self.roundtrip(original) as actual: + assert_identical(expected, actual) original = Dataset({"x": ("t", values, {}, {"_FillValue": ""})}) - with pytest.raises(NotImplementedError): - with self.roundtrip(original) as actual: - assert_identical(expected, actual) + with self.roundtrip(original) as actual: + assert_identical(expected, actual) def test_roundtrip_character_array(self) -> None: with create_tmp_file() as tmp_file: @@ -1462,7 +1495,7 @@ def test_dump_and_open_encodings(self) -> None: assert ds.variables["time"].getncattr("units") == units assert_array_equal(ds.variables["time"], np.arange(10) + 4) - def test_compression_encoding(self) -> None: + def test_compression_encoding_legacy(self) -> None: data = create_test_data() data["var2"].encoding.update( { @@ -1516,6 +1549,83 @@ def test_keep_chunksizes_if_no_original_shape(self) -> None: ds["x"].encoding["chunksizes"], actual["x"].encoding["chunksizes"] ) + def test_preferred_chunks_is_present(self) -> None: + ds = Dataset({"x": [1, 2, 3]}) + chunksizes = (2,) + ds.variables["x"].encoding = {"chunksizes": chunksizes} + + with self.roundtrip(ds) as actual: + assert actual["x"].encoding["preferred_chunks"] == {"x": 2} + + @requires_dask + def test_auto_chunking_is_based_on_disk_chunk_sizes(self) -> None: + x_size = y_size = 1000 + y_chunksize = y_size + x_chunksize = 10 + + with dask.config.set({"array.chunk-size": "100KiB"}): + with self.chunked_roundtrip( + (1, y_size, x_size), + (1, y_chunksize, x_chunksize), + open_kwargs={"chunks": "auto"}, + ) as ds: + t_chunks, y_chunks, x_chunks = ds["image"].data.chunks + assert all(np.asanyarray(y_chunks) == y_chunksize) + # Check that the chunk size is a multiple of the file chunk size + assert all(np.asanyarray(x_chunks) % x_chunksize == 0) + + @requires_dask + def test_base_chunking_uses_disk_chunk_sizes(self) -> None: + x_size = y_size = 1000 + y_chunksize = y_size + x_chunksize = 10 + + with self.chunked_roundtrip( + (1, y_size, x_size), + (1, y_chunksize, x_chunksize), + open_kwargs={"chunks": {}}, + ) as ds: + for chunksizes, expected in zip( + ds["image"].data.chunks, (1, y_chunksize, x_chunksize) + ): + assert all(np.asanyarray(chunksizes) == expected) + + @contextlib.contextmanager + def chunked_roundtrip( + self, + array_shape: tuple[int, int, int], + chunk_sizes: tuple[int, int, int], + open_kwargs: dict[str, Any] | None = None, + ) -> Generator[Dataset, None, None]: + t_size, y_size, x_size = array_shape + t_chunksize, y_chunksize, x_chunksize = chunk_sizes + + image = xr.DataArray( + np.arange(t_size * x_size * y_size, dtype=np.int16).reshape( + (t_size, y_size, x_size) + ), + dims=["t", "y", "x"], + ) + image.encoding = {"chunksizes": (t_chunksize, y_chunksize, x_chunksize)} + dataset = xr.Dataset(dict(image=image)) + + with self.roundtrip(dataset, open_kwargs=open_kwargs) as ds: + yield ds + + def test_preferred_chunks_are_disk_chunk_sizes(self) -> None: + x_size = y_size = 1000 + y_chunksize = y_size + x_chunksize = 10 + + with self.chunked_roundtrip( + (1, y_size, x_size), (1, y_chunksize, x_chunksize) + ) as ds: + assert ds["image"].encoding["preferred_chunks"] == { + "t": 1, + "y": y_chunksize, + "x": x_chunksize, + } + def test_encoding_chunksizes_unlimited(self) -> None: # regression test for GH1225 ds = Dataset({"x": [1, 2, 3], "y": ("x", [2, 3, 4])}) @@ -1541,6 +1651,7 @@ def test_mask_and_scale(self) -> None: v.add_offset = 10 v.scale_factor = 0.1 v[:] = np.array([-1, -1, 0, 1, 2]) + dtype = type(v.scale_factor) # first make sure netCDF4 reads the masked and scaled data # correctly @@ -1553,7 +1664,7 @@ def test_mask_and_scale(self) -> None: # now check xarray with open_dataset(tmp_file) as ds: - expected = create_masked_and_scaled_data() + expected = create_masked_and_scaled_data(np.dtype(dtype)) assert_identical(expected, ds) def test_0dimensional_variable(self) -> None: @@ -1592,6 +1703,140 @@ def test_encoding_unlimited_dims(self) -> None: assert actual.encoding["unlimited_dims"] == set("y") assert_equal(ds, actual) + def test_raise_on_forward_slashes_in_names(self) -> None: + # test for forward slash in variable names and dimensions + # see GH 7943 + data_vars: list[dict[str, Any]] = [ + {"PASS/FAIL": (["PASSFAIL"], np.array([0]))}, + {"PASS/FAIL": np.array([0])}, + {"PASSFAIL": (["PASS/FAIL"], np.array([0]))}, + ] + for dv in data_vars: + ds = Dataset(data_vars=dv) + with pytest.raises(ValueError, match="Forward slashes '/' are not allowed"): + with self.roundtrip(ds): + pass + + @requires_netCDF4 + def test_encoding_enum__no_fill_value(self): + with create_tmp_file() as tmp_file: + cloud_type_dict = {"clear": 0, "cloudy": 1} + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("time", size=2) + cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) + v = nc.createVariable( + "clouds", + cloud_type, + "time", + fill_value=None, + ) + v[:] = 1 + with open_dataset(tmp_file) as original: + save_kwargs = {} + if self.engine == "h5netcdf": + save_kwargs["invalid_netcdf"] = True + with self.roundtrip(original, save_kwargs=save_kwargs) as actual: + assert_equal(original, actual) + assert ( + actual.clouds.encoding["dtype"].metadata["enum"] + == cloud_type_dict + ) + if self.engine != "h5netcdf": + # not implemented in h5netcdf yet + assert ( + actual.clouds.encoding["dtype"].metadata["enum_name"] + == "cloud_type" + ) + + @requires_netCDF4 + def test_encoding_enum__multiple_variable_with_enum(self): + with create_tmp_file() as tmp_file: + cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255} + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("time", size=2) + cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) + nc.createVariable( + "clouds", + cloud_type, + "time", + fill_value=255, + ) + nc.createVariable( + "tifa", + cloud_type, + "time", + fill_value=255, + ) + with open_dataset(tmp_file) as original: + save_kwargs = {} + if self.engine == "h5netcdf": + save_kwargs["invalid_netcdf"] = True + with self.roundtrip(original, save_kwargs=save_kwargs) as actual: + assert_equal(original, actual) + assert ( + actual.clouds.encoding["dtype"] == actual.tifa.encoding["dtype"] + ) + assert ( + actual.clouds.encoding["dtype"].metadata + == actual.tifa.encoding["dtype"].metadata + ) + assert ( + actual.clouds.encoding["dtype"].metadata["enum"] + == cloud_type_dict + ) + if self.engine != "h5netcdf": + # not implemented in h5netcdf yet + assert ( + actual.clouds.encoding["dtype"].metadata["enum_name"] + == "cloud_type" + ) + + @requires_netCDF4 + def test_encoding_enum__error_multiple_variable_with_changing_enum(self): + """ + Given 2 variables, if they share the same enum type, + the 2 enum definition should be identical. + """ + with create_tmp_file() as tmp_file: + cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255} + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("time", size=2) + cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) + nc.createVariable( + "clouds", + cloud_type, + "time", + fill_value=255, + ) + nc.createVariable( + "tifa", + cloud_type, + "time", + fill_value=255, + ) + with open_dataset(tmp_file) as original: + assert ( + original.clouds.encoding["dtype"].metadata + == original.tifa.encoding["dtype"].metadata + ) + modified_enum = original.clouds.encoding["dtype"].metadata["enum"] + modified_enum.update({"neblig": 2}) + original.clouds.encoding["dtype"] = np.dtype( + "u1", + metadata={"enum": modified_enum, "enum_name": "cloud_type"}, + ) + if self.engine != "h5netcdf": + # not implemented yet in h5netcdf + with pytest.raises( + ValueError, + match=( + "Cannot save variable .*" + " because an enum `cloud_type` already exists in the Dataset .*" + ), + ): + with self.roundtrip(original): + pass + @requires_netCDF4 class TestNetCDF4Data(NetCDF4Base): @@ -1652,6 +1897,74 @@ def test_setncattr_string(self) -> None: assert_array_equal(one_element_list_of_strings, totest.attrs["bar"]) assert one_string == totest.attrs["baz"] + @pytest.mark.parametrize( + "compression", + [ + None, + "zlib", + "szip", + "zstd", + "blosc_lz", + "blosc_lz4", + "blosc_lz4hc", + "blosc_zlib", + "blosc_zstd", + ], + ) + @requires_netCDF4_1_6_2_or_above + @pytest.mark.xfail(ON_WINDOWS, reason="new compression not yet implemented") + def test_compression_encoding(self, compression: str | None) -> None: + data = create_test_data(dim_sizes=(20, 80, 10)) + encoding_params: dict[str, Any] = dict(compression=compression, blosc_shuffle=1) + data["var2"].encoding.update(encoding_params) + data["var2"].encoding.update( + { + "chunksizes": (20, 40), + "original_shape": data.var2.shape, + "blosc_shuffle": 1, + "fletcher32": False, + } + ) + with self.roundtrip(data) as actual: + expected_encoding = data["var2"].encoding.copy() + # compression does not appear in the retrieved encoding, that differs + # from the input encoding. shuffle also chantges. Here we modify the + # expected encoding to account for this + compression = expected_encoding.pop("compression") + blosc_shuffle = expected_encoding.pop("blosc_shuffle") + if compression is not None: + if "blosc" in compression and blosc_shuffle: + expected_encoding["blosc"] = { + "compressor": compression, + "shuffle": blosc_shuffle, + } + expected_encoding["shuffle"] = False + elif compression == "szip": + expected_encoding["szip"] = { + "coding": "nn", + "pixels_per_block": 8, + } + expected_encoding["shuffle"] = False + else: + # This will set a key like zlib=true which is what appears in + # the encoding when we read it. + expected_encoding[compression] = True + if compression == "zstd": + expected_encoding["shuffle"] = False + else: + expected_encoding["shuffle"] = False + + actual_encoding = actual["var2"].encoding + assert expected_encoding.items() <= actual_encoding.items() + if ( + encoding_params["compression"] is not None + and "blosc" not in encoding_params["compression"] + ): + # regression test for #156 + expected = data.isel(dim1=0) + with self.roundtrip(expected) as actual: + assert_equal(expected, actual) + @pytest.mark.skip(reason="https://github.com/Unidata/netcdf4-python/issues/1195") def test_refresh_from_disk(self) -> None: super().test_refresh_from_disk() @@ -1733,8 +2046,8 @@ def test_unsorted_index_raises(self) -> None: # dask first pulls items by block. pass + @pytest.mark.skip(reason="caching behavior differs for dask") def test_dataset_caching(self) -> None: - # caching behavior differs for dask pass def test_write_inconsistent_chunks(self) -> None: @@ -1827,9 +2140,14 @@ def test_non_existent_store(self) -> None: def test_with_chunkstore(self) -> None: expected = create_test_data() - with self.create_zarr_target() as store_target, self.create_zarr_target() as chunk_store: + with ( + self.create_zarr_target() as store_target, + self.create_zarr_target() as chunk_store, + ): save_kwargs = {"chunk_store": chunk_store} self.save(expected, store_target, **save_kwargs) + # the chunk store must have been populated with some entries + assert len(chunk_store) > 0 open_kwargs = {"backend_kwargs": {"chunk_store": chunk_store}} with self.open(store_target, **open_kwargs) as ds: assert_equal(ds, expected) @@ -1853,7 +2171,7 @@ def test_auto_chunk(self) -> None: assert v.chunks == original[k].chunks @requires_dask - @pytest.mark.filterwarnings("ignore:The specified Dask chunks separate") + @pytest.mark.filterwarnings("ignore:The specified chunks separate:UserWarning") def test_manual_chunk(self) -> None: original = create_test_data().chunk({"dim1": 3, "dim2": 4, "dim3": 3}) @@ -1963,6 +2281,10 @@ def test_chunk_encoding(self) -> None: pass @requires_dask + @pytest.mark.skipif( + ON_WINDOWS, + reason="Very flaky on Windows CI. Can re-enable assuming it starts consistently passing.", + ) def test_chunk_encoding_with_dask(self) -> None: # These datasets DO have dask chunks. Need to check for various # interactions between dask and zarr chunks @@ -2046,10 +2368,10 @@ def test_chunk_encoding_with_dask(self) -> None: pass def test_drop_encoding(self): - ds = open_example_dataset("example_1.nc") - encodings = {v: {**ds[v].encoding} for v in ds.data_vars} - with self.create_zarr_target() as store: - ds.to_zarr(store, encoding=encodings) + with open_example_dataset("example_1.nc") as ds: + encodings = {v: {**ds[v].encoding} for v in ds.data_vars} + with self.create_zarr_target() as store: + ds.to_zarr(store, encoding=encodings) def test_hidden_zarr_keys(self) -> None: expected = create_test_data() @@ -2150,9 +2472,6 @@ def test_encoding_kwarg_fixed_width_string(self) -> None: # not relevant for zarr, since we don't use EncodedStringCoder pass - # TODO: someone who understand caching figure out whether caching - # makes sense for Zarr backend - @pytest.mark.xfail(reason="Zarr caching not implemented") def test_dataset_caching(self) -> None: super().test_dataset_caching() @@ -2276,6 +2595,33 @@ def test_append_with_new_variable(self) -> None: xr.open_dataset(store_target, engine="zarr", **self.version_kwargs), ) + def test_append_with_append_dim_no_overwrite(self) -> None: + ds, ds_to_append, _ = create_append_test_data() + with self.create_zarr_target() as store_target: + ds.to_zarr(store_target, mode="w", **self.version_kwargs) + original = xr.concat([ds, ds_to_append], dim="time") + original2 = xr.concat([original, ds_to_append], dim="time") + + # overwrite a coordinate; + # for mode='a-', this will not get written to the store + # because it does not have the append_dim as a dim + lon = ds_to_append.lon.to_numpy().copy() + lon[:] = -999 + ds_to_append["lon"] = lon + ds_to_append.to_zarr( + store_target, mode="a-", append_dim="time", **self.version_kwargs + ) + actual = xr.open_dataset(store_target, engine="zarr", **self.version_kwargs) + assert_identical(original, actual) + + # by default, mode="a" will overwrite all coordinates. + ds_to_append.to_zarr(store_target, append_dim="time", **self.version_kwargs) + actual = xr.open_dataset(store_target, engine="zarr", **self.version_kwargs) + lon = original2.lon.to_numpy().copy() + lon[:] = -999 + original2["lon"] = lon + assert_identical(original2, actual) + @requires_dask def test_to_zarr_compute_false_roundtrip(self) -> None: from dask.delayed import Delayed @@ -2353,7 +2699,8 @@ def test_no_warning_from_open_emptydim_with_chunks(self) -> None: @pytest.mark.parametrize("consolidated", [False, True, None]) @pytest.mark.parametrize("compute", [False, True]) @pytest.mark.parametrize("use_dask", [False, True]) - def test_write_region(self, consolidated, compute, use_dask) -> None: + @pytest.mark.parametrize("write_empty", [False, True, None]) + def test_write_region(self, consolidated, compute, use_dask, write_empty) -> None: if (use_dask or not compute) and not has_dask: pytest.skip("requires dask") if consolidated and self.zarr_version > 2: @@ -2385,6 +2732,7 @@ def test_write_region(self, consolidated, compute, use_dask) -> None: store, region=region, consolidated=consolidated, + write_empty_chunks=write_empty, **self.version_kwargs, ) with xr.open_zarr( @@ -2470,7 +2818,7 @@ def setup_and_verify_store(expected=data): with pytest.raises( ValueError, match=re.escape( - "cannot set region unless mode='a', mode='r+' or mode=None" + "cannot set region unless mode='a', mode='a-', mode='r+' or mode=None" ), ): data.to_zarr( @@ -2580,10 +2928,10 @@ def test_write_read_select_write(self) -> None: ds.to_zarr(initial_store, mode="w", **self.version_kwargs) ds1 = xr.open_zarr(initial_store, **self.version_kwargs) - # Combination of where+squeeze triggers error on write. - ds_sel = ds1.where(ds1.coords["dim3"] == "a", drop=True).squeeze("dim3") - with self.create_zarr_target() as final_store: - ds_sel.to_zarr(final_store, mode="w", **self.version_kwargs) + # Combination of where+squeeze triggers error on write. + ds_sel = ds1.where(ds1.coords["dim3"] == "a", drop=True).squeeze("dim3") + with self.create_zarr_target() as final_store: + ds_sel.to_zarr(final_store, mode="w", **self.version_kwargs) @pytest.mark.parametrize("obj", [Dataset(), DataArray(name="foo")]) def test_attributes(self, obj) -> None: @@ -2601,6 +2949,36 @@ def test_attributes(self, obj) -> None: with pytest.raises(TypeError, match=r"Invalid attribute in Dataset.attrs."): ds.to_zarr(store_target, **self.version_kwargs) + @requires_dask + @pytest.mark.parametrize("dtype", ["datetime64[ns]", "timedelta64[ns]"]) + def test_chunked_datetime64_or_timedelta64(self, dtype) -> None: + # Generalized from @malmans2's test in PR #8253 + original = create_test_data().astype(dtype).chunk(1) + with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual: + for name, actual_var in actual.variables.items(): + assert original[name].chunks == actual_var.chunks + assert original.chunks == actual.chunks + + @requires_cftime + @requires_dask + def test_chunked_cftime_datetime(self) -> None: + # Based on @malmans2's test in PR #8253 + times = cftime_range("2000", freq="D", periods=3) + original = xr.Dataset(data_vars={"chunked_times": (["time"], times)}) + original = original.chunk({"time": 1}) + with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual: + for name, actual_var in actual.variables.items(): + assert original[name].chunks == actual_var.chunks + assert original.chunks == actual.chunks + + def test_vectorized_indexing_negative_step(self) -> None: + if not has_dask: + pytest.xfail( + reason="zarr without dask handles negative steps in slices incorrectly" + ) + + super().test_vectorized_indexing_negative_step() + @requires_zarr class TestZarrDictStore(ZarrBase): @@ -2613,6 +2991,10 @@ def create_zarr_target(self): @requires_zarr +@pytest.mark.skipif( + ON_WINDOWS, + reason="Very flaky on Windows CI. Can re-enable assuming it starts consistently passing.", +) class TestZarrDirectoryStore(ZarrBase): @contextlib.contextmanager def create_zarr_target(self): @@ -2630,6 +3012,126 @@ def create_store(self): yield group +@requires_zarr +class TestZarrWriteEmpty(TestZarrDirectoryStore): + @contextlib.contextmanager + def temp_dir(self) -> Iterator[tuple[str, str]]: + with tempfile.TemporaryDirectory() as d: + store = os.path.join(d, "test.zarr") + yield d, store + + @contextlib.contextmanager + def roundtrip_dir( + self, + data, + store, + save_kwargs=None, + open_kwargs=None, + allow_cleanup_failure=False, + ) -> Iterator[Dataset]: + if save_kwargs is None: + save_kwargs = {} + if open_kwargs is None: + open_kwargs = {} + + data.to_zarr(store, **save_kwargs, **self.version_kwargs) + with xr.open_dataset( + store, engine="zarr", **open_kwargs, **self.version_kwargs + ) as ds: + yield ds + + @pytest.mark.parametrize("consolidated", [True, False, None]) + @pytest.mark.parametrize("write_empty", [True, False, None]) + def test_write_empty( + self, consolidated: bool | None, write_empty: bool | None + ) -> None: + if write_empty is False: + expected = ["0.1.0", "1.1.0"] + else: + expected = [ + "0.0.0", + "0.0.1", + "0.1.0", + "0.1.1", + "1.0.0", + "1.0.1", + "1.1.0", + "1.1.1", + ] + + ds = xr.Dataset( + data_vars={ + "test": ( + ("Z", "Y", "X"), + np.array([np.nan, np.nan, 1.0, np.nan]).reshape((1, 2, 2)), + ) + } + ) + + if has_dask: + ds["test"] = ds["test"].chunk(1) + encoding = None + else: + encoding = {"test": {"chunks": (1, 1, 1)}} + + with self.temp_dir() as (d, store): + ds.to_zarr( + store, + mode="w", + encoding=encoding, + write_empty_chunks=write_empty, + ) + + with self.roundtrip_dir( + ds, + store, + {"mode": "a", "append_dim": "Z", "write_empty_chunks": write_empty}, + ) as a_ds: + expected_ds = xr.concat([ds, ds], dim="Z") + + assert_identical(a_ds, expected_ds) + + ls = listdir(os.path.join(store, "test")) + assert set(expected) == set([file for file in ls if file[0] != "."]) + + def test_avoid_excess_metadata_calls(self) -> None: + """Test that chunk requests do not trigger redundant metadata requests. + + This test targets logic in backends.zarr.ZarrArrayWrapper, asserting that calls + to retrieve chunk data after initialization do not trigger additional + metadata requests. + + https://github.com/pydata/xarray/issues/8290 + """ + + import zarr + + ds = xr.Dataset(data_vars={"test": (("Z",), np.array([123]).reshape(1))}) + + # The call to retrieve metadata performs a group lookup. We patch Group.__getitem__ + # so that we can inspect calls to this method - specifically count of calls. + # Use of side_effect means that calls are passed through to the original method + # rather than a mocked method. + Group = zarr.hierarchy.Group + with ( + self.create_zarr_target() as store, + patch.object( + Group, "__getitem__", side_effect=Group.__getitem__, autospec=True + ) as mock, + ): + ds.to_zarr(store, mode="w") + + # We expect this to request array metadata information, so call_count should be == 1, + xrds = xr.open_zarr(store) + call_count = mock.call_count + assert call_count == 1 + + # compute() requests array data, which should not trigger additional metadata requests + # we assert that the number of calls has not increased after fetchhing the array + xrds.test.compute(scheduler="sync") + assert mock.call_count == call_count + + class ZarrBaseV3(ZarrBase): zarr_version = 3 @@ -2893,8 +3395,9 @@ def create_store(self): def test_complex(self) -> None: expected = Dataset({"x": ("y", np.ones(5) + 1j * np.ones(5))}) save_kwargs = {"invalid_netcdf": True} - with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: - assert_equal(expected, actual) + with pytest.warns(UserWarning, match="You are writing invalid netcdf features"): + with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: + assert_equal(expected, actual) @pytest.mark.parametrize("invalid_netcdf", [None, False]) def test_complex_error(self, invalid_netcdf) -> None: @@ -2908,14 +3411,14 @@ def test_complex_error(self, invalid_netcdf) -> None: with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: assert_equal(expected, actual) - @pytest.mark.filterwarnings("ignore:You are writing invalid netcdf features") def test_numpy_bool_(self) -> None: # h5netcdf loads booleans as numpy.bool_, this type needs to be supported # when writing invalid_netcdf datasets in order to support a roundtrip expected = Dataset({"x": ("y", np.ones(5), {"numpy_bool": np.bool_(True)})}) save_kwargs = {"invalid_netcdf": True} - with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: - assert_identical(expected, actual) + with pytest.warns(UserWarning, match="You are writing invalid netcdf features"): + with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: + assert_identical(expected, actual) def test_cross_engine_read_write_netcdf4(self) -> None: # Drop dim3, because its labels include strings. These appear to be @@ -3186,8 +3689,8 @@ def roundtrip( ) as ds: yield ds + @pytest.mark.skip(reason="caching behavior differs for dask") def test_dataset_caching(self) -> None: - # caching behavior differs for dask pass def test_write_inconsistent_chunks(self) -> None: @@ -3208,6 +3711,36 @@ def test_write_inconsistent_chunks(self) -> None: assert actual["y"].encoding["chunksizes"] == (100, 50) +@requires_h5netcdf_ros3 +class TestH5NetCDFDataRos3Driver(TestCommon): + engine: T_NetcdfEngine = "h5netcdf" + test_remote_dataset: str = ( + "https://www.unidata.ucar.edu/software/netcdf/examples/OMI-Aura_L2-example.nc" + ) + + @pytest.mark.filterwarnings("ignore:Duplicate dimension names") + def test_get_variable_list(self) -> None: + with open_dataset( + self.test_remote_dataset, + engine="h5netcdf", + backend_kwargs={"driver": "ros3"}, + ) as actual: + assert "Temperature" in list(actual) + + @pytest.mark.filterwarnings("ignore:Duplicate dimension names") + def test_get_variable_list_empty_driver_kwds(self) -> None: + driver_kwds = { + "secret_id": b"", + "secret_key": b"", + } + backend_kwargs = {"driver": "ros3", "driver_kwds": driver_kwds} + + with open_dataset( + self.test_remote_dataset, engine="h5netcdf", backend_kwargs=backend_kwargs + ) as actual: + assert "Temperature" in list(actual) + + @pytest.fixture(params=["scipy", "netcdf4", "h5netcdf", "pynio", "zarr"]) def readengine(request): return request.param @@ -3238,6 +3771,21 @@ def chunks(request): return request.param +@pytest.fixture(params=["tmp_path", "ZipStore", "Dict"]) +def tmp_store(request, tmp_path): + if request.param == "tmp_path": + return tmp_path + elif request.param == "ZipStore": + from zarr.storage import ZipStore + + path = tmp_path / "store.zip" + return ZipStore(path) + elif request.param == "Dict": + return dict() + else: + raise ValueError("not supported") + + # using pytest.mark.skipif does not work so this a work around def skip_if_not_engine(engine): if engine == "netcdf4": @@ -3250,6 +3798,7 @@ def skip_if_not_engine(engine): @requires_dask @pytest.mark.filterwarnings("ignore:use make_scale(name) instead") +@pytest.mark.xfail(reason="Flaky test. Very open to contributions on fixing this") def test_open_mfdataset_manyfiles( readengine, nfiles, parallel, chunks, file_cache_maxsize ): @@ -3775,7 +4324,6 @@ def test_open_mfdataset_raise_on_bad_combine_args(self) -> None: with pytest.raises(ValueError, match="`concat_dim` has no effect"): open_mfdataset([tmp1, tmp2], concat_dim="x") - @pytest.mark.xfail(reason="mfdataset loses encoding currently.") def test_encoding_mfdataset(self) -> None: original = Dataset( { @@ -3953,6 +4501,7 @@ def test_open_multi_dataset(self) -> None: ) as actual: assert_identical(expected, actual) + @pytest.mark.xfail(reason="Flaky test. Very open to contributions on fixing this") def test_dask_roundtrip(self) -> None: with create_tmp_file() as tmp: data = create_test_data() @@ -3988,7 +4537,6 @@ def test_dataarray_compute(self) -> None: assert computed._in_memory assert_allclose(actual, computed, decode_bytes=False) - @pytest.mark.xfail def test_save_mfdataset_compute_false_roundtrip(self) -> None: from dask.delayed import Delayed @@ -4036,13 +4584,19 @@ def test_inline_array(self) -> None: def num_graph_nodes(obj): return len(obj.__dask_graph__()) - not_inlined_ds = open_dataset(tmp, inline_array=False, chunks=chunks) - inlined_ds = open_dataset(tmp, inline_array=True, chunks=chunks) - assert num_graph_nodes(inlined_ds) < num_graph_nodes(not_inlined_ds) + with ( + open_dataset(tmp, inline_array=False, chunks=chunks) as not_inlined_ds, + open_dataset(tmp, inline_array=True, chunks=chunks) as inlined_ds, + ): + assert num_graph_nodes(inlined_ds) < num_graph_nodes(not_inlined_ds) - not_inlined_da = open_dataarray(tmp, inline_array=False, chunks=chunks) - inlined_da = open_dataarray(tmp, inline_array=True, chunks=chunks) - assert num_graph_nodes(inlined_da) < num_graph_nodes(not_inlined_da) + with ( + open_dataarray( + tmp, inline_array=False, chunks=chunks + ) as not_inlined_da, + open_dataarray(tmp, inline_array=True, chunks=chunks) as inlined_da, + ): + assert num_graph_nodes(inlined_da) < num_graph_nodes(not_inlined_da) @requires_scipy_or_netCDF4 @@ -4185,782 +4739,6 @@ def test_weakrefs(self) -> None: assert_identical(actual, expected) -@requires_pseudonetcdf -@pytest.mark.filterwarnings("ignore:IOAPI_ISPH is assumed to be 6370000") -class TestPseudoNetCDFFormat: - def open(self, path, **kwargs): - return open_dataset(path, engine="pseudonetcdf", **kwargs) - - @contextlib.contextmanager - def roundtrip( - self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False - ): - if save_kwargs is None: - save_kwargs = {} - if open_kwargs is None: - open_kwargs = {} - with create_tmp_file(allow_cleanup_failure=allow_cleanup_failure) as path: - self.save(data, path, **save_kwargs) - with self.open(path, **open_kwargs) as ds: - yield ds - - def test_ict_format(self) -> None: - """ - Open a CAMx file and test data variables - """ - stdattr = { - "fill_value": -9999.0, - "missing_value": -9999, - "scale": 1, - "llod_flag": -8888, - "llod_value": "N/A", - "ulod_flag": -7777, - "ulod_value": "N/A", - } - - def myatts(**attrs): - outattr = stdattr.copy() - outattr.update(attrs) - return outattr - - input = { - "coords": {}, - "attrs": { - "fmt": "1001", - "n_header_lines": 29, - "PI_NAME": "Henderson, Barron", - "ORGANIZATION_NAME": "U.S. EPA", - "SOURCE_DESCRIPTION": "Example file with artificial data", - "MISSION_NAME": "JUST_A_TEST", - "VOLUME_INFO": "1, 1", - "SDATE": "2018, 04, 27", - "WDATE": "2018, 04, 27", - "TIME_INTERVAL": "0", - "INDEPENDENT_VARIABLE_DEFINITION": "Start_UTC", - "INDEPENDENT_VARIABLE": "Start_UTC", - "INDEPENDENT_VARIABLE_UNITS": "Start_UTC", - "ULOD_FLAG": "-7777", - "ULOD_VALUE": "N/A", - "LLOD_FLAG": "-8888", - "LLOD_VALUE": ("N/A, N/A, N/A, N/A, 0.025"), - "OTHER_COMMENTS": ( - "www-air.larc.nasa.gov/missions/etc/" + "IcarttDataFormat.htm" - ), - "REVISION": "R0", - "R0": "No comments for this revision.", - "TFLAG": "Start_UTC", - }, - "dims": {"POINTS": 4}, - "data_vars": { - "Start_UTC": { - "data": [43200.0, 46800.0, 50400.0, 50400.0], - "dims": ("POINTS",), - "attrs": myatts(units="Start_UTC", standard_name="Start_UTC"), - }, - "lat": { - "data": [41.0, 42.0, 42.0, 42.0], - "dims": ("POINTS",), - "attrs": myatts(units="degrees_north", standard_name="lat"), - }, - "lon": { - "data": [-71.0, -72.0, -73.0, -74.0], - "dims": ("POINTS",), - "attrs": myatts(units="degrees_east", standard_name="lon"), - }, - "elev": { - "data": [5.0, 15.0, 20.0, 25.0], - "dims": ("POINTS",), - "attrs": myatts(units="meters", standard_name="elev"), - }, - "TEST_ppbv": { - "data": [1.2345, 2.3456, 3.4567, 4.5678], - "dims": ("POINTS",), - "attrs": myatts(units="ppbv", standard_name="TEST_ppbv"), - }, - "TESTM_ppbv": { - "data": [2.22, -9999.0, -7777.0, -8888.0], - "dims": ("POINTS",), - "attrs": myatts( - units="ppbv", standard_name="TESTM_ppbv", llod_value=0.025 - ), - }, - }, - } - chkfile = Dataset.from_dict(input) - with open_example_dataset( - "example.ict", engine="pseudonetcdf", backend_kwargs={"format": "ffi1001"} - ) as ictfile: - assert_identical(ictfile, chkfile) - - def test_ict_format_write(self) -> None: - fmtkw = {"format": "ffi1001"} - with open_example_dataset( - "example.ict", engine="pseudonetcdf", backend_kwargs=fmtkw - ) as expected: - with self.roundtrip( - expected, save_kwargs=fmtkw, open_kwargs={"backend_kwargs": fmtkw} - ) as actual: - assert_identical(expected, actual) - - def test_uamiv_format_read(self) -> None: - """ - Open a CAMx file and test data variables - """ - - camxfile = open_example_dataset( - "example.uamiv", engine="pseudonetcdf", backend_kwargs={"format": "uamiv"} - ) - data = np.arange(20, dtype="f").reshape(1, 1, 4, 5) - expected = xr.Variable( - ("TSTEP", "LAY", "ROW", "COL"), - data, - dict(units="ppm", long_name="O3".ljust(16), var_desc="O3".ljust(80)), - ) - actual = camxfile.variables["O3"] - assert_allclose(expected, actual) - - data = np.array([[[2002154, 0]]], dtype="i") - expected = xr.Variable( - ("TSTEP", "VAR", "DATE-TIME"), - data, - dict( - long_name="TFLAG".ljust(16), - var_desc="TFLAG".ljust(80), - units="DATE-TIME".ljust(16), - ), - ) - actual = camxfile.variables["TFLAG"] - assert_allclose(expected, actual) - camxfile.close() - - @requires_dask - def test_uamiv_format_mfread(self) -> None: - """ - Open a CAMx file and test data variables - """ - - camxfile = open_example_mfdataset( - ["example.uamiv", "example.uamiv"], - engine="pseudonetcdf", - concat_dim="TSTEP", - combine="nested", - backend_kwargs={"format": "uamiv"}, - ) - - data1 = np.arange(20, dtype="f").reshape(1, 1, 4, 5) - data = np.concatenate([data1] * 2, axis=0) - expected = xr.Variable( - ("TSTEP", "LAY", "ROW", "COL"), - data, - dict(units="ppm", long_name="O3".ljust(16), var_desc="O3".ljust(80)), - ) - actual = camxfile.variables["O3"] - assert_allclose(expected, actual) - - data = np.array([[[2002154, 0]]], dtype="i").repeat(2, 0) - attrs = dict( - long_name="TFLAG".ljust(16), - var_desc="TFLAG".ljust(80), - units="DATE-TIME".ljust(16), - ) - dims = ("TSTEP", "VAR", "DATE-TIME") - expected = xr.Variable(dims, data, attrs) - actual = camxfile.variables["TFLAG"] - assert_allclose(expected, actual) - camxfile.close() - - @pytest.mark.xfail(reason="Flaky; see GH3711") - def test_uamiv_format_write(self) -> None: - fmtkw = {"format": "uamiv"} - - expected = open_example_dataset( - "example.uamiv", engine="pseudonetcdf", backend_kwargs=fmtkw - ) - with self.roundtrip( - expected, - save_kwargs=fmtkw, - open_kwargs={"backend_kwargs": fmtkw}, - allow_cleanup_failure=True, - ) as actual: - assert_identical(expected, actual) - - expected.close() - - def save(self, dataset, path, **save_kwargs): - import PseudoNetCDF as pnc - - pncf = pnc.PseudoNetCDFFile() - pncf.dimensions = { - k: pnc.PseudoNetCDFDimension(pncf, k, v) for k, v in dataset.dims.items() - } - pncf.variables = { - k: pnc.PseudoNetCDFVariable( - pncf, k, v.dtype.char, v.dims, values=v.data[...], **v.attrs - ) - for k, v in dataset.variables.items() - } - for pk, pv in dataset.attrs.items(): - setattr(pncf, pk, pv) - - pnc.pncwrite(pncf, path, **save_kwargs) - - -@requires_rasterio -@contextlib.contextmanager -def create_tmp_geotiff( - nx=4, - ny=3, - nz=3, - transform=None, - transform_args=default_value, - crs=default_value, - open_kwargs=None, - additional_attrs=None, -): - if transform_args is default_value: - transform_args = [5000, 80000, 1000, 2000.0] - if crs is default_value: - crs = { - "units": "m", - "no_defs": True, - "ellps": "WGS84", - "proj": "utm", - "zone": 18, - } - # yields a temporary geotiff file and a corresponding expected DataArray - import rasterio - from rasterio.transform import from_origin - - if open_kwargs is None: - open_kwargs = {} - - with create_tmp_file(suffix=".tif", allow_cleanup_failure=ON_WINDOWS) as tmp_file: - # allow 2d or 3d shapes - if nz == 1: - data_shape = ny, nx - write_kwargs = {"indexes": 1} - else: - data_shape = nz, ny, nx - write_kwargs = {} - data = np.arange(nz * ny * nx, dtype=rasterio.float32).reshape(*data_shape) - if transform is None: - transform = from_origin(*transform_args) - if additional_attrs is None: - additional_attrs = { - "descriptions": tuple(f"d{n + 1}" for n in range(nz)), - "units": tuple(f"u{n + 1}" for n in range(nz)), - } - with rasterio.open( - tmp_file, - "w", - driver="GTiff", - height=ny, - width=nx, - count=nz, - crs=crs, - transform=transform, - dtype=rasterio.float32, - **open_kwargs, - ) as s: - for attr, val in additional_attrs.items(): - setattr(s, attr, val) - s.write(data, **write_kwargs) - dx, dy = s.res[0], -s.res[1] - - a, b, c, d = transform_args - data = data[np.newaxis, ...] if nz == 1 else data - expected = DataArray( - data, - dims=("band", "y", "x"), - coords={ - "band": np.arange(nz) + 1, - "y": -np.arange(ny) * d + b + dy / 2, - "x": np.arange(nx) * c + a + dx / 2, - }, - ) - yield tmp_file, expected - - -@requires_rasterio -class TestRasterio: - @requires_scipy_or_netCDF4 - def test_serialization(self) -> None: - with create_tmp_geotiff(additional_attrs={}) as (tmp_file, expected): - # Write it to a netcdf and read again (roundtrip) - with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: - with create_tmp_file(suffix=".nc") as tmp_nc_file: - rioda.to_netcdf(tmp_nc_file) - with xr.open_dataarray(tmp_nc_file) as ncds: - assert_identical(rioda, ncds) - - def test_utm(self) -> None: - with create_tmp_geotiff() as (tmp_file, expected): - with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: - assert_allclose(rioda, expected) - assert rioda.attrs["scales"] == (1.0, 1.0, 1.0) - assert rioda.attrs["offsets"] == (0.0, 0.0, 0.0) - assert rioda.attrs["descriptions"] == ("d1", "d2", "d3") - assert rioda.attrs["units"] == ("u1", "u2", "u3") - assert isinstance(rioda.attrs["crs"], str) - assert isinstance(rioda.attrs["res"], tuple) - assert isinstance(rioda.attrs["is_tiled"], np.uint8) - assert isinstance(rioda.attrs["transform"], tuple) - assert len(rioda.attrs["transform"]) == 6 - np.testing.assert_array_equal( - rioda.attrs["nodatavals"], [np.NaN, np.NaN, np.NaN] - ) - - # Check no parse coords - with pytest.warns(DeprecationWarning), xr.open_rasterio( - tmp_file, parse_coordinates=False - ) as rioda: - assert "x" not in rioda.coords - assert "y" not in rioda.coords - - def test_non_rectilinear(self) -> None: - from rasterio.transform import from_origin - - # Create a geotiff file with 2d coordinates - with create_tmp_geotiff( - transform=from_origin(0, 3, 1, 1).rotation(45), crs=None - ) as (tmp_file, _): - # Default is to not parse coords - with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: - assert "x" not in rioda.coords - assert "y" not in rioda.coords - assert "crs" not in rioda.attrs - assert rioda.attrs["scales"] == (1.0, 1.0, 1.0) - assert rioda.attrs["offsets"] == (0.0, 0.0, 0.0) - assert rioda.attrs["descriptions"] == ("d1", "d2", "d3") - assert rioda.attrs["units"] == ("u1", "u2", "u3") - assert isinstance(rioda.attrs["res"], tuple) - assert isinstance(rioda.attrs["is_tiled"], np.uint8) - assert isinstance(rioda.attrs["transform"], tuple) - assert len(rioda.attrs["transform"]) == 6 - - # See if a warning is raised if we force it - with pytest.warns(Warning, match="transformation isn't rectilinear"): - with xr.open_rasterio(tmp_file, parse_coordinates=True) as rioda: - assert "x" not in rioda.coords - assert "y" not in rioda.coords - - def test_platecarree(self) -> None: - with create_tmp_geotiff( - 8, - 10, - 1, - transform_args=[1, 2, 0.5, 2.0], - crs="+proj=latlong", - open_kwargs={"nodata": -9765}, - ) as (tmp_file, expected): - with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: - assert_allclose(rioda, expected) - assert rioda.attrs["scales"] == (1.0,) - assert rioda.attrs["offsets"] == (0.0,) - assert isinstance(rioda.attrs["descriptions"], tuple) - assert isinstance(rioda.attrs["units"], tuple) - assert isinstance(rioda.attrs["crs"], str) - assert isinstance(rioda.attrs["res"], tuple) - assert isinstance(rioda.attrs["is_tiled"], np.uint8) - assert isinstance(rioda.attrs["transform"], tuple) - assert len(rioda.attrs["transform"]) == 6 - np.testing.assert_array_equal(rioda.attrs["nodatavals"], [-9765.0]) - - # rasterio throws a Warning, which is expected since we test rasterio's defaults - @pytest.mark.filterwarnings("ignore:Dataset has no geotransform") - def test_notransform(self) -> None: - # regression test for https://github.com/pydata/xarray/issues/1686 - - import rasterio - - # Create a geotiff file - with create_tmp_file(suffix=".tif") as tmp_file: - # data - nx, ny, nz = 4, 3, 3 - data = np.arange(nx * ny * nz, dtype=rasterio.float32).reshape(nz, ny, nx) - with rasterio.open( - tmp_file, - "w", - driver="GTiff", - height=ny, - width=nx, - count=nz, - dtype=rasterio.float32, - ) as s: - s.descriptions = ("nx", "ny", "nz") - s.units = ("cm", "m", "km") - s.write(data) - - # Tests - expected = DataArray( - data, - dims=("band", "y", "x"), - coords={ - "band": [1, 2, 3], - "y": [0.5, 1.5, 2.5], - "x": [0.5, 1.5, 2.5, 3.5], - }, - ) - with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: - assert_allclose(rioda, expected) - assert rioda.attrs["scales"] == (1.0, 1.0, 1.0) - assert rioda.attrs["offsets"] == (0.0, 0.0, 0.0) - assert rioda.attrs["descriptions"] == ("nx", "ny", "nz") - assert rioda.attrs["units"] == ("cm", "m", "km") - assert isinstance(rioda.attrs["res"], tuple) - assert isinstance(rioda.attrs["is_tiled"], np.uint8) - assert isinstance(rioda.attrs["transform"], tuple) - assert len(rioda.attrs["transform"]) == 6 - - def test_indexing(self) -> None: - with create_tmp_geotiff( - 8, 10, 3, transform_args=[1, 2, 0.5, 2.0], crs="+proj=latlong" - ) as (tmp_file, expected): - with pytest.warns(DeprecationWarning), xr.open_rasterio( - tmp_file, cache=False - ) as actual: - # tests - # assert_allclose checks all data + coordinates - assert_allclose(actual, expected) - assert not actual.variable._in_memory - - # Basic indexer - ind = {"x": slice(2, 5), "y": slice(5, 7)} - assert_allclose(expected.isel(**ind), actual.isel(**ind)) - assert not actual.variable._in_memory - - ind2 = {"band": slice(1, 2), "x": slice(2, 5), "y": slice(5, 7)} - assert_allclose(expected.isel(**ind2), actual.isel(**ind2)) - assert not actual.variable._in_memory - - ind3 = {"band": slice(1, 2), "x": slice(2, 5), "y": 0} - assert_allclose(expected.isel(**ind3), actual.isel(**ind3)) - assert not actual.variable._in_memory - - # orthogonal indexer - ind4 = { - "band": np.array([2, 1, 0]), - "x": np.array([1, 0]), - "y": np.array([0, 2]), - } - assert_allclose(expected.isel(**ind4), actual.isel(**ind4)) - assert not actual.variable._in_memory - - ind5 = {"band": np.array([2, 1, 0]), "x": np.array([1, 0]), "y": 0} - assert_allclose(expected.isel(**ind5), actual.isel(**ind5)) - assert not actual.variable._in_memory - - ind6 = {"band": 0, "x": np.array([0, 0]), "y": np.array([1, 1, 1])} - assert_allclose(expected.isel(**ind6), actual.isel(**ind6)) - assert not actual.variable._in_memory - - # minus-stepped slice - ind7 = {"band": np.array([2, 1, 0]), "x": slice(-1, None, -1), "y": 0} - assert_allclose(expected.isel(**ind7), actual.isel(**ind7)) - assert not actual.variable._in_memory - - ind8 = {"band": np.array([2, 1, 0]), "x": 1, "y": slice(-1, 1, -2)} - assert_allclose(expected.isel(**ind8), actual.isel(**ind8)) - assert not actual.variable._in_memory - - # empty selection - ind9 = {"band": np.array([2, 1, 0]), "x": 1, "y": slice(2, 2, 1)} - assert_allclose(expected.isel(**ind9), actual.isel(**ind9)) - assert not actual.variable._in_memory - - ind10 = {"band": slice(0, 0), "x": 1, "y": 2} - assert_allclose(expected.isel(**ind10), actual.isel(**ind10)) - assert not actual.variable._in_memory - - # vectorized indexer - ind11 = { - "band": DataArray([2, 1, 0], dims="a"), - "x": DataArray([1, 0, 0], dims="a"), - "y": np.array([0, 2]), - } - assert_allclose(expected.isel(**ind11), actual.isel(**ind11)) - assert not actual.variable._in_memory - - ind12 = { - "band": DataArray([[2, 1, 0], [1, 0, 2]], dims=["a", "b"]), - "x": DataArray([[1, 0, 0], [0, 1, 0]], dims=["a", "b"]), - "y": 0, - } - assert_allclose(expected.isel(**ind12), actual.isel(**ind12)) - assert not actual.variable._in_memory - - # Selecting lists of bands is fine - ex = expected.isel(band=[1, 2]) - ac = actual.isel(band=[1, 2]) - assert_allclose(ac, ex) - ex = expected.isel(band=[0, 2]) - ac = actual.isel(band=[0, 2]) - assert_allclose(ac, ex) - - # Integer indexing - ex = expected.isel(band=1) - ac = actual.isel(band=1) - assert_allclose(ac, ex) - - ex = expected.isel(x=1, y=2) - ac = actual.isel(x=1, y=2) - assert_allclose(ac, ex) - - ex = expected.isel(band=0, x=1, y=2) - ac = actual.isel(band=0, x=1, y=2) - assert_allclose(ac, ex) - - # Mixed - ex = actual.isel(x=slice(2), y=slice(2)) - ac = actual.isel(x=[0, 1], y=[0, 1]) - assert_allclose(ac, ex) - - ex = expected.isel(band=0, x=1, y=slice(5, 7)) - ac = actual.isel(band=0, x=1, y=slice(5, 7)) - assert_allclose(ac, ex) - - ex = expected.isel(band=0, x=slice(2, 5), y=2) - ac = actual.isel(band=0, x=slice(2, 5), y=2) - assert_allclose(ac, ex) - - # One-element lists - ex = expected.isel(band=[0], x=slice(2, 5), y=[2]) - ac = actual.isel(band=[0], x=slice(2, 5), y=[2]) - assert_allclose(ac, ex) - - def test_caching(self) -> None: - with create_tmp_geotiff( - 8, 10, 3, transform_args=[1, 2, 0.5, 2.0], crs="+proj=latlong" - ) as (tmp_file, expected): - # Cache is the default - with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as actual: - # This should cache everything - assert_allclose(actual, expected) - - # once cached, non-windowed indexing should become possible - ac = actual.isel(x=[2, 4]) - ex = expected.isel(x=[2, 4]) - assert_allclose(ac, ex) - - @requires_dask - def test_chunks(self) -> None: - with create_tmp_geotiff( - 8, 10, 3, transform_args=[1, 2, 0.5, 2.0], crs="+proj=latlong" - ) as (tmp_file, expected): - # Chunk at open time - with pytest.warns(DeprecationWarning), xr.open_rasterio( - tmp_file, chunks=(1, 2, 2) - ) as actual: - import dask.array as da - - assert isinstance(actual.data, da.Array) - assert "open_rasterio" in actual.data.name - - # do some arithmetic - ac = actual.mean() - ex = expected.mean() - assert_allclose(ac, ex) - - ac = actual.sel(band=1).mean(dim="x") - ex = expected.sel(band=1).mean(dim="x") - assert_allclose(ac, ex) - - @pytest.mark.xfail( - not has_dask, reason="without dask, a non-serializable lock is used" - ) - def test_pickle_rasterio(self) -> None: - # regression test for https://github.com/pydata/xarray/issues/2121 - with create_tmp_geotiff() as (tmp_file, expected): - with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: - temp = pickle.dumps(rioda) - with pickle.loads(temp) as actual: - assert_equal(actual, rioda) - - def test_ENVI_tags(self) -> None: - rasterio = pytest.importorskip("rasterio") - from rasterio.transform import from_origin - - # Create an ENVI file with some tags in the ENVI namespace - # this test uses a custom driver, so we can't use create_tmp_geotiff - with create_tmp_file(suffix=".dat") as tmp_file: - # data - nx, ny, nz = 4, 3, 3 - data = np.arange(nx * ny * nz, dtype=rasterio.float32).reshape(nz, ny, nx) - transform = from_origin(5000, 80000, 1000, 2000.0) - with rasterio.open( - tmp_file, - "w", - driver="ENVI", - height=ny, - width=nx, - count=nz, - crs={ - "units": "m", - "no_defs": True, - "ellps": "WGS84", - "proj": "utm", - "zone": 18, - }, - transform=transform, - dtype=rasterio.float32, - ) as s: - s.update_tags( - ns="ENVI", - description="{Tagged file}", - wavelength="{123.000000, 234.234000, 345.345678}", - fwhm="{1.000000, 0.234000, 0.000345}", - ) - s.write(data) - dx, dy = s.res[0], -s.res[1] - - # Tests - coords = { - "band": [1, 2, 3], - "y": -np.arange(ny) * 2000 + 80000 + dy / 2, - "x": np.arange(nx) * 1000 + 5000 + dx / 2, - "wavelength": ("band", np.array([123, 234.234, 345.345678])), - "fwhm": ("band", np.array([1, 0.234, 0.000345])), - } - expected = DataArray(data, dims=("band", "y", "x"), coords=coords) - - with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: - assert_allclose(rioda, expected) - assert isinstance(rioda.attrs["crs"], str) - assert isinstance(rioda.attrs["res"], tuple) - assert isinstance(rioda.attrs["is_tiled"], np.uint8) - assert isinstance(rioda.attrs["transform"], tuple) - assert len(rioda.attrs["transform"]) == 6 - # from ENVI tags - assert isinstance(rioda.attrs["description"], str) - assert isinstance(rioda.attrs["map_info"], str) - assert isinstance(rioda.attrs["samples"], str) - - def test_geotiff_tags(self) -> None: - # Create a geotiff file with some tags - with create_tmp_geotiff() as (tmp_file, _): - with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: - assert isinstance(rioda.attrs["AREA_OR_POINT"], str) - - @requires_dask - def test_no_mftime(self) -> None: - # rasterio can accept "filename" urguments that are actually urls, - # including paths to remote files. - # In issue #1816, we found that these caused dask to break, because - # the modification time was used to determine the dask token. This - # tests ensure we can still chunk such files when reading with - # rasterio. - with create_tmp_geotiff( - 8, 10, 3, transform_args=[1, 2, 0.5, 2.0], crs="+proj=latlong" - ) as (tmp_file, expected): - with mock.patch("os.path.getmtime", side_effect=OSError): - with pytest.warns(DeprecationWarning), xr.open_rasterio( - tmp_file, chunks=(1, 2, 2) - ) as actual: - import dask.array as da - - assert isinstance(actual.data, da.Array) - assert_allclose(actual, expected) - - @network - def test_http_url(self) -> None: - # more examples urls here - # http://download.osgeo.org/geotiff/samples/ - url = "http://download.osgeo.org/geotiff/samples/made_up/ntf_nord.tif" - with pytest.warns(DeprecationWarning), xr.open_rasterio(url) as actual: - assert actual.shape == (1, 512, 512) - # make sure chunking works - with pytest.warns(DeprecationWarning), xr.open_rasterio( - url, chunks=(1, 256, 256) - ) as actual: - import dask.array as da - - assert isinstance(actual.data, da.Array) - - def test_rasterio_environment(self) -> None: - import rasterio - - with create_tmp_geotiff() as (tmp_file, expected): - # Should fail with error since suffix not allowed - with pytest.raises(Exception): - with rasterio.Env(GDAL_SKIP="GTiff"): - with pytest.warns(DeprecationWarning), xr.open_rasterio( - tmp_file - ) as actual: - assert_allclose(actual, expected) - - @pytest.mark.xfail(reason="rasterio 1.1.1 is broken. GH3573") - def test_rasterio_vrt(self) -> None: - import rasterio - - # tmp_file default crs is UTM: CRS({'init': 'epsg:32618'} - with create_tmp_geotiff() as (tmp_file, expected): - with rasterio.open(tmp_file) as src: - with rasterio.vrt.WarpedVRT(src, crs="epsg:4326") as vrt: - expected_shape = (vrt.width, vrt.height) - expected_crs = vrt.crs - expected_res = vrt.res - # Value of single pixel in center of image - lon, lat = vrt.xy(vrt.width // 2, vrt.height // 2) - expected_val = next(vrt.sample([(lon, lat)])) - with pytest.warns(DeprecationWarning), xr.open_rasterio(vrt) as da: - actual_shape = (da.sizes["x"], da.sizes["y"]) - actual_crs = da.crs - actual_res = da.res - actual_val = da.sel(dict(x=lon, y=lat), method="nearest").data - - assert actual_crs == expected_crs - assert actual_res == expected_res - assert actual_shape == expected_shape - assert expected_val.all() == actual_val.all() - - @pytest.mark.filterwarnings( - "ignore:open_rasterio is Deprecated in favor of rioxarray." - ) - def test_rasterio_vrt_with_transform_and_size(self) -> None: - # Test open_rasterio() support of WarpedVRT with transform, width and - # height (issue #2864) - - rasterio = pytest.importorskip("rasterio") - from affine import Affine - from rasterio.warp import calculate_default_transform - - with create_tmp_geotiff() as (tmp_file, expected): - with rasterio.open(tmp_file) as src: - # Estimate the transform, width and height - # for a change of resolution - # tmp_file initial res is (1000,2000) (default values) - trans, w, h = calculate_default_transform( - src.crs, src.crs, src.width, src.height, resolution=500, *src.bounds - ) - with rasterio.vrt.WarpedVRT( - src, transform=trans, width=w, height=h - ) as vrt: - expected_shape = (vrt.width, vrt.height) - expected_res = vrt.res - expected_transform = vrt.transform - with xr.open_rasterio(vrt) as da: - actual_shape = (da.sizes["x"], da.sizes["y"]) - actual_res = da.res - actual_transform = Affine(*da.transform) - assert actual_res == expected_res - assert actual_shape == expected_shape - assert actual_transform == expected_transform - - def test_rasterio_vrt_with_src_crs(self) -> None: - # Test open_rasterio() support of WarpedVRT with specified src_crs - - rasterio = pytest.importorskip("rasterio") - - # create geotiff with no CRS and specify it manually - with create_tmp_geotiff(crs=None) as (tmp_file, expected): - src_crs = rasterio.crs.CRS({"init": "epsg:32618"}) - with rasterio.open(tmp_file) as src: - assert src.crs is None - with rasterio.vrt.WarpedVRT(src, src_crs=src_crs) as vrt: - with pytest.warns(DeprecationWarning), xr.open_rasterio(vrt) as da: - assert da.crs == src_crs - - class TestEncodingInvalid: def test_extract_nc4_variable_encoding(self) -> None: var = xr.Variable(("x",), [1, 2, 3], {}, {"foo": "bar"}) @@ -4982,7 +4760,7 @@ def test_extract_nc4_variable_encoding(self) -> None: assert {} == encoding @requires_netCDF4 - def test_extract_nc4_variable_encoding_netcdf4(self, monkeypatch): + def test_extract_nc4_variable_encoding_netcdf4(self): # New netCDF4 1.6.0 compression argument. var = xr.Variable(("x",), [1, 2, 3], {}, {"compression": "szlib"}) _extract_nc4_variable_encoding(var, backend="netCDF4", raise_on_invalid=True) @@ -5148,6 +4926,56 @@ def test_dataarray_to_netcdf_no_name_pathlib(self) -> None: assert_identical(original_da, loaded_da) +@requires_zarr +class TestDataArrayToZarr: + def test_dataarray_to_zarr_no_name(self, tmp_store) -> None: + original_da = DataArray(np.arange(12).reshape((3, 4))) + + original_da.to_zarr(tmp_store) + + with open_dataarray(tmp_store, engine="zarr") as loaded_da: + assert_identical(original_da, loaded_da) + + def test_dataarray_to_zarr_with_name(self, tmp_store) -> None: + original_da = DataArray(np.arange(12).reshape((3, 4)), name="test") + + original_da.to_zarr(tmp_store) + + with open_dataarray(tmp_store, engine="zarr") as loaded_da: + assert_identical(original_da, loaded_da) + + def test_dataarray_to_zarr_coord_name_clash(self, tmp_store) -> None: + original_da = DataArray( + np.arange(12).reshape((3, 4)), dims=["x", "y"], name="x" + ) + + original_da.to_zarr(tmp_store) + + with open_dataarray(tmp_store, engine="zarr") as loaded_da: + assert_identical(original_da, loaded_da) + + def test_open_dataarray_options(self, tmp_store) -> None: + data = DataArray(np.arange(5), coords={"y": ("x", range(5))}, dims=["x"]) + + data.to_zarr(tmp_store) + + expected = data.drop_vars("y") + with open_dataarray(tmp_store, engine="zarr", drop_variables=["y"]) as loaded: + assert_identical(expected, loaded) + + @requires_dask + def test_dataarray_to_zarr_compute_false(self, tmp_store) -> None: + from dask.delayed import Delayed + + original_da = DataArray(np.arange(12).reshape((3, 4))) + + output = original_da.to_zarr(tmp_store, compute=False) + assert isinstance(output, Delayed) + output.compute() + with open_dataarray(tmp_store, engine="zarr") as loaded_da: + assert_identical(original_da, loaded_da) + + @requires_scipy_or_netCDF4 def test_no_warning_from_dask_effective_get() -> None: with create_tmp_file() as tmpfile: @@ -5424,15 +5252,17 @@ def test_open_fsspec() -> None: ds2 = open_dataset(url, engine="zarr") xr.testing.assert_equal(ds0, ds2) - # multi dataset - url = "memory://out*.zarr" - ds2 = open_mfdataset(url, engine="zarr") - xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) + # open_mfdataset requires dask + if has_dask: + # multi dataset + url = "memory://out*.zarr" + ds2 = open_mfdataset(url, engine="zarr") + xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) - # multi dataset with caching - url = "simplecache::memory://out*.zarr" - ds2 = open_mfdataset(url, engine="zarr") - xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) + # multi dataset with caching + url = "simplecache::memory://out*.zarr" + ds2 = open_mfdataset(url, engine="zarr") + xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) @requires_h5netcdf @@ -5491,7 +5321,7 @@ def test_open_dataset_chunking_zarr(chunks, tmp_path: Path) -> None: @pytest.mark.parametrize( "chunks", ["auto", -1, {}, {"x": "auto"}, {"x": -1}, {"x": "auto", "y": -1}] ) -@pytest.mark.filterwarnings("ignore:The specified Dask chunks separate") +@pytest.mark.filterwarnings("ignore:The specified chunks separate") def test_chunking_consintency(chunks, tmp_path: Path) -> None: encoded_chunks: dict[str, Any] = {} dask_arr = da.from_array( @@ -5580,7 +5410,7 @@ def test_scipy_entrypoint(tmp_path: Path) -> None: assert entrypoint.guess_can_open("something-local.nc") assert entrypoint.guess_can_open("something-local.nc.gz") assert not entrypoint.guess_can_open("not-found-and-no-extension") - assert not entrypoint.guess_can_open(b"not-a-netcdf-file") + assert not entrypoint.guess_can_open(b"not-a-netcdf-file") # type: ignore[arg-type] @requires_h5netcdf @@ -5629,7 +5459,7 @@ def test_write_file_from_np_str(str_type, tmpdir) -> None: class TestNCZarr: @property def netcdfc_version(self): - return Version(nc4.getlibversion().split()[0]) + return Version(nc4.getlibversion().split()[0].split("-development")[0]) def _create_nczarr(self, filename): if self.netcdfc_version < Version("4.8.1"): @@ -5678,5 +5508,215 @@ def test_raise_writing_to_nczarr(self, mode) -> None: @requires_netCDF4 @requires_dask def test_pickle_open_mfdataset_dataset(): - ds = open_example_mfdataset(["bears.nc"]) - assert_identical(ds, pickle.loads(pickle.dumps(ds))) + with open_example_mfdataset(["bears.nc"]) as ds: + assert_identical(ds, pickle.loads(pickle.dumps(ds))) + + +@requires_zarr +def test_zarr_closing_internal_zip_store(): + store_name = "tmp.zarr.zip" + original_da = DataArray(np.arange(12).reshape((3, 4))) + original_da.to_zarr(store_name, mode="w") + + with open_dataarray(store_name, engine="zarr") as loaded_da: + assert_identical(original_da, loaded_da) + + +@requires_zarr +class TestZarrRegionAuto: + def test_zarr_region_auto_all(self, tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8)) + ds_region.to_zarr(tmp_path / "test.zarr", region="auto") + + ds_updated = xr.open_zarr(tmp_path / "test.zarr") + + expected = ds.copy() + expected["test"][2:4, 6:8] += 1 + assert_identical(ds_updated, expected) + + def test_zarr_region_auto_mixed(self, tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8)) + ds_region.to_zarr( + tmp_path / "test.zarr", region={"x": "auto", "y": slice(6, 8)} + ) + + ds_updated = xr.open_zarr(tmp_path / "test.zarr") + + expected = ds.copy() + expected["test"][2:4, 6:8] += 1 + assert_identical(ds_updated, expected) + + def test_zarr_region_auto_noncontiguous(self, tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + ds_region = 1 + ds.isel(x=[0, 2, 3], y=[5, 6]) + with pytest.raises(ValueError): + ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"}) + + def test_zarr_region_auto_new_coord_vals(self, tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + x = np.arange(5, 55, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + + ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8)) + with pytest.raises(KeyError): + ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"}) + + def test_zarr_region_index_write(self, tmp_path): + from xarray.backends.zarr import ZarrStore + + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + + ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8)) + + ds.to_zarr(tmp_path / "test.zarr") + + with patch.object( + ZarrStore, + "set_variables", + side_effect=ZarrStore.set_variables, + autospec=True, + ) as mock: + ds_region.to_zarr(tmp_path / "test.zarr", region="auto", mode="r+") + + # should write the data vars but never the index vars with auto mode + for call in mock.call_args_list: + written_variables = call.args[1].keys() + assert "test" in written_variables + assert "x" not in written_variables + assert "y" not in written_variables + + def test_zarr_region_append(self, tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + x_new = np.arange(40, 70, 10) + data_new = np.ones((3, 10)) + ds_new = xr.Dataset( + { + "test": xr.DataArray( + data_new, + dims=("x", "y"), + coords={"x": x_new, "y": y}, + ) + } + ) + + # Don't allow auto region detection in append mode due to complexities in + # implementing the overlap logic and lack of safety with parallel writes + with pytest.raises(ValueError): + ds_new.to_zarr( + tmp_path / "test.zarr", mode="a", append_dim="x", region="auto" + ) + + +@requires_zarr +def test_zarr_region(tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + ds_transposed = ds.transpose("y", "x") + + ds_region = 1 + ds_transposed.isel(x=[0], y=[0]) + ds_region.to_zarr( + tmp_path / "test.zarr", region={"x": slice(0, 1), "y": slice(0, 1)} + ) + + # Write without region + ds_transposed.to_zarr(tmp_path / "test.zarr", mode="r+") diff --git a/xarray/tests/test_backends_api.py b/xarray/tests/test_backends_api.py index befc4cbaf04..d4f8b7ed31d 100644 --- a/xarray/tests/test_backends_api.py +++ b/xarray/tests/test_backends_api.py @@ -79,11 +79,13 @@ def explicit_chunks(chunks, shape): # Emulate `dask.array.core.normalize_chunks` but for simpler inputs. return tuple( ( - (size // chunk) * (chunk,) - + ((size % chunk,) if size % chunk or size == 0 else ()) + ( + (size // chunk) * (chunk,) + + ((size % chunk,) if size % chunk or size == 0 else ()) + ) + if isinstance(chunk, Number) + else chunk ) - if isinstance(chunk, Number) - else chunk for chunk, size in zip(chunks, shape) ) diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py new file mode 100644 index 00000000000..7bdb2b532d9 --- /dev/null +++ b/xarray/tests/test_backends_datatree.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from xarray.backends.api import open_datatree +from xarray.datatree_.datatree.testing import assert_equal +from xarray.tests import ( + requires_h5netcdf, + requires_netCDF4, + requires_zarr, +) + +if TYPE_CHECKING: + from xarray.backends.api import T_NetcdfEngine + + +class DatatreeIOBase: + engine: T_NetcdfEngine | None = None + + def test_to_netcdf(self, tmpdir, simple_datatree): + filepath = tmpdir / "test.nc" + original_dt = simple_datatree + original_dt.to_netcdf(filepath, engine=self.engine) + + roundtrip_dt = open_datatree(filepath, engine=self.engine) + assert_equal(original_dt, roundtrip_dt) + + def test_netcdf_encoding(self, tmpdir, simple_datatree): + filepath = tmpdir / "test.nc" + original_dt = simple_datatree + + # add compression + comp = dict(zlib=True, complevel=9) + enc = {"/set2": {var: comp for var in original_dt["/set2"].ds.data_vars}} + + original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine) + roundtrip_dt = open_datatree(filepath, engine=self.engine) + + assert roundtrip_dt["/set2/a"].encoding["zlib"] == comp["zlib"] + assert roundtrip_dt["/set2/a"].encoding["complevel"] == comp["complevel"] + + enc["/not/a/group"] = {"foo": "bar"} # type: ignore + with pytest.raises(ValueError, match="unexpected encoding group.*"): + original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine) + + +@requires_netCDF4 +class TestNetCDF4DatatreeIO(DatatreeIOBase): + engine: T_NetcdfEngine | None = "netcdf4" + + +@requires_h5netcdf +class TestH5NetCDFDatatreeIO(DatatreeIOBase): + engine: T_NetcdfEngine | None = "h5netcdf" + + +@requires_zarr +class TestZarrDatatreeIO: + engine = "zarr" + + def test_to_zarr(self, tmpdir, simple_datatree): + filepath = tmpdir / "test.zarr" + original_dt = simple_datatree + original_dt.to_zarr(filepath) + + roundtrip_dt = open_datatree(filepath, engine="zarr") + assert_equal(original_dt, roundtrip_dt) + + def test_zarr_encoding(self, tmpdir, simple_datatree): + import zarr + + filepath = tmpdir / "test.zarr" + original_dt = simple_datatree + + comp = {"compressor": zarr.Blosc(cname="zstd", clevel=3, shuffle=2)} + enc = {"/set2": {var: comp for var in original_dt["/set2"].ds.data_vars}} + original_dt.to_zarr(filepath, encoding=enc) + roundtrip_dt = open_datatree(filepath, engine="zarr") + + print(roundtrip_dt["/set2/a"].encoding) + assert roundtrip_dt["/set2/a"].encoding["compressor"] == comp["compressor"] + + enc["/not/a/group"] = {"foo": "bar"} # type: ignore + with pytest.raises(ValueError, match="unexpected encoding group.*"): + original_dt.to_zarr(filepath, encoding=enc, engine="zarr") + + def test_to_zarr_zip_store(self, tmpdir, simple_datatree): + from zarr.storage import ZipStore + + filepath = tmpdir / "test.zarr.zip" + original_dt = simple_datatree + store = ZipStore(filepath) + original_dt.to_zarr(store) + + roundtrip_dt = open_datatree(store, engine="zarr") + assert_equal(original_dt, roundtrip_dt) + + def test_to_zarr_not_consolidated(self, tmpdir, simple_datatree): + filepath = tmpdir / "test.zarr" + zmetadata = filepath / ".zmetadata" + s1zmetadata = filepath / "set1" / ".zmetadata" + filepath = str(filepath) # casting to str avoids a pathlib bug in xarray + original_dt = simple_datatree + original_dt.to_zarr(filepath, consolidated=False) + assert not zmetadata.exists() + assert not s1zmetadata.exists() + + with pytest.warns(RuntimeWarning, match="consolidated"): + roundtrip_dt = open_datatree(filepath, engine="zarr") + assert_equal(original_dt, roundtrip_dt) + + def test_to_zarr_default_write_mode(self, tmpdir, simple_datatree): + import zarr + + simple_datatree.to_zarr(tmpdir) + + # with default settings, to_zarr should not overwrite an existing dir + with pytest.raises(zarr.errors.ContainsGroupError): + simple_datatree.to_zarr(tmpdir) diff --git a/xarray/tests/test_calendar_ops.py b/xarray/tests/test_calendar_ops.py index d118ccf4556..d2792034876 100644 --- a/xarray/tests/test_calendar_ops.py +++ b/xarray/tests/test_calendar_ops.py @@ -18,7 +18,7 @@ ("standard", "noleap", None, "D"), ("noleap", "proleptic_gregorian", True, "D"), ("noleap", "all_leap", None, "D"), - ("all_leap", "proleptic_gregorian", False, "4H"), + ("all_leap", "proleptic_gregorian", False, "4h"), ], ) def test_convert_calendar(source, target, use_cftime, freq): @@ -67,7 +67,7 @@ def test_convert_calendar(source, target, use_cftime, freq): [ ("standard", "360_day", "D"), ("360_day", "proleptic_gregorian", "D"), - ("proleptic_gregorian", "360_day", "4H"), + ("proleptic_gregorian", "360_day", "4h"), ], ) @pytest.mark.parametrize("align_on", ["date", "year"]) @@ -87,17 +87,17 @@ def test_convert_calendar_360_days(source, target, freq, align_on): if align_on == "date": np.testing.assert_array_equal( - conv.time.resample(time="M").last().dt.day, + conv.time.resample(time="ME").last().dt.day, [30, 29, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30], ) elif target == "360_day": np.testing.assert_array_equal( - conv.time.resample(time="M").last().dt.day, + conv.time.resample(time="ME").last().dt.day, [30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 29], ) else: np.testing.assert_array_equal( - conv.time.resample(time="M").last().dt.day, + conv.time.resample(time="ME").last().dt.day, [30, 29, 30, 30, 31, 30, 30, 31, 30, 31, 29, 31], ) if source == "360_day" and align_on == "year": @@ -111,8 +111,8 @@ def test_convert_calendar_360_days(source, target, freq, align_on): "source,target,freq", [ ("standard", "noleap", "D"), - ("noleap", "proleptic_gregorian", "4H"), - ("noleap", "all_leap", "M"), + ("noleap", "proleptic_gregorian", "4h"), + ("noleap", "all_leap", "ME"), ("360_day", "noleap", "D"), ("noleap", "360_day", "D"), ], @@ -132,7 +132,9 @@ def test_convert_calendar_missing(source, target, freq): np.linspace(0, 1, src.size), dims=("time",), coords={"time": src} ) out = convert_calendar(da_src, target, missing=np.nan, align_on="date") - assert infer_freq(out.time) == freq + + expected_freq = freq + assert infer_freq(out.time) == expected_freq expected = date_range( "2004-01-01", @@ -142,7 +144,7 @@ def test_convert_calendar_missing(source, target, freq): ) np.testing.assert_array_equal(out.time, expected) - if freq != "M": + if freq != "ME": out_without_missing = convert_calendar(da_src, target, align_on="date") expected_nan = out.isel(time=~out.time.isin(out_without_missing.time)) assert expected_nan.isnull().all() @@ -181,7 +183,7 @@ def test_convert_calendar_errors(): def test_convert_calendar_same_calendar(): src = DataArray( - date_range("2000-01-01", periods=12, freq="6H", use_cftime=False), + date_range("2000-01-01", periods=12, freq="6h", use_cftime=False), dims=("time",), name="time", ) diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py index 6b628c15488..0110afe40ac 100644 --- a/xarray/tests/test_cftime_offsets.py +++ b/xarray/tests/test_cftime_offsets.py @@ -1,6 +1,7 @@ from __future__ import annotations from itertools import product +from typing import Callable, Literal import numpy as np import pandas as pd @@ -24,6 +25,8 @@ YearBegin, YearEnd, _days_in_month, + _legacy_to_new_freq, + _new_to_legacy_freq, cftime_range, date_range, date_range_like, @@ -33,7 +36,13 @@ ) from xarray.coding.frequencies import infer_freq from xarray.core.dataarray import DataArray -from xarray.tests import _CFTIME_CALENDARS, has_cftime, requires_cftime +from xarray.tests import ( + _CFTIME_CALENDARS, + assert_no_warnings, + has_cftime, + has_pandas_ge_2_2, + requires_cftime, +) cftime = pytest.importorskip("cftime") @@ -153,8 +162,17 @@ def test_year_offset_constructor_invalid_month(offset, invalid_month, exception) [ (BaseCFTimeOffset(), None), (MonthBegin(), "MS"), - (YearBegin(), "AS-JAN"), + (MonthEnd(), "ME"), + (YearBegin(), "YS-JAN"), + (YearEnd(), "YE-DEC"), (QuarterBegin(), "QS-MAR"), + (QuarterEnd(), "QE-MAR"), + (Day(), "D"), + (Hour(), "h"), + (Minute(), "min"), + (Second(), "s"), + (Millisecond(), "ms"), + (Microsecond(), "us"), ], ids=_id_func, ) @@ -190,12 +208,16 @@ def test_to_offset_offset_input(offset): [ ("M", MonthEnd()), ("2M", MonthEnd(n=2)), + ("ME", MonthEnd()), + ("2ME", MonthEnd(n=2)), ("MS", MonthBegin()), ("2MS", MonthBegin(n=2)), ("D", Day()), ("2D", Day(n=2)), ("H", Hour()), ("2H", Hour(n=2)), + ("h", Hour()), + ("2h", Hour(n=2)), ("T", Minute()), ("2T", Minute(n=2)), ("min", Minute()), @@ -210,21 +232,43 @@ def test_to_offset_offset_input(offset): ("2U", Microsecond(n=2)), ("us", Microsecond(n=1)), ("2us", Microsecond(n=2)), + # negative + ("-2M", MonthEnd(n=-2)), + ("-2ME", MonthEnd(n=-2)), + ("-2MS", MonthBegin(n=-2)), + ("-2D", Day(n=-2)), + ("-2H", Hour(n=-2)), + ("-2h", Hour(n=-2)), + ("-2T", Minute(n=-2)), + ("-2min", Minute(n=-2)), + ("-2S", Second(n=-2)), + ("-2L", Millisecond(n=-2)), + ("-2ms", Millisecond(n=-2)), + ("-2U", Microsecond(n=-2)), + ("-2us", Microsecond(n=-2)), ], ids=_id_func, ) +@pytest.mark.filterwarnings("ignore::FutureWarning") # Deprecation of "M" etc. def test_to_offset_sub_annual(freq, expected): assert to_offset(freq) == expected -_ANNUAL_OFFSET_TYPES = {"A": YearEnd, "AS": YearBegin} +_ANNUAL_OFFSET_TYPES = { + "A": YearEnd, + "AS": YearBegin, + "Y": YearEnd, + "YS": YearBegin, + "YE": YearEnd, +} @pytest.mark.parametrize( ("month_int", "month_label"), list(_MONTH_ABBREVIATIONS.items()) + [(0, "")] ) -@pytest.mark.parametrize("multiple", [None, 2]) -@pytest.mark.parametrize("offset_str", ["AS", "A"]) +@pytest.mark.parametrize("multiple", [None, 2, -1]) +@pytest.mark.parametrize("offset_str", ["AS", "A", "YS", "Y"]) +@pytest.mark.filterwarnings("ignore::FutureWarning") # Deprecation of "A" etc. def test_to_offset_annual(month_label, month_int, multiple, offset_str): freq = offset_str offset_type = _ANNUAL_OFFSET_TYPES[offset_str] @@ -245,14 +289,15 @@ def test_to_offset_annual(month_label, month_int, multiple, offset_str): assert result == expected -_QUARTER_OFFSET_TYPES = {"Q": QuarterEnd, "QS": QuarterBegin} +_QUARTER_OFFSET_TYPES = {"Q": QuarterEnd, "QS": QuarterBegin, "QE": QuarterEnd} @pytest.mark.parametrize( ("month_int", "month_label"), list(_MONTH_ABBREVIATIONS.items()) + [(0, "")] ) -@pytest.mark.parametrize("multiple", [None, 2]) -@pytest.mark.parametrize("offset_str", ["QS", "Q"]) +@pytest.mark.parametrize("multiple", [None, 2, -1]) +@pytest.mark.parametrize("offset_str", ["QS", "Q", "QE"]) +@pytest.mark.filterwarnings("ignore::FutureWarning") # Deprecation of "Q" etc. def test_to_offset_quarter(month_label, month_int, multiple, offset_str): freq = offset_str offset_type = _QUARTER_OFFSET_TYPES[offset_str] @@ -385,6 +430,7 @@ def test_eq(a, b): _MUL_TESTS = [ (BaseCFTimeOffset(), 3, BaseCFTimeOffset(n=3)), + (BaseCFTimeOffset(), -3, BaseCFTimeOffset(n=-3)), (YearEnd(), 3, YearEnd(n=3)), (YearBegin(), 3, YearBegin(n=3)), (QuarterEnd(), 3, QuarterEnd(n=3)), @@ -400,6 +446,7 @@ def test_eq(a, b): (Microsecond(), 3, Microsecond(n=3)), (Day(), 0.5, Hour(n=12)), (Hour(), 0.5, Minute(n=30)), + (Hour(), -0.5, Minute(n=-30)), (Minute(), 0.5, Second(n=30)), (Second(), 0.5, Millisecond(n=500)), (Millisecond(), 0.5, Microsecond(n=500)), @@ -1129,7 +1176,7 @@ def test_rollback(calendar, offset, initial_date_args, partial_expected_date_arg "0001-01-30", "0011-02-01", None, - "3AS-JUN", + "3YS-JUN", "both", False, [(1, 6, 1), (4, 6, 1), (7, 6, 1), (10, 6, 1)], @@ -1144,6 +1191,15 @@ def test_rollback(calendar, offset, initial_date_args, partial_expected_date_arg False, [(10, 1, 1), (8, 1, 1), (6, 1, 1), (4, 1, 1)], ), + ( + "0010", + None, + 4, + "-2YS", + "both", + False, + [(10, 1, 1), (8, 1, 1), (6, 1, 1), (4, 1, 1)], + ), ( "0001-01-01", "0001-01-04", @@ -1162,6 +1218,24 @@ def test_rollback(calendar, offset, initial_date_args, partial_expected_date_arg False, [(1, 6, 1), (2, 3, 1), (2, 12, 1), (3, 9, 1)], ), + ( + "0001-06-01", + None, + 4, + "-1MS", + "both", + False, + [(1, 6, 1), (1, 5, 1), (1, 4, 1), (1, 3, 1)], + ), + ( + "0001-01-30", + None, + 4, + "-1D", + "both", + False, + [(1, 1, 30), (1, 1, 29), (1, 1, 28), (1, 1, 27)], + ), ] @@ -1215,47 +1289,81 @@ def test_cftime_range_name(): @pytest.mark.parametrize( - ("start", "end", "periods", "freq", "closed"), + ("start", "end", "periods", "freq", "inclusive"), [ - (None, None, 5, "A", None), - ("2000", None, None, "A", None), - (None, "2000", None, "A", None), - ("2000", "2001", None, None, None), + (None, None, 5, "YE", None), + ("2000", None, None, "YE", None), + (None, "2000", None, "YE", None), (None, None, None, None, None), - ("2000", "2001", None, "A", "up"), - ("2000", "2001", 5, "A", None), + ("2000", "2001", None, "YE", "up"), + ("2000", "2001", 5, "YE", None), ], ) -def test_invalid_cftime_range_inputs(start, end, periods, freq, closed): +def test_invalid_cftime_range_inputs( + start: str | None, + end: str | None, + periods: int | None, + freq: str | None, + inclusive: Literal["up", None], +) -> None: with pytest.raises(ValueError): - cftime_range(start, end, periods, freq, closed=closed) + cftime_range(start, end, periods, freq, inclusive=inclusive) # type: ignore[arg-type] + + +def test_invalid_cftime_arg() -> None: + with pytest.warns( + FutureWarning, match="Following pandas, the `closed` parameter is deprecated" + ): + cftime_range("2000", "2001", None, "YE", closed="left") _CALENDAR_SPECIFIC_MONTH_END_TESTS = [ - ("2M", "noleap", [(2, 28), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), - ("2M", "all_leap", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), - ("2M", "360_day", [(2, 30), (4, 30), (6, 30), (8, 30), (10, 30), (12, 30)]), - ("2M", "standard", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), - ("2M", "gregorian", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), - ("2M", "julian", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ("noleap", [(2, 28), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ("all_leap", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ("360_day", [(2, 30), (4, 30), (6, 30), (8, 30), (10, 30), (12, 30)]), + ("standard", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ("gregorian", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), + ("julian", [(2, 29), (4, 30), (6, 30), (8, 31), (10, 31), (12, 31)]), ] @pytest.mark.parametrize( - ("freq", "calendar", "expected_month_day"), + ("calendar", "expected_month_day"), _CALENDAR_SPECIFIC_MONTH_END_TESTS, ids=_id_func, ) -def test_calendar_specific_month_end(freq, calendar, expected_month_day): +def test_calendar_specific_month_end( + calendar: str, expected_month_day: list[tuple[int, int]] +) -> None: year = 2000 # Use a leap-year to highlight calendar differences result = cftime_range( - start="2000-02", end="2001", freq=freq, calendar=calendar + start="2000-02", end="2001", freq="2ME", calendar=calendar ).values date_type = get_date_type(calendar) expected = [date_type(year, *args) for args in expected_month_day] np.testing.assert_equal(result, expected) +@pytest.mark.parametrize( + ("calendar", "expected_month_day"), + _CALENDAR_SPECIFIC_MONTH_END_TESTS, + ids=_id_func, +) +def test_calendar_specific_month_end_negative_freq( + calendar: str, expected_month_day: list[tuple[int, int]] +) -> None: + year = 2000 # Use a leap-year to highlight calendar differences + result = cftime_range( + start="2000-12", + end="2000", + freq="-2ME", + calendar=calendar, + ).values + date_type = get_date_type(calendar) + expected = [date_type(year, *args) for args in expected_month_day[::-1]] + np.testing.assert_equal(result, expected) + + @pytest.mark.parametrize( ("calendar", "start", "end", "expected_number_of_days"), [ @@ -1273,26 +1381,32 @@ def test_calendar_specific_month_end(freq, calendar, expected_month_day): ("julian", "2001", "2002", 365), ], ) -def test_calendar_year_length(calendar, start, end, expected_number_of_days): - result = cftime_range(start, end, freq="D", closed="left", calendar=calendar) +def test_calendar_year_length( + calendar: str, start: str, end: str, expected_number_of_days: int +) -> None: + result = cftime_range(start, end, freq="D", inclusive="left", calendar=calendar) assert len(result) == expected_number_of_days -@pytest.mark.parametrize("freq", ["A", "M", "D"]) -def test_dayofweek_after_cftime_range(freq): +@pytest.mark.parametrize("freq", ["YE", "ME", "D"]) +def test_dayofweek_after_cftime_range(freq: str) -> None: result = cftime_range("2000-02-01", periods=3, freq=freq).dayofweek + # TODO: remove once requiring pandas 2.2+ + freq = _new_to_legacy_freq(freq) expected = pd.date_range("2000-02-01", periods=3, freq=freq).dayofweek np.testing.assert_array_equal(result, expected) -@pytest.mark.parametrize("freq", ["A", "M", "D"]) -def test_dayofyear_after_cftime_range(freq): +@pytest.mark.parametrize("freq", ["YE", "ME", "D"]) +def test_dayofyear_after_cftime_range(freq: str) -> None: result = cftime_range("2000-02-01", periods=3, freq=freq).dayofyear + # TODO: remove once requiring pandas 2.2+ + freq = _new_to_legacy_freq(freq) expected = pd.date_range("2000-02-01", periods=3, freq=freq).dayofyear np.testing.assert_array_equal(result, expected) -def test_cftime_range_standard_calendar_refers_to_gregorian(): +def test_cftime_range_standard_calendar_refers_to_gregorian() -> None: from cftime import DatetimeGregorian (result,) = cftime_range("2000", periods=1) @@ -1310,7 +1424,9 @@ def test_cftime_range_standard_calendar_refers_to_gregorian(): ("3400-01-01", "standard", None, CFTimeIndex), ], ) -def test_date_range(start, calendar, use_cftime, expected_type): +def test_date_range( + start: str, calendar: str, use_cftime: bool | None, expected_type +) -> None: dr = date_range( start, periods=14, freq="D", calendar=calendar, use_cftime=use_cftime ) @@ -1318,7 +1434,7 @@ def test_date_range(start, calendar, use_cftime, expected_type): assert isinstance(dr, expected_type) -def test_date_range_errors(): +def test_date_range_errors() -> None: with pytest.raises(ValueError, match="Date range is invalid"): date_range( "1400-01-01", periods=1, freq="D", calendar="standard", use_cftime=False @@ -1343,20 +1459,28 @@ def test_date_range_errors(): @pytest.mark.parametrize( "start,freq,cal_src,cal_tgt,use_cftime,exp0,exp_pd", [ - ("2020-02-01", "4M", "standard", "noleap", None, "2020-02-28", False), - ("2020-02-01", "M", "noleap", "gregorian", True, "2020-02-29", True), - ("2020-02-28", "3H", "all_leap", "gregorian", False, "2020-02-28", True), - ("2020-03-30", "M", "360_day", "gregorian", False, "2020-03-31", True), - ("2020-03-31", "M", "gregorian", "360_day", None, "2020-03-30", False), + ("2020-02-01", "4ME", "standard", "noleap", None, "2020-02-28", False), + ("2020-02-01", "ME", "noleap", "gregorian", True, "2020-02-29", True), + ("2020-02-01", "QE-DEC", "noleap", "gregorian", True, "2020-03-31", True), + ("2020-02-01", "YS-FEB", "noleap", "gregorian", True, "2020-02-01", True), + ("2020-02-01", "YE-FEB", "noleap", "gregorian", True, "2020-02-29", True), + ("2020-02-01", "-1YE-FEB", "noleap", "gregorian", True, "2020-02-29", True), + ("2020-02-28", "3h", "all_leap", "gregorian", False, "2020-02-28", True), + ("2020-03-30", "ME", "360_day", "gregorian", False, "2020-03-31", True), + ("2020-03-31", "ME", "gregorian", "360_day", None, "2020-03-30", False), + ("2020-03-31", "-1ME", "gregorian", "360_day", None, "2020-03-30", False), ], ) def test_date_range_like(start, freq, cal_src, cal_tgt, use_cftime, exp0, exp_pd): + expected_freq = freq + source = date_range(start, periods=12, freq=freq, calendar=cal_src) out = date_range_like(source, cal_tgt, use_cftime=use_cftime) assert len(out) == 12 - assert infer_freq(out) == freq + + assert infer_freq(out) == expected_freq assert out[0].isoformat().startswith(exp0) @@ -1367,12 +1491,28 @@ def test_date_range_like(start, freq, cal_src, cal_tgt, use_cftime, exp0, exp_pd assert out.calendar == cal_tgt +@requires_cftime +@pytest.mark.parametrize( + "freq", ("YE", "YS", "YE-MAY", "MS", "ME", "QS", "h", "min", "s") +) +@pytest.mark.parametrize("use_cftime", (True, False)) +def test_date_range_like_no_deprecation(freq, use_cftime): + # ensure no internal warnings + # TODO: remove once freq string deprecation is finished + + source = date_range("2000", periods=3, freq=freq, use_cftime=False) + + with assert_no_warnings(): + date_range_like(source, "standard", use_cftime=use_cftime) + + def test_date_range_like_same_calendar(): - src = date_range("2000-01-01", periods=12, freq="6H", use_cftime=False) + src = date_range("2000-01-01", periods=12, freq="6h", use_cftime=False) out = date_range_like(src, "standard", use_cftime=False) assert src is out +@pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_date_range_like_errors(): src = date_range("1899-02-03", periods=20, freq="D", use_cftime=False) src = src[np.arange(20) != 10] # Remove 1 day so the frequency is not inferable. @@ -1411,7 +1551,7 @@ def as_timedelta_not_implemented_error(): @pytest.mark.parametrize("function", [cftime_range, date_range]) -def test_cftime_or_date_range_closed_and_inclusive_error(function) -> None: +def test_cftime_or_date_range_closed_and_inclusive_error(function: Callable) -> None: if function == cftime_range and not has_cftime: pytest.skip("requires cftime") @@ -1420,22 +1560,29 @@ def test_cftime_or_date_range_closed_and_inclusive_error(function) -> None: @pytest.mark.parametrize("function", [cftime_range, date_range]) -def test_cftime_or_date_range_invalid_closed_value(function) -> None: +def test_cftime_or_date_range_invalid_inclusive_value(function: Callable) -> None: if function == cftime_range and not has_cftime: pytest.skip("requires cftime") - with pytest.raises(ValueError, match="Argument `closed` must be"): - function("2000", periods=3, closed="foo") + with pytest.raises(ValueError, match="nclusive"): + function("2000", periods=3, inclusive="foo") -@pytest.mark.parametrize("function", [cftime_range, date_range]) +@pytest.mark.parametrize( + "function", + [ + pytest.param(cftime_range, id="cftime", marks=requires_cftime), + pytest.param(date_range, id="date"), + ], +) @pytest.mark.parametrize( ("closed", "inclusive"), [(None, "both"), ("left", "left"), ("right", "right")] ) -def test_cftime_or_date_range_closed(function, closed, inclusive) -> None: - if function == cftime_range and not has_cftime: - pytest.skip("requires cftime") - +def test_cftime_or_date_range_closed( + function: Callable, + closed: Literal["left", "right", None], + inclusive: Literal["left", "right", "both"], +) -> None: with pytest.warns(FutureWarning, match="Following pandas"): result_closed = function("2000-01-01", "2000-01-04", freq="D", closed=closed) result_inclusive = function( @@ -1452,3 +1599,180 @@ def test_cftime_or_date_range_inclusive_None(function) -> None: result_None = function("2000-01-01", "2000-01-04") result_both = function("2000-01-01", "2000-01-04", inclusive="both") np.testing.assert_equal(result_None.values, result_both.values) + + +@pytest.mark.parametrize( + "freq", ["A", "AS", "Q", "M", "H", "T", "S", "L", "U", "Y", "A-MAY"] +) +def test_to_offset_deprecation_warning(freq): + # Test for deprecations outlined in GitHub issue #8394 + with pytest.warns(FutureWarning, match="is deprecated"): + to_offset(freq) + + +@pytest.mark.skipif(has_pandas_ge_2_2, reason="only relevant for pandas lt 2.2") +@pytest.mark.parametrize( + "freq, expected", + ( + ["Y", "YE"], + ["A", "YE"], + ["Q", "QE"], + ["M", "ME"], + ["AS", "YS"], + ["YE", "YE"], + ["QE", "QE"], + ["ME", "ME"], + ["YS", "YS"], + ), +) +@pytest.mark.parametrize("n", ("", "2")) +def test_legacy_to_new_freq(freq, expected, n): + freq = f"{n}{freq}" + result = _legacy_to_new_freq(freq) + + expected = f"{n}{expected}" + + assert result == expected + + +@pytest.mark.skipif(has_pandas_ge_2_2, reason="only relevant for pandas lt 2.2") +@pytest.mark.parametrize("year_alias", ("YE", "Y", "A")) +@pytest.mark.parametrize("n", ("", "2")) +def test_legacy_to_new_freq_anchored(year_alias, n): + for month in _MONTH_ABBREVIATIONS.values(): + freq = f"{n}{year_alias}-{month}" + result = _legacy_to_new_freq(freq) + + expected = f"{n}YE-{month}" + + assert result == expected + + +@pytest.mark.skipif(has_pandas_ge_2_2, reason="only relevant for pandas lt 2.2") +@pytest.mark.filterwarnings("ignore:'[AY]' is deprecated") +@pytest.mark.parametrize( + "freq, expected", + (["A", "A"], ["YE", "A"], ["Y", "A"], ["QE", "Q"], ["ME", "M"], ["YS", "AS"]), +) +@pytest.mark.parametrize("n", ("", "2")) +def test_new_to_legacy_freq(freq, expected, n): + freq = f"{n}{freq}" + result = _new_to_legacy_freq(freq) + + expected = f"{n}{expected}" + + assert result == expected + + +@pytest.mark.skipif(has_pandas_ge_2_2, reason="only relevant for pandas lt 2.2") +@pytest.mark.filterwarnings("ignore:'[AY]-.{3}' is deprecated") +@pytest.mark.parametrize("year_alias", ("A", "Y", "YE")) +@pytest.mark.parametrize("n", ("", "2")) +def test_new_to_legacy_freq_anchored(year_alias, n): + for month in _MONTH_ABBREVIATIONS.values(): + freq = f"{n}{year_alias}-{month}" + result = _new_to_legacy_freq(freq) + + expected = f"{n}A-{month}" + + assert result == expected + + +@pytest.mark.skipif(has_pandas_ge_2_2, reason="only for pandas lt 2.2") +@pytest.mark.parametrize( + "freq, expected", + ( + # pandas-only freq strings are passed through + ("BH", "BH"), + ("CBH", "CBH"), + ("N", "N"), + ), +) +def test_legacy_to_new_freq_pd_freq_passthrough(freq, expected): + + result = _legacy_to_new_freq(freq) + assert result == expected + + +@pytest.mark.filterwarnings("ignore:'.' is deprecated ") +@pytest.mark.skipif(has_pandas_ge_2_2, reason="only for pandas lt 2.2") +@pytest.mark.parametrize( + "freq, expected", + ( + # these are each valid in pandas lt 2.2 + ("T", "T"), + ("min", "min"), + ("S", "S"), + ("s", "s"), + ("L", "L"), + ("ms", "ms"), + ("U", "U"), + ("us", "us"), + # pandas-only freq strings are passed through + ("bh", "bh"), + ("cbh", "cbh"), + ("ns", "ns"), + ), +) +def test_new_to_legacy_freq_pd_freq_passthrough(freq, expected): + + result = _new_to_legacy_freq(freq) + assert result == expected + + +@pytest.mark.filterwarnings("ignore:Converting a CFTimeIndex with:") +@pytest.mark.parametrize("start", ("2000", "2001")) +@pytest.mark.parametrize("end", ("2000", "2001")) +@pytest.mark.parametrize( + "freq", ("MS", "-1MS", "YS", "-1YS", "ME", "-1ME", "YE", "-1YE") +) +def test_cftime_range_same_as_pandas(start, end, freq): + result = date_range(start, end, freq=freq, calendar="standard", use_cftime=True) + result = result.to_datetimeindex() + expected = date_range(start, end, freq=freq, use_cftime=False) + + np.testing.assert_array_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:Converting a CFTimeIndex with:") +@pytest.mark.parametrize( + "start, end, periods", + [ + ("2022-01-01", "2022-01-10", 2), + ("2022-03-01", "2022-03-31", 2), + ("2022-01-01", "2022-01-10", None), + ("2022-03-01", "2022-03-31", None), + ], +) +def test_cftime_range_no_freq(start, end, periods): + """ + Test whether cftime_range produces the same result as Pandas + when freq is not provided, but start, end and periods are. + """ + # Generate date ranges using cftime_range + result = cftime_range(start=start, end=end, periods=periods) + result = result.to_datetimeindex() + expected = pd.date_range(start=start, end=end, periods=periods) + + np.testing.assert_array_equal(result, expected) + + +@pytest.mark.parametrize( + "start, end, periods", + [ + ("2022-01-01", "2022-01-10", 2), + ("2022-03-01", "2022-03-31", 2), + ("2022-01-01", "2022-01-10", None), + ("2022-03-01", "2022-03-31", None), + ], +) +def test_date_range_no_freq(start, end, periods): + """ + Test whether date_range produces the same result as Pandas + when freq is not provided, but start, end and periods are. + """ + # Generate date ranges using date_range + result = date_range(start=start, end=end, periods=periods) + expected = pd.date_range(start=start, end=end, periods=periods) + + np.testing.assert_array_equal(result, expected) diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index a676b1f07f1..f6eb15fa373 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -7,7 +7,6 @@ import numpy as np import pandas as pd import pytest -from packaging.version import Version import xarray as xr from xarray.coding.cftimeindex import ( @@ -33,12 +32,7 @@ # cftime 1.5.2 renames "gregorian" to "standard" standard_or_gregorian = "" if has_cftime: - import cftime - - if Version(cftime.__version__) >= Version("1.5.2"): - standard_or_gregorian = "standard" - else: - standard_or_gregorian = "gregorian" + standard_or_gregorian = "standard" def date_dict(year=None, month=None, day=None, hour=None, minute=None, second=None): @@ -244,28 +238,57 @@ def test_assert_all_valid_date_type(date_type, index): ) def test_cftimeindex_field_accessors(index, field, expected): result = getattr(index, field) + expected = np.array(expected, dtype=np.int64) assert_array_equal(result, expected) + assert result.dtype == expected.dtype + + +@requires_cftime +@pytest.mark.parametrize( + ("field"), + [ + "year", + "month", + "day", + "hour", + "minute", + "second", + "microsecond", + "dayofyear", + "dayofweek", + "days_in_month", + ], +) +def test_empty_cftimeindex_field_accessors(field): + index = CFTimeIndex([]) + result = getattr(index, field) + expected = np.array([], dtype=np.int64) + assert_array_equal(result, expected) + assert result.dtype == expected.dtype @requires_cftime def test_cftimeindex_dayofyear_accessor(index): result = index.dayofyear - expected = [date.dayofyr for date in index] + expected = np.array([date.dayofyr for date in index], dtype=np.int64) assert_array_equal(result, expected) + assert result.dtype == expected.dtype @requires_cftime def test_cftimeindex_dayofweek_accessor(index): result = index.dayofweek - expected = [date.dayofwk for date in index] + expected = np.array([date.dayofwk for date in index], dtype=np.int64) assert_array_equal(result, expected) + assert result.dtype == expected.dtype @requires_cftime def test_cftimeindex_days_in_month_accessor(index): result = index.days_in_month - expected = [date.daysinmonth for date in index] + expected = np.array([date.daysinmonth for date in index], dtype=np.int64) assert_array_equal(result, expected) + assert result.dtype == expected.dtype @requires_cftime @@ -747,10 +770,10 @@ def test_cftimeindex_add_timedeltaindex(calendar) -> None: "freq,units", [ ("D", "D"), - ("H", "H"), - ("T", "min"), - ("S", "S"), - ("L", "ms"), + ("h", "h"), + ("min", "min"), + ("s", "s"), + ("ms", "ms"), ], ) @pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) @@ -772,7 +795,7 @@ def test_cftimeindex_shift_float_us() -> None: @requires_cftime -@pytest.mark.parametrize("freq", ["AS", "A", "YS", "Y", "QS", "Q", "MS", "M"]) +@pytest.mark.parametrize("freq", ["YS", "YE", "QS", "QE", "MS", "ME"]) def test_cftimeindex_shift_float_fails_for_non_tick_freqs(freq) -> None: a = xr.cftime_range("2000", periods=3, freq="D") with pytest.raises(TypeError, match="unsupported operand type"): @@ -965,6 +988,31 @@ def test_cftimeindex_calendar_property(calendar, expected): assert index.calendar == expected +@requires_cftime +def test_empty_cftimeindex_calendar_property(): + index = CFTimeIndex([]) + assert index.calendar is None + + +@requires_cftime +@pytest.mark.parametrize( + "calendar", + [ + "noleap", + "365_day", + "360_day", + "julian", + "gregorian", + "standard", + "proleptic_gregorian", + ], +) +def test_cftimeindex_freq_property_none_size_lt_3(calendar): + for periods in range(3): + index = xr.cftime_range(start="2000", periods=periods, calendar=calendar) + assert index.freq is None + + @requires_cftime @pytest.mark.parametrize( ("calendar", "expected"), @@ -997,7 +1045,7 @@ def test_cftimeindex_periods_repr(periods): @requires_cftime @pytest.mark.parametrize("calendar", ["noleap", "360_day", "standard"]) -@pytest.mark.parametrize("freq", ["D", "H"]) +@pytest.mark.parametrize("freq", ["D", "h"]) def test_cftimeindex_freq_in_repr(freq, calendar): """Test that cftimeindex has frequency property in repr.""" index = xr.cftime_range(start="2000", periods=3, freq=freq, calendar=calendar) @@ -1141,7 +1189,6 @@ def test_to_datetimeindex_feb_29(calendar): @requires_cftime -@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/24263") def test_multiindex(): index = xr.cftime_range("2001-01-01", periods=100, calendar="360_day") mindex = pd.MultiIndex.from_arrays([index]) @@ -1149,20 +1196,32 @@ def test_multiindex(): @requires_cftime -@pytest.mark.parametrize("freq", ["3663S", "33T", "2H"]) +@pytest.mark.parametrize("freq", ["3663s", "33min", "2h"]) @pytest.mark.parametrize("method", ["floor", "ceil", "round"]) def test_rounding_methods_against_datetimeindex(freq, method): - expected = pd.date_range("2000-01-02T01:03:51", periods=10, freq="1777S") + expected = pd.date_range("2000-01-02T01:03:51", periods=10, freq="1777s") expected = getattr(expected, method)(freq) - result = xr.cftime_range("2000-01-02T01:03:51", periods=10, freq="1777S") + result = xr.cftime_range("2000-01-02T01:03:51", periods=10, freq="1777s") result = getattr(result, method)(freq).to_datetimeindex() assert result.equals(expected) +@requires_cftime +@pytest.mark.parametrize("method", ["floor", "ceil", "round"]) +def test_rounding_methods_empty_cftimindex(method): + index = CFTimeIndex([]) + result = getattr(index, method)("2s") + + expected = CFTimeIndex([]) + + assert result.equals(expected) + assert result is not index + + @requires_cftime @pytest.mark.parametrize("method", ["floor", "ceil", "round"]) def test_rounding_methods_invalid_freq(method): - index = xr.cftime_range("2000-01-02T01:03:51", periods=10, freq="1777S") + index = xr.cftime_range("2000-01-02T01:03:51", periods=10, freq="1777s") with pytest.raises(ValueError, match="fixed"): getattr(index, method)("MS") @@ -1180,7 +1239,7 @@ def rounding_index(date_type): @requires_cftime def test_ceil(rounding_index, date_type): - result = rounding_index.ceil("S") + result = rounding_index.ceil("s") expected = xr.CFTimeIndex( [ date_type(1, 1, 1, 2, 0, 0, 0), @@ -1193,7 +1252,7 @@ def test_ceil(rounding_index, date_type): @requires_cftime def test_floor(rounding_index, date_type): - result = rounding_index.floor("S") + result = rounding_index.floor("s") expected = xr.CFTimeIndex( [ date_type(1, 1, 1, 1, 59, 59, 0), @@ -1206,7 +1265,7 @@ def test_floor(rounding_index, date_type): @requires_cftime def test_round(rounding_index, date_type): - result = rounding_index.round("S") + result = rounding_index.round("s") expected = xr.CFTimeIndex( [ date_type(1, 1, 1, 2, 0, 0, 0), @@ -1237,6 +1296,14 @@ def test_asi8_distant_date(): np.testing.assert_array_equal(result, expected) +@requires_cftime +def test_asi8_empty_cftimeindex(): + index = xr.CFTimeIndex([]) + result = index.asi8 + expected = np.array([], dtype=np.int64) + np.testing.assert_array_equal(result, expected) + + @requires_cftime def test_infer_freq_valid_types(): cf_indx = xr.cftime_range("2000-01-01", periods=3, freq="D") @@ -1285,19 +1352,19 @@ def test_infer_freq_invalid_inputs(): @pytest.mark.parametrize( "freq", [ - "300AS-JAN", - "A-DEC", - "AS-JUL", - "2AS-FEB", - "Q-NOV", + "300YS-JAN", + "YE-DEC", + "YS-JUL", + "2YS-FEB", + "QE-NOV", "3QS-DEC", "MS", - "4M", + "4ME", "7D", "D", - "30H", - "5T", - "40S", + "30h", + "5min", + "40s", ], ) @pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) diff --git a/xarray/tests/test_cftimeindex_resample.py b/xarray/tests/test_cftimeindex_resample.py index 07bc14f8983..98d4377706c 100644 --- a/xarray/tests/test_cftimeindex_resample.py +++ b/xarray/tests/test_cftimeindex_resample.py @@ -6,8 +6,10 @@ import numpy as np import pandas as pd import pytest +from packaging.version import Version import xarray as xr +from xarray.coding.cftime_offsets import _new_to_legacy_freq from xarray.core.pdcompat import _convert_base_to_offset from xarray.core.resample_cftime import CFTimeGrouper @@ -24,31 +26,31 @@ FREQS = [ ("8003D", "4001D"), ("8003D", "16006D"), - ("8003D", "21AS"), - ("6H", "3H"), - ("6H", "12H"), - ("6H", "400T"), + ("8003D", "21YS"), + ("6h", "3h"), + ("6h", "12h"), + ("6h", "400min"), ("3D", "D"), ("3D", "6D"), ("11D", "MS"), ("3MS", "MS"), ("3MS", "6MS"), ("3MS", "85D"), - ("7M", "3M"), - ("7M", "14M"), - ("7M", "2QS-APR"), + ("7ME", "3ME"), + ("7ME", "14ME"), + ("7ME", "2QS-APR"), ("43QS-AUG", "21QS-AUG"), ("43QS-AUG", "86QS-AUG"), - ("43QS-AUG", "11A-JUN"), - ("11Q-JUN", "5Q-JUN"), - ("11Q-JUN", "22Q-JUN"), - ("11Q-JUN", "51MS"), - ("3AS-MAR", "AS-MAR"), - ("3AS-MAR", "6AS-MAR"), - ("3AS-MAR", "14Q-FEB"), - ("7A-MAY", "3A-MAY"), - ("7A-MAY", "14A-MAY"), - ("7A-MAY", "85M"), + ("43QS-AUG", "11YE-JUN"), + ("11QE-JUN", "5QE-JUN"), + ("11QE-JUN", "22QE-JUN"), + ("11QE-JUN", "51MS"), + ("3YS-MAR", "YS-MAR"), + ("3YS-MAR", "6YS-MAR"), + ("3YS-MAR", "14QE-FEB"), + ("7YE-MAY", "3YE-MAY"), + ("7YE-MAY", "14YE-MAY"), + ("7YE-MAY", "85ME"), ] @@ -114,20 +116,33 @@ def da(index) -> xr.DataArray: ) +@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") @pytest.mark.parametrize("freqs", FREQS, ids=lambda x: "{}->{}".format(*x)) @pytest.mark.parametrize("closed", [None, "left", "right"]) @pytest.mark.parametrize("label", [None, "left", "right"]) @pytest.mark.parametrize( - ("base", "offset"), [(24, None), (31, None), (None, "5S")], ids=lambda x: f"{x}" + ("base", "offset"), [(24, None), (31, None), (None, "5s")], ids=lambda x: f"{x}" ) def test_resample(freqs, closed, label, base, offset) -> None: initial_freq, resample_freq = freqs + if ( + resample_freq == "4001D" + and closed == "right" + and Version(pd.__version__) < Version("2.2") + ): + pytest.skip( + "Pandas fixed a bug in this test case in version 2.2, which we " + "ported to xarray, so this test no longer produces the same " + "result as pandas for earlier pandas versions." + ) start = "2000-01-01T12:07:01" - loffset = "12H" + loffset = "12h" origin = "start" - index_kwargs = dict(start=start, periods=5, freq=initial_freq) - datetime_index = pd.date_range(**index_kwargs) - cftime_index = xr.cftime_range(**index_kwargs) + + datetime_index = pd.date_range( + start=start, periods=5, freq=_new_to_legacy_freq(initial_freq) + ) + cftime_index = xr.cftime_range(start=start, periods=5, freq=initial_freq) da_datetimeindex = da(datetime_index) da_cftimeindex = da(cftime_index) @@ -148,16 +163,16 @@ def test_resample(freqs, closed, label, base, offset) -> None: @pytest.mark.parametrize( ("freq", "expected"), [ - ("S", "left"), - ("T", "left"), - ("H", "left"), + ("s", "left"), + ("min", "left"), + ("h", "left"), ("D", "left"), - ("M", "right"), + ("ME", "right"), ("MS", "left"), - ("Q", "right"), + ("QE", "right"), ("QS", "left"), - ("A", "right"), - ("AS", "left"), + ("YE", "right"), + ("YS", "left"), ], ) def test_closed_label_defaults(freq, expected) -> None: @@ -166,12 +181,13 @@ def test_closed_label_defaults(freq, expected) -> None: @pytest.mark.filterwarnings("ignore:Converting a CFTimeIndex") +@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") @pytest.mark.parametrize( "calendar", ["gregorian", "noleap", "all_leap", "360_day", "julian"] ) -def test_calendars(calendar) -> None: +def test_calendars(calendar: str) -> None: # Limited testing for non-standard calendars - freq, closed, label, base = "8001T", None, None, 17 + freq, closed, label, base = "8001min", None, None, 17 loffset = datetime.timedelta(hours=12) xr_index = xr.cftime_range( start="2004-01-01T12:07:01", periods=7, freq="3D", calendar=calendar @@ -198,6 +214,7 @@ class DateRangeKwargs(TypedDict): freq: str +@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") @pytest.mark.parametrize("closed", ["left", "right"]) @pytest.mark.parametrize( "origin", @@ -205,7 +222,7 @@ class DateRangeKwargs(TypedDict): ids=lambda x: f"{x}", ) def test_origin(closed, origin) -> None: - initial_freq, resample_freq = ("3H", "9H") + initial_freq, resample_freq = ("3h", "9h") start = "1969-12-31T12:07:01" index_kwargs: DateRangeKwargs = dict(start=start, periods=12, freq=initial_freq) datetime_index = pd.date_range(**index_kwargs) @@ -222,11 +239,12 @@ def test_origin(closed, origin) -> None: ) +@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") def test_base_and_offset_error(): cftime_index = xr.cftime_range("2000", periods=5) da_cftime = da(cftime_index) with pytest.raises(ValueError, match="base and offset cannot"): - da_cftime.resample(time="2D", base=3, offset="5S") + da_cftime.resample(time="2D", base=3, offset="5s") @pytest.mark.parametrize("offset", ["foo", "5MS", 10]) @@ -239,7 +257,7 @@ def test_invalid_offset_error(offset) -> None: def test_timedelta_offset() -> None: timedelta = datetime.timedelta(seconds=5) - string = "5S" + string = "5s" cftime_index = xr.cftime_range("2000", periods=5) da_cftime = da(cftime_index) @@ -249,31 +267,32 @@ def test_timedelta_offset() -> None: xr.testing.assert_identical(timedelta_result, string_result) -@pytest.mark.parametrize("loffset", ["12H", datetime.timedelta(hours=-12)]) +@pytest.mark.parametrize("loffset", ["MS", "12h", datetime.timedelta(hours=-12)]) def test_resample_loffset_cftimeindex(loffset) -> None: - datetimeindex = pd.date_range("2000-01-01", freq="6H", periods=10) + datetimeindex = pd.date_range("2000-01-01", freq="6h", periods=10) da_datetimeindex = xr.DataArray(np.arange(10), [("time", datetimeindex)]) - cftimeindex = xr.cftime_range("2000-01-01", freq="6H", periods=10) + cftimeindex = xr.cftime_range("2000-01-01", freq="6h", periods=10) da_cftimeindex = xr.DataArray(np.arange(10), [("time", cftimeindex)]) with pytest.warns(FutureWarning, match="`loffset` parameter"): - result = da_cftimeindex.resample(time="24H", loffset=loffset).mean() - expected = da_datetimeindex.resample(time="24H", loffset=loffset).mean() + result = da_cftimeindex.resample(time="24h", loffset=loffset).mean() + expected = da_datetimeindex.resample(time="24h", loffset=loffset).mean() result["time"] = result.xindexes["time"].to_pandas_index().to_datetimeindex() xr.testing.assert_identical(result, expected) +@pytest.mark.filterwarnings("ignore:.*the `(base|loffset)` parameter to resample") def test_resample_invalid_loffset_cftimeindex() -> None: - times = xr.cftime_range("2000-01-01", freq="6H", periods=10) + times = xr.cftime_range("2000-01-01", freq="6h", periods=10) da = xr.DataArray(np.arange(10), [("time", times)]) with pytest.raises(ValueError): - da.resample(time="24H", loffset=1) # type: ignore + da.resample(time="24h", loffset=1) # type: ignore -@pytest.mark.parametrize(("base", "freq"), [(1, "10S"), (17, "3H"), (15, "5U")]) +@pytest.mark.parametrize(("base", "freq"), [(1, "10s"), (17, "3h"), (15, "5us")]) def test__convert_base_to_offset(base, freq): # Verify that the cftime_offset adapted version of _convert_base_to_offset # produces the same result as the pandas version. @@ -286,4 +305,4 @@ def test__convert_base_to_offset(base, freq): def test__convert_base_to_offset_invalid_index(): with pytest.raises(ValueError, match="Can only resample"): - _convert_base_to_offset(1, "12H", pd.Index([0])) + _convert_base_to_offset(1, "12h", pd.Index([0])) diff --git a/xarray/tests/test_coarsen.py b/xarray/tests/test_coarsen.py index d58361afdd3..01d5393e289 100644 --- a/xarray/tests/test_coarsen.py +++ b/xarray/tests/test_coarsen.py @@ -6,6 +6,7 @@ import xarray as xr from xarray import DataArray, Dataset, set_options +from xarray.core import duck_array_ops from xarray.tests import ( assert_allclose, assert_equal, @@ -17,7 +18,10 @@ def test_coarsen_absent_dims_error(ds: Dataset) -> None: - with pytest.raises(ValueError, match=r"not found in Dataset."): + with pytest.raises( + ValueError, + match=r"Window dimensions \('foo',\) not found in Dataset dimensions", + ): ds.coarsen(foo=2) @@ -269,21 +273,24 @@ def test_coarsen_construct(self, dask: bool) -> None: expected = xr.Dataset(attrs={"foo": "bar"}) expected["vart"] = ( ("year", "month"), - ds.vart.data.reshape((-1, 12)), + duck_array_ops.reshape(ds.vart.data, (-1, 12)), {"a": "b"}, ) expected["varx"] = ( ("x", "x_reshaped"), - ds.varx.data.reshape((-1, 5)), + duck_array_ops.reshape(ds.varx.data, (-1, 5)), {"a": "b"}, ) expected["vartx"] = ( ("x", "x_reshaped", "year", "month"), - ds.vartx.data.reshape(2, 5, 4, 12), + duck_array_ops.reshape(ds.vartx.data, (2, 5, 4, 12)), {"a": "b"}, ) expected["vary"] = ds.vary - expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12))) + expected.coords["time"] = ( + ("year", "month"), + duck_array_ops.reshape(ds.time.data, (-1, 12)), + ) with raise_if_dask_computes(): actual = ds.coarsen(time=12, x=5).construct( diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index f7579c4b488..6d81d6f5dc8 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -51,9 +51,8 @@ def test_CFMaskCoder_encode_missing_fill_values_conflict(data, encoding) -> None assert encoded.dtype == encoded.attrs["missing_value"].dtype assert encoded.dtype == encoded.attrs["_FillValue"].dtype - with pytest.warns(variables.SerializationWarning): - roundtripped = decode_cf_variable("foo", encoded) - assert_identical(roundtripped, original) + roundtripped = decode_cf_variable("foo", encoded) + assert_identical(roundtripped, original) def test_CFMaskCoder_missing_value() -> None: @@ -96,16 +95,18 @@ def test_coder_roundtrip() -> None: @pytest.mark.parametrize("dtype", "u1 u2 i1 i2 f2 f4".split()) -def test_scaling_converts_to_float32(dtype) -> None: +@pytest.mark.parametrize("dtype2", "f4 f8".split()) +def test_scaling_converts_to_float(dtype: str, dtype2: str) -> None: + dt = np.dtype(dtype2) original = xr.Variable( - ("x",), np.arange(10, dtype=dtype), encoding=dict(scale_factor=10) + ("x",), np.arange(10, dtype=dtype), encoding=dict(scale_factor=dt.type(10)) ) coder = variables.CFScaleOffsetCoder() encoded = coder.encode(original) - assert encoded.dtype == np.float32 + assert encoded.dtype == dt roundtripped = coder.decode(encoded) assert_identical(original, roundtripped) - assert roundtripped.dtype == np.float32 + assert roundtripped.dtype == dt @pytest.mark.parametrize("scale_factor", (10, [10])) diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py index cb9595f4a64..51f63ea72dd 100644 --- a/xarray/tests/test_coding_strings.py +++ b/xarray/tests/test_coding_strings.py @@ -32,6 +32,10 @@ def test_vlen_dtype() -> None: assert strings.is_bytes_dtype(dtype) assert strings.check_vlen_dtype(dtype) is bytes + # check h5py variant ("vlen") + dtype = np.dtype("O", metadata={"vlen": str}) # type: ignore[call-overload,unused-ignore] + assert strings.check_vlen_dtype(dtype) is str + assert strings.check_vlen_dtype(np.dtype(object)) is None @@ -177,7 +181,7 @@ def test_StackedBytesArray_vectorized_indexing() -> None: V = IndexerMaker(indexing.VectorizedIndexer) indexer = V[np.array([[0, 1], [1, 0]])] - actual = stacked[indexer] + actual = stacked.vindex[indexer] assert_array_equal(actual, expected) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 580de878fe6..9a5589ff872 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -16,19 +16,24 @@ cftime_range, coding, conventions, + date_range, decode_cf, ) from xarray.coding.times import ( _encode_datetime_with_cftime, + _numpy_to_netcdf_timeunit, _should_cftime_be_used, cftime_to_nptime, decode_cf_datetime, + decode_cf_timedelta, encode_cf_datetime, + encode_cf_timedelta, to_timedelta_unboxed, ) from xarray.coding.variables import SerializationWarning from xarray.conventions import _update_bounds_attributes, cf_encoder from xarray.core.common import contains_cftime_datetimes +from xarray.core.utils import is_duck_dask_array from xarray.testing import assert_equal, assert_identical from xarray.tests import ( FirstElementAccessibleArray, @@ -110,6 +115,7 @@ def _all_cftime_date_types(): @requires_cftime @pytest.mark.filterwarnings("ignore:Ambiguous reference date string") +@pytest.mark.filterwarnings("ignore:Times can't be serialized faithfully") @pytest.mark.parametrize(["num_dates", "units", "calendar"], _CF_DATETIME_TESTS) def test_cf_datetime(num_dates, units, calendar) -> None: import cftime @@ -201,7 +207,7 @@ def test_decode_standard_calendar_inside_timestamp_range(calendar) -> None: import cftime units = "days since 0001-01-01" - times = pd.date_range("2001-04-01-00", end="2001-04-30-23", freq="H") + times = pd.date_range("2001-04-01-00", end="2001-04-30-23", freq="h") time = cftime.date2num(times.to_pydatetime(), units, calendar=calendar) expected = times.values expected_dtype = np.dtype("M8[ns]") @@ -221,7 +227,7 @@ def test_decode_non_standard_calendar_inside_timestamp_range(calendar) -> None: import cftime units = "days since 0001-01-01" - times = pd.date_range("2001-04-01-00", end="2001-04-30-23", freq="H") + times = pd.date_range("2001-04-01-00", end="2001-04-30-23", freq="h") non_standard_time = cftime.date2num(times.to_pydatetime(), units, calendar=calendar) expected = cftime.num2date( @@ -511,12 +517,12 @@ def test_decoded_cf_datetime_array_2d() -> None: FREQUENCIES_TO_ENCODING_UNITS = { - "N": "nanoseconds", - "U": "microseconds", - "L": "milliseconds", - "S": "seconds", - "T": "minutes", - "H": "hours", + "ns": "nanoseconds", + "us": "microseconds", + "ms": "milliseconds", + "s": "seconds", + "min": "minutes", + "h": "hours", "D": "days", } @@ -567,6 +573,7 @@ def test_infer_cftime_datetime_units(calendar, date_args, expected) -> None: assert expected == coding.times.infer_datetime_units(dates) +@pytest.mark.filterwarnings("ignore:Timedeltas can't be serialized faithfully") @pytest.mark.parametrize( ["timedeltas", "units", "numbers"], [ @@ -576,10 +583,10 @@ def test_infer_cftime_datetime_units(calendar, date_args, expected) -> None: ("1ms", "milliseconds", np.int64(1)), ("1us", "microseconds", np.int64(1)), ("1ns", "nanoseconds", np.int64(1)), - (["NaT", "0s", "1s"], None, [np.nan, 0, 1]), + (["NaT", "0s", "1s"], None, [np.iinfo(np.int64).min, 0, 1]), (["30m", "60m"], "hours", [0.5, 1.0]), - ("NaT", "days", np.nan), - (["NaT", "NaT"], "days", [np.nan, np.nan]), + ("NaT", "days", np.iinfo(np.int64).min), + (["NaT", "NaT"], "days", [np.iinfo(np.int64).min, np.iinfo(np.int64).min]), ], ) def test_cf_timedelta(timedeltas, units, numbers) -> None: @@ -730,7 +737,7 @@ def test_encode_time_bounds() -> None: # if time_bounds attrs are same as time attrs, it doesn't matter ds.time_bounds.encoding = {"calendar": "noleap", "units": "days since 2000-01-01"} - encoded, _ = cf_encoder({k: ds[k] for k in ds.variables}, ds.attrs) + encoded, _ = cf_encoder({k: v for k, v in ds.variables.items()}, ds.attrs) assert_equal(encoded["time_bounds"], expected["time_bounds"]) assert "calendar" not in encoded["time_bounds"].attrs assert "units" not in encoded["time_bounds"].attrs @@ -738,7 +745,7 @@ def test_encode_time_bounds() -> None: # for CF-noncompliant case of time_bounds attrs being different from # time attrs; preserve them for faithful roundtrip ds.time_bounds.encoding = {"calendar": "noleap", "units": "days since 1849-01-01"} - encoded, _ = cf_encoder({k: ds[k] for k in ds.variables}, ds.attrs) + encoded, _ = cf_encoder({k: v for k, v in ds.variables.items()}, ds.attrs) with pytest.raises(AssertionError): assert_equal(encoded["time_bounds"], expected["time_bounds"]) assert "calendar" not in encoded["time_bounds"].attrs @@ -1020,6 +1027,7 @@ def test_decode_ambiguous_time_warns(calendar) -> None: np.testing.assert_array_equal(result, expected) +@pytest.mark.filterwarnings("ignore:Times can't be serialized faithfully") @pytest.mark.parametrize("encoding_units", FREQUENCIES_TO_ENCODING_UNITS.values()) @pytest.mark.parametrize("freq", FREQUENCIES_TO_ENCODING_UNITS.keys()) @pytest.mark.parametrize("date_range", [pd.date_range, cftime_range]) @@ -1028,11 +1036,11 @@ def test_encode_cf_datetime_defaults_to_correct_dtype( ) -> None: if not has_cftime and date_range == cftime_range: pytest.skip("Test requires cftime") - if (freq == "N" or encoding_units == "nanoseconds") and date_range == cftime_range: + if (freq == "ns" or encoding_units == "nanoseconds") and date_range == cftime_range: pytest.skip("Nanosecond frequency is not valid for cftime dates.") times = date_range("2000", periods=3, freq=freq) units = f"{encoding_units} since 2000-01-01" - encoded, _, _ = coding.times.encode_cf_datetime(times, units) + encoded, _units, _ = coding.times.encode_cf_datetime(times, units) numpy_timeunit = coding.times._netcdf_to_numpy_timeunit(encoding_units) encoding_units_as_timedelta = np.timedelta64(1, numpy_timeunit) @@ -1045,7 +1053,7 @@ def test_encode_cf_datetime_defaults_to_correct_dtype( @pytest.mark.parametrize("freq", FREQUENCIES_TO_ENCODING_UNITS.keys()) def test_encode_decode_roundtrip_datetime64(freq) -> None: # See GH 4045. Prior to GH 4684 this test would fail for frequencies of - # "S", "L", "U", and "N". + # "s", "ms", "us", and "ns". initial_time = pd.date_range("1678-01-01", periods=1) times = initial_time.append(pd.date_range("1968", periods=2, freq=freq)) variable = Variable(["time"], times) @@ -1055,7 +1063,7 @@ def test_encode_decode_roundtrip_datetime64(freq) -> None: @requires_cftime -@pytest.mark.parametrize("freq", ["U", "L", "S", "T", "H", "D"]) +@pytest.mark.parametrize("freq", ["us", "ms", "s", "min", "h", "D"]) def test_encode_decode_roundtrip_cftime(freq) -> None: initial_time = cftime_range("0001", periods=1) times = initial_time.append( @@ -1191,3 +1199,402 @@ def test_contains_cftime_lazy() -> None: ) array = FirstElementAccessibleArray(times) assert _contains_cftime_datetimes(array) + + +@pytest.mark.parametrize( + "timestr, timeunit, dtype, fill_value, use_encoding", + [ + ("1677-09-21T00:12:43.145224193", "ns", np.int64, 20, True), + ("1970-09-21T00:12:44.145224808", "ns", np.float64, 1e30, True), + ( + "1677-09-21T00:12:43.145225216", + "ns", + np.float64, + -9.223372036854776e18, + True, + ), + ("1677-09-21T00:12:43.145224193", "ns", np.int64, None, False), + ("1677-09-21T00:12:43.145225", "us", np.int64, None, False), + ("1970-01-01T00:00:01.000001", "us", np.int64, None, False), + ("1677-09-21T00:21:52.901038080", "ns", np.float32, 20.0, True), + ], +) +def test_roundtrip_datetime64_nanosecond_precision( + timestr: str, + timeunit: str, + dtype: np.typing.DTypeLike, + fill_value: int | float | None, + use_encoding: bool, +) -> None: + # test for GH7817 + time = np.datetime64(timestr, timeunit) + times = [np.datetime64("1970-01-01T00:00:00", timeunit), np.datetime64("NaT"), time] + + if use_encoding: + encoding = dict(dtype=dtype, _FillValue=fill_value) + else: + encoding = {} + + var = Variable(["time"], times, encoding=encoding) + assert var.dtype == np.dtype(" None: + # test warning if times can't be serialized faithfully + times = [ + np.datetime64("1970-01-01T00:01:00", "ns"), + np.datetime64("NaT"), + np.datetime64("1970-01-02T00:01:00", "ns"), + ] + units = "days since 1970-01-10T01:01:00" + needed_units = "hours" + new_units = f"{needed_units} since 1970-01-10T01:01:00" + + encoding = dict(dtype=None, _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with pytest.warns(UserWarning, match=f"Resolution of {needed_units!r} needed."): + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.float64 + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == 20.0 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="int64", _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with pytest.warns( + UserWarning, match=f"Serializing with units {new_units!r} instead." + ): + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == new_units + assert encoded_var.attrs["_FillValue"] == 20 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="float64", _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with warnings.catch_warnings(): + warnings.simplefilter("error") + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.float64 + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == 20.0 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="int64", _FillValue=20, units=new_units) + var = Variable(["time"], times, encoding=encoding) + with warnings.catch_warnings(): + warnings.simplefilter("error") + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == new_units + assert encoded_var.attrs["_FillValue"] == 20 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + +@pytest.mark.parametrize( + "dtype, fill_value", + [(np.int64, 20), (np.int64, np.iinfo(np.int64).min), (np.float64, 1e30)], +) +def test_roundtrip_timedelta64_nanosecond_precision( + dtype: np.typing.DTypeLike, fill_value: int | float +) -> None: + # test for GH7942 + one_day = np.timedelta64(1, "ns") + nat = np.timedelta64("nat", "ns") + timedelta_values = (np.arange(5) * one_day).astype("timedelta64[ns]") + timedelta_values[2] = nat + timedelta_values[4] = nat + + encoding = dict(dtype=dtype, _FillValue=fill_value) + var = Variable(["time"], timedelta_values, encoding=encoding) + + encoded_var = conventions.encode_cf_variable(var) + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + + assert_identical(var, decoded_var) + + +def test_roundtrip_timedelta64_nanosecond_precision_warning() -> None: + # test warning if timedeltas can't be serialized faithfully + one_day = np.timedelta64(1, "D") + nat = np.timedelta64("nat", "ns") + timedelta_values = (np.arange(5) * one_day).astype("timedelta64[ns]") + timedelta_values[2] = nat + timedelta_values[4] = np.timedelta64(12, "h").astype("timedelta64[ns]") + + units = "days" + needed_units = "hours" + wmsg = ( + f"Timedeltas can't be serialized faithfully with requested units {units!r}. " + f"Serializing with units {needed_units!r} instead." + ) + encoding = dict(dtype=np.int64, _FillValue=20, units=units) + var = Variable(["time"], timedelta_values, encoding=encoding) + with pytest.warns(UserWarning, match=wmsg): + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == needed_units + assert encoded_var.attrs["_FillValue"] == 20 + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + assert decoded_var.encoding["dtype"] == np.int64 + + +def test_roundtrip_float_times() -> None: + # Regression test for GitHub issue #8271 + fill_value = 20.0 + times = [ + np.datetime64("1970-01-01 00:00:00", "ns"), + np.datetime64("1970-01-01 06:00:00", "ns"), + np.datetime64("NaT", "ns"), + ] + + units = "days since 1960-01-01" + var = Variable( + ["time"], + times, + encoding=dict(dtype=np.float64, _FillValue=fill_value, units=units), + ) + + encoded_var = conventions.encode_cf_variable(var) + np.testing.assert_array_equal(encoded_var, np.array([3653, 3653.25, 20.0])) + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == fill_value + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + assert decoded_var.encoding["units"] == units + assert decoded_var.encoding["_FillValue"] == fill_value + + +_ENCODE_DATETIME64_VIA_DASK_TESTS = { + "pandas-encoding-with-prescribed-units-and-dtype": ( + "D", + "days since 1700-01-01", + np.dtype("int32"), + ), + "mixed-cftime-pandas-encoding-with-prescribed-units-and-dtype": ( + "250YS", + "days since 1700-01-01", + np.dtype("int32"), + ), + "pandas-encoding-with-default-units-and-dtype": ("250YS", None, None), +} + + +@requires_dask +@pytest.mark.parametrize( + ("freq", "units", "dtype"), + _ENCODE_DATETIME64_VIA_DASK_TESTS.values(), + ids=_ENCODE_DATETIME64_VIA_DASK_TESTS.keys(), +) +def test_encode_cf_datetime_datetime64_via_dask(freq, units, dtype) -> None: + import dask.array + + times = pd.date_range(start="1700", freq=freq, periods=3) + times = dask.array.from_array(times, chunks=1) + encoded_times, encoding_units, encoding_calendar = encode_cf_datetime( + times, units, None, dtype + ) + + assert is_duck_dask_array(encoded_times) + assert encoded_times.chunks == times.chunks + + if units is not None and dtype is not None: + assert encoding_units == units + assert encoded_times.dtype == dtype + else: + assert encoding_units == "nanoseconds since 1970-01-01" + assert encoded_times.dtype == np.dtype("int64") + + assert encoding_calendar == "proleptic_gregorian" + + decoded_times = decode_cf_datetime(encoded_times, encoding_units, encoding_calendar) + np.testing.assert_equal(decoded_times, times) + + +@requires_dask +@pytest.mark.parametrize( + ("range_function", "start", "units", "dtype"), + [ + (pd.date_range, "2000", None, np.dtype("int32")), + (pd.date_range, "2000", "days since 2000-01-01", None), + (pd.timedelta_range, "0D", None, np.dtype("int32")), + (pd.timedelta_range, "0D", "days", None), + ], +) +def test_encode_via_dask_cannot_infer_error( + range_function, start, units, dtype +) -> None: + values = range_function(start=start, freq="D", periods=3) + encoding = dict(units=units, dtype=dtype) + variable = Variable(["time"], values, encoding=encoding).chunk({"time": 1}) + with pytest.raises(ValueError, match="When encoding chunked arrays"): + conventions.encode_cf_variable(variable) + + +@requires_cftime +@requires_dask +@pytest.mark.parametrize( + ("units", "dtype"), [("days since 1700-01-01", np.dtype("int32")), (None, None)] +) +def test_encode_cf_datetime_cftime_datetime_via_dask(units, dtype) -> None: + import dask.array + + calendar = "standard" + times = cftime_range(start="1700", freq="D", periods=3, calendar=calendar) + times = dask.array.from_array(times, chunks=1) + encoded_times, encoding_units, encoding_calendar = encode_cf_datetime( + times, units, None, dtype + ) + + assert is_duck_dask_array(encoded_times) + assert encoded_times.chunks == times.chunks + + if units is not None and dtype is not None: + assert encoding_units == units + assert encoded_times.dtype == dtype + else: + assert encoding_units == "microseconds since 1970-01-01" + assert encoded_times.dtype == np.int64 + + assert encoding_calendar == calendar + + decoded_times = decode_cf_datetime( + encoded_times, encoding_units, encoding_calendar, use_cftime=True + ) + np.testing.assert_equal(decoded_times, times) + + +@pytest.mark.parametrize( + "use_cftime", [False, pytest.param(True, marks=requires_cftime)] +) +@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)]) +def test_encode_cf_datetime_casting_value_error(use_cftime, use_dask) -> None: + times = date_range(start="2000", freq="12h", periods=3, use_cftime=use_cftime) + encoding = dict(units="days since 2000-01-01", dtype=np.dtype("int64")) + variable = Variable(["time"], times, encoding=encoding) + + if use_dask: + variable = variable.chunk({"time": 1}) + + if not use_cftime and not use_dask: + # In this particular case we automatically modify the encoding units to + # continue encoding with integer values. For all other cases we raise. + with pytest.warns(UserWarning, match="Times can't be serialized"): + encoded = conventions.encode_cf_variable(variable) + assert encoded.attrs["units"] == "hours since 2000-01-01" + decoded = conventions.decode_cf_variable("name", encoded) + assert_equal(variable, decoded) + else: + with pytest.raises(ValueError, match="Not possible"): + encoded = conventions.encode_cf_variable(variable) + encoded.compute() + + +@pytest.mark.parametrize( + "use_cftime", [False, pytest.param(True, marks=requires_cftime)] +) +@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)]) +@pytest.mark.parametrize("dtype", [np.dtype("int16"), np.dtype("float16")]) +def test_encode_cf_datetime_casting_overflow_error(use_cftime, use_dask, dtype) -> None: + # Regression test for GitHub issue #8542 + times = date_range(start="2018", freq="5h", periods=3, use_cftime=use_cftime) + encoding = dict(units="microseconds since 2018-01-01", dtype=dtype) + variable = Variable(["time"], times, encoding=encoding) + + if use_dask: + variable = variable.chunk({"time": 1}) + + with pytest.raises(OverflowError, match="Not possible"): + encoded = conventions.encode_cf_variable(variable) + encoded.compute() + + +@requires_dask +@pytest.mark.parametrize( + ("units", "dtype"), [("days", np.dtype("int32")), (None, None)] +) +def test_encode_cf_timedelta_via_dask(units, dtype) -> None: + import dask.array + + times = pd.timedelta_range(start="0D", freq="D", periods=3) + times = dask.array.from_array(times, chunks=1) + encoded_times, encoding_units = encode_cf_timedelta(times, units, dtype) + + assert is_duck_dask_array(encoded_times) + assert encoded_times.chunks == times.chunks + + if units is not None and dtype is not None: + assert encoding_units == units + assert encoded_times.dtype == dtype + else: + assert encoding_units == "nanoseconds" + assert encoded_times.dtype == np.dtype("int64") + + decoded_times = decode_cf_timedelta(encoded_times, encoding_units) + np.testing.assert_equal(decoded_times, times) + + +@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)]) +def test_encode_cf_timedelta_casting_value_error(use_dask) -> None: + timedeltas = pd.timedelta_range(start="0h", freq="12h", periods=3) + encoding = dict(units="days", dtype=np.dtype("int64")) + variable = Variable(["time"], timedeltas, encoding=encoding) + + if use_dask: + variable = variable.chunk({"time": 1}) + + if not use_dask: + # In this particular case we automatically modify the encoding units to + # continue encoding with integer values. + with pytest.warns(UserWarning, match="Timedeltas can't be serialized"): + encoded = conventions.encode_cf_variable(variable) + assert encoded.attrs["units"] == "hours" + decoded = conventions.decode_cf_variable("name", encoded) + assert_equal(variable, decoded) + else: + with pytest.raises(ValueError, match="Not possible"): + encoded = conventions.encode_cf_variable(variable) + encoded.compute() + + +@pytest.mark.parametrize("use_dask", [False, pytest.param(True, marks=requires_dask)]) +@pytest.mark.parametrize("dtype", [np.dtype("int16"), np.dtype("float16")]) +def test_encode_cf_timedelta_casting_overflow_error(use_dask, dtype) -> None: + timedeltas = pd.timedelta_range(start="0h", freq="5h", periods=3) + encoding = dict(units="microseconds", dtype=dtype) + variable = Variable(["time"], timedeltas, encoding=encoding) + + if use_dask: + variable = variable.chunk({"time": 1}) + + with pytest.raises(OverflowError, match="Not possible"): + encoded = conventions.encode_cf_variable(variable) + encoded.compute() diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 3c10e3f27ab..820fcd48bd3 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -118,8 +118,10 @@ def test_apply_identity() -> None: assert_identical(variable, apply_identity(variable)) assert_identical(data_array, apply_identity(data_array)) assert_identical(data_array, apply_identity(data_array.groupby("x"))) + assert_identical(data_array, apply_identity(data_array.groupby("x", squeeze=False))) assert_identical(dataset, apply_identity(dataset)) assert_identical(dataset, apply_identity(dataset.groupby("x"))) + assert_identical(dataset, apply_identity(dataset.groupby("x", squeeze=False))) def add(a, b): @@ -257,6 +259,168 @@ def func(x): assert_identical(out1, dataset) +def test_apply_missing_dims() -> None: + ## Single arg + + def add_one(a, core_dims, on_missing_core_dim): + return apply_ufunc( + lambda x: x + 1, + a, + input_core_dims=core_dims, + output_core_dims=core_dims, + on_missing_core_dim=on_missing_core_dim, + ) + + array = np.arange(6).reshape(2, 3) + variable = xr.Variable(["x", "y"], array) + variable_no_y = xr.Variable(["x", "z"], array) + + ds = xr.Dataset({"x_y": variable, "x_z": variable_no_y}) + + # Check the standard stuff works OK + assert_identical( + add_one(ds[["x_y"]], core_dims=[["y"]], on_missing_core_dim="raise"), + ds[["x_y"]] + 1, + ) + + # `raise` — should raise on a missing dim + with pytest.raises(ValueError): + add_one(ds, core_dims=[["y"]], on_missing_core_dim="raise") + + # `drop` — should drop the var with the missing dim + assert_identical( + add_one(ds, core_dims=[["y"]], on_missing_core_dim="drop"), + (ds + 1).drop_vars("x_z"), + ) + + # `copy` — should not add one to the missing with `copy` + copy_result = add_one(ds, core_dims=[["y"]], on_missing_core_dim="copy") + assert_identical(copy_result["x_y"], (ds + 1)["x_y"]) + assert_identical(copy_result["x_z"], ds["x_z"]) + + ## Multiple args + + def sum_add(a, b, core_dims, on_missing_core_dim): + return apply_ufunc( + lambda a, b, axis=None: a.sum(axis) + b.sum(axis), + a, + b, + input_core_dims=core_dims, + on_missing_core_dim=on_missing_core_dim, + ) + + # Check the standard stuff works OK + assert_identical( + sum_add( + ds[["x_y"]], + ds[["x_y"]], + core_dims=[["x", "y"], ["x", "y"]], + on_missing_core_dim="raise", + ), + ds[["x_y"]].sum() * 2, + ) + + # `raise` — should raise on a missing dim + with pytest.raises( + ValueError, + match=r".*Missing core dims \{'y'\} from arg number 1 on a variable named `x_z`:\n.* None: data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) @@ -357,8 +521,10 @@ def func(x): assert_identical(stacked_variable, stack_negative(variable)) assert_identical(stacked_data_array, stack_negative(data_array)) assert_identical(stacked_dataset, stack_negative(dataset)) - assert_identical(stacked_data_array, stack_negative(data_array.groupby("x"))) - assert_identical(stacked_dataset, stack_negative(dataset.groupby("x"))) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(stacked_data_array, stack_negative(data_array.groupby("x"))) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(stacked_dataset, stack_negative(dataset.groupby("x"))) def original_and_stack_negative(obj): def func(x): @@ -385,11 +551,13 @@ def func(x): assert_identical(dataset, out0) assert_identical(stacked_dataset, out1) - out0, out1 = original_and_stack_negative(data_array.groupby("x")) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + out0, out1 = original_and_stack_negative(data_array.groupby("x")) assert_identical(data_array, out0) assert_identical(stacked_data_array, out1) - out0, out1 = original_and_stack_negative(dataset.groupby("x")) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + out0, out1 = original_and_stack_negative(dataset.groupby("x")) assert_identical(dataset, out0) assert_identical(stacked_dataset, out1) @@ -1028,7 +1196,7 @@ def test_apply_dask() -> None: # unknown setting for dask array handling with pytest.raises(ValueError): - apply_ufunc(identity, array, dask="unknown") + apply_ufunc(identity, array, dask="unknown") # type: ignore def dask_safe_identity(x): return apply_ufunc(identity, x, dask="allowed") @@ -1087,7 +1255,7 @@ def check(x, y): assert actual.data.chunks == array.chunks assert_identical(data_array, actual) - check(data_array, 0), + check(data_array, 0) check(0, data_array) check(data_array, xr.DataArray(0)) check(data_array, 0 * data_array) @@ -1216,7 +1384,7 @@ def func(da): expected = extract(ds) actual = extract(ds.chunk()) - assert actual.dims == {"lon_new": 3, "lat_new": 6} + assert actual.sizes == {"lon_new": 3, "lat_new": 6} assert_identical(expected.chunk(), actual) @@ -1613,6 +1781,97 @@ def test_complex_cov() -> None: assert abs(actual.item()) == 2 +@pytest.mark.parametrize("weighted", [True, False]) +def test_bilinear_cov_corr(weighted: bool) -> None: + # Test the bilinear properties of covariance and correlation + da = xr.DataArray( + np.random.random((3, 21, 4)), + coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, + dims=("a", "time", "x"), + ) + db = xr.DataArray( + np.random.random((3, 21, 4)), + coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, + dims=("a", "time", "x"), + ) + dc = xr.DataArray( + np.random.random((3, 21, 4)), + coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, + dims=("a", "time", "x"), + ) + if weighted: + weights = xr.DataArray( + np.abs(np.random.random(4)), + dims=("x"), + ) + else: + weights = None + k = np.random.random(1)[0] + + # Test covariance properties + assert_allclose( + xr.cov(da + k, db, weights=weights), xr.cov(da, db, weights=weights) + ) + assert_allclose( + xr.cov(da, db + k, weights=weights), xr.cov(da, db, weights=weights) + ) + assert_allclose( + xr.cov(da + dc, db, weights=weights), + xr.cov(da, db, weights=weights) + xr.cov(dc, db, weights=weights), + ) + assert_allclose( + xr.cov(da, db + dc, weights=weights), + xr.cov(da, db, weights=weights) + xr.cov(da, dc, weights=weights), + ) + assert_allclose( + xr.cov(k * da, db, weights=weights), k * xr.cov(da, db, weights=weights) + ) + assert_allclose( + xr.cov(da, k * db, weights=weights), k * xr.cov(da, db, weights=weights) + ) + + # Test correlation properties + assert_allclose( + xr.corr(da + k, db, weights=weights), xr.corr(da, db, weights=weights) + ) + assert_allclose( + xr.corr(da, db + k, weights=weights), xr.corr(da, db, weights=weights) + ) + assert_allclose( + xr.corr(k * da, db, weights=weights), xr.corr(da, db, weights=weights) + ) + assert_allclose( + xr.corr(da, k * db, weights=weights), xr.corr(da, db, weights=weights) + ) + + +def test_equally_weighted_cov_corr() -> None: + # Test that equal weights for all values produces same results as weights=None + da = xr.DataArray( + np.random.random((3, 21, 4)), + coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, + dims=("a", "time", "x"), + ) + db = xr.DataArray( + np.random.random((3, 21, 4)), + coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, + dims=("a", "time", "x"), + ) + # + assert_allclose( + xr.cov(da, db, weights=None), xr.cov(da, db, weights=xr.DataArray(1)) + ) + assert_allclose( + xr.cov(da, db, weights=None), xr.cov(da, db, weights=xr.DataArray(2)) + ) + assert_allclose( + xr.corr(da, db, weights=None), xr.corr(da, db, weights=xr.DataArray(1)) + ) + assert_allclose( + xr.corr(da, db, weights=None), xr.corr(da, db, weights=xr.DataArray(2)) + ) + + @requires_dask def test_vectorize_dask_new_output_dims() -> None: # regression test for GH3574 @@ -1666,7 +1925,10 @@ def identity(x): def tuple3x(x): return (x, x, x) - with pytest.raises(ValueError, match=r"number of outputs"): + with pytest.raises( + ValueError, + match=r"number of outputs.* Received a with 10 elements. Expected a tuple of 2 elements:\n\narray\(\[0", + ): apply_ufunc(identity, variable, output_core_dims=[(), ()]) with pytest.raises(ValueError, match=r"number of outputs"): @@ -1682,7 +1944,10 @@ def add_dim(x): def remove_dim(x): return x[..., 0] - with pytest.raises(ValueError, match=r"unexpected number of dimensions"): + with pytest.raises( + ValueError, + match=r"unexpected number of dimensions.*from:\n\n.*array\(\[\[0", + ): apply_ufunc(add_dim, variable, output_core_dims=[("y", "z")]) with pytest.raises(ValueError, match=r"unexpected number of dimensions"): @@ -1768,7 +2033,7 @@ def test_dot(use_dask: bool) -> None: da_a = da_a.chunk({"a": 3}) da_b = da_b.chunk({"a": 3}) da_c = da_c.chunk({"c": 3}) - actual = xr.dot(da_a, da_b, dims=["a", "b"]) + actual = xr.dot(da_a, da_b, dim=["a", "b"]) assert actual.dims == ("c",) assert (actual.data == np.einsum("ij,ijk->k", a, b)).all() assert isinstance(actual.variable.data, type(da_a.variable.data)) @@ -1792,33 +2057,33 @@ def test_dot(use_dask: bool) -> None: if use_dask: da_a = da_a.chunk({"a": 3}) da_b = da_b.chunk({"a": 3}) - actual = xr.dot(da_a, da_b, dims=["b"]) + actual = xr.dot(da_a, da_b, dim=["b"]) assert actual.dims == ("a", "c") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() assert isinstance(actual.variable.data, type(da_a.variable.data)) - actual = xr.dot(da_a, da_b, dims=["b"]) + actual = xr.dot(da_a, da_b, dim=["b"]) assert actual.dims == ("a", "c") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() - actual = xr.dot(da_a, da_b, dims="b") + actual = xr.dot(da_a, da_b, dim="b") assert actual.dims == ("a", "c") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() - actual = xr.dot(da_a, da_b, dims="a") + actual = xr.dot(da_a, da_b, dim="a") assert actual.dims == ("b", "c") assert (actual.data == np.einsum("ij,ijk->jk", a, b)).all() - actual = xr.dot(da_a, da_b, dims="c") + actual = xr.dot(da_a, da_b, dim="c") assert actual.dims == ("a", "b") assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all() - actual = xr.dot(da_a, da_b, da_c, dims=["a", "b"]) + actual = xr.dot(da_a, da_b, da_c, dim=["a", "b"]) assert actual.dims == ("c", "e") assert (actual.data == np.einsum("ij,ijk,kl->kl ", a, b, c)).all() # should work with tuple - actual = xr.dot(da_a, da_b, dims=("c",)) + actual = xr.dot(da_a, da_b, dim=("c",)) assert actual.dims == ("a", "b") assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all() @@ -1828,47 +2093,47 @@ def test_dot(use_dask: bool) -> None: assert (actual.data == np.einsum("ij,ijk,kl->l ", a, b, c)).all() # 1 array summation - actual = xr.dot(da_a, dims="a") + actual = xr.dot(da_a, dim="a") assert actual.dims == ("b",) assert (actual.data == np.einsum("ij->j ", a)).all() # empty dim - actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims="a") + actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim="a") assert actual.dims == ("b",) assert (actual.data == np.zeros(actual.shape)).all() # Ellipsis (...) sums over all dimensions - actual = xr.dot(da_a, da_b, dims=...) + actual = xr.dot(da_a, da_b, dim=...) assert actual.dims == () assert (actual.data == np.einsum("ij,ijk->", a, b)).all() - actual = xr.dot(da_a, da_b, da_c, dims=...) + actual = xr.dot(da_a, da_b, da_c, dim=...) assert actual.dims == () assert (actual.data == np.einsum("ij,ijk,kl-> ", a, b, c)).all() - actual = xr.dot(da_a, dims=...) + actual = xr.dot(da_a, dim=...) assert actual.dims == () assert (actual.data == np.einsum("ij-> ", a)).all() - actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims=...) + actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim=...) assert actual.dims == () assert (actual.data == np.zeros(actual.shape)).all() # Invalid cases if not use_dask: with pytest.raises(TypeError): - xr.dot(da_a, dims="a", invalid=None) + xr.dot(da_a, dim="a", invalid=None) with pytest.raises(TypeError): - xr.dot(da_a.to_dataset(name="da"), dims="a") + xr.dot(da_a.to_dataset(name="da"), dim="a") with pytest.raises(TypeError): - xr.dot(dims="a") + xr.dot(dim="a") # einsum parameters - actual = xr.dot(da_a, da_b, dims=["b"], order="C") + actual = xr.dot(da_a, da_b, dim=["b"], order="C") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() assert actual.values.flags["C_CONTIGUOUS"] assert not actual.values.flags["F_CONTIGUOUS"] - actual = xr.dot(da_a, da_b, dims=["b"], order="F") + actual = xr.dot(da_a, da_b, dim=["b"], order="F") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() # dask converts Fortran arrays to C order when merging the final array if not use_dask: @@ -1910,7 +2175,7 @@ def test_dot_align_coords(use_dask: bool) -> None: expected = (da_a * da_b).sum(["a", "b"]) xr.testing.assert_allclose(expected, actual) - actual = xr.dot(da_a, da_b, dims=...) + actual = xr.dot(da_a, da_b, dim=...) expected = (da_a * da_b).sum() xr.testing.assert_allclose(expected, actual) @@ -2151,7 +2416,7 @@ def test_polyval_cftime(use_dask: bool, date: str) -> None: import cftime x = xr.DataArray( - xr.date_range(date, freq="1S", periods=3, use_cftime=True), + xr.date_range(date, freq="1s", periods=3, use_cftime=True), dims="x", ) coeffs = xr.DataArray([0, 1], dims="degree", coords={"degree": [0, 1]}) @@ -2171,7 +2436,7 @@ def test_polyval_cftime(use_dask: bool, date: str) -> None: xr.DataArray( [0, 1e9, 2e9], dims="x", - coords={"x": xr.date_range(date, freq="1S", periods=3, use_cftime=True)}, + coords={"x": xr.date_range(date, freq="1s", periods=3, use_cftime=True)}, ) + offset ) diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index f60308f8863..0cf4cc03a09 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -9,6 +9,7 @@ from xarray import DataArray, Dataset, Variable, concat from xarray.core import dtypes, merge +from xarray.core.coordinates import Coordinates from xarray.core.indexes import PandasIndex from xarray.tests import ( InaccessibleArray, @@ -297,6 +298,7 @@ def test_concat_multiple_datasets_with_multiple_missing_variables() -> None: assert_identical(actual, expected) +@pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_concat_type_of_missing_fill() -> None: datasets = create_typed_datasets(2, seed=123) expected1 = concat(datasets, dim="day", fill_value=dtypes.NA) @@ -492,7 +494,7 @@ def test_concat_merge_variables_present_in_some_datasets(self, data) -> None: def test_concat_2(self, data) -> None: dim = "dim2" - datasets = [g for _, g in data.groupby(dim, squeeze=True)] + datasets = [g.squeeze(dim) for _, g in data.groupby(dim, squeeze=False)] concat_over = [k for k, v in data.coords.items() if dim in v.dims and k != dim] actual = concat(datasets, data[dim], coords=concat_over) assert_identical(data, self.rectify_dim_order(data, actual)) @@ -503,11 +505,11 @@ def test_concat_coords_kwarg(self, data, dim, coords) -> None: data = data.copy(deep=True) # make sure the coords argument behaves as expected data.coords["extra"] = ("dim4", np.arange(3)) - datasets = [g for _, g in data.groupby(dim, squeeze=True)] + datasets = [g.squeeze() for _, g in data.groupby(dim, squeeze=False)] actual = concat(datasets, data[dim], coords=coords) if coords == "all": - expected = np.array([data["extra"].values for _ in range(data.dims[dim])]) + expected = np.array([data["extra"].values for _ in range(data.sizes[dim])]) assert_array_equal(actual["extra"].values, expected) else: @@ -537,8 +539,7 @@ def test_concat_data_vars_typing(self) -> None: actual = concat(objs, dim="x", data_vars="minimal") assert_identical(data, actual) - def test_concat_data_vars(self): - # TODO: annotating this func fails + def test_concat_data_vars(self) -> None: data = Dataset({"foo": ("x", np.random.randn(10))}) objs: list[Dataset] = [data.isel(x=slice(5)), data.isel(x=slice(5, None))] for data_vars in ["minimal", "different", "all", [], ["foo"]]: @@ -614,11 +615,14 @@ def test_concat_errors(self): with pytest.raises(ValueError, match=r"must supply at least one"): concat([], "dim1") - with pytest.raises(ValueError, match=r"are not coordinates"): + with pytest.raises(ValueError, match=r"are not found in the coordinates"): concat([data, data], "new_dim", coords=["not_found"]) + with pytest.raises(ValueError, match=r"are not found in the data variables"): + concat([data, data], "new_dim", data_vars=["not_found"]) + with pytest.raises(ValueError, match=r"global attributes not"): - # call deepcopy seperately to get unique attrs + # call deepcopy separately to get unique attrs data0 = deepcopy(split_data[0]) data1 = deepcopy(split_data[1]) data1.attrs["foo"] = "bar" @@ -906,8 +910,9 @@ def test_concat_dim_is_dataarray(self) -> None: assert_identical(actual, expected) def test_concat_multiindex(self) -> None: - x = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]]) - expected = Dataset(coords={"x": x}) + midx = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]]) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + expected = Dataset(coords=midx_coords) actual = concat( [expected.isel(x=slice(2)), expected.isel(x=slice(2, None))], "x" ) @@ -917,8 +922,9 @@ def test_concat_multiindex(self) -> None: def test_concat_along_new_dim_multiindex(self) -> None: # see https://github.com/pydata/xarray/issues/6881 level_names = ["x_level_0", "x_level_1"] - x = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]], names=level_names) - ds = Dataset(coords={"x": x}) + midx = pd.MultiIndex.from_product([[1, 2, 3], ["a", "b"]], names=level_names) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + ds = Dataset(coords=midx_coords) concatenated = concat([ds], "new") actual = list(concatenated.xindexes.get_all_coords("x")) expected = ["x"] + level_names @@ -994,7 +1000,7 @@ def test_concat(self) -> None: actual = concat([foo, bar], "w") assert_equal(expected, actual) # from iteration: - grouped = [g for _, g in foo.groupby("x")] + grouped = [g.squeeze() for _, g in foo.groupby("x", squeeze=False)] stacked = concat(grouped, ds["x"]) assert_identical(foo, stacked) # with an index as the 'dim' argument @@ -1064,10 +1070,10 @@ def test_concat_fill_value(self, fill_value) -> None: def test_concat_join_kwarg(self) -> None: ds1 = Dataset( {"a": (("x", "y"), [[0]])}, coords={"x": [0], "y": [0]} - ).to_array() + ).to_dataarray() ds2 = Dataset( {"a": (("x", "y"), [[0]])}, coords={"x": [1], "y": [0.0001]} - ).to_array() + ).to_dataarray() expected: dict[JoinOptions, Any] = {} expected["outer"] = Dataset( @@ -1095,7 +1101,7 @@ def test_concat_join_kwarg(self) -> None: for join in expected: actual = concat([ds1, ds2], join=join, dim="x") - assert_equal(actual, expected[join].to_array()) + assert_equal(actual, expected[join].to_dataarray()) def test_concat_combine_attrs_kwarg(self) -> None: da1 = DataArray([0], coords=[("x", [0])], attrs={"b": 42}) @@ -1208,7 +1214,7 @@ def test_concat_preserve_coordinate_order() -> None: # check dimension order for act, exp in zip(actual.dims, expected.dims): assert act == exp - assert actual.dims[act] == expected.dims[exp] + assert actual.sizes[act] == expected.sizes[exp] # check coordinate order for act, exp in zip(actual.coords, expected.coords): @@ -1218,7 +1224,7 @@ def test_concat_preserve_coordinate_order() -> None: def test_concat_typing_check() -> None: ds = Dataset({"foo": 1}, {"bar": 2}) - da = Dataset({"foo": 3}, {"bar": 4}).to_array(dim="foo") + da = Dataset({"foo": 3}, {"bar": 4}).to_dataarray(dim="foo") # concatenate a list of non-homogeneous types must raise TypeError with pytest.raises( diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 9485b506b89..fdfea3c3fe8 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -32,7 +32,7 @@ class TestBoolTypeArray: def test_booltype_array(self) -> None: x = np.array([1, 0, 1, 1, 0], dtype="i1") - bx = conventions.BoolTypeArray(x) + bx = coding.variables.BoolTypeArray(x) assert bx.dtype == bool assert_array_equal(bx, np.array([True, False, True, True, False], dtype=bool)) @@ -41,7 +41,7 @@ class TestNativeEndiannessArray: def test(self) -> None: x = np.arange(5, dtype=">i8") expected = np.arange(5, dtype="int64") - a = conventions.NativeEndiannessArray(x) + a = coding.variables.NativeEndiannessArray(x) assert a.dtype == expected.dtype assert a.dtype == expected[:].dtype assert_array_equal(a, expected) @@ -52,10 +52,9 @@ def test_decode_cf_with_conflicting_fill_missing_value() -> None: var = Variable( ["t"], np.arange(3), {"units": "foobar", "missing_value": 0, "_FillValue": 1} ) - with warnings.catch_warnings(record=True) as w: + with pytest.warns(SerializationWarning, match="has multiple fill"): actual = conventions.decode_cf_variable("t", var) assert_identical(actual, expected) - assert "has multiple fill" in str(w[0].message) expected = Variable(["t"], np.arange(10), {"units": "foobar"}) @@ -64,7 +63,13 @@ def test_decode_cf_with_conflicting_fill_missing_value() -> None: np.arange(10), {"units": "foobar", "missing_value": np.nan, "_FillValue": np.nan}, ) - actual = conventions.decode_cf_variable("t", var) + + # the following code issues two warnings, so we need to check for both + with pytest.warns(SerializationWarning) as winfo: + actual = conventions.decode_cf_variable("t", var) + for aw in winfo: + assert "non-conforming" in str(aw.message) + assert_identical(actual, expected) var = Variable( @@ -76,10 +81,37 @@ def test_decode_cf_with_conflicting_fill_missing_value() -> None: "_FillValue": np.float32(np.nan), }, ) - actual = conventions.decode_cf_variable("t", var) + + # the following code issues two warnings, so we need to check for both + with pytest.warns(SerializationWarning) as winfo: + actual = conventions.decode_cf_variable("t", var) + for aw in winfo: + assert "non-conforming" in str(aw.message) assert_identical(actual, expected) +def test_decode_cf_variable_with_mismatched_coordinates() -> None: + # tests for decoding mismatched coordinates attributes + # see GH #1809 + zeros1 = np.zeros((1, 5, 3)) + orig = Dataset( + { + "XLONG": (["x", "y"], zeros1.squeeze(0), {}), + "XLAT": (["x", "y"], zeros1.squeeze(0), {}), + "foo": (["time", "x", "y"], zeros1, {"coordinates": "XTIME XLONG XLAT"}), + "time": ("time", [0.0], {"units": "hours since 2017-01-01"}), + } + ) + decoded = conventions.decode_cf(orig, decode_coords=True) + assert decoded["foo"].encoding["coordinates"] == "XTIME XLONG XLAT" + assert list(decoded.coords.keys()) == ["XLONG", "XLAT", "time"] + + decoded = conventions.decode_cf(orig, decode_coords=False) + assert "coordinates" not in decoded["foo"].encoding + assert decoded["foo"].attrs.get("coordinates") == "XTIME XLONG XLAT" + assert list(decoded.coords.keys()) == ["time"] + + @requires_cftime class TestEncodeCFVariable: def test_incompatible_attributes(self) -> None: @@ -129,9 +161,9 @@ def test_multidimensional_coordinates(self) -> None: foo1_coords = enc["foo1"].attrs.get("coordinates", "") foo2_coords = enc["foo2"].attrs.get("coordinates", "") foo3_coords = enc["foo3"].attrs.get("coordinates", "") - assert set(foo1_coords.split()) == {"lat1", "lon1"} - assert set(foo2_coords.split()) == {"lat2", "lon2"} - assert set(foo3_coords.split()) == {"lat3", "lon3"} + assert foo1_coords == "lon1 lat1" + assert foo2_coords == "lon2 lat2" + assert foo3_coords == "lon3 lat3" # Should not have any global coordinates. assert "coordinates" not in attrs @@ -150,11 +182,12 @@ def test_var_with_coord_attr(self) -> None: enc, attrs = conventions.encode_dataset_coordinates(orig) # Make sure we have the right coordinates for each variable. values_coords = enc["values"].attrs.get("coordinates", "") - assert set(values_coords.split()) == {"time", "lat", "lon"} + assert values_coords == "time lon lat" # Should not have any global coordinates. assert "coordinates" not in attrs def test_do_not_overwrite_user_coordinates(self) -> None: + # don't overwrite user-defined "coordinates" encoding orig = Dataset( coords={"x": [0, 1, 2], "y": ("x", [5, 6, 7]), "z": ("x", [8, 9, 10])}, data_vars={"a": ("x", [1, 2, 3]), "b": ("x", [3, 5, 6])}, @@ -168,6 +201,19 @@ def test_do_not_overwrite_user_coordinates(self) -> None: with pytest.raises(ValueError, match=r"'coordinates' found in both attrs"): conventions.encode_dataset_coordinates(orig) + def test_deterministic_coords_encoding(self) -> None: + # the coordinates attribute is sorted when set by xarray.conventions ... + # ... on a variable's coordinates attribute + ds = Dataset({"foo": 0}, coords={"baz": 0, "bar": 0}) + vars, attrs = conventions.encode_dataset_coordinates(ds) + assert vars["foo"].attrs["coordinates"] == "bar baz" + assert attrs.get("coordinates") is None + # ... on the global coordinates attribute + ds = ds.drop_vars("foo") + vars, attrs = conventions.encode_dataset_coordinates(ds) + assert attrs["coordinates"] == "bar baz" + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_emit_coordinates_attribute_in_attrs(self) -> None: orig = Dataset( {"a": 1, "b": 1}, @@ -185,6 +231,7 @@ def test_emit_coordinates_attribute_in_attrs(self) -> None: assert enc["b"].attrs.get("coordinates") == "t" assert "coordinates" not in enc["b"].encoding + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_emit_coordinates_attribute_in_encoding(self) -> None: orig = Dataset( {"a": 1, "b": 1}, @@ -231,9 +278,12 @@ def test_dataset(self) -> None: assert_identical(expected, actual) def test_invalid_coordinates(self) -> None: - # regression test for GH308 + # regression test for GH308, GH1809 original = Dataset({"foo": ("t", [1, 2], {"coordinates": "invalid"})}) + decoded = Dataset({"foo": ("t", [1, 2], {}, {"coordinates": "invalid"})}) actual = conventions.decode_cf(original) + assert_identical(decoded, actual) + actual = conventions.decode_cf(original, decode_coords=False) assert_identical(original, actual) def test_decode_coordinates(self) -> None: @@ -247,16 +297,15 @@ def test_decode_coordinates(self) -> None: def test_0d_int32_encoding(self) -> None: original = Variable((), np.int32(0), encoding={"dtype": "int64"}) expected = Variable((), np.int64(0)) - actual = conventions.maybe_encode_nonstring_dtype(original) + actual = coding.variables.NonStringCoder().encode(original) assert_identical(expected, actual) def test_decode_cf_with_multiple_missing_values(self) -> None: original = Variable(["t"], [0, 1, 2], {"missing_value": np.array([0, 1])}) expected = Variable(["t"], [np.nan, np.nan, 2], {}) - with warnings.catch_warnings(record=True) as w: + with pytest.warns(SerializationWarning, match="has multiple fill"): actual = conventions.decode_cf_variable("t", original) assert_identical(expected, actual) - assert "has multiple fill" in str(w[0].message) def test_decode_cf_with_drop_variables(self) -> None: original = Dataset( @@ -293,6 +342,17 @@ def test_invalid_time_units_raises_eagerly(self) -> None: with pytest.raises(ValueError, match=r"unable to decode time"): decode_cf(ds) + @pytest.mark.parametrize("decode_times", [True, False]) + def test_invalid_timedelta_units_do_not_decode(self, decode_times) -> None: + # regression test for #8269 + ds = Dataset( + {"time": ("time", [0, 1, 20], {"units": "days invalid", "_FillValue": 20})} + ) + expected = Dataset( + {"time": ("time", [0.0, 1.0, np.nan], {"units": "days invalid"})} + ) + assert_identical(expected, decode_cf(ds, decode_times=decode_times)) + @requires_cftime def test_dataset_repr_with_netcdf4_datetimes(self) -> None: # regression test for #347 @@ -336,7 +396,6 @@ def test_decode_cf_with_dask(self) -> None: } ).chunk() decoded = conventions.decode_cf(original) - print(decoded) assert all( isinstance(var.data, da.Array) for name, var in decoded.variables.items() @@ -444,6 +503,18 @@ def test_encoding_kwarg_fixed_width_string(self) -> None: pass +@pytest.mark.parametrize( + "data", + [ + np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object), + np.array([["x", 1], ["y", 2]], dtype="object"), + ], +) +def test_infer_dtype_error_on_mixed_types(data): + with pytest.raises(ValueError, match="unable to infer dtype on variable"): + conventions._infer_dtype(data, "test") + + class TestDecodeCFVariableWithArrayUnits: def test_decode_cf_variable_with_array_units(self) -> None: v = Variable(["t"], [1, 2, 3], {"units": np.array(["foobar"], dtype=object)}) @@ -485,3 +556,18 @@ def test_decode_cf_error_includes_variable_name(): ds = Dataset({"invalid": ([], 1e36, {"units": "days since 2000-01-01"})}) with pytest.raises(ValueError, match="Failed to decode variable 'invalid'"): decode_cf(ds) + + +def test_encode_cf_variable_with_vlen_dtype() -> None: + v = Variable( + ["x"], np.array(["a", "b"], dtype=coding.strings.create_vlen_dtype(str)) + ) + encoded_v = conventions.encode_cf_variable(v) + assert encoded_v.data.dtype.kind == "O" + assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) == str + + # empty array + v = Variable(["x"], np.array([], dtype=coding.strings.create_vlen_dtype(str))) + encoded_v = conventions.encode_cf_variable(v) + assert encoded_v.data.dtype.kind == "O" + assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) == str diff --git a/xarray/tests/test_coordinates.py b/xarray/tests/test_coordinates.py new file mode 100644 index 00000000000..68ce55b05da --- /dev/null +++ b/xarray/tests/test_coordinates.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import pandas as pd +import pytest + +from xarray.core.alignment import align +from xarray.core.coordinates import Coordinates +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset +from xarray.core.indexes import PandasIndex, PandasMultiIndex +from xarray.tests import assert_identical, source_ndarray + + +class TestCoordinates: + def test_init_noindex(self) -> None: + coords = Coordinates(coords={"foo": ("x", [0, 1, 2])}) + expected = Dataset(coords={"foo": ("x", [0, 1, 2])}) + assert_identical(coords.to_dataset(), expected) + + def test_init_default_index(self) -> None: + coords = Coordinates(coords={"x": [1, 2]}) + expected = Dataset(coords={"x": [1, 2]}) + assert_identical(coords.to_dataset(), expected) + assert "x" in coords.xindexes + + def test_init_no_default_index(self) -> None: + # dimension coordinate with no default index (explicit) + coords = Coordinates(coords={"x": [1, 2]}, indexes={}) + assert "x" not in coords.xindexes + + def test_init_from_coords(self) -> None: + expected = Dataset(coords={"foo": ("x", [0, 1, 2])}) + coords = Coordinates(coords=expected.coords) + assert_identical(coords.to_dataset(), expected) + + # test variables copied + assert coords.variables["foo"] is not expected.variables["foo"] + + # test indexes are extracted + expected = Dataset(coords={"x": [0, 1, 2]}) + coords = Coordinates(coords=expected.coords) + assert_identical(coords.to_dataset(), expected) + assert expected.xindexes == coords.xindexes + + # coords + indexes not supported + with pytest.raises( + ValueError, match="passing both.*Coordinates.*indexes.*not allowed" + ): + coords = Coordinates( + coords=expected.coords, indexes={"x": PandasIndex([0, 1, 2], "x")} + ) + + def test_init_empty(self) -> None: + coords = Coordinates() + assert len(coords) == 0 + + def test_init_index_error(self) -> None: + idx = PandasIndex([1, 2, 3], "x") + with pytest.raises(ValueError, match="no coordinate variables found"): + Coordinates(indexes={"x": idx}) + + with pytest.raises(TypeError, match=".* is not an `xarray.indexes.Index`"): + Coordinates(coords={"x": ("x", [1, 2, 3])}, indexes={"x": "not_an_xarray_index"}) # type: ignore + + def test_init_dim_sizes_conflict(self) -> None: + with pytest.raises(ValueError): + Coordinates(coords={"foo": ("x", [1, 2]), "bar": ("x", [1, 2, 3, 4])}) + + def test_from_pandas_multiindex(self) -> None: + midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + coords = Coordinates.from_pandas_multiindex(midx, "x") + + assert isinstance(coords.xindexes["x"], PandasMultiIndex) + assert coords.xindexes["x"].index.equals(midx) + assert coords.xindexes["x"].dim == "x" + + expected = PandasMultiIndex(midx, "x").create_variables() + assert list(coords.variables) == list(expected) + for name in ("x", "one", "two"): + assert_identical(expected[name], coords.variables[name]) + + @pytest.mark.filterwarnings("ignore:return type") + def test_dims(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + assert set(coords.dims) == {"x"} + + def test_sizes(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + assert coords.sizes == {"x": 3} + + def test_dtypes(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + assert coords.dtypes == {"x": int} + + def test_getitem(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + assert_identical( + coords["x"], + DataArray([0, 1, 2], coords={"x": [0, 1, 2]}, name="x"), + ) + + def test_delitem(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + del coords["x"] + assert "x" not in coords + + with pytest.raises( + KeyError, match="'nonexistent' is not in coordinate variables" + ): + del coords["nonexistent"] + + def test_update(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + + coords.update({"y": ("y", [4, 5, 6])}) + assert "y" in coords + assert "y" in coords.xindexes + expected = DataArray([4, 5, 6], coords={"y": [4, 5, 6]}, name="y") + assert_identical(coords["y"], expected) + + def test_equals(self): + coords = Coordinates(coords={"x": [0, 1, 2]}) + + assert coords.equals(coords) + assert not coords.equals("not_a_coords") + + def test_identical(self): + coords = Coordinates(coords={"x": [0, 1, 2]}) + + assert coords.identical(coords) + assert not coords.identical("not_a_coords") + + def test_assign(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + expected = Coordinates(coords={"x": [0, 1, 2], "y": [3, 4]}) + + actual = coords.assign(y=[3, 4]) + assert_identical(actual, expected) + + actual = coords.assign({"y": [3, 4]}) + assert_identical(actual, expected) + + def test_copy(self) -> None: + no_index_coords = Coordinates({"foo": ("x", [1, 2, 3])}) + copied = no_index_coords.copy() + assert_identical(no_index_coords, copied) + v0 = no_index_coords.variables["foo"] + v1 = copied.variables["foo"] + assert v0 is not v1 + assert source_ndarray(v0.data) is source_ndarray(v1.data) + + deep_copied = no_index_coords.copy(deep=True) + assert_identical(no_index_coords.to_dataset(), deep_copied.to_dataset()) + v0 = no_index_coords.variables["foo"] + v1 = deep_copied.variables["foo"] + assert v0 is not v1 + assert source_ndarray(v0.data) is not source_ndarray(v1.data) + + def test_align(self) -> None: + coords = Coordinates(coords={"x": [0, 1, 2]}) + + left = coords + + # test Coordinates._reindex_callback + right = coords.to_dataset().isel(x=[0, 1]).coords + left2, right2 = align(left, right, join="inner") + assert_identical(left2, right2) + + # test Coordinates._overwrite_indexes + right.update({"x": ("x", [4, 5, 6])}) + left2, right2 = align(left, right, join="override") + assert_identical(left2, left) + assert_identical(left2, right2) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 21f0ab93d78..517fc0c2d62 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -178,13 +178,24 @@ def test_binary_op(self): self.assertLazyAndIdentical(u + u, v + v) self.assertLazyAndIdentical(u[0] + u, v[0] + v) + def test_binary_op_bitshift(self) -> None: + # bit shifts only work on ints so we need to generate + # new eager and lazy vars + rng = np.random.default_rng(0) + values = rng.integers(low=-10000, high=10000, size=(4, 6)) + data = da.from_array(values, chunks=(2, 2)) + u = Variable(("x", "y"), values) + v = Variable(("x", "y"), data) + self.assertLazyAndIdentical(u << 2, v << 2) + self.assertLazyAndIdentical(u << 5, v << 5) + self.assertLazyAndIdentical(u >> 2, v >> 2) + self.assertLazyAndIdentical(u >> 5, v >> 5) + def test_repr(self): expected = dedent( - """\ - - {!r}""".format( - self.lazy_var.data - ) + f"""\ + Size: 192B + {self.lazy_var.data!r}""" ) assert expected == repr(self.lazy_var) @@ -586,11 +597,11 @@ def test_to_dataset_roundtrip(self): v = self.lazy_array expected = u.assign_coords(x=u["x"]) - self.assertLazyAndEqual(expected, v.to_dataset("x").to_array("x")) + self.assertLazyAndEqual(expected, v.to_dataset("x").to_dataarray("x")) def test_merge(self): def duplicate_and_merge(array): - return xr.merge([array, array.rename("bar")]).to_array() + return xr.merge([array, array.rename("bar")]).to_dataarray() expected = duplicate_and_merge(self.eager_array) actual = duplicate_and_merge(self.lazy_array) @@ -643,14 +654,12 @@ def test_dataarray_repr(self): nonindex_coord = build_dask_array("coord") a = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)}) expected = dedent( - """\ - - {!r} + f"""\ + Size: 8B + {data!r} Coordinates: - y (x) int64 dask.array - Dimensions without coordinates: x""".format( - data - ) + y (x) int64 8B dask.array + Dimensions without coordinates: x""" ) assert expected == repr(a) assert kernel_call_count == 0 # should not evaluate dask array @@ -661,13 +670,13 @@ def test_dataset_repr(self): ds = Dataset(data_vars={"a": ("x", data)}, coords={"y": ("x", nonindex_coord)}) expected = dedent( """\ - + Size: 16B Dimensions: (x: 1) Coordinates: - y (x) int64 dask.array + y (x) int64 8B dask.array Dimensions without coordinates: x Data variables: - a (x) int64 dask.array""" + a (x) int64 8B dask.array""" ) assert expected == repr(ds) assert kernel_call_count == 0 # should not evaluate dask array @@ -782,7 +791,7 @@ def test_to_dask_dataframe(self): assert isinstance(actual, dd.DataFrame) # use the .equals from pandas to check dataframes are equivalent - assert_frame_equal(expected.compute(), actual.compute()) + assert_frame_equal(actual.compute(), expected.compute()) # test if no index is given expected = dd.from_pandas(expected_pd.reset_index(drop=False), chunksize=4) @@ -790,8 +799,12 @@ def test_to_dask_dataframe(self): actual = ds.to_dask_dataframe(set_index=False) assert isinstance(actual, dd.DataFrame) - assert_frame_equal(expected.compute(), actual.compute()) + assert_frame_equal(actual.compute(), expected.compute()) + @pytest.mark.xfail( + reason="Currently pandas with pyarrow installed will return a `string[pyarrow]` type, " + "which causes the `y` column to have a different type depending on whether pyarrow is installed" + ) def test_to_dask_dataframe_2D(self): # Test if 2-D dataset is supplied w = np.random.randn(2, 3) @@ -810,7 +823,7 @@ def test_to_dask_dataframe_2D(self): actual = ds.to_dask_dataframe(set_index=False) assert isinstance(actual, dd.DataFrame) - assert_frame_equal(expected, actual.compute()) + assert_frame_equal(actual.compute(), expected) @pytest.mark.xfail(raises=NotImplementedError) def test_to_dask_dataframe_2D_set_index(self): @@ -843,6 +856,10 @@ def test_to_dask_dataframe_coordinates(self): assert isinstance(actual, dd.DataFrame) assert_frame_equal(expected.compute(), actual.compute()) + @pytest.mark.xfail( + reason="Currently pandas with pyarrow installed will return a `string[pyarrow]` type, " + "which causes the index to have a different type depending on whether pyarrow is installed" + ) def test_to_dask_dataframe_not_daskarray(self): # Test if DataArray is not a dask array x = np.random.randn(10) @@ -891,13 +908,12 @@ def test_to_dask_dataframe_dim_order(self): @pytest.mark.parametrize("method", ["load", "compute"]) def test_dask_kwargs_variable(method): - x = Variable("y", da.from_array(np.arange(3), chunks=(2,))) - # args should be passed on to da.Array.compute() - with mock.patch.object( - da.Array, "compute", return_value=np.arange(3) - ) as mock_compute: + chunked_array = da.from_array(np.arange(3), chunks=(2,)) + x = Variable("y", chunked_array) + # args should be passed on to dask.compute() (via DaskManager.compute()) + with mock.patch.object(da, "compute", return_value=(np.arange(3),)) as mock_compute: getattr(x, method)(foo="bar") - mock_compute.assert_called_with(foo="bar") + mock_compute.assert_called_with(chunked_array, foo="bar") @pytest.mark.parametrize("method", ["load", "compute", "persist"]) @@ -1287,12 +1303,12 @@ def test_map_blocks_kwargs(obj): assert_identical(actual, expected) -def test_map_blocks_to_array(map_ds): +def test_map_blocks_to_dataarray(map_ds): with raise_if_dask_computes(): - actual = xr.map_blocks(lambda x: x.to_array(), map_ds) + actual = xr.map_blocks(lambda x: x.to_dataarray(), map_ds) - # to_array does not preserve name, so cannot use assert_identical - assert_equal(actual, map_ds.to_array()) + # to_dataarray does not preserve name, so cannot use assert_identical + assert_equal(actual, map_ds.to_dataarray()) @pytest.mark.parametrize( @@ -1348,6 +1364,25 @@ def test_map_blocks_da_ds_with_template(obj): assert_identical(actual, template) +def test_map_blocks_roundtrip_string_index(): + ds = xr.Dataset( + {"data": (["label"], [1, 2, 3])}, coords={"label": ["foo", "bar", "baz"]} + ).chunk(label=1) + assert ds.label.dtype == np.dtype(" None: v = Variable(["time", "x"], [[1, 2, 3], [4, 5, 6]], {"foo": "bar"}) - coords = {"x": np.arange(3, dtype=np.int64), "other": np.int64(0)} + v = v.astype(np.uint64) + coords = {"x": np.arange(3, dtype=np.uint64), "other": np.uint64(0)} data_array = DataArray(v, coords, name="my_variable") expected = dedent( """\ - + Size: 48B array([[1, 2, 3], - [4, 5, 6]]) + [4, 5, 6]], dtype=uint64) Coordinates: - * x (x) int64 0 1 2 - other int64 0 + * x (x) uint64 24B 0 1 2 + other uint64 8B 0 Dimensions without coordinates: time Attributes: foo: bar""" @@ -101,12 +112,12 @@ def test_repr(self) -> None: def test_repr_multiindex(self) -> None: expected = dedent( """\ - - array([0, 1, 2, 3]) + Size: 32B + array([0, 1, 2, 3], dtype=uint64) Coordinates: - * x (x) object MultiIndex - * level_1 (x) object 'a' 'a' 'b' 'b' - * level_2 (x) int64 1 2 1 2""" + * x (x) object 32B MultiIndex + * level_1 (x) object 32B 'a' 'a' 'b' 'b' + * level_2 (x) int64 32B 1 2 1 2""" ) assert expected == repr(self.mda) @@ -115,16 +126,19 @@ def test_repr_multiindex_long(self) -> None: [["a", "b", "c", "d"], [1, 2, 3, 4, 5, 6, 7, 8]], names=("level_1", "level_2"), ) - mda_long = DataArray(list(range(32)), coords={"x": mindex_long}, dims="x") + mda_long = DataArray( + list(range(32)), coords={"x": mindex_long}, dims="x" + ).astype(np.uint64) expected = dedent( """\ - + Size: 256B array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]) + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + dtype=uint64) Coordinates: - * x (x) object MultiIndex - * level_1 (x) object 'a' 'a' 'a' 'a' 'a' 'a' 'a' ... 'd' 'd' 'd' 'd' 'd' 'd' - * level_2 (x) int64 1 2 3 4 5 6 7 8 1 2 3 4 5 6 ... 4 5 6 7 8 1 2 3 4 5 6 7 8""" + * x (x) object 256B MultiIndex + * level_1 (x) object 256B 'a' 'a' 'a' 'a' 'a' 'a' ... 'd' 'd' 'd' 'd' 'd' 'd' + * level_2 (x) int64 256B 1 2 3 4 5 6 7 8 1 2 3 4 ... 5 6 7 8 1 2 3 4 5 6 7 8""" ) assert expected == repr(mda_long) @@ -278,6 +292,25 @@ def test_encoding(self) -> None: self.dv.encoding = expected2 assert expected2 is not self.dv.encoding + def test_drop_encoding(self) -> None: + array = self.mda + encoding = {"scale_factor": 10} + array.encoding = encoding + array["x"].encoding = encoding + + assert array.encoding == encoding + assert array["x"].encoding == encoding + + actual = array.drop_encoding() + + # did not modify in place + assert array.encoding == encoding + assert array["x"].encoding == encoding + + # variable and coord encoding is empty + assert actual.encoding == {} + assert actual["x"].encoding == {} + def test_constructor(self) -> None: data = np.random.random((2, 3)) @@ -375,8 +408,8 @@ def test_constructor_invalid(self) -> None: with pytest.raises(ValueError, match=r"not a subset of the .* dim"): DataArray(data, {"x": [0, 1, 2]}) - with pytest.raises(TypeError, match=r"is not a string"): - DataArray(data, dims=["x", None]) + with pytest.raises(TypeError, match=r"is not hashable"): + DataArray(data, dims=["x", []]) # type: ignore[list-item] with pytest.raises(ValueError, match=r"conflicting sizes for dim"): DataArray([1, 2, 3], coords=[("x", [0, 1])]) @@ -388,9 +421,6 @@ def test_constructor_invalid(self) -> None: with pytest.raises(ValueError, match=r"conflicting MultiIndex"): DataArray(np.random.rand(4, 4), [("x", self.mindex), ("level_1", range(4))]) - with pytest.raises(ValueError, match=r"matching the dimension size"): - DataArray(data, coords={"x": 0}, dims=["x", "y"]) - def test_constructor_from_self_described(self) -> None: data = [[-0.1, 21], [0, 2]] expected = DataArray( @@ -467,6 +497,31 @@ def test_constructor_dask_coords(self) -> None: expected = DataArray(data, coords={"x": ecoord, "y": ecoord}, dims=["x", "y"]) assert_equal(actual, expected) + def test_constructor_no_default_index(self) -> None: + # explicitly passing a Coordinates object skips the creation of default index + da = DataArray(range(3), coords=Coordinates({"x": [1, 2, 3]}, indexes={})) + assert "x" in da.coords + assert "x" not in da.xindexes + + def test_constructor_multiindex(self) -> None: + midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + coords = Coordinates.from_pandas_multiindex(midx, "x") + + da = DataArray(range(4), coords=coords, dims="x") + assert_identical(da.coords, coords) + + def test_constructor_custom_index(self) -> None: + class CustomIndex(Index): ... + + coords = Coordinates( + coords={"x": ("x", [1, 2, 3])}, indexes={"x": CustomIndex()} + ) + da = DataArray(range(3), coords=coords) + assert isinstance(da.xindexes["x"], CustomIndex) + + # test coordinate variables copied + assert da.coords["x"] is not coords.variables["x"] + def test_equals_and_identical(self) -> None: orig = DataArray(np.arange(5.0), {"a": 42}, dims="x") @@ -788,6 +843,27 @@ def get_data(): ) da[dict(x=ind)] = value # should not raise + def test_setitem_vectorized(self) -> None: + # Regression test for GH:7030 + # Positional indexing + v = xr.DataArray(np.r_[:120].reshape(2, 3, 4, 5), dims=["a", "b", "c", "d"]) + b = xr.DataArray([[0, 0], [1, 0]], dims=["u", "v"]) + c = xr.DataArray([[0, 1], [2, 3]], dims=["u", "v"]) + w = xr.DataArray([-1, -2], dims=["u"]) + index = dict(b=b, c=c) + v[index] = w + assert (v[index] == w).all() + + # Indexing with coordinates + v = xr.DataArray(np.r_[:120].reshape(2, 3, 4, 5), dims=["a", "b", "c", "d"]) + v.coords["b"] = [2, 4, 6] + b = xr.DataArray([[2, 2], [4, 2]], dims=["u", "v"]) + c = xr.DataArray([[0, 1], [2, 3]], dims=["u", "v"]) + w = xr.DataArray([-1, -2], dims=["u"]) + index = dict(b=b, c=c) + v.loc[index] = w + assert (v.loc[index] == w).all() + def test_contains(self) -> None: data_array = DataArray([1, 2]) assert 1 in data_array @@ -807,13 +883,14 @@ def test_chunk(self) -> None: assert blocked.chunks == ((3,), (4,)) first_dask_name = blocked.data.name - blocked = unblocked.chunk(chunks=((2, 1), (2, 2))) - assert blocked.chunks == ((2, 1), (2, 2)) - assert blocked.data.name != first_dask_name + with pytest.warns(DeprecationWarning): + blocked = unblocked.chunk(chunks=((2, 1), (2, 2))) # type: ignore + assert blocked.chunks == ((2, 1), (2, 2)) + assert blocked.data.name != first_dask_name - blocked = unblocked.chunk(chunks=(3, 3)) - assert blocked.chunks == ((3,), (3, 1)) - assert blocked.data.name != first_dask_name + blocked = unblocked.chunk(chunks=(3, 3)) + assert blocked.chunks == ((3,), (3, 1)) + assert blocked.data.name != first_dask_name # name doesn't change when rechunking by same amount # this fails if ReprObject doesn't have __dask_tokenize__ defined @@ -1004,32 +1081,53 @@ def test_sel_dataarray_datetime_slice(self) -> None: result = array.sel(delta=slice(array.delta[0], array.delta[-1])) assert_equal(result, array) - def test_sel_float(self) -> None: + @pytest.mark.parametrize( + ["coord_values", "indices"], + ( + pytest.param( + np.array([0.0, 0.111, 0.222, 0.333], dtype="float64"), + slice(1, 3), + id="float64", + ), + pytest.param( + np.array([0.0, 0.111, 0.222, 0.333], dtype="float32"), + slice(1, 3), + id="float32", + ), + pytest.param( + np.array([0.0, 0.111, 0.222, 0.333], dtype="float32"), [2], id="scalar" + ), + ), + ) + def test_sel_float(self, coord_values, indices) -> None: data_values = np.arange(4) - # case coords are float32 and label is list of floats - float_values = [0.0, 0.111, 0.222, 0.333] - coord_values = np.asarray(float_values, dtype="float32") - array = DataArray(data_values, [("float32_coord", coord_values)]) - expected = DataArray(data_values[1:3], [("float32_coord", coord_values[1:3])]) - actual = array.sel(float32_coord=float_values[1:3]) - # case coords are float16 and label is list of floats - coord_values_16 = np.asarray(float_values, dtype="float16") - expected_16 = DataArray( - data_values[1:3], [("float16_coord", coord_values_16[1:3])] - ) - array_16 = DataArray(data_values, [("float16_coord", coord_values_16)]) - actual_16 = array_16.sel(float16_coord=float_values[1:3]) + arr = DataArray(data_values, coords={"x": coord_values}, dims="x") - # case coord, label are scalars - expected_scalar = DataArray( - data_values[2], coords={"float32_coord": coord_values[2]} + actual = arr.sel(x=coord_values[indices]) + expected = DataArray( + data_values[indices], coords={"x": coord_values[indices]}, dims="x" ) - actual_scalar = array.sel(float32_coord=float_values[2]) - assert_equal(expected, actual) - assert_equal(expected_scalar, actual_scalar) - assert_equal(expected_16, actual_16) + assert_equal(actual, expected) + + def test_sel_float16(self) -> None: + data_values = np.arange(4) + coord_values = np.array([0.0, 0.111, 0.222, 0.333], dtype="float16") + indices = slice(1, 3) + + message = "`pandas.Index` does not support the `float16` dtype.*" + + with pytest.warns(DeprecationWarning, match=message): + arr = DataArray(data_values, coords={"x": coord_values}, dims="x") + with pytest.warns(DeprecationWarning, match=message): + expected = DataArray( + data_values[indices], coords={"x": coord_values[indices]}, dims="x" + ) + + actual = arr.sel(x=coord_values[indices]) + + assert_equal(actual, expected) def test_sel_float_multiindex(self) -> None: # regression test https://github.com/pydata/xarray/issues/5691 @@ -1353,8 +1451,8 @@ def test_coords(self) -> None: expected_repr = dedent( """\ Coordinates: - * x (x) int64 -1 -2 - * y (y) int64 0 1 2""" + * x (x) int64 16B -1 -2 + * y (y) int64 24B 0 1 2""" ) actual = repr(da.coords) assert expected_repr == actual @@ -1368,7 +1466,7 @@ def test_coords(self) -> None: assert_identical(da, expected) with pytest.raises( - ValueError, match=r"cannot set or update variable.*corrupt.*index " + ValueError, match=r"cannot drop or update coordinate.*corrupt.*index " ): self.mda["level_1"] = ("x", np.arange(4)) self.mda.coords["level_1"] = ("x", np.arange(4)) @@ -1488,7 +1586,7 @@ def test_assign_coords(self) -> None: assert_identical(actual, expected) with pytest.raises( - ValueError, match=r"cannot set or update variable.*corrupt.*index " + ValueError, match=r"cannot drop or update coordinate.*corrupt.*index " ): self.mda.assign_coords(level_1=("x", range(4))) @@ -1503,9 +1601,29 @@ def test_assign_coords(self) -> None: def test_assign_coords_existing_multiindex(self) -> None: data = self.mda - with pytest.warns(FutureWarning, match=r"Updating MultiIndexed coordinate"): + with pytest.warns( + FutureWarning, match=r"updating coordinate.*MultiIndex.*inconsistent" + ): data.assign_coords(x=range(4)) + def test_assign_coords_custom_index(self) -> None: + class CustomIndex(Index): + pass + + coords = Coordinates( + coords={"x": ("x", [1, 2, 3])}, indexes={"x": CustomIndex()} + ) + da = xr.DataArray([0, 1, 2], dims="x") + actual = da.assign_coords(coords) + assert isinstance(actual.xindexes["x"], CustomIndex) + + def test_assign_coords_no_default_index(self) -> None: + coords = Coordinates({"y": [1, 2, 3]}, indexes={}) + da = DataArray([1, 2, 3], dims="y") + actual = da.assign_coords(coords) + assert_identical(actual.coords, coords, check_default_indexes=False) + assert "y" not in actual.xindexes + def test_coords_alignment(self) -> None: lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])]) rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])]) @@ -1523,7 +1641,7 @@ def test_set_coords_update_index(self) -> None: def test_set_coords_multiindex_level(self) -> None: with pytest.raises( - ValueError, match=r"cannot set or update variable.*corrupt.*index " + ValueError, match=r"cannot drop or update coordinate.*corrupt.*index " ): self.mda["level_1"] = range(4) @@ -1658,6 +1776,19 @@ def test_reindex_str_dtype(self, dtype) -> None: assert_identical(expected, actual) assert actual.dtype == expected.dtype + def test_reindex_empty_array_dtype(self) -> None: + # Dtype of reindex result should match dtype of the original DataArray. + # See GH issue #7299 + x = xr.DataArray([], dims=("x",), coords={"x": []}).astype("float32") + y = x.reindex(x=[1.0, 2.0]) + + assert ( + x.dtype == y.dtype + ), "Dtype of reindexed DataArray should match dtype of the original DataArray" + assert ( + y.dtype == np.float32 + ), "Dtype of reindexed DataArray should remain float32" + def test_rename(self) -> None: da = xr.DataArray( [1, 2, 3], dims="dim", name="name", coords={"coord": ("dim", [5, 6, 7])} @@ -1757,6 +1888,16 @@ def test_rename_dimension_coord_warnings(self) -> None: ): da.rename(x="y") + # No operation should not raise a warning + da = xr.DataArray( + data=np.ones((2, 3)), + dims=["x", "y"], + coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])}, + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + da.rename(x="x") + def test_init_value(self) -> None: expected = DataArray( np.full((3, 4), 3), dims=["x", "y"], coords=[range(3), range(4)] @@ -2392,6 +2533,28 @@ def test_unstack_pandas_consistency(self) -> None: actual = DataArray(s, dims="z").unstack("z") assert_identical(expected, actual) + def test_unstack_requires_unique(self) -> None: + df = pd.DataFrame({"foo": range(2), "x": ["a", "a"], "y": [0, 0]}) + s = df.set_index(["x", "y"])["foo"] + + with pytest.raises( + ValueError, match="Cannot unstack MultiIndex containing duplicates" + ): + DataArray(s, dims="z").unstack("z") + + @pytest.mark.filterwarnings("error") + def test_unstack_roundtrip_integer_array(self) -> None: + arr = xr.DataArray( + np.arange(6).reshape(2, 3), + coords={"x": ["a", "b"], "y": [0, 1, 2]}, + dims=["x", "y"], + ) + + stacked = arr.stack(z=["x", "y"]) + roundtripped = stacked.unstack() + + assert_identical(arr, roundtripped) + def test_stack_nonunique_consistency(self, da) -> None: da = da.isel(time=0, drop=True) # 2D actual = da.stack(z=["a", "x"]) @@ -2504,6 +2667,14 @@ def test_drop_coordinates(self) -> None: actual = renamed.drop_vars("foo", errors="ignore") assert_identical(actual, renamed) + def test_drop_vars_callable(self) -> None: + A = DataArray( + np.random.randn(2, 3), dims=["x", "y"], coords={"x": [1, 2], "y": [3, 4, 5]} + ) + expected = A.drop_vars(["x", "y"]) + actual = A.drop_vars(lambda x: x.indexes) + assert_identical(expected, actual) + def test_drop_multiindex_level(self) -> None: # GH6505 expected = self.mda.drop_vars(["x", "level_1", "level_2"]) @@ -2578,6 +2749,14 @@ def test_where_lambda(self) -> None: actual = arr.where(lambda x: x.y < 2, drop=True) assert_identical(actual, expected) + def test_where_other_lambda(self) -> None: + arr = DataArray(np.arange(4), dims="y") + expected = xr.concat( + [arr.sel(y=slice(2)), arr.sel(y=slice(2, None)) + 1], dim="y" + ) + actual = arr.where(lambda x: x.y < 2, lambda x: x + 1) + assert_identical(actual, expected) + def test_where_string(self) -> None: array = DataArray(["a", "b"]) expected = DataArray(np.array(["a", np.nan], dtype=object)) @@ -2725,14 +2904,15 @@ def test_reduce_out(self) -> None: with pytest.raises(TypeError): orig.mean(out=np.ones(orig.shape)) + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) @pytest.mark.parametrize("skipna", [True, False, None]) @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) @pytest.mark.parametrize( "axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]]) ) - def test_quantile(self, q, axis, dim, skipna) -> None: + def test_quantile(self, q, axis, dim, skipna, compute_backend) -> None: va = self.va.copy(deep=True) - va[0, 0] = np.NaN + va[0, 0] = np.nan actual = DataArray(va).quantile(q, dim=dim, keep_attrs=True, skipna=skipna) _percentile_func = np.nanpercentile if skipna in (True, None) else np.percentile @@ -2750,10 +2930,7 @@ def test_quantile_method(self, method) -> None: q = [0.25, 0.5, 0.75] actual = DataArray(self.va).quantile(q, method=method) - if Version(np.__version__) >= Version("1.22.0"): - expected = np.nanquantile(self.dv.values, np.array(q), method=method) - else: - expected = np.nanquantile(self.dv.values, np.array(q), interpolation=method) + expected = np.nanquantile(self.dv.values, np.array(q), method=method) np.testing.assert_allclose(actual.values, expected) @@ -3013,10 +3190,10 @@ def test_align_str_dtype(self) -> None: b = DataArray([1, 2], dims=["x"], coords={"x": ["b", "c"]}) expected_a = DataArray( - [0, 1, np.NaN], dims=["x"], coords={"x": ["a", "b", "c"]} + [0, 1, np.nan], dims=["x"], coords={"x": ["a", "b", "c"]} ) expected_b = DataArray( - [np.NaN, 1, 2], dims=["x"], coords={"x": ["a", "b", "c"]} + [np.nan, 1, 2], dims=["x"], coords={"x": ["a", "b", "c"]} ) actual_a, actual_b = xr.align(a, b, join="outer") @@ -3027,6 +3204,42 @@ def test_align_str_dtype(self) -> None: assert_identical(expected_b, actual_b) assert expected_b.x.dtype == actual_b.x.dtype + def test_broadcast_on_vs_off_global_option_different_dims(self) -> None: + xda_1 = xr.DataArray([1], dims="x1") + xda_2 = xr.DataArray([1], dims="x2") + + with xr.set_options(arithmetic_broadcast=True): + expected_xda = xr.DataArray([[1.0]], dims=("x1", "x2")) + actual_xda = xda_1 / xda_2 + assert_identical(actual_xda, expected_xda) + + with xr.set_options(arithmetic_broadcast=False): + with pytest.raises( + ValueError, + match=re.escape( + "Broadcasting is necessary but automatic broadcasting is disabled via " + "global option `'arithmetic_broadcast'`. " + "Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting." + ), + ): + xda_1 / xda_2 + + @pytest.mark.parametrize("arithmetic_broadcast", [True, False]) + def test_broadcast_on_vs_off_global_option_same_dims( + self, arithmetic_broadcast: bool + ) -> None: + # Ensure that no error is raised when arithmetic broadcasting is disabled, + # when broadcasting is not needed. The two DataArrays have the same + # dimensions of the same size. + xda_1 = xr.DataArray([1], dims="x") + xda_2 = xr.DataArray([1], dims="x") + expected_xda = xr.DataArray([2.0], dims=("x",)) + + with xr.set_options(arithmetic_broadcast=arithmetic_broadcast): + assert_identical(xda_1 + xda_2, expected_xda) + assert_identical(xda_1 + np.array([1.0]), expected_xda) + assert_identical(np.array([1.0]) + xda_1, expected_xda) + def test_broadcast_arrays(self) -> None: x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x") y = DataArray([1, 2], coords=[("b", [3, 4])], name="y") @@ -3205,6 +3418,41 @@ def test_to_dataframe_0length(self) -> None: assert len(actual) == 0 assert_array_equal(actual.index.names, list("ABC")) + @requires_dask_expr + @requires_dask + @pytest.mark.xfail(reason="dask-expr is broken") + def test_to_dask_dataframe(self) -> None: + arr_np = np.arange(3 * 4).reshape(3, 4) + arr = DataArray(arr_np, [("B", [1, 2, 3]), ("A", list("cdef"))], name="foo") + expected = arr.to_series() + actual = arr.to_dask_dataframe()["foo"] + + assert_array_equal(actual.values, expected.values) + + actual = arr.to_dask_dataframe(dim_order=["A", "B"])["foo"] + assert_array_equal(arr_np.transpose().reshape(-1), actual.values) + + # regression test for coords with different dimensions + + arr.coords["C"] = ("B", [-1, -2, -3]) + expected = arr.to_series().to_frame() + expected["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 + expected = expected[["C", "foo"]] + actual = arr.to_dask_dataframe()[["C", "foo"]] + + assert_array_equal(expected.values, actual.values) + assert_array_equal(expected.columns.values, actual.columns.values) + + with pytest.raises(ValueError, match="does not match the set of dimensions"): + arr.to_dask_dataframe(dim_order=["B", "A", "C"]) + + arr.name = None + with pytest.raises( + ValueError, + match="Cannot convert an unnamed DataArray", + ): + arr.to_dask_dataframe() + def test_to_pandas_name_matches_coordinate(self) -> None: # coordinate with same name as array arr = DataArray([1, 2, 3], dims="x", name="x") @@ -3305,46 +3553,70 @@ def test_series_categorical_index(self) -> None: arr = DataArray(s) assert "'a'" in repr(arr) # should not error + @pytest.mark.parametrize("use_dask", [True, False]) + @pytest.mark.parametrize("data", ["list", "array", True]) @pytest.mark.parametrize("encoding", [True, False]) - def test_to_and_from_dict(self, encoding) -> None: + def test_to_and_from_dict( + self, encoding: bool, data: bool | Literal["list", "array"], use_dask: bool + ) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + encoding_data = {"bar": "spam"} array = DataArray( np.random.randn(2, 3), {"x": ["a", "b"]}, ["x", "y"], name="foo" ) - array.encoding = {"bar": "spam"} - expected = { + array.encoding = encoding_data + + return_data = array.to_numpy() + coords_data = np.array(["a", "b"]) + if data == "list" or data is True: + return_data = return_data.tolist() + coords_data = coords_data.tolist() + + expected: dict[str, Any] = { "name": "foo", "dims": ("x", "y"), - "data": array.values.tolist(), + "data": return_data, "attrs": {}, - "coords": {"x": {"dims": ("x",), "data": ["a", "b"], "attrs": {}}}, + "coords": {"x": {"dims": ("x",), "data": coords_data, "attrs": {}}}, } if encoding: - expected["encoding"] = {"bar": "spam"} - actual = array.to_dict(encoding=encoding) + expected["encoding"] = encoding_data + + if has_dask: + da = array.chunk() + else: + da = array + + if data == "array" or data is False: + with raise_if_dask_computes(): + actual = da.to_dict(encoding=encoding, data=data) + else: + actual = da.to_dict(encoding=encoding, data=data) # check that they are identical - assert expected == actual + np.testing.assert_equal(expected, actual) # check roundtrip - assert_identical(array, DataArray.from_dict(actual)) + assert_identical(da, DataArray.from_dict(actual)) # a more bare bones representation still roundtrips d = { "name": "foo", "dims": ("x", "y"), - "data": array.values.tolist(), + "data": da.values.tolist(), "coords": {"x": {"dims": "x", "data": ["a", "b"]}}, } - assert_identical(array, DataArray.from_dict(d)) + assert_identical(da, DataArray.from_dict(d)) # and the most bare bones representation still roundtrips - d = {"name": "foo", "dims": ("x", "y"), "data": array.values} - assert_identical(array.drop_vars("x"), DataArray.from_dict(d)) + d = {"name": "foo", "dims": ("x", "y"), "data": da.values} + assert_identical(da.drop_vars("x"), DataArray.from_dict(d)) # missing a dims in the coords d = { "dims": ("x", "y"), - "data": array.values, + "data": da.values, "coords": {"x": {"data": ["a", "b"]}}, } with pytest.raises( @@ -3367,7 +3639,7 @@ def test_to_and_from_dict(self, encoding) -> None: endiantype = "U1" expected_no_data["coords"]["x"].update({"dtype": endiantype, "shape": (2,)}) expected_no_data.update({"dtype": "float64", "shape": (2, 3)}) - actual_no_data = array.to_dict(data=False, encoding=encoding) + actual_no_data = da.to_dict(data=False, encoding=encoding) assert expected_no_data == actual_no_data def test_to_and_from_dict_with_time_dim(self) -> None: @@ -3451,95 +3723,6 @@ def test_to_masked_array(self) -> None: ma = da.to_masked_array() assert len(ma.mask) == N - def test_to_and_from_cdms2_classic(self) -> None: - """Classic with 1D axes""" - pytest.importorskip("cdms2") - - original = DataArray( - np.arange(6).reshape(2, 3), - [ - ("distance", [-2, 2], {"units": "meters"}), - ("time", pd.date_range("2000-01-01", periods=3)), - ], - name="foo", - attrs={"baz": 123}, - ) - expected_coords = [ - IndexVariable("distance", [-2, 2]), - IndexVariable("time", [0, 1, 2]), - ] - actual = original.to_cdms2() - assert_array_equal(actual.asma(), original) - assert actual.id == original.name - assert tuple(actual.getAxisIds()) == original.dims - for axis, coord in zip(actual.getAxisList(), expected_coords): - assert axis.id == coord.name - assert_array_equal(axis, coord.values) - assert actual.baz == original.attrs["baz"] - - component_times = actual.getAxis(1).asComponentTime() - assert len(component_times) == 3 - assert str(component_times[0]) == "2000-1-1 0:0:0.0" - - roundtripped = DataArray.from_cdms2(actual) - assert_identical(original, roundtripped) - - back = from_cdms2(actual) - assert original.dims == back.dims - assert original.coords.keys() == back.coords.keys() - for coord_name in original.coords.keys(): - assert_array_equal(original.coords[coord_name], back.coords[coord_name]) - - def test_to_and_from_cdms2_sgrid(self) -> None: - """Curvilinear (structured) grid - - The rectangular grid case is covered by the classic case - """ - pytest.importorskip("cdms2") - - lonlat = np.mgrid[:3, :4] - lon = DataArray(lonlat[1], dims=["y", "x"], name="lon") - lat = DataArray(lonlat[0], dims=["y", "x"], name="lat") - x = DataArray(np.arange(lon.shape[1]), dims=["x"], name="x") - y = DataArray(np.arange(lon.shape[0]), dims=["y"], name="y") - original = DataArray( - lonlat.sum(axis=0), - dims=["y", "x"], - coords=dict(x=x, y=y, lon=lon, lat=lat), - name="sst", - ) - actual = original.to_cdms2() - assert tuple(actual.getAxisIds()) == original.dims - assert_array_equal(original.coords["lon"], actual.getLongitude().asma()) - assert_array_equal(original.coords["lat"], actual.getLatitude().asma()) - - back = from_cdms2(actual) - assert original.dims == back.dims - assert set(original.coords.keys()) == set(back.coords.keys()) - assert_array_equal(original.coords["lat"], back.coords["lat"]) - assert_array_equal(original.coords["lon"], back.coords["lon"]) - - def test_to_and_from_cdms2_ugrid(self) -> None: - """Unstructured grid""" - pytest.importorskip("cdms2") - - lon = DataArray(np.random.uniform(size=5), dims=["cell"], name="lon") - lat = DataArray(np.random.uniform(size=5), dims=["cell"], name="lat") - cell = DataArray(np.arange(5), dims=["cell"], name="cell") - original = DataArray( - np.arange(5), dims=["cell"], coords={"lon": lon, "lat": lat, "cell": cell} - ) - actual = original.to_cdms2() - assert tuple(actual.getAxisIds()) == original.dims - assert_array_equal(original.coords["lon"], actual.getLongitude().getValue()) - assert_array_equal(original.coords["lat"], actual.getLatitude().getValue()) - - back = from_cdms2(actual) - assert set(original.dims) == set(back.dims) - assert set(original.coords.keys()) == set(back.coords.keys()) - assert_array_equal(original.coords["lat"], back.coords["lat"]) - assert_array_equal(original.coords["lon"], back.coords["lon"]) - def test_to_dataset_whole(self) -> None: unnamed = DataArray([1, 2], dims="x") with pytest.raises(ValueError, match=r"unable to convert unnamed"): @@ -3565,15 +3748,23 @@ def test_to_dataset_whole(self) -> None: actual = named.to_dataset("bar") def test_to_dataset_split(self) -> None: - array = DataArray([1, 2, 3], coords=[("x", list("abc"))], attrs={"a": 1}) - expected = Dataset({"a": 1, "b": 2, "c": 3}, attrs={"a": 1}) + array = DataArray( + [[1, 2], [3, 4], [5, 6]], + coords=[("x", list("abc")), ("y", [0.0, 0.1])], + attrs={"a": 1}, + ) + expected = Dataset( + {"a": ("y", [1, 2]), "b": ("y", [3, 4]), "c": ("y", [5, 6])}, + coords={"y": [0.0, 0.1]}, + attrs={"a": 1}, + ) actual = array.to_dataset("x") assert_identical(expected, actual) with pytest.raises(TypeError): array.to_dataset("x", name="foo") - roundtripped = actual.to_array(dim="x") + roundtripped = actual.to_dataarray(dim="x") assert_identical(array, roundtripped) array = DataArray([1, 2, 3], dims="x") @@ -3590,10 +3781,55 @@ def test_to_dataset_retains_keys(self) -> None: array = DataArray([1, 2, 3], coords=[("x", dates)], attrs={"a": 1}) # convert to dateset and back again - result = array.to_dataset("x").to_array(dim="x") + result = array.to_dataset("x").to_dataarray(dim="x") assert_equal(array, result) + def test_to_dataset_coord_value_is_dim(self) -> None: + # github issue #7823 + + array = DataArray( + np.zeros((3, 3)), + coords={ + # 'a' is both a coordinate value and the name of a coordinate + "x": ["a", "b", "c"], + "a": [1, 2, 3], + }, + ) + + with pytest.raises( + ValueError, + match=( + re.escape("dimension 'x' would produce the variables ('a',)") + + ".*" + + re.escape("DataArray.rename(a=...) or DataArray.assign_coords(x=...)") + ), + ): + array.to_dataset("x") + + # test error message formatting when there are multiple ambiguous + # values/coordinates + array2 = DataArray( + np.zeros((3, 3, 2)), + coords={ + "x": ["a", "b", "c"], + "a": [1, 2, 3], + "b": [0.0, 0.1], + }, + ) + + with pytest.raises( + ValueError, + match=( + re.escape("dimension 'x' would produce the variables ('a', 'b')") + + ".*" + + re.escape( + "DataArray.rename(a=..., b=...) or DataArray.assign_coords(x=...)" + ) + ), + ): + array2.to_dataset("x") + def test__title_for_slice(self) -> None: array = DataArray( np.ones((4, 3, 2)), @@ -3790,17 +4026,17 @@ def test_dot(self) -> None: assert_equal(expected3, actual3) # Ellipsis: all dims are shared - actual4 = da.dot(da, dims=...) + actual4 = da.dot(da, dim=...) expected4 = da.dot(da) assert_equal(expected4, actual4) # Ellipsis: not all dims are shared - actual5 = da.dot(dm3, dims=...) - expected5 = da.dot(dm3, dims=("j", "x", "y", "z")) + actual5 = da.dot(dm3, dim=...) + expected5 = da.dot(dm3, dim=("j", "x", "y", "z")) assert_equal(expected5, actual5) with pytest.raises(NotImplementedError): - da.dot(dm3.to_dataset(name="dm")) # type: ignore + da.dot(dm3.to_dataset(name="dm")) with pytest.raises(TypeError): da.dot(dm3.values) # type: ignore @@ -3886,6 +4122,11 @@ def test_binary_op_propagate_indexes(self) -> None: actual = (self.dv > 10).xindexes["x"] assert expected is actual + # use mda for bitshift test as it's type int + actual = (self.mda << 2).xindexes["x"] + expected = self.mda.xindexes["x"] + assert expected is actual + def test_binary_op_join_setting(self) -> None: dim = "x" align_type: Final = "outer" @@ -4000,9 +4241,7 @@ def test_polyfit(self, use_dask, use_datetime) -> None: xcoord = x da_raw = DataArray( - np.stack( - (10 + 1e-15 * x + 2e-28 * x**2, 30 + 2e-14 * x + 1e-29 * x**2) - ), + np.stack((10 + 1e-15 * x + 2e-28 * x**2, 30 + 2e-14 * x + 1e-29 * x**2)), dims=("d", "x"), coords={"x": xcoord, "d": [0, 1]}, ) @@ -4022,7 +4261,7 @@ def test_polyfit(self, use_dask, use_datetime) -> None: # Full output and deficient rank with warnings.catch_warnings(): - warnings.simplefilter("ignore", np.RankWarning) + warnings.simplefilter("ignore", RankWarning) out = da.polyfit("x", 12, full=True) assert out.polyfit_residuals.isnull().all() @@ -4043,7 +4282,7 @@ def test_polyfit(self, use_dask, use_datetime) -> None: np.testing.assert_almost_equal(out.polyfit_residuals, [0, 0]) with warnings.catch_warnings(): - warnings.simplefilter("ignore", np.RankWarning) + warnings.simplefilter("ignore", RankWarning) out = da.polyfit("x", 8, full=True) np.testing.assert_array_equal(out.polyfit_residuals.isnull(), [True, False]) @@ -4064,7 +4303,7 @@ def test_pad_constant(self) -> None: ar = xr.DataArray([9], dims="x") actual = ar.pad(x=1) - expected = xr.DataArray([np.NaN, 9, np.NaN], dims="x") + expected = xr.DataArray([np.nan, 9, np.nan], dims="x") assert_identical(actual, expected) actual = ar.pad(x=1, constant_values=1.23456) @@ -4072,7 +4311,7 @@ def test_pad_constant(self) -> None: assert_identical(actual, expected) with pytest.raises(ValueError, match="cannot convert float NaN to integer"): - ar.pad(x=1, constant_values=np.NaN) + ar.pad(x=1, constant_values=np.nan) def test_pad_coords(self) -> None: ar = DataArray( @@ -4297,7 +4536,7 @@ def exp_decay(t, n0, tau=1): da = da.chunk({"x": 1}) fit = da.curvefit( - coords=[da.t], func=exp_decay, p0={"n0": 4}, bounds={"tau": [2, 6]} + coords=[da.t], func=exp_decay, p0={"n0": 4}, bounds={"tau": (2, 6)} ) assert_allclose(fit.curvefit_coefficients, expected, rtol=1e-3) @@ -4318,12 +4557,183 @@ def exp_decay(t, n0, tau=1): assert param_defaults == {"n0": 4, "tau": 6} assert bounds_defaults == {"n0": (-np.inf, np.inf), "tau": (5, np.inf)} + # DataArray as bound + param_defaults, bounds_defaults = xr.core.dataset._initialize_curvefit_params( + params=params, + p0={"n0": 4}, + bounds={"tau": [DataArray([3, 4], coords=[("x", [1, 2])]), np.inf]}, + func_args=func_args, + ) + assert param_defaults["n0"] == 4 + assert ( + param_defaults["tau"] == xr.DataArray([4, 5], coords=[("x", [1, 2])]) + ).all() + assert bounds_defaults["n0"] == (-np.inf, np.inf) + assert ( + bounds_defaults["tau"][0] == DataArray([3, 4], coords=[("x", [1, 2])]) + ).all() + assert bounds_defaults["tau"][1] == np.inf + param_names = ["a"] params, func_args = xr.core.dataset._get_func_args(np.power, param_names) assert params == param_names with pytest.raises(ValueError): xr.core.dataset._get_func_args(np.power, []) + @requires_scipy + @pytest.mark.parametrize("use_dask", [True, False]) + def test_curvefit_multidimensional_guess(self, use_dask: bool) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + + def sine(t, a, f, p): + return a * np.sin(2 * np.pi * (f * t + p)) + + t = np.arange(0, 2, 0.02) + da = DataArray( + np.stack([sine(t, 1.0, 2, 0), sine(t, 1.0, 2, 0)]), + coords={"x": [0, 1], "t": t}, + ) + + # Fitting to a sine curve produces a different result depending on the + # initial guess: either the phase is zero and the amplitude is positive + # or the phase is 0.5 * 2pi and the amplitude is negative. + + expected = DataArray( + [[1, 2, 0], [-1, 2, 0.5]], + coords={"x": [0, 1], "param": ["a", "f", "p"]}, + ) + + # Different initial guesses for different values of x + a_guess = DataArray([1, -1], coords=[da.x]) + p_guess = DataArray([0, 0.5], coords=[da.x]) + + if use_dask: + da = da.chunk({"x": 1}) + + fit = da.curvefit( + coords=[da.t], + func=sine, + p0={"a": a_guess, "p": p_guess, "f": 2}, + ) + assert_allclose(fit.curvefit_coefficients, expected) + + with pytest.raises( + ValueError, + match=r"Initial guess for 'a' has unexpected dimensions .* should only have " + "dimensions that are in data dimensions", + ): + # initial guess with additional dimensions should be an error + da.curvefit( + coords=[da.t], + func=sine, + p0={"a": DataArray([1, 2], coords={"foo": [1, 2]})}, + ) + + @requires_scipy + @pytest.mark.parametrize("use_dask", [True, False]) + def test_curvefit_multidimensional_bounds(self, use_dask: bool) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + + def sine(t, a, f, p): + return a * np.sin(2 * np.pi * (f * t + p)) + + t = np.arange(0, 2, 0.02) + da = xr.DataArray( + np.stack([sine(t, 1.0, 2, 0), sine(t, 1.0, 2, 0)]), + coords={"x": [0, 1], "t": t}, + ) + + # Fit a sine with different bounds: positive amplitude should result in a fit with + # phase 0 and negative amplitude should result in phase 0.5 * 2pi. + + expected = DataArray( + [[1, 2, 0], [-1, 2, 0.5]], + coords={"x": [0, 1], "param": ["a", "f", "p"]}, + ) + + if use_dask: + da = da.chunk({"x": 1}) + + fit = da.curvefit( + coords=[da.t], + func=sine, + p0={"f": 2, "p": 0.25}, # this guess is needed to get the expected result + bounds={ + "a": ( + DataArray([0, -2], coords=[da.x]), + DataArray([2, 0], coords=[da.x]), + ), + }, + ) + assert_allclose(fit.curvefit_coefficients, expected) + + # Scalar lower bound with array upper bound + fit2 = da.curvefit( + coords=[da.t], + func=sine, + p0={"f": 2, "p": 0.25}, # this guess is needed to get the expected result + bounds={ + "a": (-2, DataArray([2, 0], coords=[da.x])), + }, + ) + assert_allclose(fit2.curvefit_coefficients, expected) + + with pytest.raises( + ValueError, + match=r"Upper bound for 'a' has unexpected dimensions .* should only have " + "dimensions that are in data dimensions", + ): + # bounds with additional dimensions should be an error + da.curvefit( + coords=[da.t], + func=sine, + bounds={"a": (0, DataArray([1], coords={"foo": [1]}))}, + ) + + @requires_scipy + @pytest.mark.parametrize("use_dask", [True, False]) + def test_curvefit_ignore_errors(self, use_dask: bool) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + + # nonsense function to make the optimization fail + def line(x, a, b): + if a > 10: + return 0 + return a * x + b + + da = DataArray( + [[1, 3, 5], [0, 20, 40]], + coords={"i": [1, 2], "x": [0.0, 1.0, 2.0]}, + ) + + if use_dask: + da = da.chunk({"i": 1}) + + expected = DataArray( + [[2, 1], [np.nan, np.nan]], coords={"i": [1, 2], "param": ["a", "b"]} + ) + + with pytest.raises(RuntimeError, match="calls to function has reached maxfev"): + da.curvefit( + coords="x", + func=line, + # limit maximum number of calls so the optimization fails + kwargs=dict(maxfev=5), + ).compute() # have to compute to raise the error + + fit = da.curvefit( + coords="x", + func=line, + errors="ignore", + # limit maximum number of calls so the optimization fails + kwargs=dict(maxfev=5), + ).compute() + + assert_allclose(fit.curvefit_coefficients, expected) + class TestReduce: @pytest.fixture(autouse=True) @@ -4339,10 +4749,10 @@ def setup(self): np.array([0.0, 1.0, 2.0, 0.0, -2.0, -4.0, 2.0]), 5, 2, None, id="float" ), pytest.param( - np.array([1.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0]), 5, 2, 1, id="nan" + np.array([1.0, np.nan, 2.0, np.nan, -2.0, -4.0, 2.0]), 5, 2, 1, id="nan" ), pytest.param( - np.array([1.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0]).astype("object"), + np.array([1.0, np.nan, 2.0, np.nan, -2.0, -4.0, 2.0]).astype("object"), 5, 2, 1, @@ -4351,7 +4761,7 @@ def setup(self): ), id="obj", ), - pytest.param(np.array([np.NaN, np.NaN]), np.NaN, np.NaN, 0, id="allnan"), + pytest.param(np.array([np.nan, np.nan]), np.nan, np.nan, 0, id="allnan"), pytest.param( np.array( ["2015-12-31", "2020-01-02", "2020-01-01", "2016-01-01"], @@ -4527,8 +4937,10 @@ def test_idxmin( else: ar0 = ar0_raw - # dim doesn't exist - with pytest.raises(KeyError): + with pytest.raises( + KeyError, + match=r"'spam' not found in array dimensions", + ): ar0.idxmin(dim="spam") # Scalar Dataarray @@ -4544,7 +4956,7 @@ def test_idxmin( if hasna: coordarr1[...] = 1 - fill_value_0 = np.NaN + fill_value_0 = np.nan else: fill_value_0 = 1 @@ -4558,7 +4970,7 @@ def test_idxmin( assert_identical(result0, expected0) # Manually specify NaN fill_value - result1 = ar0.idxmin(fill_value=np.NaN) + result1 = ar0.idxmin(fill_value=np.nan) assert_identical(result1, expected0) # keep_attrs @@ -4640,8 +5052,10 @@ def test_idxmax( else: ar0 = ar0_raw - # dim doesn't exist - with pytest.raises(KeyError): + with pytest.raises( + KeyError, + match=r"'spam' not found in array dimensions", + ): ar0.idxmax(dim="spam") # Scalar Dataarray @@ -4657,7 +5071,7 @@ def test_idxmax( if hasna: coordarr1[...] = 1 - fill_value_0 = np.NaN + fill_value_0 = np.nan else: fill_value_0 = 1 @@ -4671,7 +5085,7 @@ def test_idxmax( assert_identical(result0, expected0) # Manually specify NaN fill_value - result1 = ar0.idxmax(fill_value=np.NaN) + result1 = ar0.idxmax(fill_value=np.nan) assert_identical(result1, expected0) # keep_attrs @@ -4836,12 +5250,12 @@ def test_argmax_dim( np.array( [ [2.0, 1.0, 2.0, 0.0, -2.0, -4.0, 2.0], - [-4.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0], - [np.NaN] * 7, + [-4.0, np.nan, 2.0, np.nan, -2.0, -4.0, 2.0], + [np.nan] * 7, ] ), - [5, 0, np.NaN], - [0, 2, np.NaN], + [5, 0, np.nan], + [0, 2, np.nan], [None, 1, 0], id="nan", ), @@ -4849,12 +5263,12 @@ def test_argmax_dim( np.array( [ [2.0, 1.0, 2.0, 0.0, -2.0, -4.0, 2.0], - [-4.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0], - [np.NaN] * 7, + [-4.0, np.nan, 2.0, np.nan, -2.0, -4.0, 2.0], + [np.nan] * 7, ] ).astype("object"), - [5, 0, np.NaN], - [0, 2, np.NaN], + [5, 0, np.nan], + [0, 2, np.nan], [None, 1, 0], marks=pytest.mark.filterwarnings( "ignore:invalid value encountered in reduce:RuntimeWarning:" @@ -5129,7 +5543,7 @@ def test_idxmin( coordarr1[hasna, :] = 1 minindex0 = [x if not np.isnan(x) else 0 for x in minindex] - nan_mult_0 = np.array([np.NaN if x else 1 for x in hasna])[:, None] + nan_mult_0 = np.array([np.nan if x else 1 for x in hasna])[:, None] expected0list = [ (coordarr1 * nan_mult_0).isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(minindex0) @@ -5144,7 +5558,7 @@ def test_idxmin( # Manually specify NaN fill_value with raise_if_dask_computes(max_computes=max_computes): - result1 = ar0.idxmin(dim="x", fill_value=np.NaN) + result1 = ar0.idxmin(dim="x", fill_value=np.nan) assert_identical(result1, expected0) # keep_attrs @@ -5271,7 +5685,7 @@ def test_idxmax( coordarr1[hasna, :] = 1 maxindex0 = [x if not np.isnan(x) else 0 for x in maxindex] - nan_mult_0 = np.array([np.NaN if x else 1 for x in hasna])[:, None] + nan_mult_0 = np.array([np.nan if x else 1 for x in hasna])[:, None] expected0list = [ (coordarr1 * nan_mult_0).isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(maxindex0) @@ -5286,7 +5700,7 @@ def test_idxmax( # Manually specify NaN fill_value with raise_if_dask_computes(max_computes=max_computes): - result1 = ar0.idxmax(dim="x", fill_value=np.NaN) + result1 = ar0.idxmax(dim="x", fill_value=np.nan) assert_identical(result1, expected0) # keep_attrs @@ -5545,31 +5959,31 @@ def test_argmax_dim( np.array( [ [[2.0, 1.0, 2.0, 0.0], [-2.0, -4.0, 2.0, 0.0]], - [[-4.0, np.NaN, 2.0, np.NaN], [-2.0, -4.0, 2.0, 0.0]], - [[np.NaN] * 4, [np.NaN] * 4], + [[-4.0, np.nan, 2.0, np.nan], [-2.0, -4.0, 2.0, 0.0]], + [[np.nan] * 4, [np.nan] * 4], ] ), {"x": np.array([[1, 0, 0, 0], [0, 0, 0, 0]])}, { "y": np.array( - [[1, 1, 0, 0], [0, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + [[1, 1, 0, 0], [0, 1, 0, 1], [np.nan, np.nan, np.nan, np.nan]] ) }, - {"z": np.array([[3, 1], [0, 1], [np.NaN, np.NaN]])}, + {"z": np.array([[3, 1], [0, 1], [np.nan, np.nan]])}, {"x": np.array([1, 0, 0, 0]), "y": np.array([0, 1, 0, 0])}, {"x": np.array([1, 0]), "z": np.array([0, 1])}, - {"y": np.array([1, 0, np.NaN]), "z": np.array([1, 0, np.NaN])}, + {"y": np.array([1, 0, np.nan]), "z": np.array([1, 0, np.nan])}, {"x": np.array(0), "y": np.array(1), "z": np.array(1)}, {"x": np.array([[0, 0, 0, 0], [0, 0, 0, 0]])}, { "y": np.array( - [[0, 0, 0, 0], [1, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + [[0, 0, 0, 0], [1, 1, 0, 1], [np.nan, np.nan, np.nan, np.nan]] ) }, - {"z": np.array([[0, 2], [2, 2], [np.NaN, np.NaN]])}, + {"z": np.array([[0, 2], [2, 2], [np.nan, np.nan]])}, {"x": np.array([0, 0, 0, 0]), "y": np.array([0, 0, 0, 0])}, {"x": np.array([0, 0]), "z": np.array([2, 2])}, - {"y": np.array([0, 0, np.NaN]), "z": np.array([0, 2, np.NaN])}, + {"y": np.array([0, 0, np.nan]), "z": np.array([0, 2, np.nan])}, {"x": np.array(0), "y": np.array(0), "z": np.array(0)}, {"x": np.array([[2, 1, 2, 1], [2, 2, 2, 2]])}, { @@ -5588,31 +6002,31 @@ def test_argmax_dim( np.array( [ [[2.0, 1.0, 2.0, 0.0], [-2.0, -4.0, 2.0, 0.0]], - [[-4.0, np.NaN, 2.0, np.NaN], [-2.0, -4.0, 2.0, 0.0]], - [[np.NaN] * 4, [np.NaN] * 4], + [[-4.0, np.nan, 2.0, np.nan], [-2.0, -4.0, 2.0, 0.0]], + [[np.nan] * 4, [np.nan] * 4], ] ).astype("object"), {"x": np.array([[1, 0, 0, 0], [0, 0, 0, 0]])}, { "y": np.array( - [[1, 1, 0, 0], [0, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + [[1, 1, 0, 0], [0, 1, 0, 1], [np.nan, np.nan, np.nan, np.nan]] ) }, - {"z": np.array([[3, 1], [0, 1], [np.NaN, np.NaN]])}, + {"z": np.array([[3, 1], [0, 1], [np.nan, np.nan]])}, {"x": np.array([1, 0, 0, 0]), "y": np.array([0, 1, 0, 0])}, {"x": np.array([1, 0]), "z": np.array([0, 1])}, - {"y": np.array([1, 0, np.NaN]), "z": np.array([1, 0, np.NaN])}, + {"y": np.array([1, 0, np.nan]), "z": np.array([1, 0, np.nan])}, {"x": np.array(0), "y": np.array(1), "z": np.array(1)}, {"x": np.array([[0, 0, 0, 0], [0, 0, 0, 0]])}, { "y": np.array( - [[0, 0, 0, 0], [1, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + [[0, 0, 0, 0], [1, 1, 0, 1], [np.nan, np.nan, np.nan, np.nan]] ) }, - {"z": np.array([[0, 2], [2, 2], [np.NaN, np.NaN]])}, + {"z": np.array([[0, 2], [2, 2], [np.nan, np.nan]])}, {"x": np.array([0, 0, 0, 0]), "y": np.array([0, 0, 0, 0])}, {"x": np.array([0, 0]), "z": np.array([2, 2])}, - {"y": np.array([0, 0, np.NaN]), "z": np.array([0, 2, np.NaN])}, + {"y": np.array([0, 0, np.nan]), "z": np.array([0, 2, np.nan])}, {"x": np.array(0), "y": np.array(0), "z": np.array(0)}, {"x": np.array([[2, 1, 2, 1], [2, 2, 2, 2]])}, { @@ -6158,12 +6572,12 @@ def test_isin(da) -> None: def test_raise_no_warning_for_nan_in_binary_ops() -> None: with assert_no_warnings(): - xr.DataArray([1, 2, np.NaN]) > 0 + xr.DataArray([1, 2, np.nan]) > 0 @pytest.mark.filterwarnings("error") def test_no_warning_for_all_nan() -> None: - _ = xr.DataArray([np.NaN, np.NaN]).mean() + _ = xr.DataArray([np.nan, np.nan]).mean() def test_name_in_masking() -> None: @@ -6203,7 +6617,7 @@ def test_to_and_from_iris(self) -> None: ) # Set a bad value to test the masking logic - original.data[0, 2] = np.NaN + original.data[0, 2] = np.nan original.attrs["cell_methods"] = "height: mean (comment: A cell method)" actual = original.to_iris() @@ -6595,7 +7009,12 @@ def test_drop_duplicates_1d(self, keep) -> None: result = da.drop_duplicates("time", keep=keep) assert_equal(expected, result) - with pytest.raises(ValueError, match="['space'] not found"): + with pytest.raises( + ValueError, + match=re.escape( + "Dimensions ('space',) not found in data dimensions ('time',)" + ), + ): da.drop_duplicates("space", keep=keep) def test_drop_duplicates_2d(self) -> None: @@ -6716,3 +7135,35 @@ def test_error_on_ellipsis_without_list(self) -> None: da = DataArray([[1, 2], [1, 2]], dims=("x", "y")) with pytest.raises(ValueError): da.stack(flat=...) # type: ignore + + +def test_nD_coord_dataarray() -> None: + # should succeed + da = DataArray( + np.ones((2, 4)), + dims=("x", "y"), + coords={ + "x": (("x", "y"), np.arange(8).reshape((2, 4))), + "y": ("y", np.arange(4)), + }, + ) + _assert_internal_invariants(da, check_default_indexes=True) + + da2 = DataArray(np.ones(4), dims=("y"), coords={"y": ("y", np.arange(4))}) + da3 = DataArray(np.ones(4), dims=("z")) + + _, actual = xr.align(da, da2) + assert_identical(da2, actual) + + expected = da.drop_vars("x") + _, actual = xr.broadcast(da, da2) + assert_identical(expected, actual) + + actual, _ = xr.broadcast(da, da3) + expected = da.expand_dims(z=4, axis=-1) + assert_identical(actual, expected) + + da4 = DataArray(np.ones((2, 4)), coords={"x": 0}, dims=["x", "y"]) + _assert_internal_invariants(da4, check_default_indexes=True) + assert "x" not in da4.xindexes + assert "x" in da4.coords diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 2e23d02a261..d2b8634b8b9 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -8,13 +8,19 @@ from copy import copy, deepcopy from io import StringIO from textwrap import dedent -from typing import Any +from typing import Any, Literal import numpy as np import pandas as pd import pytest from pandas.core.indexes.datetimes import DatetimeIndex +# remove once numpy 2.0 is the oldest supported version +try: + from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] +except ImportError: + from numpy import RankWarning + import xarray as xr from xarray import ( DataArray, @@ -31,11 +37,13 @@ from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import dtypes, indexing, utils from xarray.core.common import duck_array_ops, full_like -from xarray.core.coordinates import DatasetCoordinates +from xarray.core.coordinates import Coordinates, DatasetCoordinates from xarray.core.indexes import Index, PandasIndex -from xarray.core.pycompat import array_type, integer_types from xarray.core.utils import is_scalar +from xarray.namedarray.pycompat import array_type, integer_types +from xarray.testing import _assert_internal_invariants from xarray.tests import ( + DuckArrayWrapper, InaccessibleArray, UnexpectedDataAccess, assert_allclose, @@ -43,6 +51,7 @@ assert_equal, assert_identical, assert_no_warnings, + assert_writeable, create_test_data, has_cftime, has_dask, @@ -52,6 +61,7 @@ requires_cupy, requires_dask, requires_numexpr, + requires_pandas_version_two, requires_pint, requires_scipy, requires_sparse, @@ -87,11 +97,11 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]: nt2 = 2 time1 = pd.date_range("2000-01-01", periods=nt1) time2 = pd.date_range("2000-02-01", periods=nt2) - string_var = np.array(["ae", "bc", "df"], dtype=object) + string_var = np.array(["a", "bc", "def"], dtype=object) string_var_to_append = np.array(["asdf", "asdfg"], dtype=object) string_var_fixed_length = np.array(["aa", "bb", "cc"], dtype="|S2") string_var_fixed_length_to_append = np.array(["dd", "ee"], dtype="|S2") - unicode_var = ["áó", "áó", "áó"] + unicode_var = np.array(["áó", "áó", "áó"]) datetime_var = np.array( ["2019-01-01", "2019-01-02", "2019-01-03"], dtype="datetime64[s]" ) @@ -101,60 +111,51 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]: bool_var = np.array([True, False, True], dtype=bool) bool_var_to_append = np.array([False, True], dtype=bool) - ds = xr.Dataset( - data_vars={ - "da": xr.DataArray( - rs.rand(3, 3, nt1), - coords=[lat, lon, time1], - dims=["lat", "lon", "time"], - ), - "string_var": xr.DataArray(string_var, coords=[time1], dims=["time"]), - "string_var_fixed_length": xr.DataArray( - string_var_fixed_length, coords=[time1], dims=["time"] - ), - "unicode_var": xr.DataArray( - unicode_var, coords=[time1], dims=["time"] - ).astype(np.unicode_), - "datetime_var": xr.DataArray(datetime_var, coords=[time1], dims=["time"]), - "bool_var": xr.DataArray(bool_var, coords=[time1], dims=["time"]), - } - ) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Converting non-nanosecond") + ds = xr.Dataset( + data_vars={ + "da": xr.DataArray( + rs.rand(3, 3, nt1), + coords=[lat, lon, time1], + dims=["lat", "lon", "time"], + ), + "string_var": ("time", string_var), + "string_var_fixed_length": ("time", string_var_fixed_length), + "unicode_var": ("time", unicode_var), + "datetime_var": ("time", datetime_var), + "bool_var": ("time", bool_var), + } + ) - ds_to_append = xr.Dataset( - data_vars={ - "da": xr.DataArray( - rs.rand(3, 3, nt2), - coords=[lat, lon, time2], - dims=["lat", "lon", "time"], - ), - "string_var": xr.DataArray( - string_var_to_append, coords=[time2], dims=["time"] - ), - "string_var_fixed_length": xr.DataArray( - string_var_fixed_length_to_append, coords=[time2], dims=["time"] - ), - "unicode_var": xr.DataArray( - unicode_var[:nt2], coords=[time2], dims=["time"] - ).astype(np.unicode_), - "datetime_var": xr.DataArray( - datetime_var_to_append, coords=[time2], dims=["time"] - ), - "bool_var": xr.DataArray(bool_var_to_append, coords=[time2], dims=["time"]), - } - ) + ds_to_append = xr.Dataset( + data_vars={ + "da": xr.DataArray( + rs.rand(3, 3, nt2), + coords=[lat, lon, time2], + dims=["lat", "lon", "time"], + ), + "string_var": ("time", string_var_to_append), + "string_var_fixed_length": ("time", string_var_fixed_length_to_append), + "unicode_var": ("time", unicode_var[:nt2]), + "datetime_var": ("time", datetime_var_to_append), + "bool_var": ("time", bool_var_to_append), + } + ) - ds_with_new_var = xr.Dataset( - data_vars={ - "new_var": xr.DataArray( - rs.rand(3, 3, nt1 + nt2), - coords=[lat, lon, time1.append(time2)], - dims=["lat", "lon", "time"], - ) - } - ) + ds_with_new_var = xr.Dataset( + data_vars={ + "new_var": xr.DataArray( + rs.rand(3, 3, nt1 + nt2), + coords=[lat, lon, time1.append(time2)], + dims=["lat", "lon", "time"], + ) + } + ) - assert all(objp.data.flags.writeable for objp in ds.variables.values()) - assert all(objp.data.flags.writeable for objp in ds_to_append.variables.values()) + assert_writeable(ds) + assert_writeable(ds_to_append) + assert_writeable(ds_with_new_var) return ds, ds_to_append, ds_with_new_var @@ -167,10 +168,8 @@ def make_datasets(data, data_to_append) -> tuple[Dataset, Dataset]: ds_to_append = xr.Dataset( {"temperature": (["time"], data_to_append)}, coords={"time": [0, 1, 2]} ) - assert all(objp.data.flags.writeable for objp in ds.variables.values()) - assert all( - objp.data.flags.writeable for objp in ds_to_append.variables.values() - ) + assert_writeable(ds) + assert_writeable(ds_to_append) return ds, ds_to_append u2_strings = ["ab", "cd", "ef"] @@ -191,7 +190,7 @@ def create_test_multiindex() -> Dataset: mindex = pd.MultiIndex.from_product( [["a", "b"], [1, 2]], names=("level_1", "level_2") ) - return Dataset({}, {"x": mindex}) + return Dataset({}, Coordinates.from_pandas_multiindex(mindex, "x")) def create_test_stacked_array() -> tuple[DataArray, DataArray]: @@ -203,6 +202,10 @@ def create_test_stacked_array() -> tuple[DataArray, DataArray]: class InaccessibleVariableDataStore(backends.InMemoryDataStore): + """ + Store that does not allow any data access. + """ + def __init__(self): super().__init__() self._indexvars = set() @@ -223,6 +226,47 @@ def lazy_inaccessible(k, v): return {k: lazy_inaccessible(k, v) for k, v in self._variables.items()} +class DuckBackendArrayWrapper(backends.common.BackendArray): + """Mimic a BackendArray wrapper around DuckArrayWrapper""" + + def __init__(self, array): + self.array = DuckArrayWrapper(array) + self.shape = array.shape + self.dtype = array.dtype + + def get_array(self): + return self.array + + def __getitem__(self, key): + return self.array[key.tuple] + + +class AccessibleAsDuckArrayDataStore(backends.InMemoryDataStore): + """ + Store that returns a duck array, not convertible to numpy array, + on read. Modeled after nVIDIA's kvikio. + """ + + def __init__(self): + super().__init__() + self._indexvars = set() + + def store(self, variables, *args, **kwargs) -> None: + super().store(variables, *args, **kwargs) + for k, v in variables.items(): + if isinstance(v, IndexVariable): + self._indexvars.add(k) + + def get_variables(self) -> dict[Any, xr.Variable]: + def lazy_accessible(k, v) -> xr.Variable: + if k in self._indexvars: + return v + data = indexing.LazilyIndexedArray(DuckBackendArrayWrapper(v.values)) + return Variable(v.dims, data, v.attrs) + + return {k: lazy_accessible(k, v) for k, v in self._variables.items()} + + class TestDataset: def test_repr(self) -> None: data = create_test_data(seed=123) @@ -230,18 +274,18 @@ def test_repr(self) -> None: # need to insert str dtype at runtime to handle different endianness expected = dedent( """\ - + Size: 2kB Dimensions: (dim2: 9, dim3: 10, time: 20, dim1: 8) Coordinates: - * dim2 (dim2) float64 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 - * dim3 (dim3) %s 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' - * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-20 - numbers (dim3) int64 0 1 2 0 0 1 1 2 2 3 + * dim2 (dim2) float64 72B 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 + * dim3 (dim3) %s 40B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' + * time (time) datetime64[ns] 160B 2000-01-01 2000-01-02 ... 2000-01-20 + numbers (dim3) int64 80B 0 1 2 0 0 1 1 2 2 3 Dimensions without coordinates: dim1 Data variables: - var1 (dim1, dim2) float64 -1.086 0.9973 0.283 ... 0.1995 0.4684 -0.8312 - var2 (dim1, dim2) float64 1.162 -1.097 -2.123 ... 0.1302 1.267 0.3328 - var3 (dim3, dim1) float64 0.5565 -0.2121 0.4563 ... -0.2452 -0.3616 + var1 (dim1, dim2) float64 576B -1.086 0.9973 0.283 ... 0.4684 -0.8312 + var2 (dim1, dim2) float64 576B 1.162 -1.097 -2.123 ... 1.267 0.3328 + var3 (dim3, dim1) float64 640B 0.5565 -0.2121 0.4563 ... -0.2452 -0.3616 Attributes: foo: bar""" % data["dim3"].dtype @@ -256,7 +300,7 @@ def test_repr(self) -> None: expected = dedent( """\ - + Size: 0B Dimensions: () Data variables: *empty*""" @@ -269,10 +313,10 @@ def test_repr(self) -> None: data = Dataset({"foo": ("x", np.ones(10))}).mean() expected = dedent( """\ - + Size: 8B Dimensions: () Data variables: - foo float64 1.0""" + foo float64 8B 1.0""" ) actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) print(actual) @@ -286,12 +330,12 @@ def test_repr_multiindex(self) -> None: data = create_test_multiindex() expected = dedent( """\ - + Size: 96B Dimensions: (x: 4) Coordinates: - * x (x) object MultiIndex - * level_1 (x) object 'a' 'a' 'b' 'b' - * level_2 (x) int64 1 2 1 2 + * x (x) object 32B MultiIndex + * level_1 (x) object 32B 'a' 'a' 'b' 'b' + * level_2 (x) int64 32B 1 2 1 2 Data variables: *empty*""" ) @@ -300,18 +344,19 @@ def test_repr_multiindex(self) -> None: assert expected == actual # verify that long level names are not truncated - mindex = pd.MultiIndex.from_product( + midx = pd.MultiIndex.from_product( [["a", "b"], [1, 2]], names=("a_quite_long_level_name", "level_2") ) - data = Dataset({}, {"x": mindex}) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + data = Dataset({}, midx_coords) expected = dedent( """\ - + Size: 96B Dimensions: (x: 4) Coordinates: - * x (x) object MultiIndex - * a_quite_long_level_name (x) object 'a' 'a' 'b' 'b' - * level_2 (x) int64 1 2 1 2 + * x (x) object 32B MultiIndex + * a_quite_long_level_name (x) object 32B 'a' 'a' 'b' 'b' + * level_2 (x) int64 32B 1 2 1 2 Data variables: *empty*""" ) @@ -321,7 +366,7 @@ def test_repr_multiindex(self) -> None: def test_repr_period_index(self) -> None: data = create_test_data(seed=456) - data.coords["time"] = pd.period_range("2000-01-01", periods=20, freq="B") + data.coords["time"] = pd.period_range("2000-01-01", periods=20, freq="D") # check that creating the repr doesn't raise an error #GH645 repr(data) @@ -334,10 +379,10 @@ def test_unicode_data(self) -> None: byteorder = "<" if sys.byteorder == "little" else ">" expected = dedent( """\ - + Size: 12B Dimensions: (foø: 1) Coordinates: - * foø (foø) %cU3 %r + * foø (foø) %cU3 12B %r Data variables: *empty* Attributes: @@ -351,10 +396,14 @@ def test_repr_nep18(self) -> None: class Array: def __init__(self): self.shape = (2,) + self.ndim = 1 self.dtype = np.dtype(np.float64) def __array_function__(self, *args, **kwargs): - pass + return NotImplemented + + def __array_ufunc__(self, *args, **kwargs): + return NotImplemented def __repr__(self): return "Custom\nArray" @@ -362,11 +411,11 @@ def __repr__(self): dataset = Dataset({"foo": ("x", Array())}) expected = dedent( """\ - + Size: 16B Dimensions: (x: 2) Dimensions without coordinates: x Data variables: - foo (x) float64 Custom Array""" + foo (x) float64 16B Custom Array""" ) assert expected == repr(dataset) @@ -415,13 +464,16 @@ def test_constructor(self) -> None: with pytest.raises(ValueError, match=r"conflicting sizes"): Dataset({"a": x1, "b": x2}) - with pytest.raises(ValueError, match=r"disallows such variables"): - Dataset({"a": x1, "x": z}) with pytest.raises(TypeError, match=r"tuple of form"): Dataset({"x": (1, 2, 3, 4, 5, 6, 7)}) with pytest.raises(ValueError, match=r"already exists as a scalar"): Dataset({"x": 0, "y": ("x", [1, 2, 3])}) + # nD coordinate variable "x" sharing name with dimension + actual = Dataset({"a": x1, "x": z}) + assert "x" not in actual.xindexes + _assert_internal_invariants(actual, check_default_indexes=True) + # verify handling of DataArrays expected = Dataset({"x": x1, "z": z}) actual = Dataset({"z": expected["z"]}) @@ -443,6 +495,7 @@ def test_constructor_1d(self) -> None: actual = Dataset({"x": [5, 6, 7, 8, 9]}) assert_identical(expected, actual) + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_constructor_0d(self) -> None: expected = Dataset({"x": ([], 1)}) for arg in [1, np.array(1), expected["x"]]: @@ -574,9 +627,51 @@ def test_constructor_with_coords(self) -> None: [["a", "b"], [1, 2]], names=("level_1", "level_2") ) with pytest.raises(ValueError, match=r"conflicting MultiIndex"): - Dataset({}, {"x": mindex, "y": mindex}) - Dataset({}, {"x": mindex, "level_1": range(4)}) + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + Dataset({}, {"x": mindex, "y": mindex}) + Dataset({}, {"x": mindex, "level_1": range(4)}) + + def test_constructor_no_default_index(self) -> None: + # explicitly passing a Coordinates object skips the creation of default index + ds = Dataset(coords=Coordinates({"x": [1, 2, 3]}, indexes={})) + assert "x" in ds + assert "x" not in ds.xindexes + + def test_constructor_multiindex(self) -> None: + midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + coords = Coordinates.from_pandas_multiindex(midx, "x") + + ds = Dataset(coords=coords) + assert_identical(ds, coords.to_dataset()) + + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + Dataset(data_vars={"x": midx}) + + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + Dataset(coords={"x": midx}) + + def test_constructor_custom_index(self) -> None: + class CustomIndex(Index): ... + + coords = Coordinates( + coords={"x": ("x", [1, 2, 3])}, indexes={"x": CustomIndex()} + ) + ds = Dataset(coords=coords) + assert isinstance(ds.xindexes["x"], CustomIndex) + # test coordinate variables copied + assert ds.variables["x"] is not coords.variables["x"] + + @pytest.mark.filterwarnings("ignore:return type") def test_properties(self) -> None: ds = create_test_data() @@ -584,10 +679,15 @@ def test_properties(self) -> None: # These exact types aren't public API, but this makes sure we don't # change them inadvertently: assert isinstance(ds.dims, utils.Frozen) + # TODO change after deprecation cycle in GH #8500 is complete assert isinstance(ds.dims.mapping, dict) - assert type(ds.dims.mapping) is dict - assert ds.dims == {"dim1": 8, "dim2": 9, "dim3": 10, "time": 20} - assert ds.sizes == ds.dims + assert type(ds.dims.mapping) is dict # noqa: E721 + with pytest.warns( + FutureWarning, + match=" To access a mapping from dimension names to lengths, please use `Dataset.sizes`", + ): + assert ds.dims == ds.sizes + assert ds.sizes == {"dim1": 8, "dim2": 9, "dim3": 10, "time": 20} # dtypes assert isinstance(ds.dtypes, utils.Frozen) @@ -639,6 +739,27 @@ def test_properties(self) -> None: == 16 ) + def test_warn_ds_dims_deprecation(self) -> None: + # TODO remove after deprecation cycle in GH #8500 is complete + ds = create_test_data() + + with pytest.warns(FutureWarning, match="return type"): + ds.dims["dim1"] + + with pytest.warns(FutureWarning, match="return type"): + ds.dims.keys() + + with pytest.warns(FutureWarning, match="return type"): + ds.dims.values() + + with pytest.warns(FutureWarning, match="return type"): + ds.dims.items() + + with assert_no_warnings(): + len(ds.dims) + ds.dims.__iter__() + "dim1" in ds.dims + def test_asarray(self) -> None: ds = Dataset({"x": 0}) with pytest.raises(TypeError, match=r"cannot directly convert"): @@ -694,7 +815,7 @@ def test_modify_inplace(self) -> None: b = Dataset() b["x"] = ("x", vec, attributes) assert_identical(a["x"], b["x"]) - assert a.dims == b.dims + assert a.sizes == b.sizes # this should work a["x"] = ("x", vec[:5]) a["z"] = ("x", np.arange(5)) @@ -746,16 +867,16 @@ def test_coords_properties(self) -> None: expected = dedent( """\ Coordinates: - * x (x) int64 -1 -2 - * y (y) int64 0 1 2 - a (x) int64 4 5 - b int64 -10""" + * x (x) int64 16B -1 -2 + * y (y) int64 24B 0 1 2 + a (x) int64 16B 4 5 + b int64 8B -10""" ) actual = repr(coords) assert expected == actual # dims - assert coords.dims == {"x": 2, "y": 3} + assert coords.sizes == {"x": 2, "y": 3} # dtypes assert coords.dtypes == { @@ -784,7 +905,7 @@ def test_coords_modify(self) -> None: assert_array_equal(actual["z"], ["a", "b"]) actual = data.copy(deep=True) - with pytest.raises(ValueError, match=r"conflicting sizes"): + with pytest.raises(ValueError, match=r"conflicting dimension sizes"): actual.coords["x"] = ("x", [-1]) assert_identical(actual, data) # should not be modified @@ -821,9 +942,7 @@ def test_coords_setitem_with_new_dimension(self) -> None: def test_coords_setitem_multiindex(self) -> None: data = create_test_multiindex() - with pytest.raises( - ValueError, match=r"cannot set or update variable.*corrupt.*index " - ): + with pytest.raises(ValueError, match=r"cannot drop or update.*corrupt.*index "): data.coords["level_1"] = range(4) def test_coords_set(self) -> None: @@ -941,8 +1060,8 @@ def test_data_vars_properties(self) -> None: expected = dedent( """\ Data variables: - foo (x) float64 1.0 - bar float64 2.0""" + foo (x) float64 8B 1.0 + bar float64 8B 2.0""" ) actual = repr(ds.data_vars) assert expected == actual @@ -953,6 +1072,17 @@ def test_data_vars_properties(self) -> None: "bar": np.dtype("float64"), } + # len + ds.coords["x"] = [1] + assert len(ds.data_vars) == 2 + + # https://github.com/pydata/xarray/issues/7588 + with pytest.raises( + AssertionError, match="something is wrong with Dataset._coord_names" + ): + ds._coord_names = {"w", "x", "y", "z"} + len(ds.data_vars) + def test_equals_and_identical(self) -> None: data = create_test_data(seed=42) assert data.equals(data) @@ -1054,7 +1184,12 @@ def get_dask_names(ds): for k, v in new_dask_names.items(): assert v == orig_dask_names[k] - with pytest.raises(ValueError, match=r"some chunks"): + with pytest.raises( + ValueError, + match=re.escape( + "chunks keys ('foo',) not found in data dimensions ('dim2', 'dim3', 'time', 'dim1')" + ), + ): data.chunk({"foo": 10}) @requires_dask @@ -1091,9 +1226,9 @@ def test_isel(self) -> None: assert list(data.dims) == list(ret.dims) for d in data.dims: if d in slicers: - assert ret.dims[d] == np.arange(data.dims[d])[slicers[d]].size + assert ret.sizes[d] == np.arange(data.sizes[d])[slicers[d]].size else: - assert data.dims[d] == ret.dims[d] + assert data.sizes[d] == ret.sizes[d] # Verify that the data is what we expect for v in data.variables: assert data[v].dims == ret[v].dims @@ -1127,19 +1262,19 @@ def test_isel(self) -> None: assert_identical(data, data.isel(not_a_dim=slice(0, 2), missing_dims="ignore")) ret = data.isel(dim1=0) - assert {"time": 20, "dim2": 9, "dim3": 10} == ret.dims + assert {"time": 20, "dim2": 9, "dim3": 10} == ret.sizes assert set(data.data_vars) == set(ret.data_vars) assert set(data.coords) == set(ret.coords) assert set(data.xindexes) == set(ret.xindexes) ret = data.isel(time=slice(2), dim1=0, dim2=slice(5)) - assert {"time": 2, "dim2": 5, "dim3": 10} == ret.dims + assert {"time": 2, "dim2": 5, "dim3": 10} == ret.sizes assert set(data.data_vars) == set(ret.data_vars) assert set(data.coords) == set(ret.coords) assert set(data.xindexes) == set(ret.xindexes) ret = data.isel(time=0, dim1=0, dim2=slice(5)) - assert {"dim2": 5, "dim3": 10} == ret.dims + assert {"dim2": 5, "dim3": 10} == ret.sizes assert set(data.data_vars) == set(ret.data_vars) assert set(data.coords) == set(ret.coords) assert set(data.xindexes) == set(list(ret.xindexes) + ["time"]) @@ -1500,9 +1635,11 @@ def test_sel_dataarray(self) -> None: def test_sel_dataarray_mindex(self) -> None: midx = pd.MultiIndex.from_product([list("abc"), [0, 1]], names=("one", "two")) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + midx_coords["y"] = range(3) + mds = xr.Dataset( - {"var": (("x", "y"), np.random.rand(6, 3))}, - coords={"x": midx, "y": range(3)}, + {"var": (("x", "y"), np.random.rand(6, 3))}, coords=midx_coords ) actual_isel = mds.isel(x=xr.DataArray(np.arange(3), dims="x")) @@ -1620,7 +1757,8 @@ def test_sel_drop(self) -> None: def test_sel_drop_mindex(self) -> None: midx = pd.MultiIndex.from_arrays([["a", "a"], [1, 2]], names=("foo", "bar")) - data = Dataset(coords={"x": midx}) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + data = Dataset(coords=midx_coords) actual = data.sel(foo="a", drop=True) assert "foo" not in actual.coords @@ -1841,10 +1979,11 @@ def test_loc(self) -> None: data.loc["a"] # type: ignore[index] def test_selection_multiindex(self) -> None: - mindex = pd.MultiIndex.from_product( + midx = pd.MultiIndex.from_product( [["a", "b"], [1, 2], [-1, -2]], names=("one", "two", "three") ) - mdata = Dataset(data_vars={"var": ("x", range(8))}, coords={"x": mindex}) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + mdata = Dataset(data_vars={"var": ("x", range(8))}, coords=midx_coords) def test_sel( lab_indexer, pos_indexer, replaced_idx=False, renamed_dim=None @@ -2204,9 +2343,9 @@ def test_align(self) -> None: assert np.isnan(left2["var3"][-2:]).all() with pytest.raises(ValueError, match=r"invalid value for join"): - align(left, right, join="foobar") # type: ignore[arg-type] + align(left, right, join="foobar") # type: ignore[call-overload] with pytest.raises(TypeError): - align(left, right, foo="bar") # type: ignore[call-arg] + align(left, right, foo="bar") # type: ignore[call-overload] def test_align_exact(self) -> None: left = xr.Dataset(coords={"x": [0, 1]}) @@ -2323,10 +2462,10 @@ def test_align_str_dtype(self) -> None: b = Dataset({"foo": ("x", [1, 2])}, coords={"x": ["b", "c"]}) expected_a = Dataset( - {"foo": ("x", [0, 1, np.NaN])}, coords={"x": ["a", "b", "c"]} + {"foo": ("x", [0, 1, np.nan])}, coords={"x": ["a", "b", "c"]} ) expected_b = Dataset( - {"foo": ("x", [np.NaN, 1, 2])}, coords={"x": ["a", "b", "c"]} + {"foo": ("x", [np.nan, 1, 2])}, coords={"x": ["a", "b", "c"]} ) actual_a, actual_b = xr.align(a, b, join="outer") @@ -2523,19 +2662,19 @@ def test_drop_variables(self) -> None: # deprecated approach with `drop` works (straight copy paste from above) - with pytest.warns(PendingDeprecationWarning): + with pytest.warns(DeprecationWarning): actual = data.drop("not_found_here", errors="ignore") assert_identical(data, actual) - with pytest.warns(PendingDeprecationWarning): + with pytest.warns(DeprecationWarning): actual = data.drop(["not_found_here"], errors="ignore") assert_identical(data, actual) - with pytest.warns(PendingDeprecationWarning): + with pytest.warns(DeprecationWarning): actual = data.drop(["time", "not_found_here"], errors="ignore") assert_identical(expected, actual) - with pytest.warns(PendingDeprecationWarning): + with pytest.warns(DeprecationWarning): actual = data.drop({"time", "not_found_here"}, errors="ignore") assert_identical(expected, actual) @@ -2569,8 +2708,7 @@ def test_drop_index_labels(self) -> None: assert_identical(data, actual) with pytest.raises(ValueError): - with pytest.warns(DeprecationWarning): - data.drop(["c"], dim="x", errors="wrong_value") # type: ignore[arg-type] + data.drop(["c"], dim="x", errors="wrong_value") # type: ignore[arg-type] with pytest.warns(DeprecationWarning): actual = data.drop(["a", "b", "c"], "x", errors="ignore") @@ -2608,9 +2746,9 @@ def test_drop_labels_by_keyword(self) -> None: ds5 = data.drop_sel(x=["a", "b"], y=range(0, 6, 2)) arr = DataArray(range(3), dims=["c"]) - with pytest.warns(FutureWarning): + with pytest.warns(DeprecationWarning): data.drop(arr.coords) - with pytest.warns(FutureWarning): + with pytest.warns(DeprecationWarning): data.drop(arr.xindexes) assert_array_equal(ds1.coords["x"], ["b"]) @@ -2676,7 +2814,10 @@ def test_drop_indexes(self) -> None: assert type(actual.x.variable) is Variable assert type(actual.y.variable) is Variable - with pytest.raises(ValueError, match="those coordinates don't exist"): + with pytest.raises( + ValueError, + match=r"The coordinates \('not_a_coord',\) are not found in the dataset coordinates", + ): ds.drop_indexes("not_a_coord") with pytest.raises(ValueError, match="those coordinates do not have an index"): @@ -2686,8 +2827,9 @@ def test_drop_indexes(self) -> None: assert_identical(actual, ds) # test index corrupted - mindex = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) - ds = Dataset(coords={"x": mindex}) + midx = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + ds = Dataset(coords=midx_coords) with pytest.raises(ValueError, match=".*would corrupt the following index.*"): ds.drop_indexes("a") @@ -2806,10 +2948,11 @@ def test_copy_coords(self, deep, expected_orig) -> None: name="value", ).to_dataset() ds_cp = ds.copy(deep=deep) - ds_cp.coords["a"].data[0] = 999 + new_a = np.array([999, 2]) + ds_cp.coords["a"] = ds_cp.a.copy(data=new_a) expected_cp = xr.DataArray( - xr.IndexVariable("a", np.array([999, 2])), + xr.IndexVariable("a", new_a), coords={"a": [999, 2]}, dims=["a"], ) @@ -2827,6 +2970,21 @@ def test_copy_with_data_errors(self) -> None: with pytest.raises(ValueError, match=r"contain all variables in original"): orig.copy(data={"var1": new_var1}) + def test_drop_encoding(self) -> None: + orig = create_test_data() + vencoding = {"scale_factor": 10} + orig.encoding = {"foo": "bar"} + + for k, v in orig.variables.items(): + orig[k].encoding = vencoding + + actual = orig.drop_encoding() + assert actual.encoding == {} + for k, v in actual.variables.items(): + assert v.encoding == {} + + assert_equal(actual, orig) + def test_rename(self) -> None: data = create_test_data() newnames = { @@ -2885,8 +3043,7 @@ def test_rename_old_name(self) -> None: def test_rename_same_name(self) -> None: data = create_test_data() newnames = {"var1": "var1", "dim2": "dim2"} - with pytest.warns(UserWarning, match="does not create an index anymore"): - renamed = data.rename(newnames) + renamed = data.rename(newnames) assert_identical(renamed, data) def test_rename_dims(self) -> None: @@ -2956,10 +3113,23 @@ def test_rename_dimension_coord_warnings(self) -> None: ): ds.rename(x="y") + # No operation should not raise a warning + ds = Dataset( + data_vars={"data": (("x", "y"), np.ones((2, 3)))}, + coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])}, + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + ds.rename(x="x") + def test_rename_multiindex(self) -> None: - mindex = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) - original = Dataset({}, {"x": mindex}) - expected = Dataset({}, {"x": mindex.rename(["a", "c"])}) + midx = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + original = Dataset({}, midx_coords) + + midx_renamed = midx.rename(["a", "c"]) + midx_coords_renamed = Coordinates.from_pandas_multiindex(midx_renamed, "x") + expected = Dataset({}, midx_coords_renamed) actual = original.rename({"b": "c"}) assert_identical(expected, actual) @@ -2973,8 +3143,7 @@ def test_rename_multiindex(self) -> None: original.rename({"a": "x"}) with pytest.raises(ValueError, match=r"'b' conflicts"): - with pytest.warns(UserWarning, match="does not create an index anymore"): - original.rename({"a": "b"}) + original.rename({"a": "b"}) def test_rename_perserve_attrs_encoding(self) -> None: # test propagate attrs/encoding to new variable(s) created from Index object @@ -3069,9 +3238,14 @@ def test_swap_dims(self) -> None: assert_identical(expected, actual) # handle multiindex case - idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"]) - original = Dataset({"x": [1, 2, 3], "y": ("x", idx), "z": 42}) - expected = Dataset({"z": 42}, {"x": ("y", [1, 2, 3]), "y": idx}) + midx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"]) + + original = Dataset({"x": [1, 2, 3], "y": ("x", midx), "z": 42}) + + midx_coords = Coordinates.from_pandas_multiindex(midx, "y") + midx_coords["x"] = ("y", [1, 2, 3]) + expected = Dataset({"z": 42}, midx_coords) + actual = original.swap_dims({"x": "y"}) assert_identical(expected, actual) assert isinstance(actual.variables["y"], IndexVariable) @@ -3258,6 +3432,13 @@ def test_expand_dims_kwargs_python36plus(self) -> None: ) assert_identical(other_way_expected, other_way) + @requires_pandas_version_two + def test_expand_dims_non_nanosecond_conversion(self) -> None: + # Regression test for https://github.com/pydata/xarray/issues/7493#issuecomment-1953091000 + with pytest.warns(UserWarning, match="non-nanosecond precision"): + ds = Dataset().expand_dims({"time": [np.datetime64("2018-01-01", "s")]}) + assert ds.time.dtype == np.dtype("datetime64[ns]") + def test_set_index(self) -> None: expected = create_test_multiindex() mindex = expected["x"].to_index() @@ -3289,6 +3470,12 @@ def test_set_index(self) -> None: with pytest.raises(ValueError, match=r"dimension mismatch.*"): ds.set_index(y="x_var") + ds = Dataset(coords={"x": 1}) + with pytest.raises( + ValueError, match=r".*cannot set a PandasIndex.*scalar variable.*" + ): + ds.set_index(x="x") + def test_set_index_deindexed_coords(self) -> None: # test de-indexed coordinates are converted to base variable # https://github.com/pydata/xarray/issues/6969 @@ -3297,16 +3484,20 @@ def test_set_index_deindexed_coords(self) -> None: three = ["c", "c", "d", "d"] four = [3, 4, 3, 4] - mindex_12 = pd.MultiIndex.from_arrays([one, two], names=["one", "two"]) - mindex_34 = pd.MultiIndex.from_arrays([three, four], names=["three", "four"]) + midx_12 = pd.MultiIndex.from_arrays([one, two], names=["one", "two"]) + midx_34 = pd.MultiIndex.from_arrays([three, four], names=["three", "four"]) - ds = xr.Dataset( - coords={"x": mindex_12, "three": ("x", three), "four": ("x", four)} - ) + coords = Coordinates.from_pandas_multiindex(midx_12, "x") + coords["three"] = ("x", three) + coords["four"] = ("x", four) + ds = xr.Dataset(coords=coords) actual = ds.set_index(x=["three", "four"]) - expected = xr.Dataset( - coords={"x": mindex_34, "one": ("x", one), "two": ("x", two)} - ) + + coords_expected = Coordinates.from_pandas_multiindex(midx_34, "x") + coords_expected["one"] = ("x", one) + coords_expected["two"] = ("x", two) + expected = xr.Dataset(coords=coords_expected) + assert_identical(actual, expected) def test_reset_index(self) -> None: @@ -3337,7 +3528,7 @@ def test_reset_index_drop_dims(self) -> None: assert len(reset.dims) == 0 @pytest.mark.parametrize( - "arg,drop,dropped,converted,renamed", + ["arg", "drop", "dropped", "converted", "renamed"], [ ("foo", False, [], [], {"bar": "x"}), ("foo", True, ["foo"], [], {"bar": "x"}), @@ -3350,14 +3541,20 @@ def test_reset_index_drop_dims(self) -> None: ], ) def test_reset_index_drop_convert( - self, arg, drop, dropped, converted, renamed + self, + arg: str | list[str], + drop: bool, + dropped: list[str], + converted: list[str], + renamed: dict[str, str], ) -> None: # regressions https://github.com/pydata/xarray/issues/6946 and # https://github.com/pydata/xarray/issues/6989 # check that multi-index dimension or level coordinates are dropped, converted # from IndexVariable to Variable or renamed to dimension as expected midx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("foo", "bar")) - ds = xr.Dataset(coords={"x": midx}) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + ds = xr.Dataset(coords=midx_coords) reset = ds.reset_index(arg, drop=drop) for name in dropped: @@ -3371,7 +3568,8 @@ def test_reorder_levels(self) -> None: ds = create_test_multiindex() mindex = ds["x"].to_index() midx = mindex.reorder_levels(["level_2", "level_1"]) - expected = Dataset({}, coords={"x": midx}) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + expected = Dataset({}, coords=midx_coords) # check attrs propagated ds["level_1"].attrs["foo"] = "bar" @@ -3397,8 +3595,7 @@ def test_set_xindex(self) -> None: expected_mindex = ds.set_index(x=["foo", "bar"]) assert_identical(actual_mindex, expected_mindex) - class NotAnIndex: - ... + class NotAnIndex: ... with pytest.raises(TypeError, match=".*not a subclass of xarray.Index"): ds.set_xindex("foo", NotAnIndex) # type: ignore @@ -3436,10 +3633,12 @@ def test_stack(self) -> None: coords={"x": ("x", [0, 1]), "y": ["a", "b"]}, ) - exp_index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["x", "y"]) + midx_expected = pd.MultiIndex.from_product( + [[0, 1], ["a", "b"]], names=["x", "y"] + ) + midx_coords_expected = Coordinates.from_pandas_multiindex(midx_expected, "z") expected = Dataset( - data_vars={"b": ("z", [0, 1, 2, 3])}, - coords={"z": exp_index}, + data_vars={"b": ("z", [0, 1, 2, 3])}, coords=midx_coords_expected ) # check attrs propagated ds["x"].attrs["foo"] = "bar" @@ -3460,10 +3659,12 @@ def test_stack(self) -> None: actual = ds.stack(z=[..., "y"]) assert_identical(expected, actual) - exp_index = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=["y", "x"]) + midx_expected = pd.MultiIndex.from_product( + [["a", "b"], [0, 1]], names=["y", "x"] + ) + midx_coords_expected = Coordinates.from_pandas_multiindex(midx_expected, "z") expected = Dataset( - data_vars={"b": ("z", [0, 2, 1, 3])}, - coords={"z": exp_index}, + data_vars={"b": ("z", [0, 2, 1, 3])}, coords=midx_coords_expected ) expected["x"].attrs["foo"] = "bar" @@ -3494,9 +3695,11 @@ def test_stack_create_index(self, create_index, expected_keys) -> None: def test_stack_multi_index(self) -> None: # multi-index on a dimension to stack is discarded too midx = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=("lvl1", "lvl2")) + coords = Coordinates.from_pandas_multiindex(midx, "x") + coords["y"] = [0, 1] ds = xr.Dataset( data_vars={"b": (("x", "y"), [[0, 1], [2, 3], [4, 5], [6, 7]])}, - coords={"x": midx, "y": [0, 1]}, + coords=coords, ) expected = Dataset( data_vars={"b": ("z", [0, 1, 2, 3, 4, 5, 6, 7])}, @@ -3521,10 +3724,8 @@ def test_stack_non_dim_coords(self) -> None: ).rename_vars(x="xx") exp_index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["xx", "y"]) - expected = Dataset( - data_vars={"b": ("z", [0, 1, 2, 3])}, - coords={"z": exp_index}, - ) + exp_coords = Coordinates.from_pandas_multiindex(exp_index, "z") + expected = Dataset(data_vars={"b": ("z", [0, 1, 2, 3])}, coords=exp_coords) actual = ds.stack(z=["x", "y"]) assert_identical(expected, actual) @@ -3532,7 +3733,8 @@ def test_stack_non_dim_coords(self) -> None: def test_unstack(self) -> None: index = pd.MultiIndex.from_product([[0, 1], ["a", "b"]], names=["x", "y"]) - ds = Dataset(data_vars={"b": ("z", [0, 1, 2, 3])}, coords={"z": index}) + coords = Coordinates.from_pandas_multiindex(index, "z") + ds = Dataset(data_vars={"b": ("z", [0, 1, 2, 3])}, coords=coords) expected = Dataset( {"b": (("x", "y"), [[0, 1], [2, 3]]), "x": [0, 1], "y": ["a", "b"]} ) @@ -3547,11 +3749,22 @@ def test_unstack(self) -> None: def test_unstack_errors(self) -> None: ds = Dataset({"x": [1, 2, 3]}) - with pytest.raises(ValueError, match=r"does not contain the dimensions"): + with pytest.raises( + ValueError, + match=re.escape("Dimensions ('foo',) not found in data dimensions ('x',)"), + ): ds.unstack("foo") with pytest.raises(ValueError, match=r".*do not have exactly one multi-index"): ds.unstack("x") + ds = Dataset({"da": [1, 2]}, coords={"y": ("x", [1, 1]), "z": ("x", [0, 0])}) + ds = ds.set_index(x=("y", "z")) + + with pytest.raises( + ValueError, match="Cannot unstack MultiIndex containing duplicates" + ): + ds.unstack("x") + def test_unstack_fill_value(self) -> None: ds = xr.Dataset( {"var": (("x",), np.arange(6)), "other_var": (("x",), np.arange(3, 9))}, @@ -3594,12 +3807,12 @@ def test_unstack_sparse(self) -> None: assert actual2.variable._to_dense().equals(expected2.variable) assert actual2.data.density < 1.0 - mindex = pd.MultiIndex.from_arrays( - [np.arange(3), np.arange(3)], names=["a", "b"] - ) + midx = pd.MultiIndex.from_arrays([np.arange(3), np.arange(3)], names=["a", "b"]) + coords = Coordinates.from_pandas_multiindex(midx, "z") + coords["foo"] = np.arange(4) + coords["bar"] = np.arange(5) ds_eye = Dataset( - {"var": (("z", "foo", "bar"), np.ones((3, 4, 5)))}, - coords={"z": mindex, "foo": np.arange(4), "bar": np.arange(5)}, + {"var": (("z", "foo", "bar"), np.ones((3, 4, 5)))}, coords=coords ) actual3 = ds_eye.unstack(sparse=True, fill_value=0) assert isinstance(actual3["var"].data, sparse_array_type) @@ -3656,7 +3869,10 @@ def test_to_stacked_array_invalid_sample_dims(self) -> None: data_vars={"a": (("x", "y"), [[0, 1, 2], [3, 4, 5]]), "b": ("x", [6, 7])}, coords={"y": ["u", "v", "w"]}, ) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=r"Variables in the dataset must contain all ``sample_dims`` \(\['y'\]\) but 'b' misses \['y'\]", + ): data.to_stacked_array("features", sample_dims=["y"]) def test_to_stacked_array_name(self) -> None: @@ -3845,14 +4061,15 @@ def test_virtual_variables_time(self) -> None: def test_virtual_variable_same_name(self) -> None: # regression test for GH367 - times = pd.date_range("2000-01-01", freq="H", periods=5) + times = pd.date_range("2000-01-01", freq="h", periods=5) data = Dataset({"time": times}) actual = data["time.time"] expected = DataArray(times.time, [("time", times)], name="time") assert_identical(actual, expected) def test_time_season(self) -> None: - ds = Dataset({"t": pd.date_range("2000-01-01", periods=12, freq="M")}) + time = xr.date_range("2000-01-01", periods=12, freq="ME", use_cftime=False) + ds = Dataset({"t": time}) seas = ["DJF"] * 2 + ["MAM"] * 3 + ["JJA"] * 3 + ["SON"] * 3 + ["DJF"] assert_array_equal(seas, ds["t.season"]) @@ -3936,7 +4153,8 @@ def test_setitem(self) -> None: data4[{"dim2": [2, 3]}] = data3["var1"][{"dim2": [3, 4]}].values data5 = data4.astype(str) data5["var4"] = data4["var1"] - err_msg = "could not convert string to float: 'a'" + # convert to `np.str_('a')` once `numpy<2.0` has been dropped + err_msg = "could not convert string to float: .*'a'.*" with pytest.raises(ValueError, match=err_msg): data5[{"dim2": 1}] = "a" @@ -4057,6 +4275,29 @@ def test_setitem_align_new_indexes(self) -> None: ) assert_identical(ds, expected) + def test_setitem_vectorized(self) -> None: + # Regression test for GH:7030 + # Positional indexing + da = xr.DataArray(np.r_[:120].reshape(2, 3, 4, 5), dims=["a", "b", "c", "d"]) + ds = xr.Dataset({"da": da}) + b = xr.DataArray([[0, 0], [1, 0]], dims=["u", "v"]) + c = xr.DataArray([[0, 1], [2, 3]], dims=["u", "v"]) + w = xr.DataArray([-1, -2], dims=["u"]) + index = dict(b=b, c=c) + ds[index] = xr.Dataset({"da": w}) + assert (ds[index]["da"] == w).all() + + # Indexing with coordinates + da = xr.DataArray(np.r_[:120].reshape(2, 3, 4, 5), dims=["a", "b", "c", "d"]) + ds = xr.Dataset({"da": da}) + ds.coords["b"] = [2, 4, 6] + b = xr.DataArray([[2, 2], [4, 2]], dims=["u", "v"]) + c = xr.DataArray([[0, 1], [2, 3]], dims=["u", "v"]) + w = xr.DataArray([-1, -2], dims=["u"]) + index = dict(b=b, c=c) + ds.loc[index] = xr.Dataset({"da": w}, coords={"b": ds.coords["b"]}) + assert (ds.loc[index]["da"] == w).all() + @pytest.mark.parametrize("dtype", [str, bytes]) def test_setitem_str_dtype(self, dtype) -> None: ds = xr.Dataset(coords={"x": np.array(["x", "y"], dtype=dtype)}) @@ -4141,22 +4382,58 @@ def test_assign_attrs(self) -> None: def test_assign_multiindex_level(self) -> None: data = create_test_multiindex() - with pytest.raises( - ValueError, match=r"cannot set or update variable.*corrupt.*index " - ): + with pytest.raises(ValueError, match=r"cannot drop or update.*corrupt.*index "): data.assign(level_1=range(4)) data.assign_coords(level_1=range(4)) + def test_assign_new_multiindex(self) -> None: + midx = pd.MultiIndex.from_arrays([["a", "a", "b", "b"], [0, 1, 0, 1]]) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + + ds = Dataset(coords={"x": [1, 2]}) + expected = Dataset(coords=midx_coords) + + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + actual = ds.assign(x=midx) + assert_identical(actual, expected) + + @pytest.mark.parametrize("orig_coords", [{}, {"x": range(4)}]) + def test_assign_coords_new_multiindex(self, orig_coords) -> None: + ds = Dataset(coords=orig_coords) + midx = pd.MultiIndex.from_arrays( + [["a", "a", "b", "b"], [0, 1, 0, 1]], names=("one", "two") + ) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + + expected = Dataset(coords=midx_coords) + + with pytest.warns( + FutureWarning, + match=".*`pandas.MultiIndex`.*no longer be implicitly promoted.*", + ): + actual = ds.assign_coords({"x": midx}) + assert_identical(actual, expected) + + actual = ds.assign_coords(midx_coords) + assert_identical(actual, expected) + def test_assign_coords_existing_multiindex(self) -> None: data = create_test_multiindex() - with pytest.warns(FutureWarning, match=r"Updating MultiIndexed coordinate"): - data.assign_coords(x=range(4)) - - with pytest.warns(FutureWarning, match=r"Updating MultiIndexed coordinate"): - data.assign(x=range(4)) + with pytest.warns( + FutureWarning, match=r"updating coordinate.*MultiIndex.*inconsistent" + ): + updated = data.assign_coords(x=range(4)) + # https://github.com/pydata/xarray/issues/7097 (coord names updated) + assert len(updated.coords) == 1 + with pytest.warns( + FutureWarning, match=r"updating coordinate.*MultiIndex.*inconsistent" + ): + updated = data.assign(x=range(4)) # https://github.com/pydata/xarray/issues/7097 (coord names updated) - updated = data.assign_coords(x=range(4)) assert len(updated.coords) == 1 def test_assign_all_multiindex_coords(self) -> None: @@ -4183,6 +4460,25 @@ class CustomIndex(PandasIndex): actual = ds.assign_coords(y=[4, 5, 6]) assert isinstance(actual.xindexes["x"], CustomIndex) + def test_assign_coords_custom_index(self) -> None: + class CustomIndex(Index): + pass + + coords = Coordinates( + coords={"x": ("x", [1, 2, 3])}, indexes={"x": CustomIndex()} + ) + ds = Dataset() + actual = ds.assign_coords(coords) + assert isinstance(actual.xindexes["x"], CustomIndex) + + def test_assign_coords_no_default_index(self) -> None: + coords = Coordinates({"y": [1, 2, 3]}, indexes={}) + ds = Dataset() + actual = ds.assign_coords(coords) + expected = coords.to_dataset() + assert_identical(expected, actual, check_default_indexes=False) + assert "y" not in actual.xindexes + def test_merge_multiindex_level(self) -> None: data = create_test_multiindex() @@ -4298,7 +4594,7 @@ def test_squeeze_drop(self) -> None: selected = data.squeeze(drop=True) assert_identical(data, selected) - def test_to_array(self) -> None: + def test_to_dataarray(self) -> None: ds = Dataset( {"a": 1, "b": ("x", [1, 2, 3])}, coords={"c": 42}, @@ -4308,10 +4604,10 @@ def test_to_array(self) -> None: coords = {"c": 42, "variable": ["a", "b"]} dims = ("variable", "x") expected = DataArray(data, coords, dims, attrs=ds.attrs) - actual = ds.to_array() + actual = ds.to_dataarray() assert_identical(expected, actual) - actual = ds.to_array("abc", name="foo") + actual = ds.to_dataarray("abc", name="foo") expected = expected.rename({"variable": "abc"}).rename("foo") assert_identical(expected, actual) @@ -4426,6 +4722,17 @@ def test_from_dataframe_categorical(self) -> None: assert len(ds["i1"]) == 2 assert len(ds["i2"]) == 2 + def test_from_dataframe_categorical_string_categories(self) -> None: + cat = pd.CategoricalIndex( + pd.Categorical.from_codes( + np.array([1, 1, 0, 2]), + categories=pd.Index(["foo", "bar", "baz"], dtype="string"), + ) + ) + ser = pd.Series(1, index=cat) + ds = ser.to_xarray() + assert ds.coords.dtypes["index"] == np.dtype("O") + @requires_sparse def test_from_dataframe_sparse(self) -> None: import sparse @@ -4518,7 +4825,7 @@ def test_convert_dataframe_with_many_types_and_multiindex(self) -> None: "e": [True, False, True], "f": pd.Categorical(list("abc")), "g": pd.date_range("20130101", periods=3), - "h": pd.date_range("20130101", periods=3, tz="US/Eastern"), + "h": pd.date_range("20130101", periods=3, tz="America/New_York"), } ) df.index = pd.MultiIndex.from_product([["a"], range(3)], names=["one", "two"]) @@ -4528,7 +4835,11 @@ def test_convert_dataframe_with_many_types_and_multiindex(self) -> None: expected = df.apply(np.asarray) assert roundtripped.equals(expected) - def test_to_and_from_dict(self) -> None: + @pytest.mark.parametrize("encoding", [True, False]) + @pytest.mark.parametrize("data", [True, "list", "array"]) + def test_to_and_from_dict( + self, encoding: bool, data: bool | Literal["list", "array"] + ) -> None: # # Dimensions: (t: 10) # Coordinates: @@ -4549,14 +4860,25 @@ def test_to_and_from_dict(self) -> None: "b": {"dims": ("t",), "data": y.tolist(), "attrs": {}}, }, } + if encoding: + ds.t.encoding.update({"foo": "bar"}) + expected["encoding"] = {} + expected["coords"]["t"]["encoding"] = ds.t.encoding + for vvs in ["a", "b"]: + expected["data_vars"][vvs]["encoding"] = {} - actual = ds.to_dict() + actual = ds.to_dict(data=data, encoding=encoding) # check that they are identical - assert expected == actual + np.testing.assert_equal(expected, actual) # check roundtrip - assert_identical(ds, Dataset.from_dict(actual)) + ds_rt = Dataset.from_dict(actual) + assert_identical(ds, ds_rt) + if encoding: + assert set(ds_rt.variables) == set(ds.variables) + for vv in ds.variables: + np.testing.assert_equal(ds_rt[vv].encoding, ds[vv].encoding) # check the data=False option expected_no_data = expected.copy() @@ -4567,14 +4889,18 @@ def test_to_and_from_dict(self) -> None: expected_no_data["coords"]["t"].update({"dtype": endiantype, "shape": (10,)}) expected_no_data["data_vars"]["a"].update({"dtype": "float64", "shape": (10,)}) expected_no_data["data_vars"]["b"].update({"dtype": "float64", "shape": (10,)}) - actual_no_data = ds.to_dict(data=False) + actual_no_data = ds.to_dict(data=False, encoding=encoding) assert expected_no_data == actual_no_data # verify coords are included roundtrip expected_ds = ds.set_coords("b") - actual2 = Dataset.from_dict(expected_ds.to_dict()) + actual2 = Dataset.from_dict(expected_ds.to_dict(data=data, encoding=encoding)) assert_identical(expected_ds, actual2) + if encoding: + assert set(expected_ds.variables) == set(actual2.variables) + for vv in ds.variables: + np.testing.assert_equal(expected_ds[vv].encoding, actual2[vv].encoding) # test some incomplete dicts: # this one has no attrs field, the dims are strings, and x, y are @@ -4622,7 +4948,10 @@ def test_to_and_from_dict_with_time_dim(self) -> None: roundtripped = Dataset.from_dict(ds.to_dict()) assert_identical(ds, roundtripped) - def test_to_and_from_dict_with_nan_nat(self) -> None: + @pytest.mark.parametrize("data", [True, "list", "array"]) + def test_to_and_from_dict_with_nan_nat( + self, data: bool | Literal["list", "array"] + ) -> None: x = np.random.randn(10, 3) y = np.random.randn(10, 3) y[2] = np.nan @@ -4638,7 +4967,7 @@ def test_to_and_from_dict_with_nan_nat(self) -> None: "lat": ("lat", lat), } ) - roundtripped = Dataset.from_dict(ds.to_dict()) + roundtripped = Dataset.from_dict(ds.to_dict(data=data)) assert_identical(ds, roundtripped) def test_to_dict_with_numpy_attrs(self) -> None: @@ -4667,7 +4996,7 @@ def test_pickle(self) -> None: roundtripped = pickle.loads(pickle.dumps(data)) assert_identical(data, roundtripped) # regression test for #167: - assert data.dims == roundtripped.dims + assert data.sizes == roundtripped.sizes def test_lazy_load(self) -> None: store = InaccessibleVariableDataStore() @@ -4684,6 +5013,29 @@ def test_lazy_load(self) -> None: ds.isel(time=10) ds.isel(time=slice(10), dim1=[0]).isel(dim1=0, dim2=-1) + def test_lazy_load_duck_array(self) -> None: + store = AccessibleAsDuckArrayDataStore() + create_test_data().dump_to_store(store) + + for decode_cf in [True, False]: + ds = open_dataset(store, decode_cf=decode_cf) + with pytest.raises(UnexpectedDataAccess): + ds["var1"].values + + # these should not raise UnexpectedDataAccess: + ds.var1.data + ds.isel(time=10) + ds.isel(time=slice(10), dim1=[0]).isel(dim1=0, dim2=-1) + repr(ds) + + # preserve the duck array type and don't cast to array + assert isinstance(ds["var1"].load().data, DuckArrayWrapper) + assert isinstance( + ds["var1"].isel(dim2=0, dim1=0).load().data, DuckArrayWrapper + ) + + ds.close() + def test_dropna(self) -> None: x = np.random.randn(4, 4) x[::2, 0] = np.nan @@ -4737,12 +5089,15 @@ def test_dropna(self) -> None: expected = ds.isel(a=[1, 3]) assert_identical(actual, ds) - with pytest.raises(ValueError, match=r"a single dataset dimension"): + with pytest.raises( + ValueError, + match=r"'foo' not found in data dimensions \('a', 'b'\)", + ): ds.dropna("foo") with pytest.raises(ValueError, match=r"invalid how"): - ds.dropna("a", how="somehow") # type: ignore + ds.dropna("a", how="somehow") # type: ignore[arg-type] with pytest.raises(TypeError, match=r"must specify how or thresh"): - ds.dropna("a", how=None) # type: ignore + ds.dropna("a", how=None) # type: ignore[arg-type] def test_fillna(self) -> None: ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]}) @@ -5055,7 +5410,10 @@ def test_mean_uint_dtype(self) -> None: def test_reduce_bad_dim(self) -> None: data = create_test_data() - with pytest.raises(ValueError, match=r"Dataset does not contain"): + with pytest.raises( + ValueError, + match=r"Dimensions \('bad_dim',\) not found in data dimensions", + ): data.mean(dim="bad_dim") def test_reduce_cumsum(self) -> None: @@ -5081,7 +5439,10 @@ def test_reduce_cumsum(self) -> None: @pytest.mark.parametrize("func", ["cumsum", "cumprod"]) def test_reduce_cumsum_test_dims(self, reduct, expected, func) -> None: data = create_test_data() - with pytest.raises(ValueError, match=r"Dataset does not contain"): + with pytest.raises( + ValueError, + match=r"Dimensions \('bad_dim',\) not found in data dimensions", + ): getattr(data, func)(dim="bad_dim") # ensure dimensions are correct @@ -5093,7 +5454,7 @@ def test_reduce_non_numeric(self) -> None: data2 = create_test_data(seed=44) add_vars = {"var4": ["dim1", "dim2"], "var5": ["dim1"]} for v, dims in sorted(add_vars.items()): - size = tuple(data1.dims[d] for d in dims) + size = tuple(data1.sizes[d] for d in dims) data = np.random.randint(0, 100, size=size).astype(np.str_) data1[v] = (dims, data, {"foo": "variable"}) @@ -5253,11 +5614,12 @@ def test_reduce_keepdims(self) -> None: ) assert_identical(expected, actual) + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) @pytest.mark.parametrize("skipna", [True, False, None]) @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) - def test_quantile(self, q, skipna) -> None: + def test_quantile(self, q, skipna, compute_backend) -> None: ds = create_test_data(seed=123) - ds.var1.data[0, 0] = np.NaN + ds.var1.data[0, 0] = np.nan for dim in [None, "dim1", ["dim1"]]: ds_quantile = ds.quantile(q, dim=dim, skipna=skipna) @@ -5276,8 +5638,9 @@ def test_quantile(self, q, skipna) -> None: assert "dim3" in ds_quantile.dims assert all(d not in ds_quantile.dims for d in dim) + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) @pytest.mark.parametrize("skipna", [True, False]) - def test_quantile_skipna(self, skipna) -> None: + def test_quantile_skipna(self, skipna, compute_backend) -> None: q = 0.1 dim = "time" ds = Dataset({"a": ([dim], np.arange(0, 11))}) @@ -5329,7 +5692,12 @@ def test_rank(self) -> None: assert list(z.coords) == list(ds.coords) assert list(x.coords) == list(y.coords) # invalid dim - with pytest.raises(ValueError, match=r"does not contain"): + with pytest.raises( + ValueError, + match=re.escape( + "Dimension 'invalid_dim' not found in data dimensions ('dim3', 'dim1')" + ), + ): x.rank("invalid_dim") def test_rank_use_bottleneck(self) -> None: @@ -5495,6 +5863,7 @@ def test_dataset_math_auto_align(self) -> None: expected = ds + other.reindex_like(ds) assert_identical(expected, actual) + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_dataset_math_errors(self) -> None: ds = self.make_example_math_dataset() @@ -6083,6 +6452,13 @@ def test_ipython_key_completion(self) -> None: ds["var3"].coords[item] # should not raise assert sorted(actual) == sorted(expected) + coords = Coordinates(ds.coords) + actual = coords._ipython_key_completions_() + expected = ["time", "dim2", "dim3", "numbers"] + for item in actual: + coords[item] # should not raise + assert sorted(actual) == sorted(expected) + # data_vars actual = ds.data_vars._ipython_key_completions_() expected = ["var1", "var2", "var3", "dim1"] @@ -6103,13 +6479,21 @@ def test_polyfit_output(self) -> None: out = ds.polyfit("time", 2) assert len(out.data_vars) == 0 + def test_polyfit_weighted(self) -> None: + # Make sure weighted polyfit does not change the original object (issue #5644) + ds = create_test_data(seed=1) + ds_copy = ds.copy(deep=True) + + ds.polyfit("dim2", 2, w=np.arange(ds.sizes["dim2"])) + xr.testing.assert_identical(ds, ds_copy) + def test_polyfit_warnings(self) -> None: ds = create_test_data(seed=1) with warnings.catch_warnings(record=True) as ws: ds.var1.polyfit("dim2", 10, full=False) assert len(ws) == 1 - assert ws[0].category == np.RankWarning + assert ws[0].category == RankWarning ds.var1.polyfit("dim2", 10, full=True) assert len(ws) == 1 @@ -6121,7 +6505,7 @@ def test_pad(self) -> None: assert padded["var1"].shape == (8, 11) assert padded["var2"].shape == (8, 11) assert padded["var3"].shape == (10, 8) - assert dict(padded.dims) == {"dim1": 8, "dim2": 11, "dim3": 10, "time": 20} + assert dict(padded.sizes) == {"dim1": 8, "dim2": 11, "dim3": 10, "time": 20} np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42) np.testing.assert_equal(padded["dim2"][[0, -1]].data, np.nan) @@ -6338,6 +6722,23 @@ def test_query(self, backend, engine, parser) -> None: # pytest tests — new tests should go here, rather than in the class. +@pytest.mark.parametrize("parser", ["pandas", "python"]) +def test_eval(ds, parser) -> None: + """Currently much more minimal testing that `query` above, and much of the setup + isn't used. But the risks are fairly low — `query` shares much of the code, and + the method is currently experimental.""" + + actual = ds.eval("z1 + 5", parser=parser) + expect = ds["z1"] + 5 + assert_identical(expect, actual) + + # check pandas query syntax is supported + if parser == "pandas": + actual = ds.eval("(z1 > 5) and (z2 > 0)", parser=parser) + expect = (ds["z1"] > 5) & (ds["z2"] > 0) + assert_identical(expect, actual) + + @pytest.mark.parametrize("test_elements", ([1, 2], np.array([1, 2]), DataArray([1, 2]))) def test_isin(test_elements, backend) -> None: expected = Dataset( @@ -6436,7 +6837,7 @@ def test_dir_unicode(ds) -> None: def test_raise_no_warning_for_nan_in_binary_ops() -> None: with assert_no_warnings(): - Dataset(data_vars={"x": ("y", [1, 2, np.NaN])}) > 0 + Dataset(data_vars={"x": ("y", [1, 2, np.nan])}) > 0 @pytest.mark.filterwarnings("error") @@ -6494,6 +6895,7 @@ def test_differentiate(dask, edge_order) -> None: da.differentiate("x2d") +@pytest.mark.filterwarnings("ignore:Converting non-nanosecond") @pytest.mark.parametrize("dask", [True, False]) def test_differentiate_datetime(dask) -> None: rs = np.random.RandomState(42) @@ -6547,7 +6949,7 @@ def test_differentiate_datetime(dask) -> None: @pytest.mark.parametrize("dask", [True, False]) def test_differentiate_cftime(dask) -> None: rs = np.random.RandomState(42) - coord = xr.cftime_range("2000", periods=8, freq="2M") + coord = xr.cftime_range("2000", periods=8, freq="2ME") da = xr.DataArray( rs.randn(8, 6), @@ -6690,6 +7092,7 @@ def test_cumulative_integrate(dask) -> None: da.cumulative_integrate("x2d") +@pytest.mark.filterwarnings("ignore:Converting non-nanosecond") @pytest.mark.parametrize("dask", [True, False]) @pytest.mark.parametrize("which_datetime", ["np", "cftime"]) def test_trapz_datetime(dask, which_datetime) -> None: @@ -6814,7 +7217,7 @@ def test_clip(ds) -> None: assert all((result.max(...) <= 0.75).values()) result = ds.clip(min=ds.mean("y"), max=ds.mean("y")) - assert result.dims == ds.dims + assert result.sizes == ds.sizes class TestDropDuplicates: @@ -6844,7 +7247,12 @@ def test_drop_duplicates_1d(self, keep) -> None: result = ds.drop_duplicates("time", keep=keep) assert_equal(expected, result) - with pytest.raises(ValueError, match="['space'] not found"): + with pytest.raises( + ValueError, + match=re.escape( + "Dimensions ('space',) not found in data dimensions ('time',)" + ), + ): ds.drop_duplicates("space", keep=keep) diff --git a/xarray/tests/test_deprecation_helpers.py b/xarray/tests/test_deprecation_helpers.py index 35128829073..f21c8097060 100644 --- a/xarray/tests/test_deprecation_helpers.py +++ b/xarray/tests/test_deprecation_helpers.py @@ -15,15 +15,15 @@ def f1(a, b, *, c="c", d="d"): assert result == (1, 2, 3, 4) with pytest.warns(FutureWarning, match=r".*v0.1"): - result = f1(1, 2, 3) + result = f1(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): - result = f1(1, 2, 3) + result = f1(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): - result = f1(1, 2, 3, 4) + result = f1(1, 2, 3, 4) # type: ignore[misc] assert result == (1, 2, 3, 4) @_deprecate_positional_args("v0.1") @@ -31,7 +31,7 @@ def f2(a="a", *, b="b", c="c", d="d"): return a, b, c, d with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f2(1, 2) + result = f2(1, 2) # type: ignore[misc] assert result == (1, 2, "c", "d") @_deprecate_positional_args("v0.1") @@ -39,11 +39,11 @@ def f3(a, *, b="b", **kwargs): return a, b, kwargs with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f3(1, 2) + result = f3(1, 2) # type: ignore[misc] assert result == (1, 2, {}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f3(1, 2, f="f") + result = f3(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) @_deprecate_positional_args("v0.1") @@ -57,7 +57,7 @@ def f4(a, /, *, b="b", **kwargs): assert result == (1, 2, {"f": "f"}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f4(1, 2, f="f") + result = f4(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) with pytest.raises(TypeError, match=r"Keyword-only param without default"): @@ -80,15 +80,15 @@ def method(self, a, b, *, c="c", d="d"): assert result == (1, 2, 3, 4) with pytest.warns(FutureWarning, match=r".*v0.1"): - result = A1().method(1, 2, 3) + result = A1().method(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): - result = A1().method(1, 2, 3) + result = A1().method(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): - result = A1().method(1, 2, 3, 4) + result = A1().method(1, 2, 3, 4) # type: ignore[misc] assert result == (1, 2, 3, 4) class A2: @@ -97,11 +97,11 @@ def method(self, a=1, b=1, *, c="c", d="d"): return a, b, c, d with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): - result = A2().method(1, 2, 3) + result = A2().method(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): - result = A2().method(1, 2, 3, 4) + result = A2().method(1, 2, 3, 4) # type: ignore[misc] assert result == (1, 2, 3, 4) class A3: @@ -110,11 +110,11 @@ def method(self, a, *, b="b", **kwargs): return a, b, kwargs with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = A3().method(1, 2) + result = A3().method(1, 2) # type: ignore[misc] assert result == (1, 2, {}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = A3().method(1, 2, f="f") + result = A3().method(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) class A4: @@ -129,7 +129,7 @@ def method(self, a, /, *, b="b", **kwargs): assert result == (1, 2, {"f": "f"}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = A4().method(1, 2, f="f") + result = A4().method(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) with pytest.raises(TypeError, match=r"Keyword-only param without default"): diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index a29cccd0f50..d223bce2098 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -1,4 +1,5 @@ """ isort:skip_file """ + from __future__ import annotations import pickle @@ -6,7 +7,6 @@ import numpy as np import pytest -from packaging.version import Version if TYPE_CHECKING: import dask @@ -28,7 +28,7 @@ ) import xarray as xr -from xarray.backends.locks import HDF5_LOCK, CombinedLock +from xarray.backends.locks import HDF5_LOCK, CombinedLock, SerializableLock from xarray.tests import ( assert_allclose, assert_identical, @@ -37,13 +37,11 @@ has_scipy, requires_cftime, requires_netCDF4, - requires_rasterio, requires_zarr, ) from xarray.tests.test_backends import ( ON_WINDOWS, create_tmp_file, - create_tmp_geotiff, ) from xarray.tests.test_dataset import create_test_data @@ -129,7 +127,8 @@ def test_dask_distributed_write_netcdf_with_dimensionless_variables( @requires_cftime @requires_netCDF4 -def test_open_mfdataset_can_open_files_with_cftime_index(tmp_path): +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_can_open_files_with_cftime_index(parallel, tmp_path): T = xr.cftime_range("20010101", "20010501", calendar="360_day") Lon = np.arange(100) data = np.random.random((T.size, Lon.size)) @@ -138,9 +137,59 @@ def test_open_mfdataset_can_open_files_with_cftime_index(tmp_path): da.to_netcdf(file_path) with cluster() as (s, [a, b]): with Client(s["address"]): - for parallel in (False, True): - with xr.open_mfdataset(file_path, parallel=parallel) as tf: - assert_identical(tf["test"], da) + with xr.open_mfdataset(file_path, parallel=parallel) as tf: + assert_identical(tf["test"], da) + + +@requires_cftime +@requires_netCDF4 +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_multiple_files_parallel_distributed(parallel, tmp_path): + lon = np.arange(100) + time = xr.cftime_range("20010101", periods=100, calendar="360_day") + data = np.random.random((time.size, lon.size)) + da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test") + + fnames = [] + for i in range(0, 100, 10): + fname = tmp_path / f"test_{i}.nc" + da.isel(time=slice(i, i + 10)).to_netcdf(fname) + fnames.append(fname) + + with cluster() as (s, [a, b]): + with Client(s["address"]): + with xr.open_mfdataset( + fnames, parallel=parallel, concat_dim="time", combine="nested" + ) as tf: + assert_identical(tf["test"], da) + + +# TODO: move this to test_backends.py +@requires_cftime +@requires_netCDF4 +@pytest.mark.parametrize("parallel", (True, False)) +def test_open_mfdataset_multiple_files_parallel(parallel, tmp_path): + if parallel: + pytest.skip( + "Flaky in CI. Would be a welcome contribution to make a similar test reliable." + ) + lon = np.arange(100) + time = xr.cftime_range("20010101", periods=100, calendar="360_day") + data = np.random.random((time.size, lon.size)) + da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test") + + fnames = [] + for i in range(0, 100, 10): + fname = tmp_path / f"test_{i}.nc" + da.isel(time=slice(i, i + 10)).to_netcdf(fname) + fnames.append(fname) + + for get in [dask.threaded.get, dask.multiprocessing.get, dask.local.get_sync, None]: + with dask.config.set(scheduler=get): + with xr.open_mfdataset( + fnames, parallel=parallel, concat_dim="time", combine="nested" + ) as tf: + assert_identical(tf["test"], da) @pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS) @@ -196,22 +245,6 @@ def test_dask_distributed_zarr_integration_test( assert_allclose(original, computed) -@requires_rasterio -@pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") -def test_dask_distributed_rasterio_integration_test(loop) -> None: - with create_tmp_geotiff() as (tmp_file, expected): - with cluster() as (s, [a, b]): - with pytest.warns(DeprecationWarning), Client(s["address"], loop=loop): - da_tiff = xr.open_rasterio(tmp_file, chunks={"band": 1}) - assert isinstance(da_tiff.data, da.Array) - actual = da_tiff.compute() - assert_allclose(actual, expected) - - -@pytest.mark.xfail( - condition=Version(distributed.__version__) < Version("2022.02.0"), - reason="https://github.com/dask/distributed/pull/5739", -) @gen_cluster(client=True) async def test_async(c, s, a, b) -> None: x = create_test_data() @@ -241,13 +274,9 @@ async def test_async(c, s, a, b) -> None: def test_hdf5_lock() -> None: - assert isinstance(HDF5_LOCK, dask.utils.SerializableLock) + assert isinstance(HDF5_LOCK, SerializableLock) -@pytest.mark.xfail( - condition=Version(distributed.__version__) < Version("2022.02.0"), - reason="https://github.com/dask/distributed/pull/5739", -) @gen_cluster(client=True) async def test_serializable_locks(c, s, a, b) -> None: def f(x, lock=None): diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py index 1c942a1e6c8..3c2ee5e8f6f 100644 --- a/xarray/tests/test_dtypes.py +++ b/xarray/tests/test_dtypes.py @@ -10,25 +10,25 @@ "args, expected", [ ([bool], bool), - ([bool, np.string_], np.object_), + ([bool, np.bytes_], np.object_), ([np.float32, np.float64], np.float64), - ([np.float32, np.string_], np.object_), - ([np.unicode_, np.int64], np.object_), - ([np.unicode_, np.unicode_], np.unicode_), - ([np.bytes_, np.unicode_], np.object_), + ([np.float32, np.bytes_], np.object_), + ([np.str_, np.int64], np.object_), + ([np.str_, np.str_], np.str_), + ([np.bytes_, np.str_], np.object_), ], ) -def test_result_type(args, expected): +def test_result_type(args, expected) -> None: actual = dtypes.result_type(*args) assert actual == expected -def test_result_type_scalar(): +def test_result_type_scalar() -> None: actual = dtypes.result_type(np.arange(3, dtype=np.float32), np.nan) assert actual == np.float32 -def test_result_type_dask_array(): +def test_result_type_dask_array() -> None: # verify it works without evaluating dask arrays da = pytest.importorskip("dask.array") dask = pytest.importorskip("dask") @@ -50,7 +50,7 @@ def error(): @pytest.mark.parametrize("obj", [1.0, np.inf, "ab", 1.0 + 1.0j, True]) -def test_inf(obj): +def test_inf(obj) -> None: assert dtypes.INF > obj assert dtypes.NINF < obj @@ -85,7 +85,7 @@ def test_inf(obj): ("V", (np.dtype("O"), "nan")), # dtype('V') ], ) -def test_maybe_promote(kind, expected): +def test_maybe_promote(kind, expected) -> None: # 'g': np.float128 is not tested : not available on all platforms # 'G': np.complex256 is not tested : not available on all platforms @@ -94,7 +94,7 @@ def test_maybe_promote(kind, expected): assert str(actual[1]) == expected[1] -def test_nat_types_membership(): +def test_nat_types_membership() -> None: assert np.datetime64("NaT").dtype in dtypes.NAT_TYPES assert np.timedelta64("NaT").dtype in dtypes.NAT_TYPES assert np.float64 not in dtypes.NAT_TYPES diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 0d6efa2a8d3..df1ab1f40f9 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -27,7 +27,7 @@ timedelta_to_numeric, where, ) -from xarray.core.pycompat import array_type +from xarray.namedarray.pycompat import array_type from xarray.testing import assert_allclose, assert_equal, assert_identical from xarray.tests import ( arm_xfail, @@ -500,7 +500,7 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): expected = getattr(da.compute(), func)(skipna=skipna, dim=aggdim) assert_allclose(actual, expected, rtol=rtol) - # make sure the compatiblility with pandas' results. + # make sure the compatibility with pandas' results. if func in ["var", "std"]: expected = series_reduce(da, func, skipna=skipna, dim=aggdim, ddof=0) assert_allclose(actual, expected, rtol=rtol) @@ -577,17 +577,39 @@ def test_argmin_max_error(): @pytest.mark.parametrize( - "array", + ["array", "expected"], [ - np.array([np.datetime64("2000-01-01"), np.datetime64("NaT")]), - np.array([np.timedelta64(1, "h"), np.timedelta64("NaT")]), - np.array([0.0, np.nan]), - np.array([1j, np.nan]), - np.array(["foo", np.nan], dtype=object), + ( + np.array([np.datetime64("2000-01-01"), np.datetime64("NaT")]), + np.array([False, True]), + ), + ( + np.array([np.timedelta64(1, "h"), np.timedelta64("NaT")]), + np.array([False, True]), + ), + ( + np.array([0.0, np.nan]), + np.array([False, True]), + ), + ( + np.array([1j, np.nan]), + np.array([False, True]), + ), + ( + np.array(["foo", np.nan], dtype=object), + np.array([False, True]), + ), + ( + np.array([1, 2], dtype=int), + np.array([False, False]), + ), + ( + np.array([True, False], dtype=bool), + np.array([False, False]), + ), ], ) -def test_isnull(array): - expected = np.array([False, True]) +def test_isnull(array, expected): actual = duck_array_ops.isnull(array) np.testing.assert_equal(expected, actual) diff --git a/xarray/tests/test_error_messages.py b/xarray/tests/test_error_messages.py new file mode 100644 index 00000000000..b5840aafdfa --- /dev/null +++ b/xarray/tests/test_error_messages.py @@ -0,0 +1,17 @@ +""" +This new file is intended to test the quality & friendliness of error messages that are +raised by xarray. It's currently separate from the standard tests, which are more +focused on the functions working (though we could consider integrating them.). +""" + +import pytest + + +def test_no_var_in_dataset(ds): + with pytest.raises( + KeyError, + match=( + r"No variable named 'foo'. Variables on the dataset include \['z1', 'z2', 'x', 'time', 'c', 'y'\]" + ), + ): + ds["foo"] diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 3cba5b965f9..6923d26b79a 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -10,7 +10,9 @@ import xarray as xr from xarray.core import formatting -from xarray.tests import requires_dask, requires_netCDF4 +from xarray.tests import requires_cftime, requires_dask, requires_netCDF4 + +ON_WINDOWS = sys.platform == "win32" class TestFormatting: @@ -218,31 +220,70 @@ def test_attribute_repr(self) -> None: assert "\n" not in newlines assert "\t" not in tabs - def test_index_repr(self): + def test_index_repr(self) -> None: from xarray.core.indexes import Index class CustomIndex(Index): - def __init__(self, names): + names: tuple[str, ...] + + def __init__(self, names: tuple[str, ...]): self.names = names def __repr__(self): return f"CustomIndex(coords={self.names})" - coord_names = ["x", "y"] + coord_names = ("x", "y") index = CustomIndex(coord_names) - name = "x" + names = ("x",) - normal = formatting.summarize_index(name, index, col_width=20) - assert name in normal + normal = formatting.summarize_index(names, index, col_width=20) + assert names[0] in normal + assert len(normal.splitlines()) == len(names) assert "CustomIndex" in normal - CustomIndex._repr_inline_ = ( - lambda self, max_width: f"CustomIndex[{', '.join(self.names)}]" - ) - inline = formatting.summarize_index(name, index, col_width=20) - assert name in inline + class IndexWithInlineRepr(CustomIndex): + def _repr_inline_(self, max_width: int): + return f"CustomIndex[{', '.join(self.names)}]" + + index = IndexWithInlineRepr(coord_names) + inline = formatting.summarize_index(names, index, col_width=20) + assert names[0] in inline assert index._repr_inline_(max_width=40) in inline + @pytest.mark.parametrize( + "names", + ( + ("x",), + ("x", "y"), + ("x", "y", "z"), + ("x", "y", "z", "a"), + ), + ) + def test_index_repr_grouping(self, names) -> None: + from xarray.core.indexes import Index + + class CustomIndex(Index): + def __init__(self, names): + self.names = names + + def __repr__(self): + return f"CustomIndex(coords={self.names})" + + index = CustomIndex(names) + + normal = formatting.summarize_index(names, index, col_width=20) + assert all(name in normal for name in names) + assert len(normal.splitlines()) == len(names) + assert "CustomIndex" in normal + + hint_chars = [line[2] for line in normal.splitlines()] + + if len(names) <= 1: + assert hint_chars == [" "] + else: + assert hint_chars[0] == "┌" and hint_chars[-1] == "└" + assert len(names) == 2 or hint_chars[1:-1] == ["│"] * (len(names) - 2) + def test_diff_array_repr(self) -> None: da_a = xr.DataArray( np.array([[1, 2, 3], [4, 5, 6]], dtype="int64"), @@ -277,12 +318,12 @@ def test_diff_array_repr(self) -> None: R array([1, 2], dtype=int64) Differing coordinates: - L * x (x) %cU1 'a' 'b' - R * x (x) %cU1 'a' 'c' + L * x (x) %cU1 8B 'a' 'b' + R * x (x) %cU1 8B 'a' 'c' Coordinates only on the left object: - * y (y) int64 1 2 3 + * y (y) int64 24B 1 2 3 Coordinates only on the right object: - label (x) int64 1 2 + label (x) int64 16B 1 2 Differing attributes: L units: m R units: kg @@ -367,19 +408,27 @@ def test_diff_dataset_repr(self) -> None: "var2": ("x", np.array([3, 4], dtype="int64")), }, coords={ - "x": np.array(["a", "b"], dtype="U1"), + "x": ( + "x", + np.array(["a", "b"], dtype="U1"), + {"foo": "bar", "same": "same"}, + ), "y": np.array([1, 2, 3], dtype="int64"), }, - attrs={"units": "m", "description": "desc"}, + attrs={"title": "mytitle", "description": "desc"}, ) ds_b = xr.Dataset( data_vars={"var1": ("x", np.array([1, 2], dtype="int64"))}, coords={ - "x": ("x", np.array(["a", "c"], dtype="U1"), {"source": 0}), + "x": ( + "x", + np.array(["a", "c"], dtype="U1"), + {"source": 0, "foo": "baz", "same": "same"}, + ), "label": ("x", np.array([1, 2], dtype="int64")), }, - attrs={"units": "kg"}, + attrs={"title": "newtitle"}, ) byteorder = "<" if sys.byteorder == "little" else ">" @@ -389,21 +438,25 @@ def test_diff_dataset_repr(self) -> None: Differing dimensions: (x: 2, y: 3) != (x: 2) Differing coordinates: - L * x (x) %cU1 'a' 'b' - R * x (x) %cU1 'a' 'c' - source: 0 + L * x (x) %cU1 8B 'a' 'b' + Differing variable attributes: + foo: bar + R * x (x) %cU1 8B 'a' 'c' + Differing variable attributes: + source: 0 + foo: baz Coordinates only on the left object: - * y (y) int64 1 2 3 + * y (y) int64 24B 1 2 3 Coordinates only on the right object: - label (x) int64 1 2 + label (x) int64 16B 1 2 Differing data variables: - L var1 (x, y) int64 1 2 3 4 5 6 - R var1 (x) int64 1 2 + L var1 (x, y) int64 48B 1 2 3 4 5 6 + R var1 (x) int64 16B 1 2 Data variables only on the left object: - var2 (x) int64 3 4 + var2 (x) int64 16B 3 4 Differing attributes: - L units: m - R units: kg + L title: mytitle + R title: newtitle Attributes only on the left object: description: desc""" % (byteorder, byteorder) @@ -413,16 +466,22 @@ def test_diff_dataset_repr(self) -> None: assert actual == expected def test_array_repr(self) -> None: - ds = xr.Dataset(coords={"foo": [1, 2, 3], "bar": [1, 2, 3]}) - ds[(1, 2)] = xr.DataArray([0], dims="test") + ds = xr.Dataset( + coords={ + "foo": np.array([1, 2, 3], dtype=np.uint64), + "bar": np.array([1, 2, 3], dtype=np.uint64), + } + ) + ds[(1, 2)] = xr.DataArray(np.array([0], dtype=np.uint64), dims="test") ds_12 = ds[(1, 2)] # Test repr function behaves correctly: actual = formatting.array_repr(ds_12) + expected = dedent( """\ - - array([0]) + Size: 8B + array([0], dtype=uint64) Dimensions without coordinates: test""" ) @@ -440,7 +499,7 @@ def test_array_repr(self) -> None: actual = formatting.array_repr(ds[(1, 2)]) expected = dedent( """\ - + Size: 8B 0 Dimensions without coordinates: test""" ) @@ -458,7 +517,7 @@ def test_array_repr_variable(self) -> None: def test_array_repr_recursive(self) -> None: # GH:issue:7111 - # direct recurion + # direct recursion var = xr.Variable("x", [0, 1]) var.attrs["x"] = var formatting.array_repr(var) @@ -510,7 +569,7 @@ def _repr_inline_(self, width): return formatted - def __array_function__(self, *args, **kwargs): + def __array_namespace__(self, *args, **kwargs): return NotImplemented @property @@ -542,7 +601,7 @@ def test_set_numpy_options() -> None: assert np.get_printoptions() == original_options -def test_short_numpy_repr() -> None: +def test_short_array_repr() -> None: cases = [ np.random.randn(500), np.random.randn(20, 20), @@ -552,16 +611,16 @@ def test_short_numpy_repr() -> None: ] # number of lines: # for default numpy repr: 167, 140, 254, 248, 599 - # for short_numpy_repr: 1, 7, 24, 19, 25 + # for short_array_repr: 1, 7, 24, 19, 25 for array in cases: - num_lines = formatting.short_numpy_repr(array).count("\n") + 1 + num_lines = formatting.short_array_repr(array).count("\n") + 1 assert num_lines < 30 # threshold option (default: 200) array2 = np.arange(100) - assert "..." not in formatting.short_numpy_repr(array2) + assert "..." not in formatting.short_array_repr(array2) with xr.set_options(display_values_threshold=10): - assert "..." in formatting.short_numpy_repr(array2) + assert "..." in formatting.short_array_repr(array2) def test_large_array_repr_length() -> None: @@ -576,13 +635,14 @@ def test_repr_file_collapsed(tmp_path) -> None: arr_to_store = xr.DataArray(np.arange(300, dtype=np.int64), dims="test") arr_to_store.to_netcdf(tmp_path / "test.nc", engine="netcdf4") - with xr.open_dataarray(tmp_path / "test.nc") as arr, xr.set_options( - display_expand_data=False + with ( + xr.open_dataarray(tmp_path / "test.nc") as arr, + xr.set_options(display_expand_data=False), ): actual = repr(arr) expected = dedent( """\ - + Size: 2kB [300 values with dtype=int64] Dimensions without coordinates: test""" ) @@ -593,7 +653,7 @@ def test_repr_file_collapsed(tmp_path) -> None: actual = arr_loaded.__repr__() expected = dedent( """\ - + Size: 2kB 0 1 2 3 4 5 6 7 8 9 10 11 12 ... 288 289 290 291 292 293 294 295 296 297 298 299 Dimensions without coordinates: test""" ) @@ -611,15 +671,16 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None: b = defchararray.add("attr_", np.arange(0, n_attr).astype(str)) c = defchararray.add("coord", np.arange(0, n_vars).astype(str)) attrs = {k: 2 for k in b} - coords = {_c: np.array([0, 1]) for _c in c} + coords = {_c: np.array([0, 1], dtype=np.uint64) for _c in c} data_vars = dict() for v, _c in zip(a, coords.items()): data_vars[v] = xr.DataArray( name=v, - data=np.array([3, 4]), + data=np.array([3, 4], dtype=np.uint64), dims=[_c[0]], coords=dict([_c]), ) + ds = xr.Dataset(data_vars) ds.attrs = attrs @@ -656,8 +717,9 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None: dims_values = formatting.dim_summary_limited( ds, col_width=col_width + 1, max_rows=display_max_rows ) + expected_size = "1kB" expected = f"""\ - + Size: {expected_size} {dims_start}({dims_values}) Coordinates: ({n_vars}) Data variables: ({n_vars}) @@ -722,3 +784,259 @@ def __array__(self, dtype=None): # These will crash if var.data are converted to numpy arrays: var.__repr__() var._repr_html_() + + +@pytest.mark.parametrize("as_dataset", (False, True)) +def test_format_xindexes_none(as_dataset: bool) -> None: + # ensure repr for empty xindexes can be displayed #8367 + + expected = """\ + Indexes: + *empty*""" + expected = dedent(expected) + + obj: xr.DataArray | xr.Dataset = xr.DataArray() + obj = obj._to_temp_dataset() if as_dataset else obj + + actual = repr(obj.xindexes) + assert actual == expected + + +@pytest.mark.parametrize("as_dataset", (False, True)) +def test_format_xindexes(as_dataset: bool) -> None: + expected = """\ + Indexes: + x PandasIndex""" + expected = dedent(expected) + + obj: xr.DataArray | xr.Dataset = xr.DataArray([1], coords={"x": [1]}) + obj = obj._to_temp_dataset() if as_dataset else obj + + actual = repr(obj.xindexes) + assert actual == expected + + +@requires_cftime +def test_empty_cftimeindex_repr() -> None: + index = xr.coding.cftimeindex.CFTimeIndex([]) + + expected = """\ + Indexes: + time CFTimeIndex([], dtype='object', length=0, calendar=None, freq=None)""" + expected = dedent(expected) + + da = xr.DataArray([], coords={"time": index}) + + actual = repr(da.indexes) + assert actual == expected + + +def test_display_nbytes() -> None: + xds = xr.Dataset( + { + "foo": np.arange(1200, dtype=np.int16), + "bar": np.arange(111, dtype=np.int16), + } + ) + + # Note: int16 is used to ensure that dtype is shown in the + # numpy array representation for all OSes included Windows + + actual = repr(xds) + expected = """ + Size: 3kB +Dimensions: (foo: 1200, bar: 111) +Coordinates: + * foo (foo) int16 2kB 0 1 2 3 4 5 6 ... 1194 1195 1196 1197 1198 1199 + * bar (bar) int16 222B 0 1 2 3 4 5 6 7 ... 104 105 106 107 108 109 110 +Data variables: + *empty* + """.strip() + assert actual == expected + + actual = repr(xds["foo"]) + expected = """ + Size: 2kB +array([ 0, 1, 2, ..., 1197, 1198, 1199], dtype=int16) +Coordinates: + * foo (foo) int16 2kB 0 1 2 3 4 5 6 ... 1194 1195 1196 1197 1198 1199 +""".strip() + assert actual == expected + + +def test_array_repr_dtypes(): + + # These dtypes are expected to be represented similarly + # on Ubuntu, macOS and Windows environments of the CI. + # Unsigned integer could be used as easy replacements + # for tests where the data-type does not matter, + # but the repr does, including the size + # (size of a int == size of an uint) + + # Signed integer dtypes + + ds = xr.DataArray(np.array([0], dtype="int8"), dims="x") + actual = repr(ds) + expected = """ + Size: 1B +array([0], dtype=int8) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="int16"), dims="x") + actual = repr(ds) + expected = """ + Size: 2B +array([0], dtype=int16) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + # Unsigned integer dtypes + + ds = xr.DataArray(np.array([0], dtype="uint8"), dims="x") + actual = repr(ds) + expected = """ + Size: 1B +array([0], dtype=uint8) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="uint16"), dims="x") + actual = repr(ds) + expected = """ + Size: 2B +array([0], dtype=uint16) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="uint32"), dims="x") + actual = repr(ds) + expected = """ + Size: 4B +array([0], dtype=uint32) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="uint64"), dims="x") + actual = repr(ds) + expected = """ + Size: 8B +array([0], dtype=uint64) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + # Float dtypes + + ds = xr.DataArray(np.array([0.0]), dims="x") + actual = repr(ds) + expected = """ + Size: 8B +array([0.]) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="float16"), dims="x") + actual = repr(ds) + expected = """ + Size: 2B +array([0.], dtype=float16) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="float32"), dims="x") + actual = repr(ds) + expected = """ + Size: 4B +array([0.], dtype=float32) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="float64"), dims="x") + actual = repr(ds) + expected = """ + Size: 8B +array([0.]) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + +@pytest.mark.skipif( + ON_WINDOWS, + reason="Default numpy's dtypes vary according to OS", +) +def test_array_repr_dtypes_unix() -> None: + + # Signed integer dtypes + + ds = xr.DataArray(np.array([0]), dims="x") + actual = repr(ds) + expected = """ + Size: 8B +array([0]) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="int32"), dims="x") + actual = repr(ds) + expected = """ + Size: 4B +array([0], dtype=int32) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="int64"), dims="x") + actual = repr(ds) + expected = """ + Size: 8B +array([0]) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + +@pytest.mark.skipif( + not ON_WINDOWS, + reason="Default numpy's dtypes vary according to OS", +) +def test_array_repr_dtypes_on_windows() -> None: + + # Integer dtypes + + ds = xr.DataArray(np.array([0]), dims="x") + actual = repr(ds) + expected = """ + Size: 4B +array([0]) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="int32"), dims="x") + actual = repr(ds) + expected = """ + Size: 4B +array([0]) +Dimensions without coordinates: x + """.strip() + assert actual == expected + + ds = xr.DataArray(np.array([0], dtype="int64"), dims="x") + actual = repr(ds) + expected = """ + Size: 8B +array([0], dtype=int64) +Dimensions without coordinates: x + """.strip() + assert actual == expected diff --git a/xarray/tests/test_formatting_html.py b/xarray/tests/test_formatting_html.py index 7ea5c19019b..6540406e914 100644 --- a/xarray/tests/test_formatting_html.py +++ b/xarray/tests/test_formatting_html.py @@ -6,29 +6,31 @@ import xarray as xr from xarray.core import formatting_html as fh +from xarray.core.coordinates import Coordinates @pytest.fixture -def dataarray(): +def dataarray() -> xr.DataArray: return xr.DataArray(np.random.RandomState(0).randn(4, 6)) @pytest.fixture -def dask_dataarray(dataarray): +def dask_dataarray(dataarray: xr.DataArray) -> xr.DataArray: pytest.importorskip("dask") return dataarray.chunk() @pytest.fixture -def multiindex(): - mindex = pd.MultiIndex.from_product( +def multiindex() -> xr.Dataset: + midx = pd.MultiIndex.from_product( [["a", "b"], [1, 2]], names=("level_1", "level_2") ) - return xr.Dataset({}, {"x": mindex}) + midx_coords = Coordinates.from_pandas_multiindex(midx, "x") + return xr.Dataset({}, midx_coords) @pytest.fixture -def dataset(): +def dataset() -> xr.Dataset: times = pd.date_range("2000-01-01", "2001-12-31", name="time") annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28)) @@ -46,17 +48,17 @@ def dataset(): ) -def test_short_data_repr_html(dataarray) -> None: +def test_short_data_repr_html(dataarray: xr.DataArray) -> None: data_repr = fh.short_data_repr_html(dataarray) assert data_repr.startswith("
    array")
     
     
    -def test_short_data_repr_html_non_str_keys(dataset) -> None:
    +def test_short_data_repr_html_non_str_keys(dataset: xr.Dataset) -> None:
         ds = dataset.assign({2: lambda x: x["tmin"]})
         fh.dataset_repr(ds)
     
     
    -def test_short_data_repr_html_dask(dask_dataarray) -> None:
    +def test_short_data_repr_html_dask(dask_dataarray: xr.DataArray) -> None:
         assert hasattr(dask_dataarray.data, "_repr_html_")
         data_repr = fh.short_data_repr_html(dask_dataarray)
         assert data_repr == dask_dataarray.data._repr_html_()
    @@ -97,7 +99,7 @@ def test_summarize_attrs_with_unsafe_attr_name_and_value() -> None:
         assert "
    <pd.DataFrame>
    " in formatted -def test_repr_of_dataarray(dataarray) -> None: +def test_repr_of_dataarray(dataarray: xr.DataArray) -> None: formatted = fh.array_repr(dataarray) assert "dim_0" in formatted # has an expanded data section @@ -119,12 +121,12 @@ def test_repr_of_dataarray(dataarray) -> None: ) -def test_repr_of_multiindex(multiindex) -> None: +def test_repr_of_multiindex(multiindex: xr.Dataset) -> None: formatted = fh.dataset_repr(multiindex) assert "(x)" in formatted -def test_repr_of_dataset(dataset) -> None: +def test_repr_of_dataset(dataset: xr.Dataset) -> None: formatted = fh.dataset_repr(dataset) # coords, attrs, and data_vars are expanded assert ( @@ -151,7 +153,7 @@ def test_repr_of_dataset(dataset) -> None: assert "<IA>" in formatted -def test_repr_text_fallback(dataset) -> None: +def test_repr_text_fallback(dataset: xr.Dataset) -> None: formatted = fh.dataset_repr(dataset) # Just test that the "pre" block used for fallback to plain text is present. @@ -170,7 +172,7 @@ def test_variable_repr_html() -> None: assert "xarray.Variable" in html -def test_repr_of_nonstr_dataset(dataset) -> None: +def test_repr_of_nonstr_dataset(dataset: xr.Dataset) -> None: ds = dataset.copy() ds.attrs[1] = "Test value" ds[2] = ds["tmin"] @@ -179,7 +181,7 @@ def test_repr_of_nonstr_dataset(dataset) -> None: assert "
    2" in formatted -def test_repr_of_nonstr_dataarray(dataarray) -> None: +def test_repr_of_nonstr_dataarray(dataarray: xr.DataArray) -> None: da = dataarray.rename(dim_0=15) da.attrs[1] = "value" formatted = fh.array_repr(da) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index ccbead9dbc4..d927550e424 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -1,22 +1,27 @@ from __future__ import annotations import datetime +import operator import warnings +from unittest import mock import numpy as np import pandas as pd import pytest +from packaging.version import Version import xarray as xr from xarray import DataArray, Dataset, Variable from xarray.core.groupby import _consolidate_slices from xarray.tests import ( + InaccessibleArray, assert_allclose, assert_array_equal, assert_equal, assert_identical, create_test_data, has_cftime, + has_flox, has_pandas_version_two, requires_dask, requires_flox, @@ -51,28 +56,59 @@ def test_consolidate_slices() -> None: slices = [slice(2, 3), slice(5, 6)] assert _consolidate_slices(slices) == slices + # ignore type because we're checking for an error anyway with pytest.raises(ValueError): - _consolidate_slices([slice(3), 4]) + _consolidate_slices([slice(3), 4]) # type: ignore[list-item] -def test_groupby_dims_property(dataset) -> None: - assert dataset.groupby("x").dims == dataset.isel(x=1).dims - assert dataset.groupby("y").dims == dataset.isel(y=1).dims +@pytest.mark.filterwarnings("ignore:return type") +def test_groupby_dims_property(dataset, recwarn) -> None: + # dims is sensitive to squeeze, always warn + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert dataset.groupby("x").dims == dataset.isel(x=1).dims + assert dataset.groupby("y").dims == dataset.isel(y=1).dims + # in pytest-8, pytest.warns() no longer clears all warnings + recwarn.clear() + + # when squeeze=False, no warning should be raised + assert tuple(dataset.groupby("x", squeeze=False).dims) == tuple( + dataset.isel(x=slice(1, 2)).dims + ) + assert tuple(dataset.groupby("y", squeeze=False).dims) == tuple( + dataset.isel(y=slice(1, 2)).dims + ) + assert len(recwarn) == 0 + + stacked = dataset.stack({"xy": ("x", "y")}) + assert tuple(stacked.groupby("xy", squeeze=False).dims) == tuple( + stacked.isel(xy=[0]).dims + ) + assert len(recwarn) == 0 + + +def test_groupby_sizes_property(dataset) -> None: + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert dataset.groupby("x").sizes == dataset.isel(x=1).sizes + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert dataset.groupby("y").sizes == dataset.isel(y=1).sizes stacked = dataset.stack({"xy": ("x", "y")}) - assert stacked.groupby("xy").dims == stacked.isel(xy=0).dims + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert stacked.groupby("xy").sizes == stacked.isel(xy=0).sizes def test_multi_index_groupby_map(dataset) -> None: # regression test for GH873 ds = dataset.isel(z=1, drop=True)[["foo"]] expected = 2 * ds - actual = ( - ds.stack(space=["x", "y"]) - .groupby("space") - .map(lambda x: 2 * x) - .unstack("space") - ) + # The function in `map` may be sensitive to squeeze, always warn + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + actual = ( + ds.stack(space=["x", "y"]) + .groupby("space") + .map(lambda x: 2 * x) + .unstack("space") + ) assert_equal(expected, actual) @@ -134,6 +170,18 @@ def test_groupby_input_mutation() -> None: assert_identical(array, array_copy) # should not modify inputs +@pytest.mark.parametrize("use_flox", [True, False]) +def test_groupby_indexvariable(use_flox: bool) -> None: + # regression test for GH7919 + array = xr.DataArray([1, 2, 3], [("x", [2, 2, 1])]) + iv = xr.IndexVariable(dims="x", data=pd.Index(array.x.values)) + with xr.set_options(use_flox=use_flox): + actual = array.groupby(iv).sum() + actual = array.groupby(iv).sum() + expected = xr.DataArray([3, 3], [("x", [1, 2])]) + assert_identical(expected, actual) + + @pytest.mark.parametrize( "obj", [ @@ -173,7 +221,8 @@ def func(arg1, arg2, arg3=0): array = xr.DataArray([1, 1, 1], [("x", [1, 2, 3])]) expected = xr.DataArray([3, 3, 3], [("x", [1, 2, 3])]) - actual = array.groupby("x").map(func, args=(1,), arg3=1) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + actual = array.groupby("x").map(func, args=(1,), arg3=1) assert_identical(expected, actual) @@ -183,7 +232,9 @@ def func(arg1, arg2, arg3=0): dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]}) expected = xr.Dataset({"foo": ("x", [3, 3, 3])}, {"x": [1, 2, 3]}) - actual = dataset.groupby("x").map(func, args=(1,), arg3=1) + # The function in `map` may be sensitive to squeeze, always warn + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + actual = dataset.groupby("x").map(func, args=(1,), arg3=1) assert_identical(expected, actual) @@ -216,11 +267,11 @@ def test_da_groupby_quantile() -> None: assert_identical(expected, actual) array = xr.DataArray( - data=[np.NaN, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" + data=[np.nan, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" ) for skipna in (True, False, None): - e = [np.NaN, 5] if skipna is False else [2.5, 5] + e = [np.nan, 5] if skipna is False else [2.5, 5] expected = xr.DataArray(data=e, coords={"x": [1, 2], "quantile": 0.5}, dims="x") actual = array.groupby("x").quantile(0.5, skipna=skipna) @@ -330,12 +381,12 @@ def test_ds_groupby_quantile() -> None: assert_identical(expected, actual) ds = xr.Dataset( - data_vars={"a": ("x", [np.NaN, 2, 3, 4, 5, 6])}, + data_vars={"a": ("x", [np.nan, 2, 3, 4, 5, 6])}, coords={"x": [1, 1, 1, 2, 2, 2]}, ) for skipna in (True, False, None): - e = [np.NaN, 5] if skipna is False else [2.5, 5] + e = [np.nan, 5] if skipna is False else [2.5, 5] expected = xr.Dataset( data_vars={"a": ("x", e)}, coords={"quantile": 0.5, "x": [1, 2]} @@ -452,8 +503,10 @@ def test_da_groupby_assign_coords() -> None: actual = xr.DataArray( [[3, 4, 5], [6, 7, 8]], dims=["y", "x"], coords={"y": range(2), "x": range(3)} ) - actual1 = actual.groupby("x").assign_coords({"y": [-1, -2]}) - actual2 = actual.groupby("x").assign_coords(y=[-1, -2]) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + actual1 = actual.groupby("x").assign_coords({"y": [-1, -2]}) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + actual2 = actual.groupby("x").assign_coords(y=[-1, -2]) expected = xr.DataArray( [[3, 4, 5], [6, 7, 8]], dims=["y", "x"], coords={"y": [-1, -2], "x": range(3)} ) @@ -467,7 +520,7 @@ def test_da_groupby_assign_coords() -> None: coords={ "z": ["a", "b", "c", "a", "b", "c"], "x": [1, 1, 1, 2, 2, 3, 4, 5, 3, 4], - "t": pd.date_range("2001-01-01", freq="M", periods=24), + "t": xr.date_range("2001-01-01", freq="ME", periods=24, use_cftime=False), "month": ("t", list(range(1, 13)) * 2), }, ) @@ -501,6 +554,7 @@ def test_groupby_repr_datetime(obj) -> None: assert actual == expected +@pytest.mark.filterwarnings("ignore:Converting non-nanosecond") @pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning") def test_groupby_drops_nans() -> None: # GH2383 @@ -537,7 +591,7 @@ def test_groupby_drops_nans() -> None: .reset_index("id", drop=True) .assign(id=stacked.id.values) .dropna("id") - .transpose(*actual2.dims) + .transpose(*actual2.variable.dims) ) assert_identical(actual2, expected2) @@ -583,25 +637,24 @@ def test_groupby_grouping_errors() -> None: with pytest.raises( ValueError, match=r"None of the data falls within bins with edges" ): - dataset.to_array().groupby_bins("x", bins=[0.1, 0.2, 0.3]) + dataset.to_dataarray().groupby_bins("x", bins=[0.1, 0.2, 0.3]) with pytest.raises(ValueError, match=r"All bin edges are NaN."): dataset.groupby_bins("x", bins=[np.nan, np.nan, np.nan]) with pytest.raises(ValueError, match=r"All bin edges are NaN."): - dataset.to_array().groupby_bins("x", bins=[np.nan, np.nan, np.nan]) + dataset.to_dataarray().groupby_bins("x", bins=[np.nan, np.nan, np.nan]) with pytest.raises(ValueError, match=r"Failed to group data."): dataset.groupby(dataset.foo * np.nan) with pytest.raises(ValueError, match=r"Failed to group data."): - dataset.to_array().groupby(dataset.foo * np.nan) + dataset.to_dataarray().groupby(dataset.foo * np.nan) def test_groupby_reduce_dimension_error(array) -> None: grouped = array.groupby("y") - with pytest.raises(ValueError, match=r"cannot reduce over dimensions"): - grouped.mean() + # assert_identical(array, grouped.mean()) with pytest.raises(ValueError, match=r"cannot reduce over dimensions"): grouped.mean("huh") @@ -609,6 +662,10 @@ def test_groupby_reduce_dimension_error(array) -> None: with pytest.raises(ValueError, match=r"cannot reduce over dimensions"): grouped.mean(("x", "y", "asd")) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(array.mean("x"), grouped.reduce(np.mean, "x")) + assert_allclose(array.mean(["x", "z"]), grouped.reduce(np.mean, ["x", "z"])) + grouped = array.groupby("y", squeeze=False) assert_identical(array, grouped.mean()) @@ -627,7 +684,7 @@ def test_groupby_bins_timeseries() -> None: pd.date_range("2010-08-01", "2010-08-15", freq="15min"), dims="time" ) ds["val"] = xr.DataArray(np.ones(ds["time"].shape), dims="time") - time_bins = pd.date_range(start="2010-08-01", end="2010-08-15", freq="24H") + time_bins = pd.date_range(start="2010-08-01", end="2010-08-15", freq="24h") actual = ds.groupby_bins("time", time_bins).sum() expected = xr.DataArray( 96 * np.ones((14,)), @@ -650,13 +707,26 @@ def test_groupby_none_group_name() -> None: def test_groupby_getitem(dataset) -> None: - assert_identical(dataset.sel(x="a"), dataset.groupby("x")["a"]) - assert_identical(dataset.sel(z=1), dataset.groupby("z")[1]) - - assert_identical(dataset.foo.sel(x="a"), dataset.foo.groupby("x")["a"]) - assert_identical(dataset.foo.sel(z=1), dataset.foo.groupby("z")[1]) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(dataset.sel(x="a"), dataset.groupby("x")["a"]) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(dataset.sel(z=1), dataset.groupby("z")[1]) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(dataset.foo.sel(x="a"), dataset.foo.groupby("x")["a"]) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(dataset.foo.sel(z=1), dataset.foo.groupby("z")[1]) + + assert_identical(dataset.sel(x=["a"]), dataset.groupby("x", squeeze=False)["a"]) + assert_identical(dataset.sel(z=[1]), dataset.groupby("z", squeeze=False)[1]) + + assert_identical( + dataset.foo.sel(x=["a"]), dataset.foo.groupby("x", squeeze=False)["a"] + ) + assert_identical(dataset.foo.sel(z=[1]), dataset.foo.groupby("z", squeeze=False)[1]) - actual = dataset.groupby("boo")["f"].unstack().transpose("x", "y", "z") + actual = ( + dataset.groupby("boo", squeeze=False)["f"].unstack().transpose("x", "y", "z") + ) expected = dataset.sel(y=[1], z=[1, 2]).transpose("x", "y", "z") assert_identical(expected, actual) @@ -666,14 +736,14 @@ def test_groupby_dataset() -> None: {"z": (["x", "y"], np.random.randn(3, 5))}, {"x": ("x", list("abc")), "c": ("x", [0, 1, 0]), "y": range(5)}, ) - groupby = data.groupby("x") + groupby = data.groupby("x", squeeze=False) assert len(groupby) == 3 - expected_groups = {"a": 0, "b": 1, "c": 2} + expected_groups = {"a": slice(0, 1), "b": slice(1, 2), "c": slice(2, 3)} assert groupby.groups == expected_groups expected_items = [ - ("a", data.isel(x=0)), - ("b", data.isel(x=1)), - ("c", data.isel(x=2)), + ("a", data.isel(x=[0])), + ("b", data.isel(x=[1])), + ("c", data.isel(x=[2])), ] for actual1, expected1 in zip(groupby, expected_items): assert actual1[0] == expected1[0] @@ -687,31 +757,61 @@ def identity(x): assert_equal(data, actual2) +def test_groupby_dataset_squeeze_None() -> None: + """Delete when removing squeeze.""" + data = Dataset( + {"z": (["x", "y"], np.random.randn(3, 5))}, + {"x": ("x", list("abc")), "c": ("x", [0, 1, 0]), "y": range(5)}, + ) + groupby = data.groupby("x") + assert len(groupby) == 3 + expected_groups = {"a": 0, "b": 1, "c": 2} + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert groupby.groups == expected_groups + expected_items = [ + ("a", data.isel(x=0)), + ("b", data.isel(x=1)), + ("c", data.isel(x=2)), + ] + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + for actual1, expected1 in zip(groupby, expected_items): + assert actual1[0] == expected1[0] + assert_equal(actual1[1], expected1[1]) + + def identity(x): + return x + + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + for k in ["x", "c"]: + actual2 = data.groupby(k).map(identity) + assert_equal(data, actual2) + + def test_groupby_dataset_returns_new_type() -> None: data = Dataset({"z": (["x", "y"], np.random.randn(3, 5))}) - actual1 = data.groupby("x").map(lambda ds: ds["z"]) + actual1 = data.groupby("x", squeeze=False).map(lambda ds: ds["z"]) expected1 = data["z"] assert_identical(expected1, actual1) - actual2 = data["z"].groupby("x").map(lambda x: x.to_dataset()) + actual2 = data["z"].groupby("x", squeeze=False).map(lambda x: x.to_dataset()) expected2 = data assert_identical(expected2, actual2) def test_groupby_dataset_iter() -> None: data = create_test_data() - for n, (t, sub) in enumerate(list(data.groupby("dim1"))[:3]): + for n, (t, sub) in enumerate(list(data.groupby("dim1", squeeze=False))[:3]): assert data["dim1"][n] == t - assert_equal(data["var1"][n], sub["var1"]) - assert_equal(data["var2"][n], sub["var2"]) - assert_equal(data["var3"][:, n], sub["var3"]) + assert_equal(data["var1"][[n]], sub["var1"]) + assert_equal(data["var2"][[n]], sub["var2"]) + assert_equal(data["var3"][:, [n]], sub["var3"]) def test_groupby_dataset_errors() -> None: data = create_test_data() with pytest.raises(TypeError, match=r"`group` must be"): - data.groupby(np.arange(10)) + data.groupby(np.arange(10)) # type: ignore[arg-type,unused-ignore] with pytest.raises(ValueError, match=r"length does not match"): data.groupby(data["dim1"][:3]) with pytest.raises(TypeError, match=r"`group` must be"): @@ -793,9 +893,9 @@ def test_groupby_math_more() -> None: with pytest.raises(TypeError, match=r"only support binary ops"): grouped + 1 # type: ignore[operator] with pytest.raises(TypeError, match=r"only support binary ops"): - grouped + grouped + grouped + grouped # type: ignore[operator] with pytest.raises(TypeError, match=r"in-place operations"): - ds += grouped + ds += grouped # type: ignore[arg-type] ds = Dataset( { @@ -807,6 +907,56 @@ def test_groupby_math_more() -> None: ds + ds.groupby("time.month") +def test_groupby_math_bitshift() -> None: + # create new dataset of int's only + ds = Dataset( + { + "x": ("index", np.ones(4, dtype=int)), + "y": ("index", np.ones(4, dtype=int) * -1), + "level": ("index", [1, 1, 2, 2]), + "index": [0, 1, 2, 3], + } + ) + shift = DataArray([1, 2, 1], [("level", [1, 2, 8])]) + + left_expected = Dataset( + { + "x": ("index", [2, 2, 4, 4]), + "y": ("index", [-2, -2, -4, -4]), + "level": ("index", [2, 2, 8, 8]), + "index": [0, 1, 2, 3], + } + ) + + left_manual = [] + for lev, group in ds.groupby("level"): + shifter = shift.sel(level=lev) + left_manual.append(group << shifter) + left_actual = xr.concat(left_manual, dim="index").reset_coords(names="level") + assert_equal(left_expected, left_actual) + + left_actual = (ds.groupby("level") << shift).reset_coords(names="level") + assert_equal(left_expected, left_actual) + + right_expected = Dataset( + { + "x": ("index", [0, 0, 2, 2]), + "y": ("index", [-1, -1, -2, -2]), + "level": ("index", [0, 0, 4, 4]), + "index": [0, 1, 2, 3], + } + ) + right_manual = [] + for lev, group in left_expected.groupby("level"): + shifter = shift.sel(level=lev) + right_manual.append(group >> shifter) + right_actual = xr.concat(right_manual, dim="index").reset_coords(names="level") + assert_equal(right_expected, right_actual) + + right_actual = (left_expected.groupby("level") >> shift).reset_coords(names="level") + assert_equal(right_expected, right_actual) + + @pytest.mark.parametrize("use_flox", [True, False]) def test_groupby_bins_cut_kwargs(use_flox: bool) -> None: da = xr.DataArray(np.arange(12).reshape(6, 2), dims=("x", "y")) @@ -814,7 +964,7 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None: with xr.set_options(use_flox=use_flox): actual = da.groupby_bins( - "x", bins=x_bins, include_lowest=True, right=False + "x", bins=x_bins, include_lowest=True, right=False, squeeze=False ).mean() expected = xr.DataArray( np.array([[1.0, 2.0], [5.0, 6.0], [9.0, 10.0]]), @@ -890,7 +1040,7 @@ def test_groupby_math_dim_order() -> None: da = DataArray( np.ones((10, 10, 12)), dims=("x", "y", "time"), - coords={"time": pd.date_range("2001-01-01", periods=12, freq="6H")}, + coords={"time": pd.date_range("2001-01-01", periods=12, freq="6h")}, ) grouped = da.groupby("time.day") result = grouped - grouped.mean() @@ -1040,12 +1190,15 @@ def test_stack_groupby_unsorted_coord(self): def test_groupby_iter(self): for (act_x, act_dv), (exp_x, exp_ds) in zip( - self.dv.groupby("y"), self.ds.groupby("y") + self.dv.groupby("y", squeeze=False), self.ds.groupby("y", squeeze=False) ): assert exp_x == act_x assert_identical(exp_ds["foo"], act_dv) - for (_, exp_dv), act_dv in zip(self.dv.groupby("x"), self.dv): - assert_identical(exp_dv, act_dv) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + for (_, exp_dv), (_, act_dv) in zip( + self.dv.groupby("x"), self.dv.groupby("x") + ): + assert_identical(exp_dv, act_dv) def test_groupby_properties(self): grouped = self.da.groupby("abc") @@ -1059,8 +1212,8 @@ def test_groupby_properties(self): "by, use_da", [("x", False), ("y", False), ("y", True), ("abc", False)] ) @pytest.mark.parametrize("shortcut", [True, False]) - @pytest.mark.parametrize("squeeze", [True, False]) - def test_groupby_map_identity(self, by, use_da, shortcut, squeeze) -> None: + @pytest.mark.parametrize("squeeze", [None, True, False]) + def test_groupby_map_identity(self, by, use_da, shortcut, squeeze, recwarn) -> None: expected = self.da if use_da: by = expected.coords[by] @@ -1072,6 +1225,10 @@ def identity(x): actual = grouped.map(identity, shortcut=shortcut) assert_identical(expected, actual) + # abc is not a dim coordinate so no warnings expected! + if (by.name if use_da else by) != "abc": + assert len(recwarn) == (1 if squeeze in [None, True] else 0) + def test_groupby_sum(self): array = self.da grouped = array.groupby("abc") @@ -1273,8 +1430,15 @@ def test_groupby_math_not_aligned(self): expected = DataArray([10, 11, np.nan, np.nan], array.coords) assert_identical(expected, actual) + # regression test for #7797 + other = array.groupby("b").sum() + actual = array.sel(x=[0, 1]).groupby("b") - other + expected = DataArray([-1, 0], {"b": ("x", [0, 0]), "x": [0, 1]}, dims="x") + assert_identical(expected, actual) + other = DataArray([10], coords={"c": 123, "b": [0]}, dims="b") actual = array.groupby("b") + other + expected = DataArray([10, 11, np.nan, np.nan], array.coords) expected.coords["c"] = (["x"], [123] * 2 + [np.nan] * 2) assert_identical(expected, actual) @@ -1295,7 +1459,7 @@ def test_groupby_restore_dim_order(self): ("a", ("a", "y")), ("b", ("x", "b")), ]: - result = array.groupby(by).map(lambda x: x.squeeze()) + result = array.groupby(by, squeeze=False).map(lambda x: x.squeeze()) assert result.dims == expected_dims def test_groupby_restore_coord_dims(self): @@ -1315,7 +1479,7 @@ def test_groupby_restore_coord_dims(self): ("a", ("a", "y")), ("b", ("x", "b")), ]: - result = array.groupby(by, restore_coord_dims=True).map( + result = array.groupby(by, squeeze=False, restore_coord_dims=True).map( lambda x: x.squeeze() )["c"] assert result.dims == expected_dims @@ -1370,29 +1534,58 @@ def test_groupby_multidim_map(self): ) assert_identical(expected, actual) - def test_groupby_bins(self): - array = DataArray(np.arange(4), dims="dim_0") + @pytest.mark.parametrize("use_flox", [True, False]) + @pytest.mark.parametrize("coords", [np.arange(4), np.arange(4)[::-1], [2, 0, 3, 1]]) + @pytest.mark.parametrize( + "cut_kwargs", + ( + {"labels": None, "include_lowest": True}, + {"labels": None, "include_lowest": False}, + {"labels": ["a", "b"]}, + {"labels": [1.2, 3.5]}, + {"labels": ["b", "a"]}, + ), + ) + def test_groupby_bins( + self, + coords: np.typing.ArrayLike, + use_flox: bool, + cut_kwargs: dict, + ) -> None: + array = DataArray( + np.arange(4), dims="dim_0", coords={"dim_0": coords}, name="a" + ) # the first value should not be part of any group ("right" binning) array[0] = 99 # bins follow conventions for pandas.cut # http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html bins = [0, 1.5, 5] - bin_coords = pd.cut(array["dim_0"], bins).categories - expected = DataArray( - [1, 5], dims="dim_0_bins", coords={"dim_0_bins": bin_coords} + + df = array.to_dataframe() + df["dim_0_bins"] = pd.cut(array["dim_0"], bins, **cut_kwargs) + + expected_df = df.groupby("dim_0_bins", observed=True).sum() + # TODO: can't convert df with IntervalIndex to Xarray + expected = ( + expected_df.reset_index(drop=True) + .to_xarray() + .assign_coords(index=np.array(expected_df.index)) + .rename({"index": "dim_0_bins"})["a"] ) - actual = array.groupby_bins("dim_0", bins=bins).sum() - assert_identical(expected, actual) - actual = array.groupby_bins("dim_0", bins=bins, labels=[1.2, 3.5]).sum() - assert_identical(expected.assign_coords(dim_0_bins=[1.2, 3.5]), actual) + with xr.set_options(use_flox=use_flox): + actual = array.groupby_bins("dim_0", bins=bins, **cut_kwargs).sum() + assert_identical(expected, actual) - actual = array.groupby_bins("dim_0", bins=bins).map(lambda x: x.sum()) - assert_identical(expected, actual) + actual = array.groupby_bins("dim_0", bins=bins, **cut_kwargs).map( + lambda x: x.sum() + ) + assert_identical(expected, actual) - # make sure original array dims are unchanged - assert len(array.dim_0) == 4 + # make sure original array dims are unchanged + assert len(array.dim_0) == 4 + def test_groupby_bins_ellipsis(self): da = xr.DataArray(np.ones((2, 3, 4))) bins = [-1, 0, 1, 2] with xr.set_options(use_flox=False): @@ -1401,6 +1594,36 @@ def test_groupby_bins(self): expected = da.groupby_bins("dim_0", bins).mean(...) assert_allclose(actual, expected) + @pytest.mark.parametrize("use_flox", [True, False]) + def test_groupby_bins_gives_correct_subset(self, use_flox: bool) -> None: + # GH7766 + rng = np.random.default_rng(42) + coords = rng.normal(5, 5, 1000) + bins = np.logspace(-4, 1, 10) + labels = [ + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + ] + # xArray + # Make a mock dataarray + darr = xr.DataArray(coords, coords=[coords], dims=["coords"]) + expected = xr.DataArray( + [np.nan, np.nan, 1, 1, 1, 8, 31, 104, 542], + dims="coords_bins", + coords={"coords_bins": labels}, + ) + gb = darr.groupby_bins("coords", bins, labels=labels) + with xr.set_options(use_flox=use_flox): + actual = gb.count() + assert_identical(actual, expected) + def test_groupby_bins_empty(self): array = DataArray(np.arange(4), [("x", range(4))]) # one of these bins will be empty @@ -1490,7 +1713,7 @@ def test_resample(self, use_cftime: bool) -> None: if use_cftime and not has_cftime: pytest.skip() times = xr.date_range( - "2000-01-01", freq="6H", periods=10, use_cftime=use_cftime + "2000-01-01", freq="6h", periods=10, use_cftime=use_cftime ) def resample_as_pandas(array, *args, **kwargs): @@ -1508,15 +1731,15 @@ def resample_as_pandas(array, *args, **kwargs): array = DataArray(np.arange(10), [("time", times)]) - actual = array.resample(time="24H").mean() - expected = resample_as_pandas(array, "24H") + actual = array.resample(time="24h").mean() + expected = resample_as_pandas(array, "24h") assert_identical(expected, actual) - actual = array.resample(time="24H").reduce(np.mean) + actual = array.resample(time="24h").reduce(np.mean) assert_identical(expected, actual) - actual = array.resample(time="24H", closed="right").mean() - expected = resample_as_pandas(array, "24H", closed="right") + actual = array.resample(time="24h", closed="right").mean() + expected = resample_as_pandas(array, "24h", closed="right") assert_identical(expected, actual) with pytest.raises(ValueError, match=r"index must be monotonic"): @@ -1535,19 +1758,19 @@ def test_resample_doctest(self, use_cftime: bool) -> None: time=( "time", xr.date_range( - "2001-01-01", freq="M", periods=6, use_cftime=use_cftime + "2001-01-01", freq="ME", periods=6, use_cftime=use_cftime ), ), labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), ), ) - actual = da.resample(time="3M").count() + actual = da.resample(time="3ME").count() expected = DataArray( [1, 3, 1], dims="time", coords={ "time": xr.date_range( - "2001-01-01", freq="3M", periods=3, use_cftime=use_cftime + "2001-01-01", freq="3ME", periods=3, use_cftime=use_cftime ) }, ) @@ -1564,16 +1787,20 @@ def func(arg1, arg2, arg3=0.0): assert_identical(actual, expected) def test_resample_first(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) + times = pd.date_range("2000-01-01", freq="6h", periods=10) array = DataArray(np.arange(10), [("time", times)]) + # resample to same frequency + actual = array.resample(time="6h").first() + assert_identical(array, actual) + actual = array.resample(time="1D").first() expected = DataArray([0, 4, 8], [("time", times[::4])]) assert_identical(expected, actual) # verify that labels don't use the first value - actual = array.resample(time="24H").first() - expected = DataArray(array.to_series().resample("24H").first()) + actual = array.resample(time="24h").first() + expected = DataArray(array.to_series().resample("24h").first()) assert_identical(expected, actual) # missing values @@ -1597,7 +1824,7 @@ def test_resample_first(self): assert_identical(expected, actual) def test_resample_bad_resample_dim(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) + times = pd.date_range("2000-01-01", freq="6h", periods=10) array = DataArray(np.arange(10), [("__resample_dim__", times)]) with pytest.raises(ValueError, match=r"Proxy resampling dimension"): array.resample(**{"__resample_dim__": "1D"}).first() @@ -1606,7 +1833,7 @@ def test_resample_bad_resample_dim(self): def test_resample_drop_nondim_coords(self): xs = np.arange(6) ys = np.arange(3) - times = pd.date_range("2000-01-01", freq="6H", periods=5) + times = pd.date_range("2000-01-01", freq="6h", periods=5) data = np.tile(np.arange(5), (6, 3, 1)) xx, yy = np.meshgrid(xs * 5, ys * 2.5) tt = np.arange(len(times), dtype=int) @@ -1621,21 +1848,21 @@ def test_resample_drop_nondim_coords(self): array = ds["data"] # Re-sample - actual = array.resample(time="12H", restore_coord_dims=True).mean("time") + actual = array.resample(time="12h", restore_coord_dims=True).mean("time") assert "tc" not in actual.coords # Up-sample - filling - actual = array.resample(time="1H", restore_coord_dims=True).ffill() + actual = array.resample(time="1h", restore_coord_dims=True).ffill() assert "tc" not in actual.coords # Up-sample - interpolation - actual = array.resample(time="1H", restore_coord_dims=True).interpolate( + actual = array.resample(time="1h", restore_coord_dims=True).interpolate( "linear" ) assert "tc" not in actual.coords def test_resample_keep_attrs(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) + times = pd.date_range("2000-01-01", freq="6h", periods=10) array = DataArray(np.ones(10), [("time", times)]) array.attrs["meta"] = "data" @@ -1643,13 +1870,8 @@ def test_resample_keep_attrs(self): expected = DataArray([1, 1, 1], [("time", times[::4])], attrs=array.attrs) assert_identical(result, expected) - with pytest.warns( - UserWarning, match="Passing ``keep_attrs`` to ``resample`` has no effect." - ): - array.resample(time="1D", keep_attrs=True) - def test_resample_skipna(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) + times = pd.date_range("2000-01-01", freq="6h", periods=10) array = DataArray(np.ones(10), [("time", times)]) array[1] = np.nan @@ -1658,33 +1880,33 @@ def test_resample_skipna(self): assert_identical(result, expected) def test_upsample(self): - times = pd.date_range("2000-01-01", freq="6H", periods=5) + times = pd.date_range("2000-01-01", freq="6h", periods=5) array = DataArray(np.arange(5), [("time", times)]) # Forward-fill - actual = array.resample(time="3H").ffill() - expected = DataArray(array.to_series().resample("3H").ffill()) + actual = array.resample(time="3h").ffill() + expected = DataArray(array.to_series().resample("3h").ffill()) assert_identical(expected, actual) # Backward-fill - actual = array.resample(time="3H").bfill() - expected = DataArray(array.to_series().resample("3H").bfill()) + actual = array.resample(time="3h").bfill() + expected = DataArray(array.to_series().resample("3h").bfill()) assert_identical(expected, actual) # As frequency - actual = array.resample(time="3H").asfreq() - expected = DataArray(array.to_series().resample("3H").asfreq()) + actual = array.resample(time="3h").asfreq() + expected = DataArray(array.to_series().resample("3h").asfreq()) assert_identical(expected, actual) # Pad - actual = array.resample(time="3H").pad() - expected = DataArray(array.to_series().resample("3H").ffill()) + actual = array.resample(time="3h").pad() + expected = DataArray(array.to_series().resample("3h").ffill()) assert_identical(expected, actual) # Nearest - rs = array.resample(time="3H") + rs = array.resample(time="3h") actual = rs.nearest() - new_times = rs._full_index + new_times = rs.groupers[0].full_index expected = DataArray(array.reindex(time=new_times, method="nearest")) assert_identical(expected, actual) @@ -1692,14 +1914,14 @@ def test_upsample_nd(self): # Same as before, but now we try on multi-dimensional DataArrays. xs = np.arange(6) ys = np.arange(3) - times = pd.date_range("2000-01-01", freq="6H", periods=5) + times = pd.date_range("2000-01-01", freq="6h", periods=5) data = np.tile(np.arange(5), (6, 3, 1)) array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) # Forward-fill - actual = array.resample(time="3H").ffill() + actual = array.resample(time="3h").ffill() expected_data = np.repeat(data, 2, axis=-1) - expected_times = times.to_series().resample("3H").asfreq().index + expected_times = times.to_series().resample("3h").asfreq().index expected_data = expected_data[..., : len(expected_times)] expected = DataArray( expected_data, @@ -1709,10 +1931,10 @@ def test_upsample_nd(self): assert_identical(expected, actual) # Backward-fill - actual = array.resample(time="3H").ffill() + actual = array.resample(time="3h").ffill() expected_data = np.repeat(np.flipud(data.T).T, 2, axis=-1) expected_data = np.flipud(expected_data.T).T - expected_times = times.to_series().resample("3H").asfreq().index + expected_times = times.to_series().resample("3h").asfreq().index expected_data = expected_data[..., : len(expected_times)] expected = DataArray( expected_data, @@ -1722,10 +1944,10 @@ def test_upsample_nd(self): assert_identical(expected, actual) # As frequency - actual = array.resample(time="3H").asfreq() + actual = array.resample(time="3h").asfreq() expected_data = np.repeat(data, 2, axis=-1).astype(float)[..., :-1] expected_data[..., 1::2] = np.nan - expected_times = times.to_series().resample("3H").asfreq().index + expected_times = times.to_series().resample("3h").asfreq().index expected = DataArray( expected_data, {"time": expected_times, "x": xs, "y": ys}, @@ -1734,11 +1956,11 @@ def test_upsample_nd(self): assert_identical(expected, actual) # Pad - actual = array.resample(time="3H").pad() + actual = array.resample(time="3h").pad() expected_data = np.repeat(data, 2, axis=-1) expected_data[..., 1::2] = expected_data[..., ::2] expected_data = expected_data[..., :-1] - expected_times = times.to_series().resample("3H").asfreq().index + expected_times = times.to_series().resample("3h").asfreq().index expected = DataArray( expected_data, {"time": expected_times, "x": xs, "y": ys}, @@ -1749,21 +1971,21 @@ def test_upsample_nd(self): def test_upsample_tolerance(self): # Test tolerance keyword for upsample methods bfill, pad, nearest times = pd.date_range("2000-01-01", freq="1D", periods=2) - times_upsampled = pd.date_range("2000-01-01", freq="6H", periods=5) + times_upsampled = pd.date_range("2000-01-01", freq="6h", periods=5) array = DataArray(np.arange(2), [("time", times)]) # Forward fill - actual = array.resample(time="6H").ffill(tolerance="12H") + actual = array.resample(time="6h").ffill(tolerance="12h") expected = DataArray([0.0, 0.0, 0.0, np.nan, 1.0], [("time", times_upsampled)]) assert_identical(expected, actual) # Backward fill - actual = array.resample(time="6H").bfill(tolerance="12H") + actual = array.resample(time="6h").bfill(tolerance="12h") expected = DataArray([0.0, np.nan, 1.0, 1.0, 1.0], [("time", times_upsampled)]) assert_identical(expected, actual) # Nearest - actual = array.resample(time="6H").nearest(tolerance="6H") + actual = array.resample(time="6h").nearest(tolerance="6h") expected = DataArray([0, 0, np.nan, 1, 1], [("time", times_upsampled)]) assert_identical(expected, actual) @@ -1773,18 +1995,18 @@ def test_upsample_interpolate(self): xs = np.arange(6) ys = np.arange(3) - times = pd.date_range("2000-01-01", freq="6H", periods=5) + times = pd.date_range("2000-01-01", freq="6h", periods=5) z = np.arange(5) ** 2 data = np.tile(z, (6, 3, 1)) array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) - expected_times = times.to_series().resample("1H").asfreq().index + expected_times = times.to_series().resample("1h").asfreq().index # Split the times into equal sub-intervals to simulate the 6 hour # to 1 hour up-sampling new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: - actual = array.resample(time="1H").interpolate(kind) + actual = array.resample(time="1h").interpolate(kind) f = interp1d( np.arange(len(times)), data, @@ -1805,10 +2027,11 @@ def test_upsample_interpolate(self): assert_allclose(expected, actual, rtol=1e-16) @requires_scipy + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_upsample_interpolate_bug_2197(self): dates = pd.date_range("2007-02-01", "2007-03-01", freq="D") da = xr.DataArray(np.arange(len(dates)), [("time", dates)]) - result = da.resample(time="M").interpolate("linear") + result = da.resample(time="ME").interpolate("linear") expected_times = np.array( [np.datetime64("2007-02-28"), np.datetime64("2007-03-31")] ) @@ -1834,7 +2057,7 @@ def test_upsample_interpolate_dask(self, chunked_time): xs = np.arange(6) ys = np.arange(3) - times = pd.date_range("2000-01-01", freq="6H", periods=5) + times = pd.date_range("2000-01-01", freq="6h", periods=5) z = np.arange(5) ** 2 data = np.tile(z, (6, 3, 1)) @@ -1843,12 +2066,12 @@ def test_upsample_interpolate_dask(self, chunked_time): if chunked_time: chunks["time"] = 3 - expected_times = times.to_series().resample("1H").asfreq().index + expected_times = times.to_series().resample("1h").asfreq().index # Split the times into equal sub-intervals to simulate the 6 hour # to 1 hour up-sampling new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: - actual = array.chunk(chunks).resample(time="1H").interpolate(kind) + actual = array.chunk(chunks).resample(time="1h").interpolate(kind) actual = actual.compute() f = interp1d( np.arange(len(times)), @@ -1871,34 +2094,34 @@ def test_upsample_interpolate_dask(self, chunked_time): @pytest.mark.skipif(has_pandas_version_two, reason="requires pandas < 2.0.0") def test_resample_base(self) -> None: - times = pd.date_range("2000-01-01T02:03:01", freq="6H", periods=10) + times = pd.date_range("2000-01-01T02:03:01", freq="6h", periods=10) array = DataArray(np.arange(10), [("time", times)]) base = 11 with pytest.warns(FutureWarning, match="the `base` parameter to resample"): - actual = array.resample(time="24H", base=base).mean() + actual = array.resample(time="24h", base=base).mean() expected = DataArray( - array.to_series().resample("24H", offset=f"{base}H").mean() + array.to_series().resample("24h", offset=f"{base}h").mean() ) assert_identical(expected, actual) def test_resample_offset(self) -> None: - times = pd.date_range("2000-01-01T02:03:01", freq="6H", periods=10) + times = pd.date_range("2000-01-01T02:03:01", freq="6h", periods=10) array = DataArray(np.arange(10), [("time", times)]) - offset = pd.Timedelta("11H") - actual = array.resample(time="24H", offset=offset).mean() - expected = DataArray(array.to_series().resample("24H", offset=offset).mean()) + offset = pd.Timedelta("11h") + actual = array.resample(time="24h", offset=offset).mean() + expected = DataArray(array.to_series().resample("24h", offset=offset).mean()) assert_identical(expected, actual) def test_resample_origin(self) -> None: - times = pd.date_range("2000-01-01T02:03:01", freq="6H", periods=10) + times = pd.date_range("2000-01-01T02:03:01", freq="6h", periods=10) array = DataArray(np.arange(10), [("time", times)]) origin = "start" - actual = array.resample(time="24H", origin=origin).mean() - expected = DataArray(array.to_series().resample("24H", origin=origin).mean()) + actual = array.resample(time="24h", origin=origin).mean() + expected = DataArray(array.to_series().resample("24h", origin=origin).mean()) assert_identical(expected, actual) @pytest.mark.skipif(has_pandas_version_two, reason="requires pandas < 2.0.0") @@ -1912,12 +2135,12 @@ def test_resample_origin(self) -> None: ], ) def test_resample_loffset(self, loffset) -> None: - times = pd.date_range("2000-01-01", freq="6H", periods=10) + times = pd.date_range("2000-01-01", freq="6h", periods=10) array = DataArray(np.arange(10), [("time", times)]) with pytest.warns(FutureWarning, match="`loffset` parameter"): - actual = array.resample(time="24H", loffset=loffset).mean() - series = array.to_series().resample("24H").mean() + actual = array.resample(time="24h", loffset=loffset).mean() + series = array.to_series().resample("24h").mean() if not isinstance(loffset, pd.DateOffset): loffset = pd.Timedelta(loffset) series.index = series.index + loffset @@ -1925,19 +2148,19 @@ def test_resample_loffset(self, loffset) -> None: assert_identical(actual, expected) def test_resample_invalid_loffset(self) -> None: - times = pd.date_range("2000-01-01", freq="6H", periods=10) + times = pd.date_range("2000-01-01", freq="6h", periods=10) array = DataArray(np.arange(10), [("time", times)]) with pytest.warns( FutureWarning, match="Following pandas, the `loffset` parameter" ): with pytest.raises(ValueError, match="`loffset` must be"): - array.resample(time="24H", loffset=1).mean() # type: ignore + array.resample(time="24h", loffset=1).mean() # type: ignore class TestDatasetResample: def test_resample_and_first(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) + times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( { "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), @@ -1951,9 +2174,9 @@ def test_resample_and_first(self): assert_identical(expected, actual) # upsampling - expected_time = pd.date_range("2000-01-01", freq="3H", periods=19) + expected_time = pd.date_range("2000-01-01", freq="3h", periods=19) expected = ds.reindex(time=expected_time) - actual = ds.resample(time="3H") + actual = ds.resample(time="3h") for how in ["mean", "sum", "first", "last"]: method = getattr(actual, how) result = method() @@ -1963,7 +2186,7 @@ def test_resample_and_first(self): assert_equal(expected, result) def test_resample_min_count(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) + times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( { "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), @@ -1985,7 +2208,7 @@ def test_resample_min_count(self): assert_allclose(expected, actual) def test_resample_by_mean_with_keep_attrs(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) + times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( { "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), @@ -2004,13 +2227,8 @@ def test_resample_by_mean_with_keep_attrs(self): expected = ds.attrs assert expected == actual - with pytest.warns( - UserWarning, match="Passing ``keep_attrs`` to ``resample`` has no effect." - ): - ds.resample(time="1D", keep_attrs=True) - def test_resample_loffset(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) + times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( { "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), @@ -2021,7 +2239,7 @@ def test_resample_loffset(self): ds.attrs["dsmeta"] = "dsdata" def test_resample_by_mean_discarding_attrs(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) + times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( { "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), @@ -2037,7 +2255,7 @@ def test_resample_by_mean_discarding_attrs(self): assert resampled_ds.attrs == {} def test_resample_by_last_discarding_attrs(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) + times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( { "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), @@ -2056,7 +2274,7 @@ def test_resample_by_last_discarding_attrs(self): def test_resample_drop_nondim_coords(self): xs = np.arange(6) ys = np.arange(3) - times = pd.date_range("2000-01-01", freq="6H", periods=5) + times = pd.date_range("2000-01-01", freq="6h", periods=5) data = np.tile(np.arange(5), (6, 3, 1)) xx, yy = np.meshgrid(xs * 5, ys * 2.5) tt = np.arange(len(times), dtype=int) @@ -2068,19 +2286,19 @@ def test_resample_drop_nondim_coords(self): ds = ds.set_coords(["xc", "yc", "tc"]) # Re-sample - actual = ds.resample(time="12H").mean("time") + actual = ds.resample(time="12h").mean("time") assert "tc" not in actual.coords # Up-sample - filling - actual = ds.resample(time="1H").ffill() + actual = ds.resample(time="1h").ffill() assert "tc" not in actual.coords # Up-sample - interpolation - actual = ds.resample(time="1H").interpolate("linear") + actual = ds.resample(time="1h").interpolate("linear") assert "tc" not in actual.coords def test_resample_old_api(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) + times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( { "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), @@ -2099,7 +2317,7 @@ def test_resample_old_api(self): ds.resample("1D", dim="time") def test_resample_ds_da_are_the_same(self): - time = pd.date_range("2000-01-01", freq="6H", periods=365 * 4) + time = pd.date_range("2000-01-01", freq="6h", periods=365 * 4) ds = xr.Dataset( { "foo": (("time", "x"), np.random.randn(365 * 4, 5)), @@ -2108,7 +2326,7 @@ def test_resample_ds_da_are_the_same(self): } ) assert_allclose( - ds.resample(time="M").mean()["foo"], ds.foo.resample(time="M").mean() + ds.resample(time="ME").mean()["foo"], ds.foo.resample(time="ME").mean() ) def test_ds_resample_apply_func_args(self): @@ -2183,20 +2401,122 @@ def test_resample_cumsum(method: str, expected_array: list[float]) -> None: ds = xr.Dataset( {"foo": ("time", [1, 2, 3, 1, 2, np.nan])}, coords={ - "time": pd.date_range("01-01-2001", freq="M", periods=6), + "time": xr.date_range("01-01-2001", freq="ME", periods=6, use_cftime=False), }, ) - actual = getattr(ds.resample(time="3M"), method)(dim="time") + actual = getattr(ds.resample(time="3ME"), method)(dim="time") expected = xr.Dataset( {"foo": (("time",), expected_array)}, coords={ - "time": pd.date_range("01-01-2001", freq="M", periods=6), + "time": xr.date_range("01-01-2001", freq="ME", periods=6, use_cftime=False), }, ) # TODO: Remove drop_vars when GH6528 is fixed # when Dataset.cumsum propagates indexes, and the group variable? assert_identical(expected.drop_vars(["time"]), actual) - actual = getattr(ds.foo.resample(time="3M"), method)(dim="time") + actual = getattr(ds.foo.resample(time="3ME"), method)(dim="time") expected.coords["time"] = ds.time assert_identical(expected.drop_vars(["time"]).foo, actual) + + +def test_groupby_binary_op_regression() -> None: + # regression test for #7797 + # monthly timeseries that should return "zero anomalies" everywhere + time = xr.date_range("2023-01-01", "2023-12-31", freq="MS") + data = np.linspace(-1, 1, 12) + x = xr.DataArray(data, coords={"time": time}) + clim = xr.DataArray(data, coords={"month": np.arange(1, 13, 1)}) + + # seems to give the correct result if we use the full x, but not with a slice + x_slice = x.sel(time=["2023-04-01"]) + + # two typical ways of computing anomalies + anom_gb = x_slice.groupby("time.month") - clim + + assert_identical(xr.zeros_like(anom_gb), anom_gb) + + +def test_groupby_multiindex_level() -> None: + # GH6836 + midx = pd.MultiIndex.from_product([list("abc"), [0, 1]], names=("one", "two")) + mda = xr.DataArray(np.random.rand(6, 3), [("x", midx), ("y", range(3))]) + groups = mda.groupby("one").groups + assert groups == {"a": [0, 1], "b": [2, 3], "c": [4, 5]} + + +@requires_flox +@pytest.mark.parametrize("func", ["sum", "prod"]) +@pytest.mark.parametrize("skipna", [True, False]) +@pytest.mark.parametrize("min_count", [None, 1]) +def test_min_count_vs_flox(func: str, min_count: int | None, skipna: bool) -> None: + da = DataArray( + data=np.array([np.nan, 1, 1, np.nan, 1, 1]), + dims="x", + coords={"labels": ("x", np.array([1, 2, 3, 1, 2, 3]))}, + ) + + gb = da.groupby("labels") + method = operator.methodcaller(func, min_count=min_count, skipna=skipna) + with xr.set_options(use_flox=True): + actual = method(gb) + with xr.set_options(use_flox=False): + expected = method(gb) + assert_identical(actual, expected) + + +@pytest.mark.parametrize("use_flox", [True, False]) +def test_min_count_error(use_flox: bool) -> None: + if use_flox and not has_flox: + pytest.skip() + da = DataArray( + data=np.array([np.nan, 1, 1, np.nan, 1, 1]), + dims="x", + coords={"labels": ("x", np.array([1, 2, 3, 1, 2, 3]))}, + ) + with xr.set_options(use_flox=use_flox): + with pytest.raises(TypeError): + da.groupby("labels").mean(min_count=1) + + +@requires_dask +def test_groupby_math_auto_chunk(): + da = xr.DataArray( + [[1, 2, 3], [1, 2, 3], [1, 2, 3]], + dims=("y", "x"), + coords={"label": ("x", [2, 2, 1])}, + ) + sub = xr.DataArray( + InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]} + ) + actual = da.chunk(x=1, y=2).groupby("label") - sub + assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)} + + +@pytest.mark.parametrize("use_flox", [True, False]) +def test_groupby_dim_no_dim_equal(use_flox): + # https://github.com/pydata/xarray/issues/8263 + da = DataArray( + data=[1, 2, 3, 4], dims="lat", coords={"lat": np.linspace(0, 1.01, 4)} + ) + with xr.set_options(use_flox=use_flox): + actual1 = da.drop_vars("lat").groupby("lat", squeeze=False).sum() + actual2 = da.groupby("lat", squeeze=False).sum() + assert_identical(actual1, actual2.drop_vars("lat")) + + +@requires_flox +def test_default_flox_method(): + import flox.xarray + + da = xr.DataArray([1, 2, 3], dims="x", coords={"label": ("x", [2, 2, 1])}) + + result = xr.DataArray([3, 3], dims="label", coords={"label": [1, 2]}) + with mock.patch("flox.xarray.xarray_reduce", return_value=result) as mocked_reduce: + da.groupby("label").sum() + + kwargs = mocked_reduce.call_args.kwargs + if Version(flox.__version__) < Version("0.9.0"): + assert kwargs["method"] == "cohorts" + else: + assert "method" not in kwargs diff --git a/xarray/tests/test_hashable.py b/xarray/tests/test_hashable.py new file mode 100644 index 00000000000..9f92c604dc3 --- /dev/null +++ b/xarray/tests/test_hashable.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Union + +import pytest + +from xarray import DataArray, Dataset, Variable + +if TYPE_CHECKING: + from xarray.core.types import TypeAlias + + DimT: TypeAlias = Union[int, tuple, "DEnum", "CustomHashable"] + + +class DEnum(Enum): + dim = "dim" + + +class CustomHashable: + def __init__(self, a: int) -> None: + self.a = a + + def __hash__(self) -> int: + return self.a + + +parametrize_dim = pytest.mark.parametrize( + "dim", + [ + pytest.param(5, id="int"), + pytest.param(("a", "b"), id="tuple"), + pytest.param(DEnum.dim, id="enum"), + pytest.param(CustomHashable(3), id="HashableObject"), + ], +) + + +@parametrize_dim +def test_hashable_dims(dim: DimT) -> None: + v = Variable([dim], [1, 2, 3]) + da = DataArray([1, 2, 3], dims=[dim]) + Dataset({"a": ([dim], [1, 2, 3])}) + + # alternative constructors + DataArray(v) + Dataset({"a": v}) + Dataset({"a": da}) + + +@parametrize_dim +def test_dataset_variable_hashable_names(dim: DimT) -> None: + Dataset({dim: ("x", [1, 2, 3])}) diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 27b5cf2119c..5ebdfd5da6e 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -145,6 +145,11 @@ def test_from_variables(self) -> None: with pytest.raises(ValueError, match=r".*only accepts one variable.*"): PandasIndex.from_variables({"x": var, "foo": var2}, options={}) + with pytest.raises( + ValueError, match=r".*cannot set a PandasIndex.*scalar variable.*" + ): + PandasIndex.from_variables({"foo": xr.Variable((), 1)}, options={}) + with pytest.raises( ValueError, match=r".*only accepts a 1-dimensional variable.*" ): @@ -347,7 +352,7 @@ def test_constructor(self) -> None: # default level names pd_idx = pd.MultiIndex.from_arrays([foo_data, bar_data]) index = PandasMultiIndex(pd_idx, "x") - assert index.index.names == ("x_level_0", "x_level_1") + assert list(index.index.names) == ["x_level_0", "x_level_1"] def test_from_variables(self) -> None: v_level1 = xr.Variable( @@ -365,7 +370,7 @@ def test_from_variables(self) -> None: assert index.dim == "x" assert index.index.equals(expected_idx) assert index.index.name == "x" - assert index.index.names == ["level1", "level2"] + assert list(index.index.names) == ["level1", "level2"] var = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) with pytest.raises( @@ -408,7 +413,8 @@ def test_stack(self) -> None: index = PandasMultiIndex.stack(prod_vars, "z") assert index.dim == "z" - assert index.index.names == ["x", "y"] + # TODO: change to tuple when pandas 3 is minimum + assert list(index.index.names) == ["x", "y"] np.testing.assert_array_equal( index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] ) @@ -447,6 +453,15 @@ def test_unstack(self) -> None: assert new_indexes["two"].equals(PandasIndex([1, 2, 3], "two")) assert new_pd_idx.equals(pd_midx) + def test_unstack_requires_unique(self) -> None: + pd_midx = pd.MultiIndex.from_product([["a", "a"], [1, 2]], names=["one", "two"]) + index = PandasMultiIndex(pd_midx, "x") + + with pytest.raises( + ValueError, match="Cannot unstack MultiIndex containing duplicates" + ): + index.unstack() + def test_create_variables(self) -> None: foo_data = np.array([0, 0, 1], dtype="int64") bar_data = np.array([1.1, 1.2, 1.3], dtype="float64") @@ -482,7 +497,10 @@ def test_sel(self) -> None: index.sel({"x": 0}) with pytest.raises(ValueError, match=r"cannot provide labels for both.*"): index.sel({"one": 0, "x": "a"}) - with pytest.raises(ValueError, match=r"invalid multi-index level names"): + with pytest.raises( + ValueError, + match=r"multi-index level names \('three',\) not found in indexes", + ): index.sel({"x": {"three": 0}}) with pytest.raises(IndexError): index.sel({"x": (slice(None), 1, "no_level")}) @@ -514,12 +532,12 @@ def test_rename(self) -> None: assert new_index is index new_index = index.rename({"two": "three"}, {}) - assert new_index.index.names == ["one", "three"] + assert list(new_index.index.names) == ["one", "three"] assert new_index.dim == "x" assert new_index.level_coords_dtype == {"one": " None: x_idx = unique_indexes[0] diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 9f57b3b9056..f019d3c789c 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -23,6 +23,28 @@ B = IndexerMaker(indexing.BasicIndexer) +class TestIndexCallable: + def test_getitem(self): + def getter(key): + return key * 2 + + indexer = indexing.IndexCallable(getter) + assert indexer[3] == 6 + assert indexer[0] == 0 + assert indexer[-1] == -2 + + def test_setitem(self): + def getter(key): + return key * 2 + + def setter(key, value): + raise NotImplementedError("Setter not implemented") + + indexer = indexing.IndexCallable(getter, setter) + with pytest.raises(NotImplementedError): + indexer[3] = 6 + + class TestIndexers: def set_to_zero(self, x, i): x = x.copy() @@ -307,6 +329,7 @@ def test_lazily_indexed_array(self) -> None: assert expected.shape == actual.shape assert_array_equal(expected, actual) assert isinstance(actual._data, indexing.LazilyIndexedArray) + assert isinstance(v_lazy._data, indexing.LazilyIndexedArray) # make sure actual.key is appropriate type if all( @@ -327,6 +350,7 @@ def test_lazily_indexed_array(self) -> None: ([0, 3, 5], arr[:2]), ] for i, j in indexers: + expected_b = v[i][j] actual = v_lazy[i][j] assert expected_b.shape == actual.shape @@ -397,6 +421,41 @@ def check_indexing(v_eager, v_lazy, indexers): ] check_indexing(v_eager, v_lazy, indexers) + def test_lazily_indexed_array_vindex_setitem(self) -> None: + + lazy = indexing.LazilyIndexedArray(np.random.rand(10, 20, 30)) + + # vectorized indexing + indexer = indexing.VectorizedIndexer( + (np.array([0, 1]), np.array([0, 1]), slice(None, None, None)) + ) + with pytest.raises( + NotImplementedError, + match=r"Lazy item assignment with the vectorized indexer is not yet", + ): + lazy.vindex[indexer] = 0 + + @pytest.mark.parametrize( + "indexer_class, key, value", + [ + (indexing.OuterIndexer, (0, 1, slice(None, None, None)), 10), + (indexing.BasicIndexer, (0, 1, slice(None, None, None)), 10), + ], + ) + def test_lazily_indexed_array_setitem(self, indexer_class, key, value) -> None: + original = np.random.rand(10, 20, 30) + x = indexing.NumpyIndexingAdapter(original) + lazy = indexing.LazilyIndexedArray(x) + + if indexer_class is indexing.BasicIndexer: + indexer = indexer_class(key) + lazy[indexer] = value + elif indexer_class is indexing.OuterIndexer: + indexer = indexer_class(key) + lazy.oindex[indexer] = value + + assert_array_equal(original[key], value) + class TestCopyOnWriteArray: def test_setitem(self) -> None: @@ -555,7 +614,9 @@ def test_arrayize_vectorized_indexer(self) -> None: vindex_array = indexing._arrayize_vectorized_indexer( vindex, self.data.shape ) - np.testing.assert_array_equal(self.data[vindex], self.data[vindex_array]) + np.testing.assert_array_equal( + self.data.vindex[vindex], self.data.vindex[vindex_array] + ) actual = indexing._arrayize_vectorized_indexer( indexing.VectorizedIndexer((slice(None),)), shape=(5,) @@ -666,16 +727,39 @@ def test_decompose_indexers(shape, indexer_mode, indexing_support) -> None: indexer = get_indexers(shape, indexer_mode) backend_ind, np_ind = indexing.decompose_indexer(indexer, shape, indexing_support) + indexing_adapter = indexing.NumpyIndexingAdapter(data) + + # Dispatch to appropriate indexing method + if indexer_mode.startswith("vectorized"): + expected = indexing_adapter.vindex[indexer] + + elif indexer_mode.startswith("outer"): + expected = indexing_adapter.oindex[indexer] + + else: + expected = indexing_adapter[indexer] # Basic indexing + + if isinstance(backend_ind, indexing.VectorizedIndexer): + array = indexing_adapter.vindex[backend_ind] + elif isinstance(backend_ind, indexing.OuterIndexer): + array = indexing_adapter.oindex[backend_ind] + else: + array = indexing_adapter[backend_ind] - expected = indexing.NumpyIndexingAdapter(data)[indexer] - array = indexing.NumpyIndexingAdapter(data)[backend_ind] if len(np_ind.tuple) > 0: - array = indexing.NumpyIndexingAdapter(array)[np_ind] + array_indexing_adapter = indexing.NumpyIndexingAdapter(array) + if isinstance(np_ind, indexing.VectorizedIndexer): + array = array_indexing_adapter.vindex[np_ind] + elif isinstance(np_ind, indexing.OuterIndexer): + array = array_indexing_adapter.oindex[np_ind] + else: + array = array_indexing_adapter[np_ind] np.testing.assert_array_equal(expected, array) if not all(isinstance(k, indexing.integer_types) for k in np_ind.tuple): combined_ind = indexing._combine_indexers(backend_ind, shape, np_ind) - array = indexing.NumpyIndexingAdapter(data)[combined_ind] + assert isinstance(combined_ind, indexing.VectorizedIndexer) + array = indexing_adapter.vindex[combined_ind] np.testing.assert_array_equal(expected, array) @@ -796,7 +880,7 @@ def test_create_mask_dask() -> None: def test_create_mask_error() -> None: with pytest.raises(TypeError, match=r"unexpected key type"): - indexing.create_mask((1, 2), (3, 4)) + indexing.create_mask((1, 2), (3, 4)) # type: ignore[arg-type] @pytest.mark.parametrize( diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index e66045e978d..a7644ac9d2b 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -605,6 +605,7 @@ def test_interp_like() -> None: pytest.param("2000-01-01T12:00", 0.5, marks=pytest.mark.xfail), ], ) +@pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_datetime(x_new, expected) -> None: da = xr.DataArray( np.arange(24), @@ -738,7 +739,7 @@ def test_datetime_interp_noerror() -> None: xi = xr.DataArray( np.linspace(1, 3, 50), dims=["time"], - coords={"time": pd.date_range("01-01-2001", periods=50, freq="H")}, + coords={"time": pd.date_range("01-01-2001", periods=50, freq="h")}, ) a.interp(x=xi, time=xi.time) # should not raise an error @@ -746,7 +747,7 @@ def test_datetime_interp_noerror() -> None: @requires_cftime @requires_scipy def test_3641() -> None: - times = xr.cftime_range("0001", periods=3, freq="500Y") + times = xr.cftime_range("0001", periods=3, freq="500YE") da = xr.DataArray(range(3), dims=["time"], coords=[times]) da.interp(time=["0002-05-01"]) @@ -837,8 +838,8 @@ def test_interpolate_chunk_1d( if chunked: dest[dim] = xr.DataArray(data=dest[dim], dims=[dim]) dest[dim] = dest[dim].chunk(2) - actual = da.interp(method=method, **dest, kwargs=kwargs) # type: ignore - expected = da.compute().interp(method=method, **dest, kwargs=kwargs) # type: ignore + actual = da.interp(method=method, **dest, kwargs=kwargs) + expected = da.compute().interp(method=method, **dest, kwargs=kwargs) assert_identical(actual, expected) diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index 8957f9c829a..c6597d5abb0 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -235,6 +235,13 @@ def test_merge_dicts_dims(self): expected = xr.Dataset({"x": [12], "y": ("x", [13])}) assert_identical(actual, expected) + def test_merge_coordinates(self): + coords1 = xr.Coordinates({"x": ("x", [0, 1, 2])}) + coords2 = xr.Coordinates({"y": ("y", [3, 4, 5])}) + expected = xr.Dataset(coords={"x": [0, 1, 2], "y": [3, 4, 5]}) + actual = xr.merge([coords1, coords2]) + assert_identical(actual, expected) + def test_merge_error(self): ds = xr.Dataset({"x": 0}) with pytest.raises(xr.MergeError): @@ -375,6 +382,16 @@ def test_merge_compat(self): assert ds1.identical(ds1.merge(ds2, compat="override")) + def test_merge_compat_minimal(self) -> None: + # https://github.com/pydata/xarray/issues/7405 + # https://github.com/pydata/xarray/issues/7588 + ds1 = xr.Dataset(coords={"foo": [1, 2, 3], "bar": 4}) + ds2 = xr.Dataset(coords={"foo": [1, 2, 3], "bar": 5}) + + actual = xr.merge([ds1, ds2], compat="minimal") + expected = xr.Dataset(coords={"foo": [1, 2, 3]}) + assert_identical(actual, expected) + def test_merge_auto_align(self): ds1 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) ds2 = xr.Dataset({"b": ("x", [3, 4]), "x": [1, 2]}) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index a6b6b1f80ce..c1d1058fd6e 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -14,7 +14,7 @@ _get_nan_block_lengths, get_clean_interp_index, ) -from xarray.core.pycompat import array_type +from xarray.namedarray.pycompat import array_type from xarray.tests import ( _CFTIME_CALENDARS, assert_allclose, @@ -24,6 +24,8 @@ requires_bottleneck, requires_cftime, requires_dask, + requires_numbagg, + requires_numbagg_or_bottleneck, requires_scipy, ) @@ -92,32 +94,46 @@ def make_interpolate_example_data(shape, frac_nan, seed=12345, non_uniform=False return da, df +@pytest.mark.parametrize("fill_value", [None, np.nan, 47.11]) +@pytest.mark.parametrize( + "method", ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"] +) @requires_scipy -def test_interpolate_pd_compat(): +def test_interpolate_pd_compat(method, fill_value) -> None: shapes = [(8, 8), (1, 20), (20, 1), (100, 100)] frac_nans = [0, 0.5, 1] - methods = ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"] - for shape, frac_nan, method in itertools.product(shapes, frac_nans, methods): + for shape, frac_nan in itertools.product(shapes, frac_nans): da, df = make_interpolate_example_data(shape, frac_nan) for dim in ["time", "x"]: - actual = da.interpolate_na(method=method, dim=dim, fill_value=np.nan) + actual = da.interpolate_na(method=method, dim=dim, fill_value=fill_value) + # need limit_direction="both" here, to let pandas fill + # in both directions instead of default forward direction only expected = df.interpolate( - method=method, axis=da.get_axis_num(dim), fill_value=(np.nan, np.nan) + method=method, + axis=da.get_axis_num(dim), + limit_direction="both", + fill_value=fill_value, ) - # Note, Pandas does some odd things with the left/right fill_value - # for the linear methods. This next line inforces the xarray - # fill_value convention on the pandas output. Therefore, this test - # only checks that interpolated values are the same (not nans) - expected.values[pd.isnull(actual.values)] = np.nan - np.testing.assert_allclose(actual.values, expected.values) + if method == "linear": + # Note, Pandas does not take left/right fill_value into account + # for the numpy linear methods. + # see https://github.com/pandas-dev/pandas/issues/55144 + # This aligns the pandas output with the xarray output + fixed = expected.values.copy() + fixed[pd.isnull(actual.values)] = np.nan + fixed[actual.values == fill_value] = fill_value + else: + fixed = expected.values + + np.testing.assert_allclose(actual.values, fixed) @requires_scipy -@pytest.mark.parametrize("method", ["barycentric", "krog", "pchip", "spline", "akima"]) -def test_scipy_methods_function(method): +@pytest.mark.parametrize("method", ["barycentric", "krogh", "pchip", "spline", "akima"]) +def test_scipy_methods_function(method) -> None: # Note: Pandas does some wacky things with these methods and the full # integration tests won't work. da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True) @@ -140,7 +156,8 @@ def test_interpolate_pd_compat_non_uniform_index(): method="linear", dim=dim, use_coordinate=True, fill_value=np.nan ) expected = df.interpolate( - method=method, axis=da.get_axis_num(dim), fill_value=np.nan + method=method, + axis=da.get_axis_num(dim), ) # Note, Pandas does some odd things with the left/right fill_value @@ -395,7 +412,7 @@ def test_interpolate_dask_expected_dtype(dtype, method): assert da.dtype == da.compute().dtype -@requires_bottleneck +@requires_numbagg_or_bottleneck def test_ffill(): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") expected = xr.DataArray(np.array([4, 5, 5], dtype=np.float64), dims="x") @@ -403,9 +420,9 @@ def test_ffill(): assert_equal(actual, expected) -def test_ffill_use_bottleneck(): +def test_ffill_use_bottleneck_numbagg(): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") - with xr.set_options(use_bottleneck=False): + with xr.set_options(use_bottleneck=False, use_numbagg=False): with pytest.raises(RuntimeError): da.ffill("x") @@ -414,14 +431,24 @@ def test_ffill_use_bottleneck(): def test_ffill_use_bottleneck_dask(): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") da = da.chunk({"x": 1}) - with xr.set_options(use_bottleneck=False): + with xr.set_options(use_bottleneck=False, use_numbagg=False): with pytest.raises(RuntimeError): da.ffill("x") +@requires_numbagg +@requires_dask +def test_ffill_use_numbagg_dask(): + with xr.set_options(use_bottleneck=False): + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + da = da.chunk(x=-1) + # Succeeds with a single chunk: + _ = da.ffill("x").compute() + + def test_bfill_use_bottleneck(): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") - with xr.set_options(use_bottleneck=False): + with xr.set_options(use_bottleneck=False, use_numbagg=False): with pytest.raises(RuntimeError): da.bfill("x") @@ -430,7 +457,7 @@ def test_bfill_use_bottleneck(): def test_bfill_use_bottleneck_dask(): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") da = da.chunk({"x": 1}) - with xr.set_options(use_bottleneck=False): + with xr.set_options(use_bottleneck=False, use_numbagg=False): with pytest.raises(RuntimeError): da.bfill("x") @@ -524,7 +551,7 @@ def test_ffill_limit(): def test_interpolate_dataset(ds): actual = ds.interpolate_na(dim="time") # no missing values in var1 - assert actual["var1"].count("time") == actual.dims["time"] + assert actual["var1"].count("time") == actual.sizes["time"] # var2 should be the same as it was assert_array_equal(actual["var2"], ds["var2"]) @@ -582,7 +609,7 @@ def test_get_clean_interp_index_cf_calendar(cf_da, calendar): @requires_cftime @pytest.mark.parametrize( - ("calendar", "freq"), zip(["gregorian", "proleptic_gregorian"], ["1D", "1M", "1Y"]) + ("calendar", "freq"), zip(["gregorian", "proleptic_gregorian"], ["1D", "1ME", "1Y"]) ) def test_get_clean_interp_index_dt(cf_da, calendar, freq): """In the gregorian case, the index should be proportional to normal datetimes.""" @@ -633,12 +660,12 @@ def test_interpolate_na_max_gap_errors(da_time): with pytest.raises(ValueError, match=r"max_gap must be a scalar."): da_time.interpolate_na("t", max_gap=(1,)) - da_time["t"] = pd.date_range("2001-01-01", freq="H", periods=11) + da_time["t"] = pd.date_range("2001-01-01", freq="h", periods=11) with pytest.raises(TypeError, match=r"Expected value of type str"): da_time.interpolate_na("t", max_gap=1) with pytest.raises(TypeError, match=r"Expected integer or floating point"): - da_time.interpolate_na("t", max_gap="1H", use_coordinate=False) + da_time.interpolate_na("t", max_gap="1h", use_coordinate=False) with pytest.raises(ValueError, match=r"Could not convert 'huh' to timedelta64"): da_time.interpolate_na("t", max_gap="huh") @@ -651,12 +678,12 @@ def test_interpolate_na_max_gap_errors(da_time): ) @pytest.mark.parametrize("transform", [lambda x: x, lambda x: x.to_dataset(name="a")]) @pytest.mark.parametrize( - "max_gap", ["3H", np.timedelta64(3, "h"), pd.to_timedelta("3H")] + "max_gap", ["3h", np.timedelta64(3, "h"), pd.to_timedelta("3h")] ) def test_interpolate_na_max_gap_time_specifier( da_time, max_gap, transform, time_range_func ): - da_time["t"] = time_range_func("2001-01-01", freq="H", periods=11) + da_time["t"] = time_range_func("2001-01-01", freq="h", periods=11) expected = transform( da_time.copy(data=[np.nan, 1, 2, 3, 4, 5, np.nan, np.nan, np.nan, np.nan, 10]) ) diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py new file mode 100644 index 00000000000..2a3faf32b85 --- /dev/null +++ b/xarray/tests/test_namedarray.py @@ -0,0 +1,574 @@ +from __future__ import annotations + +import copy +import warnings +from abc import abstractmethod +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Generic, cast, overload + +import numpy as np +import pytest + +from xarray.core.indexing import ExplicitlyIndexed +from xarray.namedarray._typing import ( + _arrayfunction_or_api, + _default, + _DType_co, + _ShapeType_co, +) +from xarray.namedarray.core import NamedArray, from_array + +if TYPE_CHECKING: + from types import ModuleType + + from numpy.typing import ArrayLike, DTypeLike, NDArray + + from xarray.namedarray._typing import ( + Default, + _AttrsLike, + _Dim, + _DimsLike, + _DType, + _IndexKeyLike, + _IntOrUnknown, + _Shape, + _ShapeLike, + duckarray, + ) + + +class CustomArrayBase(Generic[_ShapeType_co, _DType_co]): + def __init__(self, array: duckarray[Any, _DType_co]) -> None: + self.array: duckarray[Any, _DType_co] = array + + @property + def dtype(self) -> _DType_co: + return self.array.dtype + + @property + def shape(self) -> _Shape: + return self.array.shape + + +class CustomArray( + CustomArrayBase[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] +): + def __array__(self) -> np.ndarray[Any, np.dtype[np.generic]]: + return np.array(self.array) + + +class CustomArrayIndexable( + CustomArrayBase[_ShapeType_co, _DType_co], + ExplicitlyIndexed, + Generic[_ShapeType_co, _DType_co], +): + def __getitem__( + self, key: _IndexKeyLike | CustomArrayIndexable[Any, Any], / + ) -> CustomArrayIndexable[Any, _DType_co]: + if isinstance(key, CustomArrayIndexable): + if isinstance(key.array, type(self.array)): + # TODO: key.array is duckarray here, can it be narrowed down further? + # an _arrayapi cannot be used on a _arrayfunction for example. + return type(self)(array=self.array[key.array]) # type: ignore[index] + else: + raise TypeError("key must have the same array type as self") + else: + return type(self)(array=self.array[key]) + + def __array_namespace__(self) -> ModuleType: + return np + + +class NamedArraySubclassobjects: + @pytest.fixture + def target(self, data: np.ndarray[Any, Any]) -> Any: + """Fixture that needs to be overridden""" + raise NotImplementedError + + @abstractmethod + def cls(self, *args: Any, **kwargs: Any) -> Any: + """Method that needs to be overridden""" + raise NotImplementedError + + @pytest.fixture + def data(self) -> np.ndarray[Any, np.dtype[Any]]: + return 0.5 * np.arange(10).reshape(2, 5) + + @pytest.fixture + def random_inputs(self) -> np.ndarray[Any, np.dtype[np.float32]]: + return np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + + def test_properties(self, target: Any, data: Any) -> None: + assert target.dims == ("x", "y") + assert np.array_equal(target.data, data) + assert target.dtype == float + assert target.shape == (2, 5) + assert target.ndim == 2 + assert target.sizes == {"x": 2, "y": 5} + assert target.size == 10 + assert target.nbytes == 80 + assert len(target) == 2 + + def test_attrs(self, target: Any) -> None: + assert target.attrs == {} + attrs = {"foo": "bar"} + target.attrs = attrs + assert target.attrs == attrs + assert isinstance(target.attrs, dict) + target.attrs["foo"] = "baz" + assert target.attrs["foo"] == "baz" + + @pytest.mark.parametrize( + "expected", [np.array([1, 2], dtype=np.dtype(np.int8)), [1, 2]] + ) + def test_init(self, expected: Any) -> None: + actual = self.cls(("x",), expected) + assert np.array_equal(np.asarray(actual.data), expected) + + actual = self.cls(("x",), expected) + assert np.array_equal(np.asarray(actual.data), expected) + + def test_data(self, random_inputs: Any) -> None: + expected = self.cls(["x", "y", "z"], random_inputs) + assert np.array_equal(np.asarray(expected.data), random_inputs) + with pytest.raises(ValueError): + expected.data = np.random.random((3, 4)).astype(np.float64) + d2 = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + expected.data = d2 + assert np.array_equal(np.asarray(expected.data), d2) + + +class TestNamedArray(NamedArraySubclassobjects): + def cls(self, *args: Any, **kwargs: Any) -> NamedArray[Any, Any]: + return NamedArray(*args, **kwargs) + + @pytest.fixture + def target(self, data: np.ndarray[Any, Any]) -> NamedArray[Any, Any]: + return NamedArray(["x", "y"], data) + + @pytest.mark.parametrize( + "expected", + [ + np.array([1, 2], dtype=np.dtype(np.int8)), + pytest.param( + [1, 2], + marks=pytest.mark.xfail( + reason="NamedArray only supports array-like objects" + ), + ), + ], + ) + def test_init(self, expected: Any) -> None: + super().test_init(expected) + + @pytest.mark.parametrize( + "dims, data, expected, raise_error", + [ + (("x",), [1, 2, 3], np.array([1, 2, 3]), False), + ((1,), np.array([4, 5, 6]), np.array([4, 5, 6]), False), + ((), 2, np.array(2), False), + # Fail: + ( + ("x",), + NamedArray("time", np.array([1, 2, 3])), + np.array([1, 2, 3]), + True, + ), + ], + ) + def test_from_array( + self, + dims: _DimsLike, + data: ArrayLike, + expected: np.ndarray[Any, Any], + raise_error: bool, + ) -> None: + actual: NamedArray[Any, Any] + if raise_error: + with pytest.raises(TypeError, match="already a Named array"): + actual = from_array(dims, data) + + # Named arrays are not allowed: + from_array(actual) # type: ignore[call-overload] + else: + actual = from_array(dims, data) + + assert np.array_equal(np.asarray(actual.data), expected) + + def test_from_array_with_masked_array(self) -> None: + masked_array: np.ndarray[Any, np.dtype[np.generic]] + masked_array = np.ma.array([1, 2, 3], mask=[False, True, False]) # type: ignore[no-untyped-call] + with pytest.raises(NotImplementedError): + from_array(("x",), masked_array) + + def test_from_array_with_0d_object(self) -> None: + data = np.empty((), dtype=object) + data[()] = (10, 12, 12) + narr = from_array((), data) + np.array_equal(np.asarray(narr.data), data) + + # TODO: Make xr.core.indexing.ExplicitlyIndexed pass as a subclass of_arrayfunction_or_api + # and remove this test. + def test_from_array_with_explicitly_indexed( + self, random_inputs: np.ndarray[Any, Any] + ) -> None: + array: CustomArray[Any, Any] + array = CustomArray(random_inputs) + output: NamedArray[Any, Any] + output = from_array(("x", "y", "z"), array) + assert isinstance(output.data, np.ndarray) + + array2: CustomArrayIndexable[Any, Any] + array2 = CustomArrayIndexable(random_inputs) + output2: NamedArray[Any, Any] + output2 = from_array(("x", "y", "z"), array2) + assert isinstance(output2.data, CustomArrayIndexable) + + def test_real_and_imag(self) -> None: + expected_real: np.ndarray[Any, np.dtype[np.float64]] + expected_real = np.arange(3, dtype=np.float64) + + expected_imag: np.ndarray[Any, np.dtype[np.float64]] + expected_imag = -np.arange(3, dtype=np.float64) + + arr: np.ndarray[Any, np.dtype[np.complex128]] + arr = expected_real + 1j * expected_imag + + named_array: NamedArray[Any, np.dtype[np.complex128]] + named_array = NamedArray(["x"], arr) + + actual_real: duckarray[Any, np.dtype[np.float64]] = named_array.real.data + assert np.array_equal(np.asarray(actual_real), expected_real) + assert actual_real.dtype == expected_real.dtype + + actual_imag: duckarray[Any, np.dtype[np.float64]] = named_array.imag.data + assert np.array_equal(np.asarray(actual_imag), expected_imag) + assert actual_imag.dtype == expected_imag.dtype + + # Additional tests as per your original class-based code + @pytest.mark.parametrize( + "data, dtype", + [ + ("foo", np.dtype("U3")), + (b"foo", np.dtype("S3")), + ], + ) + def test_from_array_0d_string(self, data: Any, dtype: DTypeLike) -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], data) + assert named_array.data == data + assert named_array.dims == () + assert named_array.sizes == {} + assert named_array.attrs == {} + assert named_array.ndim == 0 + assert named_array.size == 1 + assert named_array.dtype == dtype + + def test_from_array_0d_object(self) -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], (10, 12, 12)) + expected_data = np.empty((), dtype=object) + expected_data[()] = (10, 12, 12) + assert np.array_equal(np.asarray(named_array.data), expected_data) + + assert named_array.dims == () + assert named_array.sizes == {} + assert named_array.attrs == {} + assert named_array.ndim == 0 + assert named_array.size == 1 + assert named_array.dtype == np.dtype("O") + + def test_from_array_0d_datetime(self) -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], np.datetime64("2000-01-01")) + assert named_array.dtype == np.dtype("datetime64[D]") + + @pytest.mark.parametrize( + "timedelta, expected_dtype", + [ + (np.timedelta64(1, "D"), np.dtype("timedelta64[D]")), + (np.timedelta64(1, "s"), np.dtype("timedelta64[s]")), + (np.timedelta64(1, "m"), np.dtype("timedelta64[m]")), + (np.timedelta64(1, "h"), np.dtype("timedelta64[h]")), + (np.timedelta64(1, "us"), np.dtype("timedelta64[us]")), + (np.timedelta64(1, "ns"), np.dtype("timedelta64[ns]")), + (np.timedelta64(1, "ps"), np.dtype("timedelta64[ps]")), + (np.timedelta64(1, "fs"), np.dtype("timedelta64[fs]")), + (np.timedelta64(1, "as"), np.dtype("timedelta64[as]")), + ], + ) + def test_from_array_0d_timedelta( + self, timedelta: np.timedelta64, expected_dtype: np.dtype[np.timedelta64] + ) -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], timedelta) + assert named_array.dtype == expected_dtype + assert named_array.data == timedelta + + @pytest.mark.parametrize( + "dims, data_shape, new_dims, raises", + [ + (["x", "y", "z"], (2, 3, 4), ["a", "b", "c"], False), + (["x", "y", "z"], (2, 3, 4), ["a", "b"], True), + (["x", "y", "z"], (2, 4, 5), ["a", "b", "c", "d"], True), + ([], [], (), False), + ([], [], ("x",), True), + ], + ) + def test_dims_setter( + self, dims: Any, data_shape: Any, new_dims: Any, raises: bool + ) -> None: + named_array: NamedArray[Any, Any] + named_array = NamedArray(dims, np.asarray(np.random.random(data_shape))) + assert named_array.dims == tuple(dims) + if raises: + with pytest.raises(ValueError): + named_array.dims = new_dims + else: + named_array.dims = new_dims + assert named_array.dims == tuple(new_dims) + + def test_duck_array_class( + self, + ) -> None: + def test_duck_array_typevar( + a: duckarray[Any, _DType], + ) -> duckarray[Any, _DType]: + # Mypy checks a is valid: + b: duckarray[Any, _DType] = a + + # Runtime check if valid: + if isinstance(b, _arrayfunction_or_api): + return b + else: + raise TypeError( + f"a ({type(a)}) is not a valid _arrayfunction or _arrayapi" + ) + + numpy_a: NDArray[np.int64] + numpy_a = np.array([2.1, 4], dtype=np.dtype(np.int64)) + test_duck_array_typevar(numpy_a) + + masked_a: np.ma.MaskedArray[Any, np.dtype[np.int64]] + masked_a = np.ma.asarray([2.1, 4], dtype=np.dtype(np.int64)) # type: ignore[no-untyped-call] + test_duck_array_typevar(masked_a) + + custom_a: CustomArrayIndexable[Any, np.dtype[np.int64]] + custom_a = CustomArrayIndexable(numpy_a) + test_duck_array_typevar(custom_a) + + # Test numpy's array api: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + r"The numpy.array_api submodule is still experimental", + category=UserWarning, + ) + import numpy.array_api as nxp + + # TODO: nxp doesn't use dtype typevars, so can only use Any for the moment: + arrayapi_a: duckarray[Any, Any] # duckarray[Any, np.dtype[np.int64]] + arrayapi_a = nxp.asarray([2.1, 4], dtype=np.dtype(np.int64)) + test_duck_array_typevar(arrayapi_a) + + def test_new_namedarray(self) -> None: + dtype_float = np.dtype(np.float32) + narr_float: NamedArray[Any, np.dtype[np.float32]] + narr_float = NamedArray(("x",), np.array([1.5, 3.2], dtype=dtype_float)) + assert narr_float.dtype == dtype_float + + dtype_int = np.dtype(np.int8) + narr_int: NamedArray[Any, np.dtype[np.int8]] + narr_int = narr_float._new(("x",), np.array([1, 3], dtype=dtype_int)) + assert narr_int.dtype == dtype_int + + class Variable( + NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] + ): + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: duckarray[Any, _DType] = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[Any, _DType]: ... + + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: Default = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[_ShapeType_co, _DType_co]: ... + + def _new( + self, + dims: _DimsLike | Default = _default, + data: duckarray[Any, _DType] | Default = _default, + attrs: _AttrsLike | Default = _default, + ) -> Variable[Any, _DType] | Variable[_ShapeType_co, _DType_co]: + dims_ = copy.copy(self._dims) if dims is _default else dims + + attrs_: Mapping[Any, Any] | None + if attrs is _default: + attrs_ = None if self._attrs is None else self._attrs.copy() + else: + attrs_ = attrs + + if data is _default: + return type(self)(dims_, copy.copy(self._data), attrs_) + cls_ = cast("type[Variable[Any, _DType]]", type(self)) + return cls_(dims_, data, attrs_) + + var_float: Variable[Any, np.dtype[np.float32]] + var_float = Variable(("x",), np.array([1.5, 3.2], dtype=dtype_float)) + assert var_float.dtype == dtype_float + + var_int: Variable[Any, np.dtype[np.int8]] + var_int = var_float._new(("x",), np.array([1, 3], dtype=dtype_int)) + assert var_int.dtype == dtype_int + + def test_replace_namedarray(self) -> None: + dtype_float = np.dtype(np.float32) + np_val: np.ndarray[Any, np.dtype[np.float32]] + np_val = np.array([1.5, 3.2], dtype=dtype_float) + np_val2: np.ndarray[Any, np.dtype[np.float32]] + np_val2 = 2 * np_val + + narr_float: NamedArray[Any, np.dtype[np.float32]] + narr_float = NamedArray(("x",), np_val) + assert narr_float.dtype == dtype_float + + narr_float2: NamedArray[Any, np.dtype[np.float32]] + narr_float2 = NamedArray(("x",), np_val2) + assert narr_float2.dtype == dtype_float + + class Variable( + NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] + ): + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: duckarray[Any, _DType] = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[Any, _DType]: ... + + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: Default = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[_ShapeType_co, _DType_co]: ... + + def _new( + self, + dims: _DimsLike | Default = _default, + data: duckarray[Any, _DType] | Default = _default, + attrs: _AttrsLike | Default = _default, + ) -> Variable[Any, _DType] | Variable[_ShapeType_co, _DType_co]: + dims_ = copy.copy(self._dims) if dims is _default else dims + + attrs_: Mapping[Any, Any] | None + if attrs is _default: + attrs_ = None if self._attrs is None else self._attrs.copy() + else: + attrs_ = attrs + + if data is _default: + return type(self)(dims_, copy.copy(self._data), attrs_) + cls_ = cast("type[Variable[Any, _DType]]", type(self)) + return cls_(dims_, data, attrs_) + + var_float: Variable[Any, np.dtype[np.float32]] + var_float = Variable(("x",), np_val) + assert var_float.dtype == dtype_float + + var_float2: Variable[Any, np.dtype[np.float32]] + var_float2 = var_float._replace(("x",), np_val2) + assert var_float2.dtype == dtype_float + + @pytest.mark.parametrize( + "dim,expected_ndim,expected_shape,expected_dims", + [ + (None, 3, (1, 2, 5), (None, "x", "y")), + (_default, 3, (1, 2, 5), ("dim_2", "x", "y")), + ("z", 3, (1, 2, 5), ("z", "x", "y")), + ], + ) + def test_expand_dims( + self, + target: NamedArray[Any, np.dtype[np.float32]], + dim: _Dim | Default, + expected_ndim: int, + expected_shape: _ShapeLike, + expected_dims: _DimsLike, + ) -> None: + result = target.expand_dims(dim=dim) + assert result.ndim == expected_ndim + assert result.shape == expected_shape + assert result.dims == expected_dims + + @pytest.mark.parametrize( + "dims, expected_sizes", + [ + ((), {"y": 5, "x": 2}), + (["y", "x"], {"y": 5, "x": 2}), + (["y", ...], {"y": 5, "x": 2}), + ], + ) + def test_permute_dims( + self, + target: NamedArray[Any, np.dtype[np.float32]], + dims: _DimsLike, + expected_sizes: dict[_Dim, _IntOrUnknown], + ) -> None: + actual = target.permute_dims(*dims) + assert actual.sizes == expected_sizes + + def test_permute_dims_errors( + self, + target: NamedArray[Any, np.dtype[np.float32]], + ) -> None: + with pytest.raises(ValueError, match=r"'y'.*permuted list"): + dims = ["y"] + target.permute_dims(*dims) + + @pytest.mark.parametrize( + "broadcast_dims,expected_ndim", + [ + ({"x": 2, "y": 5}, 2), + ({"x": 2, "y": 5, "z": 2}, 3), + ({"w": 1, "x": 2, "y": 5}, 3), + ], + ) + def test_broadcast_to( + self, + target: NamedArray[Any, np.dtype[np.float32]], + broadcast_dims: Mapping[_Dim, int], + expected_ndim: int, + ) -> None: + expand_dims = set(broadcast_dims.keys()) - set(target.dims) + # loop over expand_dims and call .expand_dims(dim=dim) in a loop + for dim in expand_dims: + target = target.expand_dims(dim=dim) + result = target.broadcast_to(broadcast_dims) + assert result.ndim == expected_ndim + assert result.sizes == broadcast_dims + + def test_broadcast_to_errors( + self, target: NamedArray[Any, np.dtype[np.float32]] + ) -> None: + with pytest.raises( + ValueError, + match=r"operands could not be broadcast together with remapped shapes", + ): + target.broadcast_to({"x": 2, "y": 2}) + + with pytest.raises(ValueError, match=r"Cannot add new dimensions"): + target.broadcast_to({"x": 2, "y": 2, "z": 2}) + + def test_warn_on_repeated_dimension_names(self) -> None: + with pytest.warns(UserWarning, match="Duplicate dimension names"): + NamedArray(("x", "x"), np.arange(4).reshape(2, 2)) diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py index 3cecf1b52ec..8ad1cbe11be 100644 --- a/xarray/tests/test_options.py +++ b/xarray/tests/test_options.py @@ -165,7 +165,6 @@ def test_concat_attr_retention(self) -> None: result = concat([ds1, ds2], dim="dim1") assert result.attrs == original_attrs - @pytest.mark.xfail def test_merge_attr_retention(self) -> None: da1 = create_test_dataarray_attrs(var="var1") da2 = create_test_dataarray_attrs(var="var2") diff --git a/xarray/tests/test_parallelcompat.py b/xarray/tests/test_parallelcompat.py new file mode 100644 index 00000000000..dbe40be710c --- /dev/null +++ b/xarray/tests/test_parallelcompat.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +from importlib.metadata import EntryPoint +from typing import Any + +import numpy as np +import pytest + +from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks +from xarray.namedarray._typing import _Chunks +from xarray.namedarray.daskmanager import DaskManager +from xarray.namedarray.parallelcompat import ( + ChunkManagerEntrypoint, + get_chunked_array_type, + guess_chunkmanager, + list_chunkmanagers, + load_chunkmanagers, +) +from xarray.tests import has_dask, requires_dask + + +class DummyChunkedArray(np.ndarray): + """ + Mock-up of a chunked array class. + + Adds a (non-functional) .chunks attribute by following this example in the numpy docs + https://numpy.org/doc/stable/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray + """ + + chunks: T_NormalizedChunks + + def __new__( + cls, + shape, + dtype=float, + buffer=None, + offset=0, + strides=None, + order=None, + chunks=None, + ): + obj = super().__new__(cls, shape, dtype, buffer, offset, strides, order) + obj.chunks = chunks + return obj + + def __array_finalize__(self, obj): + if obj is None: + return + self.chunks = getattr(obj, "chunks", None) + + def rechunk(self, chunks, **kwargs): + copied = self.copy() + copied.chunks = chunks + return copied + + +class DummyChunkManager(ChunkManagerEntrypoint): + """Mock-up of ChunkManager class for DummyChunkedArray""" + + def __init__(self): + self.array_cls = DummyChunkedArray + + def is_chunked_array(self, data: Any) -> bool: + return isinstance(data, DummyChunkedArray) + + def chunks(self, data: DummyChunkedArray) -> T_NormalizedChunks: + return data.chunks + + def normalize_chunks( + self, + chunks: T_Chunks | T_NormalizedChunks, + shape: tuple[int, ...] | None = None, + limit: int | None = None, + dtype: np.dtype | None = None, + previous_chunks: T_NormalizedChunks | None = None, + ) -> T_NormalizedChunks: + from dask.array.core import normalize_chunks + + return normalize_chunks(chunks, shape, limit, dtype, previous_chunks) + + def from_array( + self, data: T_DuckArray | np.typing.ArrayLike, chunks: _Chunks, **kwargs + ) -> DummyChunkedArray: + from dask import array as da + + return da.from_array(data, chunks, **kwargs) + + def rechunk(self, data: DummyChunkedArray, chunks, **kwargs) -> DummyChunkedArray: + return data.rechunk(chunks, **kwargs) + + def compute(self, *data: DummyChunkedArray, **kwargs) -> tuple[np.ndarray, ...]: + from dask.array import compute + + return compute(*data, **kwargs) + + def apply_gufunc( + self, + func, + signature, + *args, + axes=None, + axis=None, + keepdims=False, + output_dtypes=None, + output_sizes=None, + vectorize=None, + allow_rechunk=False, + meta=None, + **kwargs, + ): + from dask.array.gufunc import apply_gufunc + + return apply_gufunc( + func, + signature, + *args, + axes=axes, + axis=axis, + keepdims=keepdims, + output_dtypes=output_dtypes, + output_sizes=output_sizes, + vectorize=vectorize, + allow_rechunk=allow_rechunk, + meta=meta, + **kwargs, + ) + + +@pytest.fixture +def register_dummy_chunkmanager(monkeypatch): + """ + Mocks the registering of an additional ChunkManagerEntrypoint. + + This preserves the presence of the existing DaskManager, so a test that relies on this and DaskManager both being + returned from list_chunkmanagers() at once would still work. + + The monkeypatching changes the behavior of list_chunkmanagers when called inside xarray.namedarray.parallelcompat, + but not when called from this tests file. + """ + # Should include DaskManager iff dask is available to be imported + preregistered_chunkmanagers = list_chunkmanagers() + + monkeypatch.setattr( + "xarray.namedarray.parallelcompat.list_chunkmanagers", + lambda: {"dummy": DummyChunkManager()} | preregistered_chunkmanagers, + ) + yield + + +class TestGetChunkManager: + def test_get_chunkmanger(self, register_dummy_chunkmanager) -> None: + chunkmanager = guess_chunkmanager("dummy") + assert isinstance(chunkmanager, DummyChunkManager) + + def test_fail_on_nonexistent_chunkmanager(self) -> None: + with pytest.raises(ValueError, match="unrecognized chunk manager foo"): + guess_chunkmanager("foo") + + @requires_dask + def test_get_dask_if_installed(self) -> None: + chunkmanager = guess_chunkmanager(None) + assert isinstance(chunkmanager, DaskManager) + + @pytest.mark.skipif(has_dask, reason="requires dask not to be installed") + def test_dont_get_dask_if_not_installed(self) -> None: + with pytest.raises(ValueError, match="unrecognized chunk manager dask"): + guess_chunkmanager("dask") + + @requires_dask + def test_choose_dask_over_other_chunkmanagers( + self, register_dummy_chunkmanager + ) -> None: + chunk_manager = guess_chunkmanager(None) + assert isinstance(chunk_manager, DaskManager) + + +class TestGetChunkedArrayType: + def test_detect_chunked_arrays(self, register_dummy_chunkmanager) -> None: + dummy_arr = DummyChunkedArray([1, 2, 3]) + + chunk_manager = get_chunked_array_type(dummy_arr) + assert isinstance(chunk_manager, DummyChunkManager) + + def test_ignore_inmemory_arrays(self, register_dummy_chunkmanager) -> None: + dummy_arr = DummyChunkedArray([1, 2, 3]) + + chunk_manager = get_chunked_array_type(*[dummy_arr, 1.0, np.array([5, 6])]) + assert isinstance(chunk_manager, DummyChunkManager) + + with pytest.raises(TypeError, match="Expected a chunked array"): + get_chunked_array_type(5.0) + + def test_raise_if_no_arrays_chunked(self, register_dummy_chunkmanager) -> None: + with pytest.raises(TypeError, match="Expected a chunked array "): + get_chunked_array_type(*[1.0, np.array([5, 6])]) + + def test_raise_if_no_matching_chunkmanagers(self) -> None: + dummy_arr = DummyChunkedArray([1, 2, 3]) + + with pytest.raises( + TypeError, match="Could not find a Chunk Manager which recognises" + ): + get_chunked_array_type(dummy_arr) + + @requires_dask + def test_detect_dask_if_installed(self) -> None: + import dask.array as da + + dask_arr = da.from_array([1, 2, 3], chunks=(1,)) + + chunk_manager = get_chunked_array_type(dask_arr) + assert isinstance(chunk_manager, DaskManager) + + @requires_dask + def test_raise_on_mixed_array_types(self, register_dummy_chunkmanager) -> None: + import dask.array as da + + dummy_arr = DummyChunkedArray([1, 2, 3]) + dask_arr = da.from_array([1, 2, 3], chunks=(1,)) + + with pytest.raises(TypeError, match="received multiple types"): + get_chunked_array_type(*[dask_arr, dummy_arr]) + + +def test_bogus_entrypoint() -> None: + # Create a bogus entry-point as if the user broke their setup.cfg + # or is actively developing their new chunk manager + entry_point = EntryPoint( + "bogus", "xarray.bogus.doesnotwork", "xarray.chunkmanagers" + ) + with pytest.warns(UserWarning, match="Failed to load chunk manager"): + assert len(load_chunkmanagers([entry_point])) == 0 diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 02f7f4b9be2..6f983a121fe 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -15,7 +15,7 @@ import xarray as xr import xarray.plot as xplt from xarray import DataArray, Dataset -from xarray.core.utils import module_available +from xarray.namedarray.utils import module_available from xarray.plot.dataarray_plot import _infer_interval_breaks from xarray.plot.dataset_plot import _infer_meta_data from xarray.plot.utils import ( @@ -43,6 +43,7 @@ # import mpl and change the backend before other mpl imports try: import matplotlib as mpl + import matplotlib.dates import matplotlib.pyplot as plt import mpl_toolkits except ImportError: @@ -167,7 +168,14 @@ def imshow_called(self, plotmethod): def contourf_called(self, plotmethod): plotmethod() - paths = plt.gca().findobj(mpl.collections.PathCollection) + + # Compatible with mpl before (PathCollection) and after (QuadContourSet) 3.8 + def matchfunc(x): + return isinstance( + x, (mpl.collections.PathCollection, mpl.contour.QuadContourSet) + ) + + paths = plt.gca().findobj(matchfunc) return len(paths) > 0 @@ -421,6 +429,7 @@ def test2d_1d_2d_coordinates_pcolormesh(self) -> None: ]: p = a.plot.pcolormesh(x=x, y=y) v = p.get_paths()[0].vertices + assert isinstance(v, np.ndarray) # Check all vertices are different, except last vertex which should be the # same as the first @@ -440,7 +449,7 @@ def test_str_coordinates_pcolormesh(self) -> None: def test_contourf_cmap_set(self) -> None: a = DataArray(easy_array((4, 4)), dims=["z", "time"]) - cmap = mpl.cm.viridis + cmap_expected = mpl.colormaps["viridis"] # use copy to ensure cmap is not changed by contourf() # Set vmin and vmax so that _build_discrete_colormap is called with @@ -450,55 +459,59 @@ def test_contourf_cmap_set(self) -> None: # extend='neither' (but if extend='neither' the under and over values # would not be used because the data would all be within the plotted # range) - pl = a.plot.contourf(cmap=copy(cmap), vmin=0.1, vmax=0.9) + pl = a.plot.contourf(cmap=copy(cmap_expected), vmin=0.1, vmax=0.9) # check the set_bad color + cmap = pl.cmap + assert cmap is not None assert_array_equal( - pl.cmap(np.ma.masked_invalid([np.nan]))[0], cmap(np.ma.masked_invalid([np.nan]))[0], + cmap_expected(np.ma.masked_invalid([np.nan]))[0], ) # check the set_under color - assert pl.cmap(-np.inf) == cmap(-np.inf) + assert cmap(-np.inf) == cmap_expected(-np.inf) # check the set_over color - assert pl.cmap(np.inf) == cmap(np.inf) + assert cmap(np.inf) == cmap_expected(np.inf) def test_contourf_cmap_set_with_bad_under_over(self) -> None: a = DataArray(easy_array((4, 4)), dims=["z", "time"]) # make a copy here because we want a local cmap that we will modify. - cmap = copy(mpl.cm.viridis) + cmap_expected = copy(mpl.colormaps["viridis"]) - cmap.set_bad("w") + cmap_expected.set_bad("w") # check we actually changed the set_bad color assert np.all( - cmap(np.ma.masked_invalid([np.nan]))[0] - != mpl.cm.viridis(np.ma.masked_invalid([np.nan]))[0] + cmap_expected(np.ma.masked_invalid([np.nan]))[0] + != mpl.colormaps["viridis"](np.ma.masked_invalid([np.nan]))[0] ) - cmap.set_under("r") + cmap_expected.set_under("r") # check we actually changed the set_under color - assert cmap(-np.inf) != mpl.cm.viridis(-np.inf) + assert cmap_expected(-np.inf) != mpl.colormaps["viridis"](-np.inf) - cmap.set_over("g") + cmap_expected.set_over("g") # check we actually changed the set_over color - assert cmap(np.inf) != mpl.cm.viridis(-np.inf) + assert cmap_expected(np.inf) != mpl.colormaps["viridis"](-np.inf) # copy to ensure cmap is not changed by contourf() - pl = a.plot.contourf(cmap=copy(cmap)) + pl = a.plot.contourf(cmap=copy(cmap_expected)) + cmap = pl.cmap + assert cmap is not None # check the set_bad color has been kept assert_array_equal( - pl.cmap(np.ma.masked_invalid([np.nan]))[0], cmap(np.ma.masked_invalid([np.nan]))[0], + cmap_expected(np.ma.masked_invalid([np.nan]))[0], ) # check the set_under color has been kept - assert pl.cmap(-np.inf) == cmap(-np.inf) + assert cmap(-np.inf) == cmap_expected(-np.inf) # check the set_over color has been kept - assert pl.cmap(np.inf) == cmap(np.inf) + assert cmap(np.inf) == cmap_expected(np.inf) def test3d(self) -> None: self.darray.plot() @@ -716,6 +729,13 @@ def test_labels_with_units_with_interval(self, dim) -> None: expected = "dim_0_bins_center [m]" assert actual == expected + def test_multiplot_over_length_one_dim(self) -> None: + a = easy_array((3, 1, 1, 1)) + d = DataArray(a, dims=("x", "col", "row", "hue")) + d.plot(col="col") + d.plot(row="row") + d.plot(hue="hue") + class TestPlot1D(PlotTestCase): @pytest.fixture(autouse=True) @@ -767,12 +787,17 @@ def test_plot_nans(self) -> None: self.darray[1] = np.nan self.darray.plot.line() - def test_x_ticks_are_rotated_for_time(self) -> None: + def test_dates_are_concise(self) -> None: + import matplotlib.dates as mdates + time = pd.date_range("2000-01-01", "2000-01-10") a = DataArray(np.arange(len(time)), [("t", time)]) a.plot.line() - rotation = plt.gca().get_xticklabels()[0].get_rotation() - assert rotation != 0 + + ax = plt.gca() + + assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator) + assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter) def test_xyincrease_false_changes_axes(self) -> None: self.darray.plot.line(xincrease=False, yincrease=False) @@ -831,19 +856,25 @@ def test_coord_with_interval_step(self) -> None: """Test step plot with intervals.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step() - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) def test_coord_with_interval_step_x(self) -> None: """Test step plot with intervals explicitly on x axis.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(x="dim_0_bins") - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) def test_coord_with_interval_step_y(self) -> None: """Test step plot with intervals explicitly on y axis.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(y="dim_0_bins") - assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + line = plt.gca().lines[0] + assert isinstance(line, mpl.lines.Line2D) + assert len(np.asarray(line.get_xdata())) == ((len(bins) - 1) * 2) def test_coord_with_interval_step_x_and_y_raises_valueeerror(self) -> None: """Test that step plot with intervals both on x and y axes raises an error.""" @@ -883,8 +914,11 @@ def test_can_pass_in_axis(self) -> None: self.pass_in_axis(self.darray.plot.hist) def test_primitive_returned(self) -> None: - h = self.darray.plot.hist() - assert isinstance(h[-1][0], mpl.patches.Rectangle) + n, bins, patches = self.darray.plot.hist() + assert isinstance(n, np.ndarray) + assert isinstance(bins, np.ndarray) + assert isinstance(patches, mpl.container.BarContainer) + assert isinstance(patches[0], mpl.patches.Rectangle) @pytest.mark.slow def test_plot_nans(self) -> None: @@ -928,9 +962,9 @@ def test_cmap_sequential_option(self) -> None: assert cmap_params["cmap"] == "magma" def test_cmap_sequential_explicit_option(self) -> None: - with xr.set_options(cmap_sequential=mpl.cm.magma): + with xr.set_options(cmap_sequential=mpl.colormaps["magma"]): cmap_params = _determine_cmap_params(self.data) - assert cmap_params["cmap"] == mpl.cm.magma + assert cmap_params["cmap"] == mpl.colormaps["magma"] def test_cmap_divergent_option(self) -> None: with xr.set_options(cmap_divergent="magma"): @@ -1170,7 +1204,7 @@ def test_discrete_colormap_list_of_levels(self) -> None: def test_discrete_colormap_int_levels(self) -> None: for extend, levels, vmin, vmax, cmap in [ ("neither", 7, None, None, None), - ("neither", 7, None, 20, mpl.cm.RdBu), + ("neither", 7, None, 20, mpl.colormaps["RdBu"]), ("both", 7, 4, 8, None), ("min", 10, 4, 15, None), ]: @@ -1327,12 +1361,17 @@ def test_xyincrease_true_changes_axes(self) -> None: diffs = xlim[0] - 0, xlim[1] - 14, ylim[0] - 0, ylim[1] - 9 assert all(abs(x) < 1 for x in diffs) - def test_x_ticks_are_rotated_for_time(self) -> None: + def test_dates_are_concise(self) -> None: + import matplotlib.dates as mdates + time = pd.date_range("2000-01-01", "2000-01-10") a = DataArray(np.random.randn(2, len(time)), [("xx", [1, 2]), ("t", time)]) - a.plot(x="t") - rotation = plt.gca().get_xticklabels()[0].get_rotation() - assert rotation != 0 + self.plotfunc(a, x="t") + + ax = plt.gca() + + assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator) + assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter) def test_plot_nans(self) -> None: x1 = self.darray[:5] @@ -1720,8 +1759,8 @@ class TestContour(Common2dMixin, PlotTestCase): # matplotlib cmap.colors gives an rgbA ndarray # when seaborn is used, instead we get an rgb tuple @staticmethod - def _color_as_tuple(c): - return tuple(c[:3]) + def _color_as_tuple(c: Any) -> tuple[Any, Any, Any]: + return c[0], c[1], c[2] def test_colors(self) -> None: # with single color, we don't want rgb array @@ -1743,10 +1782,16 @@ def test_colors_np_levels(self) -> None: # https://github.com/pydata/xarray/issues/3284 levels = np.array([-0.5, 0.0, 0.5, 1.0]) artist = self.darray.plot.contour(levels=levels, colors=["k", "r", "w", "b"]) - assert self._color_as_tuple(artist.cmap.colors[1]) == (1.0, 0.0, 0.0) - assert self._color_as_tuple(artist.cmap.colors[2]) == (1.0, 1.0, 1.0) + cmap = artist.cmap + assert isinstance(cmap, mpl.colors.ListedColormap) + colors = cmap.colors + assert isinstance(colors, list) + + assert self._color_as_tuple(colors[1]) == (1.0, 0.0, 0.0) + assert self._color_as_tuple(colors[2]) == (1.0, 1.0, 1.0) # the last color is now under "over" - assert self._color_as_tuple(artist.cmap._rgba_over) == (0.0, 0.0, 1.0) + assert hasattr(cmap, "_rgba_over") + assert self._color_as_tuple(cmap._rgba_over) == (0.0, 0.0, 1.0) def test_cmap_and_color_both(self) -> None: with pytest.raises(ValueError): @@ -1798,7 +1843,9 @@ def test_dont_infer_interval_breaks_for_cartopy(self) -> None: artist = self.plotmethod(x="x2d", y="y2d", ax=ax) assert isinstance(artist, mpl.collections.QuadMesh) # Let cartopy handle the axis limits and artist size - assert artist.get_array().size <= self.darray.size + arr = artist.get_array() + assert arr is not None + assert arr.size <= self.darray.size class TestPcolormeshLogscale(PlotTestCase): @@ -1851,6 +1898,25 @@ def test_interval_breaks_logspace(self) -> None: class TestImshow(Common2dMixin, PlotTestCase): plotfunc = staticmethod(xplt.imshow) + @pytest.mark.xfail( + reason=( + "Failing inside matplotlib. Should probably be fixed upstream because " + "other plot functions can handle it. " + "Remove this test when it works, already in Common2dMixin" + ) + ) + def test_dates_are_concise(self) -> None: + import matplotlib.dates as mdates + + time = pd.date_range("2000-01-01", "2000-01-10") + a = DataArray(np.random.randn(2, len(time)), [("xx", [1, 2]), ("t", time)]) + self.plotfunc(a, x="t") + + ax = plt.gca() + + assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator) + assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter) + @pytest.mark.slow def test_imshow_called(self) -> None: # Having both statements ensures the test works properly @@ -1949,6 +2015,7 @@ def test_normalize_rgb_imshow( ) -> None: da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) arr = da.plot.imshow(vmin=vmin, vmax=vmax, robust=robust).get_array() + assert arr is not None assert 0 <= arr.min() <= arr.max() <= 1 def test_normalize_rgb_one_arg_error(self) -> None: @@ -1965,7 +2032,10 @@ def test_imshow_rgb_values_in_valid_range(self) -> None: da = DataArray(np.arange(75, dtype="uint8").reshape((5, 5, 3))) _, ax = plt.subplots() out = da.plot.imshow(ax=ax).get_array() - assert out.dtype == np.uint8 + assert out is not None + dtype = out.dtype + assert dtype is not None + assert dtype == np.uint8 assert (out[..., :3] == da.values).all() # Compare without added alpha @pytest.mark.filterwarnings("ignore:Several dimensions of this array") @@ -1991,6 +2061,25 @@ class TestSurface(Common2dMixin, PlotTestCase): plotfunc = staticmethod(xplt.surface) subplot_kws = {"projection": "3d"} + @pytest.mark.xfail( + reason=( + "Failing inside matplotlib. Should probably be fixed upstream because " + "other plot functions can handle it. " + "Remove this test when it works, already in Common2dMixin" + ) + ) + def test_dates_are_concise(self) -> None: + import matplotlib.dates as mdates + + time = pd.date_range("2000-01-01", "2000-01-10") + a = DataArray(np.random.randn(2, len(time)), [("xx", [1, 2]), ("t", time)]) + self.plotfunc(a, x="t") + + ax = plt.gca() + + assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator) + assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter) + def test_primitive_artist_returned(self) -> None: artist = self.plotmethod() assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection) @@ -2000,6 +2089,7 @@ def test_2d_coord_names(self) -> None: self.plotmethod(x="x2d", y="y2d") # make sure labels came out ok ax = plt.gca() + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) assert "x2d" == ax.get_xlabel() assert "y2d" == ax.get_ylabel() assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel() @@ -2042,7 +2132,7 @@ def test_seaborn_palette_as_cmap(self) -> None: def test_convenient_facetgrid(self) -> None: a = easy_array((10, 15, 4)) d = DataArray(a, dims=["y", "x", "z"]) - g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) + g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) # type: ignore[arg-type] # https://github.com/python/mypy/issues/15015 assert_array_equal(g.axs.shape, [2, 2]) for (y, x), ax in np.ndenumerate(g.axs): @@ -2051,7 +2141,7 @@ def test_convenient_facetgrid(self) -> None: assert "x" == ax.get_xlabel() # Inferring labels - g = self.plotfunc(d, col="z", col_wrap=2) + g = self.plotfunc(d, col="z", col_wrap=2) # type: ignore[arg-type] # https://github.com/python/mypy/issues/15015 assert_array_equal(g.axs.shape, [2, 2]) for (y, x), ax in np.ndenumerate(g.axs): assert ax.has_data() @@ -2122,6 +2212,7 @@ def test_colorbar(self) -> None: self.g.map_dataarray(xplt.imshow, "x", "y") for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) clim = np.array(image.get_clim()) assert np.allclose(expected, clim) @@ -2132,7 +2223,9 @@ def test_colorbar_scatter(self) -> None: fg: xplt.FacetGrid = ds.plot.scatter(x="a", y="a", row="x", hue="a") cbar = fg.cbar assert cbar is not None + assert hasattr(cbar, "vmin") assert cbar.vmin == 0 + assert hasattr(cbar, "vmax") assert cbar.vmax == 3 @pytest.mark.slow @@ -2199,6 +2292,7 @@ def test_can_set_vmin_vmax(self) -> None: self.g.map_dataarray(xplt.imshow, "x", "y", vmin=vmin, vmax=vmax) for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) clim = np.array(image.get_clim()) assert np.allclose(expected, clim) @@ -2215,6 +2309,7 @@ def test_can_set_norm(self) -> None: norm = mpl.colors.SymLogNorm(0.1) self.g.map_dataarray(xplt.imshow, "x", "y", norm=norm) for image in plt.gcf().findobj(mpl.image.AxesImage): + assert isinstance(image, mpl.image.AxesImage) assert image.norm is norm @pytest.mark.slow @@ -2708,23 +2803,32 @@ def test_bad_args( x=x, y=y, hue=hue, add_legend=add_legend, add_colorbar=add_colorbar ) - @pytest.mark.xfail(reason="datetime,timedelta hue variable not supported.") - @pytest.mark.parametrize("hue_style", ["discrete", "continuous"]) - def test_datetime_hue(self, hue_style: Literal["discrete", "continuous"]) -> None: + def test_datetime_hue(self) -> None: ds2 = self.ds.copy() + + # TODO: Currently plots as categorical, should it behave as numerical? ds2["hue"] = pd.date_range("2000-1-1", periods=4) - ds2.plot.scatter(x="A", y="B", hue="hue", hue_style=hue_style) + ds2.plot.scatter(x="A", y="B", hue="hue") ds2["hue"] = pd.timedelta_range("-1D", periods=4, freq="D") - ds2.plot.scatter(x="A", y="B", hue="hue", hue_style=hue_style) + ds2.plot.scatter(x="A", y="B", hue="hue") - @pytest.mark.parametrize("hue_style", ["discrete", "continuous"]) - def test_facetgrid_hue_style( - self, hue_style: Literal["discrete", "continuous"] - ) -> None: - g = self.ds.plot.scatter( - x="A", y="B", row="row", col="col", hue="hue", hue_style=hue_style - ) + def test_facetgrid_hue_style(self) -> None: + ds2 = self.ds.copy() + + # Numbers plots as continuous: + g = ds2.plot.scatter(x="A", y="B", row="row", col="col", hue="hue") + assert isinstance(g._mappables[-1], mpl.collections.PathCollection) + + # Datetimes plots as categorical: + # TODO: Currently plots as categorical, should it behave as numerical? + ds2["hue"] = pd.date_range("2000-1-1", periods=4) + g = ds2.plot.scatter(x="A", y="B", row="row", col="col", hue="hue") + assert isinstance(g._mappables[-1], mpl.collections.PathCollection) + + # Strings plots as categorical: + ds2["hue"] = ["a", "a", "b", "b"] + g = ds2.plot.scatter(x="A", y="B", row="row", col="col", hue="hue") assert isinstance(g._mappables[-1], mpl.collections.PathCollection) @pytest.mark.parametrize( @@ -2743,15 +2847,20 @@ def test_non_numeric_legend(self) -> None: ds2 = self.ds.copy() ds2["hue"] = ["a", "b", "c", "d"] pc = ds2.plot.scatter(x="A", y="B", markersize="hue") + axes = pc.axes + assert axes is not None # should make a discrete legend - assert pc.axes.legend_ is not None + assert hasattr(axes, "legend_") + assert axes.legend_ is not None def test_legend_labels(self) -> None: # regression test for #4126: incorrect legend labels ds2 = self.ds.copy() ds2["hue"] = ["a", "a", "b", "b"] pc = ds2.plot.scatter(x="A", y="B", markersize="hue") - actual = [t.get_text() for t in pc.axes.get_legend().texts] + axes = pc.axes + assert axes is not None + actual = [t.get_text() for t in axes.get_legend().texts] expected = ["hue", "a", "b"] assert actual == expected @@ -2772,7 +2881,9 @@ def test_legend_labels_facetgrid(self) -> None: def test_add_legend_by_default(self) -> None: sc = self.ds.plot.scatter(x="A", y="B", hue="hue") - assert len(sc.figure.axes) == 2 + fig = sc.figure + assert fig is not None + assert len(fig.axes) == 2 class TestDatetimePlot(PlotTestCase): @@ -2811,6 +2922,7 @@ def test_datetime_plot1d(self) -> None: # mpl.dates.AutoDateLocator passes and no other subclasses: assert type(ax.xaxis.get_major_locator()) is mpl.dates.AutoDateLocator + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_datetime_plot2d(self) -> None: # Test that matplotlib-native datetime works: da = DataArray( @@ -2824,6 +2936,7 @@ def test_datetime_plot2d(self) -> None: p = da.plot.pcolormesh() ax = p.axes + assert ax is not None # Make sure only mpl converters are used, use type() so only # mpl.dates.AutoDateLocator passes and no other subclasses: @@ -2842,7 +2955,7 @@ def setUp(self) -> None: """ # case for 1d array data = np.random.rand(4, 12) - time = xr.cftime_range(start="2017", periods=12, freq="1M", calendar="noleap") + time = xr.cftime_range(start="2017", periods=12, freq="1ME", calendar="noleap") darray = DataArray(data, dims=["x", "time"]) darray.coords["time"] = time @@ -3259,3 +3372,16 @@ def test_plot1d_default_rcparams() -> None: np.testing.assert_allclose( ax.collections[0].get_edgecolor(), mpl.colors.to_rgba_array("k") ) + + +@requires_matplotlib +def test_plot1d_filtered_nulls() -> None: + ds = xr.tutorial.scatter_example_dataset(seed=42) + y = ds.y.where(ds.y > 0.2) + expected = y.notnull().sum().item() + + with figure_context(): + pc = y.plot.scatter() + actual = pc.get_offsets().shape[0] + + assert expected == actual diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py index 421be1df2dc..b518c973d3a 100644 --- a/xarray/tests/test_plugins.py +++ b/xarray/tests/test_plugins.py @@ -2,11 +2,26 @@ import sys from importlib.metadata import EntryPoint + +if sys.version_info >= (3, 10): + from importlib.metadata import EntryPoints +else: + EntryPoints = list[EntryPoint] from unittest import mock import pytest from xarray.backends import common, plugins +from xarray.tests import ( + has_h5netcdf, + has_netCDF4, + has_pydap, + has_pynio, + has_scipy, + has_zarr, +) + +# Do not import list_engines here, this will break the lazy tests importlib_metadata_mock = "importlib.metadata" @@ -57,7 +72,7 @@ def test_broken_plugin() -> None: "xarray.backends", ) with pytest.warns(RuntimeWarning) as record: - _ = plugins.build_engines([broken_backend]) + _ = plugins.build_engines(EntryPoints([broken_backend])) assert len(record) == 1 message = str(record[0].message) assert "Engine 'broken_backend'" in message @@ -99,23 +114,29 @@ def test_set_missing_parameters() -> None: assert backend_1.open_dataset_parameters == ("filename_or_obj", "decoder") assert backend_2.open_dataset_parameters == ("filename_or_obj",) - backend = DummyBackendEntrypointKwargs() - backend.open_dataset_parameters = ("filename_or_obj", "decoder") # type: ignore[misc] - plugins.set_missing_parameters({"engine": backend}) - assert backend.open_dataset_parameters == ("filename_or_obj", "decoder") + backend_kwargs = DummyBackendEntrypointKwargs + backend_kwargs.open_dataset_parameters = ("filename_or_obj", "decoder") + plugins.set_missing_parameters({"engine": backend_kwargs}) + assert backend_kwargs.open_dataset_parameters == ("filename_or_obj", "decoder") - backend_args = DummyBackendEntrypointArgs() - backend_args.open_dataset_parameters = ("filename_or_obj", "decoder") # type: ignore[misc] + backend_args = DummyBackendEntrypointArgs + backend_args.open_dataset_parameters = ("filename_or_obj", "decoder") plugins.set_missing_parameters({"engine": backend_args}) assert backend_args.open_dataset_parameters == ("filename_or_obj", "decoder") + # reset + backend_1.open_dataset_parameters = None + backend_1.open_dataset_parameters = None + backend_kwargs.open_dataset_parameters = None + backend_args.open_dataset_parameters = None + def test_set_missing_parameters_raise_error() -> None: - backend = DummyBackendEntrypointKwargs() + backend = DummyBackendEntrypointKwargs with pytest.raises(TypeError): plugins.set_missing_parameters({"engine": backend}) - backend_args = DummyBackendEntrypointArgs() + backend_args = DummyBackendEntrypointArgs with pytest.raises(TypeError): plugins.set_missing_parameters({"engine": backend_args}) @@ -128,7 +149,7 @@ def test_build_engines() -> None: dummy_pkg_entrypoint = EntryPoint( "dummy", "xarray.tests.test_plugins:backend_1", "xarray_backends" ) - backend_entrypoints = plugins.build_engines([dummy_pkg_entrypoint]) + backend_entrypoints = plugins.build_engines(EntryPoints([dummy_pkg_entrypoint])) assert isinstance(backend_entrypoints["dummy"], DummyBackendEntrypoint1) assert backend_entrypoints["dummy"].open_dataset_parameters == ( @@ -142,10 +163,16 @@ def test_build_engines() -> None: mock.MagicMock(return_value=DummyBackendEntrypoint1), ) def test_build_engines_sorted() -> None: - dummy_pkg_entrypoints = [ - EntryPoint("dummy2", "xarray.tests.test_plugins:backend_1", "xarray.backends"), - EntryPoint("dummy1", "xarray.tests.test_plugins:backend_1", "xarray.backends"), - ] + dummy_pkg_entrypoints = EntryPoints( + [ + EntryPoint( + "dummy2", "xarray.tests.test_plugins:backend_1", "xarray.backends" + ), + EntryPoint( + "dummy1", "xarray.tests.test_plugins:backend_1", "xarray.backends" + ), + ] + ) backend_entrypoints = list(plugins.build_engines(dummy_pkg_entrypoints)) indices = [] @@ -191,28 +218,29 @@ def test_lazy_import() -> None: When importing xarray these should not be imported as well. Only when running code for the first time that requires them. """ - blacklisted = [ + deny_list = [ + "cubed", + "cupy", + # "dask", # TODO: backends.locks is not lazy yet :( + "dask.array", + "dask.distributed", + "flox", "h5netcdf", + "matplotlib", + "nc_time_axis", "netCDF4", - "PseudoNetCDF", - "pydap", "Nio", + "numbagg", + "pint", + "pydap", "scipy", - "zarr", - "matplotlib", - "nc_time_axis", - "flox", - # "dask", # TODO: backends.locks is not lazy yet :( - "dask.array", - "dask.distributed", "sparse", - "cupy", - "pint", + "zarr", ] # ensure that none of the above modules has been imported before modules_backup = {} for pkg in list(sys.modules.keys()): - for mod in blacklisted + ["xarray"]: + for mod in deny_list + ["xarray"]: if pkg.startswith(mod): modules_backup[pkg] = sys.modules[pkg] del sys.modules[pkg] @@ -228,7 +256,7 @@ def test_lazy_import() -> None: # lazy loaded are loaded when importing xarray is_imported = set() for pkg in sys.modules: - for mod in blacklisted: + for mod in deny_list: if pkg.startswith(mod): is_imported.add(mod) break @@ -239,3 +267,56 @@ def test_lazy_import() -> None: finally: # restore original sys.modules.update(modules_backup) + + +def test_list_engines() -> None: + from xarray.backends import list_engines + + engines = list_engines() + assert list_engines.cache_info().currsize == 1 + + assert ("scipy" in engines) == has_scipy + assert ("h5netcdf" in engines) == has_h5netcdf + assert ("netcdf4" in engines) == has_netCDF4 + assert ("pydap" in engines) == has_pydap + assert ("zarr" in engines) == has_zarr + assert ("pynio" in engines) == has_pynio + assert "store" in engines + + +def test_refresh_engines() -> None: + from xarray.backends import list_engines, refresh_engines + + EntryPointMock1 = mock.MagicMock() + EntryPointMock1.name = "test1" + EntryPointMock1.load.return_value = DummyBackendEntrypoint1 + + if sys.version_info >= (3, 10): + return_value = EntryPoints([EntryPointMock1]) + else: + return_value = {"xarray.backends": [EntryPointMock1]} + + with mock.patch("xarray.backends.plugins.entry_points", return_value=return_value): + list_engines.cache_clear() + engines = list_engines() + assert "test1" in engines + assert isinstance(engines["test1"], DummyBackendEntrypoint1) + + EntryPointMock2 = mock.MagicMock() + EntryPointMock2.name = "test2" + EntryPointMock2.load.return_value = DummyBackendEntrypoint2 + + if sys.version_info >= (3, 10): + return_value2 = EntryPoints([EntryPointMock2]) + else: + return_value2 = {"xarray.backends": [EntryPointMock2]} + + with mock.patch("xarray.backends.plugins.entry_points", return_value=return_value2): + refresh_engines() + engines = list_engines() + assert "test1" not in engines + assert "test2" in engines + assert isinstance(engines["test2"], DummyBackendEntrypoint2) + + # reset to original + refresh_engines() diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index d78afa36011..403a72f9028 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -3,13 +3,12 @@ import numpy as np import pandas as pd import pytest -from packaging.version import Version +from numpy.testing import assert_array_equal import xarray as xr from xarray import DataArray, Dataset, set_options from xarray.tests import ( assert_allclose, - assert_array_equal, assert_equal, assert_identical, get_expected_rolling_indices, @@ -25,6 +24,31 @@ ] +@pytest.mark.parametrize("func", ["mean", "sum"]) +@pytest.mark.parametrize("min_periods", [1, 10]) +def test_cumulative(d, func, min_periods) -> None: + # One dim + result = getattr(d.cumulative("z", min_periods=min_periods), func)() + expected = getattr(d.rolling(z=d["z"].size, min_periods=min_periods), func)() + assert_identical(result, expected) + + # Multiple dim + result = getattr(d.cumulative(["z", "x"], min_periods=min_periods), func)() + expected = getattr( + d.rolling(z=d["z"].size, x=d["x"].size, min_periods=min_periods), + func, + )() + assert_identical(result, expected) + + +def test_cumulative_vs_cum(d) -> None: + result = d.cumulative("z").sum() + expected = d.cumsum("z") + # cumsum drops the coord of the dimension; cumulative doesn't + expected = expected.assign_coords(z=result["z"]) + assert_identical(result, expected) + + class TestDataArrayRolling: @pytest.mark.parametrize("da", (1, 2), indirect=True) @pytest.mark.parametrize("center", [True, False]) @@ -120,23 +144,33 @@ def test_rolling_properties(self, da) -> None: with pytest.raises(ValueError, match="min_periods must be greater than zero"): da.rolling(time=2, min_periods=0) + with pytest.raises( + KeyError, + match=r"\('foo',\) not found in DataArray dimensions", + ): + da.rolling(foo=2) + @requires_bottleneck - @pytest.mark.parametrize("name", ("sum", "mean", "std", "min", "max", "median")) + @pytest.mark.parametrize( + "name", ("sum", "mean", "std", "min", "max", "median", "argmin", "argmax") + ) @pytest.mark.parametrize("center", (True, False, None)) @pytest.mark.parametrize("min_periods", (1, None)) @pytest.mark.parametrize("backend", ["numpy"], indirect=True) @pytest.mark.parametrize("pad", (True, False)) def test_rolling_wrapped_bottleneck( - self, da, name, min_periods, center, pad + self, da, name, min_periods, center, pad, compute_backend ) -> None: import bottleneck as bn window = 7 + # Test all bottleneck functions rolling_obj = da.rolling( time=window, min_periods=min_periods, center=center, pad=pad ) + window = 7 func_name = f"move_{name}" actual = getattr(rolling_obj, name)() expected_values = getattr(bn, func_name)( @@ -145,12 +179,26 @@ def test_rolling_wrapped_bottleneck( expected_indices = get_expected_rolling_indices( da.sizes["time"], window, center, pad ) + # index 0 is at the rightmost edge of the window + # need to reverse index here + # see GH #8541 + if func_name in ["move_argmin", "move_argmax"]: + expected_values = window - 1 - expected_values expected = da.copy(data=expected_values).isel(time=expected_indices) assert_equal(actual, expected) + # Using assert_allclose because we get tiny (1e-17) differences in numbagg. + np.testing.assert_allclose(actual.values, expected) + with pytest.warns(DeprecationWarning, match="Reductions are applied"): getattr(rolling_obj, name)(dim="time") + # Test center + rolling_obj = da.rolling(time=7, center=center) + actual = getattr(rolling_obj, name)()["time"] + # Using assert_allclose because we get tiny (1e-17) differences in numbagg. + assert_allclose(actual, da["time"]) + @requires_dask @pytest.mark.parametrize("name", ("mean", "count")) @pytest.mark.parametrize("center", (True, False, None)) @@ -208,8 +256,10 @@ def test_rolling_wrapped_dask_nochunk(self, center, pad) -> None: @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("pad", (False,)) @pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) - @pytest.mark.parametrize("window", (2, 3, 4)) - def test_rolling_pandas_compat(self, center, pad, window, min_periods) -> None: + @pytest.mark.parametrize("window", (1, 2, 3, 4)) + def test_rolling_pandas_compat( + self, center, pad, window, min_periods, compute_backend + ) -> None: s = pd.Series(np.arange(10)) da = DataArray.from_series(s) @@ -236,8 +286,8 @@ def test_rolling_pandas_compat(self, center, pad, window, min_periods) -> None: np.testing.assert_allclose(s_rolling.index, da_rolling_np["index"]) @pytest.mark.parametrize("center", (True, False)) - @pytest.mark.parametrize("window", (2, 3, 4)) - def test_rolling_construct(self, center, window) -> None: + @pytest.mark.parametrize("window", (1, 2, 3, 4)) + def test_rolling_construct(self, center: bool, window: int) -> None: length = 10 s = pd.Series(np.arange(length)) da = DataArray.from_series(s) @@ -287,7 +337,9 @@ def test_rolling_construct(self, center, window) -> None: @pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) @pytest.mark.parametrize("window", (1, 2, 3, 4)) @pytest.mark.parametrize("name", ("sum", "mean", "std", "max")) - def test_rolling_reduce(self, da, center, pad, min_periods, window, name) -> None: + def test_rolling_reduce( + self, da, center, min_periods, pad, window, name, compute_backend + ) -> None: if min_periods is not None and window < min_periods: min_periods = window @@ -303,7 +355,7 @@ def test_rolling_reduce(self, da, center, pad, min_periods, window, name) -> Non actual = rolling_obj.reduce(getattr(np, "nan%s" % name)) expected = getattr(rolling_obj, name)() assert_allclose(actual, expected) - assert actual.dims == expected.dims + assert actual.sizes == expected.sizes @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("pad", (True, False)) @@ -311,7 +363,7 @@ def test_rolling_reduce(self, da, center, pad, min_periods, window, name) -> Non @pytest.mark.parametrize("window", (1, 2, 3, 4)) @pytest.mark.parametrize("name", ("sum", "max")) def test_rolling_reduce_nonnumeric( - self, center, pad, min_periods, window, name + self, center, pad, min_periods, window, name, compute_backend ) -> None: da = DataArray( [0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time" @@ -328,7 +380,7 @@ def test_rolling_reduce_nonnumeric( actual = rolling_obj.reduce(getattr(np, "nan%s" % name)) expected = getattr(rolling_obj, name)() assert_allclose(actual, expected) - assert actual.dims == expected.dims + assert actual.sizes == expected.sizes @pytest.mark.parametrize( "time, min_periods, pad, expected", @@ -354,7 +406,9 @@ def test_rolling_reduce_nonnumeric( [7, 2, False, DataArray([5, 5, 5, 5, 5], dims="time")], ), ) - def test_rolling_count_correct(self, time, min_periods, pad, expected) -> None: + def test_rolling_count_correct( + self, time, min_periods, pad, expected, compute_backend + ) -> None: da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time") result = da.rolling(time=time, min_periods=min_periods, pad=pad).count() assert_equal(result, expected) @@ -371,9 +425,11 @@ def test_rolling_count_correct(self, time, min_periods, pad, expected) -> None: @pytest.mark.parametrize("pad", (True, False, {"time": True, "x": False})) @pytest.mark.parametrize("min_periods", (None, 1)) @pytest.mark.parametrize("name", ("sum", "mean", "max")) - def test_ndrolling_reduce(self, da, center, pad, min_periods, name) -> None: + def test_ndrolling_reduce( + self, da, center, min_periods, pad, name, compute_backend + ) -> None: rolling_obj = da.rolling( - time=3, x=2, center=center, pad=pad, min_periods=min_periods + time=3, x=2, pad=pad, center=center, min_periods=min_periods ) actual = getattr(rolling_obj, name)() @@ -386,7 +442,7 @@ def test_ndrolling_reduce(self, da, center, pad, min_periods, name) -> None: )() assert_allclose(actual, expected) - assert actual.dims == expected.dims + assert actual.sizes == expected.sizes if name in ["mean"]: # test our reimplementation of nanmean using np.nanmean @@ -499,16 +555,8 @@ class TestDataArrayRollingExp: [["span", 5], ["alpha", 0.5], ["com", 0.5], ["halflife", 5]], ) @pytest.mark.parametrize("backend", ["numpy"], indirect=True) - @pytest.mark.parametrize("func", ["mean", "sum"]) + @pytest.mark.parametrize("func", ["mean", "sum", "var", "std"]) def test_rolling_exp_runs(self, da, dim, window_type, window, func) -> None: - import numbagg - - if ( - Version(getattr(numbagg, "__version__", "0.1.0")) < Version("0.2.1") - and func == "sum" - ): - pytest.skip("rolling_exp.sum requires numbagg 0.2.1") - da = da.where(da > 0.2) rolling_exp = da.rolling_exp(window_type=window_type, **{dim: window}) @@ -540,14 +588,6 @@ def test_rolling_exp_mean_pandas(self, da, dim, window_type, window) -> None: @pytest.mark.parametrize("backend", ["numpy"], indirect=True) @pytest.mark.parametrize("func", ["mean", "sum"]) def test_rolling_exp_keep_attrs(self, da, func) -> None: - import numbagg - - if ( - Version(getattr(numbagg, "__version__", "0.1.0")) < Version("0.2.1") - and func == "sum" - ): - pytest.skip("rolling_exp.sum requires numbagg 0.2.1") - attrs = {"attrs": "da"} da.attrs = attrs @@ -667,6 +707,11 @@ def test_rolling_properties(self, ds) -> None: ds.rolling(time=2, min_periods=0) with pytest.raises(KeyError, match="time2"): ds.rolling(time2=2) + with pytest.raises( + KeyError, + match=r"\('foo',\) not found in Dataset dimensions", + ): + ds.rolling(foo=2) @requires_bottleneck @pytest.mark.parametrize( @@ -678,7 +723,7 @@ def test_rolling_properties(self, ds) -> None: @pytest.mark.parametrize("pad", (False,)) @pytest.mark.parametrize("backend", ["numpy"], indirect=True) def test_rolling_wrapped_bottleneck( - self, ds, name, min_periods, key, center, pad + self, ds, name, pad, center, min_periods, key, compute_backend ) -> None: import bottleneck as bn @@ -696,12 +741,18 @@ def test_rolling_wrapped_bottleneck( if key == "z1": # z1 does not depend on 'Time' axis. Stored as it is. expected = ds[key] elif key == "z2": - expected = getattr(bn, func_name)( - ds[key].values, window=7, axis=0, min_count=min_periods + expected = ( + ds[key] + .copy( + data=getattr(bn, func_name)( + ds[key].values, window=7, axis=0, min_count=min_periods + ) + ) + .isel(time=expected_indices) ) else: raise ValueError - assert_equal(actual[key], expected[key].isel(time=expected_indices)) + assert_equal(actual[key], expected) @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) @@ -729,7 +780,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None: @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("window", (1, 2, 3, 4)) - def test_rolling_construct(self, center, window) -> None: + def test_rolling_construct(self, center: bool, window: int) -> None: df = pd.DataFrame( { "x": np.random.randn(20), @@ -746,19 +797,58 @@ def test_rolling_construct(self, center, window) -> None: np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values) np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"]) - # with stride - ds_rolling_mean = ds_rolling.construct("window", stride=2).mean("window") - np.testing.assert_allclose( - df_rolling["x"][::2].values, ds_rolling_mean["x"].values - ) - np.testing.assert_allclose(df_rolling.index[::2], ds_rolling_mean["index"]) # with fill_value ds_rolling_mean = ds_rolling.construct("window", stride=2, fill_value=0.0).mean( "window" ) - assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim="vars").all() + assert (ds_rolling_mean.isnull().sum() == 0).to_dataarray(dim="vars").all() assert (ds_rolling_mean["x"] == 0.0).sum() >= 0 + @pytest.mark.parametrize("center", (True, False)) + @pytest.mark.parametrize("window", (1, 2, 3, 4)) + def test_rolling_construct_stride(self, center: bool, window: int) -> None: + df = pd.DataFrame( + { + "x": np.random.randn(20), + "y": np.random.randn(20), + "time": np.linspace(0, 1, 20), + } + ) + ds = Dataset.from_dataframe(df) + df_rolling_mean = df.rolling(window, center=center, min_periods=1).mean() + + # With an index (dimension coordinate) + ds_rolling = ds.rolling(index=window, center=center) + ds_rolling_mean = ds_rolling.construct("w", stride=2).mean("w") + np.testing.assert_allclose( + df_rolling_mean["x"][::2].values, ds_rolling_mean["x"].values + ) + np.testing.assert_allclose(df_rolling_mean.index[::2], ds_rolling_mean["index"]) + + # Without index (https://github.com/pydata/xarray/issues/7021) + ds2 = ds.drop_vars("index") + ds2_rolling = ds2.rolling(index=window, center=center) + ds2_rolling_mean = ds2_rolling.construct("w", stride=2).mean("w") + np.testing.assert_allclose( + df_rolling_mean["x"][::2].values, ds2_rolling_mean["x"].values + ) + + # Mixed coordinates, indexes and 2D coordinates + ds3 = xr.Dataset( + {"x": ("t", range(20)), "x2": ("y", range(5))}, + { + "t": range(20), + "y": ("y", range(5)), + "t2": ("t", range(20)), + "y2": ("y", range(5)), + "yt": (["t", "y"], np.ones((20, 5))), + }, + ) + ds3_rolling = ds3.rolling(t=window, center=center) + ds3_rolling_mean = ds3_rolling.construct("w", stride=2).mean("w") + for coord in ds3.coords: + assert coord in ds3_rolling_mean.coords + @pytest.mark.slow @pytest.mark.parametrize("ds", (1, 2), indirect=True) @pytest.mark.parametrize("center", (True, False)) @@ -780,7 +870,7 @@ def test_rolling_reduce(self, ds, center, min_periods, window, name) -> None: actual = rolling_obj.reduce(getattr(np, "nan%s" % name)) expected = getattr(rolling_obj, name)() assert_allclose(actual, expected) - assert ds.dims == actual.dims + assert ds.sizes == actual.sizes # make sure the order of data_var are not changed. assert list(ds.data_vars.keys()) == list(actual.data_vars.keys()) @@ -807,7 +897,7 @@ def test_ndrolling_reduce(self, ds, center, min_periods, name, dask) -> None: name, )() assert_allclose(actual, expected) - assert actual.dims == expected.dims + assert actual.sizes == expected.sizes # Do it in the opposite order expected = getattr( @@ -818,7 +908,7 @@ def test_ndrolling_reduce(self, ds, center, min_periods, name, dask) -> None: )() assert_allclose(actual, expected) - assert actual.dims == expected.dims + assert actual.sizes == expected.sizes @pytest.mark.parametrize("center", (True, False, (True, False))) @pytest.mark.parametrize("fill_value", (np.nan, 0.0)) @@ -846,9 +936,7 @@ def test_ndrolling_construct(self, center, fill_value, dask) -> None: ) assert_allclose(actual, expected) - @pytest.mark.xfail( - reason="See https://github.com/pydata/xarray/pull/4369 or docstring" - ) + @requires_dask @pytest.mark.filterwarnings("error") @pytest.mark.parametrize("ds", (2,), indirect=True) @pytest.mark.parametrize("name", ("mean", "max")) @@ -870,7 +958,9 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None: @requires_numbagg class TestDatasetRollingExp: - @pytest.mark.parametrize("backend", ["numpy"], indirect=True) + @pytest.mark.parametrize( + "backend", ["numpy", pytest.param("dask", marks=requires_dask)], indirect=True + ) def test_rolling_exp(self, ds) -> None: result = ds.rolling_exp(time=10, window_type="span").mean() assert isinstance(result, Dataset) @@ -891,13 +981,18 @@ def test_rolling_exp_keep_attrs(self, ds) -> None: # discard attrs result = ds.rolling_exp(time=10).mean(keep_attrs=False) assert result.attrs == {} - assert result.z1.attrs == {} + # TODO: from #8114 — this arguably should be empty, but `apply_ufunc` doesn't do + # that at the moment. We should change in `apply_func` rather than + # special-case it here. + # + # assert result.z1.attrs == {} # test discard attrs using global option with set_options(keep_attrs=False): result = ds.rolling_exp(time=10).mean() assert result.attrs == {} - assert result.z1.attrs == {} + # See above + # assert result.z1.attrs == {} # keyword takes precedence over global option with set_options(keep_attrs=False): @@ -908,7 +1003,8 @@ def test_rolling_exp_keep_attrs(self, ds) -> None: with set_options(keep_attrs=True): result = ds.rolling_exp(time=10).mean(keep_attrs=False) assert result.attrs == {} - assert result.z1.attrs == {} + # See above + # assert result.z1.attrs == {} with pytest.warns( UserWarning, diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index f64ce9338d7..09c12818754 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -10,7 +10,7 @@ import xarray as xr from xarray import DataArray, Variable -from xarray.core.pycompat import array_type +from xarray.namedarray.pycompat import array_type from xarray.tests import assert_equal, assert_identical, requires_dask filterwarnings = pytest.mark.filterwarnings @@ -147,7 +147,6 @@ def test_variable_property(prop): ], ), True, - marks=xfail(reason="Coercion to dense"), ), param( do("conjugate"), @@ -201,7 +200,6 @@ def test_variable_property(prop): param( do("reduce", func="sum", dim="x"), True, - marks=xfail(reason="Coercion to dense"), ), param( do("rolling_window", dim="x", window=2, window_dim="x_win"), @@ -218,7 +216,7 @@ def test_variable_property(prop): param( do("var"), False, marks=xfail(reason="Missing implementation for np.nanvar") ), - param(do("to_dict"), False, marks=xfail(reason="Coercion to dense")), + param(do("to_dict"), False), (do("where", cond=make_xrvar({"x": 10, "y": 5}) > 0.5), True), ], ids=repr, @@ -237,7 +235,14 @@ def test_variable_method(func, sparse_output): assert isinstance(ret_s.data, sparse.SparseArray) assert np.allclose(ret_s.data.todense(), ret_d.data, equal_nan=True) else: - assert np.allclose(ret_s, ret_d, equal_nan=True) + if func.meth != "to_dict": + assert np.allclose(ret_s, ret_d) + else: + # pop the arrays from the dict + arr_s, arr_d = ret_s.pop("data"), ret_d.pop("data") + + assert np.allclose(arr_s, arr_d) + assert ret_s == ret_d @pytest.mark.parametrize( @@ -292,7 +297,7 @@ def test_bivariate_ufunc(self): def test_repr(self): expected = dedent( """\ - + Size: 288B """ ) assert expected == repr(self.var) @@ -573,7 +578,7 @@ def setUp(self): def test_to_dataset_roundtrip(self): x = self.sp_xr - assert_equal(x, x.to_dataset("x").to_array("x")) + assert_equal(x, x.to_dataset("x").to_dataarray("x")) def test_align(self): a1 = xr.DataArray( @@ -676,10 +681,10 @@ def test_dataarray_repr(self): ) expected = dedent( """\ - + Size: 64B Coordinates: - y (x) int64 + y (x) int64 48B Dimensions without coordinates: x""" ) assert expected == repr(a) @@ -691,13 +696,13 @@ def test_dataset_repr(self): ) expected = dedent( """\ - + Size: 112B Dimensions: (x: 4) Coordinates: - y (x) int64 + y (x) int64 48B Dimensions without coordinates: x Data variables: - a (x) float64 """ + a (x) float64 64B """ ) assert expected == repr(ds) @@ -708,11 +713,11 @@ def test_sparse_dask_dataset_repr(self): ).chunk() expected = dedent( """\ - + Size: 32B Dimensions: (x: 4) Dimensions without coordinates: x Data variables: - a (x) float64 dask.array""" + a (x) float64 32B dask.array""" ) assert expected == repr(ds) @@ -825,7 +830,7 @@ def test_reindex(self): @pytest.mark.xfail def test_merge(self): x = self.sp_xr - y = xr.merge([x, x.rename("bar")]).to_array() + y = xr.merge([x, x.rename("bar")]).to_dataarray() assert isinstance(y, sparse.SparseArray) @pytest.mark.xfail @@ -873,10 +878,6 @@ def test_dask_token(): import dask s = sparse.COO.from_numpy(np.array([0, 0, 1, 2])) - - # https://github.com/pydata/sparse/issues/300 - s.__dask_tokenize__ = lambda: dask.base.normalize_token(s.__dict__) - a = DataArray(s) t1 = dask.base.tokenize(a) t2 = dask.base.tokenize(a) diff --git a/xarray/tests/test_strategies.py b/xarray/tests/test_strategies.py new file mode 100644 index 00000000000..44f0d56cde8 --- /dev/null +++ b/xarray/tests/test_strategies.py @@ -0,0 +1,271 @@ +import numpy as np +import numpy.testing as npt +import pytest + +pytest.importorskip("hypothesis") +# isort: split + +import hypothesis.extra.numpy as npst +import hypothesis.strategies as st +from hypothesis import given +from hypothesis.extra.array_api import make_strategies_namespace + +from xarray.core.variable import Variable +from xarray.testing.strategies import ( + attrs, + dimension_names, + dimension_sizes, + supported_dtypes, + unique_subset_of, + variables, +) +from xarray.tests import requires_numpy_array_api + +ALLOWED_ATTRS_VALUES_TYPES = (int, bool, str, np.ndarray) + + +class TestDimensionNamesStrategy: + @given(dimension_names()) + def test_types(self, dims): + assert isinstance(dims, list) + for d in dims: + assert isinstance(d, str) + + @given(dimension_names()) + def test_unique(self, dims): + assert len(set(dims)) == len(dims) + + @given(st.data(), st.tuples(st.integers(0, 10), st.integers(0, 10)).map(sorted)) + def test_number_of_dims(self, data, ndims): + min_dims, max_dims = ndims + dim_names = data.draw(dimension_names(min_dims=min_dims, max_dims=max_dims)) + assert isinstance(dim_names, list) + assert min_dims <= len(dim_names) <= max_dims + + +class TestDimensionSizesStrategy: + @given(dimension_sizes()) + def test_types(self, dims): + assert isinstance(dims, dict) + for d, n in dims.items(): + assert isinstance(d, str) + assert len(d) >= 1 + + assert isinstance(n, int) + assert n >= 0 + + @given(st.data(), st.tuples(st.integers(0, 10), st.integers(0, 10)).map(sorted)) + def test_number_of_dims(self, data, ndims): + min_dims, max_dims = ndims + dim_sizes = data.draw(dimension_sizes(min_dims=min_dims, max_dims=max_dims)) + assert isinstance(dim_sizes, dict) + assert min_dims <= len(dim_sizes) <= max_dims + + @given(st.data()) + def test_restrict_names(self, data): + capitalized_names = st.text(st.characters(), min_size=1).map(str.upper) + dim_sizes = data.draw(dimension_sizes(dim_names=capitalized_names)) + for dim in dim_sizes.keys(): + assert dim.upper() == dim + + +def check_dict_values(dictionary: dict, allowed_attrs_values_types) -> bool: + """Helper function to assert that all values in recursive dict match one of a set of types.""" + for key, value in dictionary.items(): + if isinstance(value, allowed_attrs_values_types) or value is None: + continue + elif isinstance(value, dict): + # If the value is a dictionary, recursively check it + if not check_dict_values(value, allowed_attrs_values_types): + return False + else: + # If the value is not an integer or a dictionary, it's not valid + return False + return True + + +class TestAttrsStrategy: + @given(attrs()) + def test_type(self, attrs): + assert isinstance(attrs, dict) + check_dict_values(attrs, ALLOWED_ATTRS_VALUES_TYPES) + + +class TestVariablesStrategy: + @given(variables()) + def test_given_nothing(self, var): + assert isinstance(var, Variable) + + @given(st.data()) + def test_given_incorrect_types(self, data): + with pytest.raises(TypeError, match="dims must be provided as a"): + data.draw(variables(dims=["x", "y"])) # type: ignore[arg-type] + + with pytest.raises(TypeError, match="dtype must be provided as a"): + data.draw(variables(dtype=np.dtype("int32"))) # type: ignore[arg-type] + + with pytest.raises(TypeError, match="attrs must be provided as a"): + data.draw(variables(attrs=dict())) # type: ignore[arg-type] + + with pytest.raises(TypeError, match="Callable"): + data.draw(variables(array_strategy_fn=np.array([0]))) # type: ignore[arg-type] + + @given(st.data(), dimension_names()) + def test_given_fixed_dim_names(self, data, fixed_dim_names): + var = data.draw(variables(dims=st.just(fixed_dim_names))) + + assert list(var.dims) == fixed_dim_names + + @given(st.data(), dimension_sizes()) + def test_given_fixed_dim_sizes(self, data, dim_sizes): + var = data.draw(variables(dims=st.just(dim_sizes))) + + assert var.dims == tuple(dim_sizes.keys()) + assert var.shape == tuple(dim_sizes.values()) + + @given(st.data(), supported_dtypes()) + def test_given_fixed_dtype(self, data, dtype): + var = data.draw(variables(dtype=st.just(dtype))) + + assert var.dtype == dtype + + @given(st.data(), npst.arrays(shape=npst.array_shapes(), dtype=supported_dtypes())) + def test_given_fixed_data_dims_and_dtype(self, data, arr): + def fixed_array_strategy_fn(*, shape=None, dtype=None): + """The fact this ignores shape and dtype is only okay because compatible shape & dtype will be passed separately.""" + return st.just(arr) + + dim_names = data.draw(dimension_names(min_dims=arr.ndim, max_dims=arr.ndim)) + dim_sizes = {name: size for name, size in zip(dim_names, arr.shape)} + + var = data.draw( + variables( + array_strategy_fn=fixed_array_strategy_fn, + dims=st.just(dim_sizes), + dtype=st.just(arr.dtype), + ) + ) + + npt.assert_equal(var.data, arr) + assert var.dtype == arr.dtype + + @given(st.data(), st.integers(0, 3)) + def test_given_array_strat_arbitrary_size_and_arbitrary_data(self, data, ndims): + dim_names = data.draw(dimension_names(min_dims=ndims, max_dims=ndims)) + + def array_strategy_fn(*, shape=None, dtype=None): + return npst.arrays(shape=shape, dtype=dtype) + + var = data.draw( + variables( + array_strategy_fn=array_strategy_fn, + dims=st.just(dim_names), + dtype=supported_dtypes(), + ) + ) + + assert var.ndim == ndims + + @given(st.data()) + def test_catch_unruly_dtype_from_custom_array_strategy_fn(self, data): + def dodgy_array_strategy_fn(*, shape=None, dtype=None): + """Dodgy function which ignores the dtype it was passed""" + return npst.arrays(shape=shape, dtype=npst.floating_dtypes()) + + with pytest.raises( + ValueError, match="returned an array object with a different dtype" + ): + data.draw( + variables( + array_strategy_fn=dodgy_array_strategy_fn, + dtype=st.just(np.dtype("int32")), + ) + ) + + @given(st.data()) + def test_catch_unruly_shape_from_custom_array_strategy_fn(self, data): + def dodgy_array_strategy_fn(*, shape=None, dtype=None): + """Dodgy function which ignores the shape it was passed""" + return npst.arrays(shape=(3, 2), dtype=dtype) + + with pytest.raises( + ValueError, match="returned an array object with a different shape" + ): + data.draw( + variables( + array_strategy_fn=dodgy_array_strategy_fn, + dims=st.just({"a": 2, "b": 1}), + dtype=supported_dtypes(), + ) + ) + + @requires_numpy_array_api + @given(st.data()) + def test_make_strategies_namespace(self, data): + """ + Test not causing a hypothesis.InvalidArgument by generating a dtype that's not in the array API. + + We still want to generate dtypes not in the array API by default, but this checks we don't accidentally override + the user's choice of dtypes with non-API-compliant ones. + """ + from numpy import ( + array_api as np_array_api, # requires numpy>=1.26.0, and we expect a UserWarning to be raised + ) + + np_array_api_st = make_strategies_namespace(np_array_api) + + data.draw( + variables( + array_strategy_fn=np_array_api_st.arrays, + dtype=np_array_api_st.scalar_dtypes(), + ) + ) + + +class TestUniqueSubsetOf: + @given(st.data()) + def test_invalid(self, data): + with pytest.raises(TypeError, match="must be an Iterable or a Mapping"): + data.draw(unique_subset_of(0)) # type: ignore[call-overload] + + with pytest.raises(ValueError, match="length-zero object"): + data.draw(unique_subset_of({})) + + @given(st.data(), dimension_sizes(min_dims=1)) + def test_mapping(self, data, dim_sizes): + subset_of_dim_sizes = data.draw(unique_subset_of(dim_sizes)) + + for dim, length in subset_of_dim_sizes.items(): + assert dim in dim_sizes + assert dim_sizes[dim] == length + + @given(st.data(), dimension_names(min_dims=1)) + def test_iterable(self, data, dim_names): + subset_of_dim_names = data.draw(unique_subset_of(dim_names)) + + for dim in subset_of_dim_names: + assert dim in dim_names + + +class TestReduction: + """ + These tests are for checking that the examples given in the docs page on testing actually work. + """ + + @given(st.data(), variables(dims=dimension_names(min_dims=1))) + def test_mean(self, data, var): + """ + Test that given a Variable of at least one dimension, + the mean of the Variable is always equal to the mean of the underlying array. + """ + + # specify arbitrary reduction along at least one dimension + reduction_dims = data.draw(unique_subset_of(var.dims, min_size=1)) + + # create expected result (using nanmean because arrays with Nans will be generated) + reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims) + expected = np.nanmean(var.data, axis=reduction_axes) + + # assert property is always satisfied + result = var.mean(dim=reduction_dims).data + npt.assert_equal(expected, result) diff --git a/xarray/tests/test_treenode.py b/xarray/tests/test_treenode.py new file mode 100644 index 00000000000..b0e737bd317 --- /dev/null +++ b/xarray/tests/test_treenode.py @@ -0,0 +1,405 @@ +from __future__ import annotations + +from collections.abc import Iterator +from typing import cast + +import pytest + +from xarray.core.treenode import InvalidTreeError, NamedNode, NodePath, TreeNode +from xarray.datatree_.datatree.iterators import LevelOrderIter, PreOrderIter + + +class TestFamilyTree: + def test_lonely(self): + root: TreeNode = TreeNode() + assert root.parent is None + assert root.children == {} + + def test_parenting(self): + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + mary._set_parent(john, "Mary") + + assert mary.parent == john + assert john.children["Mary"] is mary + + def test_no_time_traveller_loops(self): + john: TreeNode = TreeNode() + + with pytest.raises(InvalidTreeError, match="cannot be a parent of itself"): + john._set_parent(john, "John") + + with pytest.raises(InvalidTreeError, match="cannot be a parent of itself"): + john.children = {"John": john} + + mary: TreeNode = TreeNode() + rose: TreeNode = TreeNode() + mary._set_parent(john, "Mary") + rose._set_parent(mary, "Rose") + + with pytest.raises(InvalidTreeError, match="is already a descendant"): + john._set_parent(rose, "John") + + with pytest.raises(InvalidTreeError, match="is already a descendant"): + rose.children = {"John": john} + + def test_parent_swap(self): + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + mary._set_parent(john, "Mary") + + steve: TreeNode = TreeNode() + mary._set_parent(steve, "Mary") + + assert mary.parent == steve + assert steve.children["Mary"] is mary + assert "Mary" not in john.children + + def test_multi_child_family(self): + mary: TreeNode = TreeNode() + kate: TreeNode = TreeNode() + john: TreeNode = TreeNode(children={"Mary": mary, "Kate": kate}) + assert john.children["Mary"] is mary + assert john.children["Kate"] is kate + assert mary.parent is john + assert kate.parent is john + + def test_disown_child(self): + mary: TreeNode = TreeNode() + john: TreeNode = TreeNode(children={"Mary": mary}) + mary.orphan() + assert mary.parent is None + assert "Mary" not in john.children + + def test_doppelganger_child(self): + kate: TreeNode = TreeNode() + john: TreeNode = TreeNode() + + with pytest.raises(TypeError): + john.children = {"Kate": 666} + + with pytest.raises(InvalidTreeError, match="Cannot add same node"): + john.children = {"Kate": kate, "Evil_Kate": kate} + + john = TreeNode(children={"Kate": kate}) + evil_kate: TreeNode = TreeNode() + evil_kate._set_parent(john, "Kate") + assert john.children["Kate"] is evil_kate + + def test_sibling_relationships(self): + mary: TreeNode = TreeNode() + kate: TreeNode = TreeNode() + ashley: TreeNode = TreeNode() + TreeNode(children={"Mary": mary, "Kate": kate, "Ashley": ashley}) + assert kate.siblings["Mary"] is mary + assert kate.siblings["Ashley"] is ashley + assert "Kate" not in kate.siblings + + def test_ancestors(self): + tony: TreeNode = TreeNode() + michael: TreeNode = TreeNode(children={"Tony": tony}) + vito = TreeNode(children={"Michael": michael}) + assert tony.root is vito + assert tony.parents == (michael, vito) + assert tony.ancestors == (vito, michael, tony) + + +class TestGetNodes: + def test_get_child(self): + steven: TreeNode = TreeNode() + sue = TreeNode(children={"Steven": steven}) + mary = TreeNode(children={"Sue": sue}) + john = TreeNode(children={"Mary": mary}) + + # get child + assert john._get_item("Mary") is mary + assert mary._get_item("Sue") is sue + + # no child exists + with pytest.raises(KeyError): + john._get_item("Kate") + + # get grandchild + assert john._get_item("Mary/Sue") is sue + + # get great-grandchild + assert john._get_item("Mary/Sue/Steven") is steven + + # get from middle of tree + assert mary._get_item("Sue/Steven") is steven + + def test_get_upwards(self): + sue: TreeNode = TreeNode() + kate: TreeNode = TreeNode() + mary = TreeNode(children={"Sue": sue, "Kate": kate}) + john = TreeNode(children={"Mary": mary}) + + assert sue._get_item("../") is mary + assert sue._get_item("../../") is john + + # relative path + assert sue._get_item("../Kate") is kate + + def test_get_from_root(self): + sue: TreeNode = TreeNode() + mary = TreeNode(children={"Sue": sue}) + john = TreeNode(children={"Mary": mary}) # noqa + + assert sue._get_item("/Mary") is mary + + +class TestSetNodes: + def test_set_child_node(self): + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + john._set_item("Mary", mary) + + assert john.children["Mary"] is mary + assert isinstance(mary, TreeNode) + assert mary.children == {} + assert mary.parent is john + + def test_child_already_exists(self): + mary: TreeNode = TreeNode() + john: TreeNode = TreeNode(children={"Mary": mary}) + mary_2: TreeNode = TreeNode() + with pytest.raises(KeyError): + john._set_item("Mary", mary_2, allow_overwrite=False) + + def test_set_grandchild(self): + rose: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + john: TreeNode = TreeNode() + + john._set_item("Mary", mary) + john._set_item("Mary/Rose", rose) + + assert john.children["Mary"] is mary + assert isinstance(mary, TreeNode) + assert "Rose" in mary.children + assert rose.parent is mary + + def test_create_intermediate_child(self): + john: TreeNode = TreeNode() + rose: TreeNode = TreeNode() + + # test intermediate children not allowed + with pytest.raises(KeyError, match="Could not reach"): + john._set_item(path="Mary/Rose", item=rose, new_nodes_along_path=False) + + # test intermediate children allowed + john._set_item("Mary/Rose", rose, new_nodes_along_path=True) + assert "Mary" in john.children + mary = john.children["Mary"] + assert isinstance(mary, TreeNode) + assert mary.children == {"Rose": rose} + assert rose.parent == mary + assert rose.parent == mary + + def test_overwrite_child(self): + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + john._set_item("Mary", mary) + + # test overwriting not allowed + marys_evil_twin: TreeNode = TreeNode() + with pytest.raises(KeyError, match="Already a node object"): + john._set_item("Mary", marys_evil_twin, allow_overwrite=False) + assert john.children["Mary"] is mary + assert marys_evil_twin.parent is None + + # test overwriting allowed + marys_evil_twin = TreeNode() + john._set_item("Mary", marys_evil_twin, allow_overwrite=True) + assert john.children["Mary"] is marys_evil_twin + assert marys_evil_twin.parent is john + + +class TestPruning: + def test_del_child(self): + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + john._set_item("Mary", mary) + + del john["Mary"] + assert "Mary" not in john.children + assert mary.parent is None + + with pytest.raises(KeyError): + del john["Mary"] + + +def create_test_tree() -> tuple[NamedNode, NamedNode]: + # a + # ├── b + # │ ├── d + # │ └── e + # │ ├── f + # │ └── g + # └── c + # └── h + # └── i + a: NamedNode = NamedNode(name="a") + b: NamedNode = NamedNode() + c: NamedNode = NamedNode() + d: NamedNode = NamedNode() + e: NamedNode = NamedNode() + f: NamedNode = NamedNode() + g: NamedNode = NamedNode() + h: NamedNode = NamedNode() + i: NamedNode = NamedNode() + + a.children = {"b": b, "c": c} + b.children = {"d": d, "e": e} + e.children = {"f": f, "g": g} + c.children = {"h": h} + h.children = {"i": i} + + return a, f + + +class TestIterators: + def test_preorderiter(self): + root, _ = create_test_tree() + result: list[str | None] = [ + node.name for node in cast(Iterator[NamedNode], PreOrderIter(root)) + ] + expected = [ + "a", + "b", + "d", + "e", + "f", + "g", + "c", + "h", + "i", + ] + assert result == expected + + def test_levelorderiter(self): + root, _ = create_test_tree() + result: list[str | None] = [ + node.name for node in cast(Iterator[NamedNode], LevelOrderIter(root)) + ] + expected = [ + "a", # root Node is unnamed + "b", + "c", + "d", + "e", + "h", + "f", + "g", + "i", + ] + assert result == expected + + +class TestAncestry: + + def test_parents(self): + _, leaf_f = create_test_tree() + expected = ["e", "b", "a"] + assert [node.name for node in leaf_f.parents] == expected + + def test_lineage(self): + _, leaf_f = create_test_tree() + expected = ["f", "e", "b", "a"] + assert [node.name for node in leaf_f.lineage] == expected + + def test_ancestors(self): + _, leaf_f = create_test_tree() + ancestors = leaf_f.ancestors + expected = ["a", "b", "e", "f"] + for node, expected_name in zip(ancestors, expected): + assert node.name == expected_name + + def test_subtree(self): + root, _ = create_test_tree() + subtree = root.subtree + expected = [ + "a", + "b", + "d", + "e", + "f", + "g", + "c", + "h", + "i", + ] + for node, expected_name in zip(subtree, expected): + assert node.name == expected_name + + def test_descendants(self): + root, _ = create_test_tree() + descendants = root.descendants + expected = [ + "b", + "d", + "e", + "f", + "g", + "c", + "h", + "i", + ] + for node, expected_name in zip(descendants, expected): + assert node.name == expected_name + + def test_leaves(self): + tree, _ = create_test_tree() + leaves = tree.leaves + expected = [ + "d", + "f", + "g", + "i", + ] + for node, expected_name in zip(leaves, expected): + assert node.name == expected_name + + def test_levels(self): + a, f = create_test_tree() + + assert a.level == 0 + assert f.level == 3 + + assert a.depth == 3 + assert f.depth == 3 + + assert a.width == 1 + assert f.width == 3 + + +class TestRenderTree: + def test_render_nodetree(self): + sam: NamedNode = NamedNode() + ben: NamedNode = NamedNode() + mary: NamedNode = NamedNode(children={"Sam": sam, "Ben": ben}) + kate: NamedNode = NamedNode() + john: NamedNode = NamedNode(children={"Mary": mary, "Kate": kate}) + expected_nodes = [ + "NamedNode()", + "\tNamedNode('Mary')", + "\t\tNamedNode('Sam')", + "\t\tNamedNode('Ben')", + "\tNamedNode('Kate')", + ] + expected_str = "NamedNode('Mary')" + john_repr = john.__repr__() + mary_str = mary.__str__() + + assert mary_str == expected_str + + john_nodes = john_repr.splitlines() + assert len(john_nodes) == len(expected_nodes) + for expected_node, repr_node in zip(expected_nodes, john_nodes): + assert expected_node == repr_node + + +def test_nodepath(): + path = NodePath("/Mary") + assert path.root == "/" + assert path.stem == "Mary" diff --git a/xarray/tests/test_typed_ops.py b/xarray/tests/test_typed_ops.py new file mode 100644 index 00000000000..1d4ef89ae29 --- /dev/null +++ b/xarray/tests/test_typed_ops.py @@ -0,0 +1,246 @@ +import numpy as np + +from xarray import DataArray, Dataset, Variable + + +def test_variable_typed_ops() -> None: + """Tests for type checking of typed_ops on Variable""" + + var = Variable(dims=["t"], data=[1, 2, 3]) + + def _test(var: Variable) -> None: + # mypy checks the input type + assert isinstance(var, Variable) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + + # __add__ as an example of binary ops + _test(var + _int) + _test(var + _list) + _test(var + _ndarray) + _test(var + var) + + # __radd__ as an example of reflexive binary ops + _test(_int + var) + _test(_list + var) + _test(_ndarray + var) # type: ignore[arg-type] # numpy problem + + # __eq__ as an example of cmp ops + _test(var == _int) + _test(var == _list) + _test(var == _ndarray) + _test(_int == var) # type: ignore[arg-type] # typeshed problem + _test(_list == var) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == var) + + # __lt__ as another example of cmp ops + _test(var < _int) + _test(var < _list) + _test(var < _ndarray) + _test(_int > var) + _test(_list > var) + _test(_ndarray > var) # type: ignore[arg-type] # numpy problem + + # __iadd__ as an example of inplace binary ops + var += _int + var += _list + var += _ndarray + + # __neg__ as an example of unary ops + _test(-var) + + +def test_dataarray_typed_ops() -> None: + """Tests for type checking of typed_ops on DataArray""" + + da = DataArray([1, 2, 3], dims=["t"]) + + def _test(da: DataArray) -> None: + # mypy checks the input type + assert isinstance(da, DataArray) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + _var = Variable(dims=["t"], data=[1, 2, 3]) + + # __add__ as an example of binary ops + _test(da + _int) + _test(da + _list) + _test(da + _ndarray) + _test(da + _var) + _test(da + da) + + # __radd__ as an example of reflexive binary ops + _test(_int + da) + _test(_list + da) + _test(_ndarray + da) # type: ignore[arg-type] # numpy problem + _test(_var + da) + + # __eq__ as an example of cmp ops + _test(da == _int) + _test(da == _list) + _test(da == _ndarray) + _test(da == _var) + _test(_int == da) # type: ignore[arg-type] # typeshed problem + _test(_list == da) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == da) + _test(_var == da) + + # __lt__ as another example of cmp ops + _test(da < _int) + _test(da < _list) + _test(da < _ndarray) + _test(da < _var) + _test(_int > da) + _test(_list > da) + _test(_ndarray > da) # type: ignore[arg-type] # numpy problem + _test(_var > da) + + # __iadd__ as an example of inplace binary ops + da += _int + da += _list + da += _ndarray + da += _var + + # __neg__ as an example of unary ops + _test(-da) + + +def test_dataset_typed_ops() -> None: + """Tests for type checking of typed_ops on Dataset""" + + ds = Dataset({"a": ("t", [1, 2, 3])}) + + def _test(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + _var = Variable(dims=["t"], data=[1, 2, 3]) + _da = DataArray([1, 2, 3], dims=["t"]) + + # __add__ as an example of binary ops + _test(ds + _int) + _test(ds + _list) + _test(ds + _ndarray) + _test(ds + _var) + _test(ds + _da) + _test(ds + ds) + + # __radd__ as an example of reflexive binary ops + _test(_int + ds) + _test(_list + ds) + _test(_ndarray + ds) + _test(_var + ds) + _test(_da + ds) + + # __eq__ as an example of cmp ops + _test(ds == _int) + _test(ds == _list) + _test(ds == _ndarray) + _test(ds == _var) + _test(ds == _da) + _test(_int == ds) # type: ignore[arg-type] # typeshed problem + _test(_list == ds) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == ds) + _test(_var == ds) + _test(_da == ds) + + # __lt__ as another example of cmp ops + _test(ds < _int) + _test(ds < _list) + _test(ds < _ndarray) + _test(ds < _var) + _test(ds < _da) + _test(_int > ds) + _test(_list > ds) + _test(_ndarray > ds) # type: ignore[arg-type] # numpy problem + _test(_var > ds) + _test(_da > ds) + + # __iadd__ as an example of inplace binary ops + ds += _int + ds += _list + ds += _ndarray + ds += _var + ds += _da + + # __neg__ as an example of unary ops + _test(-ds) + + +def test_dataarray_groupy_typed_ops() -> None: + """Tests for type checking of typed_ops on DataArrayGroupBy""" + + da = DataArray([1, 2, 3], coords={"x": ("t", [1, 2, 2])}, dims=["t"]) + grp = da.groupby("x") + + def _testda(da: DataArray) -> None: + # mypy checks the input type + assert isinstance(da, DataArray) + + def _testds(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _da = DataArray([5, 6], coords={"x": [1, 2]}, dims="x") + _ds = _da.to_dataset(name="a") + + # __add__ as an example of binary ops + _testda(grp + _da) + _testds(grp + _ds) + + # __radd__ as an example of reflexive binary ops + _testda(_da + grp) + _testds(_ds + grp) + + # __eq__ as an example of cmp ops + _testda(grp == _da) + _testda(_da == grp) + _testds(grp == _ds) + _testds(_ds == grp) + + # __lt__ as another example of cmp ops + _testda(grp < _da) + _testda(_da > grp) + _testds(grp < _ds) + _testds(_ds > grp) + + +def test_dataset_groupy_typed_ops() -> None: + """Tests for type checking of typed_ops on DatasetGroupBy""" + + ds = Dataset({"a": ("t", [1, 2, 3])}, coords={"x": ("t", [1, 2, 2])}) + grp = ds.groupby("x") + + def _test(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _da = DataArray([5, 6], coords={"x": [1, 2]}, dims="x") + _ds = _da.to_dataset(name="a") + + # __add__ as an example of binary ops + _test(grp + _da) + _test(grp + _ds) + + # __radd__ as an example of reflexive binary ops + _test(_da + grp) + _test(_ds + grp) + + # __eq__ as an example of cmp ops + _test(grp == _da) + _test(_da == grp) + _test(grp == _ds) + _test(_ds == grp) + + # __lt__ as another example of cmp ops + _test(grp < _da) + _test(_da > grp) + _test(grp < _ds) + _test(_ds > grp) diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 6cd73e9cfb7..6b4c3f38ee9 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -4,7 +4,7 @@ import pytest import xarray as xr -from xarray.tests import assert_array_equal, mock +from xarray.tests import assert_allclose, assert_array_equal, mock from xarray.tests import assert_identical as assert_identical_ @@ -16,16 +16,16 @@ def assert_identical(a, b): assert_array_equal(a, b) -def test_unary(): - args = [ - 0, - np.zeros(2), +@pytest.mark.parametrize( + "a", + [ xr.Variable(["x"], [0, 0]), xr.DataArray([0, 0], dims="x"), xr.Dataset({"y": ("x", [0, 0])}), - ] - for a in args: - assert_identical(a + 1, np.cos(a)) + ], +) +def test_unary(a): + assert_allclose(a + 1, np.cos(a)) def test_binary(): diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 9e872c93c0c..2f11fe688b7 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -4,9 +4,7 @@ import operator import numpy as np -import pandas as pd import pytest -from packaging import version import xarray as xr from xarray.core import dtypes, duck_array_ops @@ -17,6 +15,7 @@ assert_identical, requires_dask, requires_matplotlib, + requires_numbagg, ) from xarray.tests.test_plot import PlotTestCase from xarray.tests.test_variable import _PAD_XR_NP_ARGS @@ -303,11 +302,13 @@ def __call__(self, obj, *args, **kwargs): all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} + from xarray.core.groupby import GroupBy + xarray_classes = ( xr.Variable, xr.DataArray, xr.Dataset, - xr.core.groupby.GroupBy, + GroupBy, ) if not isinstance(obj, xarray_classes): @@ -1498,10 +1499,11 @@ def test_dot_dataarray(dtype): data_array = xr.DataArray(data=array1, dims=("x", "y")) other = xr.DataArray(data=array2, dims=("y", "z")) - expected = attach_units( - xr.dot(strip_units(data_array), strip_units(other)), {None: unit_registry.m} - ) - actual = xr.dot(data_array, other) + with xr.set_options(use_opt_einsum=False): + expected = attach_units( + xr.dot(strip_units(data_array), strip_units(other)), {None: unit_registry.m} + ) + actual = xr.dot(data_array, other) assert_units_equal(expected, actual) assert_identical(expected, actual) @@ -1530,13 +1532,6 @@ class TestVariable: ids=repr, ) def test_aggregation(self, func, dtype): - if ( - func.name == "prod" - and dtype.kind == "f" - and version.parse(pint.__version__) < version.parse("0.19") - ): - pytest.xfail(reason="nanprod is not by older `pint` versions") - array = np.linspace(0, 1, 10).astype(dtype) * ( unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless ) @@ -1639,15 +1634,19 @@ def test_raw_numpy_methods(self, func, unit, error, dtype): variable = xr.Variable("x", array) args = [ - item * unit - if isinstance(item, (int, float, list)) and func.name != "item" - else item + ( + item * unit + if isinstance(item, (int, float, list)) and func.name != "item" + else item + ) for item in func.args ] kwargs = { - key: value * unit - if isinstance(value, (int, float, list)) and func.name != "item" - else value + key: ( + value * unit + if isinstance(value, (int, float, list)) and func.name != "item" + else value + ) for key, value in func.kwargs.items() } @@ -1658,15 +1657,19 @@ def test_raw_numpy_methods(self, func, unit, error, dtype): return converted_args = [ - strip_units(convert_units(item, {None: unit_registry.m})) - if func.name != "item" - else item + ( + strip_units(convert_units(item, {None: unit_registry.m})) + if func.name != "item" + else item + ) for item in args ] converted_kwargs = { - key: strip_units(convert_units(value, {None: unit_registry.m})) - if func.name != "item" - else value + key: ( + strip_units(convert_units(value, {None: unit_registry.m})) + if func.name != "item" + else value + ) for key, value in kwargs.items() } @@ -1978,9 +1981,11 @@ def test_masking(self, func, unit, error, dtype): strip_units( convert_units( other, - {None: base_unit} - if is_compatible(base_unit, unit) - else {None: None}, + ( + {None: base_unit} + if is_compatible(base_unit, unit) + else {None: None} + ), ) ), ), @@ -2008,6 +2013,7 @@ def test_squeeze(self, dim, dtype): assert_units_equal(expected, actual) assert_identical(expected, actual) + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) @pytest.mark.parametrize( "func", ( @@ -2029,7 +2035,7 @@ def test_squeeze(self, dim, dtype): ), ids=repr, ) - def test_computation(self, func, dtype): + def test_computation(self, func, dtype, compute_backend): base_unit = unit_registry.m array = np.linspace(0, 5, 5 * 10).reshape(5, 10).astype(dtype) * base_unit variable = xr.Variable(("x", "y"), array) @@ -2391,13 +2397,6 @@ def test_repr(self, func, variant, dtype): ids=repr, ) def test_aggregation(self, func, dtype): - if ( - func.name == "prod" - and dtype.kind == "f" - and version.parse(pint.__version__) < version.parse("0.19") - ): - pytest.xfail(reason="nanprod is not by older `pint` versions") - array = np.arange(10).astype(dtype) * ( unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless ) @@ -2449,8 +2448,9 @@ def test_binary_operations(self, func, dtype): data_array = xr.DataArray(data=array) units = extract_units(func(array)) - expected = attach_units(func(strip_units(data_array)), units) - actual = func(data_array) + with xr.set_options(use_opt_einsum=False): + expected = attach_units(func(strip_units(data_array)), units) + actual = func(data_array) assert_units_equal(expected, actual) assert_identical(expected, actual) @@ -2535,7 +2535,6 @@ def test_univariate_ufunc(self, units, error, dtype): assert_units_equal(expected, actual) assert_identical(expected, actual) - @pytest.mark.xfail(reason="needs the type register system for __array_ufunc__") @pytest.mark.parametrize( "unit,error", ( @@ -3768,6 +3767,7 @@ def test_differentiate_integrate(self, func, variant, dtype): assert_units_equal(expected, actual) assert_identical(expected, actual) + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) @pytest.mark.parametrize( "variant", ( @@ -3788,7 +3788,7 @@ def test_differentiate_integrate(self, func, variant, dtype): ), ids=repr, ) - def test_computation(self, func, variant, dtype): + def test_computation(self, func, variant, dtype, compute_backend): unit = unit_registry.m variants = { @@ -3814,8 +3814,9 @@ def test_computation(self, func, variant, dtype): if not isinstance(func, (function, method)): units.update(extract_units(func(array.reshape(-1)))) - expected = attach_units(func(strip_units(data_array)), units) - actual = func(data_array) + with xr.set_options(use_opt_einsum=False): + expected = attach_units(func(strip_units(data_array)), units) + actual = func(data_array) assert_units_equal(expected, actual) assert_identical(expected, actual) @@ -3836,23 +3837,21 @@ def test_computation(self, func, variant, dtype): method("groupby", "x"), method("groupby_bins", "y", bins=4), method("coarsen", y=2), - pytest.param( - method("rolling", y=3), - marks=pytest.mark.xfail( - reason="numpy.lib.stride_tricks.as_strided converts to ndarray" - ), - ), - pytest.param( - method("rolling_exp", y=3), - marks=pytest.mark.xfail( - reason="numbagg functions are not supported by pint" - ), - ), + method("rolling", y=3), + pytest.param(method("rolling_exp", y=3), marks=requires_numbagg), method("weighted", xr.DataArray(data=np.linspace(0, 1, 10), dims="y")), ), ids=repr, ) def test_computation_objects(self, func, variant, dtype): + if variant == "data": + if func.name == "rolling_exp": + pytest.xfail(reason="numbagg functions are not supported by pint") + elif func.name == "rolling": + pytest.xfail( + reason="numpy.lib.stride_tricks.as_strided converts to ndarray" + ) + unit = unit_registry.m variants = { @@ -3883,11 +3882,11 @@ def test_computation_objects(self, func, variant, dtype): def test_resample(self, dtype): array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m - time = pd.date_range("10-09-2010", periods=len(array), freq="1y") + time = xr.date_range("10-09-2010", periods=len(array), freq="YE") data_array = xr.DataArray(data=array, coords={"time": time}, dims="time") units = extract_units(data_array) - func = method("resample", time="6m") + func = method("resample", time="6ME") expected = attach_units(func(strip_units(data_array)).mean(), units) actual = func(data_array).mean() @@ -3895,6 +3894,7 @@ def test_resample(self, dtype): assert_units_equal(expected, actual) assert_identical(expected, actual) + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) @pytest.mark.parametrize( "variant", ( @@ -3915,7 +3915,7 @@ def test_resample(self, dtype): ), ids=repr, ) - def test_grouped_operations(self, func, variant, dtype): + def test_grouped_operations(self, func, variant, dtype, compute_backend): unit = unit_registry.m variants = { @@ -3945,9 +3945,12 @@ def test_grouped_operations(self, func, variant, dtype): for key, value in func.kwargs.items() } expected = attach_units( - func(strip_units(data_array).groupby("y"), **stripped_kwargs), units + func( + strip_units(data_array).groupby("y", squeeze=False), **stripped_kwargs + ), + units, ) - actual = func(data_array.groupby("y")) + actual = func(data_array.groupby("y", squeeze=False)) assert_units_equal(expected, actual) assert_identical(expected, actual) @@ -4090,13 +4093,6 @@ def test_repr(self, func, variant, dtype): ids=repr, ) def test_aggregation(self, func, dtype): - if ( - func.name == "prod" - and dtype.kind == "f" - and version.parse(pint.__version__) < version.parse("0.19") - ): - pytest.xfail(reason="nanprod is not by older `pint` versions") - unit_a, unit_b = ( (unit_registry.Pa, unit_registry.degK) if func.name != "cumprod" @@ -5256,6 +5252,7 @@ def test_interp_reindex_like_indexing(self, func, unit, error, dtype): assert_units_equal(expected, actual) assert_equal(expected, actual) + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) @pytest.mark.parametrize( "func", ( @@ -5278,7 +5275,7 @@ def test_interp_reindex_like_indexing(self, func, unit, error, dtype): "coords", ), ) - def test_computation(self, func, variant, dtype): + def test_computation(self, func, variant, dtype, compute_backend): variants = { "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), "dims": ((1, 1), unit_registry.m, 1), @@ -5390,7 +5387,7 @@ def test_resample(self, variant, dtype): array1 = np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit1 array2 = np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit2 - t = pd.date_range("10-09-2010", periods=array1.shape[0], freq="1y") + t = xr.date_range("10-09-2010", periods=array1.shape[0], freq="YE") y = np.arange(5) * dim_unit z = np.arange(8) * dim_unit @@ -5402,7 +5399,7 @@ def test_resample(self, variant, dtype): ) units = extract_units(ds) - func = method("resample", time="6m") + func = method("resample", time="6ME") expected = attach_units(func(strip_units(ds)).mean(), units) actual = func(ds).mean() @@ -5410,6 +5407,7 @@ def test_resample(self, variant, dtype): assert_units_equal(expected, actual) assert_equal(expected, actual) + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) @pytest.mark.parametrize( "func", ( @@ -5431,7 +5429,7 @@ def test_resample(self, variant, dtype): "coords", ), ) - def test_grouped_operations(self, func, variant, dtype): + def test_grouped_operations(self, func, variant, dtype, compute_backend): variants = { "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), "dims": ((1, 1), unit_registry.m, 1), @@ -5459,9 +5457,9 @@ def test_grouped_operations(self, func, variant, dtype): name: strip_units(value) for name, value in func.kwargs.items() } expected = attach_units( - func(strip_units(ds).groupby("y"), **stripped_kwargs), units + func(strip_units(ds).groupby("y", squeeze=False), **stripped_kwargs), units ) - actual = func(ds.groupby("y")) + actual = func(ds.groupby("y", squeeze=False)) assert_units_equal(expected, actual) assert_equal(expected, actual) @@ -5631,12 +5629,12 @@ def test_duck_array_ops(self): import dask.array d = dask.array.array([1, 2, 3]) - q = pint.Quantity(d, units="m") + q = unit_registry.Quantity(d, units="m") da = xr.DataArray(q, dims="x") actual = da.mean().compute() actual.name = None - expected = xr.DataArray(pint.Quantity(np.array(2.0), units="m")) + expected = xr.DataArray(unit_registry.Quantity(np.array(2.0), units="m")) assert_units_equal(expected, actual) # Don't use isinstance b/c we don't want to allow subclasses through diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 567e5a4e936..82841024182 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -1,13 +1,18 @@ from __future__ import annotations -from collections.abc import Hashable, Iterable, Sequence +from collections.abc import Hashable import numpy as np import pandas as pd import pytest from xarray.core import duck_array_ops, utils -from xarray.core.utils import either_dict_or_kwargs, expand_args_to_dims, iterate_nested +from xarray.core.utils import ( + either_dict_or_kwargs, + expand_args_to_dims, + infix_dims, + iterate_nested, +) from xarray.tests import assert_array_equal, requires_dask @@ -23,11 +28,13 @@ def new_method(): @pytest.mark.parametrize( - "a, b, expected", [["a", "b", np.array(["a", "b"])], [1, 2, pd.Index([1, 2])]] + ["a", "b", "expected"], + [ + [np.array(["a"]), np.array(["b"]), np.array(["a", "b"])], + [np.array([1], dtype="int64"), np.array([2], dtype="int64"), pd.Index([1, 2])], + ], ) def test_maybe_coerce_to_str(a, b, expected): - a = np.array([a]) - b = np.array([b]) index = pd.Index(a).append(pd.Index(b)) actual = utils.maybe_coerce_to_str(index, [a, b]) @@ -237,7 +244,7 @@ def test_either_dict_or_kwargs(): ], ) def test_infix_dims(supplied, all_, expected): - result = list(utils.infix_dims(supplied, all_)) + result = list(infix_dims(supplied, all_)) assert result == expected @@ -246,7 +253,7 @@ def test_infix_dims(supplied, all_, expected): ) def test_infix_dims_errors(supplied, all_): with pytest.raises(ValueError): - list(utils.infix_dims(supplied, all_)) + list(infix_dims(supplied, all_)) @pytest.mark.parametrize( @@ -255,17 +262,18 @@ def test_infix_dims_errors(supplied, all_): pytest.param("a", ("a",), id="str"), pytest.param(["a", "b"], ("a", "b"), id="list_of_str"), pytest.param(["a", 1], ("a", 1), id="list_mixed"), + pytest.param(["a", ...], ("a", ...), id="list_with_ellipsis"), pytest.param(("a", "b"), ("a", "b"), id="tuple_of_str"), pytest.param(["a", ("b", "c")], ("a", ("b", "c")), id="list_with_tuple"), pytest.param((("b", "c"),), (("b", "c"),), id="tuple_of_tuple"), + pytest.param({"a", 1}, tuple({"a", 1}), id="non_sequence_collection"), + pytest.param((), (), id="empty_tuple"), + pytest.param(set(), (), id="empty_collection"), pytest.param(None, None, id="None"), pytest.param(..., ..., id="ellipsis"), ], ) -def test_parse_dims( - dim: str | Iterable[Hashable] | None, - expected: tuple[Hashable, ...], -) -> None: +def test_parse_dims(dim, expected) -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables actual = utils.parse_dims(dim, all_dims, replace_none=False) assert actual == expected @@ -295,7 +303,7 @@ def test_parse_dims_replace_none(dim: None | ellipsis) -> None: pytest.param(["x", 2], id="list_missing_all"), ], ) -def test_parse_dims_raises(dim: str | Iterable[Hashable]) -> None: +def test_parse_dims_raises(dim) -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables with pytest.raises(ValueError, match="'x'"): utils.parse_dims(dim, all_dims, check_exists=True) @@ -311,10 +319,7 @@ def test_parse_dims_raises(dim: str | Iterable[Hashable]) -> None: pytest.param(["a", ..., "b"], ("a", "c", "b"), id="list_with_middle_ellipsis"), ], ) -def test_parse_ordered_dims( - dim: str | Sequence[Hashable | ellipsis], - expected: tuple[Hashable, ...], -) -> None: +def test_parse_ordered_dims(dim, expected) -> None: all_dims = ("a", "b", "c") actual = utils.parse_ordered_dims(dim, all_dims) assert actual == expected diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index f656818c71f..061510f2515 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1,17 +1,18 @@ from __future__ import annotations import warnings +from abc import ABC from copy import copy, deepcopy from datetime import datetime, timedelta from textwrap import dedent +from typing import Generic import numpy as np import pandas as pd import pytest import pytz -from packaging.version import Version -from xarray import Coordinate, DataArray, Dataset, IndexVariable, Variable, set_options +from xarray import DataArray, Dataset, IndexVariable, Variable, set_options from xarray.core import dtypes, duck_array_ops, indexing from xarray.core.common import full_like, ones_like, zeros_like from xarray.core.indexing import ( @@ -25,9 +26,10 @@ PandasIndexingAdapter, VectorizedIndexer, ) -from xarray.core.pycompat import array_type +from xarray.core.types import T_DuckArray from xarray.core.utils import NDArrayMixin from xarray.core.variable import as_compatible_data, as_variable +from xarray.namedarray.pycompat import array_type from xarray.tests import ( assert_allclose, assert_array_equal, @@ -44,6 +46,7 @@ requires_sparse, source_ndarray, ) +from xarray.tests.test_namedarray import NamedArraySubclassobjects dask_array_type = array_type("dask") @@ -61,32 +64,26 @@ def var(): return Variable(dims=list("xyz"), data=np.random.rand(3, 4, 5)) -class VariableSubclassobjects: - cls: staticmethod[Variable] +@pytest.mark.parametrize( + "data", + [ + np.array(["a", "bc", "def"], dtype=object), + np.array(["2019-01-01", "2019-01-02", "2019-01-03"], dtype="datetime64[ns]"), + ], +) +def test_as_compatible_data_writeable(data): + pd.set_option("mode.copy_on_write", True) + # GH8843, ensure writeable arrays for data_vars even with + # pandas copy-on-write mode + assert as_compatible_data(data).flags.writeable + pd.reset_option("mode.copy_on_write") - def test_properties(self): - data = 0.5 * np.arange(10) - v = self.cls(["time"], data, {"foo": "bar"}) - assert v.dims == ("time",) - assert_array_equal(v.values, data) - assert v.dtype == float - assert v.shape == (10,) - assert v.size == 10 - assert v.sizes == {"time": 10} - assert v.nbytes == 80 - assert v.ndim == 1 - assert len(v) == 10 - assert v.attrs == {"foo": "bar"} - - def test_attrs(self): - v = self.cls(["time"], 0.5 * np.arange(10)) - assert v.attrs == {} - attrs = {"foo": "bar"} - v.attrs = attrs - assert v.attrs == attrs - assert isinstance(v.attrs, dict) - v.attrs["foo"] = "baz" - assert v.attrs["foo"] == "baz" + +class VariableSubclassobjects(NamedArraySubclassobjects, ABC): + @pytest.fixture + def target(self, data): + data = 0.5 * np.arange(10).reshape(2, 5) + return Variable(["x", "y"], data) def test_getitem_dict(self): v = self.cls(["x"], np.random.randn(5)) @@ -194,7 +191,7 @@ def test_index_0d_int(self): self._assertIndexedLikeNDArray(x, value, dtype) def test_index_0d_float(self): - for value, dtype in [(0.5, np.float_), (np.float32(0.5), np.float32)]: + for value, dtype in [(0.5, float), (np.float32(0.5), np.float32)]: x = self.cls(["x"], [value]) self._assertIndexedLikeNDArray(x, value, dtype) @@ -204,6 +201,7 @@ def test_index_0d_string(self): x = self.cls(["x"], [value]) self._assertIndexedLikeNDArray(x, value, dtype) + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_index_0d_datetime(self): d = datetime(2000, 1, 1) x = self.cls(["x"], [d]) @@ -215,6 +213,7 @@ def test_index_0d_datetime(self): x = self.cls(["x"], pd.DatetimeIndex([d])) self._assertIndexedLikeNDArray(x, np.datetime64(d), "datetime64[ns]") + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_index_0d_timedelta64(self): td = timedelta(hours=1) @@ -275,6 +274,7 @@ def test_0d_time_data(self): expected = np.datetime64("2000-01-01", "ns") assert x[0].values == expected + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_datetime64_conversion(self): times = pd.date_range("2000-01-01", periods=3) for values, preserve_source in [ @@ -290,6 +290,7 @@ def test_datetime64_conversion(self): same_source = source_ndarray(v.values) is source_ndarray(values) assert preserve_source == same_source + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_timedelta64_conversion(self): times = pd.timedelta_range(start=0, periods=3) for values, preserve_source in [ @@ -310,6 +311,7 @@ def test_object_conversion(self): actual = self.cls("x", data) assert actual.dtype == data.dtype + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_datetime64_valid_range(self): data = np.datetime64("1250-01-01", "us") pderror = pd.errors.OutOfBoundsDatetime @@ -317,6 +319,7 @@ def test_datetime64_valid_range(self): self.cls(["t"], [data]) @pytest.mark.xfail(reason="pandas issue 36615") + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_timedelta64_valid_range(self): data = np.timedelta64("200000", "D") pderror = pd.errors.OutOfBoundsTimedelta @@ -330,14 +333,15 @@ def test_pandas_data(self): assert v[0].values == v.values[0] def test_pandas_period_index(self): - v = self.cls(["x"], pd.period_range(start="2000", periods=20, freq="B")) + v = self.cls(["x"], pd.period_range(start="2000", periods=20, freq="D")) v = v.load() # for dask-based Variable - assert v[0] == pd.Period("2000", freq="B") - assert "Period('2000-01-03', 'B')" in repr(v) + assert v[0] == pd.Period("2000", freq="D") + assert "Period('2000-01-01', 'D')" in repr(v) - def test_1d_math(self): - x = 1.0 * np.arange(5) - y = np.ones(5) + @pytest.mark.parametrize("dtype", [float, int]) + def test_1d_math(self, dtype: np.typing.DTypeLike) -> None: + x = np.arange(5, dtype=dtype) + y = np.ones(5, dtype=dtype) # should we need `.to_base_variable()`? # probably a break that `+v` changes type? @@ -351,11 +355,18 @@ def test_1d_math(self): assert_identical(base_v, v + 0) assert_identical(base_v, 0 + v) assert_identical(base_v, v * 1) + if dtype is int: + assert_identical(base_v, v << 0) + assert_array_equal(v << 3, x << 3) + assert_array_equal(v >> 2, x >> 2) # binary ops with numpy arrays assert_array_equal((v * x).values, x**2) assert_array_equal((x * v).values, x**2) assert_array_equal(v - y, v - 1) assert_array_equal(y - v, 1 - v) + if dtype is int: + assert_array_equal(v << x, x << x) + assert_array_equal(v >> x, x >> x) # verify attributes are dropped v2 = self.cls(["x"], x, {"units": "meters"}) with set_options(keep_attrs=False): @@ -369,10 +380,10 @@ def test_1d_math(self): # something complicated assert_array_equal((v**2 * w - 1 + x).values, x**2 * y - 1 + x) # make sure dtype is preserved (for Index objects) - assert float == (+v).dtype - assert float == (+v).values.dtype - assert float == (0 + v).dtype - assert float == (0 + v).values.dtype + assert dtype == (+v).dtype + assert dtype == (+v).values.dtype + assert dtype == (0 + v).dtype + assert dtype == (0 + v).values.dtype # check types of returned data assert isinstance(+v, Variable) assert not isinstance(+v, IndexVariable) @@ -455,6 +466,23 @@ def test_encoding_preserved(self): assert_identical(expected.to_base_variable(), actual.to_base_variable()) assert expected.encoding == actual.encoding + def test_drop_encoding(self) -> None: + encoding1 = {"scale_factor": 1} + # encoding set via cls constructor + v1 = self.cls(["a"], [0, 1, 2], encoding=encoding1) + assert v1.encoding == encoding1 + v2 = v1.drop_encoding() + assert v1.encoding == encoding1 + assert v2.encoding == {} + + # encoding set via setter + encoding3 = {"scale_factor": 10} + v3 = self.cls(["a"], [0, 1, 2], encoding=encoding3) + assert v3.encoding == encoding3 + v4 = v3.drop_encoding() + assert v3.encoding == encoding3 + assert v4.encoding == {} + def test_concat(self): x = np.arange(5) y = np.arange(5, 10) @@ -582,7 +610,7 @@ def test_copy_with_data_errors(self) -> None: orig = Variable(("x", "y"), [[1.5, 2.0], [3.1, 4.3]], {"foo": "bar"}) new_data = [2.5, 5.0] with pytest.raises(ValueError, match=r"must match shape of object"): - orig.copy(data=new_data) + orig.copy(data=new_data) # type: ignore[arg-type] def test_copy_index_with_data(self) -> None: orig = IndexVariable("x", np.arange(5)) @@ -850,20 +878,10 @@ def test_getitem_error(self): "mode", [ "mean", - pytest.param( - "median", - marks=pytest.mark.xfail(reason="median is not implemented by Dask"), - ), - pytest.param( - "reflect", marks=pytest.mark.xfail(reason="dask.array.pad bug") - ), + "median", + "reflect", "edge", - pytest.param( - "linear_ramp", - marks=pytest.mark.xfail( - reason="pint bug: https://github.com/hgrecco/pint/issues/1026" - ), - ), + "linear_ramp", "maximum", "minimum", "symmetric", @@ -891,7 +909,7 @@ def test_pad_constant_values(self, xr_arg, np_arg): actual = v.pad(**xr_arg) expected = np.pad( - np.array(v.data.astype(float)), + np.array(duck_array_ops.astype(v.data, float)), np_arg, mode="constant", constant_values=np.nan, @@ -1033,15 +1051,15 @@ def test_rolling_window_errors(self, dim, window, window_dim, center): class TestVariable(VariableSubclassobjects): - cls = staticmethod(Variable) + def cls(self, *args, **kwargs) -> Variable: + return Variable(*args, **kwargs) @pytest.fixture(autouse=True) def setup(self): self.d = np.random.random((10, 3)).astype(np.float64) - def test_data_and_values(self): + def test_values(self): v = Variable(["time", "x"], self.d) - assert_array_equal(v.data, self.d) assert_array_equal(v.values, self.d) assert source_ndarray(v.values) is self.d with pytest.raises(ValueError): @@ -1050,18 +1068,16 @@ def test_data_and_values(self): d2 = np.random.random((10, 3)) v.values = d2 assert source_ndarray(v.values) is d2 - d3 = np.random.random((10, 3)) - v.data = d3 - assert source_ndarray(v.data) is d3 def test_numpy_same_methods(self): v = Variable([], np.float32(0.0)) assert v.item() == 0 - assert type(v.item()) is float + assert type(v.item()) is float # noqa: E721 v = IndexVariable("x", np.arange(5)) assert 2 == v.searchsorted(2) + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_datetime64_conversion_scalar(self): expected = np.datetime64("2000-01-01", "ns") for values in [ @@ -1074,6 +1090,7 @@ def test_datetime64_conversion_scalar(self): assert v.values == expected assert v.values.dtype == np.dtype("datetime64[ns]") + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_timedelta64_conversion_scalar(self): expected = np.timedelta64(24 * 60 * 60 * 10**9, "ns") for values in [ @@ -1091,15 +1108,16 @@ def test_0d_str(self): assert v.dtype == np.dtype("U3") assert v.values == "foo" - v = Variable([], np.string_("foo")) + v = Variable([], np.bytes_("foo")) assert v.dtype == np.dtype("S3") - assert v.values == bytes("foo", "ascii") + assert v.values == "foo".encode("ascii") def test_0d_datetime(self): v = Variable([], pd.Timestamp("2000-01-01")) assert v.dtype == np.dtype("datetime64[ns]") assert v.values == np.datetime64("2000-01-01", "ns") + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_0d_timedelta(self): for td in [pd.to_timedelta("1s"), np.timedelta64(1, "s")]: v = Variable([], td) @@ -1209,8 +1227,10 @@ def test_as_variable(self): expected = Variable(("x", "y"), data) with pytest.raises(ValueError, match=r"without explicit dimension names"): as_variable(data, name="x") - with pytest.raises(ValueError, match=r"has more than 1-dimension"): - as_variable(expected, name="x") + + # name of nD variable matches dimension name + actual = as_variable(expected, name="x") + assert_identical(expected, actual) # test datetime, timedelta conversion dt = np.array([datetime(1999, 1, 1) + timedelta(days=x) for x in range(10)]) @@ -1223,11 +1243,12 @@ def test_as_variable(self): def test_repr(self): v = Variable(["time", "x"], [[1, 2, 3], [4, 5, 6]], {"foo": "bar"}) + v = v.astype(np.uint64) expected = dedent( """ - + Size: 48B array([[1, 2, 3], - [4, 5, 6]]) + [4, 5, 6]], dtype=uint64) Attributes: foo: bar """ @@ -1427,10 +1448,10 @@ def test_isel(self): def test_index_0d_numpy_string(self): # regression test to verify our work around for indexing 0d strings - v = Variable([], np.string_("asdf")) + v = Variable([], np.bytes_("asdf")) assert_identical(v[()], v) - v = Variable([], np.unicode_("asdf")) + v = Variable([], np.str_("asdf")) assert_identical(v[()], v) def test_indexing_0d_unicode(self): @@ -1538,6 +1559,7 @@ def test_transpose(self): v.transpose(..., "not_a_dim", missing_dims="warn") assert_identical(expected_ell, actual) + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_transpose_0d(self): for value in [ 3.5, @@ -1667,6 +1689,15 @@ def test_stack_unstack_consistency(self): actual = v.stack(z=("x", "y")).unstack(z={"x": 2, "y": 2}) assert_identical(actual, v) + @pytest.mark.filterwarnings("error::RuntimeWarning") + def test_unstack_without_missing(self): + v = Variable(["z"], [0, 1, 2, 3]) + expected = Variable(["x", "y"], [[0, 1], [2, 3]]) + + actual = v.unstack(z={"x": 2, "y": 2}) + + assert_identical(actual, expected) + def test_broadcasting_math(self): x = np.random.randn(2, 3) v = Variable(["a", "b"], x) @@ -1690,6 +1721,7 @@ def test_broadcasting_math(self): v * w[0], Variable(["a", "b", "c", "d"], np.einsum("ab,cd->abcd", x, y[0])) ) + @pytest.mark.filterwarnings("ignore:Duplicate dimension names") def test_broadcasting_failures(self): a = Variable(["x"], np.arange(10)) b = Variable(["x"], np.arange(5)) @@ -1738,7 +1770,8 @@ def test_reduce(self): v.mean(dim="x", axis=0) @requires_bottleneck - def test_reduce_use_bottleneck(self, monkeypatch): + @pytest.mark.parametrize("compute_backend", ["bottleneck"], indirect=True) + def test_reduce_use_bottleneck(self, monkeypatch, compute_backend): def raise_if_called(*args, **kwargs): raise RuntimeError("should not have been called") @@ -1761,7 +1794,7 @@ def raise_if_called(*args, **kwargs): ) def test_quantile(self, q, axis, dim, skipna): d = self.d.copy() - d[0, 0] = np.NaN + d[0, 0] = np.nan v = Variable(["x", "y"], d) actual = v.quantile(q, dim=dim, skipna=skipna) @@ -1791,10 +1824,7 @@ def test_quantile_method(self, method, use_dask) -> None: q = np.array([0.25, 0.5, 0.75]) actual = v.quantile(q, dim="y", method=method) - if Version(np.__version__) >= Version("1.22"): - expected = np.nanquantile(self.d, q, axis=1, method=method) - else: - expected = np.nanquantile(self.d, q, axis=1, interpolation=method) + expected = np.nanquantile(self.d, q, axis=1, method=method) if use_dask: assert isinstance(actual.data, dask_array_type) @@ -1828,21 +1858,34 @@ def test_quantile_chunked_dim_error(self): with pytest.raises(ValueError, match=r"consists of multiple chunks"): v.quantile(0.5, dim="x") + @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) @pytest.mark.parametrize("q", [-0.1, 1.1, [2], [0.25, 2]]) - def test_quantile_out_of_bounds(self, q): + def test_quantile_out_of_bounds(self, q, compute_backend): v = Variable(["x", "y"], self.d) # escape special characters with pytest.raises( - ValueError, match=r"Quantiles must be in the range \[0, 1\]" + ValueError, + match=r"(Q|q)uantiles must be in the range \[0, 1\]", ): v.quantile(q, dim="x") @requires_dask @requires_bottleneck - def test_rank_dask_raises(self): - v = Variable(["x"], [3.0, 1.0, np.nan, 2.0, 4.0]).chunk(2) - with pytest.raises(TypeError, match=r"arrays stored as dask"): + def test_rank_dask(self): + # Instead of a single test here, we could parameterize the other tests for both + # arrays. But this is sufficient. + v = Variable( + ["x", "y"], [[30.0, 1.0, np.nan, 20.0, 4.0], [30.0, 1.0, np.nan, 20.0, 4.0]] + ).chunk(x=1) + expected = Variable( + ["x", "y"], [[4.0, 1.0, np.nan, 3.0, 2.0], [4.0, 1.0, np.nan, 3.0, 2.0]] + ) + assert_equal(v.rank("y").compute(), expected) + + with pytest.raises( + ValueError, match=r" with dask='parallelized' consists of multiple chunks" + ): v.rank("x") def test_rank_use_bottleneck(self): @@ -1874,7 +1917,8 @@ def test_rank(self): v_expect = Variable(["x"], [0.75, 0.25, np.nan, 0.5, 1.0]) assert_equal(v.rank("x", pct=True), v_expect) # invalid dim - with pytest.raises(ValueError, match=r"not found"): + with pytest.raises(ValueError): + # apply_ufunc error message isn't great here — `ValueError: tuple.index(x): x not in tuple` v.rank("y") def test_big_endian_reduce(self): @@ -2203,7 +2247,8 @@ def test_coarsen_keep_attrs(self, operation="mean"): @requires_dask class TestVariableWithDask(VariableSubclassobjects): - cls = staticmethod(lambda *args: Variable(*args).chunk()) + def cls(self, *args, **kwargs) -> Variable: + return Variable(*args, **kwargs).chunk() def test_chunk(self): unblocked = Variable(["dim_0", "dim_1"], np.ones((3, 4))) @@ -2296,12 +2341,35 @@ def test_dask_rolling(self, dim, window, center): assert actual.shape == expected.shape assert_equal(actual, expected) - @pytest.mark.xfail( - reason="https://github.com/pydata/xarray/issues/6209#issuecomment-1025116203" - ) def test_multiindex(self): super().test_multiindex() + @pytest.mark.parametrize( + "mode", + [ + "mean", + pytest.param( + "median", + marks=pytest.mark.xfail(reason="median is not implemented by Dask"), + ), + pytest.param( + "reflect", marks=pytest.mark.xfail(reason="dask.array.pad bug") + ), + "edge", + "linear_ramp", + "maximum", + "minimum", + "symmetric", + "wrap", + ], + ) + @pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS) + @pytest.mark.filterwarnings( + r"ignore:dask.array.pad.+? converts integers to floats." + ) + def test_pad(self, mode, xr_arg, np_arg): + super().test_pad(mode, xr_arg, np_arg) + @requires_sparse class TestVariableWithSparse: @@ -2315,7 +2383,8 @@ def test_as_sparse(self): class TestIndexVariable(VariableSubclassobjects): - cls = staticmethod(IndexVariable) + def cls(self, *args, **kwargs) -> IndexVariable: + return IndexVariable(*args, **kwargs) def test_init(self): with pytest.raises(ValueError, match=r"must be 1-dimensional"): @@ -2404,11 +2473,6 @@ def test_concat_str_dtype(self, dtype): assert actual.identical(expected) assert np.issubdtype(actual.dtype, dtype) - def test_coordinate_alias(self): - with pytest.warns(Warning, match="deprecated"): - x = Coordinate("x", [1, 2, 3]) - assert isinstance(x, IndexVariable) - def test_datetime64(self): # GH:1932 Make sure indexing keeps precision t = np.array([1518418799999986560, 1518418799999996560], dtype="datetime64[ns]") @@ -2486,7 +2550,7 @@ def test_to_index_variable_copy(self) -> None: assert a.dims == ("x",) -class TestAsCompatibleData: +class TestAsCompatibleData(Generic[T_DuckArray]): def test_unchanged_types(self): types = (np.asarray, PandasIndexingAdapter, LazilyIndexedArray) for t in types: @@ -2519,6 +2583,20 @@ def test_masked_array(self): assert_array_equal(expected, actual) assert np.dtype(float) == actual.dtype + original = np.ma.MaskedArray([1.0, 2.0], mask=[True, False]) + original.flags.writeable = False + expected = [np.nan, 2.0] + actual = as_compatible_data(original) + assert_array_equal(expected, actual) + assert np.dtype(float) == actual.dtype + + # GH2377 + actual = Variable(dims=tuple(), data=np.ma.masked) + expected = Variable(dims=tuple(), data=np.nan) + assert_array_equal(expected, actual) + assert actual.dtype == expected.dtype + + @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") def test_datetime(self): expected = np.datetime64("2000-01-01") actual = as_compatible_data(expected) @@ -2547,23 +2625,23 @@ def test_datetime(self): @requires_pandas_version_two def test_tz_datetime(self) -> None: - tz = pytz.timezone("US/Eastern") + tz = pytz.timezone("America/New_York") times_ns = pd.date_range("2000", periods=1, tz=tz) times_s = times_ns.astype(pd.DatetimeTZDtype("s", tz)) with warnings.catch_warnings(): warnings.simplefilter("ignore") - actual = as_compatible_data(times_s) + actual: T_DuckArray = as_compatible_data(times_s) assert actual.array == times_s assert actual.array.dtype == pd.DatetimeTZDtype("ns", tz) series = pd.Series(times_s) with warnings.catch_warnings(): warnings.simplefilter("ignore") - actual = as_compatible_data(series) + actual2: T_DuckArray = as_compatible_data(series) - np.testing.assert_array_equal(actual, series.values) - assert actual.dtype == np.dtype("datetime64[ns]") + np.testing.assert_array_equal(actual2, series.values) + assert actual2.dtype == np.dtype("datetime64[ns]") def test_full_like(self) -> None: # For more thorough tests, see test_variable.py @@ -2591,7 +2669,7 @@ def test_full_like(self) -> None: def test_full_like_dask(self) -> None: orig = Variable( dims=("x", "y"), data=[[1.5, 2.0], [3.1, 4.3]], attrs={"foo": "bar"} - ).chunk(((1, 1), (2,))) + ).chunk(dict(x=(1, 1), y=(2,))) def check(actual, expect_dtype, expect_values): assert actual.dtype == expect_dtype @@ -2662,7 +2740,7 @@ def __init__(self, array): def test_raise_no_warning_for_nan_in_binary_ops(): with assert_no_warnings(): - Variable("x", [1, 2, np.NaN]) > 0 + Variable("x", [1, 2, np.nan]) > 0 class TestBackendIndexing: @@ -2847,9 +2925,11 @@ def test_from_pint_wrapping_dask(self, Var): (pd.date_range("2000", periods=1), False), (datetime(2000, 1, 1), False), (np.array([datetime(2000, 1, 1)]), False), - (pd.date_range("2000", periods=1, tz=pytz.timezone("US/Eastern")), False), + (pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")), False), ( - pd.Series(pd.date_range("2000", periods=1, tz=pytz.timezone("US/Eastern"))), + pd.Series( + pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")) + ), False, ), ], @@ -2871,8 +2951,9 @@ def test_datetime_conversion_warning(values, warns_under_pandas_version_two) -> # The only case where a non-datetime64 dtype can occur currently is in # the case that the variable is backed by a timezone-aware # DatetimeIndex, and thus is hidden within the PandasIndexingAdapter class. + assert isinstance(var._data, PandasIndexingAdapter) assert var._data.array.dtype == pd.DatetimeTZDtype( - "ns", pytz.timezone("US/Eastern") + "ns", pytz.timezone("America/New_York") ) @@ -2884,12 +2965,14 @@ def test_pandas_two_only_datetime_conversion_warnings() -> None: (pd.date_range("2000", periods=1), "datetime64[s]"), (pd.Series(pd.date_range("2000", periods=1)), "datetime64[s]"), ( - pd.date_range("2000", periods=1, tz=pytz.timezone("US/Eastern")), - pd.DatetimeTZDtype("s", pytz.timezone("US/Eastern")), + pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")), + pd.DatetimeTZDtype("s", pytz.timezone("America/New_York")), ), ( - pd.Series(pd.date_range("2000", periods=1, tz=pytz.timezone("US/Eastern"))), - pd.DatetimeTZDtype("s", pytz.timezone("US/Eastern")), + pd.Series( + pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")) + ), + pd.DatetimeTZDtype("s", pytz.timezone("America/New_York")), ), ] for data, dtype in cases: @@ -2902,8 +2985,9 @@ def test_pandas_two_only_datetime_conversion_warnings() -> None: # The only case where a non-datetime64 dtype can occur currently is in # the case that the variable is backed by a timezone-aware # DatetimeIndex, and thus is hidden within the PandasIndexingAdapter class. + assert isinstance(var._data, PandasIndexingAdapter) assert var._data.array.dtype == pd.DatetimeTZDtype( - "ns", pytz.timezone("US/Eastern") + "ns", pytz.timezone("America/New_York") ) @@ -2942,3 +3026,19 @@ def test_pandas_two_only_timedelta_conversion_warning() -> None: var = Variable(["time"], data) assert var.dtype == np.dtype("timedelta64[ns]") + + +@requires_pandas_version_two +@pytest.mark.parametrize( + ("index", "dtype"), + [ + (pd.date_range("2000", periods=1), "datetime64"), + (pd.timedelta_range("1", periods=1), "timedelta64"), + ], + ids=lambda x: f"{x}", +) +def test_pandas_indexing_adapter_non_nanosecond_conversion(index, dtype) -> None: + data = PandasIndexingAdapter(index.astype(f"{dtype}[s]")) + with pytest.warns(UserWarning, match="non-nanosecond precision"): + var = Variable(["time"], data) + assert var.dtype == np.dtype(f"{dtype}[ns]") diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index e2530d41fbe..f3337d70a76 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -58,7 +58,7 @@ def test_weighted_weights_nan_raises_dask(as_dataset, weights): @requires_cftime @requires_dask @pytest.mark.parametrize("time_chunks", (1, 5)) -@pytest.mark.parametrize("resample_spec", ("1AS", "5AS", "10AS")) +@pytest.mark.parametrize("resample_spec", ("1YS", "5YS", "10YS")) def test_weighted_lazy_resample(time_chunks, resample_spec): # https://github.com/pydata/xarray/issues/4625 @@ -67,7 +67,7 @@ def mean_func(ds): return ds.weighted(ds.weights).mean("time") # example dataset - t = xr.cftime_range(start="2000", periods=20, freq="1AS") + t = xr.cftime_range(start="2000", periods=20, freq="1YS") weights = xr.DataArray(np.random.rand(len(t)), dims=["time"], coords={"time": t}) data = xr.DataArray( np.random.rand(len(t)), dims=["time"], coords={"time": t, "weights": weights} @@ -608,7 +608,7 @@ def test_weighted_operations_3D(dim, add_nans, skipna): # add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700) if add_nans: c = int(data.size * 0.25) - data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN + data.ravel()[np.random.choice(data.size, c, replace=False)] = np.nan data = DataArray(data, dims=dims, coords=coords) @@ -631,7 +631,7 @@ def test_weighted_quantile_3D(dim, q, add_nans, skipna): # add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700) if add_nans: c = int(data.size * 0.25) - data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN + data.ravel()[np.random.choice(data.size, c, replace=False)] = np.nan da = DataArray(data, dims=dims, coords=coords) @@ -709,7 +709,7 @@ def test_weighted_operations_different_shapes( # add approximately 25 % NaNs if add_nans: c = int(data.size * 0.25) - data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN + data.ravel()[np.random.choice(data.size, c, replace=False)] = np.nan data = DataArray(data) @@ -782,9 +782,12 @@ def test_weighted_bad_dim(operation, as_dataset): if operation == "quantile": kwargs["q"] = 0.5 - error_msg = ( - f"{data.__class__.__name__}Weighted" - " does not contain the dimensions: {'bad_dim'}" - ) - with pytest.raises(ValueError, match=error_msg): + with pytest.raises( + ValueError, + match=( + f"Dimensions \\('bad_dim',\\) not found in {data.__class__.__name__}Weighted " + # the order of (dim_0, dim_1) varies + "dimensions \\(('dim_0', 'dim_1'|'dim_1', 'dim_0')\\)" + ), + ): getattr(data.weighted(weights), operation)(**kwargs) diff --git a/xarray/tutorial.py b/xarray/tutorial.py index 17fde8e3b92..82bb3940b98 100644 --- a/xarray/tutorial.py +++ b/xarray/tutorial.py @@ -5,17 +5,16 @@ * building tutorials in the documentation. """ + from __future__ import annotations import os import pathlib -import warnings from typing import TYPE_CHECKING import numpy as np from xarray.backends.api import open_dataset as _open_dataset -from xarray.backends.rasterio_ import open_rasterio as _open_rasterio from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -40,10 +39,6 @@ def _construct_cache_dir(path): external_urls = {} # type: dict -external_rasterio_urls = { - "RGB.byte": "https://github.com/rasterio/rasterio/raw/1.2.1/tests/data/RGB.byte.tif", - "shade": "https://github.com/rasterio/rasterio/raw/1.2.1/tests/data/shade.tif", -} file_formats = { "air_temperature": 3, "air_temperature_gradient": 4, @@ -172,84 +167,6 @@ def open_dataset( return ds -def open_rasterio( - name, - engine=None, - cache=True, - cache_dir=None, - **kws, -): - """ - Open a rasterio dataset from the online repository (requires internet). - - .. deprecated:: 0.20.0 - - Deprecated in favor of rioxarray. - For information about transitioning, see: - https://corteva.github.io/rioxarray/stable/getting_started/getting_started.html - - If a local copy is found then always use that to avoid network traffic. - - Available datasets: - - * ``"RGB.byte"``: TIFF file derived from USGS Landsat 7 ETM imagery. - * ``"shade"``: TIFF file derived from from USGS SRTM 90 data - - ``RGB.byte`` and ``shade`` are downloaded from the ``rasterio`` repository [1]_. - - Parameters - ---------- - name : str - Name of the file containing the dataset. - e.g. 'RGB.byte' - cache_dir : path-like, optional - The directory in which to search for and write cached data. - cache : bool, optional - If True, then cache data locally for use on subsequent calls - **kws : dict, optional - Passed to xarray.open_rasterio - - See Also - -------- - xarray.open_rasterio - - References - ---------- - .. [1] https://github.com/rasterio/rasterio - """ - warnings.warn( - "open_rasterio is Deprecated in favor of rioxarray. " - "For information about transitioning, see: " - "https://corteva.github.io/rioxarray/stable/getting_started/getting_started.html", - DeprecationWarning, - stacklevel=2, - ) - try: - import pooch - except ImportError as e: - raise ImportError( - "tutorial.open_rasterio depends on pooch to download and manage datasets." - " To proceed please install pooch." - ) from e - - logger = pooch.get_logger() - logger.setLevel("WARNING") - - cache_dir = _construct_cache_dir(cache_dir) - url = external_rasterio_urls.get(name) - if url is None: - raise ValueError(f"unknown rasterio dataset: {name}") - - # retrieve the file - filepath = pooch.retrieve(url=url, known_hash=None, path=cache_dir) - arr = _open_rasterio(filepath, **kws) - if not cache: - arr = arr.load() - pathlib.Path(filepath).unlink() - - return arr - - def load_dataset(*args, **kwargs) -> Dataset: """ Open, load into memory, and close a dataset from the online repository diff --git a/xarray/util/deprecation_helpers.py b/xarray/util/deprecation_helpers.py index e9681bdf398..c620e45574e 100644 --- a/xarray/util/deprecation_helpers.py +++ b/xarray/util/deprecation_helpers.py @@ -34,6 +34,11 @@ import inspect import warnings from functools import wraps +from typing import Callable, TypeVar + +from xarray.core.utils import emit_user_level_warning + +T = TypeVar("T", bound=Callable) POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY @@ -41,7 +46,7 @@ EMPTY = inspect.Parameter.empty -def _deprecate_positional_args(version): +def _deprecate_positional_args(version) -> Callable[[T], T]: """Decorator for methods that issues warnings for positional arguments Using the keyword-only argument syntax in pep 3102, arguments after the @@ -112,3 +117,28 @@ def inner(*args, **kwargs): return inner return _decorator + + +def deprecate_dims(func: T) -> T: + """ + For functions that previously took `dims` as a kwarg, and have now transitioned to + `dim`. This decorator will issue a warning if `dims` is passed while forwarding it + to `dim`. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + if "dims" in kwargs: + emit_user_level_warning( + "The `dims` argument has been renamed to `dim`, and will be removed " + "in the future. This renaming is taking place throughout xarray over the " + "next few releases.", + # Upgrade to `DeprecationWarning` in the future, when the renaming is complete. + PendingDeprecationWarning, + ) + kwargs["dim"] = kwargs.pop("dims") + return func(*args, **kwargs) + + # We're quite confident we're just returning `T` from this function, so it's fine to ignore typing + # within the function. + return wrapper # type: ignore diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index efc69c46947..3462af28663 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -12,9 +12,10 @@ while replacing the doctests. """ + import collections import textwrap -from dataclasses import dataclass +from dataclasses import dataclass, field MODULE_PREAMBLE = '''\ """Mixin classes with reduction operations.""" @@ -22,20 +23,35 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Callable from xarray.core import duck_array_ops from xarray.core.options import OPTIONS -from xarray.core.types import Dims -from xarray.core.utils import contains_only_dask_or_numpy, module_available +from xarray.core.types import Dims, Self +from xarray.core.utils import contains_only_chunked_or_numpy, module_available if TYPE_CHECKING: from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset -flox_available = module_available("flox")''' +flox_available = module_available("flox") +''' + +NAMED_ARRAY_MODULE_PREAMBLE = '''\ +"""Mixin classes with reduction operations.""" +# This file was generated using xarray.util.generate_aggregations. Do not edit manually. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, Callable -DEFAULT_PREAMBLE = """ +from xarray.core import duck_array_ops +from xarray.core.types import Dims, Self +''' + +AGGREGATIONS_PREAMBLE = """ class {obj}{cls}Aggregations: __slots__ = () @@ -49,14 +65,44 @@ def reduce( keep_attrs: bool | None = None, keepdims: bool = False, **kwargs: Any, - ) -> {obj}: + ) -> Self: raise NotImplementedError()""" +NAMED_ARRAY_AGGREGATIONS_PREAMBLE = """ + +class {obj}{cls}Aggregations: + __slots__ = () + + def reduce( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + **kwargs: Any, + ) -> Self: + raise NotImplementedError()""" + + GROUPBY_PREAMBLE = """ class {obj}{cls}Aggregations: _obj: {obj} + def _reduce_without_squeeze_warn( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> {obj}: + raise NotImplementedError() + def reduce( self, func: Callable[..., Any], @@ -81,6 +127,19 @@ def _flox_reduce( class {obj}{cls}Aggregations: _obj: {obj} + def _reduce_without_squeeze_warn( + self, + func: Callable[..., Any], + dim: Dims = None, + *, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> {obj}: + raise NotImplementedError() + def reduce( self, func: Callable[..., Any], @@ -103,11 +162,9 @@ def _flox_reduce( TEMPLATE_REDUCTION_SIGNATURE = ''' def {method}( self, - dim: Dims = None, - *,{extra_kwargs} - keep_attrs: bool | None = None, + dim: Dims = None,{kw_only}{extra_kwargs}{keep_attrs} **kwargs: Any, - ) -> {obj}: + ) -> Self: """ Reduce this {obj}'s data by applying ``{method}`` along some dimension(s). @@ -138,9 +195,7 @@ def {method}( TEMPLATE_SEE_ALSO = """ See Also -------- - numpy.{method} - dask.array.{method} - {see_also_obj}.{method} +{see_also_methods} :ref:`{docref}` User guide on {docref_description}.""" @@ -228,6 +283,15 @@ def {method}( ) +@dataclass +class DataStructure: + name: str + create_example: str + example_var_name: str + numeric_only: bool = False + see_also_modules: tuple[str] = tuple + + class Method: def __init__( self, @@ -235,11 +299,12 @@ def __init__( bool_reduce=False, extra_kwargs=tuple(), numeric_only=False, + see_also_modules=("numpy", "dask.array"), ): self.name = name self.extra_kwargs = extra_kwargs self.numeric_only = numeric_only - + self.see_also_modules = see_also_modules if bool_reduce: self.array_method = f"array_{name}" self.np_example_array = """ @@ -248,37 +313,29 @@ def __init__( else: self.array_method = name self.np_example_array = """ - ... np.array([1, 2, 3, 1, 2, np.nan])""" + ... np.array([1, 2, 3, 0, 2, np.nan])""" +@dataclass class AggregationGenerator: _dim_docstring = _DIM_DOCSTRING _template_signature = TEMPLATE_REDUCTION_SIGNATURE - def __init__( - self, - cls, - datastructure, - methods, - docref, - docref_description, - example_call_preamble, - definition_preamble, - see_also_obj=None, - notes=None, - ): - self.datastructure = datastructure - self.cls = cls - self.methods = methods - self.docref = docref - self.docref_description = docref_description - self.example_call_preamble = example_call_preamble - self.preamble = definition_preamble.format(obj=datastructure.name, cls=cls) - self.notes = "" if notes is None else notes - if not see_also_obj: - self.see_also_obj = self.datastructure.name - else: - self.see_also_obj = see_also_obj + cls: str + datastructure: DataStructure + methods: tuple[Method, ...] + docref: str + docref_description: str + example_call_preamble: str + definition_preamble: str + has_keep_attrs: bool = True + notes: str = "" + preamble: str = field(init=False) + + def __post_init__(self): + self.preamble = self.definition_preamble.format( + obj=self.datastructure.name, cls=self.cls + ) def generate_methods(self): yield [self.preamble] @@ -286,7 +343,18 @@ def generate_methods(self): yield self.generate_method(method) def generate_method(self, method): - template_kwargs = dict(obj=self.datastructure.name, method=method.name) + has_kw_only = method.extra_kwargs or self.has_keep_attrs + + template_kwargs = dict( + obj=self.datastructure.name, + method=method.name, + keep_attrs=( + "\n keep_attrs: bool | None = None," + if self.has_keep_attrs + else "" + ), + kw_only="\n *," if has_kw_only else "", + ) if method.extra_kwargs: extra_kwargs = "\n " + "\n ".join( @@ -303,7 +371,7 @@ def generate_method(self, method): for text in [ self._dim_docstring.format(method=method.name, cls=self.cls), *(kwarg.docs for kwarg in method.extra_kwargs if kwarg.docs), - _KEEP_ATTRS_DOCSTRING, + _KEEP_ATTRS_DOCSTRING if self.has_keep_attrs else None, _KWARGS_DOCSTRING.format(method=method.name), ]: if text: @@ -311,11 +379,24 @@ def generate_method(self, method): yield TEMPLATE_RETURNS.format(**template_kwargs) + # we want Dataset.count to refer to DataArray.count + # but we also want DatasetGroupBy.count to refer to Dataset.count + # The generic aggregations have self.cls == '' + others = ( + self.datastructure.see_also_modules + if self.cls == "" + else (self.datastructure.name,) + ) + see_also_methods = "\n".join( + " " * 8 + f"{mod}.{method.name}" + for mod in (method.see_also_modules + others) + ) + # Fixes broken links mentioned in #8055 yield TEMPLATE_SEE_ALSO.format( **template_kwargs, docref=self.docref, docref_description=self.docref_description, - see_also_obj=self.see_also_obj, + see_also_methods=see_also_methods, ) notes = self.notes @@ -330,18 +411,12 @@ def generate_method(self, method): yield textwrap.indent(self.generate_example(method=method), "") yield ' """' - yield self.generate_code(method) + yield self.generate_code(method, self.has_keep_attrs) def generate_example(self, method): - create_da = f""" - >>> da = xr.DataArray({method.np_example_array}, - ... dims="time", - ... coords=dict( - ... time=("time", pd.date_range("2001-01-01", freq="M", periods=6)), - ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), - ... ), - ... )""" - + created = self.datastructure.create_example.format( + example_array=method.np_example_array + ) calculation = f"{self.datastructure.example_var_name}{self.example_call_preamble}.{method.name}" if method.extra_kwargs: extra_examples = "".join( @@ -352,7 +427,8 @@ def generate_example(self, method): return f""" Examples - --------{create_da}{self.datastructure.docstring_create} + --------{created} + >>> {self.datastructure.example_var_name} >>> {calculation}(){extra_examples}""" @@ -361,7 +437,7 @@ class GroupByAggregationGenerator(AggregationGenerator): _dim_docstring = _DIM_DOCSTRING_GROUPBY _template_signature = TEMPLATE_REDUCTION_SIGNATURE_GROUPBY - def generate_code(self, method): + def generate_code(self, method, has_keep_attrs): extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call] if self.datastructure.numeric_only: @@ -382,7 +458,7 @@ def generate_code(self, method): if method_is_not_flox_supported: return f"""\ - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.{method.array_method}, dim=dim,{extra_kwargs} keep_attrs=keep_attrs, @@ -394,7 +470,7 @@ def generate_code(self, method): if ( flox_available and OPTIONS["use_flox"] - and contains_only_dask_or_numpy(self._obj) + and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( func="{method.name}", @@ -404,7 +480,7 @@ def generate_code(self, method): **kwargs, ) else: - return self.reduce( + return self._reduce_without_squeeze_warn( duck_array_ops.{method.array_method}, dim=dim,{extra_kwargs} keep_attrs=keep_attrs, @@ -413,7 +489,7 @@ def generate_code(self, method): class GenericAggregationGenerator(AggregationGenerator): - def generate_code(self, method): + def generate_code(self, method, has_keep_attrs): extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call] if self.datastructure.numeric_only: @@ -423,18 +499,20 @@ def generate_code(self, method): extra_kwargs = textwrap.indent("\n" + "\n".join(extra_kwargs), 12 * " ") else: extra_kwargs = "" + keep_attrs = ( + "\n" + 12 * " " + "keep_attrs=keep_attrs," if has_keep_attrs else "" + ) return f"""\ return self.reduce( duck_array_ops.{method.array_method}, - dim=dim,{extra_kwargs} - keep_attrs=keep_attrs, + dim=dim,{extra_kwargs}{keep_attrs} **kwargs, )""" AGGREGATION_METHODS = ( # Reductions: - Method("count"), + Method("count", see_also_modules=("pandas.DataFrame", "dask.dataframe.DataFrame")), Method("all", bool_reduce=True), Method("any", bool_reduce=True), Method("max", extra_kwargs=(skipna,)), @@ -451,28 +529,34 @@ def generate_code(self, method): ) -@dataclass -class DataStructure: - name: str - docstring_create: str - example_var_name: str - numeric_only: bool = False - - DATASET_OBJECT = DataStructure( name="Dataset", - docstring_create=""" - >>> ds = xr.Dataset(dict(da=da)) - >>> ds""", + create_example=""" + >>> da = xr.DataArray({example_array}, + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... ) + >>> ds = xr.Dataset(dict(da=da))""", example_var_name="ds", numeric_only=True, + see_also_modules=("DataArray",), ) DATAARRAY_OBJECT = DataStructure( name="DataArray", - docstring_create=""" - >>> da""", + create_example=""" + >>> da = xr.DataArray({example_array}, + ... dims="time", + ... coords=dict( + ... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)), + ... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])), + ... ), + ... )""", example_var_name="da", numeric_only=False, + see_also_modules=("Dataset",), ) DATASET_GENERATOR = GenericAggregationGenerator( cls="", @@ -481,8 +565,7 @@ class DataStructure: docref="agg", docref_description="reduction or aggregation operations", example_call_preamble="", - see_also_obj="DataArray", - definition_preamble=DEFAULT_PREAMBLE, + definition_preamble=AGGREGATIONS_PREAMBLE, ) DATAARRAY_GENERATOR = GenericAggregationGenerator( cls="", @@ -491,8 +574,7 @@ class DataStructure: docref="agg", docref_description="reduction or aggregation operations", example_call_preamble="", - see_also_obj="Dataset", - definition_preamble=DEFAULT_PREAMBLE, + definition_preamble=AGGREGATIONS_PREAMBLE, ) DATAARRAY_GROUPBY_GENERATOR = GroupByAggregationGenerator( cls="GroupBy", @@ -510,7 +592,7 @@ class DataStructure: methods=AGGREGATION_METHODS, docref="resampling", docref_description="resampling operations", - example_call_preamble='.resample(time="3M")', + example_call_preamble='.resample(time="3ME")', definition_preamble=RESAMPLE_PREAMBLE, notes=_FLOX_RESAMPLE_NOTES, ) @@ -530,29 +612,64 @@ class DataStructure: methods=AGGREGATION_METHODS, docref="resampling", docref_description="resampling operations", - example_call_preamble='.resample(time="3M")', + example_call_preamble='.resample(time="3ME")', definition_preamble=RESAMPLE_PREAMBLE, notes=_FLOX_RESAMPLE_NOTES, ) +NAMED_ARRAY_OBJECT = DataStructure( + name="NamedArray", + create_example=""" + >>> from xarray.namedarray.core import NamedArray + >>> na = NamedArray( + ... "x",{example_array}, + ... )""", + example_var_name="na", + numeric_only=False, + see_also_modules=("Dataset", "DataArray"), +) + +NAMED_ARRAY_GENERATOR = GenericAggregationGenerator( + cls="", + datastructure=NAMED_ARRAY_OBJECT, + methods=AGGREGATION_METHODS, + docref="agg", + docref_description="reduction or aggregation operations", + example_call_preamble="", + definition_preamble=NAMED_ARRAY_AGGREGATIONS_PREAMBLE, + has_keep_attrs=False, +) + + +def write_methods(filepath, generators, preamble): + with open(filepath, mode="w", encoding="utf-8") as f: + f.write(preamble) + for gen in generators: + for lines in gen.generate_methods(): + for line in lines: + f.write(line + "\n") + if __name__ == "__main__": import os from pathlib import Path p = Path(os.getcwd()) - filepath = p.parent / "xarray" / "xarray" / "core" / "_aggregations.py" - # filepath = p.parent / "core" / "_aggregations.py" # Run from script location - with open(filepath, mode="w", encoding="utf-8") as f: - f.write(MODULE_PREAMBLE + "\n") - for gen in [ + write_methods( + filepath=p.parent / "xarray" / "xarray" / "core" / "_aggregations.py", + generators=[ DATASET_GENERATOR, DATAARRAY_GENERATOR, DATASET_GROUPBY_GENERATOR, DATASET_RESAMPLE_GENERATOR, DATAARRAY_GROUPBY_GENERATOR, DATAARRAY_RESAMPLE_GENERATOR, - ]: - for lines in gen.generate_methods(): - for line in lines: - f.write(line + "\n") + ], + preamble=MODULE_PREAMBLE, + ) + write_methods( + filepath=p.parent / "xarray" / "xarray" / "namedarray" / "_aggregations.py", + generators=[NAMED_ARRAY_GENERATOR], + preamble=NAMED_ARRAY_MODULE_PREAMBLE, + ) + # filepath = p.parent / "core" / "_aggregations.py" # Run from script location diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index 02a3725f475..ee4dd68b3ba 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -3,14 +3,17 @@ For internal xarray development use only. Usage: - python xarray/util/generate_ops.py --module > xarray/core/_typed_ops.py - python xarray/util/generate_ops.py --stubs > xarray/core/_typed_ops.pyi + python xarray/util/generate_ops.py > xarray/core/_typed_ops.py """ + # Note: the comments in https://github.com/pydata/xarray/pull/4904 provide some # background to some of the design choices made here. -import sys +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from typing import Optional BINOPS_EQNE = (("__eq__", "nputils.array_eq"), ("__ne__", "nputils.array_ne")) BINOPS_CMP = ( @@ -30,6 +33,8 @@ ("__and__", "operator.and_"), ("__xor__", "operator.xor"), ("__or__", "operator.or_"), + ("__lshift__", "operator.lshift"), + ("__rshift__", "operator.rshift"), ) BINOPS_REFLEXIVE = ( ("__radd__", "operator.add"), @@ -54,6 +59,8 @@ ("__iand__", "operator.iand"), ("__ixor__", "operator.ixor"), ("__ior__", "operator.ior"), + ("__ilshift__", "operator.ilshift"), + ("__irshift__", "operator.irshift"), ) UNARY_OPS = ( ("__neg__", "operator.neg"), @@ -70,155 +77,186 @@ ("conjugate", "ops.conjugate"), ) + +required_method_binary = """ + def _binary_op( + self, other: {other_type}, f: Callable, reflexive: bool = False + ) -> {return_type}: + raise NotImplementedError""" template_binop = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> {return_type}:{type_ignore} + return self._binary_op(other, {func})""" +template_binop_overload = """ + @overload{overload_type_ignore} + def {method}(self, other: {overload_type}) -> {overload_type}: + ... + + @overload + def {method}(self, other: {other_type}) -> {return_type}: + ... + + def {method}(self, other: {other_type}) -> {return_type} | {overload_type}:{type_ignore} return self._binary_op(other, {func})""" template_reflexive = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> {return_type}: return self._binary_op(other, {func}, reflexive=True)""" + +required_method_inplace = """ + def _inplace_binary_op(self, other: {other_type}, f: Callable) -> Self: + raise NotImplementedError""" template_inplace = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> Self:{type_ignore} return self._inplace_binary_op(other, {func})""" + +required_method_unary = """ + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: + raise NotImplementedError""" template_unary = """ - def {method}(self): + def {method}(self) -> Self: return self._unary_op({func})""" template_other_unary = """ - def {method}(self, *args, **kwargs): + def {method}(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op({func}, *args, **kwargs)""" -required_method_unary = """ - def _unary_op(self, f, *args, **kwargs): - raise NotImplementedError""" -required_method_binary = """ - def _binary_op(self, other, f, reflexive=False): - raise NotImplementedError""" -required_method_inplace = """ - def _inplace_binary_op(self, other, f): - raise NotImplementedError""" +unhashable = """ + # When __eq__ is defined but __hash__ is not, then an object is unhashable, + # and it should be declared as follows: + __hash__: None # type:ignore[assignment]""" # For some methods we override return type `bool` defined by base class `object`. -OVERRIDE_TYPESHED = {"override": " # type: ignore[override]"} -NO_OVERRIDE = {"override": ""} - -# Note: in some of the overloads below the return value in reality is NotImplemented, -# which cannot accurately be expressed with type hints,e.g. Literal[NotImplemented] -# or type(NotImplemented) are not allowed and NoReturn has a different meaning. -# In such cases we are lending the type checkers a hand by specifying the return type -# of the corresponding reflexive method on `other` which will be called instead. -stub_ds = """\ - def {method}(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...{override}""" -stub_da = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def {method}(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...""" -stub_var = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: T_DataArray) -> T_DataArray: ... - @overload - def {method}(self: T_Variable, other: VarCompatible) -> T_Variable: ...""" -stub_dsgb = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: "DataArray") -> "Dataset": ... - @overload - def {method}(self, other: GroupByIncompatible) -> NoReturn: ...""" -stub_dagb = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: T_DataArray) -> T_DataArray: ... - @overload - def {method}(self, other: GroupByIncompatible) -> NoReturn: ...""" -stub_unary = """\ - def {method}(self: {self_type}) -> {self_type}: ...""" -stub_other_unary = """\ - def {method}(self: {self_type}, *args, **kwargs) -> {self_type}: ...""" -stub_required_unary = """\ - def _unary_op(self, f, *args, **kwargs): ...""" -stub_required_binary = """\ - def _binary_op(self, other, f, reflexive=...): ...""" -stub_required_inplace = """\ - def _inplace_binary_op(self, other, f): ...""" - - -def unops(self_type): - extra_context = {"self_type": self_type} +# We need to add "# type: ignore[override]" +# Keep an eye out for: +# https://discuss.python.org/t/make-type-hints-for-eq-of-primitives-less-strict/34240 +# The type ignores might not be necessary anymore at some point. +# +# We require a "hack" to tell type checkers that e.g. Variable + DataArray = DataArray +# In reality this returns NotImplementes, but this is not a valid type in python 3.9. +# Therefore, we return DataArray. In reality this would call DataArray.__add__(Variable) +# TODO: change once python 3.10 is the minimum. +# +# Mypy seems to require that __iadd__ and __add__ have the same signature. +# This requires some extra type: ignores[misc] in the inplace methods :/ + + +def _type_ignore(ignore: str) -> str: + return f" # type:ignore[{ignore}]" if ignore else "" + + +FuncType = Sequence[tuple[Optional[str], Optional[str]]] +OpsType = tuple[FuncType, str, dict[str, str]] + + +def binops( + other_type: str, return_type: str = "Self", type_ignore_eq: str = "override" +) -> list[OpsType]: + extras = {"other_type": other_type, "return_type": return_type} + return [ + ([(None, None)], required_method_binary, extras), + (BINOPS_NUM + BINOPS_CMP, template_binop, extras | {"type_ignore": ""}), + ( + BINOPS_EQNE, + template_binop, + extras | {"type_ignore": _type_ignore(type_ignore_eq)}, + ), + ([(None, None)], unhashable, extras), + (BINOPS_REFLEXIVE, template_reflexive, extras), + ] + + +def binops_overload( + other_type: str, + overload_type: str, + return_type: str = "Self", + type_ignore_eq: str = "override", +) -> list[OpsType]: + extras = {"other_type": other_type, "return_type": return_type} return [ - ([(None, None)], required_method_unary, stub_required_unary, {}), - (UNARY_OPS, template_unary, stub_unary, extra_context), - (OTHER_UNARY_METHODS, template_other_unary, stub_other_unary, extra_context), + ([(None, None)], required_method_binary, extras), + ( + BINOPS_NUM + BINOPS_CMP, + template_binop_overload, + extras + | { + "overload_type": overload_type, + "type_ignore": "", + "overload_type_ignore": "", + }, + ), + ( + BINOPS_EQNE, + template_binop_overload, + extras + | { + "overload_type": overload_type, + "type_ignore": "", + "overload_type_ignore": _type_ignore(type_ignore_eq), + }, + ), + ([(None, None)], unhashable, extras), + (BINOPS_REFLEXIVE, template_reflexive, extras), ] -def binops(stub=""): +def inplace(other_type: str, type_ignore: str = "") -> list[OpsType]: + extras = {"other_type": other_type} return [ - ([(None, None)], required_method_binary, stub_required_binary, {}), - (BINOPS_NUM + BINOPS_CMP, template_binop, stub, NO_OVERRIDE), - (BINOPS_EQNE, template_binop, stub, OVERRIDE_TYPESHED), - (BINOPS_REFLEXIVE, template_reflexive, stub, NO_OVERRIDE), + ([(None, None)], required_method_inplace, extras), + ( + BINOPS_INPLACE, + template_inplace, + extras | {"type_ignore": _type_ignore(type_ignore)}, + ), ] -def inplace(): +def unops() -> list[OpsType]: return [ - ([(None, None)], required_method_inplace, stub_required_inplace, {}), - (BINOPS_INPLACE, template_inplace, "", {}), + ([(None, None)], required_method_unary, {}), + (UNARY_OPS, template_unary, {}), + (OTHER_UNARY_METHODS, template_other_unary, {}), ] ops_info = {} -ops_info["DatasetOpsMixin"] = binops(stub_ds) + inplace() + unops("T_Dataset") -ops_info["DataArrayOpsMixin"] = binops(stub_da) + inplace() + unops("T_DataArray") -ops_info["VariableOpsMixin"] = binops(stub_var) + inplace() + unops("T_Variable") -ops_info["DatasetGroupByOpsMixin"] = binops(stub_dsgb) -ops_info["DataArrayGroupByOpsMixin"] = binops(stub_dagb) +ops_info["DatasetOpsMixin"] = ( + binops(other_type="DsCompatible") + inplace(other_type="DsCompatible") + unops() +) +ops_info["DataArrayOpsMixin"] = ( + binops(other_type="DaCompatible") + inplace(other_type="DaCompatible") + unops() +) +ops_info["VariableOpsMixin"] = ( + binops_overload(other_type="VarCompatible", overload_type="T_DataArray") + + inplace(other_type="VarCompatible", type_ignore="misc") + + unops() +) +ops_info["DatasetGroupByOpsMixin"] = binops( + other_type="GroupByCompatible", return_type="Dataset" +) +ops_info["DataArrayGroupByOpsMixin"] = binops( + other_type="T_Xarray", return_type="T_Xarray" +) MODULE_PREAMBLE = '''\ """Mixin classes with arithmetic operators.""" # This file was generated using xarray.util.generate_ops. Do not edit manually. -import operator - -from . import nputils, ops''' +from __future__ import annotations -STUBFILE_PREAMBLE = '''\ -"""Stub file for mixin classes with arithmetic operators.""" -# This file was generated using xarray.util.generate_ops. Do not edit manually. - -from typing import NoReturn, TypeVar, overload - -import numpy as np -from numpy.typing import ArrayLike +import operator +from typing import TYPE_CHECKING, Any, Callable, overload -from .dataarray import DataArray -from .dataset import Dataset -from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy -from .types import ( +from xarray.core import nputils, ops +from xarray.core.types import ( DaCompatible, DsCompatible, - GroupByIncompatible, - ScalarOrArray, + GroupByCompatible, + Self, + T_DataArray, + T_Xarray, VarCompatible, ) -from .variable import Variable -try: - from dask.array import Array as DaskArray -except ImportError: - DaskArray = np.ndarray # type: ignore - -# DatasetOpsMixin etc. are parent classes of Dataset etc. -# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally -# we use the ones in `types`. (We're open to refining this, and potentially integrating -# the `py` & `pyi` files to simplify them.) -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin")''' +if TYPE_CHECKING: + from xarray.core.dataset import Dataset''' CLASS_PREAMBLE = """{newline} @@ -229,35 +267,28 @@ class {cls_name}: {method}.__doc__ = {func}.__doc__""" -def render(ops_info, is_module): +def render(ops_info: dict[str, list[OpsType]]) -> Iterator[str]: """Render the module or stub file.""" - yield MODULE_PREAMBLE if is_module else STUBFILE_PREAMBLE + yield MODULE_PREAMBLE for cls_name, method_blocks in ops_info.items(): - yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n" * is_module) - yield from _render_classbody(method_blocks, is_module) + yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n") + yield from _render_classbody(method_blocks) -def _render_classbody(method_blocks, is_module): - for method_func_pairs, method_template, stub_template, extra in method_blocks: - template = method_template if is_module else stub_template +def _render_classbody(method_blocks: list[OpsType]) -> Iterator[str]: + for method_func_pairs, template, extra in method_blocks: if template: for method, func in method_func_pairs: yield template.format(method=method, func=func, **extra) - if is_module: - yield "" - for method_func_pairs, *_ in method_blocks: - for method, func in method_func_pairs: - if method and func: - yield COPY_DOCSTRING.format(method=method, func=func) + yield "" + for method_func_pairs, *_ in method_blocks: + for method, func in method_func_pairs: + if method and func: + yield COPY_DOCSTRING.format(method=method, func=func) if __name__ == "__main__": - option = sys.argv[1].lower() if len(sys.argv) == 2 else None - if option not in {"--module", "--stubs"}: - raise SystemExit(f"Usage: {sys.argv[0]} --module | --stubs") - is_module = option == "--module" - - for line in render(ops_info, is_module): + for line in render(ops_info): print(line) diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py index 42ce3746942..4c715437588 100755 --- a/xarray/util/print_versions.py +++ b/xarray/util/print_versions.py @@ -1,4 +1,5 @@ """Utility functions for printing version information.""" + import importlib import locale import os @@ -107,8 +108,6 @@ def show_versions(file=sys.stdout): ("zarr", lambda mod: mod.__version__), ("cftime", lambda mod: mod.__version__), ("nc_time_axis", lambda mod: mod.__version__), - ("PseudoNetCDF", lambda mod: mod.__version__), - ("rasterio", lambda mod: mod.__version__), ("iris", lambda mod: mod.__version__), ("bottleneck", lambda mod: mod.__version__), ("dask", lambda mod: mod.__version__),