From 374dec02ccee476b7677bebf69a91bdc8d95bc93 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Thu, 4 Apr 2019 16:33:24 -0400 Subject: [PATCH] add test --- xbatcher/tests/test_generators.py | 34 +++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 85104a4..d21891c 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -26,6 +26,40 @@ def test_batch_1d(sample_ds_1d, bsize): ds_batch_expected = sample_ds_1d.isel(x=expected_slice) assert ds_batch.equals(ds_batch_expected) +@pytest.mark.parametrize("bsize", [5, 10]) +def test_batch_1d_concat(sample_ds_1d, bsize): + bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize}, + concat_input_dims=True) + for n, ds_batch in enumerate(bg): + assert isinstance(ds_batch, xr.Dataset) + assert ds_batch.dims['x_input'] == bsize + assert ds_batch.dims['input_batch'] == sample_ds_1d.dims['x']//bsize + assert 'x' in ds_batch.coords + +@pytest.mark.parametrize("bsize", [5, 10]) +def test_batch_1d_no_coordinate(sample_ds_1d, bsize): + # fix for #3 + ds_dropped = sample_ds_1d.drop('x') + bg = BatchGenerator(ds_dropped, input_dims={'x': bsize}) + for n, ds_batch in enumerate(bg): + assert isinstance(ds_batch, xr.Dataset) + assert ds_batch.dims['x'] == bsize + expected_slice = slice(bsize*n, bsize*(n+1)) + ds_batch_expected = ds_dropped.isel(x=expected_slice) + assert ds_batch.equals(ds_batch_expected) + +@pytest.mark.parametrize("bsize", [5, 10]) +def test_batch_1d_concat_no_coordinate(sample_ds_1d, bsize): + # fix for #3 + ds_dropped = sample_ds_1d.drop('x') + bg = BatchGenerator(ds_dropped, input_dims={'x': bsize}, + concat_input_dims=True) + for n, ds_batch in enumerate(bg): + assert isinstance(ds_batch, xr.Dataset) + assert ds_batch.dims['x_input'] == bsize + assert ds_batch.dims['input_batch'] == sample_ds_1d.dims['x']//bsize + assert 'x' not in ds_batch.coords + @pytest.mark.parametrize("olap", [1, 4]) def test_batch_1d_overlap(sample_ds_1d, olap):