forked from SciTools/iris-esmf-regrid
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Regridder load/saving (SciTools#130)
* add regridder saving * add regridder saving * avoid saving bug * add docstrings, copy iris utils * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test/lint fixes * test fix * add test * test functionality * add comments and tests * flake fix * refactor tests * fix tests * fix tests * fix tests * use pytest fixture tmp_path * refresh nox cache * fix test * remove temp file architecture * remove imports * update nox cache * fix tests * increment CONDA_CACHE_BUILD * toggle nox environment reuse * toggle nox environment reuse * determine regridder_type generically * fix saver * fix saver * fix loader Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
7c17235
commit 0484f41
Showing
6 changed files
with
302 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
"""Provides load/save functions for regridders.""" | ||
|
||
import iris | ||
from iris.coords import AuxCoord | ||
from iris.cube import Cube, CubeList | ||
from iris.experimental.ugrid import PARSE_UGRID_ON_LOAD | ||
import numpy as np | ||
import scipy.sparse | ||
|
||
from esmf_regrid.experimental.unstructured_scheme import ( | ||
GridToMeshESMFRegridder, | ||
MeshToGridESMFRegridder, | ||
) | ||
|
||
|
||
SUPPORTED_REGRIDDERS = [ | ||
GridToMeshESMFRegridder, | ||
MeshToGridESMFRegridder, | ||
] | ||
REGRIDDER_NAME_MAP = {rg_class.__name__: rg_class for rg_class in SUPPORTED_REGRIDDERS} | ||
|
||
|
||
def save_regridder(rg, filename): | ||
""" | ||
Save a regridder scheme instance. | ||
Saves either a `GridToMeshESMFRegridder` or a `MeshToGridESMFRegridder`. | ||
Parameters | ||
---------- | ||
rg : GridToMeshESMFRegridder, MeshToGridESMFRegridder | ||
The regridder instance to save. | ||
filename : str | ||
The file name to save to. | ||
""" | ||
src_name = "regridder source field" | ||
tgt_name = "regridder target field" | ||
regridder_type = rg.__class__.__name__ | ||
if regridder_type == "GridToMeshESMFRegridder": | ||
src_grid = (rg.grid_y, rg.grid_x) | ||
src_shape = [len(coord.points) for coord in src_grid] | ||
src_data = np.zeros(src_shape) | ||
src_cube = Cube(src_data, long_name=src_name) | ||
src_cube.add_dim_coord(src_grid[0], 0) | ||
src_cube.add_dim_coord(src_grid[1], 1) | ||
|
||
tgt_mesh = rg.mesh | ||
tgt_data = np.zeros(tgt_mesh.face_node_connectivity.indices.shape[0]) | ||
tgt_cube = Cube(tgt_data, long_name=tgt_name) | ||
for coord in tgt_mesh.to_MeshCoords("face"): | ||
tgt_cube.add_aux_coord(coord, 0) | ||
elif regridder_type == "MeshToGridESMFRegridder": | ||
src_mesh = rg.mesh | ||
src_data = np.zeros(src_mesh.face_node_connectivity.indices.shape[0]) | ||
src_cube = Cube(src_data, long_name=src_name) | ||
for coord in src_mesh.to_MeshCoords("face"): | ||
src_cube.add_aux_coord(coord, 0) | ||
|
||
tgt_grid = (rg.grid_y, rg.grid_x) | ||
tgt_shape = [len(coord.points) for coord in tgt_grid] | ||
tgt_data = np.zeros(tgt_shape) | ||
tgt_cube = Cube(tgt_data, long_name=tgt_name) | ||
tgt_cube.add_dim_coord(tgt_grid[0], 0) | ||
tgt_cube.add_dim_coord(tgt_grid[1], 1) | ||
else: | ||
msg = ( | ||
f"Expected a regridder of type `GridToMeshESMFRegridder` or " | ||
f"`MeshToGridESMFRegridder`, got type {regridder_type}." | ||
) | ||
raise TypeError(msg) | ||
|
||
metadata_name = "regridder weights and metadata" | ||
|
||
weight_matrix = rg.regridder.weight_matrix | ||
reformatted_weight_matrix = weight_matrix.tocoo() | ||
weight_data = reformatted_weight_matrix.data | ||
weight_rows = reformatted_weight_matrix.row | ||
weight_cols = reformatted_weight_matrix.col | ||
weight_shape = reformatted_weight_matrix.shape | ||
|
||
mdtol = rg.mdtol | ||
attributes = { | ||
"regridder type": regridder_type, | ||
"mdtol": mdtol, | ||
"weights shape": weight_shape, | ||
} | ||
|
||
metadata_cube = Cube(weight_data, long_name=metadata_name, attributes=attributes) | ||
row_name = "weight matrix rows" | ||
row_coord = AuxCoord(weight_rows, long_name=row_name) | ||
col_name = "weight matrix columns" | ||
col_coord = AuxCoord(weight_cols, long_name=col_name) | ||
metadata_cube.add_aux_coord(row_coord, 0) | ||
metadata_cube.add_aux_coord(col_coord, 0) | ||
|
||
# Avoid saving bug by placing the mesh cube second. | ||
# TODO: simplify this when this bug is fixed in iris. | ||
if regridder_type == "GridToMeshESMFRegridder": | ||
cube_list = CubeList([src_cube, tgt_cube, metadata_cube]) | ||
elif regridder_type == "MeshToGridESMFRegridder": | ||
cube_list = CubeList([tgt_cube, src_cube, metadata_cube]) | ||
iris.fileformats.netcdf.save(cube_list, filename) | ||
|
||
|
||
def load_regridder(filename): | ||
""" | ||
Load a regridder scheme instance. | ||
Loads either a `GridToMeshESMFRegridder` or a `MeshToGridESMFRegridder`. | ||
Parameters | ||
---------- | ||
filename : str | ||
The file name to load from. | ||
""" | ||
with PARSE_UGRID_ON_LOAD.context(): | ||
cubes = iris.load(filename) | ||
|
||
src_name = "regridder source field" | ||
tgt_name = "regridder target field" | ||
metadata_name = "regridder weights and metadata" | ||
|
||
# Extract the source, target and metadata information. | ||
src_cube = cubes.extract_cube(src_name) | ||
tgt_cube = cubes.extract_cube(tgt_name) | ||
metadata_cube = cubes.extract_cube(metadata_name) | ||
|
||
# Determine the regridder type. | ||
regridder_type = metadata_cube.attributes["regridder type"] | ||
assert regridder_type in REGRIDDER_NAME_MAP.keys() | ||
scheme = REGRIDDER_NAME_MAP[regridder_type] | ||
|
||
# Reconstruct the weight matrix. | ||
weight_data = metadata_cube.data | ||
row_name = "weight matrix rows" | ||
weight_rows = metadata_cube.coord(row_name).points | ||
col_name = "weight matrix columns" | ||
weight_cols = metadata_cube.coord(col_name).points | ||
weight_shape = metadata_cube.attributes["weights shape"] | ||
weight_matrix = scipy.sparse.csr_matrix( | ||
(weight_data, (weight_rows, weight_cols)), shape=weight_shape | ||
) | ||
|
||
mdtol = metadata_cube.attributes["mdtol"] | ||
|
||
regridder = scheme( | ||
src_cube, tgt_cube, mdtol=mdtol, precomputed_weights=weight_matrix | ||
) | ||
return regridder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Unit tests for :mod:`esmf_regrid.experimental.io`.""" |
118 changes: 118 additions & 0 deletions
118
esmf_regrid/tests/unit/experimental/io/test_round_tripping.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
"""Unit tests for round tripping (saving then loading) with :mod:`esmf_regrid.experimental.io`.""" | ||
|
||
from iris.cube import Cube | ||
import numpy as np | ||
from numpy import ma | ||
|
||
from esmf_regrid.experimental.io import load_regridder, save_regridder | ||
from esmf_regrid.experimental.unstructured_scheme import ( | ||
GridToMeshESMFRegridder, | ||
MeshToGridESMFRegridder, | ||
) | ||
from esmf_regrid.tests.unit.experimental.unstructured_scheme.test__cube_to_GridInfo import ( | ||
_grid_cube, | ||
) | ||
from esmf_regrid.tests.unit.experimental.unstructured_scheme.test__mesh_to_MeshInfo import ( | ||
_gridlike_mesh, | ||
) | ||
|
||
|
||
def _make_grid_to_mesh_regridder(): | ||
src_lons = 3 | ||
src_lats = 4 | ||
tgt_lons = 5 | ||
tgt_lats = 6 | ||
lon_bounds = (-180, 180) | ||
lat_bounds = (-90, 90) | ||
# TODO check that circularity is preserved. | ||
src = _grid_cube(src_lons, src_lats, lon_bounds, lat_bounds, circular=True) | ||
src.coord("longitude").var_name = "longitude" | ||
src.coord("latitude").var_name = "latitude" | ||
mesh = _gridlike_mesh(tgt_lons, tgt_lats) | ||
mesh_coord_x, mesh_coord_y = mesh.to_MeshCoords("face") | ||
tgt_data = np.zeros(tgt_lons * tgt_lats) | ||
tgt = Cube(tgt_data) | ||
tgt.add_aux_coord(mesh_coord_x, 0) | ||
tgt.add_aux_coord(mesh_coord_y, 0) | ||
|
||
rg = GridToMeshESMFRegridder(src, tgt, mdtol=0.5) | ||
return rg, src | ||
|
||
|
||
def _make_mesh_to_grid_regridder(): | ||
src_lons = 3 | ||
src_lats = 4 | ||
tgt_lons = 5 | ||
tgt_lats = 6 | ||
lon_bounds = (-180, 180) | ||
lat_bounds = (-90, 90) | ||
# TODO check that circularity is preserved. | ||
tgt = _grid_cube(tgt_lons, tgt_lats, lon_bounds, lat_bounds, circular=True) | ||
tgt.coord("longitude").var_name = "longitude" | ||
tgt.coord("latitude").var_name = "latitude" | ||
mesh = _gridlike_mesh(src_lons, src_lats) | ||
mesh_coord_x, mesh_coord_y = mesh.to_MeshCoords("face") | ||
src_data = np.zeros(src_lons * src_lats) | ||
src = Cube(src_data) | ||
src.add_aux_coord(mesh_coord_x, 0) | ||
src.add_aux_coord(mesh_coord_y, 0) | ||
|
||
rg = MeshToGridESMFRegridder(src, tgt, mdtol=0.5) | ||
return rg, src | ||
|
||
|
||
def test_GridToMeshESMFRegridder_round_trip(tmp_path): | ||
"""Test save/load round tripping for `GridToMeshESMFRegridder`.""" | ||
original_rg, src = _make_grid_to_mesh_regridder() | ||
filename = tmp_path / "regridder.nc" | ||
save_regridder(original_rg, filename) | ||
loaded_rg = load_regridder(str(filename)) | ||
|
||
assert original_rg.mdtol == loaded_rg.mdtol | ||
assert original_rg.grid_x == loaded_rg.grid_x | ||
assert original_rg.grid_y == loaded_rg.grid_y | ||
# TODO: uncomment when iris mesh comparison becomes available. | ||
# assert original_rg.mesh == loaded_rg.mesh | ||
|
||
# Compare the weight matrices. | ||
original_matrix = original_rg.regridder.weight_matrix | ||
loaded_matrix = loaded_rg.regridder.weight_matrix | ||
# Ensure the original and loaded weight matrix have identical type. | ||
assert type(original_matrix) is type(loaded_matrix) # noqa E721 | ||
assert np.array_equal(original_matrix.todense(), loaded_matrix.todense()) | ||
|
||
# Demonstrate regridding still gives the same results. | ||
src_data = np.arange(np.product(src.data.shape)).reshape(src.data.shape) | ||
src_mask = np.zeros(src.data.shape) | ||
src_mask[0, 0] = 1 | ||
src.data = ma.array(src_data, mask=src_mask) | ||
# TODO: make this a cube comparison when mesh comparison becomes available. | ||
assert np.array_equal(original_rg(src).data, loaded_rg(src).data) | ||
|
||
|
||
def test_MeshToGridESMFRegridder_round_trip(tmp_path): | ||
"""Test save/load round tripping for `MeshToGridESMFRegridder`.""" | ||
original_rg, src = _make_mesh_to_grid_regridder() | ||
filename = tmp_path / "regridder.nc" | ||
save_regridder(original_rg, filename) | ||
loaded_rg = load_regridder(str(filename)) | ||
|
||
assert original_rg.mdtol == loaded_rg.mdtol | ||
assert original_rg.grid_x == loaded_rg.grid_x | ||
assert original_rg.grid_y == loaded_rg.grid_y | ||
# TODO: uncomment when iris mesh comparison becomes available. | ||
# assert original_rg.mesh == loaded_rg.mesh | ||
|
||
# Compare the weight matrices. | ||
original_matrix = original_rg.regridder.weight_matrix | ||
loaded_matrix = loaded_rg.regridder.weight_matrix | ||
# Ensure the original and loaded weight matrix have identical type. | ||
assert type(original_matrix) is type(loaded_matrix) # noqa E721 | ||
assert np.array_equal(original_matrix.todense(), loaded_matrix.todense()) | ||
|
||
# Demonstrate regridding still gives the same results. | ||
src_data = np.arange(np.product(src.data.shape)).reshape(src.data.shape) | ||
src_mask = np.zeros(src.data.shape) | ||
src_mask[0] = 1 | ||
src.data = ma.array(src_data, mask=src_mask) | ||
assert original_rg(src) == loaded_rg(src) |
13 changes: 13 additions & 0 deletions
13
esmf_regrid/tests/unit/experimental/io/test_save_regridder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
"""Unit tests for :mod:`esmf_regrid.experimental.io.save_regridder`.""" | ||
|
||
import pytest | ||
|
||
from esmf_regrid.experimental.io import save_regridder | ||
|
||
|
||
def test_invalid_type(tmp_path): | ||
"""Test that `save_regridder` raises a TypeError where appropriate.""" | ||
invalid_obj = None | ||
filename = tmp_path / "regridder.nc" | ||
with pytest.raises(TypeError): | ||
save_regridder(invalid_obj, filename) |