Skip to content

Commit

Permalink
Merge pull request #320 from davidhassell/dask-contains
Browse files Browse the repository at this point in the history
dask: `Data.__contains__`
  • Loading branch information
davidhassell authored Mar 2, 2022
2 parents b195c02 + 77a52f3 commit f130048
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 29 deletions.
26 changes: 26 additions & 0 deletions cf/data/dask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,32 @@ def allclose(a_blocks, b_blocks, rtol=rtol, atol=atol):
)


def cf_contains(a, value):
"""Whether or not an array contains a value.
.. versionadded:: TODODASK
.. seealso:: `cf.Data.__contains__`
:Parameters:
a: `numpy.ndarray`
The array.
value: array_like
The value.
:Returns:
`numpy.ndarray`
A size 1 Boolean array, with the same number of dimensions
as *a*, that indicates whether or not *a* contains the
value.
"""
return np.array(value in a).reshape((1,) * a.ndim)


try:
from scipy.ndimage import convolve1d
except ImportError:
Expand Down
83 changes: 66 additions & 17 deletions cf/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
)
from .dask_utils import (
_da_ma_allclose,
cf_contains,
cf_dt2rt,
cf_harden_mask,
cf_percentile,
Expand Down Expand Up @@ -650,44 +651,90 @@ def __contains__(self, value):
x.__contains__(y) <==> y in x
Returns True if the value is contained anywhere in the data
array. The value may be a `cf.Data` object.
Returns True if the scalar *value* is contained anywhere in
the data. If *value* is not scalar then an exception is
raised.
**Performance**
All delayed operations are exectued, and there is no
short-circuit once the first occurrence is found.
`__contains__` causes all delayed operations to be computed
unless *value* is a `Data` object with incompatible units, in
which case `False` is always returned.
**Examples:**
**Examples**
>>> d = cf.Data([[0.0, 1, 2], [3, 4, 5]], 'm')
>>> d = cf.Data([[0, 1, 2], [3, 4, 5]], 'm')
>>> 4 in d
True
>>> cf.Data(3) in d
>>> 4.0 in d
True
>>> cf.Data([2.5], units='2 m') in d
>>> cf.Data(5) in d
True
>>> [[2]] in d
>>> cf.Data(5, 'm') in d
True
>>> numpy.array([[[2]]]) in d
>>> cf.Data(0.005, 'km') in d
True
>>> Data(2, 'seconds') in d
>>> 99 in d
False
>>> cf.Data(2, 'seconds') in d
False
"""
>>> [1] in d
Traceback (most recent call last):
...
TypeError: elementwise comparison failed; must test against a scalar, not [1]
>>> [1, 2] in d
Traceback (most recent call last):
...
TypeError: elementwise comparison failed; must test against a scalar, not [1, 2]
def contains_chunk(a, value):
out = value in a
return np.array(out).reshape((1,) * a.ndim)
>>> d = cf.Data(["foo", "bar"])
>>> 'foo' in d
True
>>> 'xyz' in d
False
"""
# Check that value is scalar by seeing if its shape is ()
shape = getattr(value, "shape", None)
if shape is None:
if isinstance(value, str):
# Strings are scalars, even though they have a len().
shape = ()
else:
try:
len(value)
except TypeError:
# value has no len() so assume that it is a scalar
shape = ()
else:
# value has a len() so assume that it is not a scalar
shape = True
elif is_dask_collection(value) and math.isnan(value.size):
# value is a dask array with unknown size, so calculate
# the size. This is acceptable, as we're going to compute
# it anyway at the end of this method.
value.compute_chunk_sizes()
shape = value.shape

if shape:
raise TypeError(
"elementwise comparison failed; must test against a scalar, "
f"not {value!r}"
)

if isinstance(value, self.__class__): # TODDASK chek aother type stoo
# If value is a scalar Data object then conform its units
if isinstance(value, self.__class__):
self_units = self.Units
value_units = value.Units
if value_units.equivalent(self_units):
if not value_units.equals(self_units):
value = value.copy()
value.Units = self_units
elif value_units:
# No need to check the dask array if the value units
# are incompatible
return False

value = value._get_dask()
Expand All @@ -698,10 +745,12 @@ def contains_chunk(a, value):
dx_ind = out_ind

dx = da.blockwise(
partial(contains_chunk, value=value),
cf_contains,
out_ind,
dx,
dx_ind,
value,
(),
adjust_chunks={i: 1 for i in out_ind},
dtype=bool,
)
Expand Down
64 changes: 52 additions & 12 deletions cf/test/test_Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from functools import reduce
from operator import mul

import dask.array as da
import numpy as np

SCIPY_AVAILABLE = False
Expand Down Expand Up @@ -1094,21 +1095,60 @@ def test_Data_AUXILIARY_MASK(self):
self.assertEqual(f.shape, fm.shape)
self.assertTrue((f._auxiliary_mask_return().array == fm).all())

@unittest.skipIf(TEST_DASKIFIED_ONLY, "TypeError: 'int' is not iterable")
def test_Data___contains__(self):
def test_Data__contains__(self):
if self.test_only and inspect.stack()[0][3] not in self.test_only:
return

d = cf.Data([[0.0, 1, 2], [3, 4, 5]], units="m")
self.assertIn(4, d)
self.assertNotIn(40, d)
self.assertIn(cf.Data(3), d)
self.assertIn(cf.Data([[[[3]]]]), d)
value = d[1, 2]
value.Units *= 2
value.squeeze(0)
self.assertIn(value, d)
self.assertIn(np.array([[[2]]]), d)
d = cf.Data([[0, 1, 2], [3, 4, 5]], units="m", chunks=2)

for value in (
4,
4.0,
cf.Data(3),
cf.Data(0.005, "km"),
np.array(2),
da.from_array(2),
):
self.assertIn(value, d)

for value in (
99,
np.array(99),
da.from_array(99),
cf.Data(99, "km"),
cf.Data(2, "seconds"),
):
self.assertNotIn(value, d)

for value in (
[1],
[[1]],
[1, 2],
[[1, 2]],
np.array([1]),
np.array([[1]]),
np.array([1, 2]),
np.array([[1, 2]]),
da.from_array([1]),
da.from_array([[1]]),
da.from_array([1, 2]),
da.from_array([[1, 2]]),
cf.Data([1]),
cf.Data([[1]]),
cf.Data([1, 2]),
cf.Data([[1, 2]]),
cf.Data([0.005], "km"),
):
with self.assertRaises(TypeError):
value in d

# Strings
d = cf.Data(["foo", "bar"])
self.assertIn("foo", d)
self.assertNotIn("xyz", d)

with self.assertRaises(TypeError):
["foo"] in d

@unittest.skipIf(TEST_DASKIFIED_ONLY, "no attr. 'partition_configuration'")
def test_Data_asdata(self):
Expand Down

0 comments on commit f130048

Please sign in to comment.