Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
rhugonnet committed Apr 26, 2024
1 parent 03e370a commit fafd239
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 102 deletions.
91 changes: 60 additions & 31 deletions geoutils/raster/delayed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Module for dask-delayed functions for out-of-memory raster operations.
"""
import warnings
from typing import Any, Literal
from typing import Any, Literal, TypeVar

import dask.array as da
import dask.delayed
Expand All @@ -12,6 +12,7 @@
import rasterio as rio
from scipy.interpolate import interpn

from geoutils._typing import NDArrayNum
from geoutils.projtools import _get_bounds_projected, _get_footprint_projected

# 1/ SUBSAMPLING
Expand All @@ -23,7 +24,9 @@
# usage by having to drop an axis and re-chunk along 1D of the 2D array, so we use the dask.delayed solution instead)


def _random_state_from_user_input(random_state: np.random.RandomState | int | None = None) -> np.random.RandomState:
def _random_state_from_user_input(
random_state: np.random.RandomState | int | None = None,
) -> np.random.RandomState | np.random.Generator:
"""Define random state based on varied user input."""

# Define state for random sampling (to fix results during testing)
Expand Down Expand Up @@ -61,7 +64,7 @@ def _get_subsample_size_from_user_input(subsample: int | float, total_nb_valids:


def _get_indices_block_per_subsample(
xxs: np.ndarray, num_chunks: tuple[int, int], nb_valids_per_block: list[int]
xxs: NDArrayNum, num_chunks: tuple[int, int], nb_valids_per_block: list[int]
) -> list[list[int]]:
"""
Get list of 1D valid subsample indices relative to the block for each block.
Expand Down Expand Up @@ -102,25 +105,25 @@ def _get_indices_block_per_subsample(
return relative_ind_per_block


@dask.delayed
def _delayed_nb_valids(arr_chunk: np.ndarray) -> np.ndarray:
@dask.delayed # type: ignore
def _delayed_nb_valids(arr_chunk: NDArrayNum) -> NDArrayNum:
"""Count number of valid values per block."""
return np.array([np.count_nonzero(np.isfinite(arr_chunk))]).reshape((1, 1))


@dask.delayed
def _delayed_subsample_block(arr_chunk: np.ndarray, subsample_indices: np.ndarray) -> np.ndarray:
@dask.delayed # type: ignore
def _delayed_subsample_block(arr_chunk: NDArrayNum, subsample_indices: NDArrayNum) -> NDArrayNum:
"""Subsample the valid values at the corresponding 1D valid indices per block."""

s_chunk = arr_chunk[np.isfinite(arr_chunk)][subsample_indices]

return s_chunk


@dask.delayed
@dask.delayed # type: ignore
def _delayed_subsample_indices_block(
arr_chunk: np.ndarray, subsample_indices: np.ndarray, block_id: dict[str, Any]
) -> np.ndarray:
arr_chunk: NDArrayNum, subsample_indices: NDArrayNum, block_id: dict[str, Any]
) -> NDArrayNum:
"""Return 2D indices from the subsampled 1D valid indices per block."""

