From 82bcc880b082a3e386d1921740c3a7040f3675d6 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Mon, 1 Apr 2024 12:20:48 +0800 Subject: [PATCH] feat: per-source config of ca_certs and client_cert (#2754) * feat: per-source config of ca_certs and client_cert Signed-off-by: Frost Ming * add news Signed-off-by: Frost Ming --- news/2754.bugfix.md | 1 + src/pdm/_types.py | 2 + src/pdm/cli/commands/publish/repository.py | 4 +- src/pdm/environments/base.py | 47 +++++------------- src/pdm/models/session.py | 56 +++++++++++++++++----- src/pdm/project/core.py | 3 ++ 6 files changed, 63 insertions(+), 50 deletions(-) create mode 100644 news/2754.bugfix.md diff --git a/news/2754.bugfix.md b/news/2754.bugfix.md new file mode 100644 index 0000000000..552c64bc10 --- /dev/null +++ b/news/2754.bugfix.md @@ -0,0 +1 @@ +Per-source configuration for ca-certs and client-cert. diff --git a/src/pdm/_types.py b/src/pdm/_types.py index 06fb469d92..3b632e537d 100644 --- a/src/pdm/_types.py +++ b/src/pdm/_types.py @@ -25,6 +25,8 @@ class _RepositoryConfig: verify_ssl: bool | None = None type: str | None = None ca_certs: str | None = None + client_cert: str | None = None + client_key: str | None = None include_packages: list[str] = dc.field(default_factory=list) exclude_packages: list[str] = dc.field(default_factory=list) diff --git a/src/pdm/cli/commands/publish/repository.py b/src/pdm/cli/commands/publish/repository.py index 3f24b51d79..25c508ba37 100644 --- a/src/pdm/cli/commands/publish/repository.py +++ b/src/pdm/cli/commands/publish/repository.py @@ -12,7 +12,6 @@ from pdm.exceptions import PdmUsageError from pdm.project import Project from pdm.project.config import DEFAULT_REPOSITORIES -from pdm.utils import get_trusted_hosts if TYPE_CHECKING: from typing import Callable, Self @@ -39,8 +38,7 @@ def __iter__(self) -> Iterable[bytes]: class Repository: def __init__(self, project: Project, config: RepositoryConfig) -> None: self.url = cast(str, config.url) - trusted_hosts = get_trusted_hosts([config]) - self.session = project.environment._build_session(trusted_hosts, verify=config.ca_certs) + self.session = project.environment._build_session([config]) self._credentials_to_save: tuple[str, str, str] | None = None self.ui = project.core.ui diff --git a/src/pdm/environments/base.py b/src/pdm/environments/base.py index b7da071dc8..dd87cac84e 100644 --- a/src/pdm/environments/base.py +++ b/src/pdm/environments/base.py @@ -12,18 +12,17 @@ from contextlib import contextmanager from functools import cached_property, partial from pathlib import Path -from typing import TYPE_CHECKING, Generator, Mapping, no_type_check +from typing import TYPE_CHECKING, Generator, no_type_check from pdm.exceptions import BuildError, PdmUsageError from pdm.models.in_process import get_pep508_environment, get_python_abis, get_uname, sysconfig_get_platform from pdm.models.python import PythonInfo from pdm.models.working_set import WorkingSet -from pdm.utils import deprecation_warning, get_trusted_hosts, is_pip_compatible_with_python +from pdm.utils import deprecation_warning, is_pip_compatible_with_python if TYPE_CHECKING: import unearth from httpx import BaseTransport - from httpx._types import CertTypes, VerifyTypes from pdm._types import RepositoryConfig from pdm.models.session import PDMPyPIClient @@ -112,40 +111,20 @@ def target_python(self) -> unearth.TargetPython: return tp def _build_session( - self, - trusted_hosts: list[str] | None = None, - verify: VerifyTypes | None = None, - cert: CertTypes | None = None, - mounts: Mapping[str, BaseTransport | None] | None = None, + self, sources: list[RepositoryConfig] | None = None, mounts: dict[str, BaseTransport | None] | None = None ) -> PDMPyPIClient: from pdm.models.session import PDMPyPIClient - if trusted_hosts is None: - trusted_hosts = get_trusted_hosts(self.project.sources) - - if verify is None: - verify = self.project.config.get("pypi.ca_certs") - - if cert is None: - certfn = self.project.config.get("pypi.client_cert") - keyfn = self.project.config.get("pypi.client_key") - if certfn: - cert = (certfn, keyfn) - - session_args = { - "cache_dir": self.project.cache("http"), - "trusted_hosts": trusted_hosts, - "timeout": self.project.config["request_timeout"], - "auth": self.auth, - } - if verify is not None: - session_args["verify"] = verify - if cert is not None: - session_args["cert"] = cert - if mounts: - session_args["mounts"] = mounts - - session = PDMPyPIClient(**session_args) + if sources is None: + sources = self.project.sources + + session = PDMPyPIClient( + sources=sources, + cache_dir=self.project.cache("http"), + timeout=self.project.config["request_timeout"], + auth=self.auth, + mounts=mounts, + ) self.project.core.exit_stack.callback(session.close) return session diff --git a/src/pdm/models/session.py b/src/pdm/models/session.py index 80641ef862..46a3381914 100644 --- a/src/pdm/models/session.py +++ b/src/pdm/models/session.py @@ -1,10 +1,12 @@ from __future__ import annotations import sys +from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING, Any, cast import hishel +import httpx import msgpack from hishel._serializers import Metadata from httpcore import Request, Response @@ -16,7 +18,7 @@ if TYPE_CHECKING: from ssl import SSLContext - from httpx import Response as HTTPXResponse + from pdm._types import RepositoryConfig def _create_truststore_ssl_context() -> SSLContext | None: @@ -26,7 +28,6 @@ def _create_truststore_ssl_context() -> SSLContext | None: try: import ssl except ImportError: - logger.warning("Disabling truststore since ssl support is missing") return None try: @@ -37,9 +38,17 @@ def _create_truststore_ssl_context() -> SSLContext | None: return truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +_ssl_context = _create_truststore_ssl_context() CACHES_TTL = 7 * 24 * 60 * 60 # 7 days +@lru_cache(maxsize=None) +def _get_transport( + verify: bool | SSLContext | str = True, cert: tuple[str, str | None] | None = None +) -> httpx.BaseTransport: + return httpx.HTTPTransport(verify=verify, cert=cert, trust_env=True) + + class MsgPackSerializer(hishel.BaseSerializer): KNOWN_REQUEST_EXTENSIONS = ("timeout", "sni_hostname") KNOWN_RESPONSE_EXTENSIONS = ("http_version", "reason_phrase") @@ -115,22 +124,43 @@ def is_binary(self) -> bool: class PDMPyPIClient(PyPIClient): - def __init__(self, *, cache_dir: Path, **kwargs: Any) -> None: + def __init__(self, *, sources: list[RepositoryConfig], cache_dir: Path, **kwargs: Any) -> None: + from unearth.fetchers.sync import LocalFSTransport + storage = hishel.FileStorage(serializer=MsgPackSerializer(), base_path=cache_dir, ttl=CACHES_TTL) controller = hishel.Controller() - kwargs.setdefault("verify", _create_truststore_ssl_context() or True) - kwargs.setdefault("follow_redirects", True) - super().__init__(**kwargs) + mounts: dict[str, httpx.BaseTransport] = {"file://": LocalFSTransport()} + self._trusted_host_ports: set[tuple[str, int | None]] = set() + transport: httpx.BaseTransport | None = None + for s in sources: + if s.name == "pypi": + transport = self._transport_for(s) + continue + assert s.url is not None + url = httpx.URL(s.url) + mounts[f"{url.scheme}://{url.netloc.decode('ascii')}/"] = hishel.CacheTransport( + self._transport_for(s), storage, controller + ) + self._trusted_host_ports.add((url.host, url.port)) + mounts.update(kwargs.pop("mounts", None) or {}) + + httpx.Client.__init__(self, mounts=mounts, follow_redirects=True, transport=transport, **kwargs) + self.headers["User-Agent"] = self._make_user_agent() self.event_hooks["response"].append(self.on_response) - self._transport = hishel.CacheTransport(self._transport, storage, controller) # type: ignore[has-type] - for name, transport in self._mounts.items(): - if name.scheme == "file" or transport is None: - # don't cache file:// transport - continue - self._mounts[name] = hishel.CacheTransport(transport, storage, controller) + + def _transport_for(self, source: RepositoryConfig) -> httpx.BaseTransport: + if source.ca_certs: + verify: str | bool | SSLContext = source.ca_certs + else: + verify = source.verify_ssl is not False and (_ssl_context or True) + if source.client_cert: + cert = (source.client_cert, source.client_key) + else: + cert = None + return _get_transport(verify=verify, cert=cert) def _make_user_agent(self) -> str: import platform @@ -143,7 +173,7 @@ def _make_user_agent(self) -> str: platform.release(), ) - def on_response(self, response: HTTPXResponse) -> None: + def on_response(self, response: httpx.Response) -> None: from unearth.utils import ARCHIVE_EXTENSIONS if response.extensions.get("from_cache") and response.url.path.endswith(ARCHIVE_EXTENSIONS): diff --git a/src/pdm/project/core.py b/src/pdm/project/core.py index 4de3a75e5c..da539c11d2 100644 --- a/src/pdm/project/core.py +++ b/src/pdm/project/core.py @@ -360,6 +360,9 @@ def default_source(self) -> RepositoryConfig: verify_ssl=self.config["pypi.verify_ssl"], username=self.config.get("pypi.username"), password=self.config.get("pypi.password"), + ca_certs=self.config.get("pypi.ca_certs"), + client_cert=self.config.get("pypi.client_cert"), + client_key=self.config.get("pypi.client_key"), ) @property