Skip to content

Commit

Permalink
Formalise regridder file format (SciTools#137)
Browse files Browse the repository at this point in the history
* improve regridder save format

* add normalization to file format

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address review comments

* address review comments

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
stephenworsley and pre-commit-ci[bot] authored Dec 13, 2021
1 parent a3e147b commit 4222dca
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 35 deletions.
4 changes: 4 additions & 0 deletions esmf_regrid/esmf_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from numpy import ma
import scipy.sparse

import esmf_regrid
from ._esmf_sdo import GridInfo

__all__ = [
Expand Down Expand Up @@ -77,7 +78,9 @@ def __init__(self, src, tgt, precomputed_weights=None):
self.src = src
self.tgt = tgt

self.esmf_regrid_version = esmf_regrid.__version__
if precomputed_weights is None:
self.esmf_version = ESMF.__version__
weights_dict = _get_regrid_weights_dict(
src.make_esmf_field(), tgt.make_esmf_field()
)
Expand All @@ -99,6 +102,7 @@ def __init__(self, src, tgt, precomputed_weights=None):
precomputed_weights.shape,
)
)
self.esmf_version = None
self.weight_matrix = precomputed_weights

def regrid(self, src_array, norm_type="fracarea", mdtol=1):
Expand Down
100 changes: 65 additions & 35 deletions esmf_regrid/experimental/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import scipy.sparse

import esmf_regrid
from esmf_regrid.experimental.unstructured_scheme import (
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
Expand All @@ -18,6 +19,16 @@
MeshToGridESMFRegridder,
]
REGRIDDER_NAME_MAP = {rg_class.__name__: rg_class for rg_class in SUPPORTED_REGRIDDERS}
SOURCE_NAME = "regridder_source_field"
TARGET_NAME = "regridder_target_field"
WEIGHTS_NAME = "regridder_weights"
WEIGHTS_SHAPE_NAME = "weights_shape"
WEIGHTS_ROW_NAME = "weight_matrix_rows"
WEIGHTS_COL_NAME = "weight_matrix_columns"
REGRIDDER_TYPE = "regridder_type"
VERSION_ESMF = "ESMF_version"
VERSION_INITIAL = "esmf_regrid_version_on_initialise"
MDTOL = "mdtol"


