Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend regridder saving/loading to all regridders #357

Merged
merged 11 commits into from
May 30, 2024
4 changes: 2 additions & 2 deletions docs/src/userguide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ Saving and Loading a Regridder
A regridder can be set up for reuse, this saves time performing the
computationally expensive initialisation process::

from esmf_regrid.experimental.unstructured_scheme import MeshToGridESMFRegridder
from esmf_regrid.experimental.unstructured_scheme import ESMFAreaWeighted

# Initialise the regridder with a source mesh and target grid.
regridder = MeshToGridESMFRegridder(source_mesh_cube, target_grid_cube)
regridder = ESMFAreaWeighted().regridder(source_mesh_cube, target_grid_cube)

# use the initialised regridder to regrid the data from the source cube
# onto a cube with the same grid as `target_grid_cube`.
Expand Down
10 changes: 6 additions & 4 deletions docs/src/userguide/scheme_comparison.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@ These were formerly the only way to do regridding with a source or
target cube defined on an unstructured mesh. These are less flexible and
require that the source/target be defined on a grid/mesh. Unlike the above
regridders whose method is fixed, these regridders take a ``method`` keyword
of ``conservative``, ``bilinear`` or ``nearest``. While most of the
functionality in these regridders have been ported into the above schemes and
regridders, these remain the only regridders capable of being saved and loaded by
:mod:`esmf_regrid.experimental.io`.
of ``conservative``, ``bilinear`` or ``nearest``. All the
functionality in these regridders has now been ported into the above schemes and
regridders. Before version 0.10, these were the only regridders capable of being
saved and loaded by :mod:`esmf_regrid.experimental.io`, so while the above generic
regridders are recomended, these regridders are still available for the sake of
consistency with regridders saved from older versions.


Overview: Miscellaneous Functions
Expand Down
166 changes: 138 additions & 28 deletions esmf_regrid/experimental/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Provides load/save functions for regridders."""

from contextlib import contextmanager

import iris
from iris.coords import AuxCoord
from iris.cube import Cube, CubeList
Expand All @@ -13,9 +15,19 @@
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
)
from esmf_regrid.schemes import (
ESMFAreaWeightedRegridder,
ESMFBilinearRegridder,
ESMFNearestRegridder,
GridRecord,
MeshRecord,
)


SUPPORTED_REGRIDDERS = [
ESMFAreaWeightedRegridder,
ESMFBilinearRegridder,
ESMFNearestRegridder,
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
]
Expand All @@ -34,6 +46,8 @@
MDTOL = "mdtol"
METHOD = "method"
RESOLUTION = "resolution"
SOURCE_RESOLUTION = "src_resolution"
TARGET_RESOLUTION = "tgt_resolution"


def _add_mask_to_cube(mask, cube, name):
Expand All @@ -43,6 +57,49 @@ def _add_mask_to_cube(mask, cube, name):
cube.add_aux_coord(mask_coord, list(range(cube.ndim)))


@contextmanager
def _managed_var_name(src_cube, tgt_cube):
pp-mo marked this conversation as resolved.
Show resolved Hide resolved
src_coord_names = []
src_mesh_coords = []
if src_cube.mesh is not None:
src_mesh = src_cube.mesh
src_mesh_coords = src_mesh.coords()
for coord in src_mesh_coords:
src_coord_names.append(coord.var_name)
tgt_coord_names = []
tgt_mesh_coords = []
if tgt_cube.mesh is not None:
tgt_mesh = tgt_cube.mesh
tgt_mesh_coords = tgt_mesh.coords()
for coord in tgt_mesh_coords:
tgt_coord_names.append(coord.var_name)

try:
for coord in src_mesh_coords:
coord.var_name = "_".join([SOURCE_NAME, "mesh", coord.name()])
for coord in tgt_mesh_coords:
coord.var_name = "_".join([TARGET_NAME, "mesh", coord.name()])
yield None
finally:
for coord, var_name in zip(src_mesh_coords, src_coord_names):
coord.var_name = var_name
for coord, var_name in zip(tgt_mesh_coords, tgt_coord_names):
coord.var_name = var_name


def _clean_var_names(cube):
pp-mo marked this conversation as resolved.
Show resolved Hide resolved
cube.var_name = None
for coord in cube.coords():
coord.var_name = None
if cube.mesh is not None:
cube.mesh.var_name = None
for coord in cube.mesh.coords():
coord.var_name = None
for con in cube.mesh.connectivities():
con.var_name = None
return cube


def save_regridder(rg, filename):
"""
Save a regridder scheme instance.
pp-mo marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -76,28 +133,56 @@ def _standard_grid_cube(grid, name):
cube.add_aux_coord(grid[1], [0, 1])
return cube

