From b643544dc0bc88c26a0009ff77bd222b0de3d423 Mon Sep 17 00:00:00 2001 From: rentainhe <596106517@qq.com> Date: Mon, 6 Feb 2023 10:45:24 +0800 Subject: [PATCH 1/4] add ema hook --- configs/common/train.py | 13 +- detrex/modeling/ema.py | 266 ++++++++++++++++++++++++++++++++++++++++ tools/train_net.py | 51 ++++++-- 3 files changed, 317 insertions(+), 13 deletions(-) create mode 100644 detrex/modeling/ema.py diff --git a/configs/common/train.py b/configs/common/train.py index 2d9e854d..bd0f8dbb 100644 --- a/configs/common/train.py +++ b/configs/common/train.py @@ -29,11 +29,11 @@ # after every `checkpointer.period` iterations, # and only `checkpointer.max_to_keep` number of checkpoint will be kept. checkpointer=dict(period=5000, max_to_keep=100), - # Run evaluation after every `eval_period` number of iterations + # run evaluation after every `eval_period` number of iterations eval_period=5000, - # Output log to console every `log_period` number of iterations. + # output log to console every `log_period` number of iterations. log_period=20, - # wandb logging params + # logging training info to Wandb # note that you should add wandb writer in `train_net.py`` wandb=dict( enabled=False, @@ -43,6 +43,13 @@ name="detrex_experiment", ) ), + # model ema + model_ema=dict( + enabled=False, + decay=0.999, + device="", + use_ema_weights_for_eval_only=True, + ), # the training device, choose from {"cuda", "cpu"} device="cuda", # ... diff --git a/detrex/modeling/ema.py b/detrex/modeling/ema.py new file mode 100644 index 00000000..b538acec --- /dev/null +++ b/detrex/modeling/ema.py @@ -0,0 +1,266 @@ +# coding=utf-8 +# Copyright 2022 The IDEA Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------------------------------ +# Model EMA +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------------------------------ +# Modified from: +# https://github.com/facebookresearch/d2go/blob/main/d2go/modeling/ema.py +# ------------------------------------------------------------------------------------------------ + +import copy +import itertools +import logging +from contextlib import contextmanager +from typing import List + +import torch +from detectron2.engine.train_loop import HookBase + + +logger = logging.getLogger(__name__) + + +class EMAState(object): + def __init__(self): + self.state = {} + + @classmethod + def FromModel(cls, model: torch.nn.Module, device: str = ""): + ret = cls() + ret.save_from(model, device) + return ret + + def save_from(self, model: torch.nn.Module, device: str = ""): + """Save model state from `model` to this object""" + for name, val in self.get_model_state_iterator(model): + val = val.detach().clone() + self.state[name] = val.to(device) if device else val + + def apply_to(self, model: torch.nn.Module): + """Apply state to `model` from this object""" + with torch.no_grad(): + for name, val in self.get_model_state_iterator(model): + assert ( + name in self.state + ), f"Name {name} not existed, available names {self.state.keys()}" + val.copy_(self.state[name]) + + @contextmanager + def apply_and_restore(self, model): + old_state = EMAState.FromModel(model, self.device) + self.apply_to(model) + yield old_state + old_state.apply_to(model) + + def get_ema_model(self, model): + ret = copy.deepcopy(model) + self.apply_to(ret) + return ret + + @property + def device(self): + if not self.has_inited(): + return None + return next(iter(self.state.values())).device + + def to(self, device): + for name in self.state: + self.state[name] = self.state[name].to(device) + return self + + def has_inited(self): + return self.state + + def clear(self): + self.state.clear() + return self + + def get_model_state_iterator(self, model): + param_iter = model.named_parameters() + buffer_iter = model.named_buffers() + return itertools.chain(param_iter, buffer_iter) + + def state_dict(self): + return self.state + + def load_state_dict(self, state_dict, strict: bool = True): + self.clear() + for x, y in state_dict.items(): + self.state[x] = y + return torch.nn.modules.module._IncompatibleKeys( + missing_keys=[], unexpected_keys=[] + ) + + def __repr__(self): + ret = f"EMAState(state=[{','.join(self.state.keys())}])" + return ret + + +class EMAUpdater(object): + """Model Exponential Moving Average + Keep a moving average of everything in the model state_dict (parameters and + buffers). This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + Note: It's very important to set EMA for ALL network parameters (instead of + parameters that require gradient), including batch-norm moving average mean + and variance. This leads to significant improvement in accuracy. + For example, for EfficientNetB3, with default setting (no mixup, lr exponential + decay) without bn_sync, the EMA accuracy with EMA on params that requires + gradient is 79.87%, while the corresponding accuracy with EMA on all params + is 80.61%. + Also, bn sync should be switched on for EMA. + """ + + def __init__(self, state: EMAState, decay: float = 0.999, device: str = ""): + self.decay = decay + self.device = device + + self.state = state + + def init_state(self, model): + self.state.clear() + self.state.save_from(model, self.device) + + def update(self, model): + with torch.no_grad(): + ema_param_list = [] + param_list = [] + for name, val in self.state.get_model_state_iterator(model): + ema_val = self.state.state[name] + if self.device: + val = val.to(self.device) + if val.dtype in [torch.float32, torch.float16]: + ema_param_list.append(ema_val) + param_list.append(val) + else: + ema_val.copy_(ema_val * self.decay + val * (1.0 - self.decay)) + self._ema_avg(ema_param_list, param_list, self.decay) + + def _ema_avg( + self, + averaged_model_parameters: List[torch.Tensor], + model_parameters: List[torch.Tensor], + decay: float, + ) -> None: + """ + Function to perform exponential moving average: + x_avg = alpha * x_avg + (1-alpha)* x_t + """ + torch._foreach_mul_(averaged_model_parameters, decay) + torch._foreach_add_( + averaged_model_parameters, model_parameters, alpha=1 - decay + ) + + +def _remove_ddp(model): + from torch.nn.parallel import DistributedDataParallel + + if isinstance(model, DistributedDataParallel): + return model.module + return model + + +def may_build_model_ema(cfg, model): + if not cfg.train.model_ema.enabled: + return + model = _remove_ddp(model) + assert not hasattr( + model, "ema_state" + ), "Name `ema_state` is reserved for model ema." + model.ema_state = EMAState() + logger.info("Using Model EMA.") + + +def may_get_ema_checkpointer(cfg, model): + if not cfg.train.model_ema.enabled: + return {} + model = _remove_ddp(model) + return {"ema_state": model.ema_state} + + +def get_model_ema_state(model): + """Return the ema state stored in `model`""" + model = _remove_ddp(model) + assert hasattr(model, "ema_state") + ema = model.ema_state + return ema + + +def apply_model_ema(model, state=None, save_current=False): + """Apply ema stored in `model` to model and returns a function to restore + the weights are applied + """ + model = _remove_ddp(model) + + if state is None: + state = get_model_ema_state(model) + + if save_current: + # save current model state + old_state = EMAState.FromModel(model, state.device) + state.apply_to(model) + + if save_current: + return old_state + return None + + +@contextmanager +def apply_model_ema_and_restore(model, state=None): + """Apply ema stored in `model` to model and returns a function to restore + the weights are applied + """ + model = _remove_ddp(model) + + if state is None: + state = get_model_ema_state(model) + + old_state = EMAState.FromModel(model, state.device) + state.apply_to(model) + yield old_state + old_state.apply_to(model) + + +class EMAHook(HookBase): + def __init__(self, cfg, model): + model = _remove_ddp(model) + assert cfg.train.model_ema.enabled + assert hasattr( + model, "ema_state" + ), "Call `may_build_model_ema` first to initilaize the model ema" + self.model = model + self.ema = self.model.ema_state + self.device = cfg.train.model_ema.device or cfg.model.device + self.ema_updater = EMAUpdater( + self.model.ema_state, decay=cfg.train.model_ema.decay, device=self.device + ) + + def before_train(self): + if self.ema.has_inited(): + self.ema.to(self.device) + else: + self.ema_updater.init_state(self.model) + + def after_train(self): + pass + + def before_step(self): + pass + + def after_step(self): + if not self.model.train: + return + self.ema_updater.update(self.model) \ No newline at end of file diff --git a/tools/train_net.py b/tools/train_net.py index 9e4ffccd..8404960f 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -39,6 +39,7 @@ ) from detrex.utils import WandbWriter +from detrex.modeling import ema sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) @@ -145,12 +146,23 @@ def load_state_dict(self, state_dict): def do_test(cfg, model): - if "evaluator" in cfg.dataloader: - ret = inference_on_dataset( - model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) - ) - print_csv_format(ret) - return ret + logger = logging.getLogger("detectron2") + if cfg.train.model_ema.enabled: + logger.info("Run evaluation with EMA.") + with ema.apply_model_ema_and_restore(model): + if "evaluator" in cfg.dataloader: + ret = inference_on_dataset( + model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) + ) + print_csv_format(ret) + return ret + else: + if "evaluator" in cfg.dataloader: + ret = inference_on_dataset( + model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) + ) + print_csv_format(ret) + return ret def do_train(args, cfg): @@ -176,14 +188,20 @@ def do_train(args, cfg): logger = logging.getLogger("detectron2") logger.info("Model:\n{}".format(model)) model.to(cfg.train.device) - + + # instantiate optimizer cfg.optimizer.params.model = model optim = instantiate(cfg.optimizer) + # build training loader train_loader = instantiate(cfg.dataloader.train) - + + # create ddp model model = create_ddp_model(model, **cfg.train.ddp) + # build model ema + ema.may_build_model_ema(cfg, model) + trainer = Trainer( model=model, dataloader=train_loader, @@ -191,11 +209,15 @@ def do_train(args, cfg): amp=cfg.train.amp.enabled, clip_grad_params=cfg.train.clip_grad.params if cfg.train.clip_grad.enabled else None, ) - + + kwargs = {} + kwargs.update(ema.may_get_ema_checkpointer(cfg, model)) checkpointer = DetectionCheckpointer( model, cfg.train.output_dir, trainer=trainer, + # save model ema + **kwargs ) if comm.is_main_process(): @@ -214,6 +236,7 @@ def do_train(args, cfg): trainer.register_hooks( [ hooks.IterationTimer(), + ema.EMAHook(cfg, model) if cfg.train.model_ema.enabled else None, hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) if comm.is_main_process() @@ -253,7 +276,15 @@ def main(args): model = instantiate(cfg.model) model.to(cfg.train.device) model = create_ddp_model(model) - DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + + # using ema for evaluation + ema.may_build_model_ema(cfg, model) + kwargs = {} + kwargs.update(ema.may_get_ema_checkpointer(cfg, model)) + DetectionCheckpointer(model, **kwargs).load(cfg.train.init_checkpoint) + # Apply ema state for evaluation + if cfg.train.model_ema.enabled and cfg.train.model_ema.use_ema_weights_for_eval_only: + ema.apply_model_ema(model) print(do_test(cfg, model)) else: do_train(args, cfg) From c8ec8edaaaf6afe98515293f8e5ad72a34d6958a Mon Sep 17 00:00:00 2001 From: rentainhe <596106517@qq.com> Date: Mon, 6 Feb 2023 14:52:48 +0800 Subject: [PATCH 2/4] refine ema usage --- configs/common/train.py | 2 +- tools/train_net.py | 49 +++++++++++++++++++++++++++-------------- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/configs/common/train.py b/configs/common/train.py index bd0f8dbb..f729fc89 100644 --- a/configs/common/train.py +++ b/configs/common/train.py @@ -48,7 +48,7 @@ enabled=False, decay=0.999, device="", - use_ema_weights_for_eval_only=True, + use_ema_weights_for_eval_only=False, ), # the training device, choose from {"cuda", "cpu"} device="cuda", diff --git a/tools/train_net.py b/tools/train_net.py index 8404960f..de4bfcae 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -145,24 +145,39 @@ def load_state_dict(self, state_dict): self.grad_scaler.load_state_dict(state_dict["grad_scaler"]) -def do_test(cfg, model): +def do_test(cfg, model, eval_only=False): logger = logging.getLogger("detectron2") - if cfg.train.model_ema.enabled: - logger.info("Run evaluation with EMA.") - with ema.apply_model_ema_and_restore(model): - if "evaluator" in cfg.dataloader: - ret = inference_on_dataset( - model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) - ) - print_csv_format(ret) - return ret - else: + + if eval_only: + logger.info("Run evaluation under eval-only mode") + if cfg.train.model_ema.enabled and cfg.train.model_ema.use_ema_weights_for_eval_only: + logger.info("Run evaluation with EMA.") + else: + logger.info("Run evaluation without EMA.") if "evaluator" in cfg.dataloader: - ret = inference_on_dataset( - model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) - ) - print_csv_format(ret) - return ret + ret = inference_on_dataset( + model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) + ) + print_csv_format(ret) + return ret + + logger.info("Run evaluation without EMA.") + if "evaluator" in cfg.dataloader: + ret = inference_on_dataset( + model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) + ) + print_csv_format(ret) + + if cfg.train.model_ema.enabled: + logger.info("Run evaluation with EMA.") + with ema.apply_model_ema_and_restore(model): + if "evaluator" in cfg.dataloader: + ema_ret = inference_on_dataset( + model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) + ) + print_csv_format(ema_ret) + ret.update(ema_ret) + return ret def do_train(args, cfg): @@ -285,7 +300,7 @@ def main(args): # Apply ema state for evaluation if cfg.train.model_ema.enabled and cfg.train.model_ema.use_ema_weights_for_eval_only: ema.apply_model_ema(model) - print(do_test(cfg, model)) + print(do_test(cfg, model, eval_only=True)) else: do_train(args, cfg) From d81090ebc0711bcf7f786f7eba3c0960b77f532b Mon Sep 17 00:00:00 2001 From: rentainhe <596106517@qq.com> Date: Tue, 7 Feb 2023 10:43:18 +0800 Subject: [PATCH 3/4] refine code --- tools/train_net.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tools/train_net.py b/tools/train_net.py index de4bfcae..ff39672a 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -225,14 +225,12 @@ def do_train(args, cfg): clip_grad_params=cfg.train.clip_grad.params if cfg.train.clip_grad.enabled else None, ) - kwargs = {} - kwargs.update(ema.may_get_ema_checkpointer(cfg, model)) checkpointer = DetectionCheckpointer( model, cfg.train.output_dir, trainer=trainer, # save model ema - **kwargs + **ema.may_get_ema_checkpointer(cfg, model) ) if comm.is_main_process(): @@ -294,9 +292,7 @@ def main(args): # using ema for evaluation ema.may_build_model_ema(cfg, model) - kwargs = {} - kwargs.update(ema.may_get_ema_checkpointer(cfg, model)) - DetectionCheckpointer(model, **kwargs).load(cfg.train.init_checkpoint) + DetectionCheckpointer(model, **ema.may_get_ema_checkpointer(cfg, model)).load(cfg.train.init_checkpoint) # Apply ema state for evaluation if cfg.train.model_ema.enabled and cfg.train.model_ema.use_ema_weights_for_eval_only: ema.apply_model_ema(model) From fdd78b37262accbf407aebc4211cd729dfd0bdc8 Mon Sep 17 00:00:00 2001 From: rentainhe <596106517@qq.com> Date: Wed, 8 Feb 2023 01:35:04 +0800 Subject: [PATCH 4/4] update ema links --- docs/source/tutorials/Model_Zoo.md | 8 ++++++++ projects/dino/README.md | 9 +++++++++ 2 files changed, 17 insertions(+) diff --git a/docs/source/tutorials/Model_Zoo.md b/docs/source/tutorials/Model_Zoo.md index 5eaa635e..ceaacf78 100644 --- a/docs/source/tutorials/Model_Zoo.md +++ b/docs/source/tutorials/Model_Zoo.md @@ -146,6 +146,14 @@ Here we provides our pretrained baselines with **detrex**. And more pretrained w