diff --git a/CHANGES/9839.misc.rst b/CHANGES/9839.misc.rst new file mode 100644 index 00000000000..8bdd50268a7 --- /dev/null +++ b/CHANGES/9839.misc.rst @@ -0,0 +1 @@ +Implemented zero copy writes for ``StreamWriter`` -- by :user:`bdraco`. diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index a1a9860b48d..c6c80edc3c4 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -2,7 +2,16 @@ import asyncio import zlib -from typing import Any, Awaitable, Callable, NamedTuple, Optional, Union # noqa +from typing import ( # noqa + Any, + Awaitable, + Callable, + Iterable, + List, + NamedTuple, + Optional, + Union, +) from multidict import CIMultiDict @@ -76,6 +85,17 @@ def _write(self, chunk: bytes) -> None: raise ClientConnectionResetError("Cannot write to closing transport") transport.write(chunk) + def _writelines(self, chunks: Iterable[bytes]) -> None: + size = 0 + for chunk in chunks: + size += len(chunk) + self.buffer_size += size + self.output_size += size + transport = self._protocol.transport + if transport is None or transport.is_closing(): + raise ClientConnectionResetError("Cannot write to closing transport") + transport.writelines(chunks) + async def write( self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000 ) -> None: @@ -110,10 +130,11 @@ async def write( if chunk: if self.chunked: - chunk_len_pre = ("%x\r\n" % len(chunk)).encode("ascii") - chunk = chunk_len_pre + chunk + b"\r\n" - - self._write(chunk) + self._writelines( + (f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n") + ) + else: + self._write(chunk) if self.buffer_size > LIMIT and drain: self.buffer_size = 0 @@ -142,22 +163,31 @@ async def write_eof(self, chunk: bytes = b"") -> None: await self._on_chunk_sent(chunk) if self._compress: - if chunk: - chunk = await self._compress.compress(chunk) + chunks: List[bytes] = [] + chunks_len = 0 + if chunk and (compressed_chunk := await self._compress.compress(chunk)): + chunks_len = len(compressed_chunk) + chunks.append(compressed_chunk) - chunk += self._compress.flush() - if chunk and self.chunked: - chunk_len = ("%x\r\n" % len(chunk)).encode("ascii") - chunk = chunk_len + chunk + b"\r\n0\r\n\r\n" - else: - if self.chunked: - if chunk: - chunk_len = ("%x\r\n" % len(chunk)).encode("ascii") - chunk = chunk_len + chunk + b"\r\n0\r\n\r\n" - else: - chunk = b"0\r\n\r\n" + flush_chunk = self._compress.flush() + chunks_len += len(flush_chunk) + chunks.append(flush_chunk) + assert chunks_len - if chunk: + if self.chunked: + chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii") + self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n")) + elif len(chunks) > 1: + self._writelines(chunks) + else: + self._write(chunks[0]) + elif self.chunked: + if chunk: + chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii") + self._writelines((chunk_len_pre, chunk, b"\r\n0\r\n\r\n")) + else: + self._write(b"0\r\n\r\n") + elif chunk: self._write(chunk) await self.drain() diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 213b9562ee1..d20374ee163 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -5,7 +5,16 @@ import sys import zlib from http.cookies import BaseCookie, Morsel, SimpleCookie -from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Protocol +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterable, + Iterator, + List, + Protocol, +) from unittest import mock import pytest @@ -73,12 +82,17 @@ def protocol( @pytest.fixture def transport(buf: bytearray) -> mock.Mock: - transport = mock.create_autospec(asyncio.Transport, spec_set=True) + transport = mock.create_autospec(asyncio.Transport, spec_set=True, instance=True) def write(chunk: bytes) -> None: buf.extend(chunk) + def writelines(chunks: Iterable[bytes]) -> None: + for chunk in chunks: + buf.extend(chunk) + transport.write.side_effect = write + transport.writelines.side_effect = writelines transport.is_closing.return_value = False return transport # type: ignore[no-any-return] diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 856b1c4ce6a..f6828e19385 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -1,7 +1,8 @@ # Tests for aiohttp/http_writer.py import array import asyncio -from typing import Any +import zlib +from typing import Any, Iterable from unittest import mock import pytest @@ -24,7 +25,12 @@ def transport(buf: bytearray) -> Any: def write(chunk: bytes) -> None: buf.extend(chunk) + def writelines(chunks: Iterable[bytes]) -> None: + for chunk in chunks: + buf.extend(chunk) + transport.write.side_effect = write + transport.writelines.side_effect = writelines transport.is_closing.return_value = False return transport @@ -105,6 +111,32 @@ async def test_write_payload_length( assert b"da" == content.split(b"\r\n\r\n", 1)[-1] +async def test_write_large_payload_deflate_compression_data_in_eof( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + + await msg.write(b"data" * 4096) + assert transport.write.called # type: ignore[attr-defined] + chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] + transport.write.reset_mock() # type: ignore[attr-defined] + assert not transport.writelines.called # type: ignore[attr-defined] + + # This payload compresses to 20447 bytes + payload = b"".join( + [bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)] + ) + await msg.write_eof(payload) + assert not transport.write.called # type: ignore[attr-defined] + assert transport.writelines.called # type: ignore[attr-defined] + chunks.extend(transport.writelines.mock_calls[0][1][0]) # type: ignore[attr-defined] + content = b"".join(chunks) + assert zlib.decompress(content) == (b"data" * 4096) + payload + + async def test_write_payload_chunked_filter( protocol: BaseProtocol, transport: asyncio.Transport, @@ -116,11 +148,12 @@ async def test_write_payload_chunked_filter( await msg.write(b"ta") await msg.write_eof() - content = b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] + content = b"".join([b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)]) # type: ignore[attr-defined] + content += b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] assert content.endswith(b"2\r\nda\r\n2\r\nta\r\n0\r\n\r\n") -async def test_write_payload_chunked_filter_mutiple_chunks( +async def test_write_payload_chunked_filter_multiple_chunks( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, @@ -133,7 +166,8 @@ async def test_write_payload_chunked_filter_mutiple_chunks( await msg.write(b"at") await msg.write(b"a2") await msg.write_eof() - content = b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] + content = b"".join([b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)]) # type: ignore[attr-defined] + content += b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] assert content.endswith( b"2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n2\r\na2\r\n0\r\n\r\n" ) @@ -156,6 +190,24 @@ async def test_write_payload_deflate_compression( assert COMPRESSED == content.split(b"\r\n\r\n", 1)[-1] +async def test_write_payload_deflate_compression_chunked( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + expected = b"2\r\nx\x9c\r\na\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n0\r\n\r\n" + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + await msg.write(b"data") + await msg.write_eof() + + chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined] + assert all(chunks) + content = b"".join(chunks) + assert content == expected + + async def test_write_payload_deflate_and_chunked( buf: bytearray, protocol: BaseProtocol, @@ -174,6 +226,65 @@ async def test_write_payload_deflate_and_chunked( assert thing == buf +async def test_write_payload_deflate_compression_chunked_data_in_eof( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + expected = b"2\r\nx\x9c\r\nd\r\nKI,IL\xcdK\x01\x00\x0b@\x02\xd2\r\n0\r\n\r\n" + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + await msg.write(b"data") + await msg.write_eof(b"end") + + chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined] + assert all(chunks) + content = b"".join(chunks) + assert content == expected + + +async def test_write_large_payload_deflate_compression_chunked_data_in_eof( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + + await msg.write(b"data" * 4096) + # This payload compresses to 1111 bytes + payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)]) + await msg.write_eof(payload) + assert not transport.write.called # type: ignore[attr-defined] + + chunks = [] + for write_lines_call in transport.writelines.mock_calls: # type: ignore[attr-defined] + chunked_payload = list(write_lines_call[1][0])[1:] + chunked_payload.pop() + chunks.extend(chunked_payload) + + assert all(chunks) + content = b"".join(chunks) + assert zlib.decompress(content) == (b"data" * 4096) + payload + + +async def test_write_payload_deflate_compression_chunked_connection_lost( + protocol: BaseProtocol, + transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, +) -> None: + msg = http.StreamWriter(protocol, loop) + msg.enable_compression("deflate") + msg.enable_chunking() + await msg.write(b"data") + with pytest.raises( + ClientConnectionResetError, match="Cannot write to closing transport" + ), mock.patch.object(transport, "is_closing", return_value=True): + await msg.write_eof(b"end") + + async def test_write_payload_bytes_memoryview( buf: bytearray, protocol: BaseProtocol,