Skip to content

Commit

Permalink
Add connect retries (#221)
Browse files Browse the repository at this point in the history
* Add connect retries

* Update tests/async_tests/test_retries.py

Co-authored-by: Jamie Hewland <[email protected]>

* Unasync

Co-authored-by: Jamie Hewland <[email protected]>
  • Loading branch information
florimondmanca and JayH5 authored Oct 12, 2020
1 parent 11f537e commit f9f59af
Show file tree
Hide file tree
Showing 15 changed files with 457 additions and 42 deletions.
51 changes: 34 additions & 17 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import Optional, Tuple, cast

from .._backends.auto import AsyncBackend, AsyncLock, AsyncSocketStream, AutoBackend
from .._exceptions import ConnectError, ConnectTimeout
from .._types import URL, Headers, Origin, TimeoutDict
from .._utils import get_logger, url_to_origin
from .._utils import exponential_backoff, get_logger, url_to_origin
from .base import (
AsyncByteStream,
AsyncHTTPTransport,
Expand All @@ -14,6 +15,8 @@

logger = get_logger(__name__)

RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.


class AsyncHTTPConnection(AsyncHTTPTransport):
def __init__(
Expand All @@ -24,6 +27,7 @@ def __init__(
ssl_context: SSLContext = None,
socket: AsyncSocketStream = None,
local_address: str = None,
retries: int = 0,
backend: AsyncBackend = None,
):
self.origin = origin
Expand All @@ -32,6 +36,7 @@ def __init__(
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket
self.local_address = local_address
self.retries = retries

if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
Expand Down Expand Up @@ -103,22 +108,34 @@ async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream:
scheme, hostname, port = self.origin
timeout = {} if timeout is None else timeout
ssl_context = self.ssl_context if scheme == b"https" else None
try:
if self.uds is None:
return await self.backend.open_tcp_stream(
hostname,
port,
ssl_context,
timeout,
local_address=self.local_address,
)
else:
return await self.backend.open_uds_stream(
self.uds, hostname, ssl_context, timeout
)
except Exception: # noqa: PIE786
self.connect_failed = True
raise

retries_left = self.retries
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)

while True:
try:
if self.uds is None:
return await self.backend.open_tcp_stream(
hostname,
port,
ssl_context,
timeout,
local_address=self.local_address,
)
else:
return await self.backend.open_uds_stream(
self.uds, hostname, ssl_context, timeout
)
except (ConnectError, ConnectTimeout):
if retries_left <= 0:
self.connect_failed = True
raise
retries_left -= 1
delay = next(delays)
await self.backend.sleep(delay)
except Exception: # noqa: PIE786
self.connect_failed = True
raise

def _create_connection(self, socket: AsyncSocketStream) -> None:
http_version = socket.get_http_version()
Expand Down
26 changes: 22 additions & 4 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
import warnings
from ssl import SSLContext
from typing import AsyncIterator, Callable, Dict, List, Optional, Set, Tuple, cast
from typing import (
AsyncIterator,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast,
)

from .._backends.auto import AsyncLock, AsyncSemaphore
from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore
from .._backends.base import lookup_async_backend
from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol
from .._threadlock import ThreadLock
Expand Down Expand Up @@ -84,6 +94,8 @@ class AsyncConnectionPool(AsyncHTTPTransport):
`local_address="0.0.0.0"` will connect using an `AF_INET` address (IPv4),
while using `local_address="::"` will connect using an `AF_INET6` address
(IPv6).
* **retries** - `int` - The maximum number of retries when trying to establish a
connection.
* **backend** - `str` - A name indicating which concurrency backend to use.
"""

Expand All @@ -96,8 +108,9 @@ def __init__(
http2: bool = False,
uds: str = None,
local_address: str = None,
retries: int = 0,
max_keepalive: int = None,
backend: str = "auto",
backend: Union[AsyncBackend, str] = "auto",
):
if max_keepalive is not None:
warnings.warn(
Expand All @@ -106,16 +119,20 @@ def __init__(
)
max_keepalive_connections = max_keepalive

if isinstance(backend, str):
backend = lookup_async_backend(backend)

self._ssl_context = SSLContext() if ssl_context is None else ssl_context
self._max_connections = max_connections
self._max_keepalive_connections = max_keepalive_connections
self._keepalive_expiry = keepalive_expiry
self._http2 = http2
self._uds = uds
self._local_address = local_address
self._retries = retries
self._connections: Dict[Origin, Set[AsyncHTTPConnection]] = {}
self._thread_lock = ThreadLock()
self._backend = lookup_async_backend(backend)
self._backend = backend
self._next_keepalive_check = 0.0

if http2:
Expand Down Expand Up @@ -157,6 +174,7 @@ def _create_connection(
uds=self._uds,
ssl_context=self._ssl_context,
local_address=self._local_address,
retries=self._retries,
backend=self._backend,
)

Expand Down
3 changes: 3 additions & 0 deletions httpcore/_backends/anyio.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,6 @@ def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:

async def time(self) -> float:
return await anyio.current_time()

async def sleep(self, seconds: float) -> None:
await anyio.sleep(seconds)
3 changes: 3 additions & 0 deletions httpcore/_backends/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,6 @@ def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:
async def time(self) -> float:
loop = asyncio.get_event_loop()
return loop.time()

async def sleep(self, seconds: float) -> None:
await asyncio.sleep(seconds)
3 changes: 3 additions & 0 deletions httpcore/_backends/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,6 @@ def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:

async def time(self) -> float:
return await self.backend.time()

async def sleep(self, seconds: float) -> None:
await self.backend.sleep(seconds)
3 changes: 3 additions & 0 deletions httpcore/_backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,6 @@ def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:

async def time(self) -> float:
raise NotImplementedError() # pragma: no cover

async def sleep(self, seconds: float) -> None:
raise NotImplementedError() # pragma: no cover
3 changes: 3 additions & 0 deletions httpcore/_backends/curio.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,6 @@ def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:

async def time(self) -> float:
return await curio.clock()

async def sleep(self, seconds: float) -> None:
await curio.sleep(seconds)
3 changes: 3 additions & 0 deletions httpcore/_backends/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,6 @@ def create_semaphore(self, max_value: int, exc_class: type) -> SyncSemaphore:

def time(self) -> float:
return time.monotonic()

def sleep(self, seconds: float) -> None:
time.sleep(seconds)
3 changes: 3 additions & 0 deletions httpcore/_backends/trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,6 @@ def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:

async def time(self) -> float:
return trio.current_time()

async def sleep(self, seconds: float) -> None:
await trio.sleep(seconds)
51 changes: 34 additions & 17 deletions httpcore/_sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import Optional, Tuple, cast

from .._backends.sync import SyncBackend, SyncLock, SyncSocketStream, SyncBackend
from .._exceptions import ConnectError, ConnectTimeout
from .._types import URL, Headers, Origin, TimeoutDict
from .._utils import get_logger, url_to_origin
from .._utils import exponential_backoff, get_logger, url_to_origin
from .base import (
SyncByteStream,
SyncHTTPTransport,
Expand All @@ -14,6 +15,8 @@

logger = get_logger(__name__)

RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.


class SyncHTTPConnection(SyncHTTPTransport):
def __init__(
Expand All @@ -24,6 +27,7 @@ def __init__(
ssl_context: SSLContext = None,
socket: SyncSocketStream = None,
local_address: str = None,
retries: int = 0,
backend: SyncBackend = None,
):
self.origin = origin
Expand All @@ -32,6 +36,7 @@ def __init__(
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket
self.local_address = local_address
self.retries = retries

if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
Expand Down Expand Up @@ -103,22 +108,34 @@ def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream:
scheme, hostname, port = self.origin
timeout = {} if timeout is None else timeout
ssl_context = self.ssl_context if scheme == b"https" else None
try:
if self.uds is None:
return self.backend.open_tcp_stream(
hostname,
port,
ssl_context,
timeout,
local_address=self.local_address,
)
else:
return self.backend.open_uds_stream(
self.uds, hostname, ssl_context, timeout
)
except Exception: # noqa: PIE786
self.connect_failed = True
raise

retries_left = self.retries
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)

while True:
try:
if self.uds is None:
return self.backend.open_tcp_stream(
hostname,
port,
ssl_context,
timeout,
local_address=self.local_address,
)
else:
return self.backend.open_uds_stream(
self.uds, hostname, ssl_context, timeout
)
except (ConnectError, ConnectTimeout):
if retries_left <= 0:
self.connect_failed = True
raise
retries_left -= 1
delay = next(delays)
self.backend.sleep(delay)
except Exception: # noqa: PIE786
self.connect_failed = True
raise

def _create_connection(self, socket: SyncSocketStream) -> None:
http_version = socket.get_http_version()
Expand Down
26 changes: 22 additions & 4 deletions httpcore/_sync/connection_pool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
import warnings
from ssl import SSLContext
from typing import Iterator, Callable, Dict, List, Optional, Set, Tuple, cast
from typing import (
Iterator,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast,
)

from .._backends.sync import SyncLock, SyncSemaphore
from .._backends.sync import SyncBackend, SyncLock, SyncSemaphore
from .._backends.base import lookup_sync_backend
from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol
from .._threadlock import ThreadLock
Expand Down Expand Up @@ -84,6 +94,8 @@ class SyncConnectionPool(SyncHTTPTransport):
`local_address="0.0.0.0"` will connect using an `AF_INET` address (IPv4),
while using `local_address="::"` will connect using an `AF_INET6` address
(IPv6).
* **retries** - `int` - The maximum number of retries when trying to establish a
connection.
* **backend** - `str` - A name indicating which concurrency backend to use.
"""

Expand All @@ -96,8 +108,9 @@ def __init__(
http2: bool = False,
uds: str = None,
local_address: str = None,
retries: int = 0,
max_keepalive: int = None,
backend: str = "sync",
backend: Union[SyncBackend, str] = "sync",
):
if max_keepalive is not None:
warnings.warn(
Expand All @@ -106,16 +119,20 @@ def __init__(
)
max_keepalive_connections = max_keepalive

if isinstance(backend, str):
backend = lookup_sync_backend(backend)

self._ssl_context = SSLContext() if ssl_context is None else ssl_context
self._max_connections = max_connections
self._max_keepalive_connections = max_keepalive_connections
self._keepalive_expiry = keepalive_expiry
self._http2 = http2
self._uds = uds
self._local_address = local_address
self._retries = retries
self._connections: Dict[Origin, Set[SyncHTTPConnection]] = {}
self._thread_lock = ThreadLock()
self._backend = lookup_sync_backend(backend)
self._backend = backend
self._next_keepalive_check = 0.0

if http2:
Expand Down Expand Up @@ -157,6 +174,7 @@ def _create_connection(
uds=self._uds,
ssl_context=self._ssl_context,
local_address=self._local_address,
retries=self._retries,
backend=self._backend,
)

Expand Down
7 changes: 7 additions & 0 deletions httpcore/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import logging
import os
import sys
Expand Down Expand Up @@ -63,3 +64,9 @@ def origin_to_url_string(origin: Origin) -> str:
scheme, host, explicit_port = origin
port = f":{explicit_port}" if explicit_port != DEFAULT_PORTS[scheme] else ""
return f"{scheme.decode('ascii')}://{host.decode('ascii')}{port}"


def exponential_backoff(factor: float) -> typing.Iterator[float]:
yield 0
for n in itertools.count(2):
yield factor * (2 ** (n - 2))
Loading

0 comments on commit f9f59af

Please sign in to comment.