Skip to content

Commit

Permalink
Allow all integer dtypes in polyval (#7619)
Browse files Browse the repository at this point in the history
* Fixed unneeded ValueError

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed unneeded ValueError

* Revert "Merge"

This reverts commit 87a82a2, reversing
changes made to fb27a96.

* add to whats-new

* add tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: dcherian <[email protected]>
Co-authored-by: Tom Nicholas <[email protected]>
Co-authored-by: Michael Niklas <[email protected]>
  • Loading branch information
5 people authored Mar 22, 2023
1 parent ed09383 commit 1e361cc
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ Deprecations
Bug fixes
~~~~~~~~~

- Fix :py:meth:`xr.polyval` with non-system standard integer coeffs (:pull:`7619`).
By `Shreyal Gupta <https://github.com/Ravenin7>`_ and `Michael Niklas <https://github.com/headtr1ck>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,7 +1968,7 @@ def polyval(
raise ValueError(
f"Dimension `{degree_dim}` should be a coordinate variable with labels."
)
if not np.issubdtype(coeffs[degree_dim].dtype, int):
if not np.issubdtype(coeffs[degree_dim].dtype, np.integer):
raise ValueError(
f"Dimension `{degree_dim}` should be of integer dtype. Received {coeffs[degree_dim].dtype} instead."
)
Expand Down
30 changes: 30 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2092,6 +2092,36 @@ def test_where_attrs() -> None:
xr.DataArray([1000.0, 2000.0, 3000.0], dims="x"),
id="timedelta",
),
pytest.param(
xr.DataArray([1, 2, 3], dims="x"),
xr.DataArray(
[2, 3, 4],
dims="degree",
coords={"degree": np.array([0, 1, 2], dtype=np.int64)},
),
xr.DataArray([9, 2 + 6 + 16, 2 + 9 + 36], dims="x"),
id="int64-degree",
),
pytest.param(
xr.DataArray([1, 2, 3], dims="x"),
xr.DataArray(
[2, 3, 4],
dims="degree",
coords={"degree": np.array([0, 1, 2], dtype=np.int32)},
),
xr.DataArray([9, 2 + 6 + 16, 2 + 9 + 36], dims="x"),
id="int32-degree",
),
pytest.param(
xr.DataArray([1, 2, 3], dims="x"),
xr.DataArray(
[2, 3, 4],
dims="degree",
coords={"degree": np.array([0, 1, 2], dtype=np.uint8)},
),
xr.DataArray([9, 2 + 6 + 16, 2 + 9 + 36], dims="x"),
id="uint8-degree",
),
],
)
def test_polyval(
Expand Down

0 comments on commit 1e361cc

Please sign in to comment.