From 913e7d27441b015aefd51772c12d278053f289a2 Mon Sep 17 00:00:00 2001 From: xcnick Date: Fri, 18 Nov 2022 04:00:46 +0000 Subject: [PATCH 01/15] add ApexOptimWrapper --- mmengine/optim/__init__.py | 13 ++- mmengine/optim/optimizer/__init__.py | 4 +- .../optim/optimizer/apex_optimizer_wrapper.py | 48 ++++++++ .../test_optimizer/test_optimizer_wrapper.py | 104 +++++++++++++++++- 4 files changed, 161 insertions(+), 8 deletions(-) create mode 100644 mmengine/optim/optimizer/apex_optimizer_wrapper.py diff --git a/mmengine/optim/__init__.py b/mmengine/optim/__init__.py index 72118b179f..38426ea847 100644 --- a/mmengine/optim/__init__.py +++ b/mmengine/optim/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, - AmpOptimWrapper, DefaultOptimWrapperConstructor, - OptimWrapper, OptimWrapperDict, build_optim_wrapper) + AmpOptimWrapper, ApexOptimWrapper, + DefaultOptimWrapperConstructor, OptimWrapper, + OptimWrapperDict, build_optim_wrapper) # yapf: disable from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler, CosineAnnealingLR, CosineAnnealingMomentum, @@ -25,8 +26,8 @@ 'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler', 'ExponentialParamScheduler', 'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler', - '_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict', - 'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR', 'PolyMomentum', - 'PolyParamScheduler', 'ReduceOnPlateauLR', 'ReduceOnPlateauMomentum', - 'ReduceOnPlateauParamScheduler' + '_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'ApexOptimWrapper', + 'OptimWrapperDict', 'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR', + 'PolyMomentum', 'PolyParamScheduler', 'ReduceOnPlateauLR', + 'ReduceOnPlateauMomentum', 'ReduceOnPlateauParamScheduler' ] diff --git a/mmengine/optim/optimizer/__init__.py b/mmengine/optim/optimizer/__init__.py index fdc77679ba..0116951438 100644 --- a/mmengine/optim/optimizer/__init__.py +++ b/mmengine/optim/optimizer/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .amp_optimizer_wrapper import AmpOptimWrapper +from .apex_optimizer_wrapper import ApexOptimWrapper from .builder import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, build_optim_wrapper) from .default_constructor import DefaultOptimWrapperConstructor @@ -10,5 +11,6 @@ __all__ = [ 'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimWrapperConstructor', 'build_optim_wrapper', 'OptimWrapper', - 'AmpOptimWrapper', 'OptimWrapperDict', 'ZeroRedundancyOptimizer' + 'AmpOptimWrapper', 'ApexOptimWrapper', 'OptimWrapperDict', + 'ZeroRedundancyOptimizer' ] diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py new file mode 100644 index 0000000000..f2c630df41 --- /dev/null +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmengine.registry import OPTIM_WRAPPERS +from .optimizer_wrapper import OptimWrapper + +try: + import apex.amp as apex_amp +except ImportError: + pass + + +@OPTIM_WRAPPERS.register_module() +class ApexOptimWrapper(OptimWrapper): + """A subclass of :class:`OptimWrapper` that supports automatic mixed + precision training based on apex.amp. + + ``ApexOptimWrapper`` provides a unified interface with + ``OptimWrapper``, so ``ApexOptimWrapper`` can be used in the same way + as ``OptimWrapper``. + + Warnings: + ``ApexOptimWrapper`` requires + [nvidia apex](https://github.com/NVIDIA/apex). + + Args: + + **kwargs: Keyword arguments passed to OptimWrapper. + + Note: + If you use ``IterBasedRunner`` and enable gradient accumulation, + the original `max_iters` should be multiplied by + ``accumulative_counts``. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def backward(self, loss: torch.Tensor, **kwargs): + """Perform gradient back propagation with :attr:`loss_scaler`. + + Args: + loss (torch.Tensor): The loss of current iteration. + kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward` + """ + with apex_amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + self._inner_count += 1 diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index 35984ce37f..8d415a15c4 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -14,12 +14,19 @@ from mmengine.dist import all_gather from mmengine.logging import MessageHub, MMLogger -from mmengine.optim import AmpOptimWrapper, OptimWrapper +from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper from mmengine.testing import assert_allclose from mmengine.testing._internal import MultiProcessTestCase from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION +is_apex_available = False +try: + import apex.amp as apex_amp + is_apex_available = True +except ImportError: + pass + class ToyModel(nn.Module): @@ -283,6 +290,101 @@ def mock_methd(loss): optim_wrapper.zero_grad = MagicMock() +class TestApexOptimWrapper(TestCase): + + def setUp(self) -> None: + self.model = ToyModel().cuda() + self.optimizer = SGD(self.model.parameters(), lr=0.1) + + @unittest.skipIf( + not is_apex_available, + reason='`apex` is not available, Please install apex from ' + 'https://www.github.com/nvidia/apex') + def test_init(self): + self.model, self.optimizer = apex_amp.initialize( + self.model, self.optimizer, opt_level='O1', loss_scale=1) + # Test with default arguments. + _ = ApexOptimWrapper(optimizer=self.optimizer) + + @unittest.skipIf( + not is_apex_available, + reason='`apex` is not available, Please install apex from ' + 'https://www.github.com/nvidia/apex') + def test_step(self): + optimizer = MagicMock(spec=Optimizer) + self.model, optimizer = apex_amp.initialize( + self.model, optimizer, opt_level='O1') + + apex_optim_wrapper = ApexOptimWrapper(optimizer=optimizer) + apex_optim_wrapper.optimizer.param_groups = MagicMock() + loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) + apex_optim_wrapper.backward(loss) + apex_optim_wrapper.step() + + @unittest.skipIf( + not is_apex_available, + reason='`apex` is not available, Please install apex from ' + 'https://www.github.com/nvidia/apex') + def test_backward(self): + self.model, self.optimizer = apex_amp.initialize( + self.model, self.optimizer, opt_level='O1', loss_scale=1) + + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer) + loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) + apex_optim_wrapper.backward(loss) + + @unittest.skipIf( + not is_apex_available, + reason='`apex` is not available, Please install apex from ' + 'https://www.github.com/nvidia/apex') + def test_state_dict(self): + self.model, self.optimizer = apex_amp.initialize( + self.model, self.optimizer, opt_level='O1') + + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer) + loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) + apex_optim_wrapper.update_params(loss) + state_dict = apex_optim_wrapper.state_dict() + optim_state_dict = state_dict + + self.assertDictEqual(optim_state_dict, + apex_optim_wrapper.optimizer.state_dict()) + + @unittest.skipIf( + not is_apex_available, + reason='`apex` is not available, Please install apex from ' + 'https://www.github.com/nvidia/apex') + def test_load_state_dict(self): + self.model, self.optimizer = apex_amp.initialize( + self.model, self.optimizer, opt_level='O1') + + apex_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) + # Test load from optimizer + optimizer = SGD(self.model.parameters(), lr=0.1) + apex_optim_wrapper.load_state_dict(optimizer.state_dict()) + + self.assertDictEqual(optimizer.state_dict(), + apex_optim_wrapper.optimizer.state_dict()) + # Test load from optim_wrapper + apex_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) + apex_optim_wrapper_ = AmpOptimWrapper( + optimizer=SGD(self.model.parameters(), lr=0.1)) + apex_optim_wrapper_.load_state_dict(apex_optim_wrapper.state_dict()) + self.assertDictEqual(apex_optim_wrapper.optimizer.state_dict(), + apex_optim_wrapper_.optimizer.state_dict()) + + @unittest.skipIf( + not is_apex_available, + reason='`apex` is not available, Please install apex from ' + 'https://www.github.com/nvidia/apex') + def test_optim_context(self): + amp_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer) + with amp_optim_wrapper.optim_context(self.model): + x = torch.randn(1, 1, 1, 1).cuda() + y = nn.Conv2d(1, 1, 1).cuda()(x) + self.assertEqual(y.dtype, torch.float16) + + class TestAmpOptimWrapper(TestCase): def setUp(self) -> None: From 01c7b411889d730b38e3bbf55399d734a306c63d Mon Sep 17 00:00:00 2001 From: xcnick Date: Fri, 18 Nov 2022 06:34:08 +0000 Subject: [PATCH 02/15] typo fix --- tests/test_optim/test_optimizer/test_optimizer_wrapper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index 8d415a15c4..7e5c0e9d0a 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -358,7 +358,7 @@ def test_load_state_dict(self): self.model, self.optimizer = apex_amp.initialize( self.model, self.optimizer, opt_level='O1') - apex_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer) # Test load from optimizer optimizer = SGD(self.model.parameters(), lr=0.1) apex_optim_wrapper.load_state_dict(optimizer.state_dict()) @@ -366,8 +366,8 @@ def test_load_state_dict(self): self.assertDictEqual(optimizer.state_dict(), apex_optim_wrapper.optimizer.state_dict()) # Test load from optim_wrapper - apex_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) - apex_optim_wrapper_ = AmpOptimWrapper( + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer) + apex_optim_wrapper_ = ApexOptimWrapper( optimizer=SGD(self.model.parameters(), lr=0.1)) apex_optim_wrapper_.load_state_dict(apex_optim_wrapper.state_dict()) self.assertDictEqual(apex_optim_wrapper.optimizer.state_dict(), From cdf612e43c767be01bf4d4e65bb0e2eb5aaa2b42 Mon Sep 17 00:00:00 2001 From: xcnick Date: Wed, 23 Nov 2022 02:03:45 +0000 Subject: [PATCH 03/15] add apex amp.initialize in optim_context --- .../optim/optimizer/apex_optimizer_wrapper.py | 60 +++++++++++- .../test_optimizer/test_optimizer_wrapper.py | 96 +++++++++---------- 2 files changed, 107 insertions(+), 49 deletions(-) diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index f2c630df41..273ee34f79 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from contextlib import contextmanager + import torch +import torch.nn as nn from mmengine.registry import OPTIM_WRAPPERS from .optimizer_wrapper import OptimWrapper @@ -33,8 +36,10 @@ class ApexOptimWrapper(OptimWrapper): ``accumulative_counts``. """ - def __init__(self, **kwargs): + def __init__(self, opt_level='O1', loss_scale='dynamic', **kwargs): super().__init__(**kwargs) + self.opt_level = opt_level + self.loss_scale = loss_scale def backward(self, loss: torch.Tensor, **kwargs): """Perform gradient back propagation with :attr:`loss_scaler`. @@ -46,3 +51,56 @@ def backward(self, loss: torch.Tensor, **kwargs): with apex_amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() self._inner_count += 1 + + def state_dict(self) -> dict: + """Get the state dictionary of :attr:`optimizer` and + :attr:`apex_amp`. + + Based on the state dictionary of the optimizer, the returned state + dictionary will add a key named "apex_amp". + + Returns: + dict: The merged state dict of :attr:`apex_amp` and + :attr:`optimizer`. + """ + state_dict = self.optimizer.state_dict() + state_dict['apex_amp'] = apex_amp.state_dict() + return state_dict + + def load_state_dict(self, state_dict: dict): + """Load and parse the state dictionary of :attr:`optimizer` and + :attr:`apex_amp`. + + If state_dict contains "apex_amp", the :attr:`apex_amp` will + load the corresponding keys. Otherwise, only the :attr:`optimizer` + will load the state dictionary. + + Args: + state_dict (dict): The state dict of :attr:`optimizer` and + :attr:`apex_amp` + """ + if 'apex_amp' in state_dict: + apex_amp.load_state_dict(state_dict.pop('apex_amp')) + self.optimizer.load_state_dict(state_dict) + + @contextmanager + def optim_context(self, model: nn.Module): + """Enables the context for mixed precision training, and enables the + context for disabling gradient synchronization during gradient + accumulation context. + + Args: + model (nn.Module): The training model. + """ + if hasattr(self.optimizer, '_amp_stash'): + yield + else: + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model = model.module + with super().optim_context(model): + model, self.optimizer = apex_amp.initialize( + model, + self.optimizer, + opt_level=self.opt_level, + loss_scale=self.loss_scale) + yield diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index 7e5c0e9d0a..d4c00fe7c8 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -301,10 +301,10 @@ def setUp(self) -> None: reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_init(self): - self.model, self.optimizer = apex_amp.initialize( - self.model, self.optimizer, opt_level='O1', loss_scale=1) - # Test with default arguments. - _ = ApexOptimWrapper(optimizer=self.optimizer) + apex_optim_wrapper = ApexOptimWrapper( + optimizer=self.optimizer, opt_level='O1', loss_scale=1) + with apex_optim_wrapper.optim_context(self.model): + pass @unittest.skipIf( not is_apex_available, @@ -312,74 +312,74 @@ def test_init(self): 'https://www.github.com/nvidia/apex') def test_step(self): optimizer = MagicMock(spec=Optimizer) - self.model, optimizer = apex_amp.initialize( - self.model, optimizer, opt_level='O1') - - apex_optim_wrapper = ApexOptimWrapper(optimizer=optimizer) - apex_optim_wrapper.optimizer.param_groups = MagicMock() - loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) - apex_optim_wrapper.backward(loss) - apex_optim_wrapper.step() + apex_optim_wrapper = ApexOptimWrapper( + optimizer=optimizer, opt_level='O1', loss_scale=1) + with apex_optim_wrapper.optim_context(self.model): + apex_optim_wrapper.optimizer.param_groups = MagicMock() + loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) + apex_optim_wrapper.backward(loss) + apex_optim_wrapper.step() @unittest.skipIf( not is_apex_available, reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_backward(self): - self.model, self.optimizer = apex_amp.initialize( - self.model, self.optimizer, opt_level='O1', loss_scale=1) - - apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer) - loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) - apex_optim_wrapper.backward(loss) + apex_optim_wrapper = ApexOptimWrapper( + optimizer=self.optimizer, opt_level='O1', loss_scale=1) + with apex_optim_wrapper.optim_context(self.model): + loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) + apex_optim_wrapper.backward(loss) @unittest.skipIf( not is_apex_available, reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_state_dict(self): - self.model, self.optimizer = apex_amp.initialize( - self.model, self.optimizer, opt_level='O1') - - apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer) - loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) - apex_optim_wrapper.update_params(loss) - state_dict = apex_optim_wrapper.state_dict() - optim_state_dict = state_dict - - self.assertDictEqual(optim_state_dict, - apex_optim_wrapper.optimizer.state_dict()) + apex_optim_wrapper = ApexOptimWrapper( + optimizer=self.optimizer, opt_level='O1', loss_scale=1) + with apex_optim_wrapper.optim_context(self.model): + loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) + apex_optim_wrapper.update_params(loss) + state_dict = apex_optim_wrapper.state_dict() + amp_state_dict = state_dict.pop('apex_amp') + optim_state_dict = state_dict + + self.assertDictEqual(optim_state_dict, + apex_optim_wrapper.optimizer.state_dict()) + self.assertDictEqual(amp_state_dict, apex_amp.state_dict()) @unittest.skipIf( not is_apex_available, reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_load_state_dict(self): - self.model, self.optimizer = apex_amp.initialize( - self.model, self.optimizer, opt_level='O1') - - apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer) - # Test load from optimizer - optimizer = SGD(self.model.parameters(), lr=0.1) - apex_optim_wrapper.load_state_dict(optimizer.state_dict()) - - self.assertDictEqual(optimizer.state_dict(), - apex_optim_wrapper.optimizer.state_dict()) - # Test load from optim_wrapper - apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer) - apex_optim_wrapper_ = ApexOptimWrapper( - optimizer=SGD(self.model.parameters(), lr=0.1)) - apex_optim_wrapper_.load_state_dict(apex_optim_wrapper.state_dict()) - self.assertDictEqual(apex_optim_wrapper.optimizer.state_dict(), - apex_optim_wrapper_.optimizer.state_dict()) + apex_optim_wrapper = ApexOptimWrapper( + optimizer=self.optimizer, opt_level='O1', loss_scale=1) + with apex_optim_wrapper.optim_context(self.model): + # Test load from optimizer + optimizer = SGD(self.model.parameters(), lr=0.1) + apex_optim_wrapper.load_state_dict(optimizer.state_dict()) + + self.assertDictEqual(optimizer.state_dict(), + apex_optim_wrapper.optimizer.state_dict()) + # Test load from optim_wrapper + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer) + apex_optim_wrapper_ = ApexOptimWrapper( + optimizer=SGD(self.model.parameters(), lr=0.1)) + apex_optim_wrapper_.load_state_dict( + apex_optim_wrapper.state_dict()) + self.assertDictEqual(apex_optim_wrapper.optimizer.state_dict(), + apex_optim_wrapper_.optimizer.state_dict()) @unittest.skipIf( not is_apex_available, reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_optim_context(self): - amp_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer) - with amp_optim_wrapper.optim_context(self.model): + apex_optim_wrapper = ApexOptimWrapper( + optimizer=self.optimizer, opt_level='O1', loss_scale=1) + with apex_optim_wrapper.optim_context(self.model): x = torch.randn(1, 1, 1, 1).cuda() y = nn.Conv2d(1, 1, 1).cuda()(x) self.assertEqual(y.dtype, torch.float16) From b31329007a61c70e145e9b7638ed2040411dc26d Mon Sep 17 00:00:00 2001 From: xcnick Date: Wed, 23 Nov 2022 08:49:37 +0000 Subject: [PATCH 04/15] assert apex_amp --- mmengine/optim/optimizer/apex_optimizer_wrapper.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index 273ee34f79..23a3e5c734 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -10,7 +10,7 @@ try: import apex.amp as apex_amp except ImportError: - pass + apex_amp = None @OPTIM_WRAPPERS.register_module() @@ -37,6 +37,9 @@ class ApexOptimWrapper(OptimWrapper): """ def __init__(self, opt_level='O1', loss_scale='dynamic', **kwargs): + assert apex_amp is not None, \ + 'Apex is not installed. Please check ' \ + 'https://github.com/NVIDIA/apex#linux.' super().__init__(**kwargs) self.opt_level = opt_level self.loss_scale = loss_scale From 6f7a25f5dc77e9e8bd820db58b3aa347d251f1e7 Mon Sep 17 00:00:00 2001 From: xcnick Date: Mon, 9 Jan 2023 09:34:03 +0000 Subject: [PATCH 05/15] polish code --- .../optim/optimizer/apex_optimizer_wrapper.py | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index 23a3e5c734..f17ffe195c 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager +from typing import Union import torch import torch.nn as nn @@ -23,11 +24,17 @@ class ApexOptimWrapper(OptimWrapper): as ``OptimWrapper``. Warnings: - ``ApexOptimWrapper`` requires - [nvidia apex](https://github.com/NVIDIA/apex). + ``ApexOptimWrapper`` requires `nvidia apex + `_ Args: + opt_level (str, default="O1"): Pure or mixed precision + optimization level. Accepted values are "O0", "O1", "O2", + and "O3". + loss_scale (float or str, default=None): If passed as + a string, must be a string representing a number, + e.g., "128.0", or the string "dynamic". **kwargs: Keyword arguments passed to OptimWrapper. Note: @@ -36,7 +43,10 @@ class ApexOptimWrapper(OptimWrapper): ``accumulative_counts``. """ - def __init__(self, opt_level='O1', loss_scale='dynamic', **kwargs): + def __init__(self, + opt_level: str = 'O1', + loss_scale: Union[float, str] = 'dynamic', + **kwargs): assert apex_amp is not None, \ 'Apex is not installed. Please check ' \ 'https://github.com/NVIDIA/apex#linux.' @@ -44,7 +54,7 @@ def __init__(self, opt_level='O1', loss_scale='dynamic', **kwargs): self.opt_level = opt_level self.loss_scale = loss_scale - def backward(self, loss: torch.Tensor, **kwargs): + def backward(self, loss: torch.Tensor, **kwargs) -> None: """Perform gradient back propagation with :attr:`loss_scaler`. Args: @@ -70,7 +80,7 @@ def state_dict(self) -> dict: state_dict['apex_amp'] = apex_amp.state_dict() return state_dict - def load_state_dict(self, state_dict: dict): + def load_state_dict(self, state_dict: dict) -> None: """Load and parse the state dictionary of :attr:`optimizer` and :attr:`apex_amp`. @@ -95,15 +105,16 @@ def optim_context(self, model: nn.Module): Args: model (nn.Module): The training model. """ - if hasattr(self.optimizer, '_amp_stash'): - yield - else: + with super().optim_context(model): + # when a given optimizer be passed through apex_amp.initialize, + # the "_amp_stash" property will be added + if hasattr(self.optimizer, '_amp_stash'): + yield if isinstance(model, torch.nn.parallel.DistributedDataParallel): model = model.module - with super().optim_context(model): - model, self.optimizer = apex_amp.initialize( - model, - self.optimizer, - opt_level=self.opt_level, - loss_scale=self.loss_scale) - yield + model, self.optimizer = apex_amp.initialize( + model, + self.optimizer, + opt_level=self.opt_level, + loss_scale=self.loss_scale) + yield From bdf660f866a9643a0f5f40b6cbe77aa8314fff35 Mon Sep 17 00:00:00 2001 From: xcnick Date: Sat, 28 Jan 2023 07:23:17 +0000 Subject: [PATCH 06/15] add parameters of apex_amp.initialize --- .../optim/optimizer/apex_optimizer_wrapper.py | 65 ++++++++++++++++++- 1 file changed, 62 insertions(+), 3 deletions(-) diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index f17ffe195c..88d4ecaccd 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager -from typing import Union +from typing import Optional, Union import torch import torch.nn as nn @@ -35,6 +35,35 @@ class ApexOptimWrapper(OptimWrapper): loss_scale (float or str, default=None): If passed as a string, must be a string representing a number, e.g., "128.0", or the string "dynamic". + enabled (bool, default=True): If False, renders all Amp calls no-ops, + so your script should run as if Amp were not present. + cast_model_type (torch.dtype, default=None): Model's parameters and + buffers to the desired type. + patch_torch_functions (bool, default=None): Patch all Torch functions + and Tensor methods to perform Tensor Core-friendly ops like GEMMs + and convolutions in FP16, + and any ops that benefit from FP32 precision in FP32. + keep_batchnorm_fp32 (bool or str, default=None): To enhance precision + and enable cudnn batchnorm (which improves performance), + it's often beneficial to keep batchnorm weights in FP32 + even if the rest of the model is FP16. + If passed as a string, must be the string "True" or "False". + master_weights (bool, default=None): Maintain FP32 master weights to + accompany any FP16 model weights. FP32 master weights are stepped + by the optimizer to enhance precision and capture small gradients. + cast_model_outputs (torch.dtype, default=None): Option to ensure that + the outputs of your model(s) are always cast to a particular type + regardless of ``opt_level``. + num_losses (int, default=1): Option to tell Amp in advance how many + losses/backward passes you plan to use. + verbosity (int, default=1): Set to 0 to suppress Amp-related output. + min_loss_scale (float, default=None): Sets a floor for the loss scale + values that can be chosen by dynamic loss scaling. + The default value of None means that no floor is imposed. + If dynamic loss scaling is not used, `min_loss_scale` is ignored. + max_loss_scale (float, default=2.**24): Sets a ceiling for the + loss scale values that can be chosen by dynamic loss scaling. + If dynamic loss scaling is not used, `max_loss_scale` is ignored. **kwargs: Keyword arguments passed to OptimWrapper. Note: @@ -46,6 +75,16 @@ class ApexOptimWrapper(OptimWrapper): def __init__(self, opt_level: str = 'O1', loss_scale: Union[float, str] = 'dynamic', + enabled: Optional[bool] = True, + cast_model_type: Optional[torch.dtype] = None, + patch_torch_functions: Optional[bool] = None, + keep_batchnorm_fp32: Optional[Union[bool, str]] = None, + master_weights: Optional[bool] = None, + cast_model_outputs: Optional[torch.dtype] = None, + num_losses: Optional[int] = 1, + verbosity: Optional[int] = 1, + min_loss_scale: Optional[float] = None, + max_loss_scale: Optional[float] = 2.**24, **kwargs): assert apex_amp is not None, \ 'Apex is not installed. Please check ' \ @@ -53,6 +92,16 @@ def __init__(self, super().__init__(**kwargs) self.opt_level = opt_level self.loss_scale = loss_scale + self.enabled = enabled + self.cast_model_type = cast_model_type + self.patch_torch_functions = patch_torch_functions + self.keep_batchnorm_fp32 = keep_batchnorm_fp32 + self.master_weights = master_weights + self.cast_model_outputs = cast_model_outputs + self.num_losses = num_losses + self.verbosity = verbosity + self.min_loss_scale = min_loss_scale + self.max_loss_scale = max_loss_scale def backward(self, loss: torch.Tensor, **kwargs) -> None: """Perform gradient back propagation with :attr:`loss_scaler`. @@ -62,7 +111,7 @@ def backward(self, loss: torch.Tensor, **kwargs) -> None: kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward` """ with apex_amp.scale_loss(loss, self.optimizer) as scaled_loss: - scaled_loss.backward() + scaled_loss.backward(**kwargs) self._inner_count += 1 def state_dict(self) -> dict: @@ -116,5 +165,15 @@ def optim_context(self, model: nn.Module): model, self.optimizer, opt_level=self.opt_level, - loss_scale=self.loss_scale) + loss_scale=self.loss_scale, + enabled=self.enabled, + cast_model_type=self.cast_model_type, + patch_torch_functions=self.patch_torch_functions, + keep_batchnorm_fp32=self.keep_batchnorm_fp32, + master_weights=self.master_weights, + cast_model_outputs=self.cast_model_outputs, + num_losses=self.num_losses, + verbosity=self.verbosity, + min_loss_scale=self.min_loss_scale, + max_loss_scale=self.max_loss_scale) yield From 2b8ea950543f898d6cbef08022c68f9a7d61fab4 Mon Sep 17 00:00:00 2001 From: xcnick Date: Sun, 29 Jan 2023 01:57:06 +0000 Subject: [PATCH 07/15] add docs --- docs/en/api/optim.rst | 1 + docs/zh_cn/api/optim.rst | 1 + mmengine/optim/optimizer/apex_optimizer_wrapper.py | 3 +-- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/en/api/optim.rst b/docs/en/api/optim.rst index 3de2ade6a2..142d0f089f 100644 --- a/docs/en/api/optim.rst +++ b/docs/en/api/optim.rst @@ -20,6 +20,7 @@ Optimizer :template: classtemplate.rst AmpOptimWrapper + ApexOptimWrapper OptimWrapper OptimWrapperDict DefaultOptimWrapperConstructor diff --git a/docs/zh_cn/api/optim.rst b/docs/zh_cn/api/optim.rst index 3de2ade6a2..142d0f089f 100644 --- a/docs/zh_cn/api/optim.rst +++ b/docs/zh_cn/api/optim.rst @@ -20,6 +20,7 @@ Optimizer :template: classtemplate.rst AmpOptimWrapper + ApexOptimWrapper OptimWrapper OptimWrapperDict DefaultOptimWrapperConstructor diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index 88d4ecaccd..8f8750d834 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -23,12 +23,11 @@ class ApexOptimWrapper(OptimWrapper): ``OptimWrapper``, so ``ApexOptimWrapper`` can be used in the same way as ``OptimWrapper``. - Warnings: + Warning: ``ApexOptimWrapper`` requires `nvidia apex `_ Args: - opt_level (str, default="O1"): Pure or mixed precision optimization level. Accepted values are "O0", "O1", "O2", and "O3". From fe6247f1f6a99ad5961dca6c5b7911c4c978df6f Mon Sep 17 00:00:00 2001 From: xcnick Date: Thu, 2 Feb 2023 15:00:24 +0000 Subject: [PATCH 08/15] polish code --- mmengine/optim/optimizer/apex_optimizer_wrapper.py | 6 +++--- tests/test_optim/test_optimizer/test_optimizer_wrapper.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index 8f8750d834..10c8e44a45 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -24,8 +24,7 @@ class ApexOptimWrapper(OptimWrapper): as ``OptimWrapper``. Warning: - ``ApexOptimWrapper`` requires `nvidia apex - `_ + ``ApexOptimWrapper`` requires `nvidia apex `_ Args: opt_level (str, default="O1"): Pure or mixed precision @@ -69,7 +68,7 @@ class ApexOptimWrapper(OptimWrapper): If you use ``IterBasedRunner`` and enable gradient accumulation, the original `max_iters` should be multiplied by ``accumulative_counts``. - """ + """ # noqa: E501 def __init__(self, opt_level: str = 'O1', @@ -158,6 +157,7 @@ def optim_context(self, model: nn.Module): # the "_amp_stash" property will be added if hasattr(self.optimizer, '_amp_stash'): yield + return if isinstance(model, torch.nn.parallel.DistributedDataParallel): model = model.module model, self.optimizer = apex_amp.initialize( diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index d4c00fe7c8..41da09664a 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -315,7 +315,6 @@ def test_step(self): apex_optim_wrapper = ApexOptimWrapper( optimizer=optimizer, opt_level='O1', loss_scale=1) with apex_optim_wrapper.optim_context(self.model): - apex_optim_wrapper.optimizer.param_groups = MagicMock() loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) apex_optim_wrapper.backward(loss) apex_optim_wrapper.step() From f87ed69b383c7f2903524edc30b0ef7b509e011a Mon Sep 17 00:00:00 2001 From: xcnick Date: Thu, 2 Feb 2023 16:32:42 +0000 Subject: [PATCH 09/15] polish code --- mmengine/optim/optimizer/apex_optimizer_wrapper.py | 5 ++++- tests/test_optim/test_optimizer/test_optimizer_wrapper.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index 10c8e44a45..a3caac5d91 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -5,6 +5,9 @@ import torch import torch.nn as nn +# a circular import will be caused by +# from mmengine.model.wrappers import is_model_wrapper +import mmengine from mmengine.registry import OPTIM_WRAPPERS from .optimizer_wrapper import OptimWrapper @@ -158,7 +161,7 @@ def optim_context(self, model: nn.Module): if hasattr(self.optimizer, '_amp_stash'): yield return - if isinstance(model, torch.nn.parallel.DistributedDataParallel): + if mmengine.model.wrappers.is_model_wrapper(model): model = model.module model, self.optimizer = apex_amp.initialize( model, diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index 41da09664a..d00033f210 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -290,6 +290,7 @@ def mock_methd(loss): optim_wrapper.zero_grad = MagicMock() +@unittest.skipIf(not torch.cuda.is_available(), reason='need gpu to test Apex') class TestApexOptimWrapper(TestCase): def setUp(self) -> None: From 3a76471cf7e64f3a6d956e91fbc55f5d623f09df Mon Sep 17 00:00:00 2001 From: xcnick Date: Thu, 2 Feb 2023 16:39:55 +0000 Subject: [PATCH 10/15] polish code --- .../optim/optimizer/apex_optimizer_wrapper.py | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index a3caac5d91..a0c955c567 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -158,24 +158,22 @@ def optim_context(self, model: nn.Module): with super().optim_context(model): # when a given optimizer be passed through apex_amp.initialize, # the "_amp_stash" property will be added - if hasattr(self.optimizer, '_amp_stash'): - yield - return - if mmengine.model.wrappers.is_model_wrapper(model): - model = model.module - model, self.optimizer = apex_amp.initialize( - model, - self.optimizer, - opt_level=self.opt_level, - loss_scale=self.loss_scale, - enabled=self.enabled, - cast_model_type=self.cast_model_type, - patch_torch_functions=self.patch_torch_functions, - keep_batchnorm_fp32=self.keep_batchnorm_fp32, - master_weights=self.master_weights, - cast_model_outputs=self.cast_model_outputs, - num_losses=self.num_losses, - verbosity=self.verbosity, - min_loss_scale=self.min_loss_scale, - max_loss_scale=self.max_loss_scale) + if not hasattr(self.optimizer, '_amp_stash'): + if mmengine.model.wrappers.is_model_wrapper(model): + model = model.module + model, self.optimizer = apex_amp.initialize( + model, + self.optimizer, + opt_level=self.opt_level, + loss_scale=self.loss_scale, + enabled=self.enabled, + cast_model_type=self.cast_model_type, + patch_torch_functions=self.patch_torch_functions, + keep_batchnorm_fp32=self.keep_batchnorm_fp32, + master_weights=self.master_weights, + cast_model_outputs=self.cast_model_outputs, + num_losses=self.num_losses, + verbosity=self.verbosity, + min_loss_scale=self.min_loss_scale, + max_loss_scale=self.max_loss_scale) yield From 5deeef18235855dab4455bbdcc09a46c9858518b Mon Sep 17 00:00:00 2001 From: xcnick Date: Fri, 3 Feb 2023 09:09:14 +0000 Subject: [PATCH 11/15] fix calling of apex amp load_state_dict --- mmengine/optim/optimizer/apex_optimizer_wrapper.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index a0c955c567..074baa267a 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -103,6 +103,7 @@ def __init__(self, self.verbosity = verbosity self.min_loss_scale = min_loss_scale self.max_loss_scale = max_loss_scale + self._apex_amp_state_dict = None def backward(self, loss: torch.Tensor, **kwargs) -> None: """Perform gradient back propagation with :attr:`loss_scaler`. @@ -138,12 +139,18 @@ def load_state_dict(self, state_dict: dict) -> None: load the corresponding keys. Otherwise, only the :attr:`optimizer` will load the state dictionary. + Note: + :meth:`load_state_dict` shuold be called after + `apex_amp.initialize` is called. Args: state_dict (dict): The state dict of :attr:`optimizer` and :attr:`apex_amp` """ if 'apex_amp' in state_dict: - apex_amp.load_state_dict(state_dict.pop('apex_amp')) + if hasattr(self.optimizer, '_amp_stash'): + apex_amp.load_state_dict(state_dict.pop('apex_amp')) + else: + self._apex_amp_state_dict = state_dict.pop('apex_amp') self.optimizer.load_state_dict(state_dict) @contextmanager @@ -176,4 +183,8 @@ def optim_context(self, model: nn.Module): verbosity=self.verbosity, min_loss_scale=self.min_loss_scale, max_loss_scale=self.max_loss_scale) + # loading apex_amp state_dict after initialization of apex_amp + if self._apex_amp_state_dict: + apex_amp.load_state_dict(self._apex_amp_state_dict) + self._apex_amp_state_dict = None yield From 9689f630a9f33417cbbe206603567dde52df5d27 Mon Sep 17 00:00:00 2001 From: xcnick Date: Fri, 3 Feb 2023 14:14:20 +0000 Subject: [PATCH 12/15] polish --- mmengine/optim/optimizer/apex_optimizer_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index 074baa267a..ca15f6a7b0 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -184,7 +184,7 @@ def optim_context(self, model: nn.Module): min_loss_scale=self.min_loss_scale, max_loss_scale=self.max_loss_scale) # loading apex_amp state_dict after initialization of apex_amp - if self._apex_amp_state_dict: + if self._apex_amp_state_dict is not None: apex_amp.load_state_dict(self._apex_amp_state_dict) self._apex_amp_state_dict = None yield From 1b5882c59f79c5438f01a13de0109fa578f94d57 Mon Sep 17 00:00:00 2001 From: xcnick Date: Sun, 5 Feb 2023 08:11:36 +0000 Subject: [PATCH 13/15] add comments --- mmengine/optim/optimizer/apex_optimizer_wrapper.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index ca15f6a7b0..6aee968f47 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -147,6 +147,10 @@ def load_state_dict(self, state_dict: dict) -> None: :attr:`apex_amp` """ if 'apex_amp' in state_dict: + # when `apex_amp` is not initialized, calling `load_state_dict` + # will raise an error, so we temporarily cache the apex_amp + # part, and then load it into `apex_amp` after completing + # the `apex_amp` initialization in `optim_context` method if hasattr(self.optimizer, '_amp_stash'): apex_amp.load_state_dict(state_dict.pop('apex_amp')) else: From 7a825d04b60bdbdeb839e0589f9dcfae86312534 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Mon, 6 Feb 2023 15:11:34 +0800 Subject: [PATCH 14/15] Update apex_optimizer_wrapper.py --- .../optim/optimizer/apex_optimizer_wrapper.py | 65 ++++++++++--------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index 6aee968f47..96bb1c7d03 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -23,67 +23,72 @@ class ApexOptimWrapper(OptimWrapper): precision training based on apex.amp. ``ApexOptimWrapper`` provides a unified interface with - ``OptimWrapper``, so ``ApexOptimWrapper`` can be used in the same way - as ``OptimWrapper``. + ``OptimWrapper``, so it can be used in the same way as ``OptimWrapper``. Warning: ``ApexOptimWrapper`` requires `nvidia apex `_ Args: - opt_level (str, default="O1"): Pure or mixed precision - optimization level. Accepted values are "O0", "O1", "O2", - and "O3". - loss_scale (float or str, default=None): If passed as - a string, must be a string representing a number, - e.g., "128.0", or the string "dynamic". - enabled (bool, default=True): If False, renders all Amp calls no-ops, - so your script should run as if Amp were not present. - cast_model_type (torch.dtype, default=None): Model's parameters and - buffers to the desired type. - patch_torch_functions (bool, default=None): Patch all Torch functions + opt_level (str): Pure or mixed precision optimization level. Accepted + values are "O0", "O1", "O2", and "O3". Defaults to "O1". + loss_scale (float or str, optional): If passed as a string, must be a + string representing a number, e.g., "128.0", or the string + "dynamic". Defaults to "dynamic". + enabled (bool): If False, renders all Amp calls no-ops, so your script + should run as if Amp were not present. Defaults to True. + cast_model_type (torch.dtype, optional): Model's parameters and + buffers to the desired type. Defaults to None. + patch_torch_functions (bool, optional): Patch all Torch functions and Tensor methods to perform Tensor Core-friendly ops like GEMMs - and convolutions in FP16, - and any ops that benefit from FP32 precision in FP32. - keep_batchnorm_fp32 (bool or str, default=None): To enhance precision + and convolutions in FP16, and any ops that benefit from FP32 + precision in FP32. Defaults to None. + keep_batchnorm_fp32 (bool or str, optional): To enhance precision and enable cudnn batchnorm (which improves performance), it's often beneficial to keep batchnorm weights in FP32 even if the rest of the model is FP16. If passed as a string, must be the string "True" or "False". - master_weights (bool, default=None): Maintain FP32 master weights to + Defaults to None. + master_weights (bool, optional): Maintain FP32 master weights to accompany any FP16 model weights. FP32 master weights are stepped by the optimizer to enhance precision and capture small gradients. - cast_model_outputs (torch.dtype, default=None): Option to ensure that + Defaults to None. + cast_model_outputs (torch.dtype, optional): Option to ensure that the outputs of your model(s) are always cast to a particular type - regardless of ``opt_level``. - num_losses (int, default=1): Option to tell Amp in advance how many - losses/backward passes you plan to use. - verbosity (int, default=1): Set to 0 to suppress Amp-related output. - min_loss_scale (float, default=None): Sets a floor for the loss scale + regardless of ``opt_level``. Defaults to None. + num_losses (int): Option to tell Amp in advance how many + losses/backward passes you plan to use. Defaults to 1. + verbosity (int): Set to 0 to suppress Amp-related output. + Defaults to 1. + min_loss_scale (float, optional): Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. The default value of None means that no floor is imposed. If dynamic loss scaling is not used, `min_loss_scale` is ignored. - max_loss_scale (float, default=2.**24): Sets a ceiling for the - loss scale values that can be chosen by dynamic loss scaling. - If dynamic loss scaling is not used, `max_loss_scale` is ignored. + Defaults to None. + max_loss_scale (float, optional): Sets a ceiling for the loss scale + values that can be chosen by dynamic loss scaling. If dynamic + loss scaling is not used, `max_loss_scale` is ignored. + Defaults to 2.**24. **kwargs: Keyword arguments passed to OptimWrapper. Note: If you use ``IterBasedRunner`` and enable gradient accumulation, the original `max_iters` should be multiplied by ``accumulative_counts``. + + `New in version 0.6.0.` """ # noqa: E501 def __init__(self, opt_level: str = 'O1', - loss_scale: Union[float, str] = 'dynamic', + loss_scale: Union[float, str, None] = 'dynamic', enabled: Optional[bool] = True, cast_model_type: Optional[torch.dtype] = None, patch_torch_functions: Optional[bool] = None, - keep_batchnorm_fp32: Optional[Union[bool, str]] = None, + keep_batchnorm_fp32: Union[bool, str, None] = None, master_weights: Optional[bool] = None, cast_model_outputs: Optional[torch.dtype] = None, - num_losses: Optional[int] = 1, - verbosity: Optional[int] = 1, + num_losses: int = 1, + verbosity: int = 1, min_loss_scale: Optional[float] = None, max_loss_scale: Optional[float] = 2.**24, **kwargs): From 955b5d6b6cd2bd807ea70a717536b406eaf6e17c Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Mon, 6 Feb 2023 15:22:39 +0800 Subject: [PATCH 15/15] Update apex_optimizer_wrapper.py --- mmengine/optim/optimizer/apex_optimizer_wrapper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index 96bb1c7d03..5f2f6f4a1b 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -74,8 +74,9 @@ class ApexOptimWrapper(OptimWrapper): If you use ``IterBasedRunner`` and enable gradient accumulation, the original `max_iters` should be multiplied by ``accumulative_counts``. - - `New in version 0.6.0.` + + Note: + `New in version 0.6.0.` """ # noqa: E501 def __init__(self,