From 121999f9f8dce1f1bb1b2be0922c1aafb2f99611 Mon Sep 17 00:00:00 2001 From: Andreas Schuh Date: Mon, 6 Nov 2023 23:57:12 +0000 Subject: [PATCH] [core] Properly fix compose_flows() this time --- src/deepali/core/flow.py | 62 ++++---------- src/deepali/core/functional.py | 2 - tests/test_core_flow_utils.py | 147 ++++++++++++++++++++++++++++++++- 3 files changed, 160 insertions(+), 51 deletions(-) diff --git a/src/deepali/core/flow.py b/src/deepali/core/flow.py index 1a5b412..a01ed3b 100644 --- a/src/deepali/core/flow.py +++ b/src/deepali/core/flow.py @@ -1,6 +1,6 @@ r"""Functions relating to tensors representing vector fields.""" -from itertools import combinations_with_replacement, permutations, product +from itertools import permutations, product from typing import Dict, Optional, Sequence, Tuple, Union import torch @@ -56,10 +56,12 @@ def affine_flow(matrix: Tensor, grid: Union[Grid, Tensor], channels_last: bool = def compose_flows(u: Tensor, v: Tensor, align_corners: bool = True) -> Tensor: - r"""Compute composite flow field ``w = v o u``.""" - u = move_dim(u, 1, -1) - w = F.grid_sample(v, u, mode="bilinear", padding_mode="border", align_corners=align_corners) - return w + r"""Compute composite flow field ``w = v o u = u(x) + v(x + u(x))``.""" + grid = Grid(shape=u.shape[2:], align_corners=align_corners) + x = grid.coords(channels_last=False, dtype=u.dtype, device=u.device) + x = move_dim(x.unsqueeze(0).add_(u), 1, -1) + v = F.grid_sample(v, x, mode="bilinear", padding_mode="border", align_corners=align_corners) + return u.add(v) def curl( @@ -462,9 +464,9 @@ def jacobian_dict( for i in range(D): deriv[FlowDerivativeKeys.symbol(i, i)].add_(1) jac = {} - for i, j in combinations_with_replacement(range(D), 2): + for i, j in product(range(D), repeat=2): jac[(i, j)] = deriv[FlowDerivativeKeys.symbol(i, j)] - return {(i, j): jac[(i, j) if i < j else (j, i)] for i, j in product(range(D), repeat=2)} + return jac def jacobian_matrix( @@ -491,46 +493,8 @@ def jacobian_matrix( Full Jacobian matrices as tensor of shape ``(N, ..., X, D, D)``. """ - jac = jacobian_dict( - flow, - mode=mode, - sigma=sigma, - spacing=spacing, - stride=stride, - add_identity=add_identity, - ) D = flow.ndim - 2 - mat = torch.cat([jac[(i, j)] for i, j in product(range(D), repeat=2)], dim=1) - mat = move_dim(mat, 1, -1) - mat = mat.reshape(mat.shape[:-1] + (D, D)) - return mat.contiguous() - - -def jacobian_triu( - flow: torch.Tensor, - mode: Optional[str] = None, - sigma: Optional[float] = None, - spacing: Optional[Union[Scalar, Array]] = None, - stride: Optional[ScalarOrTuple[int]] = None, - add_identity: bool = False, -) -> Tensor: - r"""Evaluate Jacobian of spatial deformation. - - Args: - flow: Input vector field as tensor of shape ``(N, D, ..., X)``. - mode: Mode of :func:`flow_derivatives()` approximation. - sigma: Standard deviation of Gaussian used for computing spatial derivatives. - spacing: Physical size of image voxels used to compute spatial derivatives. - stride: Number of output grid points between control points plus one for ``mode='bspline'``. - add_identity: Whether to calculate derivatives of :math:`u(x)` (False) or the spatial - deformation given by :math:`x + u(x)` (True), where :math:`u` is the flow field, - by adding the identity matrix to the Jacobian of :math:`u`. - - Returns: - Flattened upper triangular Jacobian matrices as tensor of shape ``(N, D * (D + 1) / 2, ..., X)``. - - """ - jac = jacobian_dict( + deriv = jacobian_dict( flow, mode=mode, sigma=sigma, @@ -538,8 +502,10 @@ def jacobian_triu( stride=stride, add_identity=add_identity, ) - D = flow.ndim - 2 - return torch.cat([jac[(i, j)] for i, j in combinations_with_replacement(range(D), 2)], dim=1) + jac = torch.cat(list(deriv.values()), dim=1) + jac = move_dim(jac, 1, -1) + jac = jac.reshape(jac.shape[:-1] + (D, D)) + return jac.contiguous() def normalize_flow( diff --git a/src/deepali/core/functional.py b/src/deepali/core/functional.py index e064e55..6eb6add 100644 --- a/src/deepali/core/functional.py +++ b/src/deepali/core/functional.py @@ -102,7 +102,6 @@ from .flow import jacobian_det from .flow import jacobian_dict from .flow import jacobian_matrix -from .flow import jacobian_triu from .flow import normalize_flow from .flow import sample_flow from .flow import warp_grid @@ -217,7 +216,6 @@ "jacobian_det", "jacobian_dict", "jacobian_matrix", - "jacobian_triu", "max_pool", "min_pool", "normalize_flow", diff --git a/tests/test_core_flow_utils.py b/tests/test_core_flow_utils.py index 8cc4d7e..917bd99 100644 --- a/tests/test_core_flow_utils.py +++ b/tests/test_core_flow_utils.py @@ -162,7 +162,6 @@ def test_flow_derivatives() -> None: z = p.narrow(1, 2, 1) flow = torch.cat([x.mul(z), y.mul(z), x.mul(y)], dim=1) - deriv = U.flow_derivatives(flow, order=1) assert difference(deriv["du/dx"], z).abs().max().lt(1e-5) @@ -235,3 +234,149 @@ def test_flow_divergence_free() -> None: assert flow.shape == data.shape div = U.divergence(flow) assert div.abs().max().lt(1e-4) + + +def test_flow_jacobian() -> None: + # 2D flow field + p = Grid(size=(64, 32)).coords(channels_last=False).unsqueeze_(0) + + x = p.narrow(1, 0, 1) + y = p.narrow(1, 1, 1) + + interior = [slice(1, n - 1) for n in p.shape[2:]] + + # u = [x^2, xy] + flow = torch.cat([x.square(), x.mul(y)], dim=1) + + jac = torch.zeros((p.shape[0],) + p.shape[2:] + (2, 2)) + jac[..., 0, 0] = x.squeeze(1).mul(2) + jac[..., 1, 0] = y.squeeze(1) + jac[..., 1, 1] = x.squeeze(1) + + derivs = U.jacobian_dict(flow) + for (i, j), deriv in derivs.items(): + atol = 1e-5 + error = difference(jac[..., i, j].unsqueeze(1), deriv) + if (i, j) == (0, 0): + error = error[[slice(None), slice(None)] + interior] + if error.abs().max().gt(atol): + raise AssertionError(f"max absolute difference of jac[{i}, {j}] > {atol}") + + mat = U.jacobian_matrix(flow) + assert torch.allclose( + mat[[slice(None)] + interior], + jac[[slice(None)] + interior], + atol=1e-5, + ) + + jac[..., 0, 0] += 1 + jac[..., 1, 1] += 1 + + mat = U.jacobian_matrix(flow, add_identity=True) + assert torch.allclose( + mat[[slice(None)] + interior], + jac[[slice(None)] + interior], + atol=1e-5, + ) + + det = U.jacobian_det(flow) + assert torch.allclose( + det[[slice(None), 0] + interior], + jac[[slice(None)] + interior].det(), + atol=1e-5, + ) + + # 3D flow field + p = Grid(size=(64, 32, 16)).coords(channels_last=False).unsqueeze_(0) + + x = p.narrow(1, 0, 1) + y = p.narrow(1, 1, 1) + z = p.narrow(1, 2, 1) + + interior = [slice(1, n - 1) for n in p.shape[2:]] + + # u = [z^2, 0, xy] + flow = torch.cat([z.square(), torch.zeros_like(y), x.mul(y)], dim=1) + + jac = torch.zeros((p.shape[0],) + p.shape[2:] + (3, 3)) + jac[..., 0, 2] = z.squeeze(1).mul(2) + jac[..., 2, 0] = y.squeeze(1) + jac[..., 2, 1] = x.squeeze(1) + + derivs = U.jacobian_dict(flow) + for (i, j), deriv in derivs.items(): + atol = 1e-5 + error = difference(jac[..., i, j].unsqueeze(1), deriv) + if (i, j) == (0, 2): + error = error[[slice(None), slice(None)] + interior] + if error.abs().max().gt(atol): + raise AssertionError(f"max absolute difference of jac[{i}, {j}] > {atol}") + + mat = U.jacobian_matrix(flow) + assert torch.allclose( + mat[[slice(None)] + interior], + jac[[slice(None)] + interior], + atol=1e-5, + ) + + jac[..., 0, 0] += 1 + jac[..., 1, 1] += 1 + jac[..., 2, 2] += 1 + + mat = U.jacobian_matrix(flow, add_identity=True) + assert torch.allclose( + mat[[slice(None)] + interior], + jac[[slice(None)] + interior], + atol=1e-5, + ) + + det = U.jacobian_det(flow) + assert torch.allclose( + det[[slice(None), 0] + interior], + jac[[slice(None)] + interior].det(), + atol=1e-5, + ) + + # u = [0, x + y^3, yz] + flow = torch.cat([torch.zeros_like(x), x.add(y.pow(3)), y.mul(z)], dim=1) + + jac = torch.zeros((p.shape[0],) + p.shape[2:] + (3, 3)) + jac[..., 1, 0] = 1 + jac[..., 1, 1] = y.squeeze(1).square().mul(3) + jac[..., 2, 1] = z.squeeze(1) + jac[..., 2, 2] = y.squeeze(1) + + derivs = U.jacobian_dict(flow) + for (i, j), deriv in derivs.items(): + atol = 1e-5 + error = difference(jac[..., i, j].unsqueeze(1), deriv) + if (i, j) == (1, 1): + atol = 0.005 + error = error[[slice(None), slice(None)] + interior] + if error.abs().max().gt(atol): + raise AssertionError(f"max absolute difference of jac[{i}, {j}] > {atol}") + + mat = U.jacobian_matrix(flow) + assert torch.allclose( + mat[[slice(None)] + interior], + jac[[slice(None)] + interior], + atol=0.005, + ) + + jac[..., 0, 0] += 1 + jac[..., 1, 1] += 1 + jac[..., 2, 2] += 1 + + mat = U.jacobian_matrix(flow, add_identity=True) + assert torch.allclose( + mat[[slice(None)] + interior], + jac[[slice(None)] + interior], + atol=0.005, + ) + + det = U.jacobian_det(flow) + assert torch.allclose( + det[[slice(None), 0] + interior], + jac[[slice(None)] + interior].det(), + atol=0.01, + )