From e8853739d512d7fea45eb4d69bd8fa8163d9519e Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Thu, 18 Apr 2019 16:35:02 -0400 Subject: [PATCH] think I fixed it --- xbatcher/generators.py | 9 ++++++--- xbatcher/tests/test_generators.py | 3 +-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 8963973..ea3d664 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -41,10 +41,13 @@ def _drop_input_dims(ds, input_dims, suffix='_input'): # remove input_dims coordinates from datasets, rename the dimensions # then put intput_dims back in as coordinates out = ds.copy() - out = (out.drop(input_dims) - .rename({dim: dim + suffix for dim in input_dims})) for dim in input_dims: - out.coords[dim] = dim + suffix, ds[dim].values + newdim = dim + suffix + out = out.rename({dim: newdim}) + # extra steps needed if there is a coordinate + if newdim in out: + out = out.drop(newdim) + out.coords[dim] = newdim, ds[dim].data, ds[dim].attrs return out diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index d21891c..24ebc76 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -50,7 +50,7 @@ def test_batch_1d_no_coordinate(sample_ds_1d, bsize): @pytest.mark.parametrize("bsize", [5, 10]) def test_batch_1d_concat_no_coordinate(sample_ds_1d, bsize): - # fix for #3 + # test for #3 ds_dropped = sample_ds_1d.drop('x') bg = BatchGenerator(ds_dropped, input_dims={'x': bsize}, concat_input_dims=True) @@ -60,7 +60,6 @@ def test_batch_1d_concat_no_coordinate(sample_ds_1d, bsize): assert ds_batch.dims['input_batch'] == sample_ds_1d.dims['x']//bsize assert 'x' not in ds_batch.coords - @pytest.mark.parametrize("olap", [1, 4]) def test_batch_1d_overlap(sample_ds_1d, olap): bsize = 10