Skip to content

Commit

Permalink
[core] Properly fix compose_flows() this time
Browse files Browse the repository at this point in the history
  • Loading branch information
aschuh-hf committed Nov 7, 2023
1 parent 9a35d1b commit 121999f
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 51 deletions.
62 changes: 14 additions & 48 deletions src/deepali/core/flow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
r"""Functions relating to tensors representing vector fields."""

from itertools import combinations_with_replacement, permutations, product
from itertools import permutations, product
from typing import Dict, Optional, Sequence, Tuple, Union

import torch
Expand Down Expand Up @@ -56,10 +56,12 @@ def affine_flow(matrix: Tensor, grid: Union[Grid, Tensor], channels_last: bool =


def compose_flows(u: Tensor, v: Tensor, align_corners: bool = True) -> Tensor:
r"""Compute composite flow field ``w = v o u``."""
u = move_dim(u, 1, -1)
w = F.grid_sample(v, u, mode="bilinear", padding_mode="border", align_corners=align_corners)
return w
r"""Compute composite flow field ``w = v o u = u(x) + v(x + u(x))``."""
grid = Grid(shape=u.shape[2:], align_corners=align_corners)
x = grid.coords(channels_last=False, dtype=u.dtype, device=u.device)
x = move_dim(x.unsqueeze(0).add_(u), 1, -1)
v = F.grid_sample(v, x, mode="bilinear", padding_mode="border", align_corners=align_corners)
return u.add(v)


def curl(
Expand Down Expand Up @@ -462,9 +464,9 @@ def jacobian_dict(
for i in range(D):
deriv[FlowDerivativeKeys.symbol(i, i)].add_(1)
jac = {}
for i, j in combinations_with_replacement(range(D), 2):
for i, j in product(range(D), repeat=2):
jac[(i, j)] = deriv[FlowDerivativeKeys.symbol(i, j)]
return {(i, j): jac[(i, j) if i < j else (j, i)] for i, j in product(range(D), repeat=2)}
return jac


def jacobian_matrix(
Expand All @@ -491,55 +493,19 @@ def jacobian_matrix(
Full Jacobian matrices as tensor of shape ``(N, ..., X, D, D)``.
"""
jac = jacobian_dict(
flow,
mode=mode,
sigma=sigma,
spacing=spacing,
stride=stride,
add_identity=add_identity,
)
D = flow.ndim - 2
mat = torch.cat([jac[(i, j)] for i, j in product(range(D), repeat=2)], dim=1)
mat = move_dim(mat, 1, -1)
mat = mat.reshape(mat.shape[:-1] + (D, D))
return mat.contiguous()


def jacobian_triu(
flow: torch.Tensor,
mode: Optional[str] = None,
sigma: Optional[float] = None,
spacing: Optional[Union[Scalar, Array]] = None,
stride: Optional[ScalarOrTuple[int]] = None,
add_identity: bool = False,
) -> Tensor:
r"""Evaluate Jacobian of spatial deformation.
Args:
flow: Input vector field as tensor of shape ``(N, D, ..., X)``.
mode: Mode of :func:`flow_derivatives()` approximation.
sigma: Standard deviation of Gaussian used for computing spatial derivatives.
spacing: Physical size of image voxels used to compute spatial derivatives.
stride: Number of output grid points between control points plus one for ``mode='bspline'``.
add_identity: Whether to calculate derivatives of :math:`u(x)` (False) or the spatial
deformation given by :math:`x + u(x)` (True), where :math:`u` is the flow field,
by adding the identity matrix to the Jacobian of :math:`u`.
Returns:
Flattened upper triangular Jacobian matrices as tensor of shape ``(N, D * (D + 1) / 2, ..., X)``.
"""
jac = jacobian_dict(
deriv = jacobian_dict(
flow,
mode=mode,
sigma=sigma,
spacing=spacing,
stride=stride,
add_identity=add_identity,
)
D = flow.ndim - 2
return torch.cat([jac[(i, j)] for i, j in combinations_with_replacement(range(D), 2)], dim=1)
jac = torch.cat(list(deriv.values()), dim=1)
jac = move_dim(jac, 1, -1)
jac = jac.reshape(jac.shape[:-1] + (D, D))
return jac.contiguous()


def normalize_flow(
Expand Down
2 changes: 0 additions & 2 deletions src/deepali/core/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@
from .flow import jacobian_det
from .flow import jacobian_dict
from .flow import jacobian_matrix
from .flow import jacobian_triu
from .flow import normalize_flow
from .flow import sample_flow
from .flow import warp_grid
Expand Down Expand Up @@ -217,7 +216,6 @@
"jacobian_det",
"jacobian_dict",
"jacobian_matrix",
"jacobian_triu",
"max_pool",
"min_pool",
"normalize_flow",
Expand Down
147 changes: 146 additions & 1 deletion tests/test_core_flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def test_flow_derivatives() -> None:
z = p.narrow(1, 2, 1)

flow = torch.cat([x.mul(z), y.mul(z), x.mul(y)], dim=1)

deriv = U.flow_derivatives(flow, order=1)

assert difference(deriv["du/dx"], z).abs().max().lt(1e-5)
Expand Down Expand Up @@ -235,3 +234,149 @@ def test_flow_divergence_free() -> None:
assert flow.shape == data.shape
div = U.divergence(flow)
assert div.abs().max().lt(1e-4)


def test_flow_jacobian() -> None:
# 2D flow field
p = Grid(size=(64, 32)).coords(channels_last=False).unsqueeze_(0)

x = p.narrow(1, 0, 1)
y = p.narrow(1, 1, 1)

interior = [slice(1, n - 1) for n in p.shape[2:]]

# u = [x^2, xy]
flow = torch.cat([x.square(), x.mul(y)], dim=1)

jac = torch.zeros((p.shape[0],) + p.shape[2:] + (2, 2))
jac[..., 0, 0] = x.squeeze(1).mul(2)
jac[..., 1, 0] = y.squeeze(1)
jac[..., 1, 1] = x.squeeze(1)

derivs = U.jacobian_dict(flow)
for (i, j), deriv in derivs.items():
atol = 1e-5
error = difference(jac[..., i, j].unsqueeze(1), deriv)
if (i, j) == (0, 0):
error = error[[slice(None), slice(None)] + interior]
if error.abs().max().gt(atol):
raise AssertionError(f"max absolute difference of jac[{i}, {j}] > {atol}")

mat = U.jacobian_matrix(flow)
assert torch.allclose(
mat[[slice(None)] + interior],
jac[[slice(None)] + interior],
atol=1e-5,
)

jac[..., 0, 0] += 1
jac[..., 1, 1] += 1

mat = U.jacobian_matrix(flow, add_identity=True)
assert torch.allclose(
mat[[slice(None)] + interior],
jac[[slice(None)] + interior],
atol=1e-5,
)

det = U.jacobian_det(flow)
assert torch.allclose(
det[[slice(None), 0] + interior],
jac[[slice(None)] + interior].det(),
atol=1e-5,
)

# 3D flow field
p = Grid(size=(64, 32, 16)).coords(channels_last=False).unsqueeze_(0)

x = p.narrow(1, 0, 1)
y = p.narrow(1, 1, 1)
z = p.narrow(1, 2, 1)

interior = [slice(1, n - 1) for n in p.shape[2:]]

# u = [z^2, 0, xy]
flow = torch.cat([z.square(), torch.zeros_like(y), x.mul(y)], dim=1)

jac = torch.zeros((p.shape[0],) + p.shape[2:] + (3, 3))
jac[..., 0, 2] = z.squeeze(1).mul(2)
jac[..., 2, 0] = y.squeeze(1)
jac[..., 2, 1] = x.squeeze(1)

derivs = U.jacobian_dict(flow)
for (i, j), deriv in derivs.items():
atol = 1e-5
error = difference(jac[..., i, j].unsqueeze(1), deriv)
if (i, j) == (0, 2):
error = error[[slice(None), slice(None)] + interior]
if error.abs().max().gt(atol):
raise AssertionError(f"max absolute difference of jac[{i}, {j}] > {atol}")

mat = U.jacobian_matrix(flow)
assert torch.allclose(
mat[[slice(None)] + interior],
jac[[slice(None)] + interior],
atol=1e-5,
)

jac[..., 0, 0] += 1
jac[..., 1, 1] += 1
jac[..., 2, 2] += 1

mat = U.jacobian_matrix(flow, add_identity=True)
assert torch.allclose(
mat[[slice(None)] + interior],
jac[[slice(None)] + interior],
atol=1e-5,
)

det = U.jacobian_det(flow)
assert torch.allclose(
det[[slice(None), 0] + interior],
jac[[slice(None)] + interior].det(),
atol=1e-5,
)

# u = [0, x + y^3, yz]
flow = torch.cat([torch.zeros_like(x), x.add(y.pow(3)), y.mul(z)], dim=1)

jac = torch.zeros((p.shape[0],) + p.shape[2:] + (3, 3))
jac[..., 1, 0] = 1
jac[..., 1, 1] = y.squeeze(1).square().mul(3)
jac[..., 2, 1] = z.squeeze(1)
jac[..., 2, 2] = y.squeeze(1)

derivs = U.jacobian_dict(flow)
for (i, j), deriv in derivs.items():
atol = 1e-5
error = difference(jac[..., i, j].unsqueeze(1), deriv)
if (i, j) == (1, 1):
atol = 0.005
error = error[[slice(None), slice(None)] + interior]
if error.abs().max().gt(atol):
raise AssertionError(f"max absolute difference of jac[{i}, {j}] > {atol}")

mat = U.jacobian_matrix(flow)
assert torch.allclose(
mat[[slice(None)] + interior],
jac[[slice(None)] + interior],
atol=0.005,
)

jac[..., 0, 0] += 1
jac[..., 1, 1] += 1
jac[..., 2, 2] += 1

mat = U.jacobian_matrix(flow, add_identity=True)
assert torch.allclose(
mat[[slice(None)] + interior],
jac[[slice(None)] + interior],
atol=0.005,
)

det = U.jacobian_det(flow)
assert torch.allclose(
det[[slice(None), 0] + interior],
jac[[slice(None)] + interior].det(),
atol=0.01,
)

0 comments on commit 121999f

Please sign in to comment.