Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support non-conventional optimizers #16143

Open
simonpokorny opened this issue Dec 20, 2022 · 6 comments
Open

Support non-conventional optimizers #16143

simonpokorny opened this issue Dec 20, 2022 · 6 comments
Labels
design Includes a design discussion feature Is an improvement or enhancement optimizer
Milestone

Comments

@simonpokorny
Copy link

simonpokorny commented Dec 20, 2022

Bug description

I turned off the automatic optimisation, because I am using SAM optimizer (https://github.com/davda54/sam). After that, the global_step of the trainer is not updating each train step, therefore checkpointcallback are not call even though it is pass to trainer ..

used callback :

pl.callbacks.ModelCheckpoint save_weights_only=True, save_top_k=3, monitor="val_acc", mode="max", save_on_train_epoch_end=False)

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
- PyTorch Lightning Version 1.8.4:
- PyTorch Version 1.13:
- Python version 3.9:

More info

No response

cc @tchaton @justusschock @awaelchli @Borda @carmocca

@simonpokorny simonpokorny added the needs triage Waiting to be triaged by maintainers label Dec 20, 2022
@carmocca carmocca added this to the v1.8.x milestone Dec 21, 2022
@carmocca carmocca added bug Something isn't working loops Related to the Loop API and removed needs triage Waiting to be triaged by maintainers labels Dec 21, 2022
@carmocca carmocca self-assigned this Dec 21, 2022
@carmocca carmocca added the waiting on author Waiting on user action, correction, or update label Dec 21, 2022
@carmocca
Copy link
Contributor

Can you provide more details? This example shows it working

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.automatic_optimization = False

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        print(self.trainer.global_step)
        opt = self.optimizers()
        opt.zero_grad()
        loss = self(batch).sum()
        loss.backward()
        opt.step()
        return loss.detach()

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_epochs=2,
        limit_train_batches=3,
        enable_model_summary=False,
        enable_progress_bar=False,
        logger=False,
        enable_checkpointing=False,
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

@simonpokorny
Copy link
Author

Thanks, for sure.

I used your example with the custom optimizer (see below) and the global step is not increasing ..

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from classifiers.sam import SAM


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)
        self.labels = torch.randint(low=0, high=2, size=(size,))

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Linear(32, 2)
        self.automatic_optimization = False
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):

        data, labels = batch

        opt = self.optimizers()

        # first forward-backward pass
        pred = self.model(data)
        loss_1 = self.loss_fn(pred, labels)
        self.manual_backward(loss_1)
        opt.first_step(zero_grad=True)

        # second forward-backward pass
        pred = self.model(data)
        loss_2 = self.loss_fn(pred, labels)
        self.manual_backward(loss_2)
        opt.second_step(zero_grad=True)

        print(self.trainer.global_step)
        return loss_2

    def configure_optimizers(self):
        base_optimizer = torch.optim.Adam
        optimizer = SAM(self.parameters(), base_optimizer, rho=1, adaptive=True, lr=0.001)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}


def run():
    train_data = DataLoader(RandomDataset(size=32, length=64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_epochs=2,
        limit_train_batches=3,
        enable_model_summary=False,
        enable_progress_bar=False,
        logger=False,
        enable_checkpointing=False,
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

Where the SAM optimizer is from https://github.com/davda54/sam.

class SAM(torch.optim.Optimizer):
    """
    SAM Optimizer
    https://github.com/davda54/sam
    """

    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][
            0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
            torch.stack([
                ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                for group in self.param_groups for p in group["params"]
                if p.grad is not None
            ]),
            p=2
        )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

@carmocca
Copy link
Contributor

carmocca commented Dec 21, 2022

Okay. This happens because we assume there will be an optimizer.step() call, which is what we wrap to inject the strategy-specific logic (e.g. DDP): https://github.com/Lightning-AI/lightning/blob/50331e08e111d6b9ebb25a21a86b7170b46c5f1f/src/pytorch_lightning/core/optimizer.py#L101-L173

The call chain is LightningModule.training_step() -> _LightningOptimizer.step() -> Strategy.optimizer_step() -> PrecisionPlugin.optimizer_step() -> Optimizer.step()

Your use of the SAM optimizer violates this assumption, as you are calling two different step methods ({first,second}_step) which are not wrapped like .step(). It's not clear to me if you would expect to increase the global_step count after each or if only after the second_step().

To resolve this, we would need some mechanism to indicate what method we should wrap.
cc @awaelchli @justusschock in case they have suggestions in this regard.

Another example of this issue is in https://github.com/ludwigwinkler/JaxLightning/blob/8585863be636152b6adba77a0436ff7509fb92f3/BNN/JaxLightning_BNN.py#L215-L217 (cc @ludwigwinkler) which also suffers from this issue because the Jax optimizer uses .update() instead of .step()

@carmocca carmocca added design Includes a design discussion optimizer and removed waiting on author Waiting on user action, correction, or update loops Related to the Loop API labels Dec 21, 2022
@simonpokorny
Copy link
Author

The SAM optimizer training step can be rewrite to classical form with a single closure-based step function

    def training_step(self, batch, batch_idx):

        data, labels = batch
        opt = self.optimizers()

        def closure():
            loss = self.loss_fn(self.model(data), labels)
            loss.backward()
            return loss

        loss = self.loss_fn(self.model(data), labels)
        loss.backward()
        opt.step(closure)
        opt.zero_grad()

        print(self.trainer.global_step)
        return loss

After that , pl is able to wrap call .step() and self.trainer.global_step is increasing.

@carmocca carmocca modified the milestones: v1.8.x, future Dec 21, 2022
@carmocca carmocca removed the bug Something isn't working label Mar 17, 2023
@carmocca carmocca removed their assignment Mar 17, 2023
@carmocca carmocca added bug Something isn't working feature Is an improvement or enhancement and removed bug Something isn't working labels Mar 17, 2023
@carmocca carmocca changed the title global step is not updating if the automatic optimisation is not enable.. Support non-conventional optimizers Mar 17, 2023
@awaelchli
Copy link
Contributor

awaelchli commented Sep 20, 2023

If I understand this here correctly, my proposal is to have a check in our LightningOptimizer wrapper that the step method is available. If not, raise an error suggesting the user should do optimizer.step = optimizer.real_step_method in e.g. the configure_optimizers hook to have it supported in Lightning. IMO this is the easiest and doesn't require new APIs.

@carmocca
Copy link
Contributor

The suggestion

have a check in our LightningOptimizer wrapper that the step method is available

is not foolproof: the SAM optimizer shown above offers first_step, second_step, and step. If the user didn't know about this limitation and called first_step and second_step, they would face this issue but such check wouldn't trigger because the Optimizer also defines a step.

But I don't have a better suggestion that doesn't involve a complex solution such as wrapping all optimizer methods and checking if parameters changed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion feature Is an improvement or enhancement optimizer
Projects
None yet
Development

No branches or pull requests

3 participants