diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 9ef08113ff01d..ddc6f6121fe7d 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The LoadBalancer now uses internal ip + port instead of URL exposed ([#16119](https://github.com/Lightning-AI/lightning/pull/16119)) +- Added support for logging in different trainer stages with `DeviceStatsMonitor` +([#16002](https://github.com/Lightning-AI/lightning/pull/16002)) + ### Deprecated diff --git a/src/pytorch_lightning/callbacks/device_stats_monitor.py b/src/pytorch_lightning/callbacks/device_stats_monitor.py index 0bc014290f271..db4057dabab20 100644 --- a/src/pytorch_lightning/callbacks/device_stats_monitor.py +++ b/src/pytorch_lightning/callbacks/device_stats_monitor.py @@ -30,8 +30,9 @@ class DeviceStatsMonitor(Callback): r""" - Automatically monitors and logs device stats during training stage. ``DeviceStatsMonitor`` - is a special callback as it requires a ``logger`` to passed as argument to the ``Trainer``. + Automatically monitors and logs device stats during training, validation and testing stage. + ``DeviceStatsMonitor`` is a special callback as it requires a ``logger`` to passed as argument + to the ``Trainer``. Args: cpu_stats: if ``None``, it will log CPU stats only if the accelerator is CPU. @@ -109,6 +110,38 @@ def on_train_batch_end( ) -> None: self._get_and_log_device_stats(trainer, "on_train_batch_end") + def on_validation_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + self._get_and_log_device_stats(trainer, "on_validation_batch_start") + + def on_validation_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Optional[STEP_OUTPUT], + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + self._get_and_log_device_stats(trainer, "on_validation_batch_end") + + def on_test_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + self._get_and_log_device_stats(trainer, "on_test_batch_start") + + def on_test_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Optional[STEP_OUTPUT], + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + self._get_and_log_device_stats(trainer, "on_test_batch_end") + def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: return {prefix + separator + k: v for k, v in metrics_dict.items()} diff --git a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py index 826fa0f088f28..3466b98636c4c 100644 --- a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py +++ b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import csv import os +import re from typing import Dict, Optional from unittest import mock from unittest.mock import Mock @@ -166,3 +168,57 @@ def test_device_stats_monitor_warning_when_psutil_not_available(monkeypatch, tmp # TODO: raise an exception from v1.9 with pytest.warns(UserWarning, match="psutil` is not installed"): monitor.setup(trainer, Mock(), "fit") + + +def test_device_stats_monitor_logs_for_different_stages(tmpdir): + """Test that metrics are logged for all stages that is training, testing and validation.""" + + model = BoringModel() + device_stats = DeviceStatsMonitor() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=4, + limit_val_batches=4, + limit_test_batches=1, + log_every_n_steps=1, + accelerator="cpu", + devices=1, + callbacks=[device_stats], + logger=CSVLogger(tmpdir), + enable_checkpointing=False, + enable_progress_bar=False, + ) + + # training and validation stages will run + trainer.fit(model) + + with open(f"{tmpdir}/lightning_logs/version_0/metrics.csv") as csvfile: + + content = csv.reader(csvfile, delimiter=",") + it = iter(content).__next__() + + # searching for training stage logs + train_stage_results = [re.match(r".+on_train_batch", i) for i in it] + train = any(train_stage_results) + assert train, "training stage logs not found" + + # searching for validation stage logs + validation_stage_results = [re.match(r".+on_validation_batch", i) for i in it] + valid = any(validation_stage_results) + assert valid, "validation stage logs not found" + + # testing stage will run + trainer.test(model) + + with open(f"{tmpdir}/lightning_logs/version_0/metrics.csv") as csvfile: + + content = csv.reader(csvfile, delimiter=",") + it = iter(content).__next__() + + # searching for testing stage logs + test_stage_results = [re.match(r".+on_test_batch", i) for i in it] + test = any(test_stage_results) + + assert test, "testing stage logs not found"