diff --git a/src/pytorch_lightning/loggers/wandb.py b/src/pytorch_lightning/loggers/wandb.py index 8e30827759b99..530fb58fabe5e 100644 --- a/src/pytorch_lightning/loggers/wandb.py +++ b/src/pytorch_lightning/loggers/wandb.py @@ -328,7 +328,7 @@ def __getstate__(self) -> Dict[str, Any]: @property # type: ignore[misc] @rank_zero_experiment - def experiment(self) -> Run: + def experiment(self) -> Union[Run, RunDisabled]: r""" Actual wandb object. To use wandb features in your @@ -361,11 +361,13 @@ def experiment(self) -> Run: self._experiment = wandb.init(**self._wandb_init) # define default x-axis - if isinstance(self._experiment, Run) and getattr(self._experiment, "define_metric", None): + if isinstance(self._experiment, (Run, RunDisabled)) and getattr( + self._experiment, "define_metric", None + ): self._experiment.define_metric("trainer/global_step") self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) - assert isinstance(self._experiment, Run) + assert isinstance(self._experiment, (Run, RunDisabled)) return self._experiment def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True) -> None: