Skip to content

Commit

Permalink
Proper support for Remote Stop and Remote Abort with NeptuneLogger (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky authored Feb 23, 2024
1 parent 0235543 commit f2f3ef5
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 8 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the divisibility check for `Trainer.accumulate_grad_batches` and `Trainer.log_every_n_steps` in ThroughputMonitor ([#19470](https://github.com/Lightning-AI/lightning/pull/19470))


- Fixed support for Remote Stop and Remote Abort with NeptuneLogger ([#19130](https://github.com/Lightning-AI/pytorch-lightning/pull/19130))


-


Expand Down
30 changes: 22 additions & 8 deletions src/lightning/pytorch/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import logging
import os
from argparse import Namespace
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Set, Union
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Set, Union

from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
Expand Down Expand Up @@ -48,6 +49,19 @@
_INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning"


# Neptune client throws `InactiveRunException` when trying to log to an inactive run.
# This may happen when the run was stopped through the UI and the logger is still trying to log to it.
def _catch_inactive(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
from neptune.exceptions import InactiveRunException

with contextlib.suppress(InactiveRunException):
return func(*args, **kwargs)

return wrapper


class NeptuneLogger(Logger):
r"""Log using `Neptune <https://neptune.ai>`_.
Expand Down Expand Up @@ -245,10 +259,7 @@ def __init__(
if self._run_instance is not None:
self._retrieve_run_data()

if _NEPTUNE_AVAILABLE:
from neptune.handler import Handler
else:
from neptune.new.handler import Handler
from neptune.handler import Handler

# make sure that we've log integration version for outside `Run` instances
root_obj = self._run_instance
Expand Down Expand Up @@ -383,6 +394,7 @@ def run(self) -> "Run":

@override
@rank_zero_only
@_catch_inactive
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
r"""Log hyperparameters to the run.
Expand Down Expand Up @@ -430,9 +442,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:

@override
@rank_zero_only
def log_metrics( # type: ignore[override]
self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None
) -> None:
@_catch_inactive
def log_metrics(self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None:
"""Log metrics (numeric values) in Neptune runs.
Args:
Expand All @@ -450,6 +461,7 @@ def log_metrics( # type: ignore[override]

@override
@rank_zero_only
@_catch_inactive
def finalize(self, status: str) -> None:
if not self._run_instance:
# When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been
Expand All @@ -473,6 +485,7 @@ def save_dir(self) -> Optional[str]:
return os.path.join(os.getcwd(), ".neptune")

@rank_zero_only
@_catch_inactive
def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None:
from neptune.types import File

Expand All @@ -483,6 +496,7 @@ def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) ->

@override
@rank_zero_only
@_catch_inactive
def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
Expand Down
4 changes: 4 additions & 0 deletions tests/tests_pytorch/loggers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def __setitem__(self, key, value):
neptune_utils.stringify_unsupported = Mock()
monkeypatch.setitem(sys.modules, "neptune.utils", neptune_utils)

neptune_exceptions = ModuleType("exceptions")
neptune_exceptions.InactiveRunException = Exception
monkeypatch.setitem(sys.modules, "neptune.exceptions", neptune_exceptions)

neptune.handler = neptune_handler
neptune.types = neptune_types
neptune.utils = neptune_utils
Expand Down
10 changes: 10 additions & 0 deletions tests/tests_pytorch/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,13 @@ def test_get_full_model_names_from_exp_structure():
}
expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"}
assert NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar") == expected_keys


def test_inactive_run(neptune_mock, tmp_path):
from neptune.exceptions import InactiveRunException

logger, run_instance_mock, _ = _get_logger_with_mocks(api_key="test", project="project")
run_instance_mock.__setitem__.side_effect = InactiveRunException

# this should work without any exceptions
_fit_and_test(logger=logger, model=BoringModel(), tmp_path=tmp_path)

0 comments on commit f2f3ef5

Please sign in to comment.