diff --git a/mmedit/core/hooks/__init__.py b/mmedit/core/hooks/__init__.py index 9183c5015e..575c43b35c 100644 --- a/mmedit/core/hooks/__init__.py +++ b/mmedit/core/hooks/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .ema import ExponentialMovingAverageHook from .visualization import VisualizationHook -__all__ = ['VisualizationHook'] +__all__ = ['VisualizationHook', 'ExponentialMovingAverageHook'] diff --git a/mmedit/core/hooks/ema.py b/mmedit/core/hooks/ema.py new file mode 100644 index 0000000000..0e7f0b2e84 --- /dev/null +++ b/mmedit/core/hooks/ema.py @@ -0,0 +1,113 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from copy import deepcopy +from functools import partial + +import mmcv +import torch +from mmcv.parallel import is_module_wrapper +from mmcv.runner import HOOKS, Hook + + +@HOOKS.register_module() +class ExponentialMovingAverageHook(Hook): + """Exponential Moving Average Hook. + + Exponential moving average is a trick that widely used in current GAN + literature, e.g., PGGAN, StyleGAN, and BigGAN. This general idea of it is + maintaining a model with the same architecture, but its parameters are + updated as a moving average of the trained weights in the original model. + In general, the model with moving averaged weights achieves better + performance. + + Args: + module_keys (str | tuple[str]): The name of the ema model. Note that we + require these keys are followed by '_ema' so that we can easily + find the original model by discarding the last four characters. + interp_mode (str, optional): Mode of the interpolation method. + Defaults to 'lerp'. + interp_cfg (dict | None, optional): Set arguments of the interpolation + function. Defaults to None. + interval (int, optional): Evaluation interval (by iterations). + Default: -1. + start_iter (int, optional): Start iteration for ema. If the start + iteration is not reached, the weights of ema model will maintain + the same as the original one. Otherwise, its parameters are updated + as a moving average of the trained weights in the original model. + Default: 0. + """ + + def __init__(self, + module_keys, + interp_mode='lerp', + interp_cfg=None, + interval=-1, + start_iter=0): + super().__init__() + assert isinstance(module_keys, str) or mmcv.is_tuple_of( + module_keys, str) + self.module_keys = (module_keys, ) if isinstance(module_keys, + str) else module_keys + # sanity check for the format of module keys + for k in self.module_keys: + assert k.endswith( + '_ema'), 'You should give keys that end with "_ema".' + self.interp_mode = interp_mode + self.interp_cfg = dict() if interp_cfg is None else deepcopy( + interp_cfg) + self.interval = interval + self.start_iter = start_iter + + assert hasattr( + self, interp_mode + ), f'Currently, we do not support {self.interp_mode} for EMA.' + self.interp_func = partial( + getattr(self, interp_mode), **self.interp_cfg) + + @staticmethod + def lerp(a, b, momentum=0.999, momentum_nontrainable=0., trainable=True): + m = momentum if trainable else momentum_nontrainable + return a + (b - a) * m + + def every_n_iters(self, runner, n): + if runner.iter < self.start_iter: + return True + return (runner.iter + 1 - self.start_iter) % n == 0 if n > 0 else False + + @torch.no_grad() + def after_train_iter(self, runner): + if not self.every_n_iters(runner, self.interval): + return + + model = runner.model.module if is_module_wrapper( + runner.model) else runner.model + + for key in self.module_keys: + # get current ema states + ema_net = getattr(model, key) + states_ema = ema_net.state_dict(keep_vars=False) + # get currently original states + net = getattr(model, key[:-4]) + states_orig = net.state_dict(keep_vars=True) + + for k, v in states_orig.items(): + if runner.iter < self.start_iter: + states_ema[k].data.copy_(v.data) + else: + states_ema[k] = self.interp_func( + v, states_ema[k], trainable=v.requires_grad).detach() + ema_net.load_state_dict(states_ema, strict=True) + + def before_run(self, runner): + model = runner.model.module if is_module_wrapper( + runner.model) else runner.model + # sanity check for ema model + for k in self.module_keys: + if not hasattr(model, k) and not hasattr(model, k[:-4]): + raise RuntimeError( + f'Cannot find both {k[:-4]} and {k} network for EMA hook.') + if not hasattr(model, k) and hasattr(model, k[:-4]): + setattr(model, k, deepcopy(getattr(model, k[:-4]))) + warnings.warn( + f'We do not suggest construct and initialize EMA model {k}' + ' in hook. You may explicitly define it by yourself.') diff --git a/tests/test_runtime/test_ema_hook.py b/tests/test_runtime/test_ema_hook.py new file mode 100644 index 0000000000..d7bb9e80ad --- /dev/null +++ b/tests/test_runtime/test_ema_hook.py @@ -0,0 +1,236 @@ +from copy import deepcopy + +import pytest +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel + +from mmedit.core.hooks import ExponentialMovingAverageHook + + +class SimpleModule(nn.Module): + + def __init__(self): + super().__init__() + self.a = nn.Parameter(torch.tensor([1., 2.])) + if torch.__version__ >= '1.7.0': + self.register_buffer('b', torch.tensor([2., 3.]), persistent=True) + self.register_buffer('c', torch.tensor([0., 1.]), persistent=False) + else: + self.register_buffer('b', torch.tensor([2., 3.])) + self.c = torch.tensor([0., 1.]) + + +class SimpleModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.module_a = SimpleModule() + self.module_b = SimpleModule() + + self.module_a_ema = SimpleModule() + self.module_b_ema = SimpleModule() + + +class SimpleModelNoEMA(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.module_a = SimpleModule() + self.module_b = SimpleModule() + + +class SimpleRunner: + + def __init__(self): + self.model = SimpleModel() + self.iter = 0 + + +class TestEMA: + + @classmethod + def setup_class(cls): + cls.default_config = dict( + module_keys=('module_a_ema', 'module_b_ema'), + interval=1, + interp_cfg=dict(momentum=0.5)) + cls.runner = SimpleRunner() + + @torch.no_grad() + def test_ema_hook(self): + cfg_ = deepcopy(self.default_config) + cfg_['interval'] = -1 + ema = ExponentialMovingAverageHook(**cfg_) + ema.before_run(self.runner) + ema.after_train_iter(self.runner) + + module_a = self.runner.model.module_a + module_a_ema = self.runner.model.module_a_ema + + ema_states = module_a_ema.state_dict() + assert torch.equal(ema_states['a'], torch.tensor([1., 2.])) + + ema = ExponentialMovingAverageHook(**self.default_config) + ema.after_train_iter(self.runner) + + ema_states = module_a_ema.state_dict() + assert torch.equal(ema_states['a'], torch.tensor([1., 2.])) + + module_a.b /= 2. + module_a.a.data /= 2. + module_a.c /= 2. + + self.runner.iter += 1 + ema.after_train_iter(self.runner) + ema_states = module_a_ema.state_dict() + assert torch.equal(self.runner.model.module_a.a, + torch.tensor([0.5, 1.])) + assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5])) + assert torch.equal(ema_states['b'], torch.tensor([1., 1.5])) + assert 'c' not in ema_states + + # check for the validity of args + with pytest.raises(AssertionError): + _ = ExponentialMovingAverageHook(module_keys=['a']) + + with pytest.raises(AssertionError): + _ = ExponentialMovingAverageHook(module_keys=('a')) + + with pytest.raises(AssertionError): + _ = ExponentialMovingAverageHook( + module_keys=('module_a_ema'), interp_mode='xxx') + + # test before run + ema = ExponentialMovingAverageHook(**self.default_config) + self.runner.model = SimpleModelNoEMA() + self.runner.iter = 0 + ema.before_run(self.runner) + assert hasattr(self.runner.model, 'module_a_ema') + + module_a = self.runner.model.module_a + module_a_ema = self.runner.model.module_a_ema + + ema.after_train_iter(self.runner) + ema_states = module_a_ema.state_dict() + assert torch.equal(ema_states['a'], torch.tensor([1., 2.])) + + module_a.b /= 2. + module_a.a.data /= 2. + module_a.c /= 2. + + self.runner.iter += 1 + ema.after_train_iter(self.runner) + ema_states = module_a_ema.state_dict() + assert torch.equal(self.runner.model.module_a.a, + torch.tensor([0.5, 1.])) + assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5])) + assert torch.equal(ema_states['b'], torch.tensor([1., 1.5])) + assert 'c' not in ema_states + + # test ema with simple warm up + runner = SimpleRunner() + cfg_ = deepcopy(self.default_config) + cfg_.update(dict(start_iter=3, interval=1)) + ema = ExponentialMovingAverageHook(**cfg_) + ema.before_run(runner) + + module_a = runner.model.module_a + module_a_ema = runner.model.module_a_ema + + module_a.a.data /= 2. + + runner.iter += 1 + ema.after_train_iter(runner) + ema_states = module_a_ema.state_dict() + assert torch.equal(runner.model.module_a.a, torch.tensor([0.5, 1.])) + assert torch.equal(ema_states['a'], torch.tensor([0.5, 1.])) + + module_a.a.data /= 2 + runner.iter += 2 + ema.after_train_iter(runner) + ema_states = module_a_ema.state_dict() + assert torch.equal(runner.model.module_a.a, torch.tensor([0.25, 0.5])) + assert torch.equal(ema_states['a'], torch.tensor([0.375, 0.75])) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') + def test_ema_hook_cuda(self): + ema = ExponentialMovingAverageHook(**self.default_config) + cuda_runner = SimpleRunner() + cuda_runner.model = cuda_runner.model.cuda() + ema.after_train_iter(cuda_runner) + + module_a = cuda_runner.model.module_a + module_a_ema = cuda_runner.model.module_a_ema + + ema_states = module_a_ema.state_dict() + assert torch.equal(ema_states['a'], torch.tensor([1., 2.]).cuda()) + + module_a.b /= 2. + module_a.a.data /= 2. + module_a.c /= 2. + + cuda_runner.iter += 1 + ema.after_train_iter(cuda_runner) + ema_states = module_a_ema.state_dict() + assert torch.equal(cuda_runner.model.module_a.a, + torch.tensor([0.5, 1.]).cuda()) + assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5]).cuda()) + assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]).cuda()) + assert 'c' not in ema_states + + # test before run + ema = ExponentialMovingAverageHook(**self.default_config) + self.runner.model = SimpleModelNoEMA().cuda() + self.runner.model = DataParallel(self.runner.model) + self.runner.iter = 0 + ema.before_run(self.runner) + assert hasattr(self.runner.model.module, 'module_a_ema') + + module_a = self.runner.model.module.module_a + module_a_ema = self.runner.model.module.module_a_ema + + ema.after_train_iter(self.runner) + ema_states = module_a_ema.state_dict() + assert torch.equal(ema_states['a'], torch.tensor([1., 2.]).cuda()) + + module_a.b /= 2. + module_a.a.data /= 2. + module_a.c /= 2. + + self.runner.iter += 1 + ema.after_train_iter(self.runner) + ema_states = module_a_ema.state_dict() + assert torch.equal(self.runner.model.module.module_a.a, + torch.tensor([0.5, 1.]).cuda()) + assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5]).cuda()) + assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]).cuda()) + assert 'c' not in ema_states + + # test ema with simple warm up + runner = SimpleRunner() + runner.model = runner.model.cuda() + cfg_ = deepcopy(self.default_config) + cfg_.update(dict(start_iter=3, interval=1)) + ema = ExponentialMovingAverageHook(**cfg_) + ema.before_run(runner) + + module_a = runner.model.module_a + module_a_ema = runner.model.module_a_ema + + module_a.a.data /= 2. + + runner.iter += 1 + ema.after_train_iter(runner) + ema_states = module_a_ema.state_dict() + assert torch.equal(runner.model.module_a.a, + torch.tensor([0.5, 1.]).cuda()) + assert torch.equal(ema_states['a'], torch.tensor([0.5, 1.]).cuda()) + + module_a.a.data /= 2 + runner.iter += 2 + ema.after_train_iter(runner) + ema_states = module_a_ema.state_dict() + assert torch.equal(runner.model.module_a.a, + torch.tensor([0.25, 0.5]).cuda()) + assert torch.equal(ema_states['a'], torch.tensor([0.375, 0.75]).cuda())