Skip to content

Commit

Permalink
Merge pull request #2730 from tv3141/speedup_AreaWeighted
Browse files Browse the repository at this point in the history
Speedup area weighted regridding
  • Loading branch information
pelson authored Oct 24, 2017
2 parents bfdc172 + a1b485e commit 5d5e3fb
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions lib/iris/experimental/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
Regridding functions.
"""

from __future__ import (absolute_import, division, print_function)
from six.moves import (filter, input, map, range, zip) # noqa
import six
Expand Down Expand Up @@ -384,21 +383,22 @@ def _weighted_mean_with_mdtol(data, weights, axis=None, mdtol=0):
Numpy array (possibly masked) or scalar.
"""
res = ma.average(data, weights=weights, axis=axis)
if ma.isMaskedArray(data) and mdtol < 1:
weights_total = weights.sum(axis=axis)
masked_weights = weights.copy()
masked_weights[~ma.getmaskarray(data)] = 0
masked_weights_total = masked_weights.sum(axis=axis)
frac_masked = np.true_divide(masked_weights_total, weights_total)
mask_pt = frac_masked > mdtol
if np.any(mask_pt):
if np.isscalar(res):
res = ma.masked
elif ma.isMaskedArray(res):
res.mask |= mask_pt
else:
res = ma.masked_array(res, mask=mask_pt)
if ma.is_masked(data):
res, unmasked_weights_sum = ma.average(data, weights=weights,
axis=axis, returned=True)
if mdtol < 1:
weights_sum = weights.sum(axis=axis)
frac_masked = 1 - np.true_divide(unmasked_weights_sum, weights_sum)
mask_pt = frac_masked > mdtol
if np.any(mask_pt):
if np.isscalar(res):
res = ma.masked
elif ma.isMaskedArray(res):
res.mask |= mask_pt
else:
res = ma.masked_array(res, mask=mask_pt)
else:
res = np.average(data, weights=weights, axis=axis)
return res


Expand Down

0 comments on commit 5d5e3fb

Please sign in to comment.