From f5cce75e7061a6736f3021fd5a5a0c744704ed4f Mon Sep 17 00:00:00 2001 From: Shaden Smith Date: Wed, 16 Sep 2020 14:26:30 -0700 Subject: [PATCH] Overflow fix (#416) * Switches fused_optimizer overflow calculation --- deepspeed/runtime/fp16/fused_optimizer.py | 43 +++++++++++++---------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index d2f7870008e1..8c1d2003cb1b 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -11,7 +11,7 @@ from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE -from deepspeed.utils import logger +from deepspeed.utils import logger, log_dist class FP16_Optimizer(object): @@ -204,9 +204,30 @@ def step(self, closure=None): UPDATE_FP16 = 'update_fp16' STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16] - # First compute norm for all group so we know if there is overflow - grads_groups_flat = [] + # First determine if there is overflow. + self.start_timers([OVERFLOW_CHECK]) + fp16_params = [] + for i, group in enumerate(self.fp16_groups): + fp16_params.extend([p for p in group if p.grad is not None]) + self.overflow = self.overflow_checker.has_overflow(fp16_params) + self.stop_timers([OVERFLOW_CHECK]) + prev_scale = self.cur_scale + self._update_scale(self.overflow) + if self.overflow: + if self.verbose: + log_dist( + "Overflow detected. Skipping step. Attempted loss " + f"scale: {prev_scale}, reducing to {self.cur_scale}", + ranks=[0]) + # Clear gradients + for i, group in enumerate(self.fp16_groups): + for p in group: + p.grad = None + + self.log_timers(OVERFLOW_TIMERS) + return self.overflow + grads_groups_flat = [] for i, group in enumerate(self.fp16_groups): data_type = self.fp32_groups_flat[i].dtype @@ -227,22 +248,6 @@ def step(self, closure=None): all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) self.stop_timers([COMPUTE_NORM]) - self.start_timers([OVERFLOW_CHECK]) - self.overflow = self.overflow_checker.check_using_norm([all_groups_norm]) - self.stop_timers([OVERFLOW_CHECK]) - - prev_scale = self.cur_scale - self._update_scale(self.overflow) - - if self.overflow: - if self.verbose: - print("[deepspeed] OVERFLOW! Skipping step. Attempted loss " - "scale: {}, reducing to {} ".format(prev_scale, - self.cur_scale)) - self.log_timers(OVERFLOW_TIMERS) - grads_groups_flat = None - return self.overflow - self.start_timers([UNSCALE_AND_CLIP]) self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm]) self.stop_timers([UNSCALE_AND_CLIP])