Skip to content
This repository has been archived by the owner on Feb 21, 2023. It is now read-only.

Commit

Permalink
Support for connection timeout param
Browse files Browse the repository at this point in the history
Related to #184, it allows aioredis to configure a limited
time that will be used trying to open a connection, if it is
reached a `asyncio.TimeoutError` will be raised

By default any timeout is configured.
  • Loading branch information
pfreixes authored and popravich committed Jun 21, 2017
1 parent 6b13dd1 commit eaede3d
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 8 deletions.
15 changes: 10 additions & 5 deletions aioredis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

@asyncio.coroutine
def create_connection(address, *, db=None, password=None, ssl=None,
encoding=None, loop=None):
encoding=None, loop=None, timeout=None):
"""Creates redis connection.
Opens connection to Redis server specified by address argument.
Expand All @@ -54,6 +54,10 @@ def create_connection(address, *, db=None, password=None, ssl=None,
SSL argument is passed through to asyncio.create_connection.
By default SSL/TLS is not used.
By default any timeout is applied at the connection stage, however
you can set a limitted time used trying to open a connection via
the `timeout` Kw.
Encoding argument can be used to decode byte-replies to strings.
By default no decoding is done.
Expand All @@ -66,17 +70,18 @@ def create_connection(address, *, db=None, password=None, ssl=None,
if isinstance(address, (list, tuple)):
host, port = address
logger.debug("Creating tcp connection to %r", address)
reader, writer = yield from asyncio.open_connection(
host, port, ssl=ssl, loop=loop)
reader, writer = yield from asyncio.wait_for(asyncio.open_connection(
host, port, ssl=ssl, loop=loop), timeout, loop=loop)
sock = writer.transport.get_extra_info('socket')
if sock is not None:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
address = sock.getpeername()
address = tuple(address[:2])
else:
logger.debug("Creating unix connection to %r", address)
reader, writer = yield from asyncio.open_unix_connection(
address, ssl=ssl, loop=loop)
reader, writer = yield from asyncio.wait_for(
asyncio.open_unix_connection(address, ssl=ssl, loop=loop),
timeout, loop=loop)
sock = writer.transport.get_extra_info('socket')
if sock is not None:
address = sock.getpeername()
Expand Down
10 changes: 8 additions & 2 deletions aioredis/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

@asyncio.coroutine
def create_pool(address, *, db=0, password=None, ssl=None, encoding=None,
minsize=1, maxsize=10, commands_factory=_NOTSET, loop=None):
minsize=1, maxsize=10, commands_factory=_NOTSET,
loop=None, timeout_create_connection=None):
# FIXME: rewrite docstring
"""Creates Redis Pool.
By default it creates pool of Redis instances, but it is
Expand All @@ -37,6 +39,7 @@ def create_pool(address, *, db=0, password=None, ssl=None, encoding=None,
pool = RedisPool(address, db, password, encoding,
minsize=minsize, maxsize=maxsize,
commands_factory=commands_factory,
timeout_create_connection=timeout_create_connection,
ssl=ssl, loop=loop)
try:
yield from pool._fill_free(override_min=False)
Expand All @@ -53,7 +56,8 @@ class RedisPool:
"""

def __init__(self, address, db=0, password=None, encoding=None,
*, minsize, maxsize, commands_factory, ssl=None, loop=None):
*, minsize, maxsize, commands_factory, ssl=None,
timeout_create_connection=None, loop=None):
assert isinstance(minsize, int) and minsize >= 0, (
"minsize must be int >= 0", minsize, type(minsize))
assert maxsize is not None, "Arbitrary pool size is disallowed."
Expand All @@ -70,6 +74,7 @@ def __init__(self, address, db=0, password=None, encoding=None,
self._encoding = encoding
self._minsize = minsize
self._factory = commands_factory
self._timeout_create_connection = timeout_create_connection
self._loop = loop
self._pool = collections.deque(maxlen=maxsize)
self._used = set()
Expand Down Expand Up @@ -251,6 +256,7 @@ def _fill_free(self, *, override_min):
# connection may be closed at yield point
self._drop_closed()

@asyncio.coroutine
def _create_new_connection(self):
return create_redis(self._address,
db=self._db,
Expand Down
18 changes: 17 additions & 1 deletion docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ Connection usage is as simple as:
.. cofunction:: create_connection(address, \*, db=0, password=None, ssl=None,\
encoding=None, loop=None)
encoding=None, loop=None, timeout=None)

Creates Redis connection.

.. versionchanged:: v0.3.1
``timeout`` argument added.

:param address: An address where to connect. Can be a (host, port) tuple or
unix domain socket path string.
:type address: tuple or str
Expand All @@ -61,6 +64,11 @@ Connection usage is as simple as:
(uses :func:`asyncio.get_event_loop` if not specified).
:type loop: :ref:`EventLoop<asyncio-event-loop>`

:param timeout: Max time used to open a connection, otherwise
raise `asyncio.TimeoutError` exception.
``None`` by default
:type timeout: float or None

:return: :class:`RedisConnection` instance.


Expand Down Expand Up @@ -238,6 +246,9 @@ The library provides connections pool. The basic usage is as follows:
.. deprecated:: v0.2.9
*commands_factory* argument is deprecated and will be removed in *v0.3*.

.. versionchanged:: v0.3.1
``timeout_create_connection`` argument added.

:param address: An address where to connect. Can be a (host, port) tuple or
unix domain socket path string.
:type address: tuple or str
Expand Down Expand Up @@ -271,6 +282,11 @@ The library provides connections pool. The basic usage is as follows:
(uses :func:`asyncio.get_event_loop` if not specified).
:type loop: :ref:`EventLoop<asyncio-event-loop>`

:param timeout_create_connection: Max time used to open a connection,
otherwise raise an `asyncio.TimeoutError`.
``None`` by default.
:type timeout_create_connection: float or None

:return: :class:`RedisPool` instance.


Expand Down
16 changes: 16 additions & 0 deletions tests/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def test_connect_tcp(request, create_connection, loop, server):
assert str(conn) == '<RedisConnection [db:0]>'


@pytest.mark.run_loop
def test_connect_tcp_timeout(request, create_connection, loop, server):
with pytest.raises(asyncio.TimeoutError):
yield from create_connection(
server.tcp_address, loop=loop, timeout=0)


@pytest.mark.run_loop
@pytest.mark.skipif(sys.platform == 'win32',
reason="No unixsocket on Windows")
Expand All @@ -43,6 +50,15 @@ def test_connect_unixsocket(create_connection, loop, server):
assert str(conn) == '<RedisConnection [db:0]>'


@pytest.mark.run_loop
@pytest.mark.skipif(sys.platform == 'win32',
reason="No unixsocket on Windows")
def test_connect_unixsocket_timeout(create_connection, loop, server):
with pytest.raises(asyncio.TimeoutError):
yield from create_connection(
server.unixsocket, db=0, loop=loop, timeout=0)


def test_global_loop(create_connection, loop, server):
asyncio.set_event_loop(loop)

Expand Down
8 changes: 8 additions & 0 deletions tests/pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ def test_maxsize(maxsize, create_pool, loop, server):
minsize=2, maxsize=maxsize, loop=loop)


@pytest.mark.run_loop
def test_create_connection_timeout(create_pool, loop, server):
with pytest.raises(asyncio.TimeoutError):
yield from create_pool(
server.tcp_address, loop=loop,
timeout_create_connection=0)


def test_no_yield_from(pool):
with pytest.raises(RuntimeError):
with pool:
Expand Down

0 comments on commit eaede3d

Please sign in to comment.