Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add to_numpy() and as_numpy() methods #5568

Merged
merged 33 commits into from
Jul 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
17c5755
added to_numpy() and as_numpy() methods
TomNicholas Jul 2, 2021
48ba107
remove special-casing of cupy arrays in .values in favour of using .t…
TomNicholas Jul 2, 2021
ae6e931
lint
max-sixty Jul 2, 2021
dc24d3f
Fix mypy (I think?)
max-sixty Jul 2, 2021
6ce6b05
Merge branch 'main' of https://github.com/pydata/xarray into to_numpy
TomNicholas Jul 3, 2021
04d7b02
Merge branch 'to_numpy' of https://github.com/TomNicholas/xarray into…
TomNicholas Jul 3, 2021
ee34649
added Dataset.as_numpy()
TomNicholas Jul 3, 2021
552b322
improved docstrings
TomNicholas Jul 3, 2021
1215e69
add what's new
TomNicholas Jul 3, 2021
af8a1ee
add to API docs
TomNicholas Jul 3, 2021
e095bf0
linting
TomNicholas Jul 3, 2021
eb7d84d
fix failures by only importing pint when needed
TomNicholas Jul 7, 2021
74c05e3
refactor pycompat into class
TomNicholas Jul 7, 2021
45245d0
compute instead of load
TomNicholas Jul 8, 2021
27fc4e5
added tests
TomNicholas Jul 8, 2021
3e8cb24
fixed sparse test
TomNicholas Jul 8, 2021
f9d6370
tests and fixes for ds.as_numpy()
TomNicholas Jul 9, 2021
50fdf4c
fix sparse tests
TomNicholas Jul 9, 2021
1c94a97
fix linting
TomNicholas Jul 9, 2021
2d07c0f
tests for Variable
TomNicholas Jul 9, 2021
9673cea
test IndexVariable too
TomNicholas Jul 9, 2021
0d624cc
use numpy.asarray to avoid a copy
TomNicholas Jul 12, 2021
2f1ff46
also convert coords
TomNicholas Jul 14, 2021
afd35e2
Merge branch 'main' into to_numpy
TomNicholas Jul 15, 2021
6d33b35
Force tests again after #5600
TomNicholas Jul 16, 2021
eae95f5
Merge branch 'main' into to_numpy
TomNicholas Jul 16, 2021
b90b7e3
Apply suggestions from code review
dcherian Jul 16, 2021
f39b301
Update xarray/core/variable.py
dcherian Jul 16, 2021
8b346d3
fix import
TomNicholas Jul 21, 2021
576ab7b
formatting
TomNicholas Jul 21, 2021
4ed1dd8
Fix fsspec error by merging branch 'main' into to_numpy
TomNicholas Jul 21, 2021
976f89a
remove type check
TomNicholas Jul 21, 2021
7bc5d6f
remove attempt to call to_numpy
TomNicholas Jul 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ Dataset methods
open_zarr
Dataset.to_netcdf
Dataset.to_pandas
Dataset.as_numpy
Dataset.to_zarr
save_mfdataset
Dataset.to_array
Expand Down Expand Up @@ -716,6 +717,8 @@ DataArray methods
DataArray.to_pandas
DataArray.to_series
DataArray.to_dataframe
DataArray.to_numpy
DataArray.as_numpy
DataArray.to_index
DataArray.to_masked_array
DataArray.to_cdms2
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ New Features
- Allow removal of the coordinate attribute ``coordinates`` on variables by setting ``.attrs['coordinates']= None``
(:issue:`5510`).
By `Elle Smith <https://github.com/ellesmith88>`_.
- Added :py:meth:`DataArray.to_numpy`, :py:meth:`DataArray.as_numpy`, and :py:meth:`Dataset.as_numpy`. (:pull:`5568`).
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
52 changes: 47 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,12 @@ def __init__(
self._close = None

def _replace(
self,
self: T_DataArray,
variable: Variable = None,
coords=None,
name: Union[Hashable, None, Default] = _default,
indexes=None,
) -> "DataArray":
) -> T_DataArray:
if variable is None:
variable = self.variable
if coords is None:
Expand Down Expand Up @@ -623,7 +623,16 @@ def __len__(self) -> int:

