From afdb05c7bffc6b199388148f0497ce8a5bc77e25 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Thu, 17 Oct 2024 17:43:32 -0700 Subject: [PATCH] Add support for the `sslnegotiation` parameter (#1187) Direct TLS connections are already supported via the `direct_tls` argument, however PostgreSQL 17 added native support for this via `sslnegotiation`, so recognize it in DSNs and the environment. I decided not to introduce the `sslnegotiation` connection constructor argument for now, `direct_tls` should continue to be used instead. --- asyncpg/compat.py | 8 ++++++ asyncpg/connect_utils.py | 44 ++++++++++++++++++++++++++---- asyncpg/connection.py | 2 +- pyproject.toml | 9 ++++++ tests/test_connect.py | 59 +++++++++++++++++++++++++++++++++++++++- 5 files changed, 114 insertions(+), 8 deletions(-) diff --git a/asyncpg/compat.py b/asyncpg/compat.py index 881873a2..57eec650 100644 --- a/asyncpg/compat.py +++ b/asyncpg/compat.py @@ -6,6 +6,7 @@ from __future__ import annotations +import enum import pathlib import platform import typing @@ -78,3 +79,10 @@ def markcoroutinefunction(c): # type: ignore from collections.abc import ( # noqa: F401 Awaitable as Awaitable, ) + +if sys.version_info < (3, 11): + class StrEnum(str, enum.Enum): + __str__ = str.__str__ + __repr__ = enum.Enum.__repr__ +else: + from enum import StrEnum as StrEnum # noqa: F401 diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 0631f976..4890d007 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -45,6 +45,11 @@ def parse(cls, sslmode): return getattr(cls, sslmode.replace('-', '_')) +class SSLNegotiation(compat.StrEnum): + postgres = "postgres" + direct = "direct" + + _ConnectionParameters = collections.namedtuple( 'ConnectionParameters', [ @@ -53,7 +58,7 @@ def parse(cls, sslmode): 'database', 'ssl', 'sslmode', - 'direct_tls', + 'ssl_negotiation', 'server_settings', 'target_session_attrs', 'krbsrvname', @@ -269,6 +274,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, auth_hosts = None sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None ssl_min_protocol_version = ssl_max_protocol_version = None + sslnegotiation = None if dsn: parsed = urllib.parse.urlparse(dsn) @@ -362,6 +368,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if 'sslrootcert' in query: sslrootcert = query.pop('sslrootcert') + if 'sslnegotiation' in query: + sslnegotiation = query.pop('sslnegotiation') + if 'sslcrl' in query: sslcrl = query.pop('sslcrl') @@ -503,13 +512,36 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if ssl is None and have_tcp_addrs: ssl = 'prefer' + if direct_tls is not None: + sslneg = ( + SSLNegotiation.direct if direct_tls else SSLNegotiation.postgres + ) + else: + if sslnegotiation is None: + sslnegotiation = os.environ.get("PGSSLNEGOTIATION") + + if sslnegotiation is not None: + try: + sslneg = SSLNegotiation(sslnegotiation) + except ValueError: + modes = ', '.join( + m.name.replace('_', '-') + for m in SSLNegotiation + ) + raise exceptions.ClientConfigurationError( + f'`sslnegotiation` parameter must be one of: {modes}' + ) from None + else: + sslneg = SSLNegotiation.postgres + if isinstance(ssl, (str, SSLMode)): try: sslmode = SSLMode.parse(ssl) except AttributeError: modes = ', '.join(m.name.replace('_', '-') for m in SSLMode) raise exceptions.ClientConfigurationError( - '`sslmode` parameter must be one of: {}'.format(modes)) + '`sslmode` parameter must be one of: {}'.format(modes) + ) from None # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html if sslmode < SSLMode.allow: @@ -676,7 +708,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, params = _ConnectionParameters( user=user, password=password, database=database, ssl=ssl, - sslmode=sslmode, direct_tls=direct_tls, + sslmode=sslmode, ssl_negotiation=sslneg, server_settings=server_settings, target_session_attrs=target_session_attrs, krbsrvname=krbsrvname, gsslib=gsslib) @@ -882,9 +914,9 @@ async def __connect_addr( # UNIX socket connector = loop.create_unix_connection(proto_factory, addr) - elif params.ssl and params.direct_tls: - # if ssl and direct_tls are given, skip STARTTLS and perform direct - # SSL connection + elif params.ssl and params.ssl_negotiation is SSLNegotiation.direct: + # if ssl and ssl_negotiation is `direct`, skip STARTTLS and perform + # direct SSL connection connector = loop.create_connection( proto_factory, *addr, ssl=params.ssl ) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 18892cfd..6ac2a09d 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -2001,7 +2001,7 @@ async def connect(dsn=None, *, max_cacheable_statement_size=1024 * 15, command_timeout=None, ssl=None, - direct_tls=False, + direct_tls=None, connection_class=Connection, record_class=protocol.Record, server_settings=None, diff --git a/pyproject.toml b/pyproject.toml index d7a6ebcb..15c034f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,15 @@ exclude_lines = [ show_missing = true [tool.mypy] +exclude = [ + "^.eggs", + "^.github", + "^.vscode", + "^build", + "^dist", + "^docs", + "^tests", +] incremental = true strict = true implicit_reexport = true diff --git a/tests/test_connect.py b/tests/test_connect.py index 049aea26..517f05f9 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -592,6 +592,58 @@ class TestConnectParams(tb.TestCase): 'target_session_attrs': 'any'}) }, + { + 'name': 'params_ssl_negotiation_dsn', + 'env': { + 'PGSSLNEGOTIATION': 'postgres' + }, + + 'dsn': 'postgres://u:p@localhost/d?sslnegotiation=direct', + + 'result': ([('localhost', 5432)], { + 'user': 'u', + 'password': 'p', + 'database': 'd', + 'ssl_negotiation': 'direct', + 'target_session_attrs': 'any', + }) + }, + + { + 'name': 'params_ssl_negotiation_env', + 'env': { + 'PGSSLNEGOTIATION': 'direct' + }, + + 'dsn': 'postgres://u:p@localhost/d', + + 'result': ([('localhost', 5432)], { + 'user': 'u', + 'password': 'p', + 'database': 'd', + 'ssl_negotiation': 'direct', + 'target_session_attrs': 'any', + }) + }, + + { + 'name': 'params_ssl_negotiation_params', + 'env': { + 'PGSSLNEGOTIATION': 'direct' + }, + + 'dsn': 'postgres://u:p@localhost/d', + 'direct_tls': False, + + 'result': ([('localhost', 5432)], { + 'user': 'u', + 'password': 'p', + 'database': 'd', + 'ssl_negotiation': 'postgres', + 'target_session_attrs': 'any', + }) + }, + { 'name': 'dsn_overrides_env_partially_ssl_prefer', 'env': { @@ -1067,6 +1119,7 @@ def run_testcase(self, testcase): passfile = testcase.get('passfile') database = testcase.get('database') sslmode = testcase.get('ssl') + direct_tls = testcase.get('direct_tls') server_settings = testcase.get('server_settings') target_session_attrs = testcase.get('target_session_attrs') krbsrvname = testcase.get('krbsrvname') @@ -1093,7 +1146,7 @@ def run_testcase(self, testcase): addrs, params = connect_utils._parse_connect_dsn_and_args( dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, database=database, ssl=sslmode, - direct_tls=False, + direct_tls=direct_tls, server_settings=server_settings, target_session_attrs=target_session_attrs, krbsrvname=krbsrvname, gsslib=gsslib) @@ -1118,6 +1171,10 @@ def run_testcase(self, testcase): # Avoid the hassle of specifying direct_tls # unless explicitly tested for params.pop('direct_tls', False) + if 'ssl_negotiation' not in expected[1]: + # Avoid the hassle of specifying sslnegotiation + # unless explicitly tested for + params.pop('ssl_negotiation', False) if 'gsslib' not in expected[1]: # Avoid the hassle of specifying gsslib # unless explicitly tested for