Skip to content

Commit

Permalink
Moved create and verify ssl context function into the utils and inclu…
Browse files Browse the repository at this point in the history
…ded lru cache
  • Loading branch information
Can Sarigol committed May 25, 2020
1 parent 440b5ab commit 76bb86c
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 95 deletions.
116 changes: 22 additions & 94 deletions httpx/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,14 @@

from ._models import URL, Headers
from ._types import CertTypes, HeaderTypes, TimeoutTypes, URLTypes, VerifyTypes
from ._utils import get_ca_bundle_from_env, get_logger, warn_deprecated

DEFAULT_CIPHERS = ":".join(
[
"ECDHE+AESGCM",
"ECDHE+CHACHA20",
"DHE+AESGCM",
"DHE+CHACHA20",
"ECDH+AESGCM",
"DH+AESGCM",
"ECDH+AES",
"DH+AES",
"RSA+AESGCM",
"RSA+AES",
"!aNULL",
"!eNULL",
"!MD5",
"!DSS",
]
from ._utils import (
create_and_verify_ssl_context,
get_ca_bundle_from_env,
get_logger,
load_client_certs,
warn_deprecated,
)


logger = get_logger(__name__)


Expand Down Expand Up @@ -89,11 +75,7 @@ def load_ssl_context_no_verify(self) -> ssl.SSLContext:
"""
Return an SSL context for unverified connections.
"""
context = self._create_default_ssl_context()
context.verify_mode = ssl.CERT_NONE
context.check_hostname = False
self._load_client_certs(context)
return context
return self._create_and_verify_ssl_context(ssl.CERT_NONE)

def load_ssl_context_verify(self) -> ssl.SSLContext:
"""
Expand All @@ -107,7 +89,8 @@ def load_ssl_context_verify(self) -> ssl.SSLContext:
if isinstance(self.verify, ssl.SSLContext):
# Allow passing in our own SSLContext object that's pre-configured.
context = self.verify
self._load_client_certs(context)
if self.cert is not None:
load_client_certs(context, self.cert)
return context
elif isinstance(self.verify, bool):
ca_bundle_path = self.DEFAULT_CA_BUNDLE_PATH
Expand All @@ -119,74 +102,19 @@ def load_ssl_context_verify(self) -> ssl.SSLContext:
"invalid path: {}".format(self.verify)
)

context = self._create_default_ssl_context()
context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = True

# Signal to server support for PHA in TLS 1.3. Raises an
# AttributeError if only read-only access is implemented.
try:
context.post_handshake_auth = True # type: ignore
except AttributeError: # pragma: nocover
pass

# Disable using 'commonName' for SSLContext.check_hostname
# when the 'subjectAltName' extension isn't available.
try:
context.hostname_checks_common_name = False # type: ignore
except AttributeError: # pragma: nocover
pass

if ca_bundle_path.is_file():
logger.trace(f"load_verify_locations cafile={ca_bundle_path!s}")
context.load_verify_locations(cafile=str(ca_bundle_path))
elif ca_bundle_path.is_dir():
logger.trace(f"load_verify_locations capath={ca_bundle_path!s}")
context.load_verify_locations(capath=str(ca_bundle_path))

self._load_client_certs(context)

return context

def _create_default_ssl_context(self) -> ssl.SSLContext:
"""
Creates the default SSLContext object that's used for both verified
and unverified connections.
"""
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_TLSv1
context.options |= ssl.OP_NO_TLSv1_1
context.options |= ssl.OP_NO_COMPRESSION
context.set_ciphers(DEFAULT_CIPHERS)

if ssl.HAS_ALPN:
alpn_idents = ["http/1.1", "h2"] if self.http2 else ["http/1.1"]
context.set_alpn_protocols(alpn_idents)

if hasattr(context, "keylog_filename"): # pragma: nocover (Available in 3.8+)
keylogfile = os.environ.get("SSLKEYLOGFILE")
if keylogfile and self.trust_env:
context.keylog_filename = keylogfile # type: ignore

return context

def _load_client_certs(self, ssl_context: ssl.SSLContext) -> None:
"""
Loads client certificates into our SSLContext object
"""
if self.cert is not None:
if isinstance(self.cert, str):
ssl_context.load_cert_chain(certfile=self.cert)
elif isinstance(self.cert, tuple) and len(self.cert) == 2:
ssl_context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
elif isinstance(self.cert, tuple) and len(self.cert) == 3:
ssl_context.load_cert_chain(
certfile=self.cert[0],
keyfile=self.cert[1],
password=self.cert[2], # type: ignore
)
return self._create_and_verify_ssl_context(ssl.CERT_REQUIRED, ca_bundle_path)

def _create_and_verify_ssl_context(
self, verify_mode: int, ca_bundle_path: typing.Optional[Path] = None
) -> ssl.SSLContext:
return create_and_verify_ssl_context(
verify_mode,
self.cert,
self.http2,
self.trust_env,
os.environ.get("SSLKEYLOGFILE"),
ca_bundle_path,
)


class Timeout:
Expand Down
100 changes: 99 additions & 1 deletion httpx/_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import codecs
import collections
import contextlib
import functools
import logging
import mimetypes
import netrc
import os
import re
import ssl
import sys
import typing
import warnings
Expand All @@ -16,7 +18,7 @@
from urllib.request import getproxies

from ._exceptions import NetworkError
from ._types import PrimitiveData, StrOrBytes
from ._types import CertTypes, PrimitiveData, StrOrBytes

if typing.TYPE_CHECKING: # pragma: no cover
from ._models import URL
Expand Down Expand Up @@ -405,3 +407,99 @@ def as_network_error(*exception_classes: type) -> typing.Iterator[None]:

def warn_deprecated(message: str) -> None:
warnings.warn(message, DeprecationWarning, stacklevel=2)


DEFAULT_CIPHERS = ":".join(
[
"ECDHE+AESGCM",
"ECDHE+CHACHA20",
"DHE+AESGCM",
"DHE+CHACHA20",
"ECDH+AESGCM",
"DH+AESGCM",
"ECDH+AES",
"DH+AES",
"RSA+AESGCM",
"RSA+AES",
"!aNULL",
"!eNULL",
"!MD5",
"!DSS",
]
)


@functools.lru_cache(1)
def create_and_verify_ssl_context(
verify_mode: int,
cert: CertTypes,
http2: bool,
trust_env: bool,
keylogfile: typing.Optional[str] = None,
ca_bundle_path: typing.Optional[Path] = None,
) -> ssl.SSLContext:
def create_default_ssl_context() -> ssl.SSLContext:
"""
Creates the default SSLContext object that's used for both verified
and unverified connections.
"""
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_TLSv1
context.options |= ssl.OP_NO_TLSv1_1
context.options |= ssl.OP_NO_COMPRESSION
context.set_ciphers(DEFAULT_CIPHERS)

if ssl.HAS_ALPN:
alpn_idents = ["http/1.1", "h2"] if http2 else ["http/1.1"]
context.set_alpn_protocols(alpn_idents)

if (
hasattr(context, "keylog_filename") and trust_env and keylogfile
): # pragma: nocover (Available in 3.8+)
context.keylog_filename = keylogfile # type: ignore

return context

context = create_default_ssl_context()
context.verify_mode = verify_mode
context.check_hostname = verify_mode != ssl.CERT_NONE

if context.check_hostname and ca_bundle_path is not None:
# Signal to server support for PHA in TLS 1.3. Raises an
# AttributeError if only read-only access is implemented.
try:
context.post_handshake_auth = True # type: ignore
except AttributeError: # pragma: nocover
pass

# Disable using 'commonName' for SSLContext.check_hostname
# when the 'subjectAltName' extension isn't available.
try:
context.hostname_checks_common_name = False # type: ignore
except AttributeError: # pragma: nocover
pass

if ca_bundle_path.is_file():
context.load_verify_locations(cafile=str(ca_bundle_path))
elif ca_bundle_path.is_dir():
context.load_verify_locations(capath=str(ca_bundle_path))

load_client_certs(context, cert)
return context


def load_client_certs(ssl_context: ssl.SSLContext, cert: CertTypes) -> None:
"""
Loads client certificates into our SSLContext object
"""
if cert is not None:
if isinstance(cert, str):
ssl_context.load_cert_chain(certfile=cert)
elif isinstance(cert, tuple) and len(cert) == 2:
ssl_context.load_cert_chain(certfile=cert[0], keyfile=cert[1])
elif isinstance(cert, tuple) and len(cert) == 3:
ssl_context.load_cert_chain(
certfile=cert[0], keyfile=cert[1], password=cert[2], # type: ignore
)
34 changes: 34 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,40 @@ def test_ssl_eq():
assert ssl == SSLConfig(verify=False)


def test_ssl_context_cache():
ssl1 = SSLConfig()
ssl2 = SSLConfig()
assert id(ssl1.ssl_context) == id(ssl2.ssl_context)

ssl3 = SSLConfig(http2=True)
assert id(ssl1.ssl_context) != id(ssl3.ssl_context)

ssl4 = SSLConfig(http2=True, trust_env=None)
assert id(ssl3.ssl_context) == id(ssl4.ssl_context)

ssl5 = SSLConfig(http2=True, trust_env=True)
assert id(ssl4.ssl_context) != id(ssl5.ssl_context)


def test_ssl_context_cache_for_path_param(
cert_pem_file, cert_private_key_file, cert_encrypted_private_key_file
):
ssl1 = SSLConfig(cert=(cert_pem_file, cert_private_key_file))
ssl2 = SSLConfig(cert=(cert_pem_file, cert_encrypted_private_key_file, "password"))
ssl3 = SSLConfig(cert=(cert_pem_file, cert_encrypted_private_key_file, b"password"))

assert id(ssl1.ssl_context) != id(ssl2.ssl_context) != id(ssl3.ssl_context)


def test_ssl_context_cache_for_cert_param():
path = Path(certifi.where()).parent
ssl1 = SSLConfig()
ssl2 = SSLConfig(verify=path)
ssl3 = SSLConfig(verify="/")

assert id(ssl1.ssl_context) != id(ssl2.ssl_context) != id(ssl3.ssl_context)


def test_limits_repr():
limits = httpx.PoolLimits(max_connections=100)
assert repr(limits) == "PoolLimits(max_keepalive=None, max_connections=100)"
Expand Down

0 comments on commit 76bb86c

Please sign in to comment.