Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dask: Deal with some (but not all) TODODASK placeolders #541

Merged
merged 4 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cf/data/dask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ def _da_ma_allclose(x, y, masked_equal=True, rtol=None, atol=None):
the corresponding NumPy method (see the `numpy.ma.allclose` API
reference).

TODODASK: put in a PR to Dask to request to add as genuine method.

.. versionadded:: 4.0.0
.. versionadded:: 3.14.0

:Parameters:

Expand All @@ -58,6 +56,8 @@ def _da_ma_allclose(x, y, masked_equal=True, rtol=None, atol=None):
the given *rtol* and *atol* tolerance.

"""
# TODODASK: put in a PR to Dask to request to add as genuine method.

if rtol is None:
rtol = cf_rtol()
if atol is None:
Expand Down
12 changes: 3 additions & 9 deletions cf/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@
from .mixin import DataClassDeprecationsMixin
from .utils import (
YMDhms,
_is_numeric_dtype,
conform_units,
convert_to_datetime,
convert_to_reftime,
first_non_missing_value,
is_numeric_dtype,
new_axis_identifier,
scalar_masked_array,
)
Expand Down Expand Up @@ -7128,8 +7128,8 @@ def equals(
# We assume that all inputs are masked arrays. Note we compare the
# data first as this may return False due to different dtype without
# having to wait until the compute call.
self_is_numeric = _is_numeric_dtype(self_dx)
other_is_numeric = _is_numeric_dtype(other_dx)
self_is_numeric = is_numeric_dtype(self_dx)
other_is_numeric = is_numeric_dtype(other_dx)
if self_is_numeric and other_is_numeric:
data_comparison = _da_ma_allclose(
self_dx,
Expand Down Expand Up @@ -10201,8 +10201,6 @@ def squeeze(self, axes=None, inplace=False, i=False):
"""
d = _inplace_enabled_define_and_cleanup(self)

# TODODASK - check if axis parsing is done in dask

