diff --git a/src/deepali/core/flow.py b/src/deepali/core/flow.py index 5677c19..805b532 100644 --- a/src/deepali/core/flow.py +++ b/src/deepali/core/flow.py @@ -71,7 +71,7 @@ def compose_svfs( sigma: Optional[float] = None, spacing: Optional[Union[Scalar, Array]] = None, stride: Optional[ScalarOrTuple[int]] = None, - bch_terms: int = 4, + bch_terms: int = 3, ) -> Tensor: r"""Approximate stationary velocity field (SVF) of composite deformation. @@ -79,19 +79,36 @@ def compose_svfs( of a stationary velocity field, and ``log`` its inverse. The velocity field ``w`` is given by the `Baker-Campbell-Hausdorff (BCH) formula `_. + The BCH formula with 5 Lie bracket terms (cf. ``bch_terms`` parameter) is + + .. math:: + + w = v + u + \frac{1}{2} [v, u] + + \frac{1}{12} ([v, [v, u]] - [u, [v, u]]) + + \frac{1}{48} ([[v, [v, u]], u] - [v, [u, [v, u]]]) + + where + + .. math:: + + [[v, [v, u]], u] - [v, [u, [v, u]]] = -2 [u, [v, [v, u]]] + References: - - Vercauteren, 2008. Symmetric Log-Domain Diffeomorphic Registration: A Demons-based Approach. - doi:10.1007/978-3-540-85988-8_90 + - Bossa & Olmos, 2008. A new algorithm for the computation of the group logarithm of diffeomorphisms. + https://inria.hal.science/inria-00629873 + - Vercauteren et al., 2008. Symmetric log-domain diffeomorphic registration: A Demons-based approach. + https://doi.org/10.1007/978-3-540-85988-8_90 Args: u: First applied stationary velocity field as tensor of shape ``(N, D, ..., X)``. v: Second applied stationary velocity field as tensor of shape ``(N, D, ..., X)``. - bch_terms: Number of terms of the BCH formula to consider. Must be at least 2. - When 2, the returned velocity field is the sum of ``u`` and ``v``. + bch_terms: Number of Lie bracket terms of the BCH formula to consider. + When 0, the returned velocity field is the sum of ``u`` and ``v``. This approximation is accurate if the input velocity fields commute, i.e., - the Lie bracket [v, u] = 0. When ``bch_terms=3``, the approximation is given by - ``w = v + u + 1/2 [v, u]`` (note that deformation ``exp(u)`` is applied first), - and when ``bch_terms=4``, it is ``w = v + u + 1/2 [v, u] + 1/12 [v, [v, u]]``. + the Lie bracket [v, u] = 0. When ``bch_terms=1``, the approximation is given by + ``w = v + u + 1/2 [v, u]`` (note ``exp(u)`` is applied before ``exp(v)``). Formula + ``w = v + u + \frac{1}{2} [v, u] + \frac{1}{12} ([v, [v, u]] - [u, [v, u]])`` is + used by default, i.e., ``bch_terms=3``. mode: Mode of :func:`flow_derivatives()` approximation. sigma: Standard deviation of Gaussian used for computing spatial derivatives. spacing: Physical size of image voxels used to compute spatial derivatives. @@ -114,25 +131,30 @@ def lb(a: Tensor, b: Tensor) -> Tensor: raise ValueError(f"compose_svfs() '{name}' must have shape (N, D, ..., X)") if u.shape != v.shape: raise ValueError("compose_svfs() 'u' and 'v' must have the same shape") - if bch_terms < 2: - raise ValueError("compose_svfs() 'bch_terms' must be at least 2") - elif bch_terms > 6: + if bch_terms < 0: + raise ValueError("compose_svfs() 'bch_terms' must not be negative") + elif bch_terms > 5: raise NotImplementedError("compose_svfs() 'bch_terms' of more than 6 not implemented") + # w = v + u w = v.add(u) - if bch_terms >= 3: + if bch_terms >= 1: + # + 1/2 [v, u] vu = lb(v, u) w = w.add(vu.mul(0.5)) - if bch_terms >= 4: + if bch_terms >= 2: + # + 1/12 [v, [v, u]] vvu = lb(v, vu) w = w.add(vvu.mul(1 / 12)) - if bch_terms >= 5: - uv = lb(u, v) - uuv = lb(u, uv) - w = w.add(uuv.mul(1 / 12)) - if bch_terms >= 6: + if bch_terms >= 3: + # - 1/12 [u, [v, u]] + uvu = lb(u, vu) + w = w.sub(uvu.mul(1 / 12)) + if bch_terms >= 4: + # + 1/48 [[v, [v, u]], u] = - 1/48 [u, [v, [v, u]]] + # - 1/48 [v, [u, [v, u]]] = - 1/48 [u, [v, [v, u]]] uvvu = lb(u, vvu) - w = w.sub(uvvu.mul(1 / 24)) + w = w.sub(uvvu.mul((1 if bch_terms == 4 else 2) / 48)) return w diff --git a/tests/_test_compose_svfs.py b/tests/_test_compose_svfs.py index 47d33fc..f4582f6 100644 --- a/tests/_test_compose_svfs.py +++ b/tests/_test_compose_svfs.py @@ -52,7 +52,7 @@ def visualize_flow(ax, flow: Tensor) -> None: # %% # Approximate velocity field of composite displacement field -flow_w = U.expv(U.compose_svfs(u, v, bch_terms=6)) +flow_w = U.expv(U.compose_svfs(u, v, bch_terms=3)) # %% diff --git a/tests/test_core_flow_utils.py b/tests/test_core_flow_utils.py index d1611c1..bf8cdc1 100644 --- a/tests/test_core_flow_utils.py +++ b/tests/test_core_flow_utils.py @@ -438,34 +438,32 @@ def test_flow_compose_svfs() -> None: with pytest.raises(ValueError): U.compose_svfs(p, p, bch_terms=-1) - with pytest.raises(ValueError): - U.compose_svfs(p, p, bch_terms=0) - with pytest.raises(ValueError): - U.compose_svfs(p, p, bch_terms=1) with pytest.raises(NotImplementedError): - U.compose_svfs(p, p, bch_terms=7) + U.compose_svfs(p, p, bch_terms=6) # u = [yz, xz, xy] and v = u u = v = torch.cat([y.mul(z), x.mul(z), x.mul(y)], dim=1) + w = U.compose_svfs(u, v, bch_terms=0) + assert torch.allclose(w, u.add(v)) + w = U.compose_svfs(u, v, bch_terms=1) + assert torch.allclose(w, u.add(v)) w = U.compose_svfs(u, v, bch_terms=2) assert torch.allclose(w, u.add(v)) w = U.compose_svfs(u, v, bch_terms=3) assert torch.allclose(w, u.add(v)) w = U.compose_svfs(u, v, bch_terms=4) - assert torch.allclose(w, u.add(v)) + assert torch.allclose(w, u.add(v), atol=1e-5) w = U.compose_svfs(u, v, bch_terms=5) - assert torch.allclose(w, u.add(v)) - w = U.compose_svfs(u, v, bch_terms=6) assert torch.allclose(w, u.add(v), atol=1e-5) # u = [yz, xz, xy] and v = [x, y, z] u = torch.cat([y.mul(z), x.mul(z), x.mul(y)], dim=1) v = torch.cat([x, y, z], dim=1) - w = U.compose_svfs(u, v, bch_terms=2) + w = U.compose_svfs(u, v, bch_terms=0) assert torch.allclose(w, u.add(v)) - w = U.compose_svfs(u, v, bch_terms=3) + w = U.compose_svfs(u, v, bch_terms=1) assert torch.allclose(w, u.mul(0.5).add(v), atol=1e-6) # u = random_svf(), u -> 0 at boundary @@ -474,7 +472,7 @@ def test_flow_compose_svfs() -> None: generator = torch.Generator().manual_seed(42) u = random_svf(size, stride=4, generator=generator).mul_(0.1) v = random_svf(size, stride=4, generator=generator).mul_(0.05) - w = U.compose_svfs(u, v, bch_terms=6) + w = U.compose_svfs(u, v, bch_terms=5) flow_u = U.expv(u) flow_v = U.expv(v)