Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RefinedGridInfo class (for potential Hovmoller/zonal mean support) #165

Merged
merged 22 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 188 additions & 6 deletions esmf_regrid/_esmf_sdo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import cartopy.crs as ccrs
import ESMF
import numpy as np
import scipy.sparse


class SDO(ABC):
Expand Down Expand Up @@ -35,6 +36,11 @@ def shape(self):
"""Return shape."""
return self._shape

@property
def _refined_shape(self):
"""Return shape passed to ESMF."""
return self._shape

@property
def dims(self):
"""Return number of dimensions."""
Expand All @@ -45,6 +51,11 @@ def size(self):
"""Return the number of cells in the sdo."""
return np.prod(self._shape)

@property
def _refined_size(self):
"""Return the number of cells passed to ESMF."""
return np.prod(self._refined_shape)

@property
def index_offset(self):
"""Return the index offset."""
Expand Down Expand Up @@ -170,7 +181,9 @@ def __init__(
shape = self.lons.shape

self.lonbounds = lonbounds
self._refined_lonbounds = lonbounds
self.latbounds = latbounds
self._refined_latbounds = latbounds
if crs is None:
self.crs = ccrs.Geodetic()
else:
Expand All @@ -185,26 +198,28 @@ def __init__(
)

def _as_esmf_info(self):
shape = np.array(self._shape)
shape = np.array(self._refined_shape)

londims = len(self.lons.shape)

if londims == 1:
if self.circular:
adjustedlonbounds = self.lonbounds[:-1]
adjustedlonbounds = self._refined_lonbounds[:-1]
else:
adjustedlonbounds = self.lonbounds
adjustedlonbounds = self._refined_lonbounds
centerlons, centerlats = np.meshgrid(self.lons, self.lats)
cornerlons, cornerlats = np.meshgrid(adjustedlonbounds, self.latbounds)
cornerlons, cornerlats = np.meshgrid(
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
adjustedlonbounds, self._refined_latbounds
)
elif londims == 2:
if self.circular:
slice = np.s_[:, :-1]
else:
slice = np.s_[:]
centerlons = self.lons[slice]
centerlats = self.lats[slice]
cornerlons = self.lonbounds[slice]
cornerlats = self.latbounds[slice]
cornerlons = self._refined_lonbounds[slice]
cornerlats = self._refined_latbounds[slice]

truecenters = ccrs.Geodetic().transform_points(self.crs, centerlons, centerlats)
truecorners = ccrs.Geodetic().transform_points(self.crs, cornerlons, cornerlats)
Expand Down Expand Up @@ -275,3 +290,170 @@ def _make_esmf_sdo(self):
grid_areas[:] = areas.T

return grid


class RefinedGridInfo(GridInfo):
"""
Class for handling structured grids represented in :mod:`ESMF` in higher resolution.

A specialised version of :class:`GridInfo`. Designed to provide higher
accuracy conservative regridding for rectilinear grids, especially those with
particularly large cells which may not be well represented by :mod:`ESMF`. This
class differs from :class:`GridInfo` primarily in the way it represents itself
as a :class:`~ESMF.api.field.Field` in :mod:`ESMF`. This :class:`~ESMF.api.field.Field`
is designed to be a higher resolution version of the given grid and should
contain enough information for area weighted regridding but may be
inappropriate for other :mod:`ESMF` regridding schemes.

"""

def __init__(
self,
lonbounds,
latbounds,
resolution=3,
crs=None,
):
"""
Create a :class:`RefinedGridInfo` object describing the grid.

Parameters
----------
lonbounds : :obj:`~numpy.typing.ArrayLike`
A 1D array or list describing the longitude bounds of the grid.
Must be strictly increasing (for example, if a bound goes from
170 to -170 consider transposing -170 to 190).
latbounds : :obj:`~numpy.typing.ArrayLike`
A 1D array or list describing the latitude bounds of the grid.
Must be strictly increasing.
resolution : int, default=400
A number describing how many latitude slices each cell should
be divided into when passing a higher resolution grid to ESMF.
crs : :class:`cartopy.crs.CRS`, optional
Describes how to interpret the
above arguments. If ``None``, defaults to :class:`~cartopy.crs.Geodetic`.

"""
# Convert bounds to numpy arrays where necessary.
if not isinstance(lonbounds, np.ndarray):
lonbounds = np.array(lonbounds)
if not isinstance(latbounds, np.ndarray):
latbounds = np.array(latbounds)

# Ensure bounds are strictly increasing.
if not np.all(lonbounds[:-1] < lonbounds[1:]):
raise ValueError("The longitude bounds must be strictly increasing.")
if not np.all(latbounds[:-1] < latbounds[1:]):
raise ValueError("The latitude bounds must be strictly increasing.")

self.resolution = resolution
self.n_lons_orig = len(lonbounds) - 1
self.n_lats_orig = len(latbounds) - 1

# Create dummy lat/lon values
lons = np.zeros(self.n_lons_orig)
lats = np.zeros(self.n_lats_orig)
super().__init__(lons, lats, lonbounds, latbounds, crs=crs)

if self.n_lats_orig == 1 and np.allclose(latbounds, [-90, 90]):
self._refined_latbounds = np.array([-90, 0, 90])
self._refined_lonbounds = lonbounds
else:
self._refined_latbounds = latbounds
self._refined_lonbounds = np.append(
np.linspace(
lonbounds[:-1],
lonbounds[1:],
self.resolution,
endpoint=False,
axis=1,
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
).flatten(),
lonbounds[-1],
)
self.lon_expansion = int(
(len(self._refined_lonbounds) - 1) / (len(self.lonbounds) - 1)
)
self.lat_expansion = int(
(len(self._refined_latbounds) - 1) / (len(self.latbounds) - 1)
)

@property
def _refined_shape(self):
"""Return shape passed to ESMF."""
return (
self.n_lats_orig * self.lat_expansion,
self.n_lons_orig * self.lon_expansion,
)

def _collapse_weights(self, is_tgt):
"""
Return a matrix to collapse the weight matrix.

The refined grid may contain more cells than the represented grid. When this is
the case, the generated weight matrix will refer to too many points and will have
to be collapsed. This is done by multiplying by this matrix, pre-multiplying when
the target grid is represented and post multiplying when the source grid is
represented.

Parameters
----------
is_tgt : bool
True if the target field is being represented, False otherwise.
"""
# The column indices represent each of the cells in the refined grid.
column_indices = np.arange(self._refined_size)

# The row indices represent the cells of the unrefined grid. These are broadcast
# so that each row index coincides with all column indices of the refined cells
# which the unrefined cell is split into.
if self.lat_expansion > 1:
# The latitudes are expanded only in the case where there is one latitude
# bound from -90 to 90. In this case, there is no longitude expansion.
row_indices = np.empty([self.n_lons_orig, self.lat_expansion])
row_indices[:] = np.arange(self.n_lons_orig)[:, np.newaxis]
else:
# The row indices are broadcast across a dimension representing the expansion
# of the longitude. Each row index is broadcast and flattened so that all the
# row indices representing the unrefined cell match up with the column indices
# representing the refined cells it is split into.
row_indices = np.empty(
[self.n_lons_orig, self.lon_expansion, self.n_lats_orig]
)
row_indices[:] = np.arange(self.n_lons_orig * self.n_lats_orig).reshape(
[self.n_lons_orig, self.n_lats_orig]
)[:, np.newaxis, :]
row_indices = row_indices.flatten()
matrix_shape = (self.size, self._refined_size)
refinement_weights = scipy.sparse.csr_matrix(
(
np.ones(self._refined_size),
(row_indices, column_indices),
),
shape=matrix_shape,
)
if is_tgt:
# When the RefinedGridInfo is the target of the regridder, we want to take
# the average of the weights of each refined target cell. This is because
# these weights represent the proportion of area of the target cells which
# is covered by a given source cell. Since the refined cells are divided in
# such a way that they have equal area, the weights for the unrefined cells
# can be reconstructed by taking an average. This is done via matrix
# multiplication, with the returned matrix pre-multiplying the weight matrix
# so that it operates on the rows of the weight matrix (representing the
# target cells). At this point the returned matrix consists of ones, so we
# divided by the number of refined cells per unrefined cell.
refinement_weights = refinement_weights / (
self.lon_expansion * self.lat_expansion
)
else:
# When the RefinedGridInfo is the source of the regridder, we want to take
# the sum of the weights of each refined target cell. This is because those
# weights represent the proportion of the area of a given target cell which
# is covered by each refined source cell. The total proportion covered by
# each unrefined source cell is then the sum of the weights from each of its
# refined cells. This sum is done by matrix multiplication, the returned
# matrix post-multiplying the weight matrix so that it operates on the columns
# of the weight matrix (representing the source cells). In order for the
# post-multiplication to work, the returned matrix must be transposed.
refinement_weights = refinement_weights.T
return refinement_weights
18 changes: 16 additions & 2 deletions esmf_regrid/esmf_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import scipy.sparse

import esmf_regrid
from ._esmf_sdo import GridInfo
from ._esmf_sdo import GridInfo, RefinedGridInfo

__all__ = [
"GridInfo",
Expand Down Expand Up @@ -103,9 +103,23 @@ def __init__(self, src, tgt, method="conservative", precomputed_weights=None):
)
self.weight_matrix = _weights_dict_to_sparse_array(
weights_dict,
(self.tgt.size, self.src.size),
(self.tgt._refined_size, self.src._refined_size),
(self.tgt.index_offset, self.src.index_offset),
)
if isinstance(tgt, RefinedGridInfo):
# At this point, the weight matrix represents more target points than
# tgt respresents. In order to collapse these points, we collapse the
# weights matrix by the appropriate matrix multiplication.
self.weight_matrix = (
tgt._collapse_weights(is_tgt=True) @ self.weight_matrix
)
if isinstance(src, RefinedGridInfo):
# At this point, the weight matrix represents more source points than
# src respresents. In order to collapse these points, we collapse the
# weights matrix by the appropriate matrix multiplication.
self.weight_matrix = self.weight_matrix @ src._collapse_weights(
is_tgt=False
)
else:
if not scipy.sparse.isspmatrix(precomputed_weights):
raise ValueError(
Expand Down
8 changes: 8 additions & 0 deletions esmf_regrid/experimental/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
VERSION_INITIAL = "esmf_regrid_version_on_initialise"
MDTOL = "mdtol"
METHOD = "method"
RESOLUTION = "resolution"


def save_regridder(rg, filename):
Expand Down Expand Up @@ -87,6 +88,7 @@ def save_regridder(rg, filename):
raise TypeError(msg)

method = rg.method
resolution = rg.resolution

weight_matrix = rg.regridder.weight_matrix
reformatted_weight_matrix = scipy.sparse.coo_matrix(weight_matrix)
Expand All @@ -113,6 +115,8 @@ def save_regridder(rg, filename):
MDTOL: mdtol,
METHOD: method,
}
if resolution is not None:
attributes[RESOLUTION] = resolution

weights_cube = Cube(weight_data, var_name=WEIGHTS_NAME, long_name=WEIGHTS_NAME)
row_coord = AuxCoord(
Expand Down Expand Up @@ -178,6 +182,9 @@ def load_regridder(filename):
# Determine the regridding method, allowing for files created when
# conservative regridding was the only method.
method = weights_cube.attributes.get(METHOD, "conservative")
resolution = weights_cube.attributes.get(RESOLUTION, None)
if resolution is not None:
resolution = int(resolution)

# Reconstruct the weight matrix.
weight_data = weights_cube.data
Expand All @@ -196,6 +203,7 @@ def load_regridder(filename):
mdtol=mdtol,
method=method,
precomputed_weights=weight_matrix,
resolution=resolution,
)

esmf_version = weights_cube.attributes[VERSION_ESMF]
Expand Down
Loading