Skip to content

Commit

Permalink
feat: per-source config of ca_certs and client_cert (#2754)
Browse files Browse the repository at this point in the history
* feat: per-source config of ca_certs and client_cert

Signed-off-by: Frost Ming <[email protected]>

* add news

Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming authored Apr 1, 2024
1 parent b22e95a commit 82bcc88
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 50 deletions.
1 change: 1 addition & 0 deletions news/2754.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Per-source configuration for ca-certs and client-cert.
2 changes: 2 additions & 0 deletions src/pdm/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class _RepositoryConfig:
verify_ssl: bool | None = None
type: str | None = None
ca_certs: str | None = None
client_cert: str | None = None
client_key: str | None = None
include_packages: list[str] = dc.field(default_factory=list)
exclude_packages: list[str] = dc.field(default_factory=list)

Expand Down
4 changes: 1 addition & 3 deletions src/pdm/cli/commands/publish/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pdm.exceptions import PdmUsageError
from pdm.project import Project
from pdm.project.config import DEFAULT_REPOSITORIES
from pdm.utils import get_trusted_hosts

if TYPE_CHECKING:
from typing import Callable, Self
Expand All @@ -39,8 +38,7 @@ def __iter__(self) -> Iterable[bytes]:
class Repository:
def __init__(self, project: Project, config: RepositoryConfig) -> None:
self.url = cast(str, config.url)
trusted_hosts = get_trusted_hosts([config])
self.session = project.environment._build_session(trusted_hosts, verify=config.ca_certs)
self.session = project.environment._build_session([config])

self._credentials_to_save: tuple[str, str, str] | None = None
self.ui = project.core.ui
Expand Down
47 changes: 13 additions & 34 deletions src/pdm/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@
from contextlib import contextmanager
from functools import cached_property, partial
from pathlib import Path
from typing import TYPE_CHECKING, Generator, Mapping, no_type_check
from typing import TYPE_CHECKING, Generator, no_type_check

from pdm.exceptions import BuildError, PdmUsageError
from pdm.models.in_process import get_pep508_environment, get_python_abis, get_uname, sysconfig_get_platform
from pdm.models.python import PythonInfo
from pdm.models.working_set import WorkingSet
from pdm.utils import deprecation_warning, get_trusted_hosts, is_pip_compatible_with_python
from pdm.utils import deprecation_warning, is_pip_compatible_with_python

if TYPE_CHECKING:
import unearth
from httpx import BaseTransport
from httpx._types import CertTypes, VerifyTypes

from pdm._types import RepositoryConfig
from pdm.models.session import PDMPyPIClient
Expand Down Expand Up @@ -112,40 +111,20 @@ def target_python(self) -> unearth.TargetPython:
return tp

def _build_session(
self,
trusted_hosts: list[str] | None = None,
verify: VerifyTypes | None = None,
cert: CertTypes | None = None,
mounts: Mapping[str, BaseTransport | None] | None = None,
self, sources: list[RepositoryConfig] | None = None, mounts: dict[str, BaseTransport | None] | None = None
) -> PDMPyPIClient:
from pdm.models.session import PDMPyPIClient

if trusted_hosts is None:
trusted_hosts = get_trusted_hosts(self.project.sources)

if verify is None:
verify = self.project.config.get("pypi.ca_certs")

if cert is None:
certfn = self.project.config.get("pypi.client_cert")
keyfn = self.project.config.get("pypi.client_key")
if certfn:
cert = (certfn, keyfn)

session_args = {
"cache_dir": self.project.cache("http"),
"trusted_hosts": trusted_hosts,
"timeout": self.project.config["request_timeout"],
"auth": self.auth,
}
if verify is not None:
session_args["verify"] = verify
if cert is not None:
session_args["cert"] = cert
if mounts:
session_args["mounts"] = mounts

session = PDMPyPIClient(**session_args)
if sources is None:
sources = self.project.sources

session = PDMPyPIClient(
sources=sources,
cache_dir=self.project.cache("http"),
timeout=self.project.config["request_timeout"],
auth=self.auth,
mounts=mounts,
)
self.project.core.exit_stack.callback(session.close)
return session

Expand Down
56 changes: 43 additions & 13 deletions src/pdm/models/session.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import sys
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast

import hishel
import httpx
import msgpack
from hishel._serializers import Metadata
from httpcore import Request, Response
Expand All @@ -16,7 +18,7 @@
if TYPE_CHECKING:
from ssl import SSLContext

from httpx import Response as HTTPXResponse
from pdm._types import RepositoryConfig


def _create_truststore_ssl_context() -> SSLContext | None:
Expand All @@ -26,7 +28,6 @@ def _create_truststore_ssl_context() -> SSLContext | None:
try:
import ssl
except ImportError:
logger.warning("Disabling truststore since ssl support is missing")
return None

try:
Expand All @@ -37,9 +38,17 @@ def _create_truststore_ssl_context() -> SSLContext | None:
return truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)


