diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 2fedade0a60..1c8564ff9ae 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -46,6 +46,9 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Minor improvements to support of the python `array api standard `_, + internally using the function ``xp.astype()`` instead of the method ``arr.astype()``, as the latter is not in the standard. + (:pull:`7847`) By `Tom Nicholas `_. .. _whats-new.2023.05.0: diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index c6c4af87d1c..31028f10350 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -51,6 +51,7 @@ import numpy as np +from xarray.core import duck_array_ops from xarray.core.computation import apply_ufunc from xarray.core.types import T_DataArray @@ -2085,13 +2086,16 @@ def _get_res_multi(val, pat): else: # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 - return self._apply( - func=_get_res_multi, - func_args=(pat,), - dtype=np.object_, - output_core_dims=[[dim]], - output_sizes={dim: maxgroups}, - ).astype(self._obj.dtype.kind) + return duck_array_ops.astype( + self._apply( + func=_get_res_multi, + func_args=(pat,), + dtype=np.object_, + output_core_dims=[[dim]], + output_sizes={dim: maxgroups}, + ), + self._obj.dtype.kind, + ) def extractall( self, @@ -2258,15 +2262,18 @@ def _get_res(val, ipat, imaxcount=maxcount, dtype=self._obj.dtype): return res - return self._apply( - # dtype MUST be object or strings can be truncated - # See: https://github.com/numpy/numpy/issues/8352 - func=_get_res, - func_args=(pat,), - dtype=np.object_, - output_core_dims=[[group_dim, match_dim]], - output_sizes={group_dim: maxgroups, match_dim: maxcount}, - ).astype(self._obj.dtype.kind) + return duck_array_ops.astype( + self._apply( + # dtype MUST be object or strings can be truncated + # See: https://github.com/numpy/numpy/issues/8352 + func=_get_res, + func_args=(pat,), + dtype=np.object_, + output_core_dims=[[group_dim, match_dim]], + output_sizes={group_dim: maxgroups, match_dim: maxcount}, + ), + self._obj.dtype.kind, + ) def findall( self, @@ -2385,13 +2392,16 @@ def _partitioner( # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 - return self._apply( - func=arrfunc, - func_args=(sep,), - dtype=np.object_, - output_core_dims=[[dim]], - output_sizes={dim: 3}, - ).astype(self._obj.dtype.kind) + return duck_array_ops.astype( + self._apply( + func=arrfunc, + func_args=(sep,), + dtype=np.object_, + output_core_dims=[[dim]], + output_sizes={dim: 3}, + ), + self._obj.dtype.kind, + ) def partition( self, @@ -2510,13 +2520,16 @@ def _dosplit(mystr, sep, maxsplit=maxsplit, dtype=self._obj.dtype): # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 - return self._apply( - func=_dosplit, - func_args=(sep,), - dtype=np.object_, - output_core_dims=[[dim]], - output_sizes={dim: maxsplit}, - ).astype(self._obj.dtype.kind) + return duck_array_ops.astype( + self._apply( + func=_dosplit, + func_args=(sep,), + dtype=np.object_, + output_core_dims=[[dim]], + output_sizes={dim: maxsplit}, + ), + self._obj.dtype.kind, + ) def split( self, diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 27a5399d639..f4813af9782 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1421,7 +1421,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA): pads = [(0, 0) if d != dim else dim_pad for d in self.dims] data = np.pad( - trimmed_data.astype(dtype), + duck_array_ops.astype(trimmed_data, dtype), pads, mode="constant", constant_values=fill_value, @@ -1570,7 +1570,7 @@ def pad( pad_option_kwargs["reflect_type"] = reflect_type array = np.pad( - self.data.astype(dtype, copy=False), + duck_array_ops.astype(self.data, dtype, copy=False), pad_width_by_index, mode=mode, **pad_option_kwargs, @@ -2438,7 +2438,7 @@ def rolling_window( """ if fill_value is dtypes.NA: # np.nan is passed dtype, fill_value = dtypes.maybe_promote(self.dtype) - var = self.astype(dtype, copy=False) + var = duck_array_ops.astype(self, dtype, copy=False) else: dtype = self.dtype var = self