-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
wrappers.py
183 lines (145 loc) · 6.6 KB
/
wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
from contextlib import contextmanager
from itertools import chain
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Sized, Type, Union
import torch
from torch import nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
def _do_nothing_closure() -> None:
return None
class _LiteOptimizer:
def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None:
"""LiteOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer
step calls to the accelerator/strategy plugin.
The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`.
Args:
optimizer: The optimizer to wrap
accelerator: Reference to the accelerator for handling the optimizer step
"""
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
# not want to call on destruction of the `_LiteOptimizer
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")}
self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
self._optimizer = optimizer
self._accelerator = accelerator
@property
def optimizer(self) -> Optimizer:
return self._optimizer
def step(self, closure: Optional[Callable] = None) -> None:
closure = closure or _do_nothing_closure
self._accelerator.optimizer_step(
self.optimizer,
opt_idx=0,
closure=closure,
model=self._accelerator.model,
)
class _LiteModule(nn.Module):
def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None:
"""The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
automatically for the forward pass.
The underlying wrapped module can be accessed via the property :attr:`module`.
Args:
module: The module to wrap
precision_plugin: Reference to the precision plugin for handling precision context
"""
super().__init__()
self._module = module
self._precision_plugin = precision_plugin
@property
def module(self) -> nn.Module:
return self._module
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Casts all inputs to the right precision and handles autocast for operations in the module forward
method."""
precision = self._precision_plugin.precision
precision_to_type = {
"bf16": torch.bfloat16,
16: torch.float16,
32: torch.float32,
64: torch.float64,
}
# TODO (@awaelchli): let the precision plugin handle the conversion
to_type = precision_to_type[precision]
args, kwargs = apply_to_collection([args, kwargs], function=lambda t: t.to(to_type), dtype=Tensor)
with self._precision_plugin.forward_context():
output = self.module(*args, **kwargs)
output = apply_to_collection(output, function=lambda t: t.to(torch.get_default_dtype()), dtype=Tensor)
return output
def _wrap_init(f: Callable) -> Callable:
@functools.wraps(f)
def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None:
params = dict(inspect.signature(module._old_init).parameters)
params.pop("args")
params.pop("kwargs")
for init_name, init_arg in chain(zip(params, args), kwargs.items()):
setattr(module, init_name, init_arg)
f(module, *args, **kwargs)
return wrapper
# https://stackoverflow.com/a/63851681/9201239
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
subclasses = set()
def recurse(cl: Type[Any]) -> None:
for subclass in cl.__subclasses__():
subclasses.add(subclass)
recurse(subclass)
recurse(cls)
return subclasses
def _enable_class(cls: Type[Any]) -> None:
cls._old_init = cls.__init__
cls.__init__ = _wrap_init(cls.__init__)
def _disable_class(cls: Type[Any]) -> None:
cls.__init__ = cls._old_init
del cls._old_init
@contextmanager
def _replace_dataloader_init_method() -> Generator:
"""This context manager is used to support custom :class:`~torch.utils.data.DataLoader."""
for subclass in _get_all_subclasses(DataLoader):
_enable_class(subclass)
yield
for subclass in _get_all_subclasses(DataLoader):
_disable_class(subclass)
class _LiteDataLoader:
def __init__(self, dataloader: Union[Iterable, DataLoader], device: Optional[torch.device] = None) -> None:
"""The LiteDataLoader is an extension of an Iterator. It would move the data to the device automatically if
the device is specified.
Args:
dataloader: The current dataloader to be used.
device: The device to which the data should be moved. By default the device is `None` and no data
transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`).
"""
super().__init__()
self.__dict__.update(getattr(dataloader, "__dict__", {}))
self._dataloader = dataloader
self._device = device
def __len__(self) -> Union[int, float]:
if isinstance(self._dataloader, Sized):
return len(self._dataloader)
return float("inf")
@property
def device(self) -> Optional[torch.device]:
return self._device
def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
iterator = iter(self._dataloader)
if self._device is None:
yield from iterator
for item in iterator:
yield move_data_to_device(item, self._device)