diff --git a/httpx/concurrency.py b/httpx/concurrency.py index b8246b5350..2ee45a8503 100644 --- a/httpx/concurrency.py +++ b/httpx/concurrency.py @@ -108,6 +108,9 @@ async def read( return data + def is_connection_dropped(self) -> bool: + return self.stream_reader.at_eof() + class Writer(BaseWriter): def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig): diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index b1400afdf4..4a303f2779 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -102,3 +102,10 @@ def is_closed(self) -> bool: else: assert self.h11_connection is not None return self.h11_connection.is_closed + + def is_connection_dropped(self) -> bool: + if self.h2_connection is not None: + return self.h2_connection.is_connection_dropped() + else: + assert self.h11_connection is not None + return self.h11_connection.is_connection_dropped() diff --git a/httpx/dispatch/connection_pool.py b/httpx/dispatch/connection_pool.py index 0b827c128f..8e271d72b4 100644 --- a/httpx/dispatch/connection_pool.py +++ b/httpx/dispatch/connection_pool.py @@ -9,7 +9,6 @@ TimeoutTypes, VerifyTypes, ) -from ..exceptions import NotConnected from ..interfaces import AsyncDispatcher, ConcurrencyBackend from ..models import AsyncRequest, AsyncResponse, Origin from .connection import HTTPConnection @@ -121,7 +120,7 @@ async def send( except BaseException as exc: self.active_connections.remove(connection) self.max_connections.release() - if isinstance(exc, NotConnected) and allow_connection_reuse: + if allow_connection_reuse: connection = None allow_connection_reuse = False else: @@ -138,7 +137,7 @@ async def acquire_connection( if connection is None: connection = self.keepalive_connections.pop_by_origin(origin) - if connection is None: + if connection is None or connection.is_connection_dropped(): await self.max_connections.acquire() connection = HTTPConnection( origin, diff --git a/httpx/dispatch/http11.py b/httpx/dispatch/http11.py index c23953562f..dd0e24ad10 100644 --- a/httpx/dispatch/http11.py +++ b/httpx/dispatch/http11.py @@ -4,7 +4,6 @@ from ..concurrency import TimeoutFlag from ..config import TimeoutConfig, TimeoutTypes -from ..exceptions import NotConnected from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend from ..models import AsyncRequest, AsyncResponse @@ -46,12 +45,7 @@ async def send( ) -> AsyncResponse: timeout = None if timeout is None else TimeoutConfig(timeout) - try: - await self._send_request(request, timeout) - except ConnectionResetError: # pragma: nocover - # We're currently testing this case in HTTP/2. - # Really we should test it here too, but this'll do in the meantime. - raise NotConnected() from None + await self._send_request(request, timeout) task, args = self._send_request_data, [request.stream(), timeout] async with self.backend.background_manager(task, args=args): @@ -188,3 +182,6 @@ async def response_closed(self) -> None: @property def is_closed(self) -> bool: return self.h11_state.our_state in (h11.CLOSED, h11.ERROR) + + def is_connection_dropped(self) -> bool: + return self.reader.is_connection_dropped() diff --git a/httpx/dispatch/http2.py b/httpx/dispatch/http2.py index 35d487ad6b..331f82df38 100644 --- a/httpx/dispatch/http2.py +++ b/httpx/dispatch/http2.py @@ -6,7 +6,6 @@ from ..concurrency import TimeoutFlag from ..config import TimeoutConfig, TimeoutTypes -from ..exceptions import NotConnected from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend from ..models import AsyncRequest, AsyncResponse @@ -39,10 +38,7 @@ async def send( if not self.initialized: self.initiate_connection() - try: - stream_id = await self.send_headers(request, timeout) - except ConnectionResetError: - raise NotConnected() from None + stream_id = await self.send_headers(request, timeout) self.events[stream_id] = [] self.timeout_flags[stream_id] = TimeoutFlag() @@ -176,3 +172,6 @@ async def response_closed(self, stream_id: int) -> None: @property def is_closed(self) -> bool: return False + + def is_connection_dropped(self) -> bool: + return self.reader.is_connection_dropped() diff --git a/httpx/exceptions.py b/httpx/exceptions.py index 3305e2d7f7..19af3e6be4 100644 --- a/httpx/exceptions.py +++ b/httpx/exceptions.py @@ -34,13 +34,6 @@ class PoolTimeout(Timeout): # HTTP exceptions... -class NotConnected(Exception): - """ - A connection was lost at the point of starting a request, - prior to any writes succeeding. - """ - - class HttpError(Exception): """ An HTTP error occurred. diff --git a/httpx/interfaces.py b/httpx/interfaces.py index 02d11ce5b5..f058edeb6d 100644 --- a/httpx/interfaces.py +++ b/httpx/interfaces.py @@ -130,6 +130,9 @@ async def read( ) -> bytes: raise NotImplementedError() # pragma: no cover + def is_connection_dropped(self) -> bool: + raise NotImplementedError() # pragma: no cover + class BaseWriter: """ diff --git a/tests/dispatch/test_connection_pools.py b/tests/dispatch/test_connection_pools.py index 5ef884ee58..14fd40c62e 100644 --- a/tests/dispatch/test_connection_pools.py +++ b/tests/dispatch/test_connection_pools.py @@ -131,3 +131,39 @@ async def test_premature_response_close(server): await response.close() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 0 + + +@pytest.mark.asyncio +async def test_keepalive_connection_closed_by_server_is_reestablished(server): + """ + Upon keep-alive connection closed by remote a new connection should be reestablished. + """ + async with httpx.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() + + await server.shutdown() # shutdown the server to close the keep-alive connection + await server.startup() + + response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 + + +@pytest.mark.asyncio +async def test_keepalive_http2_connection_closed_by_server_is_reestablished(server): + """ + Upon keep-alive connection closed by remote a new connection should be reestablished. + """ + async with httpx.ConnectionPool() as http: + response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() + + await server.shutdown() # shutdown the server to close the keep-alive connection + await server.startup() + + response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() + assert len(http.active_connections) == 0 + assert len(http.keepalive_connections) == 1 diff --git a/tests/dispatch/test_http2.py b/tests/dispatch/test_http2.py index b19452427d..4883cf97aa 100644 --- a/tests/dispatch/test_http2.py +++ b/tests/dispatch/test_http2.py @@ -76,3 +76,22 @@ def test_http2_reconnect(): assert response_2.status_code == 200 assert json.loads(response_2.content) == {"method": "GET", "path": "/2", "body": ""} + + +def test_http2_reconnect_after_remote_closed_connection(): + """ + If a connection has been closed between requests, then we should + be seemlessly reconnected. + """ + backend = MockHTTP2Backend(app=app) + + with Client(backend=backend) as client: + response_1 = client.get("http://example.org/1") + backend.server.close_connection = True + response_2 = client.get("http://example.org/2") + + assert response_1.status_code == 200 + assert json.loads(response_1.content) == {"method": "GET", "path": "/1", "body": ""} + + assert response_2.status_code == 200 + assert json.loads(response_2.content) == {"method": "GET", "path": "/2", "body": ""} diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py index a1e02a4c18..31fc0996ea 100644 --- a/tests/dispatch/utils.py +++ b/tests/dispatch/utils.py @@ -44,6 +44,7 @@ def __init__(self, app): self.buffer = b"" self.requests = {} self.raise_disconnect = False + self.close_connection = False # BaseReader interface @@ -74,6 +75,9 @@ async def write(self, data: bytes, timeout) -> None: async def close(self) -> None: pass + def is_connection_dropped(self) -> bool: + return self.close_connection + # Server implementation def request_received(self, headers, stream_id): diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 9a8f12c7c2..354bccee73 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -10,11 +10,11 @@ CertTypes, Client, Dispatcher, - multipart, Request, Response, TimeoutTypes, VerifyTypes, + multipart, )