Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate global step with progress tracking #11805

Merged
merged 12 commits into from
Mar 7, 2022
Merged

Conversation

carmocca
Copy link
Contributor

@carmocca carmocca commented Feb 8, 2022

What does this PR do?

  • trainer.global_step now returns the progress tracking's optimizer step count.
  • ModelCheckpoint and other checkpoint logic now use this attribute.
  • The previous trainer.fit_loop.epoch_loop.global_step counter is now renamed to trainer.fit_loop.epoch_loop._batches_that_stepped, a name that matches its actual behaviour.
    • Loggers still use this value as the x-axis step value.

Fixes #7406

Does your PR introduce any breaking changes? If yes, please list them.

  • Any access to trainer.global_step during an intra-training validation hook will now correctly return the number of optimizer steps taken already. That would be new_global_step == master_global_step + 1. In pseudocode

Before:

training_step()
validation_if_necessary()
global_step++

Now:

training_step()
global_step++
validation_if_necessary()
  • Saved checkpoints that use the global step value as part of the filename are now increased by 1 for the reason in the bullet before.

  • If users were using TBPTT or multiple optimizes, the trainer.global_step value will account for those and be different from the value in current master.

  • The Trainer arguments {min,max}_steps compare with the new global_step value so they suffer from the same breaking changes. In the case of multiple optimizers or TBPTT users will need to adjust them.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

cc @tchaton @rohitgr7 @akihironitta @awaelchli @ananthsub @ninginthecloud @Borda @carmocca

@carmocca carmocca added checkpointing Related to checkpointing breaking change Includes a breaking change progress tracking (internal) Related to the progress tracking dataclasses labels Feb 8, 2022
@carmocca carmocca added this to the 1.6 milestone Feb 8, 2022
@carmocca carmocca self-assigned this Feb 8, 2022
@carmocca carmocca force-pushed the bugfix/saved-global-step branch from a5c9580 to 70e64f9 Compare February 8, 2022 00:38
@carmocca carmocca changed the base branch from master to bugfix/5007 February 8, 2022 00:48
@carmocca carmocca force-pushed the bugfix/saved-global-step branch from 70e64f9 to 6356ef3 Compare February 8, 2022 03:23
@carmocca carmocca force-pushed the bugfix/saved-global-step branch 3 times, most recently from ca4c65d to f0dfd23 Compare February 8, 2022 14:12
@carmocca carmocca added the bug Something isn't working label Feb 9, 2022
@carmocca carmocca force-pushed the bugfix/saved-global-step branch from f0dfd23 to e218bfc Compare February 10, 2022 16:38
Base automatically changed from bugfix/5007 to master February 10, 2022 16:56
@carmocca carmocca force-pushed the bugfix/saved-global-step branch from e218bfc to 87bf779 Compare February 10, 2022 17:06
@carmocca carmocca changed the base branch from master to feat/ptracking-manual-opt-step February 10, 2022 17:07
@carmocca carmocca force-pushed the bugfix/saved-global-step branch from 87bf779 to dec43b9 Compare February 10, 2022 17:07
@carmocca carmocca force-pushed the feat/ptracking-manual-opt-step branch from b543100 to dc28ad0 Compare February 10, 2022 17:18
@carmocca carmocca force-pushed the bugfix/saved-global-step branch 4 times, most recently from 4843cda to d95f2b7 Compare February 11, 2022 00:47
@carmocca carmocca force-pushed the feat/ptracking-manual-opt-step branch from 5afe552 to 124e9ce Compare February 11, 2022 16:51
Base automatically changed from feat/ptracking-manual-opt-step to master February 16, 2022 21:27
@mergify mergify bot added the ready PRs ready to be merged label Mar 2, 2022
@carmocca carmocca enabled auto-merge (squash) March 2, 2022 13:46
@mergify mergify bot added the has conflicts label Mar 5, 2022
@mergify mergify bot removed the has conflicts label Mar 5, 2022
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work !

@carmocca carmocca merged commit aea96e4 into master Mar 7, 2022
@carmocca carmocca deleted the bugfix/saved-global-step branch March 7, 2022 19:21
Borda pushed a commit to jerome-habana/lightning that referenced this pull request Mar 10, 2022
@yifuwang
Copy link
Contributor

yifuwang commented Mar 14, 2022

Hi @carmocca, it seems that last.ckpt is not always generated anymore (due to the removal of ModelCheckpoint.on_train_end)? Some of Meta's internal test broke because they expected last.ckpt to exist but it wasn't produced anymore.

Can you help us understand under what circumstances last.ckpt was generate before this change, but not anymore after this change? What should the user do if their load logic looks for last.ckpt?

@jjenniferdai @ananthsub

@jjenniferdai
Copy link
Contributor

I think on_train_end should be brought back. An example of when last.ckpt used to be generated and now is not - if the ModelCheckpoint trigger qualifications do not pass (e.g. every_n_epochs = 10 but trainer max_epochs=1). Before we would still save last.ckpt from on_train_end, now we save nothing. If save_last is set users would expect last.ckpt to get saved.

@yifuwang
Copy link
Contributor

