diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index b597bf59b6eada..13175b7c311507 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -152,7 +152,7 @@ def step(self, closure=None): # Decay the first and second moment running average coefficient # In-place operations to update the averages at the same time - exp_avg.mul_(beta1).add_(grad, 1.0 - beta1) + exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1.0 - beta2) denom = exp_avg_sq.sqrt().add_(group["eps"]) @@ -173,6 +173,6 @@ def step(self, closure=None): # of the weights to the loss with plain (non-momentum) SGD. # Add weight decay at the end (fixed version) if group["weight_decay"] > 0.0: - p.data.add_(p.data, -group["lr"] * group["weight_decay"]) + p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) return loss