Skip to content

Commit

Permalink
FP16 optimizer automatically detect DeepSpeed compatibility (#18084)
Browse files Browse the repository at this point in the history
### FP16 optimizer automatically detect DeepSpeed compatibility

Optimum/Transformers are using accelerate lib to prepare models, so our
FP16 optimizer wrapper does not work for long time. Because the
namespace is `accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper`,
which underlying is still calling into DeepSpeed stage1and2 optimizer.

This PR includes following changes:
1. Add `accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper` in the
modifier registry, plus a check on its contained `optimizer` property
MUST be DeepSpeed stage 1 and 2 optimizer. (let's cover Stage 3
optimizer later)
2. For DeepSpeed version > 0.9.1, we will store the source code in a
version list. As long as the related function in DeepSpeed remains
unchanged during its new release, we won't need manually upgrade the
version check any more. If some day, the source code did not match, a
warning will be raised to users, to add a new version of source code in
the list.

With the above change, we will have our FP16 Optimizer working again in
Optimum.


![image](https://github.com/microsoft/onnxruntime/assets/10530022/d35b4aa9-b371-46f1-98ae-73114f91179b)
  • Loading branch information
pengwa authored Oct 25, 2023
1 parent ae85619 commit 2c6b31c
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 31 deletions.
2 changes: 2 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ exclude_patterns = [
'cmake/external/**',
# ignore generated flatbuffers code
'onnxruntime/core/flatbuffers/ort_flatbuffers_py/**',
'orttraining/orttraining/python/training/optim/_ds_code_store.py',
]
command = [
'python',
Expand Down Expand Up @@ -76,6 +77,7 @@ exclude_patterns = [
'cmake/**',
'orttraining/*',
'onnxruntime/core/flatbuffers/**',
'orttraining/orttraining/python/training/optim/_ds_code_store.py',
]
command = [
'python',
Expand Down
81 changes: 81 additions & 0 deletions orttraining/orttraining/python/training/optim/_ds_code_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# Copyright 2020 The Microsoft DeepSpeed Team
#
# !!!IMPORTANT: This file is a copy of the original one in DeepSpeed repo at given version,
# It is used to compare with the source code of current installed DeepSpeed during runtime.
# Please don't modify it or do any code formatting for it.
# 'orttraining/orttraining/python/training/optim/_ds_code_store.py' is removed from lintrunner config by intention.
# --------------------------------------------------------------------------

# Wrap code in this to make sure the indentation is correct compared with raw DeepSpeed.

class Stage1And2_DeepSpeedZeroOptimizer_0_9_2:

def has_overflow_serial(self, params, is_grad_list=False):
for p in params:
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
return True

return False


def get_grad_norm_direct(self, gradients, params, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group)

# Take max across all GPUs.
self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
# if dist.get_rank() == 0:
# logger.info(f"Total Norm beginning {total_norm}")
for g, p in zip(gradients, params):
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
continue
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2
# Sum across all model parallel GPUs.
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group)

self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)

total_norm = total_norm_cuda[0].item()**(1. / norm_type)

if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1

return total_norm


def has_overflow_partitioned_grads_serial(self):
for i in range(len(self.bit16_groups)):
for j, grad in enumerate(self.averaged_gradients[i]):
if grad is not None and self._has_inf_or_nan(grad.data, j):
return True
return False
85 changes: 79 additions & 6 deletions orttraining/orttraining/python/training/optim/_ds_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,112 @@
# - has_overflow_partitioned_grads_serial : https://github.com/microsoft/DeepSpeed/blob/d8e9ef6f99e27bb95e10bd146d145b3372b4cfda/deepspeed/runtime/zero/stage2.py#L1799
# --------------------------------------------------------------------------

from __future__ import annotations

import inspect
import types
import warnings

import torch
from numpy import inf
from packaging.version import Version

from ._ds_code_store import Stage1And2_DeepSpeedZeroOptimizer_0_9_2
from ._modifier import FP16OptimizerModifier, check_overflow, check_overflow_for_grads
from ._multi_tensor_apply import MultiTensorApply

multi_tensor_applier = MultiTensorApply(2048 * 32)


def _get_normalized_str(function) -> str:
return inspect.getsource(function)


def _dynamic_checks(cur_ds_version: Version, optimizer) -> bool:
_functions_to_override = ["has_overflow_serial", "get_grad_norm_direct", "has_overflow_partitioned_grads_serial"]

_version_to_source_code_map = {"0.9.2": Stage1And2_DeepSpeedZeroOptimizer_0_9_2}

# Try to find the biggest version that is smaller than or equal to cur_ds_version.
# then compare the source code (in case the found version is the latest version supported);
# If current code does not match the found version, return False, and raise a warning to
# add the new version to the list.
versions = [Version(v) for v in _version_to_source_code_map]
sorted_versions = sorted(versions, reverse=True)
version_to_compare = None
for sv in sorted_versions:
if cur_ds_version >= sv:
version_to_compare = sv
break

if version_to_compare is None:
warnings.warn(
"Unable to find a DeepSpeed version that is smaller than or equal to the current version "
f"{cur_ds_version}. Skip modifying optimizer.",
UserWarning,
)
return False

v_optimizer_cls = _version_to_source_code_map[str(version_to_compare)]
all_match = True
for func_name in _functions_to_override:
if not getattr(optimizer, func_name):
warnings.warn(
f"DeepSpeed function {func_name} is not found in optimizer. Skip modifying optimizer.", UserWarning
)
all_match = False
cur_code_str = _get_normalized_str(getattr(optimizer, func_name))
v_code_str = _get_normalized_str(getattr(v_optimizer_cls, func_name))
if cur_code_str != v_code_str:
warnings.warn(
f"DeepSpeed function {func_name} has changed after version {version_to_compare}. "
f"Please append new version {cur_ds_version} in _version_to_source_code_map and _ds_code_store.py.\n"
f"---[{func_name}] Old Source Code Start----\n"
f"{v_code_str}\n"
f"---{func_name} Old Source Code End----\n"
f"---[{func_name}] New Source Code Start----\n"
f"{cur_code_str}\n"
f"---{func_name} New Source Code End----",
UserWarning,
)
all_match = False

return all_match


class DeepSpeedZeROModifier(FP16OptimizerModifier):
def __init__(self, optimizer, **kwargs) -> None:
super().__init__(optimizer)

def can_be_modified(self):
import deepspeed

# Note 1:
# This modifier relies on the implementation of has_overflow_serial, get_grad_norm_direct,
# and has_overflow_partitioned_grads_serial
# in https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage_1_and_2.py.
# Everytime if we want to update this version supporting list to a newer version,
# we need to check if the implementation of these functions are changed.
# An easy way to check is to check the history of this file, if there is no change during the update,
# The minimum version supported is 0.4.0, all versions in between [0.4.0, 0.9.1]
# are manually checked to make sure the implementation of these functions are "logically" not changed.
# The way we did the check is to check the history of this file, if there is no change during the update,
# it's safe to update the version supporting list. Otherwise, or the file is moved or renamed,
# we need to check the implementation of these functions in detail.
#
# Note 2:
# Since version 0.9.2, we added dynamic source code check, by comparing installed version of code with
# the source code in our code store. If the source code is changed, we will raise a warning to ask user
# to add the new version to the code store. Otherwise, we will override the functions.

ds_version = Version(deepspeed.__version__)
if ds_version > Version("0.9.1") or ds_version < Version("0.4.0"):
if ds_version < Version("0.4.0"):
warnings.warn(
f"Skip modifying optimizer because of unsupported DeepSpeed version {ds_version}, "
"minimum supported version: 0.4.0, current version",
UserWarning,
)
return False
if ds_version > Version("0.9.1") and not _dynamic_checks(ds_version, self._optimizer):
warnings.warn(
"Skip modifying optimizer because of unsupported DeepSpeed version {}, "
"supported version: 0.4.0 - 0.9.1.".format(deepspeed.__version__),
f"Skip modifying optimizer because of unsupported DeepSpeed version {ds_version}.",
UserWarning,
)
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,59 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from __future__ import annotations

import warnings
from typing import ClassVar

from ._apex_amp_modifier import ApexAMPModifier
from ._ds_modifier import DeepSpeedZeROModifier
from ._megatron_modifier import LegacyMegatronLMModifier
from ._modifier import FP16OptimizerModifier


class _AccelerateDeepSpeedZeROModifier(DeepSpeedZeROModifier):
"""
Modifier for wrapper of DeepSpeed Optimizer in accelerator.
https://github.com/huggingface/accelerate/blob/7843286f2e1c50735d259fbc0084a7f1c85e00e3/src/accelerate/utils/deepspeed.py#L182C19-L182C19
"""

def __init__(self, accelerator_optimizer, **kwargs) -> None:
super().__init__(accelerator_optimizer.optimizer)


def get_full_qualified_type_name(o):
klass = o.__class__
module = klass.__module__
if module == "builtins":
return klass.__qualname__
return module + "." + klass.__qualname__


class OptimizerModifierTypeRegistry:
_MAP: ClassVar[dict[str, FP16OptimizerModifier]] = {
"megatron.fp16.fp16.FP16_Optimizer": LegacyMegatronLMModifier,
"deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer": DeepSpeedZeROModifier,
"deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer": DeepSpeedZeROModifier,
"apex.amp.optimizer.unique_name_as_id": ApexAMPModifier,
}

@staticmethod
def create_modifier(optimizer_full_qualified_name: str, optimizer, **kwargs) -> FP16OptimizerModifier | None:
"""Create modifier for optimizer."""
if optimizer_full_qualified_name in OptimizerModifierTypeRegistry._MAP:
return OptimizerModifierTypeRegistry._MAP[optimizer_full_qualified_name](optimizer, **kwargs)

if optimizer_full_qualified_name == "accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper":
if (
hasattr(optimizer, "optimizer")
and get_full_qualified_type_name(optimizer.optimizer) in OptimizerModifierTypeRegistry._MAP
):
return _AccelerateDeepSpeedZeROModifier(optimizer, **kwargs)

OptimizerModifierTypeRegistry = {
"megatron.fp16.fp16.FP16_Optimizer": LegacyMegatronLMModifier,
"deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer": DeepSpeedZeROModifier,
"deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer": DeepSpeedZeROModifier,
"apex.amp.optimizer.unique_name_as_id": ApexAMPModifier,
}
warnings.warn(
"Skip modifying optimizer because of optimizer name not found in the registry: "
f"{optimizer_full_qualified_name}",
UserWarning,
)
return None
28 changes: 9 additions & 19 deletions orttraining/orttraining/python/training/optim/fp16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import warnings

from ._modifier_registry import OptimizerModifierTypeRegistry
from ._modifier_registry import OptimizerModifierTypeRegistry, get_full_qualified_type_name


def FP16_Optimizer(optimizer, **kwargs): # noqa: N802
Expand Down Expand Up @@ -80,22 +79,13 @@ def FP16_Optimizer(optimizer, **kwargs): # noqa: N802
"""

def get_full_qualified_type_name(o):
if hasattr(optimizer, "_amp_stash"):
return "apex.amp.optimizer.unique_name_as_id"

klass = o.__class__
module = klass.__module__
if module == "builtins":
return klass.__qualname__
return module + "." + klass.__qualname__

optimizer_full_qualified_name = get_full_qualified_type_name(optimizer)
if optimizer_full_qualified_name not in OptimizerModifierTypeRegistry:
warnings.warn("Skip modifying optimizer because of optimizer name not found in registry.", UserWarning)
return optimizer

modifier = OptimizerModifierTypeRegistry[optimizer_full_qualified_name](optimizer, **kwargs)
modifier.apply()
optimizer_full_qualified_name = (
"apex.amp.optimizer.unique_name_as_id"
if hasattr(optimizer, "_amp_stash")
else get_full_qualified_type_name(optimizer)
)
modifier = OptimizerModifierTypeRegistry.create_modifier(optimizer_full_qualified_name, optimizer, **kwargs)
if modifier is not None:
modifier.apply()

return optimizer

0 comments on commit 2c6b31c

Please sign in to comment.