-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Regression: CSVLogger
not working on version 2.2.0
#19432
Comments
@ramon-adalia-lmd Would you be able to provide a code example that produces this error? |
Uppon further testing, it does not seem to be specific to 2.2.0, but the bug is still there. Interestingly, the bug happens every other time I run the code. Here is an example script that triggers it: import torch
from lightning import LightningModule
from lightning import Trainer
from lightning.pytorch.loggers import CSVLogger
from torchmetrics import MeanSquaredError
class Model(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(100, 1)
self.train_mse = MeanSquaredError()
self.val_mse = MeanSquaredError()
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = torch.nn.functional.mse_loss(y_hat, y)
self.train_mse.update(y_hat, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
self.val_mse.update(y_hat, y)
def on_train_epoch_end(self):
self.log("train_mse", self.train_mse.compute(), prog_bar=True)
self.train_mse.reset()
def on_validation_epoch_end(self):
self.log("val_mse", self.val_mse.compute(), prog_bar=True)
self.val_mse.reset()
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
def main():
X = torch.randn(100, 100)
y = torch.randn(100, 1)
Z = torch.randn(100, 100)
t = torch.randn(100, 1)
train_loader = torch.utils.data.DataLoader(list(zip(X, y)))
val_loader = torch.utils.data.DataLoader(list(zip(Z, t)))
model = Model()
trainer = Trainer(
max_epochs=5, logger=CSVLogger("test_logs", name="test", version=0)
)
trainer.fit(model, train_loader, val_loader)
if __name__ == "__main__":
main() Run it the first time: works. The second time: error. The third time: works. And so on... |
@ramon-adalia-lmd Ah ok, this is because you fixed the version to 0, so the second time it gets executed, the file is already there, the logger tries to append to the file but sees different keys. In this case, the best we can do I think is delete the file from the beginning if it exists, since the user explicitly asks |
Bug description
CSVLogger
throws the following error when used in version 2.2.0:Going back to 2.1.4 solves the issue.
What version are you seeing the problem on?
master
How to reproduce the bug
No response
Error messages and logs
Environment
Current environment
More info
No response
cc @Borda
The text was updated successfully, but these errors were encountered: