diff --git a/README.md b/README.md index 52f51ae..fbf0474 100644 --- a/README.md +++ b/README.md @@ -31,9 +31,9 @@ metnet3 = MetNet3( sparse_input_2496_channels = 8, dense_input_2496_channels = 8, dense_input_4996_channels = 8, - surface_target_channels = 4, - hrrr_target_channels = 4, - precipitation_target_channels = 4 + precipitation_target_channels = 2, + surface_target_channels = 6, + hrrr_target_channels = 617, ) lead_times = torch.randint(0, 722, (2,)) @@ -41,17 +41,40 @@ sparse_input_2496 = torch.randn((2, 8, 624, 624)) dense_input_2496 = torch.randn((2, 8, 624, 624)) dense_input_4996 = torch.randn((2, 8, 624, 624)) +precipitation_target = torch.randint(0, 2, (2, 512, 512)) +surface_target = torch.randint(0, 6, (2, 128, 128)) +hrrr_target = torch.randn(2, 617, 128, 128) + +total_loss, loss_breakdown = metnet3( + lead_times = lead_times, + sparse_input_2496 = sparse_input_2496, + dense_input_2496 = dense_input_2496, + dense_input_4996 = dense_input_4996, + precipitation_target = precipitation_target, + surface_target = surface_target, + hrrr_target = hrrr_target +) + +total_loss.backward() + +# after much training from above, you can predict as follows + +metnet3.eval() + surface_target, hrrr_target, precipitation_target = metnet3( lead_times = lead_times, sparse_input_2496 = sparse_input_2496, dense_input_2496 = dense_input_2496, dense_input_4996 = dense_input_4996 ) + ``` ## Todo -- [ ] figure out all the cross entropy and MSE losses, and handle the normalization of MSE losses specifically as mentioned in the paper, utilizing sync batchnorm for tracking running mean and variance +- [x] figure out all the cross entropy and MSE losses + +- [ ] auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training, as well as allow researcher to pass in their own normalization variables - [ ] figure out the topological embedding, consult a neural weather researcher ## Citations diff --git a/metnet3_pytorch/metnet3_pytorch.py b/metnet3_pytorch/metnet3_pytorch.py index c569cff..a3a099a 100644 --- a/metnet3_pytorch/metnet3_pytorch.py +++ b/metnet3_pytorch/metnet3_pytorch.py @@ -200,7 +200,6 @@ def forward(self, x, cond = None): for block in self.blocks: x = block(x, cond = cond) - print(x.shape) return x # multi-headed rms normalization, for query / key normalized attention @@ -582,9 +581,9 @@ def __init__( dense_input_2496_channels = 8, dense_input_4996_channels = 8, surface_and_hrrr_target_spatial_size = 128, - surface_target_channels = 4, - hrrr_target_channels = 4, - precipitation_target_channels = 4, + surface_target_channels = 6, + hrrr_target_channels = 617, + precipitation_target_channels = 2, crop_size_post_16km = 48, resnet_block_depth = 2, ): @@ -595,6 +594,10 @@ def __init__( self.surface_and_hrrr_target_spatial_size = surface_and_hrrr_target_spatial_size + self.surface_target_shape = ((self.surface_and_hrrr_target_spatial_size,) * 2) + self.hrrr_target_shape = (hrrr_target_channels, *self.surface_target_shape) + self.precipitation_target_shape = (surface_and_hrrr_target_spatial_size * 4,) * 2 + self.lead_time_embedding = nn.Embedding(num_lead_times, lead_time_embed_dim) dim_in_4km = sparse_input_2496_channels + dense_input_2496_channels @@ -683,6 +686,8 @@ def __init__( nn.Conv2d(dim, precipitation_target_channels, 1) ) + self.mse_loss_scaler = LossScaler() + def forward( self, lead_times, @@ -693,7 +698,9 @@ def forward( hrrr_target = None, precipitation_target = None ): - assert lead_times.shape[0] == sparse_input_2496.shape[0] == dense_input_2496.shape[0] == dense_input_4996.shape[0], 'batch size across all inputs must be the same' + batch = lead_times.shape[0] + + assert batch == sparse_input_2496.shape[0] == dense_input_2496.shape[0] == dense_input_4996.shape[0], 'batch size across all inputs must be the same' assert sparse_input_2496.shape[1:] == self.sparse_input_2496_shape assert dense_input_2496.shape[1:] == self.dense_input_2496_shape @@ -744,4 +751,42 @@ def forward( precipitation_pred = self.to_precipitation_pred(x) - return Predictions(surface_pred, hrrr_pred, precipitation_pred) + exist_targets = [exists(target) for target in (surface_target, hrrr_target, precipitation_target)] + + pred = Predictions(surface_pred, hrrr_pred, precipitation_pred) + + if not any(exist_targets): + return pred + + assert all(exist_targets), 'all targets must be passed in for loss calculation' + + assert batch == surface_target.shape[0] == hrrr_target.shape[0] == precipitation_target.shape[0] + + assert surface_target.shape[1:] == self.surface_target_shape + assert hrrr_target.shape[1:] == self.hrrr_target_shape + assert precipitation_target.shape[1:] == self.precipitation_target_shape + + # calculate categorical losses + + surface_pred = rearrange(surface_pred, '... h w -> ... (h w)') + precipitation_pred = rearrange(precipitation_pred, '... h w -> ... (h w)') + + surface_target = rearrange(surface_target, '... h w -> ... (h w)') + precipitation_target = rearrange(precipitation_target, '... h w -> ... (h w)') + + surface_loss = F.cross_entropy(surface_pred, surface_target) + precipition_loss = F.cross_entropy(precipitation_pred, precipitation_target) + + # calculate HRRR mse loss + + hrrr_pred = self.mse_loss_scaler(hrrr_pred) + + hrrr_loss = F.mse_loss(hrrr_pred, hrrr_target) + + # total loss + + total_loss = hrrr_loss + precipition_loss + surface_loss + + loss_breakdown = LossBreakdown(surface_loss, hrrr_loss, precipition_loss) + + return total_loss, loss_breakdown diff --git a/setup.py b/setup.py index 7a43dc8..ccfd345 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'metnet3-pytorch', packages = find_packages(exclude=[]), - version = '0.0.1', + version = '0.0.2', license='MIT', description = 'MetNet 3 - Pytorch', author = 'Phil Wang',