yifuwang commented Mar 15, 2022

Hi @carmocca, this change broke many Meta's internal use cases' tests. To summarize how it's affecting us:

  • last.ckpt is not always generated anymore even when save_last is set (due to the removal of ModelCheckpoint.on_train_end). This seems to have broken the contract "save_last: When True, always saves the model at the end of the epoch to a file last.ckpt".
  • As you mentioned, the step count in checkpoint filenames is incremented by 1.
  • The steps at which loggers are flushed are shifted by 1.

I wonder:

  • Are all these changes intended?
  • If yes, are they all necessary? (e.g., can we keep the checkpointing/logging behavior intact by compensating the filename/flushing condition?)
  • If all these changes are intended and necessary, can we introduce them separately over time?

@awaelchli
Copy link
Contributor

awaelchli commented Mar 15, 2022

@yifuwang After reading your quote:

This seems to have broken the contract "save_last: When True, always saves the model at the end of the epoch to a file last.ckpt".

I'm updating the docs for that argument #12332. The description is wrong, the behavior is different in general. With default checkpointing behavior, the statement is still true though. Have you found a case where it is not?

@yifuwang
Copy link
Contributor

Thanks for the prompt response @awaelchli! Please see the example provided by @jjenniferdai above for how the behavior has differed after this change:

An example of when last.ckpt used to be generated and now is not - if the ModelCheckpoint trigger qualifications do not pass (e.g. every_n_epochs = 10 but trainer max_epochs=1). Before we would still save last.ckpt from on_train_end, now we save nothing. If save_last is set users would expect last.ckpt to get saved.

@carmocca
Copy link
Contributor Author

ModelCheckpoint checks whether it has already saved at the current global step with https://github.com/PyTorchLightning/pytorch-lightning/blob/7ee690758ccad7f702460d056f6369c1d4371a46/pytorch_lightning/callbacks/model_checkpoint.py#L387

Before this PR, with all the manual decrements and increments to global_step we had a "bug" where we were saving a "duplicate" last checkpoint even though we had already saved at this global step. After fixing it, the on_train_end override became redundant: #11805 (comment)

As Adrian mentioned, the existing implementation meant the doc were not strictly correct.

One could avoid the issues you are seeing by instantiating a separate ModelCheckpoint that only saves "last" checkpoints. Then, the state of each "saving mode" will not impact the others: https://github.com/PyTorchLightning/pytorch-lightning/blob/7ee690758ccad7f702460d056f6369c1d4371a46/pytorch_lightning/callbacks/model_checkpoint.py#L366-L373

Following this structure would also make sense if one wants to checkpoint every_n_epochs=N and save_last=True but wants save_last to check every epoch. If setting both makes "last" checkpoints check for both conditions during training (the case before this PR), then the same logic should apply at the end of training.

The steps at which loggers are flushed are shifted by 1.

This might not have been intended. Can you elaborate on what's the new observed and expected behaviour?

@jjenniferdai
Copy link
Contributor

An example of when last.ckpt used to be generated and now is not - if the ModelCheckpoint trigger qualifications do not pass (e.g. every_n_epochs = 10 but trainer max_epochs=1). Before we would still save last.ckpt from on_train_end, now we save nothing. If save_last is set users would expect last.ckpt to get saved.

I think users would still expect to save last.ckpt in this example - right? It used to, but no longer does - because ModelCheckpoint.save_checkpoint() never gets called (@carmocca - i.e. all the logic in your above comment) both before and after, but before last.ckpt was still correctly generated in on_train_end

@yifuwang
Copy link
Contributor

@carmocca for the following snippet, a last.ckpt is generated before this PR, but not anymore after:

import uuid

import torch
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


if __name__ == "__main__":
    tmpdir = f"/tmp/{uuid.uuid4()}"
    print(tmpdir)

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        callbacks=[ModelCheckpoint(dirpath=tmpdir, every_n_epochs=10, save_last=True)],
        enable_checkpointing=True,
    )
    model = BoringModel()
    trainer.fit(
        model, train_dataloaders=DataLoader(RandomDataset(32, 64), batch_size=2)
    )

The following contract was respected prior to this PR, but not anymore after:

save_last: When True, always saves the model at the end of the epoch to a file last.ckpt

It is a BC-breaking change that changes a behavior some users rely on, regardless of whether it is believed to be a "bug".

@carmocca
Copy link
Contributor Author

Yes, I'm fine restoring the previous behaviour for every_n_epochs as it wasn't an intended change of this PR. Will be done before the release.

@zhong-yy
Copy link

zhong-yy commented Feb 6, 2023

I think it can cause confusion, especially when using the default filename ModelCheckpoint(filename="{epoch}-{step}"), where epoch numbers from 0 and step numbers from 1. See #16636

Also, it is too counter-intuitive to number something from 1 instead of 0 in python. I can't find any explicit statement about this numbering stragtegy in the documentation ModelCheckpoint.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
breaking change Includes a breaking change bug Something isn't working checkpointing Related to checkpointing priority: 0 High priority task progress tracking (internal) Related to the progress tracking dataclasses ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

global_step/current_epoch issues
7 participants