Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement zero copy writes in StreamWriter #9839

Merged
merged 26 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/9839.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implemented zero copy writes for ``StreamWriter`` -- by :user:`bdraco`.
68 changes: 49 additions & 19 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@

import asyncio
import zlib
from typing import Any, Awaitable, Callable, NamedTuple, Optional, Union # noqa
from typing import ( # noqa
Any,
Awaitable,
Callable,
Iterable,
List,
NamedTuple,
Optional,
Union,
)
Dismissed Show dismissed Hide dismissed

from multidict import CIMultiDict

Expand Down Expand Up @@ -76,6 +85,17 @@ def _write(self, chunk: bytes) -> None:
raise ClientConnectionResetError("Cannot write to closing transport")
transport.write(chunk)

def _writelines(self, chunks: Iterable[bytes]) -> None:
size = 0
for chunk in chunks:
size += len(chunk)
self.buffer_size += size
self.output_size += size
transport = self._protocol.transport
if transport is None or transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
transport.writelines(chunks)

async def write(
self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000
) -> None:
Expand Down Expand Up @@ -110,10 +130,11 @@ async def write(

if chunk:
if self.chunked:
chunk_len_pre = ("%x\r\n" % len(chunk)).encode("ascii")
chunk = chunk_len_pre + chunk + b"\r\n"

self._write(chunk)
self._writelines(
(f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n")
)
else:
self._write(chunk)

if self.buffer_size > LIMIT and drain:
self.buffer_size = 0
Expand Down Expand Up @@ -142,22 +163,31 @@ async def write_eof(self, chunk: bytes = b"") -> None:
await self._on_chunk_sent(chunk)

if self._compress:
if chunk:
chunk = await self._compress.compress(chunk)
chunks: List[bytes] = []
chunks_len = 0
if chunk and (compressed_chunk := await self._compress.compress(chunk)):
chunks_len = len(compressed_chunk)
chunks.append(compressed_chunk)

chunk += self._compress.flush()
if chunk and self.chunked:
chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
else:
if self.chunked:
if chunk:
chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
else:
chunk = b"0\r\n\r\n"
flush_chunk = self._compress.flush()
chunks_len += len(flush_chunk)
chunks.append(flush_chunk)
assert chunks_len

if chunk:
if self.chunked:
chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n"))
elif len(chunks) > 1:
self._writelines(chunks)
else:
self._write(chunks[0])
elif self.chunked:
if chunk:
chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
self._writelines((chunk_len_pre, chunk, b"\r\n0\r\n\r\n"))
else:
self._write(b"0\r\n\r\n")
elif chunk:
self._write(chunk)

await self.drain()
Expand Down
18 changes: 16 additions & 2 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
import sys
import zlib
from http.cookies import BaseCookie, Morsel, SimpleCookie
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Protocol
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterable,
Iterator,
List,
Protocol,
)
from unittest import mock

import pytest
Expand Down Expand Up @@ -73,12 +82,17 @@ def protocol(

@pytest.fixture
def transport(buf: bytearray) -> mock.Mock:
transport = mock.create_autospec(asyncio.Transport, spec_set=True)
transport = mock.create_autospec(asyncio.Transport, spec_set=True, instance=True)

def write(chunk: bytes) -> None:
buf.extend(chunk)

def writelines(chunks: Iterable[bytes]) -> None:
for chunk in chunks:
buf.extend(chunk)

transport.write.side_effect = write
transport.writelines.side_effect = writelines
transport.is_closing.return_value = False

return transport # type: ignore[no-any-return]
bdraco marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
119 changes: 115 additions & 4 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Tests for aiohttp/http_writer.py
import array
import asyncio
from typing import Any
import zlib
from typing import Any, Iterable
from unittest import mock

import pytest
Expand All @@ -24,7 +25,12 @@ def transport(buf: bytearray) -> Any:
def write(chunk: bytes) -> None:
buf.extend(chunk)

def writelines(chunks: Iterable[bytes]) -> None:
for chunk in chunks:
buf.extend(chunk)

transport.write.side_effect = write
transport.writelines.side_effect = writelines
transport.is_closing.return_value = False
return transport

Expand Down Expand Up @@ -105,6 +111,32 @@ async def test_write_payload_length(
assert b"da" == content.split(b"\r\n\r\n", 1)[-1]


async def test_write_large_payload_deflate_compression_data_in_eof(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")

await msg.write(b"data" * 4096)
assert transport.write.called # type: ignore[attr-defined]
chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined]
transport.write.reset_mock() # type: ignore[attr-defined]
assert not transport.writelines.called # type: ignore[attr-defined]

# This payload compresses to 20447 bytes
payload = b"".join(
[bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)]
)
await msg.write_eof(payload)
assert not transport.write.called # type: ignore[attr-defined]
assert transport.writelines.called # type: ignore[attr-defined]
chunks.extend(transport.writelines.mock_calls[0][1][0]) # type: ignore[attr-defined]
content = b"".join(chunks)
assert zlib.decompress(content) == (b"data" * 4096) + payload


async def test_write_payload_chunked_filter(
protocol: BaseProtocol,
transport: asyncio.Transport,
Expand All @@ -116,11 +148,12 @@ async def test_write_payload_chunked_filter(
await msg.write(b"ta")
await msg.write_eof()

content = b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined]
content = b"".join([b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)]) # type: ignore[attr-defined]
content += b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined]
assert content.endswith(b"2\r\nda\r\n2\r\nta\r\n0\r\n\r\n")


async def test_write_payload_chunked_filter_mutiple_chunks(
async def test_write_payload_chunked_filter_multiple_chunks(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
Expand All @@ -133,7 +166,8 @@ async def test_write_payload_chunked_filter_mutiple_chunks(
await msg.write(b"at")
await msg.write(b"a2")
await msg.write_eof()
content = b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined]
content = b"".join([b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)]) # type: ignore[attr-defined]
content += b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined]
assert content.endswith(
b"2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n2\r\na2\r\n0\r\n\r\n"
)
Expand All @@ -156,6 +190,24 @@ async def test_write_payload_deflate_compression(
assert COMPRESSED == content.split(b"\r\n\r\n", 1)[-1]


async def test_write_payload_deflate_compression_chunked(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
expected = b"2\r\nx\x9c\r\na\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n0\r\n\r\n"
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()
await msg.write(b"data")
await msg.write_eof()

chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined]
assert all(chunks)
content = b"".join(chunks)
assert content == expected


async def test_write_payload_deflate_and_chunked(
buf: bytearray,
protocol: BaseProtocol,
Expand All @@ -174,6 +226,65 @@ async def test_write_payload_deflate_and_chunked(
assert thing == buf


async def test_write_payload_deflate_compression_chunked_data_in_eof(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
expected = b"2\r\nx\x9c\r\nd\r\nKI,IL\xcdK\x01\x00\x0b@\x02\xd2\r\n0\r\n\r\n"
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()
await msg.write(b"data")
await msg.write_eof(b"end")

chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined]
assert all(chunks)
content = b"".join(chunks)
assert content == expected


async def test_write_large_payload_deflate_compression_chunked_data_in_eof(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()

await msg.write(b"data" * 4096)
# This payload compresses to 1111 bytes
payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)])
await msg.write_eof(payload)
assert not transport.write.called # type: ignore[attr-defined]

chunks = []
for write_lines_call in transport.writelines.mock_calls: # type: ignore[attr-defined]
chunked_payload = list(write_lines_call[1][0])[1:]
chunked_payload.pop()
chunks.extend(chunked_payload)

assert all(chunks)
content = b"".join(chunks)
assert zlib.decompress(content) == (b"data" * 4096) + payload


async def test_write_payload_deflate_compression_chunked_connection_lost(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()
await msg.write(b"data")
with pytest.raises(
ClientConnectionResetError, match="Cannot write to closing transport"
), mock.patch.object(transport, "is_closing", return_value=True):
await msg.write_eof(b"end")


async def test_write_payload_bytes_memoryview(
buf: bytearray,
protocol: BaseProtocol,
Expand Down
Loading