@property
def data(self) -> Any:
"""The array's data as a dask or numpy array"""
"""
The DataArray's data as an array. The underlying array type
(e.g. dask, sparse, pint) is preserved.

See Also
--------
DataArray.to_numpy
DataArray.as_numpy
DataArray.values
"""
return self.variable.data

@data.setter
Expand All @@ -632,13 +641,46 @@ def data(self, value: Any) -> None:

@property
def values(self) -> np.ndarray:
"""The array's data as a numpy.ndarray"""
"""
The array's data as a numpy.ndarray.

If the array's data is not a numpy.ndarray this will attempt to convert
it naively using np.array(), which will raise an error if the array
type does not support coercion like this (e.g. cupy).
"""
return self.variable.values

@values.setter
def values(self, value: Any) -> None:
self.variable.values = value

def to_numpy(self) -> np.ndarray:
"""
Coerces wrapped data to numpy and returns a numpy.ndarray.

See also
--------
DataArray.as_numpy : Same but returns the surrounding DataArray instead.
Dataset.as_numpy
dcherian marked this conversation as resolved.
Show resolved Hide resolved
DataArray.values
DataArray.data
"""
return self.variable.to_numpy()

def as_numpy(self: T_DataArray) -> T_DataArray:
"""
Coerces wrapped data and coordinates into numpy arrays, returning a DataArray.

See also
--------
DataArray.to_numpy : Same but returns only the data as a numpy.ndarray object.
Dataset.as_numpy : Converts all variables in a Dataset.
dcherian marked this conversation as resolved.
Show resolved Hide resolved
DataArray.values
DataArray.data
"""
coords = {k: v.as_numpy() for k, v in self._coords.items()}
return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes)

@property
def _in_memory(self) -> bool:
return self.variable._in_memory
Expand Down Expand Up @@ -931,7 +973,7 @@ def persist(self, **kwargs) -> "DataArray":
ds = self._to_temp_dataset().persist(**kwargs)
return self._from_temp_dataset(ds)

def copy(self, deep: bool = True, data: Any = None) -> "DataArray":
def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray:
"""Returns a copy of this array.

If `deep=True`, a deep copy is made of the data array.
Expand Down
12 changes: 12 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,18 @@ def copy(self, deep: bool = False, data: Mapping = None) -> "Dataset":

return self._replace(variables, attrs=attrs)

def as_numpy(self: "Dataset") -> "Dataset":
"""
Coerces wrapped data and coordinates into numpy arrays, returning a Dataset.

See also
--------
DataArray.as_numpy
DataArray.to_numpy : Returns only the data as a numpy.ndarray object.
"""
numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()}
return self._replace(variables=numpy_variables)

@property
def _level_coords(self) -> Dict[str, Hashable]:
"""Return a mapping of all MultiIndex levels and their corresponding
Expand Down
76 changes: 46 additions & 30 deletions xarray/core/pycompat.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,63 @@
from distutils.version import LooseVersion
from importlib import import_module

import numpy as np

from .utils import is_duck_array

integer_types = (int, np.integer)

try:
import dask
import dask.array
from dask.base import is_dask_collection

dask_version = LooseVersion(dask.__version__)
class DuckArrayModule:
"""
Solely for internal isinstance and version checks.

# solely for isinstance checks
dask_array_type = (dask.array.Array,)
Motivated by having to only import pint when required (as pint currently imports xarray)
https://github.com/pydata/xarray/pull/5561#discussion_r664815718
"""

