Skip to content

Commit

Permalink
redo the hrrr norm strategy to be completely flexible
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 13, 2023
1 parent 98f99ba commit bf63284
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 14 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<img src="./metnet3.png" width="450px"></img>

## MetNet-3 - Pytorch (wip)
## MetNet-3 - Pytorch

Implementation of <a href="https://blog.research.google/2023/11/metnet-3-state-of-art-neural-weather.html">MetNet 3</a>, SOTA neural weather model out of Google Deepmind, in Pytorch

Expand Down Expand Up @@ -43,7 +43,9 @@ metnet3 = MetNet3(
omo_wind_component_y = 256,
omo_wind_direction = 180
),
hrrr_loss_weight = 10
hrrr_loss_weight = 10,
hrrr_norm_strategy = 'sync_batchnorm', # this would use a sync batchnorm to normalize the input hrrr and target, without having to precalculate the mean and variance of the hrrr dataset per channel
hrrr_norm_statistics = None # you can also also set `hrrr_norm_strategy = "precalculated"` and pass in the mean and variance as shape `(2, 617)` through this keyword argument
)

# inputs
Expand Down Expand Up @@ -107,8 +109,8 @@ surface_preds, hrrr_pred, precipitation_preds = metnet3(
- [x] auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training (using sync batchnorm as hack)
- [x] allow researcher to pass in their own normalization variables for HRRR
- [x] build all the inputs to spec, also make sure hrrr input is normalized, offer option to unnormalize hrrr predictions
- [x] make sure model can be easily saved and loaded, with different ways of handling hrrr norm

- [ ] make sure model can be easily saved and loaded, with different ways of handling hrrr norm
- [ ] figure out the topological embedding, consult a neural weather researcher

## Citations
Expand Down
88 changes: 78 additions & 10 deletions metnet3_pytorch/metnet3_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from contextlib import contextmanager
from functools import partial
from collections import namedtuple
Expand All @@ -13,7 +14,9 @@
from einops.layers.torch import Rearrange, Reduce

from beartype import beartype
from beartype.typing import Tuple, Union, List, Optional, Dict
from beartype.typing import Tuple, Union, List, Optional, Dict, Literal

import pickle

# helpers

Expand Down Expand Up @@ -609,13 +612,27 @@ def __init__(
omo_wind_component_y = 256,
omo_wind_direction = 180
),
hrrr_norm_strategy: Union[
Literal['none'],
Literal['precalculated'],
Literal['sync_batchnorm']
] = 'none',
hrrr_channels = 617,
hrrr_norm_statistics: Optional[Tensor] = None,
hrrr_loss_weight = 10,
crop_size_post_16km = 48,
resnet_block_depth = 2,
):
super().__init__()

# for autosaving the config

_locals = locals()
_locals.pop('self', None)
_locals.pop('__class__', None)
_locals.pop('hrrr_norm_statistics', None)
self._configs = pickle.dumps(_locals)

self.hrrr_input_2496_shape = (hrrr_channels, input_spatial_size, input_spatial_size)
self.input_2496_shape = (input_2496_channels, input_spatial_size, input_spatial_size)
self.input_4996_shape = (input_4996_channels, input_spatial_size, input_spatial_size)
Expand Down Expand Up @@ -732,15 +749,60 @@ def __init__(

self.hrrr_loss_weight = hrrr_loss_weight / hrrr_channels

self.has_hrrr_norm_statistics = exists(hrrr_norm_statistics)
self.mse_loss_scaler = LossScaler()

# norm statistics

if self.has_hrrr_norm_statistics:
assert hrrr_norm_statistics.shape == (2, hrrr_channels), f'normalization statistics must be of shape (2, {normed_hrrr_target}), containing mean and variance of each target calculated from the dataset'
default_hrrr_statistics = torch.empty((2, hrrr_channels), dtype = torch.float32)

if hrrr_norm_strategy == 'none':
self.register_buffer('hrrr_norm_statistics', default_hrrr_statistics, persistent = False)

elif hrrr_norm_strategy == 'precalculated':
assert exists(hrrr_norm_statistics), 'hrrr_norm_statistics must be passed in, if normalizing input hrrr as well as target with precalculated dataset mean and variance'
assert hrrr_norm_statistics.shape == (2, hrrr_channels), f'normalization statistics must be of shape (2, {hrrr_channels}), containing mean and variance of each target calculated from the dataset'
self.register_buffer('hrrr_norm_statistics', hrrr_norm_statistics)
else:

elif hrrr_norm_strategy == 'sync_batchnorm':
self.register_buffer('hrrr_norm_statistics', default_hrrr_statistics, persistent = False)
self.batchnorm_hrrr = MaybeSyncBatchnorm2d()(hrrr_channels, affine = False)

self.mse_loss_scaler = LossScaler()
self.hrrr_norm_strategy = hrrr_norm_strategy

@classmethod
def init_and_load_from(cls, path, strict = True):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path), map_location = 'cpu')

assert 'config' in pkg, 'model configs were not found in this saved checkpoint'

config = pickle.loads(pkg['config'])
tokenizer = cls(**config)
tokenizer.load(path, strict = strict)
return tokenizer

def save(self, path, overwrite = True):
path = Path(path)
assert overwrite or not path.exists(), f'{str(path)} already exists'

pkg = dict(
model_state_dict = self.state_dict(),
config = self._configs
)

torch.save(pkg, str(path))

def load(self, path, strict = True):
path = Path(path)
assert path.exists()

pkg = torch.load(str(path))
state_dict = pkg.get('model_state_dict')

assert exists(state_dict)

self.load_state_dict(state_dict, strict = strict)

@beartype
def forward(
Expand All @@ -763,9 +825,9 @@ def forward(
assert input_2496.shape[1:] == self.input_2496_shape
assert input_4996.shape[1:] == self.input_4996_shape

# normalize HRRR input and target as needed
# normalize HRRR input and target, if needed

if self.has_hrrr_norm_statistics:
if self.hrrr_norm_strategy == 'precalculated':
mean, variance = self.hrrr_norm_statistics
mean = rearrange(mean, 'c -> c 1 1')
variance = rearrange(variance, 'c -> c 1 1')
Expand All @@ -776,7 +838,7 @@ def forward(
if exists(hrrr_target):
normed_hrrr_target = (hrrr_target - mean) * inv_std

else:
elif self.hrrr_norm_strategy == 'sync_batchnorm':
# use a batchnorm to normalize each channel to mean zero and unit variance

with freeze_batchnorm(self.batchnorm_hrrr) as frozen_batchnorm:
Expand All @@ -785,6 +847,12 @@ def forward(
if exists(hrrr_target):
normed_hrrr_target = frozen_batchnorm(hrrr_target)

elif self.hrrr_norm_strategy == 'none':
normed_hrrr_input = hrrr_input_2496

if exists(hrrr_target):
normed_hrrr_target = hrrr_target

# main network

cond = self.lead_time_embedding(lead_times)
Expand Down Expand Up @@ -899,7 +967,7 @@ def forward(

# update hrrr normalization statistics, if using batchnorm way

if not self.has_hrrr_norm_statistics and self.training:
if self.training and self.hrrr_norm_strategy == 'sync_batchnorm':
_ = self.batchnorm_hrrr(hrrr_target)

# total loss
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'metnet3-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.9',
version = '0.0.11',
license='MIT',
description = 'MetNet 3 - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit bf63284

Please sign in to comment.