Skip to content

Commit

Permalink
Add support for the sslnegotiation parameter
Browse files Browse the repository at this point in the history
Direct TLS connections are already supported via the `direct_tls`
argument, however PostgreSQL 17 added native support for this via
`sslnegotiation`, so recognize it in DSNs and the environment.  I
decided not to introduce the `sslnegotiation` connection constructor
argument for now, `direct_tls` should continue to be used instead.
  • Loading branch information
elprans committed Oct 17, 2024
1 parent 8f2be4c commit d8d0efb
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 7 deletions.
8 changes: 8 additions & 0 deletions asyncpg/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import enum
import pathlib
import platform
import typing
Expand Down Expand Up @@ -78,3 +79,10 @@ def markcoroutinefunction(c): # type: ignore
from collections.abc import ( # noqa: F401
Awaitable as Awaitable,
)

if sys.version_info < (3, 11):
class StrEnum(str, enum.Enum):
__str__ = str.__str__
__repr__ = enum.Enum.__repr__
else:
from enum import StrEnum as StrEnum # noqa: F401
41 changes: 36 additions & 5 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def parse(cls, sslmode):
return getattr(cls, sslmode.replace('-', '_'))


class SSLNegotiation(compat.StrEnum):
postgres = "postgres"
direct = "direct"


_ConnectionParameters = collections.namedtuple(
'ConnectionParameters',
[
Expand All @@ -53,7 +58,7 @@ def parse(cls, sslmode):
'database',
'ssl',
'sslmode',
'direct_tls',
'ssl_negotiation',
'server_settings',
'target_session_attrs',
'krbsrvname',
Expand Down Expand Up @@ -269,6 +274,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
auth_hosts = None
sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None
ssl_min_protocol_version = ssl_max_protocol_version = None
sslnegotiation = None

if dsn:
parsed = urllib.parse.urlparse(dsn)
Expand Down Expand Up @@ -362,6 +368,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if 'sslrootcert' in query:
sslrootcert = query.pop('sslrootcert')

if 'sslnegotiation' in query:
sslnegotiation = query.pop('sslnegotiation')

if 'sslcrl' in query:
sslcrl = query.pop('sslcrl')

Expand Down Expand Up @@ -503,6 +512,28 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if ssl is None and have_tcp_addrs:
ssl = 'prefer'

if direct_tls is not None:
sslneg = (
SSLNegotiation.direct if direct_tls else SSLNegotiation.postgres
)
else:
if sslnegotiation is None:
sslnegotiation = os.environ.get("PGSSLNEGOTIATION")

if sslnegotiation is not None:
try:
sslneg = SSLNegotiation(sslnegotiation)
except ValueError:
modes = ', '.join(
m.name.replace('_', '-')
for m in SSLNegotiation
)
raise exceptions.ClientConfigurationError(
f'`sslnegotiation` parameter must be one of: {modes}'
)
else:
sslneg = SSLNegotiation.postgres

if isinstance(ssl, (str, SSLMode)):
try:
sslmode = SSLMode.parse(ssl)
Expand Down Expand Up @@ -676,7 +707,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,

params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
sslmode=sslmode, direct_tls=direct_tls,
sslmode=sslmode, ssl_negotiation=sslneg,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
Expand Down Expand Up @@ -882,9 +913,9 @@ async def __connect_addr(
# UNIX socket
connector = loop.create_unix_connection(proto_factory, addr)

elif params.ssl and params.direct_tls:
# if ssl and direct_tls are given, skip STARTTLS and perform direct
# SSL connection
elif params.ssl and params.ssl_negotiation is SSLNegotiation.direct:
# if ssl and ssl_negotiation is `direct`, skip STARTTLS and perform
# direct SSL connection
connector = loop.create_connection(
proto_factory, *addr, ssl=params.ssl
)
Expand Down
2 changes: 1 addition & 1 deletion asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,7 +2001,7 @@ async def connect(dsn=None, *,
max_cacheable_statement_size=1024 * 15,
command_timeout=None,
ssl=None,
direct_tls=False,
direct_tls=None,
connection_class=Connection,
record_class=protocol.Record,
server_settings=None,
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ exclude_lines = [
show_missing = true

[tool.mypy]
exclude = [
"^.eggs",
"^.github",
"^.vscode",
"^build",
"^dist",
"^docs",
"^tests",
]
incremental = true
strict = true
implicit_reexport = true
Expand Down
59 changes: 58 additions & 1 deletion tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,58 @@ class TestConnectParams(tb.TestCase):
'target_session_attrs': 'any'})
},

{
'name': 'params_ssl_negotiation_dsn',
'env': {
'PGSSLNEGOTIATION': 'postgres'
},

'dsn': 'postgres://u:p@localhost/d?sslnegotiation=direct',

'result': ([('localhost', 5432)], {
'user': 'u',
'password': 'p',
'database': 'd',
'ssl_negotiation': 'direct',
'target_session_attrs': 'any',
})
},

{
'name': 'params_ssl_negotiation_env',
'env': {
'PGSSLNEGOTIATION': 'direct'
},

'dsn': 'postgres://u:p@localhost/d',

'result': ([('localhost', 5432)], {
'user': 'u',
'password': 'p',
'database': 'd',
'ssl_negotiation': 'direct',
'target_session_attrs': 'any',
})
},

{
'name': 'params_ssl_negotiation_params',
'env': {
'PGSSLNEGOTIATION': 'direct'
},

'dsn': 'postgres://u:p@localhost/d',
'direct_tls': False,

'result': ([('localhost', 5432)], {
'user': 'u',
'password': 'p',
'database': 'd',
'ssl_negotiation': 'postgres',
'target_session_attrs': 'any',
})
},

{
'name': 'dsn_overrides_env_partially_ssl_prefer',
'env': {
Expand Down Expand Up @@ -1067,6 +1119,7 @@ def run_testcase(self, testcase):
passfile = testcase.get('passfile')
database = testcase.get('database')
sslmode = testcase.get('ssl')
direct_tls = testcase.get('direct_tls')
server_settings = testcase.get('server_settings')
target_session_attrs = testcase.get('target_session_attrs')
krbsrvname = testcase.get('krbsrvname')
Expand All @@ -1093,7 +1146,7 @@ def run_testcase(self, testcase):
addrs, params = connect_utils._parse_connect_dsn_and_args(
dsn=dsn, host=host, port=port, user=user, password=password,
passfile=passfile, database=database, ssl=sslmode,
direct_tls=False,
direct_tls=direct_tls,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
Expand All @@ -1118,6 +1171,10 @@ def run_testcase(self, testcase):
# Avoid the hassle of specifying direct_tls
# unless explicitly tested for
params.pop('direct_tls', False)
if 'ssl_negotiation' not in expected[1]:
# Avoid the hassle of specifying sslnegotiation
# unless explicitly tested for
params.pop('ssl_negotiation', False)
if 'gsslib' not in expected[1]:
# Avoid the hassle of specifying gsslib
# unless explicitly tested for
Expand Down

0 comments on commit d8d0efb

Please sign in to comment.