Skip to content

Commit

Permalink
Implement one-liner coregistration (#267)
Browse files Browse the repository at this point in the history
* Remove unused imports

* Include a one-liner function for DEM coregistration

* Split coreg into horizontal/vertical coreg. Update stats and plots

* Minor comment edit

* Fix issue with edge artifacts brought up in #314

* Add an experimatal function to slightly extrapolate DEM prior to applying transformation

* Load DEMs only in area of overlap using new geoutils functionalities.

* Linting and mypy

* Upgrade pre-commit

* Improve docstring. Make output file optional.

* Add a 5 NMAD filter to the filtering option.
  • Loading branch information
adehecq authored Oct 18, 2022
1 parent e878d78 commit 792a486
Showing 1 changed file with 204 additions and 17 deletions.
221 changes: 204 additions & 17 deletions xdem/coreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_has_cv2 = False
import fiona
import geoutils as gu
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rasterio as rio
Expand All @@ -25,6 +26,7 @@
import scipy.optimize
import skimage.transform
from geoutils import spatial_tools
from geoutils._typing import AnyNumber
from geoutils.georaster import RasterType
from rasterio import Affine
from tqdm import tqdm, trange
Expand Down Expand Up @@ -1483,6 +1485,7 @@ def apply_matrix(
centroid: tuple[float, float, float] | None = None,
resampling: int | str = "bilinear",
dilate_mask: bool = False,
fill_max_search: int = 0,
) -> NDArrayf:
"""
Apply a 3D transformation matrix to a 2.5D DEM.
Expand All @@ -1503,7 +1506,10 @@ def apply_matrix(
:param invert: Invert the transformation matrix.
:param centroid: The X/Y/Z transformation centroid. Irrelevant for pure translations. Defaults to the midpoint (Z=0)
:param resampling: The resampling method to use. Can be `nearest`, `bilinear`, `cubic` or an integer from 0-5.
:param dilate_mask: Dilate the nan mask to exclude edge pixels that could be wrong.
:param dilate_mask: DEPRECATED - This option does not do anything anymore. Will be removed in the future.
:param fill_max_search: Set to > 0 value to fill the DEM before applying the transformation, to avoid spreading\
gaps. The DEM will be filled with rasterio.fill.fillnodata with max_search_distance set to fill_max_search.\
This is experimental, use at your own risk !
:returns: The transformed DEM with NaNs as nodata values (replaces a potential mask of the input `dem`).
"""
Expand Down Expand Up @@ -1536,8 +1542,11 @@ def apply_matrix(

nan_mask = spatial_tools.get_mask(dem)
assert np.count_nonzero(~nan_mask) > 0, "Given DEM had all nans."
# Create a filled version of the DEM. (skimage doesn't like nans)
filled_dem = np.where(~nan_mask, demc, np.nan)
# Optionally, fill DEM around gaps to reduce spread of gaps
if fill_max_search > 0:
filled_dem = rio.fill.fillnodata(demc, mask=(~nan_mask).astype("uint8"), max_search_distance=fill_max_search)
else:
filled_dem = demc # np.where(~nan_mask, demc, np.nan) # I don't know why this was needed - to delete

# Get the centre coordinates of the DEM pixels.
x_coords, y_coords = _get_x_and_y_coords(demc.shape, transform)
Expand Down Expand Up @@ -1579,9 +1588,11 @@ def apply_matrix(
# Shift the elevation values of the soon-to-be-warped DEM.
filled_dem -= deramp(x_coords, y_coords)

# Create gap-free arrays of x and y coordinates to be converted into index coordinates.
x_inds = rio.fill.fillnodata(transformed_points[:, :, 0].copy(), mask=(~nan_mask).astype("uint8"))
y_inds = rio.fill.fillnodata(transformed_points[:, :, 1].copy(), mask=(~nan_mask).astype("uint8"))
# Create arrays of x and y coordinates to be converted into index coordinates.
x_inds = transformed_points[:, :, 0].copy()
x_inds[x_inds == 0] = np.nan
y_inds = transformed_points[:, :, 1].copy()
y_inds[y_inds == 0] = np.nan

# Divide the coordinates by the resolution to create index coordinates.
x_inds /= resolution
Expand All @@ -1601,19 +1612,20 @@ def apply_matrix(
transformed_dem = skimage.transform.warp(
filled_dem, inds, order=resampling_order, mode="constant", cval=np.nan, preserve_range=True
)
# Warp the NaN mask, setting true to all values outside the new frame.
tr_nan_mask = (
skimage.transform.warp(
nan_mask.astype("uint8"), inds, order=resampling_order, mode="constant", cval=1, preserve_range=True
)
> 0
)
# TODO: remove these lines when dilate_mask is deprecated
# # Warp the NaN mask, setting true to all values outside the new frame.
# tr_nan_mask = (
# skimage.transform.warp(
# nan_mask.astype("uint8"), inds, order=resampling_order, mode="constant", cval=1, preserve_range=True
# )
# > 0
# )

if dilate_mask:
tr_nan_mask = scipy.ndimage.binary_dilation(tr_nan_mask, iterations=resampling_order)
# if dilate_mask:
# tr_nan_mask = scipy.ndimage.binary_dilation(tr_nan_mask, iterations=resampling_order)

# Apply the transformed nan_mask
transformed_dem[tr_nan_mask] = np.nan
# # Apply the transformed nan_mask
# transformed_dem[tr_nan_mask] = np.nan

assert np.count_nonzero(~np.isnan(transformed_dem)) > 0, "Transformed DEM has all nans."

Expand Down Expand Up @@ -2141,3 +2153,178 @@ def warp_dem(
assert not np.all(np.isnan(warped)), "All-NaN output."

return warped.reshape(dem.shape)


hmodes_dict = {
"nuth_kaab": NuthKaab(),
"nuth_kaab_block": BlockwiseCoreg(coreg=NuthKaab(), subdivision=16),
"icp": ICP(),
}

vmodes_dict = {
"median": BiasCorr(bias_func=np.median),
"mean": BiasCorr(bias_func=np.mean),
"deramp": Deramp(),
}


def dem_coregistration(
src_dem_path: str,
ref_dem_path: str,
out_dem_path: str | None = None,
shpfile: str | None = None,
coreg_method: Coreg | None = None,
hmode: str = "nuth_kaab",
vmode: str = "median",
deramp_degree: int = 1,
grid: str = "ref",
filtering: bool = True,
slope_lim: list[AnyNumber] | tuple[AnyNumber, AnyNumber] = (0.1, 40),
plot: bool = False,
out_fig: str = None,
verbose: bool = False,
) -> tuple[xdem.DEM, pd.DataFrame]:
"""
A one-line function to coregister a selected DEM to a reference DEM.
Reads both DEMs, reprojects them on the same grid, mask content of shpfile, filter steep slopes and outliers, \
run the coregistration, returns the coregistered DEM and some statistics.
Optionally, save the coregistered DEM to file and make a figure.
:param src_dem_path: path to the input DEM to be coregistered
:param ref_dem: path to the reference DEM
:param out_dem_path: Path where to save the coregistered DEM. If set to None (default), will not save to file.
:param shpfile: path to a vector file containing areas to be masked for coregistration
:param coreg_method: The xdem coregistration method, or pipeline. If set to None, DEMs will be resampled to \
ref grid and optionally filtered, but not coregistered. Will be used in priority over hmode and vmode.
:param hmode: The method to be used for horizontally aligning the DEMs, e.g. Nuth & Kaab or ICP. Can be any \
of {list(vmodes_dict.keys())}.
:param vmode: The method to be used for vertically aligning the DEMs, e.g. mean/median bias correction or \
deramping. Can be any of {list(hmodes_dict.keys())}.
:param deramp_degree: The degree of the polynomial for deramping.
:param grid: the grid to be used during coregistration, set either to "ref" or "src".
:param filtering: if set to True, filtering will be applied prior to coregistration
:param plot: Set to True to plot a figure of elevation diff before/after coregistration
:param out_fig: Path to the output figure. If None will display to screen.
:param verbose: set to True to print details on screen during coregistration.
:returns: a tuple containing 1) coregistered DEM as an xdem.DEM instance and 2) DataFrame of coregistration \
statistics (count of obs, median and NMAD over stable terrain) before and after coreg.
"""
# Check input arguments
if (coreg_method is not None) and ((hmode is not None) or (vmode is not None)):
warnings.warn("Both `coreg_method` and `hmode/vmode` are set. Using coreg_method.")

if hmode not in list(hmodes_dict.keys()):
raise ValueError(f"vhmode must be in {list(hmodes_dict.keys())}")

if vmode not in list(vmodes_dict.keys()):
raise ValueError(f"vmode must be in {list(vmodes_dict.keys())}")

# Load both DEMs
if verbose:
print("Loading and reprojecting input data")
if grid == "ref":
ref_dem, src_dem = gu.spatial_tools.load_multiple_rasters([ref_dem_path, src_dem_path], ref_grid=0)
elif grid == "src":
ref_dem, src_dem = gu.spatial_tools.load_multiple_rasters([ref_dem_path, src_dem_path], ref_grid=1)
else:
raise ValueError(f"`grid` must be either 'ref' or 'src' - currently set to {grid}")

# Convert to DEM instance with Float32 dtype
ref_dem = xdem.DEM(ref_dem.astype(np.float32))
src_dem = xdem.DEM(src_dem.astype(np.float32))

# Create raster mask
if shpfile is not None:
outlines = gu.Vector(shpfile)
stable_mask = ~outlines.create_mask(src_dem)
else:
stable_mask = np.ones(src_dem.data.shape, dtype="bool")

# Calculate dDEM
ddem = src_dem - ref_dem

# Filter gross outliers in stable terrain
if filtering:
# Remove gross blunders where dh differ by 5 NMAD from the median
inlier_mask = stable_mask & (np.abs(ddem.data - np.median(ddem)) < 5 * xdem.spatialstats.nmad(ddem)).filled(
False
)

# Exclude steep slopes for coreg
slope = xdem.terrain.slope(ref_dem)
inlier_mask[slope.data < slope_lim[0]] = False
inlier_mask[slope.data > slope_lim[1]] = False

else:
inlier_mask = stable_mask

# Calculate dDEM statistics on pixels used for coreg
inlier_data = ddem.data[inlier_mask].compressed()
nstable_orig, mean_orig = len(inlier_data), np.mean(inlier_data)
med_orig, nmad_orig = np.median(inlier_data), xdem.spatialstats.nmad(inlier_data)

# Coregister to reference - Note: this will spread NaN
# Better strategy: calculate shift, update transform, resample
if isinstance(coreg_method, xdem.coreg.Coreg):
coreg_method.fit(ref_dem, src_dem, inlier_mask, verbose=verbose)
dem_coreg = coreg_method.apply(src_dem, dilate_mask=False)
elif coreg_method is None:
# Horizontal coregistration
hcoreg_method = hmodes_dict[hmode]
hcoreg_method.fit(ref_dem, src_dem, inlier_mask, verbose=verbose)
dem_hcoreg = hcoreg_method.apply(src_dem, dilate_mask=False)

# Vertical coregistration
vcoreg_method = vmodes_dict[vmode]
if vmode == "deramp":
vcoreg_method.degree = deramp_degree
vcoreg_method.fit(ref_dem, dem_hcoreg, inlier_mask, verbose=verbose)
dem_coreg = vcoreg_method.apply(dem_hcoreg, dilate_mask=False)

ddem_coreg = dem_coreg - ref_dem

# Calculate new stats
inlier_data = ddem_coreg.data[inlier_mask].compressed()
nstable_coreg, mean_coreg = len(inlier_data), np.mean(inlier_data)
med_coreg, nmad_coreg = np.median(inlier_data), xdem.spatialstats.nmad(inlier_data)

# Plot results
if plot:
# Max colorbar value - 98th percentile rounded to nearest 5
vmax = np.percentile(np.abs(ddem.data.compressed()), 98) // 5 * 5

plt.figure(figsize=(11, 5))

ax1 = plt.subplot(121)
plt.imshow(ddem.data.squeeze(), cmap="coolwarm_r", vmin=-vmax, vmax=vmax)
cb = plt.colorbar()
cb.set_label("Elevation change (m)")
ax1.set_title(f"Before coreg\n\nmean = {mean_orig:.1f} m - med = {med_orig:.1f} m - NMAD = {nmad_orig:.1f} m")

ax2 = plt.subplot(122, sharex=ax1, sharey=ax1)
plt.imshow(ddem_coreg.data.squeeze(), cmap="coolwarm_r", vmin=-vmax, vmax=vmax)
cb = plt.colorbar()
cb.set_label("Elevation change (m)")
ax2.set_title(
f"After coreg\n\n\nmean = {mean_coreg:.1f} m - med = {med_coreg:.1f} m - NMAD = {nmad_coreg:.1f} m"
)

plt.tight_layout()
if out_fig is None:
plt.show()
else:
plt.savefig(out_fig, dpi=200)
plt.close()

# Save coregistered DEM
if out_dem_path is not None:
dem_coreg.save(out_dem_path, tiled=True)

# Save stats to DataFrame
out_stats = pd.DataFrame(
((nstable_orig, med_orig, nmad_orig, nstable_coreg, med_coreg, nmad_coreg),),
columns=("nstable_orig", "med_orig", "nmad_orig", "nstable_coreg", "med_coreg", "nmad_coreg"),
)

return dem_coreg, out_stats

0 comments on commit 792a486

Please sign in to comment.