Skip to content

Commit

Permalink
Add --fp16-scale-tolerance (facebookresearch#397)
Browse files Browse the repository at this point in the history
Summary:
Let's only decrease the loss scale if a large enough percentage of batches overflow.
Pull Request resolved: facebookresearch#397

Differential Revision: D13355159

Pulled By: myleott

fbshipit-source-id: e17dde73d34a639519b4348c013fdd19d2b314e6
  • Loading branch information
myleott authored and facebook-github-bot committed Dec 7, 2018
1 parent 6c006a3 commit 03ef3ab
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
15 changes: 13 additions & 2 deletions fairseq/optim/fp16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,29 @@

class DynamicLossScaler:

def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000):
def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000, tolerance=0.05):
self.loss_scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self.tolerance = tolerance
self._iter = 0
self._last_overflow_iter = -1
self._last_rescale_iter = -1
self._overflows_since_rescale = 0

def update_scale(self, overflow):
iter_since_rescale = self._iter - self._last_rescale_iter
if overflow:
self.loss_scale /= self.scale_factor
self._last_overflow_iter = self._iter
self._overflows_since_rescale += 1
pct_overflow = self._overflows_since_rescale / float(iter_since_rescale)
if pct_overflow >= self.tolerance:
self.loss_scale /= self.scale_factor
self._last_rescale_iter = self._iter
self._overflows_since_rescale = 0
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.loss_scale *= self.scale_factor
self._last_rescale_iter = self._iter
self._iter += 1

@staticmethod
Expand Down Expand Up @@ -55,6 +65,7 @@ def __init__(self, args, params, fp32_optimizer, fp32_params):
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
)

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def get_parser(desc, default_task='translation'):
help='default FP16 loss scale')
parser.add_argument('--fp16-scale-window', type=int,
help='number of updates before increasing loss scale')
parser.add_argument('--fp16-scale-tolerance', default=0.0, type=float,
help='pct of updates that can overflow before decreasing the loss scale')

# Task definitions can be found under fairseq/tasks/
parser.add_argument('--task', metavar='TASK', default=default_task,
Expand Down

0 comments on commit 03ef3ab

Please sign in to comment.