if not d.ndim:
if axes or axes == 0:
raise ValueError(
Expand Down Expand Up @@ -10757,10 +10755,6 @@ def func(

dx = d.to_dask_array()

# TODODASK: Steps to preserve invalid values shown, taking same
# approach as pre-daskification, but maybe we can now change approach
# to avoid finding mask and data, which requires early compute...
# Step 1. extract the non-masked data and the mask separately
if preserve_invalid:
# Assume all inputs are masked, as checking for a mask to confirm
# is expensive. If unmasked, effective mask will be all False.
Expand Down
2 changes: 1 addition & 1 deletion cf/data/fragment/abstract/fragmentarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _parse_indices(self, indices):
continue

if isinstance(i, Integral) or not getattr(i, "ndim", True):
# TODODASK: what about [] or np.array([])?
# TODOCFA: what about [] or np.array([])?

# 'i' is an integer or a scalar numpy/dask array
raise ValueError(
Expand Down
20 changes: 9 additions & 11 deletions cf/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
_units_None = Units(None)


def _is_numeric_dtype(array):
def is_numeric_dtype(array):
"""True if the given array is of a numeric or boolean data type.

.. versionadded:: 4.0.0
.. versionadded:: 3.14.0

:Parameters:

Expand All @@ -36,29 +36,27 @@ def _is_numeric_dtype(array):
**Examples**

>>> a = np.array([0, 1, 2])
>>> cf.data.utils._is_numeric_dtype(a)
>>> cf.data.utils.is_numeric_dtype(a)
True
>>> a = np.array([False, True, True])
>>> cf.data.utils._is_numeric_dtype(a)
>>> cf.data.utils.is_numeric_dtype(a)
True
>>> a = np.array(["a", "b", "c"], dtype="S1")
>>> cf.data.utils._is_numeric_dtype(a)
>>> cf.data.utils.is_numeric_dtype(a)
False
>>> a = np.ma.array([10.0, 2.0, 3.0], mask=[1, 0, 0])
>>> cf.data.utils._is_numeric_dtype(a)
>>> cf.data.utils.is_numeric_dtype(a)
True
>>> a = np.array(10)
>>> cf.data.utils._is_numeric_dtype(a)
>>> cf.data.utils.is_numeric_dtype(a)
True
>>> a = np.empty(1, dtype=object)
>>> cf.data.utils._is_numeric_dtype(a)
>>> cf.data.utils.is_numeric_dtype(a)
False

"""
# TODODASK: do we need to make any specific checks relating to ways of
# encoding datetimes, which could be encoded as strings, e.g. as in
# "2000-12-3 12:00", yet could be considered, or encoded as, numeric?
dtype = array.dtype

# This checks if the dtype is either a standard "numeric" type (i.e.
# int types, floating point types or complex floating point types)
# or Boolean, which are effectively a restricted int type (0 or 1).
Expand Down
3 changes: 2 additions & 1 deletion cf/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,8 @@ def roll(self, axis, shift, inplace=False):
>>> f.roll('X', -3)

"""
# TODODASK - allow multiple roll axes
# TODODASK: Consider allowing multiple roll axes, now that
# Data supports them.

axis = self.domain_axis(
axis,
Expand Down
19 changes: 6 additions & 13 deletions cf/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -8865,11 +8865,6 @@ def collapse(

d = aux[0]

# TODODASK: remove once dask. For some reason,
# without this we now get LAMA related failures in
# Partition.nbytes ...
_ = aux.dtype

if aux.has_bounds() or (aux[:-1] != aux[1:]).any():
logger.info(
f" Removing {aux.construct_type} {key!r}"
Expand Down Expand Up @@ -13759,8 +13754,7 @@ def percentile(
# Initialise the output field with the percentile data
# ------------------------------------------------------------

# TODODASK: Make sure that this is OK whaen `ranks` is a
# scalar
# TODODASK: Make sure that this is OK when `ranks` is a scalar

out = type(self)()
out.set_properties(self.properties())
Expand Down Expand Up @@ -14219,7 +14213,8 @@ def roll(self, axis, shift, inplace=False, i=False, **kwargs):
>>> f.roll('X', -3)

"""
# TODODASK - allow multiple roll axes
# TODODASK: Consider allowing multiple roll axes, since Data
# now supports them.

axis = self.domain_axis(
axis,
Expand All @@ -14238,8 +14233,9 @@ def roll(self, axis, shift, inplace=False, i=False, **kwargs):

iaxes = self._axis_positions(axis, parse=False)
if iaxes:
# TODODASK - remove these two lines when multiaxis rolls
# are allowed at 3.14.0
# TODODASK: Remove these two lines if multiaxis rolls are
# allowed

iaxis = iaxes[0]
shift = shift[0]

Expand Down Expand Up @@ -14833,7 +14829,6 @@ def section(self, axes=None, stop=None, min_step=1, **kwargs):
<CF Field: eastward_wind(model_level_number(1), latitude(145), longitude(192)) m s-1>]

"""

# TODODASK: This still need some attention, keyword checking,
# testing, docs, etc., but has been partially
# already updated due to changes already happening
Expand Down Expand Up @@ -15414,7 +15409,6 @@ def regrids(
# Retrieve the destination field's mask if appropriate
dst_mask = None
if dst_field and use_dst_mask and dst.data.ismasked:
# TODODASK: Just get the mask?
dst_mask = regrid_get_destination_mask(
dst, dst_order, axes=dst_axis_keys
)
Expand Down Expand Up @@ -16168,7 +16162,6 @@ def regridc(
# Retrieve the destination field's mask if appropriate
dst_mask = None
if not dst_dict and use_dst_mask and dst.data.ismasked:
# TODODASK: Just get the mask?
dst_mask = regrid_get_destination_mask(
dst,
dst_order,
Expand Down
8 changes: 4 additions & 4 deletions cf/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2921,10 +2921,10 @@ def _section(x, axes=None, stop=None, chunks=False, min_step=1):
ndim = x.ndim
shape = x.shape

# TODODASK: For v4.0.0, redefine axes by removing the next
# line. I.e. the specified axes would be those that you
# want to be chopped, not those that you want to remain
# whole.
# TODODASK: For v4.0.0, consider redefining the axes by removing
# the next line. I.e. the specified axes would be those
# that you want to be chopped, not those that you want
# to remain whole.
axes = [i for i in range(ndim) if i not in axes]

indices = [
Expand Down
6 changes: 3 additions & 3 deletions cf/mixin/fielddomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def _indices(self, mode, data_axes, ancillary_mask, kwargs):

if envelope or full:
size = domain_axes[axis].get_size()
# TODODASK - consider using dask.arange here
# TODODASK: consider using dask.arange here
d = np.arange(size) # self._Data(range(size))
ind = (d[value],) # .array,)
index = slice(None)
Expand Down Expand Up @@ -735,8 +735,8 @@ def _roll_constructs(self, axis, shift):
# This construct does not span the roll axes
continue

# TODODASK - remove these two lines when multiaxis rolls
# are allowed at v4.0.0
# TODODASK: Consider removing these two lines, now that
# multiaxis rolls are allowed on Data objects.
c_axes = c_axes[0]
c_shifts = c_shifts[0]

Expand Down
8 changes: 4 additions & 4 deletions cf/test/test_Data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,22 @@ def test_Data_Utils__da_ma_allclose(self):
b2 = a / 10000
self.assertTrue(allclose(b1, b2, atol=1e-05).compute())

def test_Data_Utils__is_numeric_dtype(self):
def test_Data_Utils_is_numeric_dtype(self):
"""TODO."""
_is_numeric_dtype = cf.data.utils._is_numeric_dtype
is_numeric_dtype = cf.data.utils.is_numeric_dtype
for a in [
np.array([0, 1, 2]),
np.array([False, True, True]),
np.ma.array([10.0, 2.0, 3.0], mask=[1, 0, 0]),
np.array(10),
]:
self.assertTrue(_is_numeric_dtype(a))
self.assertTrue(is_numeric_dtype(a))

for b in [
np.array(["a", "b", "c"], dtype="S1"),
np.empty(1, dtype=object),
]:
self.assertFalse(_is_numeric_dtype(b))
self.assertFalse(is_numeric_dtype(b))

def test_Data_Utils_convert_to_datetime(self):
"""TODO."""
Expand Down