diff --git a/src/deepali/core/flow.py b/src/deepali/core/flow.py index cea7054..1a5b412 100644 --- a/src/deepali/core/flow.py +++ b/src/deepali/core/flow.py @@ -55,11 +55,11 @@ def affine_flow(matrix: Tensor, grid: Union[Grid, Tensor], channels_last: bool = return flow -def compose_flows(a: Tensor, b: Tensor, align_corners: bool = True) -> Tensor: - r"""Compute composite flow field ``c = b o a``.""" - a = move_dim(b, 1, -1) - c = F.grid_sample(b, a, mode="bilinear", padding_mode="border", align_corners=align_corners) - return c +def compose_flows(u: Tensor, v: Tensor, align_corners: bool = True) -> Tensor: + r"""Compute composite flow field ``w = v o u``.""" + u = move_dim(u, 1, -1) + w = F.grid_sample(v, u, mode="bilinear", padding_mode="border", align_corners=align_corners) + return w def curl(