Skip to content

Commit

Permalink
perf: don't clear the entire dependency cache when backtracking (pyth…
Browse files Browse the repository at this point in the history
…on-poetry#7950)

(cherry picked from commit f54864e)
  • Loading branch information
chriskuehl committed May 24, 2023
1 parent 3e7466b commit 00ec1ce
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 28 deletions.
56 changes: 36 additions & 20 deletions src/poetry/mixology/version_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,20 @@ class DependencyCache:
"""

def __init__(self, provider: Provider) -> None:
self.provider = provider
self.cache: dict[
tuple[str, str | None, str | None, str | None, str | None],
list[DependencyPackage],
] = {}
self._provider = provider
self._cache: dict[
int,
dict[
tuple[str, str | None, str | None, str | None, str | None],
list[DependencyPackage],
],
] = collections.defaultdict(dict)

self.search_for = functools.lru_cache(maxsize=128)(self._search_for)

def _search_for(self, dependency: Dependency) -> list[DependencyPackage]:
def _search_for(
self, dependency: Dependency, level: int
) -> list[DependencyPackage]:
key = (
dependency.complete_name,
dependency.source_type,
Expand All @@ -56,12 +61,17 @@ def _search_for(self, dependency: Dependency) -> list[DependencyPackage]:
dependency.source_subdirectory,
)

packages = self.cache.get(key)

if packages:
packages = [
p for p in packages if dependency.constraint.allows(p.package.version)
]
for check_level in range(level, -1, -1):
packages = self._cache[check_level].get(key)
if packages is not None:
packages = [
p
for p in packages
if dependency.constraint.allows(p.package.version)
]
break
else:
packages = None

# provider.search_for() normally does not include pre-release packages
# (unless requested), but will include them if there are no other
Expand All @@ -71,14 +81,14 @@ def _search_for(self, dependency: Dependency) -> list[DependencyPackage]:
# nothing, we need to call provider.search_for() again as it may return
# additional results this time.
if not packages:
packages = self.provider.search_for(dependency)

self.cache[key] = packages
packages = self._provider.search_for(dependency)

self._cache[level][key] = packages
return packages

def clear(self) -> None:
self.cache.clear()
def clear_level(self, level: int) -> None:
self.search_for.cache_clear()
self._cache.pop(level, None)


class VersionSolver:
Expand Down Expand Up @@ -318,9 +328,9 @@ def _resolve_conflict(self, incompatibility: Incompatibility) -> Incompatibility
self._solution.decision_level, previous_satisfier_level, -1
):
self._contradicted_incompatibilities.pop(level, None)
self._dependency_cache.clear_level(level)

self._solution.backtrack(previous_satisfier_level)
self._dependency_cache.clear()
if new_incompatibility:
self._add_incompatibility(incompatibility)

Expand Down Expand Up @@ -418,7 +428,11 @@ def _get_min(dependency: Dependency) -> tuple[bool, int, int]:
if locked:
return is_specific_marker, Preference.LOCKED, 1

num_packages = len(self._dependency_cache.search_for(dependency))
num_packages = len(
self._dependency_cache.search_for(
dependency, self._solution.decision_level
)
)

if num_packages < 2:
preference = Preference.NO_CHOICE
Expand All @@ -435,7 +449,9 @@ def _get_min(dependency: Dependency) -> tuple[bool, int, int]:

locked = self._provider.get_locked(dependency)
if locked is None:
packages = self._dependency_cache.search_for(dependency)
packages = self._dependency_cache.search_for(
dependency, self._solution.decision_level
)
package = next(iter(packages), None)

if package is None:
Expand Down
55 changes: 47 additions & 8 deletions tests/mixology/version_solver/test_dependency_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from copy import deepcopy
from typing import TYPE_CHECKING
from unittest import mock

from poetry.factory import Factory
from poetry.mixology.version_solver import DependencyCache
Expand Down Expand Up @@ -32,14 +33,14 @@ def test_solver_dependency_cache_respects_source_type(
cache.search_for.cache_clear()

# ensure cache was never hit for both calls
cache.search_for(dependency_pypi)
cache.search_for(dependency_git)
cache.search_for(dependency_pypi, 0)
cache.search_for(dependency_git, 0)
assert not cache.search_for.cache_info().hits

# increase test coverage by searching for copies
# (when searching for the exact same object, __eq__ is never called)
packages_pypi = cache.search_for(deepcopy(dependency_pypi))
packages_git = cache.search_for(deepcopy(dependency_git))
packages_pypi = cache.search_for(deepcopy(dependency_pypi), 0)
packages_git = cache.search_for(deepcopy(dependency_git), 0)

assert cache.search_for.cache_info().hits == 2
assert cache.search_for.cache_info().currsize == 2
Expand All @@ -60,6 +61,44 @@ def test_solver_dependency_cache_respects_source_type(
assert package_git.package.source_resolved_reference == MOCK_DEFAULT_GIT_REVISION


def test_solver_dependency_cache_pulls_from_prior_level_cache(
root: ProjectPackage, provider: Provider, repo: Repository
) -> None:
dependency_pypi = Factory.create_dependency("demo", ">=0.1.0")
root.add_dependency(dependency_pypi)
add_to_repo(repo, "demo", "1.0.0")

wrapped_provider = mock.Mock(wraps=provider)
cache = DependencyCache(wrapped_provider)
cache.search_for.cache_clear()

# On first call, provider.search_for() should be called and the level-0
# cache populated.
cache.search_for(dependency_pypi, 0)
assert len(wrapped_provider.search_for.mock_calls) == 1
assert ("demo", None, None, None, None) in cache._cache[0]
assert cache.search_for.cache_info().hits == 0
assert cache.search_for.cache_info().misses == 1

# On second call at level 1, provider.search_for() should not be called
# again and the level-1 cache should be populated from the level-0 cache.
cache.search_for(dependency_pypi, 1)
assert len(wrapped_provider.search_for.mock_calls) == 1
assert ("demo", None, None, None, None) in cache._cache[1]
assert cache._cache[0] == cache._cache[1]
assert cache.search_for.cache_info().hits == 0
assert cache.search_for.cache_info().misses == 2

# Clearing the level 1 cache should invalidate the lru_cache on
# cache.search_for and wipe out the level 1 cache while preserving the
# level 0 cache.
cache.clear_level(1)
assert set(cache._cache.keys()) == {0}
assert ("demo", None, None, None, None) in cache._cache[0]
assert cache.search_for.cache_info().hits == 0
assert cache.search_for.cache_info().misses == 0


def test_solver_dependency_cache_respects_subdirectories(
root: ProjectPackage, provider: Provider, repo: Repository
) -> None:
Expand Down Expand Up @@ -87,14 +126,14 @@ def test_solver_dependency_cache_respects_subdirectories(
cache.search_for.cache_clear()

# ensure cache was never hit for both calls
cache.search_for(dependency_one)
cache.search_for(dependency_one_copy)
cache.search_for(dependency_one, 0)
cache.search_for(dependency_one_copy, 0)
assert not cache.search_for.cache_info().hits

# increase test coverage by searching for copies
# (when searching for the exact same object, __eq__ is never called)
packages_one = cache.search_for(deepcopy(dependency_one))
packages_one_copy = cache.search_for(deepcopy(dependency_one_copy))
packages_one = cache.search_for(deepcopy(dependency_one), 0)
packages_one_copy = cache.search_for(deepcopy(dependency_one_copy), 0)

assert cache.search_for.cache_info().hits == 2
assert cache.search_for.cache_info().currsize == 2
Expand Down

0 comments on commit 00ec1ce

Please sign in to comment.