Skip to content

Commit

Permalink
think I fixed it
Browse files Browse the repository at this point in the history
  • Loading branch information
rabernat committed Apr 18, 2019
1 parent 374dec0 commit e885373
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
9 changes: 6 additions & 3 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,13 @@ def _drop_input_dims(ds, input_dims, suffix='_input'):
# remove input_dims coordinates from datasets, rename the dimensions
# then put intput_dims back in as coordinates
out = ds.copy()
out = (out.drop(input_dims)
.rename({dim: dim + suffix for dim in input_dims}))
for dim in input_dims:
out.coords[dim] = dim + suffix, ds[dim].values
newdim = dim + suffix
out = out.rename({dim: newdim})
# extra steps needed if there is a coordinate
if newdim in out:
out = out.drop(newdim)
out.coords[dim] = newdim, ds[dim].data, ds[dim].attrs
return out


Expand Down
3 changes: 1 addition & 2 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_batch_1d_no_coordinate(sample_ds_1d, bsize):

@pytest.mark.parametrize("bsize", [5, 10])
def test_batch_1d_concat_no_coordinate(sample_ds_1d, bsize):
# fix for #3
# test for #3
ds_dropped = sample_ds_1d.drop('x')
bg = BatchGenerator(ds_dropped, input_dims={'x': bsize},
concat_input_dims=True)
Expand All @@ -60,7 +60,6 @@ def test_batch_1d_concat_no_coordinate(sample_ds_1d, bsize):
assert ds_batch.dims['input_batch'] == sample_ds_1d.dims['x']//bsize
assert 'x' not in ds_batch.coords


@pytest.mark.parametrize("olap", [1, 4])
def test_batch_1d_overlap(sample_ds_1d, olap):
bsize = 10
Expand Down

0 comments on commit e885373

Please sign in to comment.