Skip to content
This repository has been archived by the owner on Sep 28, 2022. It is now read-only.

Fix/neptune ddp #8

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 69 additions & 53 deletions pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import logging
import os
import uuid
import warnings
from argparse import Namespace
from functools import reduce
Expand All @@ -44,7 +45,7 @@
from neptune.new.types import File as NeptuneFile
except ModuleNotFoundError:
import neptune
from neptune.exceptions import NeptuneLegacyProjectException
from neptune.exceptions import NeptuneLegacyProjectException, NeptuneOfflineModeFetchException
from neptune.run import Run
from neptune.types import File as NeptuneFile
else:
Expand Down Expand Up @@ -273,44 +274,21 @@ def __init__(
super().__init__()
self._log_model_checkpoints = log_model_checkpoints
self._prefix = prefix

self._run_instance = self._init_run_instance(api_key, project, name, run, neptune_run_kwargs)

self._run_short_id = self.run._short_id # skipcq: PYL-W0212
try:
self.run.wait()
self._run_name = self._run_instance["sys/name"].fetch()
except NeptuneOfflineModeFetchException:
self._run_name = "offline-name"

def _init_run_instance(self, api_key, project, name, run, neptune_run_kwargs) -> Run:
if run is not None:
run_instance = run
else:
try:
run_instance = neptune.init(
project=project,
api_token=api_key,
name=name,
**neptune_run_kwargs,
)
except NeptuneLegacyProjectException as e:
raise TypeError(
f"""Project {project} has not been migrated to the new structure.
You can still integrate it with the Neptune logger using legacy Python API
available as part of neptune-contrib package:
- https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n
"""
) from e

# make sure that we've log integration version for both newly created and outside `Run` instances
run_instance[_INTEGRATION_VERSION_KEY] = __version__

# keep api_key and project, they will be required when resuming Run for pickled logger
self._run_instance = run
self._api_key = api_key
self._project_name = run_instance._project_name # skipcq: PYL-W0212
self._run_name = name
self._project = project
self._custom_run_id = neptune_run_kwargs.get("custom_run_id", str(uuid.uuid4()))
if "custom_run_id" in neptune_run_kwargs:
del neptune_run_kwargs["custom_run_id"]
self._neptune_run_kwargs = neptune_run_kwargs

if self._run_instance:
self._run_name = self._preserve_name()
self._project = self._run_instance._project_name

return run_instance
# make sure that we've log integration version for outside `Run` instances
self._run_instance[_INTEGRATION_VERSION_KEY] = __version__

def _construct_path_with_prefix(self, *keys) -> str:
"""Return sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined."""
Expand Down Expand Up @@ -371,6 +349,13 @@ def _verify_input_arguments(
" you can't provide other neptune.init() parameters.\n"
)

def _preserve_name(self) -> str:
try:
self._run_instance.wait()
return self._run_instance["sys/name"].fetch()
except NeptuneOfflineModeFetchException:
return "offline-name"

def __getstate__(self):
state = self.__dict__.copy()
# Run instance can't be pickled
Expand All @@ -379,7 +364,13 @@ def __getstate__(self):

def __setstate__(self, state):
self.__dict__ = state
self._run_instance = neptune.init(project=self._project_name, api_token=self._api_key, run=self._run_short_id)
self._run_instance = neptune.init(
name=self._run_name,
project=self._project,
api_token=self._api_key,
custom_run_id=self._custom_run_id,
**self._neptune_run_kwargs,
)

@property
@rank_zero_experiment
Expand Down Expand Up @@ -412,8 +403,31 @@ def training_step(self, batch, batch_idx):
return self.run

@property
@rank_zero_experiment
def run(self) -> Run:
return self._run_instance
try:
if not self._run_instance:
self._run_instance = neptune.init(
name=self._run_name,
project=self._project,
api_token=self._api_key,
custom_run_id=self._custom_run_id,
**self._neptune_run_kwargs,
)
self._run_name = self._preserve_name()

# make sure that we've log integration version for newly created
self._run_instance[_INTEGRATION_VERSION_KEY] = __version__

return self._run_instance
except NeptuneLegacyProjectException as e:
raise TypeError(
f"""Project {self._project} has not been migrated to the new structure.
You can still integrate it with the Neptune logger using legacy Python API
available as part of neptune-contrib package:
- https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n
"""
) from e

@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # skipcq: PYL-W0221
Expand Down Expand Up @@ -474,12 +488,12 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti
for key, val in metrics.items():
# `step` is ignored because Neptune expects strictly increasing step values which
# Lightning does not always guarantee.
self.experiment[key].log(val)
self.run[key].log(val)

@rank_zero_only
def finalize(self, status: str) -> None:
if status:
self.experiment[self._construct_path_with_prefix("status")] = status
self.run[self._construct_path_with_prefix("status")] = status

super().finalize(status)

Expand All @@ -493,12 +507,14 @@ def save_dir(self) -> Optional[str]:
"""
return os.path.join(os.getcwd(), ".neptune")

@rank_zero_only
def log_model_summary(self, model, max_depth=-1):
model_str = str(ModelSummary(model=model, max_depth=max_depth))
self.experiment[self._construct_path_with_prefix("model/summary")] = neptune.types.File.from_content(
self.run[self._construct_path_with_prefix("model/summary")] = neptune.types.File.from_content(
content=model_str, extension="txt"
)

@rank_zero_only
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.

Expand All @@ -515,35 +531,35 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
if checkpoint_callback.last_model_path:
model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
file_names.add(model_last_name)
self.experiment[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)
self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)

# save best k models
for key in checkpoint_callback.best_k_models.keys():
model_name = self._get_full_model_name(key, checkpoint_callback)
file_names.add(model_name)
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(key)
self.run[f"{checkpoints_namespace}/{model_name}"].upload(key)

# log best model path and checkpoint
if checkpoint_callback.best_model_path:
self.experiment[
self.run[
self._construct_path_with_prefix("model/best_model_path")
] = checkpoint_callback.best_model_path

model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
file_names.add(model_name)
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)
self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)

# remove old models logged to experiment if they are not part of best k models at this point
if self.experiment.exists(checkpoints_namespace):
exp_structure = self.experiment.get_structure()
if self.run.exists(checkpoints_namespace):
exp_structure = self.run.get_structure()
uploaded_model_names = self._get_full_model_names_from_exp_structure(exp_structure, checkpoints_namespace)

for file_to_drop in list(uploaded_model_names - file_names):
del self.experiment[f"{checkpoints_namespace}/{file_to_drop}"]
del self.run[f"{checkpoints_namespace}/{file_to_drop}"]

# log best model score
if checkpoint_callback.best_model_score:
self.experiment[self._construct_path_with_prefix("model/best_model_score")] = (
self.run[self._construct_path_with_prefix("model/best_model_score")] = (
checkpoint_callback.best_model_score.cpu().detach().numpy()
)

Expand Down Expand Up @@ -574,6 +590,8 @@ def _dict_paths(cls, d: dict, path_in_build: str = None) -> Generator:
@property
def name(self) -> str:
"""Return the experiment name or 'offline-name' when exp is run in offline mode."""
if not self._run_name:
self._run_name = self._preserve_name()
return self._run_name

@property
Expand All @@ -582,7 +600,7 @@ def version(self) -> str:

It's Neptune Run's short_id
"""
return self._run_short_id
return self.run._short_id # skipcq: PYL-W0212

@staticmethod
def _signal_deprecated_api_usage(f_name, sample_code, raise_exception=False):
Expand Down Expand Up @@ -634,13 +652,11 @@ def log_artifact(self, artifact: str, destination: Optional[str] = None) -> None
self._signal_deprecated_api_usage("log_artifact", f"logger.run['{key}].log('path_to_file')")
self.run[key].log(destination)

@rank_zero_only
def set_property(self, *args, **kwargs):
self._signal_deprecated_api_usage(
"log_artifact", f"logger.run['{self._prefix}/{self.PARAMETERS_KEY}/key'].log(value)", raise_exception=True
)

@rank_zero_only
def append_tags(self, *args, **kwargs):
self._signal_deprecated_api_usage(
"append_tags", "logger.run['sys/tags'].add(['foo', 'bar'])", raise_exception=True
Expand Down
9 changes: 6 additions & 3 deletions tests/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def tmpdir_unittest_fixture(request, tmpdir):
class TestNeptuneLogger(unittest.TestCase):
def test_neptune_online(self, neptune):
logger = NeptuneLogger(api_key="test", project="project")
created_run_mock = logger._run_instance
created_run_mock = logger.run

self.assertEqual(logger._run_instance, created_run_mock)
self.assertEqual(logger.name, "Run test name")
Expand Down Expand Up @@ -109,7 +109,10 @@ def test_neptune_pickling(self, neptune):
pickled_logger = pickle.dumps(logger)
unpickled = pickle.loads(pickled_logger)

neptune.init.assert_called_once_with(project="test-project", api_token=None, run="TEST-42")
neptune.init.assert_called_once_with(name='Test name',
project="test-project",
api_token=None,
custom_run_id=logger._custom_run_id)
self.assertIsNotNone(unpickled.experiment)

@patch("pytorch_lightning.loggers.neptune.Run", Run)
Expand Down Expand Up @@ -360,7 +363,7 @@ def test_legacy_functions(self, neptune, neptune_file_mock, warnings_mock):
logger = NeptuneLogger(api_key="test", project="project")

# test deprecated functions which will be shut down in pytorch-lightning 1.7.0
attr_mock = logger._run_instance.__getitem__
attr_mock = logger.run.__getitem__
attr_mock.reset_mock()
fake_image = {}

Expand Down