From 54c10767da49573fd9fa22dc0b093bfadad224ca Mon Sep 17 00:00:00 2001 From: Sachin Date: Sat, 2 Nov 2024 06:45:40 +0530 Subject: [PATCH] add padding to non divisible images Signed-off-by: Sachin --- terratorch/models/backbones/prithvi_vit.py | 33 +++++++++++++++++++++- tests/test_backbones.py | 15 +++++++++- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index bd0e708e..ecb2a973 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -9,7 +9,7 @@ from timm.models import FeatureInfo from timm.models._builder import build_model_with_cfg from timm.models._registry import generate_default_cfgs, register_model -from torch import nn +from torch import nn, Tensor from terratorch.datasets import HLSBands from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights @@ -66,6 +66,16 @@ def checkpoint_filter_fn( return state_dict + +def pad_images(imgs: Tensor,patch_size: int, padding:str) -> Tensor: + p = patch_size + # h, w = imgs.shape[3], imgs.shape[4] + t, h, w = imgs.shape[-3:] + h_pad, w_pad = (p - h % p) % p, (p - w % p) % p # Ensure padding is within bounds + if h_pad > 0 or w_pad > 0: + imgs = nn.functional.pad(imgs, (0, w_pad, 0, h_pad), mode=padding) + return imgs + def _create_prithvi( variant: str, pretrained: bool = False, # noqa: FBT001, FBT002 @@ -76,6 +86,9 @@ def _create_prithvi( if pretrained_bands is None: pretrained_bands = PRETRAINED_BANDS + padding = kwargs.get("padding", "none") + patch_size = kwargs.get("patch_size", 16) + # Little hack because VIT does not support timm's features_only # so we do it ourselves encoder_only = kwargs.get("features_only", False) @@ -113,6 +126,24 @@ def forward_filter_indices(*args, **kwargs): model.model_bands = model_bands model.pretrained_bands = pretrained_bands + if padding != "none": + original_forward = model.forward + original_forward_features = model.forward_features + + def pad_and_forward(forward_fn, patch_size, padding, *args, **kwargs): + inputs = pad_images(args[0], patch_size, padding) + return forward_fn(inputs, **kwargs) + + def forward_pad_images(*args, **kwargs): + return pad_and_forward(original_forward, patch_size, padding, *args, **kwargs) + + def forward_features_pad_images(*args, **kwargs): + return pad_and_forward(original_forward_features, patch_size, padding, *args, **kwargs) + + model.forward = forward_pad_images + model.forward_features = forward_features_pad_images + + return model def create_prithvi_vit_100( diff --git a/tests/test_backbones.py b/tests/test_backbones.py index 12ae3f9f..480599d8 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -57,7 +57,8 @@ def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal): @pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"]) def test_vit_models_non_divisible_input(model_name, input_non_divisible): - backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES) + #padding 'none','constant', 'reflect', 'replicate' or 'circular' default is 'none' + backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES,padding='constant') backbone(input_non_divisible) @pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"]) @@ -105,6 +106,18 @@ def test_out_indices(model_name, input_224): for filtered_index, full_index in enumerate(out_indices): assert torch.allclose(full_output[full_index], output[filtered_index]) +@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"]) +def test_out_indices_non_divisible(model_name, input_non_divisible): + out_indices = [2, 4, 8, 10] + backbone = timm.create_model(model_name, pretrained=False, features_only=True, num_frames=NUM_FRAMES, out_indices=out_indices, padding='constant') + assert backbone.feature_info.out_indices == out_indices + + output = backbone(input_non_divisible) + full_output = backbone.forward_features(input_non_divisible) + + for filtered_index, full_index in enumerate(out_indices): + assert torch.allclose(full_output[full_index], output[filtered_index]) + @pytest.mark.parametrize("model_name", ["vit_base_patch16", "vit_large_patch16"]) def test_scale_mae(model_name): out_indices = [2, 4, 8, 10]