Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resuming training resets the logged step number #12274

Closed
eladsegal opened this issue Mar 9, 2022 · 7 comments · Fixed by #13467
Closed

Resuming training resets the logged step number #12274

eladsegal opened this issue Mar 9, 2022 · 7 comments · Fixed by #13467
Assignees
Labels
checkpointing Related to checkpointing priority: 0 High priority task progress tracking (internal) Related to the progress tracking dataclasses
Milestone

Comments

@eladsegal
Copy link
Contributor

eladsegal commented Mar 9, 2022

🐛 Bug

The change introduced in #11805 causes a reset to the logged step number.
https://github.com/PyTorchLightning/pytorch-lightning/blob/49a4a36ad45b937dd0124ecfb08eb7400dbf3950/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py#L122

To Reproduce

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run(ckpt_path=None):
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        num_sanity_val_steps=0,
        max_epochs=2,
        enable_model_summary=False,
        callbacks=ModelCheckpoint(dirpath="checkpoints", save_top_k=-1, filename="{epoch}", save_on_train_epoch_end=False),
        log_every_n_steps=1,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data, ckpt_path=ckpt_path)


if __name__ == "__main__":
    run()
    run("checkpoints/epoch=0.ckpt")

The script will create two tensorboard logs:

  • version_0: steps 0 to 63
  • version_1: steps 0 to 31

Expected behavior

  • version_1: steps 31 to 63

This was the behavior before #11805

Environment

  • PyTorch Lightning Version (e.g., 1.5.0): master (49a4a36)
  • Fault-tolerant training is off (PL_FAULT_TOLERANT_TRAINING=0)

cc @tchaton @rohitgr7 @akihironitta @awaelchli @ananthsub @ninginthecloud @carmocca

@ananthsub ananthsub added progress tracking (internal) Related to the progress tracking dataclasses checkpointing Related to checkpointing labels Mar 9, 2022
@carmocca carmocca self-assigned this Mar 9, 2022
@carmocca carmocca added the priority: 0 High priority task label Mar 9, 2022
@carmocca carmocca modified the milestones: 1.6, 1.6.x Mar 24, 2022
@toriving
Copy link

toriving commented Apr 18, 2022

Any progress on this issue?
Or does a workaround exist?

@ZENGYIMING-EAMON
Copy link

Same bug here for PL 1.6.1. Any progress on this issue?

@ZENGYIMING-EAMON
Copy link

Or how can we hack this to work around?

@rohitgr7
Copy link
Contributor

just wondering if you could just point _batches_that_stepped to global_step / number of optimizers?

@ZENGYIMING-EAMON
Copy link

Don’t know how to achieve this. What do you mean by global_step / number of optimizers ? and why the _batches_that_stepped should be pointed to it?

@rbregier
Copy link

rbregier commented May 9, 2022

Hi, this workaround seems to work for my use case:

checkpoint = torch.load(args.ckpt_path, map_location='cpu')
global_step_offset = checkpoint["global_step"]
trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset
del checkpoint    
trainer.fit(experiment, datamodule=datamodule, ckpt_path=args.ckpt_path)

@rohitgr7
Copy link
Contributor

cc: @carmocca wdyt?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpointing Related to checkpointing priority: 0 High priority task progress tracking (internal) Related to the progress tracking dataclasses
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants