Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect EOF signaling remote server closed connection #143

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions httpx/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ async def read(

return data

def is_connection_dropped(self) -> bool:
return self.stream_reader.at_eof()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should clarify this seemingly simplistic implementation. The StreamReaderProtocol.connection_lost callback calls stream_reader.feed_eof which in turn sets the EOF flag which at_eof returns.



class Writer(BaseWriter):
def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
Expand Down
7 changes: 7 additions & 0 deletions httpx/dispatch/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,10 @@ def is_closed(self) -> bool:
else:
assert self.h11_connection is not None
return self.h11_connection.is_closed

def is_connection_dropped(self) -> bool:
if self.h2_connection is not None:
return self.h2_connection.is_connection_dropped()
else:
assert self.h11_connection is not None
return self.h11_connection.is_connection_dropped()
5 changes: 2 additions & 3 deletions httpx/dispatch/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
TimeoutTypes,
VerifyTypes,
)
from ..exceptions import NotConnected
from ..interfaces import AsyncDispatcher, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse, Origin
from .connection import HTTPConnection
Expand Down Expand Up @@ -121,7 +120,7 @@ async def send(
except BaseException as exc:
self.active_connections.remove(connection)
self.max_connections.release()
if isinstance(exc, NotConnected) and allow_connection_reuse:
if allow_connection_reuse:
connection = None
allow_connection_reuse = False
else:
Expand All @@ -138,7 +137,7 @@ async def acquire_connection(
if connection is None:
connection = self.keepalive_connections.pop_by_origin(origin)

if connection is None:
if connection is None or connection.is_connection_dropped():
await self.max_connections.acquire()
connection = HTTPConnection(
origin,
Expand Down
11 changes: 4 additions & 7 deletions httpx/dispatch/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from ..concurrency import TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..exceptions import NotConnected
from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse

Expand Down Expand Up @@ -46,12 +45,7 @@ async def send(
) -> AsyncResponse:
timeout = None if timeout is None else TimeoutConfig(timeout)

try:
await self._send_request(request, timeout)
except ConnectionResetError: # pragma: nocover
# We're currently testing this case in HTTP/2.
# Really we should test it here too, but this'll do in the meantime.
raise NotConnected() from None
await self._send_request(request, timeout)

task, args = self._send_request_data, [request.stream(), timeout]
async with self.backend.background_manager(task, args=args):
Expand Down Expand Up @@ -188,3 +182,6 @@ async def response_closed(self) -> None:
@property
def is_closed(self) -> bool:
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)

def is_connection_dropped(self) -> bool:
return self.reader.is_connection_dropped()
9 changes: 4 additions & 5 deletions httpx/dispatch/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from ..concurrency import TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..exceptions import NotConnected
from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse

Expand Down Expand Up @@ -39,10 +38,7 @@ async def send(
if not self.initialized:
self.initiate_connection()

try:
stream_id = await self.send_headers(request, timeout)
except ConnectionResetError:
raise NotConnected() from None
stream_id = await self.send_headers(request, timeout)

self.events[stream_id] = []
self.timeout_flags[stream_id] = TimeoutFlag()
Expand Down Expand Up @@ -176,3 +172,6 @@ async def response_closed(self, stream_id: int) -> None:
@property
def is_closed(self) -> bool:
return False

def is_connection_dropped(self) -> bool:
return self.reader.is_connection_dropped()
7 changes: 0 additions & 7 deletions httpx/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,6 @@ class PoolTimeout(Timeout):
# HTTP exceptions...


class NotConnected(Exception):
"""
A connection was lost at the point of starting a request,
prior to any writes succeeding.
"""


class HttpError(Exception):
"""
An HTTP error occurred.
Expand Down
3 changes: 3 additions & 0 deletions httpx/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ async def read(
) -> bytes:
raise NotImplementedError() # pragma: no cover

def is_connection_dropped(self) -> bool:
raise NotImplementedError() # pragma: no cover


class BaseWriter:
"""
Expand Down
36 changes: 36 additions & 0 deletions tests/dispatch/test_connection_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,39 @@ async def test_premature_response_close(server):
await response.close()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 0


@pytest.mark.asyncio
async def test_keepalive_connection_closed_by_server_is_reestablished(server):
"""
Upon keep-alive connection closed by remote a new connection should be reestablished.
"""
async with httpx.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()

await server.shutdown() # shutdown the server to close the keep-alive connection
await server.startup()

response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1


@pytest.mark.asyncio
async def test_keepalive_http2_connection_closed_by_server_is_reestablished(server):
"""
Upon keep-alive connection closed by remote a new connection should be reestablished.
"""
async with httpx.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()

await server.shutdown() # shutdown the server to close the keep-alive connection
await server.startup()

response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
19 changes: 19 additions & 0 deletions tests/dispatch/test_http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,22 @@ def test_http2_reconnect():

assert response_2.status_code == 200
assert json.loads(response_2.content) == {"method": "GET", "path": "/2", "body": ""}


def test_http2_reconnect_after_remote_closed_connection():
"""
If a connection has been closed between requests, then we should
be seemlessly reconnected.
"""
backend = MockHTTP2Backend(app=app)

with Client(backend=backend) as client:
response_1 = client.get("http://example.org/1")
backend.server.close_connection = True
response_2 = client.get("http://example.org/2")

assert response_1.status_code == 200
assert json.loads(response_1.content) == {"method": "GET", "path": "/1", "body": ""}

assert response_2.status_code == 200
assert json.loads(response_2.content) == {"method": "GET", "path": "/2", "body": ""}
4 changes: 4 additions & 0 deletions tests/dispatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, app):
self.buffer = b""
self.requests = {}
self.raise_disconnect = False
self.close_connection = False

# BaseReader interface

Expand Down Expand Up @@ -74,6 +75,9 @@ async def write(self, data: bytes, timeout) -> None:
async def close(self) -> None:
pass

def is_connection_dropped(self) -> bool:
return self.close_connection

# Server implementation

def request_received(self, headers, stream_id):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
CertTypes,
Client,
Dispatcher,
multipart,
Request,
Response,
TimeoutTypes,
VerifyTypes,
multipart,
)


Expand Down