diff --git a/xbatcher/accessors.py b/xbatcher/accessors.py index 43a7d9c..a9d19be 100644 --- a/xbatcher/accessors.py +++ b/xbatcher/accessors.py @@ -1,8 +1,21 @@ +from typing import Union + import xarray as xr from .generators import BatchGenerator +def _as_xarray_dataarray(xr_obj: Union[xr.Dataset, xr.DataArray]) -> xr.DataArray: + """ + Convert xarray.Dataset to xarray.DataArray if needed, so that it can + be converted into a Tensor object. + """ + if isinstance(xr_obj, xr.Dataset): + xr_obj = xr_obj.to_array().squeeze(dim="variable") + + return xr_obj + + @xr.register_dataarray_accessor("batch") @xr.register_dataset_accessor("batch") class BatchAccessor: @@ -26,31 +39,32 @@ def generator(self, *args, **kwargs): return BatchGenerator(self._obj, *args, **kwargs) +@xr.register_dataarray_accessor("tf") +@xr.register_dataset_accessor("tf") +class TFAccessor: + def __init__(self, xarray_obj): + self._obj = xarray_obj + + def to_tensor(self): + """Convert this DataArray to a tensorflow.Tensor""" + import tensorflow as tf + + dataarray = _as_xarray_dataarray(xr_obj=self._obj) + + return tf.convert_to_tensor(dataarray.data) + + @xr.register_dataarray_accessor("torch") @xr.register_dataset_accessor("torch") class TorchAccessor: def __init__(self, xarray_obj): self._obj = xarray_obj - def _as_xarray_dataarray(self, xr_obj): - """ - Convert xarray.Dataset to xarray.DataArray if needed, so that it can - be converted into a torch.Tensor object. - """ - try: - # Convert xr.Dataset to xr.DataArray - dataarray = xr_obj.to_array().squeeze(dim="variable") - except AttributeError: # 'DataArray' object has no attribute 'to_array' - # If object is already an xr.DataArray - dataarray = xr_obj - - return dataarray - def to_tensor(self): """Convert this DataArray to a torch.Tensor""" import torch - dataarray = self._as_xarray_dataarray(xr_obj=self._obj) + dataarray = _as_xarray_dataarray(xr_obj=self._obj) return torch.tensor(data=dataarray.data) @@ -62,6 +76,6 @@ def to_named_tensor(self): """ import torch - dataarray = self._as_xarray_dataarray(xr_obj=self._obj) + dataarray = _as_xarray_dataarray(xr_obj=self._obj) return torch.tensor(data=dataarray.data, names=tuple(dataarray.sizes)) diff --git a/xbatcher/tests/test_accessors.py b/xbatcher/tests/test_accessors.py index 18d24e0..cb3b37e 100644 --- a/xbatcher/tests/test_accessors.py +++ b/xbatcher/tests/test_accessors.py @@ -22,6 +22,30 @@ def sample_ds_3d(): return ds +@pytest.fixture(scope="module") +def sample_dataArray(): + return xr.DataArray(np.zeros((2, 4), dtype="i4"), dims=("x", "y"), name="foo") + + +@pytest.fixture(scope="module") +def sample_Dataset(): + return xr.Dataset( + { + "x": xr.DataArray(np.arange(10), dims="x"), + "foo": xr.DataArray(np.ones(10, dtype="float"), dims="x"), + } + ) + + +def test_as_xarray_dataarray(sample_dataArray, sample_Dataset): + assert isinstance( + xbatcher.accessors._as_xarray_dataarray(sample_dataArray), xr.DataArray + ) + assert isinstance( + xbatcher.accessors._as_xarray_dataarray(sample_Dataset), xr.DataArray + ) + + def test_batch_accessor_ds(sample_ds_3d): bg_class = BatchGenerator(sample_ds_3d, input_dims={"x": 5}) bg_acc = sample_ds_3d.batch.generator(input_dims={"x": 5}) @@ -40,6 +64,25 @@ def test_batch_accessor_da(sample_ds_3d): assert batch_class.equals(batch_acc) +@pytest.mark.parametrize( + "foo_var", + [ + "foo", # xr.DataArray + ["foo"], # xr.Dataset + ], +) +def test_tf_to_tensor(sample_ds_3d, foo_var): + tf = pytest.importorskip("tensorflow") + + foo = sample_ds_3d[foo_var] + t = foo.tf.to_tensor() + assert isinstance(t, tf.Tensor) + assert t.shape == tuple(foo.sizes.values()) + + foo_array = foo.to_array().squeeze() if hasattr(foo, "to_array") else foo + np.testing.assert_array_equal(t, foo_array.values) + + @pytest.mark.parametrize( "foo_var", [