diff --git a/docs/api/asyncio_con.rst b/docs/api/asyncio_con.rst index 90f42224..eaa5ffdb 100644 --- a/docs/api/asyncio_con.rst +++ b/docs/api/asyncio_con.rst @@ -44,7 +44,6 @@ Client connection pool .. py:function:: create_async_client(dsn=None, *, \ host=None, port=None, \ - admin=None, \ user=None, password=None, \ database=None, \ timeout=60, \ @@ -100,9 +99,6 @@ Client connection pool or the value of the ``EDGEDB_PORT`` environment variable, or ``5656`` if neither is specified. - :param admin: - If ``True``, try to connect to the special administration socket. - :param user: The name of the database role used for authentication. diff --git a/edgedb/__init__.py b/edgedb/__init__.py index 7c0eeb98..e082fb8a 100644 --- a/edgedb/__init__.py +++ b/edgedb/__init__.py @@ -26,15 +26,16 @@ ) from edgedb.datatypes.datatypes import Set, Object, Array, Link, LinkSet -from .abstract import Executor, AsyncIOExecutor +from .abstract import ( + Executor, AsyncIOExecutor, ReadOnlyExecutor, AsyncIOReadOnlyExecutor +) -from .asyncio_con import async_connect_raw, AsyncIOConnection -from .asyncio_pool import ( +from .asyncio_client import ( create_async_client, AsyncIOClient ) -from .blocking_con import connect, BlockingIOConnection +from .blocking_client import create_client, Client from .options import RetryCondition, IsolationLevel, default_backoff from .options import RetryOptions, TransactionOptions diff --git a/edgedb/_testbase.py b/edgedb/_testbase.py index 7b94054c..aeac40c6 100644 --- a/edgedb/_testbase.py +++ b/edgedb/_testbase.py @@ -33,6 +33,8 @@ import unittest import edgedb +from edgedb import asyncio_client +from edgedb import blocking_client log = logging.getLogger(__name__) @@ -163,10 +165,11 @@ def _start_cluster(*, cleanup_atexit=True): else: con_args['tls_ca_file'] = data['tls_cert_file'] - con = edgedb.connect(password='test', **con_args) + client = edgedb.create_client(password='test', **con_args) + client.ensure_connected() _default_cluster = { 'proc': p, - 'con': con, + 'client': client, 'con_args': con_args, } @@ -174,7 +177,7 @@ def _start_cluster(*, cleanup_atexit=True): # Keep the temp dir which we also copied the cert from WSL _default_cluster['_tmpdir'] = tmpdir - atexit.register(con.close) + atexit.register(client.close) except Exception as e: _default_cluster = e raise e @@ -225,7 +228,7 @@ def wrapper(self, *args, __meth__=meth, **kwargs): if try_no == 3: raise else: - self.loop.run_until_complete(self.con.execute( + self.loop.run_until_complete(self.client.execute( 'ROLLBACK;' )) try_no += 1 @@ -319,17 +322,43 @@ def setUpClass(cls): cls.cluster = _start_cluster(cleanup_atexit=True) +class TestAsyncIOClient(edgedb.AsyncIOClient): + def _clear_codecs_cache(self): + self._impl.codecs_registry.clear_cache() + + @property + def connection(self): + return self._impl._holders[0]._con + + @property + def dbname(self): + return self._impl._working_params.database + + +class TestClient(edgedb.Client): + @property + def connection(self): + return self._impl._holders[0]._con + + class ConnectedTestCaseMixin: @classmethod - async def connect(cls, *, - cluster=None, - database='edgedb', - user='edgedb', - password='test'): + def test_client( + cls, *, + cluster=None, + database='edgedb', + user='edgedb', + password='test', + connection_class=asyncio_client.AsyncIOConnection, + ): conargs = cls.get_connect_args( cluster=cluster, database=database, user=user, password=password) - return await edgedb.async_connect_raw(**conargs) + return TestAsyncIOClient( + connection_class=connection_class, + max_concurrency=1, + **conargs, + ) @classmethod def get_connect_args(cls, *, @@ -358,16 +387,17 @@ class DatabaseTestCase(ClusterTestCase, ConnectedTestCaseMixin): INTERNAL_TESTMODE = True BASE_TEST_CLASS = True + TEARDOWN_RETRY_DROP_DB = 1 def setUp(self): if self.INTERNAL_TESTMODE: self.loop.run_until_complete( - self.con.execute( + self.client.execute( 'CONFIGURE SESSION SET __internal_testmode := true;')) if self.SETUP_METHOD: self.loop.run_until_complete( - self.con.execute(self.SETUP_METHOD)) + self.client.execute(self.SETUP_METHOD)) super().setUp() @@ -375,16 +405,16 @@ def tearDown(self): try: if self.TEARDOWN_METHOD: self.loop.run_until_complete( - self.con.execute(self.TEARDOWN_METHOD)) + self.client.execute(self.TEARDOWN_METHOD)) finally: try: - if self.con.is_in_transaction(): + if self.client.connection.is_in_transaction(): raise AssertionError( 'test connection is still in transaction ' '*after* the test') self.loop.run_until_complete( - self.con.execute('RESET ALIAS *;')) + self.client.execute('RESET ALIAS *;')) finally: super().tearDown() @@ -394,18 +424,17 @@ def setUpClass(cls): super().setUpClass() dbname = cls.get_database_name() - cls.admin_conn = None - cls.con = None + cls.admin_client = None class_set_up = os.environ.get('EDGEDB_TEST_CASES_SET_UP') # Only open an extra admin connection if necessary. if not class_set_up: script = f'CREATE DATABASE {dbname};' - cls.admin_conn = cls.loop.run_until_complete(cls.connect()) - cls.loop.run_until_complete(cls.admin_conn.execute(script)) + cls.admin_client = cls.test_client() + cls.loop.run_until_complete(cls.admin_client.execute(script)) - cls.con = cls.loop.run_until_complete(cls.connect(database=dbname)) + cls.client = cls.test_client(database=dbname) if not class_set_up: script = cls.get_setup_script() @@ -413,7 +442,7 @@ def setUpClass(cls): # The setup is expected to contain a CREATE MIGRATION, # which needs to be wrapped in a transaction. async def execute(): - async for tr in cls.con.transaction(): + async for tr in cls.client.transaction(): async with tr: await tr.execute(script) cls.loop.run_until_complete(execute()) @@ -482,17 +511,27 @@ def tearDownClass(cls): try: if script: cls.loop.run_until_complete( - cls.con.execute(script)) + cls.client.execute(script)) finally: try: - cls.loop.run_until_complete(cls.con.aclose()) + cls.loop.run_until_complete(cls.client.aclose()) if not class_set_up: dbname = cls.get_database_name() script = f'DROP DATABASE {dbname};' - cls.loop.run_until_complete( - cls.admin_conn.execute(script)) + retry = cls.TEARDOWN_RETRY_DROP_DB + for i in range(retry): + try: + cls.loop.run_until_complete( + cls.admin_client.execute(script)) + except edgedb.errors.ExecutionError: + if i < retry - 1: + time.sleep(0.1) + else: + raise + except edgedb.errors.UnknownDatabaseError: + break except Exception: log.exception('error running teardown') @@ -500,9 +539,9 @@ def tearDownClass(cls): # of finalizer error finally: try: - if cls.admin_conn is not None: + if cls.admin_client is not None: cls.loop.run_until_complete( - cls.admin_conn.aclose()) + cls.admin_client.aclose()) finally: super().tearDownClass() @@ -513,23 +552,28 @@ class AsyncQueryTestCase(DatabaseTestCase): class SyncQueryTestCase(DatabaseTestCase): BASE_TEST_CLASS = True + TEARDOWN_RETRY_DROP_DB = 5 def setUp(self): super().setUp() cls = type(self) - cls.async_con = cls.con + cls.async_client = cls.client conargs = cls.get_connect_args().copy() - conargs.update(dict(database=cls.async_con.dbname)) + conargs.update(dict(database=cls.async_client.dbname)) - cls.con = edgedb.connect(**conargs) + cls.client = TestClient( + connection_class=blocking_client.BlockingIOConnection, + max_concurrency=1, + **conargs + ) def tearDown(self): cls = type(self) - cls.con.close() - cls.con = cls.async_con - del cls.async_con + cls.client.close() + cls.client = cls.async_client + del cls.async_client _lock_cnt = 0 diff --git a/edgedb/abstract.py b/edgedb/abstract.py index cea8036d..0eb7d1c7 100644 --- a/edgedb/abstract.py +++ b/edgedb/abstract.py @@ -1,42 +1,147 @@ import abc import typing +from . import options from .datatypes import datatypes +from .protocol import protocol + +__all__ = ( + "QueryWithArgs", + "QueryCache", + "QueryOptions", + "QueryContext", + "Executor", + "AsyncIOExecutor", + "ReadOnlyExecutor", + "AsyncIOReadOnlyExecutor", +) + + +class QueryWithArgs(typing.NamedTuple): + query: str + args: typing.Tuple + kwargs: typing.Dict[str, typing.Any] + + +class QueryCache(typing.NamedTuple): + codecs_registry: protocol.CodecsRegistry + query_cache: protocol.QueryCodecsCache + + +class QueryOptions(typing.NamedTuple): + io_format: protocol.IoFormat + expect_one: bool + required_one: bool + + +class QueryContext(typing.NamedTuple): + query: QueryWithArgs + cache: QueryCache + query_options: QueryOptions + retry_options: typing.Optional[options.RetryOptions] + + +_query_opts = QueryOptions( + io_format=protocol.IoFormat.BINARY, + expect_one=False, + required_one=False, +) +_query_single_opts = QueryOptions( + io_format=protocol.IoFormat.BINARY, + expect_one=True, + required_one=False, +) +_query_required_single_opts = QueryOptions( + io_format=protocol.IoFormat.BINARY, + expect_one=True, + required_one=True, +) +_query_json_opts = QueryOptions( + io_format=protocol.IoFormat.JSON, + expect_one=False, + required_one=False, +) +_query_single_json_opts = QueryOptions( + io_format=protocol.IoFormat.JSON, + expect_one=True, + required_one=False, +) +_query_required_single_json_opts = QueryOptions( + io_format=protocol.IoFormat.JSON, + expect_one=True, + required_one=True, +) + + +class BaseReadOnlyExecutor(abc.ABC): + __slots__ = () + @abc.abstractmethod + def _get_query_cache(self) -> QueryCache: + ... -__all__ = ('Executor', 'AsyncIOExecutor') + def _get_retry_options(self) -> typing.Optional[options.RetryOptions]: + return None -class ReadOnlyExecutor(abc.ABC): +class ReadOnlyExecutor(BaseReadOnlyExecutor): """Subclasses can execute *at least* read-only queries""" __slots__ = () @abc.abstractmethod - def query(self, query: str, *args, **kwargs) -> datatypes.Set: + def _query(self, query_context: QueryContext): ... - @abc.abstractmethod + def query(self, query: str, *args, **kwargs) -> datatypes.Set: + return self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_opts, + retry_options=self._get_retry_options(), + )) + def query_single( self, query: str, *args, **kwargs ) -> typing.Union[typing.Any, None]: - ... + return self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_single_opts, + retry_options=self._get_retry_options(), + )) - @abc.abstractmethod def query_required_single(self, query: str, *args, **kwargs) -> typing.Any: - ... + return self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_required_single_opts, + retry_options=self._get_retry_options(), + )) - @abc.abstractmethod def query_json(self, query: str, *args, **kwargs) -> str: - ... + return self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_json_opts, + retry_options=self._get_retry_options(), + )) - @abc.abstractmethod def query_single_json(self, query: str, *args, **kwargs) -> str: - ... + return self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_single_json_opts, + retry_options=self._get_retry_options(), + )) - @abc.abstractmethod def query_required_single_json(self, query: str, *args, **kwargs) -> str: - ... + return self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_required_single_json_opts, + retry_options=self._get_retry_options(), + )) # TODO(tailhook) add *args, **kwargs, when they are supported @abc.abstractmethod @@ -50,44 +155,72 @@ class Executor(ReadOnlyExecutor): __slots__ = () -class AsyncIOReadOnlyExecutor(abc.ABC): +class AsyncIOReadOnlyExecutor(BaseReadOnlyExecutor): """Subclasses can execute *at least* read-only queries""" __slots__ = () @abc.abstractmethod - async def query(self, query: str, *args, **kwargs) -> datatypes.Set: + async def _query(self, query_context: QueryContext): ... - @abc.abstractmethod + async def query(self, query: str, *args, **kwargs) -> datatypes.Set: + return await self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_opts, + retry_options=self._get_retry_options(), + )) + async def query_single(self, query: str, *args, **kwargs) -> typing.Any: - ... + return await self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_single_opts, + retry_options=self._get_retry_options(), + )) - @abc.abstractmethod async def query_required_single( self, query: str, *args, **kwargs ) -> typing.Any: - ... + return await self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_required_single_opts, + retry_options=self._get_retry_options(), + )) - @abc.abstractmethod async def query_json(self, query: str, *args, **kwargs) -> str: - ... + return await self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_json_opts, + retry_options=self._get_retry_options(), + )) - @abc.abstractmethod async def query_single_json(self, query: str, *args, **kwargs) -> str: - ... + return await self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_single_json_opts, + retry_options=self._get_retry_options(), + )) - @abc.abstractmethod async def query_required_single_json( self, query: str, *args, **kwargs ) -> str: - ... + return await self._query(QueryContext( + query=QueryWithArgs(query, args, kwargs), + cache=self._get_query_cache(), + query_options=_query_required_single_json_opts, + retry_options=self._get_retry_options(), + )) # TODO(tailhook) add *args, **kwargs, when they are supported @abc.abstractmethod diff --git a/edgedb/asyncio_client.py b/edgedb/asyncio_client.py new file mode 100644 index 00000000..1413f890 --- /dev/null +++ b/edgedb/asyncio_client.py @@ -0,0 +1,396 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import asyncio +import functools +import logging +import socket +import ssl +import typing + +from . import abstract +from . import base_client +from . import compat +from . import con_utils +from . import errors +from . import transaction +from .protocol import asyncio_proto + + +__all__ = ( + 'create_async_client', 'AsyncIOClient' +) + + +logger = logging.getLogger(__name__) + + +class AsyncIOConnection(base_client.BaseConnection): + __slots__ = ("_loop",) + _close_exceptions = (Exception, asyncio.CancelledError) + + def __init__(self, loop, *args, **kwargs): + super().__init__(*args, **kwargs) + self._loop = loop + + def is_closed(self): + protocol = self._protocol + return protocol is None or not protocol.connected + + async def connect_addr(self, addr, timeout): + try: + await compat.wait_for(self._connect_addr(addr), timeout) + except asyncio.TimeoutError as e: + raise TimeoutError from e + + async def sleep(self, seconds): + await asyncio.sleep(seconds) + + def _protocol_factory(self, tls_compat=False): + return asyncio_proto.AsyncIOProtocol( + self._params, self._loop, tls_compat=tls_compat + ) + + async def _connect_addr(self, addr): + tr = None + + try: + if isinstance(addr, str): + # UNIX socket + tr, pr = await self._loop.create_unix_connection( + self._protocol_factory, addr + ) + else: + try: + tr, pr = await self._loop.create_connection( + self._protocol_factory, *addr, ssl=self._params.ssl_ctx + ) + except ssl.CertificateError as e: + raise con_utils.wrap_error(e) from e + except ssl.SSLError as e: + if e.reason == 'CERTIFICATE_VERIFY_FAILED': + raise con_utils.wrap_error(e) from e + tr, pr = await self._loop.create_connection( + functools.partial( + self._protocol_factory, tls_compat=True + ), + *addr, + ) + else: + con_utils.check_alpn_protocol( + tr.get_extra_info('ssl_object') + ) + except socket.gaierror as e: + # All name resolution errors are considered temporary + raise errors.ClientConnectionFailedTemporarilyError(str(e)) from e + except OSError as e: + raise con_utils.wrap_error(e) from e + except Exception: + if tr is not None: + tr.close() + raise + + pr.set_connection(self) + + try: + await pr.connect() + except OSError as e: + if tr is not None: + tr.close() + raise con_utils.wrap_error(e) from e + except Exception: + if tr is not None: + tr.close() + raise + + self._protocol = pr + self._addr = addr + + def _dispatch_log_message(self, msg): + for cb in self._log_listeners: + self._loop.call_soon(cb, self, msg) + + +class _PoolConnectionHolder(base_client.PoolConnectionHolder): + __slots__ = () + _event_class = asyncio.Event + + async def close(self, *, wait=True): + if self._con is None: + return + if wait: + await self._con.close() + else: + self._pool._loop.create_task(self._con.close()) + + async def wait_until_released(self, timeout=None): + await self._release_event.wait() + + +class _AsyncIOPoolImpl(base_client.BasePoolImpl): + __slots__ = ('_loop',) + _holder_class = _PoolConnectionHolder + + def __init__( + self, + connect_args, + *, + max_concurrency: typing.Optional[int], + connection_class, + ): + if not issubclass(connection_class, AsyncIOConnection): + raise TypeError( + f'connection_class is expected to be a subclass of ' + f'edgedb.asyncio_client.AsyncIOConnection, ' + f'got {connection_class}') + self._loop = None + super().__init__( + connect_args, + lambda *args: connection_class(self._loop, *args), + max_concurrency=max_concurrency, + ) + + def _ensure_initialized(self): + if self._loop is None: + self._loop = asyncio.get_event_loop() + self._queue = asyncio.LifoQueue(maxsize=self._max_concurrency) + self._first_connect_lock = asyncio.Lock() + self._resize_holder_pool() + + def _set_queue_maxsize(self, maxsize): + self._queue._maxsize = maxsize + + async def _maybe_get_first_connection(self): + async with self._first_connect_lock: + if self._working_addr is None: + return await self._get_first_connection() + + async def acquire(self, timeout=None): + self._ensure_initialized() + + async def _acquire_impl(): + ch = await self._queue.get() # type: _PoolConnectionHolder + try: + proxy = await ch.acquire() # type: AsyncIOConnection + except (Exception, asyncio.CancelledError): + self._queue.put_nowait(ch) + raise + else: + # Record the timeout, as we will apply it by default + # in release(). + ch._timeout = timeout + return proxy + + if self._closing: + raise errors.InterfaceError('pool is closing') + + if timeout is None: + return await _acquire_impl() + else: + return await compat.wait_for( + _acquire_impl(), timeout=timeout) + + async def _release(self, holder): + + if not isinstance(holder._con, AsyncIOConnection): + raise errors.InterfaceError( + f'release() received invalid connection: ' + f'{holder._con!r} does not belong to any connection pool' + ) + + timeout = None + + # Use asyncio.shield() to guarantee that task cancellation + # does not prevent the connection from being returned to the + # pool properly. + return await asyncio.shield(holder.release(timeout)) + + async def aclose(self): + """Attempt to gracefully close all connections in the pool. + + Wait until all pool connections are released, close them and + shut down the pool. If any error (including cancellation) occurs + in ``close()`` the pool will terminate by calling + _AsyncIOPoolImpl.terminate() . + + It is advisable to use :func:`python:asyncio.wait_for` to set + a timeout. + """ + if self._closed: + return + + if not self._loop: + self._closed = True + return + + self._closing = True + + try: + warning_callback = self._loop.call_later( + 60, self._warn_on_long_close) + + release_coros = [ + ch.wait_until_released() for ch in self._holders] + await asyncio.gather(*release_coros) + + close_coros = [ + ch.close() for ch in self._holders] + await asyncio.gather(*close_coros) + + except (Exception, asyncio.CancelledError): + self.terminate() + raise + + finally: + warning_callback.cancel() + self._closed = True + self._closing = False + + def _warn_on_long_close(self): + logger.warning( + 'AsyncIOClient.aclose() is taking over 60 seconds to complete. ' + 'Check if you have any unreleased connections left. ' + 'Use asyncio.wait_for() to set a timeout for ' + 'AsyncIOClient.aclose().') + + +class AsyncIOIteration(transaction.BaseTransaction, abstract.AsyncIOExecutor): + + __slots__ = ("_managed",) + + def __init__(self, retry, client, iteration): + super().__init__(retry, client, iteration) + self._managed = False + + async def __aenter__(self): + if self._managed: + raise errors.InterfaceError( + 'cannot enter context: already in an `async with` block') + self._managed = True + return self + + async def __aexit__(self, extype, ex, tb): + self._managed = False + return await self._exit(extype, ex) + + async def _ensure_transaction(self): + if not self._managed: + raise errors.InterfaceError( + "Only managed retriable transactions are supported. " + "Use `async with transaction:`" + ) + await super()._ensure_transaction() + + +class AsyncIORetry(transaction.BaseRetry): + + def __aiter__(self): + return self + + async def __anext__(self): + # Note: when changing this code consider also + # updating Retry.__next__. + if self._done: + raise StopAsyncIteration + if self._next_backoff: + await asyncio.sleep(self._next_backoff) + self._done = True + iteration = AsyncIOIteration(self, self._owner, self._iteration) + self._iteration += 1 + return iteration + + +class AsyncIOClient(base_client.BaseClient, abstract.AsyncIOExecutor): + """A lazy connection pool. + + A Client can be used to manage a set of connections to the database. + Connections are first acquired from the pool, then used, and then released + back to the pool. Once a connection is released, it's reset to close all + open cursors and other resources *except* prepared statements. + + Clients are created by calling + :func:`~edgedb.asyncio_client.create_async_client`. + """ + + __slots__ = () + _impl_class = _AsyncIOPoolImpl + + async def ensure_connected(self): + await self._impl.ensure_connected() + return self + + async def aclose(self): + """Attempt to gracefully close all connections in the pool. + + Wait until all pool connections are released, close them and + shut down the pool. If any error (including cancellation) occurs + in ``aclose()`` the pool will terminate by calling + AsyncIOClient.terminate() . + + It is advisable to use :func:`python:asyncio.wait_for` to set + a timeout. + """ + await self._impl.aclose() + + def transaction(self) -> AsyncIORetry: + return AsyncIORetry(self) + + async def __aenter__(self): + return await self.ensure_connected() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.aclose() + + +def create_async_client( + dsn=None, + *, + max_concurrency=None, + host: str = None, + port: int = None, + credentials: str = None, + credentials_file: str = None, + user: str = None, + password: str = None, + database: str = None, + tls_ca: str = None, + tls_ca_file: str = None, + tls_security: str = None, + wait_until_available: int = 30, + timeout: int = 10, +): + return AsyncIOClient( + connection_class=AsyncIOConnection, + max_concurrency=max_concurrency, + + # connect arguments + dsn=dsn, + host=host, + port=port, + credentials=credentials, + credentials_file=credentials_file, + user=user, + password=password, + database=database, + tls_ca=tls_ca, + tls_ca_file=tls_ca_file, + tls_security=tls_security, + wait_until_available=wait_until_available, + timeout=timeout, + ) diff --git a/edgedb/asyncio_con.py b/edgedb/asyncio_con.py deleted file mode 100644 index bf4c0ea0..00000000 --- a/edgedb/asyncio_con.py +++ /dev/null @@ -1,583 +0,0 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import asyncio -import functools -import random -import socket -import ssl -import time -import typing - -from . import abstract -from . import base_con -from . import compat -from . import con_utils -from . import errors -from . import enums -from . import options -from . import retry as _retry - -from .datatypes import datatypes -from .protocol import asyncio_proto -from .protocol import protocol -from .protocol.protocol import CodecsRegistry as _CodecsRegistry -from .protocol.protocol import QueryCodecsCache as _QueryCodecsCache - - -class _AsyncIOConnectionImpl: - - def __init__(self, codecs_registry, query_cache): - self._addr = None - self._transport = None - self._protocol = None - self._codecs_registry = codecs_registry - self._query_cache = query_cache - - def is_closed(self): - protocol = self._protocol - return protocol is None or not protocol.connected - - async def connect(self, loop, addrs, config, params, *, - single_attempt=False, connection): - addr = None - start = time.monotonic() - if single_attempt: - max_time = 0 - else: - max_time = start + config.wait_until_available - iteration = 1 - - while True: - for addr in addrs: - try: - await compat.wait_for( - self._connect_addr(loop, addr, params, connection), - config.connect_timeout, - ) - except asyncio.TimeoutError as e: - if iteration == 1 or time.monotonic() < max_time: - continue - else: - raise errors.ClientConnectionTimeoutError( - f"connecting to {addr} failed in" - f" {config.connect_timeout} sec" - ) from e - except errors.ClientConnectionError as e: - if ( - e.has_tag(errors.SHOULD_RECONNECT) and - (iteration == 1 or time.monotonic() < max_time) - ): - continue - nice_err = e.__class__( - con_utils.render_client_no_connection_error( - e, - addr, - attempts=iteration, - duration=time.monotonic() - start, - )) - raise nice_err from e.__cause__ - else: - return - - iteration += 1 - await asyncio.sleep(0.01 + random.random() * 0.2) - - async def _connect_addr(self, loop, addr, params, connection): - - factory = functools.partial( - asyncio_proto.AsyncIOProtocol, params, loop - ) - tr = None - - try: - if isinstance(addr, str): - # UNIX socket - tr, pr = await loop.create_unix_connection(factory, addr) - else: - try: - tr, pr = await loop.create_connection( - factory, *addr, ssl=params.ssl_ctx - ) - except ssl.CertificateError as e: - raise con_utils.wrap_error(e) from e - except ssl.SSLError as e: - if e.reason == 'CERTIFICATE_VERIFY_FAILED': - raise con_utils.wrap_error(e) from e - tr, pr = await loop.create_connection( - functools.partial(factory, tls_compat=True), *addr - ) - else: - con_utils.check_alpn_protocol( - tr.get_extra_info('ssl_object') - ) - except socket.gaierror as e: - # All name resolution errors are considered temporary - raise errors.ClientConnectionFailedTemporarilyError(str(e)) from e - except OSError as e: - raise con_utils.wrap_error(e) from e - except Exception: - if tr is not None: - tr.close() - raise - - pr.set_connection(connection._inner) - - try: - await pr.connect() - except OSError as e: - if tr is not None: - tr.close() - raise con_utils.wrap_error(e) from e - except Exception: - if tr is not None: - tr.close() - raise - - self._transport = tr - self._protocol = pr - self._addr = addr - - async def privileged_execute(self, query): - await self._protocol.simple_query(query, enums.Capability.ALL) - - async def aclose(self): - """Send graceful termination message wait for connection to drop.""" - if not self.is_closed(): - try: - self._protocol.terminate() - await self._protocol.wait_for_disconnect() - except (Exception, asyncio.CancelledError): - self.terminate() - raise - - def terminate(self): - if not self.is_closed(): - self._protocol.abort() - - -class _AsyncIOInnerConnection(base_con._InnerConnection): - - def __init__(self, loop, addrs, config, params, *, - codecs_registry=None, query_cache=None): - super().__init__( - addrs, config, params, - codecs_registry=codecs_registry, query_cache=query_cache) - self._loop = loop - - def _detach(self): - impl = self._impl - self._impl = None - new_conn = self.__class__( - self._loop, self._addrs, self._config, self._params, - codecs_registry=self._codecs_registry, - query_cache=self._query_cache) - new_conn._impl = impl - impl._protocol.set_connection(new_conn) - return new_conn - - def _dispatch_log_message(self, msg): - for cb in self._log_listeners: - self._loop.call_soon(cb, self, msg) - - -class AsyncIOConnection( - base_con.BaseConnection, - abstract.AsyncIOExecutor, - options._OptionsMixin, -): - - def __init__(self, loop, addrs, config, params, *, - codecs_registry, query_cache): - self._inner = _AsyncIOInnerConnection( - loop, addrs, config, params, - codecs_registry=codecs_registry, - query_cache=query_cache) - super().__init__() - - def _shallow_clone(self): - if self._inner._borrowed_for: - raise base_con.borrow_error(self._inner._borrowed_for) - new_conn = self.__class__.__new__(self.__class__) - new_conn._inner = self._inner - return new_conn - - def __repr__(self): - if self.is_closed(): - return '<{classname} [closed] {id:#x}>'.format( - classname=self.__class__.__name__, id=id(self)) - else: - return '<{classname} [connected to {addr}] {id:#x}>'.format( - classname=self.__class__.__name__, - addr=self.connected_addr(), - id=id(self)) - - async def ensure_connected(self, *, single_attempt=False): - inner = self._inner - if inner._borrowed_for: - raise base_con.borrow_error(inner._borrowed_for) - if not inner._impl or inner._impl.is_closed(): - await self._reconnect(single_attempt=single_attempt) - - # overriden by connection pool - async def _reconnect(self, single_attempt=False): - inner = self._inner - impl = _AsyncIOConnectionImpl( - inner._codecs_registry, inner._query_cache) - await impl.connect(inner._loop, inner._addrs, - inner._config, inner._params, - single_attempt=single_attempt, - connection=self) - inner._impl = impl - - async def _fetchall( - self, - query: str, - *args, - __limit__: int=0, - __typeids__: bool=False, - __typenames__: bool=False, - __allow_capabilities__: typing.Optional[int]=None, - **kwargs, - ) -> datatypes.Set: - inner = self._inner - if inner._borrowed_for: - raise base_con.borrow_error(inner._borrowed_for) - if not inner._impl or inner._impl.is_closed(): - await self._reconnect() - result, _ = await inner._impl._protocol.execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=inner._codecs_registry, - qc=inner._query_cache, - implicit_limit=__limit__, - inline_typeids=__typeids__, - inline_typenames=__typenames__, - io_format=protocol.IoFormat.BINARY, - allow_capabilities=__allow_capabilities__, - ) - return result - - async def _fetchall_with_headers( - self, - query: str, - *args, - __limit__: int=0, - __typeids__: bool=False, - __typenames__: bool=False, - __allow_capabilities__: typing.Optional[int]=None, - **kwargs, - ) -> typing.Tuple[datatypes.Set, typing.Dict[int, bytes]]: - inner = self._inner - if inner._borrowed_for: - raise base_con.borrow_error(inner._borrowed_for) - if not inner._impl or inner._impl.is_closed(): - await self._reconnect() - return await inner._impl._protocol.execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=inner._codecs_registry, - qc=inner._query_cache, - implicit_limit=__limit__, - inline_typeids=__typeids__, - inline_typenames=__typenames__, - io_format=protocol.IoFormat.BINARY, - allow_capabilities=__allow_capabilities__, - ) - - async def _fetchall_json( - self, - query: str, - *args, - __limit__: int=0, - **kwargs, - ) -> datatypes.Set: - inner = self._inner - if inner._borrowed_for: - raise base_con.borrow_error(inner._borrowed_for) - if not inner._impl or inner._impl.is_closed(): - await self._reconnect() - result, _ = await inner._impl._protocol.execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=inner._codecs_registry, - qc=inner._query_cache, - implicit_limit=__limit__, - inline_typenames=False, - io_format=protocol.IoFormat.JSON, - ) - return result - - async def _execute( - self, - query: str, - args, - kwargs, - io_format, - expect_one=False, - required_one=False, - ): - inner = self._inner - if inner._borrowed_for: - raise base_con.borrow_error(inner._borrowed_for) - if not inner._impl or inner._impl.is_closed(): - await self._reconnect() - - reconnect = False - capabilities = None - i = 0 - while True: - i += 1 - try: - if reconnect: - await self._reconnect(single_attempt=True) - result, _ = \ - await self._inner._impl._protocol.execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=inner._codecs_registry, - qc=inner._query_cache, - io_format=io_format, - expect_one=expect_one, - required_one=required_one, - allow_capabilities=enums.Capability.EXECUTE, - ) - return result - except errors.EdgeDBError as e: - if not e.has_tag(errors.SHOULD_RETRY): - raise e - if capabilities is None: - cache_item = inner._query_cache.get( - query=query, - io_format=io_format, - implicit_limit=0, - inline_typenames=False, - inline_typeids=False, - expect_one=expect_one, - ) - if cache_item is not None: - _, _, _, capabilities = cache_item - # A query is read-only if it has no capabilities i.e. - # capabilities == 0. Read-only queries are safe to retry. - # Explicit transaction conflicts as well. - if ( - capabilities != 0 - and not isinstance(e, errors.TransactionConflictError) - ): - raise e - rule = self._options.retry_options.get_rule_for_exception(e) - if i >= rule.attempts: - raise e - await asyncio.sleep(rule.backoff(i)) - reconnect = self.is_closed() - - async def query(self, query: str, *args, **kwargs) -> datatypes.Set: - return await self._execute( - query=query, - args=args, - kwargs=kwargs, - io_format=protocol.IoFormat.BINARY, - ) - - async def query_single( - self, query: str, *args, **kwargs - ) -> typing.Union[typing.Any, None]: - return await self._execute( - query=query, - args=args, - kwargs=kwargs, - expect_one=True, - io_format=protocol.IoFormat.BINARY, - ) - - async def query_required_single( - self, query: str, *args, **kwargs - ) -> typing.Any: - return await self._execute( - query=query, - args=args, - kwargs=kwargs, - expect_one=True, - required_one=True, - io_format=protocol.IoFormat.BINARY, - ) - - async def query_json(self, query: str, *args, **kwargs) -> str: - return await self._execute( - query=query, - args=args, - kwargs=kwargs, - io_format=protocol.IoFormat.JSON, - ) - - async def _fetchall_json_elements( - self, query: str, *args, **kwargs) -> typing.List[str]: - inner = self._inner - if inner._borrowed_for: - raise base_con.borrow_error(inner._borrowed_for) - if not inner._impl or inner._impl.is_closed(): - await self._reconnect() - result, _ = await inner._impl._protocol.execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=inner._codecs_registry, - qc=inner._query_cache, - io_format=protocol.IoFormat.JSON_ELEMENTS, - allow_capabilities=enums.Capability.EXECUTE, - ) - return result - - async def query_single_json(self, query: str, *args, **kwargs) -> str: - return await self._execute( - query=query, - args=args, - kwargs=kwargs, - io_format=protocol.IoFormat.JSON, - expect_one=True, - ) - - async def query_required_single_json( - self, query: str, *args, **kwargs - ) -> str: - return await self._execute( - query=query, - args=args, - kwargs=kwargs, - io_format=protocol.IoFormat.JSON, - expect_one=True, - required_one=True - ) - - async def execute(self, query: str) -> None: - """Execute an EdgeQL command (or commands). - - Example: - - .. code-block:: pycon - - >>> await con.execute(''' - ... CREATE TYPE MyType { CREATE PROPERTY a -> int64 }; - ... FOR x IN {100, 200, 300} UNION INSERT MyType { a := x }; - ... ''') - """ - inner = self._inner - if inner._borrowed_for: - raise base_con.borrow_error(inner._borrowed_for) - if not inner._impl or inner._impl.is_closed(): - await self._reconnect() - await inner._impl._protocol.simple_query( - query, enums.Capability.EXECUTE) - - def transaction(self) -> _retry.AsyncIORetry: - return _retry.AsyncIORetry(self) - - async def aclose(self) -> None: - try: - await self._inner._impl.aclose() - finally: - self._cleanup() - - def terminate(self) -> None: - try: - self._inner._impl.terminate() - finally: - self._cleanup() - - def _set_proxy(self, proxy): - if self._proxy is not None and proxy is not None: - # Should not happen unless there is a bug in `Pool`. - raise errors.InterfaceError( - 'internal client error: connection is already proxied') - - self._proxy = proxy - - def is_closed(self) -> bool: - return self._inner._impl.is_closed() - - -async def async_connect_raw( - dsn: str = None, - *, - host: str = None, - port: int = None, - credentials: str = None, - credentials_file: str = None, - user: str = None, - password: str = None, - database: str = None, - tls_ca: str = None, - tls_ca_file: str = None, - tls_security: str = None, - connection_class=None, - wait_until_available: int = 30, - timeout: int = 10, -) -> AsyncIOConnection: - - loop = asyncio.get_event_loop() - - if connection_class is None: - connection_class = AsyncIOConnection - - connect_config, client_config = con_utils.parse_connect_arguments( - dsn=dsn, - host=host, - port=port, - credentials=credentials, - credentials_file=credentials_file, - user=user, - password=password, - database=database, - timeout=timeout, - tls_ca=tls_ca, - tls_ca_file=tls_ca_file, - tls_security=tls_security, - wait_until_available=wait_until_available, - - # ToDos - command_timeout=None, - server_settings=None, - ) - - connection = connection_class( - loop, [connect_config.address], client_config, connect_config, - codecs_registry=_CodecsRegistry(), - query_cache=_QueryCodecsCache(), - ) - await connection.ensure_connected() - return connection - - -async def _connect_addr(loop, addrs, config, params, - query_cache, codecs_registry, connection_class): - - if connection_class is None: - connection_class = AsyncIOConnection - - connection = connection_class( - loop, addrs, config, params, - codecs_registry=codecs_registry, - query_cache=query_cache, - ) - await connection.ensure_connected() - return connection diff --git a/edgedb/asyncio_pool.py b/edgedb/asyncio_pool.py deleted file mode 100644 index 6397a247..00000000 --- a/edgedb/asyncio_pool.py +++ /dev/null @@ -1,685 +0,0 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import asyncio -import logging -import typing - -from . import abstract -from . import asyncio_con -from . import compat -from . import errors -from . import options -from . import retry as _retry - - -__all__ = ( - 'create_async_client', 'AsyncIOClient' -) - - -logger = logging.getLogger(__name__) - - -class PoolConnection(asyncio_con.AsyncIOConnection): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._inner._holder = None - self._inner._detached = False - - async def _reconnect(self, single_attempt=False): - if self._inner._detached: - # initial connection - raise errors.InterfaceError( - "the underlying connection has been released back to the pool" - ) - return await super()._reconnect(single_attempt=single_attempt) - - def _detach(self): - new_conn = self._shallow_clone() - inner = self._inner - holder = inner._holder - inner._holder = None - inner._detached = True - new_conn._inner = self._inner._detach() - new_conn._inner._holder = holder - new_conn._inner._detached = False - return new_conn - - def _cleanup(self): - if self._inner._holder: - self._inner._holder._release_on_close() - super()._cleanup() - - def __repr__(self): - if self._inner._holder is None: - return '<{classname} [released] {id:#x}>'.format( - classname=self.__class__.__name__, id=id(self)) - else: - return super().__repr__() - - -class PoolConnectionHolder: - - __slots__ = ('_con', '_pool', - '_on_acquire', '_on_release', - '_in_use', '_timeout', '_generation') - - def __init__(self, pool, *, on_acquire, on_release): - - self._pool = pool - self._con = None - - self._on_acquire = on_acquire - self._on_release = on_release - self._in_use = None # type: asyncio.Future - self._timeout = None - self._generation = None - - async def connect(self): - if self._con is not None: - raise errors.InternalClientError( - 'PoolConnectionHolder.connect() called while another ' - 'connection already exists') - - self._con = await self._pool._get_new_connection() - assert self._con._inner._holder is None - self._con._inner._holder = self - self._generation = self._pool._generation - - async def acquire(self) -> PoolConnection: - if self._con is None or self._con.is_closed(): - self._con = None - await self.connect() - - elif self._generation != self._pool._generation: - # Connections have been expired, re-connect the holder. - self._pool._loop.create_task( - self._con.aclose(timeout=self._timeout)) - self._con = None - await self.connect() - - if self._on_acquire is not None: - try: - await self._on_acquire(self._con) - except (Exception, asyncio.CancelledError) as ex: - # If a user-defined `on_acquire` function fails, we don't - # know if the connection is safe for re-use, hence - # we close it. A new connection will be created - # when `acquire` is called again. - try: - # Use `close()` to close the connection gracefully. - # An exception in `on_acquire` isn't necessarily caused - # by an IO or a protocol error. close() will - # do the necessary cleanup via _release_on_close(). - await self._con.aclose() - finally: - raise ex - - self._in_use = self._pool._loop.create_future() - - return self._con - - async def release(self, timeout): - if self._in_use is None: - raise errors.InternalClientError( - 'PoolConnectionHolder.release() called on ' - 'a free connection holder') - - if self._con.is_closed(): - # This is usually the case when the connection is broken rather - # than closed by the user, so we need to call _release_on_close() - # here to release the holder back to the queue, because - # self._con._cleanup() was never called. On the other hand, it is - # safe to call self._release() twice - the second call is no-op. - self._release_on_close() - return - - self._timeout = None - - if self._generation != self._pool._generation: - # The connection has expired because it belongs to - # an older generation (AsyncIOPool.expire_connections() has - # been called.) - await self._con.aclose() - return - - if self._on_release is not None: - try: - await self._on_release(self._con) - except (Exception, asyncio.CancelledError) as ex: - # If a user-defined `on_release` function fails, we don't - # know if the connection is safe for re-use, hence - # we close it. A new connection will be created - # when `acquire` is called again. - try: - # Use `close()` to close the connection gracefully. - # An exception in `setup` isn't necessarily caused - # by an IO or a protocol error. close() will - # do the necessary cleanup via _release_on_close(). - await self._con.aclose() - finally: - raise ex - - # Free this connection holder and invalidate the - # connection proxy. - self._release() - - async def wait_until_released(self): - if self._in_use is None: - return - else: - await self._in_use - - async def aclose(self): - if self._con is not None: - # AsyncIOConnection.aclose() will call _release_on_close() to - # finish holder cleanup. - await self._con.aclose() - - def terminate(self): - if self._con is not None: - # AsyncIOConnection.terminate() will call _release_on_close() to - # finish holder cleanup. - self._con.terminate() - - def _release_on_close(self): - self._release() - self._con = None - - def _release(self): - """Release this connection holder.""" - if self._in_use is None: - # The holder is not checked out. - return - - if not self._in_use.done(): - self._in_use.set_result(None) - self._in_use = None - - self._con = self._con._detach() - - # Put ourselves back to the pool queue. - self._pool._queue.put_nowait(self) - - -class _AsyncIOPoolImpl: - __slots__ = ('_queue', '_loop', '_user_concurrency', '_concurrency', - '_on_connect', '_connect_args', '_connect_kwargs', - '_working_addr', '_working_config', '_working_params', - '_codecs_registry', '_query_cache', - '_holders', '_initialized', '_initializing', '_closing', - '_closed', '_connection_class', '_generation', - '_on_acquire', '_on_release') - - def __init__(self, *connect_args, - concurrency: typing.Optional[int], - on_acquire, - on_release, - on_connect, - connection_class, - **connect_kwargs): - super().__init__() - - loop = asyncio.get_event_loop() - self._loop = loop - - if concurrency is not None and concurrency <= 0: - raise ValueError('concurrency is expected to be greater than zero') - - if not issubclass(connection_class, PoolConnection): - raise TypeError( - f'connection_class is expected to be a subclass of ' - f'edgedb.asyncio_pool.PoolConnection, ' - f'got {connection_class}') - - self._user_concurrency = concurrency - self._concurrency = concurrency if concurrency else 1 - - self._on_acquire = on_acquire - self._on_release = on_release - - self._holders = [] - self._queue = asyncio.LifoQueue(maxsize=self._concurrency) - - self._working_addr = None - self._working_config = None - self._working_params = None - - self._connection_class = connection_class - - self._closing = False - self._closed = False - self._generation = 0 - self._on_connect = on_connect - self._connect_args = connect_args - self._connect_kwargs = connect_kwargs - - self._resize_holder_pool() - - def _resize_holder_pool(self): - resize_diff = self._concurrency - len(self._holders) - - if (resize_diff > 0): - if self._queue.maxsize != self._concurrency: - self._queue._maxsize = self._concurrency - - for _ in range(resize_diff): - ch = PoolConnectionHolder( - self, - on_acquire=self._on_acquire, - on_release=self._on_release) - - self._holders.append(ch) - self._queue.put_nowait(ch) - elif resize_diff < 0: - # TODO: shrink the pool - pass - - def set_connect_args(self, dsn=None, **connect_kwargs): - r"""Set the new connection arguments for this pool. - - The new connection arguments will be used for all subsequent - new connection attempts. Existing connections will remain until - they expire. Use AsyncIOPool.expire_connections() to expedite - the connection expiry. - - :param str dsn: - Connection arguments specified using as a single string in - the following format: - ``edgedb://user:pass@host:port/database?option=value``. - - :param \*\*connect_kwargs: - Keyword arguments for the :func:`~edgedb.asyncio_con.connect` - function. - """ - - self._connect_args = [dsn] - self._connect_kwargs = connect_kwargs - self._working_addr = None - self._working_config = None - self._working_params = None - self._codecs_registry = None - self._query_cache = None - - async def _get_new_connection(self): - if self._working_addr is None: - # First connection attempt on this pool. - con = await asyncio_con.async_connect_raw( - *self._connect_args, - connection_class=self._connection_class, - **self._connect_kwargs) - - self._working_addr = con.connected_addr() - self._working_config = con._inner._config - self._working_params = con._inner._params - self._codecs_registry = con._inner._codecs_registry - self._query_cache = con._inner._query_cache - - if self._user_concurrency is None: - suggested_concurrency = con.get_settings().get( - 'suggested_pool_concurrency') - if suggested_concurrency: - self._concurrency = suggested_concurrency - self._resize_holder_pool() - - else: - # We've connected before and have a resolved address, - # and parsed options and config. - con = await asyncio_con._connect_addr( - loop=self._loop, - addrs=[self._working_addr], - config=self._working_config, - params=self._working_params, - query_cache=self._query_cache, - codecs_registry=self._codecs_registry, - connection_class=self._connection_class) - - if self._on_connect is not None: - try: - await self._on_connect(con) - except (Exception, asyncio.CancelledError) as ex: - # If a user-defined `connect` function fails, we don't - # know if the connection is safe for re-use, hence - # we close it. A new connection will be created - # when `acquire` is called again. - try: - # Use `close()` to close the connection gracefully. - # An exception in `init` isn't necessarily caused - # by an IO or a protocol error. close() will - # do the necessary cleanup via _release_on_close(). - await con.aclose() - finally: - raise ex - - return con - - async def _acquire(self, timeout, options): - async def _acquire_impl(): - ch = await self._queue.get() # type: PoolConnectionHolder - try: - proxy = await ch.acquire() # type: PoolConnection - except (Exception, asyncio.CancelledError): - self._queue.put_nowait(ch) - raise - else: - # Record the timeout, as we will apply it by default - # in release(). - ch._timeout = timeout - proxy._options = options - return proxy - - if self._closing: - raise errors.InterfaceError('pool is closing') - - if timeout is None: - return await _acquire_impl() - else: - return await compat.wait_for( - _acquire_impl(), timeout=timeout) - - async def release(self, connection): - - if not isinstance(connection, PoolConnection): - raise errors.InterfaceError( - f'AsyncIOPool.release() received invalid connection: ' - f'{connection!r} does not belong to any connection pool' - ) - - ch = connection._inner._holder - if ch is None: - # Already released, do nothing. - return - - if ch._pool is not self: - raise errors.InterfaceError( - f'AsyncIOPool.release() received invalid connection: ' - f'{connection!r} is not a member of this pool' - ) - - timeout = None - - # Use asyncio.shield() to guarantee that task cancellation - # does not prevent the connection from being returned to the - # pool properly. - return await asyncio.shield(ch.release(timeout)) - - async def aclose(self): - """Attempt to gracefully close all connections in the pool. - - Wait until all pool connections are released, close them and - shut down the pool. If any error (including cancellation) occurs - in ``close()`` the pool will terminate by calling - AsyncIOPool.terminate() . - - It is advisable to use :func:`python:asyncio.wait_for` to set - a timeout. - """ - if self._closed: - return - - self._closing = True - - try: - warning_callback = self._loop.call_later( - 60, self._warn_on_long_close) - - release_coros = [ - ch.wait_until_released() for ch in self._holders] - await asyncio.gather(*release_coros) - - close_coros = [ - ch.aclose() for ch in self._holders] - await asyncio.gather(*close_coros) - - except (Exception, asyncio.CancelledError): - self.terminate() - raise - - finally: - warning_callback.cancel() - self._closed = True - self._closing = False - - def _warn_on_long_close(self): - logger.warning( - 'AsyncIOPool.aclose() is taking over 60 seconds to complete. ' - 'Check if you have any unreleased connections left. ' - 'Use asyncio.wait_for() to set a timeout for ' - 'AsyncIOPool.aclose().') - - def terminate(self): - """Terminate all connections in the pool.""" - if self._closed: - return - for ch in self._holders: - ch.terminate() - self._closed = True - - async def expire_connections(self): - """Expire all currently open connections. - - Cause all currently open connections to get replaced on the - next AsyncIOPool.acquire() call. - """ - self._generation += 1 - - def _drop_statement_cache(self): - # Drop statement cache for all connections in the pool. - for ch in self._holders: - if ch._con is not None: - ch._con._drop_local_statement_cache() - - def _drop_type_cache(self): - # Drop type codec cache for all connections in the pool. - for ch in self._holders: - if ch._con is not None: - ch._con._drop_local_type_cache() - - -class AsyncIOClient(abstract.AsyncIOExecutor, options._OptionsMixin): - """A lazy connection pool. - - A Client can be used to manage a set of connections to the database. - Connections are first acquired from the pool, then used, and then released - back to the pool. Once a connection is released, it's reset to close all - open cursors and other resources *except* prepared statements. - - Clients are created by calling - :func:`~edgedb.asyncio_pool.create_async_client`. - """ - - __slots__ = ('_impl', '_options') - - def __init__(self, *connect_args, - concurrency: int, - on_acquire, - on_release, - on_connect, - connection_class, - **connect_kwargs): - super().__init__() - self._impl = _AsyncIOPoolImpl( - *connect_args, - concurrency=concurrency, - on_acquire=on_acquire, - on_release=on_release, - on_connect=on_connect, - connection_class=connection_class, - **connect_kwargs, - ) - - @property - def concurrency(self) -> int: - """Max number of connections in the pool.""" - - return self._impl._concurrency - - async def ensure_connected(self): - for ch in self._impl._holders: - if ch._con is not None and ch._con.is_closed(): - return self - - ch = self._impl._holders[0] - ch._con = None - await ch.connect() - - return self - - async def query(self, query, *args, **kwargs): - async with self._acquire() as con: - return await con.query(query, *args, **kwargs) - - async def query_single(self, query, *args, **kwargs): - async with self._acquire() as con: - return await con.query_single(query, *args, **kwargs) - - async def query_required_single(self, query, *args, **kwargs): - async with self._acquire() as con: - return await con.query_required_single(query, *args, **kwargs) - - async def query_json(self, query, *args, **kwargs): - async with self._acquire() as con: - return await con.query_json(query, *args, **kwargs) - - async def query_single_json(self, query, *args, **kwargs): - async with self._acquire() as con: - return await con.query_single_json(query, *args, **kwargs) - - async def query_required_single_json(self, query, *args, **kwargs): - async with self._acquire() as con: - return await con.query_required_single_json(query, *args, **kwargs) - - async def execute(self, query): - async with self._acquire() as con: - return await con.execute(query) - - def _acquire(self): - return PoolAcquireContext(self, timeout=None, options=self._options) - - async def _release(self, connection): - await self._impl.release(connection) - - async def aclose(self): - """Attempt to gracefully close all connections in the pool. - - Wait until all pool connections are released, close them and - shut down the pool. If any error (including cancellation) occurs - in ``close()`` the pool will terminate by calling - AsyncIOPool.terminate() . - - It is advisable to use :func:`python:asyncio.wait_for` to set - a timeout. - """ - await self._impl.aclose() - - def terminate(self): - """Terminate all connections in the pool.""" - self._impl.terminate() - - def transaction(self) -> _retry.AsyncIORetry: - return _retry.AsyncIORetry(self) - - def _shallow_clone(self): - new_pool = self.__class__.__new__(self.__class__) - new_pool._options = self._options - new_pool._impl = self._impl - return new_pool - - -class PoolAcquireContext: - - __slots__ = ('timeout', 'connection', 'done', 'pool') - - def __init__(self, pool, timeout, options): - self.pool = pool - self.timeout = timeout - self.connection = None - self.done = False - - async def __aenter__(self): - if self.connection is not None or self.done: - raise errors.InterfaceError('a connection is already acquired') - self.connection = await self.pool._impl._acquire( - self.timeout, - self.pool._options, - ) - return self.connection - - async def __aexit__(self, *exc): - self.done = True - con = self.connection - self.connection = None - await self.pool._release(con) - - def __await__(self): - self.done = True - return self.pool._impl._acquire( - self.timeout, - self.pool._options, - ).__await__() - - -def create_async_client( - dsn=None, - *, - concurrency=None, - **connect_kwargs -): - r"""Create an AsyncIOClient with a lazy connection pool. - - .. code-block:: python - - client = edgedb.create_async_client(user='edgedb') - con = await client.acquire() - try: - await con.fetchall('SELECT {1, 2, 3}') - finally: - await pool.release(con) - - :param str dsn: - If this parameter does not start with ``edgedb://`` then this is - a :ref:`name of an instance `. - - Otherwies it specifies as a single string in the following format: - ``edgedb://user:pass@host:port/database?option=value``. - - :param \*\*connect_kwargs: - Keyword arguments for the async_connect() function. - - :param Connection connection_class: - The class to use for connections. Must be a subclass of - :class:`~edgedb.asyncio_con.AsyncIOConnection`. - - :param int concurrency: - Max number of connections in the pool. If not set, the suggested - concurrency value sent by the server will be used. - - :return: An instance of :class:`~edgedb.AsyncIOPool`. - """ - return AsyncIOClient( - dsn, - connection_class=PoolConnection, - concurrency=concurrency, - on_acquire=None, - on_release=None, - on_connect=None, - **connect_kwargs - ) diff --git a/edgedb/base_client.py b/edgedb/base_client.py new file mode 100644 index 00000000..eb097728 --- /dev/null +++ b/edgedb/base_client.py @@ -0,0 +1,697 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import abc +import random +import time +import typing + +from . import abstract +from . import con_utils +from . import enums +from . import errors +from . import options as _options +from .protocol import protocol + + +BaseConnection_T = typing.TypeVar('BaseConnection_T', bound='BaseConnection') + + +class BaseConnection(metaclass=abc.ABCMeta): + _protocol: typing.Any + _addr: typing.Optional[typing.Union[str, typing.Tuple[str, int]]] + _addrs: typing.Iterable[typing.Union[str, typing.Tuple[str, int]]] + _config: con_utils.ClientConfiguration + _params: con_utils.ResolvedConnectConfig + _log_listeners: typing.Set[ + typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], None] + ] + _close_exceptions = (Exception,) + __slots__ = ( + "__weakref__", + "_protocol", + "_addr", + "_addrs", + "_config", + "_params", + "_log_listeners", + "_holder", + ) + + def __init__( + self, + addrs: typing.Iterable[typing.Union[str, typing.Tuple[str, int]]], + config: con_utils.ClientConfiguration, + params: con_utils.ResolvedConnectConfig, + ): + self._addr = None + self._protocol = None + self._addrs = addrs + self._config = config + self._params = params + self._log_listeners = set() + self._holder = None + + @abc.abstractmethod + def _dispatch_log_message(self, msg): + ... + + def _on_log_message(self, msg): + if self._log_listeners: + self._dispatch_log_message(msg) + + def connected_addr(self): + return self._addr + + def _get_last_status(self) -> typing.Optional[str]: + if self._protocol is None: + return None + status = self._protocol.last_status + if status is not None: + status = status.decode() + return status + + def _cleanup(self): + self._log_listeners.clear() + if self._holder: + self._holder._release_on_close() + self._holder = None + + def add_log_listener( + self: BaseConnection_T, + callback: typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], + None] + ) -> None: + """Add a listener for EdgeDB log messages. + + :param callable callback: + A callable receiving the following arguments: + **connection**: a Connection the callback is registered with; + **message**: the `edgedb.EdgeDBMessage` message. + """ + self._log_listeners.add(callback) + + def remove_log_listener( + self: BaseConnection_T, + callback: typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], + None] + ) -> None: + """Remove a listening callback for log messages.""" + self._log_listeners.discard(callback) + + @property + def dbname(self) -> str: + return self._params.database + + @abc.abstractmethod + def is_closed(self) -> bool: + ... + + @abc.abstractmethod + async def connect_addr(self, addr, timeout): + ... + + @abc.abstractmethod + async def sleep(self, seconds): + ... + + async def connect(self, *, single_attempt=False): + start = time.monotonic() + if single_attempt: + max_time = 0 + else: + max_time = start + self._config.wait_until_available + iteration = 1 + + while True: + for addr in self._addrs: + try: + await self.connect_addr(addr, self._config.connect_timeout) + except TimeoutError as e: + if iteration == 1 or time.monotonic() < max_time: + continue + else: + raise errors.ClientConnectionTimeoutError( + f"connecting to {addr} failed in" + f" {self._config.connect_timeout} sec" + ) from e + except errors.ClientConnectionError as e: + if ( + e.has_tag(errors.SHOULD_RECONNECT) and + (iteration == 1 or time.monotonic() < max_time) + ): + continue + nice_err = e.__class__( + con_utils.render_client_no_connection_error( + e, + addr, + attempts=iteration, + duration=time.monotonic() - start, + )) + raise nice_err from e.__cause__ + else: + return + + iteration += 1 + await self.sleep(0.01 + random.random() * 0.2) + + async def privileged_execute(self, query): + await self._protocol.simple_query(query, enums.Capability.ALL) + + def is_in_transaction(self) -> bool: + """Return True if Connection is currently inside a transaction. + + :return bool: True if inside transaction, False otherwise. + """ + return self._protocol.is_in_transaction() + + def get_settings(self) -> typing.Dict[str, typing.Any]: + return self._protocol.get_settings() + + async def raw_query(self, query_context: abstract.QueryContext): + if self.is_closed(): + await self.connect() + + reconnect = False + capabilities = None + i = 0 + while True: + i += 1 + try: + if reconnect: + await self.connect(single_attempt=True) + return await self._protocol.execute_anonymous( + query=query_context.query.query, + args=query_context.query.args, + kwargs=query_context.query.kwargs, + reg=query_context.cache.codecs_registry, + qc=query_context.cache.query_cache, + io_format=query_context.query_options.io_format, + expect_one=query_context.query_options.expect_one, + required_one=query_context.query_options.required_one, + allow_capabilities=enums.Capability.EXECUTE, + ) + except errors.EdgeDBError as e: + if query_context.retry_options is None: + raise + if not e.has_tag(errors.SHOULD_RETRY): + raise e + if capabilities is None: + cache_item = query_context.cache.query_cache.get( + query=query_context.query.query, + io_format=query_context.query_options.io_format, + implicit_limit=0, + inline_typenames=False, + inline_typeids=False, + expect_one=query_context.query_options.expect_one, + ) + if cache_item is not None: + _, _, _, capabilities = cache_item + # A query is read-only if it has no capabilities i.e. + # capabilities == 0. Read-only queries are safe to retry. + # Explicit transaction conflicts as well. + if ( + capabilities != 0 + and not isinstance(e, errors.TransactionConflictError) + ): + raise e + rule = query_context.retry_options.get_rule_for_exception(e) + if i >= rule.attempts: + raise e + await self.sleep(rule.backoff(i)) + reconnect = self.is_closed() + + async def execute(self, query: str) -> None: + await self._protocol.simple_query( + query, enums.Capability.EXECUTE + ) + + def terminate(self): + if not self.is_closed(): + try: + self._protocol.abort() + finally: + self._cleanup() + + async def close(self): + """Send graceful termination message wait for connection to drop.""" + if not self.is_closed(): + try: + self._protocol.terminate() + await self._protocol.wait_for_disconnect() + except self._close_exceptions: + self.terminate() + raise + finally: + self._cleanup() + + def __repr__(self): + if self.is_closed(): + return '<{classname} [closed] {id:#x}>'.format( + classname=self.__class__.__name__, id=id(self)) + else: + return '<{classname} [connected to {addr}] {id:#x}>'.format( + classname=self.__class__.__name__, + addr=self.connected_addr(), + id=id(self)) + + +class PoolConnectionHolder(abc.ABC): + __slots__ = ( + "_con", + "_pool", + "_release_event", + "_timeout", + "_generation", + ) + _event_class = NotImplemented + + def __init__(self, pool): + + self._pool = pool + self._con = None + + self._timeout = None + self._generation = None + + self._release_event = self._event_class() + self._release_event.set() + + @abc.abstractmethod + async def close(self, *, wait=True): + ... + + @abc.abstractmethod + async def wait_until_released(self, timeout=None): + ... + + async def connect(self): + if self._con is not None: + raise errors.InternalClientError( + 'PoolConnectionHolder.connect() called while another ' + 'connection already exists') + + self._con = await self._pool._get_new_connection() + assert self._con._holder is None + self._con._holder = self + self._generation = self._pool._generation + + async def acquire(self) -> BaseConnection: + if self._con is None or self._con.is_closed(): + self._con = None + await self.connect() + + elif self._generation != self._pool._generation: + # Connections have been expired, re-connect the holder. + self._con._holder = None # don't release the connection + await self.close(wait=False) + self._con = None + await self.connect() + + self._release_event.clear() + + return self._con + + async def release(self, timeout): + if self._release_event.is_set(): + raise errors.InternalClientError( + 'PoolConnectionHolder.release() called on ' + 'a free connection holder') + + if self._con.is_closed(): + # This is usually the case when the connection is broken rather + # than closed by the user, so we need to call _release_on_close() + # here to release the holder back to the queue, because + # self._con._cleanup() was never called. On the other hand, it is + # safe to call self._release() twice - the second call is no-op. + self._release_on_close() + return + + self._timeout = None + + if self._generation != self._pool._generation: + # The connection has expired because it belongs to + # an older generation (BasePoolImpl.expire_connections() has + # been called.) + await self.close() + return + + # Free this connection holder and invalidate the + # connection proxy. + self._release() + + def terminate(self): + if self._con is not None: + # AsyncIOConnection.terminate() will call _release_on_close() to + # finish holder cleanup. + self._con.terminate() + + def _release_on_close(self): + self._release() + self._con = None + + def _release(self): + """Release this connection holder.""" + if self._release_event.is_set(): + # The holder is not checked out. + return + + self._release_event.set() + + # Put ourselves back to the pool queue. + self._pool._queue.put_nowait(self) + + +class BasePoolImpl(abc.ABC): + __slots__ = ( + "_connect_args", + "_codecs_registry", + "_query_cache", + "_connection_factory", + "_queue", + "_user_max_concurrency", + "_max_concurrency", + "_first_connect_lock", + "_working_addr", + "_working_config", + "_working_params", + "_holders", + "_initialized", + "_initializing", + "_closing", + "_closed", + "_generation", + ) + + _holder_class = NotImplemented + + def __init__( + self, + connect_args, + connection_factory, + *, + max_concurrency: typing.Optional[int], + ): + self._connection_factory = connection_factory + self._connect_args = connect_args + self._codecs_registry = protocol.CodecsRegistry() + self._query_cache = protocol.QueryCodecsCache() + + if max_concurrency is not None and max_concurrency <= 0: + raise ValueError( + 'max_concurrency is expected to be greater than zero' + ) + + self._user_max_concurrency = max_concurrency + self._max_concurrency = max_concurrency if max_concurrency else 1 + + self._holders = [] + self._queue = None + + self._first_connect_lock = None + self._working_addr = None + self._working_config = None + self._working_params = None + + self._closing = False + self._closed = False + self._generation = 0 + + @abc.abstractmethod + def _ensure_initialized(self): + ... + + @abc.abstractmethod + def _set_queue_maxsize(self, maxsize): + ... + + @abc.abstractmethod + async def _maybe_get_first_connection(self): + ... + + @abc.abstractmethod + async def acquire(self, timeout=None): + ... + + @abc.abstractmethod + async def _release(self, connection): + ... + + @property + def codecs_registry(self): + return self._codecs_registry + + @property + def query_cache(self): + return self._query_cache + + def _resize_holder_pool(self): + resize_diff = self._max_concurrency - len(self._holders) + + if (resize_diff > 0): + if self._queue.maxsize != self._max_concurrency: + self._set_queue_maxsize(self._max_concurrency) + + for _ in range(resize_diff): + ch = self._holder_class(self) + + self._holders.append(ch) + self._queue.put_nowait(ch) + elif resize_diff < 0: + # TODO: shrink the pool + pass + + def get_max_concurrency(self): + return self._max_concurrency + + def get_free_size(self): + if self._queue is None: + # Queue has not been initialized yet + return self._max_concurrency + + return self._queue.qsize() + + def set_connect_args(self, dsn=None, **connect_kwargs): + r"""Set the new connection arguments for this pool. + + The new connection arguments will be used for all subsequent + new connection attempts. Existing connections will remain until + they expire. Use BasePoolImpl.expire_connections() to expedite + the connection expiry. + + :param str dsn: + Connection arguments specified using as a single string in + the following format: + ``edgedb://user:pass@host:port/database?option=value``. + + :param \*\*connect_kwargs: + Keyword arguments for the + :func:`~edgedb.asyncio_client.create_async_client` function. + """ + + connect_kwargs["dsn"] = dsn + self._connect_args = connect_kwargs + self._codecs_registry = protocol.CodecsRegistry() + self._query_cache = protocol.QueryCodecsCache() + self._working_addr = None + self._working_config = None + self._working_params = None + + async def _get_first_connection(self): + # First connection attempt on this pool. + connect_config, client_config = con_utils.parse_connect_arguments( + **self._connect_args, + # ToDos + command_timeout=None, + server_settings=None, + ) + con = self._connection_factory( + [connect_config.address], client_config, connect_config + ) + await con.connect() + self._working_addr = con.connected_addr() + self._working_config = client_config + self._working_params = connect_config + + if self._user_max_concurrency is None: + suggested_concurrency = con.get_settings().get( + 'suggested_pool_concurrency') + if suggested_concurrency: + self._max_concurrency = suggested_concurrency + self._resize_holder_pool() + return con + + async def _get_new_connection(self): + con = None + if self._working_addr is None: + con = await self._maybe_get_first_connection() + if con is None: + assert self._working_addr is not None + # We've connected before and have a resolved address, + # and parsed options and config. + con = self._connection_factory( + [self._working_addr], + self._working_config, + self._working_params, + ) + await con.connect() + + return con + + async def release(self, connection): + + if not isinstance(connection, BaseConnection): + raise errors.InterfaceError( + f'BasePoolImpl.release() received invalid connection: ' + f'{connection!r} does not belong to any connection pool' + ) + + ch = connection._holder + if ch is None: + # Already released, do nothing. + return + + if ch._pool is not self: + raise errors.InterfaceError( + f'BasePoolImpl.release() received invalid connection: ' + f'{connection!r} is not a member of this pool' + ) + + return await self._release(ch) + + def terminate(self): + """Terminate all connections in the pool.""" + if self._closed: + return + for ch in self._holders: + ch.terminate() + self._closed = True + + def expire_connections(self): + """Expire all currently open connections. + + Cause all currently open connections to get replaced on the + next query. + """ + self._generation += 1 + + async def ensure_connected(self): + self._ensure_initialized() + + for ch in self._holders: + if ch._con is not None and not ch._con.is_closed(): + return + + ch = self._holders[0] + ch._con = None + await ch.connect() + + +class BaseClient(abstract.BaseReadOnlyExecutor, _options._OptionsMixin): + __slots__ = ("_impl", "_options") + _impl_class = NotImplemented + + def __init__( + self, + *, + connection_class, + max_concurrency: typing.Optional[int], + dsn=None, + host: str = None, + port: int = None, + credentials: str = None, + credentials_file: str = None, + user: str = None, + password: str = None, + database: str = None, + tls_ca: str = None, + tls_ca_file: str = None, + tls_security: str = None, + wait_until_available: int = 30, + timeout: int = 10, + **kwargs, + ): + super().__init__() + connect_args = { + "dsn": dsn, + "host": host, + "port": port, + "credentials": credentials, + "credentials_file": credentials_file, + "user": user, + "password": password, + "database": database, + "timeout": timeout, + "tls_ca": tls_ca, + "tls_ca_file": tls_ca_file, + "tls_security": tls_security, + "wait_until_available": wait_until_available, + } + + self._impl = self._impl_class( + connect_args, + connection_class=connection_class, + max_concurrency=max_concurrency, + **kwargs, + ) + + def _shallow_clone(self): + new_client = self.__class__.__new__(self.__class__) + new_client._impl = self._impl + return new_client + + def _get_query_cache(self) -> abstract.QueryCache: + return abstract.QueryCache( + codecs_registry=self._impl.codecs_registry, + query_cache=self._impl.query_cache, + ) + + def _get_retry_options(self) -> typing.Optional[_options.RetryOptions]: + return self._options.retry_options + + @property + def max_concurrency(self) -> int: + """Max number of connections in the pool.""" + + return self._impl.get_max_concurrency() + + @property + def free_size(self) -> int: + """Number of available connections in the pool.""" + + return self._impl.get_free_size() + + async def _query(self, query_context: abstract.QueryContext): + con = await self._impl.acquire() + try: + result, _ = await con.raw_query(query_context) + return result + finally: + await self._impl.release(con) + + async def execute(self, query: str) -> None: + con = await self._impl.acquire() + try: + await con.execute(query) + finally: + await self._impl.release(con) + + def terminate(self): + """Terminate all connections in the pool.""" + self._impl.terminate() diff --git a/edgedb/base_con.py b/edgedb/base_con.py deleted file mode 100644 index 4702d941..00000000 --- a/edgedb/base_con.py +++ /dev/null @@ -1,164 +0,0 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import itertools -import typing -import uuid - -from . import errors - -from .protocol.protocol import CodecsRegistry as _CodecsRegistry -from .protocol.protocol import QueryCodecsCache as _QueryCodecsCache - - -BaseConnection_T = typing.TypeVar('BaseConnection_T', bound='BaseConnection') - - -class BorrowReason: - TRANSACTION = 'transaction' - - -BORROW_ERRORS = { - BorrowReason.TRANSACTION: - "Connection object is borrowed for a transaction. " - "Use the methods on transaction object instead.", -} - - -def borrow_error(condition): - raise errors.InterfaceError(BORROW_ERRORS[condition]) - - -class _InnerConnection: - - def __init__(self, addrs, config, params, *, - codecs_registry=None, query_cache=None): - super().__init__() - self._log_listeners = set() - - self._addrs = addrs - self._config = config - self._params = params - - if codecs_registry is not None: - self._codecs_registry = codecs_registry - else: - self._codecs_registry = _CodecsRegistry() - - if query_cache is not None: - self._query_cache = query_cache - else: - self._query_cache = _QueryCodecsCache() - - self._top_xact = None - self._borrowed_for = None - self._impl = None - - def _dispatch_log_message(self, msg): - for cb in self._log_listeners: - cb(self, msg) - - def _on_log_message(self, msg): - if self._log_listeners: - self._dispatch_log_message(msg) - - def _get_unique_id(self, prefix): - return f'_edgedb_{prefix}_{_uid_counter():x}_' - - -class BaseConnection: - _inner: _InnerConnection - - def connected_addr(self): - return self._inner._impl._addr - - def _clear_codecs_cache(self): - self._inner._codecs_registry.clear_cache() - - def _set_type_codec( - self, - typeid: uuid.UUID, - *, - encoder: typing.Callable[[typing.Any], typing.Any], - decoder: typing.Callable[[typing.Any], typing.Any], - format: str - ): - self._inner._codecs_registry.set_type_codec( - typeid, - encoder=encoder, - decoder=decoder, - format=format, - ) - - def _get_last_status(self) -> typing.Optional[str]: - impl = self._inner._impl - if impl is None: - return None - if impl._protocol is None: - return None - status = impl._protocol.last_status - if status is not None: - status = status.decode() - return status - - def _cleanup(self): - self._inner._log_listeners.clear() - - def add_log_listener( - self: BaseConnection_T, - callback: typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], - None] - ) -> None: - """Add a listener for EdgeDB log messages. - - :param callable callback: - A callable receiving the following arguments: - **connection**: a Connection the callback is registered with; - **message**: the `edgedb.EdgeDBMessage` message. - """ - self._inner._log_listeners.add(callback) - - def remove_log_listener( - self: BaseConnection_T, - callback: typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], - None] - ) -> None: - """Remove a listening callback for log messages.""" - self._inner._log_listeners.discard(callback) - - @property - def dbname(self) -> str: - return self._inner._params.database - - def is_closed(self) -> bool: - raise NotImplementedError - - def is_in_transaction(self) -> bool: - """Return True if Connection is currently inside a transaction. - - :return bool: True if inside transaction, False otherwise. - """ - return self._inner._impl._protocol.is_in_transaction() - - def get_settings(self) -> typing.Dict[str, typing.Any]: - return self._inner._impl._protocol.get_settings() - - -# Thread-safe "+= 1" counter. -_uid_counter = itertools.count(1).__next__ diff --git a/edgedb/blocking_client.py b/edgedb/blocking_client.py new file mode 100644 index 00000000..95848f33 --- /dev/null +++ b/edgedb/blocking_client.py @@ -0,0 +1,380 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2022-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import queue +import socket +import ssl +import threading +import time +import typing + +from . import abstract +from . import base_client +from . import con_utils +from . import errors +from . import transaction +from .protocol import blocking_proto + + +class BlockingIOConnection(base_client.BaseConnection): + __slots__ = () + + async def connect_addr(self, addr, timeout): + deadline = time.monotonic() + timeout + tls_compat = False + + if isinstance(addr, str): + # UNIX socket + sock = socket.socket(socket.AF_UNIX) + else: + sock = socket.socket(socket.AF_INET) + + try: + sock.settimeout(timeout) + + try: + sock.connect(addr) + + if not isinstance(addr, str): + time_left = deadline - time.monotonic() + if time_left <= 0: + raise TimeoutError + + # Upgrade to TLS + if self._params.ssl_ctx.check_hostname: + server_hostname = addr[0] + else: + server_hostname = None + sock.settimeout(time_left) + try: + sock = self._params.ssl_ctx.wrap_socket( + sock, server_hostname=server_hostname + ) + except ssl.CertificateError as e: + raise con_utils.wrap_error(e) from e + except ssl.SSLError as e: + if e.reason == 'CERTIFICATE_VERIFY_FAILED': + raise con_utils.wrap_error(e) from e + + # Retry in plain text + time_left = deadline - time.monotonic() + if time_left <= 0: + raise TimeoutError + sock.close() + sock = socket.socket(socket.AF_INET) + sock.settimeout(time_left) + sock.connect(addr) + tls_compat = True + else: + con_utils.check_alpn_protocol(sock) + except socket.gaierror as e: + # All name resolution errors are considered temporary + err = errors.ClientConnectionFailedTemporarilyError(str(e)) + raise err from e + except OSError as e: + raise con_utils.wrap_error(e) from e + + time_left = deadline - time.monotonic() + if time_left <= 0: + raise TimeoutError + + if not isinstance(addr, str): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + proto = blocking_proto.BlockingIOProtocol( + self._params, sock, tls_compat + ) + proto.set_connection(self) + + try: + sock.settimeout(time_left) + await proto.connect() + sock.settimeout(None) + except OSError as e: + raise con_utils.wrap_error(e) from e + + self._protocol = proto + self._addr = addr + + except Exception: + sock.close() + raise + + async def sleep(self, seconds): + time.sleep(seconds) + + def is_closed(self): + proto = self._protocol + return not (proto and proto.sock is not None and + proto.sock.fileno() >= 0 and proto.connected) + + def _dispatch_log_message(self, msg): + for cb in self._log_listeners: + cb(self, msg) + + +class _PoolConnectionHolder(base_client.PoolConnectionHolder): + __slots__ = () + _event_class = threading.Event + + async def close(self, *, wait=True): + if self._con is None: + return + await self._con.close() + + async def wait_until_released(self, timeout=None): + self._release_event.wait(timeout) + + +class _PoolImpl(base_client.BasePoolImpl): + _holder_class = _PoolConnectionHolder + + def __init__( + self, + connect_args, + *, + max_concurrency: typing.Optional[int], + connection_class, + ): + if not issubclass(connection_class, BlockingIOConnection): + raise TypeError( + f'connection_class is expected to be a subclass of ' + f'edgedb.blocking_client.BlockingIOConnection, ' + f'got {connection_class}') + super().__init__( + connect_args, + connection_class, + max_concurrency=max_concurrency, + ) + + def _ensure_initialized(self): + if self._queue is None: + self._queue = queue.LifoQueue(maxsize=self._max_concurrency) + self._first_connect_lock = threading.Lock() + self._resize_holder_pool() + + def _set_queue_maxsize(self, maxsize): + with self._queue.mutex: + self._queue.maxsize = maxsize + + async def _maybe_get_first_connection(self): + with self._first_connect_lock: + if self._working_addr is None: + return await self._get_first_connection() + + async def acquire(self, timeout=None): + self._ensure_initialized() + + if self._closing: + raise errors.InterfaceError('pool is closing') + + ch = self._queue.get(timeout=timeout) + try: + con = await ch.acquire() + except Exception: + self._queue.put_nowait(ch) + raise + else: + # Record the timeout, as we will apply it by default + # in release(). + ch._timeout = timeout + return con + + async def _release(self, holder): + if not isinstance(holder._con, BlockingIOConnection): + raise errors.InterfaceError( + f'release() received invalid connection: ' + f'{holder._con!r} does not belong to any connection pool' + ) + + timeout = None + return await holder.release(timeout) + + async def close(self, timeout=None): + if self._closed: + return + self._closing = True + try: + if timeout is None: + for ch in self._holders: + await ch.wait_until_released() + else: + remaining = timeout + for ch in self._holders: + start = time.monotonic() + await ch.wait_until_released(remaining) + remaining -= time.monotonic() - start + if remaining <= 0: + self.terminate() + return + for ch in self._holders: + await ch.close() + except Exception: + self.terminate() + raise + finally: + self._closed = True + self._closing = False + + +class Iteration(transaction.BaseTransaction, abstract.Executor): + + __slots__ = ("_managed",) + + def __init__(self, retry, client, iteration): + super().__init__(retry, client, iteration) + self._managed = False + + def __enter__(self): + if self._managed: + raise errors.InterfaceError( + 'cannot enter context: already in a `with` block') + self._managed = True + return self + + def __exit__(self, extype, ex, tb): + self._managed = False + return self._client._iter_coroutine(self._exit(extype, ex)) + + async def _ensure_transaction(self): + if not self._managed: + raise errors.InterfaceError( + "Only managed retriable transactions are supported. " + "Use `with transaction:`" + ) + await super()._ensure_transaction() + + def _query(self, query_context: abstract.QueryContext): + return self._client._iter_coroutine(super()._query(query_context)) + + def execute(self, query: str) -> None: + self._client._iter_coroutine(super().execute(query)) + + +class Retry(transaction.BaseRetry): + + def __iter__(self): + return self + + def __next__(self): + # Note: when changing this code consider also + # updating AsyncIORetry.__anext__. + if self._done: + raise StopIteration + if self._next_backoff: + time.sleep(self._next_backoff) + self._done = True + iteration = Iteration(self, self._owner, self._iteration) + self._iteration += 1 + return iteration + + +class Client(base_client.BaseClient, abstract.Executor): + """A lazy connection pool. + + A Client can be used to manage a set of connections to the database. + Connections are first acquired from the pool, then used, and then released + back to the pool. Once a connection is released, it's reset to close all + open cursors and other resources *except* prepared statements. + + Clients are created by calling + :func:`~edgedb.blocking_client.create_client`. + """ + + __slots__ = () + _impl_class = _PoolImpl + + def _iter_coroutine(self, coro): + try: + coro.send(None) + except StopIteration as ex: + if ex.args: + result = ex.args[0] + else: + result = None + finally: + coro.close() + return result + + def _query(self, query_context: abstract.QueryContext): + return self._iter_coroutine(super()._query(query_context)) + + def execute(self, query: str) -> None: + self._iter_coroutine(super().execute(query)) + + def ensure_connected(self): + self._iter_coroutine(self._impl.ensure_connected()) + return self + + def transaction(self) -> Retry: + return Retry(self) + + def close(self, timeout=None): + """Attempt to gracefully close all connections in the client. + + Wait until all pool connections are released, close them and + shut down the pool. If any error (including cancellation) occurs + in ``close()`` the pool will terminate by calling + Client.terminate() . + """ + self._iter_coroutine(self._impl.close(timeout)) + + def __enter__(self): + return self.ensure_connected() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +def create_client( + dsn=None, + *, + max_concurrency=None, + host: str = None, + port: int = None, + credentials: str = None, + credentials_file: str = None, + user: str = None, + password: str = None, + database: str = None, + tls_ca: str = None, + tls_ca_file: str = None, + tls_security: str = None, + wait_until_available: int = 30, + timeout: int = 10, +): + return Client( + connection_class=BlockingIOConnection, + max_concurrency=max_concurrency, + + # connect arguments + dsn=dsn, + host=host, + port=port, + credentials=credentials, + credentials_file=credentials_file, + user=user, + password=password, + database=database, + tls_ca=tls_ca, + tls_ca_file=tls_ca_file, + tls_security=tls_security, + wait_until_available=wait_until_available, + timeout=timeout, + ) diff --git a/edgedb/blocking_con.py b/edgedb/blocking_con.py deleted file mode 100644 index 356aee2e..00000000 --- a/edgedb/blocking_con.py +++ /dev/null @@ -1,486 +0,0 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import random -import socket -import ssl -import time -import typing - -from . import abstract -from . import base_con -from . import con_utils -from . import enums -from . import errors -from . import options -from . import retry as _retry - -from .datatypes import datatypes -from .protocol import blocking_proto, protocol -from .protocol.protocol import CodecsRegistry as _CodecsRegistry -from .protocol.protocol import QueryCodecsCache as _QueryCodecsCache - - -class _BlockingIOConnectionImpl: - - def __init__(self, codecs_registry, query_cache): - self._addr = None - self._protocol = None - self._codecs_registry = codecs_registry - self._query_cache = query_cache - - def connect(self, addrs, config, params, *, - single_attempt=False, connection): - addr = None - start = time.monotonic() - if single_attempt: - max_time = 0 - else: - max_time = start + config.wait_until_available - iteration = 1 - - while True: - for addr in addrs: - try: - self._connect_addr(addr, config, params, connection) - except TimeoutError as e: - if iteration == 1 or time.monotonic() < max_time: - continue - else: - raise errors.ClientConnectionTimeoutError( - f"connecting to {addr} failed in" - f" {config.connect_timeout} sec" - ) from e - except errors.ClientConnectionError as e: - if ( - e.has_tag(errors.SHOULD_RECONNECT) and - (iteration == 1 or time.monotonic() < max_time) - ): - continue - nice_err = e.__class__( - con_utils.render_client_no_connection_error( - e, - addr, - attempts=iteration, - duration=time.monotonic() - start, - )) - raise nice_err from e.__cause__ - else: - assert self._protocol - return - - iteration += 1 - time.sleep(0.01 + random.random() * 0.2) - - def _connect_addr(self, addr, config, params, connection): - timeout = config.connect_timeout - deadline = time.monotonic() + timeout - tls_compat = False - - if isinstance(addr, str): - # UNIX socket - sock = socket.socket(socket.AF_UNIX) - else: - sock = socket.socket(socket.AF_INET) - - try: - sock.settimeout(timeout) - - try: - sock.connect(addr) - - if not isinstance(addr, str): - time_left = deadline - time.monotonic() - if time_left <= 0: - raise TimeoutError - - # Upgrade to TLS - if params.ssl_ctx.check_hostname: - server_hostname = addr[0] - else: - server_hostname = None - sock.settimeout(time_left) - try: - sock = params.ssl_ctx.wrap_socket( - sock, server_hostname=server_hostname - ) - except ssl.CertificateError as e: - raise con_utils.wrap_error(e) from e - except ssl.SSLError as e: - if e.reason == 'CERTIFICATE_VERIFY_FAILED': - raise con_utils.wrap_error(e) from e - - # Retry in plain text - time_left = deadline - time.monotonic() - if time_left <= 0: - raise TimeoutError - sock.close() - sock = socket.socket(socket.AF_INET) - sock.settimeout(time_left) - sock.connect(addr) - tls_compat = True - else: - con_utils.check_alpn_protocol(sock) - except socket.gaierror as e: - # All name resolution errors are considered temporary - err = errors.ClientConnectionFailedTemporarilyError(str(e)) - raise err from e - except OSError as e: - raise con_utils.wrap_error(e) from e - - time_left = deadline - time.monotonic() - if time_left <= 0: - raise TimeoutError - - if not isinstance(addr, str): - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - proto = blocking_proto.BlockingIOProtocol( - params, sock, tls_compat - ) - proto.set_connection(connection) - - try: - sock.settimeout(time_left) - proto.sync_connect() - sock.settimeout(None) - except OSError as e: - raise con_utils.wrap_error(e) from e - - self._protocol = proto - self._addr = addr - - except Exception: - sock.close() - raise - - def privileged_execute(self, query): - self._protocol.sync_simple_query(query, enums.Capability.ALL) - - def is_closed(self): - proto = self._protocol - return not (proto and proto.sock is not None and - proto.sock.fileno() >= 0 and proto.connected) - - def close(self): - if self._protocol: - self._protocol.abort() - - -class BlockingIOConnection( - base_con.BaseConnection, - abstract.Executor, - options._OptionsMixin, -): - - def __init__(self, addrs, config, params, *, - codecs_registry, query_cache): - self._inner = base_con._InnerConnection( - addrs, config, params, - codecs_registry=codecs_registry, - query_cache=query_cache) - super().__init__() - - def _shallow_clone(self): - if self._inner._borrowed_for: - raise base_con.borrow_error(self._inner._borrowed_for) - new_conn = self.__class__.__new__(self.__class__) - new_conn._inner = self._inner - return new_conn - - def ensure_connected(self, single_attempt=False): - inner = self._inner - if inner._borrowed_for: - raise base_con.borrow_error(inner._borrowed_for) - if not inner._impl or inner._impl.is_closed(): - self._reconnect(single_attempt=single_attempt) - - def _reconnect(self, single_attempt=False): - inner = self._inner - inner._impl = _BlockingIOConnectionImpl( - inner._codecs_registry, inner._query_cache) - inner._impl.connect(inner._addrs, inner._config, inner._params, - single_attempt=single_attempt, connection=inner) - assert inner._impl._protocol - - def _get_protocol(self): - inner = self._inner - if inner._borrowed_for: - raise base_con.borrow_error(inner._borrowed_for) - if not inner._impl or inner._impl.is_closed(): - self._reconnect() - return inner._impl._protocol - - def _dump( - self, - *, - on_header: typing.Callable[[bytes], None], - on_data: typing.Callable[[bytes], None], - ) -> None: - self._get_protocol().sync_dump( - header_callback=on_header, - block_callback=on_data) - - def _restore( - self, - *, - header: bytes, - data_gen: typing.Iterable[bytes], - ) -> None: - self._get_protocol().sync_restore( - header=header, - data_gen=data_gen - ) - - def _dispatch_log_message(self, msg): - for cb in self._inner._log_listeners: - cb(self, msg) - - def _fetchall( - self, - query: str, - *args, - __limit__: int=0, - __typenames__: bool=False, - **kwargs, - ) -> datatypes.Set: - inner = self._inner - return self._get_protocol().sync_execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=inner._codecs_registry, - qc=inner._query_cache, - implicit_limit=__limit__, - inline_typenames=__typenames__, - io_format=protocol.IoFormat.BINARY, - allow_capabilities=enums.Capability.EXECUTE, - ) - - def _fetchall_json( - self, - query: str, - *args, - __limit__: int=0, - **kwargs, - ) -> datatypes.Set: - inner = self._inner - return self._get_protocol().sync_execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=inner._codecs_registry, - qc=inner._query_cache, - implicit_limit=__limit__, - inline_typenames=False, - io_format=protocol.IoFormat.JSON, - allow_capabilities=enums.Capability.EXECUTE, - ) - - def _execute( - self, - *, - query: str, - args, - kwargs, - io_format, - expect_one=False, - required_one=False, - ): - inner = self._inner - reconnect = False - capabilities = None - i = 0 - while True: - try: - if reconnect: - self._reconnect(single_attempt=True) - return self._get_protocol().sync_execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=inner._codecs_registry, - qc=inner._query_cache, - expect_one=expect_one, - required_one=required_one, - io_format=io_format, - allow_capabilities=enums.Capability.EXECUTE, - ) - except errors.EdgeDBError as e: - if not e.has_tag(errors.SHOULD_RETRY): - raise e - if capabilities is None: - cache_item = inner._query_cache.get( - query=query, - io_format=io_format, - implicit_limit=0, - inline_typenames=False, - inline_typeids=False, - expect_one=expect_one, - ) - if cache_item is not None: - _, _, _, capabilities = cache_item - # A query is read-only if it has no capabilities i.e. - # capabilities == 0. Read-only queries are safe to retry. - # Explicit transaction conflicts as well. - if ( - capabilities != 0 - and not isinstance(e, errors.TransactionConflictError) - ): - raise e - rule = self._options.retry_options.get_rule_for_exception(e) - if i >= rule.attempts: - raise e - time.sleep(rule.backoff(i)) - reconnect = self.is_closed() - - def query(self, query: str, *args, **kwargs) -> datatypes.Set: - return self._execute( - query=query, - args=args, - kwargs=kwargs, - io_format=protocol.IoFormat.BINARY, - ) - - def query_single( - self, query: str, *args, **kwargs - ) -> typing.Union[typing.Any, None]: - return self._execute( - query=query, - args=args, - kwargs=kwargs, - expect_one=True, - io_format=protocol.IoFormat.BINARY, - ) - - def query_required_single(self, query: str, *args, **kwargs) -> typing.Any: - return self._execute( - query=query, - args=args, - kwargs=kwargs, - expect_one=True, - required_one=True, - io_format=protocol.IoFormat.BINARY, - ) - - def query_json(self, query: str, *args, **kwargs) -> str: - return self._execute( - query=query, - args=args, - kwargs=kwargs, - io_format=protocol.IoFormat.JSON, - ) - - def _fetchall_json_elements( - self, query: str, *args, **kwargs) -> typing.List[str]: - inner = self._inner - return self._get_protocol().sync_execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=inner._codecs_registry, - qc=inner._query_cache, - io_format=protocol.IoFormat.JSON_ELEMENTS, - allow_capabilities=enums.Capability.EXECUTE, - ) - - def query_single_json(self, query: str, *args, **kwargs) -> str: - inner = self._inner - return self._get_protocol().sync_execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=inner._codecs_registry, - qc=inner._query_cache, - expect_one=True, - io_format=protocol.IoFormat.JSON, - allow_capabilities=enums.Capability.EXECUTE, - ) - - def query_required_single_json(self, query: str, *args, **kwargs) -> str: - inner = self._inner - return self._get_protocol().sync_execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=inner._codecs_registry, - qc=inner._query_cache, - expect_one=True, - required_one=True, - io_format=protocol.IoFormat.JSON, - allow_capabilities=enums.Capability.EXECUTE, - ) - - def execute(self, query: str) -> None: - self._get_protocol().sync_simple_query(query, enums.Capability.EXECUTE) - - def transaction(self) -> _retry.Retry: - return _retry.Retry(self) - - def close(self) -> None: - if not self.is_closed(): - self._inner._impl.close() - - def is_closed(self) -> bool: - return self._inner._impl is None or self._inner._impl.is_closed() - - -def connect( - dsn: str = None, - *, - host: str = None, - port: int = None, - credentials: str = None, - credentials_file: str = None, - user: str = None, - password: str = None, - database: str = None, - tls_ca: str = None, - tls_ca_file: str = None, - tls_security: str = None, - timeout: int = 10, - wait_until_available: int = 30, -) -> BlockingIOConnection: - - connect_config, client_config = con_utils.parse_connect_arguments( - dsn=dsn, - host=host, - port=port, - credentials=credentials, - credentials_file=credentials_file, - user=user, - password=password, - database=database, - timeout=timeout, - wait_until_available=wait_until_available, - tls_ca=tls_ca, - tls_ca_file=tls_ca_file, - tls_security=tls_security, - - # ToDos - command_timeout=None, - server_settings=None) - - conn = BlockingIOConnection( - addrs=[connect_config.address], params=connect_config, - config=client_config, - codecs_registry=_CodecsRegistry(), - query_cache=_QueryCodecsCache()) - conn.ensure_connected() - return conn diff --git a/edgedb/protocol/blocking_proto.pxd b/edgedb/protocol/blocking_proto.pxd index 4981426b..5bad598d 100644 --- a/edgedb/protocol/blocking_proto.pxd +++ b/edgedb/protocol/blocking_proto.pxd @@ -27,4 +27,4 @@ cdef class BlockingIOProtocol(protocol.SansIOProtocol): cdef: readonly object sock - cdef _iter_coroutine(self, coro) + cdef _disconnect(self) diff --git a/edgedb/protocol/blocking_proto.pyx b/edgedb/protocol/blocking_proto.pyx index 2d14b22b..36a64856 100644 --- a/edgedb/protocol/blocking_proto.pyx +++ b/edgedb/protocol/blocking_proto.pyx @@ -38,6 +38,9 @@ cdef class BlockingIOProtocol(protocol.SansIOProtocol): cpdef abort(self): self.terminate() + self._disconnect() + + cdef _disconnect(self): self.connected = False if self.sock is not None: self.sock.close() @@ -47,7 +50,7 @@ cdef class BlockingIOProtocol(protocol.SansIOProtocol): try: self.sock.send(buf) except OSError as e: - self.connected = False + self._disconnect() raise con_utils.wrap_error(e) from e async def wait_for_message(self): @@ -55,10 +58,10 @@ cdef class BlockingIOProtocol(protocol.SansIOProtocol): try: data = self.sock.recv(RECV_BUF) except OSError as e: - self.connected = False + self._disconnect() raise con_utils.wrap_error(e) from e if not data: - self.connected = False + self._disconnect() raise errors.ClientConnectionClosedError() self.buffer.feed_data(data) @@ -71,14 +74,14 @@ cdef class BlockingIOProtocol(protocol.SansIOProtocol): while not self.buffer.take_message(): data = self.sock.recv(RECV_BUF) if not data: - self.connected = False + self._disconnect() raise errors.ClientConnectionClosedError() self.buffer.feed_data(data) except BlockingIOError: # No data in the socket net buffer. return except OSError as e: - self.connected = False + self._disconnect() raise con_utils.wrap_error(e) from e finally: self.sock.settimeout(None) @@ -86,46 +89,13 @@ cdef class BlockingIOProtocol(protocol.SansIOProtocol): async def wait_for_connect(self): return True - cdef _iter_coroutine(self, coro): + async def wait_for_disconnect(self): + if self.cancelled or not self.connected: + return try: - coro.send(None) - except StopIteration as ex: - if ex.args: - result = ex.args[0] - else: - result = None - finally: - coro.close() - return result - - def sync_connect(self): - return self._iter_coroutine(self.connect()) - - def sync_execute_anonymous(self, *args, **kwargs): - result, _headers = self._iter_coroutine( - self.execute_anonymous(*args, **kwargs), - ) - # don't expose headers to blocking client for now - return result - - def sync_simple_query(self, *args, **kwargs): - return self._iter_coroutine(self.simple_query(*args, **kwargs)) - - def sync_dump(self, *, header_callback, block_callback): - async def header_wrapper(data): - header_callback(data) - async def block_wrapper(data): - block_callback(data) - return self._iter_coroutine(self.dump(header_wrapper, block_wrapper)) - - def sync_restore(self, *, header, data_gen): - async def wrapper(): while True: - try: - block = next(data_gen) - except StopIteration: - return - yield block - - return self._iter_coroutine(self.restore( - header, wrapper())) + if not self.buffer.take_message(): + await self.wait_for_message() + self.fallthrough() + except errors.ClientConnectionClosedError: + pass diff --git a/edgedb/protocol/protocol.pyx b/edgedb/protocol/protocol.pyx index 1024c727..9402b6eb 100644 --- a/edgedb/protocol/protocol.pyx +++ b/edgedb/protocol/protocol.pyx @@ -73,12 +73,12 @@ cpython.datetime.import_datetime() _QUERY_SINGLE_METHOD = { True: { IoFormat.JSON: 'query_required_single_json', - IoFormat.JSON_ELEMENTS: '_fetchall_json_elements', + IoFormat.JSON_ELEMENTS: 'raw_query', IoFormat.BINARY: 'query_required_single', }, False: { IoFormat.JSON: 'query_single_json', - IoFormat.JSON_ELEMENTS: '_fetchall_json_elements', + IoFormat.JSON_ELEMENTS: 'raw_query', IoFormat.BINARY: 'query_single', }, } diff --git a/edgedb/retry.py b/edgedb/retry.py deleted file mode 100644 index 242e7855..00000000 --- a/edgedb/retry.py +++ /dev/null @@ -1,189 +0,0 @@ -import asyncio -import time - -from . import errors -from . import transaction as _transaction - - -class AsyncIOIteration(_transaction.BaseAsyncIOTransaction): - def __init__(self, retry, owner, iteration): - super().__init__(owner, retry._options.transaction_options) - self.__retry = retry - self.__iteration = iteration - self.__started = False - - async def _ensure_transaction(self): - if not self._managed: - raise errors.InterfaceError( - "Only managed retriable transactions are supported. " - "Use `async with transaction:`" - ) - if not self.__started: - self.__started = True - await self._start(single_connect=self.__iteration != 0) - if self._pool is not None: - # Having a pool means we just acquired the connection in - # _start() - let's mark it as borrowed for transaction anyways - # just in case the connection is somehow accessed separately. - self._borrow() - - async def __aenter__(self): - if self._managed: - raise errors.InterfaceError( - 'cannot enter context: already in an `async with` block') - self._managed = True - if self._pool is None: - # Borrow the connection for transaction now if it's not on a pool, - # because that means we already have the connection now, and - # further use of the connection like this should be prevented: - # async for tx in conn.transaction(): - # async with tx: - # await conn.query("...") # <- wrong use after borrow - self._borrow() - return self - - async def __aexit__(self, extype, ex, tb): - self._managed = False - if not self.__started: - self._maybe_return() - return False - - try: - if extype is not None: - await self._rollback() - else: - await self._commit() - except errors.EdgeDBError as err: - if ex is None: - # On commit we don't know if commit is succeeded before the - # database have received it or after it have been done but - # network is dropped before we were able to receive a response - # TODO(tailhook) retry on some errors - raise err - # If we were going to rollback, look at original error - # to find out whether we want to retry, regardless of - # the rollback error. - # In this case we ignore rollback issue as original error is more - # important, e.g. in case `CancelledError` it's important - # to propagate it to cancel the whole task. - # NOTE: rollback error is always swallowed, should we use - # on_log_message for it? - - if ( - extype is not None and - issubclass(extype, errors.EdgeDBError) and - ex.has_tag(errors.SHOULD_RETRY) - ): - return self.__retry._retry(ex) - - -class BaseRetry: - - def __init__(self, owner): - self._owner = owner - self._iteration = 0 - self._done = False - self._next_backoff = 0 - self._options = owner._options - - def _retry(self, exc): - self._last_exception = exc - rule = self._options.retry_options.get_rule_for_exception(exc) - if self._iteration >= rule.attempts: - return False - self._done = False - self._next_backoff = rule.backoff(self._iteration) - return True - - -class AsyncIORetry(BaseRetry): - - def __aiter__(self): - return self - - async def __anext__(self): - # Note: when changing this code consider also - # updating Retry.__next__. - if self._done: - raise StopAsyncIteration - if self._next_backoff: - await asyncio.sleep(self._next_backoff) - self._done = True - iteration = AsyncIOIteration(self, self._owner, self._iteration) - self._iteration += 1 - return iteration - - -class Retry(BaseRetry): - - def __iter__(self): - return self - - def __next__(self): - # Note: when changing this code consider also - # updating AsyncIORetry.__anext__. - if self._done: - raise StopIteration - if self._next_backoff: - time.sleep(self._next_backoff) - self._done = True - iteration = Iteration(self, self._owner, self._iteration) - self._iteration += 1 - return iteration - - -class Iteration(_transaction.BaseBlockingIOTransaction): - def __init__(self, retry, owner, iteration): - super().__init__(owner, retry._options.transaction_options) - self.__retry = retry - self.__iteration = iteration - self.__started = False - - def _ensure_transaction(self): - if not self._managed: - raise errors.InterfaceError( - "Only managed retriable transactions are supported. " - "Use `with transaction:`" - ) - if not self.__started: - self.__started = True - self._start(single_connect=self.__iteration != 0) - - def __enter__(self): - if self._managed: - raise errors.InterfaceError( - 'cannot enter context: already in a `with` block') - self._managed = True - self._borrow() - return self - - def __exit__(self, extype, ex, tb): - self._managed = False - if not self.__started: - self._maybe_return() - return False - - try: - if extype is not None: - self._rollback() - else: - self._commit() - except errors.EdgeDBError as err: - if ex is None: - # On commit we don't know if commit is succeeded before the - # database have received it or after it have been done but - # network is dropped before we were able to receive a response - # TODO(tailhook) retry on some errors - raise err - # If we were going to rollback, look at original error - # to find out whether we want to retry, regardless of - # the rollback error. - # In this case we ignore rollback issue as original error is more - # important. - - if ( - extype is not None and - issubclass(extype, errors.EdgeDBError) and - ex.has_tag(errors.SHOULD_RETRY) - ): - return self.__retry._retry(ex) diff --git a/edgedb/transaction.py b/edgedb/transaction.py index 87124f74..3acf3b6c 100644 --- a/edgedb/transaction.py +++ b/edgedb/transaction.py @@ -18,15 +18,9 @@ import enum -import typing from . import abstract -from . import base_con -from . import enums from . import errors -from . import options -from .datatypes import datatypes -from .protocol import protocol class TransactionState(enum.Enum): @@ -40,28 +34,23 @@ class TransactionState(enum.Enum): class BaseTransaction: __slots__ = ( + '_client', '_connection', - '_connection_inner', - '_connection_impl', - '_pool', '_options', '_state', - '_managed', + '__retry', + '__iteration', + '__started', ) - def __init__(self, owner, options: options.TransactionOptions): - if isinstance(owner, base_con.BaseConnection): - self._connection = owner - self._connection_inner = owner._inner - self._pool = None - else: - self._connection = None - self._connection_inner = None - self._pool = owner - self._connection_impl = None - self._options = options + def __init__(self, retry, client, iteration): + self._client = client + self._connection = None + self._options = retry._options.transaction_options self._state = TransactionState.NEW - self._managed = False + self.__retry = retry + self.__iteration = iteration + self.__started = False def is_active(self) -> bool: return self._state is TransactionState.STARTED @@ -104,16 +93,6 @@ def _make_rollback_query(self): self.__check_state('rollback') return 'ROLLBACK;' - def _borrow(self): - inner = self._connection_inner - if inner._borrowed_for: - raise base_con.borrow_error(inner._borrowed_for) - inner._borrowed_for = base_con.BorrowReason.TRANSACTION - - def _maybe_return(self): - if self._connection_inner is not None: - self._connection_inner._borrowed_for = None - def __repr__(self): attrs = [] attrs.append('state:{}'.format(self._state.name.lower())) @@ -127,154 +106,72 @@ def __repr__(self): return '<{}.{} {} {:#x}>'.format( mod, self.__class__.__name__, ' '.join(attrs), id(self)) - -class BaseAsyncIOTransaction(BaseTransaction, abstract.AsyncIOExecutor): - __slots__ = () - - async def _start(self, single_connect=False) -> None: - query = self._make_start_query() - if self._pool is not None: - self._connection = await self._pool._acquire() - self._connection_inner = self._connection._inner - inner = self._connection_inner - if not inner._impl or inner._impl.is_closed(): - await self._connection._reconnect(single_attempt=single_connect) - self._connection_impl = self._connection._inner._impl - try: - await self._connection_impl.privileged_execute(query) - except BaseException: - self._state = TransactionState.FAILED - raise - else: - self._state = TransactionState.STARTED - - async def _commit(self): - try: - query = self._make_commit_query() + async def _ensure_transaction(self): + if not self.__started: + self.__started = True + query = self._make_start_query() + self._connection = await self._client._impl.acquire() + if self._connection.is_closed(): + await self._connection.connect( + single_attempt=self.__iteration != 0 + ) try: - await self._connection_impl.privileged_execute(query) + await self._connection.privileged_execute(query) except BaseException: self._state = TransactionState.FAILED raise else: - self._state = TransactionState.COMMITTED - finally: - self._maybe_return() - if self._pool is not None: - await self._pool._release(self._connection) + self._state = TransactionState.STARTED + + async def _exit(self, extype, ex): + if not self.__started: + return False - async def _rollback(self): try: - query = self._make_rollback_query() + if extype is None: + query = self._make_commit_query() + state = TransactionState.COMMITTED + else: + query = self._make_rollback_query() + state = TransactionState.ROLLEDBACK try: - await self._connection_impl.privileged_execute(query) + await self._connection.privileged_execute(query) except BaseException: self._state = TransactionState.FAILED raise else: - self._state = TransactionState.ROLLEDBACK + self._state = state + except errors.EdgeDBError as err: + if ex is None: + # On commit we don't know if commit is succeeded before the + # database have received it or after it have been done but + # network is dropped before we were able to receive a response + # TODO(tailhook) retry on some errors + raise err + # If we were going to rollback, look at original error + # to find out whether we want to retry, regardless of + # the rollback error. + # In this case we ignore rollback issue as original error is more + # important, e.g. in case `CancelledError` it's important + # to propagate it to cancel the whole task. + # NOTE: rollback error is always swallowed, should we use + # on_log_message for it? finally: - self._maybe_return() - if self._pool is not None: - await self._pool._release(self._connection) + await self._client._impl.release(self._connection) - async def _ensure_transaction(self): - pass - - async def query(self, query: str, *args, **kwargs) -> datatypes.Set: - await self._ensure_transaction() - con = self._connection_inner - result, _ = await self._connection_impl._protocol.execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=con._codecs_registry, - qc=con._query_cache, - io_format=protocol.IoFormat.BINARY, - allow_capabilities=enums.Capability.EXECUTE, - ) - return result - - async def query_single( - self, query: str, *args, **kwargs - ) -> typing.Union[typing.Any, None]: - await self._ensure_transaction() - con = self._connection_inner - result, _ = await self._connection_impl._protocol.execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=con._codecs_registry, - qc=con._query_cache, - expect_one=True, - io_format=protocol.IoFormat.BINARY, - allow_capabilities=enums.Capability.EXECUTE, - ) - return result - - async def query_required_single( - self, query: str, *args, **kwargs - ) -> typing.Any: - await self._ensure_transaction() - con = self._connection_inner - result, _ = await self._connection_impl._protocol.execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=con._codecs_registry, - qc=con._query_cache, - expect_one=True, - required_one=True, - io_format=protocol.IoFormat.BINARY, - allow_capabilities=enums.Capability.EXECUTE, - ) - return result + if ( + extype is not None and + issubclass(extype, errors.EdgeDBError) and + ex.has_tag(errors.SHOULD_RETRY) + ): + return self.__retry._retry(ex) - async def query_json(self, query: str, *args, **kwargs) -> str: - await self._ensure_transaction() - con = self._connection_inner - result, _ = await self._connection_impl._protocol.execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=con._codecs_registry, - qc=con._query_cache, - io_format=protocol.IoFormat.JSON, - allow_capabilities=enums.Capability.EXECUTE, - ) - return result + def _get_query_cache(self) -> abstract.QueryCache: + return self._client._get_query_cache() - async def query_single_json(self, query: str, *args, **kwargs) -> str: + async def _query(self, query_context: abstract.QueryContext): await self._ensure_transaction() - con = self._connection_inner - result, _ = await self._connection_impl._protocol.execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=con._codecs_registry, - qc=con._query_cache, - expect_one=True, - io_format=protocol.IoFormat.JSON, - allow_capabilities=enums.Capability.EXECUTE, - ) - return result - - async def query_required_single_json( - self, query: str, *args, **kwargs - ) -> str: - await self._ensure_transaction() - con = self._connection_inner - result, _ = await self._connection_impl._protocol.execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=con._codecs_registry, - qc=con._query_cache, - expect_one=True, - required_one=True, - io_format=protocol.IoFormat.JSON, - allow_capabilities=enums.Capability.EXECUTE, - ) + result, _ = await self._connection.raw_query(query_context) return result async def execute(self, query: str) -> None: @@ -290,145 +187,23 @@ async def execute(self, query: str) -> None: ... ''') """ await self._ensure_transaction() - await self._connection_impl._protocol.simple_query( - query, enums.Capability.EXECUTE) - - -class BaseBlockingIOTransaction(BaseTransaction, abstract.Executor): - __slots__ = () - - def _start(self, single_connect=False) -> None: - query = self._make_start_query() - # no pools supported for blocking con - inner = self._connection_inner - if not inner._impl or inner._impl.is_closed(): - self._connection._reconnect(single_attempt=single_connect) - self._connection_inner = self._connection._inner - self._connection_impl = self._connection_inner._impl - try: - self._connection_impl.privileged_execute(query) - except BaseException: - self._state = TransactionState.FAILED - raise - else: - self._state = TransactionState.STARTED - - def _commit(self): - try: - query = self._make_commit_query() - try: - self._connection_impl.privileged_execute(query) - except BaseException: - self._state = TransactionState.FAILED - raise - else: - self._state = TransactionState.COMMITTED - finally: - self._maybe_return() - - def _rollback(self): - try: - query = self._make_rollback_query() - try: - self._connection_impl.privileged_execute(query) - except BaseException: - self._state = TransactionState.FAILED - raise - else: - self._state = TransactionState.ROLLEDBACK - finally: - self._maybe_return() - - def _ensure_transaction(self): - pass - - def query(self, query: str, *args, **kwargs) -> datatypes.Set: - self._ensure_transaction() - con = self._connection_inner - return self._connection_impl._protocol.sync_execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=con._codecs_registry, - qc=con._query_cache, - io_format=protocol.IoFormat.BINARY, - allow_capabilities=enums.Capability.EXECUTE, - ) - - def query_single( - self, query: str, *args, **kwargs - ) -> typing.Union[typing.Any, None]: - self._ensure_transaction() - con = self._connection_inner - return self._connection_impl._protocol.sync_execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=con._codecs_registry, - qc=con._query_cache, - expect_one=True, - io_format=protocol.IoFormat.BINARY, - allow_capabilities=enums.Capability.EXECUTE, - ) - - def query_required_single(self, query: str, *args, **kwargs) -> typing.Any: - self._ensure_transaction() - con = self._connection_inner - return self._connection_impl._protocol.sync_execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=con._codecs_registry, - qc=con._query_cache, - expect_one=True, - required_one=True, - io_format=protocol.IoFormat.BINARY, - allow_capabilities=enums.Capability.EXECUTE, - ) - - def query_json(self, query: str, *args, **kwargs) -> str: - self._ensure_transaction() - con = self._connection_inner - return self._connection_impl._protocol.sync_execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=con._codecs_registry, - qc=con._query_cache, - io_format=protocol.IoFormat.JSON, - allow_capabilities=enums.Capability.EXECUTE, - ) - - def query_single_json(self, query: str, *args, **kwargs) -> str: - self._ensure_transaction() - con = self._connection_inner - return self._connection_impl._protocol.sync_execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=con._codecs_registry, - qc=con._query_cache, - expect_one=True, - io_format=protocol.IoFormat.JSON, - allow_capabilities=enums.Capability.EXECUTE, - ) - - def query_required_single_json(self, query: str, *args, **kwargs) -> str: - self._ensure_transaction() - con = self._connection_inner - return self._connection_impl._protocol.sync_execute_anonymous( - query=query, - args=args, - kwargs=kwargs, - reg=con._codecs_registry, - qc=con._query_cache, - expect_one=True, - required_one=True, - io_format=protocol.IoFormat.JSON, - allow_capabilities=enums.Capability.EXECUTE, - ) - - def execute(self, query: str) -> None: - self._ensure_transaction() - self._connection_impl._protocol.sync_simple_query( - query, enums.Capability.EXECUTE) + await self._connection.execute(query) + + +class BaseRetry: + + def __init__(self, owner): + self._owner = owner + self._iteration = 0 + self._done = False + self._next_backoff = 0 + self._options = owner._options + + def _retry(self, exc): + self._last_exception = exc + rule = self._options.retry_options.get_rule_for_exception(exc) + if self._iteration >= rule.attempts: + return False + self._done = False + self._next_backoff = rule.backoff(self._iteration) + return True diff --git a/tests/test_async_query.py b/tests/test_async_query.py index 7561dd88..71df993a 100644 --- a/tests/test_async_query.py +++ b/tests/test_async_query.py @@ -26,10 +26,12 @@ import asyncio import edgedb +from edgedb import abstract from edgedb import compat from edgedb import _taskgroup as tg from edgedb import _testbase as tb from edgedb.options import RetryOptions +from edgedb.protocol import protocol class TestAsyncQuery(tb.AsyncQueryTestCase): @@ -48,106 +50,106 @@ class TestAsyncQuery(tb.AsyncQueryTestCase): def setUp(self): super().setUp() - self.con._clear_codecs_cache() + self.client._clear_codecs_cache() async def test_async_parse_error_recover_01(self): for _ in range(2): with self.assertRaises(edgedb.EdgeQLSyntaxError): - await self.con.query('select syntax error') + await self.client.query('select syntax error') with self.assertRaises(edgedb.EdgeQLSyntaxError): - await self.con.query('select syntax error') + await self.client.query('select syntax error') with self.assertRaisesRegex(edgedb.EdgeQLSyntaxError, 'Unexpected end of line'): - await self.con.query('select (') + await self.client.query('select (') with self.assertRaisesRegex(edgedb.EdgeQLSyntaxError, 'Unexpected end of line'): - await self.con.query_json('select (') + await self.client.query_json('select (') for _ in range(10): self.assertEqual( - await self.con.query('select 1;'), + await self.client.query('select 1;'), edgedb.Set((1,))) - self.assertFalse(self.con.is_closed()) + self.assertFalse(self.client.connection.is_closed()) async def test_async_parse_error_recover_02(self): for _ in range(2): with self.assertRaises(edgedb.EdgeQLSyntaxError): - await self.con.execute('select syntax error') + await self.client.execute('select syntax error') with self.assertRaises(edgedb.EdgeQLSyntaxError): - await self.con.execute('select syntax error') + await self.client.execute('select syntax error') for _ in range(10): - await self.con.execute('select 1; select 2;'), + await self.client.execute('select 1; select 2;'), async def test_async_exec_error_recover_01(self): for _ in range(2): with self.assertRaises(edgedb.DivisionByZeroError): - await self.con.query('select 1 / 0;') + await self.client.query('select 1 / 0;') with self.assertRaises(edgedb.DivisionByZeroError): - await self.con.query('select 1 / 0;') + await self.client.query('select 1 / 0;') for _ in range(10): self.assertEqual( - await self.con.query('select 1;'), + await self.client.query('select 1;'), edgedb.Set((1,))) async def test_async_exec_error_recover_02(self): for _ in range(2): with self.assertRaises(edgedb.DivisionByZeroError): - await self.con.execute('select 1 / 0;') + await self.client.execute('select 1 / 0;') with self.assertRaises(edgedb.DivisionByZeroError): - await self.con.execute('select 1 / 0;') + await self.client.execute('select 1 / 0;') for _ in range(10): - await self.con.execute('select 1;') + await self.client.execute('select 1;') async def test_async_exec_error_recover_03(self): query = 'select 10 // $0;' for i in [1, 2, 0, 3, 1, 0, 1]: if i: self.assertEqual( - await self.con.query(query, i), + await self.client.query(query, i), edgedb.Set([10 // i])) else: with self.assertRaises(edgedb.DivisionByZeroError): - await self.con.query(query, i) + await self.client.query(query, i) async def test_async_exec_error_recover_04(self): for i in [1, 2, 0, 3, 1, 0, 1]: if i: - await self.con.execute(f'select 10 // {i};') + await self.client.execute(f'select 10 // {i};') else: with self.assertRaises(edgedb.DivisionByZeroError): - await self.con.query(f'select 10 // {i};') + await self.client.query(f'select 10 // {i};') async def test_async_exec_error_recover_05(self): with self.assertRaisesRegex(edgedb.QueryError, 'cannot accept parameters'): - await self.con.execute(f'select $0') + await self.client.execute(f'select $0') self.assertEqual( - await self.con.query('SELECT "HELLO"'), + await self.client.query('SELECT "HELLO"'), ["HELLO"]) async def test_async_query_single_01(self): - res = await self.con.query_single("SELECT 1") + res = await self.client.query_single("SELECT 1") self.assertEqual(res, 1) - res = await self.con.query_single("SELECT {}") + res = await self.client.query_single("SELECT {}") self.assertEqual(res, None) - res = await self.con.query_required_single("SELECT 1") + res = await self.client.query_required_single("SELECT 1") self.assertEqual(res, 1) with self.assertRaises(edgedb.NoDataError): - await self.con.query_required_single("SELECT {}") + await self.client.query_required_single("SELECT {}") async def test_async_query_single_command_01(self): - r = await self.con.query(''' + r = await self.client.query(''' CREATE TYPE test::server_query_single_command_01 { CREATE REQUIRED PROPERTY server_query_single_command_01 -> std::str; @@ -155,12 +157,12 @@ async def test_async_query_single_command_01(self): ''') self.assertEqual(r, []) - r = await self.con.query(''' + r = await self.client.query(''' DROP TYPE test::server_query_single_command_01; ''') self.assertEqual(r, []) - r = await self.con.query(''' + r = await self.client.query(''' CREATE TYPE test::server_query_single_command_01 { CREATE REQUIRED PROPERTY server_query_single_command_01 -> std::str; @@ -168,12 +170,12 @@ async def test_async_query_single_command_01(self): ''') self.assertEqual(r, []) - r = await self.con.query_json(''' + r = await self.client.query_json(''' DROP TYPE test::server_query_single_command_01; ''') self.assertEqual(r, '[]') - r = await self.con.query_json(''' + r = await self.client.query_json(''' CREATE TYPE test::server_query_single_command_01 { CREATE REQUIRED PROPERTY server_query_single_command_01 -> std::str; @@ -184,49 +186,51 @@ async def test_async_query_single_command_01(self): with self.assertRaisesRegex( edgedb.InterfaceError, r'query cannot be executed with query_required_single_json\('): - await self.con.query_required_single_json(''' + await self.client.query_required_single_json(''' DROP TYPE test::server_query_single_command_01; ''') - r = await self.con.query_json(''' + r = await self.client.query_json(''' DROP TYPE test::server_query_single_command_01; ''') self.assertEqual(r, '[]') - self.assertTrue(self.con._get_last_status().startswith('DROP')) + self.assertTrue( + self.client.connection._get_last_status().startswith('DROP') + ) async def test_async_query_single_command_02(self): - r = await self.con.query(''' + r = await self.client.query(''' SET MODULE default; ''') self.assertEqual(r, []) - r = await self.con.query(''' + r = await self.client.query(''' RESET ALIAS *; ''') self.assertEqual(r, []) - r = await self.con.query(''' + r = await self.client.query(''' SET ALIAS bar AS MODULE std; ''') self.assertEqual(r, []) - r = await self.con.query(''' + r = await self.client.query(''' SET MODULE default; ''') self.assertEqual(r, []) - r = await self.con.query(''' + r = await self.client.query(''' SET ALIAS bar AS MODULE std; ''') self.assertEqual(r, []) - r = await self.con.query_json(''' + r = await self.client.query_json(''' SET MODULE default; ''') self.assertEqual(r, '[]') - r = await self.con.query_json(''' + r = await self.client.query_json(''' SET ALIAS foo AS MODULE default; ''') self.assertEqual(r, '[]') @@ -236,32 +240,32 @@ async def test_async_query_single_command_03(self): edgedb.InterfaceError, r'cannot be executed with query_required_single\(\).*' r'not return'): - await self.con.query_required_single('set module default') + await self.client.query_required_single('set module default') with self.assertRaisesRegex( edgedb.InterfaceError, r'cannot be executed with query_required_single_json\(\).*' r'not return'): - await self.con.query_required_single_json('set module default') + await self.client.query_required_single_json('set module default') async def test_async_query_single_command_04(self): with self.assertRaisesRegex(edgedb.ProtocolError, 'expected one statement'): - await self.con.query(''' + await self.client.query(''' SELECT 1; SET MODULE blah; ''') with self.assertRaisesRegex(edgedb.ProtocolError, 'expected one statement'): - await self.con.query_single(''' + await self.client.query_single(''' SELECT 1; SET MODULE blah; ''') with self.assertRaisesRegex(edgedb.ProtocolError, 'expected one statement'): - await self.con.query_json(''' + await self.client.query_json(''' SELECT 1; SET MODULE blah; ''') @@ -269,22 +273,22 @@ async def test_async_query_single_command_04(self): async def test_async_basic_datatypes_01(self): for _ in range(10): self.assertEqual( - await self.con.query_single( + await self.client.query_single( 'select ()'), ()) self.assertEqual( - await self.con.query( + await self.client.query( 'select (1,)'), edgedb.Set([(1,)])) self.assertEqual( - await self.con.query( + await self.client.query( 'select ["a", "b"]'), edgedb.Set([["a", "b"]])) self.assertEqual( - await self.con.query(''' + await self.client.query(''' SELECT {(a := 1 + 1 + 40, world := ("hello", 32)), (a:=1, world := ("yo", 10))}; '''), @@ -296,61 +300,61 @@ async def test_async_basic_datatypes_01(self): with self.assertRaisesRegex( edgedb.InterfaceError, r'query_single\(\) as it returns a multiset'): - await self.con.query_single('SELECT {1, 2}') + await self.client.query_single('SELECT {1, 2}') with self.assertRaisesRegex( edgedb.InterfaceError, r'query_required_single\(\) as it returns a multiset'): - await self.con.query_required_single('SELECT {1, 2}') + await self.client.query_required_single('SELECT {1, 2}') with self.assertRaisesRegex( edgedb.NoDataError, r'\bquery_required_single\('): - await self.con.query_required_single('SELECT {}') + await self.client.query_required_single('SELECT {}') async def test_async_basic_datatypes_02(self): self.assertEqual( - await self.con.query( + await self.client.query( r'''select [b"\x00a", b"b", b'', b'\na']'''), edgedb.Set([[b"\x00a", b"b", b'', b'\na']])) self.assertEqual( - await self.con.query( + await self.client.query( r'select $0', b'he\x00llo'), edgedb.Set([b'he\x00llo'])) async def test_async_basic_datatypes_03(self): for _ in range(10): # test opportunistic execute self.assertEqual( - await self.con.query_json( + await self.client.query_json( 'select ()'), '[[]]') self.assertEqual( - await self.con.query_json( + await self.client.query_json( 'select (1,)'), '[[1]]') self.assertEqual( - await self.con.query_json( + await self.client.query_json( 'select >[]'), '[[]]') self.assertEqual( json.loads( - await self.con.query_json( + await self.client.query_json( 'select ["a", "b"]')), [["a", "b"]]) self.assertEqual( json.loads( - await self.con.query_single_json( + await self.client.query_single_json( 'select ["a", "b"]')), ["a", "b"]) self.assertEqual( json.loads( - await self.con.query_json(''' + await self.client.query_json(''' SELECT {(a := 1 + 1 + 40, world := ("hello", 32)), (a:=1, world := ("yo", 10))}; ''')), @@ -361,25 +365,27 @@ async def test_async_basic_datatypes_03(self): self.assertEqual( json.loads( - await self.con.query_json('SELECT {1, 2}')), + await self.client.query_json('SELECT {1, 2}')), [1, 2]) self.assertEqual( - json.loads(await self.con.query_json('SELECT {}')), + json.loads(await self.client.query_json('SELECT {}')), []) with self.assertRaises(edgedb.NoDataError): - await self.con.query_required_single_json('SELECT {}') + await self.client.query_required_single_json( + 'SELECT {}' + ) self.assertEqual( json.loads( - await self.con.query_single_json('SELECT {}') + await self.client.query_single_json('SELECT {}') ), None ) async def test_async_basic_datatypes_04(self): - val = await self.con.query_single( + val = await self.client.query_single( ''' SELECT schema::ObjectType { foo := { @@ -407,28 +413,28 @@ async def test_async_basic_datatypes_04(self): async def test_async_args_01(self): self.assertEqual( - await self.con.query( + await self.client.query( 'select (>$foo)[0] ++ (>$bar)[0];', foo=['aaa'], bar=['bbb']), edgedb.Set(('aaabbb',))) async def test_async_args_02(self): self.assertEqual( - await self.con.query( + await self.client.query( 'select (>$0)[0] ++ (>$1)[0];', ['aaa'], ['bbb']), edgedb.Set(('aaabbb',))) async def test_async_args_03(self): with self.assertRaisesRegex(edgedb.QueryError, r'missing \$0'): - await self.con.query('select $1;') + await self.client.query('select $1;') with self.assertRaisesRegex(edgedb.QueryError, r'missing \$1'): - await self.con.query('select $0 + $2;') + await self.client.query('select $0 + $2;') with self.assertRaisesRegex(edgedb.QueryError, 'combine positional and named parameters'): - await self.con.query('select $0 + $bar;') + await self.client.query('select $0 + $bar;') async def test_async_args_04(self): aware_datetime = datetime.datetime.now(datetime.timezone.utc) @@ -439,56 +445,56 @@ async def test_async_args_04(self): aware_time = datetime.time(hour=11, tzinfo=datetime.timezone.utc) self.assertEqual( - await self.con.query_single( + await self.client.query_single( 'select $0;', aware_datetime), aware_datetime) self.assertEqual( - await self.con.query_single( + await self.client.query_single( 'select $0;', naive_datetime), naive_datetime) self.assertEqual( - await self.con.query_single( + await self.client.query_single( 'select $0;', date), date) self.assertEqual( - await self.con.query_single( + await self.client.query_single( 'select $0;', naive_time), naive_time) with self.assertRaisesRegex(edgedb.InvalidArgumentError, r'a timezone-aware.*expected'): - await self.con.query_single( + await self.client.query_single( 'select $0;', naive_datetime) with self.assertRaisesRegex(edgedb.InvalidArgumentError, r'a naive time object.*expected'): - await self.con.query_single( + await self.client.query_single( 'select $0;', aware_time) with self.assertRaisesRegex(edgedb.InvalidArgumentError, r'a naive datetime object.*expected'): - await self.con.query_single( + await self.client.query_single( 'select $0;', aware_datetime) with self.assertRaisesRegex(edgedb.InvalidArgumentError, r'datetime.datetime object was expected'): - await self.con.query_single( + await self.client.query_single( 'select $0;', date) with self.assertRaisesRegex(edgedb.InvalidArgumentError, r'datetime.datetime object was expected'): - await self.con.query_single( + await self.client.query_single( 'select $0;', date) @@ -499,11 +505,11 @@ async def _test_async_args_05(self): # XXX move to edgedb/edgedb # which would make it fail. self.assertEqual( - await self.con.query('select $a', a=1), + await self.client.query('select $a', a=1), [1] ) self.assertEqual( - await self.con.query('select $a', a=None), + await self.client.query('select $a', a=None), [] ) @@ -513,7 +519,7 @@ async def _test_async_args_06(self): # XXX move to edgedb/edgedb # client side too. self.assertEqual( - await self.con.query('select $a', a=1), + await self.client.query('select $a', a=1), [1] ) @@ -521,7 +527,7 @@ async def _test_async_args_06(self): # XXX move to edgedb/edgedb edgedb.InvalidArgumentError, r'argument \$a is required, but received None'): self.assertEqual( - await self.con.query('select $a', a=None), + await self.client.query('select $a', a=None), [] ) @@ -533,7 +539,7 @@ async def test_async_mismatched_args_01(self): "got {'[bc]', '[bc]'}, " r"missed {'a'}, extra {'[bc]', '[bc]'}"): - await self.con.query("""SELECT $a;""", b=1, c=2) + await self.client.query("""SELECT $a;""", b=1, c=2) async def test_async_mismatched_args_02(self): # XXX: remove (?:keyword )? once protocol version 0.12 is stable @@ -543,7 +549,7 @@ async def test_async_mismatched_args_02(self): r"got {'[acd]', '[acd]', '[acd]'}, " r"missed {'b'}, extra {'[cd]', '[cd]'}"): - await self.con.query(""" + await self.client.query(""" SELECT $a + $b; """, a=1, c=2, d=3) @@ -554,7 +560,7 @@ async def test_async_mismatched_args_03(self): "expected {'a'} (?:keyword )?arguments, got {'b'}, " "missed {'a'}, extra {'b'}"): - await self.con.query("""SELECT $a;""", b=1) + await self.client.query("""SELECT $a;""", b=1) async def test_async_mismatched_args_04(self): # XXX: remove (?:keyword )? once protocol version 0.12 is stable @@ -564,7 +570,7 @@ async def test_async_mismatched_args_04(self): r"got {'a'}, " r"missed {'b'}"): - await self.con.query("""SELECT $a + $b;""", a=1) + await self.client.query("""SELECT $a + $b;""", a=1) async def test_async_mismatched_args_05(self): # XXX: remove (?:keyword )? once protocol version 0.12 is stable @@ -574,34 +580,34 @@ async def test_async_mismatched_args_05(self): r"got {'[ab]', '[ab]'}, " r"extra {'b'}"): - await self.con.query("""SELECT $a;""", a=1, b=2) + await self.client.query("""SELECT $a;""", a=1, b=2) async def test_async_args_uuid_pack(self): - obj = await self.con.query_single( + obj = await self.client.query_single( 'select schema::Object {id, name} limit 1') # Test that the custom UUID that our driver uses can be # passed back as a parameter. - ot = await self.con.query_single( + ot = await self.client.query_single( 'select schema::Object {name} filter .id=$id', id=obj.id) self.assertEqual(obj, ot) # Test that a string UUID is acceptable. - ot = await self.con.query_single( + ot = await self.client.query_single( 'select schema::Object {name} filter .id=$id', id=str(obj.id)) self.assertEqual(obj, ot) # Test that a standard uuid.UUID is acceptable. - ot = await self.con.query_single( + ot = await self.client.query_single( 'select schema::Object {name} filter .id=$id', id=uuid.UUID(bytes=obj.id.bytes)) self.assertEqual(obj, ot) with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'invalid UUID.*length must be'): - await self.con.query( + await self.client.query( 'select schema::Object {name} filter .id=$id', id='asdasas') @@ -674,51 +680,51 @@ async def test_async_args_bigint_basic(self): num += random.choice("0000000012") testar.append(int(num)) - val = await self.con.query_single( + val = await self.client.query_single( 'select >$arg', arg=testar) self.assertEqual(testar, val) async def test_async_args_bigint_pack(self): - val = await self.con.query_single( + val = await self.client.query_single( 'select $arg', arg=10) self.assertEqual(val, 10) with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'expected an int'): - await self.con.query( + await self.client.query( 'select $arg', arg='bad int') with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'expected an int'): - await self.con.query( + await self.client.query( 'select $arg', arg=10.11) with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'expected an int'): - await self.con.query( + await self.client.query( 'select $arg', arg=decimal.Decimal('10.0')) with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'expected an int'): - await self.con.query( + await self.client.query( 'select $arg', arg=decimal.Decimal('10.11')) with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'expected an int'): - await self.con.query( + await self.client.query( 'select $arg', arg='10') with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'expected an int'): - await self.con.query_single( + await self.client.query_single( 'select $arg', arg=decimal.Decimal('10')) @@ -728,7 +734,7 @@ class IntLike: def __int__(self): return 10 - await self.con.query_single( + await self.client.query_single( 'select $arg', arg=IntLike()) @@ -739,19 +745,19 @@ def __int__(self): with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'expected an int'): - await self.con.query_single( + await self.client.query_single( 'select $arg', arg=IntLike()) with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'expected an int'): - await self.con.query_single( + await self.client.query_single( 'select $arg', arg=IntLike()) with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'expected an int'): - await self.con.query_single( + await self.client.query_single( 'select $arg', arg=IntLike()) @@ -760,24 +766,25 @@ class IntLike: def __int__(self): return 10 - val = await self.con.query_single('select $0', - decimal.Decimal("10.0")) + val = await self.client.query_single( + 'select $0', decimal.Decimal("10.0") + ) self.assertEqual(val, 10) with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'expected a Decimal or an int'): - await self.con.query_single( + await self.client.query_single( 'select $arg', arg=IntLike()) with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'expected a Decimal or an int'): - await self.con.query_single( + await self.client.query_single( 'select $arg', arg="10.2") async def test_async_wait_cancel_01(self): - underscored_lock = await self.con.query_single(""" + underscored_lock = await self.client.query_single(""" SELECT EXISTS( SELECT schema::Function FILTER .name = 'sys::_advisory_lock' ) @@ -789,11 +796,15 @@ async def test_async_wait_cancel_01(self): # by closing. lock_key = tb.gen_lock_key() - con = self.con.with_retry_options(RetryOptions(attempts=1)) - _con2 = await self.connect(database=self.con.dbname) - con2 = _con2.with_retry_options(RetryOptions(attempts=1)) + client = self.client.with_retry_options(RetryOptions(attempts=1)) + client2 = self.test_client( + database=self.client.dbname + ).with_retry_options( + RetryOptions(attempts=1) + ) + await client2.ensure_connected() - async for tx in con.transaction(): + async for tx in client.transaction(): async with tx: self.assertTrue(await tx.query_single( 'select sys::_advisory_lock($0)', @@ -809,7 +820,7 @@ async def exec_to_fail(): edgedb.ClientConnectionClosedError, ConnectionResetError, )): - async for tx2 in con2.transaction(): + async for tx2 in client2.transaction(): async with tx2: # start the lazy transaction await tx2.query('SELECT 42;') @@ -833,7 +844,9 @@ async def exec_to_fail(): # cancelled, which, in turn, will terminate the # connection rudely, and exec_to_fail() will get # ConnectionResetError. - await compat.wait_for(con2.aclose(), timeout=0.5) + await compat.wait_for( + client2.aclose(), timeout=0.5 + ) finally: self.assertEqual( @@ -843,7 +856,7 @@ async def exec_to_fail(): [True]) async def test_empty_set_unpack(self): - await self.con.query_single(''' + await self.client.query_single(''' select schema::Function { name, params: { @@ -856,45 +869,59 @@ async def test_empty_set_unpack(self): ''') async def test_enum_argument_01(self): - A = await self.con.query_single('SELECT $0', 'A') + A = await self.client.query_single('SELECT $0', 'A') self.assertEqual(str(A), 'A') with self.assertRaisesRegex( edgedb.InvalidValueError, 'invalid input value for enum'): - async for tx in self.con.transaction(): + async for tx in self.client.transaction(): async with tx: await tx.query_single('SELECT $0', 'Oups') self.assertEqual( - await self.con.query_single('SELECT $0', 'A'), + await self.client.query_single('SELECT $0', 'A'), A) self.assertEqual( - await self.con.query_single('SELECT $0', A), + await self.client.query_single('SELECT $0', A), A) with self.assertRaisesRegex( edgedb.InvalidValueError, 'invalid input value for enum'): - async for tx in self.con.transaction(): + async for tx in self.client.transaction(): async with tx: await tx.query_single('SELECT $0', 'Oups') with self.assertRaisesRegex( edgedb.InvalidArgumentError, 'a str or edgedb.EnumValue'): - await self.con.query_single('SELECT $0', 123) + await self.client.query_single('SELECT $0', 123) async def test_json(self): self.assertEqual( - await self.con.query_json('SELECT {"aaa", "bbb"}'), + await self.client.query_json('SELECT {"aaa", "bbb"}'), '["aaa", "bbb"]') async def test_json_elements(self): + result, _ = await self.client.connection.raw_query( + abstract.QueryContext( + query=abstract.QueryWithArgs( + 'SELECT {"aaa", "bbb"}', (), {} + ), + cache=self.client._get_query_cache(), + query_options=abstract.QueryOptions( + io_format=protocol.IoFormat.JSON_ELEMENTS, + expect_one=False, + required_one=False, + ), + retry_options=None, + ) + ) self.assertEqual( - await self.con._fetchall_json_elements('SELECT {"aaa", "bbb"}'), + result, edgedb.Set(['"aaa"', '"bbb"'])) async def test_async_cancel_01(self): - has_sleep = await self.con.query_single(""" + has_sleep = await self.client.query_single(""" SELECT EXISTS( SELECT schema::Function FILTER .name = 'sys::_sleep' ) @@ -902,24 +929,26 @@ async def test_async_cancel_01(self): if not has_sleep: self.skipTest("No sys::_sleep function") - con = await self.connect(database=self.con.dbname) + client = self.test_client(database=self.client.dbname) try: - self.assertEqual(await con.query_single('SELECT 1'), 1) + self.assertEqual(await client.query_single('SELECT 1'), 1) - conn_before = con._inner._impl + protocol_before = client._impl._holders[0]._con._protocol with self.assertRaises(asyncio.TimeoutError): await compat.wait_for( - con.query_single('SELECT sys::_sleep(10)'), + client.query_single('SELECT sys::_sleep(10)'), timeout=0.1) - await con.query('SELECT 2') + await client.query('SELECT 2') - conn_after = con._inner._impl - self.assertIsNot(conn_before, conn_after, "Reconnect expected") + protocol_after = client._impl._holders[0]._con._protocol + self.assertIsNot( + protocol_before, protocol_after, "Reconnect expected" + ) finally: - await con.aclose() + await client.aclose() async def test_async_log_message(self): msgs = [] @@ -927,13 +956,13 @@ async def test_async_log_message(self): def on_log(con, msg): msgs.append(msg) - self.con.add_log_listener(on_log) + self.client.connection.add_log_listener(on_log) try: - await self.con.query( + await self.client.query( 'configure system set __internal_restart := true;') await asyncio.sleep(0.01) # allow the loop to call the callback finally: - self.con.remove_log_listener(on_log) + self.client.connection.remove_log_listener(on_log) for msg in msgs: if (msg.get_severity_name() == 'NOTICE' and @@ -946,9 +975,9 @@ async def test_async_banned_transaction(self): with self.assertRaisesRegex( edgedb.CapabilityError, r'cannot execute transaction control commands'): - await self.con.query('start transaction') + await self.client.query('start transaction') with self.assertRaisesRegex( edgedb.CapabilityError, r'cannot execute transaction control commands'): - await self.con.execute('start transaction') + await self.client.execute('start transaction') diff --git a/tests/test_async_retry.py b/tests/test_async_retry.py index 8b269423..cc0924b4 100644 --- a/tests/test_async_retry.py +++ b/tests/test_async_retry.py @@ -65,7 +65,7 @@ class TestAsyncRetry(tb.AsyncQueryTestCase): ''' async def test_async_retry_01(self): - async for tx in self.con.transaction(): + async for tx in self.client.transaction(): async with tx: await tx.execute(''' INSERT test::Counter { @@ -75,7 +75,7 @@ async def test_async_retry_01(self): async def test_async_retry_02(self): with self.assertRaises(ZeroDivisionError): - async for tx in self.con.transaction(): + async for tx in self.client.transaction(): async with tx: await tx.execute(''' INSERT test::Counter { @@ -84,13 +84,15 @@ async def test_async_retry_02(self): ''') 1 / 0 with self.assertRaises(edgedb.NoDataError): - await self.con.query_required_single(''' + await self.client.query_required_single(''' SELECT test::Counter FILTER .name = 'counter_retry_02' ''') async def test_async_retry_begin(self): - patcher = unittest.mock.patch("edgedb.retry.AsyncIOIteration._start") + patcher = unittest.mock.patch( + "edgedb.base_client.BaseConnection.privileged_execute" + ) _start = patcher.start() def cleanup(): @@ -104,7 +106,7 @@ def cleanup(): _start.side_effect = errors.BackendUnavailableError() with self.assertRaises(errors.BackendUnavailableError): - async for tx in self.con.transaction(): + async for tx in self.client.transaction(): async with tx: await tx.execute(''' INSERT test::Counter { @@ -112,7 +114,7 @@ def cleanup(): }; ''') with self.assertRaises(edgedb.NoDataError): - await self.con.query_required_single(''' + await self.client.query_required_single(''' SELECT test::Counter FILTER .name = 'counter_retry_begin' ''') @@ -124,7 +126,7 @@ async def recover_after_first_error(*_, **__): _start.side_effect = recover_after_first_error call_count = _start.call_count - async for tx in self.con.transaction(): + async for tx in self.client.transaction(): async with tx: await tx.execute(''' INSERT test::Counter { @@ -132,7 +134,7 @@ async def recover_after_first_error(*_, **__): }; ''') self.assertEqual(_start.call_count, call_count + 1) - await self.con.query_single(''' + await self.client.query_single(''' SELECT test::Counter FILTER .name = 'counter_retry_begin' ''') @@ -148,15 +150,15 @@ async def test_async_conflict_no_retry(self): ) async def execute_conflict(self, name='counter2', options=None): - con2 = await self.connect(database=self.get_database_name()) - self.addCleanup(con2.aclose) + client2 = self.test_client(database=self.get_database_name()) + self.addCleanup(client2.aclose) barrier = Barrier(2) lock = asyncio.Lock() iterations = 0 - async def transaction1(con): - async for tx in con.transaction(): + async def transaction1(client): + async for tx in client.transaction(): nonlocal iterations iterations += 1 async with tx: @@ -188,14 +190,14 @@ async def transaction1(con): lock.release() return res - con = self.con + client = self.client if options: - con = con.with_retry_options(options) - con2 = con2.with_retry_options(options) + client = client.with_retry_options(options) + client2 = client2.with_retry_options(options) results = await compat.wait_for(asyncio.gather( - transaction1(con), - transaction1(con2), + transaction1(client), + transaction1(client2), return_exceptions=True, ), 10) for e in results: @@ -210,7 +212,7 @@ async def test_async_transaction_interface_errors(self): AttributeError, "'AsyncIOIteration' object has no attribute 'start'", ): - async for tx in self.con.transaction(): + async for tx in self.client.transaction(): async with tx: await tx.start() @@ -218,7 +220,7 @@ async def test_async_transaction_interface_errors(self): AttributeError, "'AsyncIOIteration' object has no attribute 'rollback'", ): - async for tx in self.con.transaction(): + async for tx in self.client.transaction(): async with tx: await tx.rollback() @@ -226,24 +228,19 @@ async def test_async_transaction_interface_errors(self): AttributeError, "'AsyncIOIteration' object has no attribute 'start'", ): - async for tx in self.con.transaction(): + async for tx in self.client.transaction(): await tx.start() with self.assertRaisesRegex(edgedb.InterfaceError, r'.*Use `async with transaction:`'): - async for tx in self.con.transaction(): + async for tx in self.client.transaction(): await tx.execute("SELECT 123") with self.assertRaisesRegex( edgedb.InterfaceError, r"already in an `async with` block", ): - async for tx in self.con.transaction(): + async for tx in self.client.transaction(): async with tx: async with tx: pass - - with self.assertRaisesRegex(edgedb.InterfaceError, r".*is borrowed.*"): - async for tx in self.con.transaction(): - async with tx: - await self.con.execute("SELECT 123") diff --git a/tests/test_async_tx.py b/tests/test_async_tx.py index f0cbc791..a378a5b8 100644 --- a/tests/test_async_tx.py +++ b/tests/test_async_tx.py @@ -38,24 +38,12 @@ class TestAsyncTx(tb.AsyncQueryTestCase): ''' async def test_async_transaction_regular_01(self): - self.assertIsNone(self.con._inner._borrowed_for) - tr = self.con.with_retry_options( + tr = self.client.with_retry_options( RetryOptions(attempts=1)).transaction() - self.assertIsNone(self.con._inner._borrowed_for) with self.assertRaises(ZeroDivisionError): async for with_tr in tr: async with with_tr: - self.assertIs(self.con._inner._borrowed_for, 'transaction') - - with self.assertRaisesRegex(edgedb.InterfaceError, - '.*is borrowed.*'): - await self.con.execute(''' - INSERT test::TransactionTest { - name := 'Test Transaction' - }; - ''') - await with_tr.execute(''' INSERT test::TransactionTest { name := 'Test Transaction' @@ -64,9 +52,7 @@ async def test_async_transaction_regular_01(self): 1 / 0 - self.assertIsNone(self.con._inner._borrowed_for) - - result = await self.con.query(''' + result = await self.client.query(''' SELECT test::TransactionTest FILTER @@ -90,7 +76,9 @@ async def test_async_transaction_kinds(self): ) # skip None opt = {k: v for k, v in opt.items() if v is not None} - con = self.con.with_transaction_options(TransactionOptions(**opt)) - async for tx in con.transaction(): + client = self.client.with_transaction_options( + TransactionOptions(**opt) + ) + async for tx in client.transaction(): async with tx: pass diff --git a/tests/test_asyncio_client.py b/tests/test_asyncio_client.py new file mode 100644 index 00000000..a2b98183 --- /dev/null +++ b/tests/test_asyncio_client.py @@ -0,0 +1,489 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import random + +import edgedb + +from edgedb import compat +from edgedb import _testbase as tb +from edgedb import errors +from edgedb import asyncio_client + + +class TestAsyncIOClient(tb.AsyncQueryTestCase): + def create_client(self, **kwargs): + conargs = self.get_connect_args().copy() + conargs["database"] = self.get_database_name() + conargs["timeout"] = 120 + conargs.update(kwargs) + conargs.setdefault( + "connection_class", asyncio_client.AsyncIOConnection + ) + conargs.setdefault("max_concurrency", None) + + return tb.TestAsyncIOClient(**conargs) + + async def test_client_01(self): + for n in {1, 5, 10, 20, 100}: + with self.subTest(tasksnum=n): + client = self.create_client(max_concurrency=10) + + async def worker(): + self.assertEqual(await client.query_single("SELECT 1"), 1) + + tasks = [worker() for _ in range(n)] + await asyncio.gather(*tasks) + await client.aclose() + + async def test_client_02(self): + for n in {1, 3, 5, 10, 20, 100}: + with self.subTest(tasksnum=n): + async with self.create_client(max_concurrency=5) as client: + + async def worker(): + self.assertEqual( + await client.query_single("SELECT 1"), 1 + ) + + tasks = [worker() for _ in range(n)] + await asyncio.gather(*tasks) + + async def test_client_05(self): + for n in {1, 3, 5, 10, 20, 100}: + with self.subTest(tasksnum=n): + client = self.create_client(max_concurrency=10) + + async def worker(): + self.assertEqual(await client.query('SELECT 1'), [1]) + self.assertEqual(await client.query_single('SELECT 1'), 1) + self.assertEqual( + await client.query_json('SELECT 1'), '[1]') + self.assertEqual( + await client.query_single_json('SELECT 1'), '1') + + tasks = [worker() for _ in range(n)] + await asyncio.gather(*tasks) + await client.aclose() + + async def test_client_transaction(self): + client = self.create_client(max_concurrency=1) + + async for tx in client.transaction(): + async with tx: + self.assertEqual(await tx.query_single("SELECT 7*8"), 56) + + await client.aclose() + + async def test_client_options(self): + client = self.create_client(max_concurrency=1) + + client.with_transaction_options( + edgedb.TransactionOptions(readonly=True)) + client.with_retry_options( + edgedb.RetryOptions(attempts=1, backoff=edgedb.default_backoff)) + async for tx in client.transaction(): + async with tx: + self.assertEqual(await tx.query_single("SELECT 7*8"), 56) + + await client.aclose() + + async def test_client_init_run_until_complete(self): + client = self.create_client() + self.assertIsInstance(client, asyncio_client.AsyncIOClient) + await client.aclose() + + async def test_client_no_acquire_deadlock(self): + async with self.create_client( + max_concurrency=1, + ) as client: + + has_sleep = await client.query_single(""" + SELECT EXISTS( + SELECT schema::Function FILTER .name = 'sys::_sleep' + ) + """) + if not has_sleep: + self.skipTest("No sys::_sleep function") + + async def sleep_and_release(): + await client.execute("SELECT sys::_sleep(1)") + + asyncio.ensure_future(sleep_and_release()) + await asyncio.sleep(0.5) + + await client.query_single("SELECT 1") + + async def test_client_config_persistence(self): + N = 100 + + class MyConnection(asyncio_client.AsyncIOConnection): + async def raw_query(self, query_context): + res, h = await super().raw_query(query_context) + return res + 1, h + + async def test(client): + async for tx in client.transaction(): + async with tx: + self.assertEqual(await tx.query_single("SELECT 1"), 2) + + async with self.create_client( + max_concurrency=10, + connection_class=MyConnection, + ) as client: + + await asyncio.gather(*[test(client) for _ in range(N)]) + + self.assertEqual( + sum( + 1 + for ch in client._impl._holders + if ch._con and not ch._con.is_closed() + ), + 10, + ) + + async def test_client_connection_methods(self): + async def test_query(client): + i = random.randint(0, 20) + await asyncio.sleep(random.random() / 100) + r = await client.query("SELECT {}".format(i)) + self.assertEqual(list(r), [i]) + return 1 + + async def test_query_single(client): + i = random.randint(0, 20) + await asyncio.sleep(random.random() / 100) + r = await client.query_single("SELECT {}".format(i)) + self.assertEqual(r, i) + return 1 + + async def test_execute(client): + await asyncio.sleep(random.random() / 100) + await client.execute("SELECT {1, 2, 3, 4}") + return 1 + + async def run(N, meth): + async with self.create_client(max_concurrency=10) as client: + + coros = [meth(client) for _ in range(N)] + res = await asyncio.gather(*coros) + self.assertEqual(res, [1] * N) + + methods = [ + test_query, + test_query_single, + test_execute, + ] + + with tb.silence_asyncio_long_exec_warning(): + for method in methods: + with self.subTest(method=method.__name__): + await run(200, method) + + async def test_client_handles_transaction_exit_in_asyncgen_1(self): + client = self.create_client(max_concurrency=1) + + async def iterate(): + async for tx in client.transaction(): + async with tx: + for record in await tx.query("SELECT {1, 2, 3}"): + yield record + + class MyException(Exception): + pass + + with self.assertRaises(MyException): + agen = iterate() + try: + async for _ in agen: # noqa + raise MyException() + finally: + await agen.aclose() + + await client.aclose() + + async def test_client_handles_transaction_exit_in_asyncgen_2(self): + client = self.create_client(max_concurrency=1) + + async def iterate(): + async for tx in client.transaction(): + async with tx: + for record in await tx.query("SELECT {1, 2, 3}"): + yield record + + class MyException(Exception): + pass + + with self.assertRaises(MyException): + iterator = iterate() + try: + async for _ in iterator: # noqa + raise MyException() + finally: + await iterator.aclose() + + del iterator + + await client.aclose() + + async def test_client_handles_asyncgen_finalization(self): + client = self.create_client(max_concurrency=1) + + async def iterate(tx): + for record in await tx.query("SELECT {1, 2, 3}"): + yield record + + class MyException(Exception): + pass + + with self.assertRaises(MyException): + async for tx in client.transaction(): + async with tx: + agen = iterate(tx) + try: + async for _ in agen: # noqa + raise MyException() + finally: + await agen.aclose() + + await client.aclose() + + async def test_client_close_waits_for_release(self): + client = self.create_client(max_concurrency=1) + + flag = self.loop.create_future() + conn_released = False + + async def worker(): + nonlocal conn_released + + async for tx in client.transaction(): + async with tx: + await tx.query("SELECT 42") + flag.set_result(True) + await asyncio.sleep(0.1) + + conn_released = True + + self.loop.create_task(worker()) + + await flag + await client.aclose() + self.assertTrue(conn_released) + + async def test_client_close_timeout(self): + client = self.create_client(max_concurrency=1) + + flag = self.loop.create_future() + + async def worker(): + async for tx in client.transaction(): + async with tx: + await tx.query_single("SELECT 42") + flag.set_result(True) + await asyncio.sleep(0.5) + + task = self.loop.create_task(worker()) + + with self.assertRaises(asyncio.TimeoutError): + await flag + await compat.wait_for(client.aclose(), timeout=0.1) + + with self.assertRaises(errors.ClientConnectionClosedError): + await task + + async def test_client_expire_connections(self): + class SlowCloseConnection(asyncio_client.AsyncIOConnection): + async def close(self): + await asyncio.sleep(0.2) + await super().close() + + client = self.create_client( + max_concurrency=1, connection_class=SlowCloseConnection + ) + + async for tx in client.transaction(): + async with tx: + await tx.query("SELECT 42") + self.assertIsNotNone(client._impl._holders[0]._con) + client._impl.expire_connections() + + self.assertIsNone(client._impl._holders[0]._con) + + await client.query("SELECT 42") + self.assertIsNotNone(client._impl._holders[0]._con) + + client._impl.expire_connections() + async for tx in client.transaction(): + async with tx: + await tx.query("SELECT 42") + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(client.query("SELECT 42"), 1) + + await client.aclose() + + async def test_client_properties(self): + max_concurrency = 2 + + client = self.create_client(max_concurrency=max_concurrency) + self.assertEqual(client.max_concurrency, max_concurrency) + self.assertEqual(client.max_concurrency, max_concurrency) + + async for tx in client.transaction(): + async with tx: + await tx.query("SELECT 42") + self.assertEqual(client.free_size, max_concurrency - 1) + + self.assertEqual(client.free_size, max_concurrency) + + await client.aclose() + + async def _test_connection_broken(self, executor, broken_evt): + broken_evt.set() + + with self.assertRaises(errors.ClientConnectionError): + await executor.query_single("SELECT 123") + + broken_evt.clear() + + self.assertEqual(await executor.query_single("SELECT 123"), 123) + broken_evt.set() + with self.assertRaises(errors.ClientConnectionError): + await executor.query_single("SELECT 123") + broken_evt.clear() + self.assertEqual(await executor.query_single("SELECT 123"), 123) + + tested = False + async for tx in executor.transaction(): + async with tx: + self.assertEqual(await tx.query_single("SELECT 123"), 123) + if tested: + break + tested = True + broken_evt.set() + try: + await tx.query_single("SELECT 123") + except errors.ClientConnectionError: + broken_evt.clear() + raise + else: + self.fail("ConnectionError not raised!") + + async def test_client_connection_broken(self): + con_args = self.get_connect_args() + broken = asyncio.Event() + done = asyncio.Event() + + async def proxy(r: asyncio.StreamReader, w: asyncio.StreamWriter): + while True: + reader = self.loop.create_task(r.read(65536)) + waiter = self.loop.create_task(broken.wait()) + await asyncio.wait( + [reader, waiter], + return_when=asyncio.FIRST_COMPLETED, + ) + if waiter.done(): + reader.cancel() + w.close() + break + else: + waiter.cancel() + data = await reader + if not data: + w.close() + break + w.write(data) + + async def cb(r: asyncio.StreamReader, w: asyncio.StreamWriter): + ur, uw = await asyncio.open_connection( + con_args['host'], con_args['port'] + ) + done.clear() + task = self.loop.create_task(proxy(r, uw)) + try: + await proxy(ur, w) + finally: + try: + await task + finally: + done.set() + w.close() + uw.close() + + server = await asyncio.start_server( + cb, '127.0.0.1', 0 + ) + port = server.sockets[0].getsockname()[1] + client = self.create_client( + host='127.0.0.1', + port=port, + max_concurrency=1, + wait_until_available=5, + ) + try: + await self._test_connection_broken(client, broken) + finally: + server.close() + await server.wait_closed() + await asyncio.wait_for(client.aclose(), 5) + broken.set() + await done.wait() + + async def test_client_suggested_concurrency(self): + conargs = self.get_connect_args().copy() + conargs["database"] = self.get_database_name() + conargs["timeout"] = 120 + + client = edgedb.create_async_client(**conargs) + + self.assertEqual(client.max_concurrency, 1) + + await client.ensure_connected() + self.assertGreater(client.max_concurrency, 1) + + await client.aclose() + + client = edgedb.create_async_client(**conargs, max_concurrency=5) + + self.assertEqual(client.max_concurrency, 5) + + await client.ensure_connected() + self.assertEqual(client.max_concurrency, 5) + + await client.aclose() + + def test_client_with_different_loop(self): + conargs = self.get_connect_args() + client = edgedb.create_async_client(**conargs) + + async def test(): + self.assertIsNot(asyncio.get_event_loop(), self.loop) + result = await client.query_single("SELECT 42") + self.assertEqual(result, 42) + await asyncio.gather( + client.query_single("SELECT 42"), + client.query_single("SELECT 42"), + ) + await client.aclose() + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(test()) + asyncio.set_event_loop(self.loop) diff --git a/tests/test_blocking_client.py b/tests/test_blocking_client.py new file mode 100644 index 00000000..8686246d --- /dev/null +++ b/tests/test_blocking_client.py @@ -0,0 +1,479 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import queue +import random +import threading +import time + +import edgedb + +from edgedb import _testbase as tb +from edgedb import errors +from edgedb import blocking_client + + +class TestBlockingClient(tb.SyncQueryTestCase): + def create_client(self, **kwargs): + conargs = self.get_connect_args().copy() + conargs["database"] = self.get_database_name() + conargs["timeout"] = 120 + conargs.update(kwargs) + conargs.setdefault( + "connection_class", blocking_client.BlockingIOConnection + ) + conargs.setdefault("max_concurrency", None) + + return tb.TestClient(**conargs) + + def test_client_01(self): + for n in {1, 5, 10, 20, 100}: + with self.subTest(tasksnum=n): + client = self.create_client(max_concurrency=10) + + def worker(): + self.assertEqual(client.query_single("SELECT 1"), 1) + + tasks = [threading.Thread(target=worker) for _ in range(n)] + for task in tasks: + task.start() + for task in tasks: + task.join() + client.close() + + def test_client_02(self): + for n in {1, 3, 5, 10, 20, 100}: + with self.subTest(tasksnum=n): + with self.create_client(max_concurrency=5) as client: + + def worker(): + self.assertEqual(client.query_single("SELECT 1"), 1) + + tasks = [threading.Thread(target=worker) for _ in range(n)] + for task in tasks: + task.start() + for task in tasks: + task.join() + + def test_client_05(self): + for n in {1, 3, 5, 10, 20, 100}: + with self.subTest(tasksnum=n): + client = self.create_client(max_concurrency=10) + + def worker(): + self.assertEqual(client.query('SELECT 1'), [1]) + self.assertEqual(client.query_single('SELECT 1'), 1) + self.assertEqual(client.query_json('SELECT 1'), '[1]') + self.assertEqual(client.query_single_json('SELECT 1'), '1') + + tasks = [threading.Thread(target=worker) for _ in range(n)] + for task in tasks: + task.start() + for task in tasks: + task.join() + client.close() + + def test_client_transaction(self): + client = self.create_client(max_concurrency=1) + + for tx in client.transaction(): + with tx: + self.assertEqual(tx.query_single("SELECT 7*8"), 56) + + client.close() + + def test_client_options(self): + client = self.create_client(max_concurrency=1) + + client.with_transaction_options( + edgedb.TransactionOptions(readonly=True)) + client.with_retry_options( + edgedb.RetryOptions(attempts=1, backoff=edgedb.default_backoff)) + for tx in client.transaction(): + with tx: + self.assertEqual(tx.query_single("SELECT 7*8"), 56) + + client.close() + + def test_client_init_run_until_complete(self): + client = self.create_client() + self.assertIsInstance(client, blocking_client.Client) + client.close() + + def test_client_no_acquire_deadlock(self): + with self.create_client( + max_concurrency=1, + ) as client: + + has_sleep = client.query_single(""" + SELECT EXISTS( + SELECT schema::Function FILTER .name = 'sys::_sleep' + ) + """) + if not has_sleep: + self.skipTest("No sys::_sleep function") + + def sleep_and_release(): + client.execute("SELECT sys::_sleep(1)") + + task = threading.Thread(target=sleep_and_release) + task.start() + time.sleep(0.5) + + client.query_single("SELECT 1") + task.join() + + def test_client_config_persistence(self): + N = 100 + + class MyConnection(blocking_client.BlockingIOConnection): + async def raw_query(self, query_context): + res, h = await super().raw_query(query_context) + return res + 1, h + + def test(): + for tx in client.transaction(): + with tx: + self.assertEqual(tx.query_single("SELECT 1"), 2) + # make this test more reliable by spending more time in + # the transaction to hit max_concurrency sooner + time.sleep(random.random() / 100) + + with self.create_client( + max_concurrency=10, + connection_class=MyConnection, + ) as client: + + tasks = [threading.Thread(target=test) for _ in range(N)] + for task in tasks: + task.start() + for task in tasks: + task.join() + + self.assertEqual( + sum( + 1 + for ch in client._impl._holders + if ch._con and not ch._con.is_closed() + ), + 10, + ) + + def test_client_connection_methods(self): + def test_query(client, q): + i = random.randint(0, 20) + time.sleep(random.random() / 100) + r = client.query("SELECT {}".format(i)) + self.assertEqual(list(r), [i]) + q.put(1) + + def test_query_single(client, q): + i = random.randint(0, 20) + time.sleep(random.random() / 100) + r = client.query_single("SELECT {}".format(i)) + self.assertEqual(r, i) + q.put(1) + + def test_execute(client, q): + time.sleep(random.random() / 100) + client.execute("SELECT {1, 2, 3, 4}") + q.put(1) + + def run(N, meth): + with self.create_client(max_concurrency=10) as client: + q = queue.Queue() + coros = [ + threading.Thread(target=meth, args=(client, q)) + for _ in range(N) + ] + for coro in coros: + coro.start() + for coro in coros: + coro.join() + res = [] + while not q.empty(): + res.append(q.get_nowait()) + self.assertEqual(res, [1] * N) + + methods = [ + test_query, + test_query_single, + test_execute, + ] + + for method in methods: + with self.subTest(method=method.__name__): + run(200, method) + + def test_client_handles_transaction_exit_in_gen_1(self): + client = self.create_client(max_concurrency=1) + + def iterate(): + for tx in client.transaction(): + with tx: + for record in tx.query("SELECT {1, 2, 3}"): + yield record + + class MyException(Exception): + pass + + with self.assertRaises(MyException): + agen = iterate() + try: + for _ in agen: # noqa + raise MyException() + finally: + agen.close() + + client.close() + + def test_client_handles_transaction_exit_in_gen_2(self): + client = self.create_client(max_concurrency=1) + + def iterate(): + for tx in client.transaction(): + with tx: + for record in tx.query("SELECT {1, 2, 3}"): + yield record + + class MyException(Exception): + pass + + with self.assertRaises(MyException): + iterator = iterate() + try: + for _ in iterator: # noqa + raise MyException() + finally: + iterator.close() + + del iterator + + client.close() + + def test_client_handles_gen_finalization(self): + client = self.create_client(max_concurrency=1) + + def iterate(tx): + for record in tx.query("SELECT {1, 2, 3}"): + yield record + + class MyException(Exception): + pass + + with self.assertRaises(MyException): + for tx in client.transaction(): + with tx: + agen = iterate(tx) + try: + for _ in agen: # noqa + raise MyException() + finally: + agen.close() + + client.close() + + def test_client_close_waits_for_release(self): + client = self.create_client(max_concurrency=1) + + flag = threading.Event() + conn_released = False + + def worker(): + nonlocal conn_released + + for tx in client.transaction(): + with tx: + tx.query("SELECT 42") + flag.set() + time.sleep(0.1) + + conn_released = True + + task = threading.Thread(target=worker) + task.start() + + flag.wait() + client.close() + self.assertTrue(conn_released) + task.join() + + def test_client_close_timeout(self): + client = self.create_client(max_concurrency=1) + + flag = threading.Event() + + def worker(): + with self.assertRaises(errors.ClientConnectionClosedError): + for tx in client.transaction(): + with tx: + tx.query_single("SELECT 42") + flag.set() + time.sleep(0.5) + + task = threading.Thread(target=worker) + task.start() + + flag.wait() + client.close(timeout=0.1) + + task.join() + + def test_client_expire_connections(self): + client = self.create_client(max_concurrency=1) + + for tx in client.transaction(): + with tx: + tx.query("SELECT 42") + client._impl.expire_connections() + + self.assertIsNone(client._impl._holders[0]._con) + client.close() + + def test_client_properties(self): + max_concurrency = 2 + + client = self.create_client(max_concurrency=max_concurrency) + self.assertEqual(client.max_concurrency, max_concurrency) + self.assertEqual(client.max_concurrency, max_concurrency) + + for tx in client.transaction(): + with tx: + tx.query("SELECT 42") + self.assertEqual(client.free_size, max_concurrency - 1) + + self.assertEqual(client.free_size, max_concurrency) + + client.close() + + def _test_connection_broken(self, executor, broken_evt): + self.loop.call_soon_threadsafe(broken_evt.set) + + with self.assertRaises(errors.ClientConnectionError): + executor.query_single("SELECT 123") + + self.loop.call_soon_threadsafe(broken_evt.clear) + + self.assertEqual(executor.query_single("SELECT 123"), 123) + self.loop.call_soon_threadsafe(broken_evt.set) + with self.assertRaises(errors.ClientConnectionError): + executor.query_single("SELECT 123") + self.loop.call_soon_threadsafe(broken_evt.clear) + self.assertEqual(executor.query_single("SELECT 123"), 123) + + tested = False + for tx in executor.transaction(): + with tx: + self.assertEqual(tx.query_single("SELECT 123"), 123) + if tested: + break + tested = True + self.loop.call_soon_threadsafe(broken_evt.set) + try: + tx.query_single("SELECT 123") + except errors.ClientConnectionError: + self.loop.call_soon_threadsafe(broken_evt.clear) + raise + else: + self.fail("ConnectionError not raised!") + + async def test_client_connection_broken(self): + con_args = self.get_connect_args() + broken = asyncio.Event() + done = asyncio.Event() + + async def proxy(r: asyncio.StreamReader, w: asyncio.StreamWriter): + while True: + reader = self.loop.create_task(r.read(65536)) + waiter = self.loop.create_task(broken.wait()) + await asyncio.wait( + [reader, waiter], + return_when=asyncio.FIRST_COMPLETED, + ) + if waiter.done(): + reader.cancel() + w.close() + break + else: + waiter.cancel() + data = await reader + if not data: + w.close() + break + w.write(data) + + async def cb(r: asyncio.StreamReader, w: asyncio.StreamWriter): + ur, uw = await asyncio.open_connection( + con_args['host'], con_args['port'] + ) + done.clear() + task = self.loop.create_task(proxy(r, uw)) + try: + await proxy(ur, w) + finally: + try: + await task + finally: + done.set() + w.close() + uw.close() + + server = await asyncio.start_server( + cb, '127.0.0.1', 0 + ) + port = server.sockets[0].getsockname()[1] + client = self.create_client( + host='127.0.0.1', + port=port, + max_concurrency=1, + wait_until_available=5, + ) + try: + await self.loop.run_in_executor( + None, self._test_connection_broken, client, broken + ) + finally: + server.close() + await server.wait_closed() + await self.loop.run_in_executor(None, client.close, 5) + broken.set() + await done.wait() + + def test_client_suggested_concurrency(self): + conargs = self.get_connect_args().copy() + conargs["database"] = self.get_database_name() + conargs["timeout"] = 120 + + client = edgedb.create_client(**conargs) + + self.assertEqual(client.max_concurrency, 1) + + client.ensure_connected() + self.assertGreater(client.max_concurrency, 1) + + client.close() + + client = edgedb.create_client(**conargs, max_concurrency=5) + + self.assertEqual(client.max_concurrency, 5) + + client.ensure_connected() + self.assertEqual(client.max_concurrency, 5) + + client.close() diff --git a/tests/test_connect.py b/tests/test_connect.py index 7be828a7..e22f1958 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -52,13 +52,13 @@ async def test_connect_async_01(self): edgedb.ClientConnectionError, f'(?s).*Is the server running.*port {self.port}.*'): conn_args['host'] = '127.0.0.1' - await edgedb.async_connect_raw(**conn_args) + await edgedb.create_async_client(**conn_args).ensure_connected() with self.assertRaisesRegex( edgedb.ClientConnectionError, f'(?s).*Is the server running.*port {self.port}.*'): conn_args['host'] = orig_conn_args['host'] - await edgedb.async_connect_raw(**conn_args) + await edgedb.create_async_client(**conn_args).ensure_connected() def test_connect_sync_01(self): orig_conn_args = self.get_connect_args() @@ -70,10 +70,10 @@ def test_connect_sync_01(self): edgedb.ClientConnectionError, f'(?s).*Is the server running.*port {self.port}.*'): conn_args['host'] = '127.0.0.1' - edgedb.connect(**conn_args) + edgedb.create_client(**conn_args).ensure_connected() with self.assertRaisesRegex( edgedb.ClientConnectionError, f'(?s).*Is the server running.*port {self.port}.*'): conn_args['host'] = orig_conn_args['host'] - edgedb.connect(**conn_args) + edgedb.create_client(**conn_args).ensure_connected() diff --git a/tests/test_datetime.py b/tests/test_datetime.py index 8aca2845..e66571e5 100644 --- a/tests/test_datetime.py +++ b/tests/test_datetime.py @@ -54,7 +54,7 @@ async def test_duration_01(self): durs = [timedelta(**d) for d in duration_kwargs] # Test encode/decode roundtrip - durs_from_db = self.con.query(''' + durs_from_db = self.client.query(''' WITH args := array_unpack(>$0) SELECT args; ''', durs) @@ -62,7 +62,7 @@ async def test_duration_01(self): async def test_relative_duration_01(self): try: - self.con.query("SELECT '1y'") + self.client.query("SELECT '1y'") except errors.InvalidReferenceError: self.skipTest("feature not implemented") @@ -92,13 +92,13 @@ async def test_relative_duration_01(self): # Test that RelativeDuration.__str__ formats the # same as - durs_as_text = self.con.query(''' + durs_as_text = self.client.query(''' WITH args := array_unpack(>$0) SELECT args; ''', durs) # Test encode/decode roundtrip - durs_from_db = self.con.query(''' + durs_from_db = self.client.query(''' WITH args := array_unpack(>$0) SELECT args; ''', durs) diff --git a/tests/test_enum.py b/tests/test_enum.py index a81d537b..64c955a5 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -31,9 +31,9 @@ class TestEnum(tb.AsyncQueryTestCase): ''' async def test_enum_01(self): - ct_red = await self.con.query_single('SELECT "red"') - ct_white = await self.con.query_single('SELECT "white"') - c_red = await self.con.query_single('SELECT "red"') + ct_red = await self.client.query_single('SELECT "red"') + ct_white = await self.client.query_single('SELECT "white"') + c_red = await self.client.query_single('SELECT "red"') self.assertTrue(isinstance(ct_red, edgedb.EnumValue)) self.assertTrue(isinstance(ct_red.__tid__, uuid.UUID)) diff --git a/tests/test_memory.py b/tests/test_memory.py index d89c2365..63c032e2 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -23,7 +23,7 @@ class TestConfigMemory(tb.SyncQueryTestCase): async def test_config_memory_01(self): if ( - self.con.query_required_single( + self.client.query_required_single( "select exists " "(select schema::Type filter .name = 'cfg::memory')" ) is False @@ -44,7 +44,7 @@ async def test_config_memory_01(self): # Test that ConfigMemory.__str__ formats the # same as - mem_tuples = self.con.query(''' + mem_tuples = self.client.query(''' WITH args := array_unpack(>$0) SELECT ( args, @@ -56,7 +56,7 @@ async def test_config_memory_01(self): mem_vals = [t[0] for t in mem_tuples] # Test encode/decode roundtrip - roundtrip = self.con.query(''' + roundtrip = self.client.query(''' WITH args := array_unpack(>$0) SELECT args; ''', mem_vals) diff --git a/tests/test_pool.py b/tests/test_pool.py deleted file mode 100644 index 0cecfd21..00000000 --- a/tests/test_pool.py +++ /dev/null @@ -1,48 +0,0 @@ -# -# This source file is part of the EdgeDB open source project. -# -# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import edgedb - -from edgedb import _testbase as tb - - -class TestClient(tb.AsyncQueryTestCase): - - async def test_client_suggested_concurrency(self): - conargs = self.get_connect_args().copy() - conargs["database"] = self.con.dbname - conargs["timeout"] = 120 - - client = edgedb.create_async_client(**conargs) - - self.assertEqual(client.concurrency, 1) - - await client.ensure_connected() - self.assertGreater(client.concurrency, 1) - - await client.aclose() - - client = edgedb.create_async_client(**conargs, concurrency=5) - - self.assertEqual(client.concurrency, 5) - - await client.ensure_connected() - self.assertEqual(client.concurrency, 5) - - await client.aclose() diff --git a/tests/test_proto.py b/tests/test_proto.py index 349a0bf5..34d3716a 100644 --- a/tests/test_proto.py +++ b/tests/test_proto.py @@ -21,18 +21,37 @@ import edgedb from edgedb import _testbase as tb +from edgedb import abstract +from edgedb.protocol import protocol class TestProto(tb.SyncQueryTestCase): def test_json(self): self.assertEqual( - self.con.query_json('SELECT {"aaa", "bbb"}'), + self.client.query_json('SELECT {"aaa", "bbb"}'), '["aaa", "bbb"]') def test_json_elements(self): + self.client.ensure_connected() + result, _ = self.client._iter_coroutine( + self.client.connection.raw_query( + abstract.QueryContext( + query=abstract.QueryWithArgs( + 'SELECT {"aaa", "bbb"}', (), {} + ), + cache=self.client._get_query_cache(), + query_options=abstract.QueryOptions( + io_format=protocol.IoFormat.JSON_ELEMENTS, + expect_one=False, + required_one=False, + ), + retry_options=None, + ) + ) + ) self.assertEqual( - self.con._fetchall_json_elements('SELECT {"aaa", "bbb"}'), + result, edgedb.Set(['"aaa"', '"bbb"'])) # std::datetime is now in range of Python datetime, @@ -50,7 +69,7 @@ async def test_proto_codec_error_recovery_01(self): # we know that the codec will fail. # The test will be rewritten once it's possible to override # default codecs. - self.con.query(""" + self.client.query(""" SELECT cal::to_local_date('0001-01-01 BC', 'YYYY-MM-DD AD'); """) @@ -58,7 +77,7 @@ async def test_proto_codec_error_recovery_01(self): # The protocol, though, shouldn't be in some inconsistent # state; it should allow new queries to execute successfully. self.assertEqual( - self.con.query('SELECT {"aaa", "bbb"}'), + self.client.query('SELECT {"aaa", "bbb"}'), ['aaa', 'bbb']) @unittest.skip(""" @@ -74,7 +93,7 @@ async def test_proto_codec_error_recovery_02(self): # we know that the codec will fail. # The test will be rewritten once it's possible to override # default codecs. - self.con.query(r""" + self.client.query(r""" SELECT cal::to_local_date( { '2010-01-01 AD', @@ -91,5 +110,5 @@ async def test_proto_codec_error_recovery_02(self): # The protocol, though, shouldn't be in some inconsistent # state; it should allow new queries to execute successfully. self.assertEqual( - self.con.query('SELECT {"aaa", "bbb"}'), + self.client.query('SELECT {"aaa", "bbb"}'), ['aaa', 'bbb']) diff --git a/tests/test_sync_query.py b/tests/test_sync_query.py index a5f8ce93..8f2de250 100644 --- a/tests/test_sync_query.py +++ b/tests/test_sync_query.py @@ -38,101 +38,101 @@ class TestSyncQuery(tb.SyncQueryTestCase): def test_sync_parse_error_recover_01(self): for _ in range(2): with self.assertRaises(edgedb.EdgeQLSyntaxError): - self.con.query('select syntax error') + self.client.query('select syntax error') with self.assertRaises(edgedb.EdgeQLSyntaxError): - self.con.query('select syntax error') + self.client.query('select syntax error') with self.assertRaisesRegex(edgedb.EdgeQLSyntaxError, 'Unexpected end of line'): - self.con.query('select (') + self.client.query('select (') with self.assertRaisesRegex(edgedb.EdgeQLSyntaxError, 'Unexpected end of line'): - self.con.query_json('select (') + self.client.query_json('select (') for _ in range(10): self.assertEqual( - self.con.query('select 1;'), + self.client.query('select 1;'), edgedb.Set((1,))) - self.assertFalse(self.con.is_closed()) + self.assertFalse(self.client.connection.is_closed()) def test_sync_parse_error_recover_02(self): for _ in range(2): with self.assertRaises(edgedb.EdgeQLSyntaxError): - self.con.execute('select syntax error') + self.client.execute('select syntax error') with self.assertRaises(edgedb.EdgeQLSyntaxError): - self.con.execute('select syntax error') + self.client.execute('select syntax error') for _ in range(10): - self.con.execute('select 1; select 2;'), + self.client.execute('select 1; select 2;'), def test_sync_exec_error_recover_01(self): for _ in range(2): with self.assertRaises(edgedb.DivisionByZeroError): - self.con.query('select 1 / 0;') + self.client.query('select 1 / 0;') with self.assertRaises(edgedb.DivisionByZeroError): - self.con.query('select 1 / 0;') + self.client.query('select 1 / 0;') for _ in range(10): self.assertEqual( - self.con.query('select 1;'), + self.client.query('select 1;'), edgedb.Set((1,))) def test_sync_exec_error_recover_02(self): for _ in range(2): with self.assertRaises(edgedb.DivisionByZeroError): - self.con.execute('select 1 / 0;') + self.client.execute('select 1 / 0;') with self.assertRaises(edgedb.DivisionByZeroError): - self.con.execute('select 1 / 0;') + self.client.execute('select 1 / 0;') for _ in range(10): - self.con.execute('select 1;') + self.client.execute('select 1;') def test_sync_exec_error_recover_03(self): query = 'select 10 // $0;' for i in [1, 2, 0, 3, 1, 0, 1]: if i: self.assertEqual( - self.con.query(query, i), + self.client.query(query, i), edgedb.Set([10 // i])) else: with self.assertRaises(edgedb.DivisionByZeroError): - self.con.query(query, i) + self.client.query(query, i) def test_sync_exec_error_recover_04(self): for i in [1, 2, 0, 3, 1, 0, 1]: if i: - self.con.execute(f'select 10 // {i};') + self.client.execute(f'select 10 // {i};') else: with self.assertRaises(edgedb.DivisionByZeroError): - self.con.query(f'select 10 // {i};') + self.client.query(f'select 10 // {i};') def test_sync_exec_error_recover_05(self): with self.assertRaisesRegex(edgedb.QueryError, 'cannot accept parameters'): - self.con.execute(f'select $0') + self.client.execute(f'select $0') self.assertEqual( - self.con.query('SELECT "HELLO"'), + self.client.query('SELECT "HELLO"'), ["HELLO"]) async def test_async_query_single_01(self): - res = self.con.query_single("SELECT 1") + res = self.client.query_single("SELECT 1") self.assertEqual(res, 1) - res = self.con.query_single("SELECT {}") + res = self.client.query_single("SELECT {}") self.assertEqual(res, None) - res = self.con.query_required_single("SELECT 1") + res = self.client.query_required_single("SELECT 1") self.assertEqual(res, 1) with self.assertRaises(edgedb.NoDataError): - self.con.query_required_single("SELECT {}") + self.client.query_required_single("SELECT {}") def test_sync_query_single_command_01(self): - r = self.con.query(''' + r = self.client.query(''' CREATE TYPE test::server_query_single_command_01 { CREATE REQUIRED PROPERTY server_query_single_command_01 -> std::str; @@ -140,12 +140,12 @@ def test_sync_query_single_command_01(self): ''') self.assertEqual(r, []) - r = self.con.query(''' + r = self.client.query(''' DROP TYPE test::server_query_single_command_01; ''') self.assertEqual(r, []) - r = self.con.query(''' + r = self.client.query(''' CREATE TYPE test::server_query_single_command_01 { CREATE REQUIRED PROPERTY server_query_single_command_01 -> std::str; @@ -153,12 +153,12 @@ def test_sync_query_single_command_01(self): ''') self.assertEqual(r, []) - r = self.con.query(''' + r = self.client.query(''' DROP TYPE test::server_query_single_command_01; ''') self.assertEqual(r, []) - r = self.con.query_json(''' + r = self.client.query_json(''' CREATE TYPE test::server_query_single_command_01 { CREATE REQUIRED PROPERTY server_query_single_command_01 -> std::str; @@ -166,25 +166,27 @@ def test_sync_query_single_command_01(self): ''') self.assertEqual(r, '[]') - r = self.con.query_json(''' + r = self.client.query_json(''' DROP TYPE test::server_query_single_command_01; ''') self.assertEqual(r, '[]') - self.assertTrue(self.con._get_last_status().startswith('DROP')) + self.assertTrue( + self.client.connection._get_last_status().startswith('DROP') + ) def test_sync_query_single_command_02(self): - r = self.con.query(''' + r = self.client.query(''' SET MODULE default; ''') self.assertEqual(r, []) - r = self.con.query(''' + r = self.client.query(''' SET ALIAS foo AS MODULE default; ''') self.assertEqual(r, []) - r = self.con.query(''' + r = self.client.query(''' SET MODULE default; ''') self.assertEqual(r, []) @@ -192,21 +194,21 @@ def test_sync_query_single_command_02(self): with self.assertRaisesRegex( edgedb.InterfaceError, r'query_required_single\(\)'): - self.con.query_required_single(''' + self.client.query_required_single(''' SET ALIAS bar AS MODULE std; ''') - self.con.query(''' + self.client.query(''' SET ALIAS bar AS MODULE std; ''') self.assertEqual(r, []) - r = self.con.query_json(''' + r = self.client.query_json(''' SET MODULE default; ''') self.assertEqual(r, '[]') - r = self.con.query_json(''' + r = self.client.query_json(''' SET ALIAS bar AS MODULE std; ''') self.assertEqual(r, '[]') @@ -216,32 +218,32 @@ def test_sync_query_single_command_03(self): edgedb.InterfaceError, r'cannot be executed with query_required_single\(\).*' r'not return'): - self.con.query_required_single('set module default') + self.client.query_required_single('set module default') with self.assertRaisesRegex( edgedb.InterfaceError, r'cannot be executed with query_required_single_json\(\).*' r'not return'): - self.con.query_required_single_json('set module default') + self.client.query_required_single_json('set module default') def test_sync_query_single_command_04(self): with self.assertRaisesRegex(edgedb.ProtocolError, 'expected one statement'): - self.con.query(''' + self.client.query(''' SELECT 1; SET MODULE blah; ''') with self.assertRaisesRegex(edgedb.ProtocolError, 'expected one statement'): - self.con.query_single(''' + self.client.query_single(''' SELECT 1; SET MODULE blah; ''') with self.assertRaisesRegex(edgedb.ProtocolError, 'expected one statement'): - self.con.query_json(''' + self.client.query_json(''' SELECT 1; SET MODULE blah; ''') @@ -249,27 +251,27 @@ def test_sync_query_single_command_04(self): def test_sync_basic_datatypes_01(self): for _ in range(10): self.assertEqual( - self.con.query_single( + self.client.query_single( 'select ()'), ()) self.assertEqual( - self.con.query( + self.client.query( 'select (1,)'), edgedb.Set([(1,)])) self.assertEqual( - self.con.query_single( + self.client.query_single( 'select >[]'), []) self.assertEqual( - self.con.query( + self.client.query( 'select ["a", "b"]'), edgedb.Set([["a", "b"]])) self.assertEqual( - self.con.query(''' + self.client.query(''' SELECT {(a := 1 + 1 + 40, world := ("hello", 32)), (a:=1, world := ("yo", 10))}; '''), @@ -281,55 +283,55 @@ def test_sync_basic_datatypes_01(self): with self.assertRaisesRegex( edgedb.InterfaceError, r'query cannot be executed with query_single\('): - self.con.query_single('SELECT {1, 2}') + self.client.query_single('SELECT {1, 2}') with self.assertRaisesRegex(edgedb.NoDataError, r'\bquery_required_single_json\('): - self.con.query_required_single_json('SELECT {}') + self.client.query_required_single_json('SELECT {}') def test_sync_basic_datatypes_02(self): self.assertEqual( - self.con.query( + self.client.query( r'''select [b"\x00a", b"b", b'', b'\na']'''), edgedb.Set([[b"\x00a", b"b", b'', b'\na']])) self.assertEqual( - self.con.query( + self.client.query( r'select $0', b'he\x00llo'), edgedb.Set([b'he\x00llo'])) def test_sync_basic_datatypes_03(self): for _ in range(10): self.assertEqual( - self.con.query_json( + self.client.query_json( 'select ()'), '[[]]') self.assertEqual( - self.con.query_json( + self.client.query_json( 'select (1,)'), '[[1]]') self.assertEqual( - self.con.query_json( + self.client.query_json( 'select >[]'), '[[]]') self.assertEqual( json.loads( - self.con.query_json( + self.client.query_json( 'select ["a", "b"]')), [["a", "b"]]) self.assertEqual( json.loads( - self.con.query_single_json( + self.client.query_single_json( 'select ["a", "b"]')), ["a", "b"]) self.assertEqual( json.loads( - self.con.query_json(''' + self.client.query_json(''' SELECT {(a := 1 + 1 + 40, world := ("hello", 32)), (a:=1, world := ("yo", 10))}; ''')), @@ -340,45 +342,45 @@ def test_sync_basic_datatypes_03(self): self.assertEqual( json.loads( - self.con.query_json('SELECT {1, 2}')), + self.client.query_json('SELECT {1, 2}')), [1, 2]) self.assertEqual( - json.loads(self.con.query_json('SELECT {}')), + json.loads(self.client.query_json('SELECT {}')), []) with self.assertRaises(edgedb.NoDataError): - self.con.query_required_single_json('SELECT {}') + self.client.query_required_single_json('SELECT {}') self.assertEqual( - json.loads(self.con.query_single_json('SELECT {}')), + json.loads(self.client.query_single_json('SELECT {}')), None ) def test_sync_args_01(self): self.assertEqual( - self.con.query( + self.client.query( 'select (>$foo)[0] ++ (>$bar)[0];', foo=['aaa'], bar=['bbb']), edgedb.Set(('aaabbb',))) def test_sync_args_02(self): self.assertEqual( - self.con.query( + self.client.query( 'select (>$0)[0] ++ (>$1)[0];', ['aaa'], ['bbb']), edgedb.Set(('aaabbb',))) def test_sync_args_03(self): with self.assertRaisesRegex(edgedb.QueryError, r'missing \$0'): - self.con.query('select $1;') + self.client.query('select $1;') with self.assertRaisesRegex(edgedb.QueryError, r'missing \$1'): - self.con.query('select $0 + $2;') + self.client.query('select $0 + $2;') with self.assertRaisesRegex(edgedb.QueryError, 'combine positional and named parameters'): - self.con.query('select $0 + $bar;') + self.client.query('select $0 + $bar;') def test_sync_mismatched_args_01(self): # XXX: remove (?:keyword )? once protocol version 0.12 is stable @@ -388,7 +390,7 @@ def test_sync_mismatched_args_01(self): "got {'[bc]', '[bc]'}, " r"missed {'a'}, extra {'[bc]', '[bc]'}"): - self.con.query("""SELECT $a;""", b=1, c=2) + self.client.query("""SELECT $a;""", b=1, c=2) def test_sync_mismatched_args_02(self): # XXX: remove (?:keyword )? once protocol version 0.12 is stable @@ -398,7 +400,7 @@ def test_sync_mismatched_args_02(self): r"got {'[acd]', '[acd]', '[acd]'}, " r"missed {'b'}, extra {'[cd]', '[cd]'}"): - self.con.query(""" + self.client.query(""" SELECT $a + $b; """, a=1, c=2, d=3) @@ -409,7 +411,7 @@ def test_sync_mismatched_args_03(self): "expected {'a'} (?:keyword )?arguments, got {'b'}, " "missed {'a'}, extra {'b'}"): - self.con.query("""SELECT $a;""", b=1) + self.client.query("""SELECT $a;""", b=1) def test_sync_mismatched_args_04(self): # XXX: remove (?:keyword )? once protocol version 0.12 is stable @@ -419,7 +421,7 @@ def test_sync_mismatched_args_04(self): r"got {'a'}, " r"missed {'b'}"): - self.con.query("""SELECT $a + $b;""", a=1) + self.client.query("""SELECT $a + $b;""", a=1) def test_sync_mismatched_args_05(self): # XXX: remove (?:keyword )? once protocol version 0.12 is stable @@ -429,7 +431,7 @@ def test_sync_mismatched_args_05(self): r"got {'[ab]', '[ab]'}, " r"extra {'b'}"): - self.con.query("""SELECT $a;""", a=1, b=2) + self.client.query("""SELECT $a;""", a=1, b=2) async def test_sync_log_message(self): msgs = [] @@ -437,12 +439,16 @@ async def test_sync_log_message(self): def on_log(con, msg): msgs.append(msg) - self.con.add_log_listener(on_log) + self.client.ensure_connected() + con = self.client.connection + con.add_log_listener(on_log) try: - self.con.query('configure system set __internal_restart := true;') + self.client.query( + 'configure system set __internal_restart := true;' + ) # self.con.query('SELECT 1') finally: - self.con.remove_log_listener(on_log) + con.remove_log_listener(on_log) for msg in msgs: if (msg.get_severity_name() == 'NOTICE' and diff --git a/tests/test_sync_retry.py b/tests/test_sync_retry.py index 8045c403..831f0964 100644 --- a/tests/test_sync_retry.py +++ b/tests/test_sync_retry.py @@ -62,7 +62,7 @@ class TestSyncRetry(tb.SyncQueryTestCase): ''' def test_sync_retry_01(self): - for tx in self.con.transaction(): + for tx in self.client.transaction(): with tx: tx.execute(''' INSERT test::Counter { @@ -72,7 +72,7 @@ def test_sync_retry_01(self): def test_sync_retry_02(self): with self.assertRaises(ZeroDivisionError): - for tx in self.con.transaction(): + for tx in self.client.transaction(): with tx: tx.execute(''' INSERT test::Counter { @@ -81,12 +81,12 @@ def test_sync_retry_02(self): ''') 1 / 0 with self.assertRaises(edgedb.NoDataError): - self.con.query_required_single(''' + self.client.query_required_single(''' SELECT test::Counter FILTER .name = 'counter_retry_02' ''') self.assertEqual( - self.con.query_single(''' + self.client.query_single(''' SELECT test::Counter FILTER .name = 'counter_retry_02' '''), @@ -94,7 +94,9 @@ def test_sync_retry_02(self): ) def test_sync_retry_begin(self): - patcher = unittest.mock.patch("edgedb.retry.Iteration._start") + patcher = unittest.mock.patch( + "edgedb.base_client.BaseConnection.privileged_execute" + ) _start = patcher.start() def cleanup(): @@ -108,7 +110,7 @@ def cleanup(): _start.side_effect = errors.BackendUnavailableError() with self.assertRaises(errors.BackendUnavailableError): - for tx in self.con.transaction(): + for tx in self.client.transaction(): with tx: tx.execute(''' INSERT test::Counter { @@ -116,12 +118,12 @@ def cleanup(): }; ''') with self.assertRaises(edgedb.NoDataError): - self.con.query_required_single(''' + self.client.query_required_single(''' SELECT test::Counter FILTER .name = 'counter_retry_begin' ''') self.assertEqual( - self.con.query_single(''' + self.client.query_single(''' SELECT test::Counter FILTER .name = 'counter_retry_begin' '''), @@ -135,7 +137,7 @@ def recover_after_first_error(*_, **__): _start.side_effect = recover_after_first_error call_count = _start.call_count - for tx in self.con.transaction(): + for tx in self.client.transaction(): with tx: tx.execute(''' INSERT test::Counter { @@ -143,7 +145,7 @@ def recover_after_first_error(*_, **__): }; ''') self.assertEqual(_start.call_count, call_count + 1) - self.con.query_single(''' + self.client.query_single(''' SELECT test::Counter FILTER .name = 'counter_retry_begin' ''') @@ -161,16 +163,16 @@ def test_sync_conflict_no_retry(self): def execute_conflict(self, name='counter2', options=None): con_args = self.get_connect_args().copy() con_args.update(database=self.get_database_name()) - con2 = edgedb.connect(**con_args) - self.addCleanup(con2.close) + client2 = edgedb.create_client(**con_args) + self.addCleanup(client2.close) barrier = Barrier(2) lock = threading.Lock() iterations = 0 - def transaction1(con): - for tx in con.transaction(): + def transaction1(client): + for tx in client.transaction(): nonlocal iterations iterations += 1 with tx: @@ -202,14 +204,14 @@ def transaction1(con): lock.release() return res - con = self.con + client = self.client if options: - con = con.with_retry_options(options) - con2 = con2.with_retry_options(options) + client = client.with_retry_options(options) + client2 = client2.with_retry_options(options) with futures.ThreadPoolExecutor(2) as pool: - f1 = pool.submit(transaction1, con) - f2 = pool.submit(transaction1, con2) + f1 = pool.submit(transaction1, client) + f2 = pool.submit(transaction1, client2) results = {f1.result(), f2.result()} self.assertEqual(results, {1, 2}) @@ -220,7 +222,7 @@ def test_sync_transaction_interface_errors(self): AttributeError, "'Iteration' object has no attribute 'start'", ): - for tx in self.con.transaction(): + for tx in self.client.transaction(): with tx: tx.start() @@ -228,7 +230,7 @@ def test_sync_transaction_interface_errors(self): AttributeError, "'Iteration' object has no attribute 'rollback'", ): - for tx in self.con.transaction(): + for tx in self.client.transaction(): with tx: tx.rollback() @@ -236,24 +238,19 @@ def test_sync_transaction_interface_errors(self): AttributeError, "'Iteration' object has no attribute 'start'", ): - for tx in self.con.transaction(): + for tx in self.client.transaction(): tx.start() with self.assertRaisesRegex(edgedb.InterfaceError, r'.*Use `with transaction:`'): - for tx in self.con.transaction(): + for tx in self.client.transaction(): tx.execute("SELECT 123") with self.assertRaisesRegex( edgedb.InterfaceError, r"already in a `with` block", ): - for tx in self.con.transaction(): + for tx in self.client.transaction(): with tx: with tx: pass - - with self.assertRaisesRegex(edgedb.InterfaceError, r".*is borrowed.*"): - for tx in self.con.transaction(): - with tx: - self.con.execute("SELECT 123") diff --git a/tests/test_sync_tx.py b/tests/test_sync_tx.py index baf8720c..e672f2ad 100644 --- a/tests/test_sync_tx.py +++ b/tests/test_sync_tx.py @@ -37,9 +37,7 @@ class TestSyncTx(tb.SyncQueryTestCase): ''' def test_sync_transaction_regular_01(self): - self.assertIsNone(self.con._inner._borrowed_for) - tr = self.con.transaction() - self.assertIsNone(self.con._inner._borrowed_for) + tr = self.client.transaction() with self.assertRaises(ZeroDivisionError): for with_tr in tr: @@ -52,9 +50,7 @@ def test_sync_transaction_regular_01(self): 1 / 0 - self.assertIsNone(self.con._inner._borrowed_for) - - result = self.con.query(''' + result = self.client.query(''' SELECT test::TransactionTest FILTER @@ -78,9 +74,11 @@ async def test_sync_transaction_kinds(self): ) # skip None opt = {k: v for k, v in opt.items() if v is not None} - con = self.con.with_transaction_options(TransactionOptions(**opt)) + client = self.client.with_transaction_options( + TransactionOptions(**opt) + ) try: - for tx in con.transaction(): + for tx in client.transaction(): with tx: tx.execute( 'INSERT test::TransactionTest {name := "test"}') @@ -89,6 +87,6 @@ async def test_sync_transaction_kinds(self): else: self.assertFalse(readonly) - for tx in con.transaction(): + for tx in client.transaction(): with tx: pass