def is_duck_dask_array(x):
return is_duck_array(x) and is_dask_collection(x)
def __init__(self, mod):
try:
duck_array_module = import_module(mod)
duck_array_version = LooseVersion(duck_array_module.__version__)

if mod == "dask":
duck_array_type = (import_module("dask.array").Array,)
elif mod == "pint":
duck_array_type = (duck_array_module.Quantity,)
elif mod == "cupy":
duck_array_type = (duck_array_module.ndarray,)
elif mod == "sparse":
duck_array_type = (duck_array_module.SparseArray,)
else:
raise NotImplementedError

except ImportError: # pragma: no cover
duck_array_module = None
duck_array_version = LooseVersion("0.0.0")
duck_array_type = ()

self.module = duck_array_module
self.version = duck_array_version
self.type = duck_array_type
self.available = duck_array_module is not None

except ImportError: # pragma: no cover
dask_version = LooseVersion("0.0.0")
dask_array_type = ()
is_duck_dask_array = lambda _: False
is_dask_collection = lambda _: False

try:
# solely for isinstance checks
import sparse
def is_duck_dask_array(x):
if DuckArrayModule("dask").available:
from dask.base import is_dask_collection

return is_duck_array(x) and is_dask_collection(x)
else:
return False


sparse_version = LooseVersion(sparse.__version__)
sparse_array_type = (sparse.SparseArray,)
except ImportError: # pragma: no cover
sparse_version = LooseVersion("0.0.0")
sparse_array_type = ()
dsk = DuckArrayModule("dask")
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
dask_version = dsk.version
dask_array_type = dsk.type

try:
# solely for isinstance checks
import cupy
sp = DuckArrayModule("sparse")
sparse_array_type = sp.type
sparse_version = sp.version

