From 34c4c1a8ffc014e2726b16018aff0776778c2fda Mon Sep 17 00:00:00 2001 From: Max Jones Date: Wed, 11 May 2022 18:05:56 -0400 Subject: [PATCH] Use exceptions rather than assert statements for generator --- xbatcher/generators.py | 14 ++++++++++++-- xbatcher/tests/test_generators.py | 13 +++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index da80995..ecf827b 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -17,10 +17,13 @@ def _as_xarray_dataset(ds): def _slices(dimsize, size, overlap=0): # return a list of slices to chop up a single dimension + if overlap >= size: + raise ValueError( + 'input overlap must be less than the input sample length, but ' + f'the input sample length is {size} and the overlap is {overlap}' + ) slices = [] stride = size - overlap - assert stride > 0 - assert stride <= dimsize for start in range(0, dimsize, stride): end = start + size if end <= dimsize: @@ -34,6 +37,13 @@ def _iterate_through_dataset(ds, dims, overlap={}): dimsize = ds.dims[dim] size = dims[dim] olap = overlap.get(dim, 0) + if size > dimsize: + raise ValueError( + 'input sample length must be less than or equal to the ' + f'dimension length, but the sample length of {size} ' + f'is greater than the dimension length of {dimsize} ' + f'for {dim}' + ) dim_slices.append(_slices(dimsize, size, olap)) for slices in itertools.product(*dim_slices): diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 23f9448..f36eeea 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -211,3 +211,16 @@ def test_preload_batch_true(sample_ds_1d): for ds_batch in bg: assert isinstance(ds_batch, xr.Dataset) assert not ds_batch.chunks + + +def test_batch_exceptions(sample_ds_1d): + # ValueError when input_dim[dim] > ds.sizes[dim] + with pytest.raises(ValueError) as e: + BatchGenerator(sample_ds_1d, input_dims={'x': 110}) + assert len(e) == 1 + # ValueError when input_overlap[dim] > input_dim[dim] + with pytest.raises(ValueError) as e: + BatchGenerator( + sample_ds_1d, input_dims={'x': 10}, input_overlap={'x': 20} + ) + assert len(e) == 1