diff --git a/src/pytorch_lightning/loggers/tensorboard.py b/src/pytorch_lightning/loggers/tensorboard.py index f61959a5dd60d..50d6e95add25b 100644 --- a/src/pytorch_lightning/loggers/tensorboard.py +++ b/src/pytorch_lightning/loggers/tensorboard.py @@ -28,6 +28,7 @@ import pytorch_lightning as pl from lightning_lite.utilities.cloud_io import get_filesystem +from lightning_lite.utilities.types import _PATH from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE @@ -87,20 +88,21 @@ class TensorBoardLogger(Logger): def __init__( self, - save_dir: str, + save_dir: _PATH, name: Optional[str] = "lightning_logs", version: Optional[Union[int, str]] = None, log_graph: bool = False, default_hp_metric: bool = True, prefix: str = "", - sub_dir: Optional[str] = None, + sub_dir: Optional[_PATH] = None, **kwargs: Any, ): super().__init__() + save_dir = os.fspath(save_dir) self._save_dir = save_dir self._name = name or "" self._version = version - self._sub_dir = sub_dir + self._sub_dir = None if sub_dir is None else os.fspath(sub_dir) self._log_graph = log_graph self._default_hp_metric = default_hp_metric self._prefix = prefix diff --git a/src/pytorch_lightning/loggers/wandb.py b/src/pytorch_lightning/loggers/wandb.py index 55757487d7124..c61052762f868 100644 --- a/src/pytorch_lightning/loggers/wandb.py +++ b/src/pytorch_lightning/loggers/wandb.py @@ -24,6 +24,7 @@ from lightning_utilities.core.imports import RequirementCache from torch import Tensor +from lightning_lite.utilities.types import _PATH from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -292,10 +293,10 @@ def any_lightning_module_function_or_hook(self): def __init__( self, name: Optional[str] = None, - save_dir: str = ".", + save_dir: _PATH = ".", version: Optional[str] = None, offline: bool = False, - dir: Optional[str] = None, + dir: Optional[_PATH] = None, id: Optional[str] = None, anonymous: Optional[bool] = None, project: str = "lightning_logs", @@ -331,6 +332,13 @@ def __init__( self._experiment = experiment self._logged_model_time: Dict[str, float] = {} self._checkpoint_callback: Optional[Checkpoint] = None + + # paths are processed as strings + if save_dir is not None: + save_dir = os.fspath(save_dir) + elif dir is not None: + dir = os.fspath(dir) + # set wandb init arguments self._wandb_init: Dict[str, Any] = dict( name=name, @@ -521,7 +529,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: @rank_zero_only def download_artifact( artifact: str, - save_dir: Optional[str] = None, + save_dir: Optional[_PATH] = None, artifact_type: Optional[str] = None, use_artifact: Optional[bool] = True, ) -> str: @@ -542,6 +550,7 @@ def download_artifact( api = wandb.Api() artifact = api.artifact(artifact, type=artifact_type) + save_dir = None if save_dir is None else os.fspath(save_dir) return artifact.download(root=save_dir) def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "wandb.Artifact": diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index ea9ca7aef93b9..35606d71288a0 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -115,7 +115,7 @@ def __init__( logger: Union[Logger, Iterable[Logger], bool] = True, enable_checkpointing: bool = True, callbacks: Optional[Union[List[Callback], Callback]] = None, - default_root_dir: Optional[str] = None, + default_root_dir: Optional[_PATH] = None, gradient_clip_val: Optional[Union[int, float]] = None, gradient_clip_algorithm: Optional[str] = None, num_nodes: int = 1, @@ -399,6 +399,9 @@ def __init__( log.detail(f"{self.__class__.__name__}: Initializing trainer with parameters: {locals()}") self.state = TrainerState() + if default_root_dir is not None: + default_root_dir = os.fspath(default_root_dir) + # init connectors self._data_connector = DataConnector(self, multiple_trainloader_mode)