Skip to content

Commit

Permalink
Merge pull request #63 from meghanrjones/slice-exceptions
Browse files Browse the repository at this point in the history
Use exceptions rather than assert statements for generator
  • Loading branch information
Joe Hamman authored May 18, 2022
2 parents f8c45db + 34c4c1a commit 05b978c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
14 changes: 12 additions & 2 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ def _as_xarray_dataset(ds):

def _slices(dimsize, size, overlap=0):
# return a list of slices to chop up a single dimension
if overlap >= size:
raise ValueError(
'input overlap must be less than the input sample length, but '
f'the input sample length is {size} and the overlap is {overlap}'
)
slices = []
stride = size - overlap
assert stride > 0
assert stride <= dimsize
for start in range(0, dimsize, stride):
end = start + size
if end <= dimsize:
Expand All @@ -34,6 +37,13 @@ def _iterate_through_dataset(ds, dims, overlap={}):
dimsize = ds.dims[dim]
size = dims[dim]
olap = overlap.get(dim, 0)
if size > dimsize:
raise ValueError(
'input sample length must be less than or equal to the '
f'dimension length, but the sample length of {size} '
f'is greater than the dimension length of {dimsize} '
f'for {dim}'
)
dim_slices.append(_slices(dimsize, size, olap))

for slices in itertools.product(*dim_slices):
Expand Down
13 changes: 13 additions & 0 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,16 @@ def test_preload_batch_true(sample_ds_1d):
for ds_batch in bg:
assert isinstance(ds_batch, xr.Dataset)
assert not ds_batch.chunks


def test_batch_exceptions(sample_ds_1d):
# ValueError when input_dim[dim] > ds.sizes[dim]
with pytest.raises(ValueError) as e:
BatchGenerator(sample_ds_1d, input_dims={'x': 110})
assert len(e) == 1
# ValueError when input_overlap[dim] > input_dim[dim]
with pytest.raises(ValueError) as e:
BatchGenerator(
sample_ds_1d, input_dims={'x': 10}, input_overlap={'x': 20}
)
assert len(e) == 1

0 comments on commit 05b978c

Please sign in to comment.