Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regridding issue where target grid is more than 2D #2087

Merged
merged 3 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions esmvalcore/preprocessor/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,22 +681,36 @@ def regrid(cube, target_grid, scheme, lat_offset=True, lon_offset=True):
return cube


def _rechunk(cube, target_grid):
def _rechunk(
cube: iris.cube.Cube,
target_grid: iris.cube.Cube,
) -> iris.cube.Cube:
"""Re-chunk cube with optimal chunk sizes for target grid."""
if not cube.has_lazy_data() or cube.ndim < 3:
# Only rechunk lazy multidimensional data
return cube

if 2 * np.prod(cube.shape[-2:]) > np.prod(target_grid.shape):
lon_coord = target_grid.coord(axis='X')
lat_coord = target_grid.coord(axis='Y')
if lon_coord.ndim != 1 or lat_coord.ndim != 1:
# This function only supports 1D lat/lon coordinates.
valeriupredoi marked this conversation as resolved.
Show resolved Hide resolved
return cube

lon_dim, = target_grid.coord_dims(lon_coord)
lat_dim, = target_grid.coord_dims(lat_coord)
grid_indices = sorted((lon_dim, lat_dim))
target_grid_shape = tuple(target_grid.shape[i] for i in grid_indices)

if 2 * np.prod(cube.shape[-2:]) > np.prod(target_grid_shape):
zklaus marked this conversation as resolved.
Show resolved Hide resolved
# Only rechunk if target grid is more than a factor of 2 larger,
# because rechunking will keep the original chunk in memory.
return cube

data = cube.lazy_data()

# Compute a good chunk size for the target array
tgt_shape = data.shape[:-2] + target_grid.shape
tgt_chunks = data.chunks[:-2] + target_grid.shape
tgt_shape = data.shape[:-2] + target_grid_shape
tgt_chunks = data.chunks[:-2] + target_grid_shape
tgt_data = da.empty(tgt_shape, dtype=data.dtype, chunks=tgt_chunks)
tgt_data = tgt_data.rechunk({i: "auto" for i in range(cube.ndim - 2)})

Expand Down
73 changes: 67 additions & 6 deletions tests/unit/preprocessor/_regrid/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,34 @@ def test_use_mask_if_discontiguities_in_coords(caplog):
assert msg in caplog.text


def make_test_cube(shape):
data = da.empty(shape, dtype=np.float32)
cube = iris.cube.Cube(data)
if len(shape) > 2:
cube.add_dim_coord(
iris.coords.DimCoord(
np.arange(shape[0]),
standard_name='time',
),
0,
)
cube.add_dim_coord(
iris.coords.DimCoord(
np.linspace(-90., 90., shape[-2], endpoint=True),
standard_name='latitude',
),
len(shape) - 2,
)
cube.add_dim_coord(
iris.coords.DimCoord(
np.linspace(0., 360., shape[-1]),
standard_name='longitude',
),
len(shape) - 1,
)
return cube


def test_rechunk_on_increased_grid():
"""Test that an increase in grid size rechunks."""
with dask.config.set({'array.chunk-size': '128 M'}):
Expand All @@ -333,10 +361,9 @@ def test_rechunk_on_increased_grid():
src_grid_dims = (91, 180)
data = da.empty((time_dim, ) + src_grid_dims, dtype=np.float32)

tgt_grid_dims = (361, 720)
tgt_grid = da.empty(tgt_grid_dims, dtype=np.float32)

result = _rechunk(iris.cube.Cube(data), iris.cube.Cube(tgt_grid))
tgt_grid_dims = (2, 361, 720)
tgt_grid = make_test_cube(tgt_grid_dims)
result = _rechunk(iris.cube.Cube(data), tgt_grid)

assert result.core_data().chunks == ((123, 123), (91, ), (180, ))

Expand All @@ -350,9 +377,9 @@ def test_no_rechunk_on_decreased_grid():
data = da.empty((time_dim, ) + src_grid_dims, dtype=np.float32)

tgt_grid_dims = (91, 180)
tgt_grid = da.empty(tgt_grid_dims, dtype=np.float32)
tgt_grid = make_test_cube(tgt_grid_dims)

result = _rechunk(iris.cube.Cube(data), iris.cube.Cube(tgt_grid))
result = _rechunk(iris.cube.Cube(data), tgt_grid)

assert result.core_data().chunks == data.chunks

Expand Down Expand Up @@ -380,5 +407,39 @@ def test_no_rechunk_non_lazy():
assert result.data is cube.data


def test_no_rechunk_unsupported_grid():
"""Test that 2D target coordinates are ignored.
Because they are not supported at the moment. This could be
implemented at a later stage if needed.
"""
cube = iris.cube.Cube(da.arange(2 * 4).reshape([1, 2, 4]))
tgt_grid_dims = (5, 10)
tgt_data = da.empty(tgt_grid_dims, dtype=np.float32)
tgt_grid = iris.cube.Cube(tgt_data)
lat_points = np.linspace(-90., 90., tgt_grid_dims[0], endpoint=True)
lon_points = np.linspace(0., 360., tgt_grid_dims[1])

tgt_grid.add_aux_coord(
iris.coords.AuxCoord(
np.broadcast_to(lat_points.reshape(-1, 1), tgt_grid_dims),
standard_name='latitude',
),
(0, 1),
)
tgt_grid.add_aux_coord(
iris.coords.AuxCoord(
np.broadcast_to(lon_points.reshape(1, -1), tgt_grid_dims),
standard_name='longitude',
),
(0, 1),
)

expected_chunks = cube.core_data().chunks
result = _rechunk(cube, tgt_grid)
assert result is cube
assert result.core_data().chunks == expected_chunks


if __name__ == '__main__':
unittest.main()