diff --git a/xdem/coreg.py b/xdem/coreg.py index 49f5c088..4af8f296 100644 --- a/xdem/coreg.py +++ b/xdem/coreg.py @@ -14,6 +14,7 @@ _has_cv2 = False import fiona import geoutils as gu +import matplotlib.pyplot as plt import numpy as np import pandas as pd import rasterio as rio @@ -25,6 +26,7 @@ import scipy.optimize import skimage.transform from geoutils import spatial_tools +from geoutils._typing import AnyNumber from geoutils.georaster import RasterType from rasterio import Affine from tqdm import tqdm, trange @@ -1483,6 +1485,7 @@ def apply_matrix( centroid: tuple[float, float, float] | None = None, resampling: int | str = "bilinear", dilate_mask: bool = False, + fill_max_search: int = 0, ) -> NDArrayf: """ Apply a 3D transformation matrix to a 2.5D DEM. @@ -1503,7 +1506,10 @@ def apply_matrix( :param invert: Invert the transformation matrix. :param centroid: The X/Y/Z transformation centroid. Irrelevant for pure translations. Defaults to the midpoint (Z=0) :param resampling: The resampling method to use. Can be `nearest`, `bilinear`, `cubic` or an integer from 0-5. - :param dilate_mask: Dilate the nan mask to exclude edge pixels that could be wrong. + :param dilate_mask: DEPRECATED - This option does not do anything anymore. Will be removed in the future. + :param fill_max_search: Set to > 0 value to fill the DEM before applying the transformation, to avoid spreading\ + gaps. The DEM will be filled with rasterio.fill.fillnodata with max_search_distance set to fill_max_search.\ + This is experimental, use at your own risk ! :returns: The transformed DEM with NaNs as nodata values (replaces a potential mask of the input `dem`). """ @@ -1536,8 +1542,11 @@ def apply_matrix( nan_mask = spatial_tools.get_mask(dem) assert np.count_nonzero(~nan_mask) > 0, "Given DEM had all nans." - # Create a filled version of the DEM. (skimage doesn't like nans) - filled_dem = np.where(~nan_mask, demc, np.nan) + # Optionally, fill DEM around gaps to reduce spread of gaps + if fill_max_search > 0: + filled_dem = rio.fill.fillnodata(demc, mask=(~nan_mask).astype("uint8"), max_search_distance=fill_max_search) + else: + filled_dem = demc # np.where(~nan_mask, demc, np.nan) # I don't know why this was needed - to delete # Get the centre coordinates of the DEM pixels. x_coords, y_coords = _get_x_and_y_coords(demc.shape, transform) @@ -1579,9 +1588,11 @@ def apply_matrix( # Shift the elevation values of the soon-to-be-warped DEM. filled_dem -= deramp(x_coords, y_coords) - # Create gap-free arrays of x and y coordinates to be converted into index coordinates. - x_inds = rio.fill.fillnodata(transformed_points[:, :, 0].copy(), mask=(~nan_mask).astype("uint8")) - y_inds = rio.fill.fillnodata(transformed_points[:, :, 1].copy(), mask=(~nan_mask).astype("uint8")) + # Create arrays of x and y coordinates to be converted into index coordinates. + x_inds = transformed_points[:, :, 0].copy() + x_inds[x_inds == 0] = np.nan + y_inds = transformed_points[:, :, 1].copy() + y_inds[y_inds == 0] = np.nan # Divide the coordinates by the resolution to create index coordinates. x_inds /= resolution @@ -1601,19 +1612,20 @@ def apply_matrix( transformed_dem = skimage.transform.warp( filled_dem, inds, order=resampling_order, mode="constant", cval=np.nan, preserve_range=True ) - # Warp the NaN mask, setting true to all values outside the new frame. - tr_nan_mask = ( - skimage.transform.warp( - nan_mask.astype("uint8"), inds, order=resampling_order, mode="constant", cval=1, preserve_range=True - ) - > 0 - ) + # TODO: remove these lines when dilate_mask is deprecated + # # Warp the NaN mask, setting true to all values outside the new frame. + # tr_nan_mask = ( + # skimage.transform.warp( + # nan_mask.astype("uint8"), inds, order=resampling_order, mode="constant", cval=1, preserve_range=True + # ) + # > 0 + # ) - if dilate_mask: - tr_nan_mask = scipy.ndimage.binary_dilation(tr_nan_mask, iterations=resampling_order) + # if dilate_mask: + # tr_nan_mask = scipy.ndimage.binary_dilation(tr_nan_mask, iterations=resampling_order) - # Apply the transformed nan_mask - transformed_dem[tr_nan_mask] = np.nan + # # Apply the transformed nan_mask + # transformed_dem[tr_nan_mask] = np.nan assert np.count_nonzero(~np.isnan(transformed_dem)) > 0, "Transformed DEM has all nans." @@ -2141,3 +2153,178 @@ def warp_dem( assert not np.all(np.isnan(warped)), "All-NaN output." return warped.reshape(dem.shape) + + +hmodes_dict = { + "nuth_kaab": NuthKaab(), + "nuth_kaab_block": BlockwiseCoreg(coreg=NuthKaab(), subdivision=16), + "icp": ICP(), +} + +vmodes_dict = { + "median": BiasCorr(bias_func=np.median), + "mean": BiasCorr(bias_func=np.mean), + "deramp": Deramp(), +} + + +def dem_coregistration( + src_dem_path: str, + ref_dem_path: str, + out_dem_path: str | None = None, + shpfile: str | None = None, + coreg_method: Coreg | None = None, + hmode: str = "nuth_kaab", + vmode: str = "median", + deramp_degree: int = 1, + grid: str = "ref", + filtering: bool = True, + slope_lim: list[AnyNumber] | tuple[AnyNumber, AnyNumber] = (0.1, 40), + plot: bool = False, + out_fig: str = None, + verbose: bool = False, +) -> tuple[xdem.DEM, pd.DataFrame]: + """ + A one-line function to coregister a selected DEM to a reference DEM. + Reads both DEMs, reprojects them on the same grid, mask content of shpfile, filter steep slopes and outliers, \ +run the coregistration, returns the coregistered DEM and some statistics. + Optionally, save the coregistered DEM to file and make a figure. + + :param src_dem_path: path to the input DEM to be coregistered + :param ref_dem: path to the reference DEM + :param out_dem_path: Path where to save the coregistered DEM. If set to None (default), will not save to file. + :param shpfile: path to a vector file containing areas to be masked for coregistration + :param coreg_method: The xdem coregistration method, or pipeline. If set to None, DEMs will be resampled to \ +ref grid and optionally filtered, but not coregistered. Will be used in priority over hmode and vmode. + :param hmode: The method to be used for horizontally aligning the DEMs, e.g. Nuth & Kaab or ICP. Can be any \ +of {list(vmodes_dict.keys())}. + :param vmode: The method to be used for vertically aligning the DEMs, e.g. mean/median bias correction or \ +deramping. Can be any of {list(hmodes_dict.keys())}. + :param deramp_degree: The degree of the polynomial for deramping. + :param grid: the grid to be used during coregistration, set either to "ref" or "src". + :param filtering: if set to True, filtering will be applied prior to coregistration + :param plot: Set to True to plot a figure of elevation diff before/after coregistration + :param out_fig: Path to the output figure. If None will display to screen. + :param verbose: set to True to print details on screen during coregistration. + + :returns: a tuple containing 1) coregistered DEM as an xdem.DEM instance and 2) DataFrame of coregistration \ +statistics (count of obs, median and NMAD over stable terrain) before and after coreg. + """ + # Check input arguments + if (coreg_method is not None) and ((hmode is not None) or (vmode is not None)): + warnings.warn("Both `coreg_method` and `hmode/vmode` are set. Using coreg_method.") + + if hmode not in list(hmodes_dict.keys()): + raise ValueError(f"vhmode must be in {list(hmodes_dict.keys())}") + + if vmode not in list(vmodes_dict.keys()): + raise ValueError(f"vmode must be in {list(vmodes_dict.keys())}") + + # Load both DEMs + if verbose: + print("Loading and reprojecting input data") + if grid == "ref": + ref_dem, src_dem = gu.spatial_tools.load_multiple_rasters([ref_dem_path, src_dem_path], ref_grid=0) + elif grid == "src": + ref_dem, src_dem = gu.spatial_tools.load_multiple_rasters([ref_dem_path, src_dem_path], ref_grid=1) + else: + raise ValueError(f"`grid` must be either 'ref' or 'src' - currently set to {grid}") + + # Convert to DEM instance with Float32 dtype + ref_dem = xdem.DEM(ref_dem.astype(np.float32)) + src_dem = xdem.DEM(src_dem.astype(np.float32)) + + # Create raster mask + if shpfile is not None: + outlines = gu.Vector(shpfile) + stable_mask = ~outlines.create_mask(src_dem) + else: + stable_mask = np.ones(src_dem.data.shape, dtype="bool") + + # Calculate dDEM + ddem = src_dem - ref_dem + + # Filter gross outliers in stable terrain + if filtering: + # Remove gross blunders where dh differ by 5 NMAD from the median + inlier_mask = stable_mask & (np.abs(ddem.data - np.median(ddem)) < 5 * xdem.spatialstats.nmad(ddem)).filled( + False + ) + + # Exclude steep slopes for coreg + slope = xdem.terrain.slope(ref_dem) + inlier_mask[slope.data < slope_lim[0]] = False + inlier_mask[slope.data > slope_lim[1]] = False + + else: + inlier_mask = stable_mask + + # Calculate dDEM statistics on pixels used for coreg + inlier_data = ddem.data[inlier_mask].compressed() + nstable_orig, mean_orig = len(inlier_data), np.mean(inlier_data) + med_orig, nmad_orig = np.median(inlier_data), xdem.spatialstats.nmad(inlier_data) + + # Coregister to reference - Note: this will spread NaN + # Better strategy: calculate shift, update transform, resample + if isinstance(coreg_method, xdem.coreg.Coreg): + coreg_method.fit(ref_dem, src_dem, inlier_mask, verbose=verbose) + dem_coreg = coreg_method.apply(src_dem, dilate_mask=False) + elif coreg_method is None: + # Horizontal coregistration + hcoreg_method = hmodes_dict[hmode] + hcoreg_method.fit(ref_dem, src_dem, inlier_mask, verbose=verbose) + dem_hcoreg = hcoreg_method.apply(src_dem, dilate_mask=False) + + # Vertical coregistration + vcoreg_method = vmodes_dict[vmode] + if vmode == "deramp": + vcoreg_method.degree = deramp_degree + vcoreg_method.fit(ref_dem, dem_hcoreg, inlier_mask, verbose=verbose) + dem_coreg = vcoreg_method.apply(dem_hcoreg, dilate_mask=False) + + ddem_coreg = dem_coreg - ref_dem + + # Calculate new stats + inlier_data = ddem_coreg.data[inlier_mask].compressed() + nstable_coreg, mean_coreg = len(inlier_data), np.mean(inlier_data) + med_coreg, nmad_coreg = np.median(inlier_data), xdem.spatialstats.nmad(inlier_data) + + # Plot results + if plot: + # Max colorbar value - 98th percentile rounded to nearest 5 + vmax = np.percentile(np.abs(ddem.data.compressed()), 98) // 5 * 5 + + plt.figure(figsize=(11, 5)) + + ax1 = plt.subplot(121) + plt.imshow(ddem.data.squeeze(), cmap="coolwarm_r", vmin=-vmax, vmax=vmax) + cb = plt.colorbar() + cb.set_label("Elevation change (m)") + ax1.set_title(f"Before coreg\n\nmean = {mean_orig:.1f} m - med = {med_orig:.1f} m - NMAD = {nmad_orig:.1f} m") + + ax2 = plt.subplot(122, sharex=ax1, sharey=ax1) + plt.imshow(ddem_coreg.data.squeeze(), cmap="coolwarm_r", vmin=-vmax, vmax=vmax) + cb = plt.colorbar() + cb.set_label("Elevation change (m)") + ax2.set_title( + f"After coreg\n\n\nmean = {mean_coreg:.1f} m - med = {med_coreg:.1f} m - NMAD = {nmad_coreg:.1f} m" + ) + + plt.tight_layout() + if out_fig is None: + plt.show() + else: + plt.savefig(out_fig, dpi=200) + plt.close() + + # Save coregistered DEM + if out_dem_path is not None: + dem_coreg.save(out_dem_path, tiled=True) + + # Save stats to DataFrame + out_stats = pd.DataFrame( + ((nstable_orig, med_orig, nmad_orig, nstable_coreg, med_coreg, nmad_coreg),), + columns=("nstable_orig", "med_orig", "nmad_orig", "nstable_coreg", "med_coreg", "nmad_coreg"), + ) + + return dem_coreg, out_stats