From b4042d8330ee863501999a826294b140ff6a58f3 Mon Sep 17 00:00:00 2001 From: stephenworsley <49274989+stephenworsley@users.noreply.github.com> Date: Wed, 30 Mar 2022 16:21:09 +0100 Subject: [PATCH] Add RefinedGridInfo class (for potential Hovmoller/zonal mean support) (#165) * add RefinedGridInfo class * test for expanded lats * lint fix * add resolution to scheme * fix tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * save/load resolution * save/load resolution * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tests * fix tests * fix RefinedGridInfo source behaviour * fix regridder loading * fix tests * add error tests * add explanation to the mathematics. * check bounds are strictly increasing * address review comments * address review comments * address review comments * address review comments * fix test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- esmf_regrid/_esmf_sdo.py | 194 +++++++++++++++++- esmf_regrid/esmf_regridder.py | 18 +- esmf_regrid/experimental/io.py | 8 + .../experimental/unstructured_scheme.py | 108 +++++++--- .../unit/_esmf_sdo/test_RefinedGridInfo.py | 97 +++++++++ .../experimental/io/test_round_tripping.py | 36 +++- .../__init__.py | 0 .../test_MeshInfo.py | 0 .../test_GridToMeshESMFRegridder.py | 44 ++++ .../test_MeshToGridESMFRegridder.py | 48 +++++ .../test__cube_to_GridInfo.py | 6 +- ...test_regrid_rectilinear_to_unstructured.py | 36 +++- ...test_regrid_unstructured_to_rectilinear.py | 34 +++ 13 files changed, 590 insertions(+), 39 deletions(-) create mode 100644 esmf_regrid/tests/unit/_esmf_sdo/test_RefinedGridInfo.py rename esmf_regrid/tests/unit/experimental/{unstuctured_regrid => unstructured_regrid}/__init__.py (100%) rename esmf_regrid/tests/unit/experimental/{unstuctured_regrid => unstructured_regrid}/test_MeshInfo.py (100%) diff --git a/esmf_regrid/_esmf_sdo.py b/esmf_regrid/_esmf_sdo.py index 0940f3e3..bc25850f 100644 --- a/esmf_regrid/_esmf_sdo.py +++ b/esmf_regrid/_esmf_sdo.py @@ -5,6 +5,7 @@ import cartopy.crs as ccrs import ESMF import numpy as np +import scipy.sparse class SDO(ABC): @@ -35,6 +36,11 @@ def shape(self): """Return shape.""" return self._shape + @property + def _refined_shape(self): + """Return shape passed to ESMF.""" + return self._shape + @property def dims(self): """Return number of dimensions.""" @@ -45,6 +51,11 @@ def size(self): """Return the number of cells in the sdo.""" return np.prod(self._shape) + @property + def _refined_size(self): + """Return the number of cells passed to ESMF.""" + return np.prod(self._refined_shape) + @property def index_offset(self): """Return the index offset.""" @@ -170,7 +181,9 @@ def __init__( shape = self.lons.shape self.lonbounds = lonbounds + self._refined_lonbounds = lonbounds self.latbounds = latbounds + self._refined_latbounds = latbounds if crs is None: self.crs = ccrs.Geodetic() else: @@ -185,17 +198,19 @@ def __init__( ) def _as_esmf_info(self): - shape = np.array(self._shape) + shape = np.array(self._refined_shape) londims = len(self.lons.shape) if londims == 1: if self.circular: - adjustedlonbounds = self.lonbounds[:-1] + adjustedlonbounds = self._refined_lonbounds[:-1] else: - adjustedlonbounds = self.lonbounds + adjustedlonbounds = self._refined_lonbounds centerlons, centerlats = np.meshgrid(self.lons, self.lats) - cornerlons, cornerlats = np.meshgrid(adjustedlonbounds, self.latbounds) + cornerlons, cornerlats = np.meshgrid( + adjustedlonbounds, self._refined_latbounds + ) elif londims == 2: if self.circular: slice = np.s_[:, :-1] @@ -203,8 +218,8 @@ def _as_esmf_info(self): slice = np.s_[:] centerlons = self.lons[slice] centerlats = self.lats[slice] - cornerlons = self.lonbounds[slice] - cornerlats = self.latbounds[slice] + cornerlons = self._refined_lonbounds[slice] + cornerlats = self._refined_latbounds[slice] truecenters = ccrs.Geodetic().transform_points(self.crs, centerlons, centerlats) truecorners = ccrs.Geodetic().transform_points(self.crs, cornerlons, cornerlats) @@ -275,3 +290,170 @@ def _make_esmf_sdo(self): grid_areas[:] = areas.T return grid + + +class RefinedGridInfo(GridInfo): + """ + Class for handling structured grids represented in :mod:`ESMF` in higher resolution. + + A specialised version of :class:`GridInfo`. Designed to provide higher + accuracy conservative regridding for rectilinear grids, especially those with + particularly large cells which may not be well represented by :mod:`ESMF`. This + class differs from :class:`GridInfo` primarily in the way it represents itself + as a :class:`~ESMF.api.field.Field` in :mod:`ESMF`. This :class:`~ESMF.api.field.Field` + is designed to be a higher resolution version of the given grid and should + contain enough information for area weighted regridding but may be + inappropriate for other :mod:`ESMF` regridding schemes. + + """ + + def __init__( + self, + lonbounds, + latbounds, + resolution=3, + crs=None, + ): + """ + Create a :class:`RefinedGridInfo` object describing the grid. + + Parameters + ---------- + lonbounds : :obj:`~numpy.typing.ArrayLike` + A 1D array or list describing the longitude bounds of the grid. + Must be strictly increasing (for example, if a bound goes from + 170 to -170 consider transposing -170 to 190). + latbounds : :obj:`~numpy.typing.ArrayLike` + A 1D array or list describing the latitude bounds of the grid. + Must be strictly increasing. + resolution : int, default=400 + A number describing how many latitude slices each cell should + be divided into when passing a higher resolution grid to ESMF. + crs : :class:`cartopy.crs.CRS`, optional + Describes how to interpret the + above arguments. If ``None``, defaults to :class:`~cartopy.crs.Geodetic`. + + """ + # Convert bounds to numpy arrays where necessary. + if not isinstance(lonbounds, np.ndarray): + lonbounds = np.array(lonbounds) + if not isinstance(latbounds, np.ndarray): + latbounds = np.array(latbounds) + + # Ensure bounds are strictly increasing. + if not np.all(lonbounds[:-1] < lonbounds[1:]): + raise ValueError("The longitude bounds must be strictly increasing.") + if not np.all(latbounds[:-1] < latbounds[1:]): + raise ValueError("The latitude bounds must be strictly increasing.") + + self.resolution = resolution + self.n_lons_orig = len(lonbounds) - 1 + self.n_lats_orig = len(latbounds) - 1 + + # Create dummy lat/lon values + lons = np.zeros(self.n_lons_orig) + lats = np.zeros(self.n_lats_orig) + super().__init__(lons, lats, lonbounds, latbounds, crs=crs) + + if self.n_lats_orig == 1 and np.allclose(latbounds, [-90, 90]): + self._refined_latbounds = np.array([-90, 0, 90]) + self._refined_lonbounds = lonbounds + else: + self._refined_latbounds = latbounds + self._refined_lonbounds = np.append( + np.linspace( + lonbounds[:-1], + lonbounds[1:], + self.resolution, + endpoint=False, + axis=1, + ).flatten(), + lonbounds[-1], + ) + self.lon_expansion = int( + (len(self._refined_lonbounds) - 1) / (len(self.lonbounds) - 1) + ) + self.lat_expansion = int( + (len(self._refined_latbounds) - 1) / (len(self.latbounds) - 1) + ) + + @property + def _refined_shape(self): + """Return shape passed to ESMF.""" + return ( + self.n_lats_orig * self.lat_expansion, + self.n_lons_orig * self.lon_expansion, + ) + + def _collapse_weights(self, is_tgt): + """ + Return a matrix to collapse the weight matrix. + + The refined grid may contain more cells than the represented grid. When this is + the case, the generated weight matrix will refer to too many points and will have + to be collapsed. This is done by multiplying by this matrix, pre-multiplying when + the target grid is represented and post multiplying when the source grid is + represented. + + Parameters + ---------- + is_tgt : bool + True if the target field is being represented, False otherwise. + """ + # The column indices represent each of the cells in the refined grid. + column_indices = np.arange(self._refined_size) + + # The row indices represent the cells of the unrefined grid. These are broadcast + # so that each row index coincides with all column indices of the refined cells + # which the unrefined cell is split into. + if self.lat_expansion > 1: + # The latitudes are expanded only in the case where there is one latitude + # bound from -90 to 90. In this case, there is no longitude expansion. + row_indices = np.empty([self.n_lons_orig, self.lat_expansion]) + row_indices[:] = np.arange(self.n_lons_orig)[:, np.newaxis] + else: + # The row indices are broadcast across a dimension representing the expansion + # of the longitude. Each row index is broadcast and flattened so that all the + # row indices representing the unrefined cell match up with the column indices + # representing the refined cells it is split into. + row_indices = np.empty( + [self.n_lons_orig, self.lon_expansion, self.n_lats_orig] + ) + row_indices[:] = np.arange(self.n_lons_orig * self.n_lats_orig).reshape( + [self.n_lons_orig, self.n_lats_orig] + )[:, np.newaxis, :] + row_indices = row_indices.flatten() + matrix_shape = (self.size, self._refined_size) + refinement_weights = scipy.sparse.csr_matrix( + ( + np.ones(self._refined_size), + (row_indices, column_indices), + ), + shape=matrix_shape, + ) + if is_tgt: + # When the RefinedGridInfo is the target of the regridder, we want to take + # the average of the weights of each refined target cell. This is because + # these weights represent the proportion of area of the target cells which + # is covered by a given source cell. Since the refined cells are divided in + # such a way that they have equal area, the weights for the unrefined cells + # can be reconstructed by taking an average. This is done via matrix + # multiplication, with the returned matrix pre-multiplying the weight matrix + # so that it operates on the rows of the weight matrix (representing the + # target cells). At this point the returned matrix consists of ones, so we + # divided by the number of refined cells per unrefined cell. + refinement_weights = refinement_weights / ( + self.lon_expansion * self.lat_expansion + ) + else: + # When the RefinedGridInfo is the source of the regridder, we want to take + # the sum of the weights of each refined target cell. This is because those + # weights represent the proportion of the area of a given target cell which + # is covered by each refined source cell. The total proportion covered by + # each unrefined source cell is then the sum of the weights from each of its + # refined cells. This sum is done by matrix multiplication, the returned + # matrix post-multiplying the weight matrix so that it operates on the columns + # of the weight matrix (representing the source cells). In order for the + # post-multiplication to work, the returned matrix must be transposed. + refinement_weights = refinement_weights.T + return refinement_weights diff --git a/esmf_regrid/esmf_regridder.py b/esmf_regrid/esmf_regridder.py index 4113fb66..0226fdf9 100644 --- a/esmf_regrid/esmf_regridder.py +++ b/esmf_regrid/esmf_regridder.py @@ -6,7 +6,7 @@ import scipy.sparse import esmf_regrid -from ._esmf_sdo import GridInfo +from ._esmf_sdo import GridInfo, RefinedGridInfo __all__ = [ "GridInfo", @@ -103,9 +103,23 @@ def __init__(self, src, tgt, method="conservative", precomputed_weights=None): ) self.weight_matrix = _weights_dict_to_sparse_array( weights_dict, - (self.tgt.size, self.src.size), + (self.tgt._refined_size, self.src._refined_size), (self.tgt.index_offset, self.src.index_offset), ) + if isinstance(tgt, RefinedGridInfo): + # At this point, the weight matrix represents more target points than + # tgt respresents. In order to collapse these points, we collapse the + # weights matrix by the appropriate matrix multiplication. + self.weight_matrix = ( + tgt._collapse_weights(is_tgt=True) @ self.weight_matrix + ) + if isinstance(src, RefinedGridInfo): + # At this point, the weight matrix represents more source points than + # src respresents. In order to collapse these points, we collapse the + # weights matrix by the appropriate matrix multiplication. + self.weight_matrix = self.weight_matrix @ src._collapse_weights( + is_tgt=False + ) else: if not scipy.sparse.isspmatrix(precomputed_weights): raise ValueError( diff --git a/esmf_regrid/experimental/io.py b/esmf_regrid/experimental/io.py index cf01578f..77a71172 100644 --- a/esmf_regrid/experimental/io.py +++ b/esmf_regrid/experimental/io.py @@ -30,6 +30,7 @@ VERSION_INITIAL = "esmf_regrid_version_on_initialise" MDTOL = "mdtol" METHOD = "method" +RESOLUTION = "resolution" def save_regridder(rg, filename): @@ -87,6 +88,7 @@ def save_regridder(rg, filename): raise TypeError(msg) method = rg.method + resolution = rg.resolution weight_matrix = rg.regridder.weight_matrix reformatted_weight_matrix = scipy.sparse.coo_matrix(weight_matrix) @@ -113,6 +115,8 @@ def save_regridder(rg, filename): MDTOL: mdtol, METHOD: method, } + if resolution is not None: + attributes[RESOLUTION] = resolution weights_cube = Cube(weight_data, var_name=WEIGHTS_NAME, long_name=WEIGHTS_NAME) row_coord = AuxCoord( @@ -178,6 +182,9 @@ def load_regridder(filename): # Determine the regridding method, allowing for files created when # conservative regridding was the only method. method = weights_cube.attributes.get(METHOD, "conservative") + resolution = weights_cube.attributes.get(RESOLUTION, None) + if resolution is not None: + resolution = int(resolution) # Reconstruct the weight matrix. weight_data = weights_cube.data @@ -196,6 +203,7 @@ def load_regridder(filename): mdtol=mdtol, method=method, precomputed_weights=weight_matrix, + resolution=resolution, ) esmf_version = weights_cube.attributes[VERSION_ESMF] diff --git a/esmf_regrid/experimental/unstructured_scheme.py b/esmf_regrid/experimental/unstructured_scheme.py index dddf8962..fc0e7db8 100644 --- a/esmf_regrid/experimental/unstructured_scheme.py +++ b/esmf_regrid/experimental/unstructured_scheme.py @@ -7,7 +7,7 @@ from iris.analysis._interpolation import get_xy_dim_coords import numpy as np -from esmf_regrid.esmf_regridder import GridInfo, Regridder +from esmf_regrid.esmf_regridder import GridInfo, RefinedGridInfo, Regridder from esmf_regrid.experimental.unstructured_regrid import MeshInfo @@ -133,7 +133,7 @@ def _mesh_to_MeshInfo(mesh, location): return meshinfo -def _cube_to_GridInfo(cube, center): +def _cube_to_GridInfo(cube, center, resolution): # This is a simplified version of an equivalent function/method in PR #26. # It is anticipated that this function will be replaced by the one in PR #26. # @@ -147,14 +147,22 @@ def _cube_to_GridInfo(cube, center): # TODO: accommodate other x/y coords. # TODO: perform checks on lat/lon. # Checks may cover units, coord systems (e.g. rotated pole), contiguous bounds. - return GridInfo( - lon.points, - lat.points, - _bounds_cf_to_simple_1d(lon.bounds), - _bounds_cf_to_simple_1d(lat.bounds), - circular=lon.circular, - center=center, - ) + if resolution is None: + grid_info = GridInfo( + lon.points, + lat.points, + _bounds_cf_to_simple_1d(lon.bounds), + _bounds_cf_to_simple_1d(lat.bounds), + circular=lon.circular, + center=center, + ) + else: + grid_info = RefinedGridInfo( + _bounds_cf_to_simple_1d(lon.bounds), + _bounds_cf_to_simple_1d(lat.bounds), + resolution=resolution, + ) + return grid_info def _regrid_along_mesh_dim(regridder, data, mesh_dim, mdtol): @@ -256,6 +264,7 @@ def _regrid_unstructured_to_rectilinear__prepare( target_grid_cube, method, precomputed_weights=None, + resolution=None, ): """ First (setup) part of 'regrid_unstructured_to_rectilinear'. @@ -306,7 +315,7 @@ def _regrid_unstructured_to_rectilinear__prepare( mesh_dim = src_mesh_cube.mesh_dim() meshinfo = _mesh_to_MeshInfo(mesh, location) - gridinfo = _cube_to_GridInfo(target_grid_cube, center=center) + gridinfo = _cube_to_GridInfo(target_grid_cube, center=center, resolution=resolution) regridder = Regridder( meshinfo, gridinfo, method=method, precomputed_weights=precomputed_weights @@ -358,7 +367,11 @@ def _regrid_unstructured_to_rectilinear__perform(src_cube, regrid_info, mdtol): def regrid_unstructured_to_rectilinear( - src_cube, grid_cube, mdtol=0, method="conservative" + src_cube, + grid_cube, + mdtol=0, + method="conservative", + resolution=None, ): r""" Regrid unstructured :class:`~iris.cube.Cube` onto rectilinear grid. @@ -399,6 +412,9 @@ def regrid_unstructured_to_rectilinear( Either "conservative" or "bilinear". Corresponds to the :mod:`ESMF` methods :attr:`~ESMF.api.constants.RegridMethod.CONSERVE` or :attr:`~ESMF.api.constants.RegridMethod.BILINEAR` used to calculate weights. + resolution : int, optional + If present, represents the amount of latitude slices per cell + given to ESMF for calculation. Returns ------- @@ -407,7 +423,10 @@ def regrid_unstructured_to_rectilinear( """ regrid_info = _regrid_unstructured_to_rectilinear__prepare( - src_cube, grid_cube, method=method + src_cube, + grid_cube, + method=method, + resolution=resolution, ) result = _regrid_unstructured_to_rectilinear__perform(src_cube, regrid_info, mdtol) return result @@ -423,6 +442,7 @@ def __init__( mdtol=None, method="conservative", precomputed_weights=None, + resolution=None, ): """ Create regridder for conversions between source mesh and target grid. @@ -448,6 +468,11 @@ def __init__( If ``None``, :mod:`ESMF` will be used to calculate regridding weights. Otherwise, :mod:`ESMF` will be bypassed and ``precomputed_weights`` will be used as the regridding weights. + resolution : int, optional + If present, represents the amount of latitude slices per cell + given to ESMF for calculation. If resolution is set, target_grid_cube + must have strictly increasing bounds (bounds may be transposed plus or + minus 360 degrees to make the bounds strictly increasing). """ # TODO: Record information about the identity of the mesh. This would @@ -472,11 +497,21 @@ def __init__( self.mdtol = mdtol self.method = method + if resolution is not None: + if not (isinstance(resolution, int) and resolution > 0): + raise ValueError("resolution must be a positive integer.") + if method != "conservative": + raise ValueError( + "resolution can only be set for conservative regridding." + ) + self.resolution = resolution + partial_regrid_info = _regrid_unstructured_to_rectilinear__prepare( src_mesh_cube, target_grid_cube, method=self.method, precomputed_weights=precomputed_weights, + resolution=resolution, ) # Record source mesh. @@ -621,6 +656,7 @@ def _regrid_rectilinear_to_unstructured__prepare( target_mesh_cube, method, precomputed_weights=None, + resolution=None, ): """ First (setup) part of 'regrid_rectilinear_to_unstructured'. @@ -662,7 +698,7 @@ def _regrid_rectilinear_to_unstructured__prepare( grid_y_dim = src_grid_cube.coord_dims(grid_y)[0] meshinfo = _mesh_to_MeshInfo(mesh, location) - gridinfo = _cube_to_GridInfo(src_grid_cube, center=center) + gridinfo = _cube_to_GridInfo(src_grid_cube, center=center, resolution=resolution) regridder = Regridder( gridinfo, meshinfo, method=method, precomputed_weights=precomputed_weights @@ -718,7 +754,11 @@ def _regrid_rectilinear_to_unstructured__perform(src_cube, regrid_info, mdtol): def regrid_rectilinear_to_unstructured( - src_cube, mesh_cube, mdtol=0, method="conservative" + src_cube, + mesh_cube, + mdtol=0, + method="conservative", + resolution=None, ): r""" Regrid rectilinear :class:`~iris.cube.Cube` onto unstructured mesh. @@ -763,6 +803,9 @@ def regrid_rectilinear_to_unstructured( Either "conservative" or "bilinear". Corresponds to the :mod:`ESMF` methods :attr:`~ESMF.api.constants.RegridMethod.CONSERVE` or :attr:`~ESMF.api.constants.RegridMethod.BILINEAR` used to calculate weights. + resolution : int, optional + If present, represents the amount of latitude slices per cell + given to ESMF for calculation. Returns ------- @@ -771,7 +814,10 @@ def regrid_rectilinear_to_unstructured( """ regrid_info = _regrid_rectilinear_to_unstructured__prepare( - src_cube, mesh_cube, method=method + src_cube, + mesh_cube, + method=method, + resolution=resolution, ) result = _regrid_rectilinear_to_unstructured__perform(src_cube, regrid_info, mdtol) return result @@ -782,11 +828,12 @@ class GridToMeshESMFRegridder: def __init__( self, - src_mesh_cube, - target_grid_cube, + src_grid_cube, + target_mesh_cube, mdtol=None, method="conservative", precomputed_weights=None, + resolution=None, ): """ Create regridder for conversions between source grid and target mesh. @@ -794,9 +841,9 @@ def __init__( Parameters ---------- src_grid_cube : :class:`iris.cube.Cube` - The unstructured :class:`~iris.cube.Cube` cube providing the source grid. - target_grid_cube : :class:`iris.cube.Cube` - The rectilinear :class:`~iris.cube.Cube` providing the target mesh. + The rectilinear :class:`~iris.cube.Cube` cube providing the source grid. + target_mesh_cube : :class:`iris.cube.Cube` + The unstructured :class:`~iris.cube.Cube` providing the target mesh. mdtol : float, optional Tolerance of missing data. The value returned in each element of the returned array will be masked if the fraction of masked data @@ -812,12 +859,25 @@ def __init__( If ``None``, :mod:`ESMF` will be used to calculate regridding weights. Otherwise, :mod:`ESMF` will be bypassed and ``precomputed_weights`` will be used as the regridding weights. + resolution : int, optional + If present, represents the amount of latitude slices per cell + given to ESMF for calculation. If resolution is set, src_grid_cube + must have strictly increasing bounds (bounds may be transposed plus or + minus 360 degrees to make the bounds strictly increasing). """ if method not in ["conservative", "bilinear"]: raise ValueError( f"method must be either 'bilinear' or 'conservative', got '{method}'." ) + + if resolution is not None: + if not (isinstance(resolution, int) and resolution > 0): + raise ValueError("resolution must be a positive integer.") + if method != "conservative": + raise ValueError( + "resolution can only be set for conservative regridding." + ) # Missing data tolerance. # Code directly copied from iris. if mdtol is None: @@ -830,12 +890,14 @@ def __init__( raise ValueError(msg.format(mdtol)) self.mdtol = mdtol self.method = method + self.resolution = resolution partial_regrid_info = _regrid_rectilinear_to_unstructured__prepare( - src_mesh_cube, - target_grid_cube, + src_grid_cube, + target_mesh_cube, method=self.method, precomputed_weights=precomputed_weights, + resolution=self.resolution, ) # Store regrid info. diff --git a/esmf_regrid/tests/unit/_esmf_sdo/test_RefinedGridInfo.py b/esmf_regrid/tests/unit/_esmf_sdo/test_RefinedGridInfo.py new file mode 100644 index 00000000..45f9ca11 --- /dev/null +++ b/esmf_regrid/tests/unit/_esmf_sdo/test_RefinedGridInfo.py @@ -0,0 +1,97 @@ +"""Unit tests for :class:`esmf_regrid._esmf_sdo.RefinedGridInfo`.""" + +import numpy as np +from numpy import ma + +from esmf_regrid.esmf_regridder import RefinedGridInfo, Regridder +from esmf_regrid.experimental.unstructured_regrid import MeshInfo +from esmf_regrid.tests import make_grid_args +from esmf_regrid.tests.unit.experimental.unstructured_regrid.test_MeshInfo import ( + _make_small_mesh_args, +) + + +def test_expanded_lons_with_mesh(): + """ + Basic test for regridding with :meth:`~esmf_regrid.esmf_regridder.RefinedGridInfo.make_esmf_field`. + + Mirrors the tests in :func:`~esmf_regrid.tests.unit.experimental.unstructured_regrid.test_MeshInfo.test_regrid_with_mesh` + but with slightly different expected values due to increased accuracy. + """ + mesh_args = _make_small_mesh_args() + mesh = MeshInfo(*mesh_args) + + grid_args = make_grid_args(2, 3) + grid = RefinedGridInfo(*grid_args[2:4], resolution=4) + + mesh_to_grid_regridder = Regridder(mesh, grid) + mesh_input = np.array([3, 2]) + grid_output = mesh_to_grid_regridder.regrid(mesh_input) + expected_grid_output = np.array( + [ + [2.671534474734418, 3.0], + [2.088765949748455, 2.922517356506756], + [2.0, 2.340882413622917], + ] + ) + assert ma.allclose(expected_grid_output, grid_output) + + grid_to_mesh_regridder = Regridder(grid, mesh) + grid_input = np.array([[0, 0], [1, 0], [2, 1]]) + mesh_output = grid_to_mesh_regridder.regrid(grid_input) + expected_mesh_output = np.array([0.14117205318254747, 1.1976140197893996]) + assert ma.allclose(expected_mesh_output, mesh_output) + + def _give_extra_dims(array): + result = np.stack([array, array + 1]) + result = np.stack([result, result + 10, result + 100]) + return result + + extra_dim_mesh_input = _give_extra_dims(mesh_input) + extra_dim_grid_output = mesh_to_grid_regridder.regrid(extra_dim_mesh_input) + extra_dim_expected_grid_output = _give_extra_dims(expected_grid_output) + assert ma.allclose(extra_dim_expected_grid_output, extra_dim_grid_output) + + extra_dim_grid_input = _give_extra_dims(grid_input) + extra_dim_mesh_output = grid_to_mesh_regridder.regrid(extra_dim_grid_input) + extra_dim_expected_mesh_output = _give_extra_dims(expected_mesh_output) + assert ma.allclose(extra_dim_expected_mesh_output, extra_dim_mesh_output) + + +def test_expanded_lats_with_mesh(): + """Basic test for regridding with :meth:`~esmf_regrid.esmf_regridder.RefinedGridInfo.make_esmf_field`.""" + mesh_args = _make_small_mesh_args() + mesh = MeshInfo(*mesh_args) + + grid = RefinedGridInfo(np.array([0, 5, 10]), np.array([-90, 90]), resolution=4) + + mesh_to_grid_regridder = Regridder(mesh, grid) + mesh_input = np.array([3, 2]) + grid_output = mesh_to_grid_regridder.regrid(mesh_input) + expected_grid_output = np.array( + [ + [2.2024695514629724, 2.4336888097502642], + ] + ) + assert ma.allclose(expected_grid_output, grid_output) + + grid_to_mesh_regridder = Regridder(grid, mesh) + grid_input = np.array([[1, 2]]) + mesh_output = grid_to_mesh_regridder.regrid(grid_input) + expected_mesh_output = np.array([1.7480791292591336, 1.496070008348207]) + assert ma.allclose(expected_mesh_output, mesh_output) + + def _give_extra_dims(array): + result = np.stack([array, array + 1]) + result = np.stack([result, result + 10, result + 100]) + return result + + extra_dim_mesh_input = _give_extra_dims(mesh_input) + extra_dim_grid_output = mesh_to_grid_regridder.regrid(extra_dim_mesh_input) + extra_dim_expected_grid_output = _give_extra_dims(expected_grid_output) + assert ma.allclose(extra_dim_expected_grid_output, extra_dim_grid_output) + + extra_dim_grid_input = _give_extra_dims(grid_input) + extra_dim_mesh_output = grid_to_mesh_regridder.regrid(extra_dim_grid_input) + extra_dim_expected_mesh_output = _give_extra_dims(expected_mesh_output) + assert ma.allclose(extra_dim_expected_mesh_output, extra_dim_mesh_output) diff --git a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py index 79960897..fcdd39cb 100644 --- a/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py +++ b/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py @@ -18,7 +18,7 @@ ) -def _make_grid_to_mesh_regridder(method="conservative"): +def _make_grid_to_mesh_regridder(method="conservative", resolution=None): src_lons = 3 src_lats = 4 tgt_lons = 5 @@ -35,11 +35,13 @@ def _make_grid_to_mesh_regridder(method="conservative"): location = "face" tgt = _gridlike_mesh_cube(tgt_lons, tgt_lats, location=location) - rg = GridToMeshESMFRegridder(src, tgt, method=method, mdtol=0.5) + rg = GridToMeshESMFRegridder( + src, tgt, method=method, mdtol=0.5, resolution=resolution + ) return rg, src -def _make_mesh_to_grid_regridder(method="conservative"): +def _make_mesh_to_grid_regridder(method="conservative", resolution=None): src_lons = 3 src_lats = 4 tgt_lons = 5 @@ -54,7 +56,9 @@ def _make_mesh_to_grid_regridder(method="conservative"): location = "face" src = _gridlike_mesh_cube(src_lons, src_lats, location=location) - rg = MeshToGridESMFRegridder(src, tgt, method=method, mdtol=0.5) + rg = MeshToGridESMFRegridder( + src, tgt, method=method, mdtol=0.5, resolution=resolution + ) return rg, src @@ -95,6 +99,18 @@ def test_GridToMeshESMFRegridder_round_trip(tmp_path): == loaded_rg.regridder.esmf_regrid_version ) + # Ensure resolution is equal. + assert original_rg.resolution == loaded_rg.resolution + original_res_rg, _ = _make_grid_to_mesh_regridder(resolution=8) + res_filename = tmp_path / "regridder_res.nc" + save_regridder(original_res_rg, res_filename) + loaded_res_rg = load_regridder(str(res_filename)) + assert original_res_rg.resolution == loaded_res_rg.resolution + assert ( + original_res_rg.regridder.src.resolution + == loaded_res_rg.regridder.src.resolution + ) + def test_GridToMeshESMFRegridder_bilinear_round_trip(tmp_path): """Test save/load round tripping for `GridToMeshESMFRegridder`.""" @@ -184,6 +200,18 @@ def test_MeshToGridESMFRegridder_round_trip(tmp_path): == loaded_rg.regridder.esmf_regrid_version ) + # Ensure resolution is equal. + assert original_rg.resolution == loaded_rg.resolution + original_res_rg, _ = _make_mesh_to_grid_regridder(resolution=8) + res_filename = tmp_path / "regridder_res.nc" + save_regridder(original_res_rg, res_filename) + loaded_res_rg = load_regridder(str(res_filename)) + assert original_res_rg.resolution == loaded_res_rg.resolution + assert ( + original_res_rg.regridder.tgt.resolution + == loaded_res_rg.regridder.tgt.resolution + ) + def test_MeshToGridESMFRegridder_bilinear_round_trip(tmp_path): """Test save/load round tripping for `MeshToGridESMFRegridder`.""" diff --git a/esmf_regrid/tests/unit/experimental/unstuctured_regrid/__init__.py b/esmf_regrid/tests/unit/experimental/unstructured_regrid/__init__.py similarity index 100% rename from esmf_regrid/tests/unit/experimental/unstuctured_regrid/__init__.py rename to esmf_regrid/tests/unit/experimental/unstructured_regrid/__init__.py diff --git a/esmf_regrid/tests/unit/experimental/unstuctured_regrid/test_MeshInfo.py b/esmf_regrid/tests/unit/experimental/unstructured_regrid/test_MeshInfo.py similarity index 100% rename from esmf_regrid/tests/unit/experimental/unstuctured_regrid/test_MeshInfo.py rename to esmf_regrid/tests/unit/experimental/unstructured_regrid/test_MeshInfo.py diff --git a/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_GridToMeshESMFRegridder.py b/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_GridToMeshESMFRegridder.py index 2e627ef6..2f4c9d85 100644 --- a/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_GridToMeshESMFRegridder.py +++ b/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_GridToMeshESMFRegridder.py @@ -215,6 +215,30 @@ def test_invalid_method(): assert expected_message in str(excinfo.value) +def test_invalid_resolution(): + """ + Test initialisation of :func:`esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`. + + Checks that an error is raised when the resolution is invalid. + """ + n_lons = 6 + n_lats = 5 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + tgt = _gridlike_mesh_cube(n_lons, n_lats, location="face") + src = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True) + + with pytest.raises(ValueError) as excinfo: + _ = GridToMeshESMFRegridder(src, tgt, method="conservative", resolution=-1) + expected_message = "resolution must be a positive integer." + assert expected_message in str(excinfo.value) + + with pytest.raises(ValueError) as excinfo: + _ = GridToMeshESMFRegridder(src, tgt, method="bilinear", resolution=4) + expected_message = "resolution can only be set for conservative regridding." + assert expected_message in str(excinfo.value) + + def test_default_mdtol(): """ Test initialisation of :func:`esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`. @@ -328,3 +352,23 @@ def test_laziness(): expected_chunks = ((120,), (2, 2)) assert out_chunks == expected_chunks assert np.allclose(result.data, src_data.reshape([-1, h])) + + +def test_resolution(): + """ + Test for :func:`esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`. + + Tests for the resolution keyword. + """ + tgt = _flat_mesh_cube() + n_lons = 6 + n_lats = 5 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + grid = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True) + + resolution = 8 + + result = GridToMeshESMFRegridder(grid, tgt, resolution=resolution) + assert result.resolution == resolution + assert result.regridder.src.resolution == resolution diff --git a/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_MeshToGridESMFRegridder.py b/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_MeshToGridESMFRegridder.py index a0617df4..d826f392 100644 --- a/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_MeshToGridESMFRegridder.py +++ b/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_MeshToGridESMFRegridder.py @@ -210,6 +210,30 @@ def test_invalid_method(): assert expected_message in str(excinfo.value) +def test_invalid_resolution(): + """ + Test initialisation of :func:`esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`. + + Checks that an error is raised when the resolution is invalid. + """ + n_lons = 6 + n_lats = 5 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + src = _gridlike_mesh_cube(n_lons, n_lats, location="face") + tgt = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True) + + with pytest.raises(ValueError) as excinfo: + _ = MeshToGridESMFRegridder(src, tgt, method="conservative", resolution=-1) + expected_message = "resolution must be a positive integer." + assert expected_message in str(excinfo.value) + + with pytest.raises(ValueError) as excinfo: + _ = MeshToGridESMFRegridder(src, tgt, method="bilinear", resolution=4) + expected_message = "resolution can only be set for conservative regridding." + assert expected_message in str(excinfo.value) + + def test_default_mdtol(): """ Test initialisation of :func:`esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`. @@ -298,3 +322,27 @@ def test_laziness(): expected_chunks = ((1,), (3, 3, 3), (10,), (12,), (2, 2)) assert out_chunks == expected_chunks assert np.allclose(result.data.reshape([1, i, -1, h]), src_data) + + +def test_resolution(): + """ + Test for :func:`esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`. + + Tests for the resolution keyword. + """ + mesh_cube = _flat_mesh_cube() + + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + lon_bands = _grid_cube(1, 4, lon_bounds, lat_bounds) + lat_bands = _grid_cube(4, 1, lon_bounds, lat_bounds) + + resolution = 8 + + lon_band_rg = MeshToGridESMFRegridder(mesh_cube, lon_bands, resolution=resolution) + assert lon_band_rg.resolution == resolution + assert lon_band_rg.regridder.tgt.resolution == resolution + + lat_band_rg = MeshToGridESMFRegridder(mesh_cube, lat_bands, resolution=resolution) + assert lat_band_rg.resolution == resolution + assert lat_band_rg.regridder.tgt.resolution == resolution diff --git a/esmf_regrid/tests/unit/experimental/unstructured_scheme/test__cube_to_GridInfo.py b/esmf_regrid/tests/unit/experimental/unstructured_scheme/test__cube_to_GridInfo.py index af54e388..4299d131 100644 --- a/esmf_regrid/tests/unit/experimental/unstructured_scheme/test__cube_to_GridInfo.py +++ b/esmf_regrid/tests/unit/experimental/unstructured_scheme/test__cube_to_GridInfo.py @@ -41,7 +41,7 @@ def test_global_grid(): lat_bounds = (-90, 90) cube = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True) - gridinfo = _cube_to_GridInfo(cube, center=False) + gridinfo = _cube_to_GridInfo(cube, center=False, resolution=None) # Ensure conversion to ESMF works without error _ = gridinfo.make_esmf_field() @@ -61,7 +61,7 @@ def test_local_grid(): lat_bounds = (20, 60) cube = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds) - gridinfo = _cube_to_GridInfo(cube, center=False) + gridinfo = _cube_to_GridInfo(cube, center=False, resolution=None) # Ensure conversion to ESMF works without error _ = gridinfo.make_esmf_field() @@ -84,7 +84,7 @@ def test_grid_with_scalars(): cube = cube[:, 0] assert len(cube.shape) == 1 - gridinfo = _cube_to_GridInfo(cube, center=False) + gridinfo = _cube_to_GridInfo(cube, center=False, resolution=None) # Ensure conversion to ESMF works without error _ = gridinfo.make_esmf_field() diff --git a/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_regrid_rectilinear_to_unstructured.py b/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_regrid_rectilinear_to_unstructured.py index 97559995..77cab677 100644 --- a/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_regrid_rectilinear_to_unstructured.py +++ b/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_regrid_rectilinear_to_unstructured.py @@ -56,7 +56,7 @@ def test_flat_cubes(): src_T.transpose() result_transposed = regrid_rectilinear_to_unstructured(src_T, tgt) - expected_data = np.ones([n_lats, n_lons]) + expected_data = np.ones_like(tgt.data) expected_cube = _add_metadata(tgt) # Lenient check for data. @@ -233,3 +233,37 @@ def test_mask_handling(): assert ma.allclose(expected_0, result_0.data) assert ma.allclose(expected_05, result_05.data) assert ma.allclose(expected_1, result_1.data) + + +def test_resolution(): + """ + Basic test for :func:`esmf_regrid.experimental.unstructured_scheme.regrid_rectilinear_to_unstructured`. + + Tests the resolution keyword with grids that would otherwise not work. + """ + tgt = _flat_mesh_cube() + + # The resulting grid has full latitude bounds and cells must be split up. + n_lons = 1 + n_lats = 4 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + src = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds) + # Ensure data in the target grid is different to the expected data. + # i.e. target grid data is all zero, expected data is all one + tgt.data[:] = 0 + + src = _add_metadata(src) + src.data[:] = 1 # Ensure all data in the source is one. + result = regrid_rectilinear_to_unstructured(src, tgt, resolution=8) + + expected_data = np.ones_like(tgt.data) + expected_cube = _add_metadata(tgt) + + # Lenient check for data. + # Note that when resolution=None, this would be a fully masked array. + assert np.allclose(expected_data, result.data) + + # Check metadata and scalar coords. + expected_cube.data = result.data + assert expected_cube == result diff --git a/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_regrid_unstructured_to_rectilinear.py b/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_regrid_unstructured_to_rectilinear.py index 56580ea5..c617dc8f 100644 --- a/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_regrid_unstructured_to_rectilinear.py +++ b/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_regrid_unstructured_to_rectilinear.py @@ -176,3 +176,37 @@ def test_multidim_cubes(): # Check metadata and scalar coords. result.data = expected_data assert expected_cube == result + + +def test_resolution(): + """ + Basic test for :func:`esmf_regrid.experimental.unstructured_scheme.regrid_unstructured_to_rectilinear`. + + Tests the resolution keyword with grids that would otherwise not work. + """ + src = _flat_mesh_cube() + + # The resulting grid has full latitude bounds and cells must be split up. + n_lons = 1 + n_lats = 5 + lon_bounds = (-180, 180) + lat_bounds = (-90, 90) + tgt = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds) + # Ensure data in the target grid is different to the expected data. + # i.e. target grid data is all zero, expected data is all one + tgt.data[:] = 0 + + src = _add_metadata(src) + src.data[:] = 1 # Ensure all data in the source is one. + result = regrid_unstructured_to_rectilinear(src, tgt, resolution=8) + + expected_data = np.ones([n_lats, n_lons]) + expected_cube = _add_metadata(tgt) + + # Lenient check for data. + # Note that when resolution=None, this would be a fully masked array. + assert np.allclose(expected_data, result.data) + + # Check metadata and scalar coords. + expected_cube.data = result.data + assert expected_cube == result