Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyTorch dataloader #25

Merged
merged 9 commits into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,16 @@ repos:
hooks:
- id: prettier
language_version: system

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910
hooks:
- id: mypy
additional_dependencies: [
# Type stubs
types-setuptools,
types-pkg_resources,
# Dependencies that are typed
numpy,
xarray,
]
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import pytest


Expand Down
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest
torch
coverage
-r requirements.txt
2 changes: 2 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# All configuration values have a default; values that are commented out
# serve to show the default.

# type: ignore

import os
import sys

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ select = B,C,E,F,W,T4,B9

[isort]
known_first_party=xbatcher
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,xarray
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,torch,xarray
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
# type: ignore
import os

from setuptools import find_packages, setup
Expand Down
18 changes: 18 additions & 0 deletions xbatcher/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,21 @@ def generator(self, *args, **kwargs):
Keyword arguments to pass to the `BatchGenerator` constructor.
'''
return BatchGenerator(self._obj, *args, **kwargs)


@xr.register_dataarray_accessor('torch')
class TorchAccessor:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def to_tensor(self):
"""Convert this DataArray to a torch.Tensor"""
import torch

return torch.tensor(self._obj.data)

def to_named_tensor(self):
"""Convert this DataArray to a torch.Tensor with named dimensions"""
import torch

return torch.tensor(self._obj.data, names=self._obj.dims)
56 changes: 45 additions & 11 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import itertools
from collections import OrderedDict
from typing import Any, Dict, Hashable, Iterator

import xarray as xr

Expand Down Expand Up @@ -99,12 +100,12 @@ class BatchGenerator:

def __init__(
self,
ds,
input_dims,
input_overlap={},
batch_dims={},
concat_input_dims=False,
preload_batch=True,
ds: xr.Dataset,
input_dims: Dict[Hashable, int],
input_overlap: Dict[Hashable, int] = {},
batch_dims: Dict[Hashable, int] = {},
concat_input_dims: bool = False,
preload_batch: bool = True,
):

self.ds = _as_xarray_dataset(ds)
Expand All @@ -115,7 +116,38 @@ def __init__(
self.concat_input_dims = concat_input_dims
self.preload_batch = preload_batch

def __iter__(self):
self._batches: Dict[
int, Any
] = self._gen_batches() # dict cache for batches
# in the future, we can make this a lru cache or similar thing (cachey?)

def __iter__(self) -> Iterator[xr.Dataset]:
for batch in self._batches.values():
yield batch

def __len__(self) -> int:
return len(self._batches)

def __getitem__(self, idx: int) -> xr.Dataset:

if not isinstance(idx, int):
raise NotImplementedError(
f'{type(self).__name__}.__getitem__ currently requires a single integer key'
)

if idx < 0:
idx = list(self._batches)[idx]

if idx in self._batches:
return self._batches[idx]
else:
raise IndexError('list index out of range')

def _gen_batches(self) -> dict:
# in the future, we will want to do the batch generation lazily
# going the eager route for now is allowing me to fill out the loader api
# but it is likely to perform poorly.
Comment on lines +146 to +149
Copy link
Contributor Author

@jhamman jhamman Aug 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging this as something so discuss / work out a design for. It feels quite important that we are able to generate arbitrary batches on the fly. The current implementation eagerly generates batches which will not scale well. However, the pure generator approach doesn't work if you need to randomly access batches (eg via getitem).

batches = []
for ds_batch in self._iterate_batch_dims(self.ds):
if self.preload_batch:
ds_batch.load()
Expand All @@ -130,15 +162,17 @@ def __iter__(self):
]
dsc = xr.concat(all_dsets, dim='input_batch')
new_input_dims = [
dim + new_dim_suffix for dim in self.input_dims
str(dim) + new_dim_suffix for dim in self.input_dims
]
yield _maybe_stack_batch_dims(dsc, new_input_dims)
batches.append(_maybe_stack_batch_dims(dsc, new_input_dims))
else:
for ds_input in input_generator:
yield _maybe_stack_batch_dims(
ds_input, list(self.input_dims)
batches.append(
_maybe_stack_batch_dims(ds_input, list(self.input_dims))
)

return dict(zip(range(len(batches)), batches))

def _iterate_batch_dims(self, ds):
return _iterate_through_dataset(ds, self.batch_dims)

Expand Down
Empty file added xbatcher/loaders/__init__.py
Empty file.
83 changes: 83 additions & 0 deletions xbatcher/loaders/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Any, Callable, Optional, Tuple

import torch

# Notes:
# This module includes two PyTorch datasets.
# - The MapDataset provides an indexable interface
# - The IterableDataset provides a simple iterable interface
# Both can be provided as arguments to the the Torch DataLoader
# Assumptions made:
# - Each dataset takes pre-configured X/y xbatcher generators (may not always want two generators ina dataset)
# TODOs:
# - sort out xarray -> numpy pattern. Currently there is a hardcoded variable name for x/y
# - need to test with additional dataset parameters (e.g. transforms)


class MapDataset(torch.utils.data.Dataset):
def __init__(
self,
X_generator,
y_generator,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
'''
PyTorch Dataset adapter for Xbatcher

