Skip to content

Commit

Permalink
Fixed NeptuneLogger when using with DDP
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky committed Dec 11, 2021
1 parent 5576fbc commit 37affb2
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 52 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed `NeptuneLogger` when using DDP ([#11030](https://github.com/PyTorchLightning/pytorch-lightning/pull/11030))


-
Expand Down
119 changes: 71 additions & 48 deletions pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from neptune.new.types import File as NeptuneFile
except ModuleNotFoundError:
import neptune
from neptune.exceptions import NeptuneLegacyProjectException
from neptune.exceptions import NeptuneLegacyProjectException, NeptuneOfflineModeFetchException
from neptune.run import Run
from neptune.types import File as NeptuneFile
else:
Expand Down Expand Up @@ -273,44 +273,53 @@ def __init__(
super().__init__()
self._log_model_checkpoints = log_model_checkpoints
self._prefix = prefix
self._run_name = name
self._project_name = project
self._api_key = api_key
self._run_instance = run
self._neptune_run_kwargs = neptune_run_kwargs
self._run_short_id = None

if self._run_instance is not None:
self._retrieve_run_data()

self._run_instance = self._init_run_instance(api_key, project, name, run, neptune_run_kwargs)
# make sure that we've log integration version for outside `Run` instances
self._run_instance[_INTEGRATION_VERSION_KEY] = __version__

self._run_short_id = self.run._short_id # skipcq: PYL-W0212
def _retrieve_run_data(self):
try:
self.run.wait()
self._run_instance.wait()
self._run_short_id = self.run._short_id # skipcq: PYL-W0212
self._run_name = self._run_instance["sys/name"].fetch()
except NeptuneOfflineModeFetchException:
self._run_name = "offline-name"

def _init_run_instance(self, api_key, project, name, run, neptune_run_kwargs) -> Run:
if run is not None:
run_instance = run
else:
try:
run_instance = neptune.init(
project=project,
api_token=api_key,
name=name,
**neptune_run_kwargs,
)
except NeptuneLegacyProjectException as e:
raise TypeError(
f"""Project {project} has not been migrated to the new structure.
You can still integrate it with the Neptune logger using legacy Python API
available as part of neptune-contrib package:
- https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n
"""
) from e

# make sure that we've log integration version for both newly created and outside `Run` instances
run_instance[_INTEGRATION_VERSION_KEY] = __version__

# keep api_key and project, they will be required when resuming Run for pickled logger
self._api_key = api_key
self._project_name = run_instance._project_name # skipcq: PYL-W0212
@property
def _neptune_init_args(self):
args = {}
# Backward compatibility in case of previous version retrieval
try:
args = self._neptune_run_kwargs
except AttributeError:
pass

if self._project_name is not None:
args["project"] = self._project_name

return run_instance
if self._api_key is not None:
args["api_token"] = self._api_key

if self._run_short_id is not None:
args["run"] = self._run_short_id

# Backward compatibility in case of previous version retrieval
try:
if self._run_name is not None:
args["name"] = self._run_name
except AttributeError:
pass

return args

def _construct_path_with_prefix(self, *keys) -> str:
"""Return sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined."""
Expand Down Expand Up @@ -379,7 +388,7 @@ def __getstate__(self):

def __setstate__(self, state):
self.__dict__ = state
self._run_instance = neptune.init(project=self._project_name, api_token=self._api_key, run=self._run_short_id)
self._run_instance = neptune.init(**self._neptune_init_args)

@property
@rank_zero_experiment
Expand Down Expand Up @@ -412,8 +421,24 @@ def training_step(self, batch, batch_idx):
return self.run

@property
@rank_zero_experiment
def run(self) -> Run:
return self._run_instance
try:
if not self._run_instance:
self._run_instance = neptune.init(**self._neptune_init_args)
self._retrieve_run_data()
# make sure that we've log integration version for newly created
self._run_instance[_INTEGRATION_VERSION_KEY] = __version__

return self._run_instance
except NeptuneLegacyProjectException as e:
raise TypeError(
f"""Project {self._project_name} has not been migrated to the new structure.
You can still integrate it with the Neptune logger using legacy Python API
available as part of neptune-contrib package:
- https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n
"""
) from e

@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # skipcq: PYL-W0221
Expand Down Expand Up @@ -474,12 +499,12 @@ def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Opti
for key, val in metrics.items():
# `step` is ignored because Neptune expects strictly increasing step values which
# Lightning does not always guarantee.
self.experiment[key].log(val)
self.run[key].log(val)

@rank_zero_only
def finalize(self, status: str) -> None:
if status:
self.experiment[self._construct_path_with_prefix("status")] = status
self.run[self._construct_path_with_prefix("status")] = status

super().finalize(status)

Expand All @@ -493,12 +518,14 @@ def save_dir(self) -> Optional[str]:
"""
return os.path.join(os.getcwd(), ".neptune")

@rank_zero_only
def log_model_summary(self, model, max_depth=-1):
model_str = str(ModelSummary(model=model, max_depth=max_depth))
self.experiment[self._construct_path_with_prefix("model/summary")] = neptune.types.File.from_content(
self.run[self._construct_path_with_prefix("model/summary")] = neptune.types.File.from_content(
content=model_str, extension="txt"
)

@rank_zero_only
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
Expand All @@ -515,35 +542,33 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
if checkpoint_callback.last_model_path:
model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
file_names.add(model_last_name)
self.experiment[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)
self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)

# save best k models
for key in checkpoint_callback.best_k_models.keys():
model_name = self._get_full_model_name(key, checkpoint_callback)
file_names.add(model_name)
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(key)
self.run[f"{checkpoints_namespace}/{model_name}"].upload(key)

# log best model path and checkpoint
if checkpoint_callback.best_model_path:
self.experiment[
self._construct_path_with_prefix("model/best_model_path")
] = checkpoint_callback.best_model_path
self.run[self._construct_path_with_prefix("model/best_model_path")] = checkpoint_callback.best_model_path

model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
file_names.add(model_name)
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)
self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)

# remove old models logged to experiment if they are not part of best k models at this point
if self.experiment.exists(checkpoints_namespace):
exp_structure = self.experiment.get_structure()
if self.run.exists(checkpoints_namespace):
exp_structure = self.run.get_structure()
uploaded_model_names = self._get_full_model_names_from_exp_structure(exp_structure, checkpoints_namespace)

for file_to_drop in list(uploaded_model_names - file_names):
del self.experiment[f"{checkpoints_namespace}/{file_to_drop}"]
del self.run[f"{checkpoints_namespace}/{file_to_drop}"]

# log best model score
if checkpoint_callback.best_model_score:
self.experiment[self._construct_path_with_prefix("model/best_model_score")] = (
self.run[self._construct_path_with_prefix("model/best_model_score")] = (
checkpoint_callback.best_model_score.cpu().detach().numpy()
)

Expand Down Expand Up @@ -634,13 +659,11 @@ def log_artifact(self, artifact: str, destination: Optional[str] = None) -> None
self._signal_deprecated_api_usage("log_artifact", f"logger.run['{key}].log('path_to_file')")
self.run[key].log(destination)

@rank_zero_only
def set_property(self, *args, **kwargs):
self._signal_deprecated_api_usage(
"log_artifact", f"logger.run['{self._prefix}/{self.PARAMETERS_KEY}/key'].log(value)", raise_exception=True
)

@rank_zero_only
def append_tags(self, *args, **kwargs):
self._signal_deprecated_api_usage(
"append_tags", "logger.run['sys/tags'].add(['foo', 'bar'])", raise_exception=True
Expand Down
6 changes: 3 additions & 3 deletions tests/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def tmpdir_unittest_fixture(request, tmpdir):
class TestNeptuneLogger(unittest.TestCase):
def test_neptune_online(self, neptune):
logger = NeptuneLogger(api_key="test", project="project")
created_run_mock = logger._run_instance
created_run_mock = logger.run

self.assertEqual(logger._run_instance, created_run_mock)
self.assertEqual(logger.name, "Run test name")
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_neptune_pickling(self, neptune):
pickled_logger = pickle.dumps(logger)
unpickled = pickle.loads(pickled_logger)

neptune.init.assert_called_once_with(project="test-project", api_token=None, run="TEST-42")
neptune.init.assert_called_once_with(name="Test name", run=unpickleable_run._short_id)
self.assertIsNotNone(unpickled.experiment)

@patch("pytorch_lightning.loggers.neptune.Run", Run)
Expand Down Expand Up @@ -360,7 +360,7 @@ def test_legacy_functions(self, neptune, neptune_file_mock, warnings_mock):
logger = NeptuneLogger(api_key="test", project="project")

# test deprecated functions which will be shut down in pytorch-lightning 1.7.0
attr_mock = logger._run_instance.__getitem__
attr_mock = logger.run.__getitem__
attr_mock.reset_mock()
fake_image = {}

Expand Down

0 comments on commit 37affb2

Please sign in to comment.