diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md new file mode 100644 index 00000000000..02bc5d0f7b0 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.md @@ -0,0 +1,39 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + + + +**What happened**: + +**What you expected to happen**: + +**Minimal Complete Verifiable Example**: + +```python +# Put your MCVE code here +``` + +**Anything else we need to know?**: + +**Environment**: + +
Output of xr.show_versions() + + + + +
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index c712cf27979..00000000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,35 +0,0 @@ ---- -name: Bug report / Feature request -about: 'Post a problem or idea' -title: '' -labels: '' -assignees: '' - ---- - - - - -#### MCVE Code Sample - - -```python -# Your code here - -``` - -#### Expected Output - - -#### Problem Description - - - -#### Versions - -
Output of xr.show_versions() - - - - -
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000000..3389fbfe071 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: true +contact_links: + - name: General Question + url: https://stackoverflow.com/questions/tagged/python-xarray + about: "If you have a question like *How do I append to an xarray.Dataset?* then please ask on Stack Overflow using the #python-xarray tag." diff --git a/.github/ISSUE_TEMPLATE/feature-request.md b/.github/ISSUE_TEMPLATE/feature-request.md new file mode 100644 index 00000000000..7021fe490aa --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.md @@ -0,0 +1,22 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: '' +assignees: '' + +--- + + + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context about the feature request here. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index a921bddaa23..c9c0b720c35 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -3,4 +3,5 @@ - [ ] Closes #xxxx - [ ] Tests added - [ ] Passes `isort -rc . && black . && mypy . && flake8` - - [ ] Fully documented, including `whats-new.rst` for all changes and `api.rst` for new API + - [ ] User visible changes (including notable bug fixes) are documented in `whats-new.rst` + - [ ] New functions/methods are listed in `api.rst` diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 26bf4803ef6..447f0007fc2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,12 +11,16 @@ repos: rev: stable hooks: - id: black + - repo: https://github.com/keewis/blackdoc + rev: stable + hooks: + - id: blackdoc - repo: https://gitlab.com/pycqa/flake8 rev: 3.7.9 hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.761 # Must match ci/requirements/*.yml + rev: v0.780 # Must match ci/requirements/*.yml hooks: - id: mypy # run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000000..7a909aefd08 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1 @@ +Xarray's contributor guidelines [can be found in our online documentation](http://xarray.pydata.org/en/stable/contributing.html) diff --git a/HOW_TO_RELEASE.md b/HOW_TO_RELEASE.md index 3fdd1d7236d..c890d61d966 100644 --- a/HOW_TO_RELEASE.md +++ b/HOW_TO_RELEASE.md @@ -1,4 +1,4 @@ -How to issue an xarray release in 16 easy steps +# How to issue an xarray release in 17 easy steps Time required: about an hour. @@ -6,7 +6,16 @@ Time required: about an hour. ``` git pull upstream master ``` - 2. Look over whats-new.rst and the docs. Make sure "What's New" is complete + 2. Get a list of contributors with: + ``` + git log "$(git tag --sort="v:refname" | sed -n 'x;$p').." --format=%aN | sort -u | perl -pe 's/\n/$1, /' + ``` + or by substituting the _previous_ release in: + ``` + git log v0.X.Y-1.. --format=%aN | sort -u | perl -pe 's/\n/$1, /' + ``` + Add these into `whats-new.rst` somewhere :) + 3. Look over whats-new.rst and the docs. Make sure "What's New" is complete (check the date!) and consider adding a brief summary note describing the release at the top. Things to watch out for: @@ -16,41 +25,41 @@ Time required: about an hour. due to a bad merge. Check for these before a release by using git diff, e.g., `git diff v0.X.Y whats-new.rst` where 0.X.Y is the previous release. - 3. If you have any doubts, run the full test suite one final time! + 4. If you have any doubts, run the full test suite one final time! ``` pytest ``` - 4. Check that the ReadTheDocs build is passing. - 5. On the master branch, commit the release in git: + 5. Check that the ReadTheDocs build is passing. + 6. On the master branch, commit the release in git: ``` git commit -am 'Release v0.X.Y' ``` - 6. Tag the release: + 7. Tag the release: ``` git tag -a v0.X.Y -m 'v0.X.Y' ``` - 7. Build source and binary wheels for pypi: + 8. Build source and binary wheels for pypi: ``` git clean -xdf # this deletes all uncommited changes! python setup.py bdist_wheel sdist ``` - 8. Use twine to check the package build: + 9. Use twine to check the package build: ``` twine check dist/xarray-0.X.Y* ``` - 9. Use twine to register and upload the release on pypi. Be careful, you can't +10. Use twine to register and upload the release on pypi. Be careful, you can't take this back! ``` twine upload dist/xarray-0.X.Y* ``` You will need to be listed as a package owner at https://pypi.python.org/pypi/xarray for this to work. -10. Push your changes to master: +11. Push your changes to master: ``` git push upstream master git push upstream --tags ``` -11. Update the stable branch (used by ReadTheDocs) and switch back to master: +12. Update the stable branch (used by ReadTheDocs) and switch back to master: ``` git checkout stable git rebase master @@ -60,7 +69,7 @@ Time required: about an hour. It's OK to force push to 'stable' if necessary. (We also update the stable branch with `git cherrypick` for documentation only fixes that apply the current released version.) -12. Add a section for the next release (v.X.Y+1) to doc/whats-new.rst: +13. Add a section for the next release (v.X.Y+1) to doc/whats-new.rst: ``` .. _whats-new.0.X.Y+1: @@ -86,19 +95,19 @@ Time required: about an hour. Internal Changes ~~~~~~~~~~~~~~~~ ``` -13. Commit your changes and push to master again: +14. Commit your changes and push to master again: ``` git commit -am 'New whatsnew section' git push upstream master ``` You're done pushing to master! -14. Issue the release on GitHub. Click on "Draft a new release" at +15. Issue the release on GitHub. Click on "Draft a new release" at https://github.com/pydata/xarray/releases. Type in the version number, but don't bother to describe it -- we maintain that on the docs instead. -15. Update the docs. Login to https://readthedocs.org/projects/xray/versions/ +16. Update the docs. Login to https://readthedocs.org/projects/xray/versions/ and switch your new release tag (at the bottom) from "Inactive" to "Active". It should now build automatically. -16. Issue the release announcement! For bug fix releases, I usually only email +17. Issue the release announcement! For bug fix releases, I usually only email xarray@googlegroups.com. For major/feature releases, I will email a broader list (no more than once every 3-6 months): - pydata@googlegroups.com @@ -109,18 +118,8 @@ Time required: about an hour. Google search will turn up examples of prior release announcements (look for "ANN xarray"). - You can get a list of contributors with: - ``` - git log "$(git tag --sort="v:refname" | sed -n 'x;$p').." --format="%aN" | sort -u - ``` - or by substituting the _previous_ release in: - ``` - git log v0.X.Y-1.. --format="%aN" | sort -u - ``` - NB: copying this output into a Google Groups form can cause - [issues](https://groups.google.com/forum/#!topic/xarray/hK158wAviPs) with line breaks, so take care -Note on version numbering: +## Note on version numbering We follow a rough approximation of semantic version. Only major releases (0.X.0) should include breaking changes. Minor releases (0.X.Y) are for bug fixes and diff --git a/azure-pipelines.yml b/azure-pipelines.yml index ff85501c555..e04c8f74f68 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -108,21 +108,3 @@ jobs: python ci/min_deps_check.py ci/requirements/py36-bare-minimum.yml python ci/min_deps_check.py ci/requirements/py36-min-all-deps.yml displayName: minimum versions policy - -- job: Docs - pool: - vmImage: 'ubuntu-16.04' - steps: - - template: ci/azure/install.yml - parameters: - env_file: ci/requirements/doc.yml - - bash: | - source activate xarray-tests - # Replicate the exact environment created by the readthedocs CI - conda install --yes --quiet -c pkgs/main mock pillow sphinx sphinx_rtd_theme - displayName: Replicate readthedocs CI environment - - bash: | - source activate xarray-tests - cd doc - sphinx-build -W --keep-going -j auto -b html -d _build/doctrees . _build/html - displayName: Build HTML docs diff --git a/ci/azure/install.yml b/ci/azure/install.yml index eff229e863a..83895eebe01 100644 --- a/ci/azure/install.yml +++ b/ci/azure/install.yml @@ -10,6 +10,8 @@ steps: conda env create -n xarray-tests --file ${{ parameters.env_file }} displayName: Install conda dependencies +# TODO: add sparse back in, once Numba works with the development version of +# NumPy again: https://github.com/pydata/xarray/issues/4146 - bash: | source activate xarray-tests conda uninstall -y --force \ @@ -23,7 +25,8 @@ steps: cftime \ rasterio \ pint \ - bottleneck + bottleneck \ + sparse python -m pip install \ -i https://pypi.anaconda.org/scipy-wheels-nightly/simple \ --no-deps \ diff --git a/ci/requirements/py36-min-all-deps.yml b/ci/requirements/py36-min-all-deps.yml index 86540197dcc..a72cd000680 100644 --- a/ci/requirements/py36-min-all-deps.yml +++ b/ci/requirements/py36-min-all-deps.yml @@ -15,8 +15,8 @@ dependencies: - cfgrib=0.9 - cftime=1.0 - coveralls - - dask=2.2 - - distributed=2.2 + - dask=2.5 + - distributed=2.5 - flake8 - h5netcdf=0.7 - h5py=2.9 # Policy allows for 2.10, but it's a conflict-fest diff --git a/ci/requirements/py36-min-nep18.yml b/ci/requirements/py36-min-nep18.yml index a5eded49cd4..cd2b1a18c77 100644 --- a/ci/requirements/py36-min-nep18.yml +++ b/ci/requirements/py36-min-nep18.yml @@ -6,12 +6,11 @@ dependencies: # require drastically newer packages than everything else - python=3.6 - coveralls - - dask=2.4 - - distributed=2.4 + - dask=2.5 + - distributed=2.5 - msgpack-python=0.6 # remove once distributed is bumped. distributed GH3491 - numpy=1.17 - pandas=0.25 - - pint=0.11 - pip - pytest - pytest-cov @@ -19,3 +18,5 @@ dependencies: - scipy=1.2 - setuptools=41.2 - sparse=0.8 + - pip: + - pint==0.13 diff --git a/ci/requirements/py36.yml b/ci/requirements/py36.yml index a500173f277..aa2baf9dcce 100644 --- a/ci/requirements/py36.yml +++ b/ci/requirements/py36.yml @@ -28,7 +28,6 @@ dependencies: - numba - numpy - pandas - - pint - pip - pseudonetcdf - pydap @@ -45,3 +44,4 @@ dependencies: - zarr - pip: - numbagg + - pint diff --git a/ci/requirements/py37-windows.yml b/ci/requirements/py37-windows.yml index e9e5c7a900a..8b12704d644 100644 --- a/ci/requirements/py37-windows.yml +++ b/ci/requirements/py37-windows.yml @@ -28,7 +28,6 @@ dependencies: - numba - numpy - pandas - - pint - pip - pseudonetcdf - pydap @@ -45,3 +44,4 @@ dependencies: - zarr - pip: - numbagg + - pint diff --git a/ci/requirements/py37.yml b/ci/requirements/py37.yml index dba3926596e..70c453e8776 100644 --- a/ci/requirements/py37.yml +++ b/ci/requirements/py37.yml @@ -28,7 +28,6 @@ dependencies: - numba - numpy - pandas - - pint - pip - pseudonetcdf - pydap @@ -45,3 +44,4 @@ dependencies: - zarr - pip: - numbagg + - pint diff --git a/ci/requirements/py38-all-but-dask.yml b/ci/requirements/py38-all-but-dask.yml index a375d9e1e5a..6d76eecbd6a 100644 --- a/ci/requirements/py38-all-but-dask.yml +++ b/ci/requirements/py38-all-but-dask.yml @@ -25,7 +25,6 @@ dependencies: - numba - numpy - pandas - - pint - pip - pseudonetcdf - pydap @@ -42,3 +41,4 @@ dependencies: - zarr - pip: - numbagg + - pint diff --git a/ci/requirements/py38.yml b/ci/requirements/py38.yml index 24602f884e9..6f35138978c 100644 --- a/ci/requirements/py38.yml +++ b/ci/requirements/py38.yml @@ -22,13 +22,12 @@ dependencies: - isort - lxml # Optional dep of pydap - matplotlib - - mypy=0.761 # Must match .pre-commit-config.yaml + - mypy=0.780 # Must match .pre-commit-config.yaml - nc-time-axis - netcdf4 - numba - numpy - pandas - - pint - pip - pseudonetcdf - pydap @@ -45,3 +44,4 @@ dependencies: - zarr - pip: - numbagg + - pint diff --git a/doc/_templates/autosummary/accessor.rst b/doc/_templates/autosummary/accessor.rst new file mode 100644 index 00000000000..4ba745cd6fd --- /dev/null +++ b/doc/_templates/autosummary/accessor.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessor:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/doc/_templates/autosummary/accessor_attribute.rst b/doc/_templates/autosummary/accessor_attribute.rst new file mode 100644 index 00000000000..b5ad65d6a73 --- /dev/null +++ b/doc/_templates/autosummary/accessor_attribute.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessorattribute:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/doc/_templates/autosummary/accessor_callable.rst b/doc/_templates/autosummary/accessor_callable.rst new file mode 100644 index 00000000000..7a3301814f5 --- /dev/null +++ b/doc/_templates/autosummary/accessor_callable.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessorcallable:: {{ (module.split('.')[1:] + [objname]) | join('.') }}.__call__ diff --git a/doc/_templates/autosummary/accessor_method.rst b/doc/_templates/autosummary/accessor_method.rst new file mode 100644 index 00000000000..aefbba6ef1b --- /dev/null +++ b/doc/_templates/autosummary/accessor_method.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessormethod:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 313428c29d2..efef4259b74 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -9,8 +9,6 @@ .. autosummary:: :toctree: generated/ - auto_combine - Dataset.nbytes Dataset.chunks @@ -43,8 +41,6 @@ core.rolling.DatasetCoarsen.all core.rolling.DatasetCoarsen.any - core.rolling.DatasetCoarsen.argmax - core.rolling.DatasetCoarsen.argmin core.rolling.DatasetCoarsen.count core.rolling.DatasetCoarsen.max core.rolling.DatasetCoarsen.mean @@ -70,8 +66,6 @@ core.groupby.DatasetGroupBy.where core.groupby.DatasetGroupBy.all core.groupby.DatasetGroupBy.any - core.groupby.DatasetGroupBy.argmax - core.groupby.DatasetGroupBy.argmin core.groupby.DatasetGroupBy.count core.groupby.DatasetGroupBy.max core.groupby.DatasetGroupBy.mean @@ -87,8 +81,6 @@ core.resample.DatasetResample.all core.resample.DatasetResample.any core.resample.DatasetResample.apply - core.resample.DatasetResample.argmax - core.resample.DatasetResample.argmin core.resample.DatasetResample.assign core.resample.DatasetResample.assign_coords core.resample.DatasetResample.bfill @@ -112,8 +104,6 @@ core.resample.DatasetResample.dims core.resample.DatasetResample.groups - core.rolling.DatasetRolling.argmax - core.rolling.DatasetRolling.argmin core.rolling.DatasetRolling.count core.rolling.DatasetRolling.max core.rolling.DatasetRolling.mean @@ -187,8 +177,6 @@ core.rolling.DataArrayCoarsen.all core.rolling.DataArrayCoarsen.any - core.rolling.DataArrayCoarsen.argmax - core.rolling.DataArrayCoarsen.argmin core.rolling.DataArrayCoarsen.count core.rolling.DataArrayCoarsen.max core.rolling.DataArrayCoarsen.mean @@ -213,8 +201,6 @@ core.groupby.DataArrayGroupBy.where core.groupby.DataArrayGroupBy.all core.groupby.DataArrayGroupBy.any - core.groupby.DataArrayGroupBy.argmax - core.groupby.DataArrayGroupBy.argmin core.groupby.DataArrayGroupBy.count core.groupby.DataArrayGroupBy.max core.groupby.DataArrayGroupBy.mean @@ -230,8 +216,6 @@ core.resample.DataArrayResample.all core.resample.DataArrayResample.any core.resample.DataArrayResample.apply - core.resample.DataArrayResample.argmax - core.resample.DataArrayResample.argmin core.resample.DataArrayResample.assign_coords core.resample.DataArrayResample.bfill core.resample.DataArrayResample.count @@ -254,8 +238,6 @@ core.resample.DataArrayResample.dims core.resample.DataArrayResample.groups - core.rolling.DataArrayRolling.argmax - core.rolling.DataArrayRolling.argmin core.rolling.DataArrayRolling.count core.rolling.DataArrayRolling.max core.rolling.DataArrayRolling.mean @@ -425,8 +407,6 @@ IndexVariable.all IndexVariable.any - IndexVariable.argmax - IndexVariable.argmin IndexVariable.argsort IndexVariable.astype IndexVariable.broadcast_equals @@ -566,8 +546,6 @@ CFTimeIndex.all CFTimeIndex.any CFTimeIndex.append - CFTimeIndex.argmax - CFTimeIndex.argmin CFTimeIndex.argsort CFTimeIndex.asof CFTimeIndex.asof_locs diff --git a/doc/api.rst b/doc/api.rst index 3f25ac1a070..603e3e8f6cf 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -21,7 +21,6 @@ Top-level functions broadcast concat merge - auto_combine combine_by_coords combine_nested where @@ -233,6 +232,15 @@ Reshaping and reorganizing Dataset.sortby Dataset.broadcast_like +Plotting +-------- + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + Dataset.plot.scatter + DataArray ========= @@ -403,6 +411,122 @@ Computation :py:attr:`~core.groupby.DataArrayGroupBy.where` :py:attr:`~core.groupby.DataArrayGroupBy.quantile` + +String manipulation +------------------- + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + DataArray.str.capitalize + DataArray.str.center + DataArray.str.contains + DataArray.str.count + DataArray.str.decode + DataArray.str.encode + DataArray.str.endswith + DataArray.str.find + DataArray.str.get + DataArray.str.index + DataArray.str.isalnum + DataArray.str.isalpha + DataArray.str.isdecimal + DataArray.str.isdigit + DataArray.str.isnumeric + DataArray.str.isspace + DataArray.str.istitle + DataArray.str.isupper + DataArray.str.len + DataArray.str.ljust + DataArray.str.lower + DataArray.str.lstrip + DataArray.str.match + DataArray.str.pad + DataArray.str.repeat + DataArray.str.replace + DataArray.str.rfind + DataArray.str.rindex + DataArray.str.rjust + DataArray.str.rstrip + DataArray.str.slice + DataArray.str.slice_replace + DataArray.str.startswith + DataArray.str.strip + DataArray.str.swapcase + DataArray.str.title + DataArray.str.translate + DataArray.str.upper + DataArray.str.wrap + DataArray.str.zfill + +Datetimelike properties +----------------------- + +**Datetime properties**: + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_attribute.rst + + DataArray.dt.year + DataArray.dt.month + DataArray.dt.day + DataArray.dt.hour + DataArray.dt.minute + DataArray.dt.second + DataArray.dt.microsecond + DataArray.dt.nanosecond + DataArray.dt.weekofyear + DataArray.dt.week + DataArray.dt.dayofweek + DataArray.dt.weekday + DataArray.dt.weekday_name + DataArray.dt.dayofyear + DataArray.dt.quarter + DataArray.dt.days_in_month + DataArray.dt.daysinmonth + DataArray.dt.season + DataArray.dt.time + DataArray.dt.is_month_start + DataArray.dt.is_month_end + DataArray.dt.is_quarter_end + DataArray.dt.is_year_start + DataArray.dt.is_leap_year + +**Datetime methods**: + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + DataArray.dt.floor + DataArray.dt.ceil + DataArray.dt.round + DataArray.dt.strftime + +**Timedelta properties**: + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_attribute.rst + + DataArray.dt.days + DataArray.dt.seconds + DataArray.dt.microseconds + DataArray.dt.nanoseconds + +**Timedelta methods**: + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + DataArray.dt.floor + DataArray.dt.ceil + DataArray.dt.round + + Reshaping and reorganizing -------------------------- @@ -419,6 +543,27 @@ Reshaping and reorganizing DataArray.sortby DataArray.broadcast_like +Plotting +-------- + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_callable.rst + + DataArray.plot + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + DataArray.plot.contourf + DataArray.plot.contour + DataArray.plot.hist + DataArray.plot.imshow + DataArray.plot.line + DataArray.plot.pcolormesh + DataArray.plot.step + .. _api.ufuncs: Universal functions @@ -664,25 +809,6 @@ Creating custom indexes cftime_range -Plotting -======== - -.. autosummary:: - :toctree: generated/ - - Dataset.plot - plot.scatter - DataArray.plot - plot.plot - plot.contourf - plot.contour - plot.hist - plot.imshow - plot.line - plot.pcolormesh - plot.step - plot.FacetGrid - Faceting -------- .. autosummary:: diff --git a/doc/conf.py b/doc/conf.py index 6b16468d29e..d3d126cb33f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -20,6 +20,12 @@ import sys from contextlib import suppress +# --------- autosummary templates ------------------ +# TODO: eventually replace this with a sphinx.ext.auto_accessor module +import sphinx +from sphinx.ext.autodoc import AttributeDocumenter, Documenter, MethodDocumenter +from sphinx.util import rpartition + # make sure the source version is preferred (#3567) root = pathlib.Path(__file__).absolute().parent.parent os.environ["PYTHONPATH"] = str(root) @@ -358,3 +364,113 @@ "dask": ("https://docs.dask.org/en/latest", None), "cftime": ("https://unidata.github.io/cftime", None), } + + +# --------- autosummary templates ------------------ +# TODO: eventually replace this with a sphinx.ext.auto_accessor module +class AccessorDocumenter(MethodDocumenter): + """ + Specialized Documenter subclass for accessors. + """ + + objtype = "accessor" + directivetype = "method" + + # lower than MethodDocumenter so this is not chosen for normal methods + priority = 0.6 + + def format_signature(self): + # this method gives an error/warning for the accessors, therefore + # overriding it (accessor has no arguments) + return "" + + +class AccessorLevelDocumenter(Documenter): + """ + Specialized Documenter subclass for objects on accessor level (methods, + attributes). + """ + + # This is the simple straightforward version + # modname is None, base the last elements (eg 'hour') + # and path the part before (eg 'Series.dt') + # def resolve_name(self, modname, parents, path, base): + # modname = 'pandas' + # mod_cls = path.rstrip('.') + # mod_cls = mod_cls.split('.') + # + # return modname, mod_cls + [base] + + def resolve_name(self, modname, parents, path, base): + if modname is None: + if path: + mod_cls = path.rstrip(".") + else: + mod_cls = None + # if documenting a class-level object without path, + # there must be a current class, either from a parent + # auto directive ... + mod_cls = self.env.temp_data.get("autodoc:class") + # ... or from a class directive + if mod_cls is None: + mod_cls = self.env.temp_data.get("py:class") + # ... if still None, there's no way to know + if mod_cls is None: + return None, [] + # HACK: this is added in comparison to ClassLevelDocumenter + # mod_cls still exists of class.accessor, so an extra + # rpartition is needed + modname, accessor = rpartition(mod_cls, ".") + modname, cls = rpartition(modname, ".") + parents = [cls, accessor] + # if the module name is still missing, get it like above + if not modname: + modname = self.env.temp_data.get("autodoc:module") + if not modname: + if sphinx.__version__ > "1.3": + modname = self.env.ref_context.get("py:module") + else: + modname = self.env.temp_data.get("py:module") + # ... else, it stays None, which means invalid + return modname, parents + [base] + + +class AccessorAttributeDocumenter(AccessorLevelDocumenter, AttributeDocumenter): + + objtype = "accessorattribute" + directivetype = "attribute" + + # lower than AttributeDocumenter so this is not chosen for normal attributes + priority = 0.6 + + +class AccessorMethodDocumenter(AccessorLevelDocumenter, MethodDocumenter): + + objtype = "accessormethod" + directivetype = "method" + + # lower than MethodDocumenter so this is not chosen for normal methods + priority = 0.6 + + +class AccessorCallableDocumenter(AccessorLevelDocumenter, MethodDocumenter): + """ + This documenter lets us removes .__call__ from the method signature for + callable accessors like Series.plot + """ + + objtype = "accessorcallable" + directivetype = "method" + + # lower than MethodDocumenter; otherwise the doc build prints warnings + priority = 0.5 + + def format_name(self): + return MethodDocumenter.format_name(self).rstrip(".__call__") + + +def setup(app): + app.add_autodocumenter(AccessorDocumenter) + app.add_autodocumenter(AccessorAttributeDocumenter) + app.add_autodocumenter(AccessorMethodDocumenter) + app.add_autodocumenter(AccessorCallableDocumenter) diff --git a/doc/contributing.rst b/doc/contributing.rst index 51dba2bb0cc..9e6a3c250e9 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -148,7 +148,7 @@ We'll now kick off a two-step process: 1. Install the build dependencies 2. Build and install xarray -.. code-block:: none +.. code-block:: sh # Create and activate the build environment # This is for Linux and MacOS. On Windows, use py37-windows.yml instead. @@ -162,7 +162,10 @@ We'll now kick off a two-step process: # Build and install xarray pip install -e . -At this point you should be able to import *xarray* from your locally built version:: +At this point you should be able to import *xarray* from your locally +built version: + +.. code-block:: sh $ python # start an interpreter >>> import xarray @@ -256,7 +259,9 @@ Some other important things to know about the docs: - The tutorials make heavy use of the `ipython directive `_ sphinx extension. This directive lets you put code in the documentation which will be run - during the doc build. For example:: + during the doc build. For example: + + .. code:: rst .. ipython:: python @@ -290,7 +295,7 @@ Requirements Make sure to follow the instructions on :ref:`creating a development environment above `, but to build the docs you need to use the environment file ``ci/requirements/doc.yml``. -.. code-block:: none +.. code-block:: sh # Create and activate the docs environment conda env create -f ci/requirements/doc.yml @@ -347,7 +352,10 @@ Code Formatting xarray uses several tools to ensure a consistent code format throughout the project: -- `Black `_ for standardized code formatting +- `Black `_ for standardized + code formatting +- `blackdoc `_ for + standardized code formatting in documentation - `Flake8 `_ for general code quality - `isort `_ for standardized order in imports. See also `flake8-isort `_. @@ -356,12 +364,13 @@ xarray uses several tools to ensure a consistent code format throughout the proj ``pip``:: - pip install black flake8 isort mypy + pip install black flake8 isort mypy blackdoc and then run from the root of the Xarray repository:: isort -rc . black -t py36 . + blackdoc -t py36 . flake8 mypy . diff --git a/doc/dask.rst b/doc/dask.rst index df223982ba4..de25ee2200e 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -432,6 +432,7 @@ received by the applied function. print(da.sizes) return da.time + mapped = xr.map_blocks(func, ds.temperature) mapped @@ -461,9 +462,10 @@ Here is a common example where automated inference will not work. :okexcept: def func(da): - print(da.sizes) + print(da.sizes) return da.isel(time=[1]) + mapped = xr.map_blocks(func, ds.temperature) ``func`` cannot be run on 0-shaped inputs because it is not possible to extract element 1 along a @@ -501,6 +503,7 @@ Notice that the 0-shaped sizes were not printed to screen. Since ``template`` ha def func(obj, a, b=0): return obj + a + b + mapped = ds.map_blocks(func, args=[10], kwargs={"b": 10}) expected = ds + 10 + 10 mapped.identical(expected) diff --git a/doc/internals.rst b/doc/internals.rst index 27c7c4e1d87..46c117e312b 100644 --- a/doc/internals.rst +++ b/doc/internals.rst @@ -182,9 +182,10 @@ re-open it directly with Zarr: .. ipython:: python - ds = xr.tutorial.load_dataset('rasm') - ds.to_zarr('rasm.zarr', mode='w') + ds = xr.tutorial.load_dataset("rasm") + ds.to_zarr("rasm.zarr", mode="w") import zarr - zgroup = zarr.open('rasm.zarr') + + zgroup = zarr.open("rasm.zarr") print(zgroup.tree()) - dict(zgroup['Tair'].attrs) + dict(zgroup["Tair"].attrs) \ No newline at end of file diff --git a/doc/io.rst b/doc/io.rst index 1f854586202..4aac5e0b6f7 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -994,8 +994,8 @@ be done directly from zarr, as described in the GRIB format via cfgrib ---------------------- -xarray supports reading GRIB files via ECMWF cfgrib_ python driver and ecCodes_ -C-library, if they are installed. To open a GRIB file supply ``engine='cfgrib'`` +xarray supports reading GRIB files via ECMWF cfgrib_ python driver, +if it is installed. To open a GRIB file supply ``engine='cfgrib'`` to :py:func:`open_dataset`: .. ipython:: @@ -1003,13 +1003,11 @@ to :py:func:`open_dataset`: In [1]: ds_grib = xr.open_dataset("example.grib", engine="cfgrib") -We recommend installing ecCodes via conda:: +We recommend installing cfgrib via conda:: - conda install -c conda-forge eccodes - pip install cfgrib + conda install -c conda-forge cfgrib .. _cfgrib: https://github.com/ecmwf/cfgrib -.. _ecCodes: https://confluence.ecmwf.int/display/ECC/ecCodes+Home .. _io.pynio: diff --git a/doc/plotting.rst b/doc/plotting.rst index 14e64650902..02ddba1e00c 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -220,7 +220,7 @@ from the time and assign it as a non-dimension coordinate: .. ipython:: python - decimal_day = (air1d.time - air1d.time[0]) / pd.Timedelta('1d') + decimal_day = (air1d.time - air1d.time[0]) / pd.Timedelta("1d") air1d_multi = air1d.assign_coords(decimal_day=("time", decimal_day)) air1d_multi @@ -912,4 +912,4 @@ One can also make line plots with multidimensional coordinates. In this case, `` f, ax = plt.subplots(2, 1) da.plot.line(x="lon", hue="y", ax=ax[0]) @savefig plotting_example_2d_hue_xy.png - da.plot.line(x="lon", hue="x", ax=ax[1]) + da.plot.line(x="lon", hue="x", ax=ax[1]) \ No newline at end of file diff --git a/doc/whats-new.rst b/doc/whats-new.rst index eaed6d2811b..98f35beacc1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -33,6 +33,15 @@ Breaking changes `_. (:pull:`3274`) By `Elliott Sales de Andrade `_ +- The old :py:func:`auto_combine` function has now been removed in + favour of the :py:func:`combine_by_coords` and + :py:func:`combine_nested` functions. This also means that + the default behaviour of :py:func:`open_mfdataset` has changed to use + ``combine='by_coords'`` as the default argument value. (:issue:`2616`, :pull:`3926`) + By `Tom Nicholas `_. +- The ``DataArray`` and ``Variable`` HTML reprs now expand the data section by + default (:issue:`4176`) + By `Stephan Hoyer `_. Enhancements ~~~~~~~~~~~~ @@ -48,6 +57,13 @@ Enhancements New Features ~~~~~~~~~~~~ +- :py:meth:`DataArray.argmin` and :py:meth:`DataArray.argmax` now support + sequences of 'dim' arguments, and if a sequence is passed return a dict + (which can be passed to :py:meth:`isel` to get the value of the minimum) of + the indices for each dimension of the minimum or maximum of a DataArray. + (:pull:`3936`) + By `John Omotani `_, thanks to `Keisuke Fujii + `_ for work in :pull:`1469`. - Added :py:meth:`xarray.infer_freq` for extending frequency inferring to CFTime indexes and data (:pull:`4033`). By `Pascal Bourgault `_. - ``chunks='auto'`` is now supported in the ``chunks`` argument of @@ -69,13 +85,16 @@ New Features - Limited the length of array items with long string reprs to a reasonable width (:pull:`3900`) By `Maximilian Roos `_ +- Limited the number of lines of large arrays when numpy reprs would have greater than 40. + (:pull:`3905`) + By `Maximilian Roos `_ - Implement :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`, :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:issue:`60`, :pull:`3871`) By `Todd Jennings `_ - Support dask handling for :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`, - :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:pull:`3922`) - By `Kai Mühlbauer `_. -- More support for unit aware arrays with pint (:pull:`3643`) + :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:pull:`3922`, :pull:`4135`) + By `Kai Mühlbauer `_ and `Pascal Bourgault `_. +- More support for unit aware arrays with pint (:pull:`3643`, :pull:`3975`) By `Justus Magin `_. - Support overriding existing variables in ``to_zarr()`` with ``mode='a'`` even without ``append_dim``, as long as dimension sizes do not change. @@ -99,7 +118,6 @@ New Features By `Deepak Cherian `_ - :py:meth:`map_blocks` can now handle dask-backed xarray objects in ``args``. (:pull:`3818`) By `Deepak Cherian `_ - - Add keyword ``decode_timedelta`` to :py:func:`xarray.open_dataset`, (:py:func:`xarray.open_dataarray`, :py:func:`xarray.open_dataarray`, :py:func:`xarray.decode_cf`) that allows to disable/enable the decoding of timedeltas @@ -108,6 +126,8 @@ New Features Bug fixes ~~~~~~~~~ +- Fix errors combining attrs in :py:func:`open_mfdataset` (:issue:`4009`, :pull:`4173`) + By `John Omotani `_ - If groupby receives a ``DataArray`` with name=None, assign a default name (:issue:`158`) By `Phil Butcher `_. - Support dark mode in VS code (:issue:`4024`) @@ -177,6 +197,8 @@ Documentation By `Justus Magin `_. - Narrative documentation now describes :py:meth:`map_blocks`: :ref:`dask.automatic-parallelization`. By `Deepak Cherian `_. +- Document ``.plot``, ``.dt``, ``.str`` accessors the way they are called. (:issue:`3625`, :pull:`3988`) + By `Justus Magin `_. - Add documentation for the parameters and return values of :py:meth:`DataArray.sel`. By `Justus Magin `_. @@ -187,6 +209,9 @@ Internal Changes - Run the ``isort`` pre-commit hook only on python source files and update the ``flake8`` version. (:issue:`3750`, :pull:`3711`) By `Justus Magin `_. +- Add `blackdoc `_ to the list of + checkers for development. (:pull:`4177`) + By `Justus Magin `_. - Add a CI job that runs the tests with every optional dependency except ``dask``. (:issue:`3794`, :pull:`3919`) By `Justus Magin `_. @@ -253,6 +278,8 @@ New Features :py:meth:`core.groupby.DatasetGroupBy.quantile`, :py:meth:`core.groupby.DataArrayGroupBy.quantile` (:issue:`3843`, :pull:`3844`) By `Aaron Spring `_. +- Add a diff summary for `testing.assert_allclose`. (:issue:`3617`, :pull:`3847`) + By `Justus Magin `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/__init__.py b/xarray/__init__.py index cb4824d188d..3886edc60e6 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -16,7 +16,7 @@ from .coding.frequencies import infer_freq from .conventions import SerializationWarning, decode_cf from .core.alignment import align, broadcast -from .core.combine import auto_combine, combine_by_coords, combine_nested +from .core.combine import combine_by_coords, combine_nested from .core.common import ALL_DIMS, full_like, ones_like, zeros_like from .core.computation import apply_ufunc, corr, cov, dot, polyval, where from .core.concat import concat @@ -47,7 +47,6 @@ "align", "apply_ufunc", "as_variable", - "auto_combine", "broadcast", "cftime_range", "combine_by_coords", diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 0919d2a582b..8d7c2230b2d 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -4,7 +4,6 @@ from io import BytesIO from numbers import Number from pathlib import Path -from textwrap import dedent from typing import ( TYPE_CHECKING, Callable, @@ -23,7 +22,6 @@ from ..core.combine import ( _infer_concat_order_from_positions, _nested_combine, - auto_combine, combine_by_coords, ) from ..core.dataarray import DataArray @@ -726,14 +724,14 @@ def close(self): def open_mfdataset( paths, chunks=None, - concat_dim="_not_supplied", + concat_dim=None, compat="no_conflicts", preprocess=None, engine=None, lock=None, data_vars="all", coords="different", - combine="_old_auto", + combine="by_coords", autoclose=None, parallel=False, join="outer", @@ -746,9 +744,8 @@ def open_mfdataset( the datasets into one before returning the result, and if combine='nested' then ``combine_nested`` is used. The filepaths must be structured according to which combining function is used, the details of which are given in the documentation for - ``combine_by_coords`` and ``combine_nested``. By default the old (now deprecated) - ``auto_combine`` will be used, please specify either ``combine='by_coords'`` or - ``combine='nested'`` in future. Requires dask to be installed. See documentation for + ``combine_by_coords`` and ``combine_nested``. By default ``combine='by_coords'`` + will be used. Requires dask to be installed. See documentation for details on dask [1]_. Global attributes from the ``attrs_file`` are used for the combined dataset. @@ -758,7 +755,7 @@ def open_mfdataset( Either a string glob in the form ``"path/to/my/files/*.nc"`` or an explicit list of files to open. Paths can be given as strings or as pathlib Paths. If concatenation along more than one dimension is desired, then ``paths`` must be a - nested list-of-lists (see ``manual_combine`` for details). (A string glob will + nested list-of-lists (see ``combine_nested`` for details). (A string glob will be expanded to a 1-dimensional list.) chunks : int or dict, optional Dictionary with keys given by dimension names and values given by chunk sizes. @@ -768,15 +765,16 @@ def open_mfdataset( see the full documentation for more details [2]_. concat_dim : str, or list of str, DataArray, Index or None, optional Dimensions to concatenate files along. You only need to provide this argument - if any of the dimensions along which you want to concatenate is not a dimension - in the original datasets, e.g., if you want to stack a collection of 2D arrays - along a third dimension. Set ``concat_dim=[..., None, ...]`` explicitly to - disable concatenation along a particular dimension. + if ``combine='by_coords'``, and if any of the dimensions along which you want to + concatenate is not a dimension in the original datasets, e.g., if you want to + stack a collection of 2D arrays along a third dimension. Set + ``concat_dim=[..., None, ...]`` explicitly to disable concatenation along a + particular dimension. Default is None, which for a 1D list of filepaths is + equivalent to opening the files separately and then merging them with + ``xarray.merge``. combine : {'by_coords', 'nested'}, optional Whether ``xarray.combine_by_coords`` or ``xarray.combine_nested`` is used to - combine all the data. If this argument is not provided, `xarray.auto_combine` is - used, but in the future this behavior will switch to use - `xarray.combine_by_coords` by default. + combine all the data. Default is to use ``xarray.combine_by_coords``. compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts', 'override'}, optional String indicating how to compare variables of the same name for @@ -869,7 +867,6 @@ def open_mfdataset( -------- combine_by_coords combine_nested - auto_combine open_dataset References @@ -897,11 +894,8 @@ def open_mfdataset( # If combine='nested' then this creates a flat list which is easier to # iterate over, while saving the originally-supplied structure as "ids" if combine == "nested": - if str(concat_dim) == "_not_supplied": - raise ValueError("Must supply concat_dim when using " "combine='nested'") - else: - if isinstance(concat_dim, (str, DataArray)) or concat_dim is None: - concat_dim = [concat_dim] + if isinstance(concat_dim, (str, DataArray)) or concat_dim is None: + concat_dim = [concat_dim] combined_ids_paths = _infer_concat_order_from_positions(paths) ids, paths = (list(combined_ids_paths.keys()), list(combined_ids_paths.values())) @@ -933,30 +927,7 @@ def open_mfdataset( # Combine all datasets, closing them in case of a ValueError try: - if combine == "_old_auto": - # Use the old auto_combine for now - # Remove this after deprecation cycle from #2616 is complete - basic_msg = dedent( - """\ - In xarray version 0.15 the default behaviour of `open_mfdataset` - will change. To retain the existing behavior, pass - combine='nested'. To use future default behavior, pass - combine='by_coords'. See - http://xarray.pydata.org/en/stable/combining.html#combining-multi - """ - ) - warnings.warn(basic_msg, FutureWarning, stacklevel=2) - - combined = auto_combine( - datasets, - concat_dim=concat_dim, - compat=compat, - data_vars=data_vars, - coords=coords, - join=join, - from_openmfds=True, - ) - elif combine == "nested": + if combine == "nested": # Combined nested list by successive concat and merge operations # along each dimension, using structure given by "ids" combined = _nested_combine( @@ -967,12 +938,18 @@ def open_mfdataset( coords=coords, ids=ids, join=join, + combine_attrs="drop", ) elif combine == "by_coords": # Redo ordering from coordinates, ignoring how they were ordered # previously combined = combine_by_coords( - datasets, compat=compat, data_vars=data_vars, coords=coords, join=join + datasets, + compat=compat, + data_vars=data_vars, + coords=coords, + join=join, + combine_attrs="drop", ) else: raise ValueError( diff --git a/xarray/coding/times.py b/xarray/coding/times.py index dafa8ca03b1..77b2d2c7937 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -158,7 +158,7 @@ def decode_cf_datetime(num_dates, units, calendar=None, use_cftime=None): dates = _decode_datetime_with_pandas(flat_num_dates, units, calendar) except (KeyError, OutOfBoundsDatetime, OverflowError): dates = _decode_datetime_with_cftime( - flat_num_dates.astype(np.float), units, calendar + flat_num_dates.astype(float), units, calendar ) if ( @@ -179,7 +179,7 @@ def decode_cf_datetime(num_dates, units, calendar=None, use_cftime=None): dates = cftime_to_nptime(dates) elif use_cftime: dates = _decode_datetime_with_cftime( - flat_num_dates.astype(np.float), units, calendar + flat_num_dates.astype(float), units, calendar ) else: dates = _decode_datetime_with_pandas(flat_num_dates, units, calendar) diff --git a/xarray/conventions.py b/xarray/conventions.py index 588fcea71a3..fc0572944f3 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -116,7 +116,7 @@ def maybe_default_fill_value(var): def maybe_encode_bools(var): if ( - (var.dtype == np.bool) + (var.dtype == bool) and ("dtype" not in var.encoding) and ("dtype" not in var.attrs) ): diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 1f990457798..58bd7178fa2 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -1,7 +1,5 @@ import itertools -import warnings from collections import Counter -from textwrap import dedent import pandas as pd @@ -762,272 +760,3 @@ def combine_by_coords( join=join, combine_attrs=combine_attrs, ) - - -# Everything beyond here is only needed until the deprecation cycle in #2616 -# is completed - - -_CONCAT_DIM_DEFAULT = "__infer_concat_dim__" - - -def auto_combine( - datasets, - concat_dim="_not_supplied", - compat="no_conflicts", - data_vars="all", - coords="different", - fill_value=dtypes.NA, - join="outer", - from_openmfds=False, -): - """ - Attempt to auto-magically combine the given datasets into one. - - This entire function is deprecated in favour of ``combine_nested`` and - ``combine_by_coords``. - - This method attempts to combine a list of datasets into a single entity by - inspecting metadata and using a combination of concat and merge. - It does not concatenate along more than one dimension or sort data under - any circumstances. It does align coordinates, but different variables on - datasets can cause it to fail under some scenarios. In complex cases, you - may need to clean up your data and use ``concat``/``merge`` explicitly. - ``auto_combine`` works well if you have N years of data and M data - variables, and each combination of a distinct time period and set of data - variables is saved its own dataset. - - Parameters - ---------- - datasets : sequence of xarray.Dataset - Dataset objects to merge. - concat_dim : str or DataArray or Index, optional - Dimension along which to concatenate variables, as used by - :py:func:`xarray.concat`. You only need to provide this argument if - the dimension along which you want to concatenate is not a dimension - in the original datasets, e.g., if you want to stack a collection of - 2D arrays along a third dimension. - By default, xarray attempts to infer this argument by examining - component files. Set ``concat_dim=None`` explicitly to disable - concatenation. - compat : {'identical', 'equals', 'broadcast_equals', - 'no_conflicts', 'override'}, optional - String indicating how to compare variables of the same name for - potential conflicts: - - - 'broadcast_equals': all values must be equal when variables are - broadcast against each other to ensure common dimensions. - - 'equals': all values and dimensions must be the same. - - 'identical': all values, dimensions and attributes must be the - same. - - 'no_conflicts': only values which are not null in both datasets - must be equal. The returned dataset then contains the combination - of all non-null values. - - 'override': skip comparing and pick variable from first dataset - data_vars : {'minimal', 'different', 'all' or list of str}, optional - Details are in the documentation of concat - coords : {'minimal', 'different', 'all' o list of str}, optional - Details are in the documentation of concat - fill_value : scalar, optional - Value to use for newly missing values - join : {'outer', 'inner', 'left', 'right', 'exact'}, optional - String indicating how to combine differing indexes - (excluding concat_dim) in objects - - - 'outer': use the union of object indexes - - 'inner': use the intersection of object indexes - - 'left': use indexes from the first object with each dimension - - 'right': use indexes from the last object with each dimension - - 'exact': instead of aligning, raise `ValueError` when indexes to be - aligned are not equal - - 'override': if indexes are of same size, rewrite indexes to be - those of the first object with that dimension. Indexes for the same - dimension must have the same size in all objects. - - Returns - ------- - combined : xarray.Dataset - - See also - -------- - concat - Dataset.merge - """ - - if not from_openmfds: - basic_msg = dedent( - """\ - In xarray version 0.15 `auto_combine` will be deprecated. See - http://xarray.pydata.org/en/stable/combining.html#combining-multi""" - ) - warnings.warn(basic_msg, FutureWarning, stacklevel=2) - - if concat_dim == "_not_supplied": - concat_dim = _CONCAT_DIM_DEFAULT - message = "" - else: - message = dedent( - """\ - Also `open_mfdataset` will no longer accept a `concat_dim` argument. - To get equivalent behaviour from now on please use the new - `combine_nested` function instead (or the `combine='nested'` option to - `open_mfdataset`).""" - ) - - if _dimension_coords_exist(datasets): - message += dedent( - """\ - The datasets supplied have global dimension coordinates. You may want - to use the new `combine_by_coords` function (or the - `combine='by_coords'` option to `open_mfdataset`) to order the datasets - before concatenation. Alternatively, to continue concatenating based - on the order the datasets are supplied in future, please use the new - `combine_nested` function (or the `combine='nested'` option to - open_mfdataset).""" - ) - else: - message += dedent( - """\ - The datasets supplied do not have global dimension coordinates. In - future, to continue concatenating without supplying dimension - coordinates, please use the new `combine_nested` function (or the - `combine='nested'` option to open_mfdataset.""" - ) - - if _requires_concat_and_merge(datasets): - manual_dims = [concat_dim].append(None) - message += dedent( - """\ - The datasets supplied require both concatenation and merging. From - xarray version 0.15 this will operation will require either using the - new `combine_nested` function (or the `combine='nested'` option to - open_mfdataset), with a nested list structure such that you can combine - along the dimensions {}. Alternatively if your datasets have global - dimension coordinates then you can use the new `combine_by_coords` - function.""".format( - manual_dims - ) - ) - - warnings.warn(message, FutureWarning, stacklevel=2) - - return _old_auto_combine( - datasets, - concat_dim=concat_dim, - compat=compat, - data_vars=data_vars, - coords=coords, - fill_value=fill_value, - join=join, - ) - - -def _dimension_coords_exist(datasets): - """ - Check if the datasets have consistent global dimension coordinates - which would in future be used by `auto_combine` for concatenation ordering. - """ - - # Group by data vars - sorted_datasets = sorted(datasets, key=vars_as_keys) - grouped_by_vars = itertools.groupby(sorted_datasets, key=vars_as_keys) - - # Simulates performing the multidimensional combine on each group of data - # variables before merging back together - try: - for vars, datasets_with_same_vars in grouped_by_vars: - _infer_concat_order_from_coords(list(datasets_with_same_vars)) - return True - except ValueError: - # ValueError means datasets don't have global dimension coordinates - # Or something else went wrong in trying to determine them - return False - - -def _requires_concat_and_merge(datasets): - """ - Check if the datasets require the use of both xarray.concat and - xarray.merge, which in future might require the user to use - `manual_combine` instead. - """ - # Group by data vars - sorted_datasets = sorted(datasets, key=vars_as_keys) - grouped_by_vars = itertools.groupby(sorted_datasets, key=vars_as_keys) - - return len(list(grouped_by_vars)) > 1 - - -def _old_auto_combine( - datasets, - concat_dim=_CONCAT_DIM_DEFAULT, - compat="no_conflicts", - data_vars="all", - coords="different", - fill_value=dtypes.NA, - join="outer", -): - if concat_dim is not None: - dim = None if concat_dim is _CONCAT_DIM_DEFAULT else concat_dim - - sorted_datasets = sorted(datasets, key=vars_as_keys) - grouped = itertools.groupby(sorted_datasets, key=vars_as_keys) - - concatenated = [ - _auto_concat( - list(datasets), - dim=dim, - data_vars=data_vars, - coords=coords, - compat=compat, - fill_value=fill_value, - join=join, - ) - for vars, datasets in grouped - ] - else: - concatenated = datasets - merged = merge(concatenated, compat=compat, fill_value=fill_value, join=join) - return merged - - -def _auto_concat( - datasets, - dim=None, - data_vars="all", - coords="different", - fill_value=dtypes.NA, - join="outer", - compat="no_conflicts", -): - if len(datasets) == 1 and dim is None: - # There is nothing more to combine, so kick out early. - return datasets[0] - else: - if dim is None: - ds0 = datasets[0] - ds1 = datasets[1] - concat_dims = set(ds0.dims) - if ds0.dims != ds1.dims: - dim_tuples = set(ds0.dims.items()) - set(ds1.dims.items()) - concat_dims = {i for i, _ in dim_tuples} - if len(concat_dims) > 1: - concat_dims = {d for d in concat_dims if not ds0[d].equals(ds1[d])} - if len(concat_dims) > 1: - raise ValueError( - "too many different dimensions to " "concatenate: %s" % concat_dims - ) - elif len(concat_dims) == 0: - raise ValueError( - "cannot infer dimension to concatenate: " - "supply the ``concat_dim`` argument " - "explicitly" - ) - (dim,) = concat_dims - return concat( - datasets, - dim=dim, - data_vars=data_vars, - coords=coords, - fill_value=fill_value, - compat=compat, - ) diff --git a/xarray/core/common.py b/xarray/core/common.py index e343f342040..f759f4c32dd 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1481,7 +1481,7 @@ def zeros_like(other, dtype: DTypeLike = None): * lat (lat) int64 1 2 * lon (lon) int64 0 1 2 - >>> xr.zeros_like(x, dtype=np.float) + >>> xr.zeros_like(x, dtype=float) array([[0., 0., 0.], [0., 0., 0.]]) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index cecd4fd8e70..d8a0c53e817 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1096,10 +1096,14 @@ def cov(da_a, da_b, dim=None, ddof=1): Examples -------- - >>> da_a = DataArray(np.array([[1, 2, 3], [0.1, 0.2, 0.3], [3.2, 0.6, 1.8]]), - ... dims=("space", "time"), - ... coords=[('space', ['IA', 'IL', 'IN']), - ... ('time', pd.date_range("2000-01-01", freq="1D", periods=3))]) + >>> da_a = DataArray( + ... np.array([[1, 2, 3], [0.1, 0.2, 0.3], [3.2, 0.6, 1.8]]), + ... dims=("space", "time"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ("time", pd.date_range("2000-01-01", freq="1D", periods=3)), + ... ], + ... ) >>> da_a array([[1. , 2. , 3. ], @@ -1108,10 +1112,14 @@ def cov(da_a, da_b, dim=None, ddof=1): Coordinates: * space (space) >> da_b = DataArray(np.array([[0.2, 0.4, 0.6], [15, 10, 5], [3.2, 0.6, 1.8]]), - ... dims=("space", "time"), - ... coords=[('space', ['IA', 'IL', 'IN']), - ... ('time', pd.date_range("2000-01-01", freq="1D", periods=3))]) + >>> da_b = DataArray( + ... np.array([[0.2, 0.4, 0.6], [15, 10, 5], [3.2, 0.6, 1.8]]), + ... dims=("space", "time"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ("time", pd.date_range("2000-01-01", freq="1D", periods=3)), + ... ], + ... ) >>> da_b array([[ 0.2, 0.4, 0.6], @@ -1123,7 +1131,7 @@ def cov(da_a, da_b, dim=None, ddof=1): >>> xr.cov(da_a, da_b) array(-3.53055556) - >>> xr.cov(da_a, da_b, dim='time') + >>> xr.cov(da_a, da_b, dim="time") array([ 0.2, -0.5, 1.69333333]) Coordinates: @@ -1165,10 +1173,14 @@ def corr(da_a, da_b, dim=None): Examples -------- - >>> da_a = DataArray(np.array([[1, 2, 3], [0.1, 0.2, 0.3], [3.2, 0.6, 1.8]]), - ... dims=("space", "time"), - ... coords=[('space', ['IA', 'IL', 'IN']), - ... ('time', pd.date_range("2000-01-01", freq="1D", periods=3))]) + >>> da_a = DataArray( + ... np.array([[1, 2, 3], [0.1, 0.2, 0.3], [3.2, 0.6, 1.8]]), + ... dims=("space", "time"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ("time", pd.date_range("2000-01-01", freq="1D", periods=3)), + ... ], + ... ) >>> da_a array([[1. , 2. , 3. ], @@ -1177,10 +1189,14 @@ def corr(da_a, da_b, dim=None): Coordinates: * space (space) >> da_b = DataArray(np.array([[0.2, 0.4, 0.6], [15, 10, 5], [3.2, 0.6, 1.8]]), - ... dims=("space", "time"), - ... coords=[('space', ['IA', 'IL', 'IN']), - ... ('time', pd.date_range("2000-01-01", freq="1D", periods=3))]) + >>> da_b = DataArray( + ... np.array([[0.2, 0.4, 0.6], [15, 10, 5], [3.2, 0.6, 1.8]]), + ... dims=("space", "time"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ("time", pd.date_range("2000-01-01", freq="1D", periods=3)), + ... ], + ... ) >>> da_b array([[ 0.2, 0.4, 0.6], @@ -1192,7 +1208,7 @@ def corr(da_a, da_b, dim=None): >>> xr.corr(da_a, da_b) array(-0.57087777) - >>> xr.corr(da_a, da_b, dim='time') + >>> xr.corr(da_a, da_b, dim="time") array([ 1., -1., 1.]) Coordinates: @@ -1563,7 +1579,7 @@ def _calc_idxminmax( chunks = dict(zip(array.dims, array.chunks)) dask_coord = dask.array.from_array(array[dim].data, chunks=chunks[dim]) - res = indx.copy(data=dask_coord[(indx.data,)]) + res = indx.copy(data=dask_coord[indx.data.ravel()].reshape(indx.shape)) # we need to attach back the dim name res.name = dim else: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 44773e36e30..0ce76a5e23a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -53,7 +53,7 @@ from .formatting import format_item from .indexes import Indexes, default_indexes, propagate_indexes from .indexing import is_fancy_indexer -from .merge import PANDAS_TYPES, _extract_indexes_from_coords +from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords from .options import OPTIONS from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs from .variable import ( @@ -260,7 +260,7 @@ class DataArray(AbstractArray, DataWithCoords): _resample_cls = resample.DataArrayResample _weighted_cls = weighted.DataArrayWeighted - dt = property(CombinedDatetimelikeAccessor) + dt = utils.UncachedAccessor(CombinedDatetimelikeAccessor) def __init__( self, @@ -2713,8 +2713,15 @@ def func(self, other): # don't support automatic alignment with in-place arithmetic. other_coords = getattr(other, "coords", None) other_variable = getattr(other, "variable", other) - with self.coords._merge_inplace(other_coords): - f(self.variable, other_variable) + try: + with self.coords._merge_inplace(other_coords): + f(self.variable, other_variable) + except MergeError as exc: + raise MergeError( + "Automatic alignment is not supported for in-place operations.\n" + "Consider aligning the indices manually or using a not-in-place operation.\n" + "See https://github.com/pydata/xarray/issues/3910 for more explanations." + ) from exc return self return func @@ -2722,24 +2729,7 @@ def func(self, other): def _copy_attrs_from(self, other: Union["DataArray", Dataset, Variable]) -> None: self.attrs = other.attrs - @property - def plot(self) -> _PlotMethods: - """ - Access plotting functions for DataArray's - - >>> d = xr.DataArray([[1, 2], [3, 4]]) - - For convenience just call this directly - - >>> d.plot() - - Or use it as a namespace to use xarray.plot functions as - DataArray methods - - >>> d.plot.imshow() # equivalent to xarray.plot.imshow(d) - - """ - return _PlotMethods(self) + plot = utils.UncachedAccessor(_PlotMethods) def _title_for_slice(self, truncate: int = 50) -> str: """ @@ -3829,9 +3819,212 @@ def idxmax( keep_attrs=keep_attrs, ) + def argmin( + self, + dim: Union[Hashable, Sequence[Hashable]] = None, + axis: int = None, + keep_attrs: bool = None, + skipna: bool = None, + ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: + """Index or indices of the minimum of the DataArray over one or more dimensions. + + If a sequence is passed to 'dim', then result returned as dict of DataArrays, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns a DataArray with dtype int. + + If there are multiple minima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : hashable, sequence of hashable or ..., optional + The dimensions over which to find the minimum. By default, finds minimum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : DataArray or dict of DataArray + + See also + -------- + Variable.argmin, DataArray.idxmin + + Examples + -------- + >>> array = xr.DataArray([0, 2, -1, 3], dims="x") + >>> array.min() + + array(-1) + >>> array.argmin() + + array(2) + >>> array.argmin(...) + {'x': + array(2)} + >>> array.isel(array.argmin(...)) + array(-1) + + >>> array = xr.DataArray([[[3, 2, 1], [3, 1, 2], [2, 1, 3]], + ... [[1, 3, 2], [2, -5, 1], [2, 3, 1]]], + ... dims=("x", "y", "z")) + >>> array.min(dim="x") + + array([[ 1, 2, 1], + [ 2, -5, 1], + [ 2, 1, 1]]) + Dimensions without coordinates: y, z + >>> array.argmin(dim="x") + + array([[1, 0, 0], + [1, 1, 1], + [0, 0, 1]]) + Dimensions without coordinates: y, z + >>> array.argmin(dim=["x"]) + {'x': + array([[1, 0, 0], + [1, 1, 1], + [0, 0, 1]]) + Dimensions without coordinates: y, z} + >>> array.min(dim=("x", "z")) + + array([ 1, -5, 1]) + Dimensions without coordinates: y + >>> array.argmin(dim=["x", "z"]) + {'x': + array([0, 1, 0]) + Dimensions without coordinates: y, 'z': + array([2, 1, 1]) + Dimensions without coordinates: y} + >>> array.isel(array.argmin(dim=["x", "z"])) + + array([ 1, -5, 1]) + Dimensions without coordinates: y + """ + result = self.variable.argmin(dim, axis, keep_attrs, skipna) + if isinstance(result, dict): + return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()} + else: + return self._replace_maybe_drop_dims(result) + + def argmax( + self, + dim: Union[Hashable, Sequence[Hashable]] = None, + axis: int = None, + keep_attrs: bool = None, + skipna: bool = None, + ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: + """Index or indices of the maximum of the DataArray over one or more dimensions. + + If a sequence is passed to 'dim', then result returned as dict of DataArrays, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns a DataArray with dtype int. + + If there are multiple maxima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : hashable, sequence of hashable or ..., optional + The dimensions over which to find the maximum. By default, finds maximum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : DataArray or dict of DataArray + + See also + -------- + Variable.argmax, DataArray.idxmax + + Examples + -------- + >>> array = xr.DataArray([0, 2, -1, 3], dims="x") + >>> array.max() + + array(3) + >>> array.argmax() + + array(3) + >>> array.argmax(...) + {'x': + array(3)} + >>> array.isel(array.argmax(...)) + + array(3) + + >>> array = xr.DataArray([[[3, 2, 1], [3, 1, 2], [2, 1, 3]], + ... [[1, 3, 2], [2, 5, 1], [2, 3, 1]]], + ... dims=("x", "y", "z")) + >>> array.max(dim="x") + + array([[3, 3, 2], + [3, 5, 2], + [2, 3, 3]]) + Dimensions without coordinates: y, z + >>> array.argmax(dim="x") + + array([[0, 1, 1], + [0, 1, 0], + [0, 1, 0]]) + Dimensions without coordinates: y, z + >>> array.argmax(dim=["x"]) + {'x': + array([[0, 1, 1], + [0, 1, 0], + [0, 1, 0]]) + Dimensions without coordinates: y, z} + >>> array.max(dim=("x", "z")) + + array([3, 5, 3]) + Dimensions without coordinates: y + >>> array.argmax(dim=["x", "z"]) + {'x': + array([0, 1, 0]) + Dimensions without coordinates: y, 'z': + array([0, 1, 2]) + Dimensions without coordinates: y} + >>> array.isel(array.argmax(dim=["x", "z"])) + + array([3, 5, 3]) + Dimensions without coordinates: y + """ + result = self.variable.argmax(dim, axis, keep_attrs, skipna) + if isinstance(result, dict): + return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()} + else: + return self._replace_maybe_drop_dims(result) + # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names - str = property(StringAccessor) + str = utils.UncachedAccessor(StringAccessor) # priority most be higher than Variable to properly work with binary ufuncs diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a8011afd3e3..b46b1d6dce0 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -27,6 +27,7 @@ TypeVar, Union, cast, + overload, ) import numpy as np @@ -1241,13 +1242,25 @@ def loc(self) -> _LocIndexer: """ return _LocIndexer(self) - def __getitem__(self, key: Any) -> "Union[DataArray, Dataset]": + # FIXME https://github.com/python/mypy/issues/7328 + @overload + def __getitem__(self, key: Mapping) -> "Dataset": # type: ignore + ... + + @overload + def __getitem__(self, key: Hashable) -> "DataArray": # type: ignore + ... + + @overload + def __getitem__(self, key: Any) -> "Dataset": + ... + + def __getitem__(self, key): """Access variables or coordinates this dataset as a :py:class:`~xarray.DataArray`. Indexing with a list of names will return a new ``Dataset`` object. """ - # TODO(shoyer): type this properly: https://github.com/python/mypy/issues/7328 if utils.is_dict_like(key): return self.isel(**cast(Mapping, key)) @@ -5563,16 +5576,7 @@ def real(self): def imag(self): return self._unary_op(lambda x: x.imag, keep_attrs=True)(self) - @property - def plot(self): - """ - Access plotting functions for Datasets. - Use it as a namespace to use xarray.plot functions as Dataset methods - - >>> ds.plot.scatter(...) # equivalent to xarray.plot.scatter(ds,...) - - """ - return _Dataset_PlotMethods(self) + plot = utils.UncachedAccessor(_Dataset_PlotMethods) def filter_by_attrs(self, **kwargs): """Returns a ``Dataset`` with variables that match specific conditions. @@ -6364,5 +6368,131 @@ def idxmax( ) ) + def argmin(self, dim=None, axis=None, **kwargs): + """Indices of the minima of the member variables. + + If there are multiple minima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : str, optional + The dimension over which to find the minimum. By default, finds minimum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will be an error, since DataArray.argmin will + return a dict with indices for all dimensions, which does not make sense for + a Dataset. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : Dataset + + See also + -------- + DataArray.argmin + + """ + if dim is None and axis is None: + warnings.warn( + "Once the behaviour of DataArray.argmin() and Variable.argmin() with " + "neither dim nor axis argument changes to return a dict of indices of " + "each dimension, for consistency it will be an error to call " + "Dataset.argmin() with no argument, since we don't return a dict of " + "Datasets.", + DeprecationWarning, + stacklevel=2, + ) + if ( + dim is None + or axis is not None + or (not isinstance(dim, Sequence) and dim is not ...) + or isinstance(dim, str) + ): + # Return int index if single dimension is passed, and is not part of a + # sequence + argmin_func = getattr(duck_array_ops, "argmin") + return self.reduce(argmin_func, dim=dim, axis=axis, **kwargs) + else: + raise ValueError( + "When dim is a sequence or ..., DataArray.argmin() returns a dict. " + "dicts cannot be contained in a Dataset, so cannot call " + "Dataset.argmin() with a sequence or ... for dim" + ) + + def argmax(self, dim=None, axis=None, **kwargs): + """Indices of the maxima of the member variables. + + If there are multiple maxima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : str, optional + The dimension over which to find the maximum. By default, finds maximum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will be an error, since DataArray.argmax will + return a dict with indices for all dimensions, which does not make sense for + a Dataset. + axis : int, optional + Axis over which to apply `argmax`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : Dataset + + See also + -------- + DataArray.argmax + + """ + if dim is None and axis is None: + warnings.warn( + "Once the behaviour of DataArray.argmax() and Variable.argmax() with " + "neither dim nor axis argument changes to return a dict of indices of " + "each dimension, for consistency it will be an error to call " + "Dataset.argmax() with no argument, since we don't return a dict of " + "Datasets.", + DeprecationWarning, + stacklevel=2, + ) + if ( + dim is None + or axis is not None + or (not isinstance(dim, Sequence) and dim is not ...) + or isinstance(dim, str) + ): + # Return int index if single dimension is passed, and is not part of a + # sequence + argmax_func = getattr(duck_array_ops, "argmax") + return self.reduce(argmax_func, dim=dim, axis=axis, **kwargs) + else: + raise ValueError( + "When dim is a sequence or ..., DataArray.argmin() returns a dict. " + "dicts cannot be contained in a Dataset, so cannot call " + "Dataset.argmin() with a sequence or ... for dim" + ) + ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 1340b456cf2..df579d23544 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -6,6 +6,7 @@ import contextlib import inspect import warnings +from distutils.version import LooseVersion from functools import partial import numpy as np @@ -20,6 +21,14 @@ except ImportError: dask_array = None # type: ignore +# TODO: remove after we stop supporting dask < 2.9.1 +try: + import dask + + dask_version = dask.__version__ +except ImportError: + dask_version = None + def _dask_or_eager_func( name, @@ -199,8 +208,19 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): """ arr1 = asarray(arr1) arr2 = asarray(arr2) + lazy_equiv = lazy_array_equiv(arr1, arr2) if lazy_equiv is None: + # TODO: remove after we require dask >= 2.9.1 + sufficient_dask_version = ( + dask_version is not None and LooseVersion(dask_version) >= "2.9.1" + ) + if not sufficient_dask_version and any( + isinstance(arr, dask_array_type) for arr in [arr1, arr2] + ): + arr1 = np.array(arr1) + arr2 = np.array(arr2) + return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) else: return lazy_equiv @@ -339,6 +359,7 @@ def f(values, axis=None, skipna=None, **kwargs): cumprod_1d.numeric_only = True cumsum_1d = _create_nan_agg_method("cumsum") cumsum_1d.numeric_only = True +unravel_index = _dask_or_eager_func("unravel_index") _mean = _create_nan_agg_method("mean") diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index d6732fc182e..28eaae5f05b 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -3,7 +3,7 @@ import contextlib import functools from datetime import datetime, timedelta -from itertools import zip_longest +from itertools import chain, zip_longest from typing import Hashable import numpy as np @@ -140,7 +140,7 @@ def format_item(x, timedelta_format=None, quote_strings=True): return format_timedelta(x, timedelta_format=timedelta_format) elif isinstance(x, (str, bytes)): return repr(x) if quote_strings else x - elif isinstance(x, (float, np.float)): + elif isinstance(x, (float, np.float_)): return f"{x:.4}" else: return str(x) @@ -422,6 +422,17 @@ def set_numpy_options(*args, **kwargs): np.set_printoptions(**original) +def limit_lines(string: str, *, limit: int): + """ + If the string is more lines than the limit, + this returns the middle lines replaced by an ellipsis + """ + lines = string.splitlines() + if len(lines) > limit: + string = "\n".join(chain(lines[: limit // 2], ["..."], lines[-limit // 2 :])) + return string + + def short_numpy_repr(array): array = np.asarray(array) @@ -447,7 +458,7 @@ def short_data_repr(array): elif hasattr(internal_data, "__array_function__") or isinstance( internal_data, dask_array_type ): - return repr(array.data) + return limit_lines(repr(array.data), limit=40) elif array._in_memory or array.size < 1e5: return short_numpy_repr(array) else: @@ -539,7 +550,10 @@ def extra_items_repr(extra_keys, mapping, ab_side): for k in a_keys & b_keys: try: # compare xarray variable - compatible = getattr(a_mapping[k], compat)(b_mapping[k]) + if not callable(compat): + compatible = getattr(a_mapping[k], compat)(b_mapping[k]) + else: + compatible = compat(a_mapping[k], b_mapping[k]) is_variable = True except AttributeError: # compare attribute value @@ -596,8 +610,13 @@ def extra_items_repr(extra_keys, mapping, ab_side): def _compat_to_str(compat): + if callable(compat): + compat = compat.__name__ + if compat == "equals": return "equal" + elif compat == "allclose": + return "close" else: return compat @@ -611,8 +630,12 @@ def diff_array_repr(a, b, compat): ] summary.append(diff_dim_summary(a, b)) + if callable(compat): + equiv = compat + else: + equiv = array_equiv - if not array_equiv(a.data, b.data): + if not equiv(a.data, b.data): temp = [wrap_indent(short_numpy_repr(obj), start=" ") for obj in (a, b)] diff_data_repr = [ ab_side + "\n" + ab_data_repr diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 69832d6ca3d..400ef61502e 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -20,7 +20,9 @@ def short_data_repr_html(array): internal_data = getattr(array, "variable", array)._data if hasattr(internal_data, "_repr_html_"): return internal_data._repr_html_() - return escape(short_data_repr(array)) + else: + text = escape(short_data_repr(array)) + return f"
{text}
" def format_dims(dims, coord_names): @@ -123,7 +125,7 @@ def summarize_variable(name, var, is_index=False, dtype=None, preview=None): f"" f"
{attrs_ul}
" - f"
{data_repr}
" + f"
{data_repr}
" ) @@ -182,7 +184,7 @@ def dim_section(obj): def array_section(obj): # "unique" id to expand/collapse the section data_id = "section-" + str(uuid.uuid4()) - collapsed = "" + collapsed = "checked" variable = getattr(obj, "variable", obj) preview = escape(inline_variable_array_repr(variable, max_width=70)) data_repr = short_data_repr_html(obj) @@ -193,7 +195,7 @@ def array_section(obj): f"" f"" f"
{preview}
" - f"
{data_repr}
" + f"
{data_repr}
" "" ) diff --git a/xarray/core/ops.py b/xarray/core/ops.py index b789f93b4f1..d4aeea37aad 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -47,8 +47,6 @@ # methods which remove an axis REDUCE_METHODS = ["all", "any"] NAN_REDUCE_METHODS = [ - "argmax", - "argmin", "max", "min", "mean", diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 3a77753d0d1..86044e72dd2 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -252,7 +252,10 @@ def map_blocks( to the function being applied in ``xr.map_blocks()``: >>> xr.map_blocks( - ... calculate_anomaly, array, kwargs={"groupby_type": "time.year"}, template=array, + ... calculate_anomaly, + ... array, + ... kwargs={"groupby_type": "time.year"}, + ... template=array, ... ) array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 , diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 1126cf3037f..0542f850b02 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -787,6 +787,24 @@ def drop_dims_from_indexers( ) +class UncachedAccessor: + """ Acts like a property, but on both classes and class instances + + This class is necessary because some tools (e.g. pydoc and sphinx) + inspect classes for which property returns itself and not the + accessor. + """ + + def __init__(self, accessor): + self._accessor = accessor + + def __get__(self, obj, cls): + if obj is None: + return self._accessor + + return self._accessor(obj) + + # Singleton type, as per https://github.com/python/typing/pull/240 class Default(Enum): token = 0 diff --git a/xarray/core/variable.py b/xarray/core/variable.py index e19132b1b06..c505c749557 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -6,7 +6,17 @@ from collections import defaultdict from datetime import timedelta from distutils.version import LooseVersion -from typing import Any, Dict, Hashable, Mapping, Tuple, TypeVar, Union +from typing import ( + Any, + Dict, + Hashable, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) import numpy as np import pandas as pd @@ -2069,6 +2079,166 @@ def _to_numeric(self, offset=None, datetime_unit=None, dtype=float): ) return type(self)(self.dims, numeric_array, self._attrs) + def _unravel_argminmax( + self, + argminmax: str, + dim: Union[Hashable, Sequence[Hashable], None], + axis: Union[int, None], + keep_attrs: Optional[bool], + skipna: Optional[bool], + ) -> Union["Variable", Dict[Hashable, "Variable"]]: + """Apply argmin or argmax over one or more dimensions, returning the result as a + dict of DataArray that can be passed directly to isel. + """ + if dim is None and axis is None: + warnings.warn( + "Behaviour of argmin/argmax with neither dim nor axis argument will " + "change to return a dict of indices of each dimension. To get a " + "single, flat index, please use np.argmin(da.data) or " + "np.argmax(da.data) instead of da.argmin() or da.argmax().", + DeprecationWarning, + stacklevel=3, + ) + + argminmax_func = getattr(duck_array_ops, argminmax) + + if dim is ...: + # In future, should do this also when (dim is None and axis is None) + dim = self.dims + if ( + dim is None + or axis is not None + or not isinstance(dim, Sequence) + or isinstance(dim, str) + ): + # Return int index if single dimension is passed, and is not part of a + # sequence + return self.reduce( + argminmax_func, dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna + ) + + # Get a name for the new dimension that does not conflict with any existing + # dimension + newdimname = "_unravel_argminmax_dim_0" + count = 1 + while newdimname in self.dims: + newdimname = "_unravel_argminmax_dim_{}".format(count) + count += 1 + + stacked = self.stack({newdimname: dim}) + + result_dims = stacked.dims[:-1] + reduce_shape = tuple(self.sizes[d] for d in dim) + + result_flat_indices = stacked.reduce(argminmax_func, axis=-1, skipna=skipna) + + result_unravelled_indices = duck_array_ops.unravel_index( + result_flat_indices.data, reduce_shape + ) + + result = { + d: Variable(dims=result_dims, data=i) + for d, i in zip(dim, result_unravelled_indices) + } + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + if keep_attrs: + for v in result.values(): + v.attrs = self.attrs + + return result + + def argmin( + self, + dim: Union[Hashable, Sequence[Hashable]] = None, + axis: int = None, + keep_attrs: bool = None, + skipna: bool = None, + ) -> Union["Variable", Dict[Hashable, "Variable"]]: + """Index or indices of the minimum of the Variable over one or more dimensions. + If a sequence is passed to 'dim', then result returned as dict of Variables, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns a Variable with dtype int. + + If there are multiple minima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : hashable, sequence of hashable or ..., optional + The dimensions over which to find the minimum. By default, finds minimum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : Variable or dict of Variable + + See also + -------- + DataArray.argmin, DataArray.idxmin + """ + return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna) + + def argmax( + self, + dim: Union[Hashable, Sequence[Hashable]] = None, + axis: int = None, + keep_attrs: bool = None, + skipna: bool = None, + ) -> Union["Variable", Dict[Hashable, "Variable"]]: + """Index or indices of the maximum of the Variable over one or more dimensions. + If a sequence is passed to 'dim', then result returned as dict of Variables, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns a Variable with dtype int. + + If there are multiple maxima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : hashable, sequence of hashable or ..., optional + The dimensions over which to find the maximum. By default, finds maximum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : Variable or dict of Variable + + See also + -------- + DataArray.argmax, DataArray.idxmax + """ + return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) + ops.inject_all_ops_and_reduce_methods(Variable) diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 21ed06ea85f..fa143342c06 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -72,11 +72,11 @@ class Weighted: def __init__(self, obj: "DataArray", weights: "DataArray") -> None: ... - @overload # noqa: F811 - def __init__(self, obj: "Dataset", weights: "DataArray") -> None: # noqa: F811 + @overload + def __init__(self, obj: "Dataset", weights: "DataArray") -> None: ... - def __init__(self, obj, weights): # noqa: F811 + def __init__(self, obj, weights): """ Create a Weighted object diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 0a4ca305306..be79f0ab04c 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -62,7 +62,7 @@ def _infer_line_data(darray, x, y, hue): else: if x is None and y is None and hue is None: - raise ValueError("For 2D inputs, please" "specify either hue, x or y.") + raise ValueError("For 2D inputs, please specify either hue, x or y.") if y is None: xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) @@ -445,6 +445,11 @@ def __init__(self, darray): def __call__(self, **kwargs): return plot(self._da, **kwargs) + # we can't use functools.wraps here since that also modifies the name / qualname + __doc__ = __call__.__doc__ = plot.__doc__ + __call__.__wrapped__ = plot # type: ignore + __call__.__annotations__ = plot.__annotations__ + @functools.wraps(hist) def hist(self, ax=None, **kwargs): return hist(self._da, ax=ax, **kwargs) diff --git a/xarray/testing.py b/xarray/testing.py index e7bf5f9221a..9681503414e 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -1,10 +1,11 @@ """Testing functions exposed to the user API""" +import functools from typing import Hashable, Set, Union import numpy as np import pandas as pd -from xarray.core import duck_array_ops, formatting +from xarray.core import duck_array_ops, formatting, utils from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.indexes import default_indexes @@ -118,27 +119,31 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): """ __tracebackhide__ = True assert type(a) == type(b) - kwargs = dict(rtol=rtol, atol=atol, decode_bytes=decode_bytes) + + equiv = functools.partial( + _data_allclose_or_equiv, rtol=rtol, atol=atol, decode_bytes=decode_bytes + ) + equiv.__name__ = "allclose" + + def compat_variable(a, b): + a = getattr(a, "variable", a) + b = getattr(b, "variable", b) + + return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data)) + if isinstance(a, Variable): - assert a.dims == b.dims - allclose = _data_allclose_or_equiv(a.values, b.values, **kwargs) - assert allclose, f"{a.values}\n{b.values}" + allclose = compat_variable(a, b) + assert allclose, formatting.diff_array_repr(a, b, compat=equiv) elif isinstance(a, DataArray): - assert_allclose(a.variable, b.variable, **kwargs) - assert set(a.coords) == set(b.coords) - for v in a.coords.variables: - # can't recurse with this function as coord is sometimes a - # DataArray, so call into _data_allclose_or_equiv directly - allclose = _data_allclose_or_equiv( - a.coords[v].values, b.coords[v].values, **kwargs - ) - assert allclose, "{}\n{}".format(a.coords[v].values, b.coords[v].values) + allclose = utils.dict_equiv( + a.coords, b.coords, compat=compat_variable + ) and compat_variable(a.variable, b.variable) + assert allclose, formatting.diff_array_repr(a, b, compat=equiv) elif isinstance(a, Dataset): - assert set(a.data_vars) == set(b.data_vars) - assert set(a.coords) == set(b.coords) - for k in list(a.variables) + list(a.coords): - assert_allclose(a[k], b[k], **kwargs) - + allclose = a._coord_names == b._coord_names and utils.dict_equiv( + a.variables, b.variables, compat=compat_variable + ) + assert allclose, formatting.diff_dataset_repr(a, b, compat=equiv) else: raise TypeError("{} not supported by assertion comparison".format(type(a))) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 3642c1eb9b7..6a840e6303e 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -885,7 +885,7 @@ def test_roundtrip_endian(self): "x": np.arange(3, 10, dtype=">i2"), "y": np.arange(3, 20, dtype="array") def test_short_data_repr_html_non_str_keys(dataset): @@ -108,8 +108,8 @@ def test_summarize_attrs_with_unsafe_attr_name_and_value(): def test_repr_of_dataarray(dataarray): formatted = fh.array_repr(dataarray) assert "dim_0" in formatted - # has an expandable data section - assert formatted.count("class='xr-array-in' type='checkbox' >") == 1 + # has an expanded data section + assert formatted.count("class='xr-array-in' type='checkbox' checked>") == 1 # coords and attrs don't have an items so they'll be be disabled and collapsed assert ( formatted.count("class='xr-section-summary-in' type='checkbox' disabled >") == 2 diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 4dc33557d3a..610730e9eb2 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -118,6 +118,12 @@ class TestPlot(PlotTestCase): def setup_array(self): self.darray = DataArray(easy_array((2, 3, 4))) + def test_accessor(self): + from ..plot.plot import _PlotMethods + + assert DataArray.plot is _PlotMethods + assert isinstance(self.darray.plot, _PlotMethods) + def test_label_from_attrs(self): da = self.darray.copy() assert "" == label_from_attrs(da) @@ -150,7 +156,7 @@ def test1d(self): (self.darray[:, 0, 0] + 1j).plot() def test_1d_bool(self): - xr.ones_like(self.darray[:, 0, 0], dtype=np.bool).plot() + xr.ones_like(self.darray[:, 0, 0], dtype=bool).plot() def test_1d_x_y_kw(self): z = np.arange(10) @@ -1038,7 +1044,7 @@ def test_1d_raises_valueerror(self): self.plotfunc(self.darray[0, :]) def test_bool(self): - xr.ones_like(self.darray, dtype=np.bool).plot() + xr.ones_like(self.darray, dtype=bool).plot() def test_complex_raises_typeerror(self): with raises_regex(TypeError, "complex128"): @@ -2105,6 +2111,12 @@ def setUp(self): ds.B.attrs["units"] = "Bunits" self.ds = ds + def test_accessor(self): + from ..plot.dataset_plot import _Dataset_PlotMethods + + assert Dataset.plot is _Dataset_PlotMethods + assert isinstance(self.ds.plot, _Dataset_PlotMethods) + @pytest.mark.parametrize( "add_guide, hue_style, legend, colorbar", [ diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_testing.py index 041b7341ade..f4961af58e9 100644 --- a/xarray/tests/test_testing.py +++ b/xarray/tests/test_testing.py @@ -1,3 +1,5 @@ +import pytest + import xarray as xr @@ -5,3 +7,26 @@ def test_allclose_regression(): x = xr.DataArray(1.01) y = xr.DataArray(1.02) xr.testing.assert_allclose(x, y, atol=0.01) + + +@pytest.mark.parametrize( + "obj1,obj2", + ( + pytest.param( + xr.Variable("x", [1e-17, 2]), xr.Variable("x", [0, 3]), id="Variable", + ), + pytest.param( + xr.DataArray([1e-17, 2], dims="x"), + xr.DataArray([0, 3], dims="x"), + id="DataArray", + ), + pytest.param( + xr.Dataset({"a": ("x", [1e-17, 2]), "b": ("y", [-2e-18, 2])}), + xr.Dataset({"a": ("x", [0, 2]), "b": ("y", [0, 1])}), + id="Dataset", + ), + ), +) +def test_assert_allclose(obj1, obj2): + with pytest.raises(AssertionError): + xr.testing.assert_allclose(obj1, obj2) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 5dd4a42cff0..20a5f0e8613 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -7,9 +7,8 @@ import pytest import xarray as xr -from xarray.core import formatting from xarray.core.npcompat import IS_NEP18_ACTIVE -from xarray.testing import assert_allclose, assert_identical +from xarray.testing import assert_allclose, assert_equal, assert_identical from .test_variable import _PAD_XR_NP_ARGS, VariableSubclassobjects @@ -27,11 +26,6 @@ pytest.mark.skipif( not IS_NEP18_ACTIVE, reason="NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled" ), - # TODO: remove this once pint has a released version with __array_function__ - pytest.mark.skipif( - not hasattr(unit_registry.Quantity, "__array_function__"), - reason="pint does not implement __array_function__ yet", - ), # pytest.mark.filterwarnings("ignore:::pint[.*]"), ] @@ -51,10 +45,23 @@ def dimensionality(obj): def compatible_mappings(first, second): return { key: is_compatible(unit1, unit2) - for key, (unit1, unit2) in merge_mappings(first, second) + for key, (unit1, unit2) in zip_mappings(first, second) } +def merge_mappings(base, *mappings): + result = base.copy() + for m in mappings: + result.update(m) + + return result + + +def zip_mappings(*mappings): + for key in set(mappings[0]).intersection(*mappings[1:]): + yield key, tuple(m[key] for m in mappings) + + def array_extract_units(obj): if isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): obj = obj.data @@ -257,50 +264,11 @@ def assert_units_equal(a, b): assert extract_units(a) == extract_units(b) -def assert_equal_with_units(a, b): - # works like xr.testing.assert_equal, but also explicitly checks units - # so, it is more like assert_identical - __tracebackhide__ = True - - if isinstance(a, xr.Dataset) or isinstance(b, xr.Dataset): - a_units = extract_units(a) - b_units = extract_units(b) - - a_without_units = strip_units(a) - b_without_units = strip_units(b) - - assert a_without_units.equals(b_without_units), formatting.diff_dataset_repr( - a, b, "equals" - ) - assert a_units == b_units - else: - a = a if not isinstance(a, (xr.DataArray, xr.Variable)) else a.data - b = b if not isinstance(b, (xr.DataArray, xr.Variable)) else b.data - - assert type(a) == type(b) or ( - isinstance(a, Quantity) and isinstance(b, Quantity) - ) - - # workaround until pint implements allclose in __array_function__ - if isinstance(a, Quantity) or isinstance(b, Quantity): - assert ( - hasattr(a, "magnitude") and hasattr(b, "magnitude") - ) and np.allclose(a.magnitude, b.magnitude, equal_nan=True) - assert (hasattr(a, "units") and hasattr(b, "units")) and a.units == b.units - else: - assert np.allclose(a, b, equal_nan=True) - - @pytest.fixture(params=[float, int]) def dtype(request): return request.param -def merge_mappings(*mappings): - for key in set(mappings[0]).intersection(*mappings[1:]): - yield key, tuple(m[key] for m in mappings) - - def merge_args(default_args, new_args): from itertools import zip_longest @@ -329,19 +297,29 @@ def __call__(self, obj, *args, **kwargs): all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} + xarray_classes = ( + xr.Variable, + xr.DataArray, + xr.Dataset, + xr.core.groupby.GroupBy, + ) + + if not isinstance(obj, xarray_classes): + # remove typical xarray args like "dim" + exclude_kwargs = ("dim", "dims") + all_kwargs = { + key: value + for key, value in all_kwargs.items() + if key not in exclude_kwargs + } + func = getattr(obj, self.name, None) + if func is None or not isinstance(func, Callable): # fall back to module level numpy functions if not a xarray object if not isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): numpy_func = getattr(np, self.name) func = partial(numpy_func, obj) - # remove typical xarray args like "dim" - exclude_kwargs = ("dim", "dims") - all_kwargs = { - key: value - for key, value in all_kwargs.items() - if key not in exclude_kwargs - } else: raise AttributeError(f"{obj} has no method named '{self.name}'") @@ -425,6 +403,10 @@ def test_apply_ufunc_dataset(dtype): assert_identical(expected, actual) +# TODO: remove once pint==0.12 has been released +@pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" +) @pytest.mark.parametrize( "unit,error", ( @@ -512,6 +494,10 @@ def test_align_dataarray(fill_value, variant, unit, error, dtype): assert_allclose(expected_b, actual_b) +# TODO: remove once pint==0.12 has been released +@pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" +) @pytest.mark.parametrize( "unit,error", ( @@ -929,6 +915,10 @@ def test_concat_dataset(variant, unit, error, dtype): assert_identical(expected, actual) +# TODO: remove once pint==0.12 has been released +@pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" +) @pytest.mark.parametrize( "unit,error", ( @@ -1036,6 +1026,10 @@ def test_merge_dataarray(variant, unit, error, dtype): assert_allclose(expected, actual) +# TODO: remove once pint==0.12 has been released +@pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" +) @pytest.mark.parametrize( "unit,error", ( @@ -1385,7 +1379,6 @@ def wrapper(cls): "test_datetime64_conversion", "test_timedelta64_conversion", "test_pandas_period_index", - "test_1d_math", "test_1d_reduce", "test_array_interface", "test___array__", @@ -1413,13 +1406,20 @@ def example_1d_objects(self): ]: yield (self.cls("x", data), data) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" + ) + def test_real_and_imag(self): + super().test_real_and_imag() + @pytest.mark.parametrize( "func", ( method("all"), method("any"), - method("argmax"), - method("argmin"), + method("argmax", dim="x"), + method("argmin", dim="x"), method("argsort"), method("cumprod"), method("cumsum"), @@ -1443,13 +1443,33 @@ def test_aggregation(self, func, dtype): ) variable = xr.Variable("x", array) - units = extract_units(func(array)) + numpy_kwargs = func.kwargs.copy() + if "dim" in func.kwargs: + numpy_kwargs["axis"] = variable.get_axis_num(numpy_kwargs.pop("dim")) + + units = extract_units(func(array, **numpy_kwargs)) expected = attach_units(func(strip_units(variable)), units) actual = func(variable) assert_units_equal(expected, actual) - xr.testing.assert_identical(expected, actual) + assert_allclose(expected, actual) + + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" + ) + def test_aggregate_complex(self): + variable = xr.Variable("x", [1, 2j, np.nan] * unit_registry.m) + expected = xr.Variable((), (0.5 + 1j) * unit_registry.m) + actual = variable.mean() + + assert_units_equal(expected, actual) + xr.testing.assert_allclose(expected, actual) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" + ) @pytest.mark.parametrize( "func", ( @@ -1748,6 +1768,10 @@ def test_isel(self, indices, dtype): assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" + ) @pytest.mark.parametrize( "unit,error", ( @@ -1886,7 +1910,7 @@ def test_squeeze(self, dtype): pytest.param( method("quantile", q=[0.25, 0.75]), marks=pytest.mark.xfail( - LooseVersion(pint.__version__) < "0.12", + LooseVersion(pint.__version__) <= "0.12", reason="quantile / nanquantile not implemented yet", ), ), @@ -2224,18 +2248,34 @@ def test_repr(self, func, variant, dtype): # warnings or errors, but does not check the result func(data_array) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose", + ) @pytest.mark.parametrize( "func", ( function("all"), function("any"), - function("argmax"), - function("argmin"), + pytest.param( + function("argmax"), + marks=pytest.mark.skip( + reason="calling np.argmax as a function on xarray objects is not " + "supported" + ), + ), + pytest.param( + function("argmin"), + marks=pytest.mark.skip( + reason="calling np.argmin as a function on xarray objects is not " + "supported" + ), + ), function("max"), function("mean"), pytest.param( function("median"), - marks=pytest.mark.xfail( + marks=pytest.mark.skip( reason="median does not work with dataarrays yet" ), ), @@ -2251,8 +2291,8 @@ def test_repr(self, func, variant, dtype): function("cumprod"), method("all"), method("any"), - method("argmax"), - method("argmin"), + method("argmax", dim="x"), + method("argmin", dim="x"), method("max"), method("mean"), method("median"), @@ -2275,6 +2315,10 @@ def test_aggregation(self, func, dtype): ) data_array = xr.DataArray(data=array, dims="x") + numpy_kwargs = func.kwargs.copy() + if "dim" in numpy_kwargs: + numpy_kwargs["axis"] = data_array.get_axis_num(numpy_kwargs.pop("dim")) + # units differ based on the applied function, so we need to # first compute the units units = extract_units(func(array)) @@ -2282,7 +2326,7 @@ def test_aggregation(self, func, dtype): actual = func(data_array) assert_units_equal(expected, actual) - xr.testing.assert_allclose(expected, actual) + assert_allclose(expected, actual) @pytest.mark.parametrize( "func", @@ -3283,6 +3327,10 @@ def test_head_tail_thin(self, func, dtype): assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" + ) @pytest.mark.parametrize("variant", ("data", "coords")) @pytest.mark.parametrize( "func", @@ -3356,6 +3404,10 @@ def test_interp_reindex_indexing(self, func, unit, error, dtype): assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" + ) @pytest.mark.parametrize("variant", ("data", "coords")) @pytest.mark.parametrize( "func", @@ -3523,7 +3575,7 @@ def test_stacking_reordering(self, func, dtype): pytest.param( method("quantile", q=[0.25, 0.75]), marks=pytest.mark.xfail( - LooseVersion(pint.__version__) < "0.12", + LooseVersion(pint.__version__) <= "0.12", reason="quantile / nanquantile not implemented yet", ), ), @@ -3558,6 +3610,10 @@ def test_computation(self, func, dtype): assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) + # TODO: remove once pint==0.12 has been released + @pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", reason="pint bug in isclose" + ) @pytest.mark.parametrize( "func", ( @@ -3572,7 +3628,9 @@ def test_computation(self, func, dtype): ), pytest.param( method("rolling_exp", y=3), - marks=pytest.mark.xfail(reason="units not supported by numbagg"), + marks=pytest.mark.xfail( + reason="numbagg functions are not supported by pint" + ), ), ), ids=repr, @@ -3618,7 +3676,7 @@ def test_resample(self, dtype): pytest.param( method("quantile", q=[0.25, 0.5, 0.75], dim="x"), marks=pytest.mark.xfail( - LooseVersion(pint.__version__) < "0.12", + LooseVersion(pint.__version__) <= "0.12", reason="quantile / nanquantile not implemented yet", ), ), @@ -3653,15 +3711,16 @@ def test_grouped_operations(self, func, dtype): xr.testing.assert_identical(expected, actual) +@pytest.mark.filterwarnings("error::pint.UnitStrippedWarning") class TestDataset: @pytest.mark.parametrize( "unit,error", ( - pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param(1, xr.MergeError, id="no_unit"), pytest.param( - unit_registry.dimensionless, DimensionalityError, id="dimensionless" + unit_registry.dimensionless, xr.MergeError, id="dimensionless" ), - pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.s, xr.MergeError, id="incompatible_unit"), pytest.param(unit_registry.mm, None, id="compatible_unit"), pytest.param(unit_registry.m, None, id="same_unit"), ), @@ -3670,11 +3729,10 @@ class TestDataset: "shared", ( "nothing", - pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), pytest.param( - "coords", - marks=pytest.mark.xfail(reason="reindex does not work with pint yet"), + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") ), + "coords", ), ) def test_init(self, shared, unit, error, dtype): @@ -3682,60 +3740,53 @@ def test_init(self, shared, unit, error, dtype): scaled_unit = unit_registry.mm a = np.linspace(0, 1, 10).astype(dtype) * unit_registry.Pa - b = np.linspace(-1, 0, 12).astype(dtype) * unit_registry.Pa - - raw_x = np.arange(a.shape[0]) - x = raw_x * original_unit - x2 = x.to(scaled_unit) - - raw_y = np.arange(b.shape[0]) - y = raw_y * unit - y_units = unit if isinstance(y, unit_registry.Quantity) else None - if isinstance(y, unit_registry.Quantity): - if y.check(scaled_unit): - y2 = y.to(scaled_unit) - else: - y2 = y * 1000 - y2_units = y2.units - else: - y2 = y * 1000 - y2_units = None + b = np.linspace(-1, 0, 10).astype(dtype) * unit_registry.degK + + values_a = np.arange(a.shape[0]) + dim_a = values_a * original_unit + coord_a = dim_a.to(scaled_unit) + + values_b = np.arange(b.shape[0]) + dim_b = values_b * unit + coord_b = ( + dim_b.to(scaled_unit) + if unit_registry.is_compatible_with(dim_b, scaled_unit) + and unit != scaled_unit + else dim_b * 1000 + ) variants = { - "nothing": ({"x": x, "x2": ("x", x2)}, {"y": y, "y2": ("y", y2)}), - "dims": ( - {"x": x, "x2": ("x", strip_units(x2))}, - {"x": y, "y2": ("x", strip_units(y2))}, + "nothing": ({}, {}), + "dims": ({"x": dim_a}, {"x": dim_b}), + "coords": ( + {"x": values_a, "y": ("x", coord_a)}, + {"x": values_b, "y": ("x", coord_b)}, ), - "coords": ({"x": raw_x, "y": ("x", x2)}, {"x": raw_y, "y": ("x", y2)}), } coords_a, coords_b = variants.get(shared) dims_a, dims_b = ("x", "y") if shared == "nothing" else ("x", "x") - arr1 = xr.DataArray(data=a, coords=coords_a, dims=dims_a) - arr2 = xr.DataArray(data=b, coords=coords_b, dims=dims_b) + a = xr.DataArray(data=a, coords=coords_a, dims=dims_a) + b = xr.DataArray(data=b, coords=coords_b, dims=dims_b) + if error is not None and shared != "nothing": with pytest.raises(error): - xr.Dataset(data_vars={"a": arr1, "b": arr2}) + xr.Dataset(data_vars={"a": a, "b": b}) return - actual = xr.Dataset(data_vars={"a": arr1, "b": arr2}) + actual = xr.Dataset(data_vars={"a": a, "b": b}) - expected_units = { - "a": a.units, - "b": b.units, - "x": x.units, - "x2": x2.units, - "y": y_units, - "y2": y2_units, - } + units = merge_mappings( + extract_units(a.rename("a")), extract_units(b.rename("b")) + ) expected = attach_units( - xr.Dataset(data_vars={"a": strip_units(arr1), "b": strip_units(arr2)}), - expected_units, + xr.Dataset(data_vars={"a": strip_units(a), "b": strip_units(b)}), units ) - assert_equal_with_units(actual, expected) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", (pytest.param(str, id="str"), pytest.param(repr, id="repr")) @@ -3743,79 +3794,79 @@ def test_init(self, shared, unit, error, dtype): @pytest.mark.parametrize( "variant", ( + "data", pytest.param( - "with_dims", + "dims", marks=pytest.mark.xfail(reason="units in indexes are not supported"), ), - pytest.param("with_coords"), - pytest.param("without_coords"), + "coords", ), ) - @pytest.mark.filterwarnings("error:::pint[.*]") def test_repr(self, func, variant, dtype): - array1 = np.linspace(1, 2, 10, dtype=dtype) * unit_registry.Pa - array2 = np.linspace(0, 1, 10, dtype=dtype) * unit_registry.degK + unit1, unit2 = ( + (unit_registry.Pa, unit_registry.degK) if variant == "data" else (1, 1) + ) + + array1 = np.linspace(1, 2, 10, dtype=dtype) * unit1 + array2 = np.linspace(0, 1, 10, dtype=dtype) * unit2 x = np.arange(len(array1)) * unit_registry.s y = x.to(unit_registry.ms) variants = { - "with_dims": {"x": x}, - "with_coords": {"y": ("x", y)}, - "without_coords": {}, + "dims": {"x": x}, + "coords": {"y": ("x", y)}, + "data": {}, } - data_array = xr.Dataset( + ds = xr.Dataset( data_vars={"a": ("x", array1), "b": ("x", array2)}, coords=variants.get(variant), ) # FIXME: this just checks that the repr does not raise # warnings or errors, but does not check the result - func(data_array) + func(ds) @pytest.mark.parametrize( "func", ( + function("all"), + function("any"), pytest.param( - function("all"), - marks=pytest.mark.xfail(reason="not implemented by pint"), + function("argmax"), + marks=pytest.mark.skip( + reason="calling np.argmax as a function on xarray objects is not " + "supported" + ), ), pytest.param( - function("any"), - marks=pytest.mark.xfail(reason="not implemented by pint"), + function("argmin"), + marks=pytest.mark.skip( + reason="calling np.argmin as a function on xarray objects is not " + "supported" + ), ), - function("argmax"), - function("argmin"), function("max"), function("min"), function("mean"), pytest.param( function("median"), - marks=pytest.mark.xfail( - reason="np.median does not work with dataset yet" - ), + marks=pytest.mark.xfail(reason="median does not work with dataset yet"), ), function("sum"), pytest.param( function("prod"), - marks=pytest.mark.xfail(reason="not implemented by pint"), + marks=pytest.mark.xfail(reason="prod does not work with dataset yet"), ), function("std"), function("var"), function("cumsum"), - pytest.param( - function("cumprod"), - marks=pytest.mark.xfail(reason="fails within xarray"), - ), - pytest.param( - method("all"), marks=pytest.mark.xfail(reason="not implemented by pint") - ), - pytest.param( - method("any"), marks=pytest.mark.xfail(reason="not implemented by pint") - ), - method("argmax"), - method("argmin"), + function("cumprod"), + method("all"), + method("any"), + method("argmax", dim="x"), + method("argmin", dim="x"), method("max"), method("min"), method("mean"), @@ -3823,68 +3874,64 @@ def test_repr(self, func, variant, dtype): method("sum"), pytest.param( method("prod"), - marks=pytest.mark.xfail(reason="not implemented by pint"), + marks=pytest.mark.xfail(reason="prod does not work with dataset yet"), ), method("std"), method("var"), method("cumsum"), - pytest.param( - method("cumprod"), marks=pytest.mark.xfail(reason="fails within xarray") - ), + method("cumprod"), ), ids=repr, ) def test_aggregation(self, func, dtype): - unit_a = ( - unit_registry.Pa if func.name != "cumprod" else unit_registry.dimensionless - ) - unit_b = ( - unit_registry.kg / unit_registry.m ** 3 + unit_a, unit_b = ( + (unit_registry.Pa, unit_registry.degK) if func.name != "cumprod" - else unit_registry.dimensionless - ) - a = xr.DataArray(data=np.linspace(0, 1, 10).astype(dtype) * unit_a, dims="x") - b = xr.DataArray(data=np.linspace(-1, 0, 10).astype(dtype) * unit_b, dims="x") - x = xr.DataArray(data=np.arange(10).astype(dtype) * unit_registry.m, dims="x") - y = xr.DataArray( - data=np.arange(10, 20).astype(dtype) * unit_registry.s, dims="x" + else (unit_registry.dimensionless, unit_registry.dimensionless) ) - ds = xr.Dataset(data_vars={"a": a, "b": b}, coords={"x": x, "y": y}) + a = np.linspace(0, 1, 10).astype(dtype) * unit_a + b = np.linspace(-1, 0, 10).astype(dtype) * unit_b + + ds = xr.Dataset({"a": ("x", a), "b": ("x", b)}) + + if "dim" in func.kwargs: + numpy_kwargs = func.kwargs.copy() + dim = numpy_kwargs.pop("dim") + + axis_a = ds.a.get_axis_num(dim) + axis_b = ds.b.get_axis_num(dim) + + numpy_kwargs_a = numpy_kwargs.copy() + numpy_kwargs_a["axis"] = axis_a + numpy_kwargs_b = numpy_kwargs.copy() + numpy_kwargs_b["axis"] = axis_b + else: + numpy_kwargs_a = {} + numpy_kwargs_b = {} + + units_a = array_extract_units(func(a, **numpy_kwargs_a)) + units_b = array_extract_units(func(b, **numpy_kwargs_b)) + units = {"a": units_a, "b": units_b} actual = func(ds) - expected = attach_units( - func(strip_units(ds)), - { - "a": extract_units(func(a)).get(None), - "b": extract_units(func(b)).get(None), - }, - ) + expected = attach_units(func(strip_units(ds)), units) - assert_equal_with_units(actual, expected) + assert_units_equal(expected, actual) + assert_allclose(expected, actual) @pytest.mark.parametrize("property", ("imag", "real")) def test_numpy_properties(self, property, dtype): - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray( - data=np.linspace(0, 1, 10) * unit_registry.Pa, dims="x" - ), - "b": xr.DataArray( - data=np.linspace(-1, 0, 15) * unit_registry.Pa, dims="y" - ), - }, - coords={ - "x": np.arange(10) * unit_registry.m, - "y": np.arange(15) * unit_registry.s, - }, - ) + a = np.linspace(0, 1, 10) * unit_registry.Pa + b = np.linspace(-1, 0, 15) * unit_registry.degK + ds = xr.Dataset({"a": ("x", a), "b": ("y", b)}) units = extract_units(ds) actual = getattr(ds, property) expected = attach_units(getattr(strip_units(ds), property), units) - assert_equal_with_units(actual, expected) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", @@ -3898,31 +3945,19 @@ def test_numpy_properties(self, property, dtype): ids=repr, ) def test_numpy_methods(self, func, dtype): - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray( - data=np.linspace(1, -1, 10) * unit_registry.Pa, dims="x" - ), - "b": xr.DataArray( - data=np.linspace(-1, 1, 15) * unit_registry.Pa, dims="y" - ), - }, - coords={ - "x": np.arange(10) * unit_registry.m, - "y": np.arange(15) * unit_registry.s, - }, - ) - units = { - "a": array_extract_units(func(ds.a)), - "b": array_extract_units(func(ds.b)), - "x": unit_registry.m, - "y": unit_registry.s, - } + a = np.linspace(1, -1, 10) * unit_registry.Pa + b = np.linspace(-1, 1, 15) * unit_registry.degK + ds = xr.Dataset({"a": ("x", a), "b": ("y", b)}) + + units_a = array_extract_units(func(a)) + units_b = array_extract_units(func(b)) + units = {"a": units_a, "b": units_b} actual = func(ds) expected = attach_units(func(strip_units(ds)), units) - assert_equal_with_units(actual, expected) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize("func", (method("clip", min=3, max=8),), ids=repr) @pytest.mark.parametrize( @@ -3939,21 +3974,13 @@ def test_numpy_methods(self, func, dtype): ) def test_numpy_methods_with_args(self, func, unit, error, dtype): data_unit = unit_registry.m - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=np.arange(10) * data_unit, dims="x"), - "b": xr.DataArray(data=np.arange(15) * data_unit, dims="y"), - }, - coords={ - "x": np.arange(10) * unit_registry.m, - "y": np.arange(15) * unit_registry.s, - }, - ) + a = np.linspace(0, 10, 15) * unit_registry.m + b = np.linspace(-2, 12, 20) * unit_registry.m + ds = xr.Dataset({"a": ("x", a), "b": ("y", b)}) units = extract_units(ds) kwargs = { - key: (value * unit if isinstance(value, (int, float)) else value) - for key, value in func.kwargs.items() + key: array_attach_units(value, unit) for key, value in func.kwargs.items() } if error is not None: @@ -3970,7 +3997,8 @@ def test_numpy_methods_with_args(self, func, unit, error, dtype): actual = func(ds, **kwargs) expected = attach_units(func(strip_units(ds), **stripped_kwargs), units) - assert_equal_with_units(actual, expected) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", (method("isnull"), method("notnull"), method("count")), ids=repr @@ -4000,22 +4028,13 @@ def test_missing_value_detection(self, func, dtype): * unit_registry.Pa ) - x = np.arange(array1.shape[0]) * unit_registry.m - y = np.arange(array1.shape[1]) * unit_registry.m - z = np.arange(array2.shape[0]) * unit_registry.m - - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("z", "x")), - }, - coords={"x": x, "y": y, "z": z}, - ) + ds = xr.Dataset({"a": (("x", "y"), array1), "b": (("z", "x"), array2)}) expected = func(strip_units(ds)) actual = func(ds) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.xfail(reason="ffill and bfill lose the unit") @pytest.mark.parametrize("func", (method("ffill"), method("bfill")), ids=repr) @@ -4029,23 +4048,14 @@ def test_missing_value_filling(self, func, dtype): * unit_registry.Pa ) - x = np.arange(len(array1)) - - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - }, - coords={"x": x}, - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("y", array2)}) + units = extract_units(ds) - expected = attach_units( - func(strip_units(ds), dim="x"), - {"a": unit_registry.degK, "b": unit_registry.Pa}, - ) + expected = attach_units(func(strip_units(ds), dim="x"), units) actual = func(ds, dim="x") - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "unit,error", @@ -4055,14 +4065,7 @@ def test_missing_value_filling(self, func, dtype): unit_registry.dimensionless, DimensionalityError, id="dimensionless" ), pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param( - unit_registry.cm, - None, - id="compatible_unit", - marks=pytest.mark.xfail( - reason="where converts the array, not the fill value" - ), - ), + pytest.param(unit_registry.cm, None, id="compatible_unit",), pytest.param(unit_registry.m, None, id="identical_unit"), ), ) @@ -4083,30 +4086,26 @@ def test_fillna(self, fill_value, unit, error, dtype): np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) * unit_registry.m ) - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - } - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) + value = fill_value * unit + units = extract_units(ds) if error is not None: with pytest.raises(error): - ds.fillna(value=fill_value * unit) + ds.fillna(value=value) return - actual = ds.fillna(value=fill_value * unit) + actual = ds.fillna(value=value) expected = attach_units( strip_units(ds).fillna( - value=strip_units( - convert_units(fill_value * unit, {None: unit_registry.m}) - ) + value=strip_units(convert_units(value, {None: unit_registry.m})) ), - {"a": unit_registry.m, "b": unit_registry.m}, + units, ) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) def test_dropna(self, dtype): array1 = ( @@ -4117,22 +4116,14 @@ def test_dropna(self, dtype): np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) * unit_registry.Pa ) - x = np.arange(len(array1)) - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - }, - coords={"x": x}, - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) + units = extract_units(ds) - expected = attach_units( - strip_units(ds).dropna(dim="x"), - {"a": unit_registry.degK, "b": unit_registry.Pa}, - ) + expected = attach_units(strip_units(ds).dropna(dim="x"), units) actual = ds.dropna(dim="x") - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "unit", @@ -4153,34 +4144,28 @@ def test_isin(self, unit, dtype): np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) * unit_registry.m ) - x = np.arange(len(array1)) - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - }, - coords={"x": x}, - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) raw_values = np.array([1.4, np.nan, 2.3]).astype(dtype) values = raw_values * unit - if ( - isinstance(values, unit_registry.Quantity) - and values.check(unit_registry.m) - and unit != unit_registry.m - ): - raw_values = values.to(unit_registry.m).magnitude + converted_values = ( + convert_units(values, {None: unit_registry.m}) + if is_compatible(unit, unit_registry.m) + else values + ) - expected = strip_units(ds).isin(raw_values) - if not isinstance(values, unit_registry.Quantity) or not values.check( - unit_registry.m - ): + expected = strip_units(ds).isin(strip_units(converted_values)) + # TODO: use `unit_registry.is_compatible_with(unit, unit_registry.m)` instead. + # Needs `pint>=0.12.1`, though, so we probably should wait until that is released. + if not is_compatible(unit, unit_registry.m): expected.a[:] = False expected.b[:] = False + actual = ds.isin(values) - assert_equal_with_units(actual, expected) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "variant", ("masking", "replacing_scalar", "replacing_array", "dropping") @@ -4202,13 +4187,8 @@ def test_where(self, variant, unit, error, dtype): array1 = np.linspace(0, 1, 10).astype(dtype) * original_unit array2 = np.linspace(-1, 0, 10).astype(dtype) * original_unit - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - }, - coords={"x": np.arange(len(array1))}, - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) + units = extract_units(ds) condition = ds < 0.5 * original_unit other = np.linspace(-2, -1, 10).astype(dtype) * unit @@ -4230,15 +4210,13 @@ def test_where(self, variant, unit, error, dtype): for key, value in kwargs.items() } - expected = attach_units( - strip_units(ds).where(**kwargs_without_units), - {"a": original_unit, "b": original_unit}, - ) + expected = attach_units(strip_units(ds).where(**kwargs_without_units), units,) actual = ds.where(**kwargs) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) - @pytest.mark.xfail(reason="interpolate strips units") + @pytest.mark.xfail(reason="interpolate_na uses numpy.vectorize") def test_interpolate_na(self, dtype): array1 = ( np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) @@ -4248,24 +4226,15 @@ def test_interpolate_na(self, dtype): np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) * unit_registry.Pa ) - x = np.arange(len(array1)) - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - }, - coords={"x": x}, - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) + units = extract_units(ds) - expected = attach_units( - strip_units(ds).interpolate_na(dim="x"), - {"a": unit_registry.degK, "b": unit_registry.Pa}, - ) + expected = attach_units(strip_units(ds).interpolate_na(dim="x"), units,) actual = ds.interpolate_na(dim="x") - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) - @pytest.mark.xfail(reason="wrong argument order for `where`") @pytest.mark.parametrize( "unit,error", ( @@ -4278,31 +4247,40 @@ def test_interpolate_na(self, dtype): pytest.param(unit_registry.m, None, id="same_unit"), ), ) - def test_combine_first(self, unit, error, dtype): + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units"), + ), + ), + ) + def test_combine_first(self, variant, unit, error, dtype): + variants = { + "data": (unit_registry.m, unit, 1, 1), + "dims": (1, 1, unit_registry.m, unit), + } + data_unit, other_data_unit, dims_unit, other_dims_unit = variants.get(variant) + array1 = ( - np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) - * unit_registry.m + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) * data_unit ) array2 = ( - np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) - * unit_registry.m + np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) * data_unit ) - x = np.arange(len(array1)) + x = np.arange(len(array1)) * dims_unit ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - }, - coords={"x": x}, + data_vars={"a": ("x", array1), "b": ("x", array2)}, coords={"x": x}, ) - other_array1 = np.ones_like(array1) * unit - other_array2 = -1 * np.ones_like(array2) * unit + units = extract_units(ds) + + other_array1 = np.ones_like(array1) * other_data_unit + other_array2 = np.full_like(array2, fill_value=-1) * other_data_unit + other_x = (np.arange(array1.shape[0]) + 5) * other_dims_unit other = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=other_array1, dims="x"), - "b": xr.DataArray(data=other_array2, dims="x"), - }, - coords={"x": np.arange(array1.shape[0])}, + data_vars={"a": ("x", other_array1), "b": ("x", other_array2)}, + coords={"x": other_x}, ) if error is not None: @@ -4312,16 +4290,13 @@ def test_combine_first(self, unit, error, dtype): return expected = attach_units( - strip_units(ds).combine_first( - strip_units( - convert_units(other, {"a": unit_registry.m, "b": unit_registry.m}) - ) - ), - {"a": unit_registry.m, "b": unit_registry.m}, + strip_units(ds).combine_first(strip_units(convert_units(other, units))), + units, ) actual = ds.combine_first(other) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "unit", @@ -4334,7 +4309,7 @@ def test_combine_first(self, unit, error, dtype): ), ) @pytest.mark.parametrize( - "variation", + "variant", ( "data", pytest.param( @@ -4343,50 +4318,67 @@ def test_combine_first(self, unit, error, dtype): "coords", ), ) - @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr) - def test_comparisons(self, func, variation, unit, dtype): - def is_compatible(a, b): - a = a if a is not None else 1 - b = b if b is not None else 1 - quantity = np.arange(5) * a - - return a == b or quantity.check(b) - + @pytest.mark.parametrize( + "func", + ( + method("equals"), + pytest.param( + method("identical"), + marks=pytest.mark.skip("behaviour of identical is unclear"), + ), + ), + ids=repr, + ) + def test_comparisons(self, func, variant, unit, dtype): array1 = np.linspace(0, 5, 10).astype(dtype) array2 = np.linspace(-5, 0, 10).astype(dtype) coord = np.arange(len(array1)).astype(dtype) - original_unit = unit_registry.m - quantity1 = array1 * original_unit - quantity2 = array2 * original_unit - x = coord * original_unit - y = coord * original_unit + variants = { + "data": (unit_registry.m, 1, 1), + "dims": (1, unit_registry.m, 1), + "coords": (1, 1, unit_registry.m), + } + data_unit, dim_unit, coord_unit = variants.get(variant) - units = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} - data_unit, dim_unit, coord_unit = units.get(variation) + a = array1 * data_unit + b = array2 * data_unit + x = coord * dim_unit + y = coord * coord_unit ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=quantity1, dims="x"), - "b": xr.DataArray(data=quantity2, dims="x"), - }, - coords={"x": x, "y": ("x", y)}, + data_vars={"a": ("x", a), "b": ("x", b)}, coords={"x": x, "y": ("x", y)}, ) + units = extract_units(ds) + + other_variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } + other_data_unit, other_dim_unit, other_coord_unit = other_variants.get(variant) other_units = { - "a": data_unit if quantity1.check(data_unit) else None, - "b": data_unit if quantity2.check(data_unit) else None, - "x": dim_unit if x.check(dim_unit) else None, - "y": coord_unit if y.check(coord_unit) else None, + "a": other_data_unit, + "b": other_data_unit, + "x": other_dim_unit, + "y": other_coord_unit, } - other = attach_units(strip_units(convert_units(ds, other_units)), other_units) - units = extract_units(ds) + to_convert = { + key: unit if is_compatible(unit, reference) else None + for key, (unit, reference) in zip_mappings(units, other_units) + } + # convert units where possible, then attach all units to the converted dataset + other = attach_units(strip_units(convert_units(ds, to_convert)), other_units) other_units = extract_units(other) + # make sure all units are compatible and only then try to + # convert and compare values equal_ds = all( - is_compatible(units[name], other_units[name]) for name in units.keys() + is_compatible(unit, other_unit) + for _, (unit, other_unit) in zip_mappings(units, other_units) ) and (strip_units(ds).equals(strip_units(convert_units(other, units)))) equal_units = units == other_units expected = equal_ds and (func.name != "identical" or equal_units) @@ -4395,6 +4387,9 @@ def is_compatible(a, b): assert expected == actual + # TODO: eventually use another decorator / wrapper function that + # applies a filter to the parametrize combinations: + # we only need a single test for data @pytest.mark.parametrize( "unit", ( @@ -4405,14 +4400,29 @@ def is_compatible(a, b): pytest.param(unit_registry.m, id="identical_unit"), ), ) - def test_broadcast_like(self, unit, dtype): - array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * unit_registry.Pa - array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * unit_registry.Pa + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units"), + ), + ), + ) + def test_broadcast_like(self, variant, unit, dtype): + variants = { + "data": ((unit_registry.m, unit), (1, 1)), + "dims": ((1, 1), (unit_registry.m, unit)), + } + (data_unit1, data_unit2), (dim_unit1, dim_unit2) = variants.get(variant) - x1 = np.arange(2) * unit_registry.m - x2 = np.arange(2) * unit - y1 = np.array([0]) * unit_registry.m - y2 = np.arange(3) * unit + array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * data_unit1 + array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * data_unit2 + + x1 = np.arange(2) * dim_unit1 + x2 = np.arange(2) * dim_unit2 + y1 = np.array([0]) * dim_unit1 + y2 = np.arange(3) * dim_unit2 ds1 = xr.Dataset( data_vars={"a": (("x", "y"), array1)}, coords={"x": x1, "y": y1} @@ -4426,7 +4436,8 @@ def test_broadcast_like(self, unit, dtype): ) actual = ds1.broadcast_like(ds2) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "unit", @@ -4439,32 +4450,25 @@ def test_broadcast_like(self, unit, dtype): ), ) def test_broadcast_equals(self, unit, dtype): + # TODO: does this use indexes? left_array1 = np.ones(shape=(2, 3), dtype=dtype) * unit_registry.m left_array2 = np.zeros(shape=(3, 6), dtype=dtype) * unit_registry.m right_array1 = np.ones(shape=(2,)) * unit - right_array2 = np.ones(shape=(3,)) * unit + right_array2 = np.zeros(shape=(3,)) * unit left = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=left_array1, dims=("x", "y")), - "b": xr.DataArray(data=left_array2, dims=("y", "z")), - } - ) - right = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=right_array1, dims="x"), - "b": xr.DataArray(data=right_array2, dims="y"), - } + {"a": (("x", "y"), left_array1), "b": (("y", "z"), left_array2)}, ) + right = xr.Dataset({"a": ("x", right_array1), "b": ("y", right_array2)}) - units = { - **extract_units(left), - **({} if left_array1.check(unit) else {"a": None, "b": None}), - } - expected = strip_units(left).broadcast_equals( - strip_units(convert_units(right, units)) - ) & left_array1.check(unit) + units = merge_mappings( + extract_units(left), + {} if is_compatible(left_array1, unit) else {"a": None, "b": None}, + ) + expected = is_compatible(left_array1, unit) and strip_units( + left + ).broadcast_equals(strip_units(convert_units(right, units))) actual = left.broadcast_equals(right) assert expected == actual @@ -4474,68 +4478,74 @@ def test_broadcast_equals(self, unit, dtype): (method("unstack"), method("reset_index", "v"), method("reorder_levels")), ids=repr, ) - def test_stacking_stacked(self, func, dtype): - array1 = ( - np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m - ) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units"), + ), + ), + ) + def test_stacking_stacked(self, variant, func, dtype): + variants = { + "data": (unit_registry.m, 1), + "dims": (1, unit_registry.m), + } + data_unit, dim_unit = variants.get(variant) + + array1 = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit array2 = ( np.linspace(-10, 0, 5 * 10 * 15).reshape(5, 10, 15).astype(dtype) - * unit_registry.m + * data_unit ) - x = np.arange(array1.shape[0]) - y = np.arange(array1.shape[1]) - z = np.arange(array2.shape[2]) + x = np.arange(array1.shape[0]) * dim_unit + y = np.arange(array1.shape[1]) * dim_unit + z = np.arange(array2.shape[2]) * dim_unit ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "y", "z")), - }, + data_vars={"a": (("x", "y"), array1), "b": (("x", "y", "z"), array2)}, coords={"x": x, "y": y, "z": z}, ) + units = extract_units(ds) stacked = ds.stack(v=("x", "y")) - expected = attach_units( - func(strip_units(stacked)), {"a": unit_registry.m, "b": unit_registry.m} - ) + expected = attach_units(func(strip_units(stacked)), units) actual = func(stacked) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) - @pytest.mark.xfail(reason="does not work with quantities yet") + @pytest.mark.xfail( + reason="stacked dimension's labels have to be hashable, but is a numpy.array" + ) def test_to_stacked_array(self, dtype): - labels = np.arange(5).astype(dtype) * unit_registry.s - arrays = {name: np.linspace(0, 1, 10) * unit_registry.m for name in labels} + labels = range(5) * unit_registry.s + arrays = { + name: np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + for name in labels + } - ds = xr.Dataset( - data_vars={ - name: xr.DataArray(data=array, dims="x") - for name, array in arrays.items() - } - ) + ds = xr.Dataset({name: ("x", array) for name, array in arrays.items()}) + units = {None: unit_registry.m, "y": unit_registry.s} func = method("to_stacked_array", "z", variable_dim="y", sample_dims=["x"]) actual = func(ds).rename(None) - expected = attach_units( - func(strip_units(ds)).rename(None), - {None: unit_registry.m, "y": unit_registry.s}, - ) + expected = attach_units(func(strip_units(ds)).rename(None), units,) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", ( method("transpose", "y", "x", "z1", "z2"), - method("stack", a=("x", "y")), + method("stack", u=("x", "y")), method("set_index", x="x2"), - pytest.param( - method("shift", x=2), - marks=pytest.mark.xfail(reason="tries to concatenate nan arrays"), - ), + method("shift", x=2), method("roll", x=2, roll_coords=False), method("sortby", "x2"), ), @@ -4560,20 +4570,19 @@ def test_stacking_reordering(self, func, dtype): ds = xr.Dataset( data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y", "z1")), - "b": xr.DataArray(data=array2, dims=("x", "y", "z2")), + "a": (("x", "y", "z1"), array1), + "b": (("x", "y", "z2"), array2), }, coords={"x": x, "y": y, "z1": z1, "z2": z2, "x2": ("x", x2)}, ) + units = extract_units(ds) - expected = attach_units( - func(strip_units(ds)), {"a": unit_registry.Pa, "b": unit_registry.degK} - ) + expected = attach_units(func(strip_units(ds)), units) actual = func(ds) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) - @pytest.mark.xfail(reason="indexes strip units") @pytest.mark.parametrize( "indices", ( @@ -4585,22 +4594,14 @@ def test_isel(self, indices, dtype): array1 = np.arange(10).astype(dtype) * unit_registry.s array2 = np.linspace(0, 1, 10).astype(dtype) * unit_registry.Pa - x = np.arange(len(array1)) * unit_registry.m - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims="x"), - "b": xr.DataArray(data=array2, dims="x"), - }, - coords={"x": x}, - ) + ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("x", array2)}) + units = extract_units(ds) - expected = attach_units( - strip_units(ds).isel(x=indices), - {"a": unit_registry.s, "b": unit_registry.Pa, "x": unit_registry.m}, - ) + expected = attach_units(strip_units(ds).isel(x=indices), units) actual = ds.isel(x=indices) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( @@ -4617,7 +4618,7 @@ def test_isel(self, indices, dtype): pytest.param(1, KeyError, id="no_units"), pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), - pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.mm, KeyError, id="compatible_unit"), pytest.param(unit_registry.m, None, id="identical_unit"), ), ) @@ -4636,20 +4637,24 @@ def test_sel(self, raw_values, unit, error, dtype): values = raw_values * unit - if error is not None and not ( - isinstance(raw_values, (int, float)) and x.check(unit) - ): + # TODO: if we choose dm as compatible unit, single value keys + # can be found. Should we check that? + if error is not None: with pytest.raises(error): ds.sel(x=values) return expected = attach_units( - strip_units(ds).sel(x=strip_units(convert_units(values, {None: x.units}))), - {"a": array1.units, "b": array2.units, "x": x.units}, + strip_units(ds).sel( + x=strip_units(convert_units(values, {None: unit_registry.m})) + ), + extract_units(ds), ) actual = ds.sel(x=values) - assert_equal_with_units(expected, actual) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( @@ -4666,7 +4671,7 @@ def test_sel(self, raw_values, unit, error, dtype): pytest.param(1, KeyError, id="no_units"), pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), - pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.mm, KeyError, id="compatible_unit"), pytest.param(unit_registry.m, None, id="identical_unit"), ), ) @@ -4685,9 +4690,9 @@ def test_drop_sel(self, raw_values, unit, error, dtype): values = raw_values * unit - if error is not None and not ( - isinstance(raw_values, (int, float)) and x.check(unit) - ): + # TODO: if we choose dm as compatible unit, single value keys + # can be found. Should we check that? + if error is not None: with pytest.raises(error): ds.drop_sel(x=values) @@ -4695,12 +4700,14 @@ def test_drop_sel(self, raw_values, unit, error, dtype): expected = attach_units( strip_units(ds).drop_sel( - x=strip_units(convert_units(values, {None: x.units})) + x=strip_units(convert_units(values, {None: unit_registry.m})) ), extract_units(ds), ) actual = ds.drop_sel(x=values) - assert_equal_with_units(expected, actual) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( @@ -4717,7 +4724,7 @@ def test_drop_sel(self, raw_values, unit, error, dtype): pytest.param(1, KeyError, id="no_units"), pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), - pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.mm, KeyError, id="compatible_unit"), pytest.param(unit_registry.m, None, id="identical_unit"), ), ) @@ -4736,9 +4743,9 @@ def test_loc(self, raw_values, unit, error, dtype): values = raw_values * unit - if error is not None and not ( - isinstance(raw_values, (int, float)) and x.check(unit) - ): + # TODO: if we choose dm as compatible unit, single value keys + # can be found. Should we check that? + if error is not None: with pytest.raises(error): ds.loc[{"x": values}] @@ -4746,12 +4753,14 @@ def test_loc(self, raw_values, unit, error, dtype): expected = attach_units( strip_units(ds).loc[ - {"x": strip_units(convert_units(values, {None: x.units}))} + {"x": strip_units(convert_units(values, {None: unit_registry.m}))} ], - {"a": array1.units, "b": array2.units, "x": x.units}, + extract_units(ds), ) actual = ds.loc[{"x": values}] - assert_equal_with_units(expected, actual) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", @@ -4762,14 +4771,34 @@ def test_loc(self, raw_values, unit, error, dtype): ), ids=repr, ) - def test_head_tail_thin(self, func, dtype): - array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK - array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_registry.Pa + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_head_tail_thin(self, func, variant, dtype): + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit_a, unit_b), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_a + array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_b coords = { - "x": np.arange(10) * unit_registry.m, - "y": np.arange(5) * unit_registry.m, - "z": np.arange(8) * unit_registry.m, + "x": np.arange(10) * dim_unit, + "y": np.arange(5) * dim_unit, + "z": np.arange(8) * dim_unit, + "u": ("x", np.linspace(0, 1, 10) * coord_unit), + "v": ("y", np.linspace(1, 2, 5) * coord_unit), + "w": ("z", np.linspace(-1, 0, 8) * coord_unit), } ds = xr.Dataset( @@ -4783,8 +4812,10 @@ def test_head_tail_thin(self, func, dtype): expected = attach_units(func(strip_units(ds)), extract_units(ds)) actual = func(ds) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) + @pytest.mark.parametrize("dim", ("x", "y", "z", "t", "all")) @pytest.mark.parametrize( "shape", ( @@ -4795,13 +4826,9 @@ def test_head_tail_thin(self, func, dtype): pytest.param((1, 10, 1, 20), id="first and last dimension squeezable"), ), ) - def test_squeeze(self, shape, dtype): + def test_squeeze(self, shape, dim, dtype): names = "xyzt" - coords = { - name: np.arange(length).astype(dtype) - * (unit_registry.m if name != "t" else unit_registry.s) - for name, length in zip(names, shape) - } + dim_lengths = dict(zip(names, shape)) array1 = ( np.linspace(0, 1, 10 * 20).astype(dtype).reshape(shape) * unit_registry.degK ) @@ -4811,74 +4838,59 @@ def test_squeeze(self, shape, dtype): ds = xr.Dataset( data_vars={ - "a": xr.DataArray(data=array1, dims=tuple(names[: len(shape)])), - "b": xr.DataArray(data=array2, dims=tuple(names[: len(shape)])), + "a": (tuple(names[: len(shape)]), array1), + "b": (tuple(names[: len(shape)]), array2), }, - coords=coords, ) units = extract_units(ds) - expected = attach_units(strip_units(ds).squeeze(), units) + kwargs = {"dim": dim} if dim != "all" and dim_lengths.get(dim, 0) == 1 else {} - actual = ds.squeeze() - assert_equal_with_units(actual, expected) + expected = attach_units(strip_units(ds).squeeze(**kwargs), units) - # try squeezing the dimensions separately - names = tuple(dim for dim, coord in coords.items() if len(coord) == 1) - for name in names: - expected = attach_units(strip_units(ds).squeeze(dim=name), units) - actual = ds.squeeze(dim=name) - assert_equal_with_units(actual, expected) + actual = ds.squeeze(**kwargs) + + assert_units_equal(expected, actual) + assert_equal(expected, actual) - @pytest.mark.xfail(reason="ignores units") + @pytest.mark.parametrize("variant", ("data", "coords")) @pytest.mark.parametrize( - "unit,error", + "func", ( - pytest.param(1, DimensionalityError, id="no_unit"), pytest.param( - unit_registry.dimensionless, DimensionalityError, id="dimensionless" + method("interp"), marks=pytest.mark.xfail(reason="uses scipy") ), - pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.cm, None, id="compatible_unit"), - pytest.param(unit_registry.m, None, id="identical_unit"), + method("reindex"), ), + ids=repr, ) - def test_interp(self, unit, error): - array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK - array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_registry.Pa - - coords = { - "x": np.arange(10) * unit_registry.m, - "y": np.arange(5) * unit_registry.m, - "z": np.arange(8) * unit_registry.s, + def test_interp_reindex(self, func, variant, dtype): + variants = { + "data": (unit_registry.m, 1), + "coords": (1, unit_registry.m), } + data_unit, coord_unit = variants.get(variant) - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "z")), - }, - coords=coords, - ) - - new_coords = (np.arange(10) + 0.5) * unit + array1 = np.linspace(-1, 0, 10).astype(dtype) * data_unit + array2 = np.linspace(0, 1, 10).astype(dtype) * data_unit - if error is not None: - with pytest.raises(error): - ds.interp(x=new_coords) + y = np.arange(10) * coord_unit - return + x = np.arange(10) + new_x = np.arange(8) + 0.5 - units = extract_units(ds) - expected = attach_units( - strip_units(ds).interp(x=strip_units(convert_units(new_coords, units))), - units, + ds = xr.Dataset( + {"a": ("x", array1), "b": ("x", array2)}, coords={"x": x, "y": ("x", y)} ) - actual = ds.interp(x=new_coords) + units = extract_units(ds) - assert_equal_with_units(actual, expected) + expected = attach_units(func(strip_units(ds), x=new_x), units) + actual = func(ds, x=new_x) - @pytest.mark.xfail(reason="ignores units") + assert_units_equal(expected, actual) + assert_equal(expected, actual) + + @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( "unit,error", ( @@ -4891,106 +4903,67 @@ def test_interp(self, unit, error): pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - def test_interp_like(self, unit, error, dtype): - array1 = ( - np.linspace(0, 10, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK - ) - array2 = ( - np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa - ) - - coords = { - "x": np.arange(10) * unit_registry.m, - "y": np.arange(5) * unit_registry.m, - "z": np.arange(8) * unit_registry.m, - } + @pytest.mark.parametrize("func", (method("interp"), method("reindex")), ids=repr) + def test_interp_reindex_indexing(self, func, unit, error, dtype): + array1 = np.linspace(-1, 0, 10).astype(dtype) + array2 = np.linspace(0, 1, 10).astype(dtype) - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "z")), - }, - coords=coords, - ) + x = np.arange(10) * unit_registry.m + new_x = (np.arange(8) + 0.5) * unit - other = xr.Dataset( - data_vars={ - "c": xr.DataArray(data=np.empty((20, 10)), dims=("x", "y")), - "d": xr.DataArray(data=np.empty((20, 15)), dims=("x", "z")), - }, - coords={ - "x": (np.arange(20) + 0.3) * unit, - "y": (np.arange(10) - 0.2) * unit, - "z": (np.arange(15) + 0.4) * unit, - }, - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}, coords={"x": x}) + units = extract_units(ds) if error is not None: with pytest.raises(error): - ds.interp_like(other) + func(ds, x=new_x) return - units = extract_units(ds) - expected = attach_units( - strip_units(ds).interp_like(strip_units(convert_units(other, units))), units - ) - actual = ds.interp_like(other) + expected = attach_units(func(strip_units(ds), x=new_x), units) + actual = func(ds, x=new_x) - assert_equal_with_units(actual, expected) + assert_units_equal(expected, actual) + assert_equal(expected, actual) - @pytest.mark.xfail(reason="indexes don't support units") + @pytest.mark.parametrize("variant", ("data", "coords")) @pytest.mark.parametrize( - "unit,error", + "func", ( - pytest.param(1, DimensionalityError, id="no_unit"), pytest.param( - unit_registry.dimensionless, DimensionalityError, id="dimensionless" + method("interp_like"), marks=pytest.mark.xfail(reason="uses scipy") ), - pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.cm, None, id="compatible_unit"), - pytest.param(unit_registry.m, None, id="identical_unit"), + method("reindex_like"), ), + ids=repr, ) - def test_reindex(self, unit, error, dtype): - array1 = ( - np.linspace(1, 2, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK - ) - array2 = ( - np.linspace(1, 2, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa - ) - - coords = { - "x": np.arange(10) * unit_registry.m, - "y": np.arange(5) * unit_registry.m, - "z": np.arange(8) * unit_registry.s, + def test_interp_reindex_like(self, func, variant, dtype): + variants = { + "data": (unit_registry.m, 1), + "coords": (1, unit_registry.m), } + data_unit, coord_unit = variants.get(variant) - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "z")), - }, - coords=coords, - ) - - new_coords = (np.arange(10) + 0.5) * unit + array1 = np.linspace(-1, 0, 10).astype(dtype) * data_unit + array2 = np.linspace(0, 1, 10).astype(dtype) * data_unit - if error is not None: - with pytest.raises(error): - ds.reindex(x=new_coords) + y = np.arange(10) * coord_unit - return + x = np.arange(10) + new_x = np.arange(8) + 0.5 - expected = attach_units( - strip_units(ds).reindex( - x=strip_units(convert_units(new_coords, {None: coords["x"].units})) - ), - extract_units(ds), + ds = xr.Dataset( + {"a": ("x", array1), "b": ("x", array2)}, coords={"x": x, "y": ("x", y)} ) - actual = ds.reindex(x=new_coords) + units = extract_units(ds) + + other = xr.Dataset({"a": ("x", np.empty_like(new_x))}, coords={"x": new_x}) + + expected = attach_units(func(strip_units(ds), other), units) + actual = func(ds, other) - assert_equal_with_units(actual, expected) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( @@ -5005,54 +4978,32 @@ def test_reindex(self, unit, error, dtype): pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - def test_reindex_like(self, unit, error, dtype): - array1 = ( - np.linspace(0, 10, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK - ) - array2 = ( - np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa - ) + @pytest.mark.parametrize( + "func", (method("interp_like"), method("reindex_like")), ids=repr + ) + def test_interp_reindex_like_indexing(self, func, unit, error, dtype): + array1 = np.linspace(-1, 0, 10).astype(dtype) + array2 = np.linspace(0, 1, 10).astype(dtype) - coords = { - "x": np.arange(10) * unit_registry.m, - "y": np.arange(5) * unit_registry.m, - "z": np.arange(8) * unit_registry.m, - } + x = np.arange(10) * unit_registry.m + new_x = (np.arange(8) + 0.5) * unit - ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "z")), - }, - coords=coords, - ) + ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}, coords={"x": x}) + units = extract_units(ds) - other = xr.Dataset( - data_vars={ - "c": xr.DataArray(data=np.empty((20, 10)), dims=("x", "y")), - "d": xr.DataArray(data=np.empty((20, 15)), dims=("x", "z")), - }, - coords={ - "x": (np.arange(20) + 0.3) * unit, - "y": (np.arange(10) - 0.2) * unit, - "z": (np.arange(15) + 0.4) * unit, - }, - ) + other = xr.Dataset({"a": ("x", np.empty_like(new_x))}, coords={"x": new_x}) if error is not None: with pytest.raises(error): - ds.reindex_like(other) + func(ds, other) return - units = extract_units(ds) - expected = attach_units( - strip_units(ds).reindex_like(strip_units(convert_units(other, units))), - units, - ) - actual = ds.reindex_like(other) + expected = attach_units(func(strip_units(ds), other), units) + actual = func(ds, other) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", @@ -5062,30 +5013,46 @@ def test_reindex_like(self, unit, error, dtype): method("integrate", coord="x"), pytest.param( method("quantile", q=[0.25, 0.75]), - marks=pytest.mark.xfail(reason="nanquantile not implemented"), + marks=pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", + reason="nanquantile not implemented yet", + ), ), method("reduce", func=np.sum, dim="x"), method("map", np.fabs), ), ids=repr, ) - def test_computation(self, func, dtype): - array1 = ( - np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK - ) - array2 = ( - np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa - ) - x = np.arange(10) * unit_registry.m - y = np.arange(5) * unit_registry.m - z = np.arange(8) * unit_registry.m + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_computation(self, func, variant, dtype): + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit1, unit2), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-5, 5, 4 * 5).reshape(4, 5).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 4 * 3).reshape(4, 3).astype(dtype) * unit2 + x = np.arange(4) * dim_unit + y = np.arange(5) * dim_unit + z = np.arange(3) * dim_unit ds = xr.Dataset( data_vars={ "a": xr.DataArray(data=array1, dims=("x", "y")), "b": xr.DataArray(data=array2, dims=("x", "z")), }, - coords={"x": x, "y": y, "z": z}, + coords={"x": x, "y": y, "z": z, "y2": ("y", np.arange(5) * coord_unit)}, ) units = extract_units(ds) @@ -5093,69 +5060,105 @@ def test_computation(self, func, dtype): expected = attach_units(func(strip_units(ds)), units) actual = func(ds) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", ( method("groupby", "x"), - method("groupby_bins", "x", bins=4), + pytest.param( + method("groupby_bins", "x", bins=2), + marks=pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", + reason="needs assert_allclose but that does not work with pint", + ), + ), method("coarsen", x=2), pytest.param( method("rolling", x=3), marks=pytest.mark.xfail(reason="strips units") ), pytest.param( method("rolling_exp", x=3), - marks=pytest.mark.xfail(reason="uses numbagg which strips units"), + marks=pytest.mark.xfail( + reason="numbagg functions are not supported by pint" + ), ), ), ids=repr, ) - def test_computation_objects(self, func, dtype): - array1 = ( - np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK - ) - array2 = ( - np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype) - * unit_registry.Pa - ) - x = np.arange(10) * unit_registry.m - y = np.arange(5) * unit_registry.m - z = np.arange(8) * unit_registry.m + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_computation_objects(self, func, variant, dtype): + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit1, unit2), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-5, 5, 4 * 5).reshape(4, 5).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 4 * 3).reshape(4, 3).astype(dtype) * unit2 + x = np.arange(4) * dim_unit + y = np.arange(5) * dim_unit + z = np.arange(3) * dim_unit ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "y", "z")), - }, - coords={"x": x, "y": y, "z": z}, + data_vars={"a": (("x", "y"), array1), "b": (("x", "z"), array2)}, + coords={"x": x, "y": y, "z": z, "y2": ("y", np.arange(5) * coord_unit)}, ) units = extract_units(ds) args = [] if func.name != "groupby" else ["y"] - reduce_func = method("mean", *args) - expected = attach_units(reduce_func(func(strip_units(ds))), units) - actual = reduce_func(func(ds)) + expected = attach_units(func(strip_units(ds)).mean(*args), units) + actual = func(ds).mean(*args) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + # TODO: remove once pint 0.12 has been released + if LooseVersion(pint.__version__) <= "0.12": + assert_equal(expected, actual) + else: + assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_resample(self, variant, dtype): + # TODO: move this to test_computation_objects + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit1, unit2), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit2 - def test_resample(self, dtype): - array1 = ( - np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK - ) - array2 = ( - np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa - ) t = pd.date_range("10-09-2010", periods=array1.shape[0], freq="1y") - y = np.arange(5) * unit_registry.m - z = np.arange(8) * unit_registry.m + y = np.arange(5) * dim_unit + z = np.arange(8) * dim_unit + + u = np.linspace(-1, 0, 5) * coord_unit ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("time", "y")), - "b": xr.DataArray(data=array2, dims=("time", "z")), - }, - coords={"time": t, "y": y, "z": z}, + data_vars={"a": (("time", "y"), array1), "b": (("time", "z"), array2)}, + coords={"time": t, "y": y, "z": z, "u": ("y", u)}, ) units = extract_units(ds) @@ -5164,43 +5167,59 @@ def test_resample(self, dtype): expected = attach_units(func(strip_units(ds)).mean(), units) actual = func(ds).mean() - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", ( method("assign", c=lambda ds: 10 * ds.b), - method("assign_coords", v=("x", np.arange(10) * unit_registry.s)), + method("assign_coords", v=("x", np.arange(5) * unit_registry.s)), method("first"), method("last"), pytest.param( method("quantile", q=[0.25, 0.5, 0.75], dim="x"), - marks=pytest.mark.xfail(reason="nanquantile not implemented"), + marks=pytest.mark.xfail( + LooseVersion(pint.__version__) <= "0.12", + reason="nanquantile not implemented", + ), ), ), ids=repr, ) - def test_grouped_operations(self, func, dtype): - array1 = ( - np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK - ) - array2 = ( - np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype) - * unit_registry.Pa - ) - x = np.arange(10) * unit_registry.m - y = np.arange(5) * unit_registry.m - z = np.arange(8) * unit_registry.m + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_grouped_operations(self, func, variant, dtype): + variants = { + "data": ((unit_registry.degK, unit_registry.Pa), 1, 1), + "dims": ((1, 1), unit_registry.m, 1), + "coords": ((1, 1), 1, unit_registry.m), + } + (unit1, unit2), dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(-5, 5, 5 * 4).reshape(5, 4).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 5 * 4 * 3).reshape(5, 4, 3).astype(dtype) * unit2 + x = np.arange(5) * dim_unit + y = np.arange(4) * dim_unit + z = np.arange(3) * dim_unit + + u = np.linspace(-1, 0, 4) * coord_unit ds = xr.Dataset( - data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "y", "z")), - }, - coords={"x": x, "y": y, "z": z}, + data_vars={"a": (("x", "y"), array1), "b": (("x", "y", "z"), array2)}, + coords={"x": x, "y": y, "z": z, "u": ("y", u)}, ) - units = extract_units(ds) - units.update({"c": unit_registry.Pa, "v": unit_registry.s}) + + assigned_units = {"c": unit2, "v": unit_registry.s} + units = merge_mappings(extract_units(ds), assigned_units) stripped_kwargs = { name: strip_units(value) for name, value in func.kwargs.items() @@ -5210,20 +5229,26 @@ def test_grouped_operations(self, func, dtype): ) actual = func(ds.groupby("y")) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "func", ( method("pipe", lambda ds: ds * 10), method("assign", d=lambda ds: ds.b * 10), - method("assign_coords", y2=("y", np.arange(5) * unit_registry.mm)), + method("assign_coords", y2=("y", np.arange(4) * unit_registry.mm)), method("assign_attrs", attr1="value"), method("rename", x2="x_mm"), method("rename_vars", c="temperature"), method("rename_dims", x="offset_x"), - method("swap_dims", {"x": "x2"}), - method("expand_dims", v=np.linspace(10, 20, 12) * unit_registry.s, axis=1), + method("swap_dims", {"x": "u"}), + pytest.param( + method( + "expand_dims", v=np.linspace(10, 20, 12) * unit_registry.s, axis=1 + ), + marks=pytest.mark.xfail(reason="indexes don't support units"), + ), method("drop_vars", "x"), method("drop_dims", "z"), method("set_coords", names="c"), @@ -5232,40 +5257,55 @@ def test_grouped_operations(self, func, dtype): ), ids=repr, ) - def test_content_manipulation(self, func, dtype): - array1 = ( - np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) - * unit_registry.m ** 3 - ) - array2 = ( - np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype) - * unit_registry.Pa - ) - array3 = np.linspace(0, 10, 10).astype(dtype) * unit_registry.degK + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + ), + "coords", + ), + ) + def test_content_manipulation(self, func, variant, dtype): + variants = { + "data": ( + (unit_registry.m ** 3, unit_registry.Pa, unit_registry.degK), + 1, + 1, + ), + "dims": ((1, 1, 1), unit_registry.m, 1), + "coords": ((1, 1, 1), 1, unit_registry.m), + } + (unit1, unit2, unit3), dim_unit, coord_unit = variants.get(variant) - x = np.arange(10) * unit_registry.m - x2 = x.to(unit_registry.mm) - y = np.arange(5) * unit_registry.m - z = np.arange(8) * unit_registry.m + array1 = np.linspace(-5, 5, 5 * 4).reshape(5, 4).astype(dtype) * unit1 + array2 = np.linspace(10, 20, 5 * 4 * 3).reshape(5, 4, 3).astype(dtype) * unit2 + array3 = np.linspace(0, 10, 5).astype(dtype) * unit3 + + x = np.arange(5) * dim_unit + y = np.arange(4) * dim_unit + z = np.arange(3) * dim_unit + + x2 = np.linspace(-1, 0, 5) * coord_unit ds = xr.Dataset( data_vars={ - "a": xr.DataArray(data=array1, dims=("x", "y")), - "b": xr.DataArray(data=array2, dims=("x", "y", "z")), - "c": xr.DataArray(data=array3, dims="x"), + "a": (("x", "y"), array1), + "b": (("x", "y", "z"), array2), + "c": ("x", array3), }, coords={"x": x, "y": y, "z": z, "x2": ("x", x2)}, ) - units = { - **extract_units(ds), - **{ - "y2": unit_registry.mm, - "x_mm": unit_registry.mm, - "offset_x": unit_registry.m, - "d": unit_registry.Pa, - "temperature": unit_registry.degK, - }, + + new_units = { + "y2": unit_registry.mm, + "x_mm": coord_unit, + "offset_x": unit_registry.m, + "d": unit2, + "temperature": unit3, } + units = merge_mappings(extract_units(ds), new_units) stripped_kwargs = { key: strip_units(value) for key, value in func.kwargs.items() @@ -5273,7 +5313,8 @@ def test_content_manipulation(self, func, dtype): expected = attach_units(func(strip_units(ds), **stripped_kwargs), units) actual = func(ds) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) @pytest.mark.parametrize( "unit,error", @@ -5298,25 +5339,29 @@ def test_content_manipulation(self, func, dtype): ), ) def test_merge(self, variant, unit, error, dtype): - original_data_unit = unit_registry.m - original_dim_unit = unit_registry.m - original_coord_unit = unit_registry.m + left_variants = { + "data": (unit_registry.m, 1, 1), + "dims": (1, unit_registry.m, 1), + "coords": (1, 1, unit_registry.m), + } - variants = { - "data": (unit, original_dim_unit, original_coord_unit), - "dims": (original_data_unit, unit, original_coord_unit), - "coords": (original_data_unit, original_dim_unit, unit), + left_data_unit, left_dim_unit, left_coord_unit = left_variants.get(variant) + + right_variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), } - data_unit, dim_unit, coord_unit = variants.get(variant) + right_data_unit, right_dim_unit, right_coord_unit = right_variants.get(variant) - left_array = np.arange(10).astype(dtype) * original_data_unit - right_array = np.arange(-5, 5).astype(dtype) * data_unit + left_array = np.arange(10).astype(dtype) * left_data_unit + right_array = np.arange(-5, 5).astype(dtype) * right_data_unit - left_dim = np.arange(10, 20) * original_dim_unit - right_dim = np.arange(5, 15) * dim_unit + left_dim = np.arange(10, 20) * left_dim_unit + right_dim = np.arange(5, 15) * right_dim_unit - left_coord = np.arange(-10, 0) * original_coord_unit - right_coord = np.arange(-15, -5) * coord_unit + left_coord = np.arange(-10, 0) * left_coord_unit + right_coord = np.arange(-15, -5) * right_coord_unit left = xr.Dataset( data_vars={"a": ("x", left_array)}, @@ -5339,4 +5384,5 @@ def test_merge(self, variant, unit, error, dtype): expected = attach_units(strip_units(left).merge(strip_units(converted)), units) actual = left.merge(right) - assert_equal_with_units(expected, actual) + assert_units_equal(expected, actual) + assert_equal(expected, actual) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 3003e0d66f3..d79d40d67c0 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1657,7 +1657,7 @@ def test_reduce_funcs(self): assert_identical(v.all(dim="x"), Variable([], False)) v = Variable("t", pd.date_range("2000-01-01", periods=3)) - assert v.argmax(skipna=True) == 2 + assert v.argmax(skipna=True, dim="t") == 2 assert_identical(v.max(), Variable([], pd.Timestamp("2000-01-03")))