Skip to content

Commit

Permalink
Regridder load/saving (SciTools#130)
Browse files Browse the repository at this point in the history
* add regridder saving

* add regridder saving

* avoid saving bug

* add docstrings, copy iris utils

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test/lint fixes

* test fix

* add test

* test functionality

* add comments and tests

* flake fix

* refactor tests

* fix tests

* fix tests

* fix tests

* use pytest fixture tmp_path

* refresh nox cache

* fix test

* remove temp file architecture

* remove imports

* update nox cache

* fix tests

* increment CONDA_CACHE_BUILD

* toggle nox environment reuse

* toggle nox environment reuse

* determine regridder_type generically

* fix saver

* fix saver

* fix loader

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
stephenworsley and pre-commit-ci[bot] authored Nov 18, 2021
1 parent 7c17235 commit 0484f41
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 10 deletions.
4 changes: 2 additions & 2 deletions .cirrus.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ env:
# Maximum cache period (in weeks) before forcing a new cache upload.
CACHE_PERIOD: "2"
# Increment the build number to force new conda cache upload.
CONDA_CACHE_BUILD: "0"
CONDA_CACHE_BUILD: "1"
# Increment the build number to force new nox cache upload.
NOX_CACHE_BUILD: "0"
NOX_CACHE_BUILD: "2"
# Increment the build number to force new pip cache upload.
PIP_CACHE_BUILD: "0"
# Pip package to be installed.
Expand Down
149 changes: 149 additions & 0 deletions esmf_regrid/experimental/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""Provides load/save functions for regridders."""

import iris
from iris.coords import AuxCoord
from iris.cube import Cube, CubeList
from iris.experimental.ugrid import PARSE_UGRID_ON_LOAD
import numpy as np
import scipy.sparse

from esmf_regrid.experimental.unstructured_scheme import (
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
)


SUPPORTED_REGRIDDERS = [
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
]
REGRIDDER_NAME_MAP = {rg_class.__name__: rg_class for rg_class in SUPPORTED_REGRIDDERS}


def save_regridder(rg, filename):
"""
Save a regridder scheme instance.
Saves either a `GridToMeshESMFRegridder` or a `MeshToGridESMFRegridder`.
Parameters
----------
rg : GridToMeshESMFRegridder, MeshToGridESMFRegridder
The regridder instance to save.
filename : str
The file name to save to.
"""
src_name = "regridder source field"
tgt_name = "regridder target field"
regridder_type = rg.__class__.__name__
if regridder_type == "GridToMeshESMFRegridder":
src_grid = (rg.grid_y, rg.grid_x)
src_shape = [len(coord.points) for coord in src_grid]
src_data = np.zeros(src_shape)
src_cube = Cube(src_data, long_name=src_name)
src_cube.add_dim_coord(src_grid[0], 0)
src_cube.add_dim_coord(src_grid[1], 1)

tgt_mesh = rg.mesh
tgt_data = np.zeros(tgt_mesh.face_node_connectivity.indices.shape[0])
tgt_cube = Cube(tgt_data, long_name=tgt_name)
for coord in tgt_mesh.to_MeshCoords("face"):
tgt_cube.add_aux_coord(coord, 0)
elif regridder_type == "MeshToGridESMFRegridder":
src_mesh = rg.mesh
src_data = np.zeros(src_mesh.face_node_connectivity.indices.shape[0])
src_cube = Cube(src_data, long_name=src_name)
for coord in src_mesh.to_MeshCoords("face"):
src_cube.add_aux_coord(coord, 0)

tgt_grid = (rg.grid_y, rg.grid_x)
tgt_shape = [len(coord.points) for coord in tgt_grid]
tgt_data = np.zeros(tgt_shape)
tgt_cube = Cube(tgt_data, long_name=tgt_name)
tgt_cube.add_dim_coord(tgt_grid[0], 0)
tgt_cube.add_dim_coord(tgt_grid[1], 1)
else:
msg = (
f"Expected a regridder of type `GridToMeshESMFRegridder` or "
f"`MeshToGridESMFRegridder`, got type {regridder_type}."
)
raise TypeError(msg)

metadata_name = "regridder weights and metadata"

weight_matrix = rg.regridder.weight_matrix
reformatted_weight_matrix = weight_matrix.tocoo()
weight_data = reformatted_weight_matrix.data
weight_rows = reformatted_weight_matrix.row
weight_cols = reformatted_weight_matrix.col
weight_shape = reformatted_weight_matrix.shape

mdtol = rg.mdtol
attributes = {
"regridder type": regridder_type,
"mdtol": mdtol,
"weights shape": weight_shape,
}

metadata_cube = Cube(weight_data, long_name=metadata_name, attributes=attributes)
row_name = "weight matrix rows"
row_coord = AuxCoord(weight_rows, long_name=row_name)
col_name = "weight matrix columns"
col_coord = AuxCoord(weight_cols, long_name=col_name)
metadata_cube.add_aux_coord(row_coord, 0)
metadata_cube.add_aux_coord(col_coord, 0)

# Avoid saving bug by placing the mesh cube second.
# TODO: simplify this when this bug is fixed in iris.
if regridder_type == "GridToMeshESMFRegridder":
cube_list = CubeList([src_cube, tgt_cube, metadata_cube])
elif regridder_type == "MeshToGridESMFRegridder":
cube_list = CubeList([tgt_cube, src_cube, metadata_cube])
iris.fileformats.netcdf.save(cube_list, filename)


def load_regridder(filename):
"""
Load a regridder scheme instance.
Loads either a `GridToMeshESMFRegridder` or a `MeshToGridESMFRegridder`.
Parameters
----------
filename : str
The file name to load from.
"""
with PARSE_UGRID_ON_LOAD.context():
cubes = iris.load(filename)

src_name = "regridder source field"
tgt_name = "regridder target field"
metadata_name = "regridder weights and metadata"

# Extract the source, target and metadata information.
src_cube = cubes.extract_cube(src_name)
tgt_cube = cubes.extract_cube(tgt_name)
metadata_cube = cubes.extract_cube(metadata_name)

# Determine the regridder type.
regridder_type = metadata_cube.attributes["regridder type"]
assert regridder_type in REGRIDDER_NAME_MAP.keys()
scheme = REGRIDDER_NAME_MAP[regridder_type]

# Reconstruct the weight matrix.
weight_data = metadata_cube.data
row_name = "weight matrix rows"
weight_rows = metadata_cube.coord(row_name).points
col_name = "weight matrix columns"
weight_cols = metadata_cube.coord(col_name).points
weight_shape = metadata_cube.attributes["weights shape"]
weight_matrix = scipy.sparse.csr_matrix(
(weight_data, (weight_rows, weight_cols)), shape=weight_shape
)

mdtol = metadata_cube.attributes["mdtol"]

regridder = scheme(
src_cube, tgt_cube, mdtol=mdtol, precomputed_weights=weight_matrix
)
return regridder
27 changes: 19 additions & 8 deletions esmf_regrid/experimental/unstructured_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def copy_coords(src_coords, add_method):
return new_cube


def _regrid_unstructured_to_rectilinear__prepare(src_mesh_cube, target_grid_cube):
def _regrid_unstructured_to_rectilinear__prepare(
src_mesh_cube, target_grid_cube, precomputed_weights=None
):
"""
First (setup) part of 'regrid_unstructured_to_rectilinear'.
Expand Down Expand Up @@ -257,7 +259,7 @@ def _regrid_unstructured_to_rectilinear__prepare(src_mesh_cube, target_grid_cube
meshinfo = _mesh_to_MeshInfo(mesh)
gridinfo = _cube_to_GridInfo(target_grid_cube)

regridder = Regridder(meshinfo, gridinfo)
regridder = Regridder(meshinfo, gridinfo, precomputed_weights)

regrid_info = (mesh_dim, grid_x, grid_y, regridder)

Expand Down Expand Up @@ -350,7 +352,9 @@ def regrid_unstructured_to_rectilinear(src_cube, grid_cube, mdtol=0):
class MeshToGridESMFRegridder:
"""Regridder class for unstructured to rectilinear cubes."""

def __init__(self, src_mesh_cube, target_grid_cube, mdtol=1):
def __init__(
self, src_mesh_cube, target_grid_cube, mdtol=1, precomputed_weights=None
):
"""
Create regridder for conversions between source mesh and target grid.
Expand Down Expand Up @@ -382,9 +386,12 @@ def __init__(self, src_mesh_cube, target_grid_cube, mdtol=1):
self.mdtol = mdtol

partial_regrid_info = _regrid_unstructured_to_rectilinear__prepare(
src_mesh_cube, target_grid_cube
src_mesh_cube, target_grid_cube, precomputed_weights=precomputed_weights
)

# Record source mesh.
self.mesh = src_mesh_cube.mesh

# Store regrid info.
_, self.grid_x, self.grid_y, self.regridder = partial_regrid_info

Expand Down Expand Up @@ -491,7 +498,9 @@ def copy_coords(src_coords, add_method):
return new_cube


def _regrid_rectilinear_to_unstructured__prepare(src_grid_cube, target_mesh_cube):
def _regrid_rectilinear_to_unstructured__prepare(
src_grid_cube, target_mesh_cube, precomputed_weights=None
):
"""
First (setup) part of 'regrid_rectilinear_to_unstructured'.
Expand All @@ -510,7 +519,7 @@ def _regrid_rectilinear_to_unstructured__prepare(src_grid_cube, target_mesh_cube
meshinfo = _mesh_to_MeshInfo(mesh)
gridinfo = _cube_to_GridInfo(src_grid_cube)

regridder = Regridder(gridinfo, meshinfo)
regridder = Regridder(gridinfo, meshinfo, precomputed_weights)

regrid_info = (grid_x_dim, grid_y_dim, grid_x, grid_y, mesh, regridder)

Expand Down Expand Up @@ -610,7 +619,9 @@ def regrid_rectilinear_to_unstructured(src_cube, mesh_cube, mdtol=0):
class GridToMeshESMFRegridder:
"""Regridder class for rectilinear to unstructured cubes."""

def __init__(self, src_mesh_cube, target_grid_cube, mdtol=1):
def __init__(
self, src_mesh_cube, target_grid_cube, mdtol=1, precomputed_weights=None
):
"""
Create regridder for conversions between source grid and target mesh.
Expand All @@ -637,7 +648,7 @@ def __init__(self, src_mesh_cube, target_grid_cube, mdtol=1):
self.mdtol = mdtol

partial_regrid_info = _regrid_rectilinear_to_unstructured__prepare(
src_mesh_cube, target_grid_cube
src_mesh_cube, target_grid_cube, precomputed_weights=precomputed_weights
)

# Store regrid info.
Expand Down
1 change: 1 addition & 0 deletions esmf_regrid/tests/unit/experimental/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for :mod:`esmf_regrid.experimental.io`."""
118 changes: 118 additions & 0 deletions esmf_regrid/tests/unit/experimental/io/test_round_tripping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Unit tests for round tripping (saving then loading) with :mod:`esmf_regrid.experimental.io`."""

from iris.cube import Cube
import numpy as np
from numpy import ma

from esmf_regrid.experimental.io import load_regridder, save_regridder
from esmf_regrid.experimental.unstructured_scheme import (
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
)
from esmf_regrid.tests.unit.experimental.unstructured_scheme.test__cube_to_GridInfo import (
_grid_cube,
)
from esmf_regrid.tests.unit.experimental.unstructured_scheme.test__mesh_to_MeshInfo import (
_gridlike_mesh,
)


def _make_grid_to_mesh_regridder():
src_lons = 3
src_lats = 4
tgt_lons = 5
tgt_lats = 6
lon_bounds = (-180, 180)
lat_bounds = (-90, 90)
# TODO check that circularity is preserved.
src = _grid_cube(src_lons, src_lats, lon_bounds, lat_bounds, circular=True)
src.coord("longitude").var_name = "longitude"
src.coord("latitude").var_name = "latitude"
mesh = _gridlike_mesh(tgt_lons, tgt_lats)
mesh_coord_x, mesh_coord_y = mesh.to_MeshCoords("face")
tgt_data = np.zeros(tgt_lons * tgt_lats)
tgt = Cube(tgt_data)
tgt.add_aux_coord(mesh_coord_x, 0)
tgt.add_aux_coord(mesh_coord_y, 0)

rg = GridToMeshESMFRegridder(src, tgt, mdtol=0.5)
return rg, src


def _make_mesh_to_grid_regridder():
src_lons = 3
src_lats = 4
tgt_lons = 5
tgt_lats = 6
lon_bounds = (-180, 180)
lat_bounds = (-90, 90)
# TODO check that circularity is preserved.
tgt = _grid_cube(tgt_lons, tgt_lats, lon_bounds, lat_bounds, circular=True)
tgt.coord("longitude").var_name = "longitude"
tgt.coord("latitude").var_name = "latitude"
mesh = _gridlike_mesh(src_lons, src_lats)
mesh_coord_x, mesh_coord_y = mesh.to_MeshCoords("face")
src_data = np.zeros(src_lons * src_lats)
src = Cube(src_data)
src.add_aux_coord(mesh_coord_x, 0)
src.add_aux_coord(mesh_coord_y, 0)

rg = MeshToGridESMFRegridder(src, tgt, mdtol=0.5)
return rg, src


def test_GridToMeshESMFRegridder_round_trip(tmp_path):
"""Test save/load round tripping for `GridToMeshESMFRegridder`."""
original_rg, src = _make_grid_to_mesh_regridder()
filename = tmp_path / "regridder.nc"
save_regridder(original_rg, filename)
loaded_rg = load_regridder(str(filename))

assert original_rg.mdtol == loaded_rg.mdtol
assert original_rg.grid_x == loaded_rg.grid_x
assert original_rg.grid_y == loaded_rg.grid_y
# TODO: uncomment when iris mesh comparison becomes available.
# assert original_rg.mesh == loaded_rg.mesh

# Compare the weight matrices.
original_matrix = original_rg.regridder.weight_matrix
loaded_matrix = loaded_rg.regridder.weight_matrix
# Ensure the original and loaded weight matrix have identical type.
assert type(original_matrix) is type(loaded_matrix) # noqa E721
assert np.array_equal(original_matrix.todense(), loaded_matrix.todense())

# Demonstrate regridding still gives the same results.
src_data = np.arange(np.product(src.data.shape)).reshape(src.data.shape)
src_mask = np.zeros(src.data.shape)
src_mask[0, 0] = 1
src.data = ma.array(src_data, mask=src_mask)
# TODO: make this a cube comparison when mesh comparison becomes available.
assert np.array_equal(original_rg(src).data, loaded_rg(src).data)


def test_MeshToGridESMFRegridder_round_trip(tmp_path):
"""Test save/load round tripping for `MeshToGridESMFRegridder`."""
original_rg, src = _make_mesh_to_grid_regridder()
filename = tmp_path / "regridder.nc"
save_regridder(original_rg, filename)
loaded_rg = load_regridder(str(filename))

assert original_rg.mdtol == loaded_rg.mdtol
assert original_rg.grid_x == loaded_rg.grid_x
assert original_rg.grid_y == loaded_rg.grid_y
# TODO: uncomment when iris mesh comparison becomes available.
# assert original_rg.mesh == loaded_rg.mesh

# Compare the weight matrices.
original_matrix = original_rg.regridder.weight_matrix
loaded_matrix = loaded_rg.regridder.weight_matrix
# Ensure the original and loaded weight matrix have identical type.
assert type(original_matrix) is type(loaded_matrix) # noqa E721
assert np.array_equal(original_matrix.todense(), loaded_matrix.todense())

# Demonstrate regridding still gives the same results.
src_data = np.arange(np.product(src.data.shape)).reshape(src.data.shape)
src_mask = np.zeros(src.data.shape)
src_mask[0] = 1
src.data = ma.array(src_data, mask=src_mask)
assert original_rg(src) == loaded_rg(src)
13 changes: 13 additions & 0 deletions esmf_regrid/tests/unit/experimental/io/test_save_regridder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Unit tests for :mod:`esmf_regrid.experimental.io.save_regridder`."""

import pytest

from esmf_regrid.experimental.io import save_regridder


def test_invalid_type(tmp_path):
"""Test that `save_regridder` raises a TypeError where appropriate."""
invalid_obj = None
filename = tmp_path / "regridder.nc"
with pytest.raises(TypeError):
save_regridder(invalid_obj, filename)

0 comments on commit 0484f41

Please sign in to comment.