From a0b107d7f4b792f612341a33b63b54f79200f998 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 16 Nov 2023 08:50:14 -0800 Subject: [PATCH] start using a specific linear attention for channel modulation, which should be superior to squeeze excites --- README.md | 10 +++ metnet3_pytorch/metnet3_pytorch.py | 109 ++++++++++++++++++++++++++++- setup.py | 4 +- 3 files changed, 120 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index e1275b5..7555e9e 100644 --- a/README.md +++ b/README.md @@ -125,3 +125,13 @@ surface_preds, hrrr_pred, precipitation_preds = metnet3( url = {https://api.semanticscholar.org/CorpusID:259129311} } ``` + +```bibtex +@inproceedings{ElNouby2021XCiTCI, + title = {XCiT: Cross-Covariance Image Transformers}, + author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou}, + booktitle = {Neural Information Processing Systems}, + year = {2021}, + url = {https://api.semanticscholar.org/CorpusID:235458262} +} +``` diff --git a/metnet3_pytorch/metnet3_pytorch.py b/metnet3_pytorch/metnet3_pytorch.py index 7a0af9c..5d84fa9 100644 --- a/metnet3_pytorch/metnet3_pytorch.py +++ b/metnet3_pytorch/metnet3_pytorch.py @@ -1,7 +1,7 @@ from pathlib import Path -from contextlib import contextmanager from functools import partial from collections import namedtuple +from contextlib import contextmanager import torch from torch import nn, Tensor, einsum @@ -38,6 +38,11 @@ def cast_tuple(val, length = 1): def safe_div(num, den, eps = 1e-10): return num / den.clamp(min = eps) +# tensor helpers + +def l2norm(t): + return F.normalize(t, dim = -1) + # prepare batch norm in maxvit for distributed training def MaybeSyncBatchnorm2d(is_distributed = None): @@ -328,6 +333,91 @@ def MBConv( # attention related classes +class XCAttention(Module): + """ + this specific linear attention was proposed in https://arxiv.org/abs/2106.09681 (El-Nouby et al.) + """ + + @beartype + def __init__( + self, + *, + dim, + cond_dim: Optional[int] = None, + dim_head = 32, + heads = 8, + scale = 8, + flash = False, + dropout = 0. + ): + super().__init__() + dim_inner = dim_head * heads + + self.has_cond = exists(cond_dim) + + self.film = None + + if self.has_cond: + self.film = Sequential( + nn.Linear(cond_dim, dim * 2), + nn.SiLU(), + nn.Linear(dim * 2, dim * 2), + Rearrange('b (r d) -> r b 1 d', r = 2) + ) + + self.norm = nn.LayerNorm(dim, elementwise_affine = not self.has_cond) + + self.to_qkv = Sequential( + nn.Linear(dim, dim_inner * 3, bias = False), + Rearrange('b n (qkv h d) -> qkv b h d n', qkv = 3, h = heads) + ) + + self.scale = scale + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.attn_dropout = nn.Dropout(dropout) + + self.to_out = Sequential( + Rearrange('b h d n -> b n (h d)'), + nn.Linear(dim_inner, dim) + ) + + def forward( + self, + x, + cond: Optional[Tensor] = None + ): + x = rearrange(x, 'b c h w -> b h w c') + x, ps = pack_one(x, 'b * c') + + x = self.norm(x) + + # conditioning + + if exists(self.film): + assert exists(cond) + + gamma, beta = self.film(cond) + x = x * gamma + beta + + # cosine sim linear attention + + q, k, v = self.to_qkv(x) + + q, k = map(l2norm, (q, k)) + q = q * self.temperature.exp() + + sim = einsum('b h i n, b h j n -> b h i j', q, k) * self.scale + attn = sim.softmax(dim = -1) + + out = einsum('b h i j, b h j n -> b h i n', attn, v) + + out = self.to_out(out) + + out = unpack_one(out, ps, 'b * c') + return rearrange(out, 'b h w c -> b c h w') + class Attention(Module): def __init__( self, @@ -656,6 +746,14 @@ def __init__( depth = resnet_block_depth ) + self.linear_attn_4km = XCAttention( + dim = dim, + cond_dim = lead_time_embed_dim, + dim_head = attn_dim_head, + heads = attn_heads, + dropout = attn_dropout + ) + self.downsample_and_pad_to_8km = Sequential( Downsample2x(), CenterPad(input_spatial_size) @@ -697,6 +795,11 @@ def __init__( depth = resnet_block_depth ) + self.linear_attn_8km = XCAttention( + dim = dim, + cond_dim = lead_time_embed_dim + ) + self.upsample_8km_to_4km = Upsample2x(dim) self.crop_post_4km = CenterCrop(surface_and_hrrr_target_spatial_size) @@ -863,6 +966,8 @@ def forward( x = self.resnet_blocks_down_4km(x, cond = cond) + x = self.linear_attn_4km(x, cond = cond) + x + x = self.downsample_and_pad_to_8km(x) x = torch.cat((input_4996, x), dim = 1) @@ -883,6 +988,8 @@ def forward( x = self.resnet_blocks_up_8km(x, cond = cond) + x = self.linear_attn_8km(x, cond = cond) + x + x = self.upsample_8km_to_4km(x) x = torch.cat((skip_connect_4km, x), dim = 1) diff --git a/setup.py b/setup.py index 4f6e21b..0b36b63 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'metnet3-pytorch', packages = find_packages(exclude=[]), - version = '0.0.11', + version = '0.0.12', license='MIT', description = 'MetNet 3 - Pytorch', author = 'Phil Wang', @@ -20,7 +20,7 @@ install_requires=[ 'beartype', 'einops>=0.7.0', - 'torch>=1.6', + 'torch>=2.0', ], classifiers=[ 'Development Status :: 4 - Beta',