Skip to content

Commit

Permalink
offer another way to normalize hrrr targets and predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 8, 2023
1 parent 2bfcd1f commit 9ab7137
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ surface_target, hrrr_target, precipitation_target = metnet3(

- [x] figure out all the cross entropy and MSE losses
- [x] auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training (using sync batchnorm as hack)
- [x] allow researcher to pass in their own normalization variables for HRRR

- [ ] allow researcher to pass in their own normalization variables for HRRR
- [ ] figure out the topological embedding, consult a neural weather researcher

## Citations
Expand Down
30 changes: 23 additions & 7 deletions metnet3_pytorch/metnet3_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def __init__(
surface_and_hrrr_target_spatial_size = 128,
surface_target_channels = 6,
hrrr_target_channels = 617,
hrrr_norm_statistics: Optional[Tensor] = None,
precipitation_target_channels = 2,
crop_size_post_16km = 48,
resnet_block_depth = 2,
Expand Down Expand Up @@ -702,7 +703,13 @@ def __init__(
nn.Conv2d(dim, precipitation_target_channels, 1)
)

self.batchnorm_hrrr = MaybeSyncBatchnorm2d()(hrrr_target_channels, affine = False)
self.has_hrrr_norm_statistics = exists(hrrr_norm_statistics)

if self.has_hrrr_norm_statistics:
assert hrrr_norm_statistics.shape == (2, hrrr_target_channels), f'normalization statistics must be of shape (2, {normed_hrrr_target}), containing mean and variance of each target calculated from the dataset'
self.register_buffer('hrrr_norm_statistics', hrrr_norm_statistics)
else:
self.batchnorm_hrrr = MaybeSyncBatchnorm2d()(hrrr_target_channels, affine = False)

self.mse_loss_scaler = LossScaler()

Expand Down Expand Up @@ -797,14 +804,23 @@ def forward(

# calculate HRRR mse loss

# use a batchnorm to normalize each channel to mean zero and unit variance
if self.has_hrrr_norm_statistics:
mean, variance = self.hrrr_norm_statistics
mean = rearrange(mean, 'c -> c 1 1')
variance = rearrange(variance, 'c -> c 1 1')
inv_std = variance.clamp(min = 1e-5).rsqrt()

normed_hrrr_target = (hrrr_target - mean) * inv_std
normed_hrrr_pred = (hrrr_pred - mean) * inv_std
else:
# use a batchnorm to normalize each channel to mean zero and unit variance

if self.training:
_ = self.batchnorm_hrrr(hrrr_target)
if self.training:
_ = self.batchnorm_hrrr(hrrr_target)

with freeze_batchnorm(self.batchnorm_hrrr) as frozen_batchnorm:
normed_hrrr_pred = frozen_batchnorm(hrrr_pred)
normed_hrrr_target = frozen_batchnorm(hrrr_target)
with freeze_batchnorm(self.batchnorm_hrrr) as frozen_batchnorm:
normed_hrrr_pred = frozen_batchnorm(hrrr_pred)
normed_hrrr_target = frozen_batchnorm(hrrr_target)

# proposed loss gradient rescaler from section 4.3.2

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'metnet3-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'MetNet 3 - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 9ab7137

Please sign in to comment.