diff --git a/lib/iris/experimental/regrid.py b/lib/iris/experimental/regrid.py index 021d0ce3e2..fadefe4598 100644 --- a/lib/iris/experimental/regrid.py +++ b/lib/iris/experimental/regrid.py @@ -18,7 +18,6 @@ Regridding functions. """ - from __future__ import (absolute_import, division, print_function) from six.moves import (filter, input, map, range, zip) # noqa import six @@ -384,21 +383,22 @@ def _weighted_mean_with_mdtol(data, weights, axis=None, mdtol=0): Numpy array (possibly masked) or scalar. """ - res = ma.average(data, weights=weights, axis=axis) - if ma.isMaskedArray(data) and mdtol < 1: - weights_total = weights.sum(axis=axis) - masked_weights = weights.copy() - masked_weights[~ma.getmaskarray(data)] = 0 - masked_weights_total = masked_weights.sum(axis=axis) - frac_masked = np.true_divide(masked_weights_total, weights_total) - mask_pt = frac_masked > mdtol - if np.any(mask_pt): - if np.isscalar(res): - res = ma.masked - elif ma.isMaskedArray(res): - res.mask |= mask_pt - else: - res = ma.masked_array(res, mask=mask_pt) + if ma.is_masked(data): + res, unmasked_weights_sum = ma.average(data, weights=weights, + axis=axis, returned=True) + if mdtol < 1: + weights_sum = weights.sum(axis=axis) + frac_masked = 1 - np.true_divide(unmasked_weights_sum, weights_sum) + mask_pt = frac_masked > mdtol + if np.any(mask_pt): + if np.isscalar(res): + res = ma.masked + elif ma.isMaskedArray(res): + res.mask |= mask_pt + else: + res = ma.masked_array(res, mask=mask_pt) + else: + res = np.average(data, weights=weights, axis=axis) return res