Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle set_to_none when using DeepSpeed optimizer in Lite #16275

Merged
merged 15 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning_fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for managing callbacks via `Fabric(callbacks=...)` and emitting events through `Fabric.call()` ([#16074](https://github.com/Lightning-AI/lightning/issues/16074))


- Added support for a consistent `.zero_grad(set_to_none=...)` on the wrapped optimizer regardless of which strategy is used ([#16275](https://github.com/Lightning-AI/lightning/issues/16275))


### Changed

- Renamed the class `LightningLite` to `Fabric` ([#15932](https://github.com/Lightning-AI/lightning/issues/15932), [#15938](https://github.com/Lightning-AI/lightning/issues/15938))
Expand Down
16 changes: 15 additions & 1 deletion src/lightning_fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, overload, TypeVar, Union

import torch
Expand Down Expand Up @@ -44,7 +45,9 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None:
"""
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
# not want to call on destruction of the `_FabricOptimizer
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")}
self.__dict__ = {
k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "zero_grad", "__del__")
}
self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
self._optimizer = optimizer
self._strategy = strategy
Expand All @@ -68,6 +71,10 @@ def step(self, closure: Optional[Callable] = None) -> Any:
**kwargs,
)

def zero_grad(self, **kwargs: Any) -> None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
kwargs = _process_optimizer_zero_grad_kwargs(self.optimizer, kwargs)
self.optimizer.zero_grad(**kwargs)


class _FabricModule(_DeviceDtypeModuleMixin):
def __init__(
Expand Down Expand Up @@ -175,3 +182,10 @@ def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:

for item in iterator:
yield move_data_to_device(item, self._device)


def _process_optimizer_zero_grad_kwargs(optimizer: Optimizer, kwargs: Dict[str, Any]) -> Dict[str, Any]:
if "set_to_none" in kwargs and "set_grads_to_None" in inspect.signature(optimizer.zero_grad).parameters:
# Some optimizers out there, for example DeepSpeedZeroOptimizer, use a different name than PyTorch
kwargs["set_grads_to_None"] = kwargs.pop("set_to_none")
return kwargs
32 changes: 32 additions & 0 deletions tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import call, Mock

import pytest
Expand Down Expand Up @@ -291,3 +292,34 @@ def test_lite_optimizer_steps():
lite_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
lite_optimizer.step()
strategy.optimizer_step.assert_called_once_with(strategy.model)


def test_fabric_optimizer_zero_grad_kwargs():
"""Test that Fabric can adapt the `.zero_grad()` arguments to the underlying optimizer."""

# Test PyTorch's standard `.zero_grad()` signature
with mock.patch("torch.optim.SGD.zero_grad") as zero_grad_mock:
optimizer = torch.optim.SGD(torch.nn.Linear(1, 1).parameters(), 0.1)
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
fabric_optimizer.zero_grad()
zero_grad_mock.assert_called_with()
fabric_optimizer.zero_grad(set_to_none=False)
zero_grad_mock.assert_called_with(set_to_none=False)
fabric_optimizer.zero_grad(set_to_none=True)
zero_grad_mock.assert_called_with(set_to_none=True)

# Test weird `.zero_grad()` signatures from other libraries
custom_zero_grad = Mock()

class CustomSGD(torch.optim.SGD):
def zero_grad(self, set_grads_to_None=False):
custom_zero_grad(set_grads_to_None=set_grads_to_None)

optimizer = CustomSGD(torch.nn.Linear(1, 1).parameters(), 0.1)
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
fabric_optimizer.zero_grad()
custom_zero_grad.assert_called_with(set_grads_to_None=False)
fabric_optimizer.zero_grad(set_to_none=False)
custom_zero_grad.assert_called_with(set_grads_to_None=False)
fabric_optimizer.zero_grad(set_to_none=True)
custom_zero_grad.assert_called_with(set_grads_to_None=True)