Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Exponential Moving Average Hook #542

Merged
merged 2 commits into from
Sep 23, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
add unittest
ckkelvinchan committed Sep 20, 2021
commit df46f0a4bf309d02b75a9150635cdc0cb03dbd8a
236 changes: 236 additions & 0 deletions tests/test_runtime/test_ema_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
from copy import deepcopy

import pytest
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel

from mmedit.core.hooks import ExponentialMovingAverageHook


class SimpleModule(nn.Module):

def __init__(self):
super().__init__()
self.a = nn.Parameter(torch.tensor([1., 2.]))
if torch.__version__ >= '1.7.0':
self.register_buffer('b', torch.tensor([2., 3.]), persistent=True)
self.register_buffer('c', torch.tensor([0., 1.]), persistent=False)
else:
self.register_buffer('b', torch.tensor([2., 3.]))
self.c = torch.tensor([0., 1.])


class SimpleModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.module_a = SimpleModule()
self.module_b = SimpleModule()

self.module_a_ema = SimpleModule()
self.module_b_ema = SimpleModule()


class SimpleModelNoEMA(nn.Module):

def __init__(self) -> None:
super().__init__()
self.module_a = SimpleModule()
self.module_b = SimpleModule()


class SimpleRunner:

def __init__(self):
self.model = SimpleModel()
self.iter = 0


class TestEMA:

@classmethod
def setup_class(cls):
cls.default_config = dict(
module_keys=('module_a_ema', 'module_b_ema'),
interval=1,
interp_cfg=dict(momentum=0.5))
cls.runner = SimpleRunner()

@torch.no_grad()
def test_ema_hook(self):
cfg_ = deepcopy(self.default_config)
cfg_['interval'] = -1
ema = ExponentialMovingAverageHook(**cfg_)
ema.before_run(self.runner)
ema.after_train_iter(self.runner)

module_a = self.runner.model.module_a
module_a_ema = self.runner.model.module_a_ema

ema_states = module_a_ema.state_dict()
assert torch.equal(ema_states['a'], torch.tensor([1., 2.]))

ema = ExponentialMovingAverageHook(**self.default_config)
ema.after_train_iter(self.runner)

ema_states = module_a_ema.state_dict()
assert torch.equal(ema_states['a'], torch.tensor([1., 2.]))

module_a.b /= 2.
module_a.a.data /= 2.
module_a.c /= 2.

self.runner.iter += 1
ema.after_train_iter(self.runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(self.runner.model.module_a.a,
torch.tensor([0.5, 1.]))
assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5]))
assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]))
assert 'c' not in ema_states

# check for the validity of args
with pytest.raises(AssertionError):
_ = ExponentialMovingAverageHook(module_keys=['a'])

with pytest.raises(AssertionError):
_ = ExponentialMovingAverageHook(module_keys=('a'))

with pytest.raises(AssertionError):
_ = ExponentialMovingAverageHook(
module_keys=('module_a_ema'), interp_mode='xxx')

# test before run
ema = ExponentialMovingAverageHook(**self.default_config)
self.runner.model = SimpleModelNoEMA()
self.runner.iter = 0
ema.before_run(self.runner)
assert hasattr(self.runner.model, 'module_a_ema')

module_a = self.runner.model.module_a
module_a_ema = self.runner.model.module_a_ema

ema.after_train_iter(self.runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(ema_states['a'], torch.tensor([1., 2.]))

module_a.b /= 2.
module_a.a.data /= 2.
module_a.c /= 2.

self.runner.iter += 1
ema.after_train_iter(self.runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(self.runner.model.module_a.a,
torch.tensor([0.5, 1.]))
assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5]))
assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]))
assert 'c' not in ema_states

# test ema with simple warm up
runner = SimpleRunner()
cfg_ = deepcopy(self.default_config)
cfg_.update(dict(start_iter=3, interval=1))
ema = ExponentialMovingAverageHook(**cfg_)
ema.before_run(runner)

module_a = runner.model.module_a
module_a_ema = runner.model.module_a_ema

module_a.a.data /= 2.

runner.iter += 1
ema.after_train_iter(runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(runner.model.module_a.a, torch.tensor([0.5, 1.]))
assert torch.equal(ema_states['a'], torch.tensor([0.5, 1.]))

module_a.a.data /= 2
runner.iter += 2
ema.after_train_iter(runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(runner.model.module_a.a, torch.tensor([0.25, 0.5]))
assert torch.equal(ema_states['a'], torch.tensor([0.375, 0.75]))

@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_ema_hook_cuda(self):
ema = ExponentialMovingAverageHook(**self.default_config)
cuda_runner = SimpleRunner()
cuda_runner.model = cuda_runner.model.cuda()
ema.after_train_iter(cuda_runner)

module_a = cuda_runner.model.module_a
module_a_ema = cuda_runner.model.module_a_ema

ema_states = module_a_ema.state_dict()
assert torch.equal(ema_states['a'], torch.tensor([1., 2.]).cuda())

module_a.b /= 2.
module_a.a.data /= 2.
module_a.c /= 2.

cuda_runner.iter += 1
ema.after_train_iter(cuda_runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(cuda_runner.model.module_a.a,
torch.tensor([0.5, 1.]).cuda())
assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5]).cuda())
assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]).cuda())
assert 'c' not in ema_states

# test before run
ema = ExponentialMovingAverageHook(**self.default_config)
self.runner.model = SimpleModelNoEMA().cuda()
self.runner.model = DataParallel(self.runner.model)
self.runner.iter = 0
ema.before_run(self.runner)
assert hasattr(self.runner.model.module, 'module_a_ema')

module_a = self.runner.model.module.module_a
module_a_ema = self.runner.model.module.module_a_ema

ema.after_train_iter(self.runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(ema_states['a'], torch.tensor([1., 2.]).cuda())

module_a.b /= 2.
module_a.a.data /= 2.
module_a.c /= 2.

self.runner.iter += 1
ema.after_train_iter(self.runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(self.runner.model.module.module_a.a,
torch.tensor([0.5, 1.]).cuda())
assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5]).cuda())
assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]).cuda())
assert 'c' not in ema_states

# test ema with simple warm up
runner = SimpleRunner()
runner.model = runner.model.cuda()
cfg_ = deepcopy(self.default_config)
cfg_.update(dict(start_iter=3, interval=1))
ema = ExponentialMovingAverageHook(**cfg_)
ema.before_run(runner)

module_a = runner.model.module_a
module_a_ema = runner.model.module_a_ema

module_a.a.data /= 2.

runner.iter += 1
ema.after_train_iter(runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(runner.model.module_a.a,
torch.tensor([0.5, 1.]).cuda())
assert torch.equal(ema_states['a'], torch.tensor([0.5, 1.]).cuda())

module_a.a.data /= 2
runner.iter += 2
ema.after_train_iter(runner)
ema_states = module_a_ema.state_dict()
assert torch.equal(runner.model.module_a.a,
torch.tensor([0.25, 0.5]).cuda())
assert torch.equal(ema_states['a'], torch.tensor([0.375, 0.75]).cuda())