Skip to content

Commit

Permalink
start using a specific linear attention for channel modulation, which…
Browse files Browse the repository at this point in the history
… should be superior to squeeze excites
  • Loading branch information
lucidrains committed Nov 16, 2023
1 parent bf63284 commit a0b107d
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 3 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,13 @@ surface_preds, hrrr_pred, precipitation_preds = metnet3(
url = {https://api.semanticscholar.org/CorpusID:259129311}
}
```

```bibtex
@inproceedings{ElNouby2021XCiTCI,
title = {XCiT: Cross-Covariance Image Transformers},
author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
booktitle = {Neural Information Processing Systems},
year = {2021},
url = {https://api.semanticscholar.org/CorpusID:235458262}
}
```
109 changes: 108 additions & 1 deletion metnet3_pytorch/metnet3_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from contextlib import contextmanager
from functools import partial
from collections import namedtuple
from contextlib import contextmanager

import torch
from torch import nn, Tensor, einsum
Expand Down Expand Up @@ -38,6 +38,11 @@ def cast_tuple(val, length = 1):
def safe_div(num, den, eps = 1e-10):
return num / den.clamp(min = eps)

# tensor helpers

def l2norm(t):
return F.normalize(t, dim = -1)

# prepare batch norm in maxvit for distributed training

def MaybeSyncBatchnorm2d(is_distributed = None):
Expand Down Expand Up @@ -328,6 +333,91 @@ def MBConv(

# attention related classes

class XCAttention(Module):
"""
this specific linear attention was proposed in https://arxiv.org/abs/2106.09681 (El-Nouby et al.)
"""

@beartype
def __init__(
self,
*,
dim,
cond_dim: Optional[int] = None,
dim_head = 32,
heads = 8,
scale = 8,
flash = False,
dropout = 0.
):
super().__init__()
dim_inner = dim_head * heads

self.has_cond = exists(cond_dim)

self.film = None

if self.has_cond:
self.film = Sequential(
nn.Linear(cond_dim, dim * 2),
nn.SiLU(),
nn.Linear(dim * 2, dim * 2),
Rearrange('b (r d) -> r b 1 d', r = 2)
)

self.norm = nn.LayerNorm(dim, elementwise_affine = not self.has_cond)

self.to_qkv = Sequential(
nn.Linear(dim, dim_inner * 3, bias = False),
Rearrange('b n (qkv h d) -> qkv b h d n', qkv = 3, h = heads)
)

self.scale = scale

self.temperature = nn.Parameter(torch.ones(heads, 1, 1))

self.attn_dropout = nn.Dropout(dropout)

self.to_out = Sequential(
Rearrange('b h d n -> b n (h d)'),
nn.Linear(dim_inner, dim)
)

def forward(
self,
x,
cond: Optional[Tensor] = None
):
x = rearrange(x, 'b c h w -> b h w c')
x, ps = pack_one(x, 'b * c')

x = self.norm(x)

# conditioning

if exists(self.film):
assert exists(cond)

gamma, beta = self.film(cond)
x = x * gamma + beta

# cosine sim linear attention

q, k, v = self.to_qkv(x)

q, k = map(l2norm, (q, k))
q = q * self.temperature.exp()

sim = einsum('b h i n, b h j n -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)

out = einsum('b h i j, b h j n -> b h i n', attn, v)

out = self.to_out(out)

out = unpack_one(out, ps, 'b * c')
return rearrange(out, 'b h w c -> b c h w')

class Attention(Module):
def __init__(
self,
Expand Down Expand Up @@ -656,6 +746,14 @@ def __init__(
depth = resnet_block_depth
)

self.linear_attn_4km = XCAttention(
dim = dim,
cond_dim = lead_time_embed_dim,
dim_head = attn_dim_head,
heads = attn_heads,
dropout = attn_dropout
)

self.downsample_and_pad_to_8km = Sequential(
Downsample2x(),
CenterPad(input_spatial_size)
Expand Down Expand Up @@ -697,6 +795,11 @@ def __init__(
depth = resnet_block_depth
)

self.linear_attn_8km = XCAttention(
dim = dim,
cond_dim = lead_time_embed_dim
)

self.upsample_8km_to_4km = Upsample2x(dim)

self.crop_post_4km = CenterCrop(surface_and_hrrr_target_spatial_size)
Expand Down Expand Up @@ -863,6 +966,8 @@ def forward(

x = self.resnet_blocks_down_4km(x, cond = cond)

x = self.linear_attn_4km(x, cond = cond) + x

x = self.downsample_and_pad_to_8km(x)

x = torch.cat((input_4996, x), dim = 1)
Expand All @@ -883,6 +988,8 @@ def forward(

x = self.resnet_blocks_up_8km(x, cond = cond)

x = self.linear_attn_8km(x, cond = cond) + x

x = self.upsample_8km_to_4km(x)

x = torch.cat((skip_connect_4km, x), dim = 1)
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'metnet3-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.11',
version = '0.0.12',
license='MIT',
description = 'MetNet 3 - Pytorch',
author = 'Phil Wang',
Expand All @@ -20,7 +20,7 @@
install_requires=[
'beartype',
'einops>=0.7.0',
'torch>=1.6',
'torch>=2.0',
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit a0b107d

Please sign in to comment.