cupy_version = LooseVersion(cupy.__version__)
cupy_array_type = (cupy.ndarray,)
except ImportError: # pragma: no cover
cupy_version = LooseVersion("0.0.0")
cupy_array_type = ()
cupy_array_type = DuckArrayModule("cupy").type
27 changes: 26 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
from .indexing import BasicIndexer, OuterIndexer, VectorizedIndexer, as_indexable
from .options import _get_keep_attrs
from .pycompat import (
DuckArrayModule,
cupy_array_type,
dask_array_type,
integer_types,
is_duck_dask_array,
sparse_array_type,
)
from .utils import (
NdimSizeLenMixin,
Expand Down Expand Up @@ -259,7 +261,7 @@ def _as_array_or_item(data):

TODO: remove this (replace with np.asarray) once these issues are fixed
"""
data = data.get() if isinstance(data, cupy_array_type) else np.asarray(data)
data = np.asarray(data)
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
if data.ndim == 0:
if data.dtype.kind == "M":
data = np.datetime64(data, "ns")
Expand Down Expand Up @@ -1069,6 +1071,29 @@ def chunk(self, chunks={}, name=None, lock=False):

return self._replace(data=data)

def to_numpy(self) -> np.ndarray:
"""Coerces wrapped data to numpy and returns a numpy.ndarray"""
# TODO an entrypoint so array libraries can choose coercion method?
data = self.data
# TODO first attempt to call .to_numpy() once some libraries implement it
if isinstance(data, dask_array_type):
data = data.compute()
if isinstance(data, cupy_array_type):
data = data.get()
# pint has to be imported dynamically as pint imports xarray
pint_array_type = DuckArrayModule("pint").type
if isinstance(data, pint_array_type):
data = data.magnitude
if isinstance(data, sparse_array_type):
data = data.todense()
data = np.asarray(data)

return data

def as_numpy(self: VariableType) -> VariableType:
"""Coerces wrapped data into a numpy array, returning a Variable."""
return self._replace(data=self.to_numpy())

def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA):
"""
use sparse-array as backend.
Expand Down
1 change: 1 addition & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def LooseVersion(vstring):
has_numbagg, requires_numbagg = _importorskip("numbagg")
has_seaborn, requires_seaborn = _importorskip("seaborn")
has_sparse, requires_sparse = _importorskip("sparse")
has_cupy, requires_cupy = _importorskip("cupy")
has_cartopy, requires_cartopy = _importorskip("cartopy")
# Need Pint 0.15 for __dask_tokenize__ tests for Quantity wrapped Dask Arrays
has_pint_0_15, requires_pint_0_15 = _importorskip("pint", minversion="0.15")
Expand Down
86 changes: 86 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@
has_dask,
raise_if_dask_computes,
requires_bottleneck,
requires_cupy,
requires_dask,
requires_iris,
requires_numbagg,
requires_numexpr,
requires_pint_0_15,
requires_scipy,
requires_sparse,
source_ndarray,
Expand Down Expand Up @@ -7375,3 +7377,87 @@ def test_drop_duplicates(keep):
expected = xr.DataArray(data, dims="time", coords={"time": time}, name="test")
result = ds.drop_duplicates("time", keep=keep)
assert_equal(expected, result)


class TestNumpyCoercion:
# TODO once flexible indexes refactor complete also test coercion of dimension coords
def test_from_numpy(self):
da = xr.DataArray([1, 2, 3], dims="x", coords={"lat": ("x", [4, 5, 6])})

assert_identical(da.as_numpy(), da)
np.testing.assert_equal(da.to_numpy(), np.array([1, 2, 3]))
np.testing.assert_equal(da["lat"].to_numpy(), np.array([4, 5, 6]))

@requires_dask
def test_from_dask(self):
da = xr.DataArray([1, 2, 3], dims="x", coords={"lat": ("x", [4, 5, 6])})
da_chunked = da.chunk(1)

assert_identical(da_chunked.as_numpy(), da.compute())
np.testing.assert_equal(da.to_numpy(), np.array([1, 2, 3]))
np.testing.assert_equal(da["lat"].to_numpy(), np.array([4, 5, 6]))

@requires_pint_0_15
def test_from_pint(self):
from pint import Quantity

arr = np.array([1, 2, 3])
da = xr.DataArray(
Quantity(arr, units="Pa"),
dims="x",
coords={"lat": ("x", Quantity(arr + 3, units="m"))},
)

expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr + 3)})
assert_identical(da.as_numpy(), expected)
np.testing.assert_equal(da.to_numpy(), arr)
np.testing.assert_equal(da["lat"].to_numpy(), arr + 3)

@requires_sparse
def test_from_sparse(self):
import sparse

arr = np.diagflat([1, 2, 3])
sparr = sparse.COO.from_numpy(arr)
da = xr.DataArray(
sparr, dims=["x", "y"], coords={"elev": (("x", "y"), sparr + 3)}
)

expected = xr.DataArray(
arr, dims=["x", "y"], coords={"elev": (("x", "y"), arr + 3)}
)
assert_identical(da.as_numpy(), expected)
np.testing.assert_equal(da.to_numpy(), arr)

@requires_cupy
def test_from_cupy(self):
import cupy as cp

arr = np.array([1, 2, 3])
da = xr.DataArray(
cp.array(arr), dims="x", coords={"lat": ("x", cp.array(arr + 3))}
)

expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr + 3)})
assert_identical(da.as_numpy(), expected)
np.testing.assert_equal(da.to_numpy(), arr)

@requires_dask
@requires_pint_0_15
def test_from_pint_wrapping_dask(self):
import dask
from pint import Quantity

arr = np.array([1, 2, 3])
d = dask.array.from_array(arr)
da = xr.DataArray(
Quantity(d, units="Pa"),
dims="x",
coords={"lat": ("x", Quantity(d, units="m") * 2)},
)

result = da.as_numpy()
result.name = None # remove dask-assigned name
expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr * 2)})
assert_identical(result, expected)
np.testing.assert_equal(da.to_numpy(), arr)
Loading