Skip to content

Commit

Permalink
Fixed uploading best model checkpoint in NeptuneLogger (#10369)
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky authored Dec 1, 2021
1 parent 72cc8b7 commit c647841
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 19 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Improved exception message if `rich` version is less than `10.2.2` ([#10839](https://github.com/PyTorchLightning/pytorch-lightning/pull/10839))


- Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369))


## [1.5.4] - 2021-11-30

### Fixed
Expand Down
18 changes: 12 additions & 6 deletions pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,16 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
file_names.add(model_name)
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(key)

# log best model path and checkpoint
if checkpoint_callback.best_model_path:
self.experiment[
self._construct_path_with_prefix("model/best_model_path")
] = checkpoint_callback.best_model_path

model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
file_names.add(model_name)
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)

# remove old models logged to experiment if they are not part of best k models at this point
if self.experiment.exists(checkpoints_namespace):
exp_structure = self.experiment.get_structure()
Expand All @@ -531,11 +541,7 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
for file_to_drop in list(uploaded_model_names - file_names):
del self.experiment[f"{checkpoints_namespace}/{file_to_drop}"]

# log best model path and best model score
if checkpoint_callback.best_model_path:
self.experiment[
self._construct_path_with_prefix("model/best_model_path")
] = checkpoint_callback.best_model_path
# log best model score
if checkpoint_callback.best_model_score:
self.experiment[self._construct_path_with_prefix("model/best_model_score")] = (
checkpoint_callback.best_model_score.cpu().detach().numpy()
Expand All @@ -544,7 +550,7 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo
@staticmethod
def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> str:
"""Returns model name which is string `modle_path` appended to `checkpoint_callback.dirpath`."""
expected_model_path = f"{checkpoint_callback.dirpath}/"
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
if not model_path.startswith(expected_model_path):
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
return model_path[len(expected_model_path) :]
Expand Down
33 changes: 20 additions & 13 deletions tests/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,15 @@ def test_after_save_checkpoint(self, neptune):
logger, run_instance_mock, run_attr_mock = self._get_logger_with_mocks(
api_key="test", project="project", **prefix
)
models_root_dir = os.path.join("path", "to", "models")
cb_mock = MagicMock(
dirpath="path/to/models",
last_model_path="path/to/models/last",
dirpath=models_root_dir,
last_model_path=os.path.join(models_root_dir, "last"),
best_k_models={
"path/to/models/model1": None,
"path/to/models/model2/with/slashes": None,
f"{os.path.join(models_root_dir, 'model1')}": None,
f"{os.path.join(models_root_dir, 'model2/with/slashes')}": None,
},
best_model_path="path/to/models/best_model",
best_model_path=os.path.join(models_root_dir, "best_model"),
best_model_score=None,
)

Expand All @@ -292,19 +293,21 @@ def test_after_save_checkpoint(self, neptune):

# then:
self.assertEqual(run_instance_mock.__setitem__.call_count, 1)
self.assertEqual(run_instance_mock.__getitem__.call_count, 3)
self.assertEqual(run_attr_mock.upload.call_count, 3)
self.assertEqual(run_instance_mock.__getitem__.call_count, 4)
self.assertEqual(run_attr_mock.upload.call_count, 4)
run_instance_mock.__setitem__.assert_called_once_with(
f"{model_key_prefix}/best_model_path", "path/to/models/best_model"
f"{model_key_prefix}/best_model_path", os.path.join(models_root_dir, "best_model")
)
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/last")
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1")
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes")
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/best_model")
run_attr_mock.upload.assert_has_calls(
[
call("path/to/models/last"),
call("path/to/models/model1"),
call("path/to/models/model2/with/slashes"),
call(os.path.join(models_root_dir, "last")),
call(os.path.join(models_root_dir, "model1")),
call(os.path.join(models_root_dir, "model2/with/slashes")),
call(os.path.join(models_root_dir, "best_model")),
]
)

Expand Down Expand Up @@ -394,8 +397,12 @@ def test__get_full_model_name(self):
# given:
SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"])
test_input_data = [
("key.ext", "foo/bar/key.ext", SimpleCheckpoint(dirpath="foo/bar")),
("key/in/parts.ext", "foo/bar/key/in/parts.ext", SimpleCheckpoint(dirpath="foo/bar")),
("key.ext", os.path.join("foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar"))),
(
"key/in/parts.ext",
os.path.join("foo", "bar", "key/in/parts.ext"),
SimpleCheckpoint(dirpath=os.path.join("foo", "bar")),
),
]

# expect:
Expand Down

0 comments on commit c647841

Please sign in to comment.