Skip to content

Commit

Permalink
Remove ModelCheckpoint.on_train_end
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Feb 8, 2022
1 parent fa5abe2 commit 6356ef3
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 40 deletions.
13 changes: 0 additions & 13 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,19 +315,6 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
return
self.save_checkpoint(trainer)

def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Save a checkpoint when training stops.
This will only save a checkpoint if `save_last` is also enabled as the monitor metrics logged during
training/validation steps or end of epochs are not guaranteed to be available at this stage.
"""
if self._should_skip_saving_checkpoint(trainer) or not self.save_last:
return
if self.verbose:
rank_zero_info("Saving latest checkpoint...")
monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step)
self._save_last_checkpoint(trainer, monitor_candidates)

def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> Dict[str, Any]:
Expand Down
4 changes: 2 additions & 2 deletions tests/checkpointing/test_checkpoint_callback_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def training_step(self, batch, batch_idx):
trainer.fit(model)

if save_last:
# last epochs are saved every step (so double the save calls) and once `on_train_end`
expected = expected * 2 + 1
# last epochs are saved every step (so double the save calls)
expected = expected * 2
assert save_mock.call_count == expected


Expand Down
25 changes: 0 additions & 25 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
import pickle
Expand Down Expand Up @@ -776,30 +775,6 @@ def test_default_checkpoint_behavior(tmpdir):
assert ckpts[0] == "epoch=2-step=15.ckpt"


@pytest.mark.parametrize("max_epochs", [1, 2])
@pytest.mark.parametrize("should_validate", [True, False])
@pytest.mark.parametrize("save_last", [True, False])
@pytest.mark.parametrize("verbose", [True, False])
def test_model_checkpoint_save_last_warning(
tmpdir, caplog, max_epochs: int, should_validate: bool, save_last: bool, verbose: bool
):
"""Tests 'Saving latest checkpoint...' log."""
# set a high `every_n_epochs` to avoid saving in `on_train_epoch_end`. the message is only printed `on_train_end`
# but it would get skipped because it got already saved in `on_train_epoch_end` for the same global step
ckpt = ModelCheckpoint(dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose, every_n_epochs=123)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[ckpt],
max_epochs=max_epochs,
limit_train_batches=1,
limit_val_batches=int(should_validate),
)
model = BoringModel()
with caplog.at_level(logging.INFO):
trainer.fit(model)
assert caplog.messages.count("Saving latest checkpoint...") == (verbose and save_last)


def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
"""Tests that the save_last checkpoint contains the latest information."""
seed_everything(100)
Expand Down

0 comments on commit 6356ef3

Please sign in to comment.