-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
lr_finder.py
468 lines (363 loc) · 17.1 KB
/
lr_finder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging
import os
import uuid
from copy import deepcopy
from typing import Any, cast, Dict, List, Optional, TYPE_CHECKING, Union
import numpy as np
import torch
from lightning_utilities.core.imports import RequirementCache
from torch.optim.lr_scheduler import _LRScheduler
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT
# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
if importlib.util.find_spec("ipywidgets") is not None:
from tqdm.auto import tqdm
else:
from tqdm import tqdm
_MATPLOTLIB_AVAILABLE = RequirementCache("matplotlib")
if _MATPLOTLIB_AVAILABLE and TYPE_CHECKING:
import matplotlib.pyplot as plt
log = logging.getLogger(__name__)
def _determine_lr_attr_name(trainer: "pl.Trainer", model: "pl.LightningModule") -> str:
if isinstance(trainer.auto_lr_find, str):
if not lightning_hasattr(model, trainer.auto_lr_find):
raise MisconfigurationException(
f"`auto_lr_find` was set to {trainer.auto_lr_find}, however"
" could not find this as a field in `model` or `model.hparams`."
)
return trainer.auto_lr_find
attr_options = ("lr", "learning_rate")
for attr in attr_options:
if lightning_hasattr(model, attr):
return attr
raise MisconfigurationException(
"When `auto_lr_find=True`, either `model` or `model.hparams` should"
f" have one of these fields: {attr_options} overridden."
)
class _LRFinder:
"""LR finder object. This object stores the results of lr_find().
Args:
mode: either `linear` or `exponential`, how to increase lr after each step
lr_min: lr to start search from
lr_max: lr to stop search
num_training: number of steps to take between lr_min and lr_max
Example::
# Run lr finder
lr_finder = trainer.lr_find(model)
# Results stored in
lr_finder.results
# Plot using
lr_finder.plot()
# Get suggestion
lr = lr_finder.suggestion()
"""
def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) -> None:
assert mode in ("linear", "exponential"), "mode should be either `linear` or `exponential`"
self.mode = mode
self.lr_min = lr_min
self.lr_max = lr_max
self.num_training = num_training
self.results: Dict[str, Any] = {}
self._total_batch_idx = 0 # for debug purpose
def _exchange_scheduler(self, trainer: "pl.Trainer") -> None:
# TODO: update docs here
"""Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified
optimizer together with a new scheduler that takes care of the learning rate search."""
from pytorch_lightning.core.optimizer import _set_scheduler_opt_idx
optimizers = trainer.strategy.optimizers
if len(optimizers) != 1:
raise MisconfigurationException(
f"`model.configure_optimizers()` returned {len(optimizers)}, but"
" learning rate finder only works with single optimizer"
)
optimizer = optimizers[0]
new_lrs = [self.lr_min] * len(optimizer.param_groups)
for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
param_group["lr"] = new_lr
param_group["initial_lr"] = new_lr
args = (optimizer, self.lr_max, self.num_training)
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
scheduler = cast(pl.utilities.types._LRScheduler, scheduler)
trainer.strategy.optimizers = [optimizer]
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
_set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs)
def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figure"]:
"""Plot results from lr_find run
Args:
suggest: if True, will mark suggested lr to use with a red point
show: if True, will show figure
"""
if not _MATPLOTLIB_AVAILABLE:
raise MisconfigurationException(
"To use the `plot` method, you must have Matplotlib installed."
" Install it by running `pip install -U matplotlib`."
)
import matplotlib.pyplot as plt
lrs = self.results["lr"]
losses = self.results["loss"]
fig, ax = plt.subplots()
# Plot loss as a function of the learning rate
ax.plot(lrs, losses)
if self.mode == "exponential":
ax.set_xscale("log")
ax.set_xlabel("Learning rate")
ax.set_ylabel("Loss")
if suggest:
_ = self.suggestion()
if self._optimal_idx:
ax.plot(lrs[self._optimal_idx], losses[self._optimal_idx], markersize=10, marker="o", color="red")
if show:
plt.show()
return fig
def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]:
"""This will propose a suggestion for an initial learning rate based on the point with the steepest
negative gradient.
Args:
skip_begin: how many samples to skip in the beginning; helps to avoid too naive estimates
skip_end: how many samples to skip in the end; helps to avoid too optimistic estimates
Returns:
The suggested initial learning rate to use, or `None` if a suggestion is not possible due to too few
loss samples.
"""
losses = np.array(self.results["loss"][skip_begin:-skip_end])
losses = losses[np.isfinite(losses)]
if len(losses) < 2:
# computing np.gradient requires at least 2 points
log.error(
"Failed to compute suggestion for learning rate because there are not enough points. Increase the loop"
" iteration limits or the size of your dataset/dataloader."
)
self._optimal_idx = None
return None
# TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be
# incorrectly shifted by an offset
min_grad = np.gradient(losses).argmin()
self._optimal_idx = min_grad + skip_begin
return self.results["lr"][self._optimal_idx]
def lr_find(
trainer: "pl.Trainer",
model: "pl.LightningModule",
min_lr: float = 1e-8,
max_lr: float = 1,
num_training: int = 100,
mode: str = "exponential",
early_stop_threshold: float = 4.0,
update_attr: bool = False,
) -> Optional[_LRFinder]:
"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""
if trainer.fast_dev_run:
rank_zero_warn("Skipping learning rate finder since `fast_dev_run` is enabled.")
return None
# Determine lr attr
if update_attr:
lr_attr_name = _determine_lr_attr_name(trainer, model)
# Save initial model, that is loaded after learning rate is found
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
trainer.save_checkpoint(ckpt_path)
# Arguments we adjust during the lr finder, save for restoring
params = __lr_finder_dump_params(trainer)
# Set to values that are required by the algorithm
__lr_finder_reset_params(trainer, num_training, early_stop_threshold)
# Disable standard progress bar for fit
if trainer.progress_bar_callback:
trainer.progress_bar_callback.disable()
# Initialize lr finder object (stores results)
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
# Configure optimizer and scheduler
lr_finder._exchange_scheduler(trainer)
# Fit, lr & loss logged in callback
_try_loop_run(trainer, params)
# Prompt if we stopped early
if trainer.global_step != num_training:
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")
# Transfer results from callback to lr finder object
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose
__lr_finder_restore_params(trainer, params)
if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()
# Update lr attr if required
if update_attr:
lr = lr_finder.suggestion()
# TODO: log lr.results to self.logger
if lr is not None:
lightning_setattr(model, lr_attr_name, lr)
log.info(f"Learning rate set to {lr}")
# Restore initial state of model
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
return lr_finder
def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
return {
"optimizers": trainer.strategy.optimizers,
"lr_scheduler_configs": trainer.strategy.lr_scheduler_configs,
"optimizer_frequencies": trainer.strategy.optimizer_frequencies,
"callbacks": trainer.callbacks,
"loggers": trainer.loggers,
# TODO: check if this is required
"auto_lr_find": trainer.auto_lr_find,
"max_steps": trainer.fit_loop.max_steps,
"limit_val_batches": trainer.limit_val_batches,
"loop_state_dict": deepcopy(trainer.fit_loop.state_dict()),
}
def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float) -> None:
from pytorch_lightning.loggers.logger import DummyLogger
trainer.strategy.lr_scheduler_configs = []
trainer.strategy.optimizer_frequencies = []
# avoid lr find being called multiple times
trainer.auto_lr_find = False
# Use special lr logger callback
trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]
# No logging
trainer.logger = DummyLogger() if trainer.logger is not None else None
# Max step set to number of iterations
trainer.fit_loop.max_steps = num_training
trainer.limit_val_batches = num_training
def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
trainer.strategy.optimizers = params["optimizers"]
trainer.strategy.lr_scheduler_configs = params["lr_scheduler_configs"]
trainer.strategy.optimizer_frequencies = params["optimizer_frequencies"]
trainer.auto_lr_find = params["auto_lr_find"]
trainer.callbacks = params["callbacks"]
trainer.loggers = params["loggers"]
trainer.fit_loop.max_steps = params["max_steps"]
trainer.limit_val_batches = params["limit_val_batches"]
loop = trainer.fit_loop
loop.load_state_dict(deepcopy(params["loop_state_dict"]))
loop.restarting = False
class _LRCallback(Callback):
"""Special callback used by the learning rate finder. This callback logs the learning rate before each batch
and logs the corresponding loss after each batch.
Args:
num_training: number of iterations done by the learning rate finder
early_stop_threshold: threshold for stopping the search. If the
loss at any point is larger than ``early_stop_threshold*best_loss``
then the search is stopped. To disable, set to ``None``.
progress_bar_refresh_rate: rate to refresh the progress bar for
the learning rate finder
beta: smoothing value, the loss being logged is a running average of
loss values logged until now. ``beta`` controls the forget rate i.e.
if ``beta=0`` all past information is ignored.
"""
def __init__(
self,
num_training: int,
early_stop_threshold: float = 4.0,
progress_bar_refresh_rate: int = 0,
beta: float = 0.98,
):
self.num_training = num_training
self.early_stop_threshold = early_stop_threshold
self.beta = beta
self.losses: List[float] = []
self.lrs: List[float] = []
self.avg_loss = 0.0
self.best_loss = 0.0
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar = None
def on_train_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
) -> None:
"""Called before each training batch, logs the lr that will be used."""
if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return
if self.progress_bar_refresh_rate and self.progress_bar is None:
self.progress_bar = tqdm(desc="Finding best initial lr", total=self.num_training)
self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) # type: ignore[union-attr]
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
"""Called when the training batch ends, logs the calculated loss."""
if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return
if self.progress_bar:
self.progress_bar.update()
loss_tensor = trainer.fit_loop.running_loss.last()
assert loss_tensor is not None
current_loss = loss_tensor.item()
current_step = trainer.global_step
# Avg loss (loss with momentum) + smoothing
self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss
smoothed_loss = self.avg_loss / (1 - self.beta ** (current_step + 1))
# Check if we diverging
if self.early_stop_threshold is not None:
if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss:
trainer.fit_loop.max_steps = current_step # stop signal
if self.progress_bar:
self.progress_bar.close()
# Save best loss for diverging checking
if smoothed_loss < self.best_loss or current_step == 1:
self.best_loss = smoothed_loss
self.losses.append(smoothed_loss)
class _LinearLR(_LRScheduler):
"""Linearly increases the learning rate between two boundaries over a number of iterations.
Args:
optimizer: wrapped optimizer.
end_lr: the final learning rate.
num_iter: the number of iterations over which the test occurs.
last_epoch: the index of last epoch. Default: -1.
"""
def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1):
self.end_lr = end_lr
self.num_iter = num_iter
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]: # type: ignore[override]
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter
if self.last_epoch > 0:
val = [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs]
else:
val = [base_lr for base_lr in self.base_lrs]
self._lr = val
return val
@property
def lr(self) -> Union[float, List[float]]:
return self._lr
class _ExponentialLR(_LRScheduler):
"""Exponentially increases the learning rate between two boundaries over a number of iterations.
Arguments:
optimizer: wrapped optimizer.
end_lr: the final learning rate.
num_iter: the number of iterations over which the test occurs.
last_epoch: the index of last epoch. Default: -1.
"""
def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1):
self.end_lr = end_lr
self.num_iter = num_iter
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]: # type: ignore[override]
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter
if self.last_epoch > 0:
val = [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
else:
val = [base_lr for base_lr in self.base_lrs]
self._lr = val
return val
@property
def lr(self) -> Union[float, List[float]]:
return self._lr
def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
loop = trainer.fit_loop
loop.load_state_dict(deepcopy(params["loop_state_dict"]))
loop.restarting = False
loop.run()