# Unravel indices of valid data to the shape of the block
Expand All @@ -138,7 +141,7 @@ def delayed_subsample(
subsample: int | float = 1,
return_indices: bool = False,
random_state: np.random.RandomState | int | None = None,
) -> np.ndarray:
) -> NDArrayNum | tuple[NDArrayNum, NDArrayNum]:
"""
Subsample a raster at valid values on out-of-memory chunks.
Expand Down Expand Up @@ -246,7 +249,15 @@ def delayed_subsample(
# Code structure inspired by https://blog.dask.org/2021/07/02/ragged-output and the "block_id" in map_blocks


def _get_interp_indices_per_block(interp_x, interp_y, starts, num_chunks, chunksize, xres, yres):
def _get_interp_indices_per_block(
interp_x: NDArrayNum,
interp_y: NDArrayNum,
starts: list[tuple[int, ...]],
num_chunks: tuple[int, int],
chunksize: tuple[int, int],
xres: float,
yres: float,
) -> list[list[int]]:
"""Map blocks where each pair of interpolation coordinates will have to be computed."""

# TODO 1: Check the robustness for chunksize different and X and Y
Expand All @@ -267,10 +278,10 @@ def _get_interp_indices_per_block(interp_x, interp_y, starts, num_chunks, chunks
return ind_per_block


@dask.delayed
@dask.delayed # type: ignore
def _delayed_interp_points_block(
arr_chunk: np.ndarray, block_id: dict[str, Any], interp_coords: np.ndarray
) -> np.ndarray:
arr_chunk: NDArrayNum, block_id: dict[str, Any], interp_coords: NDArrayNum
) -> NDArrayNum:
"""
Interpolate block in 2D out-of-memory for a regular or equal grid.
"""
Expand All @@ -296,7 +307,7 @@ def delayed_interp_points(
points: tuple[list[float], list[float]],
resolution: tuple[float, float],
method: Literal["nearest", "linear", "cubic", "quintic"] = "linear",
) -> np.ndarray:
) -> NDArrayNum:
"""
Interpolate raster at point coordinates on out-of-memory chunks.
Expand All @@ -313,8 +324,9 @@ def delayed_interp_points(
:return: Array of raster value(s) for the given points.
"""

# TODO: Replace by a generic 2D point casting function accepting multiple inputs (living outside this function)
# Convert input to 2D array
points = np.vstack((points[0], points[1]))
points_arr = np.vstack((points[0], points[1]))

# Map depth of overlap required for each interpolation method
# TODO: Double-check this window somewhere in SciPy's documentation
Expand All @@ -333,7 +345,7 @@ def delayed_interp_points(

# Get samples indices per blocks
ind_per_block = _get_interp_indices_per_block(
points[0, :], points[1, :], starts, num_chunks, chunksize, resolution[0], resolution[1]
points_arr[0, :], points_arr[1, :], starts, num_chunks, chunksize, resolution[0], resolution[1]
)

# Create a delayed object for each block, and flatten the blocks into a 1d shape
Expand All @@ -353,7 +365,7 @@ def delayed_interp_points(

# Compute values delayed
list_interp = [
_delayed_interp_points_block(blocks[i], block_ids[i], points[:, ind_per_block[i]])
_delayed_interp_points_block(blocks[i], block_ids[i], points_arr[:, ind_per_block[i]])
for i, data_chunk in enumerate(blocks)
if len(ind_per_block[i]) > 0
]
Expand All @@ -380,6 +392,9 @@ def delayed_interp_points(

# We define a GeoGrid and GeoTiling class (which composes GeoGrid) to consistently deal with georeferenced footprints
# of chunked grids
GeoGridType = TypeVar("GeoGridType", bound="GeoGrid")


class GeoGrid:
"""
Georeferenced grid class.
Expand Down Expand Up @@ -430,13 +445,23 @@ def bounds_projected(self, crs: rio.crs.CRS = None) -> rio.coords.BoundingBox:
def footprint(self) -> gpd.GeoDataFrame:
return _get_footprint_projected(self.bounds, in_crs=self.crs, out_crs=self.crs, densify_points=100)

def footprint_projected(self, crs: rio.crs.CRS = None):
def footprint_projected(self, crs: rio.crs.CRS = None) -> gpd.GeoDataFrame:
if crs is None:
crs = self.crs
return _get_footprint_projected(self.bounds, in_crs=self.crs, out_crs=crs, densify_points=100)

def shift(self, xoff: float, yoff: float, distance_unit: Literal["georeferenced"] | Literal["pixel"] = "pixel"):
"""Shift geogrid, not inplace."""
@classmethod
def from_dict(cls: type[GeoGridType], dict_meta: dict[str, Any]) -> GeoGridType:
"""Create a GeoGrid from a dictionary containing transform, shape and CRS."""
return cls(**dict_meta)

def shift(
self: GeoGridType,
xoff: float,
yoff: float,
distance_unit: Literal["georeferenced"] | Literal["pixel"] = "pixel",
) -> GeoGridType:
"""Shift into a new geogrid (not inplace)."""

if distance_unit not in ["georeferenced", "pixel"]:
raise ValueError("Argument 'distance_unit' should be either 'pixel' or 'georeferenced'.")
Expand All @@ -456,7 +481,7 @@ def shift(self, xoff: float, yoff: float, distance_unit: Literal["georeferenced"

shifted_transform = rio.transform.Affine(dx, b, xmin + xoff, d, dy, ymax + yoff)

return GeoGrid(transform=shifted_transform, crs=self.crs, shape=self.shape)
return self.from_dict({"transform": shifted_transform, "crs": self.crs, "shape": self.shape})


def _get_block_ids_per_chunk(chunks: tuple[tuple[int, ...], tuple[int, ...]]) -> list[dict[str, int]]:
Expand Down Expand Up @@ -536,7 +561,9 @@ def get_block_footprints(self, crs: rio.crs.CRS = None) -> gpd.GeoDataFrame:
return pd.concat(footprints)


def _chunks2d_from_chunksizes_shape(chunksizes: tuple[int, int], shape: tuple[int, int]):
def _chunks2d_from_chunksizes_shape(
chunksizes: tuple[int, int], shape: tuple[int, int]
) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""Get tuples of chunk sizes for X/Y dimensions based on chunksizes and array shape."""

# Chunksize is fixed, except for the last chunk depending on the shape
Expand Down Expand Up @@ -583,10 +610,10 @@ def _combined_blocks_shape_transform(
return combined_meta, relative_block_indexes


@dask.delayed
@dask.delayed # type: ignore
def _delayed_reproject_per_block(
*src_arrs: tuple[np.ndarray], block_ids: list[dict[str, int]], combined_meta: dict[str, Any], **kwargs: Any
) -> np.ndarray:
*src_arrs: tuple[NDArrayNum], block_ids: list[dict[str, int]], combined_meta: dict[str, Any], **kwargs: Any
) -> NDArrayNum:
"""
Delayed reprojection per destination block (also rebuilds a square array combined from intersecting source blocks).
"""
Expand All @@ -604,7 +631,7 @@ def _delayed_reproject_per_block(
# Then fill it with the source chunks values
for i, arr in enumerate(src_arrs):
bid = block_ids[i]
comb_src_arr[bid["rys"]:bid["rye"], bid["rxs"]:bid["rxe"]] = arr
comb_src_arr[bid["rys"] : bid["rye"], bid["rxs"] : bid["rxe"]] = arr

# Now, we can simply call Rasterio!

Expand Down Expand Up @@ -638,9 +665,9 @@ def delayed_reproject(
dst_shape: tuple[int, int],
dst_crs: rio.crs.CRS,
resampling: rio.enums.Resampling,
src_nodata: int | float = None,
dst_nodata: int | float = None,
dst_chunksizes: tuple[int, int] = None,
src_nodata: int | float | None = None,
dst_nodata: int | float | None = None,
dst_chunksizes: tuple[int, int] | None = None,
**kwargs: Any,
) -> da.Array:
"""
Expand Down Expand Up @@ -679,6 +706,8 @@ def delayed_reproject(
src_geotiling = ChunkedGeoGrid(grid=src_geogrid, chunks=src_chunks)

# For destination, we need to create the chunks based on destination chunksizes
if dst_chunksizes is None:
dst_chunksizes = darr.chunksize
dst_chunks = _chunks2d_from_chunksizes_shape(chunksizes=dst_chunksizes, shape=dst_shape)
dst_geotiling = ChunkedGeoGrid(grid=dst_geogrid, chunks=dst_chunks)

Expand Down
Loading

0 comments on commit fafd239

Please sign in to comment.