Skip to content

Commit

Permalink
Add special logic for 'step' in _optimizer_to_device (#20019)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
corwinjoy and awaelchli authored Aug 5, 2024
1 parent 345450b commit 631911c
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an attribute error when loading a checkpoint into a quantized model using the `_lazy_load()` function ([#20121](https://github.com/Lightning-AI/lightning/pull/20121))


-
- Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019))



Expand Down
14 changes: 11 additions & 3 deletions src/lightning/fabric/utilities/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import MutableMapping
from typing import Iterable

from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.optim import Optimizer

from lightning.fabric.utilities.apply_func import move_data_to_device
from lightning.fabric.utilities.apply_func import apply_to_collection, move_data_to_device
from lightning.fabric.utilities.types import _DEVICE


Expand All @@ -31,4 +31,12 @@ def _optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> N
def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
"""Moves the state of a single optimizer to the device."""
for p, v in optimizer.state.items():
optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)
if not isinstance(v, MutableMapping):
# Support for custom optimizers
optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)
continue
for key, val in v.items():
# The 'step' parameter needs to remain unmoved (possibly on the CPU) since that is where the optimizer
# needs it. See https://github.com/pytorch/pytorch/issues/74424
if key != "step":
v[key] = move_data_to_device(val, device)
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))

- Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019))

- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163))


Expand Down
98 changes: 74 additions & 24 deletions tests/tests_fabric/utilities/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,86 @@
import collections
import dataclasses

import pytest
import torch
from lightning.fabric.utilities.optimizer import _optimizer_to_device
from torch import Tensor

from tests_fabric.helpers.runif import RunIf

def test_optimizer_to_device():
@dataclasses.dataclass(frozen=True)

@pytest.mark.parametrize(
"optimizer_class",
[
torch.optim.Adam,
torch.optim.AdamW,
torch.optim.SGD,
torch.optim.RMSprop,
torch.optim.Adagrad,
torch.optim.Adadelta,
torch.optim.Adamax,
],
)
@pytest.mark.parametrize(
"src_device",
[
torch.device("cpu"),
pytest.param(torch.device("cuda"), marks=RunIf(min_cuda_gpus=1)),
],
)
@pytest.mark.parametrize(
"dst_device",
[
torch.device("cpu"),
pytest.param(torch.device("cuda"), marks=RunIf(min_cuda_gpus=1)),
],
)
def test_optimizer_to_device(optimizer_class, src_device, dst_device):
# Optimizer with no state initialized
model = torch.nn.Linear(2, 2, device=src_device)
optimizer = optimizer_class(model.parameters(), lr=0.1)
_optimizer_to_device(optimizer, dst_device)
_assert_opt_parameters_on_device(optimizer, dst_device)

# Optimizer with state initialized
model = torch.nn.Linear(2, 2, device=src_device)
optimizer = optimizer_class(model.parameters(), lr=0.1)
model(torch.randn(2, 2, device=src_device)).sum().backward()
optimizer.step()
_optimizer_to_device(optimizer, dst_device)
_assert_opt_parameters_on_device(optimizer, dst_device)


def _assert_opt_parameters_on_device(opt, device):
for _, v in opt.state.items():
for key, item in v.items():
if not isinstance(item, Tensor):
continue
if key == "step":
# The "step" tensor needs to remain on CPU
assert item.device.type == "cpu"
else:
assert item.device.type == device.type


@RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize("frozen", [True, False])
def test_optimizer_to_device_with_dataclass_in_state(frozen):
src_device = torch.device("cpu")
dst_device = torch.device("cuda")
model = torch.nn.Linear(32, 2, device=src_device)

@dataclasses.dataclass(frozen=frozen)
class FooState:
bar: int
integer: int
tensor: Tensor

class TestOptimizer(torch.optim.SGD):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.state["dummy"] = torch.tensor(0)
self.state["frozen"] = FooState(0)

layer = torch.nn.Linear(32, 2)
opt = TestOptimizer(layer.parameters(), lr=0.1)
_optimizer_to_device(opt, "cpu")
if torch.cuda.is_available():
_optimizer_to_device(opt, "cuda")
assert_opt_parameters_on_device(opt, "cuda")


def assert_opt_parameters_on_device(opt, device: str):
for param in opt.state.values():
# Not sure there are any global tensors in the state dict
if isinstance(param, Tensor):
assert param.data.device.type == device
elif isinstance(param, collections.abc.Mapping):
for subparam in param.values():
if isinstance(subparam, Tensor):
assert param.data.device.type == device
self.state[model.weight] = {"dummy": torch.tensor(0)}
self.state[model.bias] = FooState(0, torch.tensor(0))

optimizer = TestOptimizer(model.parameters(), lr=0.1)
_optimizer_to_device(optimizer, dst_device)
assert optimizer.state[model.weight]["dummy"].device.type == dst_device.type
assert optimizer.state[model.bias].tensor.device.type == ("cpu" if frozen else dst_device.type)

0 comments on commit 631911c

Please sign in to comment.