Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let torch accessor and dataloader handle either xarray.DataArray or xarray.Dataset inputs #85

Merged
merged 4 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions xbatcher/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,32 @@ def generator(self, *args, **kwargs):


@xr.register_dataarray_accessor('torch')
@xr.register_dataset_accessor('torch')
class TorchAccessor:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def _as_xarray_dataarray(self, xr_obj):
"""
Convert xarray.Dataset to xarray.DataArray if needed, so that it can
be converted into a torch.Tensor object.
"""
try:
# Convert xr.Dataset to xr.DataArray
dataarray = xr_obj.to_array().squeeze(dim='variable')
except AttributeError: # 'DataArray' object has no attribute 'to_array'
# If object is already an xr.DataArray
dataarray = xr_obj

return dataarray

def to_tensor(self):
"""Convert this DataArray to a torch.Tensor"""
import torch

return torch.tensor(self._obj.data)
dataarray = self._as_xarray_dataarray(xr_obj=self._obj)

return torch.tensor(data=dataarray.data)

def to_named_tensor(self):
"""
Expand All @@ -45,4 +62,6 @@ def to_named_tensor(self):
"""
import torch

return torch.tensor(self._obj.data, names=tuple(self._obj.sizes))
dataarray = self._as_xarray_dataarray(xr_obj=self._obj)

return torch.tensor(data=dataarray.data, names=tuple(dataarray.sizes))
40 changes: 29 additions & 11 deletions xbatcher/tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,41 @@ def test_batch_accessor_da(sample_ds_3d):
assert batch_class.equals(batch_acc)


def test_torch_to_tensor(sample_ds_3d):
@pytest.mark.parametrize(
'foo_var',
[
'foo', # xr.DataArray
['foo'], # xr.Dataset
],
)
def test_torch_to_tensor(sample_ds_3d, foo_var):
torch = pytest.importorskip('torch')

da = sample_ds_3d['foo']
t = da.torch.to_tensor()
foo = sample_ds_3d[foo_var]
t = foo.torch.to_tensor()
assert isinstance(t, torch.Tensor)
assert t.names == (None, None, None)
assert t.shape == da.shape
np.testing.assert_array_equal(t, da.values)
assert t.shape == tuple(foo.sizes.values())

foo_array = foo.to_array().squeeze() if hasattr(foo, 'to_array') else foo
np.testing.assert_array_equal(t, foo_array.values)

def test_torch_to_named_tensor(sample_ds_3d):

@pytest.mark.parametrize(
'foo_var',
[
'foo', # xr.DataArray
['foo'], # xr.Dataset
],
)
def test_torch_to_named_tensor(sample_ds_3d, foo_var):
torch = pytest.importorskip('torch')

da = sample_ds_3d['foo']
t = da.torch.to_named_tensor()
foo = sample_ds_3d[foo_var]
t = foo.torch.to_named_tensor()
assert isinstance(t, torch.Tensor)
assert t.names == da.dims
assert t.shape == da.shape
np.testing.assert_array_equal(t, da.values)
assert t.names == tuple(foo.dims)
assert t.shape == tuple(foo.sizes.values())

foo_array = foo.to_array().squeeze() if hasattr(foo, 'to_array') else foo
np.testing.assert_array_equal(t, foo_array.values)
79 changes: 55 additions & 24 deletions xbatcher/tests/test_torch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,17 @@ def ds_xy():
return ds


def test_map_dataset(ds_xy):

x = ds_xy['x']
y = ds_xy['y']
@pytest.mark.parametrize(
('x_var', 'y_var'),
[
('x', 'y'), # xr.DataArray
(['x'], ['y']), # xr.Dataset
],
)
def test_map_dataset(ds_xy, x_var, y_var):

x = ds_xy[x_var]
y = ds_xy[y_var]

x_gen = BatchGenerator(x, {'sample': 10})
y_gen = BatchGenerator(y, {'sample': 10})
Expand All @@ -54,23 +61,35 @@ def test_map_dataset(ds_xy):
assert len(dataset) == len(x_gen)

# test integration with torch DataLoader
loader = torch.utils.data.DataLoader(dataset)
loader = torch.utils.data.DataLoader(dataset, batch_size=None)

for x_batch, y_batch in loader:
assert x_batch.shape == (1, 10, 5)
assert y_batch.shape == (1, 10)
assert x_batch.shape == (10, 5)
assert y_batch.shape == (10,)
assert isinstance(x_batch, torch.Tensor)

# TODO: why does pytorch add an extra dimension (length 1) to x_batch
assert x_gen[-1].shape == x_batch.shape[1:]
# TODO: add test for xarray.Dataset
assert np.array_equal(x_gen[-1], x_batch[0, :, :])
# Check that array shape of last item in generator is same as the batch image
assert tuple(x_gen[-1].sizes.values()) == x_batch.shape
# Check that array values from last item in generator and batch are the same
gen_array = (
x_gen[-1].to_array().squeeze()
if hasattr(x_gen[-1], 'to_array')
else x_gen[-1]
)
np.testing.assert_array_equal(gen_array, x_batch)


def test_map_dataset_with_transform(ds_xy):
@pytest.mark.parametrize(
('x_var', 'y_var'),
[
('x', 'y'), # xr.DataArray
(['x'], ['y']), # xr.Dataset
],
)
def test_map_dataset_with_transform(ds_xy, x_var, y_var):

x = ds_xy['x']
y = ds_xy['y']
x = ds_xy[x_var]
y = ds_xy[y_var]

x_gen = BatchGenerator(x, {'sample': 10})
y_gen = BatchGenerator(y, {'sample': 10})
Expand All @@ -92,25 +111,37 @@ def y_transform(batch):
assert (y_batch == -1).all()


def test_iterable_dataset(ds_xy):
@pytest.mark.parametrize(
('x_var', 'y_var'),
[
('x', 'y'), # xr.DataArray
(['x'], ['y']), # xr.Dataset
],
)
def test_iterable_dataset(ds_xy, x_var, y_var):

x = ds_xy['x']
y = ds_xy['y']
x = ds_xy[x_var]
y = ds_xy[y_var]

x_gen = BatchGenerator(x, {'sample': 10})
y_gen = BatchGenerator(y, {'sample': 10})

dataset = IterableDataset(x_gen, y_gen)

# test integration with torch DataLoader
loader = torch.utils.data.DataLoader(dataset)
loader = torch.utils.data.DataLoader(dataset, batch_size=None)

for x_batch, y_batch in loader:
assert x_batch.shape == (1, 10, 5)
assert y_batch.shape == (1, 10)
assert x_batch.shape == (10, 5)
assert y_batch.shape == (10,)
assert isinstance(x_batch, torch.Tensor)

# TODO: why does pytorch add an extra dimension (length 1) to x_batch
assert x_gen[-1].shape == x_batch.shape[1:]
# TODO: add test for xarray.Dataset
assert np.array_equal(x_gen[-1], x_batch[0, :, :])
# Check that array shape of last item in generator is same as the batch image
assert tuple(x_gen[-1].sizes.values()) == x_batch.shape
# Check that array values from last item in generator and batch are the same
gen_array = (
x_gen[-1].to_array().squeeze()
if hasattr(x_gen[-1], 'to_array')
else x_gen[-1]
)
np.testing.assert_array_equal(gen_array, x_batch)