Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added support for logging in different trainer stages #16002

Merged
merged 13 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The LoadBalancer now uses internal ip + port instead of URL exposed ([#16119](https://github.com/Lightning-AI/lightning/pull/16119))

- Added support for logging in different trainer stages with `DeviceStatsMonitor`
([#16002](https://github.com/Lightning-AI/lightning/pull/16002))


### Deprecated

Expand Down
37 changes: 35 additions & 2 deletions src/pytorch_lightning/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@

class DeviceStatsMonitor(Callback):
r"""
Automatically monitors and logs device stats during training stage. ``DeviceStatsMonitor``
is a special callback as it requires a ``logger`` to passed as argument to the ``Trainer``.
Automatically monitors and logs device stats during training, validation and testing stage.
``DeviceStatsMonitor`` is a special callback as it requires a ``logger`` to passed as argument
to the ``Trainer``.

Args:
cpu_stats: if ``None``, it will log CPU stats only if the accelerator is CPU.
Expand Down Expand Up @@ -109,6 +110,38 @@ def on_train_batch_end(
) -> None:
self._get_and_log_device_stats(trainer, "on_train_batch_end")

def on_validation_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
self._get_and_log_device_stats(trainer, "on_validation_batch_start")

def on_validation_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Optional[STEP_OUTPUT],
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
self._get_and_log_device_stats(trainer, "on_validation_batch_end")

def on_test_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
self._get_and_log_device_stats(trainer, "on_test_batch_start")

def on_test_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Optional[STEP_OUTPUT],
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
self._get_and_log_device_stats(trainer, "on_test_batch_end")


def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
return {prefix + separator + k: v for k, v in metrics_dict.items()}
56 changes: 56 additions & 0 deletions tests/tests_pytorch/callbacks/test_device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import os
import re
from typing import Dict, Optional
from unittest import mock
from unittest.mock import Mock
Expand Down Expand Up @@ -166,3 +168,57 @@ def test_device_stats_monitor_warning_when_psutil_not_available(monkeypatch, tmp
# TODO: raise an exception from v1.9
with pytest.warns(UserWarning, match="psutil` is not installed"):
monitor.setup(trainer, Mock(), "fit")


def test_device_stats_monitor_logs_for_different_stages(tmpdir):
"""Test that metrics are logged for all stages that is training, testing and validation."""

model = BoringModel()
device_stats = DeviceStatsMonitor()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=4,
limit_val_batches=4,
limit_test_batches=1,
log_every_n_steps=1,
accelerator="cpu",
devices=1,
callbacks=[device_stats],
logger=CSVLogger(tmpdir),
enable_checkpointing=False,
enable_progress_bar=False,
)

# training and validation stages will run
trainer.fit(model)

with open(f"{tmpdir}/lightning_logs/version_0/metrics.csv") as csvfile:

content = csv.reader(csvfile, delimiter=",")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
it = iter(content).__next__()

# searching for training stage logs
train_stage_results = [re.match(r".+on_train_batch", i) for i in it]
train = any(train_stage_results)
assert train, "training stage logs not found"

# searching for validation stage logs
validation_stage_results = [re.match(r".+on_validation_batch", i) for i in it]
valid = any(validation_stage_results)
assert valid, "validation stage logs not found"

# testing stage will run
trainer.test(model)

with open(f"{tmpdir}/lightning_logs/version_0/metrics.csv") as csvfile:

content = csv.reader(csvfile, delimiter=",")
it = iter(content).__next__()

# searching for testing stage logs
test_stage_results = [re.match(r".+on_test_batch", i) for i in it]
test = any(test_stage_results)

assert test, "testing stage logs not found"