Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support model ema with EMAHook #201

Merged
merged 4 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions configs/common/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
# after every `checkpointer.period` iterations,
# and only `checkpointer.max_to_keep` number of checkpoint will be kept.
checkpointer=dict(period=5000, max_to_keep=100),
# Run evaluation after every `eval_period` number of iterations
# run evaluation after every `eval_period` number of iterations
eval_period=5000,
# Output log to console every `log_period` number of iterations.
# output log to console every `log_period` number of iterations.
log_period=20,
# wandb logging params
# logging training info to Wandb
# note that you should add wandb writer in `train_net.py``
wandb=dict(
enabled=False,
Expand All @@ -43,6 +43,13 @@
name="detrex_experiment",
)
),
# model ema
model_ema=dict(
enabled=False,
decay=0.999,
device="",
use_ema_weights_for_eval_only=False,
),
# the training device, choose from {"cuda", "cpu"}
device="cuda",
# ...
Expand Down
266 changes: 266 additions & 0 deletions detrex/modeling/ema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Model EMA
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/facebookresearch/d2go/blob/main/d2go/modeling/ema.py
# ------------------------------------------------------------------------------------------------

import copy
import itertools
import logging
from contextlib import contextmanager
from typing import List

import torch
from detectron2.engine.train_loop import HookBase


logger = logging.getLogger(__name__)


class EMAState(object):
def __init__(self):
self.state = {}

@classmethod
def FromModel(cls, model: torch.nn.Module, device: str = ""):
ret = cls()
ret.save_from(model, device)
return ret

def save_from(self, model: torch.nn.Module, device: str = ""):
"""Save model state from `model` to this object"""
for name, val in self.get_model_state_iterator(model):
val = val.detach().clone()
self.state[name] = val.to(device) if device else val

def apply_to(self, model: torch.nn.Module):
"""Apply state to `model` from this object"""
with torch.no_grad():
for name, val in self.get_model_state_iterator(model):
assert (
name in self.state
), f"Name {name} not existed, available names {self.state.keys()}"
val.copy_(self.state[name])

@contextmanager
def apply_and_restore(self, model):
old_state = EMAState.FromModel(model, self.device)
self.apply_to(model)
yield old_state
old_state.apply_to(model)

def get_ema_model(self, model):
ret = copy.deepcopy(model)
self.apply_to(ret)
return ret

@property
def device(self):
if not self.has_inited():
return None
return next(iter(self.state.values())).device

def to(self, device):
for name in self.state:
self.state[name] = self.state[name].to(device)
return self

def has_inited(self):
return self.state

def clear(self):
self.state.clear()
return self

def get_model_state_iterator(self, model):
param_iter = model.named_parameters()
buffer_iter = model.named_buffers()
return itertools.chain(param_iter, buffer_iter)

def state_dict(self):
return self.state

def load_state_dict(self, state_dict, strict: bool = True):
self.clear()
for x, y in state_dict.items():
self.state[x] = y
return torch.nn.modules.module._IncompatibleKeys(
missing_keys=[], unexpected_keys=[]
)

def __repr__(self):
ret = f"EMAState(state=[{','.join(self.state.keys())}])"
return ret


class EMAUpdater(object):
"""Model Exponential Moving Average
Keep a moving average of everything in the model state_dict (parameters and
buffers). This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
Note: It's very important to set EMA for ALL network parameters (instead of
parameters that require gradient), including batch-norm moving average mean
and variance. This leads to significant improvement in accuracy.
For example, for EfficientNetB3, with default setting (no mixup, lr exponential
decay) without bn_sync, the EMA accuracy with EMA on params that requires
gradient is 79.87%, while the corresponding accuracy with EMA on all params
is 80.61%.
Also, bn sync should be switched on for EMA.
"""

def __init__(self, state: EMAState, decay: float = 0.999, device: str = ""):
self.decay = decay
self.device = device

self.state = state

def init_state(self, model):
self.state.clear()
self.state.save_from(model, self.device)

def update(self, model):
with torch.no_grad():
ema_param_list = []
param_list = []
for name, val in self.state.get_model_state_iterator(model):
ema_val = self.state.state[name]
if self.device:
val = val.to(self.device)
if val.dtype in [torch.float32, torch.float16]:
ema_param_list.append(ema_val)
param_list.append(val)
else:
ema_val.copy_(ema_val * self.decay + val * (1.0 - self.decay))
self._ema_avg(ema_param_list, param_list, self.decay)

def _ema_avg(
self,
averaged_model_parameters: List[torch.Tensor],
model_parameters: List[torch.Tensor],
decay: float,
) -> None:
"""
Function to perform exponential moving average:
x_avg = alpha * x_avg + (1-alpha)* x_t
"""
torch._foreach_mul_(averaged_model_parameters, decay)
torch._foreach_add_(
averaged_model_parameters, model_parameters, alpha=1 - decay
)


def _remove_ddp(model):
from torch.nn.parallel import DistributedDataParallel

if isinstance(model, DistributedDataParallel):
return model.module
return model


def may_build_model_ema(cfg, model):
if not cfg.train.model_ema.enabled:
return
model = _remove_ddp(model)
assert not hasattr(
model, "ema_state"
), "Name `ema_state` is reserved for model ema."
model.ema_state = EMAState()
logger.info("Using Model EMA.")


def may_get_ema_checkpointer(cfg, model):
if not cfg.train.model_ema.enabled:
return {}
model = _remove_ddp(model)
return {"ema_state": model.ema_state}


def get_model_ema_state(model):
"""Return the ema state stored in `model`"""
model = _remove_ddp(model)
assert hasattr(model, "ema_state")
ema = model.ema_state
return ema


def apply_model_ema(model, state=None, save_current=False):
"""Apply ema stored in `model` to model and returns a function to restore
the weights are applied
"""
model = _remove_ddp(model)

if state is None:
state = get_model_ema_state(model)

if save_current:
# save current model state
old_state = EMAState.FromModel(model, state.device)
state.apply_to(model)

if save_current:
return old_state
return None


@contextmanager
def apply_model_ema_and_restore(model, state=None):
"""Apply ema stored in `model` to model and returns a function to restore
the weights are applied
"""
model = _remove_ddp(model)

if state is None:
state = get_model_ema_state(model)

old_state = EMAState.FromModel(model, state.device)
state.apply_to(model)
yield old_state
old_state.apply_to(model)


class EMAHook(HookBase):
def __init__(self, cfg, model):
model = _remove_ddp(model)
assert cfg.train.model_ema.enabled
assert hasattr(
model, "ema_state"
), "Call `may_build_model_ema` first to initilaize the model ema"
self.model = model
self.ema = self.model.ema_state
self.device = cfg.train.model_ema.device or cfg.model.device
self.ema_updater = EMAUpdater(
self.model.ema_state, decay=cfg.train.model_ema.decay, device=self.device
)

def before_train(self):
if self.ema.has_inited():
self.ema.to(self.device)
else:
self.ema_updater.init_state(self.model)

def after_train(self):
pass

def before_step(self):
pass

def after_step(self):
if not self.model.train:
return
self.ema_updater.update(self.model)
Loading