Parameters
----------
X_generator : xbatcher.BatchGenerator
y_generator : xbatcher.BatchGenerator
transform : callable, optional
A function/transform that takes in an array and returns a transformed version.
target_transform : callable, optional
A function/transform that takes in the target and transforms it.
'''
self.X_generator = X_generator
self.y_generator = y_generator
self.transform = transform
self.target_transform = target_transform

def __len__(self) -> int:
return len(self.X_generator)

def __getitem__(self, idx) -> Tuple[Any, Any]:
if torch.is_tensor(idx):
idx = idx.tolist()
assert len(idx) == 1

# TODO: figure out the dataset -> array workflow
# currently hardcoding a variable name
X_batch = self.X_generator[idx]['x'].torch.to_tensor()
y_batch = self.y_generator[idx]['y'].torch.to_tensor()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flagging that we can't use named tensors here while we wait for pytorch/pytorch#29010


if self.transform:
X_batch = self.transform(X_batch)

if self.target_transform:
y_batch = self.target_transform(y_batch)
return X_batch, y_batch


class IterableDataset(torch.utils.data.IterableDataset):
def __init__(
self,
X_generator,
y_generator,
) -> None:
'''
PyTorch Dataset adapter for Xbatcher

Parameters
----------
X_generator : xbatcher.BatchGenerator
y_generator : xbatcher.BatchGenerator
'''

self.X_generator = X_generator
self.y_generator = y_generator

def __iter__(self):
for xb, yb in zip(self.X_generator, self.y_generator):
yield (xb['x'].torch.to_tensor(), yb['y'].torch.to_tensor())
22 changes: 22 additions & 0 deletions xbatcher/tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,25 @@ def test_batch_accessor_da(sample_ds_3d):
assert isinstance(bg_acc, BatchGenerator)
for batch_class, batch_acc in zip(bg_class, bg_acc):
assert batch_class.equals(batch_acc)


def test_torch_to_tensor(sample_ds_3d):
torch = pytest.importorskip('torch')

da = sample_ds_3d['foo']
t = da.torch.to_tensor()
assert isinstance(t, torch.Tensor)
assert t.names == (None, None, None)
assert t.shape == da.shape
np.testing.assert_array_equal(t, da.values)


def test_torch_to_named_tensor(sample_ds_3d):
torch = pytest.importorskip('torch')

da = sample_ds_3d['foo']
t = da.torch.to_named_tensor()
assert isinstance(t, torch.Tensor)
assert t.names == da.dims
assert t.shape == da.shape
np.testing.assert_array_equal(t, da.values)
22 changes: 22 additions & 0 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,28 @@ def sample_ds_1d():
return ds


@pytest.mark.parametrize('bsize', [5, 6])
def test_batcher_lenth(sample_ds_1d, bsize):
bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize})
assert len(bg) == sample_ds_1d.dims['x'] // bsize


def test_batcher_getitem(sample_ds_1d):
bg = BatchGenerator(sample_ds_1d, input_dims={'x': 10})

# first batch
assert bg[0].dims['x'] == 10
# last batch
assert bg[-1].dims['x'] == 10
# raises IndexError for out of range index
with pytest.raises(IndexError, match=r'list index out of range'):
bg[9999999]

# raises NotImplementedError for iterable index
with pytest.raises(NotImplementedError):
bg[[1, 2, 3]]


# TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension
# Should we enforce that each batch size always has to be the same
@pytest.mark.parametrize('bsize', [5, 10])
Expand Down
78 changes: 78 additions & 0 deletions xbatcher/tests/test_torch_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import pytest
import xarray as xr

torch = pytest.importorskip('torch')

from xbatcher import BatchGenerator
from xbatcher.loaders.torch import IterableDataset, MapDataset


@pytest.fixture(scope='module')
def ds_xy():
n_samples = 100
n_features = 5
ds = xr.Dataset(
{
'x': (
['sample', 'feature'],
np.random.random((n_samples, n_features)),
),
'y': (['sample'], np.random.random(n_samples)),
},
)
return ds


def test_map_dataset(ds_xy):

x = ds_xy['x']
y = ds_xy['y']

x_gen = BatchGenerator(x, {'sample': 10})
y_gen = BatchGenerator(y, {'sample': 10})

dataset = MapDataset(x_gen, y_gen)

# test __getitem__
x_batch, y_batch = dataset[0]
assert len(x_batch) == len(y_batch)
assert isinstance(x_batch, torch.Tensor)

# test __len__
assert len(dataset) == len(x_gen)

# test integration with torch DataLoader
loader = torch.utils.data.DataLoader(dataset)

for x_batch, y_batch in loader:
assert len(x_batch) == len(y_batch)
assert isinstance(x_batch, torch.Tensor)

# TODO: why does pytorch add an extra dimension (length 1) to x_batch
assert x_gen[-1]['x'].shape == x_batch.shape[1:]
# TODO: also need to revisit the variable extraction bits here
assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])


def test_iterable_dataset(ds_xy):

x = ds_xy['x']
y = ds_xy['y']

x_gen = BatchGenerator(x, {'sample': 10})
y_gen = BatchGenerator(y, {'sample': 10})

dataset = IterableDataset(x_gen, y_gen)

# test integration with torch DataLoader
loader = torch.utils.data.DataLoader(dataset)

for x_batch, y_batch in loader:
assert len(x_batch) == len(y_batch)
assert isinstance(x_batch, torch.Tensor)

# TODO: why does pytorch add an extra dimension (length 1) to x_batch
assert x_gen[-1]['x'].shape == x_batch.shape[1:]
# TODO: also need to revisit the variable extraction bits here
assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])