From 8e89a8d4f3a73a964a660a6cb0db65d36d8f0e1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Randy=20D=C3=B6ring?= <30527984+radoering@users.noreply.github.com> Date: Sat, 7 Oct 2023 10:10:35 +0200 Subject: [PATCH] fix race condition to avoid downloading the same artifact in multiple threads and trying to store it in the same location of the artifact cache --- src/poetry/utils/cache.py | 22 ++++++++++++----- tests/installation/test_executor.py | 14 ++++++----- tests/utils/test_cache.py | 37 +++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 12 deletions(-) diff --git a/src/poetry/utils/cache.py b/src/poetry/utils/cache.py index 5bd6a8dc7ef..99955e4a131 100644 --- a/src/poetry/utils/cache.py +++ b/src/poetry/utils/cache.py @@ -5,8 +5,10 @@ import json import logging import shutil +import threading import time +from collections import defaultdict from pathlib import Path from typing import TYPE_CHECKING from typing import Any @@ -188,6 +190,9 @@ def _deserialize(self, data_raw: bytes) -> CacheItem[T]: class ArtifactCache: def __init__(self, *, cache_dir: Path) -> None: self._cache_dir = cache_dir + self._archive_locks: defaultdict[Path, threading.Lock] = defaultdict( + threading.Lock + ) def get_cache_directory_for_link(self, link: Link) -> Path: key_parts = {"url": link.url_without_fragment} @@ -253,13 +258,18 @@ def get_cached_archive_for_link( cache_dir, strict=strict, filename=link.filename, env=env ) if cached_archive is None and strict and download_func is not None: - cache_dir.mkdir(parents=True, exist_ok=True) cached_archive = cache_dir / link.filename - try: - download_func(link.url, cached_archive) - except BaseException: - cached_archive.unlink(missing_ok=True) - raise + with self._archive_locks[cached_archive]: + # Check again if the archive exists (under the lock) to avoid + # duplicate downloads because it may have already been downloaded + # by another thread in the meantime + if not cached_archive.exists(): + cache_dir.mkdir(parents=True, exist_ok=True) + try: + download_func(link.url, cached_archive) + except BaseException: + cached_archive.unlink(missing_ok=True) + raise return cached_archive diff --git a/tests/installation/test_executor.py b/tests/installation/test_executor.py index 216eb2ab8f2..3b2aec4e315 100644 --- a/tests/installation/test_executor.py +++ b/tests/installation/test_executor.py @@ -582,14 +582,16 @@ def test_executor_should_delete_incomplete_downloads( pool: RepositoryPool, mock_file_downloads: None, env: MockEnv, - fixture_dir: FixtureDirGetter, ) -> None: - fixture = fixture_dir("distributions") / "demo-0.1.0-py2.py3-none-any.whl" - destination_fixture = tmp_path / "tomlkit-0.5.3-py2.py3-none-any.whl" - shutil.copyfile(str(fixture), str(destination_fixture)) + cached_archive = tmp_path / "tomlkit-0.5.3-py2.py3-none-any.whl" + + def download_fail(*_: Any) -> None: + cached_archive.touch() # broken archive + raise Exception("Download error") + mocker.patch( "poetry.installation.executor.Executor._download_archive", - side_effect=Exception("Download error"), + side_effect=download_fail, ) mocker.patch( "poetry.utils.cache.ArtifactCache._get_cached_archive", @@ -607,7 +609,7 @@ def test_executor_should_delete_incomplete_downloads( with pytest.raises(Exception, match="Download error"): executor._download(Install(Package("tomlkit", "0.5.3"))) - assert not destination_fixture.exists() + assert not cached_archive.exists() def verify_installed_distribution( diff --git a/tests/utils/test_cache.py b/tests/utils/test_cache.py index af125d0c7d1..3e73e832ca1 100644 --- a/tests/utils/test_cache.py +++ b/tests/utils/test_cache.py @@ -1,6 +1,8 @@ from __future__ import annotations +import concurrent.futures import shutil +import traceback from pathlib import Path from typing import TYPE_CHECKING @@ -322,6 +324,41 @@ def test_get_found_cached_archive_for_link( assert Path(cached) == archive +def test_get_cached_archive_for_link_no_race_condition( + tmp_path: Path, mocker: MockerFixture +) -> None: + cache = ArtifactCache(cache_dir=tmp_path) + link = Link("https://files.python-poetry.org/demo-0.1.0.tar.gz") + + def replace_file(_: str, dest: Path) -> None: + dest.unlink(missing_ok=True) + # write some data (so it takes a while) to provoke possible race conditions + dest.write_text("a" * 2**20) + + download_mock = mocker.Mock(side_effect=replace_file) + + with concurrent.futures.ThreadPoolExecutor() as executor: + tasks = [] + for _ in range(4): + tasks.append( + executor.submit( + cache.get_cached_archive_for_link, + link, + strict=True, + download_func=download_mock, + ) + ) + concurrent.futures.wait(tasks) + results = set() + for task in tasks: + try: + results.add(task.result()) + except Exception: + pytest.fail(traceback.format_exc()) + assert results == {cache.get_cache_directory_for_link(link) / link.filename} + download_mock.assert_called_once() + + def test_get_cached_archive_for_git() -> None: """Smoke test that checks that no assertion is raised.""" cache = ArtifactCache(cache_dir=Path())