Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use duck array ops in more places #8267

Merged
merged 6 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ Documentation
Internal Changes
~~~~~~~~~~~~~~~~

- More improvements to support the Python `array API standard <https://data-apis.org/array-api/latest/>`_
by using duck array ops in more places in the codebase. (:pull:`8267`)
By `Tom White <https://github.com/tomwhite>`_.


.. _whats-new.2023.09.0:

Expand Down
13 changes: 7 additions & 6 deletions xarray/core/accessor_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd

from xarray.coding.times import infer_calendar_name
from xarray.core import duck_array_ops
from xarray.core.common import (
_contains_datetime_like_objects,
is_np_datetime_like,
Expand Down Expand Up @@ -50,7 +51,7 @@ def _access_through_cftimeindex(values, name):
from xarray.coding.cftimeindex import CFTimeIndex

if not isinstance(values, CFTimeIndex):
values_as_cftimeindex = CFTimeIndex(values.ravel())
values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
values_as_cftimeindex = CFTimeIndex(duck_array_ops.reshape(values, (-1,)))

Can't we just get used to using reshape instead?

else:
values_as_cftimeindex = values
if name == "season":
Expand All @@ -69,7 +70,7 @@ def _access_through_series(values, name):
"""Coerce an array of datetime-like values to a pandas Series and
access requested datetime component
"""
values_as_series = pd.Series(values.ravel(), copy=False)
values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
values_as_series = pd.Series(duck_array_ops.reshape(values, (-1,)), copy=False)

if name == "season":
months = values_as_series.dt.month.values
field_values = _season_from_months(months)
Expand Down Expand Up @@ -148,10 +149,10 @@ def _round_through_series_or_index(values, name, freq):
from xarray.coding.cftimeindex import CFTimeIndex

if is_np_datetime_like(values.dtype):
values_as_series = pd.Series(values.ravel(), copy=False)
values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
values_as_series = pd.Series(duck_array_ops.reshape(values, (-1,)), copy=False)

method = getattr(values_as_series.dt, name)
else:
values_as_cftimeindex = CFTimeIndex(values.ravel())
values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
values_as_cftimeindex = CFTimeIndex(duck_array_ops.reshape(values, (-1,)))

method = getattr(values_as_cftimeindex, name)

field_values = method(freq=freq).values
Expand Down Expand Up @@ -195,7 +196,7 @@ def _strftime_through_cftimeindex(values, date_format: str):
"""
from xarray.coding.cftimeindex import CFTimeIndex

values_as_cftimeindex = CFTimeIndex(values.ravel())
values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
values_as_cftimeindex = CFTimeIndex(duck_array_ops.reshape(values, (-1,)))


field_values = values_as_cftimeindex.strftime(date_format)
return field_values.values.reshape(values.shape)
Expand All @@ -205,7 +206,7 @@ def _strftime_through_series(values, date_format: str):
"""Coerce an array of datetime-like values to a pandas Series and
apply string formatting
"""
values_as_series = pd.Series(values.ravel(), copy=False)
values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
values_as_series = pd.Series(duck_array_ops.reshape(values, (-1,)), copy=False)

strs = values_as_series.dt.strftime(date_format)
return strs.values.reshape(values.shape)

Expand Down
3 changes: 2 additions & 1 deletion xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2123,7 +2123,8 @@ def _calc_idxminmax(
chunkmanager = get_chunked_array_type(array.data)
chunks = dict(zip(array.dims, array.chunks))
dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim])
res = indx.copy(data=dask_coord[indx.data.ravel()].reshape(indx.shape))
data = dask_coord[duck_array_ops.ravel(indx.data)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data = dask_coord[duck_array_ops.ravel(indx.data)]
data = dask_coord[duck_array_ops.reshape(indx.data, (-1,))]

res = indx.copy(data=duck_array_ops.reshape(data, indx.shape))
# we need to attach back the dim name
res.name = dim
else:
Expand Down
6 changes: 5 additions & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ def reshape(array, shape):
return xp.reshape(array, shape)


def ravel(array):
return reshape(array, (-1,))


Comment on lines +340 to +343
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def ravel(array):
return reshape(array, (-1,))

@contextlib.contextmanager
def _ignore_warnings_if(condition):
if condition:
Expand All @@ -363,7 +367,7 @@ def f(values, axis=None, skipna=None, **kwargs):
values = asarray(values)

if coerce_strings and values.dtype.kind in "SU":
values = values.astype(object)
values = astype(values, object)

func = None
if skipna or (skipna is None and values.dtype.kind in "cfO"):
Expand Down
10 changes: 7 additions & 3 deletions xarray/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from xarray.core import dtypes, nputils, utils
from xarray.core import dtypes, duck_array_ops, nputils, utils
from xarray.core.duck_array_ops import (
astype,
count,
Expand All @@ -21,12 +21,16 @@ def _maybe_null_out(result, axis, mask, min_count=1):
xarray version of pandas.core.nanops._maybe_null_out
"""
if axis is not None and getattr(result, "ndim", False):
null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0
null_mask = (
np.take(mask.shape, axis).prod()
- duck_array_ops.sum(mask, axis)
- min_count
) < 0
dtype, fill_value = dtypes.maybe_promote(result.dtype)
result = where(null_mask, fill_value, astype(result, dtype))

elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES:
null_mask = mask.size - mask.sum()
null_mask = mask.size - duck_array_ops.sum(mask)
result = where(null_mask < min_count, np.nan, result)

return result
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2324,7 +2324,7 @@ def coarsen_reshape(self, windows, boundary, side):
else:
shape.append(variable.shape[i])

return variable.data.reshape(shape), tuple(axes)
return duck_array_ops.reshape(variable.data, shape), tuple(axes)

def isnull(self, keep_attrs: bool | None = None):
"""Test each value in the array for whether it is a missing value.
Expand Down
12 changes: 8 additions & 4 deletions xarray/tests/test_coarsen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import xarray as xr
from xarray import DataArray, Dataset, set_options
from xarray.core import duck_array_ops
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This effectively changes the test from only testing public xarray API to now relying on some internals too, which is not ideal from a code separation perspective.

In practice I think this is probably fine here, and I know you're doing it to allow Cubed arrays to go through the modified tests, but just something to bear in mind.

from xarray.tests import (
assert_allclose,
assert_equal,
Expand Down Expand Up @@ -272,21 +273,24 @@ def test_coarsen_construct(self, dask: bool) -> None:
expected = xr.Dataset(attrs={"foo": "bar"})
expected["vart"] = (
("year", "month"),
ds.vart.data.reshape((-1, 12)),
duck_array_ops.reshape(ds.vart.data, (-1, 12)),
{"a": "b"},
)
expected["varx"] = (
("x", "x_reshaped"),
ds.varx.data.reshape((-1, 5)),
duck_array_ops.reshape(ds.varx.data, (-1, 5)),
{"a": "b"},
)
expected["vartx"] = (
("x", "x_reshaped", "year", "month"),
ds.vartx.data.reshape(2, 5, 4, 12),
duck_array_ops.reshape(ds.vartx.data, (2, 5, 4, 12)),
{"a": "b"},
)
expected["vary"] = ds.vary
expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12)))
expected.coords["time"] = (
("year", "month"),
duck_array_ops.reshape(ds.time.data, (-1, 12)),
)

with raise_if_dask_computes():
actual = ds.coarsen(time=12, x=5).construct(
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ def test_pad_constant_values(self, xr_arg, np_arg):

actual = v.pad(**xr_arg)
expected = np.pad(
np.array(v.data.astype(float)),
np.array(duck_array_ops.astype(v.data, float)),
np_arg,
mode="constant",
constant_values=np.nan,
Expand Down
Loading