Skip to content

Commit

Permalink
Merge pull request #25 from jhamman/loader/torch
Browse files Browse the repository at this point in the history
Add pytorch dataloader
  • Loading branch information
Joe Hamman authored Feb 23, 2022
2 parents 802bbd5 + 8bcd870 commit 3af1306
Show file tree
Hide file tree
Showing 14 changed files with 341 additions and 19 deletions.
15 changes: 14 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
- id: double-quote-string-fixer

- repo: https://github.com/psf/black
rev: 21.12b0
rev: 22.1.0
hooks:
- id: black
args: ["--line-length", "80", "--skip-string-normalization"]
Expand All @@ -37,3 +37,16 @@ repos:
hooks:
- id: prettier
language_version: system

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.931
hooks:
- id: mypy
additional_dependencies: [
# Type stubs
types-setuptools,
types-pkg_resources,
# Dependencies that are typed
numpy,
xarray,
]
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import pytest


Expand Down
2 changes: 2 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pytest
torch
coverage
pytest-cov
adlfs
-r requirements.txt
20 changes: 14 additions & 6 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@ API reference

This page provides an auto-generated summary of Xbatcher's API.

Core
====

.. autoclass:: xbatcher.BatchGenerator
:members:

Dataset.batch and DataArray.batch
=================================

Expand All @@ -22,3 +16,17 @@ Dataset.batch and DataArray.batch

Dataset.batch.generator
DataArray.batch.generator

Core
====

.. autoclass:: xbatcher.BatchGenerator
:members:

Dataloaders
===========
.. autoclass:: xbatcher.loaders.torch.MapDataset
:members:

.. autoclass:: xbatcher.loaders.torch.IterableDataset
:members:
2 changes: 2 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# All configuration values have a default; values that are commented out
# serve to show the default.

# type: ignore

import os
import sys

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ select = B,C,E,F,W,T4,B9

[isort]
known_first_party=xbatcher
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,xarray
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,torch,xarray
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
# type: ignore
import os

