From c9133b51dbf23eee9560cb38558e0aa0f7c6cf2a Mon Sep 17 00:00:00 2001 From: Mikael Brudfors Date: Fri, 22 Sep 2023 15:22:55 +0100 Subject: [PATCH 1/4] Allow for defining reference grid on non-integer coordinates Signed-off-by: Mikael Brudfors --- monai/networks/blocks/warp.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 10a115eff8..5221e580b5 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -32,7 +32,7 @@ class Warp(nn.Module): Warp an image with given dense displacement field (DDF). """ - def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.BORDER.value): + def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.BORDER.value, jitter=False): """ For pytorch native APIs, the possible values are: @@ -47,6 +47,11 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa - padding_mode: ``"zeros"``, ``"border"``, ``"reflection"``, 0, 1, ... See also: :py:class:`monai.networks.layers.grid_pull` + + - jitter: bool, default=False + Define reference grid on non-integer values + Reference: B. Likar and F. Pernus. A heirarchical approach to elastic registration + based on mutual information. Image and Vision Computing, 19:33-44, 2001. """ super().__init__() # resolves _interp_mode for different methods @@ -84,8 +89,9 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa self._padding_mode = GridSamplePadMode(padding_mode).value self.ref_grid = None + self.jitter = jitter - def get_reference_grid(self, ddf: torch.Tensor) -> torch.Tensor: + def get_reference_grid(self, ddf: torch.Tensor, jitter: bool=False, seed: int=0) -> torch.Tensor: if ( self.ref_grid is not None and self.ref_grid.shape[0] == ddf.shape[0] @@ -96,6 +102,11 @@ def get_reference_grid(self, ddf: torch.Tensor) -> torch.Tensor: grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...) grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...) self.ref_grid = grid.to(ddf) + if jitter: + # Define reference grid on non-integer values + with torch.random.fork_rng(enabled=seed): + torch.random.manual_seed(seed) + grid += torch.rand_like(grid) self.ref_grid.requires_grad = False return self.ref_grid @@ -117,7 +128,7 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor): f"Given input {spatial_dims}-d image shape {image.shape}, the input DDF shape must be {ddf_shape}, " f"Got {ddf.shape} instead." ) - grid = self.get_reference_grid(ddf) + ddf + grid = self.get_reference_grid(ddf, jitter=self.jitter) + ddf grid = grid.permute([0] + list(range(2, 2 + spatial_dims)) + [1]) # (batch, ..., spatial_dims) if not USE_COMPILED: # pytorch native grid_sample @@ -144,7 +155,7 @@ class DVF2DDF(nn.Module): """ def __init__( - self, num_steps: int = 7, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.ZEROS.value + self, num_steps: int = 7, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.ZEROS.value, ): super().__init__() if num_steps <= 0: From ea8712601c68887e33840ef61f0eb5b2e2d1cdb2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Sep 2023 14:28:06 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/warp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 5221e580b5..9a7b46af7a 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -47,10 +47,10 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa - padding_mode: ``"zeros"``, ``"border"``, ``"reflection"``, 0, 1, ... See also: :py:class:`monai.networks.layers.grid_pull` - + - jitter: bool, default=False Define reference grid on non-integer values - Reference: B. Likar and F. Pernus. A heirarchical approach to elastic registration + Reference: B. Likar and F. Pernus. A heirarchical approach to elastic registration based on mutual information. Image and Vision Computing, 19:33-44, 2001. """ super().__init__() @@ -104,7 +104,7 @@ def get_reference_grid(self, ddf: torch.Tensor, jitter: bool=False, seed: int=0) self.ref_grid = grid.to(ddf) if jitter: # Define reference grid on non-integer values - with torch.random.fork_rng(enabled=seed): + with torch.random.fork_rng(enabled=seed): torch.random.manual_seed(seed) grid += torch.rand_like(grid) self.ref_grid.requires_grad = False From 44627a8f2005384de88bdc22dc68d6ee938b3835 Mon Sep 17 00:00:00 2001 From: Mikael Brudfors Date: Fri, 22 Sep 2023 15:31:44 +0100 Subject: [PATCH 3/4] Removed unecessary comma Signed-off-by: Mikael Brudfors --- monai/networks/blocks/warp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 9a7b46af7a..f2a2f63f47 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -155,7 +155,7 @@ class DVF2DDF(nn.Module): """ def __init__( - self, num_steps: int = 7, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.ZEROS.value, + self, num_steps: int = 7, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.ZEROS.value ): super().__init__() if num_steps <= 0: From e8241089a1f32ab7e07d87b819f5907af5dfb022 Mon Sep 17 00:00:00 2001 From: Mikael Brudfors Date: Fri, 22 Sep 2023 18:29:09 +0200 Subject: [PATCH 4/4] Lint fix DCO Remediation Commit for Mikael Brudfors I, Mikael Brudfors , hereby add my Signed-off-by to this commit: 44627a8f2005384de88bdc22dc68d6ee938b3835 Signed-off-by: Mikael Brudfors --- monai/networks/blocks/warp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index f2a2f63f47..8a570e42c4 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -91,7 +91,7 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa self.ref_grid = None self.jitter = jitter - def get_reference_grid(self, ddf: torch.Tensor, jitter: bool=False, seed: int=0) -> torch.Tensor: + def get_reference_grid(self, ddf: torch.Tensor, jitter: bool = False, seed: int = 0) -> torch.Tensor: if ( self.ref_grid is not None and self.ref_grid.shape[0] == ddf.shape[0]