diff --git a/src/poetry/mixology/version_solver.py b/src/poetry/mixology/version_solver.py index cb08ec61fe0..e3c122838d1 100644 --- a/src/poetry/mixology/version_solver.py +++ b/src/poetry/mixology/version_solver.py @@ -1,9 +1,12 @@ from __future__ import annotations +import collections import functools import time from typing import TYPE_CHECKING +from typing import Optional +from typing import Tuple from poetry.core.packages.dependency import Dependency @@ -28,6 +31,11 @@ _conflict = object() +DependencyCacheKey = Tuple[ + str, Optional[str], Optional[str], Optional[str], Optional[str] +] + + class DependencyCache: """ A cache of the valid dependencies. @@ -38,29 +46,40 @@ class DependencyCache: """ def __init__(self, provider: Provider) -> None: - self.provider = provider - self.cache: dict[ - tuple[str, str | None, str | None, str | None, str | None], - list[DependencyPackage], - ] = {} - - self.search_for = functools.lru_cache(maxsize=128)(self._search_for) + self._provider = provider - def _search_for(self, dependency: Dependency) -> list[DependencyPackage]: - key = ( - dependency.complete_name, - dependency.source_type, - dependency.source_url, - dependency.source_reference, - dependency.source_subdirectory, + # self._cache maps a package name to a stack of cached package lists, + # ordered by the decision level which added them to the cache. This is + # done so that when backtracking we can maintain cache entries from + # previous decision levels, while clearing cache entries from only the + # rolled back levels. + # + # In order to maintain the integrity of the cache, `clear_level()` + # needs to be called in descending order as decision levels are + # backtracked so that the correct items can be popped from the stack. + self._cache: dict[DependencyCacheKey, list[list[DependencyPackage]]] = ( + collections.defaultdict(list) + ) + self._cached_dependencies_by_level: dict[int, list[DependencyCacheKey]] = ( + collections.defaultdict(list) ) - packages = self.cache.get(key) + self._search_for_cached = functools.lru_cache(maxsize=128)(self._search_for) - if packages: + def _search_for( + self, + dependency: Dependency, + key: DependencyCacheKey, + ) -> list[DependencyPackage]: + cache_entries = self._cache[key] + if cache_entries: packages = [ - p for p in packages if dependency.constraint.allows(p.package.version) + p + for p in cache_entries[-1] + if dependency.constraint.allows(p.package.version) ] + else: + packages = None # provider.search_for() normally does not include pre-release packages # (unless requested), but will include them if there are no other @@ -70,14 +89,35 @@ def _search_for(self, dependency: Dependency) -> list[DependencyPackage]: # nothing, we need to call provider.search_for() again as it may return # additional results this time. if not packages: - packages = self.provider.search_for(dependency) + packages = self._provider.search_for(dependency) + + return packages + + def search_for( + self, + dependency: Dependency, + decision_level: int, + ) -> list[DependencyPackage]: + key = ( + dependency.complete_name, + dependency.source_type, + dependency.source_url, + dependency.source_reference, + dependency.source_subdirectory, + ) - self.cache[key] = packages + packages = self._search_for_cached(dependency, key) + if not self._cache[key] or self._cache[key][-1] is not packages: + self._cache[key].append(packages) + self._cached_dependencies_by_level[decision_level].append(key) return packages - def clear(self) -> None: - self.cache.clear() + def clear_level(self, level: int) -> None: + if level in self._cached_dependencies_by_level: + self._search_for_cached.cache_clear() + for key in self._cached_dependencies_by_level.pop(level): + self._cache[key].pop() class VersionSolver: @@ -95,6 +135,9 @@ def __init__(self, root: ProjectPackage, provider: Provider) -> None: self._dependency_cache = DependencyCache(provider) self._incompatibilities: dict[str, list[Incompatibility]] = {} self._contradicted_incompatibilities: set[Incompatibility] = set() + self._contradicted_incompatibilities_by_level: dict[ + int, set[Incompatibility] + ] = collections.defaultdict(set) self._solution = PartialSolution() @property @@ -193,6 +236,9 @@ def _propagate_incompatibility( # incompatibility is contradicted as well and there's nothing new we # can deduce from it. self._contradicted_incompatibilities.add(incompatibility) + self._contradicted_incompatibilities_by_level[ + self._solution.decision_level + ].add(incompatibility) return None elif relation == SetRelation.OVERLAPPING: # If more than one term is inconclusive, we can't deduce anything about @@ -211,6 +257,9 @@ def _propagate_incompatibility( return _conflict self._contradicted_incompatibilities.add(incompatibility) + self._contradicted_incompatibilities_by_level[ + self._solution.decision_level + ].add(incompatibility) adverb = "not " if unsatisfied.is_positive() else "" self._log(f"derived: {adverb}{unsatisfied.dependency}") @@ -304,9 +353,16 @@ def _resolve_conflict(self, incompatibility: Incompatibility) -> Incompatibility previous_satisfier_level < most_recent_satisfier.decision_level or most_recent_satisfier.cause is None ): + for level in range( + self._solution.decision_level, previous_satisfier_level, -1 + ): + if level in self._contradicted_incompatibilities_by_level: + self._contradicted_incompatibilities.difference_update( + self._contradicted_incompatibilities_by_level.pop(level), + ) + self._dependency_cache.clear_level(level) + self._solution.backtrack(previous_satisfier_level) - self._contradicted_incompatibilities.clear() - self._dependency_cache.clear() if new_incompatibility: self._add_incompatibility(incompatibility) @@ -404,7 +460,11 @@ def _get_min(dependency: Dependency) -> tuple[bool, int, int]: if locked: return is_specific_marker, Preference.LOCKED, 1 - num_packages = len(self._dependency_cache.search_for(dependency)) + num_packages = len( + self._dependency_cache.search_for( + dependency, self._solution.decision_level + ) + ) if num_packages < 2: preference = Preference.NO_CHOICE @@ -421,7 +481,9 @@ def _get_min(dependency: Dependency) -> tuple[bool, int, int]: locked = self._provider.get_locked(dependency) if locked is None: - packages = self._dependency_cache.search_for(dependency) + packages = self._dependency_cache.search_for( + dependency, self._solution.decision_level + ) package = next(iter(packages), None) if package is None: diff --git a/tests/mixology/version_solver/test_dependency_cache.py b/tests/mixology/version_solver/test_dependency_cache.py index dffa7e535be..96f1a3b6329 100644 --- a/tests/mixology/version_solver/test_dependency_cache.py +++ b/tests/mixology/version_solver/test_dependency_cache.py @@ -2,6 +2,7 @@ from copy import deepcopy from typing import TYPE_CHECKING +from unittest import mock from poetry.factory import Factory from poetry.mixology.version_solver import DependencyCache @@ -29,20 +30,20 @@ def test_solver_dependency_cache_respects_source_type( add_to_repo(repo, "demo", "1.0.0") cache = DependencyCache(provider) - cache.search_for.cache_clear() + cache._search_for_cached.cache_clear() # ensure cache was never hit for both calls - cache.search_for(dependency_pypi) - cache.search_for(dependency_git) - assert not cache.search_for.cache_info().hits + cache.search_for(dependency_pypi, 0) + cache.search_for(dependency_git, 0) + assert not cache._search_for_cached.cache_info().hits # increase test coverage by searching for copies # (when searching for the exact same object, __eq__ is never called) - packages_pypi = cache.search_for(deepcopy(dependency_pypi)) - packages_git = cache.search_for(deepcopy(dependency_git)) + packages_pypi = cache.search_for(deepcopy(dependency_pypi), 0) + packages_git = cache.search_for(deepcopy(dependency_git), 0) - assert cache.search_for.cache_info().hits == 2 - assert cache.search_for.cache_info().currsize == 2 + assert cache._search_for_cached.cache_info().hits == 2 + assert cache._search_for_cached.cache_info().currsize == 2 assert len(packages_pypi) == len(packages_git) == 1 assert packages_pypi != packages_git @@ -60,6 +61,65 @@ def test_solver_dependency_cache_respects_source_type( assert package_git.package.source_resolved_reference == MOCK_DEFAULT_GIT_REVISION +def test_solver_dependency_cache_pulls_from_prior_level_cache( + root: ProjectPackage, provider: Provider, repo: Repository +) -> None: + dependency_pypi = Factory.create_dependency("demo", ">=0.1.0") + dependency_pypi_constrained = Factory.create_dependency("demo", ">=0.1.0,<2.0.0") + root.add_dependency(dependency_pypi) + root.add_dependency(dependency_pypi_constrained) + add_to_repo(repo, "demo", "1.0.0") + + wrapped_provider = mock.Mock(wraps=provider) + cache = DependencyCache(wrapped_provider) + cache._search_for_cached.cache_clear() + + # On first call, provider.search_for() should be called and the cache + # populated. + cache.search_for(dependency_pypi, 0) + assert len(wrapped_provider.search_for.mock_calls) == 1 + assert ("demo", None, None, None, None) in cache._cache + assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[0] + assert cache._search_for_cached.cache_info().hits == 0 + assert cache._search_for_cached.cache_info().misses == 1 + + # On second call at level 1, neither provider.search_for() nor + # cache._search_for_cached() should have been called again, and the cache + # should remain the same. + cache.search_for(dependency_pypi, 1) + assert len(wrapped_provider.search_for.mock_calls) == 1 + assert ("demo", None, None, None, None) in cache._cache + assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[0] + assert set(cache._cached_dependencies_by_level.keys()) == {0} + assert cache._search_for_cached.cache_info().hits == 1 + assert cache._search_for_cached.cache_info().misses == 1 + + # On third call at level 2 with an updated constraint for the `demo` + # package should not call provider.search_for(), but should call + # cache._search_for_cached() and update the cache. + cache.search_for(dependency_pypi_constrained, 2) + assert len(wrapped_provider.search_for.mock_calls) == 1 + assert ("demo", None, None, None, None) in cache._cache + assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[0] + assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[2] + assert set(cache._cached_dependencies_by_level.keys()) == {0, 2} + assert cache._search_for_cached.cache_info().hits == 1 + assert cache._search_for_cached.cache_info().misses == 2 + + # Clearing the level 2 and level 1 caches should invalidate the lru_cache + # on cache.search_for and wipe out the level 2 cache while preserving the + # level 0 cache. + cache.clear_level(2) + cache.clear_level(1) + cache.search_for(dependency_pypi, 0) + assert len(wrapped_provider.search_for.mock_calls) == 1 + assert ("demo", None, None, None, None) in cache._cache + assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[0] + assert set(cache._cached_dependencies_by_level.keys()) == {0} + assert cache._search_for_cached.cache_info().hits == 0 + assert cache._search_for_cached.cache_info().misses == 1 + + def test_solver_dependency_cache_respects_subdirectories( root: ProjectPackage, provider: Provider, repo: Repository ) -> None: @@ -84,20 +144,20 @@ def test_solver_dependency_cache_respects_subdirectories( root.add_dependency(dependency_one_copy) cache = DependencyCache(provider) - cache.search_for.cache_clear() + cache._search_for_cached.cache_clear() # ensure cache was never hit for both calls - cache.search_for(dependency_one) - cache.search_for(dependency_one_copy) - assert not cache.search_for.cache_info().hits + cache.search_for(dependency_one, 0) + cache.search_for(dependency_one_copy, 0) + assert not cache._search_for_cached.cache_info().hits # increase test coverage by searching for copies # (when searching for the exact same object, __eq__ is never called) - packages_one = cache.search_for(deepcopy(dependency_one)) - packages_one_copy = cache.search_for(deepcopy(dependency_one_copy)) + packages_one = cache.search_for(deepcopy(dependency_one), 0) + packages_one_copy = cache.search_for(deepcopy(dependency_one_copy), 0) - assert cache.search_for.cache_info().hits == 2 - assert cache.search_for.cache_info().currsize == 2 + assert cache._search_for_cached.cache_info().hits == 2 + assert cache._search_for_cached.cache_info().currsize == 2 assert len(packages_one) == len(packages_one_copy) == 1