Skip to content

Commit

Permalink
Overflow fix (#416)
Browse files Browse the repository at this point in the history
* Switches fused_optimizer overflow calculation
  • Loading branch information
Shaden Smith authored Sep 16, 2020
1 parent 7d91be9 commit f5cce75
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger
from deepspeed.utils import logger, log_dist


class FP16_Optimizer(object):
Expand Down Expand Up @@ -204,9 +204,30 @@ def step(self, closure=None):
UPDATE_FP16 = 'update_fp16'
STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16]

# First compute norm for all group so we know if there is overflow
grads_groups_flat = []
# First determine if there is overflow.
self.start_timers([OVERFLOW_CHECK])
fp16_params = []
for i, group in enumerate(self.fp16_groups):
fp16_params.extend([p for p in group if p.grad is not None])
self.overflow = self.overflow_checker.has_overflow(fp16_params)
self.stop_timers([OVERFLOW_CHECK])
prev_scale = self.cur_scale
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
log_dist(
"Overflow detected. Skipping step. Attempted loss "
f"scale: {prev_scale}, reducing to {self.cur_scale}",
ranks=[0])
# Clear gradients
for i, group in enumerate(self.fp16_groups):
for p in group:
p.grad = None

self.log_timers(OVERFLOW_TIMERS)
return self.overflow

grads_groups_flat = []
for i, group in enumerate(self.fp16_groups):
data_type = self.fp32_groups_flat[i].dtype

Expand All @@ -227,22 +248,6 @@ def step(self, closure=None):
all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu)
self.stop_timers([COMPUTE_NORM])

self.start_timers([OVERFLOW_CHECK])
self.overflow = self.overflow_checker.check_using_norm([all_groups_norm])
self.stop_timers([OVERFLOW_CHECK])

prev_scale = self.cur_scale
self._update_scale(self.overflow)

if self.overflow:
if self.verbose:
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {} ".format(prev_scale,
self.cur_scale))
self.log_timers(OVERFLOW_TIMERS)
grads_groups_flat = None
return self.overflow

self.start_timers([UNSCALE_AND_CLIP])
self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm])
self.stop_timers([UNSCALE_AND_CLIP])
Expand Down

0 comments on commit f5cce75

Please sign in to comment.