from setuptools import find_packages, setup
Expand Down
18 changes: 18 additions & 0 deletions xbatcher/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,21 @@ def generator(self, *args, **kwargs):
Keyword arguments to pass to the `BatchGenerator` constructor.
'''
return BatchGenerator(self._obj, *args, **kwargs)


@xr.register_dataarray_accessor('torch')
class TorchAccessor:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def to_tensor(self):
"""Convert this DataArray to a torch.Tensor"""
import torch

return torch.tensor(self._obj.data)

def to_named_tensor(self):
"""Convert this DataArray to a torch.Tensor with named dimensions"""
import torch

return torch.tensor(self._obj.data, names=self._obj.dims)
56 changes: 45 additions & 11 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import itertools
from collections import OrderedDict
from typing import Any, Dict, Hashable, Iterator

import xarray as xr

Expand Down Expand Up @@ -99,12 +100,12 @@ class BatchGenerator:

def __init__(
self,
ds,
input_dims,
input_overlap={},
batch_dims={},
concat_input_dims=False,
preload_batch=True,
ds: xr.Dataset,
input_dims: Dict[Hashable, int],
input_overlap: Dict[Hashable, int] = {},
batch_dims: Dict[Hashable, int] = {},
concat_input_dims: bool = False,
preload_batch: bool = True,
):

self.ds = _as_xarray_dataset(ds)
Expand All @@ -115,7 +116,38 @@ def __init__(
self.concat_input_dims = concat_input_dims
self.preload_batch = preload_batch

def __iter__(self):
self._batches: Dict[
int, Any
] = self._gen_batches() # dict cache for batches
# in the future, we can make this a lru cache or similar thing (cachey?)

def __iter__(self) -> Iterator[xr.Dataset]:
for batch in self._batches.values():
yield batch

def __len__(self) -> int:
return len(self._batches)

def __getitem__(self, idx: int) -> xr.Dataset:

if not isinstance(idx, int):
raise NotImplementedError(
f'{type(self).__name__}.__getitem__ currently requires a single integer key'
)

if idx < 0:
idx = list(self._batches)[idx]

if idx in self._batches:
return self._batches[idx]
else:
raise IndexError('list index out of range')

def _gen_batches(self) -> dict:
# in the future, we will want to do the batch generation lazily
# going the eager route for now is allowing me to fill out the loader api
# but it is likely to perform poorly.
batches = []
for ds_batch in self._iterate_batch_dims(self.ds):
if self.preload_batch:
ds_batch.load()
Expand All @@ -130,15 +162,17 @@ def __iter__(self):
]
dsc = xr.concat(all_dsets, dim='input_batch')
new_input_dims = [
dim + new_dim_suffix for dim in self.input_dims
str(dim) + new_dim_suffix for dim in self.input_dims
]
yield _maybe_stack_batch_dims(dsc, new_input_dims)
batches.append(_maybe_stack_batch_dims(dsc, new_input_dims))
else:
for ds_input in input_generator:
yield _maybe_stack_batch_dims(
ds_input, list(self.input_dims)
batches.append(
_maybe_stack_batch_dims(ds_input, list(self.input_dims))
)

return dict(zip(range(len(batches)), batches))

def _iterate_batch_dims(self, ds):
return _iterate_through_dataset(ds, self.batch_dims)

Expand Down
Empty file added xbatcher/loaders/__init__.py
Empty file.
88 changes: 88 additions & 0 deletions xbatcher/loaders/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Any, Callable, Optional, Tuple

import torch

# Notes:
# This module includes two PyTorch datasets.
# - The MapDataset provides an indexable interface
# - The IterableDataset provides a simple iterable interface
# Both can be provided as arguments to the the Torch DataLoader
# Assumptions made:
# - Each dataset takes pre-configured X/y xbatcher generators (may not always want two generators ina dataset)
# TODOs:
# - sort out xarray -> numpy pattern. Currently there is a hardcoded variable name for x/y
# - need to test with additional dataset parameters (e.g. transforms)


class MapDataset(torch.utils.data.Dataset):
def __init__(
self,
X_generator,
y_generator,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
'''
PyTorch Dataset adapter for Xbatcher
Parameters
----------
X_generator : xbatcher.BatchGenerator
y_generator : xbatcher.BatchGenerator
transform : callable, optional
A function/transform that takes in an array and returns a transformed version.
target_transform : callable, optional
A function/transform that takes in the target and transforms it.
'''
self.X_generator = X_generator
self.y_generator = y_generator
self.transform = transform
self.target_transform = target_transform

def __len__(self) -> int:
return len(self.X_generator)

def __getitem__(self, idx) -> Tuple[Any, Any]:
if torch.is_tensor(idx):
idx = idx.tolist()
if len(idx) == 1:
idx = idx[0]
else:
raise NotImplementedError(
f'{type(self).__name__}.__getitem__ currently requires a single integer key'
)

# TODO: figure out the dataset -> array workflow
# currently hardcoding a variable name
X_batch = self.X_generator[idx]['x'].torch.to_tensor()
y_batch = self.y_generator[idx]['y'].torch.to_tensor()

if self.transform:
X_batch = self.transform(X_batch)

if self.target_transform:
y_batch = self.target_transform(y_batch)
return X_batch, y_batch


class IterableDataset(torch.utils.data.IterableDataset):
def __init__(
self,
X_generator,
y_generator,
) -> None:
'''
PyTorch Dataset adapter for Xbatcher
Parameters
----------
X_generator : xbatcher.BatchGenerator
y_generator : xbatcher.BatchGenerator
'''

self.X_generator = X_generator
self.y_generator = y_generator

def __iter__(self):
for xb, yb in zip(self.X_generator, self.y_generator):
yield (xb['x'].torch.to_tensor(), yb['y'].torch.to_tensor())
22 changes: 22 additions & 0 deletions xbatcher/tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,25 @@ def test_batch_accessor_da(sample_ds_3d):
assert isinstance(bg_acc, BatchGenerator)
for batch_class, batch_acc in zip(bg_class, bg_acc):
assert batch_class.equals(batch_acc)


def test_torch_to_tensor(sample_ds_3d):
torch = pytest.importorskip('torch')

da = sample_ds_3d['foo']
t = da.torch.to_tensor()
assert isinstance(t, torch.Tensor)
assert t.names == (None, None, None)
assert t.shape == da.shape
np.testing.assert_array_equal(t, da.values)


def test_torch_to_named_tensor(sample_ds_3d):
torch = pytest.importorskip('torch')

da = sample_ds_3d['foo']
t = da.torch.to_named_tensor()
assert isinstance(t, torch.Tensor)
assert t.names == da.dims
assert t.shape == da.shape
np.testing.assert_array_equal(t, da.values)
22 changes: 22 additions & 0 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,28 @@ def test_constructor_coerces_to_dataset():
assert bg.ds.equals(da.to_dataset())


@pytest.mark.parametrize('bsize', [5, 6])
def test_batcher_lenth(sample_ds_1d, bsize):
bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize})
assert len(bg) == sample_ds_1d.dims['x'] // bsize


def test_batcher_getitem(sample_ds_1d):
bg = BatchGenerator(sample_ds_1d, input_dims={'x': 10})

# first batch
assert bg[0].dims['x'] == 10
# last batch
assert bg[-1].dims['x'] == 10
# raises IndexError for out of range index
with pytest.raises(IndexError, match=r'list index out of range'):
bg[9999999]

# raises NotImplementedError for iterable index
with pytest.raises(NotImplementedError):
bg[[1, 2, 3]]


# TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension
# Should we enforce that each batch size always has to be the same
@pytest.mark.parametrize('bsize', [5, 10])
Expand Down
Loading

0 comments on commit 3af1306

Please sign in to comment.