diff --git a/geoutils/raster/delayed.py b/geoutils/raster/delayed.py index bc118564..76221cee 100644 --- a/geoutils/raster/delayed.py +++ b/geoutils/raster/delayed.py @@ -2,7 +2,7 @@ Module for dask-delayed functions for out-of-memory raster operations. """ import warnings -from typing import Any, Literal +from typing import Any, Literal, TypeVar import dask.array as da import dask.delayed @@ -12,6 +12,7 @@ import rasterio as rio from scipy.interpolate import interpn +from geoutils._typing import NDArrayNum from geoutils.projtools import _get_bounds_projected, _get_footprint_projected # 1/ SUBSAMPLING @@ -23,7 +24,9 @@ # usage by having to drop an axis and re-chunk along 1D of the 2D array, so we use the dask.delayed solution instead) -def _random_state_from_user_input(random_state: np.random.RandomState | int | None = None) -> np.random.RandomState: +def _random_state_from_user_input( + random_state: np.random.RandomState | int | None = None, +) -> np.random.RandomState | np.random.Generator: """Define random state based on varied user input.""" # Define state for random sampling (to fix results during testing) @@ -61,7 +64,7 @@ def _get_subsample_size_from_user_input(subsample: int | float, total_nb_valids: def _get_indices_block_per_subsample( - xxs: np.ndarray, num_chunks: tuple[int, int], nb_valids_per_block: list[int] + xxs: NDArrayNum, num_chunks: tuple[int, int], nb_valids_per_block: list[int] ) -> list[list[int]]: """ Get list of 1D valid subsample indices relative to the block for each block. @@ -102,14 +105,14 @@ def _get_indices_block_per_subsample( return relative_ind_per_block -@dask.delayed -def _delayed_nb_valids(arr_chunk: np.ndarray) -> np.ndarray: +@dask.delayed # type: ignore +def _delayed_nb_valids(arr_chunk: NDArrayNum) -> NDArrayNum: """Count number of valid values per block.""" return np.array([np.count_nonzero(np.isfinite(arr_chunk))]).reshape((1, 1)) -@dask.delayed -def _delayed_subsample_block(arr_chunk: np.ndarray, subsample_indices: np.ndarray) -> np.ndarray: +@dask.delayed # type: ignore +def _delayed_subsample_block(arr_chunk: NDArrayNum, subsample_indices: NDArrayNum) -> NDArrayNum: """Subsample the valid values at the corresponding 1D valid indices per block.""" s_chunk = arr_chunk[np.isfinite(arr_chunk)][subsample_indices] @@ -117,10 +120,10 @@ def _delayed_subsample_block(arr_chunk: np.ndarray, subsample_indices: np.ndarra return s_chunk -@dask.delayed +@dask.delayed # type: ignore def _delayed_subsample_indices_block( - arr_chunk: np.ndarray, subsample_indices: np.ndarray, block_id: dict[str, Any] -) -> np.ndarray: + arr_chunk: NDArrayNum, subsample_indices: NDArrayNum, block_id: dict[str, Any] +) -> NDArrayNum: """Return 2D indices from the subsampled 1D valid indices per block.""" # Unravel indices of valid data to the shape of the block @@ -138,7 +141,7 @@ def delayed_subsample( subsample: int | float = 1, return_indices: bool = False, random_state: np.random.RandomState | int | None = None, -) -> np.ndarray: +) -> NDArrayNum | tuple[NDArrayNum, NDArrayNum]: """ Subsample a raster at valid values on out-of-memory chunks. @@ -246,7 +249,15 @@ def delayed_subsample( # Code structure inspired by https://blog.dask.org/2021/07/02/ragged-output and the "block_id" in map_blocks -def _get_interp_indices_per_block(interp_x, interp_y, starts, num_chunks, chunksize, xres, yres): +def _get_interp_indices_per_block( + interp_x: NDArrayNum, + interp_y: NDArrayNum, + starts: list[tuple[int, ...]], + num_chunks: tuple[int, int], + chunksize: tuple[int, int], + xres: float, + yres: float, +) -> list[list[int]]: """Map blocks where each pair of interpolation coordinates will have to be computed.""" # TODO 1: Check the robustness for chunksize different and X and Y @@ -267,10 +278,10 @@ def _get_interp_indices_per_block(interp_x, interp_y, starts, num_chunks, chunks return ind_per_block -@dask.delayed +@dask.delayed # type: ignore def _delayed_interp_points_block( - arr_chunk: np.ndarray, block_id: dict[str, Any], interp_coords: np.ndarray -) -> np.ndarray: + arr_chunk: NDArrayNum, block_id: dict[str, Any], interp_coords: NDArrayNum +) -> NDArrayNum: """ Interpolate block in 2D out-of-memory for a regular or equal grid. """ @@ -296,7 +307,7 @@ def delayed_interp_points( points: tuple[list[float], list[float]], resolution: tuple[float, float], method: Literal["nearest", "linear", "cubic", "quintic"] = "linear", -) -> np.ndarray: +) -> NDArrayNum: """ Interpolate raster at point coordinates on out-of-memory chunks. @@ -313,8 +324,9 @@ def delayed_interp_points( :return: Array of raster value(s) for the given points. """ + # TODO: Replace by a generic 2D point casting function accepting multiple inputs (living outside this function) # Convert input to 2D array - points = np.vstack((points[0], points[1])) + points_arr = np.vstack((points[0], points[1])) # Map depth of overlap required for each interpolation method # TODO: Double-check this window somewhere in SciPy's documentation @@ -333,7 +345,7 @@ def delayed_interp_points( # Get samples indices per blocks ind_per_block = _get_interp_indices_per_block( - points[0, :], points[1, :], starts, num_chunks, chunksize, resolution[0], resolution[1] + points_arr[0, :], points_arr[1, :], starts, num_chunks, chunksize, resolution[0], resolution[1] ) # Create a delayed object for each block, and flatten the blocks into a 1d shape @@ -353,7 +365,7 @@ def delayed_interp_points( # Compute values delayed list_interp = [ - _delayed_interp_points_block(blocks[i], block_ids[i], points[:, ind_per_block[i]]) + _delayed_interp_points_block(blocks[i], block_ids[i], points_arr[:, ind_per_block[i]]) for i, data_chunk in enumerate(blocks) if len(ind_per_block[i]) > 0 ] @@ -380,6 +392,9 @@ def delayed_interp_points( # We define a GeoGrid and GeoTiling class (which composes GeoGrid) to consistently deal with georeferenced footprints # of chunked grids +GeoGridType = TypeVar("GeoGridType", bound="GeoGrid") + + class GeoGrid: """ Georeferenced grid class. @@ -430,13 +445,23 @@ def bounds_projected(self, crs: rio.crs.CRS = None) -> rio.coords.BoundingBox: def footprint(self) -> gpd.GeoDataFrame: return _get_footprint_projected(self.bounds, in_crs=self.crs, out_crs=self.crs, densify_points=100) - def footprint_projected(self, crs: rio.crs.CRS = None): + def footprint_projected(self, crs: rio.crs.CRS = None) -> gpd.GeoDataFrame: if crs is None: crs = self.crs return _get_footprint_projected(self.bounds, in_crs=self.crs, out_crs=crs, densify_points=100) - def shift(self, xoff: float, yoff: float, distance_unit: Literal["georeferenced"] | Literal["pixel"] = "pixel"): - """Shift geogrid, not inplace.""" + @classmethod + def from_dict(cls: type[GeoGridType], dict_meta: dict[str, Any]) -> GeoGridType: + """Create a GeoGrid from a dictionary containing transform, shape and CRS.""" + return cls(**dict_meta) + + def shift( + self: GeoGridType, + xoff: float, + yoff: float, + distance_unit: Literal["georeferenced"] | Literal["pixel"] = "pixel", + ) -> GeoGridType: + """Shift into a new geogrid (not inplace).""" if distance_unit not in ["georeferenced", "pixel"]: raise ValueError("Argument 'distance_unit' should be either 'pixel' or 'georeferenced'.") @@ -456,7 +481,7 @@ def shift(self, xoff: float, yoff: float, distance_unit: Literal["georeferenced" shifted_transform = rio.transform.Affine(dx, b, xmin + xoff, d, dy, ymax + yoff) - return GeoGrid(transform=shifted_transform, crs=self.crs, shape=self.shape) + return self.from_dict({"transform": shifted_transform, "crs": self.crs, "shape": self.shape}) def _get_block_ids_per_chunk(chunks: tuple[tuple[int, ...], tuple[int, ...]]) -> list[dict[str, int]]: @@ -536,7 +561,9 @@ def get_block_footprints(self, crs: rio.crs.CRS = None) -> gpd.GeoDataFrame: return pd.concat(footprints) -def _chunks2d_from_chunksizes_shape(chunksizes: tuple[int, int], shape: tuple[int, int]): +def _chunks2d_from_chunksizes_shape( + chunksizes: tuple[int, int], shape: tuple[int, int] +) -> tuple[tuple[int, ...], tuple[int, ...]]: """Get tuples of chunk sizes for X/Y dimensions based on chunksizes and array shape.""" # Chunksize is fixed, except for the last chunk depending on the shape @@ -583,10 +610,10 @@ def _combined_blocks_shape_transform( return combined_meta, relative_block_indexes -@dask.delayed +@dask.delayed # type: ignore def _delayed_reproject_per_block( - *src_arrs: tuple[np.ndarray], block_ids: list[dict[str, int]], combined_meta: dict[str, Any], **kwargs: Any -) -> np.ndarray: + *src_arrs: tuple[NDArrayNum], block_ids: list[dict[str, int]], combined_meta: dict[str, Any], **kwargs: Any +) -> NDArrayNum: """ Delayed reprojection per destination block (also rebuilds a square array combined from intersecting source blocks). """ @@ -604,7 +631,7 @@ def _delayed_reproject_per_block( # Then fill it with the source chunks values for i, arr in enumerate(src_arrs): bid = block_ids[i] - comb_src_arr[bid["rys"]:bid["rye"], bid["rxs"]:bid["rxe"]] = arr + comb_src_arr[bid["rys"] : bid["rye"], bid["rxs"] : bid["rxe"]] = arr # Now, we can simply call Rasterio! @@ -638,9 +665,9 @@ def delayed_reproject( dst_shape: tuple[int, int], dst_crs: rio.crs.CRS, resampling: rio.enums.Resampling, - src_nodata: int | float = None, - dst_nodata: int | float = None, - dst_chunksizes: tuple[int, int] = None, + src_nodata: int | float | None = None, + dst_nodata: int | float | None = None, + dst_chunksizes: tuple[int, int] | None = None, **kwargs: Any, ) -> da.Array: """ @@ -679,6 +706,8 @@ def delayed_reproject( src_geotiling = ChunkedGeoGrid(grid=src_geogrid, chunks=src_chunks) # For destination, we need to create the chunks based on destination chunksizes + if dst_chunksizes is None: + dst_chunksizes = darr.chunksize dst_chunks = _chunks2d_from_chunksizes_shape(chunksizes=dst_chunksizes, shape=dst_shape) dst_geotiling = ChunkedGeoGrid(grid=dst_geogrid, chunks=dst_chunks) diff --git a/tests/test_raster/test_delayed.py b/tests/test_raster/test_delayed.py index f0c9c142..e387ad4a 100644 --- a/tests/test_raster/test_delayed.py +++ b/tests/test_raster/test_delayed.py @@ -1,7 +1,7 @@ """Tests for dask-delayed functions.""" import os from tempfile import NamedTemporaryFile -from typing import Callable, Any +from typing import Any, Callable import dask.array as da import numpy as np @@ -9,8 +9,10 @@ import pytest import rasterio as rio import xarray as xr -from pyproj import CRS from dask.distributed import Client, LocalCluster +from dask_memusage import install +from pluggy import PluggyTeardownRaisedWarning +from pyproj import CRS from geoutils.examples import _EXAMPLES_DIRECTORY from geoutils.raster.delayed import ( @@ -19,29 +21,30 @@ delayed_subsample, ) -from pluggy import PluggyTeardownRaisedWarning -from dask_memusage import install - # Ignore teardown warning given by Dask when closing the local cluster (due to dask-memusage plugin) pytestmark = pytest.mark.filterwarnings("ignore", category=PluggyTeardownRaisedWarning) -# Fixture to use a single cluster for the entire module -@pytest.fixture(scope='module') + +@pytest.fixture(scope="module") # type: ignore def cluster(): + """Fixture to use a single cluster for the entire module (otherwise raise runtime errors).""" # Need cluster to be single-threaded to use dask-memusage confidently dask_cluster = LocalCluster(n_workers=2, threads_per_worker=1, dashboard_address=None) yield dask_cluster dask_cluster.close() -def _run_dask_measuring_memusage(cluster, dask_func: Callable, *args_dask_func: Any, **kwargs_dask_func: Any) -> tuple[Any, float]: + +def _run_dask_measuring_memusage( + cluster: Any, dask_func: Callable[..., Any], *args_dask_func: Any, **kwargs_dask_func: Any +) -> tuple[Any, float]: """Run a dask function monitoring its memory usage.""" - # Create a name temporary file that won't delete immediatley + # Create a name temporary file that won't delete immediately fn_tmp_csv = NamedTemporaryFile(suffix=".csv", delete=False).name # Setup cluster and client within context managers for a clean shutdown install(cluster.scheduler, fn_tmp_csv) - with Client(cluster) as client: + with Client(cluster) as _: outputs = dask_func(*args_dask_func, **kwargs_dask_func) # Read memusage file and cleanup @@ -56,6 +59,7 @@ def _run_dask_measuring_memusage(cluster, dask_func: Callable, *args_dask_func: return outputs, memusage_mb + def _estimate_subsample_memusage(darr: da.Array, chunksizes_in_mem: tuple[int, int], subsample_size: int) -> float: """ Estimate the theoretical memory usage of the delayed subsampling method. @@ -87,12 +91,13 @@ def _estimate_subsample_memusage(darr: da.Array, chunksizes_in_mem: tuple[int, i meta_memusage = list_per_block + list_all_blocks # Final estimate of memory usage of operation in MB - max_op_memusage = fac_dask_margin * (chunk_memusage + sample_memusage + out_memusage + meta_memusage) / (2 ** 20) + max_op_memusage = fac_dask_margin * (chunk_memusage + sample_memusage + out_memusage + meta_memusage) / (2**20) # We add a base memory usage of ~50 MB + 10MB per 1000 chunks (loaded in background by Dask even on tiny data) max_op_memusage += 50 + 10 * (num_chunks / 1000) return max_op_memusage + def _estimate_interp_points_memusage(darr: da.Array, chunksizes_in_mem: tuple[int, int], ninterp: int) -> float: """ Estimate the theoretical memory usage of the delayed interpolation method. @@ -125,14 +130,19 @@ def _estimate_interp_points_memusage(darr: da.Array, chunksizes_in_mem: tuple[in meta_memusage = list_per_block + list_all_blocks + dict_all_blocks # Final estimate of memory usage of operation in MB - max_op_memusage = fac_dask_margin * (chunk_memusage + out_memusage + meta_memusage) / (2 ** 20) + max_op_memusage = fac_dask_margin * (chunk_memusage + out_memusage + meta_memusage) / (2**20) # We add a base memory usage of ~50 MB + 10MB per 1000 chunks (loaded in background by Dask even on tiny data) max_op_memusage += 50 + 10 * (num_chunks / 1000) return max_op_memusage -def _estimate_reproject_memusage(darr: da.Array, chunksizes_in_mem: tuple[int, int], dst_chunksizes: tuple[int, int], - rel_res_fac: tuple[float, float]) -> float: + +def _estimate_reproject_memusage( + darr: da.Array, + chunksizes_in_mem: tuple[int, int], + dst_chunksizes: tuple[int, int], + rel_res_fac: tuple[float, float], +) -> float: """ Estimate the theoretical memory usage of the delayed reprojection method. (we don't need to be super precise, just within a factor of ~2 to check memory usage performs as expected) @@ -153,7 +163,7 @@ def _estimate_reproject_memusage(darr: da.Array, chunksizes_in_mem: tuple[int, i nb_source_chunks_per_dest = 8 * x_rel_source_chunks * y_rel_source_chunks # Combined memory usage of one chunk operation = squared array made from combined chunksize + original chunks - total_nb = np.ceil(np.sqrt(nb_source_chunks_per_dest))**2 + nb_source_chunks_per_dest + total_nb = np.ceil(np.sqrt(nb_source_chunks_per_dest)) ** 2 + nb_source_chunks_per_dest # We multiply the memory usage of a single chunk to the number of loaded/combined chunks chunk_memusage = darr.dtype.itemsize * np.prod(chunksizes_in_mem) * total_nb @@ -170,19 +180,21 @@ def _estimate_reproject_memusage(darr: da.Array, chunksizes_in_mem: tuple[int, i meta_memusage = combined_meta + dict_all_blocks # Final estimate of memory usage of operation in MB - max_op_memusage = fac_dask_margin * (chunk_memusage + out_memusage + meta_memusage) / (2 ** 20) + max_op_memusage = fac_dask_margin * (chunk_memusage + out_memusage + meta_memusage) / (2**20) # We add a base memory usage of ~50 MB + 10MB per 1000 chunks (loaded in background by Dask even on tiny data) max_op_memusage += 50 + 10 * (num_chunks / 1000) return max_op_memusage -def _build_dst_transform_shifted_newres(src_transform: rio.transform.Affine, - src_shape: tuple[int, int], - src_crs: CRS, - dst_crs: CRS, - bounds_rel_shift: tuple[float, float], - res_rel_fac: tuple[float, float]) -> rio.transform.Affine: +def _build_dst_transform_shifted_newres( + src_transform: rio.transform.Affine, + src_shape: tuple[int, int], + src_crs: CRS, + dst_crs: CRS, + bounds_rel_shift: tuple[float, float], + res_rel_fac: tuple[float, float], +) -> rio.transform.Affine: """ Build a destination transform intersecting the source transform given source/destination shapes, and possibly introducing a relative shift in upper-left bound and multiplicative change in resolution. @@ -212,10 +224,12 @@ def _build_dst_transform_shifted_newres(src_transform: rio.transform.Affine, west=tmp_bounds.left + bounds_rel_shift[0] * tmp_res[0] * src_shape[1], north=tmp_bounds.top + 150 * bounds_rel_shift[0] * tmp_res[1] * src_shape[0], xsize=tmp_res[0] / res_rel_fac[0], - ysize=tmp_res[1] / res_rel_fac[1]) + ysize=tmp_res[1] / res_rel_fac[1], + ) return dst_transform + class TestDelayed: """ Testing delayed functions is pretty straightforward. @@ -229,12 +243,11 @@ class TestDelayed: subsample and interp_points, or destination chunksizes to map output of reproject). 2. During execution, we capture memory usage and check that only the expected amount of memory (one or several chunk combinations + metadata) is indeed used during the compute() call. - """ + """ # Write big test files on disk out-of-memory, with different input shapes not necessarily aligned between themselves # or with chunks - fn_nc_shape = {"test_square.nc": (10000, 10000), - "test_complex.nc": (5511, 6768)} + fn_nc_shape = {"test_square.nc": (10000, 10000), "test_complex.nc": (5511, 6768)} # We can use a constant value for storage chunks, it doesn't have any influence on the accuracy of delayed methods # (can change slightly RAM usage, but pretty stable as long as chunksizes in memory are larger and # significantly bigger) @@ -254,9 +267,9 @@ class TestDelayed: writer = ds.to_netcdf(fn, encoding=encoding_kwargs, compute=False) writer.compute() - @pytest.mark.parametrize("fn", list_fn) - @pytest.mark.parametrize("chunksizes_in_mem", [(2000, 2000), (1241, 3221)]) - @pytest.mark.parametrize("subsample_size", [100, 100000]) + @pytest.mark.parametrize("fn", list_fn) # type: ignore + @pytest.mark.parametrize("chunksizes_in_mem", [(2000, 2000), (1241, 3221)]) # type: ignore + @pytest.mark.parametrize("subsample_size", [100, 100000]) # type: ignore def test_delayed_subsample(self, fn: str, chunksizes_in_mem: tuple[int, int], subsample_size: int, cluster: Any): """ Checks for delayed subsampling function, both for output and memory usage. @@ -272,15 +285,17 @@ def test_delayed_subsample(self, fn: str, chunksizes_in_mem: tuple[int, int], su # 1/ Estimation of theoretical memory usage of the subsampling script - max_op_memusage = _estimate_subsample_memusage(darr=darr, chunksizes_in_mem=chunksizes_in_mem, - subsample_size=subsample_size) + max_op_memusage = _estimate_subsample_memusage( + darr=darr, chunksizes_in_mem=chunksizes_in_mem, subsample_size=subsample_size + ) # 2/ Run delayed subsample with dask memory usage monitoring # Derive subsample from delayed function # (passed to wrapper function to measure memory usage during execution) - sub, measured_op_memusage = _run_dask_measuring_memusage(cluster, delayed_subsample, darr, - subsample=subsample_size, random_state=42) + sub, measured_op_memusage = _run_dask_measuring_memusage( + cluster, delayed_subsample, darr, subsample=subsample_size, random_state=42 + ) # Check the measured memory usage is smaller than the maximum estimated one assert measured_op_memusage < max_op_memusage @@ -297,9 +312,9 @@ def test_delayed_subsample(self, fn: str, chunksizes_in_mem: tuple[int, int], su sub2 = np.array(darr.vindex[indices[0], indices[1]]) assert np.array_equal(sub, sub2) - @pytest.mark.parametrize("fn", list_fn) - @pytest.mark.parametrize("chunksizes_in_mem", [(2000, 2000), (1241, 3221)]) - @pytest.mark.parametrize("ninterp", [100, 100000]) + @pytest.mark.parametrize("fn", list_fn) # type: ignore + @pytest.mark.parametrize("chunksizes_in_mem", [(2000, 2000), (1241, 3221)]) # type: ignore + @pytest.mark.parametrize("ninterp", [100, 100000]) # type: ignore def test_delayed_interp_points(self, fn: str, chunksizes_in_mem: tuple[int, int], ninterp: int, cluster: Any): """ Checks for delayed interpolate points function. @@ -318,14 +333,15 @@ def test_delayed_interp_points(self, fn: str, chunksizes_in_mem: tuple[int, int] interp_y = rng.choice(ds.y.size, ninterp) + rng.random(ninterp) # 1/ Estimation of theoretical memory usage of the subsampling script - max_op_memusage = _estimate_interp_points_memusage(darr=darr, chunksizes_in_mem=chunksizes_in_mem, - ninterp=ninterp) - + max_op_memusage = _estimate_interp_points_memusage( + darr=darr, chunksizes_in_mem=chunksizes_in_mem, ninterp=ninterp + ) # 2/ Run interpolation of random point coordinates with memory monitoring - interp1, measured_op_memusage = _run_dask_measuring_memusage(cluster, delayed_interp_points, darr, - points=(interp_x, interp_y), resolution=(1, 1)) + interp1, measured_op_memusage = _run_dask_measuring_memusage( + cluster, delayed_interp_points, darr, points=(interp_x, interp_y), resolution=(1, 1) + ) # Check the measured memory usage is smaller than the maximum estimated one assert measured_op_memusage < max_op_memusage @@ -340,19 +356,25 @@ def test_delayed_interp_points(self, fn: str, chunksizes_in_mem: tuple[int, int] assert np.array_equal(interp1, interp2, equal_nan=True) - @pytest.mark.parametrize("fn", list_fn) - @pytest.mark.parametrize("chunksizes_in_mem", [(2000, 2000), (1241, 3221)]) - @pytest.mark.parametrize("dst_chunksizes", [(2000, 2000), (1398, 2983)]) + @pytest.mark.parametrize("fn", list_fn) # type: ignore + @pytest.mark.parametrize("chunksizes_in_mem", [(2000, 2000), (1241, 3221)]) # type: ignore + @pytest.mark.parametrize("dst_chunksizes", [(2000, 2000), (1398, 2983)]) # type: ignore # Shift upper left corner of output bounds (relative to projected input bounds) by fractions of the raster size - @pytest.mark.parametrize("dst_bounds_rel_shift", [(0, 0), (-0.2, 0.5)]) + @pytest.mark.parametrize("dst_bounds_rel_shift", [(0, 0), (-0.2, 0.5)]) # type: ignore # Modify output resolution (relative to projected input resolution) by a factor - @pytest.mark.parametrize("dst_res_rel_fac", [(1, 1), (2.1, 0.54)]) + @pytest.mark.parametrize("dst_res_rel_fac", [(1, 1), (2.1, 0.54)]) # type: ignore # Same for shape - @pytest.mark.parametrize("dst_shape_diff", [(0, 0), (-28, 117)]) - def test_delayed_reproject(self, fn: str, chunksizes_in_mem: tuple[int, int], - dst_chunksizes: tuple[int, int], dst_bounds_rel_shift: tuple[float, float], - dst_res_rel_fac: tuple[float, float], dst_shape_diff: tuple[int, int], - cluster: Any): + @pytest.mark.parametrize("dst_shape_diff", [(0, 0), (-28, 117)]) # type: ignore + def test_delayed_reproject( + self, + fn: str, + chunksizes_in_mem: tuple[int, int], + dst_chunksizes: tuple[int, int], + dst_bounds_rel_shift: tuple[float, float], + dst_res_rel_fac: tuple[float, float], + dst_shape_diff: tuple[int, int], + cluster: Any, + ): """ Checks for the delayed reproject function. Variables that influence specifically the delayed function are: @@ -393,14 +415,20 @@ def test_delayed_reproject(self, fn: str, chunksizes_in_mem: tuple[int, int], resampling = rio.enums.Resampling.bilinear # Get shifted dst_transform with new resolution - dst_transform = _build_dst_transform_shifted_newres(src_transform=src_transform, src_crs=src_crs, dst_crs=dst_crs, - src_shape=src_shape, bounds_rel_shift=dst_bounds_rel_shift, - res_rel_fac=dst_res_rel_fac) + dst_transform = _build_dst_transform_shifted_newres( + src_transform=src_transform, + src_crs=src_crs, + dst_crs=dst_crs, + src_shape=src_shape, + bounds_rel_shift=dst_bounds_rel_shift, + res_rel_fac=dst_res_rel_fac, + ) # 1/ Estimation of theoretical memory usage of the subsampling script - max_op_memusage = _estimate_reproject_memusage(darr, chunksizes_in_mem=chunksizes_in_mem, dst_chunksizes=dst_chunksizes, - rel_res_fac=dst_res_rel_fac) + max_op_memusage = _estimate_reproject_memusage( + darr, chunksizes_in_mem=chunksizes_in_mem, dst_chunksizes=dst_chunksizes, rel_res_fac=dst_res_rel_fac + ) # 2/ Run delayed reproject with memory monitoring @@ -408,7 +436,7 @@ def test_delayed_reproject(self, fn: str, chunksizes_in_mem: tuple[int, int], # (delayed_reproject returns a delayed array that might not fit in memory, unlike subsampling/interpolation) fn_tmp_out = os.path.join(_EXAMPLES_DIRECTORY, os.path.splitext(os.path.basename(fn))[0] + "_reproj.nc") - def reproject_and_write(*args, **kwargs): + def reproject_and_write(*args: Any, **kwargs: Any) -> None: # Run delayed reprojection reproj_arr_tmp = delayed_reproject(*args, **kwargs) @@ -456,20 +484,19 @@ def reproject_and_write(*args, **kwargs): # Keeping this to visualize Rasterio resampling issue # if PLOT: - # import matplotlib.pyplot as plt - # plt.figure() - # plt.imshow((reproj_arr - dst_arr), cmap="RdYlBu", vmin=-0.2, vmax=0.2, interpolation="None") - # plt.colorbar() - # plt.savefig("/home/atom/ongoing/diff_close_zero.png", dpi=500) - # plt.figure() - # plt.imshow(np.abs(reproj_arr - dst_arr), cmap="RdYlBu", vmin=99997, vmax=100001, interpolation="None") - # plt.colorbar() - # plt.savefig("/home/atom/ongoing/diff_nodata.png", dpi=500) - # plt.figure() - # plt.imshow(dst_arr, cmap="RdYlBu", vmin=-1, vmax=1, interpolation="None") - # plt.colorbar() - # plt.savefig("/home/atom/ongoing/dst.png", dpi=500) - + # import matplotlib.pyplot as plt + # plt.figure() + # plt.imshow((reproj_arr - dst_arr), cmap="RdYlBu", vmin=-0.2, vmax=0.2, interpolation="None") + # plt.colorbar() + # plt.savefig("/home/atom/ongoing/diff_close_zero.png", dpi=500) + # plt.figure() + # plt.imshow(np.abs(reproj_arr - dst_arr), cmap="RdYlBu", vmin=99997, vmax=100001, interpolation="None") + # plt.colorbar() + # plt.savefig("/home/atom/ongoing/diff_nodata.png", dpi=500) + # plt.figure() + # plt.imshow(dst_arr, cmap="RdYlBu", vmin=-1, vmax=1, interpolation="None") + # plt.colorbar() + # plt.savefig("/home/atom/ongoing/dst.png", dpi=500) # Due to (what appears to be) Rasterio errors, we have to remain large here: even though some reprojections # are pretty good, some can get a bit nasty