Skip to content

Commit

Permalink
Add RefinedGridInfo class (for potential Hovmoller/zonal mean support) (
Browse files Browse the repository at this point in the history
#165)

* add RefinedGridInfo class

* test for expanded lats

* lint fix

* add resolution to scheme

* fix tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* save/load resolution

* save/load resolution

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tests

* fix tests

* fix RefinedGridInfo source behaviour

* fix regridder loading

* fix tests

* add error tests

* add explanation to the mathematics.

* check bounds are strictly increasing

* address review comments

* address review comments

* address review comments

* address review comments

* fix test

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
stephenworsley and pre-commit-ci[bot] authored Mar 30, 2022
1 parent 2013e7b commit b4042d8
Show file tree
Hide file tree
Showing 13 changed files with 590 additions and 39 deletions.
194 changes: 188 additions & 6 deletions esmf_regrid/_esmf_sdo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import cartopy.crs as ccrs
import ESMF
import numpy as np
import scipy.sparse


class SDO(ABC):
Expand Down Expand Up @@ -35,6 +36,11 @@ def shape(self):
"""Return shape."""
return self._shape

@property
def _refined_shape(self):
"""Return shape passed to ESMF."""
return self._shape

@property
def dims(self):
"""Return number of dimensions."""
Expand All @@ -45,6 +51,11 @@ def size(self):
"""Return the number of cells in the sdo."""
return np.prod(self._shape)

@property
def _refined_size(self):
"""Return the number of cells passed to ESMF."""
return np.prod(self._refined_shape)

@property
def index_offset(self):
"""Return the index offset."""
Expand Down Expand Up @@ -170,7 +181,9 @@ def __init__(
shape = self.lons.shape

self.lonbounds = lonbounds
self._refined_lonbounds = lonbounds
self.latbounds = latbounds
self._refined_latbounds = latbounds
if crs is None:
self.crs = ccrs.Geodetic()
else:
Expand All @@ -185,26 +198,28 @@ def __init__(
)

def _as_esmf_info(self):
shape = np.array(self._shape)
shape = np.array(self._refined_shape)

londims = len(self.lons.shape)

if londims == 1:
if self.circular:
adjustedlonbounds = self.lonbounds[:-1]
adjustedlonbounds = self._refined_lonbounds[:-1]
else:
adjustedlonbounds = self.lonbounds
adjustedlonbounds = self._refined_lonbounds
centerlons, centerlats = np.meshgrid(self.lons, self.lats)
cornerlons, cornerlats = np.meshgrid(adjustedlonbounds, self.latbounds)
cornerlons, cornerlats = np.meshgrid(
adjustedlonbounds, self._refined_latbounds
)
elif londims == 2:
if self.circular:
slice = np.s_[:, :-1]
else:
slice = np.s_[:]
centerlons = self.lons[slice]
centerlats = self.lats[slice]
cornerlons = self.lonbounds[slice]
cornerlats = self.latbounds[slice]
cornerlons = self._refined_lonbounds[slice]
cornerlats = self._refined_latbounds[slice]

truecenters = ccrs.Geodetic().transform_points(self.crs, centerlons, centerlats)
truecorners = ccrs.Geodetic().transform_points(self.crs, cornerlons, cornerlats)
Expand Down Expand Up @@ -275,3 +290,170 @@ def _make_esmf_sdo(self):
grid_areas[:] = areas.T

return grid


class RefinedGridInfo(GridInfo):
"""
Class for handling structured grids represented in :mod:`ESMF` in higher resolution.
A specialised version of :class:`GridInfo`. Designed to provide higher
accuracy conservative regridding for rectilinear grids, especially those with
particularly large cells which may not be well represented by :mod:`ESMF`. This
class differs from :class:`GridInfo` primarily in the way it represents itself
as a :class:`~ESMF.api.field.Field` in :mod:`ESMF`. This :class:`~ESMF.api.field.Field`
is designed to be a higher resolution version of the given grid and should
contain enough information for area weighted regridding but may be
inappropriate for other :mod:`ESMF` regridding schemes.
"""

def __init__(
self,
lonbounds,
latbounds,
resolution=3,
crs=None,
):
"""
Create a :class:`RefinedGridInfo` object describing the grid.
Parameters
----------
lonbounds : :obj:`~numpy.typing.ArrayLike`
A 1D array or list describing the longitude bounds of the grid.
Must be strictly increasing (for example, if a bound goes from
170 to -170 consider transposing -170 to 190).
latbounds : :obj:`~numpy.typing.ArrayLike`
A 1D array or list describing the latitude bounds of the grid.
Must be strictly increasing.
resolution : int, default=400
A number describing how many latitude slices each cell should
be divided into when passing a higher resolution grid to ESMF.
crs : :class:`cartopy.crs.CRS`, optional
Describes how to interpret the
above arguments. If ``None``, defaults to :class:`~cartopy.crs.Geodetic`.
"""
# Convert bounds to numpy arrays where necessary.
if not isinstance(lonbounds, np.ndarray):
lonbounds = np.array(lonbounds)
if not isinstance(latbounds, np.ndarray):
latbounds = np.array(latbounds)

# Ensure bounds are strictly increasing.
if not np.all(lonbounds[:-1] < lonbounds[1:]):
raise ValueError("The longitude bounds must be strictly increasing.")
if not np.all(latbounds[:-1] < latbounds[1:]):
raise ValueError("The latitude bounds must be strictly increasing.")

self.resolution = resolution
self.n_lons_orig = len(lonbounds) - 1
self.n_lats_orig = len(latbounds) - 1

# Create dummy lat/lon values
lons = np.zeros(self.n_lons_orig)
lats = np.zeros(self.n_lats_orig)
super().__init__(lons, lats, lonbounds, latbounds, crs=crs)

if self.n_lats_orig == 1 and np.allclose(latbounds, [-90, 90]):
self._refined_latbounds = np.array([-90, 0, 90])
self._refined_lonbounds = lonbounds
else:
self._refined_latbounds = latbounds
self._refined_lonbounds = np.append(
np.linspace(
lonbounds[:-1],
lonbounds[1:],
self.resolution,
endpoint=False,
axis=1,
).flatten(),
lonbounds[-1],
)
self.lon_expansion = int(
(len(self._refined_lonbounds) - 1) / (len(self.lonbounds) - 1)
)
self.lat_expansion = int(
(len(self._refined_latbounds) - 1) / (len(self.latbounds) - 1)
)

@property
def _refined_shape(self):
"""Return shape passed to ESMF."""
return (
self.n_lats_orig * self.lat_expansion,
self.n_lons_orig * self.lon_expansion,
)

def _collapse_weights(self, is_tgt):
"""
Return a matrix to collapse the weight matrix.
The refined grid may contain more cells than the represented grid. When this is
the case, the generated weight matrix will refer to too many points and will have
to be collapsed. This is done by multiplying by this matrix, pre-multiplying when
the target grid is represented and post multiplying when the source grid is
represented.
Parameters
----------
is_tgt : bool
True if the target field is being represented, False otherwise.
"""
# The column indices represent each of the cells in the refined grid.
column_indices = np.arange(self._refined_size)

# The row indices represent the cells of the unrefined grid. These are broadcast
# so that each row index coincides with all column indices of the refined cells
# which the unrefined cell is split into.
if self.lat_expansion > 1:
# The latitudes are expanded only in the case where there is one latitude
# bound from -90 to 90. In this case, there is no longitude expansion.
row_indices = np.empty([self.n_lons_orig, self.lat_expansion])
row_indices[:] = np.arange(self.n_lons_orig)[:, np.newaxis]
else:
# The row indices are broadcast across a dimension representing the expansion
# of the longitude. Each row index is broadcast and flattened so that all the
# row indices representing the unrefined cell match up with the column indices
# representing the refined cells it is split into.
row_indices = np.empty(
[self.n_lons_orig, self.lon_expansion, self.n_lats_orig]
)
row_indices[:] = np.arange(self.n_lons_orig * self.n_lats_orig).reshape(
[self.n_lons_orig, self.n_lats_orig]
)[:, np.newaxis, :]
row_indices = row_indices.flatten()
matrix_shape = (self.size, self._refined_size)
refinement_weights = scipy.sparse.csr_matrix(
(
np.ones(self._refined_size),
(row_indices, column_indices),
),
shape=matrix_shape,
)
if is_tgt:
# When the RefinedGridInfo is the target of the regridder, we want to take
# the average of the weights of each refined target cell. This is because
# these weights represent the proportion of area of the target cells which
# is covered by a given source cell. Since the refined cells are divided in
# such a way that they have equal area, the weights for the unrefined cells
# can be reconstructed by taking an average. This is done via matrix
# multiplication, with the returned matrix pre-multiplying the weight matrix
# so that it operates on the rows of the weight matrix (representing the
# target cells). At this point the returned matrix consists of ones, so we
# divided by the number of refined cells per unrefined cell.
refinement_weights = refinement_weights / (
self.lon_expansion * self.lat_expansion
)
else:
# When the RefinedGridInfo is the source of the regridder, we want to take
# the sum of the weights of each refined target cell. This is because those
# weights represent the proportion of the area of a given target cell which
# is covered by each refined source cell. The total proportion covered by
# each unrefined source cell is then the sum of the weights from each of its
# refined cells. This sum is done by matrix multiplication, the returned
# matrix post-multiplying the weight matrix so that it operates on the columns
# of the weight matrix (representing the source cells). In order for the
# post-multiplication to work, the returned matrix must be transposed.
refinement_weights = refinement_weights.T
return refinement_weights
18 changes: 16 additions & 2 deletions esmf_regrid/esmf_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import scipy.sparse

import esmf_regrid
from ._esmf_sdo import GridInfo
from ._esmf_sdo import GridInfo, RefinedGridInfo

__all__ = [
"GridInfo",
Expand Down Expand Up @@ -103,9 +103,23 @@ def __init__(self, src, tgt, method="conservative", precomputed_weights=None):
)
self.weight_matrix = _weights_dict_to_sparse_array(
weights_dict,
(self.tgt.size, self.src.size),
(self.tgt._refined_size, self.src._refined_size),
(self.tgt.index_offset, self.src.index_offset),
)
if isinstance(tgt, RefinedGridInfo):
# At this point, the weight matrix represents more target points than
# tgt respresents. In order to collapse these points, we collapse the
# weights matrix by the appropriate matrix multiplication.
self.weight_matrix = (
tgt._collapse_weights(is_tgt=True) @ self.weight_matrix
)
if isinstance(src, RefinedGridInfo):
# At this point, the weight matrix represents more source points than
# src respresents. In order to collapse these points, we collapse the
# weights matrix by the appropriate matrix multiplication.
self.weight_matrix = self.weight_matrix @ src._collapse_weights(
is_tgt=False
)
else:
if not scipy.sparse.isspmatrix(precomputed_weights):
raise ValueError(
Expand Down
8 changes: 8 additions & 0 deletions esmf_regrid/experimental/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
VERSION_INITIAL = "esmf_regrid_version_on_initialise"
MDTOL = "mdtol"
METHOD = "method"
RESOLUTION = "resolution"


def save_regridder(rg, filename):
Expand Down Expand Up @@ -87,6 +88,7 @@ def save_regridder(rg, filename):
raise TypeError(msg)

method = rg.method
resolution = rg.resolution

weight_matrix = rg.regridder.weight_matrix
reformatted_weight_matrix = scipy.sparse.coo_matrix(weight_matrix)
Expand All @@ -113,6 +115,8 @@ def save_regridder(rg, filename):
MDTOL: mdtol,
METHOD: method,
}
if resolution is not None:
attributes[RESOLUTION] = resolution

weights_cube = Cube(weight_data, var_name=WEIGHTS_NAME, long_name=WEIGHTS_NAME)
row_coord = AuxCoord(
Expand Down Expand Up @@ -178,6 +182,9 @@ def load_regridder(filename):
# Determine the regridding method, allowing for files created when
# conservative regridding was the only method.
method = weights_cube.attributes.get(METHOD, "conservative")
resolution = weights_cube.attributes.get(RESOLUTION, None)
if resolution is not None:
resolution = int(resolution)

# Reconstruct the weight matrix.
weight_data = weights_cube.data
Expand All @@ -196,6 +203,7 @@ def load_regridder(filename):
mdtol=mdtol,
method=method,
precomputed_weights=weight_matrix,
resolution=resolution,
)

esmf_version = weights_cube.attributes[VERSION_ESMF]
Expand Down
Loading

0 comments on commit b4042d8

Please sign in to comment.