From eaede3d0a2e5897f2728ceb9c7bf8b1d694688d2 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 10 May 2017 16:40:24 +0200 Subject: [PATCH] Support for connection timeout param Related to #184, it allows aioredis to configure a limited time that will be used trying to open a connection, if it is reached a `asyncio.TimeoutError` will be raised By default any timeout is configured. --- aioredis/connection.py | 15 ++++++++++----- aioredis/pool.py | 10 ++++++++-- docs/api_reference.rst | 18 +++++++++++++++++- tests/connection_test.py | 16 ++++++++++++++++ tests/pool_test.py | 8 ++++++++ 5 files changed, 59 insertions(+), 8 deletions(-) diff --git a/aioredis/connection.py b/aioredis/connection.py index cd8249ad0..7aa82da1a 100644 --- a/aioredis/connection.py +++ b/aioredis/connection.py @@ -42,7 +42,7 @@ @asyncio.coroutine def create_connection(address, *, db=None, password=None, ssl=None, - encoding=None, loop=None): + encoding=None, loop=None, timeout=None): """Creates redis connection. Opens connection to Redis server specified by address argument. @@ -54,6 +54,10 @@ def create_connection(address, *, db=None, password=None, ssl=None, SSL argument is passed through to asyncio.create_connection. By default SSL/TLS is not used. + By default any timeout is applied at the connection stage, however + you can set a limitted time used trying to open a connection via + the `timeout` Kw. + Encoding argument can be used to decode byte-replies to strings. By default no decoding is done. @@ -66,8 +70,8 @@ def create_connection(address, *, db=None, password=None, ssl=None, if isinstance(address, (list, tuple)): host, port = address logger.debug("Creating tcp connection to %r", address) - reader, writer = yield from asyncio.open_connection( - host, port, ssl=ssl, loop=loop) + reader, writer = yield from asyncio.wait_for(asyncio.open_connection( + host, port, ssl=ssl, loop=loop), timeout, loop=loop) sock = writer.transport.get_extra_info('socket') if sock is not None: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -75,8 +79,9 @@ def create_connection(address, *, db=None, password=None, ssl=None, address = tuple(address[:2]) else: logger.debug("Creating unix connection to %r", address) - reader, writer = yield from asyncio.open_unix_connection( - address, ssl=ssl, loop=loop) + reader, writer = yield from asyncio.wait_for( + asyncio.open_unix_connection(address, ssl=ssl, loop=loop), + timeout, loop=loop) sock = writer.transport.get_extra_info('socket') if sock is not None: address = sock.getpeername() diff --git a/aioredis/pool.py b/aioredis/pool.py index 63be43811..9187e0ed3 100644 --- a/aioredis/pool.py +++ b/aioredis/pool.py @@ -14,7 +14,9 @@ @asyncio.coroutine def create_pool(address, *, db=0, password=None, ssl=None, encoding=None, - minsize=1, maxsize=10, commands_factory=_NOTSET, loop=None): + minsize=1, maxsize=10, commands_factory=_NOTSET, + loop=None, timeout_create_connection=None): + # FIXME: rewrite docstring """Creates Redis Pool. By default it creates pool of Redis instances, but it is @@ -37,6 +39,7 @@ def create_pool(address, *, db=0, password=None, ssl=None, encoding=None, pool = RedisPool(address, db, password, encoding, minsize=minsize, maxsize=maxsize, commands_factory=commands_factory, + timeout_create_connection=timeout_create_connection, ssl=ssl, loop=loop) try: yield from pool._fill_free(override_min=False) @@ -53,7 +56,8 @@ class RedisPool: """ def __init__(self, address, db=0, password=None, encoding=None, - *, minsize, maxsize, commands_factory, ssl=None, loop=None): + *, minsize, maxsize, commands_factory, ssl=None, + timeout_create_connection=None, loop=None): assert isinstance(minsize, int) and minsize >= 0, ( "minsize must be int >= 0", minsize, type(minsize)) assert maxsize is not None, "Arbitrary pool size is disallowed." @@ -70,6 +74,7 @@ def __init__(self, address, db=0, password=None, encoding=None, self._encoding = encoding self._minsize = minsize self._factory = commands_factory + self._timeout_create_connection = timeout_create_connection self._loop = loop self._pool = collections.deque(maxlen=maxsize) self._used = set() @@ -251,6 +256,7 @@ def _fill_free(self, *, override_min): # connection may be closed at yield point self._drop_closed() + @asyncio.coroutine def _create_new_connection(self): return create_redis(self._address, db=self._db, diff --git a/docs/api_reference.rst b/docs/api_reference.rst index ff89d0287..56cf1185c 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -36,10 +36,13 @@ Connection usage is as simple as: .. cofunction:: create_connection(address, \*, db=0, password=None, ssl=None,\ - encoding=None, loop=None) + encoding=None, loop=None, timeout=None) Creates Redis connection. + .. versionchanged:: v0.3.1 + ``timeout`` argument added. + :param address: An address where to connect. Can be a (host, port) tuple or unix domain socket path string. :type address: tuple or str @@ -61,6 +64,11 @@ Connection usage is as simple as: (uses :func:`asyncio.get_event_loop` if not specified). :type loop: :ref:`EventLoop` + :param timeout: Max time used to open a connection, otherwise + raise `asyncio.TimeoutError` exception. + ``None`` by default + :type timeout: float or None + :return: :class:`RedisConnection` instance. @@ -238,6 +246,9 @@ The library provides connections pool. The basic usage is as follows: .. deprecated:: v0.2.9 *commands_factory* argument is deprecated and will be removed in *v0.3*. + .. versionchanged:: v0.3.1 + ``timeout_create_connection`` argument added. + :param address: An address where to connect. Can be a (host, port) tuple or unix domain socket path string. :type address: tuple or str @@ -271,6 +282,11 @@ The library provides connections pool. The basic usage is as follows: (uses :func:`asyncio.get_event_loop` if not specified). :type loop: :ref:`EventLoop` + :param timeout_create_connection: Max time used to open a connection, + otherwise raise an `asyncio.TimeoutError`. + ``None`` by default. + :type timeout_create_connection: float or None + :return: :class:`RedisPool` instance. diff --git a/tests/connection_test.py b/tests/connection_test.py index 05c231998..540a59f17 100644 --- a/tests/connection_test.py +++ b/tests/connection_test.py @@ -32,6 +32,13 @@ def test_connect_tcp(request, create_connection, loop, server): assert str(conn) == '' +@pytest.mark.run_loop +def test_connect_tcp_timeout(request, create_connection, loop, server): + with pytest.raises(asyncio.TimeoutError): + yield from create_connection( + server.tcp_address, loop=loop, timeout=0) + + @pytest.mark.run_loop @pytest.mark.skipif(sys.platform == 'win32', reason="No unixsocket on Windows") @@ -43,6 +50,15 @@ def test_connect_unixsocket(create_connection, loop, server): assert str(conn) == '' +@pytest.mark.run_loop +@pytest.mark.skipif(sys.platform == 'win32', + reason="No unixsocket on Windows") +def test_connect_unixsocket_timeout(create_connection, loop, server): + with pytest.raises(asyncio.TimeoutError): + yield from create_connection( + server.unixsocket, db=0, loop=loop, timeout=0) + + def test_global_loop(create_connection, loop, server): asyncio.set_event_loop(loop) diff --git a/tests/pool_test.py b/tests/pool_test.py index 8f3721a05..6c96ac9b3 100644 --- a/tests/pool_test.py +++ b/tests/pool_test.py @@ -57,6 +57,14 @@ def test_maxsize(maxsize, create_pool, loop, server): minsize=2, maxsize=maxsize, loop=loop) +@pytest.mark.run_loop +def test_create_connection_timeout(create_pool, loop, server): + with pytest.raises(asyncio.TimeoutError): + yield from create_pool( + server.tcp_address, loop=loop, + timeout_create_connection=0) + + def test_no_yield_from(pool): with pytest.raises(RuntimeError): with pool: