From dd84ef2cd5212ff8a7c1ea141e2f8bc159c1788a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 2 Feb 2024 09:45:04 -0800 Subject: [PATCH] ModelEmaV3 and MESA experiments --- timm/utils/__init__.py | 2 +- timm/utils/model_ema.py | 136 +++++++++++++++++++++++++++++++++++----- train.py | 33 +++++++++- 3 files changed, 153 insertions(+), 18 deletions(-) diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 63fcf4c5b4..4c6a00cad5 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -10,6 +10,6 @@ from .metrics import AverageMeter, accuracy from .misc import natural_key, add_bool_arg, ParseKwargs from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model -from .model_ema import ModelEma, ModelEmaV2 +from .model_ema import ModelEma, ModelEmaV2, ModelEmaV3 from .random import random_seed from .summary import update_summary, get_outdir diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index 5cefe08bf3..968f4f580c 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -5,6 +5,7 @@ import logging from collections import OrderedDict from copy import deepcopy +from typing import Optional import torch import torch.nn as nn @@ -102,32 +103,139 @@ class ModelEmaV2(nn.Module): This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers. """ - def __init__(self, model, decay=0.9999, device=None, foreach=False): - super(ModelEmaV2, self).__init__() + def __init__(self, model, decay=0.9999, device=None): + super().__init__() # make a copy of the model for accumulating moving average of weights self.module = deepcopy(model) self.module.eval() self.decay = decay + self.device = device # perform ema on different device from model if set + if self.device is not None: + self.module.to(device=device) + + def _update(self, model, update_fn): + with torch.no_grad(): + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) + + +class ModelEmaV3(nn.Module): + """ Model Exponential Moving Average V3 + + Keep a moving average of everything in the model state_dict (parameters and buffers). + V3 of this module leverages for_each and in-place operations for faster performance. + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__( + self, + model, + decay: float = 0.9999, + min_decay: float = 0.0, + update_after_step: int = 0, + use_warmup: bool = False, + warmup_gamma: float = 1.0, + warmup_power: float = 2/3, + device: Optional[torch.device] = None, + foreach: bool = True, + exclude_buffers: bool = False, + ): + super().__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + self.decay = decay + self.min_decay = min_decay + self.update_after_step = update_after_step + self.use_warmup = use_warmup + self.warmup_gamma = warmup_gamma + self.warmup_power = warmup_power self.foreach = foreach self.device = device # perform ema on different device from model if set + self.exclude_buffers = exclude_buffers if self.device is not None and device != next(model.parameters()).device: self.foreach = False # cannot use foreach methods with different devices self.module.to(device=device) - @torch.no_grad() - def update(self, model): - ema_params = tuple(self.module.parameters()) - model_params = tuple(model.parameters()) - if self.foreach: - torch._foreach_mul_(ema_params, scalar=self.decay) - torch._foreach_add_(ema_params, model_params, alpha=1 - self.decay) + def get_decay(self, step: Optional[int] = None) -> float: + """ + Compute the decay factor for the exponential moving average. + """ + if step is None: + return self.decay + + step = max(0, step - self.update_after_step - 1) + if step <= 0: + return 0.0 + + if self.use_warmup: + decay = 1 - (1 + step / self.warmup_gamma) ** -self.warmup_power + decay = max(min(decay, self.decay), self.min_decay) else: - for ema_p, model_p in zip(ema_params, model_params): - ema_p.mul_(self.decay).add_(model_p.to(device=self.device), alpha=1 - self.decay) + decay = self.decay + + return decay - # copy buffers instead of EMA - for ema_b, model_b in zip(self.module.buffers(), model.buffers()): - ema_b.copy_(model_b.to(device=self.device)) + @torch.no_grad() + def update(self, model, step: Optional[int] = None): + decay = self.get_decay(step) + + if self.exclude_buffers: + # interpolate parameters + ema_params = tuple(self.module.parameters()) + model_params = tuple(model.parameters()) + if self.foreach: + if hasattr(torch, '_foreach_lerp_'): + torch._foreach_lerp_(ema_params, model_params, weight=1. - decay) + else: + torch._foreach_mul_(ema_params, scalar=decay) + torch._foreach_add_(ema_params, model_params, alpha=1 - decay) + else: + for ema_p, model_p in zip(ema_params, model_params): + ema_p.lerp_(model_p, weight=1. - decay) + + # copy buffers instead of EMA + for ema_b, model_b in zip(self.module.buffers(), model.buffers()): + ema_b.copy_(model_b.to(device=self.device)) + else: + # interpolate parameters and buffers + if self.foreach: + ema_lerp_values = [] + model_lerp_values = [] + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if ema_v.is_floating_point(): + ema_lerp_values.append(ema_v) + model_lerp_values.append(model_v) + else: + ema_v.copy_(model_v) + + if hasattr(torch, '_foreach_lerp_'): + torch._foreach_lerp_(ema_lerp_values, model_lerp_values, weight=1. - decay) + else: + torch._foreach_mul_(ema_lerp_values, scalar=decay) + torch._foreach_add_(ema_lerp_values, model_lerp_values, alpha=1. - decay) + else: + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if ema_v.is_floating_point(): + ema_v.lerp_(model_v, weight=1. - decay) + else: + ema_v.copy_(model_v) @torch.no_grad() def set(self, model): diff --git a/train.py b/train.py index ba917773a0..e3b3e037a6 100755 --- a/train.py +++ b/train.py @@ -586,8 +586,12 @@ def main(): model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper - model_ema = utils.ModelEmaV2( - model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) + model_ema = utils.ModelEmaV3( + model, + decay=args.model_ema_decay, + use_warmup=True, + device='cpu' if args.model_ema_force_cpu else None, + ) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) @@ -847,6 +851,7 @@ def main(): loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, + num_updates_total=num_epochs * updates_per_epoch, ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): @@ -935,6 +940,7 @@ def train_one_epoch( loss_scaler=None, model_ema=None, mixup_fn=None, + num_updates_total=None, ): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -981,6 +987,27 @@ def _forward(): with amp_autocast(): output = model(input) loss = loss_fn(output, target) + + if num_updates / num_updates_total > 0.25: + with torch.no_grad(): + output_mesa = model_ema.module(input) + + # loss_mesa = torch.nn.functional.binary_cross_entropy_with_logits( + # output, + # torch.sigmoid(output_mesa).detach(), + # reduction='none', + # ).mean() + + # loss_mesa = loss_fn( + # output, torch.sigmoid(output_mesa).detach()) + + loss_mesa = torch.nn.functional.kl_div( + (output / 5).log_softmax(-1), + (output_mesa / 5).log_softmax(-1).detach(), + log_target=True, + reduction='none').sum(-1).mean() + loss += 5 * loss_mesa + if accum_steps > 1: loss /= accum_steps return loss @@ -1026,7 +1053,7 @@ def _backward(_loss): num_updates += 1 optimizer.zero_grad() if model_ema is not None: - model_ema.update(model) + model_ema.update(model, step=num_updates) if args.synchronize_step and device.type == 'cuda': torch.cuda.synchronize()