From a6fe1bd63a0ac798ae77282e2af0b833fb4c391e Mon Sep 17 00:00:00 2001 From: David Hassell Date: Tue, 1 Mar 2022 16:11:40 +0000 Subject: [PATCH 1/3] Data.flatten --- cf/data/data.py | 95 ++++++++++++++++++-------------------------- cf/test/test_Data.py | 1 - 2 files changed, 38 insertions(+), 58 deletions(-) diff --git a/cf/data/data.py b/cf/data/data.py index 068b309d90..bd27c8cf8d 100644 --- a/cf/data/data.py +++ b/cf/data/data.py @@ -9989,10 +9989,10 @@ def flat(self, ignore_masked=True): else: yield cf_masked + @daskified(_DASKIFIED_VERBOSE) + @_inplace_enabled(default=False) def flatten(self, axes=None, inplace=False): - """Flatten axes of the data. - - TODODASK - check against daask flatten behaviour + """Flatten specified axes of the data. Any subset of the axes may be flattened. @@ -10004,21 +10004,16 @@ def flatten(self, axes=None, inplace=False): .. versionadded:: 3.0.2 - .. seealso:: `compressed`, `insert_dimension`, `flip`, `swapaxes`, - `transpose` + .. seealso:: `compressed`, `flat`, `insert_dimension`, `flip`, + `swapaxes`, `transpose` :Parameters: - axes: (sequence of) int or str, optional - Select the axes. By default all axes are flattened. The - *axes* argument may be one, or a sequence, of: - - * An internal axis identifier. Selects this axis. - - * An integer. Selects the axis corresponding to the given - position in the list of axes of the data array. - - No axes are flattened if *axes* is an empty sequence. + axes: (sequence of) `int` + Select the axes to be flattened. By default all axes + are flattened. Each axis is identified by its integer + position. No axes are flattened if *axes* is an empty + sequence. {{inplace: `bool`, optional}} @@ -10030,7 +10025,8 @@ def flatten(self, axes=None, inplace=False): **Examples** - >>> d = cf.Data(numpy.arange(24).reshape(1, 2, 3, 4)) + >>> import numpy as np + >>> d = cf.Data(np.arange(24).reshape(1, 2, 3, 4)) >>> d >>> print(d.array) @@ -10078,27 +10074,18 @@ def flatten(self, axes=None, inplace=False): [15 19 23]]] """ - if inplace: - d = self - else: - d = self.copy() + d = _inplace_enabled_define_and_cleanup(self) - ndim = self._ndim + ndim = d.ndim if not ndim: if axes or axes == 0: raise ValueError( - "Can't flatten: Can't remove an axis from " - "scalar {}".format(self.__class__.__name__) + "Can't flatten: Can't remove an axes from " + f"scalar {self.__class__.__name__}" ) - if inplace: - d = None return d - shape = list(d._shape) - - # Note that it is important that the first axis in the list is - # the left-most flattened axis if axes is None: axes = list(range(ndim)) else: @@ -10106,39 +10093,33 @@ def flatten(self, axes=None, inplace=False): n_axes = len(axes) if n_axes <= 1: - if inplace: - d = None return d - new_shape = [n for i, n in enumerate(shape) if i not in axes] - new_shape.insert(axes[0], np.prod([shape[i] for i in axes])) - - out = d.empty(new_shape, dtype=d.dtype, units=d.Units, chunk=True) - out.hardmask = False - - n_non_flattened_axes = ndim - n_axes - - for key, data in d.section(axes).items(): - flattened_array = data.array.flatten() - size = flattened_array.size - - first_None_index = key.index(None) - - indices = [i for i in key if i is not None] - indices.insert(first_None_index, slice(0, size)) - - shape = [1] * n_non_flattened_axes - shape.insert(first_None_index, size) - - out[tuple(indices)] = flattened_array.reshape(shape) + dx = d._get_dask() - out.hardmask = True + # It is important that the first axis in the list is the + # left-most flattened axis. + # + # E.g. if the shape is (10, 20, 30, 40, 50, 60) and the axes + # to be flattened are [2, 4], then the data must be + # transposed with order [0, 1, 2, 4, 3, 5] + order = [i for i in range(ndim) if i not in axes] + order[axes[0] : axes[0]] = axes + dx = dx.transpose(order) + + # Find the flattened shape. + # + # E.g. if the *transposed* shape is (10, 20, 30, 50, 40, 60) + # and *transposed* axes [2, 3] are to be falttened then + # the new shape will be (10, 20, 1500, 40, 60) + shape = d.shape + new_shape = [n for i, n in enumerate(shape) if i not in axes] + new_shape.insert(axes[0], reduce(mul, [shape[i] for i in axes], 1)) - if inplace: - d.__dict__ = out.__dict__ - out = None + dx = dx.reshape(new_shape) + d._set_dask(dx, reset_mask_hardness=False) - return out + return d @daskified(_DASKIFIED_VERBOSE) @_deprecated_kwarg_check("i") diff --git a/cf/test/test_Data.py b/cf/test/test_Data.py index 53eed8b5e6..45543d2663 100644 --- a/cf/test/test_Data.py +++ b/cf/test/test_Data.py @@ -896,7 +896,6 @@ def test_Data_cumsum(self): e = d.cumsum(axis=i, masked_as_zero=False) self.assertTrue(cf.functions._numpy_allclose(e.array, b)) - @unittest.skipIf(TEST_DASKIFIED_ONLY, "no attribute '_ndim'") def test_Data_flatten(self): if self.test_only and inspect.stack()[0][3] not in self.test_only: return From 88123680d0181b35595338f57aff25fb09deb605 Mon Sep 17 00:00:00 2001 From: David Hassell Date: Wed, 9 Mar 2022 13:42:21 +0000 Subject: [PATCH 2/3] Typo Co-authored-by: Sadie L. Bartholomew --- cf/data/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cf/data/data.py b/cf/data/data.py index bd27c8cf8d..e51fb55f19 100644 --- a/cf/data/data.py +++ b/cf/data/data.py @@ -10110,7 +10110,7 @@ def flatten(self, axes=None, inplace=False): # Find the flattened shape. # # E.g. if the *transposed* shape is (10, 20, 30, 50, 40, 60) - # and *transposed* axes [2, 3] are to be falttened then + # and *transposed* axes [2, 3] are to be flattened then # the new shape will be (10, 20, 1500, 40, 60) shape = d.shape new_shape = [n for i, n in enumerate(shape) if i not in axes] From 1524048fbd02684bdb7719120024906de68f831c Mon Sep 17 00:00:00 2001 From: David Hassell Date: Wed, 9 Mar 2022 13:47:44 +0000 Subject: [PATCH 3/3] Typo --- cf/data/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cf/data/data.py b/cf/data/data.py index e51fb55f19..9be2568636 100644 --- a/cf/data/data.py +++ b/cf/data/data.py @@ -10080,7 +10080,7 @@ def flatten(self, axes=None, inplace=False): if not ndim: if axes or axes == 0: raise ValueError( - "Can't flatten: Can't remove an axes from " + "Can't flatten: Can't remove axes from " f"scalar {self.__class__.__name__}" )