if regridder_type == "GridToMeshESMFRegridder":
def _standard_mesh_cube(mesh, location, name):
mesh_coords = mesh.to_MeshCoords(location)
data = np.zeros(mesh_coords[0].points.shape[0])
cube = Cube(data, var_name=name, long_name=name)
for coord in mesh_coords:
cube.add_aux_coord(coord, 0)
return cube

if regridder_type in [
"ESMFAreaWeightedRegridder",
"ESMFBilinearRegridder",
"ESMFNearestRegridder",
]:
src_grid = rg._src
if isinstance(src_grid, GridRecord):
src_cube = _standard_grid_cube(
(src_grid.grid_y, src_grid.grid_x), SOURCE_NAME
)
elif isinstance(src_grid, MeshRecord):
src_mesh, src_location = src_grid
src_cube = _standard_mesh_cube(src_mesh, src_location, SOURCE_NAME)
else:
raise ValueError("Improper type for `rg._src`.")
_add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME)

tgt_grid = rg._tgt
if isinstance(tgt_grid, GridRecord):
tgt_cube = _standard_grid_cube(
(tgt_grid.grid_y, tgt_grid.grid_x), TARGET_NAME
)
elif isinstance(tgt_grid, MeshRecord):
tgt_mesh, tgt_location = tgt_grid
tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, TARGET_NAME)
else:
raise ValueError("Improper type for `rg._tgt`.")
_add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME)
elif regridder_type == "GridToMeshESMFRegridder":
src_grid = (rg.grid_y, rg.grid_x)
src_cube = _standard_grid_cube(src_grid, SOURCE_NAME)
_add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME)

tgt_mesh = rg.mesh
tgt_location = rg.location
tgt_mesh_coords = tgt_mesh.to_MeshCoords(tgt_location)
tgt_data = np.zeros(tgt_mesh_coords[0].points.shape[0])
tgt_cube = Cube(tgt_data, var_name=TARGET_NAME, long_name=TARGET_NAME)
for coord in tgt_mesh_coords:
tgt_cube.add_aux_coord(coord, 0)
tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, TARGET_NAME)
_add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME)

elif regridder_type == "MeshToGridESMFRegridder":
src_mesh = rg.mesh
src_location = rg.location
src_mesh_coords = src_mesh.to_MeshCoords(src_location)
src_data = np.zeros(src_mesh_coords[0].points.shape[0])
src_cube = Cube(src_data, var_name=SOURCE_NAME, long_name=SOURCE_NAME)
for coord in src_mesh_coords:
src_cube.add_aux_coord(coord, 0)
src_cube = _standard_mesh_cube(src_mesh, src_location, SOURCE_NAME)
_add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME)

tgt_grid = (rg.grid_y, rg.grid_x)
Expand All @@ -112,7 +197,18 @@ def _standard_grid_cube(grid, name):

method = str(check_method(rg.method).name)

resolution = rg.resolution
if regridder_type in ["GridToMeshESMFRegridder", "MeshToGridESMFRegridder"]:
resolution = rg.resolution
src_resolution = None
tgt_resolution = None
elif regridder_type == "ESMFAreaWeightedRegridder":
resolution = None
src_resolution = rg.src_resolution
tgt_resolution = rg.tgt_resolution
else:
resolution = None
src_resolution = None
tgt_resolution = None

weight_matrix = rg.regridder.weight_matrix
reformatted_weight_matrix = scipy.sparse.coo_matrix(weight_matrix)
Expand Down Expand Up @@ -141,6 +237,10 @@ def _standard_grid_cube(grid, name):
}
if resolution is not None:
attributes[RESOLUTION] = resolution
if src_resolution is not None:
attributes[SOURCE_RESOLUTION] = src_resolution
if tgt_resolution is not None:
attributes[TARGET_RESOLUTION] = tgt_resolution

weights_cube = Cube(weight_data, var_name=WEIGHTS_NAME, long_name=WEIGHTS_NAME)
row_coord = AuxCoord(
Expand All @@ -158,17 +258,14 @@ def _standard_grid_cube(grid, name):
long_name=WEIGHTS_SHAPE_NAME,
)

# Avoid saving bug by placing the mesh cube second.
# TODO: simplify this when this bug is fixed in iris.
if regridder_type == "GridToMeshESMFRegridder":
# Save cubes while ensuring var_names do not conflict for the sake of consistency.
with _managed_var_name(src_cube, tgt_cube):
cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube])
elif regridder_type == "MeshToGridESMFRegridder":
cube_list = CubeList([tgt_cube, src_cube, weights_cube, weight_shape_cube])

