Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Return all features in Clay v1 #245

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions terratorch/models/backbones/clay_v1/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,23 @@


class Embedder(nn.Module):
default_out_indices = (0,) # Single out_indices for simplicity

def __init__(self,
img_size=256,
num_frames=1,
ckpt_path=None,
bands=["blue", "green", "red", "nir", "swir16", "swir22"],
**kwargs):
default_out_indices = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)

def __init__(
self,
img_size=256,
num_frames=1,
ckpt_path=None,
bands=["blue", "green", "red", "nir", "swir16", "swir22"],
out_indices: tuple[int] = default_out_indices,
**kwargs,
):
super().__init__()
self.feature_info = []
self.img_size = img_size
self.num_frames = num_frames
self.bands = bands
self.out_indices = out_indices

if kwargs.get("datacuber", True) is not None:
self.datacuber = Datacuber(bands=bands)
Expand All @@ -55,8 +59,9 @@ def __init__(self,
)
)

# for use in features list. Single layer feature for simplicity
self.feature_info.append({"num_chs": 768, "reduction": 1, "module": "clay_encoder"})
# for use in features list.
for i in range(12):
self.feature_info.append({"num_chs": 768, "reduction": 1, "module": f"blocks.{i}"})

# assuming this is used to fine tune a network on top of the embeddings

Expand Down Expand Up @@ -103,8 +108,7 @@ def forward_features(self, x):
datacube = x
embeddings = self.clay_encoder(datacube)

# TODO: actually return features individually
return [embeddings]
return [embeddings[i] for i in self.out_indices]

def fake_datacube(self):
"Generate a fake datacube for model export."
Expand Down
12 changes: 8 additions & 4 deletions terratorch/models/backbones/clay_v1/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,15 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim):
FeedForward(dim, mlp_dim)
]))

def forward(self, x):
def forward(self, x) -> list[torch.Tensor]:
out = []
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
out.append(x.clone())
x = self.norm(x)
out[-1] = x.clone()
return out


class Encoder(nn.Module):
Expand Down Expand Up @@ -368,12 +372,12 @@ def forward(self, datacube):
patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D]

# pass the patches through the transformer
patches = self.transformer(patches) # [B (1 + L) D]
patches = self.transformer(patches) # list of [B (1 + L) D]

# # remove the cls token
# embeddings = patches[:, 1: , :] # [B L D]

return patches # [B (1 + L) D]
return patches # list [B (1 + L) D]


class FCBlock(nn.Module):
Expand Down
Loading