From 313b2b2bedcc10baf5871124ee915fdc48f5c4b7 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Mon, 9 Oct 2023 10:27:56 -0700 Subject: [PATCH] Use the `timeout` context manager in the connection path (#1087) Drop timeout management gymnastics from the `connect()` path and use the `timeout` context manager instead. --- asyncpg/compat.py | 6 ++++++ asyncpg/connect_utils.py | 45 +++++++++++----------------------------- asyncpg/connection.py | 43 +++++++++++++++++++------------------- tests/test_adversity.py | 17 +++++++++++++++ tests/test_connect.py | 2 +- 5 files changed, 58 insertions(+), 55 deletions(-) diff --git a/asyncpg/compat.py b/asyncpg/compat.py index 532c197a..3eec9eb7 100644 --- a/asyncpg/compat.py +++ b/asyncpg/compat.py @@ -53,3 +53,9 @@ async def wait_closed(stream): from ._asyncio_compat import wait_for as wait_for # noqa: F401 else: from asyncio import wait_for as wait_for # noqa: F401 + + +if sys.version_info < (3, 11): + from ._asyncio_compat import timeout_ctx as timeout # noqa: F401 +else: + from asyncio import timeout as timeout # noqa: F401 diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 9feef139..760e1297 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -20,7 +20,6 @@ import stat import struct import sys -import time import typing import urllib.parse import warnings @@ -55,7 +54,6 @@ def parse(cls, sslmode): 'ssl', 'sslmode', 'direct_tls', - 'connect_timeout', 'server_settings', 'target_session_attrs', ]) @@ -262,7 +260,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, - direct_tls, connect_timeout, server_settings, + direct_tls, server_settings, target_session_attrs): # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. @@ -655,14 +653,14 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, params = _ConnectionParameters( user=user, password=password, database=database, ssl=ssl, sslmode=sslmode, direct_tls=direct_tls, - connect_timeout=connect_timeout, server_settings=server_settings, + server_settings=server_settings, target_session_attrs=target_session_attrs) return addrs, params def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, - database, timeout, command_timeout, + database, command_timeout, statement_cache_size, max_cached_statement_lifetime, max_cacheable_statement_size, @@ -695,7 +693,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, ssl=ssl, direct_tls=direct_tls, database=database, - connect_timeout=timeout, server_settings=server_settings, + server_settings=server_settings, target_session_attrs=target_session_attrs) config = _ClientConfiguration( @@ -799,7 +797,6 @@ async def _connect_addr( *, addr, loop, - timeout, params, config, connection_class, @@ -807,9 +804,6 @@ async def _connect_addr( ): assert loop is not None - if timeout <= 0: - raise asyncio.TimeoutError - params_input = params if callable(params.password): password = params.password() @@ -827,21 +821,16 @@ async def _connect_addr( params_retry = params._replace(ssl=None) else: # skip retry if we don't have to - return await __connect_addr(params, timeout, False, *args) + return await __connect_addr(params, False, *args) # first attempt - before = time.monotonic() try: - return await __connect_addr(params, timeout, True, *args) + return await __connect_addr(params, True, *args) except _RetryConnectSignal: pass # second attempt - timeout -= time.monotonic() - before - if timeout <= 0: - raise asyncio.TimeoutError - else: - return await __connect_addr(params_retry, timeout, False, *args) + return await __connect_addr(params_retry, False, *args) class _RetryConnectSignal(Exception): @@ -850,7 +839,6 @@ class _RetryConnectSignal(Exception): async def __connect_addr( params, - timeout, retry, addr, loop, @@ -882,15 +870,10 @@ async def __connect_addr( else: connector = loop.create_connection(proto_factory, *addr) - connector = asyncio.ensure_future(connector) - before = time.monotonic() - tr, pr = await compat.wait_for(connector, timeout=timeout) - timeout -= time.monotonic() - before + tr, pr = await connector try: - if timeout <= 0: - raise asyncio.TimeoutError - await compat.wait_for(connected, timeout=timeout) + await connected except ( exceptions.InvalidAuthorizationSpecificationError, exceptions.ConnectionDoesNotExistError, # seen on Windows @@ -993,23 +976,21 @@ async def _can_use_connection(connection, attr: SessionAttribute): return await can_use(connection) -async def _connect(*, loop, timeout, connection_class, record_class, **kwargs): +async def _connect(*, loop, connection_class, record_class, **kwargs): if loop is None: loop = asyncio.get_event_loop() - addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs) + addrs, params, config = _parse_connect_arguments(**kwargs) target_attr = params.target_session_attrs candidates = [] chosen_connection = None last_error = None for addr in addrs: - before = time.monotonic() try: conn = await _connect_addr( addr=addr, loop=loop, - timeout=timeout, params=params, config=config, connection_class=connection_class, @@ -1019,10 +1000,8 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs): if await _can_use_connection(conn, target_attr): chosen_connection = conn break - except (OSError, asyncio.TimeoutError, ConnectionError) as ex: + except OSError as ex: last_error = ex - finally: - timeout -= time.monotonic() - before else: if target_attr == SessionAttribute.prefer_standby and candidates: chosen_connection = random.choice(candidates) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 06e4ce23..810227c7 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -20,6 +20,7 @@ import warnings import weakref +from . import compat from . import connect_utils from . import cursor from . import exceptions @@ -2184,27 +2185,27 @@ async def connect(dsn=None, *, if loop is None: loop = asyncio.get_event_loop() - return await connect_utils._connect( - loop=loop, - timeout=timeout, - connection_class=connection_class, - record_class=record_class, - dsn=dsn, - host=host, - port=port, - user=user, - password=password, - passfile=passfile, - ssl=ssl, - direct_tls=direct_tls, - database=database, - server_settings=server_settings, - command_timeout=command_timeout, - statement_cache_size=statement_cache_size, - max_cached_statement_lifetime=max_cached_statement_lifetime, - max_cacheable_statement_size=max_cacheable_statement_size, - target_session_attrs=target_session_attrs - ) + async with compat.timeout(timeout): + return await connect_utils._connect( + loop=loop, + connection_class=connection_class, + record_class=record_class, + dsn=dsn, + host=host, + port=port, + user=user, + password=password, + passfile=passfile, + ssl=ssl, + direct_tls=direct_tls, + database=database, + server_settings=server_settings, + command_timeout=command_timeout, + statement_cache_size=statement_cache_size, + max_cached_statement_lifetime=max_cached_statement_lifetime, + max_cacheable_statement_size=max_cacheable_statement_size, + target_session_attrs=target_session_attrs + ) class _StatementCacheEntry: diff --git a/tests/test_adversity.py b/tests/test_adversity.py index 71532317..a6e03feb 100644 --- a/tests/test_adversity.py +++ b/tests/test_adversity.py @@ -26,6 +26,23 @@ async def test_connection_close_timeout(self): with self.assertRaises(asyncio.TimeoutError): await con.close(timeout=0.5) + @tb.with_timeout(30.0) + async def test_pool_acquire_timeout(self): + pool = await self.create_pool( + database='postgres', min_size=2, max_size=2) + try: + self.proxy.trigger_connectivity_loss() + for _ in range(2): + with self.assertRaises(asyncio.TimeoutError): + async with pool.acquire(timeout=0.5): + pass + self.proxy.restore_connectivity() + async with pool.acquire(timeout=0.5): + pass + finally: + self.proxy.restore_connectivity() + pool.terminate() + @tb.with_timeout(30.0) async def test_pool_release_timeout(self): pool = await self.create_pool( diff --git a/tests/test_connect.py b/tests/test_connect.py index 171c2644..f61db61a 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -891,7 +891,7 @@ def run_testcase(self, testcase): addrs, params = connect_utils._parse_connect_dsn_and_args( dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, database=database, ssl=sslmode, - direct_tls=False, connect_timeout=None, + direct_tls=False, server_settings=server_settings, target_session_attrs=target_session_attrs)