Skip to content

Commit

Permalink
Add isolatiude experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
hlinander committed Nov 5, 2024
1 parent 3d77aed commit f29a802
Show file tree
Hide file tree
Showing 16 changed files with 2,618 additions and 18 deletions.
4 changes: 2 additions & 2 deletions experiments/weather/models/hp_windowing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch


def window_partition(x: torch.Tensor, window_size):
def window_partition(x: torch.Tensor, window_size, device):
"""
Args:
x: (B, D, N, C)
Expand All @@ -28,7 +28,7 @@ def window_partition(x: torch.Tensor, window_size):
return windows


def window_reverse(windows, window_size, D, N):
def window_reverse(windows, window_size, D, N, device):
"""
Args:
windows: (num_windows*B, window_size, C)
Expand Down
10 changes: 5 additions & 5 deletions experiments/weather/models/hp_windowing_isolatitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_pad_windows(nside):
assert (new - data_pre).sum() == 0.0


def window_reverse(windows, window_size, D, N):
def window_reverse(windows, window_size, D, N, device):
window_size_d, window_size_hp = window_size
nside = healpix.npix2nside(N)
C = windows.shape[-1]
Expand All @@ -99,7 +99,7 @@ def window_reverse(windows, window_size, D, N):
interspersed = flattened_interspersed(nside, window_size_hp, hp_windows)
padded_windows = pad_windows(window_size_hp, interspersed)

indices = torch.tensor(padded_windows)
indices = torch.tensor(padded_windows, device=device)

Nw, W = indices.shape

Expand All @@ -111,21 +111,21 @@ def window_reverse(windows, window_size, D, N):
# 0 1 2 3 4 5
x = x.contiguous().view(B, D, Nw, W, C)

new = torch.zeros((B, D, N, C))
new = torch.zeros((B, D, N, C), device=device)
new[:, :, indices, :] = x

return new


def window_partition(x: torch.Tensor, window_size):
def window_partition(x: torch.Tensor, window_size, device):
window_size_d, window_size_hp = window_size

nside = healpix.npix2nside(x.shape[2])
hp_windows = get_isolatitude_windows_hp(nside)
interspersed = flattened_interspersed(nside, window_size_hp, hp_windows)
padded_windows = pad_windows(window_size_hp, interspersed)

indices = torch.tensor(padded_windows)
indices = torch.tensor(padded_windows, device=device)
windowed = x[:, :, indices, :]

B, D, Nw, W, C = windowed.shape
Expand Down
16 changes: 13 additions & 3 deletions experiments/weather/models/swin_hp_pangu.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,11 @@ def __init__(
else:
self.shifter = hp_shifting.NoShift()

attn_mask = self.shifter.get_mask()
attn_mask = self.shifter.get_mask(
lambda x, window_size: window_partition(
x, window_size, device=next(self.parameters()).device
)
)

self.register_buffer("attn_mask", attn_mask)

Expand Down Expand Up @@ -334,13 +338,19 @@ def forward(self, x):

# partition windows
x_windows = window_partition(
shifted_x, self.window_size
# shifted_x, self.window_size
shifted_x,
self.window_size,
device=next(self.parameters()).device,
) # nW*B, window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size, C

# merge windows
shifted_x = window_reverse(attn_windows, self.window_size, D, N) # B N' C
# shifted_x = window_reverse(attn_windows, self.window_size, D, N) # B N' C
shifted_x = window_reverse(
attn_windows, self.window_size, D, N, device=next(self.parameters()).device
) # B N' C

# reverse cyclic shift
x = self.shifter.shift_back(shifted_x)
Expand Down
12 changes: 9 additions & 3 deletions experiments/weather/models/swin_hp_pangu_isolatitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,11 @@ def __init__(
else:
self.shifter = hp_shifting.NoShift()

attn_mask = self.shifter.get_mask(window_partition)
attn_mask = self.shifter.get_mask(
lambda x, window_size: window_partition(
x, window_size, device=next(self.parameters()).device
)
)

self.register_buffer("attn_mask", attn_mask)

Expand Down Expand Up @@ -333,13 +337,15 @@ def forward(self, x):

# partition windows
x_windows = window_partition(
shifted_x, self.window_size
shifted_x, self.window_size, device=next(self.parameters()).device
) # nW*B, window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size, C

# merge windows
shifted_x = window_reverse(attn_windows, self.window_size, D, N) # B N' C
shifted_x = window_reverse(
attn_windows, self.window_size, D, N, device=next(self.parameters()).device
) # B N' C

# reverse cyclic shift
x = self.shifter.shift_back(shifted_x)
Expand Down
Loading

0 comments on commit f29a802

Please sign in to comment.