Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Get rid of monkey patching in LossScaler overflow handling (#18959) (#…
Browse files Browse the repository at this point in the history
…18973)

Co-authored-by: Vladimir Cherepanov <[email protected]>

Co-authored-by: Vladimir Cherepanov <[email protected]>
  • Loading branch information
mk-61 and Vladimir Cherepanov authored Aug 24, 2020
1 parent 9445a2d commit dfefe87
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 40 deletions.
16 changes: 0 additions & 16 deletions python/mxnet/contrib/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
'list_widest_type_cast', 'list_loss_output_functions', 'list_lp16_use_fp32_params',
'convert_symbol']

from types import MethodType
from array import array
import ctypes
import logging
Expand Down Expand Up @@ -341,21 +340,6 @@ def init_trainer(optimizer_or_trainer):
if isinstance(optimizer_or_trainer, trainer.Trainer):
optimizer_or_trainer._amp_loss_scaler = loss_scaler
optimizer_or_trainer._amp_original_scale = optimizer_or_trainer._scale
skip_update = optimizer_or_trainer._amp_loss_scaler.wait_and_update
optimizer_or_trainer._optimizer.old_update_multi_precision = \
optimizer_or_trainer._optimizer.update_multi_precision
def new_update_multi_precision(self, index, weight, grad, state):
if not skip_update():
self.old_update_multi_precision(index, weight, grad, state)
optimizer_or_trainer._optimizer.update_multi_precision = \
MethodType(new_update_multi_precision, optimizer_or_trainer._optimizer)
launch_check_overflow = optimizer_or_trainer._amp_loss_scaler.launch_check_overflow
optimizer_or_trainer._old_update = optimizer_or_trainer._update
def new_update(self, ignore_stale_grad=False):
launch_check_overflow(self._params)
self._old_update(ignore_stale_grad)
optimizer_or_trainer._update = MethodType(new_update, optimizer_or_trainer)

elif isinstance(optimizer_or_trainer, opt.Optimizer):
# TODO(ptredak): make it work with the optimizer
raise TypeError("AMP is currently only compatible with Gluon Trainer")
Expand Down
39 changes: 15 additions & 24 deletions python/mxnet/contrib/amp/loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,13 @@ def __init__(self):
self._max_loss_scale = 2.**24
self._scale_seq_len = 2000
self._unskipped = 0
self._has_overflow = False

@property
def loss_scale(self):
return self._loss_scale

def launch_check_overflow(self, params):
"""Launch overflow checking for gradients."""
self._wait_for_outputs = True
self._has_overflow = False
def has_overflow(self, params):
"""Check gradients for overflow."""
with ag.pause():
chunk_size = 200
valid_params = [p._grad[0] for p in params if p._grad is not None]
Expand All @@ -56,22 +53,16 @@ def launch_check_overflow(self, params):
multi_all_finite(*valid_params[idx:idx+chunk_size],
num_arrays=len(valid_params[idx:idx+chunk_size]),
init_output=False, out=gpu_output)
self.output = gpu_output

def wait_and_update(self):
"""Wait for the results of overflow checking and update the loss scale."""
if self._wait_for_outputs:
self._has_overflow = not bool(self.output.asnumpy())
self._loss_scale = self._next_loss_scale
if self._has_overflow:
self._next_loss_scale = self._loss_scale / 2.
self._unskipped = 0
logging.info("AMP: decreasing loss scale to %f", self._next_loss_scale)
else:
self._unskipped += 1
if self._unskipped == self._scale_seq_len:
self._unskipped = 0
self._next_loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
logging.info("AMP: increasing loss scale to %f", self._next_loss_scale)
self._wait_for_outputs = False
return self._has_overflow
has_overflow = not bool(gpu_output.asnumpy())
self._loss_scale = self._next_loss_scale
if has_overflow:
self._next_loss_scale = self._loss_scale / 2.
self._unskipped = 0
logging.info("AMP: decreasing loss scale to %f", self._next_loss_scale)
else:
self._unskipped += 1
if self._unskipped == self._scale_seq_len:
self._unskipped = 0
self._next_loss_scale = min(self._max_loss_scale, self._loss_scale * 2.)
logging.info("AMP: increasing loss scale to %f", self._next_loss_scale)
return has_overflow
5 changes: 5 additions & 0 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,11 @@ def update(self, batch_size, ignore_stale_grad=False):
self._update(ignore_stale_grad)

def _update(self, ignore_stale_grad=False):
loss_scaler = getattr(self, '_amp_loss_scaler', None)
if loss_scaler is not None:
if loss_scaler.has_overflow(self._params):
return # skip on overflow

updates = [[] for _ in self._updaters]

for i, param in enumerate(self._params):
Expand Down

0 comments on commit dfefe87

Please sign in to comment.