-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support non-conventional optimizers #16143
Comments
Can you provide more details? This example shows it working import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
self.automatic_optimization = False
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
print(self.trainer.global_step)
opt = self.optimizers()
opt.zero_grad()
loss = self(batch).sum()
loss.backward()
opt.step()
return loss.detach()
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
max_epochs=2,
limit_train_batches=3,
enable_model_summary=False,
enable_progress_bar=False,
logger=False,
enable_checkpointing=False,
)
trainer.fit(model, train_dataloaders=train_data)
if __name__ == "__main__":
run() |
Thanks, for sure. I used your example with the custom optimizer (see below) and the global step is not increasing .. import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from classifiers.sam import SAM
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
self.labels = torch.randint(low=0, high=2, size=(size,))
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.model = torch.nn.Linear(32, 2)
self.automatic_optimization = False
self.loss_fn = torch.nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
data, labels = batch
opt = self.optimizers()
# first forward-backward pass
pred = self.model(data)
loss_1 = self.loss_fn(pred, labels)
self.manual_backward(loss_1)
opt.first_step(zero_grad=True)
# second forward-backward pass
pred = self.model(data)
loss_2 = self.loss_fn(pred, labels)
self.manual_backward(loss_2)
opt.second_step(zero_grad=True)
print(self.trainer.global_step)
return loss_2
def configure_optimizers(self):
base_optimizer = torch.optim.Adam
optimizer = SAM(self.parameters(), base_optimizer, rho=1, adaptive=True, lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
def run():
train_data = DataLoader(RandomDataset(size=32, length=64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
max_epochs=2,
limit_train_batches=3,
enable_model_summary=False,
enable_progress_bar=False,
logger=False,
enable_checkpointing=False,
)
trainer.fit(model, train_dataloaders=train_data)
if __name__ == "__main__":
run() Where the SAM optimizer is from https://github.com/davda54/sam. class SAM(torch.optim.Optimizer):
"""
SAM Optimizer
https://github.com/davda54/sam
"""
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
super(SAM, self).__init__(params, defaults)
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups
self.defaults.update(self.base_optimizer.defaults)
@torch.no_grad()
def first_step(self, zero_grad=False):
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group["rho"] / (grad_norm + 1e-12)
for p in group["params"]:
if p.grad is None: continue
self.state[p]["old_p"] = p.data.clone()
e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
p.add_(e_w) # climb to the local maximum "w + e(w)"
if zero_grad: self.zero_grad()
@torch.no_grad()
def second_step(self, zero_grad=False):
for group in self.param_groups:
for p in group["params"]:
if p.grad is None: continue
p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
self.base_optimizer.step() # do the actual "sharpness-aware" update
if zero_grad: self.zero_grad()
@torch.no_grad()
def step(self, closure=None):
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
self.first_step(zero_grad=True)
closure()
self.second_step()
def _grad_norm(self):
shared_device = self.param_groups[0]["params"][
0].device # put everything on the same device, in case of model parallelism
norm = torch.norm(
torch.stack([
((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
for group in self.param_groups for p in group["params"]
if p.grad is not None
]),
p=2
)
return norm
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
self.base_optimizer.param_groups = self.param_groups
|
Okay. This happens because we assume there will be an The call chain is Your use of the To resolve this, we would need some mechanism to indicate what method we should wrap. Another example of this issue is in https://github.com/ludwigwinkler/JaxLightning/blob/8585863be636152b6adba77a0436ff7509fb92f3/BNN/JaxLightning_BNN.py#L215-L217 (cc @ludwigwinkler) which also suffers from this issue because the Jax optimizer uses |
The def training_step(self, batch, batch_idx):
data, labels = batch
opt = self.optimizers()
def closure():
loss = self.loss_fn(self.model(data), labels)
loss.backward()
return loss
loss = self.loss_fn(self.model(data), labels)
loss.backward()
opt.step(closure)
opt.zero_grad()
print(self.trainer.global_step)
return loss After that , pl is able to wrap call |
If I understand this here correctly, my proposal is to have a check in our LightningOptimizer wrapper that the step method is available. If not, raise an error suggesting the user should do |
The suggestion
is not foolproof: the SAM optimizer shown above offers But I don't have a better suggestion that doesn't involve a complex solution such as wrapping all optimizer methods and checking if parameters changed |
Bug description
I turned off the automatic optimisation, because I am using SAM optimizer (https://github.com/davda54/sam). After that, the global_step of the trainer is not updating each train step, therefore checkpointcallback are not call even though it is pass to trainer ..
used callback :
pl.callbacks.ModelCheckpoint save_weights_only=True, save_top_k=3, monitor="val_acc", mode="max", save_on_train_epoch_end=False)
How to reproduce the bug
No response
Error messages and logs
Environment
Current environment
More info
No response
cc @tchaton @justusschock @awaelchli @Borda @carmocca
The text was updated successfully, but these errors were encountered: