Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorBoardLogger and WandbLogger do not track global_step when resuming training from a checkpoint (both manually, and with fault tolerant) #13163

Closed
mirandrom opened this issue May 26, 2022 · 8 comments
Labels
needs triage Waiting to be triaged by maintainers

Comments

@mirandrom
Copy link

🐛 Bug

When resuming model training from a checkpoint, the TensorboardLogger and WandbLogger will log metrics as if the global_step was reset to 0 (although the global_step in the trainer and pl_module are accurate). This issue arises when manually resuming training from a checkpoint using the ckpt_path arg in Trainer.fit and when doing fault-tolerant training as shown here: https://github.com/PyTorchLightning/pytorch-lightning/blob/1.6.3/pl_examples/fault_tolerant/automatic.py

To Reproduce

I've adapted the script linked above to test this, running v 1.6.3 of pytorch-lightning:

import os
import random as python_random
from argparse import ArgumentParser
from time import sleep

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import _logger as log
from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.loggers import WandbLogger
import wandb


class RandomGetItemDataset(Dataset):
    """A dataset with random elements generated using global rng from torch, numpy and python."""

    def __init__(self, length, size):
        self.size = size
        self.len = length

    def __getitem__(self, index):
        t = torch.rand(self.size)
        n = torch.from_numpy(np.random.rand(self.size))
        p = torch.tensor([python_random.random() for _ in range(self.size)])
        sample = (index + (t + n + p) / 10).float()
        return sample

    def __len__(self):
        return self.len


class SimpleMLP(LightningModule):
    def __init__(self, fail_on_step: int = -1):
        super().__init__()
        self.layer = torch.nn.Linear(1, 2)
        self.seen_batches = []
        self.fail_on_step = fail_on_step

    def training_step(self, batch, batch_idx):
        if self.global_step == self.fail_on_step:
            log.info(
                f"READY TO BE KILLED WITH SIGTERM SIGNAL. " f"Run `kill -SIGTERM {os.getpid()}` in another terminal."
            )
            # this line is used to wait for you to send the signal to exit gracefully.
            while not self.trainer._terminate_gracefully:
                sleep(0.1)
        batch = batch["data"] if isinstance(batch, dict) else batch
        self.seen_batches.append(torch.stack(batch) if isinstance(batch, list) else batch)
        loss = sum(self.layer(b).sum() for b in batch)
        self.log("loss", loss.item())
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def train_dataloader(self):
        return DataLoader(RandomGetItemDataset(3, 1))


def _run_training(default_root_dir=".", max_epochs=3, fail_on_step: int = -1, ckpt_path=None, logger=True):
    model = SimpleMLP(fail_on_step=fail_on_step)
    trainer = Trainer(default_root_dir=default_root_dir, max_epochs=max_epochs,
                      logger=logger, log_every_n_steps=1)
    trainer.fit(model, ckpt_path=ckpt_path)
    wandb.finish()
    return model.seen_batches, model.parameters()


def main(args):
    seed_everything(42)
    os.environ["PL_FAULT_TOLERANT_TRAINING"] = "automatic"  # active fault tolerant automatic

    ckpt_path = ".pl_auto_save.ckpt"
    auto_restart_ckpt_path_exists = os.path.exists(ckpt_path)

    if args.emulate_kill_signal:
        fail_on_step = -1 if auto_restart_ckpt_path_exists else 4
        completed_batches = 4 if auto_restart_ckpt_path_exists else 5
    else:
        fail_on_step = -1
        completed_batches = 9

    if args.use_tb:
        logger = True
    else:
        logger = WandbLogger(
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=args.wandb_run,
            id=args.wandb_run,
        )

    complete_batches, weights = _run_training(fail_on_step=fail_on_step, logger=logger)
    assert len(complete_batches) == completed_batches

    if not auto_restart_ckpt_path_exists and args.emulate_kill_signal:
        assert os.path.exists(ckpt_path)

    if auto_restart_ckpt_path_exists or not args.emulate_kill_signal:
        log.info([w for w in weights])


