Skip to content

Commit

Permalink
Use _PATH in annotations and convert to str if Path (#15560)
Browse files Browse the repository at this point in the history
Co-authored-by: Bryn Lloyd <[email protected]>
(cherry picked from commit 18f7f2d)
  • Loading branch information
dyollb authored and lexierule committed Nov 10, 2022
1 parent 6c199e9 commit efd55be
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
8 changes: 5 additions & 3 deletions src/pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import pytorch_lightning as pl
from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.types import _PATH
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
Expand Down Expand Up @@ -87,20 +88,21 @@ class TensorBoardLogger(Logger):

def __init__(
self,
save_dir: str,
save_dir: _PATH,
name: Optional[str] = "lightning_logs",
version: Optional[Union[int, str]] = None,
log_graph: bool = False,
default_hp_metric: bool = True,
prefix: str = "",
sub_dir: Optional[str] = None,
sub_dir: Optional[_PATH] = None,
**kwargs: Any,
):
super().__init__()
save_dir = os.fspath(save_dir)
self._save_dir = save_dir
self._name = name or ""
self._version = version
self._sub_dir = sub_dir
self._sub_dir = None if sub_dir is None else os.fspath(sub_dir)
self._log_graph = log_graph
self._default_hp_metric = default_hp_metric
self._prefix = prefix
Expand Down
15 changes: 12 additions & 3 deletions src/pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor

from lightning_lite.utilities.types import _PATH
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -292,10 +293,10 @@ def any_lightning_module_function_or_hook(self):
def __init__(
self,
name: Optional[str] = None,
save_dir: str = ".",
save_dir: _PATH = ".",
version: Optional[str] = None,
offline: bool = False,
dir: Optional[str] = None,
dir: Optional[_PATH] = None,
id: Optional[str] = None,
anonymous: Optional[bool] = None,
project: str = "lightning_logs",
Expand Down Expand Up @@ -331,6 +332,13 @@ def __init__(
self._experiment = experiment
self._logged_model_time: Dict[str, float] = {}
self._checkpoint_callback: Optional[Checkpoint] = None

# paths are processed as strings
if save_dir is not None:
save_dir = os.fspath(save_dir)
elif dir is not None:
dir = os.fspath(dir)

# set wandb init arguments
self._wandb_init: Dict[str, Any] = dict(
name=name,
Expand Down Expand Up @@ -521,7 +529,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
@rank_zero_only
def download_artifact(
artifact: str,
save_dir: Optional[str] = None,
save_dir: Optional[_PATH] = None,
artifact_type: Optional[str] = None,
use_artifact: Optional[bool] = True,
) -> str:
Expand All @@ -542,6 +550,7 @@ def download_artifact(
api = wandb.Api()
artifact = api.artifact(artifact, type=artifact_type)

save_dir = None if save_dir is None else os.fspath(save_dir)
return artifact.download(root=save_dir)

def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "wandb.Artifact":
Expand Down
5 changes: 4 additions & 1 deletion src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
logger: Union[Logger, Iterable[Logger], bool] = True,
enable_checkpointing: bool = True,
callbacks: Optional[Union[List[Callback], Callback]] = None,
default_root_dir: Optional[str] = None,
default_root_dir: Optional[_PATH] = None,
gradient_clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[str] = None,
num_nodes: int = 1,
Expand Down Expand Up @@ -399,6 +399,9 @@ def __init__(
log.detail(f"{self.__class__.__name__}: Initializing trainer with parameters: {locals()}")
self.state = TrainerState()

if default_root_dir is not None:
default_root_dir = os.fspath(default_root_dir)

# init connectors
self._data_connector = DataConnector(self, multiple_trainloader_mode)

Expand Down

0 comments on commit efd55be

Please sign in to comment.