for cube in cube_list:
cube.attributes = attributes
for cube in cube_list:
cube.attributes = attributes
pp-mo marked this conversation as resolved.
Show resolved Hide resolved

iris.fileformats.netcdf.save(cube_list, filename)
iris.fileformats.netcdf.save(cube_list, filename)


def load_regridder(filename):
Expand All @@ -193,8 +290,8 @@ def load_regridder(filename):
cubes = iris.load(filename)

# Extract the source, target and metadata information.
src_cube = cubes.extract_cube(SOURCE_NAME)
tgt_cube = cubes.extract_cube(TARGET_NAME)
src_cube = _clean_var_names(cubes.extract_cube(SOURCE_NAME))
tgt_cube = _clean_var_names(cubes.extract_cube(TARGET_NAME))
weights_cube = cubes.extract_cube(WEIGHTS_NAME)
weight_shape_cube = cubes.extract_cube(WEIGHTS_SHAPE_NAME)

Expand All @@ -210,8 +307,14 @@ def load_regridder(filename):
)

resolution = weights_cube.attributes.get(RESOLUTION, None)
src_resolution = weights_cube.attributes.get(SOURCE_RESOLUTION, None)
tgt_resolution = weights_cube.attributes.get(TARGET_RESOLUTION, None)
if resolution is not None:
resolution = int(resolution)
if src_resolution is not None:
src_resolution = int(src_resolution)
if tgt_resolution is not None:
tgt_resolution = int(tgt_resolution)

# Reconstruct the weight matrix.
weight_data = weights_cube.data
Expand All @@ -234,18 +337,25 @@ def load_regridder(filename):
use_tgt_mask = False

if scheme is GridToMeshESMFRegridder:
resolution_keyword = "src_resolution"
resolution_keyword = SOURCE_RESOLUTION
kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol}
elif scheme is MeshToGridESMFRegridder:
resolution_keyword = "tgt_resolution"
resolution_keyword = TARGET_RESOLUTION
kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol}
elif scheme is ESMFAreaWeightedRegridder:
kwargs = {
SOURCE_RESOLUTION: src_resolution,
TARGET_RESOLUTION: tgt_resolution,
"mdtol": mdtol,
}
elif scheme is ESMFBilinearRegridder:
kwargs = {"mdtol": mdtol}
else:
raise NotImplementedError
kwargs = {resolution_keyword: resolution}
kwargs = {}

regridder = scheme(
src_cube,
tgt_cube,
mdtol=mdtol,
method=method,
precomputed_weights=weight_matrix,
use_src_mask=use_src_mask,
use_tgt_mask=use_tgt_mask,
Expand Down
11 changes: 11 additions & 0 deletions esmf_regrid/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,8 @@ def regridder(
self,
src_grid,
tgt_grid,
src_resolution=None,
tgt_resolution=None,
use_src_mask=None,
use_tgt_mask=None,
tgt_location="face",
Expand All @@ -980,6 +982,11 @@ def regridder(
tgt_grid : :class:`iris.cube.Cube` or :class:`iris.experimental.ugrid.Mesh`
The unstructured :class:`~iris.cube.Cube`or
:class:`~iris.experimental.ugrid.Mesh` defining the target.
src_resolution, tgt_resolution : int, optional
If present, represents the amount of latitude slices per source/target cell
given to ESMF for calculation. If resolution is set, ``src`` and ``tgt``
respectively must have strictly increasing bounds (bounds may be transposed
plus or minus 360 degrees to make the bounds strictly increasing).
use_src_mask : :obj:`~numpy.typing.ArrayLike` or bool, optional
Array describing which elements :mod:`esmpy` will ignore on the src_grid.
If True, the mask will be derived from src_grid.
Expand Down Expand Up @@ -1017,6 +1024,8 @@ def regridder(
src_grid,
tgt_grid,
mdtol=self.mdtol,
src_resolution=src_resolution,
tgt_resolution=tgt_resolution,
use_src_mask=use_src_mask,
use_tgt_mask=use_tgt_mask,
tgt_location="face",
Expand Down Expand Up @@ -1465,8 +1474,10 @@ def __init__(
if tgt_location is not "face".
"""
kwargs = dict()
self.src_resolution = src_resolution
if src_resolution is not None:
kwargs["src_resolution"] = src_resolution
self.tgt_resolution = tgt_resolution
if tgt_resolution is not None:
kwargs["tgt_resolution"] = tgt_resolution
if tgt_location is not None and tgt_location != "face":
Expand Down
Loading