Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MisconfigurationException "MultiStepLR" with torch 2.0 #15912

Closed
teamclouday opened this issue Dec 5, 2022 · 12 comments · Fixed by #15940
Closed

MisconfigurationException "MultiStepLR" with torch 2.0 #15912

teamclouday opened this issue Dec 5, 2022 · 12 comments · Fixed by #15940

Comments

@teamclouday
Copy link

teamclouday commented Dec 5, 2022

Bug description

Got following error after upgrading to torch 2.0:

lightning_lite.utilities.exceptions.MisconfigurationException: The provided lr scheduler MultiStepLR doesn't follow PyTorch's LRScheduler API. You should override the LightningModule.lr_scheduler_step hook with your own logic if you are using a custom LR scheduler.

How to reproduce the bug

https://colab.research.google.com/drive/1Nks7DPZlrxUDW0UGa_2-17uOMbH0aWle?usp=sharing

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

The error came from

https://github.com/Lightning-AI/lightning/blob/6cc493360d9dfdd132665343d6611e66e9760885/src/pytorch_lightning/core/optimizer.py#L351

LRSchedulerTypeTuple is using torch.optim.lr_scheduler._LRScheduler to identify torch schedulers

https://github.com/Lightning-AI/lightning/blob/6cc493360d9dfdd132665343d6611e66e9760885/src/pytorch_lightning/utilities/types.py#L114

However, pytorch has changed _LRScheduler to be subclass of LRScheduler (link)
And MultiStepLR is subclass of LRScheduler as well (link)

A fix could be change to use torch.optim.lr_scheduler.LRScheduler in LRSchedulerTypeTuple


Updated by @akihironitta: This issue is related to #15894.

@teamclouday teamclouday added the needs triage Waiting to be triaged by maintainers label Dec 5, 2022
@lantiga
Copy link
Collaborator

lantiga commented Dec 7, 2022

Thank you @teamclouday, this is something I bumped into as well. I'll look into it shortly

@lantiga lantiga self-assigned this Dec 7, 2022
@lantiga lantiga added lr scheduler torch.compile and removed needs triage Waiting to be triaged by maintainers labels Dec 7, 2022
@teamclouday
Copy link
Author

Thank you for resolving this! Closing

@mviti
Copy link

mviti commented Dec 21, 2022

The same issue exists for CosineAnnealingLR

lightning_lite.utilities.exceptions.MisconfigurationException: The provided lr scheduler CosineAnnealingLR doesn't follow PyTorch's LRScheduler API. You should override the LightningModule.lr_scheduler_step hook with your own logic if you are using a custom LR scheduler.

@mactavish91
Copy link

@lantiga
Hello, when I use the lr-scheduler in transformers lib, I meet the same error, do you know the reason?
from transformers import (
get_polynomial_decay_schedule_with_warmup,
get_cosine_schedule_with_warmup,
)

@lantiga
Copy link
Collaborator

lantiga commented Jan 23, 2023

@mviti seeing this now
CosineAnnealingLR is working for me on Lightning 1.9 and PyTorch nightly, please let me know if this is not the case for you.

If you're a Lightning Lite user, the project has evolved into Fabric since 1.9.0: https://pytorch-lightning.readthedocs.io/en/stable/fabric/fabric.html
Converting your code will be very quick. We'll roll out tutorials etc in the next few weeks.

@lantiga
Copy link
Collaborator

lantiga commented Jan 23, 2023

@mactavish91 just to understand the context, are you using transformers with the lightning Trainer? Can you post minimal code to reproduce?

The reason these issues happen is that PyTorch 2.0 changed the hierarchy of scheduler classes. They were all deriving from _LRScheduler before, but they are deriving from LRScheduler now. Since somewhere in our code we were checking whether something was an instance of _LRScheduler, that check started to fail.

@mactavish91
Copy link

mactavish91 commented Feb 10, 2023

@lantiga Yes, the code is

from transformers import (
    get_polynomial_decay_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)
scheduler = get_cosine_schedule_with_warmup(
  optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps,
   )

the error message is:
lightning_fabric.utilities.exceptions.MisconfigurationException: The provided lr scheduler LambdaLR doesn't follow PyTorch's LRScheduler API. You should override the LightningModule.lr_scheduler_step hook with your own logic if you are using a custom LR scheduler.

@mactavish91
Copy link

mactavish91 commented Feb 10, 2023

When I use lr_scheduler_step rather than configure the scheduler in configure_optimizers, the lr remains unchanged at 0.0 and I don't know why.

def set_schedule(pl_module):  
      optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr)
      scheduler = get_polynomial_decay_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=max_steps,
            lr_end=end_lr,
            power=decay_power,
        )

    sched = {"scheduler": scheduler, "interval": "step"}
 
    return (
        [optimizer]
        [sched],
    )
def configure_optimizers(self):
    return meter_utils.set_schedule(self)

def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
    scheduler.step()

@mimbres
Copy link

mimbres commented Feb 17, 2023

I have the same issue with torch==1.14 from the nvcr.io/nvidia/pytorch:23.01-py3 of NGC docker

@lantiga Yes, the code is

from transformers import (
    get_polynomial_decay_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)
scheduler = get_cosine_schedule_with_warmup(
  optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps,
   )

the error message is: lightning_fabric.utilities.exceptions.MisconfigurationException: The provided lr scheduler LambdaLR doesn't follow PyTorch's LRScheduler API. You should override the LightningModule.lr_scheduler_step hook with your own logic if you are using a custom LR scheduler.

@TalhaUsuf
Copy link

Any further update on this issue ❓ I am using torch==1.14.0a0+44dac51 using the NGC pytorch image and still facing this issue.

@KevinSONG729
Copy link

@mviti There is the same solution about CosineAnnealingLR.Modify file optimizer.py in Pytorch_lightning(My version is 1.8.1):

if not isinstance(scheduler, torch.optim.lr_scheduler.LRScheduler) and not is_overridden("lr_scheduler_step", model):
            raise MisconfigurationException(
                f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow PyTorch's LRScheduler"
                " API. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if"
                " you are using a custom LR scheduler."
            )

@Zialo
Copy link

Zialo commented Jun 23, 2023

I have the same problem with OneCycleLR.

This is part of my code:

def configure_optimizers(self):
    # Get Adam optimizer 
    optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 
    # Set OneCycleLR policy 
    lr_scheduler = { 
        'scheduler': torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.lr, total_steps=self.total_iterations), 
        'interval': 'step', 'frequency': 1, 'name': 'lr_logger' } 
    return [optimizer], [lr_scheduler]

And when I start with the training code block, I receive the next issue:
MisconfigurationException: The provided lr scheduler OneCycleLR doesn't follow PyTorch's LRScheduler API. You should override the LightningModule.lr_scheduler_step hook with your own logic if you are using a custom LR scheduler.

My Pytorch version can be seen here:
Name: torch
Version: 2.0.1+cu117
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: [email protected]
License: BSD-3
Location: c:\users\hodei.zia\anaconda3\lib\site-packages
Requires: networkx, typing-extensions, jinja2, sympy, filelock
Required-by: torchvision, torchaudio, torchmetrics, torch-lr-finder, pytorch-lightning, pytorch-ignite

How can I fix that issue??

EDIT: I change in lightning/src/pytorch_lightning/utilities/types.py LRSchedulerTypeTuple for torch.optim.lr_scheduler.LRScheduler but I have the same issue.

What can I do?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants