Skip to content

Commit

Permalink
exploring xarray accessors
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Hamman committed Feb 7, 2020
1 parent e885373 commit 95d9fe6
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
3 changes: 2 additions & 1 deletion xbatcher/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . generators import BatchGenerator
from .generators import BatchGenerator
from .accessors import BatchAccessor
13 changes: 13 additions & 0 deletions xbatcher/accessors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import xarray as xr

from .generators import BatchGenerator


@xr.register_dataarray_accessor("batch")
@xr.register_dataset_accessor("batch")
class BatchAccessor:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def generator(self, *args, **kwargs):
return BatchGenerator(self._obj, *args, **kwargs)
34 changes: 34 additions & 0 deletions xbatcher/tests/test_accessors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import xarray as xr
import numpy as np
import pytest

from xbatcher import BatchGenerator
import xbatcher


@pytest.fixture(scope='module')
def sample_ds_3d():
shape = (10, 50, 100)
ds = xr.Dataset({'foo': (['time', 'y', 'x'], np.random.rand(*shape)),
'bar': (['time', 'y', 'x'], np.random.randint(0, 10, shape))},
{'x': (['x'], np.arange(shape[-1])),
'y': (['y'], np.arange(shape[-2]))})
return ds


def test_batch_accessor_ds(sample_ds_3d):
bg_class = BatchGenerator(sample_ds_3d, input_dims={'x': 5})
bg_acc = sample_ds_3d.batch.generator(input_dims={'x': 5})
assert isinstance(bg_acc, BatchGenerator)
for batch_class, batch_acc in zip(bg_class, bg_acc):
assert isinstance(batch_acc, xr.Dataset)
assert batch_class.equals(batch_acc)


def test_batch_accessor_da(sample_ds_3d):
sample_da = sample_ds_3d['foo']
bg_class = BatchGenerator(sample_da, input_dims={'x': 5})
bg_acc = sample_da.batch.generator(input_dims={'x': 5})
assert isinstance(bg_acc, BatchGenerator)
for batch_class, batch_acc in zip(bg_class, bg_acc):
assert batch_class.equals(batch_acc)

0 comments on commit 95d9fe6

Please sign in to comment.