diff --git a/esmvalcore/preprocessor/_regrid.py b/esmvalcore/preprocessor/_regrid.py index cb52055596..daa6b6acfd 100644 --- a/esmvalcore/preprocessor/_regrid.py +++ b/esmvalcore/preprocessor/_regrid.py @@ -681,13 +681,27 @@ def regrid(cube, target_grid, scheme, lat_offset=True, lon_offset=True): return cube -def _rechunk(cube, target_grid): +def _rechunk( + cube: iris.cube.Cube, + target_grid: iris.cube.Cube, +) -> iris.cube.Cube: """Re-chunk cube with optimal chunk sizes for target grid.""" if not cube.has_lazy_data() or cube.ndim < 3: # Only rechunk lazy multidimensional data return cube - if 2 * np.prod(cube.shape[-2:]) > np.prod(target_grid.shape): + lon_coord = target_grid.coord(axis='X') + lat_coord = target_grid.coord(axis='Y') + if lon_coord.ndim != 1 or lat_coord.ndim != 1: + # This function only supports 1D lat/lon coordinates. + return cube + + lon_dim, = target_grid.coord_dims(lon_coord) + lat_dim, = target_grid.coord_dims(lat_coord) + grid_indices = sorted((lon_dim, lat_dim)) + target_grid_shape = tuple(target_grid.shape[i] for i in grid_indices) + + if 2 * np.prod(cube.shape[-2:]) > np.prod(target_grid_shape): # Only rechunk if target grid is more than a factor of 2 larger, # because rechunking will keep the original chunk in memory. return cube @@ -695,8 +709,8 @@ def _rechunk(cube, target_grid): data = cube.lazy_data() # Compute a good chunk size for the target array - tgt_shape = data.shape[:-2] + target_grid.shape - tgt_chunks = data.chunks[:-2] + target_grid.shape + tgt_shape = data.shape[:-2] + target_grid_shape + tgt_chunks = data.chunks[:-2] + target_grid_shape tgt_data = da.empty(tgt_shape, dtype=data.dtype, chunks=tgt_chunks) tgt_data = tgt_data.rechunk({i: "auto" for i in range(cube.ndim - 2)}) diff --git a/tests/unit/preprocessor/_regrid/test_regrid.py b/tests/unit/preprocessor/_regrid/test_regrid.py index a964eaf646..780c0a1db3 100644 --- a/tests/unit/preprocessor/_regrid/test_regrid.py +++ b/tests/unit/preprocessor/_regrid/test_regrid.py @@ -325,6 +325,34 @@ def test_use_mask_if_discontiguities_in_coords(caplog): assert msg in caplog.text +def make_test_cube(shape): + data = da.empty(shape, dtype=np.float32) + cube = iris.cube.Cube(data) + if len(shape) > 2: + cube.add_dim_coord( + iris.coords.DimCoord( + np.arange(shape[0]), + standard_name='time', + ), + 0, + ) + cube.add_dim_coord( + iris.coords.DimCoord( + np.linspace(-90., 90., shape[-2], endpoint=True), + standard_name='latitude', + ), + len(shape) - 2, + ) + cube.add_dim_coord( + iris.coords.DimCoord( + np.linspace(0., 360., shape[-1]), + standard_name='longitude', + ), + len(shape) - 1, + ) + return cube + + def test_rechunk_on_increased_grid(): """Test that an increase in grid size rechunks.""" with dask.config.set({'array.chunk-size': '128 M'}): @@ -333,10 +361,9 @@ def test_rechunk_on_increased_grid(): src_grid_dims = (91, 180) data = da.empty((time_dim, ) + src_grid_dims, dtype=np.float32) - tgt_grid_dims = (361, 720) - tgt_grid = da.empty(tgt_grid_dims, dtype=np.float32) - - result = _rechunk(iris.cube.Cube(data), iris.cube.Cube(tgt_grid)) + tgt_grid_dims = (2, 361, 720) + tgt_grid = make_test_cube(tgt_grid_dims) + result = _rechunk(iris.cube.Cube(data), tgt_grid) assert result.core_data().chunks == ((123, 123), (91, ), (180, )) @@ -350,9 +377,9 @@ def test_no_rechunk_on_decreased_grid(): data = da.empty((time_dim, ) + src_grid_dims, dtype=np.float32) tgt_grid_dims = (91, 180) - tgt_grid = da.empty(tgt_grid_dims, dtype=np.float32) + tgt_grid = make_test_cube(tgt_grid_dims) - result = _rechunk(iris.cube.Cube(data), iris.cube.Cube(tgt_grid)) + result = _rechunk(iris.cube.Cube(data), tgt_grid) assert result.core_data().chunks == data.chunks @@ -380,5 +407,39 @@ def test_no_rechunk_non_lazy(): assert result.data is cube.data +def test_no_rechunk_unsupported_grid(): + """Test that 2D target coordinates are ignored. + + Because they are not supported at the moment. This could be + implemented at a later stage if needed. + """ + cube = iris.cube.Cube(da.arange(2 * 4).reshape([1, 2, 4])) + tgt_grid_dims = (5, 10) + tgt_data = da.empty(tgt_grid_dims, dtype=np.float32) + tgt_grid = iris.cube.Cube(tgt_data) + lat_points = np.linspace(-90., 90., tgt_grid_dims[0], endpoint=True) + lon_points = np.linspace(0., 360., tgt_grid_dims[1]) + + tgt_grid.add_aux_coord( + iris.coords.AuxCoord( + np.broadcast_to(lat_points.reshape(-1, 1), tgt_grid_dims), + standard_name='latitude', + ), + (0, 1), + ) + tgt_grid.add_aux_coord( + iris.coords.AuxCoord( + np.broadcast_to(lon_points.reshape(1, -1), tgt_grid_dims), + standard_name='longitude', + ), + (0, 1), + ) + + expected_chunks = cube.core_data().chunks + result = _rechunk(cube, tgt_grid) + assert result is cube + assert result.core_data().chunks == expected_chunks + + if __name__ == '__main__': unittest.main()