From 1ff7d6fc33a2da62668acb791be2ee25ded33c48 Mon Sep 17 00:00:00 2001 From: Bipin Krishnan Date: Fri, 11 Nov 2022 00:12:55 +0530 Subject: [PATCH 1/3] add axes argument to lr finder plot --- src/pytorch_lightning/tuner/lr_finder.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 63d7c09abb26e..4a5afd69ff79d 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -40,7 +40,7 @@ _MATPLOTLIB_AVAILABLE = RequirementCache("matplotlib") if _MATPLOTLIB_AVAILABLE and TYPE_CHECKING: import matplotlib.pyplot as plt - + from matplotlib.axes import Axes log = logging.getLogger(__name__) @@ -130,12 +130,14 @@ def _exchange_scheduler(self, trainer: "pl.Trainer") -> None: trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs) - def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figure"]: + def plot(self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] = None) -> Optional["plt.Figure"]: """Plot results from lr_find run Args: suggest: if True, will mark suggested lr to use with a red point show: if True, will show figure + + ax: axes object to which the plot is to be drawn """ if not _MATPLOTLIB_AVAILABLE: raise MisconfigurationException( @@ -147,7 +149,10 @@ def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figur lrs = self.results["lr"] losses = self.results["loss"] - fig, ax = plt.subplots() + if ax is None: + fig, ax = plt.subplots() + else: + fig = ax.figure # Plot loss as a function of the learning rate ax.plot(lrs, losses) From 0f1e4f27299f8ebb594ef61e60a6af18651aaa8a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 12 Nov 2022 18:08:39 +0100 Subject: [PATCH 2/3] add changelog --- src/pytorch_lightning/CHANGELOG.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index ec0043ae2ceca..d6fc566985964 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -19,8 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support to upgrade all checkpoints in a folder using the `pl.utilities.upgrade_checkpoint` script ([#15333](https://github.com/Lightning-AI/lightning/pull/15333)) -- - +- Add an axes argument `ax` to the `.lr_find().plot()` to enable writing to a user-defined axes in a matplotlib figure ([#15652](https://github.com/Lightning-AI/lightning/pull/15652)) - Added a check to validate that wrapped FSDP models are used while initializing optimizers ([#15301](https://github.com/Lightning-AI/lightning/pull/15301)) From f39f8534587082db26ab622da3b56dec0393d0c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sat, 12 Nov 2022 18:23:09 +0100 Subject: [PATCH 3/3] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/tuner/lr_finder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 4a5afd69ff79d..29a5d47776a9e 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -38,7 +38,7 @@ from tqdm import tqdm _MATPLOTLIB_AVAILABLE = RequirementCache("matplotlib") -if _MATPLOTLIB_AVAILABLE and TYPE_CHECKING: +if TYPE_CHECKING and _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt from matplotlib.axes import Axes log = logging.getLogger(__name__) @@ -137,7 +137,7 @@ def plot(self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] = show: if True, will show figure - ax: axes object to which the plot is to be drawn + ax: Axes object to which the plot is to be drawn. If not provided, a new figure is created. """ if not _MATPLOTLIB_AVAILABLE: raise MisconfigurationException(