From 1c5febf6fe41d13b61bb35fb76931c3aa3e53f72 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 11 Aug 2021 17:15:43 -0700 Subject: [PATCH 1/8] [loaders refactor] initial commit --- setup.cfg | 2 +- xbatcher/generators.py | 55 ++++++++++++++---- xbatcher/loaders/__init__.py | 0 xbatcher/loaders/torch.py | 84 ++++++++++++++++++++++++++++ xbatcher/tests/test_generators.py | 22 ++++++++ xbatcher/tests/test_torch_loaders.py | 78 ++++++++++++++++++++++++++ 6 files changed, 230 insertions(+), 11 deletions(-) create mode 100644 xbatcher/loaders/__init__.py create mode 100644 xbatcher/loaders/torch.py create mode 100644 xbatcher/tests/test_torch_loaders.py diff --git a/setup.cfg b/setup.cfg index 959a2fb..84899d3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ select = B,C,E,F,W,T4,B9 [isort] known_first_party=xbatcher -known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,xarray +known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,torch,xarray multi_line_output=3 include_trailing_comma=True force_grid_wrap=0 diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 612be61..6208efd 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -2,6 +2,8 @@ import itertools from collections import OrderedDict +from collections.abc import Iterator +from typing import Any, Dict, Hashable import xarray as xr @@ -99,12 +101,12 @@ class BatchGenerator: def __init__( self, - ds, - input_dims, - input_overlap={}, - batch_dims={}, - concat_input_dims=False, - preload_batch=True, + ds: xr.Dataset, + input_dims: Dict[Hashable, int], + input_overlap: Dict[Hashable, int] = {}, + batch_dims: Dict[Hashable, int] = {}, + concat_input_dims: bool = False, + preload_batch: bool = True, ): self.ds = _as_xarray_dataset(ds) @@ -115,7 +117,38 @@ def __init__( self.concat_input_dims = concat_input_dims self.preload_batch = preload_batch - def __iter__(self): + self._batches: Dict[ + int, Any + ] = self._gen_batches() # dict cache for batches + # in the future, we can make this a lru cache or similar thing (cachey?) + + def __iter__(self) -> Iterator[xr.Dataset]: + for batch in self._batches.values(): + yield batch + + def __len__(self) -> int: + return len(self._batches) + + def __getitem__(self, idx: int) -> xr.Dataset: + + if not isinstance(idx, int): + raise NotImplementedError( + f'{type(self).__name__}.__getitem__ currently requires a single integer key' + ) + + if idx < 0: + idx = list(self._batches)[idx] + + if idx in self._batches: + return self._batches[idx] + else: + raise IndexError('list index out of range') + + def _gen_batches(self) -> dict: + # in the future, we will want to do the batch generation lazily + # going the eager route for now is allowing me to fill out the loader api + # but it is likely to perform poorly. + batches = [] for ds_batch in self._iterate_batch_dims(self.ds): if self.preload_batch: ds_batch.load() @@ -132,13 +165,15 @@ def __iter__(self): new_input_dims = [ dim + new_dim_suffix for dim in self.input_dims ] - yield _maybe_stack_batch_dims(dsc, new_input_dims) + batches.append(_maybe_stack_batch_dims(dsc, new_input_dims)) else: for ds_input in input_generator: - yield _maybe_stack_batch_dims( - ds_input, list(self.input_dims) + batches.append( + _maybe_stack_batch_dims(ds_input, list(self.input_dims)) ) + return dict(zip(range(len(batches)), batches)) + def _iterate_batch_dims(self, ds): return _iterate_through_dataset(ds, self.batch_dims) diff --git a/xbatcher/loaders/__init__.py b/xbatcher/loaders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/xbatcher/loaders/torch.py b/xbatcher/loaders/torch.py new file mode 100644 index 0000000..91b2c92 --- /dev/null +++ b/xbatcher/loaders/torch.py @@ -0,0 +1,84 @@ +from typing import Any, Callable, Optional, Tuple + +import torch + +# Notes: +# This module includes two PyTorch datasets. +# - The MapDataset provides an indexable interface +# - The IterableDataset provides a simple iterable interface +# Both can be provided as arguments to the the Torch DataLoader +# Assumptions made: +# - Each dataset takes pre-configured X/y xbatcher generators (may not always want two generators ina dataset) +# TODOs: +# - sort out xarray -> numpy pattern. Currently there is a hardcoded variable name for x/y +# - need to test with additional dataset parameters (e.g. transforms) + + +class MapDataset(torch.utils.data.Dataset): + def __init__( + self, + X_generator, + y_generator, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + ''' + PyTorch Dataset adapter for Xbatcher + + Parameters + ---------- + X_generator : xbatcher.BatchGenerator + y_generator : xbatcher.BatchGenerator + transform : callable, optional + A function/transform that takes in an array and returns a transformed version. + target_transform : callable, optional + A function/transform that takes in the target and transforms it. + ''' + self.X_generator = X_generator + self.y_generator = y_generator + self.transform = transform + self.target_transform = target_transform + + def __len__(self) -> int: + return len(self.X_generator) + + def __getitem__(self, idx) -> Tuple[Any, Any]: + if torch.is_tensor(idx): + idx = idx.tolist() + assert len(idx) == 1 + + # TODO: figure out the dataset -> array workflow + # currently hardcoding a variable name + X_batch = self.X_generator[idx]['x'].data + y_batch = self.y_generator[idx]['y'].data + + if self.transform: + X_batch = self.transform(X_batch) + + if self.target_transform: + y_batch = self.target_transform(y_batch) + print('x_batch.shape', X_batch.shape) + return X_batch, y_batch + + +class IterableDataset(torch.utils.data.IterableDataset): + def __init__( + self, + X_generator, + y_generator, + ) -> None: + ''' + PyTorch Dataset adapter for Xbatcher + + Parameters + ---------- + X_generator : xbatcher.BatchGenerator + y_generator : xbatcher.BatchGenerator + ''' + + self.X_generator = X_generator + self.y_generator = y_generator + + def __iter__(self): + for xb, yb in zip(self.X_generator, self.y_generator): + yield (xb['x'].data, yb['y'].data) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 54984b3..bddb04e 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -18,6 +18,28 @@ def sample_ds_1d(): return ds +@pytest.mark.parametrize('bsize', [5, 6]) +def test_batcher_lenth(sample_ds_1d, bsize): + bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize}) + assert len(bg) == sample_ds_1d.dims['x'] // bsize + + +def test_batcher_getitem(sample_ds_1d): + bg = BatchGenerator(sample_ds_1d, input_dims={'x': 10}) + + # first batch + assert bg[0].dims['x'] == 10 + # last batch + assert bg[-1].dims['x'] == 10 + # raises IndexError for out of range index + with pytest.raises(IndexError, match=r'list index out of range'): + bg[9999999] + + # raises NotImplementedError for iterable index + with pytest.raises(NotImplementedError): + bg[[1, 2, 3]] + + # TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension # Should we enforce that each batch size always has to be the same @pytest.mark.parametrize('bsize', [5, 10]) diff --git a/xbatcher/tests/test_torch_loaders.py b/xbatcher/tests/test_torch_loaders.py new file mode 100644 index 0000000..4e9c897 --- /dev/null +++ b/xbatcher/tests/test_torch_loaders.py @@ -0,0 +1,78 @@ +import numpy as np +import pytest +import xarray as xr + +torch = pytest.importorskip('torch') + +from xbatcher import BatchGenerator +from xbatcher.loaders.torch import IterableDataset, MapDataset + + +@pytest.fixture(scope='module') +def ds_xy(): + n_samples = 100 + n_features = 5 + ds = xr.Dataset( + { + 'x': ( + ['sample', 'feature'], + np.random.random((n_samples, n_features)), + ), + 'y': (['sample'], np.random.random(n_samples)), + }, + ) + return ds + + +def test_map_dataset(ds_xy): + + x = ds_xy['x'] + y = ds_xy['y'] + + x_gen = BatchGenerator(x, {'sample': 10}) + y_gen = BatchGenerator(y, {'sample': 10}) + + dataset = MapDataset(x_gen, y_gen) + + # test __getitem__ + x_batch, y_batch = dataset[0] + assert len(x_batch) == len(y_batch) + assert isinstance(x_batch, np.ndarray) + + # test __len__ + assert len(dataset) == len(x_gen) + + # test integration with torch DataLoader + loader = torch.utils.data.DataLoader(dataset) + + for x_batch, y_batch in loader: + assert len(x_batch) == len(y_batch) + assert isinstance(x_batch, torch.Tensor) + + # TODO: why does pytorch add an extra dimension (length 1) to x_batch + assert x_gen[-1]['x'].shape == x_batch.shape[1:] + # TODO: also need to revisit the variable extraction bits here + assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :]) + + +def test_iterable_dataset(ds_xy): + + x = ds_xy['x'] + y = ds_xy['y'] + + x_gen = BatchGenerator(x, {'sample': 10}) + y_gen = BatchGenerator(y, {'sample': 10}) + + dataset = IterableDataset(x_gen, y_gen) + + # test integration with torch DataLoader + loader = torch.utils.data.DataLoader(dataset) + + for x_batch, y_batch in loader: + assert len(x_batch) == len(y_batch) + assert isinstance(x_batch, torch.Tensor) + + # TODO: why does pytorch add an extra dimension (length 1) to x_batch + assert x_gen[-1]['x'].shape == x_batch.shape[1:] + # TODO: also need to revisit the variable extraction bits here + assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :]) From 94519b3a46ac0e3fc8d986f174380345f626c6f6 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 11 Aug 2021 17:28:06 -0700 Subject: [PATCH 2/8] add torch to dev environment --- dev-requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/dev-requirements.txt b/dev-requirements.txt index 642f9c5..e9fee44 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,4 @@ pytest +torch coverage -r requirements.txt From 04480ba2007ddf6866065d0c76868fffdf2e7edc Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Thu, 12 Aug 2021 14:49:57 -0700 Subject: [PATCH 3/8] fix mypy checks --- .pre-commit-config.yaml | 13 +++++++++++++ conftest.py | 1 + doc/conf.py | 2 ++ setup.py | 1 + xbatcher/generators.py | 5 ++--- xbatcher/loaders/torch.py | 1 - 6 files changed, 19 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a50297..d284551 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,3 +33,16 @@ repos: hooks: - id: prettier language_version: system + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.910 + hooks: + - id: mypy + additional_dependencies: [ + # Type stubs + types-setuptools, + types-pkg_resources, + # Dependencies that are typed + numpy, + xarray, + ] diff --git a/conftest.py b/conftest.py index 44ad179..acc08d3 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,4 @@ +# type: ignore import pytest diff --git a/doc/conf.py b/doc/conf.py index 5d7674c..ef5de36 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -12,6 +12,8 @@ # All configuration values have a default; values that are commented out # serve to show the default. +# type: ignore + import os import sys diff --git a/setup.py b/setup.py index 685f7b9..50c7aa7 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +# type: ignore import os from setuptools import find_packages, setup diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 6208efd..da80995 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -2,8 +2,7 @@ import itertools from collections import OrderedDict -from collections.abc import Iterator -from typing import Any, Dict, Hashable +from typing import Any, Dict, Hashable, Iterator import xarray as xr @@ -163,7 +162,7 @@ def _gen_batches(self) -> dict: ] dsc = xr.concat(all_dsets, dim='input_batch') new_input_dims = [ - dim + new_dim_suffix for dim in self.input_dims + str(dim) + new_dim_suffix for dim in self.input_dims ] batches.append(_maybe_stack_batch_dims(dsc, new_input_dims)) else: diff --git a/xbatcher/loaders/torch.py b/xbatcher/loaders/torch.py index 91b2c92..c7da394 100644 --- a/xbatcher/loaders/torch.py +++ b/xbatcher/loaders/torch.py @@ -57,7 +57,6 @@ def __getitem__(self, idx) -> Tuple[Any, Any]: if self.target_transform: y_batch = self.target_transform(y_batch) - print('x_batch.shape', X_batch.shape) return X_batch, y_batch From 6104bf3d5c32de34e2560f03f93fa155c533c603 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Sat, 9 Oct 2021 08:44:58 -0700 Subject: [PATCH 4/8] add torch accessor --- xbatcher/accessors.py | 18 ++++++++++++++++++ xbatcher/loaders/torch.py | 6 +++--- xbatcher/tests/test_accessors.py | 22 ++++++++++++++++++++++ xbatcher/tests/test_torch_loaders.py | 2 +- 4 files changed, 44 insertions(+), 4 deletions(-) diff --git a/xbatcher/accessors.py b/xbatcher/accessors.py index 4a92bf8..44c2e84 100644 --- a/xbatcher/accessors.py +++ b/xbatcher/accessors.py @@ -24,3 +24,21 @@ def generator(self, *args, **kwargs): Keyword arguments to pass to the `BatchGenerator` constructor. ''' return BatchGenerator(self._obj, *args, **kwargs) + + +@xr.register_dataarray_accessor('torch') +class TorchAccessor: + def __init__(self, xarray_obj): + self._obj = xarray_obj + + def to_tensor(self): + """Convert this DataArray to a torch.Tensor""" + import torch + + return torch.tensor(self._obj.data) + + def to_named_tensor(self): + """Convert this DataArray to a torch.Tensor with named dimensions""" + import torch + + return torch.tensor(self._obj.data, names=self._obj.dims) diff --git a/xbatcher/loaders/torch.py b/xbatcher/loaders/torch.py index c7da394..bd030ac 100644 --- a/xbatcher/loaders/torch.py +++ b/xbatcher/loaders/torch.py @@ -49,8 +49,8 @@ def __getitem__(self, idx) -> Tuple[Any, Any]: # TODO: figure out the dataset -> array workflow # currently hardcoding a variable name - X_batch = self.X_generator[idx]['x'].data - y_batch = self.y_generator[idx]['y'].data + X_batch = self.X_generator[idx]['x'].torch.to_tensor() + y_batch = self.y_generator[idx]['y'].torch.to_tensor() if self.transform: X_batch = self.transform(X_batch) @@ -80,4 +80,4 @@ def __init__( def __iter__(self): for xb, yb in zip(self.X_generator, self.y_generator): - yield (xb['x'].data, yb['y'].data) + yield (xb['x'].torch.to_tensor(), yb['y'].torch.to_tensor()) diff --git a/xbatcher/tests/test_accessors.py b/xbatcher/tests/test_accessors.py index d9be321..4860803 100644 --- a/xbatcher/tests/test_accessors.py +++ b/xbatcher/tests/test_accessors.py @@ -38,3 +38,25 @@ def test_batch_accessor_da(sample_ds_3d): assert isinstance(bg_acc, BatchGenerator) for batch_class, batch_acc in zip(bg_class, bg_acc): assert batch_class.equals(batch_acc) + + +def test_torch_to_tensor(sample_ds_3d): + torch = pytest.importorskip('torch') + + da = sample_ds_3d['foo'] + t = da.torch.to_tensor() + assert isinstance(t, torch.Tensor) + assert t.names == (None, None, None) + assert t.shape == da.shape + np.testing.assert_array_equal(t, da.values) + + +def test_torch_to_named_tensor(sample_ds_3d): + torch = pytest.importorskip('torch') + + da = sample_ds_3d['foo'] + t = da.torch.to_named_tensor() + assert isinstance(t, torch.Tensor) + assert t.names == da.dims + assert t.shape == da.shape + np.testing.assert_array_equal(t, da.values) diff --git a/xbatcher/tests/test_torch_loaders.py b/xbatcher/tests/test_torch_loaders.py index 4e9c897..2741b8f 100644 --- a/xbatcher/tests/test_torch_loaders.py +++ b/xbatcher/tests/test_torch_loaders.py @@ -37,7 +37,7 @@ def test_map_dataset(ds_xy): # test __getitem__ x_batch, y_batch = dataset[0] assert len(x_batch) == len(y_batch) - assert isinstance(x_batch, np.ndarray) + assert isinstance(x_batch, torch.Tensor) # test __len__ assert len(dataset) == len(x_gen) From 86c8560c62893be6005edeadb5e70e11153fcd21 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 23 Feb 2022 09:05:43 -0800 Subject: [PATCH 5/8] lint --- xbatcher/tests/test_generators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index c80a71e..23f9448 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -61,7 +61,7 @@ def test_batcher_getitem(sample_ds_1d): # raises NotImplementedError for iterable index with pytest.raises(NotImplementedError): bg[[1, 2, 3]] - + # TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension # Should we enforce that each batch size always has to be the same From 2bbf2df78b229fab0739f53a4780cf8ef0b038a8 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 23 Feb 2022 09:34:06 -0800 Subject: [PATCH 6/8] additional test coverage for torch loaders --- xbatcher/loaders/torch.py | 7 +++++- xbatcher/tests/test_torch_loaders.py | 33 ++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/xbatcher/loaders/torch.py b/xbatcher/loaders/torch.py index bd030ac..f68f63a 100644 --- a/xbatcher/loaders/torch.py +++ b/xbatcher/loaders/torch.py @@ -45,7 +45,12 @@ def __len__(self) -> int: def __getitem__(self, idx) -> Tuple[Any, Any]: if torch.is_tensor(idx): idx = idx.tolist() - assert len(idx) == 1 + if len(idx) == 1: + idx = idx[0] + else: + raise NotImplementedError( + f'{type(self).__name__}.__getitem__ currently requires a single integer key' + ) # TODO: figure out the dataset -> array workflow # currently hardcoding a variable name diff --git a/xbatcher/tests/test_torch_loaders.py b/xbatcher/tests/test_torch_loaders.py index 2741b8f..4e4a412 100644 --- a/xbatcher/tests/test_torch_loaders.py +++ b/xbatcher/tests/test_torch_loaders.py @@ -39,6 +39,15 @@ def test_map_dataset(ds_xy): assert len(x_batch) == len(y_batch) assert isinstance(x_batch, torch.Tensor) + idx = torch.tensor([0]) + x_batch, y_batch = dataset[idx] + assert len(x_batch) == len(y_batch) + assert isinstance(x_batch, torch.Tensor) + + with pytest.raises(NotImplementedError): + idx = torch.tensor([0, 1]) + x_batch, y_batch = dataset[idx] + # test __len__ assert len(dataset) == len(x_gen) @@ -55,6 +64,30 @@ def test_map_dataset(ds_xy): assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :]) +def test_map_dataset_with_transform(ds_xy): + + x = ds_xy['x'] + y = ds_xy['y'] + + x_gen = BatchGenerator(x, {'sample': 10}) + y_gen = BatchGenerator(y, {'sample': 10}) + + def x_transform(batch): + return batch * 0 + 1 + + def y_transform(batch): + return batch * 0 - 1 + + dataset = MapDataset( + x_gen, y_gen, transform=x_transform, target_transform=y_transform + ) + x_batch, y_batch = dataset[0] + assert len(x_batch) == len(y_batch) + assert isinstance(x_batch, torch.Tensor) + assert (x_batch == 1).all() + assert (y_batch == -1).all() + + def test_iterable_dataset(ds_xy): x = ds_xy['x'] From 69909b42c75296e57fc3aa790ab0c618fada0c0f Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 23 Feb 2022 09:36:59 -0800 Subject: [PATCH 7/8] update pre-commit --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3a8fef5..592d965 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: double-quote-string-fixer - repo: https://github.com/psf/black - rev: 21.12b0 + rev: 22.1.0 hooks: - id: black args: ["--line-length", "80", "--skip-string-normalization"] @@ -39,7 +39,7 @@ repos: language_version: system - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.910 + rev: v0.931 hooks: - id: mypy additional_dependencies: [ From 8bcd870316f1131a19936eb979b4508d77e74253 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 23 Feb 2022 09:48:21 -0800 Subject: [PATCH 8/8] update docs --- doc/api.rst | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index f400b2c..f9f424c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -5,12 +5,6 @@ API reference This page provides an auto-generated summary of Xbatcher's API. -Core -==== - -.. autoclass:: xbatcher.BatchGenerator - :members: - Dataset.batch and DataArray.batch ================================= @@ -22,3 +16,17 @@ Dataset.batch and DataArray.batch Dataset.batch.generator DataArray.batch.generator + +Core +==== + +.. autoclass:: xbatcher.BatchGenerator + :members: + +Dataloaders +=========== +.. autoclass:: xbatcher.loaders.torch.MapDataset + :members: + +.. autoclass:: xbatcher.loaders.torch.IterableDataset + :members: