Skip to content

Commit

Permalink
Check mesh equality on MeshToGridESMFRegridder call (SciTools#138)
Browse files Browse the repository at this point in the history
* check mesh equality on call

* fix test

* attempt lenient equality

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* apply var_name leniency to grids

* temporary test fix

* fix tests

* fix tests

* fix tests

* fix tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
stephenworsley and pre-commit-ci[bot] authored Dec 13, 2021
1 parent 4222dca commit 53e8b61
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 8 deletions.
28 changes: 25 additions & 3 deletions esmf_regrid/experimental/unstructured_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,26 @@ def __call__(self, cube):
"""
mesh = cube.mesh
# TODO: Ensure cube has the same mesh as that of the recorded mesh.
# For the time being, we simply check that the mesh exists.
# TODO: replace temporary hack when iris issues are sorted.
assert mesh is not None
# Ignore differences in var_name that might be caused by saving.
# TODO: uncomment this when iris issue with masked array comparison is sorted.
# self_mesh = copy.deepcopy(self.mesh)
# self_mesh.var_name = mesh.var_name
# for self_coord, other_coord in zip(self_mesh.all_coords, mesh.all_coords):
# if self_coord is not None:
# self_coord.var_name = other_coord.var_name
# for self_con, other_con in zip(
# self_mesh.all_connectivities, mesh.all_connectivities
# ):
# if self_con is not None:
# self_con.var_name = other_con.var_name
# if self_mesh != mesh:
# raise ValueError(
# "The given cube is not defined on the same "
# "source mesh as this regridder."
# )

mesh_dim = cube.mesh_dim()

regrid_info = (mesh_dim, self.grid_x, self.grid_y, self.regridder)
Expand Down Expand Up @@ -691,7 +708,12 @@ def __call__(self, cube):
"""
grid_x, grid_y = get_xy_dim_coords(cube)
if (grid_x != self.grid_x) or (grid_y != self.grid_y):
# Ignore differences in var_name that might be caused by saving.
self_grid_x = copy.deepcopy(self.grid_x)
self_grid_x.var_name = grid_x.var_name
self_grid_y = copy.deepcopy(self.grid_y)
self_grid_y.var_name = grid_y.var_name
if (grid_x != self_grid_x) or (grid_y != self_grid_y):
raise ValueError(
"The given cube is not defined on the same "
"source grid as this regridder."
Expand Down
24 changes: 19 additions & 5 deletions esmf_regrid/tests/unit/experimental/io/test_round_tripping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Unit tests for round tripping (saving then loading) with :mod:`esmf_regrid.experimental.io`."""

from copy import deepcopy

import numpy as np
from numpy import ma

Expand Down Expand Up @@ -42,8 +44,6 @@ def _make_mesh_to_grid_regridder():
lat_bounds = (-90, 90)
# TODO check that circularity is preserved.
tgt = _grid_cube(tgt_lons, tgt_lats, lon_bounds, lat_bounds, circular=True)
tgt.coord("longitude").var_name = "longitude"
tgt.coord("latitude").var_name = "latitude"
src = _gridlike_mesh_cube(src_lons, src_lats)

rg = MeshToGridESMFRegridder(src, tgt, mdtol=0.5)
Expand Down Expand Up @@ -94,8 +94,12 @@ def test_MeshToGridESMFRegridder_round_trip(tmp_path):
loaded_rg = load_regridder(str(filename))

assert original_rg.mdtol == loaded_rg.mdtol
assert original_rg.grid_x == loaded_rg.grid_x
assert original_rg.grid_y == loaded_rg.grid_y
loaded_grid_x = deepcopy(loaded_rg.grid_x)
loaded_grid_x.var_name = original_rg.grid_x.var_name
assert original_rg.grid_x == loaded_grid_x
loaded_grid_y = deepcopy(loaded_rg.grid_y)
loaded_grid_y.var_name = original_rg.grid_y.var_name
assert original_rg.grid_y == loaded_grid_y
# TODO: uncomment when iris mesh comparison becomes available.
# assert original_rg.mesh == loaded_rg.mesh

Expand All @@ -111,7 +115,17 @@ def test_MeshToGridESMFRegridder_round_trip(tmp_path):
src_mask = np.zeros(src.data.shape)
src_mask[0] = 1
src.data = ma.array(src_data, mask=src_mask)
assert original_rg(src) == loaded_rg(src)
# Compare results, ignoring var_name changes due to saving.
original_result = original_rg(src)
loaded_result = loaded_rg(src)
original_result.var_name = loaded_result.var_name
original_result.coord("latitude").var_name = loaded_result.coord(
"latitude"
).var_name
original_result.coord("longitude").var_name = loaded_result.coord(
"longitude"
).var_name
assert original_result == loaded_result

# Ensure version data is equal.
assert original_rg.regridder.esmf_version == loaded_rg.regridder.esmf_version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from esmf_regrid.tests.unit.experimental.unstructured_scheme.test__mesh_to_MeshInfo import (
_gridlike_mesh,
_gridlike_mesh_cube,
)
from esmf_regrid.tests.unit.experimental.unstructured_scheme.test__regrid_unstructured_to_rectilinear__prepare import (
_flat_mesh_cube,
Expand Down Expand Up @@ -136,6 +137,34 @@ def test_invalid_mdtol():
_ = MeshToGridESMFRegridder(src, tgt, mdtol=-1)


@pytest.mark.xfail
def test_mistmatched_mesh():
"""
Test the calling of :func:`esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`.
Checks that an error is raised when the regridder is called with a cube
whose mesh does not match the one used for initialisation.
"""
src = _flat_mesh_cube()

n_lons = 6
n_lats = 5
lon_bounds = (-180, 180)
lat_bounds = (-90, 90)
tgt = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True)

rg = MeshToGridESMFRegridder(src, tgt)

other_src = _gridlike_mesh_cube(n_lons, n_lats)

with pytest.raises(ValueError) as excinfo:
_ = rg(other_src)
expected_message = (
"The given cube is not defined on the same " "source mesh as this regridder."
)
assert expected_message in str(excinfo.value)


def test_laziness():
"""Test that regridding is lazy when source data is lazy."""
n_lons = 12
Expand Down

0 comments on commit 53e8b61

Please sign in to comment.