Skip to content

Commit

Permalink
Prototype keras data loader (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxrjones authored Aug 15, 2022
1 parent 213264c commit 0ded974
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 2 deletions.
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pytest
tensorflow
torch
coverage
pytest-cov
Expand Down
3 changes: 3 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ Dataloaders

.. autoclass:: xbatcher.loaders.torch.IterableDataset
:members:

.. autoclass:: xbatcher.loaders.keras.CustomTFDataset
:members:
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def setup(app):
app.connect('autodoc-skip-member', skip)


autodoc_mock_imports = ['torch']
autodoc_mock_imports = ['torch', 'tensorflow']

# link to github issues
extlinks = {
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ select = B,C,E,F,W,T4,B9

[isort]
known_first_party=xbatcher
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,torch,xarray
known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,tensorflow,torch,xarray
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
73 changes: 73 additions & 0 deletions xbatcher/loaders/keras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Any, Callable, Optional, Tuple

import tensorflow as tf
import xarray as xr

# Notes:
# This module includes one Keras dataset, which can be provided to model.fit().
# - The CustomTFDataset provides an indexable interface
# Assumptions made:
# - The dataset takes pre-configured X/y xbatcher generators (may not always want two generators in a dataset)


class CustomTFDataset(tf.keras.utils.Sequence):
def __init__(
self,
X_generator,
y_generator,
*,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
dim: str = 'new_dim',
) -> None:
'''
Keras Dataset adapter for Xbatcher
Parameters
----------
X_generator : xbatcher.BatchGenerator
y_generator : xbatcher.BatchGenerator
transform : callable, optional
A function/transform that takes in an array and returns a transformed version.
target_transform : callable, optional
A function/transform that takes in the target and transforms it.
dim : str, 'new_dim'
Name of dim to pass to :func:`xarray.concat` as the dimension
to concatenate all variables along.
'''
self.X_generator = X_generator
self.y_generator = y_generator
self.transform = transform
self.target_transform = target_transform
self.concat_dim = dim

def __len__(self) -> int:
return len(self.X_generator)

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
X_batch = tf.convert_to_tensor(
xr.concat(
(
self.X_generator[idx][key]
for key in list(self.X_generator[idx].keys())
),
self.concat_dim,
).data
)
y_batch = tf.convert_to_tensor(
xr.concat(
(
self.y_generator[idx][key]
for key in list(self.y_generator[idx].keys())
),
self.concat_dim,
).data
)

# TODO: Should the transformations be applied before tensor conversion?
if self.transform:
X_batch = self.transform(X_batch)

if self.target_transform:
y_batch = self.target_transform(y_batch)
return X_batch, y_batch
69 changes: 69 additions & 0 deletions xbatcher/tests/test_keras_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy as np
import pytest
import xarray as xr

tf = pytest.importorskip('tensorflow')

from xbatcher import BatchGenerator
from xbatcher.loaders.keras import CustomTFDataset


@pytest.fixture(scope='module')
def ds_xy():
n_samples = 100
n_features = 5
ds = xr.Dataset(
{
'x': (
['sample', 'feature'],
np.random.random((n_samples, n_features)),
),
'y': (['sample'], np.random.random(n_samples)),
},
)
return ds


def test_custom_dataset(ds_xy):

x = ds_xy['x']
y = ds_xy['y']

x_gen = BatchGenerator(x, {'sample': 10})
y_gen = BatchGenerator(y, {'sample': 10})

dataset = CustomTFDataset(x_gen, y_gen)

# test __getitem__
x_batch, y_batch = dataset[0]
assert len(x_batch) == len(y_batch)
assert tf.is_tensor(x_batch)
assert tf.is_tensor(y_batch)

# test __len__
assert len(dataset) == len(x_gen)


def test_custom_dataset_with_transform(ds_xy):

x = ds_xy['x']
y = ds_xy['y']

x_gen = BatchGenerator(x, {'sample': 10})
y_gen = BatchGenerator(y, {'sample': 10})

def x_transform(batch):
return batch * 0 + 1

def y_transform(batch):
return batch * 0 - 1

dataset = CustomTFDataset(
x_gen, y_gen, transform=x_transform, target_transform=y_transform
)
x_batch, y_batch = dataset[0]
assert len(x_batch) == len(y_batch)
assert tf.is_tensor(x_batch)
assert tf.is_tensor(y_batch)
assert tf.experimental.numpy.all(x_batch == 1)
assert tf.experimental.numpy.all(y_batch == -1)

0 comments on commit 0ded974

Please sign in to comment.