From b9a8f74b2bd598efa61cd8989e2e84d22f227c89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 8 Nov 2021 13:00:19 +0100 Subject: [PATCH] Fix pickling error with CSVLogger (#10388) * Don't store csv.Dictwriter in ExperimentWriter * Add test for pickle after .save() * Add entry in changelog --- CHANGELOG.md | 1 + pytorch_lightning/loggers/csv_logs.py | 6 +++--- tests/loggers/test_all.py | 4 ++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d8683520c8d57..1c5e5dd8141a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `apply_to_collection(defaultdict)` ([#10316](https://github.com/PyTorchLightning/pytorch-lightning/issues/10316)) - Fixed failure when `DataLoader(batch_size=None)` is passed ([#10345](https://github.com/PyTorchLightning/pytorch-lightning/issues/10345)) - Fixed interception of `__init__` arguments for sub-classed DataLoader re-instantiation in Lite ([#10334](https://github.com/PyTorchLightning/pytorch-lightning/issues/10334)) +- Fixed issue with pickling `CSVLogger` after a call to `CSVLogger.save` ([#10388](https://github.com/PyTorchLightning/pytorch-lightning/pull/10388)) ## [1.5.0] - 2021-11-02 diff --git a/pytorch_lightning/loggers/csv_logs.py b/pytorch_lightning/loggers/csv_logs.py index 77adfe551f72d..454a17905c529 100644 --- a/pytorch_lightning/loggers/csv_logs.py +++ b/pytorch_lightning/loggers/csv_logs.py @@ -95,9 +95,9 @@ def save(self) -> None: metrics_keys = list(last_m.keys()) with open(self.metrics_file_path, "w", newline="") as f: - self.writer = csv.DictWriter(f, fieldnames=metrics_keys) - self.writer.writeheader() - self.writer.writerows(self.metrics) + writer = csv.DictWriter(f, fieldnames=metrics_keys) + writer.writeheader() + writer.writerows(self.metrics) class CSVLogger(LightningLoggerBase): diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 67838e219fcfb..271ffce811fe5 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -263,6 +263,10 @@ def _test_loggers_pickle(tmpdir, monkeypatch, logger_class): # the logger needs to remove it from the state before pickle _ = logger.experiment + # logger also has to avoid adding un-picklable attributes to self in .save + logger.log_metrics({"a": 1}) + logger.save() + # test pickling loggers pickle.dumps(logger)