Skip to content

Commit

Permalink
Merge remote-tracking branch 'emav3/faster_ema' into mesa_ema
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jan 27, 2024
2 parents 88889de + 3491506 commit d0ff315
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions timm/utils/model_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,25 +102,34 @@ class ModelEmaV2(nn.Module):
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(self, model, decay=0.9999, device=None):
def __init__(self, model, decay=0.9999, device=None, foreach=False):
super(ModelEmaV2, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
self.foreach = foreach
self.device = device # perform ema on different device from model if set
if self.device is not None:
if self.device is not None and device != next(model.parameters()).device:
self.foreach = False # cannot use foreach methods with different devices
self.module.to(device=device)

def _update(self, model, update_fn):
with torch.no_grad():
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_v.copy_(update_fn(ema_v, model_v))

@torch.no_grad()
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
ema_params = tuple(self.module.parameters())
model_params = tuple(model.parameters())
if self.foreach:
torch._foreach_mul_(ema_params, scalar=self.decay)
torch._foreach_add_(ema_params, model_params, alpha=1 - self.decay)
else:
for ema_p, model_p in zip(ema_params, model_params):
ema_p.mul_(self.decay).add_(model_p.to(device=self.device), alpha=1 - self.decay)

# copy buffers instead of EMA
for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
ema_b.copy_(model_b.to(device=self.device))

@torch.no_grad()
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
ema_v.copy_(model_v.to(device=self.device))

0 comments on commit d0ff315

Please sign in to comment.