if __name__ == "__main__":
    parser = ArgumentParser(description="Fault Tolerant Under Signal Example")
    parser.add_argument(
        "--emulate_kill_signal",
        action="store_true",
        help="Whether you should gracefully kill the process with a `SIGTERM` signal.",
    )
    parser.add_argument(
        "--use_tb",
        action="store_true",
        help="Use TensorBoard instead of WandB.",
    )
    parser.add_argument(
        "-e", "--wandb_entity",
        type=str,
        default=None,
        help="Wandb entity.",
    )
    parser.add_argument(
        "-p", "--wandb_project",
        type=str,
        default=None,
        help="Wandb project.",
    )
    parser.add_argument(
        "-r", "--wandb_run",
        type=str,
        default=None,
        help="Wandb run.",
    )
    main(parser.parse_args())

With tensorboard, running these:
python automatic.py --use_tb (without fault)
python automatic.py --use_tb --emulate_kill_signal (with fault)
python automatic.py --use_tb --emulate_kill_signal (resume from fault)

Results in the following, where the epoch is properly logged, but not the step:

image

With wandb, running these:
python automatic.py -e [wandb_entity] -p [wandb_project] -r no_fault (without fault)
python automatic.py -e [wandb_entity] -p [wandb_project] -r fault --emulate_kill_signal (with fault)
python automatic.py -e [wandb_entity] -p [wandb_project] -r fault --emulate_kill_signal (resume from fault)

Results in the following, where the step is properly logged (because I'm only logging once per step, see #13016), but the global_step is reset.

image

Expected behavior

The trainer/global_step in WandbLogger and step in TensorBoardLogger should properly reflect the global_step state of the trainer/pl_module when resuming from checkpoings (either manually or automatically with fault-tolerant training).

Environment

CUDA:
        - GPU:
        - available:         False
        - version:           10.2
* Packages:
        - numpy:             1.22.4
        - pyTorch_debug:     False
        - pyTorch_version:   1.11.0+cu102
        - pytorch-lightning: 1.6.3
        - tqdm:              4.64.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.10.4
        - version:           #171-Ubuntu SMP Fri Nov 5 11:55:11 UTC 2021
@mirandrom mirandrom added the needs triage Waiting to be triaged by maintainers label May 26, 2022
@mirandrom mirandrom changed the title TensorBoardLogger and WandbLogger do not track global_step when resuming from a checkpoint TensorBoardLogger and WandbLogger do not track global_step when resuming training from a checkpoint (both manually, and with fault tolerant) May 28, 2022
@mirandrom
Copy link
Author

Any updates on this? I am currently working around this by explicitly logging the global_step from the module's attribute, e.g. https://github.com/mirandrom/lightning-transformer-pretraining/blob/72491177a13482b6b7e3e0e38f420c79e950c55a/ltp/hf_mlm/model.py#L124

@samgelman
Copy link

I am also experiencing this problem. Here's a plot of weights & biases step (incremented on each call to .log()) vs. the trainer's global_step.
Screen Shot 2022-06-09 at 6 55 25 PM

@mirandrom
Copy link
Author

Related issues I missed when first opening this issue: #12991, #12274, #13069

@manangoel99
Copy link
Contributor

manangoel99 commented Jun 24, 2022

Hey Guys! Engineer from W&B here! Sorry I'm a little late but I managed to track this down to one line

https://github.com/Lightning-AI/lightning/blob/5572797bc80b564286f111861e3d4b408344ae84/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py#L102

The solution to this is to set step = self.trainer.global_step

I'm not entirely sure if this was intentional but I've pushed the fix anyways. It has caused some tests to break so looking into those but meanwhile this change should get things up and running.

@manangoel99
Copy link
Contributor

Also adding resume=True as an argument to your WandbLogger initialization might give you much cleaner looking plots!

@BraveDistribution
Copy link

Any updates on this? I am currently working around this by explicitly logging the global_step from the module's attribute, e.g. https://github.com/mirandrom/lightning-transformer-pretraining/blob/72491177a13482b6b7e3e0e38f420c79e950c55a/ltp/hf_mlm/model.py#L124

I am verifying this workaround, will let you know whether it's enough.

@BraveDistribution
Copy link

Workaround didn't help. I still get the same ridiculous charts:

image

@carmocca
Copy link
Contributor

Duplicate of #12274

@carmocca carmocca marked this as a duplicate of #12274 Jul 11, 2022
@carmocca carmocca closed this as not planned Won't fix, can't repro, duplicate, stale Jul 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants