diff --git a/tests/test_updater_ng.py b/tests/test_updater_ng.py index 0588a07e0a..bedcb2c7d2 100644 --- a/tests/test_updater_ng.py +++ b/tests/test_updater_ng.py @@ -13,6 +13,7 @@ import tempfile import unittest from typing import Callable, ClassVar, List +from unittest.mock import MagicMock, patch from securesystemslib.interface import import_rsa_privatekey_from_file from securesystemslib.signer import SSlibSigner @@ -316,6 +317,25 @@ def test_missing_targetinfo(self) -> None: # Get targetinfo for non-existing file self.assertIsNone(self.updater.get_targetinfo("file33.txt")) + @patch.object(os, "replace", wraps=os.replace) + @patch.object(os, "remove", wraps=os.remove) + def test_persist_metadata_fails( + self, wrapped_remove: MagicMock, wrapped_replace: MagicMock + ) -> None: + # Testing that when write succeeds (the file is created) and replace + # fails by throwing OSError, then the file will be deleted. + wrapped_replace.side_effect = OSError() + with self.assertRaises(OSError): + self.updater._persist_metadata("target", b"data") + + wrapped_replace.assert_called_once() + wrapped_remove.assert_called_once() + + # Assert that the created tempfile during writing is eventually deleted + # or in other words, there is no temporary file left in the folder. + for filename in os.listdir(self.updater._dir): + self.assertFalse(filename.startswith("tmp")) + if __name__ == "__main__": utils.configure_test_logging(sys.argv) diff --git a/tuf/ngclient/updater.py b/tuf/ngclient/updater.py index ceadf927ec..f4c5b24249 100644 --- a/tuf/ngclient/updater.py +++ b/tuf/ngclient/updater.py @@ -303,15 +303,24 @@ def _load_local_metadata(self, rolename: str) -> bytes: def _persist_metadata(self, rolename: str, data: bytes) -> None: """Write metadata to disk atomically to avoid data loss.""" - - # encode the rolename to avoid issues with e.g. path separators - encoded_name = parse.quote(rolename, "") - filename = os.path.join(self._dir, f"{encoded_name}.json") - with tempfile.NamedTemporaryFile( - dir=self._dir, delete=False - ) as temp_file: - temp_file.write(data) - os.replace(temp_file.name, filename) + try: + # encode the rolename to avoid issues with e.g. path separators + encoded_name = parse.quote(rolename, "") + filename = os.path.join(self._dir, f"{encoded_name}.json") + with tempfile.NamedTemporaryFile( + dir=self._dir, delete=False + ) as temp_file: + temp_file.write(data) + os.replace(temp_file.name, filename) + except OSError as e: + # remove tempfile if we managed to create one, + # then let the exception happen + if temp_file: + try: + os.remove(temp_file.name) + except FileNotFoundError: + pass + raise e def _load_root(self) -> None: """Load remote root metadata.