diff --git a/src/deepali/core/flow.py b/src/deepali/core/flow.py index a01ed3b..5677c19 100644 --- a/src/deepali/core/flow.py +++ b/src/deepali/core/flow.py @@ -64,6 +64,79 @@ def compose_flows(u: Tensor, v: Tensor, align_corners: bool = True) -> Tensor: return u.add(v) +def compose_svfs( + u: Tensor, + v: Tensor, + mode: Optional[str] = None, + sigma: Optional[float] = None, + spacing: Optional[Union[Scalar, Array]] = None, + stride: Optional[ScalarOrTuple[int]] = None, + bch_terms: int = 4, +) -> Tensor: + r"""Approximate stationary velocity field (SVF) of composite deformation. + + The output velocity field is ``w = log(exp(v) o exp(u))``, where ``exp`` is the exponential map + of a stationary velocity field, and ``log`` its inverse. The velocity field ``w`` is given by the + `Baker-Campbell-Hausdorff (BCH) formula `_. + + References: + - Vercauteren, 2008. Symmetric Log-Domain Diffeomorphic Registration: A Demons-based Approach. + doi:10.1007/978-3-540-85988-8_90 + + Args: + u: First applied stationary velocity field as tensor of shape ``(N, D, ..., X)``. + v: Second applied stationary velocity field as tensor of shape ``(N, D, ..., X)``. + bch_terms: Number of terms of the BCH formula to consider. Must be at least 2. + When 2, the returned velocity field is the sum of ``u`` and ``v``. + This approximation is accurate if the input velocity fields commute, i.e., + the Lie bracket [v, u] = 0. When ``bch_terms=3``, the approximation is given by + ``w = v + u + 1/2 [v, u]`` (note that deformation ``exp(u)`` is applied first), + and when ``bch_terms=4``, it is ``w = v + u + 1/2 [v, u] + 1/12 [v, [v, u]]``. + mode: Mode of :func:`flow_derivatives()` approximation. + sigma: Standard deviation of Gaussian used for computing spatial derivatives. + spacing: Physical size of image voxels used to compute spatial derivatives. + stride: Number of output grid points between control points plus one for ``mode='bspline'``. + + Returns: + Approximation of BCH formula as tensor of shape ``(N, D, ..., X)``. + + """ + + def lb(a: Tensor, b: Tensor) -> Tensor: + return lie_bracket(a, b, mode=mode, sigma=sigma, spacing=spacing, stride=stride) + + for name, flow in [("u", u), ("v", v)]: + if flow.ndim < 4: + raise ValueError( + f"compose_svfs() '{name}' must be vector field of shape (N, D, ..., X)" + ) + if flow.shape[1] != flow.ndim - 2: + raise ValueError(f"compose_svfs() '{name}' must have shape (N, D, ..., X)") + if u.shape != v.shape: + raise ValueError("compose_svfs() 'u' and 'v' must have the same shape") + if bch_terms < 2: + raise ValueError("compose_svfs() 'bch_terms' must be at least 2") + elif bch_terms > 6: + raise NotImplementedError("compose_svfs() 'bch_terms' of more than 6 not implemented") + + w = v.add(u) + if bch_terms >= 3: + vu = lb(v, u) + w = w.add(vu.mul(0.5)) + if bch_terms >= 4: + vvu = lb(v, vu) + w = w.add(vvu.mul(1 / 12)) + if bch_terms >= 5: + uv = lb(u, v) + uuv = lb(u, uv) + w = w.add(uuv.mul(1 / 12)) + if bch_terms >= 6: + uvvu = lb(u, vvu) + w = w.sub(uvvu.mul(1 / 24)) + + return w + + def curl( flow: Tensor, mode: Optional[str] = None, @@ -508,6 +581,71 @@ def jacobian_matrix( return jac.contiguous() +def lie_bracket( + v: Tensor, + u: Tensor, + mode: Optional[str] = None, + sigma: Optional[float] = None, + spacing: Optional[Union[Scalar, Array]] = None, + stride: Optional[ScalarOrTuple[int]] = None, +) -> Tensor: + r"""Lie bracket of two vector fields. + + Evaluate Lie bracket given by ``[v, u] = Jac(v) * u - Jac(u) * v`` as defined in Eq (6) + of Vercauteren et al. (2008). + + Most authors define the Lie bracket as the opposite of (6). Numerical simulations, + and personal communication with M. Bossa, showed the relevance of this definition. + Future research will aim at fully understanding the reason of this discrepancy. + + References: + - Vercauteren, 2008. Symmetric Log-Domain Diffeomorphic Registration: A Demons-based Approach. + doi:10.1007/978-3-540-85988-8_90 + + Args: + u: Left vector field as tensor of shape ``(N, D, ..., X)``. + v: Right vector field as tensor of shape ``(N, D, ..., X)``. + mode: Mode of :func:`flow_derivatives()` approximation. + sigma: Standard deviation of Gaussian used for computing spatial derivatives. + spacing: Physical size of image voxels used to compute spatial derivatives. + stride: Number of output grid points between control points plus one for ``mode='bspline'``. + + Returns: + Lie bracket of vector fields as tensor of shape ``(N, D, ..., X)``. + + """ + for name, flow in [("u", u), ("v", v)]: + if flow.ndim < 4: + raise ValueError(f"lie_bracket() '{name}' must be vector field of shape (N, D, ..., X)") + if flow.shape[1] != flow.ndim - 2: + raise ValueError(f"lie_bracket() '{name}' must have shape (N, D, ..., X)") + if u.shape != v.shape: + raise ValueError("lie_bracket() 'u' and 'v' must have the same shape") + jac_u = jacobian_dict( + u, + mode=mode, + sigma=sigma, + spacing=spacing, + stride=stride, + ) + jac_v = jacobian_dict( + v, + mode=mode, + sigma=sigma, + spacing=spacing, + stride=stride, + ) + D = flow.ndim - 2 + w = torch.zeros_like(u) + for i in range(D): + w_i = w.narrow(1, i, 1) + for j in range(D): + w_i = w_i.add_(jac_v[(i, j)].mul(u.narrow(1, j, 1))) + for j in range(D): + w_i = w_i.sub_(jac_u[(i, j)].mul(v.narrow(1, j, 1))) + return w + + def normalize_flow( data: Tensor, size: Optional[Union[Tensor, torch.Size]] = None, diff --git a/src/deepali/core/functional.py b/src/deepali/core/functional.py index 6eb6add..31c56a7 100644 --- a/src/deepali/core/functional.py +++ b/src/deepali/core/functional.py @@ -93,6 +93,7 @@ from .flow import affine_flow from .flow import compose_flows +from .flow import compose_svfs from .flow import curl from .flow import denormalize_flow from .flow import divergence @@ -102,6 +103,7 @@ from .flow import jacobian_det from .flow import jacobian_dict from .flow import jacobian_matrix +from .flow import lie_bracket from .flow import normalize_flow from .flow import sample_flow from .flow import warp_grid @@ -182,6 +184,7 @@ "closest_point_distances", "closest_point_indices", "compose_flows", + "compose_svfs", "conv", "conv1d", "crop", @@ -216,6 +219,7 @@ "jacobian_det", "jacobian_dict", "jacobian_matrix", + "lie_bracket", "max_pool", "min_pool", "normalize_flow", diff --git a/tests/_test_compose_svfs.py b/tests/_test_compose_svfs.py new file mode 100644 index 0000000..47d33fc --- /dev/null +++ b/tests/_test_compose_svfs.py @@ -0,0 +1,71 @@ +# %% +# Imports +from typing import Optional, Sequence + +import matplotlib.pyplot as plt + +import torch +from torch import Tensor +from torch.random import Generator + +from deepali.core import Grid +import deepali.core.bspline as B +import deepali.core.functional as U + + +# %% +# Auxiliary functions +def random_svf( + size: Sequence[int], + stride: int = 1, + generator: Optional[Generator] = None, +) -> Tensor: + cp_grid_size = B.cubic_bspline_control_point_grid_size(size, stride=stride) + data = torch.randn((1, 3) + cp_grid_size, generator=generator) + data = U.fill_border(data, margin=3, value=0, inplace=True) + return B.evaluate_cubic_bspline(data, size=size, stride=stride) + + +def visualize_flow(ax, flow: Tensor) -> None: + grid = Grid(shape=flow.shape[2:], align_corners=True) + x = grid.coords(channels_last=False, dtype=u.dtype, device=u.device) + x = U.move_dim(x.unsqueeze(0).add_(flow), 1, -1) + target_grid = U.grid_image(shape=flow.shape[2:], inverted=True, stride=(5, 5)) + warped_grid = U.warp_image(target_grid, x) + ax.imshow(warped_grid[0, 0, flow.shape[2] // 2], cmap="gray") + + +# %% +# Random velocity fields +size = (128, 128, 128) +generator = torch.Generator().manual_seed(42) +u = random_svf(size, stride=8, generator=generator).mul_(0.1) +v = random_svf(size, stride=8, generator=generator).mul_(0.05) + + +# %% +# Evaluate displacement fields +flow_u = U.expv(u) +flow_v = U.expv(v) +flow = U.compose_flows(flow_u, flow_v) + + +# %% +# Approximate velocity field of composite displacement field +flow_w = U.expv(U.compose_svfs(u, v, bch_terms=6)) + + +# %% +# Visualize composite displacement fields and error norm +fig, axes = plt.subplots(1, 3, figsize=(30, 10)) + +visualize_flow(axes[0], flow) +visualize_flow(axes[1], flow_w) + +error = flow_w.sub(flow).norm(dim=1, keepdim=True) + +ax = axes[2] +_ = ax.imshow(error[0, 0, error.shape[2] // 2], cmap="jet", vmin=0, vmax=0.1) + + +# %% diff --git a/tests/test_core_flow_utils.py b/tests/test_core_flow_utils.py index 917bd99..d1611c1 100644 --- a/tests/test_core_flow_utils.py +++ b/tests/test_core_flow_utils.py @@ -1,9 +1,14 @@ +from typing import Optional, Sequence + +import pytest import torch import torch.nn.functional as F from torch import Tensor +from torch.random import Generator from deepali.core import Grid from deepali.core.enum import FlowDerivativeKeys +import deepali.core.bspline as B import deepali.core.functional as U @@ -83,6 +88,17 @@ def periodic_flow_divergence(p: Tensor) -> Tensor: return du_dx.add(dv_dy) +def random_svf( + size: Sequence[int], + stride: int = 1, + generator: Optional[Generator] = None, +) -> Tensor: + cp_grid_size = B.cubic_bspline_control_point_grid_size(size, stride=stride) + data = torch.randn((1, 3) + cp_grid_size, generator=generator) + data = U.fill_border(data, margin=3, value=0, inplace=True) + return B.evaluate_cubic_bspline(data, size=size, stride=stride) + + def difference(a: Tensor, b: Tensor, margin: int = 0) -> Tensor: assert a.shape == b.shape i = [ @@ -380,3 +396,90 @@ def test_flow_jacobian() -> None: jac[[slice(None)] + interior].det(), atol=0.01, ) + + +def test_flow_lie_bracket() -> None: + p = U.move_dim(Grid(size=(64, 32, 16)).coords().unsqueeze_(0), -1, 1) + + x = p.narrow(1, 0, 1) + y = p.narrow(1, 1, 1) + z = p.narrow(1, 2, 1) + + # u = [yz, xz, xy] and v = [x, y, z] + u = torch.cat([y.mul(z), x.mul(z), x.mul(y)], dim=1) + v = torch.cat([x, y, z], dim=1) + w = u + + lb_uv = U.lie_bracket(u, v) + assert torch.allclose(U.lie_bracket(v, u), lb_uv.neg()) + assert U.lie_bracket(u, u).abs().lt(1e-6).all() + assert torch.allclose(lb_uv, w, atol=1e-6) + + # u = [z^2, 0, xy] and v = [0, x + y^3, yz] + u = torch.cat([z.square(), torch.zeros_like(y), x.mul(y)], dim=1) + v = torch.cat([torch.zeros_like(x), x.add(y.pow(3)), y.mul(z)], dim=1) + w = torch.cat([-2 * y * z**2, z**2, x * y**2 - x**2 - x * y**3], dim=1).neg_() + + lb_uv = U.lie_bracket(u, v) + assert torch.allclose(U.lie_bracket(v, u), lb_uv.neg()) + assert U.lie_bracket(u, u).abs().lt(1e-6).all() + error = difference(lb_uv, w).abs() + assert error[:, :, 1:-1, 1:-1, 1:-1].max().lt(1e-5) + assert error.max().lt(0.134) + + +def test_flow_compose_svfs() -> None: + # 3D flow fields + p = U.move_dim(Grid(size=(64, 32, 16)).coords().unsqueeze_(0), -1, 1) + + x = p.narrow(1, 0, 1) + y = p.narrow(1, 1, 1) + z = p.narrow(1, 2, 1) + + with pytest.raises(ValueError): + U.compose_svfs(p, p, bch_terms=-1) + with pytest.raises(ValueError): + U.compose_svfs(p, p, bch_terms=0) + with pytest.raises(ValueError): + U.compose_svfs(p, p, bch_terms=1) + with pytest.raises(NotImplementedError): + U.compose_svfs(p, p, bch_terms=7) + + # u = [yz, xz, xy] and v = u + u = v = torch.cat([y.mul(z), x.mul(z), x.mul(y)], dim=1) + + w = U.compose_svfs(u, v, bch_terms=2) + assert torch.allclose(w, u.add(v)) + w = U.compose_svfs(u, v, bch_terms=3) + assert torch.allclose(w, u.add(v)) + w = U.compose_svfs(u, v, bch_terms=4) + assert torch.allclose(w, u.add(v)) + w = U.compose_svfs(u, v, bch_terms=5) + assert torch.allclose(w, u.add(v)) + w = U.compose_svfs(u, v, bch_terms=6) + assert torch.allclose(w, u.add(v), atol=1e-5) + + # u = [yz, xz, xy] and v = [x, y, z] + u = torch.cat([y.mul(z), x.mul(z), x.mul(y)], dim=1) + v = torch.cat([x, y, z], dim=1) + + w = U.compose_svfs(u, v, bch_terms=2) + assert torch.allclose(w, u.add(v)) + w = U.compose_svfs(u, v, bch_terms=3) + assert torch.allclose(w, u.mul(0.5).add(v), atol=1e-6) + + # u = random_svf(), u -> 0 at boundary + # v = random_svf(), v -> 0 at boundary + size = (64, 64, 64) + generator = torch.Generator().manual_seed(42) + u = random_svf(size, stride=4, generator=generator).mul_(0.1) + v = random_svf(size, stride=4, generator=generator).mul_(0.05) + w = U.compose_svfs(u, v, bch_terms=6) + + flow_u = U.expv(u) + flow_v = U.expv(v) + flow_w = U.expv(w) + flow = U.compose_flows(flow_u, flow_v) + + error = flow_w.sub(flow).norm(dim=1) + assert error.max().lt(0.01)