From 28bc38f857cd89295b33fb26df8eef5ac4035c86 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Tue, 5 Apr 2016 23:47:05 -0400 Subject: [PATCH] multidimensional groupby --- xarray/core/groupby.py | 24 ++++++++++++++++++++++-- xarray/test/test_dataarray.py | 25 ++++++++++++++++++------- xarray/test/test_dataset.py | 2 -- 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index d6daeb318ee..f47de953abc 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -102,8 +102,28 @@ def __init__(self, obj, group, squeeze=False, grouper=None): from .dataset import as_dataset if group.ndim != 1: - # TODO: remove this limitation? - raise ValueError('`group` must be 1 dimensional') + # try to stack the dims of the group into a single dim + # TODO: figure out how to exclude dimensions from the stacking + # (e.g. group over space dims but leave time dim intact) + orig_dims = group.dims + stacked_dim_name = 'stacked_' + '_'.join(orig_dims) + # the copy is necessary here + group = group.stack(**{stacked_dim_name: orig_dims}).copy() + # without it, an error is raised deep in pandas + ######################## + # xarray/core/groupby.py + # ---> 31 inverse, values = pd.factorize(ar, sort=True) + # pandas/core/algorithms.pyc in factorize(values, sort, order, na_sentinel, size_hint) + # --> 196 labels = table.get_labels(vals, uniques, 0, na_sentinel, True) + # pandas/hashtable.pyx in pandas.hashtable.Float64HashTable.get_labels (pandas/hashtable.c:10302)() + # pandas/hashtable.so in View.MemoryView.memoryview_cwrapper (pandas/hashtable.c:29882)() + # pandas/hashtable.so in View.MemoryView.memoryview.__cinit__ (pandas/hashtable.c:26251)() + # ValueError: buffer source array is read-only + ####################### + # seems related to + # https://github.com/pydata/pandas/issues/10043 + # https://github.com/pydata/pandas/pull/10070 + obj = obj.stack(**{stacked_dim_name: orig_dims}) if getattr(group, 'name', None) is None: raise ValueError('`group` must have a name') if not hasattr(group, 'dims'): diff --git a/xarray/test/test_dataarray.py b/xarray/test/test_dataarray.py index 7c5081c92ac..d0f83457e68 100644 --- a/xarray/test/test_dataarray.py +++ b/xarray/test/test_dataarray.py @@ -1244,6 +1244,17 @@ def test_groupby_first_and_last(self): expected = array # should be a no-op self.assertDataArrayIdentical(expected, actual) + def test_groupby_multidim(self): + array = DataArray([[0,1],[2,3]], + coords={'lon': (['ny','nx'], [[30,40],[40,50]] ), + 'lat': (['ny','nx'], [[10,10],[20,20]] ),}, + dims=['ny','nx']) + for dim, expected_sum in [ + ('lon', DataArray([0, 3, 3], coords={'lon': [30,40,50]})), + ('lat', DataArray([1,5], coords={'lat': [10,20]}))]: + actual_sum = array.groupby(dim).sum() + self.assertDataArrayIdentical(expected_sum, actual_sum) + def make_rolling_example_array(self): times = pd.date_range('2000-01-01', freq='1D', periods=21) values = np.random.random((21, 4)) @@ -1792,29 +1803,29 @@ def test_full_like(self): actual = _full_like(DataArray([1, 2, 3]), fill_value=np.nan) self.assertEqual(actual.dtype, np.float) np.testing.assert_equal(actual.values, np.nan) - + def test_dot(self): x = np.linspace(-3, 3, 6) y = np.linspace(-3, 3, 5) - z = range(4) + z = range(4) da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) da = DataArray(da_vals, coords=[x, y, z], dims=['x', 'y', 'z']) - + dm_vals = range(4) dm = DataArray(dm_vals, coords=[z], dims=['z']) - + # nd dot 1d actual = da.dot(dm) expected_vals = np.tensordot(da_vals, dm_vals, [2, 0]) expected = DataArray(expected_vals, coords=[x, y], dims=['x', 'y']) self.assertDataArrayEqual(expected, actual) - + # all shared dims actual = da.dot(da) expected_vals = np.tensordot(da_vals, da_vals, axes=([0, 1, 2], [0, 1, 2])) expected = DataArray(expected_vals) self.assertDataArrayEqual(expected, actual) - + # multiple shared dims dm_vals = np.arange(20 * 5 * 4).reshape((20, 5, 4)) j = np.linspace(-3, 3, 20) @@ -1823,7 +1834,7 @@ def test_dot(self): expected_vals = np.tensordot(da_vals, dm_vals, axes=([1, 2], [1, 2])) expected = DataArray(expected_vals, coords=[x, j], dims=['x', 'j']) self.assertDataArrayEqual(expected, actual) - + with self.assertRaises(NotImplementedError): da.dot(dm.to_dataset(name='dm')) with self.assertRaises(TypeError): diff --git a/xarray/test/test_dataset.py b/xarray/test/test_dataset.py index 5d27cce4b69..db65192ee0b 100644 --- a/xarray/test/test_dataset.py +++ b/xarray/test/test_dataset.py @@ -1545,8 +1545,6 @@ def test_groupby_iter(self): def test_groupby_errors(self): data = create_test_data() - with self.assertRaisesRegexp(ValueError, 'must be 1 dimensional'): - data.groupby('var1') with self.assertRaisesRegexp(ValueError, 'must have a name'): data.groupby(np.arange(10)) with self.assertRaisesRegexp(ValueError, 'length does not match'):