Skip to content

Commit

Permalink
ModelEmaV3 and MESA experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Feb 2, 2024
1 parent d0ff315 commit dd84ef2
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 18 deletions.
2 changes: 1 addition & 1 deletion timm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg, ParseKwargs
from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model
from .model_ema import ModelEma, ModelEmaV2
from .model_ema import ModelEma, ModelEmaV2, ModelEmaV3
from .random import random_seed
from .summary import update_summary, get_outdir
136 changes: 122 additions & 14 deletions timm/utils/model_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
from collections import OrderedDict
from copy import deepcopy
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -102,32 +103,139 @@ class ModelEmaV2(nn.Module):
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(self, model, decay=0.9999, device=None, foreach=False):
super(ModelEmaV2, self).__init__()
def __init__(self, model, decay=0.9999, device=None):
super().__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if self.device is not None:
self.module.to(device=device)

def _update(self, model, update_fn):
with torch.no_grad():
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_v.copy_(update_fn(ema_v, model_v))

def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

def set(self, model):
self._update(model, update_fn=lambda e, m: m)


class ModelEmaV3(nn.Module):
""" Model Exponential Moving Average V3
Keep a moving average of everything in the model state_dict (parameters and buffers).
V3 of this module leverages for_each and in-place operations for faster performance.
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
disable validation of the EMA weights. Validation will have to be done manually in a separate
process, or after the training stops converging.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(
self,
model,
decay: float = 0.9999,
min_decay: float = 0.0,
update_after_step: int = 0,
use_warmup: bool = False,
warmup_gamma: float = 1.0,
warmup_power: float = 2/3,
device: Optional[torch.device] = None,
foreach: bool = True,
exclude_buffers: bool = False,
):
super().__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
self.min_decay = min_decay
self.update_after_step = update_after_step
self.use_warmup = use_warmup
self.warmup_gamma = warmup_gamma
self.warmup_power = warmup_power
self.foreach = foreach
self.device = device # perform ema on different device from model if set
self.exclude_buffers = exclude_buffers
if self.device is not None and device != next(model.parameters()).device:
self.foreach = False # cannot use foreach methods with different devices
self.module.to(device=device)

@torch.no_grad()
def update(self, model):
ema_params = tuple(self.module.parameters())
model_params = tuple(model.parameters())
if self.foreach:
torch._foreach_mul_(ema_params, scalar=self.decay)
torch._foreach_add_(ema_params, model_params, alpha=1 - self.decay)
def get_decay(self, step: Optional[int] = None) -> float:
"""
Compute the decay factor for the exponential moving average.
"""
if step is None:
return self.decay

step = max(0, step - self.update_after_step - 1)
if step <= 0:
return 0.0

if self.use_warmup:
decay = 1 - (1 + step / self.warmup_gamma) ** -self.warmup_power
decay = max(min(decay, self.decay), self.min_decay)
else:
for ema_p, model_p in zip(ema_params, model_params):
ema_p.mul_(self.decay).add_(model_p.to(device=self.device), alpha=1 - self.decay)
decay = self.decay

return decay

# copy buffers instead of EMA
for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
ema_b.copy_(model_b.to(device=self.device))
@torch.no_grad()
def update(self, model, step: Optional[int] = None):
decay = self.get_decay(step)

if self.exclude_buffers:
# interpolate parameters
ema_params = tuple(self.module.parameters())
model_params = tuple(model.parameters())
if self.foreach:
if hasattr(torch, '_foreach_lerp_'):
torch._foreach_lerp_(ema_params, model_params, weight=1. - decay)
else:
torch._foreach_mul_(ema_params, scalar=decay)
torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
else:
for ema_p, model_p in zip(ema_params, model_params):
ema_p.lerp_(model_p, weight=1. - decay)

# copy buffers instead of EMA
for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
ema_b.copy_(model_b.to(device=self.device))
else:
# interpolate parameters and buffers
if self.foreach:
ema_lerp_values = []
model_lerp_values = []
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if ema_v.is_floating_point():
ema_lerp_values.append(ema_v)
model_lerp_values.append(model_v)
else:
ema_v.copy_(model_v)

if hasattr(torch, '_foreach_lerp_'):
torch._foreach_lerp_(ema_lerp_values, model_lerp_values, weight=1. - decay)
else:
torch._foreach_mul_(ema_lerp_values, scalar=decay)
torch._foreach_add_(ema_lerp_values, model_lerp_values, alpha=1. - decay)
else:
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if ema_v.is_floating_point():
ema_v.lerp_(model_v, weight=1. - decay)
else:
ema_v.copy_(model_v)

@torch.no_grad()
def set(self, model):
Expand Down
33 changes: 30 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,12 @@ def main():
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
model_ema = utils.ModelEmaV2(
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
model_ema = utils.ModelEmaV3(
model,
decay=args.model_ema_decay,
use_warmup=True,
device='cpu' if args.model_ema_force_cpu else None,
)
if args.resume:
load_checkpoint(model_ema.module, args.resume, use_ema=True)

Expand Down Expand Up @@ -847,6 +851,7 @@ def main():
loss_scaler=loss_scaler,
model_ema=model_ema,
mixup_fn=mixup_fn,
num_updates_total=num_epochs * updates_per_epoch,
)

if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
Expand Down Expand Up @@ -935,6 +940,7 @@ def train_one_epoch(
loss_scaler=None,
model_ema=None,
mixup_fn=None,
num_updates_total=None,
):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
Expand Down Expand Up @@ -981,6 +987,27 @@ def _forward():
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)

if num_updates / num_updates_total > 0.25:
with torch.no_grad():
output_mesa = model_ema.module(input)

# loss_mesa = torch.nn.functional.binary_cross_entropy_with_logits(
# output,
# torch.sigmoid(output_mesa).detach(),
# reduction='none',
# ).mean()

# loss_mesa = loss_fn(
# output, torch.sigmoid(output_mesa).detach())

loss_mesa = torch.nn.functional.kl_div(
(output / 5).log_softmax(-1),
(output_mesa / 5).log_softmax(-1).detach(),
log_target=True,
reduction='none').sum(-1).mean()
loss += 5 * loss_mesa

if accum_steps > 1:
loss /= accum_steps
return loss
Expand Down Expand Up @@ -1026,7 +1053,7 @@ def _backward(_loss):
num_updates += 1
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model)
model_ema.update(model, step=num_updates)

if args.synchronize_step and device.type == 'cuda':
torch.cuda.synchronize()
Expand Down

0 comments on commit dd84ef2

Please sign in to comment.