Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy import tensorboard #15762

Merged
merged 3 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
import logging
import os
from argparse import Namespace
from typing import Any, Dict, Mapping, Optional, Union
from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, Union

import numpy as np
from lightning_utilities.core.imports import RequirementCache
from tensorboardX import SummaryWriter
from tensorboardX.summary import hparams
from torch import Tensor

import pytorch_lightning as pl
Expand All @@ -40,6 +38,13 @@
log = logging.getLogger(__name__)

_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard")
_TENSORBOARDX_AVAILABLE = RequirementCache("tensorboardX")
if TYPE_CHECKING:
# assumes at least one will be installed when type checking
if _TENSORBOARD_AVAILABLE:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is opposite what @lantiga said in #15728 (comment) 🦦

from torch.utils.tensorboard import SummaryWriter
else:
from tensorboardX import SummaryWriter # type: ignore[no-redef]

if _OMEGACONF_AVAILABLE:
from omegaconf import Container, OmegaConf
Expand Down Expand Up @@ -109,6 +114,10 @@ def __init__(
sub_dir: Optional[_PATH] = None,
**kwargs: Any,
):
if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this won't happen as TBX is mandatory depenency

raise ModuleNotFoundError(
"Neither `tensorboard` nor `tensorboardX` is available. Try `pip install`ing either."
)
super().__init__()
save_dir = os.fspath(save_dir)
self._save_dir = save_dir
Expand Down Expand Up @@ -172,7 +181,7 @@ def sub_dir(self) -> Optional[str]:

@property
@rank_zero_experiment
def experiment(self) -> SummaryWriter:
def experiment(self) -> "SummaryWriter":
r"""
Actual tensorboard object. To use TensorBoard features in your
:class:`~pytorch_lightning.core.module.LightningModule` do the following.
Expand All @@ -188,6 +197,12 @@ def experiment(self) -> SummaryWriter:
assert rank_zero_only.rank == 0, "tried to init log dirs in non global_rank=0"
if self.root_dir:
self._fs.makedirs(self.root_dir, exist_ok=True)

if _TENSORBOARD_AVAILABLE:
from torch.utils.tensorboard import SummaryWriter
else:
from tensorboardX import SummaryWriter # type: ignore[no-redef]
carmocca marked this conversation as resolved.
Show resolved Hide resolved

self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment

Expand Down Expand Up @@ -224,6 +239,12 @@ def log_hyperparams(

if metrics:
self.log_metrics(metrics, 0)

if _TENSORBOARD_AVAILABLE:
from torch.utils.tensorboard.summary import hparams
else:
from tensorboardX.summary import hparams # type: ignore[no-redef]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how many times is this called per epoch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was answered in #15762 (comment)


exp, ssi, sei = hparams(params, metrics)
writer = self.experiment._get_file_writer()
writer.add_summary(exp)
Expand Down
38 changes: 21 additions & 17 deletions tests/tests_pytorch/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,16 @@
import inspect
import pickle
from unittest import mock
from unittest.mock import ANY
from unittest.mock import ANY, Mock

import pytest
import torch

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.loggers import (
CometLogger,
CSVLogger,
MLFlowLogger,
NeptuneLogger,
TensorBoardLogger,
WandbLogger,
)
from pytorch_lightning.loggers import CometLogger, CSVLogger, MLFlowLogger, NeptuneLogger, WandbLogger
carmocca marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning.loggers.logger import DummyExperiment
from pytorch_lightning.loggers.tensorboard import _TENSORBOARD_AVAILABLE, TensorBoardLogger
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.loggers.test_comet import _patch_comet_atexit
from tests_pytorch.loggers.test_mlflow import mock_mlflow_run_creation
Expand Down Expand Up @@ -300,10 +294,15 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
logger.experiment.__getitem__().log.assert_called_once_with(1.0)

# TensorBoard
with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"):
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir, prefix=prefix)
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)
if _TENSORBOARD_AVAILABLE:
import torch.utils.tensorboard as tb
else:
import tensorboardX as tb

monkeypatch.setattr(tb, "SummaryWriter", Mock())
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir, prefix=prefix)
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)

# WandB
with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb, mock.patch(
Expand All @@ -316,17 +315,22 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0, "trainer/global_step": 0})


def test_logger_default_name(tmpdir):
def test_logger_default_name(tmpdir, monkeypatch):
"""Test that the default logger name is lightning_logs."""

# CSV
logger = CSVLogger(save_dir=tmpdir)
assert logger.name == "lightning_logs"

# TensorBoard
with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"):
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir)
assert logger.name == "lightning_logs"
if _TENSORBOARD_AVAILABLE:
import torch.utils.tensorboard as tb
else:
import tensorboardX as tb

monkeypatch.setattr(tb, "SummaryWriter", Mock())
logger = _instantiate_logger(TensorBoardLogger, save_dir=tmpdir)
assert logger.name == "lightning_logs"

# MLflow
with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch(
Expand Down
16 changes: 11 additions & 5 deletions tests/tests_pytorch/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
from argparse import Namespace
from unittest import mock
from unittest.mock import Mock

import numpy as np
import pytest
Expand Down Expand Up @@ -278,23 +279,28 @@ def training_step(self, *args):
assert count_steps == model.indexes


@mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter")
def test_tensorboard_finalize(summary_writer, tmpdir):
def test_tensorboard_finalize(monkeypatch, tmpdir):
"""Test that the SummaryWriter closes in finalize."""
if _TENSORBOARD_AVAILABLE:
import torch.utils.tensorboard as tb
else:
import tensorboardX as tb

monkeypatch.setattr(tb, "SummaryWriter", Mock())
logger = TensorBoardLogger(save_dir=tmpdir)
assert logger._experiment is None
logger.finalize("any")

# no log calls, no experiment created -> nothing to flush
summary_writer.assert_not_called()
logger.experiment.assert_not_called()

logger = TensorBoardLogger(save_dir=tmpdir)
logger.log_metrics({"flush_me": 11.1}) # trigger creation of an experiment
logger.finalize("any")

# finalize flushes to experiment directory
summary_writer().flush.assert_called()
summary_writer().close.assert_called()
logger.experiment.flush.assert_called()
logger.experiment.close.assert_called()


def test_tensorboard_save_hparams_to_yaml_once(tmpdir):
Expand Down
4 changes: 3 additions & 1 deletion tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,7 +1330,9 @@ def test_tensorboard_logger_init_args():
"TensorBoardLogger",
{
"save_dir": "tb", # Resolve from TensorBoardLogger.__init__
"comment": "tb", # Resolve from tensorboard.writer.SummaryWriter.__init__
},
{
"comment": "tb", # Unsupported resolving from local imports
},
)

Expand Down