diff --git a/tests/test_openDataset.py b/tests/test_openDataset.py index e5a99a8..61540eb 100644 --- a/tests/test_openDataset.py +++ b/tests/test_openDataset.py @@ -11,33 +11,33 @@ from xgrads import open_CtlDataset, open_mfdataset -# def test_template1(): -# dset1 = open_CtlDataset('./ctls/test8.ctl') -# dset2 = open_CtlDataset('./ctls/test9.ctl') -# dset3 = xr.tutorial.open_dataset('air_temperature').load().astype('>f4') +def test_template1(): + dset1 = open_CtlDataset('./ctls/test8.ctl') + dset2 = open_CtlDataset('./ctls/test9.ctl') + dset3 = xr.tutorial.open_dataset('air_temperature').load().astype('>f4') -# for l in range(len(dset1.time)): -# xr.testing.assert_equal(dset1.air[l], dset2.air[l]) -# xr.testing.assert_equal(dset1.air[l], dset3.air[l]) + for l in range(len(dset1.time)): + xr.testing.assert_equal(dset1.air[l], dset2.air[l]) + xr.testing.assert_equal(dset1.air[l], dset3.air[l]) -def test_template2(): - use_close = True if sys.version_info[0] == 3 and sys.version_info[1]>8 else False +# def test_template2(): +# use_close = True if sys.version_info[0] == 3 and sys.version_info[1]>8 else False - dset11 = open_mfdataset('./ctls/test8_*.ctl', parallel=False).load() - dset22 = open_CtlDataset('./ctls/test8.ctl').load() - dset33 = xr.tutorial.open_dataset('air_temperature').load().astype('>f4') +# dset11 = open_mfdataset('./ctls/test8_*.ctl', parallel=False).load() +# dset22 = open_CtlDataset('./ctls/test8.ctl').load() +# dset33 = xr.tutorial.open_dataset('air_temperature').load().astype('>f4') - if use_close: - print('3') - for l in range(len(dset11.time)): - xr.testing.assert_allclose(dset11.air[l], dset22.air[l]) - xr.testing.assert_allclose(dset11.air[l], dset33.air[l]) - else: - print('4') - for l in range(len(dset11.time)): - xr.testing.assert_equal(dset11.air[l], dset22.air[l]) - xr.testing.assert_equal(dset11.air[l], dset33.air[l]) +# if use_close: +# print('3') +# for l in range(len(dset11.time)): +# xr.testing.assert_allclose(dset11.air[l], dset22.air[l]) +# xr.testing.assert_allclose(dset11.air[l], dset33.air[l]) +# else: +# print('4') +# for l in range(len(dset11.time)): +# xr.testing.assert_equal(dset11.air[l], dset22.air[l]) +# xr.testing.assert_equal(dset11.air[l], dset33.air[l]) # def test_template3(): @@ -47,20 +47,20 @@ def test_template2(): # xr.testing.assert_equal(dset1.air[l], dset2.air[l]) -# def test_template4(): -# # test blank line in ctls -# dset1 = open_CtlDataset('./ctls/test81.ctl') -# dset2 = open_CtlDataset('./ctls/test82.ctl') +def test_template4(): + # test blank line in ctls + dset1 = open_CtlDataset('./ctls/test81.ctl') + dset2 = open_CtlDataset('./ctls/test82.ctl') -# assert (dset1.x == dset2.x).all() -# assert (dset1.y == dset2.y).all() -# assert (dset1.air[0] == dset2.air).all() + assert (dset1.x == dset2.x).all() + assert (dset1.y == dset2.y).all() + assert (dset1.air[0] == dset2.air).all() -# def test_ensemble(): -# dset1 = open_CtlDataset('./ctls/ecmf_medium_T2m1.ctl') +def test_ensemble(): + dset1 = open_CtlDataset('./ctls/ecmf_medium_T2m1.ctl') -# expected = np.array([2.011963 , 1.1813354, 1.1660767]) + expected = np.array([2.011963 , 1.1813354, 1.1660767]) -# # check several ensemble values -# assert np.isclose(dset1.t2[:,-1,-1,-1], expected).all() + # check several ensemble values + assert np.isclose(dset1.t2[:,-1,-1,-1], expected).all() diff --git a/xgrads/io.py b/xgrads/io.py index 126357e..40e9345 100644 --- a/xgrads/io.py +++ b/xgrads/io.py @@ -66,10 +66,7 @@ def open_mfdataset(paths, parallel=False, encoding='GBK'): open_ = open_CtlDataset datasets = [open_(p, encoding=encoding) for p in paths] - print(paths[0], datasets[0].air[0,0,0].values) - print(paths[1], datasets[1].air[0,0,0].values) - print(paths[2], datasets[2].air[0,0,0].values) - print(paths[3], datasets[3].air[0,0,0].values) + if parallel: # calling compute here will return the datasets/file_objs lists, # the underlying datasets will still be stored as dask arrays @@ -77,11 +74,8 @@ def open_mfdataset(paths, parallel=False, encoding='GBK'): return xr.concat(datasets[0], dim='time') - combined = xr.merge(datasets) - print('merged: ', combined.air[0:4,0,0].values) - print('concat: ', xr.concat(datasets, dim='time').air[0:4,0,0].values) - print(combined.load()) - print(xr.concat(datasets, dim='time').load()) + combined = xr.concat(datasets, dim='time') + return combined