Skip to content

Commit

Permalink
add padding to non divisible images
Browse files Browse the repository at this point in the history
  • Loading branch information
singam96 committed Nov 2, 2024
1 parent e354c58 commit 6031b85
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
33 changes: 32 additions & 1 deletion terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from timm.models import FeatureInfo
from timm.models._builder import build_model_with_cfg
from timm.models._registry import generate_default_cfgs, register_model
from torch import nn
from torch import nn, Tensor

from terratorch.datasets import HLSBands
from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights
Expand Down Expand Up @@ -66,6 +66,16 @@ def checkpoint_filter_fn(

return state_dict


def pad_images(imgs: Tensor,patch_size: int, padding:str) -> Tensor:
p = patch_size
# h, w = imgs.shape[3], imgs.shape[4]
t, h, w = imgs.shape[-3:]
h_pad, w_pad = (p - h % p) % p, (p - w % p) % p # Ensure padding is within bounds
if h_pad > 0 or w_pad > 0:
imgs = nn.functional.pad(imgs, (0, w_pad, 0, h_pad), mode=padding)
return imgs

def _create_prithvi(
variant: str,
pretrained: bool = False, # noqa: FBT001, FBT002
Expand All @@ -76,6 +86,9 @@ def _create_prithvi(
if pretrained_bands is None:
pretrained_bands = PRETRAINED_BANDS

padding = kwargs.get("padding", "none")
patch_size = kwargs.get("patch_size", 16)

# Little hack because VIT does not support timm's features_only
# so we do it ourselves
encoder_only = kwargs.get("features_only", False)
Expand Down Expand Up @@ -113,6 +126,24 @@ def forward_filter_indices(*args, **kwargs):
model.model_bands = model_bands
model.pretrained_bands = pretrained_bands

if padding != "none":
original_forward = model.forward
original_forward_features = model.forward_features

def pad_and_forward(forward_fn, patch_size, padding, *args, **kwargs):
inputs = pad_images(args[0], patch_size, padding)
return forward_fn(inputs, **kwargs)

def forward_pad_images(*args, **kwargs):
return pad_and_forward(original_forward, patch_size, padding, *args, **kwargs)

def forward_features_pad_images(*args, **kwargs):
return pad_and_forward(original_forward_features, patch_size, padding, *args, **kwargs)

model.forward = forward_pad_images
model.forward_features = forward_features_pad_images


return model

def create_prithvi_vit_100(
Expand Down
15 changes: 14 additions & 1 deletion tests/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal):

@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"])
def test_vit_models_non_divisible_input(model_name, input_non_divisible):
backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES)
#padding 'none','constant', 'reflect', 'replicate' or 'circular' default is 'none'
backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES,padding='constant')
backbone(input_non_divisible)

@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"])
Expand Down Expand Up @@ -105,6 +106,18 @@ def test_out_indices(model_name, input_224):
for filtered_index, full_index in enumerate(out_indices):
assert torch.allclose(full_output[full_index], output[filtered_index])

@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"])
def test_out_indices_non_divisible(model_name, input_non_divisible):
out_indices = [2, 4, 8, 10]
backbone = timm.create_model(model_name, pretrained=False, features_only=True, num_frames=NUM_FRAMES, out_indices=out_indices, padding='constant')
assert backbone.feature_info.out_indices == out_indices

output = backbone(input_non_divisible)
full_output = backbone.forward_features(input_non_divisible)

for filtered_index, full_index in enumerate(out_indices):
assert torch.allclose(full_output[full_index], output[filtered_index])

@pytest.mark.parametrize("model_name", ["vit_base_patch16", "vit_large_patch16"])
def test_scale_mae(model_name):
out_indices = [2, 4, 8, 10]
Expand Down

0 comments on commit 6031b85

Please sign in to comment.