diff --git a/xbatcher/accessors.py b/xbatcher/accessors.py index 38d247c..229ae69 100644 --- a/xbatcher/accessors.py +++ b/xbatcher/accessors.py @@ -27,15 +27,32 @@ def generator(self, *args, **kwargs): @xr.register_dataarray_accessor('torch') +@xr.register_dataset_accessor('torch') class TorchAccessor: def __init__(self, xarray_obj): self._obj = xarray_obj + def _as_xarray_dataarray(self, xr_obj): + """ + Convert xarray.Dataset to xarray.DataArray if needed, so that it can + be converted into a torch.Tensor object. + """ + try: + # Convert xr.Dataset to xr.DataArray + dataarray = xr_obj.to_array().squeeze(dim='variable') + except AttributeError: # 'DataArray' object has no attribute 'to_array' + # If object is already an xr.DataArray + dataarray = xr_obj + + return dataarray + def to_tensor(self): """Convert this DataArray to a torch.Tensor""" import torch - return torch.tensor(self._obj.data) + dataarray = self._as_xarray_dataarray(xr_obj=self._obj) + + return torch.tensor(data=dataarray.data) def to_named_tensor(self): """ @@ -45,4 +62,6 @@ def to_named_tensor(self): """ import torch - return torch.tensor(self._obj.data, names=tuple(self._obj.sizes)) + dataarray = self._as_xarray_dataarray(xr_obj=self._obj) + + return torch.tensor(data=dataarray.data, names=tuple(dataarray.sizes)) diff --git a/xbatcher/tests/test_accessors.py b/xbatcher/tests/test_accessors.py index 4860803..39a421b 100644 --- a/xbatcher/tests/test_accessors.py +++ b/xbatcher/tests/test_accessors.py @@ -40,23 +40,41 @@ def test_batch_accessor_da(sample_ds_3d): assert batch_class.equals(batch_acc) -def test_torch_to_tensor(sample_ds_3d): +@pytest.mark.parametrize( + 'foo_var', + [ + 'foo', # xr.DataArray + ['foo'], # xr.Dataset + ], +) +def test_torch_to_tensor(sample_ds_3d, foo_var): torch = pytest.importorskip('torch') - da = sample_ds_3d['foo'] - t = da.torch.to_tensor() + foo = sample_ds_3d[foo_var] + t = foo.torch.to_tensor() assert isinstance(t, torch.Tensor) assert t.names == (None, None, None) - assert t.shape == da.shape - np.testing.assert_array_equal(t, da.values) + assert t.shape == tuple(foo.sizes.values()) + foo_array = foo.to_array().squeeze() if hasattr(foo, 'to_array') else foo + np.testing.assert_array_equal(t, foo_array.values) -def test_torch_to_named_tensor(sample_ds_3d): + +@pytest.mark.parametrize( + 'foo_var', + [ + 'foo', # xr.DataArray + ['foo'], # xr.Dataset + ], +) +def test_torch_to_named_tensor(sample_ds_3d, foo_var): torch = pytest.importorskip('torch') - da = sample_ds_3d['foo'] - t = da.torch.to_named_tensor() + foo = sample_ds_3d[foo_var] + t = foo.torch.to_named_tensor() assert isinstance(t, torch.Tensor) - assert t.names == da.dims - assert t.shape == da.shape - np.testing.assert_array_equal(t, da.values) + assert t.names == tuple(foo.dims) + assert t.shape == tuple(foo.sizes.values()) + + foo_array = foo.to_array().squeeze() if hasattr(foo, 'to_array') else foo + np.testing.assert_array_equal(t, foo_array.values) diff --git a/xbatcher/tests/test_torch_loaders.py b/xbatcher/tests/test_torch_loaders.py index bbde292..f95ad9d 100644 --- a/xbatcher/tests/test_torch_loaders.py +++ b/xbatcher/tests/test_torch_loaders.py @@ -24,10 +24,17 @@ def ds_xy(): return ds -def test_map_dataset(ds_xy): - - x = ds_xy['x'] - y = ds_xy['y'] +@pytest.mark.parametrize( + ('x_var', 'y_var'), + [ + ('x', 'y'), # xr.DataArray + (['x'], ['y']), # xr.Dataset + ], +) +def test_map_dataset(ds_xy, x_var, y_var): + + x = ds_xy[x_var] + y = ds_xy[y_var] x_gen = BatchGenerator(x, {'sample': 10}) y_gen = BatchGenerator(y, {'sample': 10}) @@ -54,23 +61,35 @@ def test_map_dataset(ds_xy): assert len(dataset) == len(x_gen) # test integration with torch DataLoader - loader = torch.utils.data.DataLoader(dataset) + loader = torch.utils.data.DataLoader(dataset, batch_size=None) for x_batch, y_batch in loader: - assert x_batch.shape == (1, 10, 5) - assert y_batch.shape == (1, 10) + assert x_batch.shape == (10, 5) + assert y_batch.shape == (10,) assert isinstance(x_batch, torch.Tensor) - # TODO: why does pytorch add an extra dimension (length 1) to x_batch - assert x_gen[-1].shape == x_batch.shape[1:] - # TODO: add test for xarray.Dataset - assert np.array_equal(x_gen[-1], x_batch[0, :, :]) + # Check that array shape of last item in generator is same as the batch image + assert tuple(x_gen[-1].sizes.values()) == x_batch.shape + # Check that array values from last item in generator and batch are the same + gen_array = ( + x_gen[-1].to_array().squeeze() + if hasattr(x_gen[-1], 'to_array') + else x_gen[-1] + ) + np.testing.assert_array_equal(gen_array, x_batch) -def test_map_dataset_with_transform(ds_xy): +@pytest.mark.parametrize( + ('x_var', 'y_var'), + [ + ('x', 'y'), # xr.DataArray + (['x'], ['y']), # xr.Dataset + ], +) +def test_map_dataset_with_transform(ds_xy, x_var, y_var): - x = ds_xy['x'] - y = ds_xy['y'] + x = ds_xy[x_var] + y = ds_xy[y_var] x_gen = BatchGenerator(x, {'sample': 10}) y_gen = BatchGenerator(y, {'sample': 10}) @@ -92,10 +111,17 @@ def y_transform(batch): assert (y_batch == -1).all() -def test_iterable_dataset(ds_xy): +@pytest.mark.parametrize( + ('x_var', 'y_var'), + [ + ('x', 'y'), # xr.DataArray + (['x'], ['y']), # xr.Dataset + ], +) +def test_iterable_dataset(ds_xy, x_var, y_var): - x = ds_xy['x'] - y = ds_xy['y'] + x = ds_xy[x_var] + y = ds_xy[y_var] x_gen = BatchGenerator(x, {'sample': 10}) y_gen = BatchGenerator(y, {'sample': 10}) @@ -103,14 +129,19 @@ def test_iterable_dataset(ds_xy): dataset = IterableDataset(x_gen, y_gen) # test integration with torch DataLoader - loader = torch.utils.data.DataLoader(dataset) + loader = torch.utils.data.DataLoader(dataset, batch_size=None) for x_batch, y_batch in loader: - assert x_batch.shape == (1, 10, 5) - assert y_batch.shape == (1, 10) + assert x_batch.shape == (10, 5) + assert y_batch.shape == (10,) assert isinstance(x_batch, torch.Tensor) - # TODO: why does pytorch add an extra dimension (length 1) to x_batch - assert x_gen[-1].shape == x_batch.shape[1:] - # TODO: add test for xarray.Dataset - assert np.array_equal(x_gen[-1], x_batch[0, :, :]) + # Check that array shape of last item in generator is same as the batch image + assert tuple(x_gen[-1].sizes.values()) == x_batch.shape + # Check that array values from last item in generator and batch are the same + gen_array = ( + x_gen[-1].to_array().squeeze() + if hasattr(x_gen[-1], 'to_array') + else x_gen[-1] + ) + np.testing.assert_array_equal(gen_array, x_batch)