From ac5eae64dd248f13cc9de3a22cbcfe4688d79fa8 Mon Sep 17 00:00:00 2001 From: ChristophReich1996 Date: Thu, 11 Jul 2024 13:51:21 +0200 Subject: [PATCH] Add option to initialize flow pred. --- torchvision/models/optical_flow/raft.py | 67 +++++++++++++------------ 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index c294777ee6f..55ac65b5ec6 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -14,7 +14,6 @@ from .._utils import handle_legacy_interface from ._utils import grid_sample, make_coords_grid, upsample_flow - __all__ = ( "RAFT", "raft_large", @@ -120,7 +119,7 @@ class FeatureEncoder(nn.Module): """ def __init__( - self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), strides=(2, 1, 2, 2), norm_layer=nn.BatchNorm2d + self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), strides=(2, 1, 2, 2), norm_layer=nn.BatchNorm2d ): super().__init__() @@ -149,7 +148,7 @@ def __init__( num_downsamples = len(list(filter(lambda s: s == 2, strides))) self.output_dim = layers[-1] - self.downsample_factor = 2**num_downsamples + self.downsample_factor = 2 ** num_downsamples def _make_2_blocks(self, block, in_channels, out_channels, norm_layer, first_stride): block1 = block(in_channels, out_channels, norm_layer=norm_layer, stride=first_stride) @@ -481,13 +480,16 @@ def __init__(self, *, feature_encoder, context_encoder, corr_block, update_block if not hasattr(self.update_block, "hidden_state_size"): raise ValueError("The update_block parameter should expose a 'hidden_state_size' attribute.") - def forward(self, image1, image2, num_flow_updates: int = 12): + def forward(self, image1, image2, num_flow_updates: int = 12, flow_init: Optional[Tensor] = None): batch_size, _, h, w = image1.shape if (h, w) != image2.shape[-2:]: raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}") if not (h % 8 == 0) and (w % 8 == 0): raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)") + if (flow_init is not None) and ((batch_size, 2, h // 8, w // 8) != flow_init.shape): + raise ValueError( + f"initial optical flow must have the shape ({batch_size}, 2, {h // 8}, {w // 8}), instead got {flow_init.shape}") fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0)) fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0) @@ -516,6 +518,9 @@ def forward(self, image1, image2, num_flow_updates: int = 12): coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) + if flow_init is not None: + coords1 = coords1 + flow_init + flow_predictions = [] for _ in range(num_flow_updates): coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper @@ -754,33 +759,33 @@ class Raft_Small_Weights(WeightsEnum): def _raft( - *, - weights=None, - progress=False, - # Feature encoder - feature_encoder_layers, - feature_encoder_block, - feature_encoder_norm_layer, - # Context encoder - context_encoder_layers, - context_encoder_block, - context_encoder_norm_layer, - # Correlation block - corr_block_num_levels, - corr_block_radius, - # Motion encoder - motion_encoder_corr_layers, - motion_encoder_flow_layers, - motion_encoder_out_channels, - # Recurrent block - recurrent_block_hidden_state_size, - recurrent_block_kernel_size, - recurrent_block_padding, - # Flow Head - flow_head_hidden_size, - # Mask predictor - use_mask_predictor, - **kwargs, + *, + weights=None, + progress=False, + # Feature encoder + feature_encoder_layers, + feature_encoder_block, + feature_encoder_norm_layer, + # Context encoder + context_encoder_layers, + context_encoder_block, + context_encoder_norm_layer, + # Correlation block + corr_block_num_levels, + corr_block_radius, + # Motion encoder + motion_encoder_corr_layers, + motion_encoder_flow_layers, + motion_encoder_out_channels, + # Recurrent block + recurrent_block_hidden_state_size, + recurrent_block_kernel_size, + recurrent_block_padding, + # Flow Head + flow_head_hidden_size, + # Mask predictor + use_mask_predictor, + **kwargs, ): feature_encoder = kwargs.pop("feature_encoder", None) or FeatureEncoder( block=feature_encoder_block, layers=feature_encoder_layers, norm_layer=feature_encoder_norm_layer