Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
rabernat committed Apr 4, 2019
1 parent 1945462 commit 374dec0
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,40 @@ def test_batch_1d(sample_ds_1d, bsize):
ds_batch_expected = sample_ds_1d.isel(x=expected_slice)
assert ds_batch.equals(ds_batch_expected)

@pytest.mark.parametrize("bsize", [5, 10])
def test_batch_1d_concat(sample_ds_1d, bsize):
bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize},
concat_input_dims=True)
for n, ds_batch in enumerate(bg):
assert isinstance(ds_batch, xr.Dataset)
assert ds_batch.dims['x_input'] == bsize
assert ds_batch.dims['input_batch'] == sample_ds_1d.dims['x']//bsize
assert 'x' in ds_batch.coords

@pytest.mark.parametrize("bsize", [5, 10])
def test_batch_1d_no_coordinate(sample_ds_1d, bsize):
# fix for #3
ds_dropped = sample_ds_1d.drop('x')
bg = BatchGenerator(ds_dropped, input_dims={'x': bsize})
for n, ds_batch in enumerate(bg):
assert isinstance(ds_batch, xr.Dataset)
assert ds_batch.dims['x'] == bsize
expected_slice = slice(bsize*n, bsize*(n+1))
ds_batch_expected = ds_dropped.isel(x=expected_slice)
assert ds_batch.equals(ds_batch_expected)

@pytest.mark.parametrize("bsize", [5, 10])
def test_batch_1d_concat_no_coordinate(sample_ds_1d, bsize):
# fix for #3
ds_dropped = sample_ds_1d.drop('x')
bg = BatchGenerator(ds_dropped, input_dims={'x': bsize},
concat_input_dims=True)
for n, ds_batch in enumerate(bg):
assert isinstance(ds_batch, xr.Dataset)
assert ds_batch.dims['x_input'] == bsize
assert ds_batch.dims['input_batch'] == sample_ds_1d.dims['x']//bsize
assert 'x' not in ds_batch.coords


@pytest.mark.parametrize("olap", [1, 4])
def test_batch_1d_overlap(sample_ds_1d, olap):
Expand Down

0 comments on commit 374dec0

Please sign in to comment.