Skip to content

Commit

Permalink
Merge pull request #833 from davidhassell/regrid-dtos
Browse files Browse the repository at this point in the history
Allow 'nearest_dtos' 2-d regridding to work with discrete sampling geometry source grids
  • Loading branch information
davidhassell authored Nov 27, 2024
2 parents a66687f + 0759153 commit 2031dd8
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 21 deletions.
3 changes: 3 additions & 0 deletions Changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ version NEXTVERSION

**2024-??-??**

* Allow ``'nearest_dtos'`` 2-d regridding to work with discrete
sampling geometry source grids
(https://github.com/NCAS-CMS/cf-python/issues/832)
* New method: `cf.Field.filled`
(https://github.com/NCAS-CMS/cf-python/issues/811)
* New method: `cf.Field.is_discrete_axis`
Expand Down
52 changes: 45 additions & 7 deletions cf/data/dask_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,17 +507,20 @@ def _regrid(
# Note: It is much more efficient to access
# 'weights.indptr', 'weights.indices', and
# 'weights.data' directly, rather than iterating
# over rows of 'weights' and using 'weights.getrow'.
# over rows of 'weights' and using
# 'weights.getrow'. Also, 'np.count_nonzero' is much
# faster than 'np.any' and 'np.all'.
count_nonzero = np.count_nonzero
indptr = weights.indptr.tolist()
indices = weights.indices
data = weights.data
for j, (i0, i1) in enumerate(zip(indptr[:-1], indptr[1:])):
mask = src_mask[indices[i0:i1]]
if not count_nonzero(mask):
n_masked = count_nonzero(mask)
if not n_masked:
continue

if mask.all():
if n_masked == mask.size:
dst_mask[j] = True
continue

Expand All @@ -529,8 +532,8 @@ def _regrid(

del indptr

elif method in ("linear", "bilinear", "nearest_dtos"):
# 2) Linear and nearest neighbour methods:
elif method in ("linear", "bilinear"):
# 2) Linear methods:
#
# Mask out any row j that contains at least one positive
# (i.e. greater than or equal to 'min_weight') w_ji that
Expand All @@ -546,7 +549,9 @@ def _regrid(
# Note: It is much more efficient to access
# 'weights.indptr', 'weights.indices', and
# 'weights.data' directly, rather than iterating
# over rows of 'weights' and using 'weights.getrow'.
# over rows of 'weights' and using
# 'weights.getrow'. Also, 'np.count_nonzero' is much
# faster than 'np.any' and 'np.all'.
count_nonzero = np.count_nonzero
where = np.where
indptr = weights.indptr.tolist()
Expand All @@ -562,12 +567,45 @@ def _regrid(

del indptr, pos_data

elif method == "nearest_dtos":
# 3) Nearest neighbour dtos method:
#
# Set to 0 any weight that corresponds to a masked source
# grid cell.
#
# Mask out any row j for which all source grid cells are
# masked.
dst_size = weights.shape[0]
if dst_mask is None:
dst_mask = np.zeros((dst_size,), dtype=bool)
else:
dst_mask = dst_mask.copy()

# Note: It is much more efficient to access
# 'weights.indptr', 'weights.indices', and
# 'weights.data' directly, rather than iterating
# over rows of 'weights' and using
# 'weights.getrow'. Also, 'np.count_nonzero' is much
# faster than 'np.any' and 'np.all'.
count_nonzero = np.count_nonzero
indptr = weights.indptr.tolist()
indices = weights.indices
for j, (i0, i1) in enumerate(zip(indptr[:-1], indptr[1:])):
mask = src_mask[indices[i0:i1]]
n_masked = count_nonzero(mask)
if n_masked == mask.size:
dst_mask[j] = True
elif n_masked:
weights.data[np.arange(i0, i1)[mask]] = 0

del indptr

elif method in (
"patch",
"conservative_2nd",
"nearest_stod",
):
# 3) Patch recovery and second-order conservative methods:
# 4) Patch recovery and second-order conservative methods:
#
# A reference source data mask has already been
# incorporated into the weights matrix, and 'a' is assumed
Expand Down
4 changes: 3 additions & 1 deletion cf/docstring/docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@
mapped to the closest destination point. A
destination point can be mapped to multiple source
points. Some destination points may not be
mapped. Useful for regridding of categorical data.
mapped. Each regridded value is the sum of its
contributing source elements. Useful for binning or
for categorical data.
* `None`: This is the default and can only be used
when *dst* is a `RegridOperator`.""",
Expand Down
18 changes: 13 additions & 5 deletions cf/regrid/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,11 +523,10 @@ def regrid(
"are a UGRID mesh"
)

if src_grid.is_locstream or dst_grid.is_locstream:
if dst_grid.is_locstream:
raise ValueError(
f"{method!r} regridding is (at the moment) only available "
"when neither the source and destination grids are "
"DSG featureTypes."
f"{method!r} regridding is (at the moment) not available "
"when the destination grid is a DSG featureType."
)

elif cartesian and (src_grid.is_mesh or dst_grid.is_mesh):
Expand Down Expand Up @@ -656,6 +655,7 @@ def regrid(
dst=dst,
weights_file=weights_file if from_file else None,
src_mesh_location=src_grid.mesh_location,
src_featureType=src_grid.featureType,
dst_featureType=dst_grid.featureType,
src_z=src_grid.z,
dst_z=dst_grid.z,
Expand All @@ -674,6 +674,9 @@ def regrid(
)

if return_operator:
# Note: The `RegridOperator.tosparse` method will also set
# 'dst_mask' to False for destination points with all
# zero weights.
regrid_operator.tosparse()
return regrid_operator

Expand Down Expand Up @@ -1279,7 +1282,7 @@ def spherical_grid(

# Set cyclicity of X axis
if mesh_location or featureType:
cyclic = None
cyclic = False
elif cyclic is None:
cyclic = f.iscyclic(x_axis)
else:
Expand Down Expand Up @@ -2281,6 +2284,11 @@ def create_esmpy_locstream(grid, mask=None):
# but the esmpy mask requires 0/1 for masked/unmasked
# elements.
mask = np.invert(mask).astype("int32")
if mask.size == 1:
# Make sure that there's a mask element for each point in
# the locstream (rather than a scalar that applies to all
# elements).
mask = np.full((location_count,), mask, dtype="int32")
else:
# No masked points
mask = np.full((location_count,), 1, dtype="int32")
Expand Down
81 changes: 73 additions & 8 deletions cf/test/test_regrid_featureType.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@
except ImportError:
pass

disallowed_methods = (
"conservative",
"conservative_2nd",
"nearest_dtos",
)

methods = (
"linear",
"nearest_stod",
Expand Down Expand Up @@ -169,6 +163,78 @@ def test_Field_regrid_grid_to_featureType_3d(self):
else:
self.assertFalse(y.mask.any())

@unittest.skipUnless(esmpy_imported, "Requires esmpy/ESMF package.")
def test_Field_regrid_featureType_to_grid_2d(self):
self.assertFalse(cf.regrid_logging())

# Create some nice data
src = self.dst_featureType
src.del_construct("cellmethod0")
src = src[:12]
src[...] = 273 + np.arange(12)
x = src.coord("X")
x[...] = [4, 6, 9, 11, 14, 16, 4, 6, 9, 11, 14, 16]
y = src.coord("Y")
y[...] = [41, 41, 31, 31, 21, 21, 39, 39, 29, 29, 19, 19]

dst = self.src_grid.copy()
x = dst.coord("X")
x[...] = [5, 10, 15, 20]
y = dst.coord("Y")
y[...] = [10, 20, 30, 40]

# Mask some destination grid points
dst[0, 0, 1, 2] = cf.masked

# Expected destination regridded values
y0 = np.ma.array(
[[0, 0, 0, 0], [0, 0, 1122, 0], [0, 1114, 0, 0], [1106, 0, 0, 0]],
mask=[
[True, True, True, True],
[True, True, False, True],
[True, False, True, True],
[False, True, True, True],
],
)

for src_masked in (False, True):
y = y0.copy()
if src_masked:
src = src.copy()
src[6:8] = cf.masked
# This following element should be smaller, because it
# now only has two source cells contributing to it,
# rather than four.
y[3, 0] = 547

# Loop over whether or not to use the destination grid
# masked points
for use_dst_mask in (False, True):
if use_dst_mask:
y = y.copy()
y[1, 2] = np.ma.masked

kwargs = {"use_dst_mask": use_dst_mask}
method = "nearest_dtos"
for return_operator in (False, True):
if return_operator:
r = src.regrids(
dst, method=method, return_operator=True, **kwargs
)
x = src.regrids(r)
else:
x = src.regrids(dst, method=method, **kwargs)

a = x.array

self.assertEqual(y.size, a.size)
self.assertTrue(np.allclose(y, a, atol=atol, rtol=rtol))

if isinstance(a, np.ma.MaskedArray):
self.assertTrue((y.mask == a.mask).all())
else:
self.assertFalse(y.mask.any())

@unittest.skipUnless(esmpy_imported, "Requires esmpy/ESMF package.")
def test_Field_regrid_grid_to_featureType_2d(self):
self.assertFalse(cf.regrid_logging())
Expand Down Expand Up @@ -196,7 +262,6 @@ def test_Field_regrid_grid_to_featureType_2d(self):
a = x.array

y = esmpy_regrid(coord_sys, method, src, dst, **kwargs)

self.assertEqual(y.size, a.size)
self.assertTrue(np.allclose(y, a, atol=atol, rtol=rtol))

Expand Down Expand Up @@ -259,7 +324,7 @@ def test_Field_regrid_featureType_bad_methods(self):
dst = self.dst_featureType.copy()
src = self.src_grid.copy()

for method in disallowed_methods:
for method in ("conservative", "conservative_2nd"):
with self.assertRaises(ValueError):
src.regrids(dst, method=method)

Expand Down

0 comments on commit 2031dd8

Please sign in to comment.