Skip to content

Commit

Permalink
Fix regridding issue where target grid is more than 2D (#2087)
Browse files Browse the repository at this point in the history
  • Loading branch information
bouweandela authored Jun 8, 2023
1 parent 92e20fd commit 1cf81d2
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 10 deletions.
22 changes: 18 additions & 4 deletions esmvalcore/preprocessor/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,22 +681,36 @@ def regrid(cube, target_grid, scheme, lat_offset=True, lon_offset=True):
return cube


def _rechunk(cube, target_grid):
def _rechunk(
cube: iris.cube.Cube,
target_grid: iris.cube.Cube,
) -> iris.cube.Cube:
"""Re-chunk cube with optimal chunk sizes for target grid."""
if not cube.has_lazy_data() or cube.ndim < 3:
# Only rechunk lazy multidimensional data
return cube

if 2 * np.prod(cube.shape[-2:]) > np.prod(target_grid.shape):
lon_coord = target_grid.coord(axis='X')
lat_coord = target_grid.coord(axis='Y')
if lon_coord.ndim != 1 or lat_coord.ndim != 1:
# This function only supports 1D lat/lon coordinates.
return cube

lon_dim, = target_grid.coord_dims(lon_coord)
lat_dim, = target_grid.coord_dims(lat_coord)
grid_indices = sorted((lon_dim, lat_dim))
target_grid_shape = tuple(target_grid.shape[i] for i in grid_indices)

if 2 * np.prod(cube.shape[-2:]) > np.prod(target_grid_shape):
# Only rechunk if target grid is more than a factor of 2 larger,
# because rechunking will keep the original chunk in memory.
return cube

data = cube.lazy_data()

# Compute a good chunk size for the target array
tgt_shape = data.shape[:-2] + target_grid.shape
tgt_chunks = data.chunks[:-2] + target_grid.shape
tgt_shape = data.shape[:-2] + target_grid_shape
tgt_chunks = data.chunks[:-2] + target_grid_shape
tgt_data = da.empty(tgt_shape, dtype=data.dtype, chunks=tgt_chunks)
tgt_data = tgt_data.rechunk({i: "auto" for i in range(cube.ndim - 2)})

Expand Down
73 changes: 67 additions & 6 deletions tests/unit/preprocessor/_regrid/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,34 @@ def test_use_mask_if_discontiguities_in_coords(caplog):
assert msg in caplog.text


def make_test_cube(shape):
data = da.empty(shape, dtype=np.float32)
cube = iris.cube.Cube(data)
if len(shape) > 2:
cube.add_dim_coord(
iris.coords.DimCoord(
np.arange(shape[0]),
standard_name='time',
),
0,
)
cube.add_dim_coord(
iris.coords.DimCoord(
np.linspace(-90., 90., shape[-2], endpoint=True),
standard_name='latitude',
),
len(shape) - 2,
)
cube.add_dim_coord(
iris.coords.DimCoord(
np.linspace(0., 360., shape[-1]),
standard_name='longitude',
),
len(shape) - 1,
)
return cube


def test_rechunk_on_increased_grid():
"""Test that an increase in grid size rechunks."""
with dask.config.set({'array.chunk-size': '128 M'}):
Expand All @@ -333,10 +361,9 @@ def test_rechunk_on_increased_grid():
src_grid_dims = (91, 180)
data = da.empty((time_dim, ) + src_grid_dims, dtype=np.float32)

tgt_grid_dims = (361, 720)
tgt_grid = da.empty(tgt_grid_dims, dtype=np.float32)

result = _rechunk(iris.cube.Cube(data), iris.cube.Cube(tgt_grid))
tgt_grid_dims = (2, 361, 720)
tgt_grid = make_test_cube(tgt_grid_dims)
result = _rechunk(iris.cube.Cube(data), tgt_grid)

assert result.core_data().chunks == ((123, 123), (91, ), (180, ))

Expand All @@ -350,9 +377,9 @@ def test_no_rechunk_on_decreased_grid():
data = da.empty((time_dim, ) + src_grid_dims, dtype=np.float32)

tgt_grid_dims = (91, 180)
tgt_grid = da.empty(tgt_grid_dims, dtype=np.float32)
tgt_grid = make_test_cube(tgt_grid_dims)

result = _rechunk(iris.cube.Cube(data), iris.cube.Cube(tgt_grid))
result = _rechunk(iris.cube.Cube(data), tgt_grid)

assert result.core_data().chunks == data.chunks

Expand Down Expand Up @@ -380,5 +407,39 @@ def test_no_rechunk_non_lazy():
assert result.data is cube.data


def test_no_rechunk_unsupported_grid():
"""Test that 2D target coordinates are ignored.
Because they are not supported at the moment. This could be
implemented at a later stage if needed.
"""
cube = iris.cube.Cube(da.arange(2 * 4).reshape([1, 2, 4]))
tgt_grid_dims = (5, 10)
tgt_data = da.empty(tgt_grid_dims, dtype=np.float32)
tgt_grid = iris.cube.Cube(tgt_data)
lat_points = np.linspace(-90., 90., tgt_grid_dims[0], endpoint=True)
lon_points = np.linspace(0., 360., tgt_grid_dims[1])

tgt_grid.add_aux_coord(
iris.coords.AuxCoord(
np.broadcast_to(lat_points.reshape(-1, 1), tgt_grid_dims),
standard_name='latitude',
),
(0, 1),
)
tgt_grid.add_aux_coord(
iris.coords.AuxCoord(
np.broadcast_to(lon_points.reshape(1, -1), tgt_grid_dims),
standard_name='longitude',
),
(0, 1),
)

expected_chunks = cube.core_data().chunks
result = _rechunk(cube, tgt_grid)
assert result is cube
assert result.core_data().chunks == expected_chunks


if __name__ == '__main__':
unittest.main()

0 comments on commit 1cf81d2

Please sign in to comment.