Skip to content

Commit

Permalink
Re-structure PEP-691 fingerprints db. (#2552)
Browse files Browse the repository at this point in the history
Use a `dbs` cache entry for all databases starting with PEP-691
fingerprints and hide this cache entry as a choice when purging
individual entries using `pex3 cache purge`.

Work towards #2528.
  • Loading branch information
jsirois authored Oct 5, 2024
1 parent 13826f1 commit 22c254c
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 28 deletions.
10 changes: 10 additions & 0 deletions pex/cache/dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ def __init__(
version, # type: int
description, # type: str
dependencies=(), # type: Iterable[CacheDir.Value]
can_purge=True, # type: bool
):
Enum.Value.__init__(self, value)
self.name = name
self.version = version
self.description = description
self.dependencies = tuple(dependencies)
self.can_purge = can_purge

@property
def rel_path(self):
Expand Down Expand Up @@ -76,6 +78,14 @@ def iter_transitive_dependents(self):
description="Wheels built by Pex from resolved sdists when creating PEX files.",
)

DBS = Value(
"dbs",
version=0,
name="Pex Internal Databases",
description="Databases Pex uses for caches and to track cache structure.",
can_purge=False,
)

DOCS = Value(
"docs",
version=0,
Expand Down
2 changes: 1 addition & 1 deletion pex/cli/commands/cache/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _add_purge_arguments(cls, parser):
"--entries",
action="append",
type=CacheDir.for_value,
choices=CacheDir.values(),
choices=[cache_dir for cache_dir in CacheDir.values() if cache_dir.can_purge],
default=[],
help=(
"Specific cache entries to purge. By default, all entries are purged, but by "
Expand Down
41 changes: 27 additions & 14 deletions pex/resolve/pep_691/fingerprint_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

import os
import sqlite3
from contextlib import closing
from contextlib import closing, contextmanager
from itertools import repeat
from multiprocessing.pool import ThreadPool

from pex import pex_warnings
from pex.atomic_directory import atomic_directory
from pex.cache.dirs import CacheDir
from pex.compatibility import cpu_count
from pex.fetcher import URLFetcher
from pex.resolve.pep_691.api import Client
Expand All @@ -19,7 +21,6 @@
from pex.result import Error, catch
from pex.tracer import TRACER
from pex.typing import TYPE_CHECKING
from pex.variables import ENV

if TYPE_CHECKING:
from typing import Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union
Expand All @@ -46,29 +47,41 @@ def create(
return cls(api=Client(url_fetcher=url_fetcher), max_parallel_jobs=max_parallel_jobs)

_api = attr.ib(factory=Client) # type: Client
_path = attr.ib(factory=lambda: os.path.join(ENV.PEX_ROOT, "fingerprints.db")) # type: str
_db_dir = attr.ib(factory=lambda: CacheDir.DBS.path("pep_691")) # type: str
_max_parallel_jobs = attr.ib(default=None) # type: Optional[int]

@property
def accept(self):
# type: () -> Tuple[str, ...]
return self._api.ACCEPT

_SCHEMA = """
PRAGMA journal_mode=WAL;
CREATE TABLE hashes (
url TEXT PRIMARY KEY ASC,
algorithm TEXT NOT NULL,
hash TEXT NOT NULL
) WITHOUT ROWID;
"""

@contextmanager
def _db_connection(self):
# type: () -> Iterator[sqlite3.Connection]
with atomic_directory(self._db_dir) as atomic_dir:
if not atomic_dir.is_finalized():
with sqlite3.connect(os.path.join(atomic_dir.work_dir, "fingerprints.db")) as conn:
conn.executescript(self._SCHEMA).close()
with sqlite3.connect(os.path.join(self._db_dir, "fingerprints.db")) as conn:
conn.execute("PRAGMA synchronous=NORMAL").close()
yield conn

def _iter_cached(self, urls_to_fingerprint):
# type: (Iterable[str]) -> Iterator[_FingerprintedURL]

urls = sorted(urls_to_fingerprint)
with TRACER.timed("Searching for {count} fingerprints in database".format(count=len(urls))):
with sqlite3.connect(self._path) as conn:
conn.executescript(
"""
CREATE TABLE IF NOT EXISTS hashes (
url TEXT PRIMARY KEY ASC,
algorithm TEXT NOT NULL,
hash TEXT NOT NULL
) WITHOUT ROWID;
"""
).close()
with self._db_connection() as conn:
# N.B.: Maximum parameter count is 999 in pre-2020 versions of SQLite 3; so we limit
# to an even lower chunk size to be safe: https://www.sqlite.org/limits.html
chunk_size = 100
Expand All @@ -93,7 +106,7 @@ def _cache(self, fingerprinted_urls):
with TRACER.timed(
"Caching {count} fingerprints in database".format(count=len(fingerprinted_urls))
):
with sqlite3.connect(self._path) as conn:
with self._db_connection() as conn:
conn.executemany(
"INSERT OR REPLACE INTO hashes (url, algorithm, hash) VALUES (?, ?, ?)",
tuple(
Expand Down
27 changes: 14 additions & 13 deletions tests/resolve/pep_691/test_fingerprint_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).

import os.path
import shutil

import pytest

Expand All @@ -28,9 +29,9 @@


@pytest.fixture
def database_path(tmpdir):
def db_dir(tmpdir):
# type: (Any) -> str
return os.path.join(str(tmpdir), "fingerprints.db")
return os.path.join(str(tmpdir), "pep_691")


def file(
Expand All @@ -57,11 +58,11 @@ def create_project(
)


def test_no_fingerprints(database_path):
def test_no_fingerprints(db_dir):
# type: (str) -> None

with mock.patch.object(Client, "request", return_value=create_project("foo")) as request:
fingerprint_service = FingerprintService(path=database_path)
fingerprint_service = FingerprintService(db_dir=db_dir)
artifacts = list(
fingerprint_service.fingerprint(
endpoints={ENDPOINT},
Expand All @@ -72,7 +73,7 @@ def test_no_fingerprints(database_path):
request.assert_called_once_with(ENDPOINT)


def test_no_matching_fingerprints(database_path):
def test_no_matching_fingerprints(db_dir):
# type: (str) -> None

with mock.patch.object(
Expand All @@ -84,7 +85,7 @@ def test_no_matching_fingerprints(database_path):
file("https://files.example.org/foo-2.0.tar.gz", sha256="strong"),
),
) as request:
fingerprint_service = FingerprintService(path=database_path)
fingerprint_service = FingerprintService(db_dir=db_dir)
artifacts = list(
fingerprint_service.fingerprint(
endpoints={ENDPOINT},
Expand All @@ -95,7 +96,7 @@ def test_no_matching_fingerprints(database_path):
request.assert_called_once_with(ENDPOINT)


def test_cache_miss_retries(database_path):
def test_cache_miss_retries(db_dir):
# type: (Any) -> None

endpoint = Endpoint("https://example.org/simple/foo", "x/y")
Expand All @@ -110,7 +111,7 @@ def test_cache_miss_retries(database_path):
file("https://files.example.org/foo-2.0.tar.gz", sha256="strong"),
),
) as request:
fingerprint_service = FingerprintService(path=database_path)
fingerprint_service = FingerprintService(db_dir=db_dir)
for _ in range(attempts):

artifacts = list(
Expand All @@ -128,7 +129,7 @@ def test_cache_miss_retries(database_path):
def test_cache_hit(tmpdir):
# type: (Any) -> None

database_path = os.path.join(str(tmpdir), "fingerprints.db")
db_dir = os.path.join(str(tmpdir), "pep_691")
endpoint = Endpoint("https://example.org/simple/foo", "x/y")
initial_artifact = PartialArtifact(url="https://files.example.org/foo-1.1.tar.gz")
expected_artifact = PartialArtifact(
Expand All @@ -143,7 +144,7 @@ def test_cache_hit(tmpdir):
"foo", file("https://files.example.org/foo-1.1.tar.gz", md5="weak")
),
) as request:
fingerprint_service = FingerprintService(path=database_path)
fingerprint_service = FingerprintService(db_dir=db_dir)
for _ in range(3):
artifacts = list(
fingerprint_service.fingerprint(endpoints={endpoint}, artifacts=[initial_artifact])
Expand All @@ -154,15 +155,15 @@ def test_cache_hit(tmpdir):
request.assert_called_once_with(endpoint)

# Unless the cache is wiped out.
os.unlink(database_path)
shutil.rmtree(db_dir)
request.reset_mock()
assert [expected_artifact] == list(
fingerprint_service.fingerprint(endpoints={endpoint}, artifacts=[initial_artifact])
)
request.assert_called_once_with(endpoint)


def test_mixed(database_path):
def test_mixed(db_dir):
# type: (str) -> None

responses = {
Expand All @@ -177,7 +178,7 @@ def test_mixed(database_path):
}

with mock.patch.object(Client, "request", side_effect=responses.get) as request:
fingerprint_service = FingerprintService(path=database_path)
fingerprint_service = FingerprintService(db_dir=db_dir)
artifacts = sorted(
fingerprint_service.fingerprint(
endpoints=set(responses),
Expand Down

0 comments on commit 22c254c

Please sign in to comment.