Skip to content

Commit

Permalink
Fix missing eof when writer cancelled (#7764) (#7781)
Browse files Browse the repository at this point in the history
Fixes #5220.

I believe this is a better fix than #5238. That PR detects that we
didn't finish sending a chunked response and then closes the connection.
This PR ensures that we simply complete the chunked response by sending
the EOF bytes, allowing the connection to remain open and be reused
normally.

(cherry picked from commit 9c07121)
Dreamsorcerer authored Nov 3, 2023
1 parent cdfed8b commit 79f5266
Showing 7 changed files with 160 additions and 82 deletions.
1 change: 1 addition & 0 deletions CHANGES/7764.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed an issue when a client request is closed before completing a chunked payload -- by :user:`Dreamsorcerer`
1 change: 1 addition & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
@@ -1203,6 +1203,7 @@ async def __aexit__(
# explicitly. Otherwise connection error handling should kick in
# and close/recycle the connection as required.
self._resp.release()
await self._resp.wait_for_close()


class _WSRequestContextManager(_BaseRequestContextManager[ClientWebSocketResponse]):
59 changes: 36 additions & 23 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
@@ -584,8 +584,11 @@ async def write_bytes(
"""Support coroutines that yields bytes objects."""
# 100 response
if self._continue is not None:
await writer.drain()
await self._continue
try:
await writer.drain()
await self._continue
except asyncio.CancelledError:
return

protocol = conn.protocol
assert protocol is not None
@@ -598,8 +601,6 @@ async def write_bytes(

for chunk in self.body:
await writer.write(chunk) # type: ignore[arg-type]

await writer.write_eof()
except OSError as exc:
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
protocol.set_exception(exc)
@@ -610,12 +611,12 @@ async def write_bytes(
new_exc.__context__ = exc
new_exc.__cause__ = exc
protocol.set_exception(new_exc)
except asyncio.CancelledError as exc:
if not conn.closed:
protocol.set_exception(exc)
except asyncio.CancelledError:
await writer.write_eof()
except Exception as exc:
protocol.set_exception(exc)
else:
await writer.write_eof()
protocol.start_timeout()
finally:
self._writer = None
@@ -704,7 +705,8 @@ async def send(self, conn: "Connection") -> "ClientResponse":
async def close(self) -> None:
if self._writer is not None:
try:
await self._writer
with contextlib.suppress(asyncio.CancelledError):
await self._writer
finally:
self._writer = None

@@ -973,8 +975,7 @@ def _response_eof(self) -> None:
):
return

self._connection.release()
self._connection = None
self._release_connection()

self._closed = True
self._cleanup_writer()
@@ -986,30 +987,22 @@ def closed(self) -> bool:
def close(self) -> None:
if not self._released:
self._notify_content()
if self._closed:
return

self._closed = True
if self._loop is None or self._loop.is_closed():
return

if self._connection is not None:
self._connection.close()

This comment has been minimized.

Copy link
@bdraco

bdraco Nov 22, 2023

Member

This change replaced close with release

This comment has been minimized.

Copy link
@Dreamsorcerer

Dreamsorcerer Nov 22, 2023

Author Member

Hmm, yeah, but do we need to close the connection here?

Looking at client.py, maybe this method is only used when some kind of exception occurs, in which case it probably makes sense to close the connection...

self._connection = None
self._cleanup_writer()
self._release_connection()

def release(self) -> Any:
if not self._released:
self._notify_content()
if self._closed:
return noop()

self._closed = True
if self._connection is not None:
self._connection.release()
self._connection = None

self._cleanup_writer()
self._release_connection()
return noop()

@property
@@ -1034,10 +1027,28 @@ def raise_for_status(self) -> None:
headers=self.headers,
)

def _release_connection(self) -> None:
if self._connection is not None:
if self._writer is None:
self._connection.release()
self._connection = None
else:
self._writer.add_done_callback(lambda f: self._release_connection())

async def _wait_released(self) -> None:
if self._writer is not None:
try:
await self._writer
finally:
self._writer = None
self._release_connection()

def _cleanup_writer(self) -> None:
if self._writer is not None:
self._writer.cancel()
self._writer = None
if self._writer.done():
self._writer = None
else:
self._writer.cancel()
self._session = None

def _notify_content(self) -> None:
@@ -1066,9 +1077,10 @@ async def read(self) -> bytes:
except BaseException:
self.close()
raise
elif self._released:
elif self._released: # Response explicity released
raise ClientConnectionError("Connection closed")

await self._wait_released() # Underlying connection released
return self._body # type: ignore[no-any-return]

def get_encoding(self) -> str:
@@ -1151,3 +1163,4 @@ async def __aexit__(
# for exceptions, response object can close connection
# if state is broken
self.release()
await self.wait_for_close()
59 changes: 57 additions & 2 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
@@ -205,6 +205,7 @@ async def handler(request):
client = await aiohttp_client(app)
resp = await client.get("/")
assert resp.closed
await resp.wait_for_close()
assert 1 == len(client._session.connector._conns)


@@ -224,6 +225,60 @@ async def handler(request):
assert content == b""


async def test_stream_request_on_server_eof(aiohttp_client) -> None:
async def handler(request):
return web.Response(text="OK", status=200)

app = web.Application()
app.add_routes([web.get("/", handler)])
app.add_routes([web.put("/", handler)])

client = await aiohttp_client(app)

async def data_gen():
for _ in range(2):
yield b"just data"
await asyncio.sleep(0.1)

async with client.put("/", data=data_gen()) as resp:
assert 200 == resp.status
assert len(client.session.connector._acquired) == 1
conn = next(iter(client.session.connector._acquired))

async with client.get("/") as resp:
assert 200 == resp.status

# Connection should have been reused
conns = next(iter(client.session.connector._conns.values()))
assert len(conns) == 1
assert conns[0][0] is conn


async def test_stream_request_on_server_eof_nested(aiohttp_client) -> None:
async def handler(request):
return web.Response(text="OK", status=200)

app = web.Application()
app.add_routes([web.get("/", handler)])
app.add_routes([web.put("/", handler)])

client = await aiohttp_client(app)

async def data_gen():
for _ in range(2):
yield b"just data"
await asyncio.sleep(0.1)

async with client.put("/", data=data_gen()) as resp:
assert 200 == resp.status
async with client.get("/") as resp:
assert 200 == resp.status

# Should be 2 separate connections
conns = next(iter(client.session.connector._conns.values()))
assert len(conns) == 2


async def test_HTTP_304_WITH_BODY(aiohttp_client) -> None:
async def handler(request):
body = await request.read()
@@ -306,8 +361,8 @@ async def handler(request):
client = await aiohttp_client(app)

with io.BytesIO(data) as file_handle:
resp = await client.post("/", data=file_handle)
assert 200 == resp.status
async with client.post("/", data=file_handle) as resp:
assert 200 == resp.status


async def test_post_data_with_bytesio_file(aiohttp_client) -> None:
Loading

0 comments on commit 79f5266

Please sign in to comment.