diff --git a/esmf_regrid/experimental/unstructured_scheme.py b/esmf_regrid/experimental/unstructured_scheme.py index 4a2bd6ab..cd67e299 100644 --- a/esmf_regrid/experimental/unstructured_scheme.py +++ b/esmf_regrid/experimental/unstructured_scheme.py @@ -432,9 +432,26 @@ def __call__(self, cube): """ mesh = cube.mesh - # TODO: Ensure cube has the same mesh as that of the recorded mesh. - # For the time being, we simply check that the mesh exists. + # TODO: replace temporary hack when iris issues are sorted. assert mesh is not None + # Ignore differences in var_name that might be caused by saving. + # TODO: uncomment this when iris issue with masked array comparison is sorted. + # self_mesh = copy.deepcopy(self.mesh) + # self_mesh.var_name = mesh.var_name + # for self_coord, other_coord in zip(self_mesh.all_coords, mesh.all_coords): + # if self_coord is not None: + # self_coord.var_name = other_coord.var_name + # for self_con, other_con in zip( + # self_mesh.all_connectivities, mesh.all_connectivities + # ): + # if self_con is not None: + # self_con.var_name = other_con.var_name + # if self_mesh != mesh: + # raise ValueError( + # "The given cube is not defined on the same " + # "source mesh as this regridder." + # ) + mesh_dim = cube.mesh_dim() regrid_info = (mesh_dim, self.grid_x, self.grid_y, self.regridder) @@ -691,7 +708,12 @@ def __call__(self, cube): """ grid_x, grid_y = get_xy_dim_coords(cube) - if (grid_x != self.grid_x) or (grid_y != self.grid_y): + # Ignore differences in var_name that might be caused by saving. + self_grid_x = copy.deepcopy(self.grid_x) + self_grid_x.var_name = grid_x.var_name + self_grid_y = copy.deepcopy(self.grid_y) + self_grid_y.var_name = grid_y.var_name + if (grid_x != self_grid_x) or (grid_y != self_grid_y): raise ValueError( "The given cube is not defined on the same " "source grid as this regridder." diff --git a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py index 532e8a75..d1bdaa8f 100644 --- a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py +++ b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py @@ -1,5 +1,7 @@ """Unit tests for round tripping (saving then loading) with :mod:`esmf_regrid.experimental.io`.""" +from copy import deepcopy + import numpy as np from numpy import ma @@ -42,8 +44,6 @@ def _make_mesh_to_grid_regridder(): lat_bounds = (-90, 90) # TODO check that circularity is preserved. tgt = _grid_cube(tgt_lons, tgt_lats, lon_bounds, lat_bounds, circular=True) - tgt.coord("longitude").var_name = "longitude" - tgt.coord("latitude").var_name = "latitude" src = _gridlike_mesh_cube(src_lons, src_lats) rg = MeshToGridESMFRegridder(src, tgt, mdtol=0.5) @@ -94,8 +94,12 @@ def test_MeshToGridESMFRegridder_round_trip(tmp_path): loaded_rg = load_regridder(str(filename)) assert original_rg.mdtol == loaded_rg.mdtol - assert original_rg.grid_x == loaded_rg.grid_x - assert original_rg.grid_y == loaded_rg.grid_y + loaded_grid_x = deepcopy(loaded_rg.grid_x) + loaded_grid_x.var_name = original_rg.grid_x.var_name + assert original_rg.grid_x == loaded_grid_x + loaded_grid_y = deepcopy(loaded_rg.grid_y) + loaded_grid_y.var_name = original_rg.grid_y.var_name + assert original_rg.grid_y == loaded_grid_y # TODO: uncomment when iris mesh comparison becomes available. # assert original_rg.mesh == loaded_rg.mesh @@ -111,7 +115,17 @@ def test_MeshToGridESMFRegridder_round_trip(tmp_path): src_mask = np.zeros(src.data.shape) src_mask[0] = 1 src.data = ma.array(src_data, mask=src_mask) - assert original_rg(src) == loaded_rg(src) + # Compare results, ignoring var_name changes due to saving. + original_result = original_rg(src) + loaded_result = loaded_rg(src) + original_result.var_name = loaded_result.var_name + original_result.coord("latitude").var_name = loaded_result.coord( + "latitude" + ).var_name + original_result.coord("longitude").var_name = loaded_result.coord( + "longitude" + ).var_name + assert original_result == loaded_result # Ensure version data is equal. assert original_rg.regridder.esmf_version == loaded_rg.regridder.esmf_version diff --git a/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_MeshToGridESMFRegridder.py b/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_MeshToGridESMFRegridder.py index d35e79e2..8b9869b1 100644 --- a/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_MeshToGridESMFRegridder.py +++ b/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_MeshToGridESMFRegridder.py @@ -14,6 +14,7 @@ ) from esmf_regrid.tests.unit.experimental.unstructured_scheme.test__mesh_to_MeshInfo import ( _gridlike_mesh, + _gridlike_mesh_cube, ) from esmf_regrid.tests.unit.experimental.unstructured_scheme.test__regrid_unstructured_to_rectilinear__prepare import ( _flat_mesh_cube, @@ -136,6 +137,34 @@ def test_invalid_mdtol(): _ = MeshToGridESMFRegridder(src, tgt, mdtol=-1) +@pytest.mark.xfail +def test_mistmatched_mesh(): + """ + Test the calling of :func:`esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`. + + Checks that an error is raised when the regridder is called with a cube + whose mesh does not match the one used for initialisation. + """ + src = _flat_mesh_cube() + + n_lons = 6 + n_lats = 5 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + tgt = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True) + + rg = MeshToGridESMFRegridder(src, tgt) + + other_src = _gridlike_mesh_cube(n_lons, n_lats) + + with pytest.raises(ValueError) as excinfo: + _ = rg(other_src) + expected_message = ( + "The given cube is not defined on the same " "source mesh as this regridder." + ) + assert expected_message in str(excinfo.value) + + def test_laziness(): """Test that regridding is lazy when source data is lazy.""" n_lons = 12