From 2bfcd1f7b8f8bc92826486c96457742b85ce6c7d Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 7 Nov 2023 19:16:06 -0800 Subject: [PATCH] running mean and variance should only be updated during training --- metnet3_pytorch/metnet3_pytorch.py | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/metnet3_pytorch/metnet3_pytorch.py b/metnet3_pytorch/metnet3_pytorch.py index 7ce9548..b69e10e 100644 --- a/metnet3_pytorch/metnet3_pytorch.py +++ b/metnet3_pytorch/metnet3_pytorch.py @@ -799,7 +799,8 @@ def forward( # use a batchnorm to normalize each channel to mean zero and unit variance - _ = self.batchnorm_hrrr(hrrr_target) + if self.training: + _ = self.batchnorm_hrrr(hrrr_target) with freeze_batchnorm(self.batchnorm_hrrr) as frozen_batchnorm: normed_hrrr_pred = frozen_batchnorm(hrrr_pred) diff --git a/setup.py b/setup.py index bc788f2..1fa17f0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'metnet3-pytorch', packages = find_packages(exclude=[]), - version = '0.0.5', + version = '0.0.6', license='MIT', description = 'MetNet 3 - Pytorch', author = 'Phil Wang',