-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FP16 optimizer automatically detect DeepSpeed compatibility (#18084)
### FP16 optimizer automatically detect DeepSpeed compatibility Optimum/Transformers are using accelerate lib to prepare models, so our FP16 optimizer wrapper does not work for long time. Because the namespace is `accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper`, which underlying is still calling into DeepSpeed stage1and2 optimizer. This PR includes following changes: 1. Add `accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper` in the modifier registry, plus a check on its contained `optimizer` property MUST be DeepSpeed stage 1 and 2 optimizer. (let's cover Stage 3 optimizer later) 2. For DeepSpeed version > 0.9.1, we will store the source code in a version list. As long as the related function in DeepSpeed remains unchanged during its new release, we won't need manually upgrade the version check any more. If some day, the source code did not match, a warning will be raised to users, to add a new version of source code in the list. With the above change, we will have our FP16 Optimizer working again in Optimum. ![image](https://github.com/microsoft/onnxruntime/assets/10530022/d35b4aa9-b371-46f1-98ae-73114f91179b)
- Loading branch information
Showing
5 changed files
with
223 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
81 changes: 81 additions & 0 deletions
81
orttraining/orttraining/python/training/optim/_ds_code_store.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# | ||
# Copyright 2020 The Microsoft DeepSpeed Team | ||
# | ||
# !!!IMPORTANT: This file is a copy of the original one in DeepSpeed repo at given version, | ||
# It is used to compare with the source code of current installed DeepSpeed during runtime. | ||
# Please don't modify it or do any code formatting for it. | ||
# 'orttraining/orttraining/python/training/optim/_ds_code_store.py' is removed from lintrunner config by intention. | ||
# -------------------------------------------------------------------------- | ||
|
||
# Wrap code in this to make sure the indentation is correct compared with raw DeepSpeed. | ||
|
||
class Stage1And2_DeepSpeedZeroOptimizer_0_9_2: | ||
|
||
def has_overflow_serial(self, params, is_grad_list=False): | ||
for p in params: | ||
if p.grad is not None and self._has_inf_or_nan(p.grad.data): | ||
return True | ||
|
||
return False | ||
|
||
|
||
def get_grad_norm_direct(self, gradients, params, norm_type=2): | ||
"""Clips gradient norm of an iterable of parameters. | ||
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and | ||
added functionality to handle model parallel parameters. Note that | ||
the gradients are modified in place. | ||
Arguments: | ||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a | ||
single Tensor that will have gradients normalized | ||
max_norm (float or int): max norm of the gradients | ||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for | ||
infinity norm. | ||
Returns: | ||
Total norm of the parameters (viewed as a single vector). | ||
""" | ||
norm_type = float(norm_type) | ||
if norm_type == inf: | ||
total_norm = max(g.data.abs().max() for g in gradients) | ||
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) | ||
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group) | ||
|
||
# Take max across all GPUs. | ||
self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX) | ||
total_norm = total_norm_cuda[0].item() | ||
else: | ||
total_norm = 0.0 | ||
# if dist.get_rank() == 0: | ||
# logger.info(f"Total Norm beginning {total_norm}") | ||
for g, p in zip(gradients, params): | ||
# Pipeline parallelism may replicate parameters. Avoid multi-counting. | ||
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: | ||
continue | ||
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): | ||
param_norm = g.data.double().norm(2) | ||
total_norm += param_norm.item()**2 | ||
# Sum across all model parallel GPUs. | ||
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) | ||
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) | ||
|
||
self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) | ||
|
||
total_norm = total_norm_cuda[0].item()**(1. / norm_type) | ||
|
||
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: | ||
total_norm = -1 | ||
|
||
return total_norm | ||
|
||
|
||
def has_overflow_partitioned_grads_serial(self): | ||
for i in range(len(self.bit16_groups)): | ||
for j, grad in enumerate(self.averaged_gradients[i]): | ||
if grad is not None and self._has_inf_or_nan(grad.data, j): | ||
return True | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters