From 3488a8ade4790fe326ea259752eaf81baa79124c Mon Sep 17 00:00:00 2001 From: Bryn Lloyd Date: Sun, 6 Nov 2022 22:21:37 +0100 Subject: [PATCH] use _PATH in annotations and convert to str if Path --- src/pytorch_lightning/loggers/tensorboard.py | 8 +++++--- src/pytorch_lightning/loggers/wandb.py | 15 ++++++++++++--- src/pytorch_lightning/trainer/trainer.py | 5 ++++- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/loggers/tensorboard.py b/src/pytorch_lightning/loggers/tensorboard.py index f61959a5dd60d..50d6e95add25b 100644 --- a/src/pytorch_lightning/loggers/tensorboard.py +++ b/src/pytorch_lightning/loggers/tensorboard.py @@ -28,6 +28,7 @@ import pytorch_lightning as pl from lightning_lite.utilities.cloud_io import get_filesystem +from lightning_lite.utilities.types import _PATH from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE @@ -87,20 +88,21 @@ class TensorBoardLogger(Logger): def __init__( self, - save_dir: str, + save_dir: _PATH, name: Optional[str] = "lightning_logs", version: Optional[Union[int, str]] = None, log_graph: bool = False, default_hp_metric: bool = True, prefix: str = "", - sub_dir: Optional[str] = None, + sub_dir: Optional[_PATH] = None, **kwargs: Any, ): super().__init__() + save_dir = os.fspath(save_dir) self._save_dir = save_dir self._name = name or "" self._version = version - self._sub_dir = sub_dir + self._sub_dir = None if sub_dir is None else os.fspath(sub_dir) self._log_graph = log_graph self._default_hp_metric = default_hp_metric self._prefix = prefix diff --git a/src/pytorch_lightning/loggers/wandb.py b/src/pytorch_lightning/loggers/wandb.py index 5d60989c65961..0706fbeec960a 100644 --- a/src/pytorch_lightning/loggers/wandb.py +++ b/src/pytorch_lightning/loggers/wandb.py @@ -23,6 +23,7 @@ import torch.nn as nn from lightning_utilities.core.imports import RequirementCache +from lightning_lite.utilities.types import _PATH from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -291,10 +292,10 @@ def any_lightning_module_function_or_hook(self): def __init__( self, name: Optional[str] = None, - save_dir: str = ".", + save_dir: _PATH = ".", version: Optional[str] = None, offline: bool = False, - dir: Optional[str] = None, + dir: Optional[_PATH] = None, id: Optional[str] = None, anonymous: Optional[bool] = None, project: str = "lightning_logs", @@ -330,6 +331,13 @@ def __init__( self._experiment = experiment self._logged_model_time: Dict[str, float] = {} self._checkpoint_callback: Optional[Checkpoint] = None + + # paths are processed as strings + if save_dir is not None: + save_dir = os.fspath(save_dir) + elif dir is not None: + dir = os.fspath(dir) + # set wandb init arguments self._wandb_init: Dict[str, Any] = dict( name=name, @@ -520,7 +528,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: @rank_zero_only def download_artifact( artifact: str, - save_dir: Optional[str] = None, + save_dir: Optional[_PATH] = None, artifact_type: Optional[str] = None, use_artifact: Optional[bool] = True, ) -> str: @@ -541,6 +549,7 @@ def download_artifact( api = wandb.Api() artifact = api.artifact(artifact, type=artifact_type) + save_dir = None if save_dir is None else os.fspath(save_dir) return artifact.download(root=save_dir) def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "wandb.Artifact": diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index ea9ca7aef93b9..35606d71288a0 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -115,7 +115,7 @@ def __init__( logger: Union[Logger, Iterable[Logger], bool] = True, enable_checkpointing: bool = True, callbacks: Optional[Union[List[Callback], Callback]] = None, - default_root_dir: Optional[str] = None, + default_root_dir: Optional[_PATH] = None, gradient_clip_val: Optional[Union[int, float]] = None, gradient_clip_algorithm: Optional[str] = None, num_nodes: int = 1, @@ -399,6 +399,9 @@ def __init__( log.detail(f"{self.__class__.__name__}: Initializing trainer with parameters: {locals()}") self.state = TrainerState() + if default_root_dir is not None: + default_root_dir = os.fspath(default_root_dir) + # init connectors self._data_connector = DataConnector(self, multiple_trainloader_mode)