Skip to content

Commit

Permalink
handle both the cross entropy loss as well as the mse loss
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 7, 2023
1 parent 134d839 commit c053ca8
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 11 deletions.
31 changes: 27 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,50 @@ metnet3 = MetNet3(
sparse_input_2496_channels = 8,
dense_input_2496_channels = 8,
dense_input_4996_channels = 8,
surface_target_channels = 4,
hrrr_target_channels = 4,
precipitation_target_channels = 4
precipitation_target_channels = 2,
surface_target_channels = 6,
hrrr_target_channels = 617,
)

lead_times = torch.randint(0, 722, (2,))
sparse_input_2496 = torch.randn((2, 8, 624, 624))
dense_input_2496 = torch.randn((2, 8, 624, 624))
dense_input_4996 = torch.randn((2, 8, 624, 624))

precipitation_target = torch.randint(0, 2, (2, 512, 512))
surface_target = torch.randint(0, 6, (2, 128, 128))
hrrr_target = torch.randn(2, 617, 128, 128)

total_loss, loss_breakdown = metnet3(
lead_times = lead_times,
sparse_input_2496 = sparse_input_2496,
dense_input_2496 = dense_input_2496,
dense_input_4996 = dense_input_4996,
precipitation_target = precipitation_target,
surface_target = surface_target,
hrrr_target = hrrr_target
)

total_loss.backward()

# after much training from above, you can predict as follows

metnet3.eval()

surface_target, hrrr_target, precipitation_target = metnet3(
lead_times = lead_times,
sparse_input_2496 = sparse_input_2496,
dense_input_2496 = dense_input_2496,
dense_input_4996 = dense_input_4996
)

```

## Todo

- [ ] figure out all the cross entropy and MSE losses, and handle the normalization of MSE losses specifically as mentioned in the paper, utilizing sync batchnorm for tracking running mean and variance
- [x] figure out all the cross entropy and MSE losses

- [ ] auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training, as well as allow researcher to pass in their own normalization variables
- [ ] figure out the topological embedding, consult a neural weather researcher

## Citations
Expand Down
57 changes: 51 additions & 6 deletions metnet3_pytorch/metnet3_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def forward(self, x, cond = None):
for block in self.blocks:
x = block(x, cond = cond)

print(x.shape)
return x

# multi-headed rms normalization, for query / key normalized attention
Expand Down Expand Up @@ -582,9 +581,9 @@ def __init__(
dense_input_2496_channels = 8,
dense_input_4996_channels = 8,
surface_and_hrrr_target_spatial_size = 128,
surface_target_channels = 4,
hrrr_target_channels = 4,
precipitation_target_channels = 4,
surface_target_channels = 6,
hrrr_target_channels = 617,
precipitation_target_channels = 2,
crop_size_post_16km = 48,
resnet_block_depth = 2,
):
Expand All @@ -595,6 +594,10 @@ def __init__(

self.surface_and_hrrr_target_spatial_size = surface_and_hrrr_target_spatial_size

self.surface_target_shape = ((self.surface_and_hrrr_target_spatial_size,) * 2)
self.hrrr_target_shape = (hrrr_target_channels, *self.surface_target_shape)
self.precipitation_target_shape = (surface_and_hrrr_target_spatial_size * 4,) * 2

self.lead_time_embedding = nn.Embedding(num_lead_times, lead_time_embed_dim)

dim_in_4km = sparse_input_2496_channels + dense_input_2496_channels
Expand Down Expand Up @@ -683,6 +686,8 @@ def __init__(
nn.Conv2d(dim, precipitation_target_channels, 1)
)

self.mse_loss_scaler = LossScaler()

def forward(
self,
lead_times,
Expand All @@ -693,7 +698,9 @@ def forward(
hrrr_target = None,
precipitation_target = None
):
assert lead_times.shape[0] == sparse_input_2496.shape[0] == dense_input_2496.shape[0] == dense_input_4996.shape[0], 'batch size across all inputs must be the same'
batch = lead_times.shape[0]

assert batch == sparse_input_2496.shape[0] == dense_input_2496.shape[0] == dense_input_4996.shape[0], 'batch size across all inputs must be the same'

assert sparse_input_2496.shape[1:] == self.sparse_input_2496_shape
assert dense_input_2496.shape[1:] == self.dense_input_2496_shape
Expand Down Expand Up @@ -744,4 +751,42 @@ def forward(

precipitation_pred = self.to_precipitation_pred(x)

return Predictions(surface_pred, hrrr_pred, precipitation_pred)
exist_targets = [exists(target) for target in (surface_target, hrrr_target, precipitation_target)]

pred = Predictions(surface_pred, hrrr_pred, precipitation_pred)

if not any(exist_targets):
return pred

assert all(exist_targets), 'all targets must be passed in for loss calculation'

assert batch == surface_target.shape[0] == hrrr_target.shape[0] == precipitation_target.shape[0]

assert surface_target.shape[1:] == self.surface_target_shape
assert hrrr_target.shape[1:] == self.hrrr_target_shape
assert precipitation_target.shape[1:] == self.precipitation_target_shape

# calculate categorical losses

surface_pred = rearrange(surface_pred, '... h w -> ... (h w)')
precipitation_pred = rearrange(precipitation_pred, '... h w -> ... (h w)')

surface_target = rearrange(surface_target, '... h w -> ... (h w)')
precipitation_target = rearrange(precipitation_target, '... h w -> ... (h w)')

surface_loss = F.cross_entropy(surface_pred, surface_target)
precipition_loss = F.cross_entropy(precipitation_pred, precipitation_target)

# calculate HRRR mse loss

hrrr_pred = self.mse_loss_scaler(hrrr_pred)

hrrr_loss = F.mse_loss(hrrr_pred, hrrr_target)

# total loss

total_loss = hrrr_loss + precipition_loss + surface_loss

loss_breakdown = LossBreakdown(surface_loss, hrrr_loss, precipition_loss)

return total_loss, loss_breakdown
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'metnet3-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'MetNet 3 - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit c053ca8

Please sign in to comment.