Skip to content

Commit

Permalink
first pass at inputs, redo hrrr
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 12, 2023
1 parent 2e86e5b commit 98f99ba
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 60 deletions.
35 changes: 19 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ metnet3 = MetNet3(
lead_time_embed_dim = 32,
input_spatial_size = 624,
attn_dim_head = 8,
sparse_input_2496_channels = 8,
dense_input_2496_channels = 8,
dense_input_4996_channels = 8,
hrrr_channels = 617,
input_2496_channels = 2 + 14 + 1 + 2 + 20,
input_4996_channels = 16 + 1,
precipitation_target_bins = dict(
mrms_rate = 512,
mrms_accumulation = 512,
Expand All @@ -43,16 +43,16 @@ metnet3 = MetNet3(
omo_wind_component_y = 256,
omo_wind_direction = 180
),
hrrr_loss_weight = 10,
hrrr_target_channels = 617
hrrr_loss_weight = 10
)

# inputs

lead_times = torch.randint(0, 722, (2,))
sparse_input_2496 = torch.randn((2, 8, 624, 624))
dense_input_2496 = torch.randn((2, 8, 624, 624))
dense_input_4996 = torch.randn((2, 8, 624, 624))
hrrr_input_2496 = torch.randn((2, 617, 624, 624))
hrrr_stale_state = torch.randn((2, 1, 624, 624))
input_2496 = torch.randn((2, 39, 624, 624))
input_4996 = torch.randn((2, 17, 624, 624))

# targets

Expand All @@ -74,9 +74,10 @@ hrrr_target = torch.randn(2, 617, 128, 128)

total_loss, loss_breakdown = metnet3(
lead_times = lead_times,
sparse_input_2496 = sparse_input_2496,
dense_input_2496 = dense_input_2496,
dense_input_4996 = dense_input_4996,
hrrr_input_2496 = hrrr_input_2496,
hrrr_stale_state = hrrr_stale_state,
input_2496 = input_2496,
input_4996 = input_4996,
precipitation_targets = precipitation_targets,
surface_targets = surface_targets,
hrrr_target = hrrr_target
Expand All @@ -88,13 +89,15 @@ total_loss.backward()

metnet3.eval()

surface_targets, hrrr_target, precipitation_targets = metnet3(
surface_preds, hrrr_pred, precipitation_preds = metnet3(
lead_times = lead_times,
sparse_input_2496 = sparse_input_2496,
dense_input_2496 = dense_input_2496,
dense_input_4996 = dense_input_4996
hrrr_input_2496 = hrrr_input_2496,
hrrr_stale_state = hrrr_stale_state,
input_2496 = input_2496,
input_4996 = input_4996,
)


# Dict[str, Tensor], Tensor, Dict[str, Tensor]
```

Expand All @@ -103,8 +106,8 @@ surface_targets, hrrr_target, precipitation_targets = metnet3(
- [x] figure out all the cross entropy and MSE losses
- [x] auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training (using sync batchnorm as hack)
- [x] allow researcher to pass in their own normalization variables for HRRR
- [x] build all the inputs to spec, also make sure hrrr input is normalized, offer option to unnormalize hrrr predictions

- [ ] build all the inputs to spec, also make sure hrrr input is normalized, offer option to unnormalize hrrr predictions
- [ ] make sure model can be easily saved and loaded, with different ways of handling hrrr norm
- [ ] figure out the topological embedding, consult a neural weather researcher

Expand Down
97 changes: 54 additions & 43 deletions metnet3_pytorch/metnet3_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,8 @@ def __init__(
vit_window_size = 8,
vit_mbconv_expansion_rate = 4,
vit_mbconv_shrinkage_rate = 0.25,
sparse_input_2496_channels = 8,
dense_input_2496_channels = 8,
dense_input_4996_channels = 8,
input_2496_channels = 2 + 14 + 1 + 2 + 20,
input_4996_channels = 16 + 1,
surface_and_hrrr_target_spatial_size = 128,
precipitation_target_bins: Dict[str, int] = dict(
mrms_rate = 512,
Expand All @@ -610,26 +609,26 @@ def __init__(
omo_wind_component_y = 256,
omo_wind_direction = 180
),
hrrr_target_channels = 617,
hrrr_channels = 617,
hrrr_norm_statistics: Optional[Tensor] = None,
hrrr_loss_weight = 10,
crop_size_post_16km = 48,
resnet_block_depth = 2,
):
super().__init__()
self.sparse_input_2496_shape = (sparse_input_2496_channels, input_spatial_size, input_spatial_size)
self.dense_input_2496_shape = (dense_input_2496_channels, input_spatial_size, input_spatial_size)
self.dense_input_4996_shape = (dense_input_4996_channels, input_spatial_size, input_spatial_size)
self.hrrr_input_2496_shape = (hrrr_channels, input_spatial_size, input_spatial_size)
self.input_2496_shape = (input_2496_channels, input_spatial_size, input_spatial_size)
self.input_4996_shape = (input_4996_channels, input_spatial_size, input_spatial_size)

self.surface_and_hrrr_target_spatial_size = surface_and_hrrr_target_spatial_size

self.surface_target_shape = ((self.surface_and_hrrr_target_spatial_size,) * 2)
self.hrrr_target_shape = (hrrr_target_channels, *self.surface_target_shape)
self.hrrr_target_shape = (hrrr_channels, *self.surface_target_shape)
self.precipitation_target_shape = (surface_and_hrrr_target_spatial_size * 4,) * 2

self.lead_time_embedding = nn.Embedding(num_lead_times, lead_time_embed_dim)

dim_in_4km = sparse_input_2496_channels + dense_input_2496_channels
dim_in_4km = hrrr_channels + input_2496_channels + 1

self.to_skip_connect_4km = CenterCrop(crop_size_post_16km * 4)

Expand All @@ -645,7 +644,7 @@ def __init__(
CenterPad(input_spatial_size)
)

dim_in_8km = dense_input_4996_channels + dim
dim_in_8km = input_4996_channels + dim

self.resnet_blocks_down_8km = ResnetBlocks(
dim = dim,
Expand Down Expand Up @@ -726,53 +725,79 @@ def __init__(

self.to_hrrr_pred = Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, hrrr_target_channels, 1)
nn.Conv2d(dim, hrrr_channels, 1)
)

# they scale hrrr loss by 10. but also divided by number of channels

self.hrrr_loss_weight = hrrr_loss_weight / hrrr_target_channels
self.hrrr_loss_weight = hrrr_loss_weight / hrrr_channels

self.has_hrrr_norm_statistics = exists(hrrr_norm_statistics)

if self.has_hrrr_norm_statistics:
assert hrrr_norm_statistics.shape == (2, hrrr_target_channels), f'normalization statistics must be of shape (2, {normed_hrrr_target}), containing mean and variance of each target calculated from the dataset'
assert hrrr_norm_statistics.shape == (2, hrrr_channels), f'normalization statistics must be of shape (2, {normed_hrrr_target}), containing mean and variance of each target calculated from the dataset'
self.register_buffer('hrrr_norm_statistics', hrrr_norm_statistics)
else:
self.batchnorm_hrrr = MaybeSyncBatchnorm2d()(hrrr_target_channels, affine = False)
self.batchnorm_hrrr = MaybeSyncBatchnorm2d()(hrrr_channels, affine = False)

self.mse_loss_scaler = LossScaler()

@beartype
def forward(
self,
*,
lead_times,
sparse_input_2496,
dense_input_2496,
dense_input_4996,
hrrr_input_2496,
hrrr_stale_state,
input_2496,
input_4996,
surface_targets: Optional[Dict[str, Tensor]] = None,
precipitation_targets: Optional[Dict[str, Tensor]] = None,
hrrr_target: Optional[Tensor] = None,
):
batch = lead_times.shape[0]

assert batch == sparse_input_2496.shape[0] == dense_input_2496.shape[0] == dense_input_4996.shape[0], 'batch size across all inputs must be the same'
assert batch == hrrr_input_2496.shape[0] == input_2496.shape[0] == input_4996.shape[0], 'batch size across all inputs must be the same'

assert sparse_input_2496.shape[1:] == self.sparse_input_2496_shape
assert dense_input_2496.shape[1:] == self.dense_input_2496_shape
assert dense_input_4996.shape[1:] == self.dense_input_4996_shape
assert hrrr_input_2496.shape[1:] == self.hrrr_input_2496_shape
assert input_2496.shape[1:] == self.input_2496_shape
assert input_4996.shape[1:] == self.input_4996_shape

# normalize HRRR input and target as needed

if self.has_hrrr_norm_statistics:
mean, variance = self.hrrr_norm_statistics
mean = rearrange(mean, 'c -> c 1 1')
variance = rearrange(variance, 'c -> c 1 1')
inv_std = variance.clamp(min = 1e-5).rsqrt()

normed_hrrr_input = (hrrr_input_2496 - mean) * inv_std

if exists(hrrr_target):
normed_hrrr_target = (hrrr_target - mean) * inv_std

else:
# use a batchnorm to normalize each channel to mean zero and unit variance

with freeze_batchnorm(self.batchnorm_hrrr) as frozen_batchnorm:
normed_hrrr_input = frozen_batchnorm(hrrr_input_2496)

if exists(hrrr_target):
normed_hrrr_target = frozen_batchnorm(hrrr_target)

# main network

cond = self.lead_time_embedding(lead_times)

x = torch.cat((sparse_input_2496, dense_input_2496), dim = 1)
x = torch.cat((normed_hrrr_input, hrrr_stale_state, input_2496), dim = 1)

skip_connect_4km = self.to_skip_connect_4km(x)

x = self.resnet_blocks_down_4km(x, cond = cond)

x = self.downsample_and_pad_to_8km(x)

x = torch.cat((dense_input_4996, x), dim = 1)
x = torch.cat((input_4996, x), dim = 1)

skip_connect_8km = self.to_skip_connect_8km(x)

Expand Down Expand Up @@ -866,30 +891,16 @@ def forward(
ce_losses = ce_losses + precipition_loss

# calculate HRRR mse loss
# proposed loss gradient rescaler from section 4.3.2

if self.has_hrrr_norm_statistics:
mean, variance = self.hrrr_norm_statistics
mean = rearrange(mean, 'c -> c 1 1')
variance = rearrange(variance, 'c -> c 1 1')
inv_std = variance.clamp(min = 1e-5).rsqrt()

normed_hrrr_target = (hrrr_target - mean) * inv_std
normed_hrrr_pred = (hrrr_pred - mean) * inv_std
else:
# use a batchnorm to normalize each channel to mean zero and unit variance

if self.training:
_ = self.batchnorm_hrrr(hrrr_target)

with freeze_batchnorm(self.batchnorm_hrrr) as frozen_batchnorm:
normed_hrrr_pred = frozen_batchnorm(hrrr_pred)
normed_hrrr_target = frozen_batchnorm(hrrr_target)
hrrr_pred = self.mse_loss_scaler(hrrr_pred)

# proposed loss gradient rescaler from section 4.3.2
hrrr_loss = F.mse_loss(hrrr_pred, normed_hrrr_target)

normed_hrrr_pred = self.mse_loss_scaler(normed_hrrr_pred)
# update hrrr normalization statistics, if using batchnorm way

hrrr_loss = F.mse_loss(normed_hrrr_pred, normed_hrrr_target)
if not self.has_hrrr_norm_statistics and self.training:
_ = self.batchnorm_hrrr(hrrr_target)

# total loss

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'metnet3-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.8',
version = '0.0.9',
license='MIT',
description = 'MetNet 3 - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 98f99ba

Please sign in to comment.