def save_regridder(rg, filename):
Expand All @@ -33,33 +44,31 @@ def save_regridder(rg, filename):
filename : str
The file name to save to.
"""
src_name = "regridder source field"
tgt_name = "regridder target field"
regridder_type = rg.__class__.__name__
if regridder_type == "GridToMeshESMFRegridder":
src_grid = (rg.grid_y, rg.grid_x)
src_shape = [len(coord.points) for coord in src_grid]
src_data = np.zeros(src_shape)
src_cube = Cube(src_data, long_name=src_name)
src_cube = Cube(src_data, var_name=SOURCE_NAME, long_name=SOURCE_NAME)
src_cube.add_dim_coord(src_grid[0], 0)
src_cube.add_dim_coord(src_grid[1], 1)

tgt_mesh = rg.mesh
tgt_data = np.zeros(tgt_mesh.face_node_connectivity.indices.shape[0])
tgt_cube = Cube(tgt_data, long_name=tgt_name)
tgt_cube = Cube(tgt_data, var_name=TARGET_NAME, long_name=TARGET_NAME)
for coord in tgt_mesh.to_MeshCoords("face"):
tgt_cube.add_aux_coord(coord, 0)
elif regridder_type == "MeshToGridESMFRegridder":
src_mesh = rg.mesh
src_data = np.zeros(src_mesh.face_node_connectivity.indices.shape[0])
src_cube = Cube(src_data, long_name=src_name)
src_cube = Cube(src_data, var_name=SOURCE_NAME, long_name=SOURCE_NAME)
for coord in src_mesh.to_MeshCoords("face"):
src_cube.add_aux_coord(coord, 0)

tgt_grid = (rg.grid_y, rg.grid_x)
tgt_shape = [len(coord.points) for coord in tgt_grid]
tgt_data = np.zeros(tgt_shape)
tgt_cube = Cube(tgt_data, long_name=tgt_name)
tgt_cube = Cube(tgt_data, var_name=TARGET_NAME, long_name=TARGET_NAME)
tgt_cube.add_dim_coord(tgt_grid[0], 0)
tgt_cube.add_dim_coord(tgt_grid[1], 1)
else:
Expand All @@ -69,36 +78,57 @@ def save_regridder(rg, filename):
)
raise TypeError(msg)

metadata_name = "regridder weights and metadata"

weight_matrix = rg.regridder.weight_matrix
reformatted_weight_matrix = weight_matrix.tocoo()
weight_data = reformatted_weight_matrix.data
weight_rows = reformatted_weight_matrix.row
weight_cols = reformatted_weight_matrix.col
weight_shape = reformatted_weight_matrix.shape

esmf_version = rg.regridder.esmf_version
esmf_regrid_version = rg.regridder.esmf_regrid_version
save_version = esmf_regrid.__version__

# Currently, all schemes use the fracarea normalization.
normalization = "fracarea"

mdtol = rg.mdtol
attributes = {
"regridder type": regridder_type,
"mdtol": mdtol,
"weights shape": weight_shape,
"title": "iris-esmf-regrid regridding scheme",
REGRIDDER_TYPE: regridder_type,
VERSION_ESMF: esmf_version,
VERSION_INITIAL: esmf_regrid_version,
"esmf_regrid_version_on_save": save_version,
"normalization": normalization,
MDTOL: mdtol,
}

metadata_cube = Cube(weight_data, long_name=metadata_name, attributes=attributes)
row_name = "weight matrix rows"
row_coord = AuxCoord(weight_rows, long_name=row_name)
col_name = "weight matrix columns"
col_coord = AuxCoord(weight_cols, long_name=col_name)
metadata_cube.add_aux_coord(row_coord, 0)
metadata_cube.add_aux_coord(col_coord, 0)
weights_cube = Cube(weight_data, var_name=WEIGHTS_NAME, long_name=WEIGHTS_NAME)
row_coord = AuxCoord(
weight_rows, var_name=WEIGHTS_ROW_NAME, long_name=WEIGHTS_ROW_NAME
)
col_coord = AuxCoord(
weight_cols, var_name=WEIGHTS_COL_NAME, long_name=WEIGHTS_COL_NAME
)
weights_cube.add_aux_coord(row_coord, 0)
weights_cube.add_aux_coord(col_coord, 0)

weight_shape_cube = Cube(
weight_shape,
var_name=WEIGHTS_SHAPE_NAME,
long_name=WEIGHTS_SHAPE_NAME,
)

# Avoid saving bug by placing the mesh cube second.
# TODO: simplify this when this bug is fixed in iris.
if regridder_type == "GridToMeshESMFRegridder":
cube_list = CubeList([src_cube, tgt_cube, metadata_cube])
cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube])
elif regridder_type == "MeshToGridESMFRegridder":
cube_list = CubeList([tgt_cube, src_cube, metadata_cube])
cube_list = CubeList([tgt_cube, src_cube, weights_cube, weight_shape_cube])

for cube in cube_list:
cube.attributes = attributes

iris.fileformats.netcdf.save(cube_list, filename)


Expand All @@ -116,34 +146,34 @@ def load_regridder(filename):
with PARSE_UGRID_ON_LOAD.context():
cubes = iris.load(filename)

src_name = "regridder source field"
tgt_name = "regridder target field"
metadata_name = "regridder weights and metadata"

# Extract the source, target and metadata information.
src_cube = cubes.extract_cube(src_name)
tgt_cube = cubes.extract_cube(tgt_name)
metadata_cube = cubes.extract_cube(metadata_name)
src_cube = cubes.extract_cube(SOURCE_NAME)
tgt_cube = cubes.extract_cube(TARGET_NAME)
weights_cube = cubes.extract_cube(WEIGHTS_NAME)
weight_shape_cube = cubes.extract_cube(WEIGHTS_SHAPE_NAME)

# Determine the regridder type.
regridder_type = metadata_cube.attributes["regridder type"]
regridder_type = weights_cube.attributes[REGRIDDER_TYPE]
assert regridder_type in REGRIDDER_NAME_MAP.keys()
scheme = REGRIDDER_NAME_MAP[regridder_type]

# Reconstruct the weight matrix.
weight_data = metadata_cube.data
row_name = "weight matrix rows"
weight_rows = metadata_cube.coord(row_name).points
col_name = "weight matrix columns"
weight_cols = metadata_cube.coord(col_name).points
weight_shape = metadata_cube.attributes["weights shape"]
weight_data = weights_cube.data
weight_rows = weights_cube.coord(WEIGHTS_ROW_NAME).points
weight_cols = weights_cube.coord(WEIGHTS_COL_NAME).points
weight_shape = weight_shape_cube.data
weight_matrix = scipy.sparse.csr_matrix(
(weight_data, (weight_rows, weight_cols)), shape=weight_shape
)

mdtol = metadata_cube.attributes["mdtol"]
mdtol = weights_cube.attributes[MDTOL]

regridder = scheme(
src_cube, tgt_cube, mdtol=mdtol, precomputed_weights=weight_matrix
)

esmf_version = weights_cube.attributes[VERSION_ESMF]
regridder.regridder.esmf_version = esmf_version
esmf_regrid_version = weights_cube.attributes[VERSION_INITIAL]
regridder.regridder.esmf_regrid_version = esmf_regrid_version
return regridder
14 changes: 14 additions & 0 deletions esmf_regrid/tests/unit/experimental/io/test_round_tripping.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def test_GridToMeshESMFRegridder_round_trip(tmp_path):
# TODO: make this a cube comparison when mesh comparison becomes available.
assert np.array_equal(original_rg(src).data, loaded_rg(src).data)

# Ensure version data is equal.
assert original_rg.regridder.esmf_version == loaded_rg.regridder.esmf_version
assert (
original_rg.regridder.esmf_regrid_version
== loaded_rg.regridder.esmf_regrid_version
)


def test_MeshToGridESMFRegridder_round_trip(tmp_path):
"""Test save/load round tripping for `MeshToGridESMFRegridder`."""
Expand Down Expand Up @@ -105,3 +112,10 @@ def test_MeshToGridESMFRegridder_round_trip(tmp_path):
src_mask[0] = 1
src.data = ma.array(src_data, mask=src_mask)
assert original_rg(src) == loaded_rg(src)

# Ensure version data is equal.
assert original_rg.regridder.esmf_version == loaded_rg.regridder.esmf_version
assert (
original_rg.regridder.esmf_regrid_version
== loaded_rg.regridder.esmf_regrid_version
)

0 comments on commit 4222dca

Please sign in to comment.