_ssl_context = _create_truststore_ssl_context()
CACHES_TTL = 7 * 24 * 60 * 60 # 7 days


@lru_cache(maxsize=None)
def _get_transport(
verify: bool | SSLContext | str = True, cert: tuple[str, str | None] | None = None
) -> httpx.BaseTransport:
return httpx.HTTPTransport(verify=verify, cert=cert, trust_env=True)


class MsgPackSerializer(hishel.BaseSerializer):
KNOWN_REQUEST_EXTENSIONS = ("timeout", "sni_hostname")
KNOWN_RESPONSE_EXTENSIONS = ("http_version", "reason_phrase")
Expand Down Expand Up @@ -115,22 +124,43 @@ def is_binary(self) -> bool:


class PDMPyPIClient(PyPIClient):
def __init__(self, *, cache_dir: Path, **kwargs: Any) -> None:
def __init__(self, *, sources: list[RepositoryConfig], cache_dir: Path, **kwargs: Any) -> None:
from unearth.fetchers.sync import LocalFSTransport

storage = hishel.FileStorage(serializer=MsgPackSerializer(), base_path=cache_dir, ttl=CACHES_TTL)
controller = hishel.Controller()
kwargs.setdefault("verify", _create_truststore_ssl_context() or True)
kwargs.setdefault("follow_redirects", True)

super().__init__(**kwargs)
mounts: dict[str, httpx.BaseTransport] = {"file://": LocalFSTransport()}
self._trusted_host_ports: set[tuple[str, int | None]] = set()
transport: httpx.BaseTransport | None = None
for s in sources:
if s.name == "pypi":
transport = self._transport_for(s)
continue
assert s.url is not None
url = httpx.URL(s.url)
mounts[f"{url.scheme}://{url.netloc.decode('ascii')}/"] = hishel.CacheTransport(
self._transport_for(s), storage, controller
)
self._trusted_host_ports.add((url.host, url.port))
mounts.update(kwargs.pop("mounts", None) or {})

httpx.Client.__init__(self, mounts=mounts, follow_redirects=True, transport=transport, **kwargs)

self.headers["User-Agent"] = self._make_user_agent()
self.event_hooks["response"].append(self.on_response)

self._transport = hishel.CacheTransport(self._transport, storage, controller) # type: ignore[has-type]
for name, transport in self._mounts.items():
if name.scheme == "file" or transport is None:
# don't cache file:// transport
continue
self._mounts[name] = hishel.CacheTransport(transport, storage, controller)

def _transport_for(self, source: RepositoryConfig) -> httpx.BaseTransport:
if source.ca_certs:
verify: str | bool | SSLContext = source.ca_certs
else:
verify = source.verify_ssl is not False and (_ssl_context or True)
if source.client_cert:
cert = (source.client_cert, source.client_key)
else:
cert = None
return _get_transport(verify=verify, cert=cert)

def _make_user_agent(self) -> str:
import platform
Expand All @@ -143,7 +173,7 @@ def _make_user_agent(self) -> str:
platform.release(),
)

def on_response(self, response: HTTPXResponse) -> None:
def on_response(self, response: httpx.Response) -> None:
from unearth.utils import ARCHIVE_EXTENSIONS

if response.extensions.get("from_cache") and response.url.path.endswith(ARCHIVE_EXTENSIONS):
Expand Down
3 changes: 3 additions & 0 deletions src/pdm/project/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ def default_source(self) -> RepositoryConfig:
verify_ssl=self.config["pypi.verify_ssl"],
username=self.config.get("pypi.username"),
password=self.config.get("pypi.password"),
ca_certs=self.config.get("pypi.ca_certs"),
client_cert=self.config.get("pypi.client_cert"),
client_key=self.config.get("pypi.client_key"),
)

@property
Expand Down

0 comments on commit 82bcc88

Please sign in to comment.